summaryrefslogtreecommitdiff
path: root/parse/infer.c
diff options
context:
space:
mode:
authorOri Bernstein <ori@eigenstate.org>2018-01-13 23:39:48 -0800
committerOri Bernstein <ori@eigenstate.org>2018-01-13 23:39:48 -0800
commit3d078d5439e93a3dfc4b808ab6bf02805d455ff8 (patch)
tree516b67f82dab1d801b71f48439b13353cc932288 /parse/infer.c
parent70f97fe9898b4852257a9268e6ea0592ee7e3a88 (diff)
downloadmc-3d078d5439e93a3dfc4b808ab6bf02805d455ff8.tar.gz
Add code to fix up iterators.
Diffstat (limited to 'parse/infer.c')
-rw-r--r--parse/infer.c54
1 files changed, 45 insertions, 9 deletions
diff --git a/parse/infer.c b/parse/infer.c
index 20120bd..391a29d 100644
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -2135,6 +2135,7 @@ infernode(Node **np, Type *ret, int *sawret)
unify(n, e, b);
else
htput(seqbase, t, e);
+ delayedcheck(n, curstab());
break;
case Nmatchstmt:
infernode(&n->matchstmt.val, NULL, sawret);
@@ -2379,6 +2380,35 @@ checkvar(Node *n, Node ***rem, size_t *nrem, Stab ***remscope, size_t *nremscope
}
static void
+fixiter(Node *n, Type *ty, Type *base)
+{
+ size_t i, bestidx;
+ int r, bestrank;
+ Type *b, *t;
+
+ ty = tysearch(ty);
+ b = htget(seqbase, ty);
+ if (!b)
+ return;
+ bestrank = -1;
+ bestidx = 0;
+ for (i = 0; i < nimpltab; i++) {
+ if (impltab[i]->impl.trait != traittab[Tciter])
+ continue;
+ r = tymatchrank(impltab[i]->impl.type, ty);
+ if (r > bestrank) {
+ bestrank = r;
+ bestidx = i;
+ }
+ }
+ if (bestrank >= 0) {
+ t = tf(impltab[bestidx]->impl.aux[0]);
+ t = tyfreshen(NULL, t);
+ unify(n, t, base);
+ }
+}
+
+static void
postcheckpass(Node ***rem, size_t *nrem, Stab ***remscope, size_t *nremscope)
{
size_t i;
@@ -2387,12 +2417,16 @@ postcheckpass(Node ***rem, size_t *nrem, Stab ***remscope, size_t *nremscope)
for (i = 0; i < npostcheck; i++) {
n = postcheck[i];
pushstab(postcheckscope[i]);
- switch (exprop(n)) {
- case Omemb: infercompn(n, rem, nrem, remscope, nremscope); break;
- case Ocast: checkcast(n, rem, nrem, remscope, nremscope); break;
- case Ostruct: checkstruct(n, rem, nrem, remscope, nremscope); break;
- case Ovar: checkvar(n, rem, nrem, remscope, nremscope); break;
- default: die("should not see %s in postcheck\n", opstr[exprop(n)]);
+ if (n->type == Nexpr) {
+ switch (exprop(n)) {
+ case Omemb: infercompn(n, rem, nrem, remscope, nremscope); break;
+ case Ocast: checkcast(n, rem, nrem, remscope, nremscope); break;
+ case Ostruct: checkstruct(n, rem, nrem, remscope, nremscope); break;
+ case Ovar: checkvar(n, rem, nrem, remscope, nremscope); break;
+ default: die("should not see %s in postcheck\n", opstr[exprop(n)]);
+ }
+ } else if (n->type == Niterstmt) {
+ fixiter(n, type(n->iterstmt.seq), type(n->iterstmt.elt));
}
popstab();
}
@@ -2419,7 +2453,6 @@ postinfer(void)
postcheckscope = remscope;
npostcheckscope = nremscope;
}
- postcheckpass(NULL, NULL, NULL, NULL);
}
/* After inference, replace all
@@ -2799,10 +2832,10 @@ findtrait(Node *impl)
tr = gettrait(ns, n);
if (!tr)
fatal(impl, "trait %s does not exist near %s",
- namestr(impl->impl.traitname), ctxstr(impl));
+ namestr(impl->impl.traitname), ctxstr(impl));
if (tr->naux != impl->impl.naux)
fatal(impl, "incompatible implementation of %s: mismatched aux types",
- namestr(impl->impl.traitname), ctxstr(impl));
+ namestr(impl->impl.traitname), ctxstr(impl));
}
return tr;
}
@@ -2876,6 +2909,9 @@ initimpl(void)
pushenv(impl->impl.env);
ty = tf(impl->impl.type);
addtraittab(traitmap, tr, ty);
+ if (tr->uid == Tciter) {
+ htput(seqbase, tf(impl->impl.type), tf(impl->impl.aux[0]));
+ }
popenv(impl->impl.env);
}
popstab();