summaryrefslogtreecommitdiff
path: root/lib/math/sqrt-impl.myr
blob: eec1b964af35a341c1398333440215752b213e06 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
use std

use "fpmath"

/* See [Mul+10], sections 5.4 and 8.7 */
pkg math =
	pkglocal const sqrt32 : (x : flt32 -> flt32)
	pkglocal const sqrt64 : (x : flt64 -> flt64)
;;

extern const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
extern const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)

type fltdesc(@f, @u, @i) = struct
	explode : (f : @f -> (bool, @i, @u))
	assem : (n : bool, e : @i, s : @u -> @f)
	fma : (x : @f, y : @f, z : @f -> @f)
	tobits : (f : @f -> @u)
	frombits : (u : @u -> @f)
	nan : @u
	emin : @i
	emax : @i
	normmask : @u
	sgnmask : @u
	ab : (@u, @u)[:]
	iterlim : int
;;

/*
   The starting point of the N-R iteration of 1/sqrt, after significand
   has been normalized to [1, 4).

   See [KM06] for the construction and notation. Case p = -2. The
   dividers (left values) are chosen roughly so that maximal error
   of N-R, after 3 iterations, starting with the right value, is
   less than an ulp (of the result). If g falls in [a_i, a_{i+1}),
   N-R should start with b_{i+1}.

   In the flt64 case, we need only one more iteration.
 */
const ab32 : (uint32, uint32)[:] = [
	(0x3f800000, 0x3f800000), /* Nothing should ever get normalized to < 1.0 */
	(0x3fa66666, 0x3f6f30ae), /* [1.0,  1.3 ) -> 0.9343365431 */
	(0x3fd9999a, 0x3f5173ca), /* [1.3,  1.7 ) -> 0.8181730509 */
	(0x40100000, 0x3f3691d3), /* [1.7,  2.25) -> 0.713162601  */
	(0x40333333, 0x3f215342), /* [2.25, 2.8 ) -> 0.6301766634 */
	(0x4059999a, 0x3f118e0e), /* [2.8,  3.4 ) -> 0.5685738325 */
	(0x40800000, 0x3f053049), /* [3.4,  4.0 ) -> 0.520268023  */
][:]

const ab64 : (uint64, uint64)[:] = [
	(0x3ff0000000000000, 0x3ff0000000000000), /* < 1.0 */
	(0x3ff3333333333333, 0x3fee892ce1608cbc), /* [1.0,  1.2)  -> 0.954245033445111356940060431953 */
	(0x3ff6666666666666, 0x3fec1513a2184094), /* [1.2,  1.4)  -> 0.877572838393478438234751592972 */
	(0x3ffc000000000000, 0x3fe9878eb3e9ba20), /* [1.4,  1.75) -> 0.797797538178034670863780775107 */
	(0x400199999999999a, 0x3fe6ccb14eeb238d), /* [1.75, 2.2)  -> 0.712486890924184046447464879748 */
	(0x400599999999999a, 0x3fe47717c17cd34f), /* [2.2,  2.7)  -> 0.639537694840876969060161627567 */
	(0x400b333333333333, 0x3fe258df212a8e9a), /* [2.7,  3.4)  -> 0.573348583963212421465982515656 */
	(0x4010000000000000, 0x3fe0a5989f2dc59a), /* [3.4,  4.0)  -> 0.520214377304159869552790951275 */
][:]

const desc32 : fltdesc(flt32, uint32, int32) =  [
	.explode = std.flt32explode,
	.assem = std.flt32assem,
	.fma = fma32,
	.tobits = std.flt32bits,
	.frombits = std.flt32frombits,
	.nan = 0x7fc00000,
	.emin = -127,
	.emax = 128,
	.normmask = 1 << 23,
	.sgnmask = 1 << 31,
	.ab = ab32,
	.iterlim = 3,
]

const desc64 : fltdesc(flt64, uint64, int64) =  [
	.explode = std.flt64explode,
	.assem = std.flt64assem,
	.fma = fma64,
	.tobits = std.flt64bits,
	.frombits = std.flt64frombits,
	.nan = 0x7ff8000000000000,
	.emin = -1023,
	.emax = 1024,
	.normmask = 1 << 52,
	.sgnmask = 1 << 63,
	.ab = ab64,
	.iterlim = 4,
]

const sqrt32 = {x : flt32
	-> sqrtgen(x, desc32)
}

const sqrt64 = {x : flt64
	-> sqrtgen(x, desc64)
}

generic sqrtgen = {x : @f, d : fltdesc(@f, @u, @i) :: numeric,floating,std.equatable @f, numeric,integral @u, numeric,integral @i
	var n : bool, e : @i, s : @u, e2 : @i
	(n, e, s) = d.explode(x)

	/* Special cases: +/- 0.0, negative, NaN, and +inf */
	if e == d.emin && s == 0
		-> x
	elif n || std.isnan(x)
		/* Make sure to return a quiet NaN */
		-> d.frombits(d.nan)
	elif e == d.emax
		-> x
	;;

	/*
	   Remove a factor of 2^(even) to normalize significand.
	 */
	if e == d.emin
		e = d.emin + 1
		while s & d.normmask == 0
			s <<= 1
			e--
		;;
	;;
	if e % 2 != 0
		e2 = e - 1
		e = 1
	else
		e2 = e
		e = 0
	;;

	var a : @f = d.assem(false, e, s)
	var au : @u = d.tobits(a)

        /*
           We shall perform iterated Newton-Raphson in order to
           compute 1/sqrt(g), then multiply by g to obtain sqrt(g).
           This is faster than calculating sqrt(g) directly because
           it avoids division. (The multiplication by g is built
           into Markstein's r, g, n variables.)
         */
	var xn : @f = d.frombits(0)
	for (ai, beta) : d.ab
		if au <= ai
			xn = d.frombits(beta)
			break
		;;
	;;

	/* split up "x_{n+1} = x_n (3 - ax_n^2)/2" */
	var epsn = d.fma(-1.0 * a, xn * xn, 1.0)
	var rn = 0.5 * epsn
	var gn = a * xn
	var hn = 0.5 * xn
	for var j = 0; j < d.iterlim; ++j
		rn = d.fma(-1.0 * gn, hn, 0.5)
		gn = d.fma(gn, rn, gn)
		hn = d.fma(hn, rn, hn)
	;;

	/*
	   gn is almost what we want, except that we might want to
	   adjust by an ulp in one direction or the other. This is
	   the Tuckerman test.

	   Exhaustive testing has shown that we need only 3 adjustments
	   in the flt32 case (and it should be 4 in the flt64 case).
	 */
	(_, e, s) = d.explode(gn)
	e += (e2 / 2)
	var r : @f = d.assem(false, e, s)

	for var j = 0; j < d.iterlim; ++j
		var r_plus_ulp : @f = d.frombits(d.tobits(r) + 1)
		var r_minus_ulp : @f = d.frombits(d.tobits(r) - 1)

		var delta_1 = d.fma(r, r_minus_ulp, -1.0 * f)
		if d.tobits(delta_1) & d.sgnmask == 0
			r = r_minus_ulp
		else
			var delta_2 = d.fma(r, r_plus_ulp, -1.0 * f)
			if d.tobits(delta_2) & d.sgnmask != 0
				r = r_plus_ulp
			else
				-> r
			;;
		;;
	;;

	-> r
}