summaryrefslogtreecommitdiff
path: root/lib/math/fma-impl.myr
blob: 8dfecb52a279bc5a944a2d21e2bbee70cf90719a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
use std

use "util"

pkg math =
	pkglocal const fma32 : (x : flt32, y : flt32, z : flt32 -> flt32)
	pkglocal const fma64 : (x : flt64, y : flt64, z : flt64 -> flt64)
;;

const exp_mask32 : uint32 = 0xff << 23
const exp_mask64 : uint64 = 0x7ff << 52

pkglocal const fma32 = {x : flt32, y : flt32, z : flt32
	var xn, yn
	(xn, _, _) = std.flt32explode(x)
	(yn, _, _) = std.flt32explode(y)
	var xd : flt64 = flt64fromflt32(x)
	var yd : flt64 = flt64fromflt32(y)
	var zd : flt64 = flt64fromflt32(z)
	var prod : flt64 = xd * yd
	var pn, pe, ps
	(pn, pe, ps) = std.flt64explode(prod)
	if pe == -1023
		pe = -1022
	;;
	if pn != (xn != yn)
		/* In case of NaNs, sign might not have been preserved */
		pn = (xn != yn)
		prod = std.flt64assem(pn, pe, ps)
	;;

	var r : flt64 = prod + zd
	var rn, re, rs
	(rn, re, rs) = std.flt64explode(r)

	/*
	   At this point, r is probably the correct answer. The
	   only issue is the rounding.

	   Ex 1: If x*y > 0 and z is a tiny, negative number, then
	   adding z probably does no rounding. However, if
	   truncating to 23 bits of precision would cause round-to-even,
	   and that round would be upwards, then we need to remember
	   those trailing bits of z and cancel the rounding.

	   Ex 2: If x, y, z > 0, and z is small, with
	                 last bit in flt64 |
	          last bit in flt32 v      v
	   x * y = ...............101011..11
	       z =                          10000...,
	   then x * y + z will be rounded to
	           ...............101100..00,
	   and then as a flt32 it will become
	           ...............110,
	   Even though, looking at the original bits, it doesn't
	   "deserve" the final rounding.

	   These can only happen if r is non-inf, non-NaN, and the
	   lower 29 bits correspond to "exactly halfway".
	 */
	if re == 1024 || rs & 0x1fffffff != 0x10000000
		-> flt32fromflt64(r)
	;;

	/*
	   At this point, a rounding is about to happen. We need
	   to know what direction that rounding is, so that we can
	   tell if it's wrong. +1 means "away from 0", -1 means
	   "towards 0".
	 */
	var zn, ze, zs
	(zn, ze, zs) = std.flt64explode(zd)
	var round_direction = 0
	if rs & 0x20000000 == 0
		round_direction = -1
	else
		round_direction = 1
	;;

	var smaller, larger, smaller_e, larger_e
	if pe > ze || (pe == ze && ps > zs)
		(smaller, larger, smaller_e, larger_e) = (zs, ps, ze, pe)
	else
		(smaller, larger, smaller_e, larger_e) = (ps, zs, pe, ze)
	;;
	var mask = shr((-1 : uint64), 64 - std.min(64, larger_e - smaller_e))
	var prevent_rounding = false
	if (round_direction > 0 && pn != zn) || (round_direction < 0 && pn == zn)
		/*
		   The prospective rounding disagrees with the
		   signage. We are potentially in the case of Ex
		   1.

		   Look at the bits (of the smaller flt64) that are
		   outside the range of r. If there are any such
		   bits, we need to cancel the rounding.

		   We certainly need to consider bits very far to
		   the right, but there's an awkwardness concerning
		   the bit just outside the flt64 range: it governed
		   round-to-even, so it might have had an effect.
		   We only care about bits which did not have an
		   effect. Therefore, we perform the subtraction
		   using only the bits from smaller that lie in
		   larger's range, then check whether the result
		   is susceptible to round-to-even.

		   (Since we only care about the last bit, and the
		   base is 2, subtraction or addition are equally
		   useful.)
		*/
		if (larger ^ shr(smaller, larger_e - smaller_e)) & 0x1 == 0
			prevent_rounding = smaller & mask != 0
		;;
	else
		/*
		   The prospective rounding agrees with the signage.
		   We are potentially in the case of Ex 2.

		   We just need to check if r was obtained by
		   rounding in the addition step. In this case, we
		   still check the smaller/larger, and we only
		   care about round-to-even. Any
		   rounding that happened previously is enough
		   reason to disqualify this next rounding.
		*/
		prevent_rounding = (larger ^ shr(smaller, larger_e - smaller_e)) & 0x1 != 0
	;;

	if prevent_rounding
		if round_direction > 0
			rs--
		else
			rs++
		;;
	;;

	-> flt32fromflt64(std.flt64assem(rn, re, rs))
}

