diff options
Diffstat (limited to 'parse/infer.c')
-rw-r--r-- | parse/infer.c | 100 |
1 files changed, 63 insertions, 37 deletions
diff --git a/parse/infer.c b/parse/infer.c index a05dd64..3d84dd1 100644 --- a/parse/infer.c +++ b/parse/infer.c @@ -33,7 +33,7 @@ struct Traitmap { static void infernode(Node **np, Type *ret, int *sawret); static void inferexpr(Node **np, Type *ret, int *sawret); static void inferdecl(Node *n); -static int tryconstrain(Type *ty, Trait *tr); +static int tryconstrain(Type *ty, Trait *tr, int update); static Type *tf(Type *t); @@ -353,6 +353,7 @@ occurs_rec(Type *sub, Bitset *bs) { size_t i; + sub = tf(sub); if (bshas(bs, sub->tid)) return 1; bsput(bs, sub->tid); @@ -390,8 +391,9 @@ occursin(Type *a, Type *b) int r; bs = mkbs(); - bsput(bs, b->tid); - r = occurs_rec(a, bs); + a = tf(a); + bsput(bs, a->tid); + r = occurs_rec(b, bs); bsfree(bs); return r; } @@ -413,6 +415,7 @@ needfreshenrec(Type *t, Bitset *visited) { size_t i; + t = tysearch(t); if (bshas(visited, t->tid)) return 0; bsput(visited, t->tid); @@ -480,10 +483,10 @@ static void tyresolve(Type *t) { size_t i; + Trait *tr; if (t->resolved) return; - /* type resolution should never throw errors about non-generic * showing up within a generic type, so we push and pop a generic * around resolution */ @@ -523,6 +526,15 @@ tyresolve(Type *t) break; } + for (i = 0; i < t->ntraits; i++) { + tr = gettrait(curstab(), t->traits[i]); + if (!tr) + lfatal(t->loc, "trait %s does not exist", ctxstr(t->traits[i])); + if (!t->trneed) + t->trneed = mkbs(); + bsput(t->trneed, tr->uid); + } + for (i = 0; i < t->nsub; i++) { t->sub[i] = tf(t->sub[i]); if (t->sub[i] == t) { @@ -629,8 +641,10 @@ tf(Type *orig) popenv(orig->env); } else if (orig->type == Typaram) { tt = boundtype(t); - if (tt) + if (tt) { + tyresolve(tt); t = tt; + } } ingeneric -= isgeneric; return t; @@ -768,7 +782,7 @@ tymatchrank(Type *pat, Type *to) if (!pat->trneed) return 0; for (i = 0; bsiter(pat->trneed, &i); i++) - if (!tryconstrain(to, traittab[i])) + if (!tryconstrain(to, traittab[i], 0)) return -1; return 0; } else if (pat->type == Tyvar) { @@ -848,7 +862,7 @@ tymatchrank(Type *pat, Type *to) } static int -tryconstrain(Type *base, Trait *tr) +tryconstrain(Type *base, Trait *tr, int update) { Traitmap *tm; Bitset *bs; @@ -859,12 +873,14 @@ tryconstrain(Type *base, Trait *tr) ty = base; tm = traitmap->sub[ty->type]; while (1) { - if (ty->type == Typaram && bshas(ty->trneed, tr->uid)) - return 1; + if (ty->type == Typaram) + if (ty->trneed && bshas(ty->trneed, tr->uid)) + return 1; if (ty->type == Tyvar) { if (!ty->trneed) ty->trneed = mkbs(); - bsput(ty->trneed, tr->uid); + if (update) + bsput(ty->trneed, tr->uid); return 1; } if (bshas(tm->traits, tr->uid)) @@ -880,11 +896,12 @@ tryconstrain(Type *base, Trait *tr) if (tymatchrank(tm->filter[i], ty) >= 0) return 1; } - if (!tm->sub[ty->type]) + if (!ty->sub || ty->nsub != 1) break; - assert(ty->nsub == 1); - tm = tm->sub[ty->type]; ty = ty->sub[0]; + tm = tm->sub[ty->type]; + if (!tm) + break; } if (base->type != Tyname) break; @@ -900,7 +917,7 @@ tryconstrain(Type *base, Trait *tr) static void constrain(Node *ctx, Type *base, Trait *tr) { - if (!tryconstrain(base, tr)) + if (!tryconstrain(base, tr, 1)) fatal(ctx, "%s needs trait %s near %s", tystr(base), namestr(tr->name), ctxstr(ctx)); } @@ -908,6 +925,7 @@ static void traitsfor(Type *base, Bitset *dst) { Traitmap *tm; + Bitset *bs; Type *ty; size_t i; @@ -917,7 +935,12 @@ traitsfor(Type *base, Bitset *dst) while (1) { if (ty->type == Tyvar) break; - bsunion(dst, tm->traits); + if (ty->type == Tyname && ty->ngparam == 0) + bs = htget(tm->name, ty->name); + else + bs = tm->traits; + if (bs) + bsunion(dst, bs); for (i = 0; i < tm->nfilter; i++) { if (tymatchrank(tm->filter[i], ty) >= 0) bsput(dst, tm->filtertr[i]->uid); @@ -1247,7 +1270,6 @@ unifycall(Node *n) Type *ft; ft = type(n->expr.args[0]); - if (ft->type == Tyvar) { /* the first arg is the function itself, so it shouldn't be counted */ ft = mktyfunc(n->loc, &n->expr.args[1], n->expr.nargs - 1, mktyvar(n->loc)); @@ -1643,7 +1665,7 @@ inferexpr(Node **np, Type *ret, int *sawret) case Odiveq: /* @a /= @a -> @a */ infersub(n, ret, sawret, &isconst); t = type(args[0]); - constrain(n, type(args[0]), traittab[Tcnum]); + constrain(n, t, traittab[Tcnum]); isconst = args[0]->expr.isconst; for (i = 1; i < nargs; i++) { isconst = isconst && args[i]->expr.isconst; @@ -1671,8 +1693,8 @@ inferexpr(Node **np, Type *ret, int *sawret) case Obsreq: /* @a >>= @a -> @a */ infersub(n, ret, sawret, &isconst); t = type(args[0]); - constrain(n, type(args[0]), traittab[Tcnum]); - constrain(n, type(args[0]), traittab[Tcint]); + constrain(n, t, traittab[Tcnum]); + constrain(n, t, traittab[Tcint]); isconst = args[0]->expr.isconst; for (i = 1; i < nargs; i++) { isconst = isconst && args[i]->expr.isconst; @@ -1901,19 +1923,21 @@ specializeimpl(Node *n) Node *dcl, *proto, *name, *sym; Tysubst *subst; Type *ty; - Trait *t; + Trait *tr; size_t i, j; int generic; + char *traitns; - t = gettrait(curstab(), n->impl.traitname); - if (!t) + tr = gettrait(curstab(), n->impl.traitname); + if (!tr) fatal(n, "no trait %s\n", namestr(n->impl.traitname)); - n->impl.trait = t; + n->impl.trait = tr; + traitns = tr->name->name.ns; dcl = NULL; - if (n->impl.naux != t->naux) + if (n->impl.naux != tr->naux) fatal(n, "%s incompatibly specialized with %zd types instead of %zd types", - namestr(n->impl.traitname), n->impl.naux, t->naux); + namestr(n->impl.traitname), n->impl.naux, tr->naux); n->impl.type = tf(n->impl.type); pushenv(n->impl.type->env); for (i = 0; i < n->impl.naux; i++) @@ -1931,25 +1955,27 @@ specializeimpl(Node *n) here. */ if (file->file.globls->name) - setns(dcl->decl.name, file->file.globls->name); - for (j = 0; j < t->nproto; j++) { - if (nsnameeq(dcl->decl.name, t->proto[j]->decl.name)) { - proto = t->proto[j]; + setns(dcl->decl.name, traitns); + for (j = 0; j < tr->nproto; j++) { + if (nsnameeq(dcl->decl.name, tr->proto[j]->decl.name)) { + proto = tr->proto[j]; break; } } if (!proto) fatal(n, "declaration %s missing in %s, near %s", namestr(dcl->decl.name), - namestr(t->name), ctxstr(n)); + namestr(tr->name), ctxstr(n)); /* infer and unify types */ - verifytraits(n, t->param, n->impl.type); + pushenv(proto->decl.env); + verifytraits(n, tr->param, n->impl.type); subst = mksubst(); - substput(subst, t->param, n->impl.type); - for (j = 0; j < t->naux; j++) - substput(subst, t->aux[j], n->impl.aux[j]); + substput(subst, tr->param, n->impl.type); + for (j = 0; j < tr->naux; j++) + substput(subst, tr->aux[j], n->impl.aux[j]); ty = tyspecialize(type(proto), subst, delayed, NULL); substfree(subst); + popenv(proto->decl.env); generic = hasparams(ty); if (generic) @@ -1963,7 +1989,7 @@ specializeimpl(Node *n) sym = getdcl(file->file.globls, name); if (sym) fatal(n, "trait %s already specialized with %s on %s:%d", - namestr(t->name), tystr(n->impl.type), + namestr(tr->name), tystr(n->impl.type), fname(sym->loc), lnum(sym->loc)); dcl->decl.name = name; putdcl(file->file.globls, dcl); @@ -1974,7 +2000,7 @@ specializeimpl(Node *n) lappend(&proto->decl.gimpl, &proto->decl.ngimpl, dcl); lappend(&proto->decl.gtype, &proto->decl.ngtype, ty); } - dcl->decl.vis = t->vis; + dcl->decl.vis = tr->vis; lappend(&impldecl, &nimpldecl, dcl); if (generic) @@ -2792,7 +2818,7 @@ addtraittab(Traitmap *m, Trait *tr, Type *ty) size_t i; if (!m->sub[ty->type]) - m = mktraitmap(); + m->sub[ty->type] = mktraitmap(); mm = m->sub[ty->type]; switch (ty->type) { case Tygeneric: |