summaryrefslogtreecommitdiff
path: root/parse/infer.c
diff options
context:
space:
mode:
Diffstat (limited to 'parse/infer.c')
-rw-r--r--parse/infer.c100
1 files changed, 63 insertions, 37 deletions
diff --git a/parse/infer.c b/parse/infer.c
index a05dd64..3d84dd1 100644
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -33,7 +33,7 @@ struct Traitmap {
static void infernode(Node **np, Type *ret, int *sawret);
static void inferexpr(Node **np, Type *ret, int *sawret);
static void inferdecl(Node *n);
-static int tryconstrain(Type *ty, Trait *tr);
+static int tryconstrain(Type *ty, Trait *tr, int update);
static Type *tf(Type *t);
@@ -353,6 +353,7 @@ occurs_rec(Type *sub, Bitset *bs)
{
size_t i;
+ sub = tf(sub);
if (bshas(bs, sub->tid))
return 1;
bsput(bs, sub->tid);
@@ -390,8 +391,9 @@ occursin(Type *a, Type *b)
int r;
bs = mkbs();
- bsput(bs, b->tid);
- r = occurs_rec(a, bs);
+ a = tf(a);
+ bsput(bs, a->tid);
+ r = occurs_rec(b, bs);
bsfree(bs);
return r;
}
@@ -413,6 +415,7 @@ needfreshenrec(Type *t, Bitset *visited)
{
size_t i;
+ t = tysearch(t);
if (bshas(visited, t->tid))
return 0;
bsput(visited, t->tid);
@@ -480,10 +483,10 @@ static void
tyresolve(Type *t)
{
size_t i;
+ Trait *tr;
if (t->resolved)
return;
-
/* type resolution should never throw errors about non-generic
* showing up within a generic type, so we push and pop a generic
* around resolution */
@@ -523,6 +526,15 @@ tyresolve(Type *t)
break;
}
+ for (i = 0; i < t->ntraits; i++) {
+ tr = gettrait(curstab(), t->traits[i]);
+ if (!tr)
+ lfatal(t->loc, "trait %s does not exist", ctxstr(t->traits[i]));
+ if (!t->trneed)
+ t->trneed = mkbs();
+ bsput(t->trneed, tr->uid);
+ }
+
for (i = 0; i < t->nsub; i++) {
t->sub[i] = tf(t->sub[i]);
if (t->sub[i] == t) {
@@ -629,8 +641,10 @@ tf(Type *orig)
popenv(orig->env);
} else if (orig->type == Typaram) {
tt = boundtype(t);
- if (tt)
+ if (tt) {
+ tyresolve(tt);
t = tt;
+ }
}
ingeneric -= isgeneric;
return t;
@@ -768,7 +782,7 @@ tymatchrank(Type *pat, Type *to)
if (!pat->trneed)
return 0;
for (i = 0; bsiter(pat->trneed, &i); i++)
- if (!tryconstrain(to, traittab[i]))
+ if (!tryconstrain(to, traittab[i], 0))
return -1;
return 0;
} else if (pat->type == Tyvar) {
@@ -848,7 +862,7 @@ tymatchrank(Type *pat, Type *to)
}
static int
-tryconstrain(Type *base, Trait *tr)
+tryconstrain(Type *base, Trait *tr, int update)
{
Traitmap *tm;
Bitset *bs;
@@ -859,12 +873,14 @@ tryconstrain(Type *base, Trait *tr)
ty = base;
tm = traitmap->sub[ty->type];
while (1) {
- if (ty->type == Typaram && bshas(ty->trneed, tr->uid))
- return 1;
+ if (ty->type == Typaram)
+ if (ty->trneed && bshas(ty->trneed, tr->uid))
+ return 1;
if (ty->type == Tyvar) {
if (!ty->trneed)
ty->trneed = mkbs();
- bsput(ty->trneed, tr->uid);
+ if (update)
+ bsput(ty->trneed, tr->uid);
return 1;
}
if (bshas(tm->traits, tr->uid))
@@ -880,11 +896,12 @@ tryconstrain(Type *base, Trait *tr)
if (tymatchrank(tm->filter[i], ty) >= 0)
return 1;
}
- if (!tm->sub[ty->type])
+ if (!ty->sub || ty->nsub != 1)
break;
- assert(ty->nsub == 1);
- tm = tm->sub[ty->type];
ty = ty->sub[0];
+ tm = tm->sub[ty->type];
+ if (!tm)
+ break;
}
if (base->type != Tyname)
break;
@@ -900,7 +917,7 @@ tryconstrain(Type *base, Trait *tr)
static void
constrain(Node *ctx, Type *base, Trait *tr)
{
- if (!tryconstrain(base, tr))
+ if (!tryconstrain(base, tr, 1))
fatal(ctx, "%s needs trait %s near %s", tystr(base), namestr(tr->name), ctxstr(ctx));
}
@@ -908,6 +925,7 @@ static void
traitsfor(Type *base, Bitset *dst)
{
Traitmap *tm;
+ Bitset *bs;
Type *ty;
size_t i;
@@ -917,7 +935,12 @@ traitsfor(Type *base, Bitset *dst)
while (1) {
if (ty->type == Tyvar)
break;
- bsunion(dst, tm->traits);
+ if (ty->type == Tyname && ty->ngparam == 0)
+ bs = htget(tm->name, ty->name);
+ else
+ bs = tm->traits;
+ if (bs)
+ bsunion(dst, bs);
for (i = 0; i < tm->nfilter; i++) {
if (tymatchrank(tm->filter[i], ty) >= 0)
bsput(dst, tm->filtertr[i]->uid);
@@ -1247,7 +1270,6 @@ unifycall(Node *n)
Type *ft;
ft = type(n->expr.args[0]);
-
if (ft->type == Tyvar) {
/* the first arg is the function itself, so it shouldn't be counted */
ft = mktyfunc(n->loc, &n->expr.args[1], n->expr.nargs - 1, mktyvar(n->loc));
@@ -1643,7 +1665,7 @@ inferexpr(Node **np, Type *ret, int *sawret)
case Odiveq: /* @a /= @a -> @a */
infersub(n, ret, sawret, &isconst);
t = type(args[0]);
- constrain(n, type(args[0]), traittab[Tcnum]);
+ constrain(n, t, traittab[Tcnum]);
isconst = args[0]->expr.isconst;
for (i = 1; i < nargs; i++) {
isconst = isconst && args[i]->expr.isconst;
@@ -1671,8 +1693,8 @@ inferexpr(Node **np, Type *ret, int *sawret)
case Obsreq: /* @a >>= @a -> @a */
infersub(n, ret, sawret, &isconst);
t = type(args[0]);
- constrain(n, type(args[0]), traittab[Tcnum]);
- constrain(n, type(args[0]), traittab[Tcint]);
+ constrain(n, t, traittab[Tcnum]);
+ constrain(n, t, traittab[Tcint]);
isconst = args[0]->expr.isconst;
for (i = 1; i < nargs; i++) {
isconst = isconst && args[i]->expr.isconst;
@@ -1901,19 +1923,21 @@ specializeimpl(Node *n)
Node *dcl, *proto, *name, *sym;
Tysubst *subst;
Type *ty;
- Trait *t;
+ Trait *tr;
size_t i, j;
int generic;
+ char *traitns;
- t = gettrait(curstab(), n->impl.traitname);
- if (!t)
+ tr = gettrait(curstab(), n->impl.traitname);
+ if (!tr)
fatal(n, "no trait %s\n", namestr(n->impl.traitname));
- n->impl.trait = t;
+ n->impl.trait = tr;
+ traitns = tr->name->name.ns;
dcl = NULL;
- if (n->impl.naux != t->naux)
+ if (n->impl.naux != tr->naux)
fatal(n, "%s incompatibly specialized with %zd types instead of %zd types",
- namestr(n->impl.traitname), n->impl.naux, t->naux);
+ namestr(n->impl.traitname), n->impl.naux, tr->naux);
n->impl.type = tf(n->impl.type);
pushenv(n->impl.type->env);
for (i = 0; i < n->impl.naux; i++)
@@ -1931,25 +1955,27 @@ specializeimpl(Node *n)
here.
*/
if (file->file.globls->name)
- setns(dcl->decl.name, file->file.globls->name);
- for (j = 0; j < t->nproto; j++) {
- if (nsnameeq(dcl->decl.name, t->proto[j]->decl.name)) {
- proto = t->proto[j];
+ setns(dcl->decl.name, traitns);
+ for (j = 0; j < tr->nproto; j++) {
+ if (nsnameeq(dcl->decl.name, tr->proto[j]->decl.name)) {
+ proto = tr->proto[j];
break;
}
}
if (!proto)
fatal(n, "declaration %s missing in %s, near %s", namestr(dcl->decl.name),
- namestr(t->name), ctxstr(n));
+ namestr(tr->name), ctxstr(n));
/* infer and unify types */
- verifytraits(n, t->param, n->impl.type);
+ pushenv(proto->decl.env);
+ verifytraits(n, tr->param, n->impl.type);
subst = mksubst();
- substput(subst, t->param, n->impl.type);
- for (j = 0; j < t->naux; j++)
- substput(subst, t->aux[j], n->impl.aux[j]);
+ substput(subst, tr->param, n->impl.type);
+ for (j = 0; j < tr->naux; j++)
+ substput(subst, tr->aux[j], n->impl.aux[j]);
ty = tyspecialize(type(proto), subst, delayed, NULL);
substfree(subst);
+ popenv(proto->decl.env);
generic = hasparams(ty);
if (generic)
@@ -1963,7 +1989,7 @@ specializeimpl(Node *n)
sym = getdcl(file->file.globls, name);
if (sym)
fatal(n, "trait %s already specialized with %s on %s:%d",
- namestr(t->name), tystr(n->impl.type),
+ namestr(tr->name), tystr(n->impl.type),
fname(sym->loc), lnum(sym->loc));
dcl->decl.name = name;
putdcl(file->file.globls, dcl);
@@ -1974,7 +2000,7 @@ specializeimpl(Node *n)
lappend(&proto->decl.gimpl, &proto->decl.ngimpl, dcl);
lappend(&proto->decl.gtype, &proto->decl.ngtype, ty);
}
- dcl->decl.vis = t->vis;
+ dcl->decl.vis = tr->vis;
lappend(&impldecl, &nimpldecl, dcl);
if (generic)
@@ -2792,7 +2818,7 @@ addtraittab(Traitmap *m, Trait *tr, Type *ty)
size_t i;
if (!m->sub[ty->type])
- m = mktraitmap();
+ m->sub[ty->type] = mktraitmap();
mm = m->sub[ty->type];
switch (ty->type) {
case Tygeneric: