diff options
Diffstat (limited to 'parse/infer.c')
-rw-r--r-- | parse/infer.c | 378 |
1 files changed, 261 insertions, 117 deletions
diff --git a/parse/infer.c b/parse/infer.c index 29b3006..6d15240 100644 --- a/parse/infer.c +++ b/parse/infer.c @@ -15,6 +15,21 @@ #include "util.h" #include "parse.h" +typedef struct Traitmap Traitmap; +typedef struct Enttype Enttype; + +struct Traitmap { + Bitset *traits; + Traitmap *sub[Ntypes]; + + Htab *name; /* name => bitset(traits) */ + Type **filter; + size_t nfilter; + Trait **filtertr; + size_t nfiltertr; +}; + + static void infernode(Node **np, Type *ret, int *sawret); static void inferexpr(Node **np, Type *ret, int *sawret); static void inferdecl(Node *n); @@ -51,6 +66,7 @@ static size_t nspecializations; static Stab **specializationscope; static size_t nspecializationscope; static Htab *seqbase; +static Traitmap *traitmap; static void ctxstrcall(char *buf, size_t sz, Node *n) @@ -239,8 +255,8 @@ adddispspecialization(Node *n, Stab *stab) tr = traittab[Tcdisp]; ty = decltype(n); - if (!ty->traits || !bshas(ty->traits, Tcdisp)) - return; + //if (!ty->traits || !bshas(ty->traits, Tcdisp)) + // return; assert(tr->nproto == 1); if (hthas(tr->proto[0]->decl.impls, ty)) return; @@ -250,7 +266,7 @@ adddispspecialization(Node *n, Stab *stab) } static void -additerspecializations(Node *n, Stab *stab) +additerspecialization(Node *n, Stab *stab) { Trait *tr; Type *ty; @@ -258,8 +274,8 @@ additerspecializations(Node *n, Stab *stab) tr = traittab[Tciter]; ty = exprtype(n->iterstmt.seq); - if (!ty->traits || !bshas(ty->traits, Tciter)) - return; + //if (!ty->traits || !bshas(ty->traits, Tciter)) + // return; if (ty->type == Tyslice || ty->type == Tyarray || ty->type == Typtr) return; for (i = 0; i < tr->nproto; i++) { @@ -516,11 +532,6 @@ tyresolve(Type *t) } } base = tybase(t); - /* no-ops if base == t */ - if (t->traits && base->traits) - bsunion(t->traits, base->traits); - else if (base->traits) - t->traits = bsdup(base->traits); if (occurs(t)) lfatal(t->loc, "type %s includes itself", tystr(t)); popenv(t->env); @@ -642,6 +653,17 @@ settype(Node *n, Type *t) marksrc(t, n->loc); } +static Type* +mktylike(Srcloc l, Ty other) +{ + Type *t; + + t = mktyvar(l); + /* not perfect in general, but good enough for all places mktylike is used. */ + t->trneed = bsdup(traitmap->sub[other]->traits); + return t; +} + /* Gets the type of a literal value */ static Type * littype(Node *n) @@ -669,19 +691,11 @@ static Type * delayeducon(Type *fallback) { Type *t; - char *from, *to; if (fallback->type != Tyunion) return fallback; t = mktylike(fallback->loc, fallback->type); htput(delayed, t, fallback); - if (debugopt['u']) { - from = tystr(t); - to = tystr(fallback); - indentf(indentdepth, "Delay %s -> %s\n", from, to); - free(from); - free(to); - } return t; } @@ -746,43 +760,108 @@ _bind(Tyenv *e, Node *n) * constraint list. Otherwise, the type is checked to see * if it has the required constraint */ static void -constrain(Node *ctx, Type *a, Trait *c) +constrain(Node *ctx, Type *base, Trait *tr) { - if (a->type == Tyvar) { - if (!a->traits) - a->traits = mkbs(); - settrait(a, c); - } else if (!a->traits || !bshas(a->traits, c->uid)) { - fatal(ctx, "%s needs %s near %s", tystr(a), namestr(c->name), ctxstr(ctx)); + Traitmap *tm; + Bitset *bs; + Type *ty; + size_t i; + + while(1) { + ty = base; + tm = traitmap->sub[ty->type]; + while (1) { + if (ty->type == Typaram && bshas(ty->trneed, tr->uid)) + return; + if (ty->type == Tyvar) { + if (!ty->trneed) + ty->trneed = mkbs(); + bsput(ty->trneed, tr->uid); + return; + } + if (bshas(tm->traits, tr->uid)) + return; + if (tm->name && ty->type == Tyname) { + bs = htget(tm->name, ty->name); + if (bs && bshas(bs, tr->uid)) + return; + } + for (i = 0; i < tm->nfilter; i++) { + if (tm->filtertr[i]->uid != tr->uid) + continue; + if (tymatchrank(tm->filter[i], ty) >= 0) + return; + } + if (!tm->sub[ty->type]) + break; + assert(ty->nsub == 1); + tm = tm->sub[ty->type]; + ty = ty->sub[0]; + } + if (base->type != Tyname) + break; + base = base->sub[0]; } + fatal(ctx, "%s needs trait %s near %s", tystr(ty), namestr(tr->name), ctxstr(ctx)); } -static int -satisfiestraits(Type *a, Type *b) +static void +traitsfor(Type *base, Bitset *dst) { - if (!a->traits || bscount(a->traits) == 0) - return 1; - if (b->traits) - return bsissubset(a->traits, b->traits); - return 0; + Traitmap *tm; + Type *ty; + size_t i; + + while(1) { + ty = base; + tm = traitmap->sub[ty->type]; + while (1) { + if (ty->type == Tyvar) + break; + bsunion(dst, tm->traits); + for (i = 0; i < tm->nfilter; i++) { + if (tymatchrank(tm->filter[i], ty) >= 0) + bsput(dst, tm->filtertr[i]->uid); + } + if (!tm->sub[ty->type] || ty->nsub != 1) + break; + tm = tm->sub[ty->type]; + ty = ty->sub[0]; + } + if (base->type != Tyname) + break; + base = base->sub[0]; + } } static void verifytraits(Node *ctx, Type *a, Type *b) { + char traitbuf[64], abuf[64], bbuf[64]; + char asrc[64], bsrc[64]; + Bitset *abs, *bbs; size_t i, n; Srcloc l; char *sep; - char traitbuf[64], abuf[64], bbuf[64]; - char asrc[64], bsrc[64]; - if (!satisfiestraits(a, b)) { + abs = a->trneed; + if (!abs) { + abs = mkbs(); + traitsfor(a, abs); + } + bbs = b->trneed; + if (!bbs) { + bbs = mkbs(); + traitsfor(b, bbs); + } + if (!bsissubset(abs, bbs)) { sep = ""; n = 0; - for (i = 0; bsiter(a->traits, &i); i++) { - if (!b->traits || !bshas(b->traits, i)) + *traitbuf = 0; + for (i = 0; bsiter(abs, &i); i++) { + if (!bshas(bbs, i)) n += bprintf(traitbuf + n, sizeof(traitbuf) - n, "%s%s", sep, - namestr(traittab[i]->name)); + namestr(traittab[i]->name)); sep = ","; } tyfmt(abuf, sizeof abuf, a); @@ -792,24 +871,28 @@ verifytraits(Node *ctx, Type *a, Type *b) l = unifysrc[b->tid]; snprintf(bsrc, sizeof asrc, "\n\t%s from %s:%d", bbuf, fname(l), lnum(l)); } - fatal(ctx, "%s missing traits %s for %s near %s%s%s", - bbuf, traitbuf, abuf, ctxstr(ctx), - srcstr(a), srcstr(b)); + fatal(ctx, "%s needs trait %s near %s%s%s", + bbuf, traitbuf, ctxstr(ctx), srcstr(a), srcstr(b)); } + if (!a->trneed) + bsfree(abs); + if (!b->trneed) + bsfree(bbs); } /* Merges the constraints on types */ static void mergetraits(Node *ctx, Type *a, Type *b) { +// TRFIX if (b->type == Tyvar) { /* make sure that if a = b, both have same traits */ - if (a->traits && b->traits) - bsunion(b->traits, a->traits); - else if (a->traits) - b->traits = bsdup(a->traits); - else if (b->traits) - a->traits = bsdup(b->traits); + if (a->trneed && b->trneed) + bsunion(b->trneed, a->trneed); + else if (a->trneed) + b->trneed = bsdup(a->trneed); + else if (b->trneed) + a->trneed = bsdup(b->trneed); } else { verifytraits(ctx, a, b); } @@ -963,7 +1046,6 @@ unify(Node *ctx, Type *u, Type *v) Type *t, *r; Type *a, *b; Type *ea, *eb; - char *from, *to; size_t i; /* a ==> b */ @@ -979,16 +1061,6 @@ unify(Node *ctx, Type *u, Type *v) b = t; } - if (debugopt['u']) { - from = tystr(a); - to = tystr(b); - indentf(indentdepth, "Unify %s => %s\n", from, to); - indentf(indentdepth + 1, "indexes: %s => %s\n", - tystr(htget(seqbase, a)), tystr(htget(seqbase, b))); - free(from); - free(to); - } - /* Disallow recursive types */ if (a->type == Tyvar && b->type != Tyvar) { if (occursin(a, b)) @@ -1075,7 +1147,6 @@ unifycall(Node *n) { size_t i; Type *ft; - char *ret, *ctx; ft = type(n->expr.args[0]); @@ -1100,14 +1171,6 @@ unifycall(Node *n) if (i < ft->nsub && ft->sub[i]->type != Tyvalist) fatal(n, "%s arity mismatch (expected %zd args, got %zd)", ctxstr(n->expr.args[0]), ft->nsub - 1, i - 1); - if (debugopt['u']) { - ret = tystr(ft->sub[0]); - ctx = ctxstr(n->expr.args[0]); - indentf(indentdepth, "Call of %s returns %s\n", ctx, ret); - free(ctx); - free(ret); - } - settype(n, ft->sub[0]); } @@ -1811,10 +1874,6 @@ specializeimpl(Node *n) lappend(&proto->decl.gimpl, &proto->decl.ngimpl, dcl); lappend(&proto->decl.gtype, &proto->decl.ngtype, ty); } - if (debugopt['S']) - printf("specializing trait [%d]%s:%s => %s:%s\n", n->loc.line, - namestr(proto->decl.name), tystr(type(proto)), namestr(name), - tystr(ty)); dcl->decl.vis = t->vis; lappend(&impldecl, &nimpldecl, dcl); @@ -1901,8 +1960,6 @@ infernode(Node **np, Type *ret, int *sawret) popstab(); break; case Ndecl: - if (debugopt['u']) - indentf(indentdepth, "--- infer %s ---\n", declname(n)); if (n->decl.isgeneric) ingeneric++; indentdepth++; @@ -1915,8 +1972,6 @@ infernode(Node **np, Type *ret, int *sawret) constrain(n, type(n), traittab[Tcdisp]); popenv(n->decl.env); indentdepth--; - if (debugopt['u']) - indentf(indentdepth, "--- done ---\n"); if (n->decl.isgeneric) ingeneric--; break; @@ -2011,7 +2066,6 @@ tyfix(Node *ctx, Type *orig, int noerr) static Type *tyint, *tyflt; Type *t, *d, *base; Tyenv *env; - char *from, *to; size_t i; char buf[1024]; @@ -2036,10 +2090,10 @@ tyfix(Node *ctx, Type *orig, int noerr) tystr(d), ctxstr(ctx)); } } - if (t->type == Tyvar) { - if (hastrait(t, traittab[Tcint]) && satisfiestraits(t, tyint)) + if (t->type == Tyvar && t->trneed) { + if (bshas(t->trneed, Tcint) && bshas(t->trneed, Tcnum)) t = tyint; - if (hastrait(t, traittab[Tcfloat]) && satisfiestraits(t, tyflt)) + else if (bshas(t->trneed, Tcflt) && bshas(t->trneed, Tcnum)) t = tyflt; } else if (!t->fixed) { t->fixed = 1; @@ -2066,19 +2120,8 @@ tyfix(Node *ctx, Type *orig, int noerr) t->sub[i] = tyfix(ctx, t->sub[i], noerr); } - if (t->type == Tyvar && !noerr) { - if (debugopt['T']) - dump(file, stdout); + if (t->type == Tyvar && !noerr) fatal(ctx, "underconstrained type %s near %s", tyfmt(buf, 1024, t), ctxstr(ctx)); - } - - if (debugopt['u'] && !tyeq(orig, t)) { - from = tystr(orig); - to = tystr(t); - indentf(indentdepth, "subst %s => %s\n", from, to); - free(from); - free(to); - } if (base) htput(seqbase, t, base); if (env) @@ -2439,7 +2482,7 @@ typesub(Node *n, int noerr) typesub(n->iterstmt.elt, noerr); typesub(n->iterstmt.seq, noerr); typesub(n->iterstmt.body, noerr); - additerspecializations(n, curstab()); + additerspecialization(n, curstab()); break; case Nmatchstmt: typesub(n->matchstmt.val, noerr); @@ -2560,39 +2603,141 @@ specialize(void) popstab(); } } +static void +builtintraits(void) +{ + size_t i; + + /* char::(numeric,integral) */ + for (i = 0; i < Ntypes; i++) { + traitmap->sub[i] = zalloc(sizeof(Traitmap)); + traitmap->sub[i]->traits = mkbs(); + traitmap->sub[i]->name = mkht(namehash, nameeq); + } + + bsput(traitmap->sub[Tychar]->traits, Tcnum); + bsput(traitmap->sub[Tychar]->traits, Tcint); + + bsput(traitmap->sub[Tybyte]->traits, Tcnum); + bsput(traitmap->sub[Tybyte]->traits, Tcint); + + /* <integer types>::(numeric,integral) */ + for (i = Tyint8; i < Tyflt32; i++) { + bsput(traitmap->sub[i]->traits, Tcnum); + bsput(traitmap->sub[i]->traits, Tcint); + } + + /* <floats>::(numeric,floating) */ + bsput(traitmap->sub[Tyflt32]->traits, Tcnum); + bsput(traitmap->sub[Tyflt32]->traits, Tcflt); + bsput(traitmap->sub[Tyflt64]->traits, Tcnum); + bsput(traitmap->sub[Tyflt64]->traits, Tcflt); + + /* @a*::(sliceable) */ + bsput(traitmap->sub[Typtr]->traits, Tcslice); + + /* @a[:]::(indexable,sliceable) */ + bsput(traitmap->sub[Tyslice]->traits, Tcidx); + bsput(traitmap->sub[Tyslice]->traits, Tcslice); + bsput(traitmap->sub[Tyslice]->traits, Tciter); + + /* @a[SZ]::(indexable,sliceable) */ + bsput(traitmap->sub[Tyarray]->traits, Tcidx); + bsput(traitmap->sub[Tyarray]->traits, Tcslice); + bsput(traitmap->sub[Tyarray]->traits, Tciter); + + /* @a::function */ + bsput(traitmap->sub[Tyfunc]->traits, Tcfunc); +} + +static Trait* +findtrait(Node *impl) +{ + Trait *tr; + Node *n; + Stab *ns; + + tr = impl->impl.trait; + if (!tr) { + n = impl->impl.traitname; + ns = file->file.globls; + if (n->name.ns) + ns = getns(file, n->name.ns); + if (ns) + tr = gettrait(ns, n); + if (!tr) + fatal(impl, "trait %s does not exist near %s", + namestr(impl->impl.traitname), ctxstr(impl)); + if (tr->naux != impl->impl.naux) + fatal(impl, "incompatible implementation of %s: mismatched aux types", + namestr(impl->impl.traitname), ctxstr(impl)); + } + return tr; +} static void -applytraits(void) +addtraittab(Traitmap *m, Trait *tr, Type *ty) +{ + Bitset *bs; + Traitmap *mm; + + if (!m->sub[ty->type]) { + m->sub[ty->type] = zalloc(sizeof(Traitmap)); + m->sub[ty->type]->traits = mkbs(); + m->sub[ty->type]->name = mkht(namehash, nameeq); + } + mm = m->sub[ty->type]; + switch (ty->type) { + case Tygeneric: + case Typaram: + lappend(&mm->filter, &m->nfilter, ty); + lappend(&mm->filtertr, &m->nfiltertr, tr); + break; + case Tyname: + if (ty->ngparam == 0) { + bs = htget(mm->name, ty->name); + if (!bs) { + bs = mkbs(); + htput(mm->name, ty->name, bs); + } + bsput(bs, tr->uid); + } else { + lappend(&mm->filter, &m->nfilter, ty); + lappend(&mm->filtertr, &m->nfiltertr, tr); + } + break; + case Typtr: + case Tyarray: + addtraittab(mm, tr, ty->sub[0]); + break; + default: + if (istyprimitive(ty)) { + bsput(mm->traits, tr->uid); + } else { + lappend(&mm->filter, &m->nfilter, ty); + lappend(&mm->filtertr, &m->nfiltertr, tr); + } + } +} + +static void +initimpl(void) { size_t i; - Node *impl, *n; + Node *impl; Trait *tr; Type *ty; - Stab *ns; - tr = NULL; pushstab(file->file.globls); - /* for now, traits can only be declared globally */ + traitmap = zalloc(sizeof(Traitmap)); + builtintraits(); for (i = 0; i < nimpltab; i++) { impl = impltab[i]; - tr = impl->impl.trait; - if (!tr) { - n = impl->impl.traitname; - ns = file->file.globls; - if (n->name.ns) - ns = getns(file, n->name.ns); - if (ns) - tr = gettrait(ns, n); - if (!tr) - fatal(impl, "trait %s does not exist near %s", - namestr(impl->impl.traitname), ctxstr(impl)); - if (tr->naux != impl->impl.naux) - fatal(impl, "incompatible implementation of %s: mismatched aux types", - namestr(impl->impl.traitname), ctxstr(impl)); - } + tr = findtrait(impl); + pushenv(impl->impl.env); ty = tf(impl->impl.type); - settrait(ty, tr); + addtraittab(traitmap, tr, ty); if (tr->uid == Tciter) { htput(seqbase, tf(impl->impl.type), tf(impl->impl.aux[0])); } @@ -2601,7 +2746,7 @@ applytraits(void) popstab(); } -void +static void verify(void) { Type *t; @@ -2630,15 +2775,14 @@ verify(void) } void -infer() +infer(void) { delayed = mkht(tyhash, tyeq); seqbase = mkht(tyhash, tyeq); - /* set up the symtabs */ loaduses(); + initimpl(); /* do the inference */ - applytraits(); infernode(&file, NULL, NULL); postinfer(); |