summaryrefslogtreecommitdiff
path: root/parse/infer.c
diff options
context:
space:
mode:
Diffstat (limited to 'parse/infer.c')
-rw-r--r--parse/infer.c127
1 files changed, 92 insertions, 35 deletions
diff --git a/parse/infer.c b/parse/infer.c
index 55628a5..daa071e 100644
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -35,7 +35,9 @@ static void inferexpr(Node **np, Type *ret, int *sawret);
static void inferdecl(Node *n);
static int tryconstrain(Type *ty, Trait *tr, int update);
+static Type *tyfreshen(Tysubst *subst, Type *orig);
static Type *tf(Type *t);
+static Type *basetype(Type *a);
static Type *unify(Node *ctx, Type *a, Type *b);
static Type *tyfix(Node *ctx, Type *orig, int noerr);
@@ -66,7 +68,6 @@ static Node **specializations;
static size_t nspecializations;
static Stab **specializationscope;
static size_t nspecializationscope;
-static Htab *seqbase;
static Traitmap *traitmap;
static void
@@ -279,6 +280,7 @@ additerspecialization(Node *n, Stab *stab)
ty = exprtype(n->iterstmt.seq);
if (ty->type == Tyslice || ty->type == Tyarray || ty->type == Typtr)
return;
+ ty = tyfreshen(NULL, ty);
for (i = 0; i < tr->nproto; i++) {
ty = exprtype(n->iterstmt.seq);
if (hthas(tr->proto[i]->decl.impls, ty))
@@ -462,27 +464,32 @@ needfreshen(Type *t)
static Type *
tyfreshen(Tysubst *subst, Type *orig)
{
- Type *t;
+ Type *ty, *base;
if (!needfreshen(orig))
return orig;
pushenv(orig->env);
if (!subst) {
subst = mksubst();
- t = tyspecialize(orig, subst, delayed, seqbase);
+ ty = tyspecialize(orig, subst, delayed, seqbase);
substfree(subst);
} else {
- t = tyspecialize(orig, subst, delayed, seqbase);
+ ty = tyspecialize(orig, subst, delayed, seqbase);
}
+ ty->spec = orig->spec;
+ ty->nspec = orig->nspec;
+ base = basetype(ty);
+ if (base)
+ htput(seqbase, ty, base);
popenv(orig->env);
- return t;
+ return ty;
}
/* Resolves a type and all its subtypes recursively. */
static void
tyresolve(Type *t)
{
- size_t i;
+ size_t i, j;
Trait *tr;
if (t->resolved)
@@ -526,13 +533,16 @@ tyresolve(Type *t)
break;
}
- for (i = 0; i < t->ntraits; i++) {
- tr = gettrait(curstab(), t->traits[i]);
- if (!tr)
- lfatal(t->loc, "trait %s does not exist", ctxstr(t->traits[i]));
- if (!t->trneed)
- t->trneed = mkbs();
- bsput(t->trneed, tr->uid);
+ for (i = 0; i < t->nspec; i++) {
+ for (j = 0; j < t->spec[i]->ntrait; j++) {
+ tr = gettrait(curstab(), t->spec[i]->trait[j]);
+ if (!tr)
+ lfatal(t->loc, "trait %s does not exist", ctxstr(t->spec[i]->trait[j]));
+ if (!t->trneed)
+ t->trneed = mkbs();
+ bsput(t->trneed, tr->uid);
+ htput(seqbase, t, t->spec[i]->aux);
+ }
}
for (i = 0; i < t->nsub; i++) {
@@ -597,9 +607,8 @@ tysubstmap(Tysubst *subst, Type *t, Type *orig)
{
size_t i;
- for (i = 0; i < t->ngparam; i++) {
+ for (i = 0; i < t->ngparam; i++)
substput(subst, t->gparam[i], tf(orig->arg[i]));
- }
t = tyfreshen(subst, t);
return t;
}
@@ -628,7 +637,9 @@ tf(Type *orig)
t = tylookup(orig);
isgeneric = t->type == Tygeneric;
ingeneric += isgeneric;
+ pushenv(orig->env);
tyresolve(t);
+ popenv(orig->env);
/* If this is an instantiation of a generic type, we want the params to
* match the instantiation */
if (orig->type == Tyunres && t->type == Tygeneric) {
@@ -1919,7 +1930,7 @@ specializeimpl(Node *n)
fatal(n, "%s incompatibly specialized with %zd types instead of %zd types",
namestr(n->impl.traitname), n->impl.naux, tr->naux);
n->impl.type = tf(n->impl.type);
- pushenv(n->impl.type->env);
+ pushenv(n->impl.env);
for (i = 0; i < n->impl.naux; i++)
n->impl.aux[i] = tf(n->impl.aux[i]);
for (i = 0; i < n->impl.ndecls; i++) {
@@ -1986,7 +1997,7 @@ specializeimpl(Node *n)
if (generic)
ingeneric--;
}
- popenv(n->impl.type->env);
+ popenv(n->impl.env);
}
static void
@@ -2049,7 +2060,7 @@ infernode(Node **np, Type *ret, int *sawret)
{
size_t i, nbound;
Node **bound, *n, *pat;
- Type *t, *b;
+ Type *t, *b, *e;
n = *np;
if (!n)
@@ -2115,12 +2126,15 @@ infernode(Node **np, Type *ret, int *sawret)
infernode(&n->iterstmt.seq, NULL, sawret);
infernode(&n->iterstmt.body, ret, sawret);
- b = mktyvar(n->loc);
- t = mktyvar(n->loc);
- htput(seqbase, t, b);
- constrain(n, type(n->iterstmt.seq), traittab[Tciter]);
- unify(n, type(n->iterstmt.seq), t);
- unify(n, type(n->iterstmt.elt), b);
+ e = type(n->iterstmt.elt);
+ t = type(n->iterstmt.seq);
+ constrain(n, t, traittab[Tciter]);
+ b = basetype(t);
+ if (b)
+ unify(n, e, b);
+ else
+ htput(seqbase, t, e);
+ delayedcheck(n, curstab());
break;
case Nmatchstmt:
infernode(&n->matchstmt.val, NULL, sawret);
@@ -2365,6 +2379,43 @@ checkvar(Node *n, Node ***rem, size_t *nrem, Stab ***remscope, size_t *nremscope
}
static void
+fixiter(Node *n, Type *ty, Type *base)
+{
+ size_t i, bestidx;
+ int r, bestrank;
+ Type *b, *t, *orig;
+ Tysubst *ts;
+ Node *impl;
+
+ ty = tysearch(ty);
+ b = htget(seqbase, ty);
+ if (!b)
+ return;
+ bestrank = -1;
+ bestidx = 0;
+ for (i = 0; i < nimpltab; i++) {
+ if (impltab[i]->impl.trait != traittab[Tciter])
+ continue;
+ r = tymatchrank(impltab[i]->impl.type, ty);
+ if (r > bestrank) {
+ bestrank = r;
+ bestidx = i;
+ }
+ }
+ if (bestrank >= 0) {
+ impl = impltab[bestidx];
+ orig = impl->impl.type;
+ t = tf(impl->impl.aux[0]);
+ ts = mksubst();
+ for (i = 0; i < ty->narg; i++)
+ substput(ts, tf(orig->arg[i]), ty->arg[i]);
+ t = tyfreshen(ts, t);
+ substfree(ts);
+ unify(n, t, base);
+ }
+}
+
+static void
postcheckpass(Node ***rem, size_t *nrem, Stab ***remscope, size_t *nremscope)
{
size_t i;
@@ -2373,12 +2424,16 @@ postcheckpass(Node ***rem, size_t *nrem, Stab ***remscope, size_t *nremscope)
for (i = 0; i < npostcheck; i++) {
n = postcheck[i];
pushstab(postcheckscope[i]);
- switch (exprop(n)) {
- case Omemb: infercompn(n, rem, nrem, remscope, nremscope); break;
- case Ocast: checkcast(n, rem, nrem, remscope, nremscope); break;
- case Ostruct: checkstruct(n, rem, nrem, remscope, nremscope); break;
- case Ovar: checkvar(n, rem, nrem, remscope, nremscope); break;
- default: die("should not see %s in postcheck\n", opstr[exprop(n)]);
+ if (n->type == Nexpr) {
+ switch (exprop(n)) {
+ case Omemb: infercompn(n, rem, nrem, remscope, nremscope); break;
+ case Ocast: checkcast(n, rem, nrem, remscope, nremscope); break;
+ case Ostruct: checkstruct(n, rem, nrem, remscope, nremscope); break;
+ case Ovar: checkvar(n, rem, nrem, remscope, nremscope); break;
+ default: die("should not see %s in postcheck\n", opstr[exprop(n)]);
+ }
+ } else if (n->type == Niterstmt) {
+ fixiter(n, type(n->iterstmt.seq), type(n->iterstmt.elt));
}
popstab();
}
@@ -2405,7 +2460,6 @@ postinfer(void)
postcheckscope = remscope;
npostcheckscope = nremscope;
}
- postcheckpass(NULL, NULL, NULL, NULL);
}
/* After inference, replace all
@@ -2588,7 +2642,8 @@ typesub(Node *n, int noerr)
typesub(n->iterstmt.elt, noerr);
typesub(n->iterstmt.seq, noerr);
typesub(n->iterstmt.body, noerr);
- additerspecialization(n, curstab());
+ if (!ingeneric)
+ additerspecialization(n, curstab());
break;
case Nmatchstmt:
typesub(n->matchstmt.val, noerr);
@@ -2687,6 +2742,8 @@ specialize(void)
tr = traittab[Tciter];
assert(tr->nproto == 2);
ty = exprtype(n->iterstmt.seq);
+ if (ty->type == Typaram)
+ continue;
it = itertype(n->iterstmt.seq, mktype(n->loc, Tybool));
d = specializedcl(tr->proto[0], ty, it, &name);
@@ -2782,10 +2839,10 @@ findtrait(Node *impl)
tr = gettrait(ns, n);
if (!tr)
fatal(impl, "trait %s does not exist near %s",
- namestr(impl->impl.traitname), ctxstr(impl));
+ namestr(impl->impl.traitname), ctxstr(impl));
if (tr->naux != impl->impl.naux)
fatal(impl, "incompatible implementation of %s: mismatched aux types",
- namestr(impl->impl.traitname), ctxstr(impl));
+ namestr(impl->impl.traitname), ctxstr(impl));
}
return tr;
}
@@ -2849,6 +2906,7 @@ initimpl(void)
Type *ty;
pushstab(file->file.globls);
+ seqbase = mkht(tyhash, tyeq);
traitmap = zalloc(sizeof(Traitmap));
builtintraits();
for (i = 0; i < nimpltab; i++) {
@@ -2898,7 +2956,6 @@ void
infer(void)
{
delayed = mkht(tyhash, tyeq);
- seqbase = mkht(tyhash, tyeq);
loaduses();
initimpl();