summaryrefslogtreecommitdiff
path: root/lib/math/sqrt-impl.myr
diff options
context:
space:
mode:
Diffstat (limited to 'lib/math/sqrt-impl.myr')
-rw-r--r--lib/math/sqrt-impl.myr188
1 files changed, 188 insertions, 0 deletions
diff --git a/lib/math/sqrt-impl.myr b/lib/math/sqrt-impl.myr
new file mode 100644
index 0000000..f56f996
--- /dev/null
+++ b/lib/math/sqrt-impl.myr
@@ -0,0 +1,188 @@
+use std
+
+use "fpmath"
+
+/* See [Mul+10], sections 5.4 and 8.7 */
+pkg math =
+ pkglocal const sqrt32 : (f : flt32 -> flt32)
+ pkglocal const sqrt64 : (f : flt64 -> flt64)
+;;
+
+extern const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
+extern const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
+
+type fltdesc(@f, @u, @i) = struct
+ explode : (f : @f -> (bool, @i, @u))
+ assem : (n : bool, e : @i, s : @u -> @f)
+ fma : (x : @f, y : @f, z : @f -> @f)
+ tobits : (f : @f -> @u)
+ frombits : (u : @u -> @f)
+ nan : @f
+ emin : @i
+ emax : @i
+ normmask : @u
+ sgnmask : @u
+ ab : (@u, @u)[:]
+ iterlim : int
+;;
+
+/*
+ The starting point of the N-R iteration of 1/sqrt, after significand
+ has been normalized to [1, 4).
+
+ See [KM06] for the construction and notation. Case p = -2. The
+ dividers (left values) are chosen roughly so that maximal error
+ of N-R, after 3 iterations, starting with the right value, is
+ less than an ulp (of the result). If g falls in [a_i, a_{i+1}),
+ N-R should start with b_{i+1}.
+
+ In the flt64 case, we need only one more iteration.
+ */
+const ab32 : (uint32, uint32)[:] = [
+ (0x3f800000, 0x3f800000), /* Nothing should ever get normalized to < 1.0 */
+ (0x3fa66666, 0x3f6f30ae), /* [1.0, 1.3 ) -> 0.9343365431 */
+ (0x3fd9999a, 0x3f5173ca), /* [1.3, 1.7 ) -> 0.8181730509 */
+ (0x40100000, 0x3f3691d3), /* [1.7, 2.25) -> 0.713162601 */
+ (0x40333333, 0x3f215342), /* [2.25, 2.8 ) -> 0.6301766634 */
+ (0x4059999a, 0x3f118e0e), /* [2.8, 3.4 ) -> 0.5685738325 */
+ (0x40800000, 0x3f053049), /* [3.4, 4.0 ) -> 0.520268023 */
+][:]
+
+const ab64 : (uint64, uint64)[:] = [
+ (0x3ff0000000000000, 0x3ff0000000000000), /* < 1.0 */
+ (0x3ff3333333333333, 0x3fee892ce1608cbc), /* [1.0, 1.2) -> 0.954245033445111356940060431953 */
+ (0x3ff6666666666666, 0x3fec1513a2184094), /* [1.2, 1.4) -> 0.877572838393478438234751592972 */
+ (0x3ffc000000000000, 0x3fe9878eb3e9ba20), /* [1.4, 1.75) -> 0.797797538178034670863780775107 */
+ (0x400199999999999a, 0x3fe6ccb14eeb238d), /* [1.75, 2.2) -> 0.712486890924184046447464879748 */
+ (0x400599999999999a, 0x3fe47717c17cd34f), /* [2.2, 2.7) -> 0.639537694840876969060161627567 */
+ (0x400b333333333333, 0x3fe258df212a8e9a), /* [2.7, 3.4) -> 0.573348583963212421465982515656 */
+ (0x4010000000000000, 0x3fe0a5989f2dc59a), /* [3.4, 4.0) -> 0.520214377304159869552790951275 */
+][:]
+
+const sqrt32 = {f : flt32
+ var d : fltdesc(flt32, uint32, int32) = [
+ .explode = std.flt32explode,
+ .assem = std.flt32assem,
+ .fma = fma32,
+ .tobits = std.flt32bits,
+ .frombits = std.flt32frombits,
+ .nan = std.flt32nan(),
+ .emin = -127,
+ .emax = 128,
+ .normmask = 1 << 23,
+ .sgnmask = 1 << 31,
+ .ab = ab32,
+ .iterlim = 3,
+ ]
+ -> sqrtgen(f, d)
+}
+
+const sqrt64 = {f : flt64
+ var d : fltdesc(flt64, uint64, int64) = [
+ .explode = std.flt64explode,
+ .assem = std.flt64assem,
+ .fma = fma64,
+ .tobits = std.flt64bits,
+ .frombits = std.flt64frombits,
+ .nan = std.flt64nan(),
+ .emin = -1023,
+ .emax = 1024,
+ .normmask = 1 << 52,
+ .sgnmask = 1 << 63,
+ .ab = ab64,
+ .iterlim = 4,
+ ]
+ -> sqrtgen(f, d)
+}
+
+generic sqrtgen = {f : @f, d : fltdesc(@f, @u, @i) :: numeric,floating,std.equatable,fpmath @f, numeric,integral @u, numeric,integral @i
+ var n, e, s, e2
+ (n, e, s) = d.explode(f)
+
+ /* Special cases: +/- 0.0, negative, NaN, and +inf */
+ if e == d.emin && s == 0
+ -> f
+ elif n || std.isnan(f)
+ -> d.nan
+ elif e == d.emax
+ -> f
+ ;;
+
+ /*
+ Remove a factor of 2^(even) to normalize significand.
+ */
+ if e == d.emin
+ e = d.emin + 1
+ while s & d.normmask == 0
+ s <<= 1
+ e--
+ ;;
+ ;;
+ if e % 2 != 0
+ e2 = e - 1
+ e = 1
+ else
+ e2 = e
+ e = 0
+ ;;
+
+ var a = d.assem(false, e, s)
+ var au = d.tobits(a)
+
+ /*
+ We shall perform iterated Newton-Raphson in order to
+ compute 1/sqrt(g), then multiply by g to obtain sqrt(g).
+ This is faster than calculating sqrt(g) directly because
+ it avoids division. (The multiplication by g is built
+ into Markstein's r, g, n variables.)
+ */
+ var xn = d.frombits(0)
+ for (ai, beta) : d.ab
+ if au <= ai
+ xn = d.frombits(beta)
+ break
+ ;;
+ ;;
+
+ /* split up "x_{n+1} = x_n (3 - ax_n^2)/2" */
+ var epsn = fma(-1.0 * a, xn * xn, 1.0)
+ var rn = 0.5 * epsn
+ var gn = a * xn
+ var hn = 0.5 * xn
+ for var j = 0; j < d.iterlim; ++j
+ rn = d.fma(-1.0 * gn, hn, 0.5)
+ gn = d.fma(gn, rn, gn)
+ hn = d.fma(hn, rn, hn)
+ ;;
+
+ /*
+ gn is almost what we want, except that we might want to
+ adjust by an ulp in one direction or the other. This is
+ the Tuckerman test.
+
+ Exhaustive testing has shown that we need only 3 adjustments
+ in the flt32 case (and it should be 4 in the flt64 case).
+ */
+ (_, e, s) = d.explode(gn)
+ e += (e2 / 2)
+ var r = d.assem(false, e, s)
+
+ for var j = 0; j < d.iterlim; ++j
+ var r_plus_ulp = d.frombits(d.tobits(r) + 1)
+ var r_minus_ulp = d.frombits(d.tobits(r) - 1)
+
+ var delta_1 = d.fma(r, r_minus_ulp, -1.0 * f)
+ if d.tobits(delta_1) & d.sgnmask == 0
+ r = r_minus_ulp
+ else
+ var delta_2 = d.fma(r, r_plus_ulp, -1.0 * f)
+ if d.tobits(delta_2) & d.sgnmask != 0
+ r = r_plus_ulp
+ else
+ -> r
+ ;;
+ ;;
+ ;;
+
+ -> r
+}