summaryrefslogtreecommitdiff
path: root/parse/infer.c
diff options
context:
space:
mode:
Diffstat (limited to 'parse/infer.c')
-rw-r--r--parse/infer.c378
1 files changed, 261 insertions, 117 deletions
diff --git a/parse/infer.c b/parse/infer.c
index 29b3006..6d15240 100644
--- a/parse/infer.c
+++ b/parse/infer.c
@@ -15,6 +15,21 @@
#include "util.h"
#include "parse.h"
+typedef struct Traitmap Traitmap;
+typedef struct Enttype Enttype;
+
+struct Traitmap {
+ Bitset *traits;
+ Traitmap *sub[Ntypes];
+
+ Htab *name; /* name => bitset(traits) */
+ Type **filter;
+ size_t nfilter;
+ Trait **filtertr;
+ size_t nfiltertr;
+};
+
+
static void infernode(Node **np, Type *ret, int *sawret);
static void inferexpr(Node **np, Type *ret, int *sawret);
static void inferdecl(Node *n);
@@ -51,6 +66,7 @@ static size_t nspecializations;
static Stab **specializationscope;
static size_t nspecializationscope;
static Htab *seqbase;
+static Traitmap *traitmap;
static void
ctxstrcall(char *buf, size_t sz, Node *n)
@@ -239,8 +255,8 @@ adddispspecialization(Node *n, Stab *stab)
tr = traittab[Tcdisp];
ty = decltype(n);
- if (!ty->traits || !bshas(ty->traits, Tcdisp))
- return;
+ //if (!ty->traits || !bshas(ty->traits, Tcdisp))
+ // return;
assert(tr->nproto == 1);
if (hthas(tr->proto[0]->decl.impls, ty))
return;
@@ -250,7 +266,7 @@ adddispspecialization(Node *n, Stab *stab)
}
static void
-additerspecializations(Node *n, Stab *stab)
+additerspecialization(Node *n, Stab *stab)
{
Trait *tr;
Type *ty;
@@ -258,8 +274,8 @@ additerspecializations(Node *n, Stab *stab)
tr = traittab[Tciter];
ty = exprtype(n->iterstmt.seq);
- if (!ty->traits || !bshas(ty->traits, Tciter))
- return;
+ //if (!ty->traits || !bshas(ty->traits, Tciter))
+ // return;
if (ty->type == Tyslice || ty->type == Tyarray || ty->type == Typtr)
return;
for (i = 0; i < tr->nproto; i++) {
@@ -516,11 +532,6 @@ tyresolve(Type *t)
}
}
base = tybase(t);
- /* no-ops if base == t */
- if (t->traits && base->traits)
- bsunion(t->traits, base->traits);
- else if (base->traits)
- t->traits = bsdup(base->traits);
if (occurs(t))
lfatal(t->loc, "type %s includes itself", tystr(t));
popenv(t->env);
@@ -642,6 +653,17 @@ settype(Node *n, Type *t)
marksrc(t, n->loc);
}
+static Type*
+mktylike(Srcloc l, Ty other)
+{
+ Type *t;
+
+ t = mktyvar(l);
+ /* not perfect in general, but good enough for all places mktylike is used. */
+ t->trneed = bsdup(traitmap->sub[other]->traits);
+ return t;
+}
+
/* Gets the type of a literal value */
static Type *
littype(Node *n)
@@ -669,19 +691,11 @@ static Type *
delayeducon(Type *fallback)
{
Type *t;
- char *from, *to;
if (fallback->type != Tyunion)
return fallback;
t = mktylike(fallback->loc, fallback->type);
htput(delayed, t, fallback);
- if (debugopt['u']) {
- from = tystr(t);
- to = tystr(fallback);
- indentf(indentdepth, "Delay %s -> %s\n", from, to);
- free(from);
- free(to);
- }
return t;
}
@@ -746,43 +760,108 @@ _bind(Tyenv *e, Node *n)
* constraint list. Otherwise, the type is checked to see
* if it has the required constraint */
static void
-constrain(Node *ctx, Type *a, Trait *c)
+constrain(Node *ctx, Type *base, Trait *tr)
{
- if (a->type == Tyvar) {
- if (!a->traits)
- a->traits = mkbs();
- settrait(a, c);
- } else if (!a->traits || !bshas(a->traits, c->uid)) {
- fatal(ctx, "%s needs %s near %s", tystr(a), namestr(c->name), ctxstr(ctx));
+ Traitmap *tm;
+ Bitset *bs;
+ Type *ty;
+ size_t i;
+
+ while(1) {
+ ty = base;
+ tm = traitmap->sub[ty->type];
+ while (1) {
+ if (ty->type == Typaram && bshas(ty->trneed, tr->uid))
+ return;
+ if (ty->type == Tyvar) {
+ if (!ty->trneed)
+ ty->trneed = mkbs();
+ bsput(ty->trneed, tr->uid);
+ return;
+ }
+ if (bshas(tm->traits, tr->uid))
+ return;
+ if (tm->name && ty->type == Tyname) {
+ bs = htget(tm->name, ty->name);
+ if (bs && bshas(bs, tr->uid))
+ return;
+ }
+ for (i = 0; i < tm->nfilter; i++) {
+ if (tm->filtertr[i]->uid != tr->uid)
+ continue;
+ if (tymatchrank(tm->filter[i], ty) >= 0)
+ return;
+ }
+ if (!tm->sub[ty->type])
+ break;
+ assert(ty->nsub == 1);
+ tm = tm->sub[ty->type];
+ ty = ty->sub[0];
+ }
+ if (base->type != Tyname)
+ break;
+ base = base->sub[0];
}
+ fatal(ctx, "%s needs trait %s near %s", tystr(ty), namestr(tr->name), ctxstr(ctx));
}
-static int
-satisfiestraits(Type *a, Type *b)
+static void
+traitsfor(Type *base, Bitset *dst)
{
- if (!a->traits || bscount(a->traits) == 0)
- return 1;
- if (b->traits)
- return bsissubset(a->traits, b->traits);
- return 0;
+ Traitmap *tm;
+ Type *ty;
+ size_t i;
+
+ while(1) {
+ ty = base;
+ tm = traitmap->sub[ty->type];
+ while (1) {
+ if (ty->type == Tyvar)
+ break;
+ bsunion(dst, tm->traits);
+ for (i = 0; i < tm->nfilter; i++) {
+ if (tymatchrank(tm->filter[i], ty) >= 0)
+ bsput(dst, tm->filtertr[i]->uid);
+ }
+ if (!tm->sub[ty->type] || ty->nsub != 1)
+ break;
+ tm = tm->sub[ty->type];
+ ty = ty->sub[0];
+ }
+ if (base->type != Tyname)
+ break;
+ base = base->sub[0];
+ }
}
static void
verifytraits(Node *ctx, Type *a, Type *b)
{
+ char traitbuf[64], abuf[64], bbuf[64];
+ char asrc[64], bsrc[64];
+ Bitset *abs, *bbs;
size_t i, n;
Srcloc l;
char *sep;
- char traitbuf[64], abuf[64], bbuf[64];
- char asrc[64], bsrc[64];
- if (!satisfiestraits(a, b)) {
+ abs = a->trneed;
+ if (!abs) {
+ abs = mkbs();
+ traitsfor(a, abs);
+ }
+ bbs = b->trneed;
+ if (!bbs) {
+ bbs = mkbs();
+ traitsfor(b, bbs);
+ }
+ if (!bsissubset(abs, bbs)) {
sep = "";
n = 0;
- for (i = 0; bsiter(a->traits, &i); i++) {
- if (!b->traits || !bshas(b->traits, i))
+ *traitbuf = 0;
+ for (i = 0; bsiter(abs, &i); i++) {
+ if (!bshas(bbs, i))
n += bprintf(traitbuf + n, sizeof(traitbuf) - n, "%s%s", sep,
- namestr(traittab[i]->name));
+ namestr(traittab[i]->name));
sep = ",";
}
tyfmt(abuf, sizeof abuf, a);
@@ -792,24 +871,28 @@ verifytraits(Node *ctx, Type *a, Type *b)
l = unifysrc[b->tid];
snprintf(bsrc, sizeof asrc, "\n\t%s from %s:%d", bbuf, fname(l), lnum(l));
}
- fatal(ctx, "%s missing traits %s for %s near %s%s%s",
- bbuf, traitbuf, abuf, ctxstr(ctx),
- srcstr(a), srcstr(b));
+ fatal(ctx, "%s needs trait %s near %s%s%s",
+ bbuf, traitbuf, ctxstr(ctx), srcstr(a), srcstr(b));
}
+ if (!a->trneed)
+ bsfree(abs);
+ if (!b->trneed)
+ bsfree(bbs);
}
/* Merges the constraints on types */
static void
mergetraits(Node *ctx, Type *a, Type *b)
{
+// TRFIX
if (b->type == Tyvar) {
/* make sure that if a = b, both have same traits */
- if (a->traits && b->traits)
- bsunion(b->traits, a->traits);
- else if (a->traits)
- b->traits = bsdup(a->traits);
- else if (b->traits)
- a->traits = bsdup(b->traits);
+ if (a->trneed && b->trneed)
+ bsunion(b->trneed, a->trneed);
+ else if (a->trneed)
+ b->trneed = bsdup(a->trneed);
+ else if (b->trneed)
+ a->trneed = bsdup(b->trneed);
} else {
verifytraits(ctx, a, b);
}
@@ -963,7 +1046,6 @@ unify(Node *ctx, Type *u, Type *v)
Type *t, *r;
Type *a, *b;
Type *ea, *eb;
- char *from, *to;
size_t i;
/* a ==> b */
@@ -979,16 +1061,6 @@ unify(Node *ctx, Type *u, Type *v)
b = t;
}
- if (debugopt['u']) {
- from = tystr(a);
- to = tystr(b);
- indentf(indentdepth, "Unify %s => %s\n", from, to);
- indentf(indentdepth + 1, "indexes: %s => %s\n",
- tystr(htget(seqbase, a)), tystr(htget(seqbase, b)));
- free(from);
- free(to);
- }
-
/* Disallow recursive types */
if (a->type == Tyvar && b->type != Tyvar) {
if (occursin(a, b))
@@ -1075,7 +1147,6 @@ unifycall(Node *n)
{
size_t i;
Type *ft;
- char *ret, *ctx;
ft = type(n->expr.args[0]);
@@ -1100,14 +1171,6 @@ unifycall(Node *n)
if (i < ft->nsub && ft->sub[i]->type != Tyvalist)
fatal(n, "%s arity mismatch (expected %zd args, got %zd)",
ctxstr(n->expr.args[0]), ft->nsub - 1, i - 1);
- if (debugopt['u']) {
- ret = tystr(ft->sub[0]);
- ctx = ctxstr(n->expr.args[0]);
- indentf(indentdepth, "Call of %s returns %s\n", ctx, ret);
- free(ctx);
- free(ret);
- }
-
settype(n, ft->sub[0]);
}
@@ -1811,10 +1874,6 @@ specializeimpl(Node *n)
lappend(&proto->decl.gimpl, &proto->decl.ngimpl, dcl);
lappend(&proto->decl.gtype, &proto->decl.ngtype, ty);
}
- if (debugopt['S'])
- printf("specializing trait [%d]%s:%s => %s:%s\n", n->loc.line,
- namestr(proto->decl.name), tystr(type(proto)), namestr(name),
- tystr(ty));
dcl->decl.vis = t->vis;
lappend(&impldecl, &nimpldecl, dcl);
@@ -1901,8 +1960,6 @@ infernode(Node **np, Type *ret, int *sawret)
popstab();
break;
case Ndecl:
- if (debugopt['u'])
- indentf(indentdepth, "--- infer %s ---\n", declname(n));
if (n->decl.isgeneric)
ingeneric++;
indentdepth++;
@@ -1915,8 +1972,6 @@ infernode(Node **np, Type *ret, int *sawret)
constrain(n, type(n), traittab[Tcdisp]);
popenv(n->decl.env);
indentdepth--;
- if (debugopt['u'])
- indentf(indentdepth, "--- done ---\n");
if (n->decl.isgeneric)
ingeneric--;
break;
@@ -2011,7 +2066,6 @@ tyfix(Node *ctx, Type *orig, int noerr)
static Type *tyint, *tyflt;
Type *t, *d, *base;
Tyenv *env;
- char *from, *to;
size_t i;
char buf[1024];
@@ -2036,10 +2090,10 @@ tyfix(Node *ctx, Type *orig, int noerr)
tystr(d), ctxstr(ctx));
}
}
- if (t->type == Tyvar) {
- if (hastrait(t, traittab[Tcint]) && satisfiestraits(t, tyint))
+ if (t->type == Tyvar && t->trneed) {
+ if (bshas(t->trneed, Tcint) && bshas(t->trneed, Tcnum))
t = tyint;
- if (hastrait(t, traittab[Tcfloat]) && satisfiestraits(t, tyflt))
+ else if (bshas(t->trneed, Tcflt) && bshas(t->trneed, Tcnum))
t = tyflt;
} else if (!t->fixed) {
t->fixed = 1;
@@ -2066,19 +2120,8 @@ tyfix(Node *ctx, Type *orig, int noerr)
t->sub[i] = tyfix(ctx, t->sub[i], noerr);
}
- if (t->type == Tyvar && !noerr) {
- if (debugopt['T'])
- dump(file, stdout);
+ if (t->type == Tyvar && !noerr)
fatal(ctx, "underconstrained type %s near %s", tyfmt(buf, 1024, t), ctxstr(ctx));
- }
-
- if (debugopt['u'] && !tyeq(orig, t)) {
- from = tystr(orig);
- to = tystr(t);
- indentf(indentdepth, "subst %s => %s\n", from, to);
- free(from);
- free(to);
- }
if (base)
htput(seqbase, t, base);
if (env)
@@ -2439,7 +2482,7 @@ typesub(Node *n, int noerr)
typesub(n->iterstmt.elt, noerr);
typesub(n->iterstmt.seq, noerr);
typesub(n->iterstmt.body, noerr);
- additerspecializations(n, curstab());
+ additerspecialization(n, curstab());
break;
case Nmatchstmt:
typesub(n->matchstmt.val, noerr);
@@ -2560,39 +2603,141 @@ specialize(void)
popstab();
}
}
+static void
+builtintraits(void)
+{
+ size_t i;
+
+ /* char::(numeric,integral) */
+ for (i = 0; i < Ntypes; i++) {
+ traitmap->sub[i] = zalloc(sizeof(Traitmap));
+ traitmap->sub[i]->traits = mkbs();
+ traitmap->sub[i]->name = mkht(namehash, nameeq);
+ }
+
+ bsput(traitmap->sub[Tychar]->traits, Tcnum);
+ bsput(traitmap->sub[Tychar]->traits, Tcint);
+
+ bsput(traitmap->sub[Tybyte]->traits, Tcnum);
+ bsput(traitmap->sub[Tybyte]->traits, Tcint);
+
+ /* <integer types>::(numeric,integral) */
+ for (i = Tyint8; i < Tyflt32; i++) {
+ bsput(traitmap->sub[i]->traits, Tcnum);
+ bsput(traitmap->sub[i]->traits, Tcint);
+ }
+
+ /* <floats>::(numeric,floating) */
+ bsput(traitmap->sub[Tyflt32]->traits, Tcnum);
+ bsput(traitmap->sub[Tyflt32]->traits, Tcflt);
+ bsput(traitmap->sub[Tyflt64]->traits, Tcnum);
+ bsput(traitmap->sub[Tyflt64]->traits, Tcflt);
+
+ /* @a*::(sliceable) */
+ bsput(traitmap->sub[Typtr]->traits, Tcslice);
+
+ /* @a[:]::(indexable,sliceable) */
+ bsput(traitmap->sub[Tyslice]->traits, Tcidx);
+ bsput(traitmap->sub[Tyslice]->traits, Tcslice);
+ bsput(traitmap->sub[Tyslice]->traits, Tciter);
+
+ /* @a[SZ]::(indexable,sliceable) */
+ bsput(traitmap->sub[Tyarray]->traits, Tcidx);
+ bsput(traitmap->sub[Tyarray]->traits, Tcslice);
+ bsput(traitmap->sub[Tyarray]->traits, Tciter);
+
+ /* @a::function */
+ bsput(traitmap->sub[Tyfunc]->traits, Tcfunc);
+}
+
+static Trait*
+findtrait(Node *impl)
+{
+ Trait *tr;
+ Node *n;
+ Stab *ns;
+
+ tr = impl->impl.trait;
+ if (!tr) {
+ n = impl->impl.traitname;
+ ns = file->file.globls;
+ if (n->name.ns)
+ ns = getns(file, n->name.ns);
+ if (ns)
+ tr = gettrait(ns, n);
+ if (!tr)
+ fatal(impl, "trait %s does not exist near %s",
+ namestr(impl->impl.traitname), ctxstr(impl));
+ if (tr->naux != impl->impl.naux)
+ fatal(impl, "incompatible implementation of %s: mismatched aux types",
+ namestr(impl->impl.traitname), ctxstr(impl));
+ }
+ return tr;
+}
static void
-applytraits(void)
+addtraittab(Traitmap *m, Trait *tr, Type *ty)
+{
+ Bitset *bs;
+ Traitmap *mm;
+
+ if (!m->sub[ty->type]) {
+ m->sub[ty->type] = zalloc(sizeof(Traitmap));
+ m->sub[ty->type]->traits = mkbs();
+ m->sub[ty->type]->name = mkht(namehash, nameeq);
+ }
+ mm = m->sub[ty->type];
+ switch (ty->type) {
+ case Tygeneric:
+ case Typaram:
+ lappend(&mm->filter, &m->nfilter, ty);
+ lappend(&mm->filtertr, &m->nfiltertr, tr);
+ break;
+ case Tyname:
+ if (ty->ngparam == 0) {
+ bs = htget(mm->name, ty->name);
+ if (!bs) {
+ bs = mkbs();
+ htput(mm->name, ty->name, bs);
+ }
+ bsput(bs, tr->uid);
+ } else {
+ lappend(&mm->filter, &m->nfilter, ty);
+ lappend(&mm->filtertr, &m->nfiltertr, tr);
+ }
+ break;
+ case Typtr:
+ case Tyarray:
+ addtraittab(mm, tr, ty->sub[0]);
+ break;
+ default:
+ if (istyprimitive(ty)) {
+ bsput(mm->traits, tr->uid);
+ } else {
+ lappend(&mm->filter, &m->nfilter, ty);
+ lappend(&mm->filtertr, &m->nfiltertr, tr);
+ }
+ }
+}
+
+static void
+initimpl(void)
{
size_t i;
- Node *impl, *n;
+ Node *impl;
Trait *tr;
Type *ty;
- Stab *ns;
- tr = NULL;
pushstab(file->file.globls);
- /* for now, traits can only be declared globally */
+ traitmap = zalloc(sizeof(Traitmap));
+ builtintraits();
for (i = 0; i < nimpltab; i++) {
impl = impltab[i];
- tr = impl->impl.trait;
- if (!tr) {
- n = impl->impl.traitname;
- ns = file->file.globls;
- if (n->name.ns)
- ns = getns(file, n->name.ns);
- if (ns)
- tr = gettrait(ns, n);
- if (!tr)
- fatal(impl, "trait %s does not exist near %s",
- namestr(impl->impl.traitname), ctxstr(impl));
- if (tr->naux != impl->impl.naux)
- fatal(impl, "incompatible implementation of %s: mismatched aux types",
- namestr(impl->impl.traitname), ctxstr(impl));
- }
+ tr = findtrait(impl);
+
pushenv(impl->impl.env);
ty = tf(impl->impl.type);
- settrait(ty, tr);
+ addtraittab(traitmap, tr, ty);
if (tr->uid == Tciter) {
htput(seqbase, tf(impl->impl.type), tf(impl->impl.aux[0]));
}
@@ -2601,7 +2746,7 @@ applytraits(void)
popstab();
}
-void
+static void
verify(void)
{
Type *t;
@@ -2630,15 +2775,14 @@ verify(void)
}
void
-infer()
+infer(void)
{
delayed = mkht(tyhash, tyeq);
seqbase = mkht(tyhash, tyeq);
- /* set up the symtabs */
loaduses();
+ initimpl();
/* do the inference */
- applytraits();
infernode(&file, NULL, NULL);
postinfer();