// Code generated by "stringer -type=Accuracy"; DO NOT EDIT. package big import "strconv" func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[Below - -1] _ = x[Exact-0] _ = x[Above-1] } const _Accuracy_name = "BelowExactAbove" var _Accuracy_index = [...]uint8{0, 5, 10, 15} func (i Accuracy) String() string { i -= -1 if i < 0 || i >= Accuracy(len(_Accuracy_index)-1) { return "Accuracy(" + strconv.FormatInt(int64(i+-1), 10) + ")" } return _Accuracy_name[_Accuracy_index[i]:_Accuracy_index[i+1]] }
// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file provides Go implementations of elementary multi-precision // arithmetic operations on word vectors. These have the suffix _g. // These are needed for platforms without assembly implementations of these routines. // This file also contains elementary operations that can be implemented // sufficiently efficiently in Go. package big import ( "math/bits" _ "unsafe" // for go:linkname ) // A Word represents a single digit of a multi-precision unsigned integer. type Word uint const ( _S = _W / 8 // word size in bytes _W = bits.UintSize // word size in bits _B = 1 << _W // digit base _M = _B - 1 // digit mask ) // In these routines, it is the caller's responsibility to arrange for // x, y, and z to all have the same length. We check this and panic. // The assembly versions of these routines do not include that check. // // The check+panic also has the effect of teaching the compiler that // “i in range for z” implies “i in range for x and y”, eliminating all // bounds checks in loops from 0 to len(z) and vice versa. // ---------------------------------------------------------------------------- // Elementary operations on words // // These operations are used by the vector operations below. // mulWW returns z1:z0 = x*y, where z1:z0 denotes z1<<_W + z0. // It is a wrapper for bits.Mul to avoid writing out type conversions. // z1<<_W + z0 = x*y func mulWW(x, y Word) (z1, z0 Word) { hi, lo := bits.Mul(uint(x), uint(y)) return Word(hi), Word(lo) } // mulAddWWW returns z1:z0 = x*y + c, where z1:z0 denotes z1<<_W + z0. // It is a wrapper for bits.Mul and bits.Add to avoid writing out type conversions. func mulAddWWW(x, y, c Word) (z1, z0 Word) { hi, lo := bits.Mul(uint(x), uint(y)) var cc uint lo, cc = bits.Add(lo, uint(c), 0) return Word(hi + cc), Word(lo) } // nlz returns the number of leading zeros in x. // It is a wrapper for bits.LeadingZeros to avoid writing out type conversions. func nlz(x Word) uint { return uint(bits.LeadingZeros(uint(x))) } // addWW returns z, cc = x + y + c. // It is a wrapper for bits.Add to avoid writing out type conversions. func addWW(x, y, c Word) (z, cc Word) { zu, cu := bits.Add(uint(x), uint(y), uint(c)) return Word(zu), Word(cu) } // subWW returns z, cc = x - y - c. // It is a wrapper for bits.Sub to avoid writing out type conversions. func subWW(x, y, c Word) (z, cc Word) { zu, cu := bits.Sub(uint(x), uint(y), uint(c)) return Word(zu), Word(cu) } // addVV_g sets z, c = x + y. // It requires len(z) == len(x) == len(y). // The resulting carry is either 0 or 1. func addVV_g(z, x, y []Word) (c Word) { if len(x) != len(z) || len(y) != len(z) { panic("addVV len") } for i := range z { z[i], c = addWW(x[i], y[i], c) } return c } // subVV_g sets z, c = x - y. // It requires len(z) == len(x) == len(y). // The resulting carry is either 0 or 1. func subVV_g(z, x, y []Word) (c Word) { if len(x) != len(z) || len(y) != len(z) { panic("subVV len") } for i := range z { z[i], c = subWW(x[i], y[i], c) } return c } // addSubVV_g sets za, ca = x+y and zs, cs = x-y. // It requires len(za) == len(zs) == len(x) == len(y). // The resulting carries are either 0 or 1. func addSubVV_g(za, zs, x, y []Word) (ca, cs Word) { if len(zs) != len(za) || len(x) != len(za) || len(y) != len(za) { panic("addSubVV len") } for i := range za { xi := x[i] yi := y[i] za[i], ca = addWW(xi, yi, ca) zs[i], cs = subWW(xi, yi, cs) } return ca, cs } // addVW sets z = x + y, returning the final carry c. // It requires len(z) == len(x). // If len(z) == 0, c = y; otherwise, c is 0 or 1. // // addVW should be an internal detail, // but widely used packages access it using linkname. // Notable members of the hall of shame include: // - github.com/remyoudompheng/bigfft // // Do not remove or change the type signature. // See go.dev/issue/67401. // //go:linkname addVW func addVW(z, x []Word, y Word) (c Word) { if len(x) != len(z) { panic("addVW len") } if len(z) == 0 { return y } z[0], c = addWW(x[0], y, 0) if c == 0 { if &z[0] != &x[0] { copy(z[1:], x[1:]) } return 0 } for i := 1; i < len(z); i++ { xi := x[i] if xi != ^Word(0) { z[i] = xi + 1 if &z[0] != &x[0] { copy(z[i+1:], x[i+1:]) } return 0 } z[i] = 0 } return 1 } // subVW sets z = x - y, returning the final carry c. // It requires len(z) == len(x). // If len(z) == 0, c = y; otherwise, c is 0 or 1. // // subVW should be an internal detail, // but widely used packages access it using linkname. // Notable members of the hall of shame include: // - github.com/remyoudompheng/bigfft // // Do not remove or change the type signature. // See go.dev/issue/67401. // //go:linkname subVW func subVW(z, x []Word, y Word) (c Word) { if len(x) != len(z) { panic("subVW len") } if len(z) == 0 { return y } z[0], c = subWW(x[0], y, 0) if c == 0 { if &z[0] != &x[0] { copy(z[1:], x[1:]) } return 0 } for i := 1; i < len(z); i++ { xi := x[i] if xi != 0 { z[i] = xi - 1 if &z[0] != &x[0] { copy(z[i+1:], x[i+1:]) } return 0 } z[i] = ^Word(0) } return 1 } // lshVU_g sets z = x << s, returning the final carry word. // It requires len(z) == len(x) and 1 ≤ s ≤ _W-1. func lshVU_g(z, x []Word, s uint) (c Word) { if len(x) != len(z) { panic("lshVU len") } if s <= 0 || s >= _W { panic("lshVU shift") } if len(z) == 0 { return } s &= _W - 1 // hint to the compiler that shifts by s don't need guard code ŝ := _W - s ŝ &= _W - 1 // ditto c = x[len(z)-1] >> ŝ for i := len(z) - 1; i > 0; i-- { z[i] = x[i]<<s | x[i-1]>>ŝ } z[0] = x[0] << s return } // rshVU_g sets z = x >> s, returning the final carry (underflow) word. // It requires len(z) == len(x) and 1 ≤ s ≤ _W-1. func rshVU_g(z, x []Word, s uint) (c Word) { if len(x) != len(z) { panic("rshVU len") } if s <= 0 || s >= _W { panic("lshVU shift") } if len(z) == 0 { return } s &= _W - 1 // hint to the compiler that shifts by s don't need guard code ŝ := _W - s ŝ &= _W - 1 // ditto c = x[0] << ŝ for i := 1; i < len(z); i++ { z[i-1] = x[i-1]>>s | x[i]<<ŝ } z[len(z)-1] = x[len(z)-1] >> s return } // addLshVVU_g sets z = x + y<<s, returning the final carry word. // It requires len(z) == len(x) == len(y) and 1 ≤ s ≤ _W-1. func addLshVVU_g(z, x, y []Word, s uint) (c Word) { if len(x) != len(z) || len(y) != len(z) { panic("addLshVVU len") } if s <= 0 || s >= _W { panic("addLshVVU shift") } s &= _W - 1 // hint to the compiler that shifts by s don't need guard code ŝ := _W - s ŝ &= _W - 1 // ditto var cs Word for i := range z { yi := y[i] cs, yi = yi>>ŝ, yi<<s|cs z[i], c = addWW(x[i], yi, c) } return c + cs } // subLshVVU_g sets z = x - y<<s, returning the final carry word. // It requires len(z) == len(x) == len(y) and 1 ≤ s ≤ _W-1. func subLshVVU_g(z, x, y []Word, s uint) (c Word) { if len(x) != len(z) || len(y) != len(z) { panic("subLshVVU len") } if s <= 0 || s >= _W { panic("subLshVVU shift") } s &= _W - 1 // hint to the compiler that shifts by s don't need guard code ŝ := _W - s ŝ &= _W - 1 // ditto var cs Word for i := range z { yi := y[i] cs, yi = yi>>ŝ, yi<<s|cs z[i], c = subWW(x[i], yi, c) } return c + cs } // mulAddVWW_g sets z = x*m + a, returning the final carry word. func mulAddVWW_g(z, x []Word, m, a Word) (c Word) { if len(x) != len(z) { panic("mulAddVWW len") } c = a for i := range z { c, z[i] = mulAddWWW(x[i], m, c) } return } // addMulVVWW_g sets z = x + y*m + a, returning the final carry word. func addMulVVWW_g(z, x, y []Word, m, a Word) (c Word) { if len(x) != len(z) || len(y) != len(z) { panic("addMulVVWW len") } c = a for i := range z { z1, z0 := mulAddWWW(y[i], m, x[i]) z[i], c = addWW(z0, c, 0) c += z1 // no overflow: x[:n]+y[:n]*m+a always fits in n+1 words } return c } // subMulVVWW_g sets z = x - y*m - a, returning the final carry word. func subMulVVWW_g(z, x, y []Word, m, a Word) (c Word) { if len(x) != len(z) || len(y) != len(z) { panic("subMulVVWW len") } for i := range z { var lo Word a, lo = mulAddWWW(y[i], m, a) z[i], c = subWW(x[i], lo, c) } // no overflow: y*m+a fits in n+1 words, so 0 - (y*m+a) does too. return a + c } // q = ( x1 << _W + x0 - r)/y. m = floor(( _B^2 - 1 ) / d - _B). Requiring x1<y. // An approximate reciprocal with a reference to "Improved Division by Invariant Integers // (IEEE Transactions on Computers, 11 Jun. 2010)" func divWW(x1, x0, y, m Word) (q, r Word) { s := nlz(y) if s != 0 { x1 = x1<<s | x0>>(_W-s) x0 <<= s y <<= s } d := uint(y) // We know that // m = ⎣(B^2-1)/d⎦-B // ⎣(B^2-1)/d⎦ = m+B // (B^2-1)/d = m+B+delta1 0 <= delta1 <= (d-1)/d // B^2/d = m+B+delta2 0 <= delta2 <= 1 // The quotient we're trying to compute is // quotient = ⎣(x1*B+x0)/d⎦ // = ⎣(x1*B*(B^2/d)+x0*(B^2/d))/B^2⎦ // = ⎣(x1*B*(m+B+delta2)+x0*(m+B+delta2))/B^2⎦ // = ⎣(x1*m+x1*B+x0)/B + x0*m/B^2 + delta2*(x1*B+x0)/B^2⎦ // The latter two terms of this three-term sum are between 0 and 1. // So we can compute just the first term, and we will be low by at most 2. t1, t0 := bits.Mul(uint(m), uint(x1)) _, c := bits.Add(t0, uint(x0), 0) t1, _ = bits.Add(t1, uint(x1), c) // The quotient is either t1, t1+1, or t1+2. // We'll try t1 and adjust if needed. qq := t1 // compute remainder r=x-d*q. dq1, dq0 := bits.Mul(d, qq) r0, b := bits.Sub(uint(x0), dq0, 0) r1, _ := bits.Sub(uint(x1), dq1, b) // The remainder we just computed is bounded above by B+d: // r = x1*B + x0 - d*q. // = x1*B + x0 - d*⎣(x1*m+x1*B+x0)/B⎦ // = x1*B + x0 - d*((x1*m+x1*B+x0)/B-alpha) 0 <= alpha < 1 // = x1*B + x0 - x1*d/B*m - x1*d - x0*d/B + d*alpha // = x1*B + x0 - x1*d/B*⎣(B^2-1)/d-B⎦ - x1*d - x0*d/B + d*alpha // = x1*B + x0 - x1*d/B*⎣(B^2-1)/d-B⎦ - x1*d - x0*d/B + d*alpha // = x1*B + x0 - x1*d/B*((B^2-1)/d-B-beta) - x1*d - x0*d/B + d*alpha 0 <= beta < 1 // = x1*B + x0 - x1*B + x1/B + x1*d + x1*d/B*beta - x1*d - x0*d/B + d*alpha // = x0 + x1/B + x1*d/B*beta - x0*d/B + d*alpha // = x0*(1-d/B) + x1*(1+d*beta)/B + d*alpha // < B*(1-d/B) + d*B/B + d because x0<B (and 1-d/B>0), x1<d, 1+d*beta<=B, alpha<1 // = B - d + d + d // = B+d // So r1 can only be 0 or 1. If r1 is 1, then we know q was too small. // Add 1 to q and subtract d from r. That guarantees that r is <B, so // we no longer need to keep track of r1. if r1 != 0 { qq++ r0 -= d } // If the remainder is still too large, increment q one more time. if r0 >= d { qq++ r0 -= d } return Word(qq), Word(r0 >> s) } // reciprocalWord return the reciprocal of the divisor. rec = floor(( _B^2 - 1 ) / u - _B). u = d1 << nlz(d1). func reciprocalWord(d1 Word) Word { u := uint(d1 << nlz(d1)) x1 := ^u x0 := uint(_M) rec, _ := bits.Div(x1, x0, u) // (_B^2-1)/U-_B = (_B*(_M-C)+_M)/U return Word(rec) }
// Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. //go:build !math_big_pure_go //go:generate go test ./internal/asmgen -generate package big import _ "unsafe" // for linkname // implemented in arith_$GOARCH.s // addVV sets z = x + y, returning the carry out bit. // // addVV should be an internal detail, // but widely used packages access it using linkname. // Notable members of the hall of shame include: // - github.com/remyoudompheng/bigfft // // Do not remove or change the type signature. // See go.dev/issue/67401. // //go:linkname addVV //go:noescape func addVV(z, x, y []Word) (c Word) // addSubVV sets za, ca = x + y; zs, cs = x - y. // //go:noescape func addSubVV(za, zs, x, y []Word) (ca, cs Word) // subVV sets z = x - y, returning the carry (borrow) out bit. // // subVV should be an internal detail, // but widely used packages access it using linkname. // Notable members of the hall of shame include: // - github.com/remyoudompheng/bigfft // // Do not remove or change the type signature. // See go.dev/issue/67401. // //go:linkname subVV //go:noescape func subVV(z, x, y []Word) (c Word) // shlVU should be an internal detail (and a stale one at that), // but widely used packages access it using linkname. // Notable members of the hall of shame include: // - github.com/remyoudompheng/bigfft // // Do not remove or change the type signature. // See go.dev/issue/67401. // //go:linkname shlVU func shlVU(z, x []Word, s uint) (c Word) { if s == 0 { copy(z, x) return 0 } return lshVU(z, x, s) } // lshVU sets z = x<<s, returning the high bits c. 1 ≤ s ≤ _W-1. // //go:noescape func lshVU(z, x []Word, s uint) (c Word) // rshVU sets z = x>>s, returning the low bits c. 1 ≤ s ≤ _W-1. // //go:noescape func rshVU(z, x []Word, s uint) (c Word) // addLshVVU sets z = x + y<<s, returning the carry out word. 1 ≤ s ≤ _W-1. func addLshVVU(z, x, y []Word, s uint) (c Word) // subLshVVU sets z = x + y>>s, returning the carry (borrow) out word. 1 ≤ s ≤ _W-1. func subLshVVU(z, x, y []Word, s uint) (c Word) // mulAddVWW sets z = x*m + a. // // mulAddVWW should be an internal detail, // but widely used packages access it using linkname. // Notable members of the hall of shame include: // - github.com/remyoudompheng/bigfft // // Do not remove or change the type signature. // See go.dev/issue/67401. // //go:linkname mulAddVWW //go:noescape func mulAddVWW(z, x []Word, m, a Word) (c Word) // addMulVVW should be an internal detail (and a stale one at that), // but widely used packages access it using linkname. // Notable members of the hall of shame include: // - github.com/remyoudompheng/bigfft // // Do not remove or change the type signature. // See go.dev/issue/67401. // //go:linkname addMulVVW func addMulVVW(z, x []Word, y Word) (c Word) { return addMulVVWW(z, z, x, y, 0) } // addMulVVWW sets z = x+y*m+a. // //go:noescape func addMulVVWW(z, x, y []Word, m, a Word) (c Word) // subMulVVWW sets z = x-y*m-a. // //go:noescape func subMulVVWW(z, x, y []Word, m, a Word) (c Word) // mdivVWW sets z = x/d for odd d where x%d == 0 and d*inv == 1 mod _B. func mdivVWW(z, x []Word, d, inv Word) (c Word)
// Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements multi-precision decimal numbers. // The implementation is for float to decimal conversion only; // not general purpose use. // The only operations are precise conversion from binary to // decimal and rounding. // // The key observation and some code (shr) is borrowed from // strconv/decimal.go: conversion of binary fractional values can be done // precisely in multi-precision decimal because 2 divides 10 (required for // >> of mantissa); but conversion of decimal floating-point values cannot // be done precisely in binary representation. // // In contrast to strconv/decimal.go, only right shift is implemented in // decimal format - left shift can be done precisely in binary format. package big // A decimal represents an unsigned floating-point number in decimal representation. // The value of a non-zero decimal d is d.mant * 10**d.exp with 0.1 <= d.mant < 1, // with the most-significant mantissa digit at index 0. For the zero decimal, the // mantissa length and exponent are 0. // The zero value for decimal represents a ready-to-use 0.0. type decimal struct { mant []byte // mantissa ASCII digits, big-endian exp int // exponent } // at returns the i'th mantissa digit, starting with the most significant digit at 0. func (d *decimal) at(i int) byte { if 0 <= i && i < len(d.mant) { return d.mant[i] } return '0' } // Maximum shift amount that can be done in one pass without overflow. // A Word has _W bits and (1<<maxShift - 1)*10 + 9 must fit into Word. const maxShift = _W - 4 // TODO(gri) Since we know the desired decimal precision when converting // a floating-point number, we may be able to limit the number of decimal // digits that need to be computed by init by providing an additional // precision argument and keeping track of when a number was truncated early // (equivalent of "sticky bit" in binary rounding). // TODO(gri) Along the same lines, enforce some limit to shift magnitudes // to avoid "infinitely" long running conversions (until we run out of space). // Init initializes x to the decimal representation of m << shift (for // shift >= 0), or m >> -shift (for shift < 0). func (x *decimal) init(m nat, shift int) { // special case 0 if len(m) == 0 { x.mant = x.mant[:0] x.exp = 0 return } // Optimization: If we need to shift right, first remove any trailing // zero bits from m to reduce shift amount that needs to be done in // decimal format (since that is likely slower). if shift < 0 { ntz := m.trailingZeroBits() s := uint(-shift) if s >= ntz { s = ntz // shift at most ntz bits } m = nat(nil).rsh(m, s) shift += int(s) } // Do any shift left in binary representation. if shift > 0 { m = nat(nil).lsh(m, uint(shift)) shift = 0 } // Convert mantissa into decimal representation. s := m.utoa(10) n := len(s) x.exp = n // Trim trailing zeros; instead the exponent is tracking // the decimal point independent of the number of digits. for n > 0 && s[n-1] == '0' { n-- } x.mant = append(x.mant[:0], s[:n]...) // Do any (remaining) shift right in decimal representation. if shift < 0 { for shift < -maxShift { rsh(x, maxShift) shift += maxShift } rsh(x, uint(-shift)) } } // rsh implements x >> s, for s <= maxShift. func rsh(x *decimal, s uint) { // Division by 1<<s using shift-and-subtract algorithm. // pick up enough leading digits to cover first shift r := 0 // read index var n Word for n>>s == 0 && r < len(x.mant) { ch := Word(x.mant[r]) r++ n = n*10 + ch - '0' } if n == 0 { // x == 0; shouldn't get here, but handle anyway x.mant = x.mant[:0] return } for n>>s == 0 { r++ n *= 10 } x.exp += 1 - r // read a digit, write a digit w := 0 // write index mask := Word(1)<<s - 1 for r < len(x.mant) { ch := Word(x.mant[r]) r++ d := n >> s n &= mask // n -= d << s x.mant[w] = byte(d + '0') w++ n = n*10 + ch - '0' } // write extra digits that still fit for n > 0 && w < len(x.mant) { d := n >> s n &= mask x.mant[w] = byte(d + '0') w++ n = n * 10 } x.mant = x.mant[:w] // the number may be shorter (e.g. 1024 >> 10) // append additional digits that didn't fit for n > 0 { d := n >> s n &= mask x.mant = append(x.mant, byte(d+'0')) n = n * 10 } trim(x) } func (x *decimal) String() string { if len(x.mant) == 0 { return "0" } var buf []byte switch { case x.exp <= 0: // 0.00ddd buf = make([]byte, 0, 2+(-x.exp)+len(x.mant)) buf = append(buf, "0."...) buf = appendZeros(buf, -x.exp) buf = append(buf, x.mant...) case /* 0 < */ x.exp < len(x.mant): // dd.ddd buf = make([]byte, 0, 1+len(x.mant)) buf = append(buf, x.mant[:x.exp]...) buf = append(buf, '.') buf = append(buf, x.mant[x.exp:]...) default: // len(x.mant) <= x.exp // ddd00 buf = make([]byte, 0, x.exp) buf = append(buf, x.mant...) buf = appendZeros(buf, x.exp-len(x.mant)) } return string(buf) } // appendZeros appends n 0 digits to buf and returns buf. func appendZeros(buf []byte, n int) []byte { for ; n > 0; n-- { buf = append(buf, '0') } return buf } // shouldRoundUp reports if x should be rounded up // if shortened to n digits. n must be a valid index // for x.mant. func shouldRoundUp(x *decimal, n int) bool { if x.mant[n] == '5' && n+1 == len(x.mant) { // exactly halfway - round to even return n > 0 && (x.mant[n-1]-'0')&1 != 0 } // not halfway - digit tells all (x.mant has no trailing zeros) return x.mant[n] >= '5' } // round sets x to (at most) n mantissa digits by rounding it // to the nearest even value with n (or fever) mantissa digits. // If n < 0, x remains unchanged. func (x *decimal) round(n int) { if n < 0 || n >= len(x.mant) { return // nothing to do } if shouldRoundUp(x, n) { x.roundUp(n) } else { x.roundDown(n) } } func (x *decimal) roundUp(n int) { if n < 0 || n >= len(x.mant) { return // nothing to do } // 0 <= n < len(x.mant) // find first digit < '9' for n > 0 && x.mant[n-1] >= '9' { n-- } if n == 0 { // all digits are '9's => round up to '1' and update exponent x.mant[0] = '1' // ok since len(x.mant) > n x.mant = x.mant[:1] x.exp++ return } // n > 0 && x.mant[n-1] < '9' x.mant[n-1]++ x.mant = x.mant[:n] // x already trimmed } func (x *decimal) roundDown(n int) { if n < 0 || n >= len(x.mant) { return // nothing to do } x.mant = x.mant[:n] trim(x) } // trim cuts off any trailing zeros from x's mantissa; // they are meaningless for the value of x. func trim(x *decimal) { i := len(x.mant) for i > 0 && x.mant[i-1] == '0' { i-- } x.mant = x.mant[:i] if i == 0 { x.exp = 0 } }
// Copyright 2014 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements multi-precision floating-point numbers. // Like in the GNU MPFR library (https://www.mpfr.org/), operands // can be of mixed precision. Unlike MPFR, the rounding mode is // not specified with each operation, but with each operand. The // rounding mode of the result operand determines the rounding // mode of an operation. This is a from-scratch implementation. package big import ( "fmt" "math" "math/bits" ) const debugFloat = false // enable for debugging // A nonzero finite Float represents a multi-precision floating point number // // sign × mantissa × 2**exponent // // with 0.5 <= mantissa < 1.0, and MinExp <= exponent <= MaxExp. // A Float may also be zero (+0, -0) or infinite (+Inf, -Inf). // All Floats are ordered, and the ordering of two Floats x and y // is defined by x.Cmp(y). // // Each Float value also has a precision, rounding mode, and accuracy. // The precision is the maximum number of mantissa bits available to // represent the value. The rounding mode specifies how a result should // be rounded to fit into the mantissa bits, and accuracy describes the // rounding error with respect to the exact result. // // Unless specified otherwise, all operations (including setters) that // specify a *Float variable for the result (usually via the receiver // with the exception of [Float.MantExp]), round the numeric result according // to the precision and rounding mode of the result variable. // // If the provided result precision is 0 (see below), it is set to the // precision of the argument with the largest precision value before any // rounding takes place, and the rounding mode remains unchanged. Thus, // uninitialized Floats provided as result arguments will have their // precision set to a reasonable value determined by the operands, and // their mode is the zero value for RoundingMode (ToNearestEven). // // By setting the desired precision to 24 or 53 and using matching rounding // mode (typically [ToNearestEven]), Float operations produce the same results // as the corresponding float32 or float64 IEEE 754 arithmetic for operands // that correspond to normal (i.e., not denormal) float32 or float64 numbers. // Exponent underflow and overflow lead to a 0 or an Infinity for different // values than IEEE 754 because Float exponents have a much larger range. // // The zero (uninitialized) value for a Float is ready to use and represents // the number +0.0 exactly, with precision 0 and rounding mode [ToNearestEven]. // // Operations always take pointer arguments (*Float) rather // than Float values, and each unique Float value requires // its own unique *Float pointer. To "copy" a Float value, // an existing (or newly allocated) Float must be set to // a new value using the [Float.Set] method; shallow copies // of Floats are not supported and may lead to errors. type Float struct { prec uint32 mode RoundingMode acc Accuracy form form neg bool mant nat exp int32 } // An ErrNaN panic is raised by a [Float] operation that would lead to // a NaN under IEEE 754 rules. An ErrNaN implements the error interface. type ErrNaN struct { msg string } func (err ErrNaN) Error() string { return err.msg } // NewFloat allocates and returns a new [Float] set to x, // with precision 53 and rounding mode [ToNearestEven]. // NewFloat panics with [ErrNaN] if x is a NaN. func NewFloat(x float64) *Float { if math.IsNaN(x) { panic(ErrNaN{"NewFloat(NaN)"}) } return new(Float).SetFloat64(x) } // Exponent and precision limits. const ( MaxExp = math.MaxInt32 // largest supported exponent MinExp = math.MinInt32 // smallest supported exponent MaxPrec = math.MaxUint32 // largest (theoretically) supported precision; likely memory-limited ) // Internal representation: The mantissa bits x.mant of a nonzero finite // Float x are stored in a nat slice long enough to hold up to x.prec bits; // the slice may (but doesn't have to) be shorter if the mantissa contains // trailing 0 bits. x.mant is normalized if the msb of x.mant == 1 (i.e., // the msb is shifted all the way "to the left"). Thus, if the mantissa has // trailing 0 bits or x.prec is not a multiple of the Word size _W, // x.mant[0] has trailing zero bits. The msb of the mantissa corresponds // to the value 0.5; the exponent x.exp shifts the binary point as needed. // // A zero or non-finite Float x ignores x.mant and x.exp. // // x form neg mant exp // ---------------------------------------------------------- // ±0 zero sign - - // 0 < |x| < +Inf finite sign mantissa exponent // ±Inf inf sign - - // A form value describes the internal representation. type form byte // The form value order is relevant - do not change! const ( zero form = iota finite inf ) // RoundingMode determines how a [Float] value is rounded to the // desired precision. Rounding may change the [Float] value; the // rounding error is described by the [Float]'s [Accuracy]. type RoundingMode byte // These constants define supported rounding modes. const ( ToNearestEven RoundingMode = iota // == IEEE 754-2008 roundTiesToEven ToNearestAway // == IEEE 754-2008 roundTiesToAway ToZero // == IEEE 754-2008 roundTowardZero AwayFromZero // no IEEE 754-2008 equivalent ToNegativeInf // == IEEE 754-2008 roundTowardNegative ToPositiveInf // == IEEE 754-2008 roundTowardPositive ) //go:generate stringer -type=RoundingMode // Accuracy describes the rounding error produced by the most recent // operation that generated a [Float] value, relative to the exact value. type Accuracy int8 // Constants describing the [Accuracy] of a [Float]. const ( Below Accuracy = -1 Exact Accuracy = 0 Above Accuracy = +1 ) //go:generate stringer -type=Accuracy // SetPrec sets z's precision to prec and returns the (possibly) rounded // value of z. Rounding occurs according to z's rounding mode if the mantissa // cannot be represented in prec bits without loss of precision. // SetPrec(0) maps all finite values to ±0; infinite values remain unchanged. // If prec > [MaxPrec], it is set to [MaxPrec]. func (z *Float) SetPrec(prec uint) *Float { z.acc = Exact // optimistically assume no rounding is needed // special case if prec == 0 { z.prec = 0 if z.form == finite { // truncate z to 0 z.acc = makeAcc(z.neg) z.form = zero } return z } // general case if prec > MaxPrec { prec = MaxPrec } old := z.prec z.prec = uint32(prec) if z.prec < old { z.round(0) } return z } func makeAcc(above bool) Accuracy { if above { return Above } return Below } // SetMode sets z's rounding mode to mode and returns an exact z. // z remains unchanged otherwise. // z.SetMode(z.Mode()) is a cheap way to set z's accuracy to [Exact]. func (z *Float) SetMode(mode RoundingMode) *Float { z.mode = mode z.acc = Exact return z } // Prec returns the mantissa precision of x in bits. // The result may be 0 for |x| == 0 and |x| == Inf. func (x *Float) Prec() uint { return uint(x.prec) } // MinPrec returns the minimum precision required to represent x exactly // (i.e., the smallest prec before x.SetPrec(prec) would start rounding x). // The result is 0 for |x| == 0 and |x| == Inf. func (x *Float) MinPrec() uint { if x.form != finite { return 0 } return uint(len(x.mant))*_W - x.mant.trailingZeroBits() } // Mode returns the rounding mode of x. func (x *Float) Mode() RoundingMode { return x.mode } // Acc returns the accuracy of x produced by the most recent // operation, unless explicitly documented otherwise by that // operation. func (x *Float) Acc() Accuracy { return x.acc } // Sign returns: // - -1 if x < 0; // - 0 if x is ±0; // - +1 if x > 0. func (x *Float) Sign() int { if debugFloat { x.validate() } if x.form == zero { return 0 } if x.neg { return -1 } return 1 } // MantExp breaks x into its mantissa and exponent components // and returns the exponent. If a non-nil mant argument is // provided its value is set to the mantissa of x, with the // same precision and rounding mode as x. The components // satisfy x == mant × 2**exp, with 0.5 <= |mant| < 1.0. // Calling MantExp with a nil argument is an efficient way to // get the exponent of the receiver. // // Special cases are: // // ( ±0).MantExp(mant) = 0, with mant set to ±0 // (±Inf).MantExp(mant) = 0, with mant set to ±Inf // // x and mant may be the same in which case x is set to its // mantissa value. func (x *Float) MantExp(mant *Float) (exp int) { if debugFloat { x.validate() } if x.form == finite { exp = int(x.exp) } if mant != nil { mant.Copy(x) if mant.form == finite { mant.exp = 0 } } return } func (z *Float) setExpAndRound(exp int64, sbit uint) { if exp < MinExp { // underflow z.acc = makeAcc(z.neg) z.form = zero return } if exp > MaxExp { // overflow z.acc = makeAcc(!z.neg) z.form = inf return } z.form = finite z.exp = int32(exp) z.round(sbit) } // SetMantExp sets z to mant × 2**exp and returns z. // The result z has the same precision and rounding mode // as mant. SetMantExp is an inverse of [Float.MantExp] but does // not require 0.5 <= |mant| < 1.0. Specifically, for a // given x of type *[Float], SetMantExp relates to [Float.MantExp] // as follows: // // mant := new(Float) // new(Float).SetMantExp(mant, x.MantExp(mant)).Cmp(x) == 0 // // Special cases are: // // z.SetMantExp( ±0, exp) = ±0 // z.SetMantExp(±Inf, exp) = ±Inf // // z and mant may be the same in which case z's exponent // is set to exp. func (z *Float) SetMantExp(mant *Float, exp int) *Float { if debugFloat { z.validate() mant.validate() } z.Copy(mant) if z.form == finite { // 0 < |mant| < +Inf z.setExpAndRound(int64(z.exp)+int64(exp), 0) } return z } // Signbit reports whether x is negative or negative zero. func (x *Float) Signbit() bool { return x.neg } // IsInf reports whether x is +Inf or -Inf. func (x *Float) IsInf() bool { return x.form == inf } // IsInt reports whether x is an integer. // ±Inf values are not integers. func (x *Float) IsInt() bool { if debugFloat { x.validate() } // special cases if x.form != finite { return x.form == zero } // x.form == finite if x.exp <= 0 { return false } // x.exp > 0 return x.prec <= uint32(x.exp) || x.MinPrec() <= uint(x.exp) // not enough bits for fractional mantissa } // debugging support func (x *Float) validate() { if !debugFloat { // avoid performance bugs panic("validate called but debugFloat is not set") } if msg := x.validate0(); msg != "" { panic(msg) } } func (x *Float) validate0() string { if x.form != finite { return "" } m := len(x.mant) if m == 0 { return "nonzero finite number with empty mantissa" } const msb = 1 << (_W - 1) if x.mant[m-1]&msb == 0 { return fmt.Sprintf("msb not set in last word %#x of %s", x.mant[m-1], x.Text('p', 0)) } if x.prec == 0 { return "zero precision finite number" } return "" } // round rounds z according to z.mode to z.prec bits and sets z.acc accordingly. // sbit must be 0 or 1 and summarizes any "sticky bit" information one might // have before calling round. z's mantissa must be normalized (with the msb set) // or empty. // // CAUTION: The rounding modes [ToNegativeInf], [ToPositiveInf] are affected by the // sign of z. For correct rounding, the sign of z must be set correctly before // calling round. func (z *Float) round(sbit uint) { if debugFloat { z.validate() } z.acc = Exact if z.form != finite { // ±0 or ±Inf => nothing left to do return } // z.form == finite && len(z.mant) > 0 // m > 0 implies z.prec > 0 (checked by validate) m := uint32(len(z.mant)) // present mantissa length in words bits := m * _W // present mantissa bits; bits > 0 if bits <= z.prec { // mantissa fits => nothing to do return } // bits > z.prec // Rounding is based on two bits: the rounding bit (rbit) and the // sticky bit (sbit). The rbit is the bit immediately before the // z.prec leading mantissa bits (the "0.5"). The sbit is set if any // of the bits before the rbit are set (the "0.25", "0.125", etc.): // // rbit sbit => "fractional part" // // 0 0 == 0 // 0 1 > 0 , < 0.5 // 1 0 == 0.5 // 1 1 > 0.5, < 1.0 // bits > z.prec: mantissa too large => round r := uint(bits - z.prec - 1) // rounding bit position; r >= 0 rbit := z.mant.bit(r) & 1 // rounding bit; be safe and ensure it's a single bit // The sticky bit is only needed for rounding ToNearestEven // or when the rounding bit is zero. Avoid computation otherwise. if sbit == 0 && (rbit == 0 || z.mode == ToNearestEven) { sbit = z.mant.sticky(r) } sbit &= 1 // be safe and ensure it's a single bit // cut off extra words n := (z.prec + (_W - 1)) / _W // mantissa length in words for desired precision if m > n { copy(z.mant, z.mant[m-n:]) // move n last words to front z.mant = z.mant[:n] } // determine number of trailing zero bits (ntz) and compute lsb mask of mantissa's least-significant word ntz := n*_W - z.prec // 0 <= ntz < _W lsb := Word(1) << ntz // round if result is inexact if rbit|sbit != 0 { // Make rounding decision: The result mantissa is truncated ("rounded down") // by default. Decide if we need to increment, or "round up", the (unsigned) // mantissa. inc := false switch z.mode { case ToNegativeInf: inc = z.neg case ToZero: // nothing to do case ToNearestEven: inc = rbit != 0 && (sbit != 0 || z.mant[0]&lsb != 0) case ToNearestAway: inc = rbit != 0 case AwayFromZero: inc = true case ToPositiveInf: inc = !z.neg default: panic("unreachable") } // A positive result (!z.neg) is Above the exact result if we increment, // and it's Below if we truncate (Exact results require no rounding). // For a negative result (z.neg) it is exactly the opposite. z.acc = makeAcc(inc != z.neg) if inc { // add 1 to mantissa if addVW(z.mant, z.mant, lsb) != 0 { // mantissa overflow => adjust exponent if z.exp >= MaxExp { // exponent overflow z.form = inf return } z.exp++ // adjust mantissa: divide by 2 to compensate for exponent adjustment rshVU(z.mant, z.mant, 1) // set msb == carry == 1 from the mantissa overflow above const msb = 1 << (_W - 1) z.mant[n-1] |= msb } } } // zero out trailing bits in least-significant word z.mant[0] &^= lsb - 1 if debugFloat { z.validate() } } func (z *Float) setBits64(neg bool, x uint64) *Float { if z.prec == 0 { z.prec = 64 } z.acc = Exact z.neg = neg if x == 0 { z.form = zero return z } // x != 0 z.form = finite s := bits.LeadingZeros64(x) z.mant = z.mant.setUint64(x << uint(s)) z.exp = int32(64 - s) // always fits if z.prec < 64 { z.round(0) } return z } // SetUint64 sets z to the (possibly rounded) value of x and returns z. // If z's precision is 0, it is changed to 64 (and rounding will have // no effect). func (z *Float) SetUint64(x uint64) *Float { return z.setBits64(false, x) } // SetInt64 sets z to the (possibly rounded) value of x and returns z. // If z's precision is 0, it is changed to 64 (and rounding will have // no effect). func (z *Float) SetInt64(x int64) *Float { u := x if u < 0 { u = -u } // We cannot simply call z.SetUint64(uint64(u)) and change // the sign afterwards because the sign affects rounding. return z.setBits64(x < 0, uint64(u)) } // SetFloat64 sets z to the (possibly rounded) value of x and returns z. // If z's precision is 0, it is changed to 53 (and rounding will have // no effect). SetFloat64 panics with [ErrNaN] if x is a NaN. func (z *Float) SetFloat64(x float64) *Float { if z.prec == 0 { z.prec = 53 } if math.IsNaN(x) { panic(ErrNaN{"Float.SetFloat64(NaN)"}) } z.acc = Exact z.neg = math.Signbit(x) // handle -0, -Inf correctly if x == 0 { z.form = zero return z } if math.IsInf(x, 0) { z.form = inf return z } // normalized x != 0 z.form = finite fmant, exp := math.Frexp(x) // get normalized mantissa z.mant = z.mant.setUint64(1<<63 | math.Float64bits(fmant)<<11) z.exp = int32(exp) // always fits if z.prec < 53 { z.round(0) } return z } // fnorm normalizes mantissa m by shifting it to the left // such that the msb of the most-significant word (msw) is 1. // It returns the shift amount. It assumes that len(m) != 0. func fnorm(m nat) int64 { if debugFloat && (len(m) == 0 || m[len(m)-1] == 0) { panic("msw of mantissa is 0") } s := nlz(m[len(m)-1]) if s > 0 { c := lshVU(m, m, s) if debugFloat && c != 0 { panic("nlz or lshVU incorrect") } } return int64(s) } // SetInt sets z to the (possibly rounded) value of x and returns z. // If z's precision is 0, it is changed to the larger of x.BitLen() // or 64 (and rounding will have no effect). func (z *Float) SetInt(x *Int) *Float { // TODO(gri) can be more efficient if z.prec > 0 // but small compared to the size of x, or if there // are many trailing 0's. bits := uint32(x.BitLen()) if z.prec == 0 { z.prec = max(bits, 64) } z.acc = Exact z.neg = x.neg if len(x.abs) == 0 { z.form = zero return z } // x != 0 z.mant = z.mant.set(x.abs) fnorm(z.mant) z.setExpAndRound(int64(bits), 0) return z } // SetRat sets z to the (possibly rounded) value of x and returns z. // If z's precision is 0, it is changed to the largest of a.BitLen(), // b.BitLen(), or 64; with x = a/b. func (z *Float) SetRat(x *Rat) *Float { if x.IsInt() { return z.SetInt(x.Num()) } var a, b Float a.SetInt(x.Num()) b.SetInt(x.Denom()) if z.prec == 0 { z.prec = max(a.prec, b.prec) } return z.Quo(&a, &b) } // SetInf sets z to the infinite Float -Inf if signbit is // set, or +Inf if signbit is not set, and returns z. The // precision of z is unchanged and the result is always // [Exact]. func (z *Float) SetInf(signbit bool) *Float { z.acc = Exact z.form = inf z.neg = signbit return z } // Set sets z to the (possibly rounded) value of x and returns z. // If z's precision is 0, it is changed to the precision of x // before setting z (and rounding will have no effect). // Rounding is performed according to z's precision and rounding // mode; and z's accuracy reports the result error relative to the // exact (not rounded) result. func (z *Float) Set(x *Float) *Float { if debugFloat { x.validate() } z.acc = Exact if z != x { z.form = x.form z.neg = x.neg if x.form == finite { z.exp = x.exp z.mant = z.mant.set(x.mant) } if z.prec == 0 { z.prec = x.prec } else if z.prec < x.prec { z.round(0) } } return z } // Copy sets z to x, with the same precision, rounding mode, and accuracy as x. // Copy returns z. If x and z are identical, Copy is a no-op. func (z *Float) Copy(x *Float) *Float { if debugFloat { x.validate() } if z != x { z.prec = x.prec z.mode = x.mode z.acc = x.acc z.form = x.form z.neg = x.neg if z.form == finite { z.mant = z.mant.set(x.mant) z.exp = x.exp } } return z } // msb32 returns the 32 most significant bits of x. func msb32(x nat) uint32 { i := len(x) - 1 if i < 0 { return 0 } if debugFloat && x[i]&(1<<(_W-1)) == 0 { panic("x not normalized") } switch _W { case 32: return uint32(x[i]) case 64: return uint32(x[i] >> 32) } panic("unreachable") } // msb64 returns the 64 most significant bits of x. func msb64(x nat) uint64 { i := len(x) - 1 if i < 0 { return 0 } if debugFloat && x[i]&(1<<(_W-1)) == 0 { panic("x not normalized") } switch _W { case 32: v := uint64(x[i]) << 32 if i > 0 { v |= uint64(x[i-1]) } return v case 64: return uint64(x[i]) } panic("unreachable") } // Uint64 returns the unsigned integer resulting from truncating x // towards zero. If 0 <= x <= [math.MaxUint64], the result is [Exact] // if x is an integer and [Below] otherwise. // The result is (0, [Above]) for x < 0, and ([math.MaxUint64], [Below]) // for x > [math.MaxUint64]. func (x *Float) Uint64() (uint64, Accuracy) { if debugFloat { x.validate() } switch x.form { case finite: if x.neg { return 0, Above } // 0 < x < +Inf if x.exp <= 0 { // 0 < x < 1 return 0, Below } // 1 <= x < Inf if x.exp <= 64 { // u = trunc(x) fits into a uint64 u := msb64(x.mant) >> (64 - uint32(x.exp)) if x.MinPrec() <= 64 { return u, Exact } return u, Below // x truncated } // x too large return math.MaxUint64, Below case zero: return 0, Exact case inf: if x.neg { return 0, Above } return math.MaxUint64, Below } panic("unreachable") } // Int64 returns the integer resulting from truncating x towards zero. // If [math.MinInt64] <= x <= [math.MaxInt64], the result is [Exact] if x is // an integer, and [Above] (x < 0) or [Below] (x > 0) otherwise. // The result is ([math.MinInt64], [Above]) for x < [math.MinInt64], // and ([math.MaxInt64], [Below]) for x > [math.MaxInt64]. func (x *Float) Int64() (int64, Accuracy) { if debugFloat { x.validate() } switch x.form { case finite: // 0 < |x| < +Inf acc := makeAcc(x.neg) if x.exp <= 0 { // 0 < |x| < 1 return 0, acc } // x.exp > 0 // 1 <= |x| < +Inf if x.exp <= 63 { // i = trunc(x) fits into an int64 (excluding math.MinInt64) i := int64(msb64(x.mant) >> (64 - uint32(x.exp))) if x.neg { i = -i } if x.MinPrec() <= uint(x.exp) { return i, Exact } return i, acc // x truncated } if x.neg { // check for special case x == math.MinInt64 (i.e., x == -(0.5 << 64)) if x.exp == 64 && x.MinPrec() == 1 { acc = Exact } return math.MinInt64, acc } // x too large return math.MaxInt64, Below case zero: return 0, Exact case inf: if x.neg { return math.MinInt64, Above } return math.MaxInt64, Below } panic("unreachable") } // Float32 returns the float32 value nearest to x. If x is too small to be // represented by a float32 (|x| < [math.SmallestNonzeroFloat32]), the result // is (0, [Below]) or (-0, [Above]), respectively, depending on the sign of x. // If x is too large to be represented by a float32 (|x| > [math.MaxFloat32]), // the result is (+Inf, [Above]) or (-Inf, [Below]), depending on the sign of x. func (x *Float) Float32() (float32, Accuracy) { if debugFloat { x.validate() } switch x.form { case finite: // 0 < |x| < +Inf const ( fbits = 32 // float size mbits = 23 // mantissa size (excluding implicit msb) ebits = fbits - mbits - 1 // 8 exponent size bias = 1<<(ebits-1) - 1 // 127 exponent bias dmin = 1 - bias - mbits // -149 smallest unbiased exponent (denormal) emin = 1 - bias // -126 smallest unbiased exponent (normal) emax = bias // 127 largest unbiased exponent (normal) ) // Float mantissa m is 0.5 <= m < 1.0; compute exponent e for float32 mantissa. e := x.exp - 1 // exponent for normal mantissa m with 1.0 <= m < 2.0 // Compute precision p for float32 mantissa. // If the exponent is too small, we have a denormal number before // rounding and fewer than p mantissa bits of precision available // (the exponent remains fixed but the mantissa gets shifted right). p := mbits + 1 // precision of normal float if e < emin { // recompute precision p = mbits + 1 - emin + int(e) // If p == 0, the mantissa of x is shifted so much to the right // that its msb falls immediately to the right of the float32 // mantissa space. In other words, if the smallest denormal is // considered "1.0", for p == 0, the mantissa value m is >= 0.5. // If m > 0.5, it is rounded up to 1.0; i.e., the smallest denormal. // If m == 0.5, it is rounded down to even, i.e., 0.0. // If p < 0, the mantissa value m is <= "0.25" which is never rounded up. if p < 0 /* m <= 0.25 */ || p == 0 && x.mant.sticky(uint(len(x.mant))*_W-1) == 0 /* m == 0.5 */ { // underflow to ±0 if x.neg { var z float32 return -z, Above } return 0.0, Below } // otherwise, round up // We handle p == 0 explicitly because it's easy and because // Float.round doesn't support rounding to 0 bits of precision. if p == 0 { if x.neg { return -math.SmallestNonzeroFloat32, Below } return math.SmallestNonzeroFloat32, Above } } // p > 0 // round var r Float r.prec = uint32(p) r.Set(x) e = r.exp - 1 // Rounding may have caused r to overflow to ±Inf // (rounding never causes underflows to 0). // If the exponent is too large, also overflow to ±Inf. if r.form == inf || e > emax { // overflow if x.neg { return float32(math.Inf(-1)), Below } return float32(math.Inf(+1)), Above } // e <= emax // Determine sign, biased exponent, and mantissa. var sign, bexp, mant uint32 if x.neg { sign = 1 << (fbits - 1) } // Rounding may have caused a denormal number to // become normal. Check again. if e < emin { // denormal number: recompute precision // Since rounding may have at best increased precision // and we have eliminated p <= 0 early, we know p > 0. // bexp == 0 for denormals p = mbits + 1 - emin + int(e) mant = msb32(r.mant) >> uint(fbits-p) } else { // normal number: emin <= e <= emax bexp = uint32(e+bias) << mbits mant = msb32(r.mant) >> ebits & (1<<mbits - 1) // cut off msb (implicit 1 bit) } return math.Float32frombits(sign | bexp | mant), r.acc case zero: if x.neg { var z float32 return -z, Exact } return 0.0, Exact case inf: if x.neg { return float32(math.Inf(-1)), Exact } return float32(math.Inf(+1)), Exact } panic("unreachable") } // Float64 returns the float64 value nearest to x. If x is too small to be // represented by a float64 (|x| < [math.SmallestNonzeroFloat64]), the result // is (0, [Below]) or (-0, [Above]), respectively, depending on the sign of x. // If x is too large to be represented by a float64 (|x| > [math.MaxFloat64]), // the result is (+Inf, [Above]) or (-Inf, [Below]), depending on the sign of x. func (x *Float) Float64() (float64, Accuracy) { if debugFloat { x.validate() } switch x.form { case finite: // 0 < |x| < +Inf const ( fbits = 64 // float size mbits = 52 // mantissa size (excluding implicit msb) ebits = fbits - mbits - 1 // 11 exponent size bias = 1<<(ebits-1) - 1 // 1023 exponent bias dmin = 1 - bias - mbits // -1074 smallest unbiased exponent (denormal) emin = 1 - bias // -1022 smallest unbiased exponent (normal) emax = bias // 1023 largest unbiased exponent (normal) ) // Float mantissa m is 0.5 <= m < 1.0; compute exponent e for float64 mantissa. e := x.exp - 1 // exponent for normal mantissa m with 1.0 <= m < 2.0 // Compute precision p for float64 mantissa. // If the exponent is too small, we have a denormal number before // rounding and fewer than p mantissa bits of precision available // (the exponent remains fixed but the mantissa gets shifted right). p := mbits + 1 // precision of normal float if e < emin { // recompute precision p = mbits + 1 - emin + int(e) // If p == 0, the mantissa of x is shifted so much to the right // that its msb falls immediately to the right of the float64 // mantissa space. In other words, if the smallest denormal is // considered "1.0", for p == 0, the mantissa value m is >= 0.5. // If m > 0.5, it is rounded up to 1.0; i.e., the smallest denormal. // If m == 0.5, it is rounded down to even, i.e., 0.0. // If p < 0, the mantissa value m is <= "0.25" which is never rounded up. if p < 0 /* m <= 0.25 */ || p == 0 && x.mant.sticky(uint(len(x.mant))*_W-1) == 0 /* m == 0.5 */ { // underflow to ±0 if x.neg { var z float64 return -z, Above } return 0.0, Below } // otherwise, round up // We handle p == 0 explicitly because it's easy and because // Float.round doesn't support rounding to 0 bits of precision. if p == 0 { if x.neg { return -math.SmallestNonzeroFloat64, Below } return math.SmallestNonzeroFloat64, Above } } // p > 0 // round var r Float r.prec = uint32(p) r.Set(x) e = r.exp - 1 // Rounding may have caused r to overflow to ±Inf // (rounding never causes underflows to 0). // If the exponent is too large, also overflow to ±Inf. if r.form == inf || e > emax { // overflow if x.neg { return math.Inf(-1), Below } return math.Inf(+1), Above } // e <= emax // Determine sign, biased exponent, and mantissa. var sign, bexp, mant uint64 if x.neg { sign = 1 << (fbits - 1) } // Rounding may have caused a denormal number to // become normal. Check again. if e < emin { // denormal number: recompute precision // Since rounding may have at best increased precision // and we have eliminated p <= 0 early, we know p > 0. // bexp == 0 for denormals p = mbits + 1 - emin + int(e) mant = msb64(r.mant) >> uint(fbits-p) } else { // normal number: emin <= e <= emax bexp = uint64(e+bias) << mbits mant = msb64(r.mant) >> ebits & (1<<mbits - 1) // cut off msb (implicit 1 bit) } return math.Float64frombits(sign | bexp | mant), r.acc case zero: if x.neg { var z float64 return -z, Exact } return 0.0, Exact case inf: if x.neg { return math.Inf(-1), Exact } return math.Inf(+1), Exact } panic("unreachable") } // Int returns the result of truncating x towards zero; // or nil if x is an infinity. // The result is [Exact] if x.IsInt(); otherwise it is [Below] // for x > 0, and [Above] for x < 0. // If a non-nil *[Int] argument z is provided, [Int] stores // the result in z instead of allocating a new [Int]. func (x *Float) Int(z *Int) (*Int, Accuracy) { if debugFloat { x.validate() } if z == nil && x.form <= finite { z = new(Int) } switch x.form { case finite: // 0 < |x| < +Inf acc := makeAcc(x.neg) if x.exp <= 0 { // 0 < |x| < 1 return z.SetInt64(0), acc } // x.exp > 0 // 1 <= |x| < +Inf // determine minimum required precision for x allBits := uint(len(x.mant)) * _W exp := uint(x.exp) if x.MinPrec() <= exp { acc = Exact } // shift mantissa as needed if z == nil { z = new(Int) } z.neg = x.neg switch { case exp > allBits: z.abs = z.abs.lsh(x.mant, exp-allBits) default: z.abs = z.abs.set(x.mant) case exp < allBits: z.abs = z.abs.rsh(x.mant, allBits-exp) } return z, acc case zero: return z.SetInt64(0), Exact case inf: return nil, makeAcc(x.neg) } panic("unreachable") } // Rat returns the rational number corresponding to x; // or nil if x is an infinity. // The result is [Exact] if x is not an Inf. // If a non-nil *[Rat] argument z is provided, [Rat] stores // the result in z instead of allocating a new [Rat]. func (x *Float) Rat(z *Rat) (*Rat, Accuracy) { if debugFloat { x.validate() } if z == nil && x.form <= finite { z = new(Rat) } switch x.form { case finite: // 0 < |x| < +Inf allBits := int32(len(x.mant)) * _W // build up numerator and denominator z.a.neg = x.neg switch { case x.exp > allBits: z.a.abs = z.a.abs.lsh(x.mant, uint(x.exp-allBits)) z.b.abs = z.b.abs[:0] // == 1 (see Rat) // z already in normal form default: z.a.abs = z.a.abs.set(x.mant) z.b.abs = z.b.abs[:0] // == 1 (see Rat) // z already in normal form case x.exp < allBits: z.a.abs = z.a.abs.set(x.mant) t := z.b.abs.setUint64(1) z.b.abs = t.lsh(t, uint(allBits-x.exp)) z.norm() } return z, Exact case zero: return z.SetInt64(0), Exact case inf: return nil, makeAcc(x.neg) } panic("unreachable") } // Abs sets z to the (possibly rounded) value |x| (the absolute value of x) // and returns z. func (z *Float) Abs(x *Float) *Float { z.Set(x) z.neg = false return z } // Neg sets z to the (possibly rounded) value of x with its sign negated, // and returns z. func (z *Float) Neg(x *Float) *Float { z.Set(x) z.neg = !z.neg return z } func validateBinaryOperands(x, y *Float) { if !debugFloat { // avoid performance bugs panic("validateBinaryOperands called but debugFloat is not set") } if len(x.mant) == 0 { panic("empty mantissa for x") } if len(y.mant) == 0 { panic("empty mantissa for y") } } // z = x + y, ignoring signs of x and y for the addition // but using the sign of z for rounding the result. // x and y must have a non-empty mantissa and valid exponent. func (z *Float) uadd(x, y *Float) { // Note: This implementation requires 2 shifts most of the // time. It is also inefficient if exponents or precisions // differ by wide margins. The following article describes // an efficient (but much more complicated) implementation // compatible with the internal representation used here: // // Vincent Lefèvre: "The Generic Multiple-Precision Floating- // Point Addition With Exact Rounding (as in the MPFR Library)" // http://www.vinc17.net/research/papers/rnc6.pdf if debugFloat { validateBinaryOperands(x, y) } // compute exponents ex, ey for mantissa with "binary point" // on the right (mantissa.0) - use int64 to avoid overflow ex := int64(x.exp) - int64(len(x.mant))*_W ey := int64(y.exp) - int64(len(y.mant))*_W al := alias(z.mant, x.mant) || alias(z.mant, y.mant) // TODO(gri) having a combined add-and-shift primitive // could make this code significantly faster switch { case ex < ey: if al { t := nat(nil).lsh(y.mant, uint(ey-ex)) z.mant = z.mant.add(x.mant, t) } else { z.mant = z.mant.lsh(y.mant, uint(ey-ex)) z.mant = z.mant.add(x.mant, z.mant) } default: // ex == ey, no shift needed z.mant = z.mant.add(x.mant, y.mant) case ex > ey: if al { t := nat(nil).lsh(x.mant, uint(ex-ey)) z.mant = z.mant.add(t, y.mant) } else { z.mant = z.mant.lsh(x.mant, uint(ex-ey)) z.mant = z.mant.add(z.mant, y.mant) } ex = ey } // len(z.mant) > 0 z.setExpAndRound(ex+int64(len(z.mant))*_W-fnorm(z.mant), 0) } // z = x - y for |x| > |y|, ignoring signs of x and y for the subtraction // but using the sign of z for rounding the result. // x and y must have a non-empty mantissa and valid exponent. func (z *Float) usub(x, y *Float) { // This code is symmetric to uadd. // We have not factored the common code out because // eventually uadd (and usub) should be optimized // by special-casing, and the code will diverge. if debugFloat { validateBinaryOperands(x, y) } ex := int64(x.exp) - int64(len(x.mant))*_W ey := int64(y.exp) - int64(len(y.mant))*_W al := alias(z.mant, x.mant) || alias(z.mant, y.mant) switch { case ex < ey: if al { t := nat(nil).lsh(y.mant, uint(ey-ex)) z.mant = t.sub(x.mant, t) } else { z.mant = z.mant.lsh(y.mant, uint(ey-ex)) z.mant = z.mant.sub(x.mant, z.mant) } default: // ex == ey, no shift needed z.mant = z.mant.sub(x.mant, y.mant) case ex > ey: if al { t := nat(nil).lsh(x.mant, uint(ex-ey)) z.mant = t.sub(t, y.mant) } else { z.mant = z.mant.lsh(x.mant, uint(ex-ey)) z.mant = z.mant.sub(z.mant, y.mant) } ex = ey } // operands may have canceled each other out if len(z.mant) == 0 { z.acc = Exact z.form = zero z.neg = false return } // len(z.mant) > 0 z.setExpAndRound(ex+int64(len(z.mant))*_W-fnorm(z.mant), 0) } // z = x * y, ignoring signs of x and y for the multiplication // but using the sign of z for rounding the result. // x and y must have a non-empty mantissa and valid exponent. func (z *Float) umul(x, y *Float) { if debugFloat { validateBinaryOperands(x, y) } // Note: This is doing too much work if the precision // of z is less than the sum of the precisions of x // and y which is often the case (e.g., if all floats // have the same precision). // TODO(gri) Optimize this for the common case. e := int64(x.exp) + int64(y.exp) if x == y { z.mant = z.mant.sqr(nil, x.mant) } else { z.mant = z.mant.mul(nil, x.mant, y.mant) } z.setExpAndRound(e-fnorm(z.mant), 0) } // z = x / y, ignoring signs of x and y for the division // but using the sign of z for rounding the result. // x and y must have a non-empty mantissa and valid exponent. func (z *Float) uquo(x, y *Float) { if debugFloat { validateBinaryOperands(x, y) } // mantissa length in words for desired result precision + 1 // (at least one extra bit so we get the rounding bit after // the division) n := int(z.prec/_W) + 1 // compute adjusted x.mant such that we get enough result precision xadj := x.mant if d := n - len(x.mant) + len(y.mant); d > 0 { // d extra words needed => add d "0 digits" to x xadj = make(nat, len(x.mant)+d) copy(xadj[d:], x.mant) } // TODO(gri): If we have too many digits (d < 0), we should be able // to shorten x for faster division. But we must be extra careful // with rounding in that case. // Compute d before division since there may be aliasing of x.mant // (via xadj) or y.mant with z.mant. d := len(xadj) - len(y.mant) // divide stk := getStack() defer stk.free() var r nat z.mant, r = z.mant.div(stk, nil, xadj, y.mant) e := int64(x.exp) - int64(y.exp) - int64(d-len(z.mant))*_W // The result is long enough to include (at least) the rounding bit. // If there's a non-zero remainder, the corresponding fractional part // (if it were computed), would have a non-zero sticky bit (if it were // zero, it couldn't have a non-zero remainder). var sbit uint if len(r) > 0 { sbit = 1 } z.setExpAndRound(e-fnorm(z.mant), sbit) } // ucmp returns -1, 0, or +1, depending on whether // |x| < |y|, |x| == |y|, or |x| > |y|. // x and y must have a non-empty mantissa and valid exponent. func (x *Float) ucmp(y *Float) int { if debugFloat { validateBinaryOperands(x, y) } switch { case x.exp < y.exp: return -1 case x.exp > y.exp: return +1 } // x.exp == y.exp // compare mantissas i := len(x.mant) j := len(y.mant) for i > 0 || j > 0 { var xm, ym Word if i > 0 { i-- xm = x.mant[i] } if j > 0 { j-- ym = y.mant[j] } switch { case xm < ym: return -1 case xm > ym: return +1 } } return 0 } // Handling of sign bit as defined by IEEE 754-2008, section 6.3: // // When neither the inputs nor result are NaN, the sign of a product or // quotient is the exclusive OR of the operands’ signs; the sign of a sum, // or of a difference x−y regarded as a sum x+(−y), differs from at most // one of the addends’ signs; and the sign of the result of conversions, // the quantize operation, the roundToIntegral operations, and the // roundToIntegralExact (see 5.3.1) is the sign of the first or only operand. // These rules shall apply even when operands or results are zero or infinite. // // When the sum of two operands with opposite signs (or the difference of // two operands with like signs) is exactly zero, the sign of that sum (or // difference) shall be +0 in all rounding-direction attributes except // roundTowardNegative; under that attribute, the sign of an exact zero // sum (or difference) shall be −0. However, x+x = x−(−x) retains the same // sign as x even when x is zero. // // See also: https://play.golang.org/p/RtH3UCt5IH // Add sets z to the rounded sum x+y and returns z. If z's precision is 0, // it is changed to the larger of x's or y's precision before the operation. // Rounding is performed according to z's precision and rounding mode; and // z's accuracy reports the result error relative to the exact (not rounded) // result. Add panics with [ErrNaN] if x and y are infinities with opposite // signs. The value of z is undefined in that case. func (z *Float) Add(x, y *Float) *Float { if debugFloat { x.validate() y.validate() } if z.prec == 0 { z.prec = max(x.prec, y.prec) } if x.form == finite && y.form == finite { // x + y (common case) // Below we set z.neg = x.neg, and when z aliases y this will // change the y operand's sign. This is fine, because if an // operand aliases the receiver it'll be overwritten, but we still // want the original x.neg and y.neg values when we evaluate // x.neg != y.neg, so we need to save y.neg before setting z.neg. yneg := y.neg z.neg = x.neg if x.neg == yneg { // x + y == x + y // (-x) + (-y) == -(x + y) z.uadd(x, y) } else { // x + (-y) == x - y == -(y - x) // (-x) + y == y - x == -(x - y) if x.ucmp(y) > 0 { z.usub(x, y) } else { z.neg = !z.neg z.usub(y, x) } } if z.form == zero && z.mode == ToNegativeInf && z.acc == Exact { z.neg = true } return z } if x.form == inf && y.form == inf && x.neg != y.neg { // +Inf + -Inf // -Inf + +Inf // value of z is undefined but make sure it's valid z.acc = Exact z.form = zero z.neg = false panic(ErrNaN{"addition of infinities with opposite signs"}) } if x.form == zero && y.form == zero { // ±0 + ±0 z.acc = Exact z.form = zero z.neg = x.neg && y.neg // -0 + -0 == -0 return z } if x.form == inf || y.form == zero { // ±Inf + y // x + ±0 return z.Set(x) } // ±0 + y // x + ±Inf return z.Set(y) } // Sub sets z to the rounded difference x-y and returns z. // Precision, rounding, and accuracy reporting are as for [Float.Add]. // Sub panics with [ErrNaN] if x and y are infinities with equal // signs. The value of z is undefined in that case. func (z *Float) Sub(x, y *Float) *Float { if debugFloat { x.validate() y.validate() } if z.prec == 0 { z.prec = max(x.prec, y.prec) } if x.form == finite && y.form == finite { // x - y (common case) yneg := y.neg z.neg = x.neg if x.neg != yneg { // x - (-y) == x + y // (-x) - y == -(x + y) z.uadd(x, y) } else { // x - y == x - y == -(y - x) // (-x) - (-y) == y - x == -(x - y) if x.ucmp(y) > 0 { z.usub(x, y) } else { z.neg = !z.neg z.usub(y, x) } } if z.form == zero && z.mode == ToNegativeInf && z.acc == Exact { z.neg = true } return z } if x.form == inf && y.form == inf && x.neg == y.neg { // +Inf - +Inf // -Inf - -Inf // value of z is undefined but make sure it's valid z.acc = Exact z.form = zero z.neg = false panic(ErrNaN{"subtraction of infinities with equal signs"}) } if x.form == zero && y.form == zero { // ±0 - ±0 z.acc = Exact z.form = zero z.neg = x.neg && !y.neg // -0 - +0 == -0 return z } if x.form == inf || y.form == zero { // ±Inf - y // x - ±0 return z.Set(x) } // ±0 - y // x - ±Inf return z.Neg(y) } // Mul sets z to the rounded product x*y and returns z. // Precision, rounding, and accuracy reporting are as for [Float.Add]. // Mul panics with [ErrNaN] if one operand is zero and the other // operand an infinity. The value of z is undefined in that case. func (z *Float) Mul(x, y *Float) *Float { if debugFloat { x.validate() y.validate() } if z.prec == 0 { z.prec = max(x.prec, y.prec) } z.neg = x.neg != y.neg if x.form == finite && y.form == finite { // x * y (common case) z.umul(x, y) return z } z.acc = Exact if x.form == zero && y.form == inf || x.form == inf && y.form == zero { // ±0 * ±Inf // ±Inf * ±0 // value of z is undefined but make sure it's valid z.form = zero z.neg = false panic(ErrNaN{"multiplication of zero with infinity"}) } if x.form == inf || y.form == inf { // ±Inf * y // x * ±Inf z.form = inf return z } // ±0 * y // x * ±0 z.form = zero return z } // Quo sets z to the rounded quotient x/y and returns z. // Precision, rounding, and accuracy reporting are as for [Float.Add]. // Quo panics with [ErrNaN] if both operands are zero or infinities. // The value of z is undefined in that case. func (z *Float) Quo(x, y *Float) *Float { if debugFloat { x.validate() y.validate() } if z.prec == 0 { z.prec = max(x.prec, y.prec) } z.neg = x.neg != y.neg if x.form == finite && y.form == finite { // x / y (common case) z.uquo(x, y) return z } z.acc = Exact if x.form == zero && y.form == zero || x.form == inf && y.form == inf { // ±0 / ±0 // ±Inf / ±Inf // value of z is undefined but make sure it's valid z.form = zero z.neg = false panic(ErrNaN{"division of zero by zero or infinity by infinity"}) } if x.form == zero || y.form == inf { // ±0 / y // x / ±Inf z.form = zero return z } // x / ±0 // ±Inf / y z.form = inf return z } // Cmp compares x and y and returns: // - -1 if x < y; // - 0 if x == y (incl. -0 == 0, -Inf == -Inf, and +Inf == +Inf); // - +1 if x > y. func (x *Float) Cmp(y *Float) int { if debugFloat { x.validate() y.validate() } mx := x.ord() my := y.ord() switch { case mx < my: return -1 case mx > my: return +1 } // mx == my // only if |mx| == 1 we have to compare the mantissae switch mx { case -1: return y.ucmp(x) case +1: return x.ucmp(y) } return 0 } // ord classifies x and returns: // // -2 if -Inf == x // -1 if -Inf < x < 0 // 0 if x == 0 (signed or unsigned) // +1 if 0 < x < +Inf // +2 if x == +Inf func (x *Float) ord() int { var m int switch x.form { case finite: m = 1 case zero: return 0 case inf: m = 2 } if x.neg { m = -m } return m }
// Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements string-to-Float conversion functions. package big import ( "fmt" "io" "strings" ) var floatZero Float // SetString sets z to the value of s and returns z and a boolean indicating // success. s must be a floating-point number of the same format as accepted // by [Float.Parse], with base argument 0. The entire string (not just a prefix) must // be valid for success. If the operation failed, the value of z is undefined // but the returned value is nil. func (z *Float) SetString(s string) (*Float, bool) { if f, _, err := z.Parse(s, 0); err == nil { return f, true } return nil, false } // scan is like Parse but reads the longest possible prefix representing a valid // floating point number from an io.ByteScanner rather than a string. It serves // as the implementation of Parse. It does not recognize ±Inf and does not expect // EOF at the end. func (z *Float) scan(r io.ByteScanner, base int) (f *Float, b int, err error) { prec := z.prec if prec == 0 { prec = 64 } // A reasonable value in case of an error. z.form = zero // sign z.neg, err = scanSign(r) if err != nil { return } // mantissa var fcount int // fractional digit count; valid if <= 0 z.mant, b, fcount, err = z.mant.scan(r, base, true) if err != nil { return } // exponent var exp int64 var ebase int exp, ebase, err = scanExponent(r, true, base == 0) if err != nil { return } // special-case 0 if len(z.mant) == 0 { z.prec = prec z.acc = Exact z.form = zero f = z return } // len(z.mant) > 0 // The mantissa may have a radix point (fcount <= 0) and there // may be a nonzero exponent exp. The radix point amounts to a // division by b**(-fcount). An exponent means multiplication by // ebase**exp. Finally, mantissa normalization (shift left) requires // a correcting multiplication by 2**(-shiftcount). Multiplications // are commutative, so we can apply them in any order as long as there // is no loss of precision. We only have powers of 2 and 10, and // we split powers of 10 into the product of the same powers of // 2 and 5. This reduces the size of the multiplication factor // needed for base-10 exponents. // normalize mantissa and determine initial exponent contributions exp2 := int64(len(z.mant))*_W - fnorm(z.mant) exp5 := int64(0) // determine binary or decimal exponent contribution of radix point if fcount < 0 { // The mantissa has a radix point ddd.dddd; and // -fcount is the number of digits to the right // of '.'. Adjust relevant exponent accordingly. d := int64(fcount) switch b { case 10: exp5 = d fallthrough // 10**e == 5**e * 2**e case 2: exp2 += d case 8: exp2 += d * 3 // octal digits are 3 bits each case 16: exp2 += d * 4 // hexadecimal digits are 4 bits each default: panic("unexpected mantissa base") } // fcount consumed - not needed anymore } // take actual exponent into account switch ebase { case 10: exp5 += exp fallthrough // see fallthrough above case 2: exp2 += exp default: panic("unexpected exponent base") } // exp consumed - not needed anymore // apply 2**exp2 if MinExp <= exp2 && exp2 <= MaxExp { z.prec = prec z.form = finite z.exp = int32(exp2) f = z } else { err = fmt.Errorf("exponent overflow") return } if exp5 == 0 { // no decimal exponent contribution z.round(0) return } // exp5 != 0 // apply 5**exp5 p := new(Float).SetPrec(z.Prec() + 64) // use more bits for p -- TODO(gri) what is the right number? if exp5 < 0 { z.Quo(z, p.pow5(uint64(-exp5))) } else { z.Mul(z, p.pow5(uint64(exp5))) } return } // These powers of 5 fit into a uint64. // // for p, q := uint64(0), uint64(1); p < q; p, q = q, q*5 { // fmt.Println(q) // } var pow5tab = [...]uint64{ 1, 5, 25, 125, 625, 3125, 15625, 78125, 390625, 1953125, 9765625, 48828125, 244140625, 1220703125, 6103515625, 30517578125, 152587890625, 762939453125, 3814697265625, 19073486328125, 95367431640625, 476837158203125, 2384185791015625, 11920928955078125, 59604644775390625, 298023223876953125, 1490116119384765625, 7450580596923828125, } // pow5 sets z to 5**n and returns z. // n must not be negative. func (z *Float) pow5(n uint64) *Float { const m = uint64(len(pow5tab) - 1) if n <= m { return z.SetUint64(pow5tab[n]) } // n > m z.SetUint64(pow5tab[m]) n -= m // use more bits for f than for z // TODO(gri) what is the right number? f := new(Float).SetPrec(z.Prec() + 64).SetUint64(5) for n > 0 { if n&1 != 0 { z.Mul(z, f) } f.Mul(f, f) n >>= 1 } return z } // Parse parses s which must contain a text representation of a floating- // point number with a mantissa in the given conversion base (the exponent // is always a decimal number), or a string representing an infinite value. // // For base 0, an underscore character “_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number, or the returned // digit count. Incorrect placement of underscores is reported as an // error if there are no other errors. If base != 0, underscores are // not recognized and thus terminate scanning like any other character // that is not a valid radix point or digit. // // It sets z to the (possibly rounded) value of the corresponding floating- // point value, and returns z, the actual base b, and an error err, if any. // The entire string (not just a prefix) must be consumed for success. // If z's precision is 0, it is changed to 64 before rounding takes effect. // The number must be of the form: // // number = [ sign ] ( float | "inf" | "Inf" ) . // sign = "+" | "-" . // float = ( mantissa | prefix pmantissa ) [ exponent ] . // prefix = "0" [ "b" | "B" | "o" | "O" | "x" | "X" ] . // mantissa = digits "." [ digits ] | digits | "." digits . // pmantissa = [ "_" ] digits "." [ digits ] | [ "_" ] digits | "." digits . // exponent = ( "e" | "E" | "p" | "P" ) [ sign ] digits . // digits = digit { [ "_" ] digit } . // digit = "0" ... "9" | "a" ... "z" | "A" ... "Z" . // // The base argument must be 0, 2, 8, 10, or 16. Providing an invalid base // argument will lead to a run-time panic. // // For base 0, the number prefix determines the actual base: A prefix of // “0b” or “0B” selects base 2, “0o” or “0O” selects base 8, and // “0x” or “0X” selects base 16. Otherwise, the actual base is 10 and // no prefix is accepted. The octal prefix "0" is not supported (a leading // "0" is simply considered a "0"). // // A "p" or "P" exponent indicates a base 2 (rather than base 10) exponent; // for instance, "0x1.fffffffffffffp1023" (using base 0) represents the // maximum float64 value. For hexadecimal mantissae, the exponent character // must be one of 'p' or 'P', if present (an "e" or "E" exponent indicator // cannot be distinguished from a mantissa digit). // // The returned *Float f is nil and the value of z is valid but not // defined if an error is reported. func (z *Float) Parse(s string, base int) (f *Float, b int, err error) { // scan doesn't handle ±Inf if len(s) == 3 && (s == "Inf" || s == "inf") { f = z.SetInf(false) return } if len(s) == 4 && (s[0] == '+' || s[0] == '-') && (s[1:] == "Inf" || s[1:] == "inf") { f = z.SetInf(s[0] == '-') return } r := strings.NewReader(s) if f, b, err = z.scan(r, base); err != nil { return } // entire string must have been consumed if ch, err2 := r.ReadByte(); err2 == nil { err = fmt.Errorf("expected end of string, found %q", ch) } else if err2 != io.EOF { err = err2 } return } // ParseFloat is like f.Parse(s, base) with f set to the given precision // and rounding mode. func ParseFloat(s string, base int, prec uint, mode RoundingMode) (f *Float, b int, err error) { return new(Float).SetPrec(prec).SetMode(mode).Parse(s, base) } var _ fmt.Scanner = (*Float)(nil) // *Float must implement fmt.Scanner // Scan is a support routine for [fmt.Scanner]; it sets z to the value of // the scanned number. It accepts formats whose verbs are supported by // [fmt.Scan] for floating point values, which are: // 'b' (binary), 'e', 'E', 'f', 'F', 'g' and 'G'. // Scan doesn't handle ±Inf. func (z *Float) Scan(s fmt.ScanState, ch rune) error { s.SkipSpace() _, _, err := z.scan(byteReader{s}, 0) return err }
// Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements encoding/decoding of Floats. package big import ( "errors" "fmt" "internal/byteorder" ) // Gob codec version. Permits backward-compatible changes to the encoding. const floatGobVersion byte = 1 // GobEncode implements the [encoding/gob.GobEncoder] interface. // The [Float] value and all its attributes (precision, // rounding mode, accuracy) are marshaled. func (x *Float) GobEncode() ([]byte, error) { if x == nil { return nil, nil } // determine max. space (bytes) required for encoding sz := 1 + 1 + 4 // version + mode|acc|form|neg (3+2+2+1bit) + prec n := 0 // number of mantissa words if x.form == finite { // add space for mantissa and exponent n = int((x.prec + (_W - 1)) / _W) // required mantissa length in words for given precision // actual mantissa slice could be shorter (trailing 0's) or longer (unused bits): // - if shorter, only encode the words present // - if longer, cut off unused words when encoding in bytes // (in practice, this should never happen since rounding // takes care of it, but be safe and do it always) if len(x.mant) < n { n = len(x.mant) } // len(x.mant) >= n sz += 4 + n*_S // exp + mant } buf := make([]byte, sz) buf[0] = floatGobVersion b := byte(x.mode&7)<<5 | byte((x.acc+1)&3)<<3 | byte(x.form&3)<<1 if x.neg { b |= 1 } buf[1] = b byteorder.BEPutUint32(buf[2:], x.prec) if x.form == finite { byteorder.BEPutUint32(buf[6:], uint32(x.exp)) x.mant[len(x.mant)-n:].bytes(buf[10:]) // cut off unused trailing words } return buf, nil } // GobDecode implements the [encoding/gob.GobDecoder] interface. // The result is rounded per the precision and rounding mode of // z unless z's precision is 0, in which case z is set exactly // to the decoded value. func (z *Float) GobDecode(buf []byte) error { if len(buf) == 0 { // Other side sent a nil or default value. *z = Float{} return nil } if len(buf) < 6 { return errors.New("Float.GobDecode: buffer too small") } if buf[0] != floatGobVersion { return fmt.Errorf("Float.GobDecode: encoding version %d not supported", buf[0]) } oldPrec := z.prec oldMode := z.mode b := buf[1] z.mode = RoundingMode((b >> 5) & 7) z.acc = Accuracy((b>>3)&3) - 1 z.form = form((b >> 1) & 3) z.neg = b&1 != 0 z.prec = byteorder.BEUint32(buf[2:]) if z.form == finite { if len(buf) < 10 { return errors.New("Float.GobDecode: buffer too small for finite form float") } z.exp = int32(byteorder.BEUint32(buf[6:])) z.mant = z.mant.setBytes(buf[10:]) } if oldPrec != 0 { z.mode = oldMode z.SetPrec(uint(oldPrec)) } if msg := z.validate0(); msg != "" { return errors.New("Float.GobDecode: " + msg) } return nil } // AppendText implements the [encoding.TextAppender] interface. // Only the [Float] value is marshaled (in full precision), other // attributes such as precision or accuracy are ignored. func (x *Float) AppendText(b []byte) ([]byte, error) { if x == nil { return append(b, "<nil>"...), nil } return x.Append(b, 'g', -1), nil } // MarshalText implements the [encoding.TextMarshaler] interface. // Only the [Float] value is marshaled (in full precision), other // attributes such as precision or accuracy are ignored. func (x *Float) MarshalText() (text []byte, err error) { return x.AppendText(nil) } // UnmarshalText implements the [encoding.TextUnmarshaler] interface. // The result is rounded per the precision and rounding mode of z. // If z's precision is 0, it is changed to 64 before rounding takes // effect. func (z *Float) UnmarshalText(text []byte) error { // TODO(gri): get rid of the []byte/string conversion _, _, err := z.Parse(string(text), 0) if err != nil { err = fmt.Errorf("math/big: cannot unmarshal %q into a *big.Float (%v)", text, err) } return err }
// Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements Float-to-string conversion functions. // It is closely following the corresponding implementation // in strconv/ftoa.go, but modified and simplified for Float. package big import ( "bytes" "fmt" "strconv" ) // Text converts the floating-point number x to a string according // to the given format and precision prec. The format is one of: // // 'e' -d.dddde±dd, decimal exponent, at least two (possibly 0) exponent digits // 'E' -d.ddddE±dd, decimal exponent, at least two (possibly 0) exponent digits // 'f' -ddddd.dddd, no exponent // 'g' like 'e' for large exponents, like 'f' otherwise // 'G' like 'E' for large exponents, like 'f' otherwise // 'x' -0xd.dddddp±dd, hexadecimal mantissa, decimal power of two exponent // 'p' -0x.dddp±dd, hexadecimal mantissa, decimal power of two exponent (non-standard) // 'b' -ddddddp±dd, decimal mantissa, decimal power of two exponent (non-standard) // // For the power-of-two exponent formats, the mantissa is printed in normalized form: // // 'x' hexadecimal mantissa in [1, 2), or 0 // 'p' hexadecimal mantissa in [½, 1), or 0 // 'b' decimal integer mantissa using x.Prec() bits, or 0 // // Note that the 'x' form is the one used by most other languages and libraries. // // If format is a different character, Text returns a "%" followed by the // unrecognized format character. // // The precision prec controls the number of digits (excluding the exponent) // printed by the 'e', 'E', 'f', 'g', 'G', and 'x' formats. // For 'e', 'E', 'f', and 'x', it is the number of digits after the decimal point. // For 'g' and 'G' it is the total number of digits. A negative precision selects // the smallest number of decimal digits necessary to identify the value x uniquely // using x.Prec() mantissa bits. // The prec value is ignored for the 'b' and 'p' formats. func (x *Float) Text(format byte, prec int) string { cap := 10 // TODO(gri) determine a good/better value here if prec > 0 { cap += prec } return string(x.Append(make([]byte, 0, cap), format, prec)) } // String formats x like x.Text('g', 10). // (String must be called explicitly, [Float.Format] does not support %s verb.) func (x *Float) String() string { return x.Text('g', 10) } // Append appends to buf the string form of the floating-point number x, // as generated by x.Text, and returns the extended buffer. func (x *Float) Append(buf []byte, fmt byte, prec int) []byte { // sign if x.neg { buf = append(buf, '-') } // Inf if x.form == inf { if !x.neg { buf = append(buf, '+') } return append(buf, "Inf"...) } // pick off easy formats switch fmt { case 'b': return x.fmtB(buf) case 'p': return x.fmtP(buf) case 'x': return x.fmtX(buf, prec) } // Algorithm: // 1) convert Float to multiprecision decimal // 2) round to desired precision // 3) read digits out and format // 1) convert Float to multiprecision decimal var d decimal // == 0.0 if x.form == finite { // x != 0 d.init(x.mant, int(x.exp)-x.mant.bitLen()) } // 2) round to desired precision shortest := false if prec < 0 { shortest = true roundShortest(&d, x) // Precision for shortest representation mode. switch fmt { case 'e', 'E': prec = len(d.mant) - 1 case 'f': prec = max(len(d.mant)-d.exp, 0) case 'g', 'G': prec = len(d.mant) } } else { // round appropriately switch fmt { case 'e', 'E': // one digit before and number of digits after decimal point d.round(1 + prec) case 'f': // number of digits before and after decimal point d.round(d.exp + prec) case 'g', 'G': if prec == 0 { prec = 1 } d.round(prec) } } // 3) read digits out and format switch fmt { case 'e', 'E': return fmtE(buf, fmt, prec, d) case 'f': return fmtF(buf, prec, d) case 'g', 'G': // trim trailing fractional zeros in %e format eprec := prec if eprec > len(d.mant) && len(d.mant) >= d.exp { eprec = len(d.mant) } // %e is used if the exponent from the conversion // is less than -4 or greater than or equal to the precision. // If precision was the shortest possible, use eprec = 6 for // this decision. if shortest { eprec = 6 } exp := d.exp - 1 if exp < -4 || exp >= eprec { if prec > len(d.mant) { prec = len(d.mant) } return fmtE(buf, fmt+'e'-'g', prec-1, d) } if prec > d.exp { prec = len(d.mant) } return fmtF(buf, max(prec-d.exp, 0), d) } // unknown format if x.neg { buf = buf[:len(buf)-1] // sign was added prematurely - remove it again } return append(buf, '%', fmt) } func roundShortest(d *decimal, x *Float) { // if the mantissa is zero, the number is zero - stop now if len(d.mant) == 0 { return } // Approach: All numbers in the interval [x - 1/2ulp, x + 1/2ulp] // (possibly exclusive) round to x for the given precision of x. // Compute the lower and upper bound in decimal form and find the // shortest decimal number d such that lower <= d <= upper. // TODO(gri) strconv/ftoa.do describes a shortcut in some cases. // See if we can use it (in adjusted form) here as well. // 1) Compute normalized mantissa mant and exponent exp for x such // that the lsb of mant corresponds to 1/2 ulp for the precision of // x (i.e., for mant we want x.prec + 1 bits). mant := nat(nil).set(x.mant) exp := int(x.exp) - mant.bitLen() s := mant.bitLen() - int(x.prec+1) switch { case s < 0: mant = mant.lsh(mant, uint(-s)) case s > 0: mant = mant.rsh(mant, uint(+s)) } exp += s // x = mant * 2**exp with lsb(mant) == 1/2 ulp of x.prec // 2) Compute lower bound by subtracting 1/2 ulp. var lower decimal var tmp nat lower.init(tmp.sub(mant, natOne), exp) // 3) Compute upper bound by adding 1/2 ulp. var upper decimal upper.init(tmp.add(mant, natOne), exp) // The upper and lower bounds are possible outputs only if // the original mantissa is even, so that ToNearestEven rounding // would round to the original mantissa and not the neighbors. inclusive := mant[0]&2 == 0 // test bit 1 since original mantissa was shifted by 1 // Now we can figure out the minimum number of digits required. // Walk along until d has distinguished itself from upper and lower. for i, m := range d.mant { l := lower.at(i) u := upper.at(i) // Okay to round down (truncate) if lower has a different digit // or if lower is inclusive and is exactly the result of rounding // down (i.e., and we have reached the final digit of lower). okdown := l != m || inclusive && i+1 == len(lower.mant) // Okay to round up if upper has a different digit and either upper // is inclusive or upper is bigger than the result of rounding up. okup := m != u && (inclusive || m+1 < u || i+1 < len(upper.mant)) // If it's okay to do either, then round to the nearest one. // If it's okay to do only one, do it. switch { case okdown && okup: d.round(i + 1) return case okdown: d.roundDown(i + 1) return case okup: d.roundUp(i + 1) return } } } // %e: d.ddddde±dd func fmtE(buf []byte, fmt byte, prec int, d decimal) []byte { // first digit ch := byte('0') if len(d.mant) > 0 { ch = d.mant[0] } buf = append(buf, ch) // .moredigits if prec > 0 { buf = append(buf, '.') i := 1 m := min(len(d.mant), prec+1) if i < m { buf = append(buf, d.mant[i:m]...) i = m } for ; i <= prec; i++ { buf = append(buf, '0') } } // e± buf = append(buf, fmt) var exp int64 if len(d.mant) > 0 { exp = int64(d.exp) - 1 // -1 because first digit was printed before '.' } if exp < 0 { ch = '-' exp = -exp } else { ch = '+' } buf = append(buf, ch) // dd...d if exp < 10 { buf = append(buf, '0') // at least 2 exponent digits } return strconv.AppendInt(buf, exp, 10) } // %f: ddddddd.ddddd func fmtF(buf []byte, prec int, d decimal) []byte { // integer, padded with zeros as needed if d.exp > 0 { m := min(len(d.mant), d.exp) buf = append(buf, d.mant[:m]...) for ; m < d.exp; m++ { buf = append(buf, '0') } } else { buf = append(buf, '0') } // fraction if prec > 0 { buf = append(buf, '.') for i := 0; i < prec; i++ { buf = append(buf, d.at(d.exp+i)) } } return buf } // fmtB appends the string of x in the format mantissa "p" exponent // with a decimal mantissa and a binary exponent, or "0" if x is zero, // and returns the extended buffer. // The mantissa is normalized such that is uses x.Prec() bits in binary // representation. // The sign of x is ignored, and x must not be an Inf. // (The caller handles Inf before invoking fmtB.) func (x *Float) fmtB(buf []byte) []byte { if x.form == zero { return append(buf, '0') } if debugFloat && x.form != finite { panic("non-finite float") } // x != 0 // adjust mantissa to use exactly x.prec bits m := x.mant switch w := uint32(len(x.mant)) * _W; { case w < x.prec: m = nat(nil).lsh(m, uint(x.prec-w)) case w > x.prec: m = nat(nil).rsh(m, uint(w-x.prec)) } buf = append(buf, m.utoa(10)...) buf = append(buf, 'p') e := int64(x.exp) - int64(x.prec) if e >= 0 { buf = append(buf, '+') } return strconv.AppendInt(buf, e, 10) } // fmtX appends the string of x in the format "0x1." mantissa "p" exponent // with a hexadecimal mantissa and a binary exponent, or "0x0p0" if x is zero, // and returns the extended buffer. // A non-zero mantissa is normalized such that 1.0 <= mantissa < 2.0. // The sign of x is ignored, and x must not be an Inf. // (The caller handles Inf before invoking fmtX.) func (x *Float) fmtX(buf []byte, prec int) []byte { if x.form == zero { buf = append(buf, "0x0"...) if prec > 0 { buf = append(buf, '.') for i := 0; i < prec; i++ { buf = append(buf, '0') } } buf = append(buf, "p+00"...) return buf } if debugFloat && x.form != finite { panic("non-finite float") } // round mantissa to n bits var n uint if prec < 0 { n = 1 + (x.MinPrec()-1+3)/4*4 // round MinPrec up to 1 mod 4 } else { n = 1 + 4*uint(prec) } // n%4 == 1 x = new(Float).SetPrec(n).SetMode(x.mode).Set(x) // adjust mantissa to use exactly n bits m := x.mant switch w := uint(len(x.mant)) * _W; { case w < n: m = nat(nil).lsh(m, n-w) case w > n: m = nat(nil).rsh(m, w-n) } exp64 := int64(x.exp) - 1 // avoid wrap-around hm := m.utoa(16) if debugFloat && hm[0] != '1' { panic("incorrect mantissa: " + string(hm)) } buf = append(buf, "0x1"...) if len(hm) > 1 { buf = append(buf, '.') buf = append(buf, hm[1:]...) } buf = append(buf, 'p') if exp64 >= 0 { buf = append(buf, '+') } else { exp64 = -exp64 buf = append(buf, '-') } // Force at least two exponent digits, to match fmt. if exp64 < 10 { buf = append(buf, '0') } return strconv.AppendInt(buf, exp64, 10) } // fmtP appends the string of x in the format "0x." mantissa "p" exponent // with a hexadecimal mantissa and a binary exponent, or "0" if x is zero, // and returns the extended buffer. // The mantissa is normalized such that 0.5 <= 0.mantissa < 1.0. // The sign of x is ignored, and x must not be an Inf. // (The caller handles Inf before invoking fmtP.) func (x *Float) fmtP(buf []byte) []byte { if x.form == zero { return append(buf, '0') } if debugFloat && x.form != finite { panic("non-finite float") } // x != 0 // remove trailing 0 words early // (no need to convert to hex 0's and trim later) m := x.mant i := 0 for i < len(m) && m[i] == 0 { i++ } m = m[i:] buf = append(buf, "0x."...) buf = append(buf, bytes.TrimRight(m.utoa(16), "0")...) buf = append(buf, 'p') if x.exp >= 0 { buf = append(buf, '+') } return strconv.AppendInt(buf, int64(x.exp), 10) } var _ fmt.Formatter = &floatZero // *Float must implement fmt.Formatter // Format implements [fmt.Formatter]. It accepts all the regular // formats for floating-point numbers ('b', 'e', 'E', 'f', 'F', // 'g', 'G', 'x') as well as 'p' and 'v'. See (*Float).Text for the // interpretation of 'p'. The 'v' format is handled like 'g'. // Format also supports specification of the minimum precision // in digits, the output field width, as well as the format flags // '+' and ' ' for sign control, '0' for space or zero padding, // and '-' for left or right justification. See the fmt package // for details. func (x *Float) Format(s fmt.State, format rune) { prec, hasPrec := s.Precision() if !hasPrec { prec = 6 // default precision for 'e', 'f' } switch format { case 'e', 'E', 'f', 'b', 'p', 'x': // nothing to do case 'F': // (*Float).Text doesn't support 'F'; handle like 'f' format = 'f' case 'v': // handle like 'g' format = 'g' fallthrough case 'g', 'G': if !hasPrec { prec = -1 // default precision for 'g', 'G' } default: fmt.Fprintf(s, "%%!%c(*big.Float=%s)", format, x.String()) return } var buf []byte buf = x.Append(buf, byte(format), prec) if len(buf) == 0 { buf = []byte("?") // should never happen, but don't crash } // len(buf) > 0 var sign string switch { case buf[0] == '-': sign = "-" buf = buf[1:] case buf[0] == '+': // +Inf sign = "+" if s.Flag(' ') { sign = " " } buf = buf[1:] case s.Flag('+'): sign = "+" case s.Flag(' '): sign = " " } var padding int if width, hasWidth := s.Width(); hasWidth && width > len(sign)+len(buf) { padding = width - len(sign) - len(buf) } switch { case s.Flag('0') && !x.IsInf(): // 0-padding on left writeMultiple(s, sign, 1) writeMultiple(s, "0", padding) s.Write(buf) case s.Flag('-'): // padding on right writeMultiple(s, sign, 1) s.Write(buf) writeMultiple(s, " ", padding) default: // padding on left writeMultiple(s, " ", padding) writeMultiple(s, sign, 1) s.Write(buf) } }
// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements signed multi-precision integers. package big import ( "fmt" "io" "math/rand" "strings" ) // An Int represents a signed multi-precision integer. // The zero value for an Int represents the value 0. // // Operations always take pointer arguments (*Int) rather // than Int values, and each unique Int value requires // its own unique *Int pointer. To "copy" an Int value, // an existing (or newly allocated) Int must be set to // a new value using the [Int.Set] method; shallow copies // of Ints are not supported and may lead to errors. // // Note that methods may leak the Int's value through timing side-channels. // Because of this and because of the scope and complexity of the // implementation, Int is not well-suited to implement cryptographic operations. // The standard library avoids exposing non-trivial Int methods to // attacker-controlled inputs and the determination of whether a bug in math/big // is considered a security vulnerability might depend on the impact on the // standard library. type Int struct { neg bool // sign abs nat // absolute value of the integer } var intOne = &Int{false, natOne} // Sign returns: // - -1 if x < 0; // - 0 if x == 0; // - +1 if x > 0. func (x *Int) Sign() int { // This function is used in cryptographic operations. It must not leak // anything but the Int's sign and bit size through side-channels. Any // changes must be reviewed by a security expert. if len(x.abs) == 0 { return 0 } if x.neg { return -1 } return 1 } // SetInt64 sets z to x and returns z. func (z *Int) SetInt64(x int64) *Int { neg := false if x < 0 { neg = true x = -x } z.abs = z.abs.setUint64(uint64(x)) z.neg = neg return z } // SetUint64 sets z to x and returns z. func (z *Int) SetUint64(x uint64) *Int { z.abs = z.abs.setUint64(x) z.neg = false return z } // NewInt allocates and returns a new [Int] set to x. func NewInt(x int64) *Int { // This code is arranged to be inlineable and produce // zero allocations when inlined. See issue 29951. u := uint64(x) if x < 0 { u = -u } var abs []Word if x == 0 { } else if _W == 32 && u>>32 != 0 { abs = []Word{Word(u), Word(u >> 32)} } else { abs = []Word{Word(u)} } return &Int{neg: x < 0, abs: abs} } // Set sets z to x and returns z. func (z *Int) Set(x *Int) *Int { if z != x { z.abs = z.abs.set(x.abs) z.neg = x.neg } return z } // Bits provides raw (unchecked but fast) access to x by returning its // absolute value as a little-endian [Word] slice. The result and x share // the same underlying array. // Bits is intended to support implementation of missing low-level [Int] // functionality outside this package; it should be avoided otherwise. func (x *Int) Bits() []Word { // This function is used in cryptographic operations. It must not leak // anything but the Int's sign and bit size through side-channels. Any // changes must be reviewed by a security expert. return x.abs } // SetBits provides raw (unchecked but fast) access to z by setting its // value to abs, interpreted as a little-endian [Word] slice, and returning // z. The result and abs share the same underlying array. // SetBits is intended to support implementation of missing low-level [Int] // functionality outside this package; it should be avoided otherwise. func (z *Int) SetBits(abs []Word) *Int { z.abs = nat(abs).norm() z.neg = false return z } // Abs sets z to |x| (the absolute value of x) and returns z. func (z *Int) Abs(x *Int) *Int { z.Set(x) z.neg = false return z } // Neg sets z to -x and returns z. func (z *Int) Neg(x *Int) *Int { z.Set(x) z.neg = len(z.abs) > 0 && !z.neg // 0 has no sign return z } // Add sets z to the sum x+y and returns z. func (z *Int) Add(x, y *Int) *Int { neg := x.neg if x.neg == y.neg { // x + y == x + y // (-x) + (-y) == -(x + y) z.abs = z.abs.add(x.abs, y.abs) } else { // x + (-y) == x - y == -(y - x) // (-x) + y == y - x == -(x - y) if x.abs.cmp(y.abs) >= 0 { z.abs = z.abs.sub(x.abs, y.abs) } else { neg = !neg z.abs = z.abs.sub(y.abs, x.abs) } } z.neg = len(z.abs) > 0 && neg // 0 has no sign return z } // Sub sets z to the difference x-y and returns z. func (z *Int) Sub(x, y *Int) *Int { neg := x.neg if x.neg != y.neg { // x - (-y) == x + y // (-x) - y == -(x + y) z.abs = z.abs.add(x.abs, y.abs) } else { // x - y == x - y == -(y - x) // (-x) - (-y) == y - x == -(x - y) if x.abs.cmp(y.abs) >= 0 { z.abs = z.abs.sub(x.abs, y.abs) } else { neg = !neg z.abs = z.abs.sub(y.abs, x.abs) } } z.neg = len(z.abs) > 0 && neg // 0 has no sign return z } // Mul sets z to the product x*y and returns z. func (z *Int) Mul(x, y *Int) *Int { z.mul(nil, x, y) return z } // mul is like Mul but takes an explicit stack to use, for internal use. // It does not return a *Int because doing so makes the stack-allocated Ints // used in natmul.go escape to the heap (even though the result is unused). func (z *Int) mul(stk *stack, x, y *Int) { // x * y == x * y // x * (-y) == -(x * y) // (-x) * y == -(x * y) // (-x) * (-y) == x * y if x == y { z.abs = z.abs.sqr(stk, x.abs) z.neg = false return } z.abs = z.abs.mul(stk, x.abs, y.abs) z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign } // MulRange sets z to the product of all integers // in the range [a, b] inclusively and returns z. // If a > b (empty range), the result is 1. func (z *Int) MulRange(a, b int64) *Int { switch { case a > b: return z.SetInt64(1) // empty range case a <= 0 && b >= 0: return z.SetInt64(0) // range includes 0 } // a <= b && (b < 0 || a > 0) neg := false if a < 0 { neg = (b-a)&1 == 0 a, b = -b, -a } z.abs = z.abs.mulRange(nil, uint64(a), uint64(b)) z.neg = neg return z } // Binomial sets z to the binomial coefficient C(n, k) and returns z. func (z *Int) Binomial(n, k int64) *Int { if k > n { return z.SetInt64(0) } // reduce the number of multiplications by reducing k if k > n-k { k = n - k // C(n, k) == C(n, n-k) } // C(n, k) == n * (n-1) * ... * (n-k+1) / k * (k-1) * ... * 1 // == n * (n-1) * ... * (n-k+1) / 1 * (1+1) * ... * k // // Using the multiplicative formula produces smaller values // at each step, requiring fewer allocations and computations: // // z = 1 // for i := 0; i < k; i = i+1 { // z *= n-i // z /= i+1 // } // // finally to avoid computing i+1 twice per loop: // // z = 1 // i := 0 // for i < k { // z *= n-i // i++ // z /= i // } var N, K, i, t Int N.SetInt64(n) K.SetInt64(k) z.Set(intOne) for i.Cmp(&K) < 0 { z.Mul(z, t.Sub(&N, &i)) i.Add(&i, intOne) z.Quo(z, &i) } return z } // Quo sets z to the quotient x/y for y != 0 and returns z. // If y == 0, a division-by-zero run-time panic occurs. // Quo implements truncated division (like Go); see [Int.QuoRem] for more details. func (z *Int) Quo(x, y *Int) *Int { z.abs, _ = z.abs.div(nil, nil, x.abs, y.abs) z.neg = len(z.abs) > 0 && x.neg != y.neg // 0 has no sign return z } // Rem sets z to the remainder x%y for y != 0 and returns z. // If y == 0, a division-by-zero run-time panic occurs. // Rem implements truncated modulus (like Go); see [Int.QuoRem] for more details. func (z *Int) Rem(x, y *Int) *Int { _, z.abs = nat(nil).div(nil, z.abs, x.abs, y.abs) z.neg = len(z.abs) > 0 && x.neg // 0 has no sign return z } // QuoRem sets z to the quotient x/y and r to the remainder x%y // and returns the pair (z, r) for y != 0. // If y == 0, a division-by-zero run-time panic occurs. // // QuoRem implements T-division and modulus (like Go): // // q = x/y with the result truncated to zero // r = x - y*q // // (See Daan Leijen, “Division and Modulus for Computer Scientists”.) // See [Int.DivMod] for Euclidean division and modulus (unlike Go). func (z *Int) QuoRem(x, y, r *Int) (*Int, *Int) { z.abs, r.abs = z.abs.div(nil, r.abs, x.abs, y.abs) z.neg, r.neg = len(z.abs) > 0 && x.neg != y.neg, len(r.abs) > 0 && x.neg // 0 has no sign return z, r } // Div sets z to the quotient x/y for y != 0 and returns z. // If y == 0, a division-by-zero run-time panic occurs. // Div implements Euclidean division (unlike Go); see [Int.DivMod] for more details. func (z *Int) Div(x, y *Int) *Int { y_neg := y.neg // z may be an alias for y var r Int z.QuoRem(x, y, &r) if r.neg { if y_neg { z.Add(z, intOne) } else { z.Sub(z, intOne) } } return z } // Mod sets z to the modulus x%y for y != 0 and returns z. // If y == 0, a division-by-zero run-time panic occurs. // Mod implements Euclidean modulus (unlike Go); see [Int.DivMod] for more details. func (z *Int) Mod(x, y *Int) *Int { y0 := y // save y if z == y || alias(z.abs, y.abs) { y0 = new(Int).Set(y) } var q Int q.QuoRem(x, y, z) if z.neg { if y0.neg { z.Sub(z, y0) } else { z.Add(z, y0) } } return z } // DivMod sets z to the quotient x div y and m to the modulus x mod y // and returns the pair (z, m) for y != 0. // If y == 0, a division-by-zero run-time panic occurs. // // DivMod implements Euclidean division and modulus (unlike Go): // // q = x div y such that // m = x - y*q with 0 <= m < |y| // // (See Raymond T. Boute, “The Euclidean definition of the functions // div and mod”. ACM Transactions on Programming Languages and // Systems (TOPLAS), 14(2):127-144, New York, NY, USA, 4/1992. // ACM press.) // See [Int.QuoRem] for T-division and modulus (like Go). func (z *Int) DivMod(x, y, m *Int) (*Int, *Int) { y0 := y // save y if z == y || alias(z.abs, y.abs) { y0 = new(Int).Set(y) } z.QuoRem(x, y, m) if m.neg { if y0.neg { z.Add(z, intOne) m.Sub(m, y0) } else { z.Sub(z, intOne) m.Add(m, y0) } } return z, m } // Cmp compares x and y and returns: // - -1 if x < y; // - 0 if x == y; // - +1 if x > y. func (x *Int) Cmp(y *Int) (r int) { // x cmp y == x cmp y // x cmp (-y) == x // (-x) cmp y == y // (-x) cmp (-y) == -(x cmp y) switch { case x == y: // nothing to do case x.neg == y.neg: r = x.abs.cmp(y.abs) if x.neg { r = -r } case x.neg: r = -1 default: r = 1 } return } // CmpAbs compares the absolute values of x and y and returns: // - -1 if |x| < |y|; // - 0 if |x| == |y|; // - +1 if |x| > |y|. func (x *Int) CmpAbs(y *Int) int { return x.abs.cmp(y.abs) } // low32 returns the least significant 32 bits of x. func low32(x nat) uint32 { if len(x) == 0 { return 0 } return uint32(x[0]) } // low64 returns the least significant 64 bits of x. func low64(x nat) uint64 { if len(x) == 0 { return 0 } v := uint64(x[0]) if _W == 32 && len(x) > 1 { return uint64(x[1])<<32 | v } return v } // Int64 returns the int64 representation of x. // If x cannot be represented in an int64, the result is undefined. func (x *Int) Int64() int64 { v := int64(low64(x.abs)) if x.neg { v = -v } return v } // Uint64 returns the uint64 representation of x. // If x cannot be represented in a uint64, the result is undefined. func (x *Int) Uint64() uint64 { return low64(x.abs) } // IsInt64 reports whether x can be represented as an int64. func (x *Int) IsInt64() bool { if len(x.abs) <= 64/_W { w := int64(low64(x.abs)) return w >= 0 || x.neg && w == -w } return false } // IsUint64 reports whether x can be represented as a uint64. func (x *Int) IsUint64() bool { return !x.neg && len(x.abs) <= 64/_W } // Float64 returns the float64 value nearest x, // and an indication of any rounding that occurred. func (x *Int) Float64() (float64, Accuracy) { n := x.abs.bitLen() // NB: still uses slow crypto impl! if n == 0 { return 0.0, Exact } // Fast path: no more than 53 significant bits. if n <= 53 || n < 64 && n-int(x.abs.trailingZeroBits()) <= 53 { f := float64(low64(x.abs)) if x.neg { f = -f } return f, Exact } return new(Float).SetInt(x).Float64() } // SetString sets z to the value of s, interpreted in the given base, // and returns z and a boolean indicating success. The entire string // (not just a prefix) must be valid for success. If SetString fails, // the value of z is undefined but the returned value is nil. // // The base argument must be 0 or a value between 2 and [MaxBase]. // For base 0, the number prefix determines the actual base: A prefix of // “0b” or “0B” selects base 2, “0”, “0o” or “0O” selects base 8, // and “0x” or “0X” selects base 16. Otherwise, the selected base is 10 // and no prefix is accepted. // // For bases <= 36, lower and upper case letters are considered the same: // The letters 'a' to 'z' and 'A' to 'Z' represent digit values 10 to 35. // For bases > 36, the upper case letters 'A' to 'Z' represent the digit // values 36 to 61. // // For base 0, an underscore character “_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number. // Incorrect placement of underscores is reported as an error if there // are no other errors. If base != 0, underscores are not recognized // and act like any other character that is not a valid digit. func (z *Int) SetString(s string, base int) (*Int, bool) { return z.setFromScanner(strings.NewReader(s), base) } // setFromScanner implements SetString given an io.ByteScanner. // For documentation see comments of SetString. func (z *Int) setFromScanner(r io.ByteScanner, base int) (*Int, bool) { if _, _, err := z.scan(r, base); err != nil { return nil, false } // entire content must have been consumed if _, err := r.ReadByte(); err != io.EOF { return nil, false } return z, true // err == io.EOF => scan consumed all content of r } // SetBytes interprets buf as the bytes of a big-endian unsigned // integer, sets z to that value, and returns z. func (z *Int) SetBytes(buf []byte) *Int { z.abs = z.abs.setBytes(buf) z.neg = false return z } // Bytes returns the absolute value of x as a big-endian byte slice. // // To use a fixed length slice, or a preallocated one, use [Int.FillBytes]. func (x *Int) Bytes() []byte { // This function is used in cryptographic operations. It must not leak // anything but the Int's sign and bit size through side-channels. Any // changes must be reviewed by a security expert. buf := make([]byte, len(x.abs)*_S) return buf[x.abs.bytes(buf):] } // FillBytes sets buf to the absolute value of x, storing it as a zero-extended // big-endian byte slice, and returns buf. // // If the absolute value of x doesn't fit in buf, FillBytes will panic. func (x *Int) FillBytes(buf []byte) []byte { // Clear whole buffer. clear(buf) x.abs.bytes(buf) return buf } // BitLen returns the length of the absolute value of x in bits. // The bit length of 0 is 0. func (x *Int) BitLen() int { // This function is used in cryptographic operations. It must not leak // anything but the Int's sign and bit size through side-channels. Any // changes must be reviewed by a security expert. return x.abs.bitLen() } // TrailingZeroBits returns the number of consecutive least significant zero // bits of |x|. func (x *Int) TrailingZeroBits() uint { return x.abs.trailingZeroBits() } // Exp sets z = x**y mod |m| (i.e. the sign of m is ignored), and returns z. // If m == nil or m == 0, z = x**y unless y <= 0 then z = 1. If m != 0, y < 0, // and x and m are not relatively prime, z is unchanged and nil is returned. // // Modular exponentiation of inputs of a particular size is not a // cryptographically constant-time operation. func (z *Int) Exp(x, y, m *Int) *Int { return z.exp(x, y, m, false) } func (z *Int) expSlow(x, y, m *Int) *Int { return z.exp(x, y, m, true) } func (z *Int) exp(x, y, m *Int, slow bool) *Int { // See Knuth, volume 2, section 4.6.3. xWords := x.abs if y.neg { if m == nil || len(m.abs) == 0 { return z.SetInt64(1) } // for y < 0: x**y mod m == (x**(-1))**|y| mod m inverse := new(Int).ModInverse(x, m) if inverse == nil { return nil } xWords = inverse.abs } yWords := y.abs var mWords nat if m != nil { if z == m || alias(z.abs, m.abs) { m = new(Int).Set(m) } mWords = m.abs // m.abs may be nil for m == 0 } z.abs = z.abs.expNN(nil, xWords, yWords, mWords, slow) z.neg = len(z.abs) > 0 && x.neg && len(yWords) > 0 && yWords[0]&1 == 1 // 0 has no sign if z.neg && len(mWords) > 0 { // make modulus result positive z.abs = z.abs.sub(mWords, z.abs) // z == x**y mod |m| && 0 <= z < |m| z.neg = false } return z } // GCD sets z to the greatest common divisor of a and b and returns z. // If x or y are not nil, GCD sets their value such that z = a*x + b*y. // // a and b may be positive, zero or negative. (Before Go 1.14 both had // to be > 0.) Regardless of the signs of a and b, z is always >= 0. // // If a == b == 0, GCD sets z = x = y = 0. // // If a == 0 and b != 0, GCD sets z = |b|, x = 0, y = sign(b) * 1. // // If a != 0 and b == 0, GCD sets z = |a|, x = sign(a) * 1, y = 0. func (z *Int) GCD(x, y, a, b *Int) *Int { if len(a.abs) == 0 || len(b.abs) == 0 { lenA, lenB, negA, negB := len(a.abs), len(b.abs), a.neg, b.neg if lenA == 0 { z.Set(b) } else { z.Set(a) } z.neg = false if x != nil { if lenA == 0 { x.SetUint64(0) } else { x.SetUint64(1) x.neg = negA } } if y != nil { if lenB == 0 { y.SetUint64(0) } else { y.SetUint64(1) y.neg = negB } } return z } return z.lehmerGCD(x, y, a, b) } // lehmerSimulate attempts to simulate several Euclidean update steps // using the leading digits of A and B. It returns u0, u1, v0, v1 // such that A and B can be updated as: // // A = u0*A + v0*B // B = u1*A + v1*B // // Requirements: A >= B and len(B.abs) >= 2 // Since we are calculating with full words to avoid overflow, // we use 'even' to track the sign of the cosequences. // For even iterations: u0, v1 >= 0 && u1, v0 <= 0 // For odd iterations: u0, v1 <= 0 && u1, v0 >= 0 func lehmerSimulate(A, B *Int) (u0, u1, v0, v1 Word, even bool) { // initialize the digits var a1, a2, u2, v2 Word m := len(B.abs) // m >= 2 n := len(A.abs) // n >= m >= 2 // extract the top Word of bits from A and B h := nlz(A.abs[n-1]) a1 = A.abs[n-1]<<h | A.abs[n-2]>>(_W-h) // B may have implicit zero words in the high bits if the lengths differ switch { case n == m: a2 = B.abs[n-1]<<h | B.abs[n-2]>>(_W-h) case n == m+1: a2 = B.abs[n-2] >> (_W - h) default: a2 = 0 } // Since we are calculating with full words to avoid overflow, // we use 'even' to track the sign of the cosequences. // For even iterations: u0, v1 >= 0 && u1, v0 <= 0 // For odd iterations: u0, v1 <= 0 && u1, v0 >= 0 // The first iteration starts with k=1 (odd). even = false // variables to track the cosequences u0, u1, u2 = 0, 1, 0 v0, v1, v2 = 0, 0, 1 // Calculate the quotient and cosequences using Collins' stopping condition. // Note that overflow of a Word is not possible when computing the remainder // sequence and cosequences since the cosequence size is bounded by the input size. // See section 4.2 of Jebelean for details. for a2 >= v2 && a1-a2 >= v1+v2 { q, r := a1/a2, a1%a2 a1, a2 = a2, r u0, u1, u2 = u1, u2, u1+q*u2 v0, v1, v2 = v1, v2, v1+q*v2 even = !even } return } // lehmerUpdate updates the inputs A and B such that: // // A = u0*A + v0*B // B = u1*A + v1*B // // where the signs of u0, u1, v0, v1 are given by even // For even == true: u0, v1 >= 0 && u1, v0 <= 0 // For even == false: u0, v1 <= 0 && u1, v0 >= 0 // q, r, s, t are temporary variables to avoid allocations in the multiplication. func lehmerUpdate(A, B, q, r *Int, u0, u1, v0, v1 Word, even bool) { mulW(q, B, even, v0) mulW(r, A, even, u1) mulW(A, A, !even, u0) mulW(B, B, !even, v1) A.Add(A, q) B.Add(B, r) } // mulW sets z = x * (-?)w // where the minus sign is present when neg is true. func mulW(z, x *Int, neg bool, w Word) { z.abs = z.abs.mulAddWW(x.abs, w, 0) z.neg = x.neg != neg } // euclidUpdate performs a single step of the Euclidean GCD algorithm // if extended is true, it also updates the cosequence Ua, Ub. // q and r are used as temporaries; the initial values are ignored. func euclidUpdate(A, B, Ua, Ub, q, r *Int, extended bool) (nA, nB, nr, nUa, nUb *Int) { q.QuoRem(A, B, r) if extended { // Ua, Ub = Ub, Ua-q*Ub q.Mul(q, Ub) Ua, Ub = Ub, Ua Ub.Sub(Ub, q) } return B, r, A, Ua, Ub } // lehmerGCD sets z to the greatest common divisor of a and b, // which both must be != 0, and returns z. // If x or y are not nil, their values are set such that z = a*x + b*y. // See Knuth, The Art of Computer Programming, Vol. 2, Section 4.5.2, Algorithm L. // This implementation uses the improved condition by Collins requiring only one // quotient and avoiding the possibility of single Word overflow. // See Jebelean, "Improving the multiprecision Euclidean algorithm", // Design and Implementation of Symbolic Computation Systems, pp 45-58. // The cosequences are updated according to Algorithm 10.45 from // Cohen et al. "Handbook of Elliptic and Hyperelliptic Curve Cryptography" pp 192. func (z *Int) lehmerGCD(x, y, a, b *Int) *Int { var A, B, Ua, Ub *Int A = new(Int).Abs(a) B = new(Int).Abs(b) extended := x != nil || y != nil if extended { // Ua (Ub) tracks how many times input a has been accumulated into A (B). Ua = new(Int).SetInt64(1) Ub = new(Int) } // temp variables for multiprecision update q := new(Int) r := new(Int) // ensure A >= B if A.abs.cmp(B.abs) < 0 { A, B = B, A Ub, Ua = Ua, Ub } // loop invariant A >= B for len(B.abs) > 1 { // Attempt to calculate in single-precision using leading words of A and B. u0, u1, v0, v1, even := lehmerSimulate(A, B) // multiprecision Step if v0 != 0 { // Simulate the effect of the single-precision steps using the cosequences. // A = u0*A + v0*B // B = u1*A + v1*B lehmerUpdate(A, B, q, r, u0, u1, v0, v1, even) if extended { // Ua = u0*Ua + v0*Ub // Ub = u1*Ua + v1*Ub lehmerUpdate(Ua, Ub, q, r, u0, u1, v0, v1, even) } } else { // Single-digit calculations failed to simulate any quotients. // Do a standard Euclidean step. A, B, r, Ua, Ub = euclidUpdate(A, B, Ua, Ub, q, r, extended) } } if len(B.abs) > 0 { // extended Euclidean algorithm base case if B is a single Word if len(A.abs) > 1 { // A is longer than a single Word, so one update is needed. A, B, r, Ua, Ub = euclidUpdate(A, B, Ua, Ub, q, r, extended) } if len(B.abs) > 0 { // A and B are both a single Word. aWord, bWord := A.abs[0], B.abs[0] if extended { var ua, ub, va, vb Word ua, ub = 1, 0 va, vb = 0, 1 even := true for bWord != 0 { q, r := aWord/bWord, aWord%bWord aWord, bWord = bWord, r ua, ub = ub, ua+q*ub va, vb = vb, va+q*vb even = !even } mulW(Ua, Ua, !even, ua) mulW(Ub, Ub, even, va) Ua.Add(Ua, Ub) } else { for bWord != 0 { aWord, bWord = bWord, aWord%bWord } } A.abs[0] = aWord } } negA := a.neg if y != nil { // avoid aliasing b needed in the division below if y == b { B.Set(b) } else { B = b } // y = (z - a*x)/b y.Mul(a, Ua) // y can safely alias a if negA { y.neg = !y.neg } y.Sub(A, y) y.Div(y, B) } if x != nil { x.Set(Ua) if negA { x.neg = !x.neg } } z.Set(A) return z } // Rand sets z to a pseudo-random number in [0, n) and returns z. // // As this uses the [math/rand] package, it must not be used for // security-sensitive work. Use [crypto/rand.Int] instead. func (z *Int) Rand(rnd *rand.Rand, n *Int) *Int { // z.neg is not modified before the if check, because z and n might alias. if n.neg || len(n.abs) == 0 { z.neg = false z.abs = nil return z } z.neg = false z.abs = z.abs.random(rnd, n.abs, n.abs.bitLen()) return z } // ModInverse sets z to the multiplicative inverse of g in the ring ℤ/nℤ // and returns z. If g and n are not relatively prime, g has no multiplicative // inverse in the ring ℤ/nℤ. In this case, z is unchanged and the return value // is nil. If n == 0, a division-by-zero run-time panic occurs. func (z *Int) ModInverse(g, n *Int) *Int { // GCD expects parameters a and b to be > 0. if n.neg { var n2 Int n = n2.Neg(n) } if g.neg { var g2 Int g = g2.Mod(g, n) } var d, x Int d.GCD(&x, nil, g, n) // if and only if d==1, g and n are relatively prime if d.Cmp(intOne) != 0 { return nil } // x and y are such that g*x + n*y = 1, therefore x is the inverse element, // but it may be negative, so convert to the range 0 <= z < |n| if x.neg { z.Add(&x, n) } else { z.Set(&x) } return z } func (z nat) modInverse(g, n nat) nat { // TODO(rsc): ModInverse should be implemented in terms of this function. return (&Int{abs: z}).ModInverse(&Int{abs: g}, &Int{abs: n}).abs } // Jacobi returns the Jacobi symbol (x/y), either +1, -1, or 0. // The y argument must be an odd integer. func Jacobi(x, y *Int) int { if len(y.abs) == 0 || y.abs[0]&1 == 0 { panic(fmt.Sprintf("big: invalid 2nd argument to Int.Jacobi: need odd integer but got %s", y.String())) } // We use the formulation described in chapter 2, section 2.4, // "The Yacas Book of Algorithms": // http://yacas.sourceforge.net/Algo.book.pdf var a, b, c Int a.Set(x) b.Set(y) j := 1 if b.neg { if a.neg { j = -1 } b.neg = false } for { if b.Cmp(intOne) == 0 { return j } if len(a.abs) == 0 { return 0 } a.Mod(&a, &b) if len(a.abs) == 0 { return 0 } // a > 0 // handle factors of 2 in 'a' s := a.abs.trailingZeroBits() if s&1 != 0 { bmod8 := b.abs[0] & 7 if bmod8 == 3 || bmod8 == 5 { j = -j } } c.Rsh(&a, s) // a = 2^s*c // swap numerator and denominator if b.abs[0]&3 == 3 && c.abs[0]&3 == 3 { j = -j } a.Set(&b) b.Set(&c) } } // modSqrt3Mod4 uses the identity // // (a^((p+1)/4))^2 mod p // == u^(p+1) mod p // == u^2 mod p // // to calculate the square root of any quadratic residue mod p quickly for 3 // mod 4 primes. func (z *Int) modSqrt3Mod4Prime(x, p *Int) *Int { e := new(Int).Add(p, intOne) // e = p + 1 e.Rsh(e, 2) // e = (p + 1) / 4 z.Exp(x, e, p) // z = x^e mod p return z } // modSqrt5Mod8Prime uses Atkin's observation that 2 is not a square mod p // // alpha == (2*a)^((p-5)/8) mod p // beta == 2*a*alpha^2 mod p is a square root of -1 // b == a*alpha*(beta-1) mod p is a square root of a // // to calculate the square root of any quadratic residue mod p quickly for 5 // mod 8 primes. func (z *Int) modSqrt5Mod8Prime(x, p *Int) *Int { // p == 5 mod 8 implies p = e*8 + 5 // e is the quotient and 5 the remainder on division by 8 e := new(Int).Rsh(p, 3) // e = (p - 5) / 8 tx := new(Int).Lsh(x, 1) // tx = 2*x alpha := new(Int).Exp(tx, e, p) beta := new(Int).Mul(alpha, alpha) beta.Mod(beta, p) beta.Mul(beta, tx) beta.Mod(beta, p) beta.Sub(beta, intOne) beta.Mul(beta, x) beta.Mod(beta, p) beta.Mul(beta, alpha) z.Mod(beta, p) return z } // modSqrtTonelliShanks uses the Tonelli-Shanks algorithm to find the square // root of a quadratic residue modulo any prime. func (z *Int) modSqrtTonelliShanks(x, p *Int) *Int { // Break p-1 into s*2^e such that s is odd. var s Int s.Sub(p, intOne) e := s.abs.trailingZeroBits() s.Rsh(&s, e) // find some non-square n var n Int n.SetInt64(2) for Jacobi(&n, p) != -1 { n.Add(&n, intOne) } // Core of the Tonelli-Shanks algorithm. Follows the description in // section 6 of "Square roots from 1; 24, 51, 10 to Dan Shanks" by Ezra // Brown: // https://www.maa.org/sites/default/files/pdf/upload_library/22/Polya/07468342.di020786.02p0470a.pdf var y, b, g, t Int y.Add(&s, intOne) y.Rsh(&y, 1) y.Exp(x, &y, p) // y = x^((s+1)/2) b.Exp(x, &s, p) // b = x^s g.Exp(&n, &s, p) // g = n^s r := e for { // find the least m such that ord_p(b) = 2^m var m uint t.Set(&b) for t.Cmp(intOne) != 0 { t.Mul(&t, &t).Mod(&t, p) m++ } if m == 0 { return z.Set(&y) } t.SetInt64(0).SetBit(&t, int(r-m-1), 1).Exp(&g, &t, p) // t = g^(2^(r-m-1)) mod p g.Mul(&t, &t).Mod(&g, p) // g = g^(2^(r-m)) mod p y.Mul(&y, &t).Mod(&y, p) b.Mul(&b, &g).Mod(&b, p) r = m } } // ModSqrt sets z to a square root of x mod p if such a square root exists, and // returns z. The modulus p must be an odd prime. If x is not a square mod p, // ModSqrt leaves z unchanged and returns nil. This function panics if p is // not an odd integer, its behavior is undefined if p is odd but not prime. func (z *Int) ModSqrt(x, p *Int) *Int { switch Jacobi(x, p) { case -1: return nil // x is not a square mod p case 0: return z.SetInt64(0) // sqrt(0) mod p = 0 case 1: break } if x.neg || x.Cmp(p) >= 0 { // ensure 0 <= x < p x = new(Int).Mod(x, p) } switch { case p.abs[0]%4 == 3: // Check whether p is 3 mod 4, and if so, use the faster algorithm. return z.modSqrt3Mod4Prime(x, p) case p.abs[0]%8 == 5: // Check whether p is 5 mod 8, use Atkin's algorithm. return z.modSqrt5Mod8Prime(x, p) default: // Otherwise, use Tonelli-Shanks. return z.modSqrtTonelliShanks(x, p) } } // Lsh sets z = x << n and returns z. func (z *Int) Lsh(x *Int, n uint) *Int { z.abs = z.abs.lsh(x.abs, n) z.neg = x.neg return z } // Rsh sets z = x >> n and returns z. func (z *Int) Rsh(x *Int, n uint) *Int { if x.neg { // (-x) >> s == ^(x-1) >> s == ^((x-1) >> s) == -(((x-1) >> s) + 1) t := z.abs.sub(x.abs, natOne) // no underflow because |x| > 0 t = t.rsh(t, n) z.abs = t.add(t, natOne) z.neg = true // z cannot be zero if x is negative return z } z.abs = z.abs.rsh(x.abs, n) z.neg = false return z } // Bit returns the value of the i'th bit of x. That is, it // returns (x>>i)&1. The bit index i must be >= 0. func (x *Int) Bit(i int) uint { if i == 0 { // optimization for common case: odd/even test of x if len(x.abs) > 0 { return uint(x.abs[0] & 1) // bit 0 is same for -x } return 0 } if i < 0 { panic("negative bit index") } if x.neg { t := nat(nil).sub(x.abs, natOne) return t.bit(uint(i)) ^ 1 } return x.abs.bit(uint(i)) } // SetBit sets z to x, with x's i'th bit set to b (0 or 1). // That is, // - if b is 1, SetBit sets z = x | (1 << i); // - if b is 0, SetBit sets z = x &^ (1 << i); // - if b is not 0 or 1, SetBit will panic. func (z *Int) SetBit(x *Int, i int, b uint) *Int { if i < 0 { panic("negative bit index") } if x.neg { t := z.abs.sub(x.abs, natOne) t = t.setBit(t, uint(i), b^1) z.abs = t.add(t, natOne) z.neg = len(z.abs) > 0 return z } z.abs = z.abs.setBit(x.abs, uint(i), b) z.neg = false return z } // And sets z = x & y and returns z. func (z *Int) And(x, y *Int) *Int { if x.neg == y.neg { if x.neg { // (-x) & (-y) == ^(x-1) & ^(y-1) == ^((x-1) | (y-1)) == -(((x-1) | (y-1)) + 1) x1 := nat(nil).sub(x.abs, natOne) y1 := nat(nil).sub(y.abs, natOne) z.abs = z.abs.add(z.abs.or(x1, y1), natOne) z.neg = true // z cannot be zero if x and y are negative return z } // x & y == x & y z.abs = z.abs.and(x.abs, y.abs) z.neg = false return z } // x.neg != y.neg if x.neg { x, y = y, x // & is symmetric } // x & (-y) == x & ^(y-1) == x &^ (y-1) y1 := nat(nil).sub(y.abs, natOne) z.abs = z.abs.andNot(x.abs, y1) z.neg = false return z } // AndNot sets z = x &^ y and returns z. func (z *Int) AndNot(x, y *Int) *Int { if x.neg == y.neg { if x.neg { // (-x) &^ (-y) == ^(x-1) &^ ^(y-1) == ^(x-1) & (y-1) == (y-1) &^ (x-1) x1 := nat(nil).sub(x.abs, natOne) y1 := nat(nil).sub(y.abs, natOne) z.abs = z.abs.andNot(y1, x1) z.neg = false return z } // x &^ y == x &^ y z.abs = z.abs.andNot(x.abs, y.abs) z.neg = false return z } if x.neg { // (-x) &^ y == ^(x-1) &^ y == ^(x-1) & ^y == ^((x-1) | y) == -(((x-1) | y) + 1) x1 := nat(nil).sub(x.abs, natOne) z.abs = z.abs.add(z.abs.or(x1, y.abs), natOne) z.neg = true // z cannot be zero if x is negative and y is positive return z } // x &^ (-y) == x &^ ^(y-1) == x & (y-1) y1 := nat(nil).sub(y.abs, natOne) z.abs = z.abs.and(x.abs, y1) z.neg = false return z } // Or sets z = x | y and returns z. func (z *Int) Or(x, y *Int) *Int { if x.neg == y.neg { if x.neg { // (-x) | (-y) == ^(x-1) | ^(y-1) == ^((x-1) & (y-1)) == -(((x-1) & (y-1)) + 1) x1 := nat(nil).sub(x.abs, natOne) y1 := nat(nil).sub(y.abs, natOne) z.abs = z.abs.add(z.abs.and(x1, y1), natOne) z.neg = true // z cannot be zero if x and y are negative return z } // x | y == x | y z.abs = z.abs.or(x.abs, y.abs) z.neg = false return z } // x.neg != y.neg if x.neg { x, y = y, x // | is symmetric } // x | (-y) == x | ^(y-1) == ^((y-1) &^ x) == -(^((y-1) &^ x) + 1) y1 := nat(nil).sub(y.abs, natOne) z.abs = z.abs.add(z.abs.andNot(y1, x.abs), natOne) z.neg = true // z cannot be zero if one of x or y is negative return z } // Xor sets z = x ^ y and returns z. func (z *Int) Xor(x, y *Int) *Int { if x.neg == y.neg { if x.neg { // (-x) ^ (-y) == ^(x-1) ^ ^(y-1) == (x-1) ^ (y-1) x1 := nat(nil).sub(x.abs, natOne) y1 := nat(nil).sub(y.abs, natOne) z.abs = z.abs.xor(x1, y1) z.neg = false return z } // x ^ y == x ^ y z.abs = z.abs.xor(x.abs, y.abs) z.neg = false return z } // x.neg != y.neg if x.neg { x, y = y, x // ^ is symmetric } // x ^ (-y) == x ^ ^(y-1) == ^(x ^ (y-1)) == -((x ^ (y-1)) + 1) y1 := nat(nil).sub(y.abs, natOne) z.abs = z.abs.add(z.abs.xor(x.abs, y1), natOne) z.neg = true // z cannot be zero if only one of x or y is negative return z } // Not sets z = ^x and returns z. func (z *Int) Not(x *Int) *Int { if x.neg { // ^(-x) == ^(^(x-1)) == x-1 z.abs = z.abs.sub(x.abs, natOne) z.neg = false return z } // ^x == -x-1 == -(x+1) z.abs = z.abs.add(x.abs, natOne) z.neg = true // z cannot be zero if x is positive return z } // Sqrt sets z to ⌊√x⌋, the largest integer such that z² ≤ x, and returns z. // It panics if x is negative. func (z *Int) Sqrt(x *Int) *Int { if x.neg { panic("square root of negative number") } z.neg = false z.abs = z.abs.sqrt(nil, x.abs) return z }
// Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements int-to-string conversion functions. package big import ( "errors" "fmt" "io" ) // Text returns the string representation of x in the given base. // Base must be between 2 and 62, inclusive. The result uses the // lower-case letters 'a' to 'z' for digit values 10 to 35, and // the upper-case letters 'A' to 'Z' for digit values 36 to 61. // No prefix (such as "0x") is added to the string. If x is a nil // pointer it returns "<nil>". func (x *Int) Text(base int) string { if x == nil { return "<nil>" } return string(x.abs.itoa(x.neg, base)) } // Append appends the string representation of x, as generated by // x.Text(base), to buf and returns the extended buffer. func (x *Int) Append(buf []byte, base int) []byte { if x == nil { return append(buf, "<nil>"...) } return append(buf, x.abs.itoa(x.neg, base)...) } // String returns the decimal representation of x as generated by // x.Text(10). func (x *Int) String() string { return x.Text(10) } // write count copies of text to s. func writeMultiple(s fmt.State, text string, count int) { if len(text) > 0 { b := []byte(text) for ; count > 0; count-- { s.Write(b) } } } var _ fmt.Formatter = intOne // *Int must implement fmt.Formatter // Format implements [fmt.Formatter]. It accepts the formats // 'b' (binary), 'o' (octal with 0 prefix), 'O' (octal with 0o prefix), // 'd' (decimal), 'x' (lowercase hexadecimal), and // 'X' (uppercase hexadecimal). // Also supported are the full suite of package fmt's format // flags for integral types, including '+' and ' ' for sign // control, '#' for leading zero in octal and for hexadecimal, // a leading "0x" or "0X" for "%#x" and "%#X" respectively, // specification of minimum digits precision, output field // width, space or zero padding, and '-' for left or right // justification. func (x *Int) Format(s fmt.State, ch rune) { // determine base var base int switch ch { case 'b': base = 2 case 'o', 'O': base = 8 case 'd', 's', 'v': base = 10 case 'x', 'X': base = 16 default: // unknown format fmt.Fprintf(s, "%%!%c(big.Int=%s)", ch, x.String()) return } if x == nil { fmt.Fprint(s, "<nil>") return } // determine sign character sign := "" switch { case x.neg: sign = "-" case s.Flag('+'): // supersedes ' ' when both specified sign = "+" case s.Flag(' '): sign = " " } // determine prefix characters for indicating output base prefix := "" if s.Flag('#') { switch ch { case 'b': // binary prefix = "0b" case 'o': // octal prefix = "0" case 'x': // hexadecimal prefix = "0x" case 'X': prefix = "0X" } } if ch == 'O' { prefix = "0o" } digits := x.abs.utoa(base) if ch == 'X' { // faster than bytes.ToUpper for i, d := range digits { if 'a' <= d && d <= 'z' { digits[i] = 'A' + (d - 'a') } } } // number of characters for the three classes of number padding var left int // space characters to left of digits for right justification ("%8d") var zeros int // zero characters (actually cs[0]) as left-most digits ("%.8d") var right int // space characters to right of digits for left justification ("%-8d") // determine number padding from precision: the least number of digits to output precision, precisionSet := s.Precision() if precisionSet { switch { case len(digits) < precision: zeros = precision - len(digits) // count of zero padding case len(digits) == 1 && digits[0] == '0' && precision == 0: return // print nothing if zero value (x == 0) and zero precision ("." or ".0") } } // determine field pad from width: the least number of characters to output length := len(sign) + len(prefix) + zeros + len(digits) if width, widthSet := s.Width(); widthSet && length < width { // pad as specified switch d := width - length; { case s.Flag('-'): // pad on the right with spaces; supersedes '0' when both specified right = d case s.Flag('0') && !precisionSet: // pad with zeros unless precision also specified zeros = d default: // pad on the left with spaces left = d } } // print number as [left pad][sign][prefix][zero pad][digits][right pad] writeMultiple(s, " ", left) writeMultiple(s, sign, 1) writeMultiple(s, prefix, 1) writeMultiple(s, "0", zeros) s.Write(digits) writeMultiple(s, " ", right) } // scan sets z to the integer value corresponding to the longest possible prefix // read from r representing a signed integer number in a given conversion base. // It returns z, the actual conversion base used, and an error, if any. In the // error case, the value of z is undefined but the returned value is nil. The // syntax follows the syntax of integer literals in Go. // // The base argument must be 0 or a value from 2 through MaxBase. If the base // is 0, the string prefix determines the actual conversion base. A prefix of // “0b” or “0B” selects base 2; a “0”, “0o”, or “0O” prefix selects // base 8, and a “0x” or “0X” prefix selects base 16. Otherwise the selected // base is 10. func (z *Int) scan(r io.ByteScanner, base int) (*Int, int, error) { // determine sign neg, err := scanSign(r) if err != nil { return nil, 0, err } // determine mantissa z.abs, base, _, err = z.abs.scan(r, base, false) if err != nil { return nil, base, err } z.neg = len(z.abs) > 0 && neg // 0 has no sign return z, base, nil } func scanSign(r io.ByteScanner) (neg bool, err error) { var ch byte if ch, err = r.ReadByte(); err != nil { return false, err } switch ch { case '-': neg = true case '+': // nothing to do default: r.UnreadByte() } return } // byteReader is a local wrapper around fmt.ScanState; // it implements the ByteReader interface. type byteReader struct { fmt.ScanState } func (r byteReader) ReadByte() (byte, error) { ch, size, err := r.ReadRune() if size != 1 && err == nil { err = fmt.Errorf("invalid rune %#U", ch) } return byte(ch), err } func (r byteReader) UnreadByte() error { return r.UnreadRune() } var _ fmt.Scanner = intOne // *Int must implement fmt.Scanner // Scan is a support routine for [fmt.Scanner]; it sets z to the value of // the scanned number. It accepts the formats 'b' (binary), 'o' (octal), // 'd' (decimal), 'x' (lowercase hexadecimal), and 'X' (uppercase hexadecimal). func (z *Int) Scan(s fmt.ScanState, ch rune) error { s.SkipSpace() // skip leading space characters base := 0 switch ch { case 'b': base = 2 case 'o': base = 8 case 'd': base = 10 case 'x', 'X': base = 16 case 's', 'v': // let scan determine the base default: return errors.New("Int.Scan: invalid verb") } _, _, err := z.scan(byteReader{s}, base) return err }
// Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements encoding/decoding of Ints. package big import ( "bytes" "fmt" ) // Gob codec version. Permits backward-compatible changes to the encoding. const intGobVersion byte = 1 // GobEncode implements the [encoding/gob.GobEncoder] interface. func (x *Int) GobEncode() ([]byte, error) { if x == nil { return nil, nil } buf := make([]byte, 1+len(x.abs)*_S) // extra byte for version and sign bit i := x.abs.bytes(buf) - 1 // i >= 0 b := intGobVersion << 1 // make space for sign bit if x.neg { b |= 1 } buf[i] = b return buf[i:], nil } // GobDecode implements the [encoding/gob.GobDecoder] interface. func (z *Int) GobDecode(buf []byte) error { if len(buf) == 0 { // Other side sent a nil or default value. *z = Int{} return nil } b := buf[0] if b>>1 != intGobVersion { return fmt.Errorf("Int.GobDecode: encoding version %d not supported", b>>1) } z.neg = b&1 != 0 z.abs = z.abs.setBytes(buf[1:]) return nil } // AppendText implements the [encoding.TextAppender] interface. func (x *Int) AppendText(b []byte) (text []byte, err error) { return x.Append(b, 10), nil } // MarshalText implements the [encoding.TextMarshaler] interface. func (x *Int) MarshalText() (text []byte, err error) { return x.AppendText(nil) } // UnmarshalText implements the [encoding.TextUnmarshaler] interface. func (z *Int) UnmarshalText(text []byte) error { if _, ok := z.setFromScanner(bytes.NewReader(text), 0); !ok { return fmt.Errorf("math/big: cannot unmarshal %q into a *big.Int", text) } return nil } // The JSON marshalers are only here for API backward compatibility // (programs that explicitly look for these two methods). JSON works // fine with the TextMarshaler only. // MarshalJSON implements the [encoding/json.Marshaler] interface. func (x *Int) MarshalJSON() ([]byte, error) { if x == nil { return []byte("null"), nil } return x.abs.itoa(x.neg, 10), nil } // UnmarshalJSON implements the [encoding/json.Unmarshaler] interface. func (z *Int) UnmarshalJSON(text []byte) error { // Ignore null, like in the main JSON package. if string(text) == "null" { return nil } return z.UnmarshalText(text) }
// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements unsigned multi-precision integers (natural // numbers). They are the building blocks for the implementation // of signed integers, rationals, and floating-point numbers. // // Caution: This implementation relies on the function "alias" // which assumes that (nat) slice capacities are never // changed (no 3-operand slice expressions). If that // changes, alias needs to be updated for correctness. package big import ( "internal/byteorder" "math/bits" "math/rand" "slices" "sync" ) // An unsigned integer x of the form // // x = x[n-1]*_B^(n-1) + x[n-2]*_B^(n-2) + ... + x[1]*_B + x[0] // // with 0 <= x[i] < _B and 0 <= i < n is stored in a slice of length n, // with the digits x[i] as the slice elements. // // A number is normalized if the slice contains no leading 0 digits. // During arithmetic operations, denormalized values may occur but are // always normalized before returning the final result. The normalized // representation of 0 is the empty or nil slice (length = 0). type nat []Word var ( natOne = nat{1} natTwo = nat{2} natFive = nat{5} natTen = nat{10} ) func (z nat) String() string { return "0x" + string(z.itoa(false, 16)) } func (z nat) norm() nat { i := len(z) for i > 0 && z[i-1] == 0 { i-- } return z[0:i] } func (z nat) make(n int) nat { if n <= cap(z) { return z[:n] // reuse z } if n == 1 { // Most nats start small and stay that way; don't over-allocate. return make(nat, 1) } // Choosing a good value for e has significant performance impact // because it increases the chance that a value can be reused. const e = 4 // extra capacity return make(nat, n, n+e) } func (z nat) setWord(x Word) nat { if x == 0 { return z[:0] } z = z.make(1) z[0] = x return z } func (z nat) setUint64(x uint64) nat { // single-word value if w := Word(x); uint64(w) == x { return z.setWord(w) } // 2-word value z = z.make(2) z[1] = Word(x >> 32) z[0] = Word(x) return z } func (z nat) set(x nat) nat { z = z.make(len(x)) copy(z, x) return z } func (z nat) add(x, y nat) nat { m := len(x) n := len(y) switch { case m < n: return z.add(y, x) case m == 0: // n == 0 because m >= n; result is 0 return z[:0] case n == 0: // result is x return z.set(x) } // m > 0 z = z.make(m + 1) c := addVV(z[:n], x[:n], y[:n]) if m > n { c = addVW(z[n:m], x[n:], c) } z[m] = c return z.norm() } func (z nat) sub(x, y nat) nat { m := len(x) n := len(y) switch { case m < n: panic("underflow") case m == 0: // n == 0 because m >= n; result is 0 return z[:0] case n == 0: // result is x return z.set(x) } // m > 0 z = z.make(m) c := subVV(z[:n], x[:n], y[:n]) if m > n { c = subVW(z[n:], x[n:], c) } if c != 0 { panic("underflow") } return z.norm() } func (x nat) cmp(y nat) (r int) { Retry: m := len(x) n := len(y) if m != n || m == 0 { switch { case m < n: if y[n-1] == 0 { y = y.norm() goto Retry } r = -1 case m > n: if x[m-1] == 0 { x = x.norm() goto Retry } r = 1 } return } i := m - 1 for i > 0 && x[i] == y[i] { i-- } switch { case x[i] < y[i]: r = -1 case x[i] > y[i]: r = 1 } return } // montgomery computes z mod m = x*y*2**(-n*_W) mod m, // assuming k = -1/m mod 2**_W. // z is used for storing the result which is returned; // z must not alias x, y or m. // See Gueron, "Efficient Software Implementations of Modular Exponentiation". // https://eprint.iacr.org/2011/239.pdf // In the terminology of that paper, this is an "Almost Montgomery Multiplication": // x and y are required to satisfy 0 <= z < 2**(n*_W) and then the result // z is guaranteed to satisfy 0 <= z < 2**(n*_W), but it may not be < m. func (z nat) montgomery(x, y, m nat, k Word, n int) nat { // This code assumes x, y, m are all the same length, n. // (required by addMulVVW and the for loop). // It also assumes that x, y are already reduced mod m, // or else the result will not be properly reduced. if len(x) != n || len(y) != n || len(m) != n { panic("math/big: mismatched montgomery number lengths") } z = z.make(n * 2) clear(z) var c Word for i := 0; i < n; i++ { d := y[i] c2 := addMulVVWW(z[i:n+i], z[i:n+i], x, d, 0) t := z[i] * k c3 := addMulVVWW(z[i:n+i], z[i:n+i], m, t, 0) cx := c + c2 cy := cx + c3 z[n+i] = cy if cx < c2 || cy < c3 { c = 1 } else { c = 0 } } if c != 0 { subVV(z[:n], z[n:], m) } else { copy(z[:n], z[n:]) } return z[:n] } // alias reports whether x and y share the same base array. // // Note: alias assumes that the capacity of underlying arrays // is never changed for nat values; i.e. that there are // no 3-operand slice expressions in this code (or worse, // reflect-based operations to the same effect). func alias(x, y nat) bool { return cap(x) > 0 && cap(y) > 0 && &x[0:cap(x)][cap(x)-1] == &y[0:cap(y)][cap(y)-1] } // addTo implements z += x; z must be long enough. // (we don't use nat.add because we need z to stay the same // slice, and we don't need to normalize z after each addition) func addTo(z, x nat) { if n := len(x); n > 0 { if c := addVV(z[:n], z[:n], x[:n]); c != 0 { if n < len(z) { addVW(z[n:], z[n:], c) } } } } // subFrom implements z -= x; z must be long enough. // (we don't use nat.sub because we need z to stay the same // slice, and we don't need to normalize z after each subtraction) func subFrom(z, x nat) { if n := len(x); n > 0 { if c := subVV(z[:n], z, x); c != 0 { if n < len(z) { subVW(z[n:], z[n:], c) } } } } // mulRange computes the product of all the unsigned integers in the // range [a, b] inclusively. If a > b (empty range), the result is 1. // The caller may pass stk == nil to request that mulRange obtain and release one itself. func (z nat) mulRange(stk *stack, a, b uint64) nat { switch { case a == 0: // cut long ranges short (optimization) return z.setUint64(0) case a > b: return z.setUint64(1) case a == b: return z.setUint64(a) case a+1 == b: return z.mul(stk, nat(nil).setUint64(a), nat(nil).setUint64(b)) } if stk == nil { stk = getStack() defer stk.free() } m := a + (b-a)/2 // avoid overflow return z.mul(stk, nat(nil).mulRange(stk, a, m), nat(nil).mulRange(stk, m+1, b)) } // A stack provides temporary storage for complex calculations // such as multiplication and division. // The stack is a simple slice of words, extended as needed // to hold all the temporary storage for a calculation. // In general, if a function takes a *stack, it expects a non-nil *stack. // However, certain functions may allow passing a nil *stack instead, // so that they can handle trivial stack-free cases without forcing the // caller to obtain and free a stack that will be unused. These functions // document that they accept a nil *stack in their doc comments. type stack struct { w []Word } var stackPool sync.Pool // getStack returns a temporary stack. // The caller must call [stack.free] to give up use of the stack when finished. func getStack() *stack { s, _ := stackPool.Get().(*stack) if s == nil { s = new(stack) } return s } // free returns the stack for use by another calculation. func (s *stack) free() { s.w = s.w[:0] stackPool.Put(s) } // save returns the current stack pointer. // A future call to restore with the same value // frees any temporaries allocated on the stack after the call to save. func (s *stack) save() int { return len(s.w) } // restore restores the stack pointer to n. // It is almost always invoked as // // defer stk.restore(stk.save()) // // which makes sure to pop any temporaries allocated in the current function // from the stack before returning. func (s *stack) restore(n int) { s.w = s.w[:n] } // nat returns a nat of n words, allocated on the stack. func (s *stack) nat(n int) nat { nr := (n + 3) &^ 3 // round up to multiple of 4 off := len(s.w) s.w = slices.Grow(s.w, nr) s.w = s.w[:off+nr] x := s.w[off : off+n : off+n] if n > 0 { x[0] = 0xfedcb } return x } // bitLen returns the length of x in bits. // Unlike most methods, it works even if x is not normalized. func (x nat) bitLen() int { // This function is used in cryptographic operations. It must not leak // anything but the Int's sign and bit size through side-channels. Any // changes must be reviewed by a security expert. if i := len(x) - 1; i >= 0 { // bits.Len uses a lookup table for the low-order bits on some // architectures. Neutralize any input-dependent behavior by setting all // bits after the first one bit. top := uint(x[i]) top |= top >> 1 top |= top >> 2 top |= top >> 4 top |= top >> 8 top |= top >> 16 top |= top >> 16 >> 16 // ">> 32" doesn't compile on 32-bit architectures return i*_W + bits.Len(top) } return 0 } // trailingZeroBits returns the number of consecutive least significant zero // bits of x. func (x nat) trailingZeroBits() uint { if len(x) == 0 { return 0 } var i uint for x[i] == 0 { i++ } // x[i] != 0 return i*_W + uint(bits.TrailingZeros(uint(x[i]))) } // isPow2 returns i, true when x == 2**i and 0, false otherwise. func (x nat) isPow2() (uint, bool) { var i uint for x[i] == 0 { i++ } if i == uint(len(x))-1 && x[i]&(x[i]-1) == 0 { return i*_W + uint(bits.TrailingZeros(uint(x[i]))), true } return 0, false } func same(x, y nat) bool { return len(x) == len(y) && len(x) > 0 && &x[0] == &y[0] } // z = x << s func (z nat) lsh(x nat, s uint) nat { if s == 0 { if same(z, x) { return z } if !alias(z, x) { return z.set(x) } } m := len(x) if m == 0 { return z[:0] } // m > 0 n := m + int(s/_W) z = z.make(n + 1) if s %= _W; s == 0 { copy(z[n-m:n], x) z[n] = 0 } else { z[n] = lshVU(z[n-m:n], x, s) } clear(z[0 : n-m]) return z.norm() } // z = x >> s func (z nat) rsh(x nat, s uint) nat { if s == 0 { if same(z, x) { return z } if !alias(z, x) { return z.set(x) } } m := len(x) n := m - int(s/_W) if n <= 0 { return z[:0] } // n > 0 z = z.make(n) if s %= _W; s == 0 { copy(z, x[m-n:]) } else { rshVU(z, x[m-n:], s) } return z.norm() } func (z nat) setBit(x nat, i uint, b uint) nat { j := int(i / _W) m := Word(1) << (i % _W) n := len(x) switch b { case 0: z = z.make(n) copy(z, x) if j >= n { // no need to grow return z } z[j] &^= m return z.norm() case 1: if j >= n { z = z.make(j + 1) clear(z[n:]) } else { z = z.make(n) } copy(z, x) z[j] |= m // no need to normalize return z } panic("set bit is not 0 or 1") } // bit returns the value of the i'th bit, with lsb == bit 0. func (x nat) bit(i uint) uint { j := i / _W if j >= uint(len(x)) { return 0 } // 0 <= j < len(x) return uint(x[j] >> (i % _W) & 1) } // sticky returns 1 if there's a 1 bit within the // i least significant bits, otherwise it returns 0. func (x nat) sticky(i uint) uint { j := i / _W if j >= uint(len(x)) { if len(x) == 0 { return 0 } return 1 } // 0 <= j < len(x) for _, x := range x[:j] { if x != 0 { return 1 } } if x[j]<<(_W-i%_W) != 0 { return 1 } return 0 } func (z nat) and(x, y nat) nat { m := len(x) n := len(y) if m > n { m = n } // m <= n z = z.make(m) for i := 0; i < m; i++ { z[i] = x[i] & y[i] } return z.norm() } // trunc returns z = x mod 2ⁿ. func (z nat) trunc(x nat, n uint) nat { w := (n + _W - 1) / _W if uint(len(x)) < w { return z.set(x) } z = z.make(int(w)) copy(z, x) if n%_W != 0 { z[len(z)-1] &= 1<<(n%_W) - 1 } return z.norm() } func (z nat) andNot(x, y nat) nat { m := len(x) n := len(y) if n > m { n = m } // m >= n z = z.make(m) for i := 0; i < n; i++ { z[i] = x[i] &^ y[i] } copy(z[n:m], x[n:m]) return z.norm() } func (z nat) or(x, y nat) nat { m := len(x) n := len(y) s := x if m < n { n, m = m, n s = y } // m >= n z = z.make(m) for i := 0; i < n; i++ { z[i] = x[i] | y[i] } copy(z[n:m], s[n:m]) return z.norm() } func (z nat) xor(x, y nat) nat { m := len(x) n := len(y) s := x if m < n { n, m = m, n s = y } // m >= n z = z.make(m) for i := 0; i < n; i++ { z[i] = x[i] ^ y[i] } copy(z[n:m], s[n:m]) return z.norm() } // random creates a random integer in [0..limit), using the space in z if // possible. n is the bit length of limit. func (z nat) random(rand *rand.Rand, limit nat, n int) nat { if alias(z, limit) { z = nil // z is an alias for limit - cannot reuse } z = z.make(len(limit)) bitLengthOfMSW := uint(n % _W) if bitLengthOfMSW == 0 { bitLengthOfMSW = _W } mask := Word((1 << bitLengthOfMSW) - 1) for { switch _W { case 32: for i := range z { z[i] = Word(rand.Uint32()) } case 64: for i := range z { z[i] = Word(rand.Uint32()) | Word(rand.Uint32())<<32 } default: panic("unknown word size") } z[len(limit)-1] &= mask if z.cmp(limit) < 0 { break } } return z.norm() } // If m != 0 (i.e., len(m) != 0), expNN sets z to x**y mod m; // otherwise it sets z to x**y. The result is the value of z. // The caller may pass stk == nil to request that expNN obtain and release one itself. func (z nat) expNN(stk *stack, x, y, m nat, slow bool) nat { if alias(z, x) || alias(z, y) { // We cannot allow in-place modification of x or y. z = nil } // x**y mod 1 == 0 if len(m) == 1 && m[0] == 1 { return z.setWord(0) } // m == 0 || m > 1 // x**0 == 1 if len(y) == 0 { return z.setWord(1) } // y > 0 // 0**y = 0 if len(x) == 0 { return z.setWord(0) } // x > 0 // 1**y = 1 if len(x) == 1 && x[0] == 1 { return z.setWord(1) } // x > 1 // x**1 == x if len(y) == 1 && y[0] == 1 && len(m) == 0 { return z.set(x) } if stk == nil { stk = getStack() defer stk.free() } if len(y) == 1 && y[0] == 1 { // len(m) > 0 return z.rem(stk, x, m) } // y > 1 if len(m) != 0 { // We likely end up being as long as the modulus. z = z.make(len(m)) // If the exponent is large, we use the Montgomery method for odd values, // and a 4-bit, windowed exponentiation for powers of two, // and a CRT-decomposed Montgomery method for the remaining values // (even values times non-trivial odd values, which decompose into one // instance of each of the first two cases). if len(y) > 1 && !slow { if m[0]&1 == 1 { return z.expNNMontgomery(stk, x, y, m) } if logM, ok := m.isPow2(); ok { return z.expNNWindowed(stk, x, y, logM) } return z.expNNMontgomeryEven(stk, x, y, m) } } z = z.set(x) v := y[len(y)-1] // v > 0 because y is normalized and y > 0 shift := nlz(v) + 1 v <<= shift var q nat const mask = 1 << (_W - 1) // We walk through the bits of the exponent one by one. Each time we // see a bit, we square, thus doubling the power. If the bit is a one, // we also multiply by x, thus adding one to the power. w := _W - int(shift) // zz and r are used to avoid allocating in mul and div as // otherwise the arguments would alias. var zz, r nat for j := 0; j < w; j++ { zz = zz.sqr(stk, z) zz, z = z, zz if v&mask != 0 { zz = zz.mul(stk, z, x) zz, z = z, zz } if len(m) != 0 { zz, r = zz.div(stk, r, z, m) zz, r, q, z = q, z, zz, r } v <<= 1 } for i := len(y) - 2; i >= 0; i-- { v = y[i] for j := 0; j < _W; j++ { zz = zz.sqr(stk, z) zz, z = z, zz if v&mask != 0 { zz = zz.mul(stk, z, x) zz, z = z, zz } if len(m) != 0 { zz, r = zz.div(stk, r, z, m) zz, r, q, z = q, z, zz, r } v <<= 1 } } return z.norm() } // expNNMontgomeryEven calculates x**y mod m where m = m1 × m2 for m1 = 2ⁿ and m2 odd. // It uses two recursive calls to expNN for x**y mod m1 and x**y mod m2 // and then uses the Chinese Remainder Theorem to combine the results. // The recursive call using m1 will use expNNWindowed, // while the recursive call using m2 will use expNNMontgomery. // For more details, see Ç. K. Koç, “Montgomery Reduction with Even Modulus”, // IEE Proceedings: Computers and Digital Techniques, 141(5) 314-316, September 1994. // http://www.people.vcu.edu/~jwang3/CMSC691/j34monex.pdf func (z nat) expNNMontgomeryEven(stk *stack, x, y, m nat) nat { // Split m = m₁ × m₂ where m₁ = 2ⁿ n := m.trailingZeroBits() m1 := nat(nil).lsh(natOne, n) m2 := nat(nil).rsh(m, n) // We want z = x**y mod m. // z₁ = x**y mod m1 = (x**y mod m) mod m1 = z mod m1 // z₂ = x**y mod m2 = (x**y mod m) mod m2 = z mod m2 // (We are using the math/big convention for names here, // where the computation is z = x**y mod m, so its parts are z1 and z2. // The paper is computing x = a**e mod n; it refers to these as x2 and z1.) z1 := nat(nil).expNN(stk, x, y, m1, false) z2 := nat(nil).expNN(stk, x, y, m2, false) // Reconstruct z from z₁, z₂ using CRT, using algorithm from paper, // which uses only a single modInverse (and an easy one at that). // p = (z₁ - z₂) × m₂⁻¹ (mod m₁) // z = z₂ + p × m₂ // The final addition is in range because: // z = z₂ + p × m₂ // ≤ z₂ + (m₁-1) × m₂ // < m₂ + (m₁-1) × m₂ // = m₁ × m₂ // = m. z = z.set(z2) // Compute (z₁ - z₂) mod m1 [m1 == 2**n] into z1. z1 = z1.subMod2N(z1, z2, n) // Reuse z2 for p = (z₁ - z₂) [in z1] * m2⁻¹ (mod m₁ [= 2ⁿ]). m2inv := nat(nil).modInverse(m2, m1) z2 = z2.mul(stk, z1, m2inv) z2 = z2.trunc(z2, n) // Reuse z1 for p * m2. z = z.add(z, z1.mul(stk, z2, m2)) return z } // expNNWindowed calculates x**y mod m using a fixed, 4-bit window, // where m = 2**logM. func (z nat) expNNWindowed(stk *stack, x, y nat, logM uint) nat { if len(y) <= 1 { panic("big: misuse of expNNWindowed") } if x[0]&1 == 0 { // len(y) > 1, so y > logM. // x is even, so x**y is a multiple of 2**y which is a multiple of 2**logM. return z.setWord(0) } if logM == 1 { return z.setWord(1) } // zz is used to avoid allocating in mul as otherwise // the arguments would alias. defer stk.restore(stk.save()) w := int((logM + _W - 1) / _W) zz := stk.nat(w) const n = 4 // powers[i] contains x^i. var powers [1 << n]nat for i := range powers { powers[i] = stk.nat(w) } powers[0] = powers[0].set(natOne) powers[1] = powers[1].trunc(x, logM) for i := 2; i < 1<<n; i += 2 { p2, p, p1 := &powers[i/2], &powers[i], &powers[i+1] *p = p.sqr(stk, *p2) *p = p.trunc(*p, logM) *p1 = p1.mul(stk, *p, x) *p1 = p1.trunc(*p1, logM) } // Because phi(2**logM) = 2**(logM-1), x**(2**(logM-1)) = 1, // so we can compute x**(y mod 2**(logM-1)) instead of x**y. // That is, we can throw away all but the bottom logM-1 bits of y. // Instead of allocating a new y, we start reading y at the right word // and truncate it appropriately at the start of the loop. i := len(y) - 1 mtop := int((logM - 2) / _W) // -2 because the top word of N bits is the (N-1)/W'th word. mmask := ^Word(0) if mbits := (logM - 1) & (_W - 1); mbits != 0 { mmask = (1 << mbits) - 1 } if i > mtop { i = mtop } advance := false z = z.setWord(1) for ; i >= 0; i-- { yi := y[i] if i == mtop { yi &= mmask } for j := 0; j < _W; j += n { if advance { // Account for use of 4 bits in previous iteration. // Unrolled loop for significant performance // gain. Use go test -bench=".*" in crypto/rsa // to check performance before making changes. zz = zz.sqr(stk, z) zz, z = z, zz z = z.trunc(z, logM) zz = zz.sqr(stk, z) zz, z = z, zz z = z.trunc(z, logM) zz = zz.sqr(stk, z) zz, z = z, zz z = z.trunc(z, logM) zz = zz.sqr(stk, z) zz, z = z, zz z = z.trunc(z, logM) } zz = zz.mul(stk, z, powers[yi>>(_W-n)]) zz, z = z, zz z = z.trunc(z, logM) yi <<= n advance = true } } return z.norm() } // expNNMontgomery calculates x**y mod m using a fixed, 4-bit window. // Uses Montgomery representation. func (z nat) expNNMontgomery(stk *stack, x, y, m nat) nat { numWords := len(m) // We want the lengths of x and m to be equal. // It is OK if x >= m as long as len(x) == len(m). if len(x) > numWords { _, x = nat(nil).div(stk, nil, x, m) // Note: now len(x) <= numWords, not guaranteed ==. } if len(x) < numWords { rr := make(nat, numWords) copy(rr, x) x = rr } // Ideally the precomputations would be performed outside, and reused // k0 = -m**-1 mod 2**_W. Algorithm from: Dumas, J.G. "On Newton–Raphson // Iteration for Multiplicative Inverses Modulo Prime Powers". k0 := 2 - m[0] t := m[0] - 1 for i := 1; i < _W; i <<= 1 { t *= t k0 *= (t + 1) } k0 = -k0 // RR = 2**(2*_W*len(m)) mod m RR := nat(nil).setWord(1) zz := nat(nil).lsh(RR, uint(2*numWords*_W)) _, RR = nat(nil).div(stk, RR, zz, m) if len(RR) < numWords { zz = zz.make(numWords) copy(zz, RR) RR = zz } // one = 1, with equal length to that of m one := make(nat, numWords) one[0] = 1 const n = 4 // powers[i] contains x^i var powers [1 << n]nat powers[0] = powers[0].montgomery(one, RR, m, k0, numWords) powers[1] = powers[1].montgomery(x, RR, m, k0, numWords) for i := 2; i < 1<<n; i++ { powers[i] = powers[i].montgomery(powers[i-1], powers[1], m, k0, numWords) } // initialize z = 1 (Montgomery 1) z = z.make(numWords) copy(z, powers[0]) zz = zz.make(numWords) // same windowed exponent, but with Montgomery multiplications for i := len(y) - 1; i >= 0; i-- { yi := y[i] for j := 0; j < _W; j += n { if i != len(y)-1 || j != 0 { zz = zz.montgomery(z, z, m, k0, numWords) z = z.montgomery(zz, zz, m, k0, numWords) zz = zz.montgomery(z, z, m, k0, numWords) z = z.montgomery(zz, zz, m, k0, numWords) } zz = zz.montgomery(z, powers[yi>>(_W-n)], m, k0, numWords) z, zz = zz, z yi <<= n } } // convert to regular number zz = zz.montgomery(z, one, m, k0, numWords) // One last reduction, just in case. // See golang.org/issue/13907. if zz.cmp(m) >= 0 { // Common case is m has high bit set; in that case, // since zz is the same length as m, there can be just // one multiple of m to remove. Just subtract. // We think that the subtract should be sufficient in general, // so do that unconditionally, but double-check, // in case our beliefs are wrong. // The div is not expected to be reached. zz = zz.sub(zz, m) if zz.cmp(m) >= 0 { _, zz = nat(nil).div(stk, nil, zz, m) } } return zz.norm() } // bytes writes the value of z into buf using big-endian encoding. // The value of z is encoded in the slice buf[i:]. If the value of z // cannot be represented in buf, bytes panics. The number i of unused // bytes at the beginning of buf is returned as result. func (z nat) bytes(buf []byte) (i int) { // This function is used in cryptographic operations. It must not leak // anything but the Int's sign and bit size through side-channels. Any // changes must be reviewed by a security expert. i = len(buf) for _, d := range z { for j := 0; j < _S; j++ { i-- if i >= 0 { buf[i] = byte(d) } else if byte(d) != 0 { panic("math/big: buffer too small to fit value") } d >>= 8 } } if i < 0 { i = 0 } for i < len(buf) && buf[i] == 0 { i++ } return } // bigEndianWord returns the contents of buf interpreted as a big-endian encoded Word value. func bigEndianWord(buf []byte) Word { if _W == 64 { return Word(byteorder.BEUint64(buf)) } return Word(byteorder.BEUint32(buf)) } // setBytes interprets buf as the bytes of a big-endian unsigned // integer, sets z to that value, and returns z. func (z nat) setBytes(buf []byte) nat { z = z.make((len(buf) + _S - 1) / _S) i := len(buf) for k := 0; i >= _S; k++ { z[k] = bigEndianWord(buf[i-_S : i]) i -= _S } if i > 0 { var d Word for s := uint(0); i > 0; s += 8 { d |= Word(buf[i-1]) << s i-- } z[len(z)-1] = d } return z.norm() } // sqrt sets z = ⌊√x⌋ // The caller may pass stk == nil to request that sqrt obtain and release one itself. func (z nat) sqrt(stk *stack, x nat) nat { if x.cmp(natOne) <= 0 { return z.set(x) } if alias(z, x) { z = nil } if stk == nil { stk = getStack() defer stk.free() } // Start with value known to be too large and repeat "z = ⌊(z + ⌊x/z⌋)/2⌋" until it stops getting smaller. // See Brent and Zimmermann, Modern Computer Arithmetic, Algorithm 1.13 (SqrtInt). // https://members.loria.fr/PZimmermann/mca/pub226.html // If x is one less than a perfect square, the sequence oscillates between the correct z and z+1; // otherwise it converges to the correct z and stays there. var z1, z2 nat z1 = z z1 = z1.setUint64(1) z1 = z1.lsh(z1, uint(x.bitLen()+1)/2) // must be ≥ √x for n := 0; ; n++ { z2, _ = z2.div(stk, nil, x, z1) z2 = z2.add(z2, z1) z2 = z2.rsh(z2, 1) if z2.cmp(z1) >= 0 { // z1 is answer. // Figure out whether z1 or z2 is currently aliased to z by looking at loop count. if n&1 == 0 { return z1 } return z.set(z1) } z1, z2 = z2, z1 } } // subMod2N returns z = (x - y) mod 2ⁿ. func (z nat) subMod2N(x, y nat, n uint) nat { if uint(x.bitLen()) > n { if alias(z, x) { // ok to overwrite x in place x = x.trunc(x, n) } else { x = nat(nil).trunc(x, n) } } if uint(y.bitLen()) > n { if alias(z, y) { // ok to overwrite y in place y = y.trunc(y, n) } else { y = nat(nil).trunc(y, n) } } if x.cmp(y) >= 0 { return z.sub(x, y) } // x - y < 0; x - y mod 2ⁿ = x - y + 2ⁿ = 2ⁿ - (y - x) = 1 + 2ⁿ-1 - (y - x) = 1 + ^(y - x). z = z.sub(y, x) for uint(len(z))*_W < n { z = append(z, 0) } for i := range z { z[i] = ^z[i] } z = z.trunc(z, n) return z.add(z, natOne) }
// Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements nat-to-string conversion functions. package big import ( "errors" "fmt" "io" "math" "math/bits" "slices" "sync" ) const digits = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" // Note: MaxBase = len(digits), but it must remain an untyped rune constant // for API compatibility. // MaxBase is the largest number base accepted for string conversions. const MaxBase = 10 + ('z' - 'a' + 1) + ('Z' - 'A' + 1) const maxBaseSmall = 10 + ('z' - 'a' + 1) // maxPow returns (b**n, n) such that b**n is the largest power b**n <= _M. // For instance maxPow(10) == (1e19, 19) for 19 decimal digits in a 64bit Word. // In other words, at most n digits in base b fit into a Word. // TODO(gri) replace this with a table, generated at build time. func maxPow(b Word) (p Word, n int) { p, n = b, 1 // assuming b <= _M for max := _M / b; p <= max; { // p == b**n && p <= max p *= b n++ } // p == b**n && p <= _M return } // pow returns x**n for n > 0, and 1 otherwise. func pow(x Word, n int) (p Word) { // n == sum of bi * 2**i, for 0 <= i < imax, and bi is 0 or 1 // thus x**n == product of x**(2**i) for all i where bi == 1 // (Russian Peasant Method for exponentiation) p = 1 for n > 0 { if n&1 != 0 { p *= x } x *= x n >>= 1 } return } // scan errors var ( errNoDigits = errors.New("number has no digits") errInvalSep = errors.New("'_' must separate successive digits") ) // scan scans the number corresponding to the longest possible prefix // from r representing an unsigned number in a given conversion base. // scan returns the corresponding natural number res, the actual base b, // a digit count, and a read or syntax error err, if any. // // For base 0, an underscore character “_” may appear between a base // prefix and an adjacent digit, and between successive digits; such // underscores do not change the value of the number, or the returned // digit count. Incorrect placement of underscores is reported as an // error if there are no other errors. If base != 0, underscores are // not recognized and thus terminate scanning like any other character // that is not a valid radix point or digit. // // number = mantissa | prefix pmantissa . // prefix = "0" [ "b" | "B" | "o" | "O" | "x" | "X" ] . // mantissa = digits "." [ digits ] | digits | "." digits . // pmantissa = [ "_" ] digits "." [ digits ] | [ "_" ] digits | "." digits . // digits = digit { [ "_" ] digit } . // digit = "0" ... "9" | "a" ... "z" | "A" ... "Z" . // // Unless fracOk is set, the base argument must be 0 or a value between // 2 and MaxBase. If fracOk is set, the base argument must be one of // 0, 2, 8, 10, or 16. Providing an invalid base argument leads to a run- // time panic. // // For base 0, the number prefix determines the actual base: A prefix of // “0b” or “0B” selects base 2, “0o” or “0O” selects base 8, and // “0x” or “0X” selects base 16. If fracOk is false, a “0” prefix // (immediately followed by digits) selects base 8 as well. Otherwise, // the selected base is 10 and no prefix is accepted. // // If fracOk is set, a period followed by a fractional part is permitted. // The result value is computed as if there were no period present; and // the count value is used to determine the fractional part. // // For bases <= 36, lower and upper case letters are considered the same: // The letters 'a' to 'z' and 'A' to 'Z' represent digit values 10 to 35. // For bases > 36, the upper case letters 'A' to 'Z' represent the digit // values 36 to 61. // // A result digit count > 0 corresponds to the number of (non-prefix) digits // parsed. A digit count <= 0 indicates the presence of a period (if fracOk // is set, only), and -count is the number of fractional digits found. // In this case, the actual value of the scanned number is res * b**count. func (z nat) scan(r io.ByteScanner, base int, fracOk bool) (res nat, b, count int, err error) { // Reject invalid bases. baseOk := base == 0 || !fracOk && 2 <= base && base <= MaxBase || fracOk && (base == 2 || base == 8 || base == 10 || base == 16) if !baseOk { panic(fmt.Sprintf("invalid number base %d", base)) } // prev encodes the previously seen char: it is one // of '_', '0' (a digit), or '.' (anything else). A // valid separator '_' may only occur after a digit // and if base == 0. prev := '.' invalSep := false // one char look-ahead ch, err := r.ReadByte() // Determine actual base. b, prefix := base, 0 if base == 0 { // Actual base is 10 unless there's a base prefix. b = 10 if err == nil && ch == '0' { prev = '0' count = 1 ch, err = r.ReadByte() if err == nil { // possibly one of 0b, 0B, 0o, 0O, 0x, 0X switch ch { case 'b', 'B': b, prefix = 2, 'b' case 'o', 'O': b, prefix = 8, 'o' case 'x', 'X': b, prefix = 16, 'x' default: if !fracOk { b, prefix = 8, '0' } } if prefix != 0 { count = 0 // prefix is not counted if prefix != '0' { ch, err = r.ReadByte() } } } } } // Convert string. // Algorithm: Collect digits in groups of at most n digits in di. // For bases that pack exactly into words (2, 4, 16), append di's // directly to the int representation and then reverse at the end (bn==0 marks this case). // For other bases, use mulAddWW for every such group to shift // z up one group and add di to the result. // With more cleverness we could also handle binary bases like 8 and 32 // (corresponding to 3-bit and 5-bit chunks) that don't pack nicely into // words, but those are not too important. z = z[:0] b1 := Word(b) var bn Word // b1**n (or 0 for the special bit-packing cases b=2,4,16) var n int // max digits that fit into Word switch b { case 2: // 1 bit per digit n = _W case 4: // 2 bits per digit n = _W / 2 case 16: // 4 bits per digit n = _W / 4 default: bn, n = maxPow(b1) } di := Word(0) // 0 <= di < b1**i < bn i := 0 // 0 <= i < n dp := -1 // position of decimal point for err == nil { if ch == '.' && fracOk { fracOk = false if prev == '_' { invalSep = true } prev = '.' dp = count } else if ch == '_' && base == 0 { if prev != '0' { invalSep = true } prev = '_' } else { // convert rune into digit value d1 var d1 Word switch { case '0' <= ch && ch <= '9': d1 = Word(ch - '0') case 'a' <= ch && ch <= 'z': d1 = Word(ch - 'a' + 10) case 'A' <= ch && ch <= 'Z': if b <= maxBaseSmall { d1 = Word(ch - 'A' + 10) } else { d1 = Word(ch - 'A' + maxBaseSmall) } default: d1 = MaxBase + 1 } if d1 >= b1 { r.UnreadByte() // ch does not belong to number anymore break } prev = '0' count++ // collect d1 in di di = di*b1 + d1 i++ // if di is "full", add it to the result if i == n { if bn == 0 { z = append(z, di) } else { z = z.mulAddWW(z, bn, di) } di = 0 i = 0 } } ch, err = r.ReadByte() } if err == io.EOF { err = nil } // other errors take precedence over invalid separators if err == nil && (invalSep || prev == '_') { err = errInvalSep } if count == 0 { // no digits found if prefix == '0' { // there was only the octal prefix 0 (possibly followed by separators and digits > 7); // interpret as decimal 0 return z[:0], 10, 1, err } err = errNoDigits // fall through; result will be 0 } if bn == 0 { if i > 0 { // Add remaining digit chunk to result. // Left-justify group's digits; will shift back down after reverse. z = append(z, di*pow(b1, n-i)) } slices.Reverse(z) z = z.norm() if i > 0 { z = z.rsh(z, uint(n-i)*uint(_W/n)) } } else { if i > 0 { // Add remaining digit chunk to result. z = z.mulAddWW(z, pow(b1, i), di) } } res = z // adjust count for fraction, if any if dp >= 0 { // 0 <= dp <= count count = dp - count } return } // utoa converts x to an ASCII representation in the given base; // base must be between 2 and MaxBase, inclusive. func (x nat) utoa(base int) []byte { return x.itoa(false, base) } // itoa is like utoa but it prepends a '-' if neg && x != 0. func (x nat) itoa(neg bool, base int) []byte { if base < 2 || base > MaxBase { panic("invalid base") } // x == 0 if len(x) == 0 { return []byte("0") } // len(x) > 0 // allocate buffer for conversion i := int(float64(x.bitLen())/math.Log2(float64(base))) + 1 // off by 1 at most if neg { i++ } s := make([]byte, i) // convert power of two and non power of two bases separately if b := Word(base); b == b&-b { // shift is base b digit size in bits shift := uint(bits.TrailingZeros(uint(b))) // shift > 0 because b >= 2 mask := Word(1<<shift - 1) w := x[0] // current word nbits := uint(_W) // number of unprocessed bits in w // convert less-significant words (include leading zeros) for k := 1; k < len(x); k++ { // convert full digits for nbits >= shift { i-- s[i] = digits[w&mask] w >>= shift nbits -= shift } // convert any partial leading digit and advance to next word if nbits == 0 { // no partial digit remaining, just advance w = x[k] nbits = _W } else { // partial digit in current word w (== x[k-1]) and next word x[k] w |= x[k] << nbits i-- s[i] = digits[w&mask] // advance w = x[k] >> (shift - nbits) nbits = _W - (shift - nbits) } } // convert digits of most-significant word w (omit leading zeros) for w != 0 { i-- s[i] = digits[w&mask] w >>= shift } } else { stk := getStack() defer stk.free() bb, ndigits := maxPow(b) // construct table of successive squares of bb*leafSize to use in subdivisions // result (table != nil) <=> (len(x) > leafSize > 0) table := divisors(stk, len(x), b, ndigits, bb) // preserve x, create local copy for use by convertWords q := nat(nil).set(x) // convert q to string s in base b q.convertWords(stk, s, b, ndigits, bb, table) // strip leading zeros // (x != 0; thus s must contain at least one non-zero digit // and the loop will terminate) i = 0 for s[i] == '0' { i++ } } if neg { i-- s[i] = '-' } return s[i:] } // Convert words of q to base b digits in s. If q is large, it is recursively "split in half" // by nat/nat division using tabulated divisors. Otherwise, it is converted iteratively using // repeated nat/Word division. // // The iterative method processes n Words by n divW() calls, each of which visits every Word in the // incrementally shortened q for a total of n + (n-1) + (n-2) ... + 2 + 1, or n(n+1)/2 divW()'s. // Recursive conversion divides q by its approximate square root, yielding two parts, each half // the size of q. Using the iterative method on both halves means 2 * (n/2)(n/2 + 1)/2 divW()'s // plus the expensive long div(). Asymptotically, the ratio is favorable at 1/2 the divW()'s, and // is made better by splitting the subblocks recursively. Best is to split blocks until one more // split would take longer (because of the nat/nat div()) than the twice as many divW()'s of the // iterative approach. This threshold is represented by leafSize. Benchmarking of leafSize in the // range 2..64 shows that values of 8 and 16 work well, with a 4x speedup at medium lengths and // ~30x for 20000 digits. Use nat_test.go's BenchmarkLeafSize tests to optimize leafSize for // specific hardware. func (q nat) convertWords(stk *stack, s []byte, b Word, ndigits int, bb Word, table []divisor) { // split larger blocks recursively if table != nil { // len(q) > leafSize > 0 var r nat index := len(table) - 1 for len(q) > leafSize { // find divisor close to sqrt(q) if possible, but in any case < q maxLength := q.bitLen() // ~= log2 q, or at of least largest possible q of this bit length minLength := maxLength >> 1 // ~= log2 sqrt(q) for index > 0 && table[index-1].nbits > minLength { index-- // desired } if table[index].nbits >= maxLength && table[index].bbb.cmp(q) >= 0 { index-- if index < 0 { panic("internal inconsistency") } } // split q into the two digit number (q'*bbb + r) to form independent subblocks q, r = q.div(stk, r, q, table[index].bbb) // convert subblocks and collect results in s[:h] and s[h:] h := len(s) - table[index].ndigits r.convertWords(stk, s[h:], b, ndigits, bb, table[0:index]) s = s[:h] // == q.convertWords(stk, s, b, ndigits, bb, table[0:index+1]) } } // having split any large blocks now process the remaining (small) block iteratively i := len(s) var r Word if b == 10 { // hard-coding for 10 here speeds this up by 1.25x (allows for / and % by constants) for len(q) > 0 { // extract least significant, base bb "digit" q, r = q.divW(q, bb) for j := 0; j < ndigits && i > 0; j++ { i-- // avoid % computation since r%10 == r - int(r/10)*10; // this appears to be faster for BenchmarkString10000Base10 // and smaller strings (but a bit slower for larger ones) t := r / 10 s[i] = '0' + byte(r-t*10) r = t } } } else { for len(q) > 0 { // extract least significant, base bb "digit" q, r = q.divW(q, bb) for j := 0; j < ndigits && i > 0; j++ { i-- s[i] = digits[r%b] r /= b } } } // prepend high-order zeros for i > 0 { // while need more leading zeros i-- s[i] = '0' } } // Split blocks greater than leafSize Words (or set to 0 to disable recursive conversion) // Benchmark and configure leafSize using: go test -bench="Leaf" // // 8 and 16 effective on 3.0 GHz Xeon "Clovertown" CPU (128 byte cache lines) // 8 and 16 effective on 2.66 GHz Core 2 Duo "Penryn" CPU var leafSize int = 8 // number of Word-size binary values treat as a monolithic block type divisor struct { bbb nat // divisor nbits int // bit length of divisor (discounting leading zeros) ~= log2(bbb) ndigits int // digit length of divisor in terms of output base digits } var cacheBase10 struct { sync.Mutex table [64]divisor // cached divisors for base 10 } // expWW computes x**y func (z nat) expWW(stk *stack, x, y Word) nat { return z.expNN(stk, nat(nil).setWord(x), nat(nil).setWord(y), nil, false) } // construct table of powers of bb*leafSize to use in subdivisions. func divisors(stk *stack, m int, b Word, ndigits int, bb Word) []divisor { // only compute table when recursive conversion is enabled and x is large if leafSize == 0 || m <= leafSize { return nil } // determine k where (bb**leafSize)**(2**k) >= sqrt(x) k := 1 for words := leafSize; words < m>>1 && k < len(cacheBase10.table); words <<= 1 { k++ } // reuse and extend existing table of divisors or create new table as appropriate var table []divisor // for b == 10, table overlaps with cacheBase10.table if b == 10 { cacheBase10.Lock() table = cacheBase10.table[0:k] // reuse old table for this conversion } else { table = make([]divisor, k) // create new table for this conversion } // extend table if table[k-1].ndigits == 0 { // add new entries as needed var larger nat for i := 0; i < k; i++ { if table[i].ndigits == 0 { if i == 0 { table[0].bbb = nat(nil).expWW(stk, bb, Word(leafSize)) table[0].ndigits = ndigits * leafSize } else { table[i].bbb = nat(nil).sqr(stk, table[i-1].bbb) table[i].ndigits = 2 * table[i-1].ndigits } // optimization: exploit aggregated extra bits in macro blocks larger = nat(nil).set(table[i].bbb) for mulAddVWW(larger, larger, b, 0) == 0 { table[i].bbb = table[i].bbb.set(larger) table[i].ndigits++ } table[i].nbits = table[i].bbb.bitLen() } } } if b == 10 { cacheBase10.Unlock() } return table }
// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. /* Multi-precision division. Here be dragons. Given u and v, where u is n+m digits, and v is n digits (with no leading zeros), the goal is to return quo, rem such that u = quo*v + rem, where 0 ≤ rem < v. That is, quo = ⌊u/v⌋ where ⌊x⌋ denotes the floor (truncation to integer) of x, and rem = u - quo·v. Long Division Division in a computer proceeds the same as long division in elementary school, but computers are not as good as schoolchildren at following vague directions, so we have to be much more precise about the actual steps and what can happen. We work from most to least significant digit of the quotient, doing: • Guess a digit q, the number of v to subtract from the current section of u to zero out the topmost digit. • Check the guess by multiplying q·v and comparing it against the current section of u, adjusting the guess as needed. • Subtract q·v from the current section of u. • Add q to the corresponding section of the result quo. When all digits have been processed, the final remainder is left in u and returned as rem. For example, here is a sketch of dividing 5 digits by 3 digits (n=3, m=2). q₂ q₁ q₀ _________________ v₂ v₁ v₀ ) u₄ u₃ u₂ u₁ u₀ ↓ ↓ ↓ | | [u₄ u₃ u₂]| | - [ q₂·v ]| | ----------- ↓ | [ rem | u₁]| - [ q₁·v ]| ----------- ↓ [ rem | u₀] - [ q₀·v ] ------------ [ rem ] Instead of creating new storage for the remainders and copying digits from u as indicated by the arrows, we use u's storage directly as both the source and destination of the subtractions, so that the remainders overwrite successive overlapping sections of u as the division proceeds, using a slice of u to identify the current section. This avoids all the copying as well as shifting of remainders. Division of u with n+m digits by v with n digits (in base B) can in general produce at most m+1 digits, because: • u < B^(n+m) [B^(n+m) has n+m+1 digits] • v ≥ B^(n-1) [B^(n-1) is the smallest n-digit number] • u/v < B^(n+m) / B^(n-1) [divide bounds for u, v] • u/v < B^(m+1) [simplify] The first step is special: it takes the top n digits of u and divides them by the n digits of v, producing the first quotient digit and an n-digit remainder. In the example, q₂ = ⌊u₄u₃u₂ / v⌋. The first step divides n digits by n digits to ensure that it produces only a single digit. Each subsequent step appends the next digit from u to the remainder and divides those n+1 digits by the n digits of v, producing another quotient digit and a new n-digit remainder. Subsequent steps divide n+1 digits by n digits, an operation that in general might produce two digits. However, as used in the algorithm, that division is guaranteed to produce only a single digit. The dividend is of the form rem·B + d, where rem is a remainder from the previous step and d is a single digit, so: • rem ≤ v - 1 [rem is a remainder from dividing by v] • rem·B ≤ v·B - B [multiply by B] • d ≤ B - 1 [d is a single digit] • rem·B + d ≤ v·B - 1 [add] • rem·B + d < v·B [change ≤ to <] • (rem·B + d)/v < B [divide by v] Guess and Check At each step we need to divide n+1 digits by n digits, but this is for the implementation of division by n digits, so we can't just invoke a division routine: we _are_ the division routine. Instead, we guess at the answer and then check it using multiplication. If the guess is wrong, we correct it. How can this guessing possibly be efficient? It turns out that the following statement (let's call it the Good Guess Guarantee) is true. If • q = ⌊u/v⌋ where u is n+1 digits and v is n digits, • q < B, and • the topmost digit of v = vₙ₋₁ ≥ B/2, then q̂ = ⌊uₙuₙ₋₁ / vₙ₋₁⌋ satisfies q ≤ q̂ ≤ q+2. (Proof below.) That is, if we know the answer has only a single digit and we guess an answer by ignoring the bottom n-1 digits of u and v, using a 2-by-1-digit division, then that guess is at least as large as the correct answer. It is also not too much larger: it is off by at most two from the correct answer. Note that in the first step of the overall division, which is an n-by-n-digit division, the 2-by-1 guess uses an implicit uₙ = 0. Note that using a 2-by-1-digit division here does not mean calling ourselves recursively. Instead, we use an efficient direct hardware implementation of that operation. Note that because q is u/v rounded down, q·v must not exceed u: u ≥ q·v. If a guess q̂ is too big, it will not satisfy this test. Viewed a different way, the remainder r̂ for a given q̂ is u - q̂·v, which must be positive. If it is negative, then the guess q̂ is too big. This gives us a way to compute q. First compute q̂ with 2-by-1-digit division. Then, while u < q̂·v, decrement q̂; this loop executes at most twice, because q̂ ≤ q+2. Scaling Inputs The Good Guess Guarantee requires that the top digit of v (vₙ₋₁) be at least B/2. For example in base 10, ⌊172/19⌋ = 9, but ⌊18/1⌋ = 18: the guess is wildly off because the first digit 1 is smaller than B/2 = 5. We can ensure that v has a large top digit by multiplying both u and v by the right amount. Continuing the example, if we multiply both 172 and 19 by 3, we now have ⌊516/57⌋, the leading digit of v is now ≥ 5, and sure enough ⌊51/5⌋ = 10 is much closer to the correct answer 9. It would be easier here to multiply by 4, because that can be done with a shift. Specifically, we can always count the number of leading zeros i in the first digit of v and then shift both u and v left by i bits. Having scaled u and v, the value ⌊u/v⌋ is unchanged, but the remainder will be scaled: 172 mod 19 is 1, but 516 mod 57 is 3. We have to divide the remainder by the scaling factor (shifting right i bits) when we finish. Note that these shifts happen before and after the entire division algorithm, not at each step in the per-digit iteration. Note the effect of scaling inputs on the size of the possible quotient. In the scaled u/v, u can gain a digit from scaling; v never does, because we pick the scaling factor to make v's top digit larger but without overflowing. If u and v have n+m and n digits after scaling, then: • u < B^(n+m) [B^(n+m) has n+m+1 digits] • v ≥ B^n / 2 [vₙ₋₁ ≥ B/2, so vₙ₋₁·B^(n-1) ≥ B^n/2] • u/v < B^(n+m) / (B^n / 2) [divide bounds for u, v] • u/v < 2 B^m [simplify] The quotient can still have m+1 significant digits, but if so the top digit must be a 1. This provides a different way to handle the first digit of the result: compare the top n digits of u against v and fill in either a 0 or a 1. Refining Guesses Before we check whether u < q̂·v, we can adjust our guess to change it from q̂ = ⌊uₙuₙ₋₁ / vₙ₋₁⌋ into the refined guess ⌊uₙuₙ₋₁uₙ₋₂ / vₙ₋₁vₙ₋₂⌋. Although not mentioned above, the Good Guess Guarantee also promises that this 3-by-2-digit division guess is more precise and at most one away from the real answer q. The improvement from the 2-by-1 to the 3-by-2 guess can also be done without n-digit math. If we have a guess q̂ = ⌊uₙuₙ₋₁ / vₙ₋₁⌋ and we want to see if it also equal to ⌊uₙuₙ₋₁uₙ₋₂ / vₙ₋₁vₙ₋₂⌋, we can use the same check we would for the full division: if uₙuₙ₋₁uₙ₋₂ < q̂·vₙ₋₁vₙ₋₂, then the guess is too large and should be reduced. Checking uₙuₙ₋₁uₙ₋₂ < q̂·vₙ₋₁vₙ₋₂ is the same as uₙuₙ₋₁uₙ₋₂ - q̂·vₙ₋₁vₙ₋₂ < 0, and uₙuₙ₋₁uₙ₋₂ - q̂·vₙ₋₁vₙ₋₂ = (uₙuₙ₋₁·B + uₙ₋₂) - q̂·(vₙ₋₁·B + vₙ₋₂) [splitting off the bottom digit] = (uₙuₙ₋₁ - q̂·vₙ₋₁)·B + uₙ₋₂ - q̂·vₙ₋₂ [regrouping] The expression (uₙuₙ₋₁ - q̂·vₙ₋₁) is the remainder of uₙuₙ₋₁ / vₙ₋₁. If the initial guess returns both q̂ and its remainder r̂, then checking whether uₙuₙ₋₁uₙ₋₂ < q̂·vₙ₋₁vₙ₋₂ is the same as checking r̂·B + uₙ₋₂ < q̂·vₙ₋₂. If we find that r̂·B + uₙ₋₂ < q̂·vₙ₋₂, then we can adjust the guess by decrementing q̂ and adding vₙ₋₁ to r̂. We repeat until r̂·B + uₙ₋₂ ≥ q̂·vₙ₋₂. (As before, this fixup is only needed at most twice.) Now that q̂ = ⌊uₙuₙ₋₁uₙ₋₂ / vₙ₋₁vₙ₋₂⌋, as mentioned above it is at most one away from the correct q, and we've avoided doing any n-digit math. (If we need the new remainder, it can be computed as r̂·B + uₙ₋₂ - q̂·vₙ₋₂.) The final check u < q̂·v and the possible fixup must be done at full precision. For random inputs, a fixup at this step is exceedingly rare: the 3-by-2 guess is not often wrong at all. But still we must do the check. Note that since the 3-by-2 guess is off by at most 1, it can be convenient to perform the final u < q̂·v as part of the computation of the remainder r = u - q̂·v. If the subtraction underflows, decremeting q̂ and adding one v back to r is enough to arrive at the final q, r. That's the entirety of long division: scale the inputs, and then loop over each output position, guessing, checking, and correcting the next output digit. For a 2n-digit number divided by an n-digit number (the worst size-n case for division complexity), this algorithm uses n+1 iterations, each of which must do at least the 1-by-n-digit multiplication q̂·v. That's O(n) iterations of O(n) time each, so O(n²) time overall. Recursive Division For very large inputs, it is possible to improve on the O(n²) algorithm. Let's call a group of n/2 real digits a (very) “wide digit”. We can run the standard long division algorithm explained above over the wide digits instead of the actual digits. This will result in many fewer steps, but the math involved in each step is more work. Where basic long division uses a 2-by-1-digit division to guess the initial q̂, the new algorithm must use a 2-by-1-wide-digit division, which is of course really an n-by-n/2-digit division. That's OK: if we implement n-digit division in terms of n/2-digit division, the recursion will terminate when the divisor becomes small enough to handle with standard long division or even with the 2-by-1 hardware instruction. For example, here is a sketch of dividing 10 digits by 4, proceeding with wide digits corresponding to two regular digits. The first step, still special, must leave off a (regular) digit, dividing 5 by 4 and producing a 4-digit remainder less than v. The middle steps divide 6 digits by 4, guaranteed to produce two output digits each (one wide digit) with 4-digit remainders. The final step must use what it has: the 4-digit remainder plus one more, 5 digits to divide by 4. q₆ q₅ q₄ q₃ q₂ q₁ q₀ _______________________________ v₃ v₂ v₁ v₀ ) u₉ u₈ u₇ u₆ u₅ u₄ u₃ u₂ u₁ u₀ ↓ ↓ ↓ ↓ ↓ | | | | | [u₉ u₈ u₇ u₆ u₅]| | | | | - [ q₆q₅·v ]| | | | | ----------------- ↓ ↓ | | | [ rem |u₄ u₃]| | | - [ q₄q₃·v ]| | | -------------------- ↓ ↓ | [ rem |u₂ u₁]| - [ q₂q₁·v ]| -------------------- ↓ [ rem |u₀] - [ q₀·v ] ------------------ [ rem ] An alternative would be to look ahead to how well n/2 divides into n+m and adjust the first step to use fewer digits as needed, making the first step more special to make the last step not special at all. For example, using the same input, we could choose to use only 4 digits in the first step, leaving a full wide digit for the last step: q₆ q₅ q₄ q₃ q₂ q₁ q₀ _______________________________ v₃ v₂ v₁ v₀ ) u₉ u₈ u₇ u₆ u₅ u₄ u₃ u₂ u₁ u₀ ↓ ↓ ↓ ↓ | | | | | | [u₉ u₈ u₇ u₆]| | | | | | - [ q₆·v ]| | | | | | -------------- ↓ ↓ | | | | [ rem |u₅ u₄]| | | | - [ q₅q₄·v ]| | | | -------------------- ↓ ↓ | | [ rem |u₃ u₂]| | - [ q₃q₂·v ]| | -------------------- ↓ ↓ [ rem |u₁ u₀] - [ q₁q₀·v ] --------------------- [ rem ] Today, the code in divRecursiveStep works like the first example. Perhaps in the future we will make it work like the alternative, to avoid a special case in the final iteration. Either way, each step is a 3-by-2-wide-digit division approximated first by a 2-by-1-wide-digit division, just as we did for regular digits in long division. Because the actual answer we want is a 3-by-2-wide-digit division, instead of multiplying q̂·v directly during the fixup, we can use the quick refinement from long division (an n/2-by-n/2 multiply) to correct q to its actual value and also compute the remainder (as mentioned above), and then stop after that, never doing a full n-by-n multiply. Instead of using an n-by-n/2-digit division to produce n/2 digits, we can add (not discard) one more real digit, doing an (n+1)-by-(n/2+1)-digit division that produces n/2+1 digits. That single extra digit tightens the Good Guess Guarantee to q ≤ q̂ ≤ q+1 and lets us drop long division's special treatment of the first digit. These benefits are discussed more after the Good Guess Guarantee proof below. How Fast is Recursive Division? For a 2n-by-n-digit division, this algorithm runs a 4-by-2 long division over wide digits, producing two wide digits plus a possible leading regular digit 1, which can be handled without a recursive call. That is, the algorithm uses two full iterations, each using an n-by-n/2-digit division and an n/2-by-n/2-digit multiplication, along with a few n-digit additions and subtractions. The standard n-by-n-digit multiplication algorithm requires O(n²) time, making the overall algorithm require time T(n) where T(n) = 2T(n/2) + O(n) + O(n²) which, by the Bentley-Haken-Saxe theorem, ends up reducing to T(n) = O(n²). This is not an improvement over regular long division. When the number of digits n becomes large enough, Karatsuba's algorithm for multiplication can be used instead, which takes O(n^log₂3) = O(n^1.6) time. (Karatsuba multiplication is implemented in func karatsuba in nat.go.) That makes the overall recursive division algorithm take O(n^1.6) time as well, which is an improvement, but again only for large enough numbers. It is not critical to make sure that every recursion does only two recursive calls. While in general the number of recursive calls can change the time analysis, in this case doing three calls does not change the analysis: T(n) = 3T(n/2) + O(n) + O(n^log₂3) ends up being T(n) = O(n^log₂3). Because the Karatsuba multiplication taking time O(n^log₂3) is itself doing 3 half-sized recursions, doing three for the division does not hurt the asymptotic performance. Of course, it is likely still faster in practice to do two. Proof of the Good Guess Guarantee Given numbers x, y, let us break them into the quotients and remainders when divided by some scaling factor S, with the added constraints that the quotient x/y and the high part of y are both less than some limit T, and that the high part of y is at least half as big as T. x₁ = ⌊x/S⌋ y₁ = ⌊y/S⌋ x₀ = x mod S y₀ = y mod S x = x₁·S + x₀ 0 ≤ x₀ < S x/y < T y = y₁·S + y₀ 0 ≤ y₀ < S T/2 ≤ y₁ < T And consider the two truncated quotients: q = ⌊x/y⌋ q̂ = ⌊x₁/y₁⌋ We will prove that q ≤ q̂ ≤ q+2. The guarantee makes no real demands on the scaling factor S: it is simply the magnitude of the digits cut from both x and y to produce x₁ and y₁. The guarantee makes only limited demands on T: it must be large enough to hold the quotient x/y, and y₁ must have roughly the same size. To apply to the earlier discussion of 2-by-1 guesses in long division, we would choose: S = Bⁿ⁻¹ T = B x = u x₁ = uₙuₙ₋₁ x₀ = uₙ₋₂...u₀ y = v y₁ = vₙ₋₁ y₀ = vₙ₋₂...u₀ These simpler variables avoid repeating those longer expressions in the proof. Note also that, by definition, truncating division ⌊x/y⌋ satisfies x/y - 1 < ⌊x/y⌋ ≤ x/y. This fact will be used a few times in the proofs. Proof that q ≤ q̂: q̂·y₁ = ⌊x₁/y₁⌋·y₁ [by definition, q̂ = ⌊x₁/y₁⌋] > (x₁/y₁ - 1)·y₁ [x₁/y₁ - 1 < ⌊x₁/y₁⌋] = x₁ - y₁ [distribute y₁] So q̂·y₁ > x₁ - y₁. Since q̂·y₁ is an integer, q̂·y₁ ≥ x₁ - y₁ + 1. q̂ - q = q̂ - ⌊x/y⌋ [by definition, q = ⌊x/y⌋] ≥ q̂ - x/y [⌊x/y⌋ < x/y] = (1/y)·(q̂·y - x) [factor out 1/y] ≥ (1/y)·(q̂·y₁·S - x) [y = y₁·S + y₀ ≥ y₁·S] ≥ (1/y)·((x₁ - y₁ + 1)·S - x) [above: q̂·y₁ ≥ x₁ - y₁ + 1] = (1/y)·(x₁·S - y₁·S + S - x) [distribute S] = (1/y)·(S - x₀ - y₁·S) [-x = -x₁·S - x₀] > -y₁·S / y [x₀ < S, so S - x₀ > 0; drop it] ≥ -1 [y₁·S ≤ y] So q̂ - q > -1. Since q̂ - q is an integer, q̂ - q ≥ 0, or equivalently q ≤ q̂. Proof that q̂ ≤ q+2: x₁/y₁ - x/y = x₁·S/y₁·S - x/y [multiply left term by S/S] ≤ x/y₁·S - x/y [x₁S ≤ x] = (x/y)·(y/y₁·S - 1) [factor out x/y] = (x/y)·((y - y₁·S)/y₁·S) [move -1 into y/y₁·S fraction] = (x/y)·(y₀/y₁·S) [y - y₁·S = y₀] = (x/y)·(1/y₁)·(y₀/S) [factor out 1/y₁] < (x/y)·(1/y₁) [y₀ < S, so y₀/S < 1] ≤ (x/y)·(2/T) [y₁ ≥ T/2, so 1/y₁ ≤ 2/T] < T·(2/T) [x/y < T] = 2 [T·(2/T) = 2] So x₁/y₁ - x/y < 2. q̂ - q = ⌊x₁/y₁⌋ - q [by definition, q̂ = ⌊x₁/y₁⌋] = ⌊x₁/y₁⌋ - ⌊x/y⌋ [by definition, q = ⌊x/y⌋] ≤ x₁/y₁ - ⌊x/y⌋ [⌊x₁/y₁⌋ ≤ x₁/y₁] < x₁/y₁ - (x/y - 1) [⌊x/y⌋ > x/y - 1] = (x₁/y₁ - x/y) + 1 [regrouping] < 2 + 1 [above: x₁/y₁ - x/y < 2] = 3 So q̂ - q < 3. Since q̂ - q is an integer, q̂ - q ≤ 2. Note that when x/y < T/2, the bounds tighten to x₁/y₁ - x/y < 1 and therefore q̂ - q ≤ 1. Note also that in the general case 2n-by-n division where we don't know that x/y < T, we do know that x/y < 2T, yielding the bound q̂ - q ≤ 4. So we could remove the special case first step of long division as long as we allow the first fixup loop to run up to four times. (Using a simple comparison to decide whether the first digit is 0 or 1 is still more efficient, though.) Finally, note that when dividing three leading base-B digits by two (scaled), we have T = B² and x/y < B = T/B, a much tighter bound than x/y < T. This in turn yields the much tighter bound x₁/y₁ - x/y < 2/B. This means that ⌊x₁/y₁⌋ and ⌊x/y⌋ can only differ when x/y is less than 2/B greater than an integer. For random x and y, the chance of this is 2/B, or, for large B, approximately zero. This means that after we produce the 3-by-2 guess in the long division algorithm, the fixup loop essentially never runs. In the recursive algorithm, the extra digit in (2·⌊n/2⌋+1)-by-(⌊n/2⌋+1)-digit division has exactly the same effect: the probability of needing a fixup is the same 2/B. Even better, we can allow the general case x/y < 2T and the fixup probability only grows to 4/B, still essentially zero. References There are no great references for implementing long division; thus this comment. Here are some notes about what to expect from the obvious references. Knuth Volume 2 (Seminumerical Algorithms) section 4.3.1 is the usual canonical reference for long division, but that entire series is highly compressed, never repeating a necessary fact and leaving important insights to the exercises. For example, no rationale whatsoever is given for the calculation that extends q̂ from a 2-by-1 to a 3-by-2 guess, nor why it reduces the error bound. The proof that the calculation even has the desired effect is left to exercises. The solutions to those exercises provided at the back of the book are entirely calculations, still with no explanation as to what is going on or how you would arrive at the idea of doing those exact calculations. Nowhere is it mentioned that this test extends the 2-by-1 guess into a 3-by-2 guess. The proof of the Good Guess Guarantee is only for the 2-by-1 guess and argues by contradiction, making it difficult to understand how modifications like adding another digit or adjusting the quotient range affects the overall bound. All that said, Knuth remains the canonical reference. It is dense but packed full of information and references, and the proofs are simpler than many other presentations. The proofs above are reworkings of Knuth's to remove the arguments by contradiction and add explanations or steps that Knuth omitted. But beware of errors in older printings. Take the published errata with you. Brinch Hansen's “Multiple-length Division Revisited: a Tour of the Minefield” starts with a blunt critique of Knuth's presentation (among others) and then presents a more detailed and easier to follow treatment of long division, including an implementation in Pascal. But the algorithm and implementation work entirely in terms of 3-by-2 division, which is much less useful on modern hardware than an algorithm using 2-by-1 division. The proofs are a bit too focused on digit counting and seem needlessly complex, especially compared to the ones given above. Burnikel and Ziegler's “Fast Recursive Division” introduced the key insight of implementing division by an n-digit divisor using recursive calls to division by an n/2-digit divisor, relying on Karatsuba multiplication to yield a sub-quadratic run time. However, the presentation decisions are made almost entirely for the purpose of simplifying the run-time analysis, rather than simplifying the presentation. Instead of a single algorithm that loops over quotient digits, the paper presents two mutually-recursive algorithms, for 2n-by-n and 3n-by-2n. The paper also does not present any general (n+m)-by-n algorithm. The proofs in the paper are remarkably complex, especially considering that the algorithm is at its core just long division on wide digits, so that the usual long division proofs apply essentially unaltered. */ package big import "math/bits" // rem returns r such that r = u%v. // It uses z as the storage for r. func (z nat) rem(stk *stack, u, v nat) (r nat) { if alias(z, u) { z = nil } defer stk.restore(stk.save()) q := stk.nat(max(1, len(u)-(len(v)-1))) _, r = q.div(stk, z, u, v) return r } // div returns q, r such that q = ⌊u/v⌋ and r = u%v = u - q·v. // It uses z and z2 as the storage for q and r. // The caller may pass stk == nil to request that div obtain and release one itself. func (z nat) div(stk *stack, z2, u, v nat) (q, r nat) { if len(v) == 0 { panic("division by zero") } if len(v) == 1 { // Short division: long optimized for a single-word divisor. // In that case, the 2-by-1 guess is all we need at each step. var r2 Word q, r2 = z.divW(u, v[0]) r = z2.setWord(r2) return } if u.cmp(v) < 0 { q = z[:0] r = z2.set(u) return } if stk == nil { stk = getStack() defer stk.free() } q, r = z.divLarge(stk, z2, u, v) return } // divW returns q, r such that q = ⌊x/y⌋ and r = x%y = x - q·y. // It uses z as the storage for q. // Note that y is a single digit (Word), not a big number. func (z nat) divW(x nat, y Word) (q nat, r Word) { m := len(x) switch { case y == 0: panic("division by zero") case y == 1: q = z.set(x) // result is x return case m == 0: q = z[:0] // result is 0 return } // m > 0 z = z.make(m) r = divWVW(z, 0, x, y) q = z.norm() return } // modW returns x % d. func (x nat) modW(d Word) (r Word) { // TODO(agl): we don't actually need to store the q value. var q nat q = q.make(len(x)) return divWVW(q, 0, x, d) } // divWVW overwrites z with ⌊x/y⌋, returning the remainder r. // The caller must ensure that len(z) = len(x). func divWVW(z []Word, xn Word, x []Word, y Word) (r Word) { r = xn if len(x) == 1 { qq, rr := bits.Div(uint(r), uint(x[0]), uint(y)) z[0] = Word(qq) return Word(rr) } rec := reciprocalWord(y) for i := len(z) - 1; i >= 0; i-- { z[i], r = divWW(r, x[i], y, rec) } return r } // div returns q, r such that q = ⌊uIn/vIn⌋ and r = uIn%vIn = uIn - q·vIn. // It uses z and u as the storage for q and r. // The caller must ensure that len(vIn) ≥ 2 (use divW otherwise) // and that len(uIn) ≥ len(vIn) (the answer is 0, uIn otherwise). func (z nat) divLarge(stk *stack, u, uIn, vIn nat) (q, r nat) { n := len(vIn) m := len(uIn) - n // Scale the inputs so vIn's top bit is 1 (see “Scaling Inputs” above). // vIn is treated as a read-only input (it may be in use by another // goroutine), so we must make a copy. // uIn is copied to u. defer stk.restore(stk.save()) shift := nlz(vIn[n-1]) v := stk.nat(n) u = u.make(len(uIn) + 1) if shift == 0 { copy(v, vIn) copy(u[:len(uIn)], uIn) u[len(uIn)] = 0 } else { lshVU(v, vIn, shift) u[len(uIn)] = lshVU(u[:len(uIn)], uIn, shift) } // The caller should not pass aliased z and u, since those are // the two different outputs, but correct just in case. if alias(z, u) { z = nil } q = z.make(m + 1) // Use basic or recursive long division depending on size. if n < divRecursiveThreshold { q.divBasic(stk, u, v) } else { q.divRecursive(stk, u, v) } q = q.norm() // Undo scaling of remainder. if shift != 0 { rshVU(u, u, shift) } r = u.norm() return q, r } // divBasic implements long division as described above. // It overwrites q with ⌊u/v⌋ and overwrites u with the remainder r. // q must be large enough to hold ⌊u/v⌋. func (q nat) divBasic(stk *stack, u, v nat) { n := len(v) m := len(u) - n defer stk.restore(stk.save()) qhatv := stk.nat(n + 1) // Set up for divWW below, precomputing reciprocal argument. vn1 := v[n-1] rec := reciprocalWord(vn1) // Invent a leading 0 for u, for the first iteration. // Invariant: ujn == u[j+n] in each iteration. ujn := Word(0) // Compute each digit of quotient. for j := m; j >= 0; j-- { // Compute the 2-by-1 guess q̂. qhat := Word(_M) // ujn ≤ vn1, or else q̂ would be more than one digit. // For ujn == vn1, we set q̂ to the max digit M above. // Otherwise, we compute the 2-by-1 guess. if ujn != vn1 { var rhat Word qhat, rhat = divWW(ujn, u[j+n-1], vn1, rec) // Refine q̂ to a 3-by-2 guess. See “Refining Guesses” above. vn2 := v[n-2] x1, x2 := mulWW(qhat, vn2) ujn2 := u[j+n-2] for greaterThan(x1, x2, rhat, ujn2) { // x1x2 > r̂ u[j+n-2] qhat-- prevRhat := rhat rhat += vn1 // If r̂ overflows, then // r̂ u[j+n-2]v[n-1] is now definitely > x1 x2. if rhat < prevRhat { break } // TODO(rsc): No need for a full mulWW. // x2 += vn2; if x2 overflows, x1++ x1, x2 = mulWW(qhat, vn2) } } // Compute q̂·v. qhatv[n] = mulAddVWW(qhatv[0:n], v, qhat, 0) qhl := len(qhatv) if j+qhl > len(u) && qhatv[n] == 0 { qhl-- } // Subtract q̂·v from the current section of u. // If it underflows, q̂·v > u, which we fix up // by decrementing q̂ and adding v back. c := subVV(u[j:j+qhl], u[j:j+qhl], qhatv[:qhl]) if c != 0 { c := addVV(u[j:j+n], u[j:j+n], v) // If n == qhl, the carry from subVV and the carry from addVV // cancel out and don't affect u[j+n]. if n < qhl { u[j+n] += c } qhat-- } ujn = u[j+n-1] // Save quotient digit. // Caller may know the top digit is zero and not leave room for it. if j == m && m == len(q) && qhat == 0 { continue } q[j] = qhat } } // greaterThan reports whether the two digit numbers x1 x2 > y1 y2. // TODO(rsc): In contradiction to most of this file, x1 is the high // digit and x2 is the low digit. This should be fixed. func greaterThan(x1, x2, y1, y2 Word) bool { return x1 > y1 || x1 == y1 && x2 > y2 } // divRecursiveThreshold is the number of divisor digits // at which point divRecursive is faster than divBasic. var divRecursiveThreshold = 40 // see calibrate_test.go // divRecursive implements recursive division as described above. // It overwrites z with ⌊u/v⌋ and overwrites u with the remainder r. // z must be large enough to hold ⌊u/v⌋. // This function is just for allocating and freeing temporaries // around divRecursiveStep, the real implementation. func (z nat) divRecursive(stk *stack, u, v nat) { clear(z) z.divRecursiveStep(stk, u, v, 0) } // divRecursiveStep is the actual implementation of recursive division. // It adds ⌊u/v⌋ to z and overwrites u with the remainder r. // z must be large enough to hold ⌊u/v⌋. // It uses temps[depth] (allocating if needed) as a temporary live across // the recursive call. It also uses tmp, but not live across the recursion. func (z nat) divRecursiveStep(stk *stack, u, v nat, depth int) { // u is a subsection of the original and may have leading zeros. // TODO(rsc): The v = v.norm() is useless and should be removed. // We know (and require) that v's top digit is ≥ B/2. u = u.norm() v = v.norm() if len(u) == 0 { clear(z) return } // Fall back to basic division if the problem is now small enough. n := len(v) if n < divRecursiveThreshold { z.divBasic(stk, u, v) return } // Nothing to do if u is shorter than v (implies u < v). m := len(u) - n if m < 0 { return } // We consider B digits in a row as a single wide digit. // (See “Recursive Division” above.) // // TODO(rsc): rename B to Wide, to avoid confusion with _B, // which is something entirely different. // TODO(rsc): Look into whether using ⌈n/2⌉ is better than ⌊n/2⌋. B := n / 2 // Allocate a nat for qhat below. defer stk.restore(stk.save()) qhat0 := stk.nat(B + 1) // Compute each wide digit of the quotient. // // TODO(rsc): Change the loop to be // for j := (m+B-1)/B*B; j > 0; j -= B { // which will make the final step a regular step, letting us // delete what amounts to an extra copy of the loop body below. j := m for j > B { // Divide u[j-B:j+n] (3 wide digits) by v (2 wide digits). // First make the 2-by-1-wide-digit guess using a recursive call. // Then extend the guess to the full 3-by-2 (see “Refining Guesses”). // // For the 2-by-1-wide-digit guess, instead of doing 2B-by-B-digit, // we use a (2B+1)-by-(B+1) digit, which handles the possibility that // the result has an extra leading 1 digit as well as guaranteeing // that the computed q̂ will be off by at most 1 instead of 2. // s is the number of digits to drop from the 3B- and 2B-digit chunks. // We drop B-1 to be left with 2B+1 and B+1. s := (B - 1) // uu is the up-to-3B-digit section of u we are working on. uu := u[j-B:] // Compute the 2-by-1 guess q̂, leaving r̂ in uu[s:B+n]. qhat := qhat0 clear(qhat) qhat.divRecursiveStep(stk, uu[s:B+n], v[s:], depth+1) qhat = qhat.norm() // Extend to a 3-by-2 quotient and remainder. // Because divRecursiveStep overwrote the top part of uu with // the remainder r̂, the full uu already contains the equivalent // of r̂·B + uₙ₋₂ from the “Refining Guesses” discussion. // Subtracting q̂·vₙ₋₂ from it will compute the full-length remainder. // If that subtraction underflows, q̂·v > u, which we fix up // by decrementing q̂ and adding v back, same as in long division. // TODO(rsc): Instead of subtract and fix-up, this code is computing // q̂·vₙ₋₂ and decrementing q̂ until that product is ≤ u. // But we can do the subtraction directly, as in the comment above // and in long division, because we know that q̂ is wrong by at most one. mark := stk.save() qhatv := stk.nat(3 * n) clear(qhatv) qhatv = qhatv.mul(stk, qhat, v[:s]) for i := 0; i < 2; i++ { e := qhatv.cmp(uu.norm()) if e <= 0 { break } subVW(qhat, qhat, 1) c := subVV(qhatv[:s], qhatv[:s], v[:s]) if len(qhatv) > s { subVW(qhatv[s:], qhatv[s:], c) } addTo(uu[s:], v[s:]) } if qhatv.cmp(uu.norm()) > 0 { panic("impossible") } c := subVV(uu[:len(qhatv)], uu[:len(qhatv)], qhatv) if c > 0 { subVW(uu[len(qhatv):], uu[len(qhatv):], c) } addTo(z[j-B:], qhat) j -= B stk.restore(mark) } // TODO(rsc): Rewrite loop as described above and delete all this code. // Now u < (v<<B), compute lower bits in the same way. // Choose shift = B-1 again. s := B - 1 qhat := qhat0 clear(qhat) qhat.divRecursiveStep(stk, u[s:].norm(), v[s:], depth+1) qhat = qhat.norm() qhatv := stk.nat(3 * n) clear(qhatv) qhatv = qhatv.mul(stk, qhat, v[:s]) // Set the correct remainder as before. for i := 0; i < 2; i++ { if e := qhatv.cmp(u.norm()); e > 0 { subVW(qhat, qhat, 1) c := subVV(qhatv[:s], qhatv[:s], v[:s]) if len(qhatv) > s { subVW(qhatv[s:], qhatv[s:], c) } addTo(u[s:], v[s:]) } } if qhatv.cmp(u.norm()) > 0 { panic("impossible") } c := subVV(u[:len(qhatv)], u[:len(qhatv)], qhatv) if c > 0 { c = subVW(u[len(qhatv):], u[len(qhatv):], c) } if c > 0 { panic("impossible") } // Done! addTo(z, qhat.norm()) }
// Copyright 2025 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // FFT-based multiplication. // // References: // // Brent and Zimmermann, Modern Computer Arithmetic, // Cambridge University Press 2011. package big import ( "math" "math/bits" ) // A nega is a fixed-width nat of length 1+N/_B representing // a number modulo 1+2**N. The top word is either 0 or 1, // with a top 1 representing 2**N (= -1 mod 1+2**N). // Except for 2**N, there are two representations of each number, // using and not using the top bit; for example the two representations // of zero are 0,0,0,...,0,0,0 and 1,0,0,...,0,0,1. // nega is short for negacyclic, because the math can be viewed // as analogous to math on negacyclic polynomials. type nega nat // set sets z = x. func (z nega) set(x nega) { if len(z) != len(x) { panic("bad z") } copy(z, x) } // norm modifies z in place to normalize the representation // (clearing the top word except when the value is 2**N) // and then returns the equivalent nat. func (z nega) norm() nat { // TODO special case z[0] >= 1? if z[len(z)-1] == 1 && len(nat(z[:len(z)-1]).norm()) > 0 { z[len(z)-1] = 0 subVW(z, z, 1) } return nat(z).norm() } // debugNat allocates and returns a new nat with the value that x represents. // It is intended only for debugging. func (x nega) debugNat() nat { z := make(nat, len(x)) copy(z, x) if z[len(z)-1] == 1 && len(z[:len(z)-1].norm()) > 0 { z[len(z)-1] = 0 subVW(z, z, 1) } return z.norm() } // debugInt allocates and returns a new Int with the value that x represents. // It is intended only for debugging. func (x nega) debugInt() *Int { i := new(Int) i.abs = x.debugNat() return i } // add sets z = x + y (mod 1+2**N). // z and x must have the same length; // y can be shorter, in which case its top words are assumed to be zero. // Allowing a shorter y helps in the implementation of [nega.setNat]. func (z nega) add(x, y nega) { if len(z) != len(x) || len(z) < len(y) { panic("math/big: bad nega add") } if c := addVV(z[:len(y)], x[:len(y)], y); c != 0 { addVW(z[len(y):], x[len(y):], c) } n := len(z) - 1 if zn := z[n]; zn > 1 { z[n] = 1 subVW(z, z, zn-1) } } // sub sets z = x - y (mod 1+2**N). // Like with add, z and x must have the same length, // and y can be shorter, with implied high zero words. func (z nega) sub(x, y nega) { if len(z) != len(x) || len(z) < len(y) { panic("math/big: bad nega sub") } if c := subVV(z[:len(y)], x[:len(y)], y); c != 0 { subVW(z[len(y):], x[len(y):], c) } n := len(z) - 1 if zn := z[n]; zn > 1 { z[n] = 0 addVW(z, z, -zn) } } const useSubAdd = false func (z nega) subAdd(x, y nega) { if len(x) != len(y) { panic("math/big: bad nega addSub") } addSubVV(x, z, x, y) n := len(x) - 1 if xn := x[n]; xn > 1 { x[n] = 1 subVW(x, x, xn-1) } if zn := z[n]; zn > 1 { z[n] = 0 addVW(z, z, -zn) } } // halfShift sets z = x << k/2, using t as a temporary if needed. func (z nega) halfShift(x, t nega, k int) { if k&1 == 0 { z.shift(x, k>>1) return } k >>= 1 n := len(z) - 1 z.shift(x, k+3*(_W/4)*n) t.shift(x, k+(_W/4)*n) z.sub(z, t) } // shift sets z = x << k. func (z nega) shift(x nega, k int) { if len(z) != len(x) { panic("math/big: bad nega shift") } n := len(x) - 1 k %= 2 * n * _W if k < 0 { k += 2 * n * _W } s := k / _W sbit := k % _W // TODO simplify switch { case s == 0: copy(z, x) case s == n: // s==n, negating entirely clear(z) subVV(z, z, x) if zn := z[n]; zn > 0 { z[n] = 0 addVW(z, z, -zn) } case s < n: z[n] = 0 copy(z[s:n], x[:n-s]) clear(z[:s]) if c := subVV(z[:s], z[:s], x[n-s:n]); c != 0 { subVW(z[s:], z[s:], c) } if x[n] != 0 { subVW(z[s:], z[s:], x[n]) } if zn := z[n]; zn > 0 { z[n] = 0 addVW(z, z, -zn) } default: // s > n, swapping what's negated s -= n clear(z) copy(z[:s], x[n-s:n]) z[n] = -subVV(z[s:n], z[s:n], x[:n-s]) if x[n] != 0 { addVW(z[s:], z[s:], x[n]) } if zn := z[n]; zn > 1 { z[n] = 0 addVW(z, z, -zn) } } if sbit != 0 { shlVU(z, z, uint(sbit)) if zn := z[n]; zn > 1 { z[n] = 1 subVW(z, z, zn-1) } } } // setNat sets z = x mod 1+2**N. func (z nega) setNat(x nat) { // Work in sections of m words: bottom m get added, // next m get subtracted, next m get added, and so on. m := len(z) - 1 // Zero z and add first run of m words. n := copy(z[:m], x) x = x[n:] clear(z[n:]) // Subtract and add successive chunks of x until x is done. for len(x) > 0 { // Subtract up to m words. n := min(len(x), m) z.sub(z, nega(x[:n])) x = x[n:] if len(x) == 0 { break } // Add up to m words. n = min(len(x), m) z.add(z, nega(x[:n])) x = x[n:] } } // A negaVec is a vector of negacyclic numbers, all of the same length. type negaVec struct { words []Word len int // numbers in vector wid int // width of each number (in Words) } func newNegaVec(stk *stack, len, wid int) *negaVec { return &negaVec{stk.nat(len * wid), len, wid} } // at returns the i'th number in the vector v. func (v *negaVec) at(i int) nega { start := i * v.wid end := start + v.wid return nega(v.words[start:end:end]) } // negaSplit splits x into v, copying successive n-word chunks // from x into elements of v. // If twist != 0, it sets x[i] = x[i]<<i*twist as well. // t1 and t2 are available temporaries. func negaSplit(x nat, v *negaVec, t1, t2 nega, n, twist int) { clear(v.words) for i := range v.len { if i == v.len-1 { n++ } z := v.at(i) if twist != 0 && i != 0 { j := copy(t1[:n], x) clear(t1[j:]) x = x[j:] z.halfShift(t1, t2, twist*i) } else { j := copy(z[:n], x) clear(z[j:]) x = x[j:] } } if len(x) != 0 { panic("math/big: bad negaSplit") } } func (x nega) isZero() bool { for _, w := range x { if w != 0 { return false } } return true } // forwardFFT implements a forward discrete FFT on the vector v, // considering only the width elements at off, off+step, off+2*step, off+(width-1)*step // and using 1<<shift as the width'th root of unity. // It leaves the elements in bit-reversed order. // // Brent and Zimmerman, Algorithm 2.2. func forwardFFT(v *negaVec, t1, t2 nega, off, step, hshift, width int) { if width == 2 { x := v.at(off) y := v.at(off + step) y.subAdd(x, y) return } width /= 2 forwardFFT(v, t1, t2, off, step*2, hshift*2, width) forwardFFT(v, t1, t2, off+step, step*2, hshift*2, width) b := bits.TrailingZeros(uint(width)) for j := 0; j < width; j++ { i := off + 2*j*step x := v.at(i) y := v.at(i + step) t1.halfShift(y, t2, hshift*bitrev(j, b)) y.subAdd(x, t1) } } func bitrev(i, n int) int { return int(bits.Reverse(uint(i)) >> uint(_W-n)) } // backwardFFT implements a backward discrete FFT on the vector v, // considering only the width elements at off, off+step, off+2*step, off+(width-1)*step // and using 1<<shift as the width'th root of unity. // It expects the elements in bit-reversed order and un-reverses them. // // Brent and Zimmerman, Algorithm 2.3. func backwardFFT(v *negaVec, t1, t2 nega, off, step, hshift, width int) { if width == 2 { x := v.at(off) y := v.at(off + step) t1.add(x, y) y.sub(x, y) x.set(t1) return } width /= 2 backwardFFT(v, t1, t2, off, step, hshift*2, width) backwardFFT(v, t1, t2, off+step*width, step, hshift*2, width) for j := 0; j < width; j++ { x := v.at(off + j*step) y := v.at(off + (j+width)*step) t1.halfShift(y, t2, hshift*(2*width-j)) y.sub(x, t1) x.add(x, t1) } } func (z nega) addAt(x nat, i int) { n := len(z) - 1 if i+len(x) <= n { addTo(nat(z[i:]), x) z.norm() return } if i < n { addTo(nat(z[i:]), x[:n-i]) z.norm() x = x[n-i:] i = n } subFrom(nat(z[i-n:]), x) z.norm() } func (z nega) sub1At(i int) { if i >= len(z)-1 { i -= len(z) - 1 addVW(z[i:], z[i:], 1) z.norm() return } subVW(z[i:], z[i:], 1) z.norm() } func fftSize(Nw int, mod bool) (k, K, M, Mw, nw, ω int) { N := int64(Nw) * _W k = max(2, int(math.Log2(math.Sqrt(float64(N))))) K = 1 << k for mod && N%int64(_W*K) != 0 { // if mod then N must be exactly M*K and we want M to be word-aligned k-- K = 1 << k } M = int((N + int64(K) - 1) >> k) M = (M + _W - 1) / _W * _W Mw = int(M / _W) n := (int(k) + 1) + 2*M n = (n + _W - 1) / _W * _W if mod { K2 := K / 2 n = (n + K2 - 1) / K2 * K2 } else { K4 := K / 4 n = (n + K4 - 1) / K4 * K4 } nw = int(n / _W) ω = int(4 * n / K) return } // note: cannot find any way that this makes sense var fftModThreshold = 100000000 func fftMulMod(stk *stack, z, x, y nat, mod bool) { defer stk.restore(stk.save()) const debug = false var debugXV, debugXF, debugYV, debugYF, debugZF, debugZV, debugZVK *negaVec if debug { // z may alias x,y; copy x,y so we can refer to them after z has been computed x = cloneNat(x) y = cloneNat(y) } var Nw int if mod { Nw = len(z) - 1 } else { Nw = 2 * max(len(x), len(y)) } k, K, _, Mw, nw, ω := fftSize(Nw, mod) wid := 1 + nw t1 := nega(stk.nat(wid)) t2 := nega(stk.nat(wid)) xv := newNegaVec(stk, K, wid) twist := 0 if mod { twist = ω / 2 } negaSplit(x, xv, t1, t2, Mw, twist) if debug { debugXV = cloneNegaVec(xv) } forwardFFT(xv, t1, t2, 0, 1, ω, K) if debug { debugXF = cloneNegaVec(xv) } var yv *negaVec if &x[0] == &y[0] { yv = xv if debug { debugYV = debugXV debugYF = debugXF } } else { yv = newNegaVec(stk, K, wid) negaSplit(y, yv, t1, t2, Mw, twist) if debug { debugYV = cloneNegaVec(yv) } forwardFFT(yv, t1, t2, 0, 1, ω, K) if debug { debugYF = cloneNegaVec(yv) } } zv := xv t3 := stk.nat(2 * wid) for i := range K { if wid < fftModThreshold || K < 1024 { zv.at(i).setNat(t3.mul(stk, xv.at(i).norm(), yv.at(i).norm())) } else { fftMulMod(stk, nat(zv.at(i)), xv.at(i).norm(), yv.at(i).norm(), true) } } if debug { debugZF = cloneNegaVec(zv) } backwardFFT(zv, t1, t2, 0, 1, ω, K) if debug { debugZV = cloneNegaVec(zv) debugZVK = cloneNegaVec(zv) } if mod { z := nega(z) clear(z) for i := range K { t1.halfShift(zv.at(i), t2, -(2*k + twist*i)) t1.norm() if debug { debugZVK.at(i).set(t1) } z.addAt(nat(t1), i*Mw) if t1[2*Mw] > Word(i) { z.sub1At(i * Mw) z.sub1At(i*Mw + len(t1) - 1) } } z.norm() } else { z = z.make(len(x) + len(y)) clear(z) for i := range K { if j := i * int(Mw); j < len(z) || debug { t1.shift(zv.at(i), -k) zi := t1.norm() if debug { debugZVK.at(i).set(t1) if j >= len(z) { continue } } addTo(z[j:], zi[:min(len(zi), len(z)-j)]) } } } // Debug mode: double-check answer and print trace on failure. if debug { z1 := make(nat, len(x)+len(y)) basicMul(z1, x, y) z1 = z1.norm() if mod { t := make(nega, len(z)) t.setNat(z1) z1 = t.norm() } if z1.cmp(z.norm()) != 0 { print("fft wrong\n") print("ivy -f natmul.ivy <<EOF\n") if mod { print("fftmod=1\n") } else { print("fftmod=0\n") } print("W=", _W, "\n") print("N=", int64(Nw)*_W, "\n") print("k=", k, "\n") print("K=", K, "\n") print("M=", int64(Mw)*_W, "\n") print("n=", int64(nw)*_W, "\n") print("ω=", ω, "/2\n") print("θ=", twist, "/2\n") print("wid=", wid, "\n") trace("x", &Int{abs: x}) trace("y", &Int{abs: y}) traceNegaVec("xv", debugXV) traceNegaVec("yv", debugYV) traceNegaVec("xf", debugXF) traceNegaVec("yf", debugYF) traceNegaVec("zf", debugZF) traceNegaVec("zv", debugZV) traceNegaVec("zvk", debugZVK) trace("z", &Int{abs: z}) trace("z1", &Int{abs: z1}) print("debugFFT 0\n") print("EOF\n") panic("fft") } } } // traceNegaVec prints a debugging trace line name = v. func traceNegaVec(name string, v *negaVec) { print(name, "=(") for i := range v.len { x := v.at(i) if i > 0 { print(", ") } print(ifmt(x.debugInt())) } print(")\n") } func cloneNegaVec(v *negaVec) *negaVec { out := newNegaVec(getStack(), v.len, v.wid) copy(out.words, v.words) return out } func cloneNat(x nat) nat { z := make(nat, len(x)) copy(z, x) return z }
// Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Multiplication. package big // Algorithmic thresholds; see calibrate_test.go. // //go:generate go run toomgen.go -o nattoom.go 3 4 8 var ( // Multiplication fftThreshold = 5000 toom12Threshold = 1000 toom8Threshold = 500 toom4Threshold = 250 toom3Threshold = 150 karatsubaThreshold = 40 // Squaring fftSqrThreshold = 4000 toom12SqrThreshold = 1000 toom8SqrThreshold = 500 toom4SqrThreshold = 250 toom3SqrThreshold = 150 karatsubaSqrThreshold = 80 basicSqrThreshold = 12 ) // mul sets z = x*y, using stk for temporary storage. // The caller may pass stk == nil to request that mul obtain and release one itself. func (z nat) mul(stk *stack, x, y nat) nat { m := len(x) n := len(y) switch { case m < n: return z.mul(stk, y, x) case m == 0 || n == 0: return z[:0] case n == 1: return z.mulAddWW(x, y[0], 0) case m == n && &x[0] == &y[0]: return z.sqr(stk, x) } // m >= n > 1 // determine if z can be reused if alias(z, x) || alias(z, y) { z = nil // z is an alias for x or y - cannot reuse } z = z.make(m + n) // use basic multiplication if the numbers are small if n < karatsubaThreshold { basicMul(z, x, y) return z.norm() } if stk == nil { stk = getStack() defer stk.free() } // Let x = x1:x0 where x0 is the same length as y. // Compute z = x0*y and then add in x1*y in sections // if needed. if n >= fftThreshold { fftMulMod(stk, z, x[:n], y, false) } else if n >= toom12Threshold { toom12(stk, z[:2*n], x[:n], y) } else if n >= toom8Threshold { toom8(stk, z[:2*n], x[:n], y) } else if n >= toom4Threshold { toom4(stk, z[:2*n], x[:n], y) } else if n >= toom3Threshold { toom3(stk, z[:2*n], x[:n], y) } else { karatsuba(stk, z[:2*n], x[:n], y) } if n < m { clear(z[2*n:]) defer stk.restore(stk.save()) t := stk.nat(2 * n) for i := n; i < m; i += n { t = t.mul(stk, x[i:min(i+n, len(x))], y) addTo(z[i:], t) } } return z.norm() } // sqr sets z = x*x, using stk for temporary storage. // The caller may pass stk == nil to request that sqr obtain and release one itself. func (z nat) sqr(stk *stack, x nat) nat { n := len(x) switch { case n == 0: return z[:0] case n == 1: d := x[0] z = z.make(2) z[1], z[0] = mulWW(d, d) return z.norm() } if alias(z, x) { z = nil // z is an alias for x - cannot reuse } z = z.make(2 * n) if n < basicSqrThreshold && n < karatsubaSqrThreshold { basicMul(z, x, x) return z.norm() } if stk == nil { stk = getStack() defer stk.free() } if n < karatsubaSqrThreshold { basicSqr(stk, z, x) return z.norm() } if n >= fftSqrThreshold { fftMulMod(stk, z, x, x, false) } else if n >= toom8SqrThreshold { toom8Sqr(stk, z, x) } else if n >= toom4SqrThreshold { toom4Sqr(stk, z, x) } else if n >= toom3SqrThreshold { toom3Sqr(stk, z, x) } else { karatsubaSqr(stk, z, x) } return z.norm() } // basicSqr sets z = x*x and is asymptotically faster than basicMul // by about a factor of 2, but slower for small arguments due to overhead. // Requirements: len(x) > 0, len(z) == 2*len(x) // The (non-normalized) result is placed in z. func basicSqr(stk *stack, z, x nat) { n := len(x) if n < basicSqrThreshold { basicMul(z, x, x) return } defer stk.restore(stk.save()) t := stk.nat(2 * n) clear(t) z[1], z[0] = mulWW(x[0], x[0]) // the initial square for i := 1; i < n; i++ { d := x[i] // z collects the squares x[i] * x[i] z[2*i+1], z[2*i] = mulWW(d, d) // t collects the products x[i] * x[j] where j < i t[2*i] = addMulVVWW(t[i:2*i], t[i:2*i], x[0:i], d, 0) } t[2*n-1] = lshVU(t[1:2*n-1], t[1:2*n-1], 1) // double the j < i products addVV(z, z, t) // combine the result } // mulAddWW returns z = x*y + r. func (z nat) mulAddWW(x nat, y, r Word) nat { m := len(x) if m == 0 || y == 0 { return z.setWord(r) // result is r } // m > 0 z = z.make(m + 1) z[m] = mulAddVWW(z[0:m], x, y, r) return z.norm() } // basicMul multiplies x and y and leaves the result in z. // The (non-normalized) result is placed in z[0 : len(x) + len(y)]. func basicMul(z, x, y nat) { clear(z[0 : len(x)+len(y)]) // initialize z for i, d := range y { if d != 0 { z[len(x)+i] = addMulVVWW(z[i:i+len(x)], z[i:i+len(x)], x, d, 0) } } } // karatsuba multiplies x and y, // writing the (non-normalized) result to z. // x and y must have the same length n, // and z must have length twice that. func karatsuba(stk *stack, z, x, y nat) { n := len(y) if len(x) != n || len(z) != 2*n { panic("bad karatsuba length") } // Fall back to basic algorithm if small enough. if n < karatsubaThreshold || n < 2 { basicMul(z, x, y) return } // Let the notation x1:x0 denote the nat (x1<<N)+x0 for some N, // and similarly z2:z1:z0 = (z2<<2N)+(z1<<N)+z0. // // (Note that z0, z1, z2 might be ≥ 2**N, in which case the high // bits of, say, z0 are being added to the low bits of z1 in this notation.) // // Karatsuba multiplication is based on the observation that // // x1:x0 * y1:y0 = x1*y1:(x0*y1+y0*x1):x0*y0 // = x1*y1:((x0-x1)*(y1-y0)+x1*y1+x0*y0):x0*y0 // // The second form uses only three half-width multiplications // instead of the four that the straightforward first form does. // // We call the three pieces z0, z1, z2: // // z0 = x0*y0 // z2 = x1*y1 // z1 = (x0-x1)*(y1-y0) + z0 + z2 n2 := (n + 1) / 2 x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()} y0, y1 := &Int{abs: y[:n2].norm()}, &Int{abs: y[n2:].norm()} z0 := &Int{abs: z[0 : 2*n2]} z2 := &Int{abs: z[2*n2:]} // Allocate temporary storage for z1; repurpose z0 to hold tx and ty. tx := &Int{abs: z[n2 : 2*n2]} ty := &Int{abs: z[2*n2 : 3*n2]} defer stk.restore(stk.save()) z1 := &Int{abs: stk.nat(2*n2 + 1)} tx.Sub(x0, x1) ty.Sub(y1, y0) z1.mul(stk, tx, ty) clear(z) z0.mul(stk, x0, y0) z2.mul(stk, x1, y1) z1.Add(z1, z0) z1.Add(z1, z2) addTo(z[n2:], z1.abs) // Debug mode: double-check answer and print trace on failure. const debug = false if debug { zz := make(nat, len(z)) basicMul(zz, x, y) if z.cmp(zz) != 0 { // All the temps were aliased to z and gone. Recompute. z0 = new(Int) z0.mul(stk, x0, y0) tx = new(Int).Sub(x1, x0) ty = new(Int).Sub(y0, y1) z2 = new(Int) z2.mul(stk, x1, y1) print("karatsuba wrong\n") trace("x ", &Int{abs: x}) trace("y ", &Int{abs: y}) trace("z ", &Int{abs: z}) trace("zz", &Int{abs: zz}) trace("x0", x0) trace("x1", x1) trace("y0", y0) trace("y1", y1) trace("tx", tx) trace("ty", ty) trace("z0", z0) trace("z1", z1) trace("z2", z2) panic("karatsuba") } } } // karatsubaSqr squares x, // writing the (non-normalized) result to z. // z must have length 2*len(x). // It is analogous to [karatsuba] but can run faster // knowing both multiplicands are the same value. func karatsubaSqr(stk *stack, z, x nat) { n := len(x) if len(z) != 2*n { panic("bad karatsubaSqr length") } if n < karatsubaSqrThreshold || n < 2 { basicSqr(stk, z, x) return } // Recall that for karatsuba we want to compute: // // x1:x0 * y1:y0 = x1y1:(x0y1+y0x1):x0y0 // = x1y1:((x0-x1)*(y1-y0)+x1y1+x0y0):x0y0 // = z2:z1:z0 // where: // // z0 = x0y0 // z2 = x1y1 // z1 = (x0-x1)*(y1-y0) + z0 + z2 // // When x = y, these simplify to: // // z0 = x0² // z2 = x1² // z1 = z0 + z2 - (x0-x1)² n2 := (n + 1) / 2 x0, x1 := &Int{abs: x[:n2].norm()}, &Int{abs: x[n2:].norm()} z0 := &Int{abs: z[0 : 2*n2]} z2 := &Int{abs: z[2*n2:]} // Allocate temporary storage for z1; repurpose z0 to hold tx. defer stk.restore(stk.save()) z1 := &Int{abs: stk.nat(2*n2 + 1)} tx := &Int{abs: z[0:n2]} tx.Sub(x0, x1) z1.abs = z1.abs.sqr(stk, tx.abs) z1.neg = true clear(z) z0.abs = z0.abs.sqr(stk, x0.abs) z2.abs = z2.abs.sqr(stk, x1.abs) z1.Add(z1, z0) z1.Add(z1, z2) addTo(z[n2:], z1.abs) // Debug mode: double-check answer and print trace on failure. const debug = false if debug { zz := make(nat, len(z)) basicSqr(stk, zz, x) if z.cmp(zz) != 0 { // All the temps were aliased to z and gone. Recompute. tx = new(Int).Sub(x0, x1) z0 = new(Int).Mul(x0, x0) z2 = new(Int).Mul(x1, x1) z1 = new(Int).Mul(tx, tx) z1.Neg(z1) z1.Add(z1, z0) z1.Add(z1, z2) print("karatsubaSqr wrong\n") trace("x ", &Int{abs: x}) trace("z ", &Int{abs: z}) trace("zz", &Int{abs: zz}) trace("x0", x0) trace("x1", x1) trace("z0", z0) trace("z1", z1) trace("z2", z2) panic("karatsubaSqr") } } } // ifmt returns the debug formatting of the Int x: 0xHEX. func ifmt(x *Int) string { neg, s, t := "", x.Text(16), "" if s == "" { // happens for denormalized zero s = "0" } if s[0] == '-' { neg, s = "-", s[1:] } // Add _ between words. const D = _W / 4 // digits per chunk for len(s) > D { s, t = s[:len(s)-D], "_"+s[len(s)-D:]+t } return neg + "0x" + s + t } // trace prints a single debug value. func trace(name string, x *Int) { print(name, "=", ifmt(x), "\n") } // addSub sets za = x+y, zs = x-y. // It is permitted for za and zs to alias either of x or y (or both). func addSub(za, zs, x, y *Int) { // Copy x and y fields so that writing to za/zs does not modify them. // Still need to be careful about aliasing of the abs backing arrays. xNeg := x.neg yNeg := y.neg xAbs := x.abs yAbs := y.abs switch xAbs.cmp(yAbs) { case 0: // |x| == |y|, so one of za, zs is 0. if xNeg == yNeg { za.Add(x, y) zs.abs = zs.abs[:0] zs.neg = false } else { zs.Sub(x, y) za.abs = za.abs[:0] } return case +1: // |x| > |y|, so sign of x dominates za.neg = xNeg zs.neg = xNeg case -1: // |y| > |x|, so sign of ±y dominates za.neg = yNeg zs.neg = !yNeg xAbs, yAbs = yAbs, xAbs // now xAbs ≥ yAbs } if xNeg == yNeg { addSubAbs(za, zs, xAbs, yAbs) } else { addSubAbs(zs, za, xAbs, yAbs) } } // addSubAbs sets za.abs = x+y and zs.abs = x-y, // clearing neg when abs is 0. func addSubAbs(za, zs *Int, x, y nat) { zaa := za.abs.make(len(x) + 1) zsa := zs.abs.make(len(x)) yn := len(y) ca, cs := addSubVV(zaa[:yn], zsa[:yn], x[:yn], y) if &zaa[0] == &x[0] { subVW(zsa[yn:], x[yn:], cs) zaa[len(x)] = addVW(zaa[yn:len(x)], x[yn:], ca) } else { zaa[len(x)] = addVW(zaa[yn:len(x)], x[yn:], ca) subVW(zsa[yn:], x[yn:], cs) } if za.abs = zaa.norm(); len(za.abs) == 0 { za.neg = false } if zs.abs = zsa.norm(); len(zs.abs) == 0 { zs.neg = false } } // addLsh sets z = x + y<<s, 1 ≤ s ≤ _W-1. func addLsh(z, x, y *Int, s uint) { if x.neg || y.neg { panic("addLsh negative values") } if s <= 0 || s >= _W { panic("addLsh shift") } xa := x.abs ya := y.abs za := z.abs xn := len(x.abs) yn := len(y.abs) zn := max(xn, yn) + 1 za = za.make(zn) // Set z[0:] = x[0:] + y<<s until x[0:] or y runs out after k words (k = min(xn, yn)). // Then either z[k:] = x[k:] + c or z[k:] = y[k:]<<s + c. if xn >= yn { c := addLshVVU(za[:yn], xa[0:yn], ya, s) za[len(za)-1] = addVW(za[yn:len(za)-1], xa[yn:], c) } else { c0 := addLshVVU(za[:xn], xa[0:], ya[:xn], s) // Could fuse lshVU and addVW into lshAddVUW but not a hot spot, // especially since addVW will almost always stop early. c1 := lshVU(za[xn:len(za)-1], ya[xn:], s) c2 := addVW(za[xn:len(za)-1], za[xn:len(za)-1], c0) // c1+c2 does not overflow: c1 < 1<<(_W-1) = _B/2, and c2 ≤ 1. za[len(za)-1] = c1 + c2 } z.abs = za.norm() z.neg = false } // subLsh sets z = x - y<<s, 1 ≤ s ≤ _W-1, assuming z ≥ 0. func subLsh(z, x, y *Int, s uint) { if x.neg || y.neg { panic("subLsh negative values") } if s <= 0 || s >= _W { panic("subLsh shift") } xa := x.abs ya := y.abs za := z.abs xn := len(x.abs) yn := len(y.abs) if yn > xn { panic("subLsh out of range") } za = za.make(xn) c := subLshVVU(za[:yn], xa[:yn], ya, s) subVW(za[yn:], xa[yn:], c) z.abs = za.norm() z.neg = false } // addMul sets z = x + y*m. func addMul(z, x, y *Int, m Word) { if x.neg || y.neg { panic("addMul negative values") } xa := x.abs ya := y.abs za := z.abs xn := len(x.abs) yn := len(y.abs) if xn >= yn { za = za.make(xn + 1) c := addMulVVWW(za[:yn], xa[:yn], ya, m, 0) za[xn] = addVW(za[yn:xn], xa[yn:], c) } else { za = za.make(yn + 1) c := addMulVVWW(za[:xn], xa, ya[:xn], m, 0) za[yn] = mulAddVWW(za[xn:yn], ya[xn:], m, c) } z.abs = za.norm() z.neg = false } // subMul sets z = x - y*m, assuming z ≥ 0. func subMul(z, x, y *Int, m Word) { if x.neg || y.neg { z.abs = z.abs[:0] z.neg = false return panic("subMul negative values") } xa := x.abs ya := y.abs za := z.abs xn := len(x.abs) yn := len(y.abs) if yn > xn { z.abs = z.abs[:0] z.neg = false return panic("subMul out of range") } za = za.make(xn) c := subMulVVWW(za[:yn], xa[:yn], ya, m, 0) c = subVW(za[yn:], xa[yn:], c) if c != 0 { z.abs = z.abs[:0] z.neg = false return println("SM", len(za), xn, yn, m, c) panic("subMul out of range") } z.abs = za.norm() z.neg = false } // mdiv sets z = x/d when x%d==0 and d is odd. // mdiv does a modular division of x by d in the ring ℤ[_Bⁿ] for n=len(x). // When z is evenly divisible by d, this has the usual integer result. // The additional parameter inv must be the modular inverse of d in ℤ[_Bⁿ], // which can be computed using, for example: // // ivy -f natmul.ivy -e 'inverse 5' // // Note that: // // d * inv = 1 (mod 2⁶⁴) // ⇒ d * inv = 1 (mod 2³²) // ⇒ uint32(d) * uint32(inv) = 1 (mod 2³²) // // so for d < 2³² it works on both 32- and 64-bit platforms to pass inv & _M, // where inv is the 64-bit inverse; on 32-bit systems, the mask will truncate // to the 32-bit inverse. func mdiv(z, x *Int, d, inv Word) { za := z.abs.make(len(x.abs)) mdivVWW(za, x.abs, d, inv) z.abs = za.norm() z.neg = x.neg } // mdiv64 is like mdiv but accepts a 64-bit divisor and its 32-bit factorization d = p*q. // dinv is the inverse of d mod 2⁶⁴ and pinv, qinv are the inverses of p, q mod 2³². func mdiv64(z, x *Int, d, inv uint64, p, pinv, q, qinv uint32) { if _W == 64 { mdiv(z, x, Word(d), Word(inv)) } else { mdiv(z, x, Word(p), Word(pinv)) mdiv(z, z, Word(q), Word(qinv)) } } // mdivVV_g is the actual modular division in mdiv. // It requires len(z) == len(x) and d * inv = 1 mod _B. // It sets z and returns c such that x + c * _B**len(z) = d*z, // which is to say z = x/d (mod _B**len(z)). // See Brent and Zimmermann, Modern Computer Arithmetic, section 1.4.7. func mdivVWW_g(z, x []Word, d, inv Word) (c Word) { if len(z) != len(x) { panic("mdiv len") } if d*inv != 1 { panic("mdiv inv") } for i := 0; i < len(z); i++ { s, cs := subWW(x[i], c, 0) zi := s * inv z[i] = zi cm, _ := mulWW(zi, d) c = cs + cm } return c }
// Copyright 2025 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Code generated by 'go generate'. DO NOT EDIT. package big func toom3(stk *stack, z, x, y nat) { const debug = false // avoid accidental slicing beyond cap z = z[:len(z):len(z)] x = x[:len(x):len(x)] y = y[:len(y):len(y)] n := len(y) if len(x) != n || len(z) != 2*n { panic("bad toom3 len") } // Fall back to simpler algorithm if small enough or too small. if n < toom3Threshold || n < 4*3 { karatsuba(stk, z, x, y) return } defer stk.restore(stk.save()) p := (n + 3 - 1) / 3 x0 := &Int{abs: x[0*p : 1*p].norm()} x1 := &Int{abs: x[1*p : 2*p].norm()} x2 := &Int{abs: x[2*p:].norm()} y0 := &Int{abs: y[0*p : 1*p].norm()} y1 := &Int{abs: y[1*p : 2*p].norm()} y2 := &Int{abs: y[2*p:].norm()} z0 := &Int{abs: stk.nat(2*p + (5+_W-1)/_W + 1)} z1 := &Int{abs: stk.nat(2*p + (5+_W-1)/_W + 1)} z2 := &Int{abs: stk.nat(2*p + (5+_W-1)/_W + 1)} z3 := &Int{abs: stk.nat(2*p + (5+_W-1)/_W + 1)} z4 := &Int{abs: stk.nat(2*p + (5+_W-1)/_W + 1)} // z1, z2 = x(-1)*y(-1), x(1)*y(1) z0.Add(x0, x2) addSub(z0, z2, z0, x1) z4.Add(y0, y2) addSub(z4, z3, z4, y1) z1.mul(stk, z2, z3) z2.mul(stk, z0, z4) // z3 = x(2)*y(2) addLsh(z0, x1, x2, 1) addLsh(z0, x0, z0, 1) addLsh(z4, y1, y2, 1) addLsh(z4, y0, z4, 1) z3.mul(stk, z0, z4) // z0 = x(0)*y(0) z0.mul(stk, x0, y0) // z4 = x(∞)*y(∞) z4.mul(stk, x2, y2) var dz0, dz1, dz2, dz3, dz4 *Int if debug { dz0 = new(Int).Set(z0) dz1 = new(Int).Set(z1) dz2 = new(Int).Set(z2) dz3 = new(Int).Set(z3) dz4 = new(Int).Set(z4) } toom3Interp(z, p, z0, z1, z2, z3, z4) if debug { zz := make(nat, len(z)) karatsuba(stk, zz, x, y) if z.cmp(zz) != 0 { print("toom3 wrong\n") print("ivy -f natmul.ivy <<EOF\n") print("W=", _W, "\n") print("p=", p, "\n") trace("z", &Int{abs: z}) trace("zz", &Int{abs: zz}) trace("x", &Int{abs: x}) print("xv=(", ifmt(x0), ", ", ifmt(x1), ", ", ifmt(x2), ")\n") trace("y", &Int{abs: y}) print("yv=(", ifmt(y0), ", ", ifmt(y1), ", ", ifmt(y2), ")\n") print("zv=(", ifmt(dz0), ", ", ifmt(dz1), ", ", ifmt(dz2), ", ", ifmt(dz3), ", ", ifmt(dz4), ")\n") print("izv=(", ifmt(z0), ", ", ifmt(z1), ", ", ifmt(z2), ", ", ifmt(z3), ", ", ifmt(z4), ")\n") print("debugToom 3\n") print("EOF\n") panic("toom3") } } } func toom3Sqr(stk *stack, z, x nat) { const debug = false // avoid accidental slicing beyond cap z = z[:len(z):len(z)] x = x[:len(x):len(x)] n := len(x) if len(z) != 2*n { panic("bad toom3Sqr len") } // Fall back to simpler algorithm if small enough or too small. if n < toom3SqrThreshold || n < 4*3 { karatsubaSqr(stk, z, x) return } defer stk.restore(stk.save()) p := (n + 3 - 1) / 3 x0 := &Int{abs: x[0*p : 1*p].norm()} x1 := &Int{abs: x[1*p : 2*p].norm()} x2 := &Int{abs: x[2*p:].norm()} z0 := &Int{abs: stk.nat(2*p + (5+_W-1)/_W + 1)} z1 := &Int{abs: stk.nat(2*p + (5+_W-1)/_W + 1)} z2 := &Int{abs: stk.nat(2*p + (5+_W-1)/_W + 1)} z3 := &Int{abs: stk.nat(2*p + (5+_W-1)/_W + 1)} z4 := &Int{abs: stk.nat(2*p + (5+_W-1)/_W + 1)} // z1, z2 = x(-1)², x(1)² z0.Add(x0, x2) addSub(z0, z4, z0, x1) z1.mul(stk, z4, z4) z2.mul(stk, z0, z0) // z3 = x(2)² addLsh(z0, x1, x2, 1) addLsh(z0, x0, z0, 1) z3.mul(stk, z0, z0) // z0 = x(0)² z0.mul(stk, x0, x0) // z4 = x(∞)² z4.mul(stk, x2, x2) var dz0, dz1, dz2, dz3, dz4 *Int if debug { dz0 = new(Int).Set(z0) dz1 = new(Int).Set(z1) dz2 = new(Int).Set(z2) dz3 = new(Int).Set(z3) dz4 = new(Int).Set(z4) } toom3Interp(z, p, z0, z1, z2, z3, z4) if debug { zz := make(nat, len(z)) karatsuba(stk, zz, x, x) if z.cmp(zz) != 0 { print("toom3 wrong\n") print("ivy -f natmul.ivy <<EOF\n") print("W=", _W, "\n") print("p=", p, "\n") trace("z", &Int{abs: z}) trace("zz", &Int{abs: zz}) trace("x", &Int{abs: x}) print("xv=(", ifmt(x0), ", ", ifmt(x1), ", ", ifmt(x2), ")\n") print("zv=(", ifmt(dz0), ", ", ifmt(dz1), ", ", ifmt(dz2), ", ", ifmt(dz3), ", ", ifmt(dz4), ")\n") print("izv=(", ifmt(z0), ", ", ifmt(z1), ", ", ifmt(z2), ", ", ifmt(z3), ", ", ifmt(z4), ")\n") print("debugToomSqr 3\n") print("EOF\n") panic("toom3Sqr") } } } func toom3Interp(z nat, p int, z0, z1, z2, z3, z4 *Int) { addSub(z2, z1, z2, z1) subLsh(z2, z2, z0, 1) z3.Sub(z3, z0) z3.Sub(z3, z1) subLsh(z3, z3, z2, 1) z1.Rsh(z1, 1) z2.Rsh(z2, 1) z3.Rsh(z3, 1) mdiv(z3, z3, 3, 0xaaaaaaaaaaaaaaab&_M) subLsh(z3, z3, z4, 1) z2.Sub(z2, z4) z1.Sub(z1, z3) clear(z) addTo(z[0*p:], z0.abs) addTo(z[1*p:], z1.abs) addTo(z[2*p:], z2.abs) addTo(z[3*p:], z3.abs) addTo(z[4*p:], z4.abs) } func toom4(stk *stack, z, x, y nat) { const debug = false // avoid accidental slicing beyond cap z = z[:len(z):len(z)] x = x[:len(x):len(x)] y = y[:len(y):len(y)] n := len(y) if len(x) != n || len(z) != 2*n { panic("bad toom4 len") } // Fall back to simpler algorithm if small enough or too small. if n < toom4Threshold || n < 4*4 { toom3(stk, z, x, y) return } defer stk.restore(stk.save()) p := (n + 4 - 1) / 4 x0 := &Int{abs: x[0*p : 1*p].norm()} x1 := &Int{abs: x[1*p : 2*p].norm()} x2 := &Int{abs: x[2*p : 3*p].norm()} x3 := &Int{abs: x[3*p:].norm()} y0 := &Int{abs: y[0*p : 1*p].norm()} y1 := &Int{abs: y[1*p : 2*p].norm()} y2 := &Int{abs: y[2*p : 3*p].norm()} y3 := &Int{abs: y[3*p:].norm()} z0 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} z1 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} z2 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} z3 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} z4 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} z5 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} z6 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} // z1, z2 = x(-1)*y(-1), x(1)*y(1) z0.Add(x0, x2) z2.Add(x1, x3) addSub(z0, z2, z0, z2) z6.Add(y0, y2) z5.Add(y1, y3) addSub(z6, z5, z6, z5) z1.mul(stk, z2, z5) z2.mul(stk, z0, z6) // z3, z4 = x(-2)*y(-2), x(2)*y(2) addLsh(z0, x0, x2, 2) addLsh(z4, x1, x3, 2) z4.Lsh(z4, 1) addSub(z0, z4, z0, z4) addLsh(z6, y0, y2, 2) addLsh(z5, y1, y3, 2) z5.Lsh(z5, 1) addSub(z6, z5, z6, z5) z3.mul(stk, z4, z5) z4.mul(stk, z0, z6) // z5 = x(4)*y(4) addLsh(z0, x2, x3, 2) addLsh(z0, x1, z0, 2) addLsh(z0, x0, z0, 2) addLsh(z6, y2, y3, 2) addLsh(z6, y1, z6, 2) addLsh(z6, y0, z6, 2) z5.mul(stk, z0, z6) // z0 = x(0)*y(0) z0.mul(stk, x0, y0) // z6 = x(∞)*y(∞) z6.mul(stk, x3, y3) var dz0, dz1, dz2, dz3, dz4, dz5, dz6 *Int if debug { dz0 = new(Int).Set(z0) dz1 = new(Int).Set(z1) dz2 = new(Int).Set(z2) dz3 = new(Int).Set(z3) dz4 = new(Int).Set(z4) dz5 = new(Int).Set(z5) dz6 = new(Int).Set(z6) } toom4Interp(z, p, z0, z1, z2, z3, z4, z5, z6) if debug { zz := make(nat, len(z)) toom3(stk, zz, x, y) if z.cmp(zz) != 0 { print("toom4 wrong\n") print("ivy -f natmul.ivy <<EOF\n") print("W=", _W, "\n") print("p=", p, "\n") trace("z", &Int{abs: z}) trace("zz", &Int{abs: zz}) trace("x", &Int{abs: x}) print("xv=(", ifmt(x0), ", ", ifmt(x1), ", ", ifmt(x2), ", ", ifmt(x3), ")\n") trace("y", &Int{abs: y}) print("yv=(", ifmt(y0), ", ", ifmt(y1), ", ", ifmt(y2), ", ", ifmt(y3), ")\n") print("zv=(", ifmt(dz0), ", ", ifmt(dz1), ", ", ifmt(dz2), ", ", ifmt(dz3), ", ", ifmt(dz4), ", ", ifmt(dz5), ", ", ifmt(dz6), ")\n") print("izv=(", ifmt(z0), ", ", ifmt(z1), ", ", ifmt(z2), ", ", ifmt(z3), ", ", ifmt(z4), ", ", ifmt(z5), ", ", ifmt(z6), ")\n") print("debugToom 4\n") print("EOF\n") panic("toom4") } } } func toom4Sqr(stk *stack, z, x nat) { const debug = false // avoid accidental slicing beyond cap z = z[:len(z):len(z)] x = x[:len(x):len(x)] n := len(x) if len(z) != 2*n { panic("bad toom4Sqr len") } // Fall back to simpler algorithm if small enough or too small. if n < toom4SqrThreshold || n < 4*4 { toom3Sqr(stk, z, x) return } defer stk.restore(stk.save()) p := (n + 4 - 1) / 4 x0 := &Int{abs: x[0*p : 1*p].norm()} x1 := &Int{abs: x[1*p : 2*p].norm()} x2 := &Int{abs: x[2*p : 3*p].norm()} x3 := &Int{abs: x[3*p:].norm()} z0 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} z1 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} z2 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} z3 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} z4 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} z5 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} z6 := &Int{abs: stk.nat(2*p + (13+_W-1)/_W + 1)} // z1, z2 = x(-1)², x(1)² z0.Add(x0, x2) z6.Add(x1, x3) addSub(z0, z6, z0, z6) z1.mul(stk, z6, z6) z2.mul(stk, z0, z0) // z3, z4 = x(-2)², x(2)² addLsh(z0, x0, x2, 2) addLsh(z6, x1, x3, 2) z6.Lsh(z6, 1) addSub(z0, z6, z0, z6) z3.mul(stk, z6, z6) z4.mul(stk, z0, z0) // z5 = x(4)² addLsh(z0, x2, x3, 2) addLsh(z0, x1, z0, 2) addLsh(z0, x0, z0, 2) z5.mul(stk, z0, z0) // z0 = x(0)² z0.mul(stk, x0, x0) // z6 = x(∞)² z6.mul(stk, x3, x3) var dz0, dz1, dz2, dz3, dz4, dz5, dz6 *Int if debug { dz0 = new(Int).Set(z0) dz1 = new(Int).Set(z1) dz2 = new(Int).Set(z2) dz3 = new(Int).Set(z3) dz4 = new(Int).Set(z4) dz5 = new(Int).Set(z5) dz6 = new(Int).Set(z6) } toom4Interp(z, p, z0, z1, z2, z3, z4, z5, z6) if debug { zz := make(nat, len(z)) toom3(stk, zz, x, x) if z.cmp(zz) != 0 { print("toom4 wrong\n") print("ivy -f natmul.ivy <<EOF\n") print("W=", _W, "\n") print("p=", p, "\n") trace("z", &Int{abs: z}) trace("zz", &Int{abs: zz}) trace("x", &Int{abs: x}) print("xv=(", ifmt(x0), ", ", ifmt(x1), ", ", ifmt(x2), ", ", ifmt(x3), ")\n") print("zv=(", ifmt(dz0), ", ", ifmt(dz1), ", ", ifmt(dz2), ", ", ifmt(dz3), ", ", ifmt(dz4), ", ", ifmt(dz5), ", ", ifmt(dz6), ")\n") print("izv=(", ifmt(z0), ", ", ifmt(z1), ", ", ifmt(z2), ", ", ifmt(z3), ", ", ifmt(z4), ", ", ifmt(z5), ", ", ifmt(z6), ")\n") print("debugToomSqr 4\n") print("EOF\n") panic("toom4Sqr") } } } func toom4Interp(z nat, p int, z0, z1, z2, z3, z4, z5, z6 *Int) { addSub(z2, z1, z2, z1) addSub(z4, z3, z4, z3) subLsh(z2, z2, z0, 1) subLsh(z3, z3, z1, 1) subLsh(z4, z4, z0, 1) subLsh(z4, z4, z2, 2) z5.Sub(z5, z0) subLsh(z5, z5, z1, 1) subLsh(z5, z5, z2, 3) subMul(z5, z5, z3, 5) subMul(z5, z5, z4, 10) z1.Rsh(z1, 1) z2.Rsh(z2, 1) z3.Rsh(z3, 2) mdiv(z3, z3, 3, 0xaaaaaaaaaaaaaaab&_M) z4.Rsh(z4, 3) mdiv(z4, z4, 3, 0xaaaaaaaaaaaaaaab&_M) z5.Rsh(z5, 4) mdiv(z5, z5, 45, 0x4fa4fa4fa4fa4fa5&_M) subLsh(z5, z5, z6, 2) subMul(z4, z4, z6, 5) subMul(z3, z3, z5, 5) z2.Sub(z2, z6) z2.Sub(z2, z4) z1.Sub(z1, z5) z1.Sub(z1, z3) clear(z) addTo(z[0*p:], z0.abs) addTo(z[1*p:], z1.abs) addTo(z[2*p:], z2.abs) addTo(z[3*p:], z3.abs) addTo(z[4*p:], z4.abs) addTo(z[5*p:], z5.abs) addTo(z[6*p:], z6.abs) } func toom8(stk *stack, z, x, y nat) { const debug = false // avoid accidental slicing beyond cap z = z[:len(z):len(z)] x = x[:len(x):len(x)] y = y[:len(y):len(y)] n := len(y) if len(x) != n || len(z) != 2*n { panic("bad toom8 len") } // Fall back to simpler algorithm if small enough or too small. if n < toom8Threshold || n < 4*8 { toom4(stk, z, x, y) return } defer stk.restore(stk.save()) p := (n + 8 - 1) / 8 x0 := &Int{abs: x[0*p : 1*p].norm()} x1 := &Int{abs: x[1*p : 2*p].norm()} x2 := &Int{abs: x[2*p : 3*p].norm()} x3 := &Int{abs: x[3*p : 4*p].norm()} x4 := &Int{abs: x[4*p : 5*p].norm()} x5 := &Int{abs: x[5*p : 6*p].norm()} x6 := &Int{abs: x[6*p : 7*p].norm()} x7 := &Int{abs: x[7*p:].norm()} y0 := &Int{abs: y[0*p : 1*p].norm()} y1 := &Int{abs: y[1*p : 2*p].norm()} y2 := &Int{abs: y[2*p : 3*p].norm()} y3 := &Int{abs: y[3*p : 4*p].norm()} y4 := &Int{abs: y[4*p : 5*p].norm()} y5 := &Int{abs: y[5*p : 6*p].norm()} y6 := &Int{abs: y[6*p : 7*p].norm()} y7 := &Int{abs: y[7*p:].norm()} z0 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z1 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z2 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z3 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z4 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z5 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z6 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z7 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z8 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z9 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z10 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z11 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z12 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z13 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z14 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} // z1, z2 = x(-1)*y(-1), x(1)*y(1) z0.Add(x4, x6) z0.Add(z0, x2) z0.Add(z0, x0) z2.Add(x5, x7) z2.Add(z2, x3) z2.Add(z2, x1) addSub(z0, z2, z0, z2) z14.Add(y4, y6) z14.Add(z14, y2) z14.Add(z14, y0) z13.Add(y5, y7) z13.Add(z13, y3) z13.Add(z13, y1) addSub(z14, z13, z14, z13) z1.mul(stk, z2, z13) z2.mul(stk, z0, z14) // z3, z4 = x(-2)*y(-2), x(2)*y(2) addLsh(z0, x4, x6, 2) addLsh(z0, x2, z0, 2) addLsh(z0, x0, z0, 2) addLsh(z4, x5, x7, 2) addLsh(z4, x3, z4, 2) addLsh(z4, x1, z4, 2) z4.Lsh(z4, 1) addSub(z0, z4, z0, z4) addLsh(z14, y4, y6, 2) addLsh(z14, y2, z14, 2) addLsh(z14, y0, z14, 2) addLsh(z13, y5, y7, 2) addLsh(z13, y3, z13, 2) addLsh(z13, y1, z13, 2) z13.Lsh(z13, 1) addSub(z14, z13, z14, z13) z3.mul(stk, z4, z13) z4.mul(stk, z0, z14) // z5, z6 = x(-4)*y(-4), x(4)*y(4) addLsh(z0, x4, x6, 4) addLsh(z0, x2, z0, 4) addLsh(z0, x0, z0, 4) addLsh(z6, x5, x7, 4) addLsh(z6, x3, z6, 4) addLsh(z6, x1, z6, 4) z6.Lsh(z6, 2) addSub(z0, z6, z0, z6) addLsh(z14, y4, y6, 4) addLsh(z14, y2, z14, 4) addLsh(z14, y0, z14, 4) addLsh(z13, y5, y7, 4) addLsh(z13, y3, z13, 4) addLsh(z13, y1, z13, 4) z13.Lsh(z13, 2) addSub(z14, z13, z14, z13) z5.mul(stk, z6, z13) z6.mul(stk, z0, z14) // z7, z8 = x(-8)*y(-8), x(8)*y(8) addLsh(z0, x4, x6, 6) addLsh(z0, x2, z0, 6) addLsh(z0, x0, z0, 6) addLsh(z8, x5, x7, 6) addLsh(z8, x3, z8, 6) addLsh(z8, x1, z8, 6) z8.Lsh(z8, 3) addSub(z0, z8, z0, z8) addLsh(z14, y4, y6, 6) addLsh(z14, y2, z14, 6) addLsh(z14, y0, z14, 6) addLsh(z13, y5, y7, 6) addLsh(z13, y3, z13, 6) addLsh(z13, y1, z13, 6) z13.Lsh(z13, 3) addSub(z14, z13, z14, z13) z7.mul(stk, z8, z13) z8.mul(stk, z0, z14) // z9, z10 = x(-16)*y(-16), x(16)*y(16) addLsh(z0, x4, x6, 8) addLsh(z0, x2, z0, 8) addLsh(z0, x0, z0, 8) addLsh(z10, x5, x7, 8) addLsh(z10, x3, z10, 8) addLsh(z10, x1, z10, 8) z10.Lsh(z10, 4) addSub(z0, z10, z0, z10) addLsh(z14, y4, y6, 8) addLsh(z14, y2, z14, 8) addLsh(z14, y0, z14, 8) addLsh(z13, y5, y7, 8) addLsh(z13, y3, z13, 8) addLsh(z13, y1, z13, 8) z13.Lsh(z13, 4) addSub(z14, z13, z14, z13) z9.mul(stk, z10, z13) z10.mul(stk, z0, z14) // z11, z12 = x(-32)*y(-32), x(32)*y(32) addLsh(z0, x4, x6, 10) addLsh(z0, x2, z0, 10) addLsh(z0, x0, z0, 10) addLsh(z12, x5, x7, 10) addLsh(z12, x3, z12, 10) addLsh(z12, x1, z12, 10) z12.Lsh(z12, 5) addSub(z0, z12, z0, z12) addLsh(z14, y4, y6, 10) addLsh(z14, y2, z14, 10) addLsh(z14, y0, z14, 10) addLsh(z13, y5, y7, 10) addLsh(z13, y3, z13, 10) addLsh(z13, y1, z13, 10) z13.Lsh(z13, 5) addSub(z14, z13, z14, z13) z11.mul(stk, z12, z13) z12.mul(stk, z0, z14) // z13 = x(64)*y(64) addLsh(z0, x6, x7, 6) addLsh(z0, x5, z0, 6) addLsh(z0, x4, z0, 6) addLsh(z0, x3, z0, 6) addLsh(z0, x2, z0, 6) addLsh(z0, x1, z0, 6) addLsh(z0, x0, z0, 6) addLsh(z14, y6, y7, 6) addLsh(z14, y5, z14, 6) addLsh(z14, y4, z14, 6) addLsh(z14, y3, z14, 6) addLsh(z14, y2, z14, 6) addLsh(z14, y1, z14, 6) addLsh(z14, y0, z14, 6) z13.mul(stk, z0, z14) // z0 = x(0)*y(0) z0.mul(stk, x0, y0) // z14 = x(∞)*y(∞) z14.mul(stk, x7, y7) var dz0, dz1, dz2, dz3, dz4, dz5, dz6, dz7, dz8, dz9, dz10, dz11, dz12, dz13, dz14 *Int if debug { dz0 = new(Int).Set(z0) dz1 = new(Int).Set(z1) dz2 = new(Int).Set(z2) dz3 = new(Int).Set(z3) dz4 = new(Int).Set(z4) dz5 = new(Int).Set(z5) dz6 = new(Int).Set(z6) dz7 = new(Int).Set(z7) dz8 = new(Int).Set(z8) dz9 = new(Int).Set(z9) dz10 = new(Int).Set(z10) dz11 = new(Int).Set(z11) dz12 = new(Int).Set(z12) dz13 = new(Int).Set(z13) dz14 = new(Int).Set(z14) } toom8Interp(z, p, z0, z1, z2, z3, z4, z5, z6, z7, z8, z9, z10, z11, z12, z13, z14) if debug { zz := make(nat, len(z)) toom4(stk, zz, x, y) if z.cmp(zz) != 0 { print("toom8 wrong\n") print("ivy -f natmul.ivy <<EOF\n") print("W=", _W, "\n") print("p=", p, "\n") trace("z", &Int{abs: z}) trace("zz", &Int{abs: zz}) trace("x", &Int{abs: x}) print("xv=(", ifmt(x0), ", ", ifmt(x1), ", ", ifmt(x2), ", ", ifmt(x3), ", ", ifmt(x4), ", ", ifmt(x5), ", ", ifmt(x6), ", ", ifmt(x7), ")\n") trace("y", &Int{abs: y}) print("yv=(", ifmt(y0), ", ", ifmt(y1), ", ", ifmt(y2), ", ", ifmt(y3), ", ", ifmt(y4), ", ", ifmt(y5), ", ", ifmt(y6), ", ", ifmt(y7), ")\n") print("zv=(", ifmt(dz0), ", ", ifmt(dz1), ", ", ifmt(dz2), ", ", ifmt(dz3), ", ", ifmt(dz4), ", ", ifmt(dz5), ", ", ifmt(dz6), ", ", ifmt(dz7), ", ", ifmt(dz8), ", ", ifmt(dz9), ", ", ifmt(dz10), ", ", ifmt(dz11), ", ", ifmt(dz12), ", ", ifmt(dz13), ", ", ifmt(dz14), ")\n") print("izv=(", ifmt(z0), ", ", ifmt(z1), ", ", ifmt(z2), ", ", ifmt(z3), ", ", ifmt(z4), ", ", ifmt(z5), ", ", ifmt(z6), ", ", ifmt(z7), ", ", ifmt(z8), ", ", ifmt(z9), ", ", ifmt(z10), ", ", ifmt(z11), ", ", ifmt(z12), ", ", ifmt(z13), ", ", ifmt(z14), ")\n") print("debugToom 8\n") print("EOF\n") panic("toom8") } } } func toom8Sqr(stk *stack, z, x nat) { const debug = false // avoid accidental slicing beyond cap z = z[:len(z):len(z)] x = x[:len(x):len(x)] n := len(x) if len(z) != 2*n { panic("bad toom8Sqr len") } // Fall back to simpler algorithm if small enough or too small. if n < toom8SqrThreshold || n < 4*8 { toom4Sqr(stk, z, x) return } defer stk.restore(stk.save()) p := (n + 8 - 1) / 8 x0 := &Int{abs: x[0*p : 1*p].norm()} x1 := &Int{abs: x[1*p : 2*p].norm()} x2 := &Int{abs: x[2*p : 3*p].norm()} x3 := &Int{abs: x[3*p : 4*p].norm()} x4 := &Int{abs: x[4*p : 5*p].norm()} x5 := &Int{abs: x[5*p : 6*p].norm()} x6 := &Int{abs: x[6*p : 7*p].norm()} x7 := &Int{abs: x[7*p:].norm()} z0 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z1 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z2 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z3 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z4 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z5 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z6 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z7 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z8 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z9 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z10 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z11 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z12 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z13 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} z14 := &Int{abs: stk.nat(2*p + (85+_W-1)/_W + 1)} // z1, z2 = x(-1)², x(1)² z0.Add(x4, x6) z0.Add(z0, x2) z0.Add(z0, x0) z14.Add(x5, x7) z14.Add(z14, x3) z14.Add(z14, x1) addSub(z0, z14, z0, z14) z1.mul(stk, z14, z14) z2.mul(stk, z0, z0) // z3, z4 = x(-2)², x(2)² addLsh(z0, x4, x6, 2) addLsh(z0, x2, z0, 2) addLsh(z0, x0, z0, 2) addLsh(z14, x5, x7, 2) addLsh(z14, x3, z14, 2) addLsh(z14, x1, z14, 2) z14.Lsh(z14, 1) addSub(z0, z14, z0, z14) z3.mul(stk, z14, z14) z4.mul(stk, z0, z0) // z5, z6 = x(-4)², x(4)² addLsh(z0, x4, x6, 4) addLsh(z0, x2, z0, 4) addLsh(z0, x0, z0, 4) addLsh(z14, x5, x7, 4) addLsh(z14, x3, z14, 4) addLsh(z14, x1, z14, 4) z14.Lsh(z14, 2) addSub(z0, z14, z0, z14) z5.mul(stk, z14, z14) z6.mul(stk, z0, z0) // z7, z8 = x(-8)², x(8)² addLsh(z0, x4, x6, 6) addLsh(z0, x2, z0, 6) addLsh(z0, x0, z0, 6) addLsh(z14, x5, x7, 6) addLsh(z14, x3, z14, 6) addLsh(z14, x1, z14, 6) z14.Lsh(z14, 3) addSub(z0, z14, z0, z14) z7.mul(stk, z14, z14) z8.mul(stk, z0, z0) // z9, z10 = x(-16)², x(16)² addLsh(z0, x4, x6, 8) addLsh(z0, x2, z0, 8) addLsh(z0, x0, z0, 8) addLsh(z14, x5, x7, 8) addLsh(z14, x3, z14, 8) addLsh(z14, x1, z14, 8) z14.Lsh(z14, 4) addSub(z0, z14, z0, z14) z9.mul(stk, z14, z14) z10.mul(stk, z0, z0) // z11, z12 = x(-32)², x(32)² addLsh(z0, x4, x6, 10) addLsh(z0, x2, z0, 10) addLsh(z0, x0, z0, 10) addLsh(z14, x5, x7, 10) addLsh(z14, x3, z14, 10) addLsh(z14, x1, z14, 10) z14.Lsh(z14, 5) addSub(z0, z14, z0, z14) z11.mul(stk, z14, z14) z12.mul(stk, z0, z0) // z13 = x(64)² addLsh(z0, x6, x7, 6) addLsh(z0, x5, z0, 6) addLsh(z0, x4, z0, 6) addLsh(z0, x3, z0, 6) addLsh(z0, x2, z0, 6) addLsh(z0, x1, z0, 6) addLsh(z0, x0, z0, 6) z13.mul(stk, z0, z0) // z0 = x(0)² z0.mul(stk, x0, x0) // z14 = x(∞)² z14.mul(stk, x7, x7) var dz0, dz1, dz2, dz3, dz4, dz5, dz6, dz7, dz8, dz9, dz10, dz11, dz12, dz13, dz14 *Int if debug { dz0 = new(Int).Set(z0) dz1 = new(Int).Set(z1) dz2 = new(Int).Set(z2) dz3 = new(Int).Set(z3) dz4 = new(Int).Set(z4) dz5 = new(Int).Set(z5) dz6 = new(Int).Set(z6) dz7 = new(Int).Set(z7) dz8 = new(Int).Set(z8) dz9 = new(Int).Set(z9) dz10 = new(Int).Set(z10) dz11 = new(Int).Set(z11) dz12 = new(Int).Set(z12) dz13 = new(Int).Set(z13) dz14 = new(Int).Set(z14) } toom8Interp(z, p, z0, z1, z2, z3, z4, z5, z6, z7, z8, z9, z10, z11, z12, z13, z14) if debug { zz := make(nat, len(z)) toom4(stk, zz, x, x) if z.cmp(zz) != 0 { print("toom8 wrong\n") print("ivy -f natmul.ivy <<EOF\n") print("W=", _W, "\n") print("p=", p, "\n") trace("z", &Int{abs: z}) trace("zz", &Int{abs: zz}) trace("x", &Int{abs: x}) print("xv=(", ifmt(x0), ", ", ifmt(x1), ", ", ifmt(x2), ", ", ifmt(x3), ", ", ifmt(x4), ", ", ifmt(x5), ", ", ifmt(x6), ", ", ifmt(x7), ")\n") print("zv=(", ifmt(dz0), ", ", ifmt(dz1), ", ", ifmt(dz2), ", ", ifmt(dz3), ", ", ifmt(dz4), ", ", ifmt(dz5), ", ", ifmt(dz6), ", ", ifmt(dz7), ", ", ifmt(dz8), ", ", ifmt(dz9), ", ", ifmt(dz10), ", ", ifmt(dz11), ", ", ifmt(dz12), ", ", ifmt(dz13), ", ", ifmt(dz14), ")\n") print("izv=(", ifmt(z0), ", ", ifmt(z1), ", ", ifmt(z2), ", ", ifmt(z3), ", ", ifmt(z4), ", ", ifmt(z5), ", ", ifmt(z6), ", ", ifmt(z7), ", ", ifmt(z8), ", ", ifmt(z9), ", ", ifmt(z10), ", ", ifmt(z11), ", ", ifmt(z12), ", ", ifmt(z13), ", ", ifmt(z14), ")\n") print("debugToomSqr 8\n") print("EOF\n") panic("toom8Sqr") } } } func toom8Interp(z nat, p int, z0, z1, z2, z3, z4, z5, z6, z7, z8, z9, z10, z11, z12, z13, z14 *Int) { addSub(z2, z1, z2, z1) addSub(z4, z3, z4, z3) addSub(z6, z5, z6, z5) addSub(z8, z7, z8, z7) addSub(z10, z9, z10, z9) addSub(z12, z11, z12, z11) subLsh(z2, z2, z0, 1) subLsh(z3, z3, z1, 1) subLsh(z4, z4, z0, 1) subLsh(z4, z4, z2, 2) subLsh(z5, z5, z1, 2) subMul(z5, z5, z3, 10) subLsh(z6, z6, z0, 1) subLsh(z6, z6, z2, 4) subMul(z6, z6, z4, 20) subLsh(z7, z7, z1, 3) subMul(z7, z7, z3, 84) subMul(z7, z7, z5, 42) subLsh(z8, z8, z0, 1) subLsh(z8, z8, z2, 6) subMul(z8, z8, z4, 336) subMul(z8, z8, z6, 84) subLsh(z9, z9, z1, 4) subMul(z9, z9, z3, 680) subMul(z9, z9, z5, 1428) subMul(z9, z9, z7, 170) subLsh(z10, z10, z0, 1) subLsh(z10, z10, z2, 8) subMul(z10, z10, z4, 5440) subMul(z10, z10, z6, 5712) subMul(z10, z10, z8, 340) subLsh(z11, z11, z1, 5) subMul(z11, z11, z3, 5456) subMul(z11, z11, z5, 46376) subMul(z11, z11, z7, 23188) subMul(z11, z11, z9, 682) subLsh(z12, z12, z0, 1) subLsh(z12, z12, z2, 10) subMul(z12, z12, z4, 87296) subMul(z12, z12, z6, 371008) subMul(z12, z12, z8, 92752) subMul(z12, z12, z10, 1364) z13.Sub(z13, z0) subLsh(z13, z13, z1, 5) subLsh(z13, z13, z2, 11) subMul(z13, z13, z3, 21840) subMul(z13, z13, z4, 698880) subMul(z13, z13, z5, 744744) subMul(z13, z13, z6, 11915904) subMul(z13, z13, z7, 1507220) subMul(z13, z13, z8, 12057760) subMul(z13, z13, z9, 186186) subMul(z13, z13, z10, 744744) subMul(z13, z13, z11, 1365) subMul(z13, z13, z12, 2730) z1.Rsh(z1, 1) z2.Rsh(z2, 1) z3.Rsh(z3, 2) mdiv(z3, z3, 3, 0xaaaaaaaaaaaaaaab&_M) z4.Rsh(z4, 3) mdiv(z4, z4, 3, 0xaaaaaaaaaaaaaaab&_M) z5.Rsh(z5, 5) mdiv(z5, z5, 45, 0x4fa4fa4fa4fa4fa5&_M) z6.Rsh(z6, 7) mdiv(z6, z6, 45, 0x4fa4fa4fa4fa4fa5&_M) z7.Rsh(z7, 10) mdiv(z7, z7, 2835, 0x938cc70553e3771b&_M) z8.Rsh(z8, 13) mdiv(z8, z8, 2835, 0x938cc70553e3771b&_M) z9.Rsh(z9, 17) mdiv(z9, z9, 722925, 0x49dd6a31368a6de5&_M) z10.Rsh(z10, 21) mdiv(z10, z10, 722925, 0x49dd6a31368a6de5&_M) z11.Rsh(z11, 26) mdiv(z11, z11, 739552275, 0xbdc1e7d4816dfe1b&_M) z12.Rsh(z12, 31) mdiv(z12, z12, 739552275, 0xbdc1e7d4816dfe1b&_M) z13.Rsh(z13, 36) mdiv64(z13, z13, 3028466566125, 0x8a44806683b051e5&_M, 461586125, 0x9e2de05, 6561, 0x37360a61) subLsh(z13, z13, z14, 6) z4.Sub(z4, z12) subMul(z12, z12, z14, 1365) z3.Sub(z3, z11) subMul(z11, z11, z13, 1365) z6.Sub(z6, z10) subMul(z10, z10, z14, 93093) subMul(z10, z10, z12, 341) z5.Sub(z5, z9) subMul(z9, z9, z13, 93093) subMul(z9, z9, z11, 341) subMul(z8, z8, z14, 376805) subMul(z8, z8, z12, 5797) z4.Sub(z4, z8) subMul(z8, z8, z10, 85) subMul(z7, z7, z13, 376805) subMul(z7, z7, z11, 5797) z3.Sub(z3, z7) subMul(z7, z7, z9, 85) subMul(z6, z6, z12, 5456) subMul(z6, z6, z10, 356) subMul(z6, z6, z8, 21) subMul(z5, z5, z11, 5456) subMul(z5, z5, z9, 356) subMul(z5, z5, z7, 21) subMul(z4, z4, z12, 340) subMul(z4, z4, z8, 20) subMul(z4, z4, z6, 5) subMul(z3, z3, z11, 340) subMul(z3, z3, z7, 20) subMul(z3, z3, z5, 5) z2.Sub(z2, z14) z2.Sub(z2, z12) z2.Sub(z2, z10) z2.Sub(z2, z8) z2.Sub(z2, z6) z2.Sub(z2, z4) z1.Sub(z1, z13) z1.Sub(z1, z11) z1.Sub(z1, z9) z1.Sub(z1, z7) z1.Sub(z1, z5) z1.Sub(z1, z3) clear(z) addTo(z[0*p:], z0.abs) addTo(z[1*p:], z1.abs) addTo(z[2*p:], z2.abs) addTo(z[3*p:], z3.abs) addTo(z[4*p:], z4.abs) addTo(z[5*p:], z5.abs) addTo(z[6*p:], z6.abs) addTo(z[7*p:], z7.abs) addTo(z[8*p:], z8.abs) addTo(z[9*p:], z9.abs) addTo(z[10*p:], z10.abs) addTo(z[11*p:], z11.abs) addTo(z[12*p:], z12.abs) addTo(z[13*p:], z13.abs) addTo(z[14*p:], z14.abs) } func toom12(stk *stack, z, x, y nat) { const debug = false // avoid accidental slicing beyond cap z = z[:len(z):len(z)] x = x[:len(x):len(x)] y = y[:len(y):len(y)] n := len(y) if len(x) != n || len(z) != 2*n { panic("bad toom12 len") } // Fall back to simpler algorithm if small enough or too small. if n < toom12Threshold || n < 4*12 { toom8(stk, z, x, y) return } defer stk.restore(stk.save()) p := (n + 12 - 1) / 12 x0 := &Int{abs: x[0*p : 1*p].norm()} x1 := &Int{abs: x[1*p : 2*p].norm()} x2 := &Int{abs: x[2*p : 3*p].norm()} x3 := &Int{abs: x[3*p : 4*p].norm()} x4 := &Int{abs: x[4*p : 5*p].norm()} x5 := &Int{abs: x[5*p : 6*p].norm()} x6 := &Int{abs: x[6*p : 7*p].norm()} x7 := &Int{abs: x[7*p : 8*p].norm()} x8 := &Int{abs: x[8*p : 9*p].norm()} x9 := &Int{abs: x[9*p : 10*p].norm()} x10 := &Int{abs: x[10*p : 11*p].norm()} x11 := &Int{abs: x[11*p:].norm()} y0 := &Int{abs: y[0*p : 1*p].norm()} y1 := &Int{abs: y[1*p : 2*p].norm()} y2 := &Int{abs: y[2*p : 3*p].norm()} y3 := &Int{abs: y[3*p : 4*p].norm()} y4 := &Int{abs: y[4*p : 5*p].norm()} y5 := &Int{abs: y[5*p : 6*p].norm()} y6 := &Int{abs: y[6*p : 7*p].norm()} y7 := &Int{abs: y[7*p : 8*p].norm()} y8 := &Int{abs: y[8*p : 9*p].norm()} y9 := &Int{abs: y[9*p : 10*p].norm()} y10 := &Int{abs: y[10*p : 11*p].norm()} y11 := &Int{abs: y[11*p:].norm()} z0 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z1 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z2 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z3 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z4 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z5 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z6 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z7 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z8 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z9 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z10 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z11 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z12 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z13 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z14 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z15 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z16 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z17 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z18 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z19 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z20 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z21 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z22 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} // z1, z2 = x(-1)*y(-1), x(1)*y(1) z0.Add(x8, x10) z0.Add(z0, x6) z0.Add(z0, x4) z0.Add(z0, x2) z0.Add(z0, x0) z2.Add(x9, x11) z2.Add(z2, x7) z2.Add(z2, x5) z2.Add(z2, x3) z2.Add(z2, x1) addSub(z0, z2, z0, z2) z22.Add(y8, y10) z22.Add(z22, y6) z22.Add(z22, y4) z22.Add(z22, y2) z22.Add(z22, y0) z21.Add(y9, y11) z21.Add(z21, y7) z21.Add(z21, y5) z21.Add(z21, y3) z21.Add(z21, y1) addSub(z22, z21, z22, z21) z1.mul(stk, z2, z21) z2.mul(stk, z0, z22) // z3, z4 = x(-2)*y(-2), x(2)*y(2) addLsh(z0, x8, x10, 2) addLsh(z0, x6, z0, 2) addLsh(z0, x4, z0, 2) addLsh(z0, x2, z0, 2) addLsh(z0, x0, z0, 2) addLsh(z4, x9, x11, 2) addLsh(z4, x7, z4, 2) addLsh(z4, x5, z4, 2) addLsh(z4, x3, z4, 2) addLsh(z4, x1, z4, 2) z4.Lsh(z4, 1) addSub(z0, z4, z0, z4) addLsh(z22, y8, y10, 2) addLsh(z22, y6, z22, 2) addLsh(z22, y4, z22, 2) addLsh(z22, y2, z22, 2) addLsh(z22, y0, z22, 2) addLsh(z21, y9, y11, 2) addLsh(z21, y7, z21, 2) addLsh(z21, y5, z21, 2) addLsh(z21, y3, z21, 2) addLsh(z21, y1, z21, 2) z21.Lsh(z21, 1) addSub(z22, z21, z22, z21) z3.mul(stk, z4, z21) z4.mul(stk, z0, z22) // z5, z6 = x(-4)*y(-4), x(4)*y(4) addLsh(z0, x8, x10, 4) addLsh(z0, x6, z0, 4) addLsh(z0, x4, z0, 4) addLsh(z0, x2, z0, 4) addLsh(z0, x0, z0, 4) addLsh(z6, x9, x11, 4) addLsh(z6, x7, z6, 4) addLsh(z6, x5, z6, 4) addLsh(z6, x3, z6, 4) addLsh(z6, x1, z6, 4) z6.Lsh(z6, 2) addSub(z0, z6, z0, z6) addLsh(z22, y8, y10, 4) addLsh(z22, y6, z22, 4) addLsh(z22, y4, z22, 4) addLsh(z22, y2, z22, 4) addLsh(z22, y0, z22, 4) addLsh(z21, y9, y11, 4) addLsh(z21, y7, z21, 4) addLsh(z21, y5, z21, 4) addLsh(z21, y3, z21, 4) addLsh(z21, y1, z21, 4) z21.Lsh(z21, 2) addSub(z22, z21, z22, z21) z5.mul(stk, z6, z21) z6.mul(stk, z0, z22) // z7, z8 = x(-8)*y(-8), x(8)*y(8) addLsh(z0, x8, x10, 6) addLsh(z0, x6, z0, 6) addLsh(z0, x4, z0, 6) addLsh(z0, x2, z0, 6) addLsh(z0, x0, z0, 6) addLsh(z8, x9, x11, 6) addLsh(z8, x7, z8, 6) addLsh(z8, x5, z8, 6) addLsh(z8, x3, z8, 6) addLsh(z8, x1, z8, 6) z8.Lsh(z8, 3) addSub(z0, z8, z0, z8) addLsh(z22, y8, y10, 6) addLsh(z22, y6, z22, 6) addLsh(z22, y4, z22, 6) addLsh(z22, y2, z22, 6) addLsh(z22, y0, z22, 6) addLsh(z21, y9, y11, 6) addLsh(z21, y7, z21, 6) addLsh(z21, y5, z21, 6) addLsh(z21, y3, z21, 6) addLsh(z21, y1, z21, 6) z21.Lsh(z21, 3) addSub(z22, z21, z22, z21) z7.mul(stk, z8, z21) z8.mul(stk, z0, z22) // z9, z10 = x(-16)*y(-16), x(16)*y(16) addLsh(z0, x8, x10, 8) addLsh(z0, x6, z0, 8) addLsh(z0, x4, z0, 8) addLsh(z0, x2, z0, 8) addLsh(z0, x0, z0, 8) addLsh(z10, x9, x11, 8) addLsh(z10, x7, z10, 8) addLsh(z10, x5, z10, 8) addLsh(z10, x3, z10, 8) addLsh(z10, x1, z10, 8) z10.Lsh(z10, 4) addSub(z0, z10, z0, z10) addLsh(z22, y8, y10, 8) addLsh(z22, y6, z22, 8) addLsh(z22, y4, z22, 8) addLsh(z22, y2, z22, 8) addLsh(z22, y0, z22, 8) addLsh(z21, y9, y11, 8) addLsh(z21, y7, z21, 8) addLsh(z21, y5, z21, 8) addLsh(z21, y3, z21, 8) addLsh(z21, y1, z21, 8) z21.Lsh(z21, 4) addSub(z22, z21, z22, z21) z9.mul(stk, z10, z21) z10.mul(stk, z0, z22) // z11, z12 = x(-32)*y(-32), x(32)*y(32) addLsh(z0, x8, x10, 10) addLsh(z0, x6, z0, 10) addLsh(z0, x4, z0, 10) addLsh(z0, x2, z0, 10) addLsh(z0, x0, z0, 10) addLsh(z12, x9, x11, 10) addLsh(z12, x7, z12, 10) addLsh(z12, x5, z12, 10) addLsh(z12, x3, z12, 10) addLsh(z12, x1, z12, 10) z12.Lsh(z12, 5) addSub(z0, z12, z0, z12) addLsh(z22, y8, y10, 10) addLsh(z22, y6, z22, 10) addLsh(z22, y4, z22, 10) addLsh(z22, y2, z22, 10) addLsh(z22, y0, z22, 10) addLsh(z21, y9, y11, 10) addLsh(z21, y7, z21, 10) addLsh(z21, y5, z21, 10) addLsh(z21, y3, z21, 10) addLsh(z21, y1, z21, 10) z21.Lsh(z21, 5) addSub(z22, z21, z22, z21) z11.mul(stk, z12, z21) z12.mul(stk, z0, z22) // z13, z14 = x(-64)*y(-64), x(64)*y(64) addLsh(z0, x8, x10, 12) addLsh(z0, x6, z0, 12) addLsh(z0, x4, z0, 12) addLsh(z0, x2, z0, 12) addLsh(z0, x0, z0, 12) addLsh(z14, x9, x11, 12) addLsh(z14, x7, z14, 12) addLsh(z14, x5, z14, 12) addLsh(z14, x3, z14, 12) addLsh(z14, x1, z14, 12) z14.Lsh(z14, 6) addSub(z0, z14, z0, z14) addLsh(z22, y8, y10, 12) addLsh(z22, y6, z22, 12) addLsh(z22, y4, z22, 12) addLsh(z22, y2, z22, 12) addLsh(z22, y0, z22, 12) addLsh(z21, y9, y11, 12) addLsh(z21, y7, z21, 12) addLsh(z21, y5, z21, 12) addLsh(z21, y3, z21, 12) addLsh(z21, y1, z21, 12) z21.Lsh(z21, 6) addSub(z22, z21, z22, z21) z13.mul(stk, z14, z21) z14.mul(stk, z0, z22) // z15, z16 = x(-128)*y(-128), x(128)*y(128) addLsh(z0, x8, x10, 14) addLsh(z0, x6, z0, 14) addLsh(z0, x4, z0, 14) addLsh(z0, x2, z0, 14) addLsh(z0, x0, z0, 14) addLsh(z16, x9, x11, 14) addLsh(z16, x7, z16, 14) addLsh(z16, x5, z16, 14) addLsh(z16, x3, z16, 14) addLsh(z16, x1, z16, 14) z16.Lsh(z16, 7) addSub(z0, z16, z0, z16) addLsh(z22, y8, y10, 14) addLsh(z22, y6, z22, 14) addLsh(z22, y4, z22, 14) addLsh(z22, y2, z22, 14) addLsh(z22, y0, z22, 14) addLsh(z21, y9, y11, 14) addLsh(z21, y7, z21, 14) addLsh(z21, y5, z21, 14) addLsh(z21, y3, z21, 14) addLsh(z21, y1, z21, 14) z21.Lsh(z21, 7) addSub(z22, z21, z22, z21) z15.mul(stk, z16, z21) z16.mul(stk, z0, z22) // z17, z18 = x(-256)*y(-256), x(256)*y(256) addLsh(z0, x8, x10, 16) addLsh(z0, x6, z0, 16) addLsh(z0, x4, z0, 16) addLsh(z0, x2, z0, 16) addLsh(z0, x0, z0, 16) addLsh(z18, x9, x11, 16) addLsh(z18, x7, z18, 16) addLsh(z18, x5, z18, 16) addLsh(z18, x3, z18, 16) addLsh(z18, x1, z18, 16) z18.Lsh(z18, 8) addSub(z0, z18, z0, z18) addLsh(z22, y8, y10, 16) addLsh(z22, y6, z22, 16) addLsh(z22, y4, z22, 16) addLsh(z22, y2, z22, 16) addLsh(z22, y0, z22, 16) addLsh(z21, y9, y11, 16) addLsh(z21, y7, z21, 16) addLsh(z21, y5, z21, 16) addLsh(z21, y3, z21, 16) addLsh(z21, y1, z21, 16) z21.Lsh(z21, 8) addSub(z22, z21, z22, z21) z17.mul(stk, z18, z21) z18.mul(stk, z0, z22) // z19, z20 = x(-512)*y(-512), x(512)*y(512) addLsh(z0, x8, x10, 18) addLsh(z0, x6, z0, 18) addLsh(z0, x4, z0, 18) addLsh(z0, x2, z0, 18) addLsh(z0, x0, z0, 18) addLsh(z20, x9, x11, 18) addLsh(z20, x7, z20, 18) addLsh(z20, x5, z20, 18) addLsh(z20, x3, z20, 18) addLsh(z20, x1, z20, 18) z20.Lsh(z20, 9) addSub(z0, z20, z0, z20) addLsh(z22, y8, y10, 18) addLsh(z22, y6, z22, 18) addLsh(z22, y4, z22, 18) addLsh(z22, y2, z22, 18) addLsh(z22, y0, z22, 18) addLsh(z21, y9, y11, 18) addLsh(z21, y7, z21, 18) addLsh(z21, y5, z21, 18) addLsh(z21, y3, z21, 18) addLsh(z21, y1, z21, 18) z21.Lsh(z21, 9) addSub(z22, z21, z22, z21) z19.mul(stk, z20, z21) z20.mul(stk, z0, z22) // z21 = x(1024)*y(1024) addLsh(z0, x10, x11, 10) addLsh(z0, x9, z0, 10) addLsh(z0, x8, z0, 10) addLsh(z0, x7, z0, 10) addLsh(z0, x6, z0, 10) addLsh(z0, x5, z0, 10) addLsh(z0, x4, z0, 10) addLsh(z0, x3, z0, 10) addLsh(z0, x2, z0, 10) addLsh(z0, x1, z0, 10) addLsh(z0, x0, z0, 10) addLsh(z22, y10, y11, 10) addLsh(z22, y9, z22, 10) addLsh(z22, y8, z22, 10) addLsh(z22, y7, z22, 10) addLsh(z22, y6, z22, 10) addLsh(z22, y5, z22, 10) addLsh(z22, y4, z22, 10) addLsh(z22, y3, z22, 10) addLsh(z22, y2, z22, 10) addLsh(z22, y1, z22, 10) addLsh(z22, y0, z22, 10) z21.mul(stk, z0, z22) // z0 = x(0)*y(0) z0.mul(stk, x0, y0) // z22 = x(∞)*y(∞) z22.mul(stk, x11, y11) var dz0, dz1, dz2, dz3, dz4, dz5, dz6, dz7, dz8, dz9, dz10, dz11, dz12, dz13, dz14, dz15, dz16, dz17, dz18, dz19, dz20, dz21, dz22 *Int if debug { dz0 = new(Int).Set(z0) dz1 = new(Int).Set(z1) dz2 = new(Int).Set(z2) dz3 = new(Int).Set(z3) dz4 = new(Int).Set(z4) dz5 = new(Int).Set(z5) dz6 = new(Int).Set(z6) dz7 = new(Int).Set(z7) dz8 = new(Int).Set(z8) dz9 = new(Int).Set(z9) dz10 = new(Int).Set(z10) dz11 = new(Int).Set(z11) dz12 = new(Int).Set(z12) dz13 = new(Int).Set(z13) dz14 = new(Int).Set(z14) dz15 = new(Int).Set(z15) dz16 = new(Int).Set(z16) dz17 = new(Int).Set(z17) dz18 = new(Int).Set(z18) dz19 = new(Int).Set(z19) dz20 = new(Int).Set(z20) dz21 = new(Int).Set(z21) dz22 = new(Int).Set(z22) } toom12Interp(z, p, z0, z1, z2, z3, z4, z5, z6, z7, z8, z9, z10, z11, z12, z13, z14, z15, z16, z17, z18, z19, z20, z21, z22) if debug { zz := make(nat, len(z)) toom8(stk, zz, x, y) if z.cmp(zz) != 0 { print("toom12 wrong\n") print("ivy -f natmul.ivy <<EOF\n") print("W=", _W, "\n") print("p=", p, "\n") trace("z", &Int{abs: z}) trace("zz", &Int{abs: zz}) trace("x", &Int{abs: x}) print("xv=(", ifmt(x0), ", ", ifmt(x1), ", ", ifmt(x2), ", ", ifmt(x3), ", ", ifmt(x4), ", ", ifmt(x5), ", ", ifmt(x6), ", ", ifmt(x7), ", ", ifmt(x8), ", ", ifmt(x9), ", ", ifmt(x10), ", ", ifmt(x11), ")\n") trace("y", &Int{abs: y}) print("yv=(", ifmt(y0), ", ", ifmt(y1), ", ", ifmt(y2), ", ", ifmt(y3), ", ", ifmt(y4), ", ", ifmt(y5), ", ", ifmt(y6), ", ", ifmt(y7), ", ", ifmt(y8), ", ", ifmt(y9), ", ", ifmt(y10), ", ", ifmt(y11), ")\n") print("zv=(", ifmt(dz0), ", ", ifmt(dz1), ", ", ifmt(dz2), ", ", ifmt(dz3), ", ", ifmt(dz4), ", ", ifmt(dz5), ", ", ifmt(dz6), ", ", ifmt(dz7), ", ", ifmt(dz8), ", ", ifmt(dz9), ", ", ifmt(dz10), ", ", ifmt(dz11), ", ", ifmt(dz12), ", ", ifmt(dz13), ", ", ifmt(dz14), ", ", ifmt(dz15), ", ", ifmt(dz16), ", ", ifmt(dz17), ", ", ifmt(dz18), ", ", ifmt(dz19), ", ", ifmt(dz20), ", ", ifmt(dz21), ", ", ifmt(dz22), ")\n") print("izv=(", ifmt(z0), ", ", ifmt(z1), ", ", ifmt(z2), ", ", ifmt(z3), ", ", ifmt(z4), ", ", ifmt(z5), ", ", ifmt(z6), ", ", ifmt(z7), ", ", ifmt(z8), ", ", ifmt(z9), ", ", ifmt(z10), ", ", ifmt(z11), ", ", ifmt(z12), ", ", ifmt(z13), ", ", ifmt(z14), ", ", ifmt(z15), ", ", ifmt(z16), ", ", ifmt(z17), ", ", ifmt(z18), ", ", ifmt(z19), ", ", ifmt(z20), ", ", ifmt(z21), ", ", ifmt(z22), ")\n") print("debugToom 12\n") print("EOF\n") panic("toom12") } } } func toom12Sqr(stk *stack, z, x nat) { const debug = false // avoid accidental slicing beyond cap z = z[:len(z):len(z)] x = x[:len(x):len(x)] n := len(x) if len(z) != 2*n { panic("bad toom12Sqr len") } // Fall back to simpler algorithm if small enough or too small. if n < toom12SqrThreshold || n < 4*12 { toom8Sqr(stk, z, x) return } defer stk.restore(stk.save()) p := (n + 12 - 1) / 12 x0 := &Int{abs: x[0*p : 1*p].norm()} x1 := &Int{abs: x[1*p : 2*p].norm()} x2 := &Int{abs: x[2*p : 3*p].norm()} x3 := &Int{abs: x[3*p : 4*p].norm()} x4 := &Int{abs: x[4*p : 5*p].norm()} x5 := &Int{abs: x[5*p : 6*p].norm()} x6 := &Int{abs: x[6*p : 7*p].norm()} x7 := &Int{abs: x[7*p : 8*p].norm()} x8 := &Int{abs: x[8*p : 9*p].norm()} x9 := &Int{abs: x[9*p : 10*p].norm()} x10 := &Int{abs: x[10*p : 11*p].norm()} x11 := &Int{abs: x[11*p:].norm()} z0 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z1 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z2 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z3 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z4 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z5 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z6 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z7 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z8 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z9 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z10 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z11 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z12 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z13 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z14 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z15 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z16 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z17 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z18 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z19 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z20 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z21 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} z22 := &Int{abs: stk.nat(2*p + (221+_W-1)/_W + 1)} // z1, z2 = x(-1)², x(1)² z0.Add(x8, x10) z0.Add(z0, x6) z0.Add(z0, x4) z0.Add(z0, x2) z0.Add(z0, x0) z22.Add(x9, x11) z22.Add(z22, x7) z22.Add(z22, x5) z22.Add(z22, x3) z22.Add(z22, x1) addSub(z0, z22, z0, z22) z1.mul(stk, z22, z22) z2.mul(stk, z0, z0) // z3, z4 = x(-2)², x(2)² addLsh(z0, x8, x10, 2) addLsh(z0, x6, z0, 2) addLsh(z0, x4, z0, 2) addLsh(z0, x2, z0, 2) addLsh(z0, x0, z0, 2) addLsh(z22, x9, x11, 2) addLsh(z22, x7, z22, 2) addLsh(z22, x5, z22, 2) addLsh(z22, x3, z22, 2) addLsh(z22, x1, z22, 2) z22.Lsh(z22, 1) addSub(z0, z22, z0, z22) z3.mul(stk, z22, z22) z4.mul(stk, z0, z0) // z5, z6 = x(-4)², x(4)² addLsh(z0, x8, x10, 4) addLsh(z0, x6, z0, 4) addLsh(z0, x4, z0, 4) addLsh(z0, x2, z0, 4) addLsh(z0, x0, z0, 4) addLsh(z22, x9, x11, 4) addLsh(z22, x7, z22, 4) addLsh(z22, x5, z22, 4) addLsh(z22, x3, z22, 4) addLsh(z22, x1, z22, 4) z22.Lsh(z22, 2) addSub(z0, z22, z0, z22) z5.mul(stk, z22, z22) z6.mul(stk, z0, z0) // z7, z8 = x(-8)², x(8)² addLsh(z0, x8, x10, 6) addLsh(z0, x6, z0, 6) addLsh(z0, x4, z0, 6) addLsh(z0, x2, z0, 6) addLsh(z0, x0, z0, 6) addLsh(z22, x9, x11, 6) addLsh(z22, x7, z22, 6) addLsh(z22, x5, z22, 6) addLsh(z22, x3, z22, 6) addLsh(z22, x1, z22, 6) z22.Lsh(z22, 3) addSub(z0, z22, z0, z22) z7.mul(stk, z22, z22) z8.mul(stk, z0, z0) // z9, z10 = x(-16)², x(16)² addLsh(z0, x8, x10, 8) addLsh(z0, x6, z0, 8) addLsh(z0, x4, z0, 8) addLsh(z0, x2, z0, 8) addLsh(z0, x0, z0, 8) addLsh(z22, x9, x11, 8) addLsh(z22, x7, z22, 8) addLsh(z22, x5, z22, 8) addLsh(z22, x3, z22, 8) addLsh(z22, x1, z22, 8) z22.Lsh(z22, 4) addSub(z0, z22, z0, z22) z9.mul(stk, z22, z22) z10.mul(stk, z0, z0) // z11, z12 = x(-32)², x(32)² addLsh(z0, x8, x10, 10) addLsh(z0, x6, z0, 10) addLsh(z0, x4, z0, 10) addLsh(z0, x2, z0, 10) addLsh(z0, x0, z0, 10) addLsh(z22, x9, x11, 10) addLsh(z22, x7, z22, 10) addLsh(z22, x5, z22, 10) addLsh(z22, x3, z22, 10) addLsh(z22, x1, z22, 10) z22.Lsh(z22, 5) addSub(z0, z22, z0, z22) z11.mul(stk, z22, z22) z12.mul(stk, z0, z0) // z13, z14 = x(-64)², x(64)² addLsh(z0, x8, x10, 12) addLsh(z0, x6, z0, 12) addLsh(z0, x4, z0, 12) addLsh(z0, x2, z0, 12) addLsh(z0, x0, z0, 12) addLsh(z22, x9, x11, 12) addLsh(z22, x7, z22, 12) addLsh(z22, x5, z22, 12) addLsh(z22, x3, z22, 12) addLsh(z22, x1, z22, 12) z22.Lsh(z22, 6) addSub(z0, z22, z0, z22) z13.mul(stk, z22, z22) z14.mul(stk, z0, z0) // z15, z16 = x(-128)², x(128)² addLsh(z0, x8, x10, 14) addLsh(z0, x6, z0, 14) addLsh(z0, x4, z0, 14) addLsh(z0, x2, z0, 14) addLsh(z0, x0, z0, 14) addLsh(z22, x9, x11, 14) addLsh(z22, x7, z22, 14) addLsh(z22, x5, z22, 14) addLsh(z22, x3, z22, 14) addLsh(z22, x1, z22, 14) z22.Lsh(z22, 7) addSub(z0, z22, z0, z22) z15.mul(stk, z22, z22) z16.mul(stk, z0, z0) // z17, z18 = x(-256)², x(256)² addLsh(z0, x8, x10, 16) addLsh(z0, x6, z0, 16) addLsh(z0, x4, z0, 16) addLsh(z0, x2, z0, 16) addLsh(z0, x0, z0, 16) addLsh(z22, x9, x11, 16) addLsh(z22, x7, z22, 16) addLsh(z22, x5, z22, 16) addLsh(z22, x3, z22, 16) addLsh(z22, x1, z22, 16) z22.Lsh(z22, 8) addSub(z0, z22, z0, z22) z17.mul(stk, z22, z22) z18.mul(stk, z0, z0) // z19, z20 = x(-512)², x(512)² addLsh(z0, x8, x10, 18) addLsh(z0, x6, z0, 18) addLsh(z0, x4, z0, 18) addLsh(z0, x2, z0, 18) addLsh(z0, x0, z0, 18) addLsh(z22, x9, x11, 18) addLsh(z22, x7, z22, 18) addLsh(z22, x5, z22, 18) addLsh(z22, x3, z22, 18) addLsh(z22, x1, z22, 18) z22.Lsh(z22, 9) addSub(z0, z22, z0, z22) z19.mul(stk, z22, z22) z20.mul(stk, z0, z0) // z21 = x(1024)² addLsh(z0, x10, x11, 10) addLsh(z0, x9, z0, 10) addLsh(z0, x8, z0, 10) addLsh(z0, x7, z0, 10) addLsh(z0, x6, z0, 10) addLsh(z0, x5, z0, 10) addLsh(z0, x4, z0, 10) addLsh(z0, x3, z0, 10) addLsh(z0, x2, z0, 10) addLsh(z0, x1, z0, 10) addLsh(z0, x0, z0, 10) z21.mul(stk, z0, z0) // z0 = x(0)² z0.mul(stk, x0, x0) // z22 = x(∞)² z22.mul(stk, x11, x11) var dz0, dz1, dz2, dz3, dz4, dz5, dz6, dz7, dz8, dz9, dz10, dz11, dz12, dz13, dz14, dz15, dz16, dz17, dz18, dz19, dz20, dz21, dz22 *Int if debug { dz0 = new(Int).Set(z0) dz1 = new(Int).Set(z1) dz2 = new(Int).Set(z2) dz3 = new(Int).Set(z3) dz4 = new(Int).Set(z4) dz5 = new(Int).Set(z5) dz6 = new(Int).Set(z6) dz7 = new(Int).Set(z7) dz8 = new(Int).Set(z8) dz9 = new(Int).Set(z9) dz10 = new(Int).Set(z10) dz11 = new(Int).Set(z11) dz12 = new(Int).Set(z12) dz13 = new(Int).Set(z13) dz14 = new(Int).Set(z14) dz15 = new(Int).Set(z15) dz16 = new(Int).Set(z16) dz17 = new(Int).Set(z17) dz18 = new(Int).Set(z18) dz19 = new(Int).Set(z19) dz20 = new(Int).Set(z20) dz21 = new(Int).Set(z21) dz22 = new(Int).Set(z22) } toom12Interp(z, p, z0, z1, z2, z3, z4, z5, z6, z7, z8, z9, z10, z11, z12, z13, z14, z15, z16, z17, z18, z19, z20, z21, z22) if debug { zz := make(nat, len(z)) toom8(stk, zz, x, x) if z.cmp(zz) != 0 { print("toom12 wrong\n") print("ivy -f natmul.ivy <<EOF\n") print("W=", _W, "\n") print("p=", p, "\n") trace("z", &Int{abs: z}) trace("zz", &Int{abs: zz}) trace("x", &Int{abs: x}) print("xv=(", ifmt(x0), ", ", ifmt(x1), ", ", ifmt(x2), ", ", ifmt(x3), ", ", ifmt(x4), ", ", ifmt(x5), ", ", ifmt(x6), ", ", ifmt(x7), ", ", ifmt(x8), ", ", ifmt(x9), ", ", ifmt(x10), ", ", ifmt(x11), ")\n") print("zv=(", ifmt(dz0), ", ", ifmt(dz1), ", ", ifmt(dz2), ", ", ifmt(dz3), ", ", ifmt(dz4), ", ", ifmt(dz5), ", ", ifmt(dz6), ", ", ifmt(dz7), ", ", ifmt(dz8), ", ", ifmt(dz9), ", ", ifmt(dz10), ", ", ifmt(dz11), ", ", ifmt(dz12), ", ", ifmt(dz13), ", ", ifmt(dz14), ", ", ifmt(dz15), ", ", ifmt(dz16), ", ", ifmt(dz17), ", ", ifmt(dz18), ", ", ifmt(dz19), ", ", ifmt(dz20), ", ", ifmt(dz21), ", ", ifmt(dz22), ")\n") print("izv=(", ifmt(z0), ", ", ifmt(z1), ", ", ifmt(z2), ", ", ifmt(z3), ", ", ifmt(z4), ", ", ifmt(z5), ", ", ifmt(z6), ", ", ifmt(z7), ", ", ifmt(z8), ", ", ifmt(z9), ", ", ifmt(z10), ", ", ifmt(z11), ", ", ifmt(z12), ", ", ifmt(z13), ", ", ifmt(z14), ", ", ifmt(z15), ", ", ifmt(z16), ", ", ifmt(z17), ", ", ifmt(z18), ", ", ifmt(z19), ", ", ifmt(z20), ", ", ifmt(z21), ", ", ifmt(z22), ")\n") print("debugToomSqr 12\n") print("EOF\n") panic("toom12Sqr") } } } func toom12Interp(z nat, p int, z0, z1, z2, z3, z4, z5, z6, z7, z8, z9, z10, z11, z12, z13, z14, z15, z16, z17, z18, z19, z20, z21, z22 *Int) { addSub(z2, z1, z2, z1) addSub(z4, z3, z4, z3) addSub(z6, z5, z6, z5) addSub(z8, z7, z8, z7) addSub(z10, z9, z10, z9) addSub(z12, z11, z12, z11) addSub(z14, z13, z14, z13) addSub(z16, z15, z16, z15) addSub(z18, z17, z18, z17) addSub(z20, z19, z20, z19) subLsh(z2, z2, z0, 1) subLsh(z3, z3, z1, 1) subLsh(z4, z4, z0, 1) subLsh(z4, z4, z2, 2) subLsh(z5, z5, z1, 2) subMul(z5, z5, z3, 10) subLsh(z6, z6, z0, 1) subLsh(z6, z6, z2, 4) subMul(z6, z6, z4, 20) subLsh(z7, z7, z1, 3) subMul(z7, z7, z3, 84) subMul(z7, z7, z5, 42) subLsh(z8, z8, z0, 1) subLsh(z8, z8, z2, 6) subMul(z8, z8, z4, 336) subMul(z8, z8, z6, 84) subLsh(z9, z9, z1, 4) subMul(z9, z9, z3, 680) subMul(z9, z9, z5, 1428) subMul(z9, z9, z7, 170) subLsh(z10, z10, z0, 1) subLsh(z10, z10, z2, 8) subMul(z10, z10, z4, 5440) subMul(z10, z10, z6, 5712) subMul(z10, z10, z8, 340) subLsh(z11, z11, z1, 5) subMul(z11, z11, z3, 5456) subMul(z11, z11, z5, 46376) subMul(z11, z11, z7, 23188) subMul(z11, z11, z9, 682) subLsh(z12, z12, z0, 1) subLsh(z12, z12, z2, 10) subMul(z12, z12, z4, 87296) subMul(z12, z12, z6, 371008) subMul(z12, z12, z8, 92752) subMul(z12, z12, z10, 1364) subLsh(z13, z13, z1, 6) subMul(z13, z13, z3, 43680) subMul(z13, z13, z5, 1489488) subMul(z13, z13, z7, 3014440) subMul(z13, z13, z9, 372372) subMul(z13, z13, z11, 2730) subLsh(z14, z14, z0, 1) subLsh(z14, z14, z2, 12) subMul(z14, z14, z4, 1397760) subMul(z14, z14, z6, 23831808) subMul(z14, z14, z8, 24115520) subMul(z14, z14, z10, 1489488) subMul(z14, z14, z12, 5460) subLsh(z15, z15, z1, 7) subMul(z15, z15, z3, 349504) subMul(z15, z15, z5, 47707296) subMul(z15, z15, z7, 387337808) subMul(z15, z15, z9, 193668904) subMul(z15, z15, z11, 5963412) subMul(z15, z15, z13, 10922) subLsh(z16, z16, z0, 1) subLsh(z16, z16, z2, 14) subMul(z16, z16, z4, 22368256) subMul(z16, z16, z6, 1526633472) subMul(z16, z16, z8, 6197404928) subMul(z16, z16, z10, 1549351232) subMul(z16, z16, z12, 23853648) subMul(z16, z16, z14, 21844) subLsh(z17, z17, z1, 8) subMul(z17, z17, z3, 2796160) subMul(z17, z17, z5, 1526982976) subMul(z17, z17, z7, 49626946720) subMul(z17, z17, z9, 99545816656) subMul(z17, z17, z11, 12406736680) subMul(z17, z17, z13, 95436436) subMul(z17, z17, z15, 43690) subLsh(z18, z18, z0, 1) subLsh(z18, z18, z2, 16) subMul(z18, z18, z4, 357908480) subMul(z18, z18, z6, 97726910464) subMul(z18, z18, z8, 1588062295040) subMul(z18, z18, z10, 1592733066496) subMul(z18, z18, z12, 99253893440) subMul(z18, z18, z14, 381745744) subMul(z18, z18, z16, 87380) subLsh(z19, z19, z1, 9) subMul(z19, z19, z3, 22369536) subMul(z19, z19, z5, 48866251392) subMul(z19, z19, z7, 6353776163136) subMul(z19, z19, z9, 51017085074592) subMul(z19, z19, z11, 25508542537296) subMul(z19, z19, z13, 794222020392) subMul(z19, z19, z15, 1527070356) subMul(z19, z19, z17, 174762) subLsh(z20, z20, z0, 1) subLsh(z20, z20, z2, 18) subMul(z20, z20, z4, 5726601216) subMul(z20, z20, z6, 6254880178176) subMul(z20, z20, z8, 406641674440704) subMul(z20, z20, z10, 1632546722386944) subMul(z20, z20, z12, 408136680596736) subMul(z20, z20, z14, 6353776163136) subMul(z20, z20, z16, 6108281424) subMul(z20, z20, z18, 349524) z21.Sub(z21, z0) subLsh(z21, z21, z1, 9) subLsh(z21, z21, z2, 19) subMul(z21, z21, z3, 89478400) subMul(z21, z21, z4, 45812940800) subMul(z21, z21, z5, 781871207040) subMul(z21, z21, z6, 200159029002240) subMul(z21, z21, z7, 406666107566400) subMul(z21, z21, z8, 52053261768499200) subMul(z21, z21, z9, 13063550667177120) subMul(z21, z21, z10, 836067242699335680) subMul(z21, z21, z11, 26146256100728400) subMul(z21, z21, z12, 836680195223308800) subMul(z21, z21, z13, 3265887666794280) subMul(z21, z21, z14, 52254202668708480) subMul(z21, z21, z15, 25416631722900) subMul(z21, z21, z16, 203333053783200) subMul(z21, z21, z17, 12216737610) subMul(z21, z21, z18, 48866950440) subMul(z21, z21, z19, 349525) subMul(z21, z21, z20, 699050) z1.Rsh(z1, 1) z2.Rsh(z2, 1) z3.Rsh(z3, 2) mdiv(z3, z3, 3, 0xaaaaaaaaaaaaaaab&_M) z4.Rsh(z4, 3) mdiv(z4, z4, 3, 0xaaaaaaaaaaaaaaab&_M) z5.Rsh(z5, 5) mdiv(z5, z5, 45, 0x4fa4fa4fa4fa4fa5&_M) z6.Rsh(z6, 7) mdiv(z6, z6, 45, 0x4fa4fa4fa4fa4fa5&_M) z7.Rsh(z7, 10) mdiv(z7, z7, 2835, 0x938cc70553e3771b&_M) z8.Rsh(z8, 13) mdiv(z8, z8, 2835, 0x938cc70553e3771b&_M) z9.Rsh(z9, 17) mdiv(z9, z9, 722925, 0x49dd6a31368a6de5&_M) z10.Rsh(z10, 21) mdiv(z10, z10, 722925, 0x49dd6a31368a6de5&_M) z11.Rsh(z11, 26) mdiv(z11, z11, 739552275, 0xbdc1e7d4816dfe1b&_M) z12.Rsh(z12, 31) mdiv(z12, z12, 739552275, 0xbdc1e7d4816dfe1b&_M) z13.Rsh(z13, 37) mdiv64(z13, z13, 3028466566125, 0x8a44806683b051e5&_M, 461586125, 0x9e2de05, 6561, 0x37360a61) z14.Rsh(z14, 43) mdiv64(z14, z14, 3028466566125, 0x8a44806683b051e5&_M, 461586125, 0x9e2de05, 6561, 0x37360a61) z15.Rsh(z15, 50) mdiv64(z15, z15, 49615367752825875, 0x471f458f17d66e1b&_M, 411546421, 0xb6d9471d, 120558375, 0x769eac97) z16.Rsh(z16, 57) mdiv64(z16, z16, 49615367752825875, 0x471f458f17d66e1b&_M, 411546421, 0xb6d9471d, 120558375, 0x769eac97) z17.Rsh(z17, 65) mdiv64(z17, z17, 49615367752825875, 0x471f458f17d66e1b&_M, 411546421, 0xb6d9471d, 120558375, 0x769eac97) mdiv(z17, z17, 65535, 0xfffefffefffeffff&_M) z18.Rsh(z18, 73) mdiv64(z18, z18, 49615367752825875, 0x471f458f17d66e1b&_M, 411546421, 0xb6d9471d, 120558375, 0x769eac97) mdiv(z18, z18, 65535, 0xfffefffefffeffff&_M) z19.Rsh(z19, 82) mdiv64(z19, z19, 49615367752825875, 0x471f458f17d66e1b&_M, 411546421, 0xb6d9471d, 120558375, 0x769eac97) mdiv64(z19, z19, 17179541505, 0x55001500050001&_M, 212093105, 0x1950051, 81, 0x781948b1) z20.Rsh(z20, 91) mdiv64(z20, z20, 49615367752825875, 0x471f458f17d66e1b&_M, 411546421, 0xb6d9471d, 120558375, 0x769eac97) mdiv64(z20, z20, 17179541505, 0x55001500050001&_M, 212093105, 0x1950051, 81, 0x781948b1) z21.Rsh(z21, 100) mdiv64(z21, z21, 49615367752825875, 0x471f458f17d66e1b&_M, 411546421, 0xb6d9471d, 120558375, 0x769eac97) mdiv64(z21, z21, 18014037733605375, 0xe95afe9affeaffff&_M, 453059389, 0x3b5a4c15, 39760875, 0xfddcc3) subLsh(z21, z21, z22, 10) z4.Sub(z4, z20) subMul(z20, z20, z22, 349525) z3.Sub(z3, z19) subMul(z19, z19, z21, 349525) z6.Sub(z6, z18) subMul(z18, z18, z22, 6108368805) subMul(z18, z18, z20, 87381) z5.Sub(z5, z17) subMul(z17, z17, z21, 6108368805) subMul(z17, z17, z19, 87381) z8.Sub(z8, z16) subMul(z16, z16, z22, 6354157930725) subMul(z16, z16, z20, 381767589) z4.Sub(z4, z16) subMul(z16, z16, z18, 21845) z7.Sub(z7, z15) subMul(z15, z15, z21, 6354157930725) subMul(z15, z15, z19, 381767589) z3.Sub(z3, z15) subMul(z15, z15, z17, 21845) z10.Sub(z10, z14) subMul(z14, z14, z22, 408235958349285) subMul(z14, z14, z20, 99277752549) subMul(z14, z14, z18, 23859109) subMul(z14, z14, z16, 5461) z9.Sub(z9, z13) subMul(z13, z13, z21, 408235958349285) subMul(z13, z13, z19, 99277752549) subMul(z13, z13, z17, 23859109) subMul(z13, z13, z15, 5461) subMul(z12, z12, z22, 1634141006295525) subMul(z12, z12, z20, 1594283908581) subMul(z12, z12, z18, 1550842085) z6.Sub(z6, z12) subMul(z12, z12, z16, 1490853) z4.Sub(z4, z12) subMul(z12, z12, z14, 1365) subMul(z11, z11, z21, 1634141006295525) subMul(z11, z11, z19, 1594283908581) subMul(z11, z11, z17, 1550842085) z5.Sub(z5, z11) subMul(z11, z11, z15, 1490853) z3.Sub(z3, z11) subMul(z11, z11, z13, 1365) subMul(z10, z10, z20, 1495006156032) subMul(z10, z10, z18, 6197754432) subMul(z10, z10, z16, 24203152) subMul(z10, z10, z14, 93092) subMul(z10, z10, z12, 341) subMul(z9, z9, z19, 1495006156032) subMul(z9, z9, z17, 6197754432) subMul(z9, z9, z15, 24203152) subMul(z9, z9, z13, 93092) subMul(z9, z9, z11, 341) subMul(z8, z8, z20, 98895984960) subMul(z8, z8, z18, 1550820240) subMul(z8, z8, z16, 24208612) subMul(z8, z8, z14, 376805) subMul(z8, z8, z12, 5797) z4.Sub(z4, z8) subMul(z8, z8, z10, 85) subMul(z7, z7, z19, 98895984960) subMul(z7, z7, z17, 1550820240) subMul(z7, z7, z15, 24208612) subMul(z7, z7, z13, 376805) subMul(z7, z7, z11, 5797) z3.Sub(z3, z7) subMul(z7, z7, z9, 85) subMul(z6, z6, z20, 381680208) subMul(z6, z6, z18, 23859108) subMul(z6, z6, z14, 91728) subMul(z6, z6, z12, 5796) subMul(z6, z6, z10, 357) subMul(z6, z6, z8, 21) subMul(z5, z5, z19, 381680208) subMul(z5, z5, z17, 23859108) subMul(z5, z5, z13, 91728) subMul(z5, z5, z11, 5796) subMul(z5, z5, z9, 357) subMul(z5, z5, z7, 21) subMul(z4, z4, z20, 87380) subMul(z4, z4, z16, 5460) subMul(z4, z4, z12, 340) subMul(z4, z4, z8, 20) subMul(z4, z4, z6, 5) subMul(z3, z3, z19, 87380) subMul(z3, z3, z15, 5460) subMul(z3, z3, z11, 340) subMul(z3, z3, z7, 20) subMul(z3, z3, z5, 5) z2.Sub(z2, z22) z2.Sub(z2, z20) z2.Sub(z2, z18) z2.Sub(z2, z16) z2.Sub(z2, z14) z2.Sub(z2, z12) z2.Sub(z2, z10) z2.Sub(z2, z8) z2.Sub(z2, z6) z2.Sub(z2, z4) z1.Sub(z1, z21) z1.Sub(z1, z19) z1.Sub(z1, z17) z1.Sub(z1, z15) z1.Sub(z1, z13) z1.Sub(z1, z11) z1.Sub(z1, z9) z1.Sub(z1, z7) z1.Sub(z1, z5) z1.Sub(z1, z3) clear(z) addTo(z[0*p:], z0.abs) addTo(z[1*p:], z1.abs) addTo(z[2*p:], z2.abs) addTo(z[3*p:], z3.abs) addTo(z[4*p:], z4.abs) addTo(z[5*p:], z5.abs) addTo(z[6*p:], z6.abs) addTo(z[7*p:], z7.abs) addTo(z[8*p:], z8.abs) addTo(z[9*p:], z9.abs) addTo(z[10*p:], z10.abs) addTo(z[11*p:], z11.abs) addTo(z[12*p:], z12.abs) addTo(z[13*p:], z13.abs) addTo(z[14*p:], z14.abs) addTo(z[15*p:], z15.abs) addTo(z[16*p:], z16.abs) addTo(z[17*p:], z17.abs) addTo(z[18*p:], z18.abs) addTo(z[19*p:], z19.abs) addTo(z[20*p:], z20.abs) addTo(z[21*p:], z21.abs) addTo(z[22*p:], z22.abs) }
// Copyright 2016 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package big import "math/rand" // ProbablyPrime reports whether x is probably prime, // applying the Miller-Rabin test with n pseudorandomly chosen bases // as well as a Baillie-PSW test. // // If x is prime, ProbablyPrime returns true. // If x is chosen randomly and not prime, ProbablyPrime probably returns false. // The probability of returning true for a randomly chosen non-prime is at most ¼ⁿ. // // ProbablyPrime is 100% accurate for inputs less than 2⁶⁴. // See Menezes et al., Handbook of Applied Cryptography, 1997, pp. 145-149, // and FIPS 186-4 Appendix F for further discussion of the error probabilities. // // ProbablyPrime is not suitable for judging primes that an adversary may // have crafted to fool the test. // // As of Go 1.8, ProbablyPrime(0) is allowed and applies only a Baillie-PSW test. // Before Go 1.8, ProbablyPrime applied only the Miller-Rabin tests, and ProbablyPrime(0) panicked. func (x *Int) ProbablyPrime(n int) bool { // Note regarding the doc comment above: // It would be more precise to say that the Baillie-PSW test uses the // extra strong Lucas test as its Lucas test, but since no one knows // how to tell any of the Lucas tests apart inside a Baillie-PSW test // (they all work equally well empirically), that detail need not be // documented or implicitly guaranteed. // The comment does avoid saying "the" Baillie-PSW test // because of this general ambiguity. if n < 0 { panic("negative n for ProbablyPrime") } if x.neg || len(x.abs) == 0 { return false } // primeBitMask records the primes < 64. const primeBitMask uint64 = 1<<2 | 1<<3 | 1<<5 | 1<<7 | 1<<11 | 1<<13 | 1<<17 | 1<<19 | 1<<23 | 1<<29 | 1<<31 | 1<<37 | 1<<41 | 1<<43 | 1<<47 | 1<<53 | 1<<59 | 1<<61 w := x.abs[0] if len(x.abs) == 1 && w < 64 { return primeBitMask&(1<<w) != 0 } if w&1 == 0 { return false // x is even } const primesA = 3 * 5 * 7 * 11 * 13 * 17 * 19 * 23 * 37 const primesB = 29 * 31 * 41 * 43 * 47 * 53 var rA, rB uint32 switch _W { case 32: rA = uint32(x.abs.modW(primesA)) rB = uint32(x.abs.modW(primesB)) case 64: r := x.abs.modW((primesA * primesB) & _M) rA = uint32(r % primesA) rB = uint32(r % primesB) default: panic("math/big: invalid word size") } if rA%3 == 0 || rA%5 == 0 || rA%7 == 0 || rA%11 == 0 || rA%13 == 0 || rA%17 == 0 || rA%19 == 0 || rA%23 == 0 || rA%37 == 0 || rB%29 == 0 || rB%31 == 0 || rB%41 == 0 || rB%43 == 0 || rB%47 == 0 || rB%53 == 0 { return false } stk := getStack() defer stk.free() return x.abs.probablyPrimeMillerRabin(stk, n+1, true) && x.abs.probablyPrimeLucas(stk) } // probablyPrimeMillerRabin reports whether n passes reps rounds of the // Miller-Rabin primality test, using pseudo-randomly chosen bases. // If force2 is true, one of the rounds is forced to use base 2. // See Handbook of Applied Cryptography, p. 139, Algorithm 4.24. // The number n is known to be non-zero. func (n nat) probablyPrimeMillerRabin(stk *stack, reps int, force2 bool) bool { nm1 := nat(nil).sub(n, natOne) // determine q, k such that nm1 = q << k k := nm1.trailingZeroBits() q := nat(nil).rsh(nm1, k) nm3 := nat(nil).sub(nm1, natTwo) rand := rand.New(rand.NewSource(int64(n[0]))) var x, y, quotient nat nm3Len := nm3.bitLen() NextRandom: for i := 0; i < reps; i++ { if i == reps-1 && force2 { x = x.set(natTwo) } else { x = x.random(rand, nm3, nm3Len) x = x.add(x, natTwo) } y = y.expNN(stk, x, q, n, false) if y.cmp(natOne) == 0 || y.cmp(nm1) == 0 { continue } for j := uint(1); j < k; j++ { y = y.sqr(stk, y) quotient, y = quotient.div(stk, y, y, n) if y.cmp(nm1) == 0 { continue NextRandom } if y.cmp(natOne) == 0 { return false } } return false } return true } // probablyPrimeLucas reports whether n passes the "almost extra strong" Lucas probable prime test, // using Baillie-OEIS parameter selection. This corresponds to "AESLPSP" on Jacobsen's tables (link below). // The combination of this test and a Miller-Rabin/Fermat test with base 2 gives a Baillie-PSW test. // // References: // // Baillie and Wagstaff, "Lucas Pseudoprimes", Mathematics of Computation 35(152), // October 1980, pp. 1391-1417, especially page 1401. // https://www.ams.org/journals/mcom/1980-35-152/S0025-5718-1980-0583518-6/S0025-5718-1980-0583518-6.pdf // // Grantham, "Frobenius Pseudoprimes", Mathematics of Computation 70(234), // March 2000, pp. 873-891. // https://www.ams.org/journals/mcom/2001-70-234/S0025-5718-00-01197-2/S0025-5718-00-01197-2.pdf // // Baillie, "Extra strong Lucas pseudoprimes", OEIS A217719, https://oeis.org/A217719. // // Jacobsen, "Pseudoprime Statistics, Tables, and Data", http://ntheory.org/pseudoprimes.html. // // Nicely, "The Baillie-PSW Primality Test", https://web.archive.org/web/20191121062007/http://www.trnicely.net/misc/bpsw.html. // (Note that Nicely's definition of the "extra strong" test gives the wrong Jacobi condition, // as pointed out by Jacobsen.) // // Crandall and Pomerance, Prime Numbers: A Computational Perspective, 2nd ed. // Springer, 2005. func (n nat) probablyPrimeLucas(stk *stack) bool { // Discard 0, 1. if len(n) == 0 || n.cmp(natOne) == 0 { return false } // Two is the only even prime. // Already checked by caller, but here to allow testing in isolation. if n[0]&1 == 0 { return n.cmp(natTwo) == 0 } // Baillie-OEIS "method C" for choosing D, P, Q, // as in https://oeis.org/A217719/a217719.txt: // try increasing P ≥ 3 such that D = P² - 4 (so Q = 1) // until Jacobi(D, n) = -1. // The search is expected to succeed for non-square n after just a few trials. // After more than expected failures, check whether n is square // (which would cause Jacobi(D, n) = 1 for all D not dividing n). p := Word(3) d := nat{1} t1 := nat(nil) // temp intD := &Int{abs: d} intN := &Int{abs: n} for ; ; p++ { if p > 10000 { // This is widely believed to be impossible. // If we get a report, we'll want the exact number n. panic("math/big: internal error: cannot find (D/n) = -1 for " + intN.String()) } d[0] = p*p - 4 j := Jacobi(intD, intN) if j == -1 { break } if j == 0 { // d = p²-4 = (p-2)(p+2). // If (d/n) == 0 then d shares a prime factor with n. // Since the loop proceeds in increasing p and starts with p-2==1, // the shared prime factor must be p+2. // If p+2 == n, then n is prime; otherwise p+2 is a proper factor of n. return len(n) == 1 && n[0] == p+2 } if p == 40 { // We'll never find (d/n) = -1 if n is a square. // If n is a non-square we expect to find a d in just a few attempts on average. // After 40 attempts, take a moment to check if n is indeed a square. t1 = t1.sqrt(stk, n) t1 = t1.sqr(stk, t1) if t1.cmp(n) == 0 { return false } } } // Grantham definition of "extra strong Lucas pseudoprime", after Thm 2.3 on p. 876 // (D, P, Q above have become Δ, b, 1): // // Let U_n = U_n(b, 1), V_n = V_n(b, 1), and Δ = b²-4. // An extra strong Lucas pseudoprime to base b is a composite n = 2^r s + Jacobi(Δ, n), // where s is odd and gcd(n, 2*Δ) = 1, such that either (i) U_s ≡ 0 mod n and V_s ≡ ±2 mod n, // or (ii) V_{2^t s} ≡ 0 mod n for some 0 ≤ t < r-1. // // We know gcd(n, Δ) = 1 or else we'd have found Jacobi(d, n) == 0 above. // We know gcd(n, 2) = 1 because n is odd. // // Arrange s = (n - Jacobi(Δ, n)) / 2^r = (n+1) / 2^r. s := nat(nil).add(n, natOne) r := int(s.trailingZeroBits()) s = s.rsh(s, uint(r)) nm2 := nat(nil).sub(n, natTwo) // n-2 // We apply the "almost extra strong" test, which checks the above conditions // except for U_s ≡ 0 mod n, which allows us to avoid computing any U_k values. // Jacobsen points out that maybe we should just do the full extra strong test: // "It is also possible to recover U_n using Crandall and Pomerance equation 3.13: // U_n = D^-1 (2V_{n+1} - PV_n) allowing us to run the full extra-strong test // at the cost of a single modular inversion. This computation is easy and fast in GMP, // so we can get the full extra-strong test at essentially the same performance as the // almost extra strong test." // Compute Lucas sequence V_s(b, 1), where: // // V(0) = 2 // V(1) = P // V(k) = P V(k-1) - Q V(k-2). // // (Remember that due to method C above, P = b, Q = 1.) // // In general V(k) = α^k + β^k, where α and β are roots of x² - Px + Q. // Crandall and Pomerance (p.147) observe that for 0 ≤ j ≤ k, // // V(j+k) = V(j)V(k) - V(k-j). // // So in particular, to quickly double the subscript: // // V(2k) = V(k)² - 2 // V(2k+1) = V(k) V(k+1) - P // // We can therefore start with k=0 and build up to k=s in log₂(s) steps. natP := nat(nil).setWord(p) vk := nat(nil).setWord(2) vk1 := nat(nil).setWord(p) t2 := nat(nil) // temp for i := int(s.bitLen()); i >= 0; i-- { if s.bit(uint(i)) != 0 { // k' = 2k+1 // V(k') = V(2k+1) = V(k) V(k+1) - P. t1 = t1.mul(stk, vk, vk1) t1 = t1.add(t1, n) t1 = t1.sub(t1, natP) t2, vk = t2.div(stk, vk, t1, n) // V(k'+1) = V(2k+2) = V(k+1)² - 2. t1 = t1.sqr(stk, vk1) t1 = t1.add(t1, nm2) t2, vk1 = t2.div(stk, vk1, t1, n) } else { // k' = 2k // V(k'+1) = V(2k+1) = V(k) V(k+1) - P. t1 = t1.mul(stk, vk, vk1) t1 = t1.add(t1, n) t1 = t1.sub(t1, natP) t2, vk1 = t2.div(stk, vk1, t1, n) // V(k') = V(2k) = V(k)² - 2 t1 = t1.sqr(stk, vk) t1 = t1.add(t1, nm2) t2, vk = t2.div(stk, vk, t1, n) } } // Now k=s, so vk = V(s). Check V(s) ≡ ±2 (mod n). if vk.cmp(natTwo) == 0 || vk.cmp(nm2) == 0 { // Check U(s) ≡ 0. // As suggested by Jacobsen, apply Crandall and Pomerance equation 3.13: // // U(k) = D⁻¹ (2 V(k+1) - P V(k)) // // Since we are checking for U(k) == 0 it suffices to check 2 V(k+1) == P V(k) mod n, // or P V(k) - 2 V(k+1) == 0 mod n. t1 := t1.mul(stk, vk, natP) t2 := t2.lsh(vk1, 1) if t1.cmp(t2) < 0 { t1, t2 = t2, t1 } t1 = t1.sub(t1, t2) t3 := vk1 // steal vk1, no longer needed below vk1 = nil _ = vk1 t2, t3 = t2.div(stk, t3, t1, n) if len(t3) == 0 { return true } } // Check V(2^t s) ≡ 0 mod n for some 0 ≤ t < r-1. for t := 0; t < r-1; t++ { if len(vk) == 0 { // vk == 0 return true } // Optimization: V(k) = 2 is a fixed point for V(k') = V(k)² - 2, // so if V(k) = 2, we can stop: we will never find a future V(k) == 0. if len(vk) == 1 && vk[0] == 2 { // vk == 2 return false } // k' = 2k // V(k') = V(2k) = V(k)² - 2 t1 = t1.sqr(stk, vk) t1 = t1.sub(t1, natTwo) t2, vk = t2.div(stk, vk, t1, n) } return false }
// Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements multi-precision rational numbers. package big import ( "fmt" "math" ) // A Rat represents a quotient a/b of arbitrary precision. // The zero value for a Rat represents the value 0. // // Operations always take pointer arguments (*Rat) rather // than Rat values, and each unique Rat value requires // its own unique *Rat pointer. To "copy" a Rat value, // an existing (or newly allocated) Rat must be set to // a new value using the [Rat.Set] method; shallow copies // of Rats are not supported and may lead to errors. type Rat struct { // To make zero values for Rat work w/o initialization, // a zero value of b (len(b) == 0) acts like b == 1. At // the earliest opportunity (when an assignment to the Rat // is made), such uninitialized denominators are set to 1. // a.neg determines the sign of the Rat, b.neg is ignored. a, b Int } // NewRat creates a new [Rat] with numerator a and denominator b. func NewRat(a, b int64) *Rat { return new(Rat).SetFrac64(a, b) } // SetFloat64 sets z to exactly f and returns z. // If f is not finite, SetFloat returns nil. func (z *Rat) SetFloat64(f float64) *Rat { const expMask = 1<<11 - 1 bits := math.Float64bits(f) mantissa := bits & (1<<52 - 1) exp := int((bits >> 52) & expMask) switch exp { case expMask: // non-finite return nil case 0: // denormal exp -= 1022 default: // normal mantissa |= 1 << 52 exp -= 1023 } shift := 52 - exp // Optimization (?): partially pre-normalise. for mantissa&1 == 0 && shift > 0 { mantissa >>= 1 shift-- } z.a.SetUint64(mantissa) z.a.neg = f < 0 z.b.Set(intOne) if shift > 0 { z.b.Lsh(&z.b, uint(shift)) } else { z.a.Lsh(&z.a, uint(-shift)) } return z.norm() } // quotToFloat32 returns the non-negative float32 value // nearest to the quotient a/b, using round-to-even in // halfway cases. It does not mutate its arguments. // Preconditions: b is non-zero; a and b have no common factors. func quotToFloat32(stk *stack, a, b nat) (f float32, exact bool) { const ( // float size in bits Fsize = 32 // mantissa Msize = 23 Msize1 = Msize + 1 // incl. implicit 1 Msize2 = Msize1 + 1 // exponent Esize = Fsize - Msize1 Ebias = 1<<(Esize-1) - 1 Emin = 1 - Ebias Emax = Ebias ) // TODO(adonovan): specialize common degenerate cases: 1.0, integers. alen := a.bitLen() if alen == 0 { return 0, true } blen := b.bitLen() if blen == 0 { panic("division by zero") } // 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1) // (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B). // This is 2 or 3 more than the float32 mantissa field width of Msize: // - the optional extra bit is shifted away in step 3 below. // - the high-order 1 is omitted in "normal" representation; // - the low-order 1 will be used during rounding then discarded. exp := alen - blen var a2, b2 nat a2 = a2.set(a) b2 = b2.set(b) if shift := Msize2 - exp; shift > 0 { a2 = a2.lsh(a2, uint(shift)) } else if shift < 0 { b2 = b2.lsh(b2, uint(-shift)) } // 2. Compute quotient and remainder (q, r). NB: due to the // extra shift, the low-order bit of q is logically the // high-order bit of r. var q nat q, r := q.div(stk, a2, a2, b2) // (recycle a2) mantissa := low32(q) haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half // 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1 // (in effect---we accomplish this incrementally). if mantissa>>Msize2 == 1 { if mantissa&1 == 1 { haveRem = true } mantissa >>= 1 exp++ } if mantissa>>Msize1 != 1 { panic(fmt.Sprintf("expected exactly %d bits of result", Msize2)) } // 4. Rounding. if Emin-Msize <= exp && exp <= Emin { // Denormal case; lose 'shift' bits of precision. shift := uint(Emin - (exp - 1)) // [1..Esize1) lostbits := mantissa & (1<<shift - 1) haveRem = haveRem || lostbits != 0 mantissa >>= shift exp = 2 - Ebias // == exp + shift } // Round q using round-half-to-even. exact = !haveRem if mantissa&1 != 0 { exact = false if haveRem || mantissa&2 != 0 { if mantissa++; mantissa >= 1<<Msize2 { // Complete rollover 11...1 => 100...0, so shift is safe mantissa >>= 1 exp++ } } } mantissa >>= 1 // discard rounding bit. Mantissa now scaled by 1<<Msize1. f = float32(math.Ldexp(float64(mantissa), exp-Msize1)) if math.IsInf(float64(f), 0) { exact = false } return } // quotToFloat64 returns the non-negative float64 value // nearest to the quotient a/b, using round-to-even in // halfway cases. It does not mutate its arguments. // Preconditions: b is non-zero; a and b have no common factors. func quotToFloat64(stk *stack, a, b nat) (f float64, exact bool) { const ( // float size in bits Fsize = 64 // mantissa Msize = 52 Msize1 = Msize + 1 // incl. implicit 1 Msize2 = Msize1 + 1 // exponent Esize = Fsize - Msize1 Ebias = 1<<(Esize-1) - 1 Emin = 1 - Ebias Emax = Ebias ) // TODO(adonovan): specialize common degenerate cases: 1.0, integers. alen := a.bitLen() if alen == 0 { return 0, true } blen := b.bitLen() if blen == 0 { panic("division by zero") } // 1. Left-shift A or B such that quotient A/B is in [1<<Msize1, 1<<(Msize2+1) // (Msize2 bits if A < B when they are left-aligned, Msize2+1 bits if A >= B). // This is 2 or 3 more than the float64 mantissa field width of Msize: // - the optional extra bit is shifted away in step 3 below. // - the high-order 1 is omitted in "normal" representation; // - the low-order 1 will be used during rounding then discarded. exp := alen - blen var a2, b2 nat a2 = a2.set(a) b2 = b2.set(b) if shift := Msize2 - exp; shift > 0 { a2 = a2.lsh(a2, uint(shift)) } else if shift < 0 { b2 = b2.lsh(b2, uint(-shift)) } // 2. Compute quotient and remainder (q, r). NB: due to the // extra shift, the low-order bit of q is logically the // high-order bit of r. var q nat q, r := q.div(stk, a2, a2, b2) // (recycle a2) mantissa := low64(q) haveRem := len(r) > 0 // mantissa&1 && !haveRem => remainder is exactly half // 3. If quotient didn't fit in Msize2 bits, redo division by b2<<1 // (in effect---we accomplish this incrementally). if mantissa>>Msize2 == 1 { if mantissa&1 == 1 { haveRem = true } mantissa >>= 1 exp++ } if mantissa>>Msize1 != 1 { panic(fmt.Sprintf("expected exactly %d bits of result", Msize2)) } // 4. Rounding. if Emin-Msize <= exp && exp <= Emin { // Denormal case; lose 'shift' bits of precision. shift := uint(Emin - (exp - 1)) // [1..Esize1) lostbits := mantissa & (1<<shift - 1) haveRem = haveRem || lostbits != 0 mantissa >>= shift exp = 2 - Ebias // == exp + shift } // Round q using round-half-to-even. exact = !haveRem if mantissa&1 != 0 { exact = false if haveRem || mantissa&2 != 0 { if mantissa++; mantissa >= 1<<Msize2 { // Complete rollover 11...1 => 100...0, so shift is safe mantissa >>= 1 exp++ } } } mantissa >>= 1 // discard rounding bit. Mantissa now scaled by 1<<Msize1. f = math.Ldexp(float64(mantissa), exp-Msize1) if math.IsInf(f, 0) { exact = false } return } // Float32 returns the nearest float32 value for x and a bool indicating // whether f represents x exactly. If the magnitude of x is too large to // be represented by a float32, f is an infinity and exact is false. // The sign of f always matches the sign of x, even if f == 0. func (x *Rat) Float32() (f float32, exact bool) { b := x.b.abs if len(b) == 0 { b = natOne } stk := getStack() defer stk.free() f, exact = quotToFloat32(stk, x.a.abs, b) if x.a.neg { f = -f } return } // Float64 returns the nearest float64 value for x and a bool indicating // whether f represents x exactly. If the magnitude of x is too large to // be represented by a float64, f is an infinity and exact is false. // The sign of f always matches the sign of x, even if f == 0. func (x *Rat) Float64() (f float64, exact bool) { b := x.b.abs if len(b) == 0 { b = natOne } stk := getStack() defer stk.free() f, exact = quotToFloat64(stk, x.a.abs, b) if x.a.neg { f = -f } return } // SetFrac sets z to a/b and returns z. // If b == 0, SetFrac panics. func (z *Rat) SetFrac(a, b *Int) *Rat { z.a.neg = a.neg != b.neg babs := b.abs if len(babs) == 0 { panic("division by zero") } if &z.a == b || alias(z.a.abs, babs) { babs = nat(nil).set(babs) // make a copy } z.a.abs = z.a.abs.set(a.abs) z.b.abs = z.b.abs.set(babs) return z.norm() } // SetFrac64 sets z to a/b and returns z. // If b == 0, SetFrac64 panics. func (z *Rat) SetFrac64(a, b int64) *Rat { if b == 0 { panic("division by zero") } z.a.SetInt64(a) if b < 0 { b = -b z.a.neg = !z.a.neg } z.b.abs = z.b.abs.setUint64(uint64(b)) return z.norm() } // SetInt sets z to x (by making a copy of x) and returns z. func (z *Rat) SetInt(x *Int) *Rat { z.a.Set(x) z.b.abs = z.b.abs.setWord(1) return z } // SetInt64 sets z to x and returns z. func (z *Rat) SetInt64(x int64) *Rat { z.a.SetInt64(x) z.b.abs = z.b.abs.setWord(1) return z } // SetUint64 sets z to x and returns z. func (z *Rat) SetUint64(x uint64) *Rat { z.a.SetUint64(x) z.b.abs = z.b.abs.setWord(1) return z } // Set sets z to x (by making a copy of x) and returns z. func (z *Rat) Set(x *Rat) *Rat { if z != x { z.a.Set(&x.a) z.b.Set(&x.b) } if len(z.b.abs) == 0 { z.b.abs = z.b.abs.setWord(1) } return z } // Abs sets z to |x| (the absolute value of x) and returns z. func (z *Rat) Abs(x *Rat) *Rat { z.Set(x) z.a.neg = false return z } // Neg sets z to -x and returns z. func (z *Rat) Neg(x *Rat) *Rat { z.Set(x) z.a.neg = len(z.a.abs) > 0 && !z.a.neg // 0 has no sign return z } // Inv sets z to 1/x and returns z. // If x == 0, Inv panics. func (z *Rat) Inv(x *Rat) *Rat { if len(x.a.abs) == 0 { panic("division by zero") } z.Set(x) z.a.abs, z.b.abs = z.b.abs, z.a.abs return z } // Sign returns: // - -1 if x < 0; // - 0 if x == 0; // - +1 if x > 0. func (x *Rat) Sign() int { return x.a.Sign() } // IsInt reports whether the denominator of x is 1. func (x *Rat) IsInt() bool { return len(x.b.abs) == 0 || x.b.abs.cmp(natOne) == 0 } // Num returns the numerator of x; it may be <= 0. // The result is a reference to x's numerator; it // may change if a new value is assigned to x, and vice versa. // The sign of the numerator corresponds to the sign of x. func (x *Rat) Num() *Int { return &x.a } // Denom returns the denominator of x; it is always > 0. // The result is a reference to x's denominator, unless // x is an uninitialized (zero value) [Rat], in which case // the result is a new [Int] of value 1. (To initialize x, // any operation that sets x will do, including x.Set(x).) // If the result is a reference to x's denominator it // may change if a new value is assigned to x, and vice versa. func (x *Rat) Denom() *Int { // Note that x.b.neg is guaranteed false. if len(x.b.abs) == 0 { // Note: If this proves problematic, we could // panic instead and require the Rat to // be explicitly initialized. return &Int{abs: nat{1}} } return &x.b } func (z *Rat) norm() *Rat { switch { case len(z.a.abs) == 0: // z == 0; normalize sign and denominator z.a.neg = false fallthrough case len(z.b.abs) == 0: // z is integer; normalize denominator z.b.abs = z.b.abs.setWord(1) default: // z is fraction; normalize numerator and denominator stk := getStack() defer stk.free() neg := z.a.neg z.a.neg = false z.b.neg = false if f := NewInt(0).lehmerGCD(nil, nil, &z.a, &z.b); f.Cmp(intOne) != 0 { z.a.abs, _ = z.a.abs.div(stk, nil, z.a.abs, f.abs) z.b.abs, _ = z.b.abs.div(stk, nil, z.b.abs, f.abs) } z.a.neg = neg } return z } // mulDenom sets z to the denominator product x*y (by taking into // account that 0 values for x or y must be interpreted as 1) and // returns z. func mulDenom(stk *stack, z, x, y nat) nat { switch { case len(x) == 0 && len(y) == 0: return z.setWord(1) case len(x) == 0: return z.set(y) case len(y) == 0: return z.set(x) } return z.mul(stk, x, y) } // scaleDenom sets z to the product x*f. // If f == 0 (zero value of denominator), z is set to (a copy of) x. func (z *Int) scaleDenom(stk *stack, x *Int, f nat) { if len(f) == 0 { z.Set(x) return } z.abs = z.abs.mul(stk, x.abs, f) z.neg = x.neg } // Cmp compares x and y and returns: // - -1 if x < y; // - 0 if x == y; // - +1 if x > y. func (x *Rat) Cmp(y *Rat) int { var a, b Int stk := getStack() defer stk.free() a.scaleDenom(stk, &x.a, y.b.abs) b.scaleDenom(stk, &y.a, x.b.abs) return a.Cmp(&b) } // Add sets z to the sum x+y and returns z. func (z *Rat) Add(x, y *Rat) *Rat { stk := getStack() defer stk.free() var a1, a2 Int a1.scaleDenom(stk, &x.a, y.b.abs) a2.scaleDenom(stk, &y.a, x.b.abs) z.a.Add(&a1, &a2) z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs) return z.norm() } // Sub sets z to the difference x-y and returns z. func (z *Rat) Sub(x, y *Rat) *Rat { stk := getStack() defer stk.free() var a1, a2 Int a1.scaleDenom(stk, &x.a, y.b.abs) a2.scaleDenom(stk, &y.a, x.b.abs) z.a.Sub(&a1, &a2) z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs) return z.norm() } // Mul sets z to the product x*y and returns z. func (z *Rat) Mul(x, y *Rat) *Rat { stk := getStack() defer stk.free() if x == y { // a squared Rat is positive and can't be reduced (no need to call norm()) z.a.neg = false z.a.abs = z.a.abs.sqr(stk, x.a.abs) if len(x.b.abs) == 0 { z.b.abs = z.b.abs.setWord(1) } else { z.b.abs = z.b.abs.sqr(stk, x.b.abs) } return z } z.a.mul(stk, &x.a, &y.a) z.b.abs = mulDenom(stk, z.b.abs, x.b.abs, y.b.abs) return z.norm() } // Quo sets z to the quotient x/y and returns z. // If y == 0, Quo panics. func (z *Rat) Quo(x, y *Rat) *Rat { stk := getStack() defer stk.free() if len(y.a.abs) == 0 { panic("division by zero") } var a, b Int a.scaleDenom(stk, &x.a, y.b.abs) b.scaleDenom(stk, &y.a, x.b.abs) z.a.abs = a.abs z.b.abs = b.abs z.a.neg = a.neg != b.neg return z.norm() }
// Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements rat-to-string conversion functions. package big import ( "errors" "fmt" "io" "strconv" "strings" ) func ratTok(ch rune) bool { return strings.ContainsRune("+-/0123456789.eE", ch) } var ratZero Rat var _ fmt.Scanner = &ratZero // *Rat must implement fmt.Scanner // Scan is a support routine for fmt.Scanner. It accepts the formats // 'e', 'E', 'f', 'F', 'g', 'G', and 'v'. All formats are equivalent. func (z *Rat) Scan(s fmt.ScanState, ch rune) error { tok, err := s.Token(true, ratTok) if err != nil { return err } if !strings.ContainsRune("efgEFGv", ch) { return errors.New("Rat.Scan: invalid verb") } if _, ok := z.SetString(string(tok)); !ok { return errors.New("Rat.Scan: invalid syntax") } return nil } // SetString sets z to the value of s and returns z and a boolean indicating // success. s can be given as a (possibly signed) fraction "a/b", or as a // floating-point number optionally followed by an exponent. // If a fraction is provided, both the dividend and the divisor may be a // decimal integer or independently use a prefix of “0b”, “0” or “0o”, // or “0x” (or their upper-case variants) to denote a binary, octal, or // hexadecimal integer, respectively. The divisor may not be signed. // If a floating-point number is provided, it may be in decimal form or // use any of the same prefixes as above but for “0” to denote a non-decimal // mantissa. A leading “0” is considered a decimal leading 0; it does not // indicate octal representation in this case. // An optional base-10 “e” or base-2 “p” (or their upper-case variants) // exponent may be provided as well, except for hexadecimal floats which // only accept an (optional) “p” exponent (because an “e” or “E” cannot // be distinguished from a mantissa digit). If the exponent's absolute value // is too large, the operation may fail. // The entire string, not just a prefix, must be valid for success. If the // operation failed, the value of z is undefined but the returned value is nil. func (z *Rat) SetString(s string) (*Rat, bool) { if len(s) == 0 { return nil, false } // len(s) > 0 // parse fraction a/b, if any if sep := strings.Index(s, "/"); sep >= 0 { if _, ok := z.a.SetString(s[:sep], 0); !ok { return nil, false } r := strings.NewReader(s[sep+1:]) var err error if z.b.abs, _, _, err = z.b.abs.scan(r, 0, false); err != nil { return nil, false } // entire string must have been consumed if _, err = r.ReadByte(); err != io.EOF { return nil, false } if len(z.b.abs) == 0 { return nil, false } return z.norm(), true } // parse floating-point number r := strings.NewReader(s) // sign neg, err := scanSign(r) if err != nil { return nil, false } // mantissa var base int var fcount int // fractional digit count; valid if <= 0 z.a.abs, base, fcount, err = z.a.abs.scan(r, 0, true) if err != nil { return nil, false } // exponent var exp int64 var ebase int exp, ebase, err = scanExponent(r, true, true) if err != nil { return nil, false } // there should be no unread characters left if _, err = r.ReadByte(); err != io.EOF { return nil, false } // special-case 0 (see also issue #16176) if len(z.a.abs) == 0 { return z.norm(), true } // len(z.a.abs) > 0 // The mantissa may have a radix point (fcount <= 0) and there // may be a nonzero exponent exp. The radix point amounts to a // division by base**(-fcount), which equals a multiplication by // base**fcount. An exponent means multiplication by ebase**exp. // Multiplications are commutative, so we can apply them in any // order. We only have powers of 2 and 10, and we split powers // of 10 into the product of the same powers of 2 and 5. This // may reduce the size of shift/multiplication factors or // divisors required to create the final fraction, depending // on the actual floating-point value. // determine binary or decimal exponent contribution of radix point var exp2, exp5 int64 if fcount < 0 { // The mantissa has a radix point ddd.dddd; and // -fcount is the number of digits to the right // of '.'. Adjust relevant exponent accordingly. d := int64(fcount) switch base { case 10: exp5 = d fallthrough // 10**e == 5**e * 2**e case 2: exp2 = d case 8: exp2 = d * 3 // octal digits are 3 bits each case 16: exp2 = d * 4 // hexadecimal digits are 4 bits each default: panic("unexpected mantissa base") } // fcount consumed - not needed anymore } // take actual exponent into account switch ebase { case 10: exp5 += exp fallthrough // see fallthrough above case 2: exp2 += exp default: panic("unexpected exponent base") } // exp consumed - not needed anymore stk := getStack() defer stk.free() // apply exp5 contributions // (start with exp5 so the numbers to multiply are smaller) if exp5 != 0 { n := exp5 if n < 0 { n = -n if n < 0 { // This can occur if -n overflows. -(-1 << 63) would become // -1 << 63, which is still negative. return nil, false } } if n > 1e6 { return nil, false // avoid excessively large exponents } pow5 := z.b.abs.expNN(stk, natFive, nat(nil).setWord(Word(n)), nil, false) // use underlying array of z.b.abs if exp5 > 0 { z.a.abs = z.a.abs.mul(stk, z.a.abs, pow5) z.b.abs = z.b.abs.setWord(1) } else { z.b.abs = pow5 } } else { z.b.abs = z.b.abs.setWord(1) } // apply exp2 contributions if exp2 < -1e7 || exp2 > 1e7 { return nil, false // avoid excessively large exponents } if exp2 > 0 { z.a.abs = z.a.abs.lsh(z.a.abs, uint(exp2)) } else if exp2 < 0 { z.b.abs = z.b.abs.lsh(z.b.abs, uint(-exp2)) } z.a.neg = neg && len(z.a.abs) > 0 // 0 has no sign return z.norm(), true } // scanExponent scans the longest possible prefix of r representing a base 10 // (“e”, “E”) or a base 2 (“p”, “P”) exponent, if any. It returns the // exponent, the exponent base (10 or 2), or a read or syntax error, if any. // // If sepOk is set, an underscore character “_” may appear between successive // exponent digits; such underscores do not change the value of the exponent. // Incorrect placement of underscores is reported as an error if there are no // other errors. If sepOk is not set, underscores are not recognized and thus // terminate scanning like any other character that is not a valid digit. // // exponent = ( "e" | "E" | "p" | "P" ) [ sign ] digits . // sign = "+" | "-" . // digits = digit { [ '_' ] digit } . // digit = "0" ... "9" . // // A base 2 exponent is only permitted if base2ok is set. func scanExponent(r io.ByteScanner, base2ok, sepOk bool) (exp int64, base int, err error) { // one char look-ahead ch, err := r.ReadByte() if err != nil { if err == io.EOF { err = nil } return 0, 10, err } // exponent char switch ch { case 'e', 'E': base = 10 case 'p', 'P': if base2ok { base = 2 break // ok } fallthrough // binary exponent not permitted default: r.UnreadByte() // ch does not belong to exponent anymore return 0, 10, nil } // sign var digits []byte ch, err = r.ReadByte() if err == nil && (ch == '+' || ch == '-') { if ch == '-' { digits = append(digits, '-') } ch, err = r.ReadByte() } // prev encodes the previously seen char: it is one // of '_', '0' (a digit), or '.' (anything else). A // valid separator '_' may only occur after a digit. prev := '.' invalSep := false // exponent value hasDigits := false for err == nil { if '0' <= ch && ch <= '9' { digits = append(digits, ch) prev = '0' hasDigits = true } else if ch == '_' && sepOk { if prev != '0' { invalSep = true } prev = '_' } else { r.UnreadByte() // ch does not belong to number anymore break } ch, err = r.ReadByte() } if err == io.EOF { err = nil } if err == nil && !hasDigits { err = errNoDigits } if err == nil { exp, err = strconv.ParseInt(string(digits), 10, 64) } // other errors take precedence over invalid separators if err == nil && (invalSep || prev == '_') { err = errInvalSep } return } // String returns a string representation of x in the form "a/b" (even if b == 1). func (x *Rat) String() string { return string(x.marshal(nil)) } // marshal implements [Rat.String] returning a slice of bytes. // It appends the string representation of x in the form "a/b" (even if b == 1) to buf, // and returns the extended buffer. func (x *Rat) marshal(buf []byte) []byte { buf = x.a.Append(buf, 10) buf = append(buf, '/') if len(x.b.abs) != 0 { buf = x.b.Append(buf, 10) } else { buf = append(buf, '1') } return buf } // RatString returns a string representation of x in the form "a/b" if b != 1, // and in the form "a" if b == 1. func (x *Rat) RatString() string { if x.IsInt() { return x.a.String() } return x.String() } // FloatString returns a string representation of x in decimal form with prec // digits of precision after the radix point. The last digit is rounded to // nearest, with halves rounded away from zero. func (x *Rat) FloatString(prec int) string { var buf []byte if x.IsInt() { buf = x.a.Append(buf, 10) if prec > 0 { buf = append(buf, '.') for i := prec; i > 0; i-- { buf = append(buf, '0') } } return string(buf) } // x.b.abs != 0 stk := getStack() defer stk.free() q, r := nat(nil).div(stk, nat(nil), x.a.abs, x.b.abs) p := natOne if prec > 0 { p = nat(nil).expNN(stk, natTen, nat(nil).setUint64(uint64(prec)), nil, false) } r = r.mul(stk, r, p) r, r2 := r.div(stk, nat(nil), r, x.b.abs) // see if we need to round up r2 = r2.add(r2, r2) if x.b.abs.cmp(r2) <= 0 { r = r.add(r, natOne) if r.cmp(p) >= 0 { q = nat(nil).add(q, natOne) r = nat(nil).sub(r, p) } } if x.a.neg { buf = append(buf, '-') } buf = append(buf, q.utoa(10)...) // itoa ignores sign if q == 0 if prec > 0 { buf = append(buf, '.') rs := r.utoa(10) for i := prec - len(rs); i > 0; i-- { buf = append(buf, '0') } buf = append(buf, rs...) } return string(buf) } // Note: FloatPrec (below) is in this file rather than rat.go because // its results are relevant for decimal representation/printing. // FloatPrec returns the number n of non-repeating digits immediately // following the decimal point of the decimal representation of x. // The boolean result indicates whether a decimal representation of x // with that many fractional digits is exact or rounded. // // Examples: // // x n exact decimal representation n fractional digits // 0 0 true 0 // 1 0 true 1 // 1/2 1 true 0.5 // 1/3 0 false 0 (0.333... rounded) // 1/4 2 true 0.25 // 1/6 1 false 0.2 (0.166... rounded) func (x *Rat) FloatPrec() (n int, exact bool) { stk := getStack() defer stk.free() // Determine q and largest p2, p5 such that d = q·2^p2·5^p5. // The results n, exact are: // // n = max(p2, p5) // exact = q == 1 // // For details see: // https://en.wikipedia.org/wiki/Repeating_decimal#Reciprocals_of_integers_not_coprime_to_10 d := x.Denom().abs // d >= 1 // Determine p2 by counting factors of 2. // p2 corresponds to the trailing zero bits in d. // Do this first to reduce q as much as possible. var q nat p2 := d.trailingZeroBits() q = q.rsh(d, p2) // Determine p5 by counting factors of 5. // Build a table starting with an initial power of 5, // and use repeated squaring until the factor doesn't // divide q anymore. Then use the table to determine // the power of 5 in q. const fp = 13 // f == 5^fp var tab []nat // tab[i] == (5^fp)^(2^i) == 5^(fp·2^i) f := nat{1220703125} // == 5^fp (must fit into a uint32 Word) var t, r nat // temporaries for { if _, r = t.div(stk, r, q, f); len(r) != 0 { break // f doesn't divide q evenly } tab = append(tab, f) f = nat(nil).sqr(stk, f) // nat(nil) to ensure a new f for each table entry } // Factor q using the table entries, if any. // We start with the largest factor f = tab[len(tab)-1] // that evenly divides q. It does so at most once because // otherwise f·f would also divide q. That can't be true // because f·f is the next higher table entry, contradicting // how f was chosen in the first place. // The same reasoning applies to the subsequent factors. var p5 uint for i := len(tab) - 1; i >= 0; i-- { if t, r = t.div(stk, r, q, tab[i]); len(r) == 0 { p5 += fp * (1 << i) // tab[i] == 5^(fp·2^i) q = q.set(t) } } // If fp != 1, we may still have multiples of 5 left. for { if t, r = t.div(stk, r, q, natFive); len(r) != 0 { break } p5++ q = q.set(t) } return int(max(p2, p5)), q.cmp(natOne) == 0 }
// Copyright 2015 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // This file implements encoding/decoding of Rats. package big import ( "errors" "fmt" "internal/byteorder" "math" ) // Gob codec version. Permits backward-compatible changes to the encoding. const ratGobVersion byte = 1 // GobEncode implements the [encoding/gob.GobEncoder] interface. func (x *Rat) GobEncode() ([]byte, error) { if x == nil { return nil, nil } buf := make([]byte, 1+4+(len(x.a.abs)+len(x.b.abs))*_S) // extra bytes for version and sign bit (1), and numerator length (4) i := x.b.abs.bytes(buf) j := x.a.abs.bytes(buf[:i]) n := i - j if int(uint32(n)) != n { // this should never happen return nil, errors.New("Rat.GobEncode: numerator too large") } byteorder.BEPutUint32(buf[j-4:j], uint32(n)) j -= 1 + 4 b := ratGobVersion << 1 // make space for sign bit if x.a.neg { b |= 1 } buf[j] = b return buf[j:], nil } // GobDecode implements the [encoding/gob.GobDecoder] interface. func (z *Rat) GobDecode(buf []byte) error { if len(buf) == 0 { // Other side sent a nil or default value. *z = Rat{} return nil } if len(buf) < 5 { return errors.New("Rat.GobDecode: buffer too small") } b := buf[0] if b>>1 != ratGobVersion { return fmt.Errorf("Rat.GobDecode: encoding version %d not supported", b>>1) } const j = 1 + 4 ln := byteorder.BEUint32(buf[j-4 : j]) if uint64(ln) > math.MaxInt-j { return errors.New("Rat.GobDecode: invalid length") } i := j + int(ln) if len(buf) < i { return errors.New("Rat.GobDecode: buffer too small") } z.a.neg = b&1 != 0 z.a.abs = z.a.abs.setBytes(buf[j:i]) z.b.abs = z.b.abs.setBytes(buf[i:]) return nil } // AppendText implements the [encoding.TextAppender] interface. func (x *Rat) AppendText(b []byte) ([]byte, error) { if x.IsInt() { return x.a.AppendText(b) } return x.marshal(b), nil } // MarshalText implements the [encoding.TextMarshaler] interface. func (x *Rat) MarshalText() (text []byte, err error) { return x.AppendText(nil) } // UnmarshalText implements the [encoding.TextUnmarshaler] interface. func (z *Rat) UnmarshalText(text []byte) error { // TODO(gri): get rid of the []byte/string conversion if _, ok := z.SetString(string(text)); !ok { return fmt.Errorf("math/big: cannot unmarshal %q into a *big.Rat", text) } return nil }
// Code generated by "stringer -type=RoundingMode"; DO NOT EDIT. package big import "strconv" func _() { // An "invalid array index" compiler error signifies that the constant values have changed. // Re-run the stringer command to generate them again. var x [1]struct{} _ = x[ToNearestEven-0] _ = x[ToNearestAway-1] _ = x[ToZero-2] _ = x[AwayFromZero-3] _ = x[ToNegativeInf-4] _ = x[ToPositiveInf-5] } const _RoundingMode_name = "ToNearestEvenToNearestAwayToZeroAwayFromZeroToNegativeInfToPositiveInf" var _RoundingMode_index = [...]uint8{0, 13, 26, 32, 44, 57, 70} func (i RoundingMode) String() string { if i >= RoundingMode(len(_RoundingMode_index)-1) { return "RoundingMode(" + strconv.FormatInt(int64(i), 10) + ")" } return _RoundingMode_name[_RoundingMode_index[i]:_RoundingMode_index[i+1]] }
// Copyright 2017 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package big import ( "math" "sync" ) var threeOnce struct { sync.Once v *Float } func three() *Float { threeOnce.Do(func() { threeOnce.v = NewFloat(3.0) }) return threeOnce.v } // Sqrt sets z to the rounded square root of x, and returns it. // // If z's precision is 0, it is changed to x's precision before the // operation. Rounding is performed according to z's precision and // rounding mode, but z's accuracy is not computed. Specifically, the // result of z.Acc() is undefined. // // The function panics if z < 0. The value of z is undefined in that // case. func (z *Float) Sqrt(x *Float) *Float { if debugFloat { x.validate() } if z.prec == 0 { z.prec = x.prec } if x.Sign() == -1 { // following IEEE754-2008 (section 7.2) panic(ErrNaN{"square root of negative operand"}) } // handle ±0 and +∞ if x.form != finite { z.acc = Exact z.form = x.form z.neg = x.neg // IEEE754-2008 requires √±0 = ±0 return z } // MantExp sets the argument's precision to the receiver's, and // when z.prec > x.prec this will lower z.prec. Restore it after // the MantExp call. prec := z.prec b := x.MantExp(z) z.prec = prec // Compute √(z·2**b) as // √( z)·2**(½b) if b is even // √(2z)·2**(⌊½b⌋) if b > 0 is odd // √(½z)·2**(⌈½b⌉) if b < 0 is odd switch b % 2 { case 0: // nothing to do case 1: z.exp++ case -1: z.exp-- } // 0.25 <= z < 2.0 // Solving 1/x² - z = 0 avoids Quo calls and is faster, especially // for high precisions. z.sqrtInverse(z) // re-attach halved exponent return z.SetMantExp(z, b/2) } // Compute √x (to z.prec precision) by solving // // 1/t² - x = 0 // // for t (using Newton's method), and then inverting. func (z *Float) sqrtInverse(x *Float) { // let // f(t) = 1/t² - x // then // g(t) = f(t)/f'(t) = -½t(1 - xt²) // and the next guess is given by // t2 = t - g(t) = ½t(3 - xt²) u := newFloat(z.prec) v := newFloat(z.prec) three := three() ng := func(t *Float) *Float { u.prec = t.prec v.prec = t.prec u.Mul(t, t) // u = t² u.Mul(x, u) // = xt² v.Sub(three, u) // v = 3 - xt² u.Mul(t, v) // u = t(3 - xt²) u.exp-- // = ½t(3 - xt²) return t.Set(u) } xf, _ := x.Float64() sqi := newFloat(z.prec) sqi.SetFloat64(1 / math.Sqrt(xf)) for prec := z.prec + 32; sqi.prec < prec; { sqi.prec *= 2 sqi = ng(sqi) } // sqi = 1/√x // x/√x = √x z.Mul(x, sqi) } // newFloat returns a new *Float with space for twice the given // precision. func newFloat(prec2 uint32) *Float { z := new(Float) // nat.make ensures the slice length is > 0 z.mant = z.mant.make(int(prec2/_W) * 2) return z }