summaryrefslogtreecommitdiff log msg author committer range
path: root/lib/math/pown-impl.myr
blob: 2feecb539f7bc26921f1f7ae1abd467d452acfd8 (plain)
 ```1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 ``` ``````use std use "fpmath" use "log-impl" use "log-overkill" use "sum-impl" use "util" /* This is an implementation of pown: computing x^n where n is an integer. We sort of follow [PEB04], but without their high-radix log_2. Instead, we use log-overkill, which should be good enough. */ pkg math = pkglocal const pown32 : (x : flt32, n : int32 -> flt32) pkglocal const pown64 : (x : flt64, n : int64 -> flt64) pkglocal const rootn32 : (x : flt32, q : uint32 -> flt32) pkglocal const rootn64 : (x : flt64, q : uint64 -> flt64) ;; type fltdesc(@f, @u, @i) = struct explode : (f : @f -> (bool, @i, @u)) assem : (n : bool, e : @i, s : @u -> @f) tobits : (f : @f -> @u) frombits : (u : @u -> @f) C : (@u, @u)[:] one_over_ln2_hi : @u one_over_ln2_lo : @u nan : @u inf : @u neginf : @u magcmp : (f : @f, g : @f -> std.order) two_by_two : (x : @f, y : @f -> (@f, @f)) log_overkill : (x : @f -> (@f, @f)) emin : @i emax : @i imax : @i imin : @i ;; const desc32 : fltdesc(flt32, uint32, int32) = [ .explode = std.flt32explode, .assem = std.flt32assem, .tobits = std.flt32bits, .frombits = std.flt32frombits, .C = accurate_logs32[0:130], /* See log-impl.myr */ .one_over_ln2_hi = 0x3fb8aa3b, /* 1/ln(2), top part */ .one_over_ln2_lo = 0x32a57060, /* 1/ln(2), bottom part */ .nan = 0x7fc00000, .inf = 0x7f800000, .neginf = 0xff800000, .magcmp = mag_cmp32, .two_by_two = two_by_two32, .log_overkill = logoverkill32, .emin = -126, .emax = 127, .imax = 2147483647, /* For detecting overflow in final exponent */ .imin = -2147483648, ] const desc64 : fltdesc(flt64, uint64, int64) = [ .explode = std.flt64explode, .assem = std.flt64assem, .tobits = std.flt64bits, .frombits = std.flt64frombits, .C = accurate_logs64[0:130], /* See log-impl.myr */ .one_over_ln2_hi = 0x3ff71547652b82fe, .one_over_ln2_lo = 0x3c7777d0ffda0d24, .nan = 0x7ff8000000000000, .inf = 0x7ff0000000000000, .neginf = 0xfff0000000000000, .magcmp = mag_cmp64, .two_by_two = two_by_two64, .log_overkill = logoverkill64, .emin = -1022, .emax = 1023, .imax = 9223372036854775807, .imin = -9223372036854775808, ] const pown32 = {x : flt32, n : int32 -> powngen(x, n, desc32) } const pown64 = {x : flt64, n : int64 -> powngen(x, n, desc64) } generic powngen = {x : @f, n : @i, d : fltdesc(@f, @u, @i) :: numeric,floating,std.equatable @f, numeric,integral @u, numeric,integral @i var xb xb = d.tobits(x) var xn : bool, xe : @i, xs : @u (xn, xe, xs) = d.explode(x) var nf : @f = (n : @f) /* Special cases. Note we do not follow IEEE exceptions. */ if n == 0 /* Anything^0 is 1. We're taking the view that x is a tiny range of reals, so a dense subset of them are 1, even if x is 0.0. */ -> 1.0 elif std.isnan(x) /* Propagate NaN (why doesn't this come first? Ask IEEE.) */ -> d.frombits(d.nan) elif (x == 0.0 || x == -0.0) if n < 0 && (n % 2 == 1) && xn /* (+/- 0)^n = +/- oo */ -> d.frombits(d.neginf) elif n < 0 -> d.frombits(d.inf) elif n % 2 == 1 /* (+/- 0)^n = +/- 0 (n odd) */ -> d.assem(xn, d.emin - 1, 0) else -> 0.0 ;; elif n == 1 /* Anything^1 is itself */ -> x ;; /* (-f)^n = (-1)^n * (f)^n. Figure this out now, then pretend f >= 0.0 */ var ult_sgn = 1.0 if xn && (n % 2 == 1 || n % 2 == -1) ult_sgn = -1.0 ;; /* Compute (with x = xs * 2^e) x^n = 2^(n*log2(xs)) * 2^(n*e) = 2^(I + F) * 2^(n*e) = 2^(F) * 2^(I+n*e) Since n and e, and I are all integers, we can get the last part from scale2. The hard part is computing I and F, and then computing 2^F. */ var ln_xs_hi, ln_xs_lo (ln_xs_hi, ln_xs_lo) = d.log_overkill(d.assem(false, 0, xs)) /* Now x^n = 2^(n * [ ln_xs / ln(2) ]) * 2^(n + e) */ var ls1 : @f (ls1, ls1) = d.two_by_two(ln_xs_hi, d.frombits(d.one_over_ln2_hi)) (ls1, ls1) = d.two_by_two(ln_xs_hi, d.frombits(d.one_over_ln2_lo)) (ls1, ls1) = d.two_by_two(ln_xs_lo, d.frombits(d.one_over_ln2_hi)) (ls1, ls1) = d.two_by_two(ln_xs_lo, d.frombits(d.one_over_ln2_lo)) /* Now log2(xs) = Sum(ls1), so x^n = 2^(n * Sum(ls1)) * 2^(n * e) */ var E1, E2 (E1, E2) = double_compensated_sum(ls1[0:8]) var ls2 : @f var ls2s : @f var I = 0 (ls2, ls2) = d.two_by_two(E1, nf) (ls2, ls2) = d.two_by_two(E2, nf) ls2 = 0.0 /* Now x^n = 2^(Sum(ls2)) * 2^(n + e) */ for var j = 0; j < 5; ++j var i = rn(ls2[j]) I += i ls2[j] -= (i : @f) ;; var F1, F2 std.slcp(ls2s[0:5], ls2[0:5]) std.sort(ls2s[0:5], d.magcmp) (F1, F2) = double_compensated_sum(ls2s[0:5]) if (F1 < 0.0 || F1 > 1.0) var i = rn(F1) I += i ls2 -= (i : @f) std.slcp(ls2s[0:5], ls2[0:5]) std.sort(ls2s[0:5], d.magcmp) (F1, F2) = double_compensated_sum(ls2s[0:5]) ;; /* Now, x^n = 2^(F1 + F2) * 2^(I + n*e). */ var ls3 : @f var log2_hi, log2_lo (log2_hi, log2_lo) = d.C (ls3, ls3) = d.two_by_two(F1, d.frombits(log2_hi)) (ls3, ls3) = d.two_by_two(F1, d.frombits(log2_lo)) (ls3, ls3) = d.two_by_two(F2, d.frombits(log2_hi)) var G1, G2 (G1, G2) = double_compensated_sum(ls3[0:6]) var base = exp(G1) + G2 var pow_xen = xe * n var pow = pow_xen + I if pow_xen / n != xe || (I > 0 && d.imax - I < pow_xen) || (I < 0 && d.imin - I > pow_xen) /* The exponent overflowed. There's no way this is representable. We need to at least recover the correct sign. If the overflow was from the multiplication, then the sign we want is the sign that pow_xen should have been. If the overflow was from the addition, then we still want the sign that pow_xen should have had. */ if (xe > 0) == (n > 0) pow = 2 * d.emax else pow = 2 * d.emin ;; ;; -> ult_sgn * scale2(base, pow) } /* Rootn is barely different enough from pown to justify being split out into an entirely separate function. */ const rootn32 = {x : flt32, q : uint32 -> rootngen(x, q, desc32) } const rootn64 = {x : flt64, q : uint64 -> rootngen(x, q, desc64) } generic rootngen = {x : @f, q : @u, d : fltdesc(@f, @u, @i) :: numeric,floating,std.equatable @f, numeric,integral @u, numeric,integral @i var xb xb = d.tobits(x) var xn : bool, xe : @i, xs : @u (xn, xe, xs) = d.explode(x) var qf : @f = (q : @f) /* Special cases. Note we do not follow IEEE exceptions. */ if q == 0 /* "for any x (even a zero, quiet NaN, or infinity" */ -> 1.0 elif std.isnan(x) -> d.frombits(d.nan) elif (x == 0.0 || x == -0.0) if xn && q % 2 == 1 /* (+/- 0)^1/q = +/- oo (q odd) */ -> d.assem(xn, d.emax, 0) else -> d.frombits(d.inf) ;; elif q == 1 /* Anything^1/1 is itself */ -> x ;; /* As in pown */ var ult_sgn = 1.0 if xn && (q % 2 == 1) ult_sgn = -1.0 ;; /* Similar to pown. Let e/q = E + psi, with E an integer. x^(1/q) = e^(log(xs)/q) * 2^(e/q) = e^(log(xs)/q) * 2^(psi) * 2^E = e^(log(xs)/q) * e^(log(2) * psi) * 2^E = e^( log(xs)/q + log(2) * psi ) * 2^E I've opted to do things just in terms of natural base here because we don't have an integer part, I, that we can slide over in infinite precision. */ /* Calculate 1/q in very high precision */ var r1 = 1.0 / qf var r2 = -math.fma(r1, qf, -1.0) / qf var ln_xs_hi, ln_xs_lo (ln_xs_hi, ln_xs_lo) = d.log_overkill(d.assem(false, 0, xs)) var ls1 : @f (ls1, ls1) = d.two_by_two(ln_xs_hi, r1) (ls1, ls1) = d.two_by_two(ln_xs_hi, r2) (ls1, ls1) = d.two_by_two(ln_xs_lo, r1) var E : @i if q > std.abs(xe) /* Don't cast q to @i unless we're sure it's in small range */ E = 0 else E = xe / (q : @i) ;; var qpsi = xe - q * E var psi_hi = (qpsi : @f) / qf var psi_lo = -math.fma(psi_hi, qf, -(qpsi : @f)) / qf var log2_hi, log2_lo (log2_hi, log2_lo) = d.C (ls1[ 6], ls1[ 7]) = d.two_by_two(psi_hi, d.frombits(log2_hi)) (ls1[ 8], ls1[ 9]) = d.two_by_two(psi_hi, d.frombits(log2_lo)) (ls1, ls1) = d.two_by_two(psi_lo, d.frombits(log2_hi)) var G1, G2 (G1, G2) = double_compensated_sum(ls1[0:12]) /* G1 + G2 approximates log(xs)/q + log(2)*psi */ var base = exp(G1) + G2 -> ult_sgn * scale2(base, E) } ``````