summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOri Bernstein <ori@eigenstate.org>2018-04-08 00:59:24 -0700
committerOri Bernstein <ori@eigenstate.org>2018-04-08 00:59:24 -0700
commitf9f93d1e447873ca3e5fa6c542eb34e8dd8d4b71 (patch)
treeeb452e3f633041a440055108c89808eef5c86b41
parent185f780a03fbfbb4655b7c07b3ac147980cede2d (diff)
downloadmc-f9f93d1e447873ca3e5fa6c542eb34e8dd8d4b71.tar.gz
Constant time modpow.
-rw-r--r--lib/crypto/ct.myr2
-rw-r--r--lib/crypto/ctbig.myr247
-rw-r--r--lib/crypto/test/ctbig.myr142
-rw-r--r--lib/std/hashfuncs.myr6
4 files changed, 343 insertions, 54 deletions
diff --git a/lib/crypto/ct.myr b/lib/crypto/ct.myr
index 28be694..d6fe34d 100644
--- a/lib/crypto/ct.myr
+++ b/lib/crypto/ct.myr
@@ -53,7 +53,7 @@ generic le = {a, b
generic ne = {a, b
const nshift = 8*sizeof(@t) - 1
var q = a ^ b
- -> ((q | -q) >> nshift)^1
+ -> (q | -q) >> nshift
}
generic mux = {c, a, b
diff --git a/lib/crypto/ctbig.myr b/lib/crypto/ctbig.myr
index 64c6702..d9301b2 100644
--- a/lib/crypto/ctbig.myr
+++ b/lib/crypto/ctbig.myr
@@ -1,4 +1,5 @@
use std
+use iter
use "ct"
@@ -25,7 +26,7 @@ pkg crypto =
const ctsub : (r : ctbig#, a : ctbig#, b : ctbig# -> void)
const ctmul : (r : ctbig#, a : ctbig#, b : ctbig# -> void)
//const ctdivmod : (q : ctbig#, u : ctbig#, a : ctbig#, b : ctbig# -> void)
- //const ctmodpow : (r : ctbig#, a : ctbig#, b : ctbig# -> void)
+ const ctmodpow : (r : ctbig#, a : ctbig#, b : ctbig#, m : ctbig# -> void)
const ctiszero : (v : ctbig# -> bool)
const cteq : (a : ctbig#, b : ctbig# -> bool)
@@ -35,6 +36,9 @@ pkg crypto =
const ctlt : (a : ctbig#, b : ctbig# -> bool)
const ctle : (a : ctbig#, b : ctbig# -> bool)
+ /* for testing */
+ const growmod : (r : ctbig#, a : ctbig#, k : uint32, m : ctbig# -> void)
+
impl std.equatable ctbig#
;;
@@ -59,8 +63,8 @@ const ctfmt = {sb, ap, opts
var ct : ctbig#
ct = std.vanext(ap)
- for d : ct.dig
- std.sbfmt(sb, "{w=8,p=0,x}", d)
+ for d : iter.byreverse(ct.dig)
+ std.sbfmt(sb, "{w=8,p=0,x}.", d)
;;
}
@@ -89,6 +93,13 @@ const ctzero = {nbit
])
}
+const ctdup = {v
+ -> std.mk([
+ .nbit=v.nbit,
+ .dig=std.sldup(v.dig)
+ ])
+}
+
const ct2big = {ct
-> std.mk([
.sign=1,
@@ -155,6 +166,10 @@ const ctfree = {v
}
const ctadd = {r, a, b
+ ctaddcc(r, a, b, 1)
+}
+
+const ctaddcc = {r, a, b, ctl
var v, i, carry
checksz(a, b)
@@ -163,12 +178,16 @@ const ctadd = {r, a, b
carry = 0
for i = 0; i < a.dig.len; i++
v = (a.dig[i] : uint64) + (b.dig[i] : uint64) + carry;
- r.dig[i] = (v : uint32)
+ r.dig[i] = mux(ctl, (v : uint32), r.dig[i])
carry = v >> 32
;;
}
const ctsub = {r, a, b
+ ctsubcc(r, a, b, 1)
+}
+
+const ctsubcc = {r, a, b, ctl
var borrow, v, i
checksz(a, b)
@@ -178,10 +197,10 @@ const ctsub = {r, a, b
for i = 0; i < a.dig.len; i++
v = (a.dig[i] : uint64) - (b.dig[i] : uint64) - borrow
borrow = (v & (1<<63)) >> 63
- v = mux(borrow, v + Base, v)
- r.dig[i] = (v : uint32)
+ r.dig[i] = mux(ctl, (v : uint32), r.dig[i])
;;
clip(r)
+ -> borrow
}
const ctmul = {r, a, b
@@ -215,6 +234,186 @@ const ctmul = {r, a, b
clip(r)
}
+/*
+ * Returns the top digit in the number that has
+ * a bit set. This is useful for finding our division.
+ */
+ const topfull = {n : ctbig#
+ var top
+
+ top = 0
+ for var i = 0; i < n.dig.len; i++
+ top = mux(n.dig[i], i, top)
+ ;;
+ -> 0
+}
+
+/*
+ * Multiplies by 2**32 mod m
+ */
+const growmod = {r, a, k, m
+ var a0, a1, b0, hi, g, q, tb, e
+ var chf, clow, under, over
+ var cc : uint64
+
+ checksz(a, m)
+ std.assert(a.dig.len > 1, "bad modulus")
+ std.assert(a.nbit % 32 == 0, "ragged sizes not yet supported")
+ //std.assert(a.dig[a.dig.len - 1] & (1 << 31) != 0, "top of mod not set")
+
+ a0 = (a.dig[m.dig.len - 1] : uint64) << 32
+ a1 = (a.dig[m.dig.len - 2] : uint64) << 0
+ b0 = (m.dig[m.dig.len - 1] : uint64)
+
+ /*
+ * We hold the top digit here, so
+ * this keeps the number of digits the same, and
+ * as a result, keeps checksz() happy.
+ */
+ hi = a.dig[a.dig.len - 1]
+
+ /* Do the multiplication of x by 2**32 */
+ std.slcp(r.dig[1:], a.dig[:a.dig.len-1])
+ r.dig[0] = k
+ g = ((a0 + a1) / b0 : uint32)
+ e = eq(a0, b0)
+ q = mux((e : uint32), 0xffffffff, mux(eq(g, 0), 0, g - 1));
+
+ cc = 0;
+ tb = 1;
+ for var u = 0; u < r.dig.len; u++
+ var mw, zw, xw, nxw
+ var zl : uint64
+
+ mw = m.dig[u];
+ zl = (mw : uint64) * (q : uint64) + cc
+ cc = zl >> 32
+ zw = (zl : uint32)
+ xw = r.dig[u]
+ nxw = xw - zw;
+ cc += (gt(nxw, xw) : uint64)
+ r.dig[u] = nxw;
+ tb = mux(eq(nxw, mw), tb, gt(nxw, mw));
+ ;;
+
+ /*
+ * We can either underestimate or overestimate q,
+ * - If we overestimated, either cc < hi, or cc == hi && tb != 0.
+ * - If we overestimated, cc > hi.
+ * - Otherwise, we got it exactly right.
+ *
+ * If we overestimated, we need to subtract 'm' once. If we
+ * underestimated, we need to add it once.
+ */
+ chf = (cc >> 32 : uint32)
+ clow = (cc >> 0 : uint32)
+ over = chf | gt(clow, hi);
+ under = ~over & (tb | (~chf & lt(clow, hi)));
+ ctaddcc(r, r, m, over);
+ ctsubcc(r, r, m, under);
+
+}
+
+const tomonty = {r, x, m
+ checksz(x, r)
+ checksz(x, m)
+
+ std.slcp(r.dig, x.dig)
+ for var i = 0; i < m.dig.len; i++
+ growmod(r, r, 0, m)
+ ;;
+}
+
+const ccopy = {r, v, ctl
+ checksz(r, v)
+ for var i = 0; i < r.dig.len; i++
+ r.dig[i] = mux(ctl, v.dig[i], r.dig[i])
+ ;;
+}
+
+const muladd = {a, b, k
+ -> (a : uint64) * (b : uint64) + (k : uint64)
+}
+
+const montymul = {r : ctbig#, x : ctbig#, y : ctbig#, m : ctbig#, m0i : uint32
+ var dh : uint64
+ var s
+
+ checksz(x, y)
+ checksz(x, m)
+ checksz(x, r)
+
+ std.slfill(r.dig, 0)
+ dh = 0
+ for var u = 0; u < x.dig.len; u++
+ var f : uint32, xu : uint32
+ var r1 : uint64, r2 : uint64, zh : uint64
+
+ xu = x.dig[u]
+ f = (r.dig[0] + x.dig[u] * y.dig[0]) * m0i;
+ r1 = 0;
+ r2 = 0;
+ for var v = 0; v < y.dig.len; v++
+ var z : uint64
+ var t : uint32
+
+ z = muladd(xu, y.dig[v], r.dig[v]) + r1
+ r1 = z >> 32
+ t = (z : uint32)
+ z = muladd(f, m.dig[v], t) + r2
+ r2 = z >> 32
+ if v != 0
+ r.dig[v - 1] = (z : uint32)
+ ;;
+ ;;
+ zh = dh + r1 + r2;
+ r.dig[r.dig.len - 1] = (zh : uint32)
+ dh = zh >> 32;
+ ;;
+
+ /*
+ * r may still be greater than m at that point; notably, the
+ * 'dh' word may be non-zero.
+ */
+ s = ne(dh, 0) | (ctge(r, m) : uint64)
+ ctsubcc(r, r, m, (s : uint32))
+}
+
+const ninv32 = {x
+ var y
+
+ y = 2 - x
+ y *= 2 - y * x
+ y *= 2 - y * x
+ y *= 2 - y * x
+ y *= 2 - y * x
+ -> mux(x & 1, -y, 0)
+}
+
+const ctmodpow = {r, a, e, m
+ var t1, t2, m0i, ctl, k, d
+ var n = 0
+
+ t1 = ctdup(a)
+ t2 = ctzero(a.nbit)
+ m0i = ninv32(m.dig[0])
+
+ tomonty(t1, a, m);
+ std.slfill(r.dig, 0);
+ r.dig[0] = 1;
+ for var i = 0; i < e.nbit; i++
+ k = (i : uint32)
+ d = e.dig[e.dig.len - (k>>5) - 1]
+ ctl = (d >> (k & 0x1f)) & 1
+ montymul(t2, r, t1, m, m0i)
+ ccopy(r, t2, ctl);
+ montymul(t2, t1, t1, m, m0i);
+ std.slcp(t1.dig, t2.dig);
+ ;;
+ ctfree(t1)
+ ctfree(t2)
+}
+
const ctiszero = {a
var z, zz
@@ -227,18 +426,14 @@ const ctiszero = {a
}
const cteq = {a, b
- var z, d, e
+ var ne
checksz(a, b)
-
- e = 1
+ ne = 0
for var i = 0; i < a.dig.len; i++
- z = a.dig[i] - b.dig[i]
- /* z != 0 ? 0 : 1 */
- d = mux(z, 0, 1)
- e = mux(e, d, 0)
+ ne = ne | a.dig[i] - b.dig[i]
;;
- -> (e : bool)
+ -> (not(ne) : bool)
}
const ctne = {a, b
@@ -249,17 +444,7 @@ const ctne = {a, b
}
const ctgt = {a, b
- var e, d, g
-
- checksz(a, b)
-
- g = 0
- for var i = 0; i < a.dig.len; i++
- e = not(a.dig[i] - b.dig[i])
- d = gt(a.dig[i], b.dig[i])
- g = mux(e, g, d)
- ;;
- -> (g : bool)
+ -> (ctsubcc(b, b, a, 0) : bool)
}
const ctge = {a, b
@@ -270,17 +455,7 @@ const ctge = {a, b
}
const ctlt = {a, b
- var e, d, l
-
- checksz(a, b)
-
- l = 0
- for var i = 0; i < a.dig.len; i++
- e = not(a.dig[i] - b.dig[i])
- d = gt(a.dig[i], b.dig[i])
- l = mux(e, l, d)
- ;;
- -> (l : bool)
+ -> (ctsubcc(a, a, b, 0) : bool)
}
const ctle = {a, b
diff --git a/lib/crypto/test/ctbig.myr b/lib/crypto/test/ctbig.myr
index ec55381..89b9616 100644
--- a/lib/crypto/test/ctbig.myr
+++ b/lib/crypto/test/ctbig.myr
@@ -9,60 +9,118 @@ const main = {
testr.run([
/* normal */
[.name="add", .fn={ctx
- do(ctx, crypto.ctadd, Nbit,
+ do2(ctx, crypto.ctadd, Nbit,
"5192296858610368357189246603769160",
"5192296858534810493479828944327220",
"75557863709417659441940")
}],
[.name="sub", .fn={ctx
- do(ctx, crypto.ctsub, Nbit,
+ do2(ctx, crypto.ctsub, Nbit,
"5192296858459252629770411284885280",
"5192296858534810493479828944327220",
"75557863709417659441940")
}],
[.name="mul", .fn={ctx
- do(ctx, crypto.ctmul, Nbit,
+ do2(ctx, crypto.ctmul, Nbit,
"392318858376010676506814412592879878824393346033951606800",
"5192296858534810493479828944327220",
"75557863709417659441940")
}],
-
+ [.name="growmod", .fn={ctx
+ do2(ctx, growmod0, Nbit,
+ "259016584597313952181375284077740334036",
+ "137304361882109849168381018424069802644",
+ "279268927326277818181333274586733399084")
+ }
+ ],
+ /* comparisons */
+ [.name="lt-less", .fn={ctx
+ dobool(ctx, crypto.ctlt, Nbit,
+ true,
+ "137304361882109849168381018424069802644",
+ "279268927326277818181333274586733399084")
+ }
+ ],
+ [.name="lt-equal", .fn={ctx
+ dobool(ctx, crypto.ctlt, Nbit,
+ false,
+ "137304361882109849168381018424069802644",
+ "137304361882109849168381018424069802644")
+ }
+ ],
+ [.name="lt-greater", .fn={ctx
+ dobool(ctx, crypto.ctlt, Nbit,
+ false,
+ "279268927326277818181333274586733399084",
+ "137304361882109849168381018424069802644")
+ }
+ ],
+ [.name="gt-less", .fn={ctx
+ dobool(ctx, crypto.ctgt, Nbit,
+ false,
+ "137304361882109849168381018424069802644",
+ "279268927326277818181333274586733399084")
+ }
+ ],
+ [.name="gt-equal", .fn={ctx
+ dobool(ctx, crypto.ctgt, Nbit,
+ false,
+ "137304361882109849168381018424069802644",
+ "137304361882109849168381018424069802644")
+ }
+ ],
+ [.name="gt-greater", .fn={ctx
+ dobool(ctx, crypto.ctgt, Nbit,
+ true,
+ "279268927326277818181333274586733399084",
+ "137304361882109849168381018424069802644")
+ }
+ ],
+
+ [.name="growmodsmall", .fn={ctx
+ do2(ctx, growmod0, Nbit,
+ "30064771072",
+ "7",
+ "279268927326277818181333274586733399084")
+ }
+ ],
[.name="addfunky", .fn={ctx
- do(ctx, crypto.ctadd, Nfunky,
+ do2(ctx, crypto.ctadd, Nfunky,
"75540728658750274549064",
"5192296858534810493479828944327220",
"75557863709417659441940")
}],
[.name="subfunky", .fn={ctx
- do(ctx, crypto.ctsub, Nfunky,
+ do2(ctx, crypto.ctsub, Nfunky,
"528887911047229543018272",
"5192296858534810493479828944327220",
"75557863709417659441940")
}],
[.name="mulfunky", .fn={ctx
- do(ctx, crypto.ctmul, Nfunky,
+ do2(ctx, crypto.ctmul, Nfunky,
"434472066238453871708176",
"5192296858534810493479828944327220",
"75557863709417659441940")
}],
//[.name="div", .fn={ctx
- // do(ctx, div,
+ // do2(ctx, div,
// "75557863709417659441940",
// "392318858376010676506814412592879878824393346033951606800",
// "5192296858534810493479828944327220")
//}],
//[.name="mod", .fn={ctx
- // do(ctx, mod,
+ // do2(ctx, mod,
// "75557863709417659441940",
// "392318858376010676506814412592879878824393346033951606800",
// "5192296858534810493479828944327220")
//}],
- //[.name="modpow", .fn={ctx
- // r = do(ctx, crypto.ctsub,
- // "5192296858459252629770411284885280"
- // "5192296858534810493479828944327220",
- // "75557863709417659441940")
- //}],
+ [.name="modpow", .fn={ctx
+ do3(ctx, crypto.ctmodpow, Nbit,
+ "1231231254019581241243091223098123",
+ "1231231254019581241243091223098123",
+ "1",
+ "238513807008428752753137056878245001837")
+ }],
][:])
}
@@ -80,8 +138,29 @@ const main = {
// z = crypto.ctzero(a.nbit)
// crypto.ctdivmod(z, r, a, b)
//}
-//
-const do = {ctx, op, nbit, estr, astr, bstr
+
+const growmod0 = {r, a, b
+ crypto.growmod(r, a, 0, b)
+}
+
+const dobool : (ctx : testr.ctx#, op : (a : crypto.ctbig#, b : crypto.ctbig# -> bool), nbit : std.size, e : bool, astr : byte[:], bstr : byte[:] -> void) = {ctx, op, nbit, e, astr, bstr
+ var r, a, ai, b, bi
+
+ r = crypto.ctzero(nbit)
+ ai = std.get(std.bigparse(astr))
+ bi = std.get(std.bigparse(bstr))
+ a = crypto.big2ct(ai, nbit)
+ b = crypto.big2ct(bi, nbit)
+
+ std.bigfree(ai)
+ std.bigfree(bi)
+ testr.eq(ctx, op(a, b), e)
+
+ crypto.ctfree(a)
+ crypto.ctfree(b)
+}
+
+const do2 = {ctx, op, nbit, estr, astr, bstr
var r, a, ai, b, bi, e, ei
r = crypto.ctzero(nbit)
@@ -107,3 +186,32 @@ const do = {ctx, op, nbit, estr, astr, bstr
}
+const do3 = {ctx, op, nbit, estr, astr, bstr, cstr
+ var r, a, ai, b, bi, c, ci, e, ei
+
+ r = crypto.ctzero(nbit)
+ ei = std.get(std.bigparse(estr))
+ ai = std.get(std.bigparse(astr))
+ bi = std.get(std.bigparse(bstr))
+ ci = std.get(std.bigparse(cstr))
+ e = crypto.big2ct(ei, nbit)
+ a = crypto.big2ct(ai, nbit)
+ b = crypto.big2ct(bi, nbit)
+ c = crypto.big2ct(ci, nbit)
+
+ std.bigfree(ei)
+ std.bigfree(ai)
+ std.bigfree(bi)
+
+ op(r, a, b, c)
+
+ testr.eq(ctx, r, e)
+
+ crypto.ctfree(r)
+ crypto.ctfree(e)
+ crypto.ctfree(a)
+ crypto.ctfree(b)
+ crypto.ctfree(c)
+}
+
+
diff --git a/lib/std/hashfuncs.myr b/lib/std/hashfuncs.myr
index 013bfde..96eb91a 100644
--- a/lib/std/hashfuncs.myr
+++ b/lib/std/hashfuncs.myr
@@ -18,6 +18,12 @@ pkg std =
}
;;
+ impl equatable bool =
+ eq = {a, b
+ -> a == b
+ }
+ ;;
+
impl equatable @a :: integral,numeric @a =
eq = {a, b
-> a == b