pkglocal const fma64 = {x : flt64, y : flt64, z : flt64
	var xn : bool, yn : bool, zn : bool
	var xe : int64, ye : int64, ze : int64
	var xs : uint64, ys : uint64, zs : uint64

	var xb : uint64 = std.flt64bits(x)
	var yb : uint64 = std.flt64bits(y)
	var zb : uint64 = std.flt64bits(z)

	/* check for both NaNs and infinities */
	if xb & exp_mask64 == exp_mask64 || \
	   yb & exp_mask64 == exp_mask64
		-> x * y + z
	elif z == 0.0 || z == -0.0 || x * y == 0.0 || x * y == -0.0
		-> x * y + z
	elif zb & exp_mask64 == exp_mask64
		-> z
	;;

	(xn, xe, xs) = std.flt64explode(x)
	(yn, ye, ys) = std.flt64explode(y)
	(zn, ze, zs) = std.flt64explode(z)
	if xe == -1023
		xe = -1022
	;;
	if ye == -1023
		ye = -1022
	;;
	if ze == -1023
		ze = -1022
	;;

        /* Keep product in high/low uint64s */
	var xs_h : uint64 = xs >> 32
	var ys_h : uint64 = ys >> 32
	var xs_l : uint64 = xs & 0xffffffff
	var ys_l : uint64 = ys & 0xffffffff

	var t_l : uint64 = xs_l * ys_l
	var t_m : uint64 = xs_l * ys_h + xs_h * ys_l
	var t_h : uint64 = xs_h * ys_h

	var prod_l : uint64 = t_l + (t_m << 32)
	var prod_h : uint64 = t_h + (t_m >> 32)
	if t_l > prod_l
		prod_h++
	;;

	var prod_n = xn != yn
	var prod_lastbit_e = (xe - 52) + (ye - 52)
	var prod_first1 = find_first1_64_hl(prod_h, prod_l, 105)
	var prod_firstbit_e = prod_lastbit_e + prod_first1

	var z_firstbit_e = ze
	var z_lastbit_e = ze - 52
	var z_first1 = 52

	/* subnormals could throw firstbit_e calculations out of whack */
	if (zb & exp_mask64 == 0)
		z_first1 = find_first1_64(zs, z_first1)
		z_firstbit_e = z_lastbit_e + z_first1
	;;

	var res_n
	var res_h = 0
	var res_l = 0
	var res_first1
	var res_lastbit_e
	var res_firstbit_e

	if prod_n == zn
		res_n = prod_n

		/*
		   Align prod and z so that the top bit of the
		   result is either 53 or 54, then add.
		 */
		if prod_firstbit_e >= z_firstbit_e
			/*
			    [ prod_h ][ prod_l ]
			         [ z...
			 */
			res_lastbit_e = prod_lastbit_e
			(res_h, res_l) = (prod_h, prod_l)
			(res_h, res_l) = add_shifted(res_h, res_l, zs, z_lastbit_e - prod_lastbit_e)
		else
			/*
			        [ prod_h ][ prod_l ]
			    [ z...
			 */
			res_lastbit_e = z_lastbit_e - 64
			res_h = zs
			res_l = 0
			if prod_lastbit_e >= res_lastbit_e + 64
				/* In this situation, prod must be extremely subnormal */
				res_h += shl(prod_l, prod_lastbit_e - res_lastbit_e - 64)
			elif prod_lastbit_e >= res_lastbit_e
				res_h += shl(prod_h, prod_lastbit_e - res_lastbit_e)
				res_h += shr(prod_l, res_lastbit_e + 64 - prod_lastbit_e)
				res_l += shl(prod_l, prod_lastbit_e - res_lastbit_e)
			elif prod_lastbit_e + 64 >= res_lastbit_e
				res_h += shr(prod_h, res_lastbit_e - prod_lastbit_e)
				var l1 = shl(prod_h, prod_lastbit_e + 64 - res_lastbit_e)
				var l2 = shr(prod_l, res_lastbit_e - prod_lastbit_e)
				res_l = l1 + l2
				if res_l < l1
					res_h++
				;;
			elif prod_lastbit_e + 128 >= res_lastbit_e
				res_l += shr(prod_h, res_lastbit_e - prod_lastbit_e - 64)
			;;
		;;
	else
		match compare_hl_z(prod_h, prod_l, prod_firstbit_e, prod_lastbit_e, zs, z_firstbit_e, z_lastbit_e)
		| `std.Equal: -> 0.0
		| `std.Before:
			/* prod > z */
			res_n = prod_n
			res_lastbit_e = prod_lastbit_e
			(res_h, res_l) = sub_shifted(prod_h, prod_l, zs, z_lastbit_e - prod_lastbit_e)
		| `std.After:
			/* z > prod */
			res_n = zn
			res_lastbit_e = z_lastbit_e - 64
			(res_h, res_l) = sub_shifted(zs, 0, prod_h, prod_lastbit_e + 64 - (z_lastbit_e - 64))
			(res_h, res_l) = sub_shifted(res_h, res_l, prod_l, prod_lastbit_e - (z_lastbit_e - 64))
		;;
	;;

	res_first1 = 64 + find_first1_64(res_h, 55)
	if res_first1 == 63
		res_first1 = find_first1_64(res_l, 63)
	;;
	res_firstbit_e = res_first1 + res_lastbit_e

	/*
	   Finally, res_h and res_l are the high and low bits of
	   the result. They now need to be assembled into a flt64.
	   Subnormals and infinities could be a problem.
	 */
	var res_s = 0
	if res_firstbit_e <= -1023
		/* Subnormal case */
		if res_lastbit_e + 128 < 12 - 1022
			res_s = shr(res_h, 12 - 1022 - (res_lastbit_e + 128))
			res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
		elif res_lastbit_e + 64 < 12 - 1022
			res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
			res_s |= shr(res_l, 12 - 1022 - (res_lastbit_e + 64))
		else
			res_s = shl(res_h, -12 + (res_lastbit_e + 128) - (-1022))
			res_s |= shl(res_l, -12 + (res_lastbit_e + 64) - (-1022))
		;;

		if need_round_away(res_h, res_l, res_first1 + (-1074 - res_firstbit_e))
			res_s++
		;;

		/* No need for exponents, they are all zero */
		var res = res_s
		if res_n
			res |= (1 << 63)
		;;
		-> std.flt64frombits(res)
	;;

	if res_firstbit_e >= 1024
		/* Infinity case */
		if res_n
			-> std.flt64frombits(0xfff0000000000000)
		else
			-> std.flt64frombits(0x7ff0000000000000)
		;;
	;;

	if res_first1 - 52 >= 64
		res_s = shr(res_h, (res_first1 : int64) - 64 - 52)
		if need_round_away(res_h, res_l, res_first1 - 52)
			res_s++
		;;
	elif res_first1 - 52 >= 0
		res_s = shl(res_h, 64 - (res_first1 - 52))
		res_s |= shr(res_l, res_first1 - 52)
		if need_round_away(res_h, res_l, res_first1 - 52)
			res_s++
		;;
	else
		res_s = shl(res_h, res_first1 - 52)
	;;

	/* The res_s++s might have messed everything up */
	if res_s & (1 << 53) != 0
		res_s >= 1
		res_firstbit_e++
		if res_firstbit_e >= 1024
			if res_n
				-> std.flt64frombits(0xfff0000000000000)
			else
				-> std.flt64frombits(0x7ff0000000000000)
			;;
		;;
	;;

	-> std.flt64assem(res_n, res_firstbit_e, res_s)
}

/*
   Add (a << s) to [ h ][ l ], where if s < 0 then a corresponding
   right-shift is used. This is aligned such that if s == 0, then
   the result is [ h ][ l + a ]
 */
const add_shifted = {h : uint64, l : uint64, a : uint64, s : int64
	if s >= 64
		-> (h + shl(a, s - 64), l)
	elif s >= 0
		var new_h = h + shr(a, 64 - s)
		var sa = shl(a, s)
		var new_l = l + sa
		if new_l < l
			new_h++
		;;
		-> (new_h, new_l)
	else
		var new_h = h
		var sa = shr(a, -s)
		var new_l = l + sa
		if new_l < l
			new_h++
		;;
		-> (new_h, new_l)
	;;
}

/* As above, but subtract (a << s) */
const sub_shifted = {h : uint64, l : uint64, a : uint64, s : int64
	if s >= 64
		-> (h - shl(a, s - 64), l)
	elif s >= 0
		var new_h = h - shr(a, 64 - s)
		var sa = shl(a, s)
		var new_l = l - sa
		if sa > l
			new_h--
		;;
		-> (new_h, new_l)
	else
		var new_h = h
		var sa = shr(a, -s)
		var new_l = l - sa
		if sa > l
			new_h--
		;;
		-> (new_h, new_l)
	;;
}

const compare_hl_z = {h : uint64, l : uint64, hl_firstbit_e : int64, hl_lastbit_e : int64, z : uint64, z_firstbit_e : int64, z_lastbit_e : int64
	if hl_firstbit_e > z_firstbit_e
		-> `std.Before
	elif hl_firstbit_e < z_firstbit_e
		-> `std.After
	;;

	var h_k : int64 = (hl_firstbit_e - hl_lastbit_e - 64)
	var z_k : int64 = (z_firstbit_e - z_lastbit_e)
	while h_k >= 0 && z_k >= 0
		var h1 = h & shl(1, h_k) != 0
		var z1 = z & shl(1, z_k) != 0
		if h1 && !z1
			-> `std.Before
		elif !h1 && z1
			-> `std.After
		;;
		h_k--
		z_k--
	;;

	if z_k < 0
		if (h & shr((-1 : uint64), 64 - h_k) != 0) || (l != 0)
			-> `std.Before
		else
			-> `std.Equal
		;;
	;;

	var l_k : int64 = 63
	while l_k >= 0 && z_k >= 0
		var l1 = l & shl(1, l_k) != 0
		var z1 = z & shl(1, z_k) != 0
		if l1 && !z1
			-> `std.Before
		elif !l1 && z1
			-> `std.After
		;;
		l_k--
		z_k--
	;;

	if (z_k < 0) && (l & shr((-1 : uint64), 64 - l_k) != 0)
		-> `std.Before
	elif (l_k < 0) && (z & shr((-1 : uint64), 64 - z_k) != 0)
		-> `std.After
	;;

	-> `std.Equal
}