summaryrefslogtreecommitdiff
path: root/lib/std/bigint.myr
diff options
context:
space:
mode:
Diffstat (limited to 'lib/std/bigint.myr')
-rw-r--r--lib/std/bigint.myr742
1 files changed, 742 insertions, 0 deletions
diff --git a/lib/std/bigint.myr b/lib/std/bigint.myr
new file mode 100644
index 0000000..719fd69
--- /dev/null
+++ b/lib/std/bigint.myr
@@ -0,0 +1,742 @@
+use "alloc.use"
+use "chartype.use"
+use "cmp.use"
+use "die.use"
+use "extremum.use"
+use "hasprefix.use"
+use "option.use"
+use "slcp.use"
+use "sldup.use"
+use "slfill.use"
+use "slpush.use"
+use "types.use"
+use "utf.use"
+use "errno.use"
+
+pkg std =
+ type bigint = struct
+ dig : uint32[:] /* little endian, no leading zeros. */
+ sign : int /* -1 for -ve, 0 for zero, 1 for +ve. */
+ ;;
+
+ /* administrivia */
+ generic mkbigint : (v : @a::(numeric,integral) -> bigint#)
+ const bigfree : (a : bigint# -> void)
+ const bigdup : (a : bigint# -> bigint#)
+ const bigassign : (d : bigint#, s : bigint# -> bigint#)
+ const bigmove : (d : bigint#, s : bigint# -> bigint#)
+ const bigparse : (s : byte[:] -> option(bigint#))
+ const bigclear : (a : bigint# -> bigint#)
+ const bigbfmt : (b : byte[:], a : bigint#, base : int -> size)
+ /*
+ const bigtoint : (a : bigint# -> @a::(numeric,integral))
+ */
+
+ /* some useful predicates */
+ const bigiszero : (a : bigint# -> bool)
+ const bigeq : (a : bigint#, b : bigint# -> bool)
+ generic bigeqi : (a : bigint#, b : @a::(numeric,integral) -> bool)
+ const bigcmp : (a : bigint#, b : bigint# -> order)
+
+ /* bigint*bigint -> bigint ops */
+ const bigadd : (a : bigint#, b : bigint# -> bigint#)
+ const bigsub : (a : bigint#, b : bigint# -> bigint#)
+ const bigmul : (a : bigint#, b : bigint# -> bigint#)
+ const bigdiv : (a : bigint#, b : bigint# -> bigint#)
+ const bigmod : (a : bigint#, b : bigint# -> bigint#)
+ const bigdivmod : (a : bigint#, b : bigint# -> (bigint#, bigint#))
+ const bigshl : (a : bigint#, b : bigint# -> bigint#)
+ const bigshr : (a : bigint#, b : bigint# -> bigint#)
+ const bigmodpow : (b : bigint#, e : bigint#, m : bigint# -> bigint#)
+ /*
+ const bigpow : (a : bigint#, b : bigint# -> bigint#)
+ */
+
+ /* bigint*int -> bigint ops */
+ generic bigaddi : (a : bigint#, b : @a::(integral,numeric) -> bigint#)
+ generic bigsubi : (a : bigint#, b : @a::(integral,numeric) -> bigint#)
+ generic bigmuli : (a : bigint#, b : @a::(integral,numeric) -> bigint#)
+ generic bigdivi : (a : bigint#, b : @a::(integral,numeric) -> bigint#)
+ generic bigshli : (a : bigint#, b : @a::(integral,numeric) -> bigint#)
+ generic bigshri : (a : bigint#, b : @a::(integral,numeric) -> bigint#)
+ /*
+ const bigpowi : (a : bigint#, b : uint64 -> bigint#)
+ */
+;;
+
+const Base = 0x100000000ul
+generic mkbigint = {v : @a::(integral,numeric)
+ var a
+ var val
+
+ a = zalloc()
+
+ if v < 0
+ a.sign = -1
+ v = -v
+ elif v > 0
+ a.sign = 1
+ ;;
+ val = v castto(uint64)
+ a.dig = slpush([][:], val castto(uint32))
+ if val > Base
+ a.dig = slpush(a.dig, (val/Base) castto(uint32))
+ ;;
+ -> trim(a)
+}
+
+const bigfree = {a
+ slfree(a.dig)
+ free(a)
+}
+
+const bigdup = {a
+ -> bigassign(zalloc(), a)
+}
+
+const bigassign = {d, s
+ slfree(d.dig)
+ d# = s#
+ d.dig = sldup(s.dig)
+ -> d
+}
+
+const bigmove = {d, s
+ slfree(d.dig)
+ d# = s#
+ s.dig = [][:]
+ s.sign = 0
+ -> d
+}
+
+const bigclear = {v
+ std.slfree(v.dig)
+ v.sign = 0
+ v.dig = [][:]
+ -> v
+}
+
+/* for now, just dump out something for debugging... */
+const bigbfmt = {buf, x, base
+ const digitchars = [
+ '0','1','2','3','4','5','6','7','8','9',
+ 'a','b','c','d','e','f', 'g', 'h', 'i', 'j',
+ 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's',
+ 't', 'u', 'v', 'w', 'x', 'y', 'z']
+ var v, val
+ var n, i
+ var tmp, rem, b
+
+ if base < 0 || base > 36
+ die("invalid base in bigbfmt\n")
+ ;;
+
+ if bigiszero(x)
+ n
+ ;;
+
+ if base == 0
+ b = mkbigint(10)
+ else
+ b = mkbigint(base)
+ ;;
+ n = 0
+ val = bigdup(x)
+ /* generate the digits in reverse order */
+ while !bigiszero(val)
+ (v, rem) = bigdivmod(val, b)
+ if rem.dig.len > 0
+ n += encode(buf[n:], digitchars[rem.dig[0]])
+ else
+ n += encode(buf[n:], '0')
+ ;;
+ bigfree(val)
+ bigfree(rem)
+ val = v
+ ;;
+ bigfree(val)
+ bigfree(b)
+
+ /* this is done last, so we get things right when we reverse the string */
+ if x.sign == 0
+ n += encode(buf[n:], '0')
+ elif x.sign == -1
+ n += encode(buf[n:], '-')
+ ;;
+
+ /* we only generated ascii digits, so this works for reversing. */
+ for i = 0; i < n/2; i++
+ tmp = buf[i]
+ buf[i] = buf[n - i - 1]
+ buf[n - i - 1] = tmp
+ ;;
+ -> n
+}
+
+const bigparse = {str
+ var c, val : int, base
+ var v, b
+ var a
+
+ if hasprefix(str, "0x") || hasprefix(str, "0X")
+ base = 16
+ elif hasprefix(str, "0o") || hasprefix(str, "0O")
+ base = 8
+ elif hasprefix(str, "0b") || hasprefix(str, "0B")
+ base = 2
+ else
+ base = 10
+ ;;
+ if base != 10
+ str = str[2:]
+ ;;
+
+ a = mkbigint(0)
+ b = mkbigint(base)
+ /*
+ efficiency hack: to save allocations,
+ just mutate v[0]. The value will always
+ fit in one digit.
+ */
+ v = mkbigint(1)
+ while str.len != 0
+ (c, str) = striter(str)
+ if c == '_'
+ continue
+ ;;
+ val = charval(c, base)
+ if val < 0 || val > base
+ bigfree(a)
+ bigfree(b)
+ bigfree(v)
+ -> `None
+ ;;
+ v.dig[0] = val castto(uint32)
+ if val == 0
+ v.sign = 0
+ else
+ v.sign = 1
+ ;;
+ bigmul(a, b)
+ bigadd(a, v)
+
+ ;;
+ -> `Some a
+}
+
+const bigiszero = {v
+ -> v.dig.len == 0
+}
+
+const bigeq = {a, b
+ var i
+
+ if a.sign != b.sign || a.dig.len != b.dig.len
+ -> false
+ ;;
+ for i = 0; i < a.dig.len; i++
+ if a.dig[i] != b.dig[i]
+ -> false
+ ;;
+ ;;
+ -> true
+}
+
+generic bigeqi = {a, b
+ var v
+ var dig : uint32[2]
+
+ bigdigit(&v, b < 0, b castto(uint64), dig[:])
+ -> bigeq(a, &v)
+}
+
+const bigcmp = {a, b
+ var i
+ var da, db, sa, sb
+
+ sa = a.sign castto(int64)
+ sb = b.sign castto(int64)
+ if sa < sb
+ -> `Before
+ elif sa > sb
+ -> `After
+ elif a.dig.len < b.dig.len
+ -> signedorder(-sa)
+ elif a.dig.len > b.dig.len
+ -> signedorder(sa)
+ else
+ /* otherwise, the one with the first larger digit is bigger */
+ for i = a.dig.len; i > 0; i--
+ da = a.dig[i - 1] castto(int64)
+ db = b.dig[i - 1] castto(int64)
+ -> signedorder(sa * (da - db))
+ ;;
+ ;;
+ -> `Equal
+}
+
+const signedorder = {sign
+ if sign < 0
+ -> `Before
+ elif sign == 0
+ -> `Equal
+ else
+ -> `After
+ ;;
+}
+
+/* a += b */
+const bigadd = {a, b
+ if a.sign == b.sign || a.sign == 0
+ a.sign = b.sign
+ -> uadd(a, b)
+ elif b.sign == 0
+ -> a
+ else
+ match bigcmp(a, b)
+ | `Before: /* a is negative */
+ a.sign = b.sign
+ -> usub(b, a)
+ | `After: /* b is negative */
+ -> usub(a, b)
+ | `Equal:
+ die("Impossible. Equal vals with different sign.")
+ ;;
+ ;;
+}
+
+/* adds two unsigned values together. */
+const uadd = {a, b
+ var v, i
+ var carry
+ var n
+
+ carry = 0
+ n = max(a.dig.len, b.dig.len)
+ /* guaranteed to carry no more than one value */
+ a.dig = slzgrow(a.dig, n + 1)
+ for i = 0; i < n; i++
+ v = (a.dig[i] castto(uint64)) + carry;
+ if i < b.dig.len
+ v += (b.dig[i] castto(uint64))
+ ;;
+
+ if v >= Base
+ carry = 1
+ else
+ carry = 0
+ ;;
+ a.dig[i] = v castto(uint32)
+ ;;
+ a.dig[i] += carry castto(uint32)
+ -> trim(a)
+}
+
+/* a -= b */
+const bigsub = {a, b
+ /* 0 - x = -x */
+ if a.sign == 0
+ bigassign(a, b)
+ a.sign = -b.sign
+ -> a
+ /* x - 0 = x */
+ elif b.sign == 0
+ -> a
+ elif a.sign != b.sign
+ -> uadd(a, b)
+ else
+ match bigcmp(a, b)
+ | `Before: /* a is negative */
+ a.sign = b.sign
+ -> usub(b, a)
+ | `After: /* b is negative */
+ -> usub(a, b)
+ | `Equal:
+ -> bigclear(a)
+ ;;
+ ;;
+ -> a
+}
+
+/* subtracts two unsigned values, where 'a' is strictly greater than 'b' */
+const usub = {a, b
+ var carry
+ var v, i
+
+ carry = 0
+ for i = 0; i < a.dig.len; i++
+ v = (a.dig[i] castto(int64)) - carry
+ if i < b.dig.len
+ v -= (b.dig[i] castto(int64))
+ ;;
+ if v < 0
+ carry = 1
+ else
+ carry = 0
+ ;;
+ a.dig[i] = v castto(uint32)
+ ;;
+ -> trim(a)
+}
+
+/* a *= b */
+const bigmul = {a, b
+ var i, j
+ var ai, bj, wij
+ var carry, t
+ var w
+
+ if a.sign == 0 || b.sign == 0
+ a.sign = 0
+ slfree(a.dig)
+ a.dig = [][:]
+ -> a
+ elif a.sign != b.sign
+ a.sign = -1
+ else
+ a.sign = 1
+ ;;
+ w = slzalloc(a.dig.len + b.dig.len)
+ for j = 0; j < b.dig.len; j++
+ carry = 0
+ for i = 0; i < a.dig.len; i++
+ ai = a.dig[i] castto(uint64)
+ bj = b.dig[j] castto(uint64)
+ wij = w[i+j] castto(uint64)
+ t = ai * bj + wij + carry
+ w[i + j] = (t castto(uint32))
+ carry = t >> 32
+ ;;
+ w[i+j] = carry castto(uint32)
+ ;;
+ slfree(a.dig)
+ a.dig = w
+ -> trim(a)
+}
+
+const bigdiv = {a : bigint#, b : bigint# -> bigint#
+ var q, r
+
+ (q, r) = bigdivmod(a, b)
+ bigfree(r)
+ -> bigmove(a, q)
+}
+const bigmod = {a : bigint#, b : bigint# -> bigint#
+ var q, r
+
+ (q, r) = bigdivmod(a, b)
+ bigfree(q)
+ -> bigmove(a, r)
+}
+
+/* a /= b */
+const bigdivmod = {a : bigint#, b : bigint# -> (bigint#, bigint#)
+ /*
+ Implements bigint division using Algorithm D from
+ Knuth: Seminumerical algorithms, Section 4.3.1.
+ */
+ var m : int64, n : int64
+ var qhat, rhat, carry, shift
+ var x, y, z, w, p, t /* temporaries */
+ var b0, aj
+ var u, v
+ var i, j : int64
+ var q
+
+ if bigiszero(b)
+ die("divide by zero\n")
+ ;;
+ /* if b > a, we trucate to 0, with remainder 'a' */
+ if a.dig.len < b.dig.len
+ -> (mkbigint(0), bigdup(a))
+ ;;
+
+ q = zalloc()
+ q.dig = slzalloc(max(a.dig.len, b.dig.len) + 1)
+ if a.sign != b.sign
+ q.sign = -1
+ else
+ q.sign = 1
+ ;;
+
+ /* handle single digit divisor separately: the knuth algorithm needs at least 2 digits. */
+ if b.dig.len == 1
+ carry = 0
+ b0 = (b.dig[0] castto(uint64))
+ for j = a.dig.len; j > 0; j--
+ aj = (a.dig[j - 1] castto(uint64))
+ q.dig[j - 1] = (((carry << 32) + aj)/b0) castto(uint32)
+ carry = (carry << 32) + aj - (q.dig[j-1] castto(uint64))*b0
+ ;;
+ q = trim(q)
+ -> (q, trim(mkbigint(carry castto(int32))))
+ ;;
+
+ u = bigdup(a)
+ v = bigdup(b)
+ m = u.dig.len
+ n = v.dig.len
+
+ shift = nlz(v.dig[n - 1])
+ bigshli(u, shift)
+ bigshli(v, shift)
+ u.dig = slzgrow(u.dig, u.dig.len + 1)
+
+ for j = m - n; j >= 0; j--
+ /* load a few temps */
+ x = u.dig[j + n] castto(uint64)
+ y = u.dig[j + n - 1] castto(uint64)
+ z = v.dig[n - 1] castto(uint64)
+ w = v.dig[n - 2] castto(uint64)
+ t = u.dig[j + n - 2] castto(uint64)
+
+ /* estimate qhat */
+ qhat = (x*Base + y)/z
+ rhat = (x*Base + y) - qhat*z
+:divagain
+ if qhat >= Base || (qhat * w) > (rhat*Base + t)
+ qhat--
+ rhat += z
+ if rhat < Base
+ goto divagain
+ ;;
+ ;;
+
+ /* multiply and subtract */
+ carry = 0
+ for i = 0; i < n; i++
+ p = qhat * (v.dig[i] castto(uint64))
+
+ t = (u.dig[i+j] castto(uint64)) - carry - (p % Base)
+ u.dig[i+j] = t castto(uint32)
+
+ carry = (((p castto(int64)) >> 32) - ((t castto(int64)) >> 32)) castto(uint64);
+ ;;
+ t = (u.dig[j + n] castto(uint64)) - carry
+ u.dig[j + n] = t castto(uint32)
+ q.dig[j] = qhat castto(uint32)
+ /* adjust */
+ if t castto(int64) < 0
+ q.dig[j]--
+ carry = 0
+ for i = 0; i < n; i++
+ t = (u.dig[i+j] castto(uint64)) + (v.dig[i] castto(uint64)) + carry
+ u.dig[i+j] = t castto(uint32)
+ carry = t >> 32
+ ;;
+ u.dig[j+n] = u.dig[j+n] + (carry castto(uint32));
+ ;;
+
+ ;;
+ /* undo the biasing for remainder */
+ u = bigshri(u, shift)
+ -> (trim(q), trim(u))
+}
+
+/* computes b^e % m */
+const bigmodpow = {base, exp, mod
+ var r, n
+
+ r = mkbigint(1)
+ n = 0
+ while !bigiszero(exp)
+ if (exp.dig[0] & 1) != 0
+ bigmul(r, base)
+ bigmod(r, mod)
+ ;;
+ bigshri(exp, 1)
+ bigmul(base, base)
+ bigmod(base, mod)
+ ;;
+ -> bigmove(base, r)
+}
+
+/* returns the number of leading zeros */
+const nlz = {a : uint32
+ var n
+
+ if a == 0
+ -> 32
+ ;;
+ n = 0
+ if a <= 0x0000ffff
+ n += 16
+ a <<= 16
+ ;;
+ if a <= 0x00ffffff
+ n += 8
+ a <<= 8
+ ;;
+ if a <= 0x0fffffff
+ n += 4
+ a <<= 4
+ ;;
+ if a <= 0x3fffffff
+ n += 2
+ a <<= 2
+ ;;
+ if a <= 0x7fffffff
+ n += 1
+ a <<= 1
+ ;;
+ -> n
+}
+
+
+/* a <<= b */
+const bigshl = {a, b
+ match b.dig.len
+ | 0: -> a
+ | 1: -> bigshli(a, b.dig[0] castto(uint64))
+ | n: die("shift by way too much\n")
+ ;;
+}
+
+/* a >>= b, unsigned */
+const bigshr = {a, b
+ match b.dig.len
+ | 0: -> a
+ | 1: -> bigshri(a, b.dig[0] castto(uint64))
+ | n: die("shift by way too much\n")
+ ;;
+}
+
+/* a + b, b is integer.
+FIXME: acually make this a performace improvement
+*/
+generic bigaddi = {a, b
+ var bigb : bigint
+ var dig : uint32[2]
+
+ bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
+ bigadd(a, &bigb)
+ -> a
+}
+
+generic bigsubi = {a, b : @a::(numeric,integral)
+ var bigb : bigint
+ var dig : uint32[2]
+
+ bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
+ bigsub(a, &bigb)
+ -> a
+}
+
+generic bigmuli = {a, b
+ var bigb : bigint
+ var dig : uint32[2]
+
+ bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
+ bigmul(a, &bigb)
+ -> a
+}
+
+generic bigdivi = {a, b
+ var bigb : bigint
+ var dig : uint32[2]
+
+ bigdigit(&bigb, b < 0, b castto(uint64), dig[:])
+ bigdiv(a, &bigb)
+ -> a
+}
+
+/*
+ a << s, with integer arg.
+ logical left shift. any other type would be illogical.
+ */
+generic bigshli = {a, s : @a::(numeric,integral)
+ var off, shift
+ var t, carry
+ var i
+
+ assert(s >= 0, "shift amount must be positive")
+ off = (s castto(uint64)) / 32
+ shift = (s castto(uint64)) % 32
+
+ /* zero shifted by anything is zero */
+ if a.sign == 0
+ -> a
+ ;;
+ a.dig = slzgrow(a.dig, 1 + a.dig.len + off castto(size))
+ /* blit over the base values */
+ for i = a.dig.len; i > off; i--
+ a.dig[i - 1] = a.dig[i - 1 - off]
+ ;;
+ for i = 0; i < off; i++
+ a.dig[i] = 0
+ ;;
+ /* and shift over by the remainder */
+ carry = 0
+ for i = 0; i < a.dig.len; i++
+ t = (a.dig[i] castto(uint64)) << shift
+ a.dig[i] = (t | carry) castto(uint32)
+ carry = t >> 32
+ ;;
+ -> trim(a)
+}
+
+/* logical shift right, zero fills. sign remains untouched. */
+generic bigshri = {a, s
+ var off, shift
+ var t, carry
+ var i
+
+ assert(s >= 0, "shift amount must be positive")
+ off = (s castto(uint64)) / 32
+ shift = (s castto(uint64)) % 32
+
+ /* blit over the base values */
+ for i = 0; i < a.dig.len - off; i++
+ a.dig[i] = a.dig[i + off]
+ ;;
+ for i = a.dig.len - off; i < a.dig.len; i++
+ a.dig[i] = 0
+ ;;
+ /* and shift over by the remainder */
+ carry = 0
+ for i = a.dig.len; i > 0; i--
+ t = (a.dig[i - 1] castto(uint64))
+ a.dig[i - 1] = (carry | (t >> shift)) castto(uint32)
+ carry = t << (32 - shift)
+ ;;
+ -> trim(a)
+}
+
+/* creates a bigint on the stack; should not be modified. */
+const bigdigit = {v, isneg : bool, val : uint64, dig
+ v.sign = 1
+ if isneg
+ val = -val
+ v.sign = -1
+ ;;
+ if val == 0
+ v.sign = 0
+ v.dig = [][:]
+ elif val < Base
+ v.dig = dig[:1]
+ v.dig[0] = val castto(uint32)
+ else
+ v.dig = dig
+ v.dig[0] = val castto(uint32)
+ v.dig[1] = (val >> 32) castto(uint32)
+ ;;
+}
+
+/* trims leading zeros */
+const trim = {a
+ var i
+
+ for i = a.dig.len; i > 0; i--
+ if a.dig[i - 1] != 0
+ break
+ ;;
+ ;;
+ a.dig = slgrow(a.dig, i)
+ if i == 0
+ a.sign = 0
+ elif a.sign == 0
+ a.sign = 1
+ ;;
+ -> a
+}
+