diff options
author | cinap_lenrek <cinap_lenrek@felloff.net> | 2015-11-21 09:39:59 +0100 |
---|---|---|
committer | cinap_lenrek <cinap_lenrek@felloff.net> | 2015-11-21 09:39:59 +0100 |
commit | 38e1e5272fc9c66a00d702246813135452819ffe (patch) | |
tree | b2d56b8f5e66a17daeb63693fc4dbd15c7308275 /sys/src/libmp | |
parent | b677ab0c5909942bf8946e9e9bd148dea7dae718 (diff) |
libmp: initial attempt at constant time code, faster reductions for special primes (for ecc)
introduce MPtimesafe flag to request time invariant computation
disables normalization so significant digits are not leaked.
Diffstat (limited to 'sys/src/libmp')
30 files changed, 566 insertions, 222 deletions
diff --git a/sys/src/libmp/port/betomp.c b/sys/src/libmp/port/betomp.c index 9197f3a14..0830704ef 100644 --- a/sys/src/libmp/port/betomp.c +++ b/sys/src/libmp/port/betomp.c @@ -13,19 +13,12 @@ betomp(uchar *p, uint n, mpint *b) b = mpnew(0); setmalloctag(b, getcallerpc(&p)); } - - // dump leading zeros - while(*p == 0 && n > 1){ - p++; - n--; - } - - // get the space mpbits(b, n*8); - b->top = DIGITS(n*8); - m = b->top-1; - // first digit might not be Dbytes long + m = DIGITS(n*8); + b->top = m--; + b->sign = 1; + s = ((n-1)*8)%Dbits; x = 0; for(; n > 0; n--){ @@ -37,6 +30,5 @@ betomp(uchar *p, uint n, mpint *b) x = 0; } } - - return b; + return mpnorm(b); } diff --git a/sys/src/libmp/port/letomp.c b/sys/src/libmp/port/letomp.c index e23fed21e..d5cca241b 100644 --- a/sys/src/libmp/port/letomp.c +++ b/sys/src/libmp/port/letomp.c @@ -9,8 +9,10 @@ letomp(uchar *s, uint n, mpint *b) int i=0, m = 0; mpdigit x=0; - if(b == nil) + if(b == nil){ b = mpnew(0); + setmalloctag(b, getcallerpc(&s)); + } mpbits(b, 8*n); for(; n > 0; n--){ x |= ((mpdigit)(*s++)) << i; @@ -24,5 +26,6 @@ letomp(uchar *s, uint n, mpint *b) if(i > 0) b->p[m++] = x; b->top = m; - return b; + b->sign = 1; + return mpnorm(b); } diff --git a/sys/src/libmp/port/mkfile b/sys/src/libmp/port/mkfile index 76fa25dd7..b0bdf67d5 100644 --- a/sys/src/libmp/port/mkfile +++ b/sys/src/libmp/port/mkfile @@ -6,12 +6,15 @@ FILES=\ mpfmt\ strtomp\ mptobe\ + mptober\ mptole\ + mptolel\ betomp\ letomp\ mpadd\ mpsub\ mpcmp\ + mpsel\ mpfactorial\ mpmul\ mpleft\ @@ -20,10 +23,12 @@ FILES=\ mpvecsub\ mpvecdigmuladd\ mpveccmp\ + mpvectscmp\ mpdigdiv\ mpdiv\ mpexp\ mpmod\ + mpmodop\ mpextendedgcd\ mpinvert\ mprand\ diff --git a/sys/src/libmp/port/mpadd.c b/sys/src/libmp/port/mpadd.c index 6022a64ef..9a1ccde66 100644 --- a/sys/src/libmp/port/mpadd.c +++ b/sys/src/libmp/port/mpadd.c @@ -9,6 +9,8 @@ mpmagadd(mpint *b1, mpint *b2, mpint *sum) int m, n; mpint *t; + sum->flags |= (b1->flags | b2->flags) & MPtimesafe; + // get the sizes right if(b2->top > b1->top){ t = b1; @@ -41,6 +43,7 @@ mpadd(mpint *b1, mpint *b2, mpint *sum) int sign; if(b1->sign != b2->sign){ + assert(((b1->flags | b2->flags | sum->flags) & MPtimesafe) == 0); if(b1->sign < 0) mpmagsub(b2, b1, sum); else diff --git a/sys/src/libmp/port/mpaux.c b/sys/src/libmp/port/mpaux.c index 66f1524f0..eb70a9364 100644 --- a/sys/src/libmp/port/mpaux.c +++ b/sys/src/libmp/port/mpaux.c @@ -5,33 +5,27 @@ static mpdigit _mptwodata[1] = { 2 }; static mpint _mptwo = { - 1, - 1, - 1, + 1, 1, 1, _mptwodata, - MPstatic + MPstatic|MPnorm }; mpint *mptwo = &_mptwo; static mpdigit _mponedata[1] = { 1 }; static mpint _mpone = { - 1, - 1, - 1, + 1, 1, 1, _mponedata, - MPstatic + MPstatic|MPnorm }; mpint *mpone = &_mpone; static mpdigit _mpzerodata[1] = { 0 }; static mpint _mpzero = { - 1, - 1, - 0, + 1, 1, 0, _mpzerodata, - MPstatic + MPstatic|MPnorm }; mpint *mpzero = &_mpzero; @@ -57,18 +51,17 @@ mpnew(int n) if(n < 0) sysfatal("mpsetminbits: n < 0"); - b = mallocz(sizeof(mpint), 1); - setmalloctag(b, getcallerpc(&n)); - if(b == nil) - sysfatal("mpnew: %r"); n = DIGITS(n); if(n < mpmindigits) n = mpmindigits; - b->p = (mpdigit*)mallocz(n*Dbytes, 1); - if(b->p == nil) + b = mallocz(sizeof(mpint) + n*Dbytes, 1); + if(b == nil) sysfatal("mpnew: %r"); + setmalloctag(b, getcallerpc(&n)); + b->p = (mpdigit*)&b[1]; b->size = n; b->sign = 1; + b->flags = MPnorm; return b; } @@ -83,16 +76,23 @@ mpbits(mpint *b, int m) if(b->size >= n){ if(b->top >= n) return; - memset(&b->p[b->top], 0, Dbytes*(n - b->top)); - b->top = n; - return; + } else { + if(b->p == (mpdigit*)&b[1]){ + b->p = (mpdigit*)mallocz(n*Dbytes, 0); + if(b->p == nil) + sysfatal("mpbits: %r"); + memmove(b->p, &b[1], Dbytes*b->top); + memset(&b[1], 0, Dbytes*b->size); + } else { + b->p = (mpdigit*)realloc(b->p, n*Dbytes); + if(b->p == nil) + sysfatal("mpbits: %r"); + } + b->size = n; } - b->p = (mpdigit*)realloc(b->p, n*Dbytes); - if(b->p == nil) - sysfatal("mpbits: %r"); memset(&b->p[b->top], 0, Dbytes*(n - b->top)); - b->size = n; b->top = n; + b->flags &= ~MPnorm; } void @@ -102,22 +102,30 @@ mpfree(mpint *b) return; if(b->flags & MPstatic) sysfatal("freeing mp constant"); - memset(b->p, 0, b->size*Dbytes); // information hiding - free(b->p); + memset(b->p, 0, b->size*Dbytes); + if(b->p != (mpdigit*)&b[1]) + free(b->p); free(b); } -void +mpint* mpnorm(mpint *b) { int i; + if(b->flags & MPtimesafe){ + assert(b->sign == 1); + b->flags &= ~MPnorm; + return b; + } for(i = b->top-1; i >= 0; i--) if(b->p[i] != 0) break; b->top = i+1; if(b->top == 0) b->sign = 1; + b->flags |= MPnorm; + return b; } mpint* @@ -126,8 +134,10 @@ mpcopy(mpint *old) mpint *new; new = mpnew(Dbits*old->size); - new->top = old->top; + setmalloctag(new, getcallerpc(&old)); new->sign = old->sign; + new->top = old->top; + new->flags = old->flags & ~MPstatic; memmove(new->p, old->p, Dbytes*old->top); return new; } @@ -135,9 +145,14 @@ mpcopy(mpint *old) void mpassign(mpint *old, mpint *new) { + if(new == nil || old == new) + return; + new->top = 0; mpbits(new, Dbits*old->top); new->sign = old->sign; new->top = old->top; + new->flags &= ~MPnorm; + new->flags |= old->flags & ~MPstatic; memmove(new->p, old->p, Dbytes*old->top); } @@ -167,6 +182,7 @@ mplowbits0(mpint *n) int k, bit, digit; mpdigit d; + assert(n->flags & MPnorm); if(n->top==0) return 0; k = 0; @@ -187,4 +203,3 @@ mplowbits0(mpint *n) } return k; } - diff --git a/sys/src/libmp/port/mpcmp.c b/sys/src/libmp/port/mpcmp.c index a2e3cf724..7ab5a16b6 100644 --- a/sys/src/libmp/port/mpcmp.c +++ b/sys/src/libmp/port/mpcmp.c @@ -8,10 +8,14 @@ mpmagcmp(mpint *b1, mpint *b2) { int i; - i = b1->top - b2->top; - if(i) - return i; - + i = b1->flags | b2->flags; + if(i & MPtimesafe) + return mpvectscmp(b1->p, b1->top, b2->p, b2->top); + if(i & MPnorm){ + i = b1->top - b2->top; + if(i) + return i; + } return mpveccmp(b1->p, b1->top, b2->p, b2->top); } @@ -19,10 +23,8 @@ mpmagcmp(mpint *b1, mpint *b2) int mpcmp(mpint *b1, mpint *b2) { - if(b1->sign != b2->sign) - return b1->sign - b2->sign; - if(b1->sign < 0) - return mpmagcmp(b2, b1); - else - return mpmagcmp(b1, b2); + int sign; + + sign = (b1->sign - b2->sign) >> 1; // -1, 0, 1 + return sign | (sign&1)-1 & mpmagcmp(b1, b2)*b1->sign; } diff --git a/sys/src/libmp/port/mpdiv.c b/sys/src/libmp/port/mpdiv.c index 92aee03f4..54b943862 100644 --- a/sys/src/libmp/port/mpdiv.c +++ b/sys/src/libmp/port/mpdiv.c @@ -13,10 +13,29 @@ mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder) mpdigit qd, *up, *vp, *qp; mpint *u, *v, *t; + assert(quotient != remainder); + assert(divisor->flags & MPnorm); + // divide bv zero if(divisor->top == 0) abort(); + // division by one or small powers of two + if(divisor->top == 1 && (divisor->p[0] & divisor->p[0]-1) == 0){ + vlong r = (vlong)dividend->sign * (dividend->p[0] & divisor->p[0]-1); + if(quotient != nil){ + for(s = 0; ((divisor->p[0] >> s) & 1) == 0; s++) + ; + mpright(dividend, s, quotient); + } + if(remainder != nil){ + remainder->flags |= dividend->flags & MPtimesafe; + vtomp(r, remainder); + } + return; + } + assert((dividend->flags & MPtimesafe) == 0); + // quick check if(mpmagcmp(dividend, divisor) < 0){ if(remainder != nil) @@ -95,12 +114,14 @@ mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder) *up-- = 0; } if(qp != nil){ + assert((quotient->flags & MPtimesafe) == 0); mpnorm(quotient); if(dividend->sign != divisor->sign) quotient->sign = -1; } if(remainder != nil){ + assert((remainder->flags & MPtimesafe) == 0); mpright(u, s, remainder); // u is the remainder shifted remainder->sign = dividend->sign; } diff --git a/sys/src/libmp/port/mpeuclid.c b/sys/src/libmp/port/mpeuclid.c index 80b5983bf..586b9cc22 100644 --- a/sys/src/libmp/port/mpeuclid.c +++ b/sys/src/libmp/port/mpeuclid.c @@ -13,6 +13,9 @@ mpeuclid(mpint *a, mpint *b, mpint *d, mpint *x, mpint *y) { mpint *tmp, *x0, *x1, *x2, *y0, *y1, *y2, *q, *r; + assert((a->flags&b->flags) & MPnorm); + assert(((a->flags|b->flags|d->flags|x->flags|y->flags) & MPtimesafe) == 0); + if(a->sign<0 || b->sign<0) sysfatal("mpeuclid: negative arg"); diff --git a/sys/src/libmp/port/mpexp.c b/sys/src/libmp/port/mpexp.c index 9ec067cb9..1ebabba93 100644 --- a/sys/src/libmp/port/mpexp.c +++ b/sys/src/libmp/port/mpexp.c @@ -22,6 +22,10 @@ mpexp(mpint *b, mpint *e, mpint *m, mpint *res) mpdigit d, bit; int i, j; + assert(m->flags & MPnorm); + assert((e->flags & MPtimesafe) == 0); + res->flags |= b->flags & MPtimesafe; + i = mpcmp(e,mpzero); if(i==0){ mpassign(mpone, res); diff --git a/sys/src/libmp/port/mpextendedgcd.c b/sys/src/libmp/port/mpextendedgcd.c index 413a05c2a..72e49bce1 100644 --- a/sys/src/libmp/port/mpextendedgcd.c +++ b/sys/src/libmp/port/mpextendedgcd.c @@ -5,7 +5,7 @@ // extended binary gcd // -// For a anv b it solves, v = gcd(a,b) and finds x and y s.t. +// For a and b it solves, v = gcd(a,b) and finds x and y s.t. // ax + by = v // // Handbook of Applied Cryptography, Menezes et al, 1997, pg 608. @@ -15,6 +15,9 @@ mpextendedgcd(mpint *a, mpint *b, mpint *v, mpint *x, mpint *y) mpint *u, *A, *B, *C, *D; int g; + assert((a->flags&b->flags) & MPnorm); + assert(((a->flags|b->flags|v->flags|x->flags|y->flags) & MPtimesafe) == 0); + if(a->sign < 0 || b->sign < 0){ mpassign(mpzero, v); mpassign(mpzero, y); diff --git a/sys/src/libmp/port/mpfmt.c b/sys/src/libmp/port/mpfmt.c index f7c42a7bc..676b64be0 100644 --- a/sys/src/libmp/port/mpfmt.c +++ b/sys/src/libmp/port/mpfmt.c @@ -102,6 +102,7 @@ to10(mpint *b, char *buf, int len) return -1; d = mpcopy(b); + mpnorm(d); r = mpnew(0); billion = uitomp(1000000000, nil); out = buf+len; @@ -128,15 +129,20 @@ int mpfmt(Fmt *fmt) { mpint *b; - char *p; + char *p, f; b = va_arg(fmt->args, mpint*); if(b == nil) return fmtstrcpy(fmt, "*"); - + + f = b->flags; + b->flags &= ~MPtimesafe; + p = mptoa(b, fmt->prec, nil, 0); fmt->flags &= ~FmtPrec; + b->flags = f; + if(p == nil) return fmtstrcpy(fmt, "*"); else{ diff --git a/sys/src/libmp/port/mpleft.c b/sys/src/libmp/port/mpleft.c index cdcdff740..38929b82e 100644 --- a/sys/src/libmp/port/mpleft.c +++ b/sys/src/libmp/port/mpleft.c @@ -15,8 +15,8 @@ mpleft(mpint *b, int shift, mpint *res) return; } - // a negative left shift is a right shift - if(shift < 0){ + // a zero or negative left shift is a right shift + if(shift <= 0){ mpright(b, -shift, res); return; } @@ -46,7 +46,6 @@ mpleft(mpint *b, int shift, mpint *res) for(i = 0; i < d; i++) res->p[i] = 0; - // normalize - while(res->top > 0 && res->p[res->top-1] == 0) - res->top--; + res->flags |= b->flags & MPtimesafe; + mpnorm(res); } diff --git a/sys/src/libmp/port/mpmod.c b/sys/src/libmp/port/mpmod.c index 91bebfa27..c053f5b7f 100644 --- a/sys/src/libmp/port/mpmod.c +++ b/sys/src/libmp/port/mpmod.c @@ -2,14 +2,100 @@ #include <mp.h> #include "dat.h" -// remainder = b mod m -// -// knuth, vol 2, pp 398-400 - void -mpmod(mpint *b, mpint *m, mpint *remainder) +mpmod(mpint *x, mpint *n, mpint *r) { - mpdiv(b, m, nil, remainder); - if(remainder->sign < 0) - mpadd(m, remainder, remainder); + static int busy; + static mpint *p, *m, *c, *v; + mpdigit q[32], t[64], d; + int sign, k, s, qn, tn; + + sign = x->sign; + + assert(n->flags & MPnorm); + if(n->top < 2 || n->top > nelem(q) || (x->top-n->top) > nelem(q)) + goto hard; + + /* + * check if n = 2**k - c where c has few power of two factors + * above the lowest digit. + */ + for(k = n->top-1; k > 0; k--){ + d = n->p[k] >> 1; + if((d+1 & d) != 0) + goto hard; + } + + d = n->p[n->top-1]; + for(s = 0; (d & (mpdigit)1<<Dbits-1) == 0; s++) + d <<= 1; + + /* lo(x) = x[0:k-1], hi(x) = x[k:xn-1] */ + k = n->top; + + while(_tas(&busy)) + ; + + if(p == nil || mpmagcmp(n, p) != 0){ + if(m == nil){ + m = mpnew(0); + c = mpnew(0); + p = mpnew(0); + } + mpassign(n, p); + + mpleft(n, s, m); + mpleft(mpone, k*Dbits, c); + mpsub(c, m, c); + } + + mpleft(x, s, r); + if(r->top <= k){ + mpbits(r, (k+1)*Dbits); + r->top = k+1; + } + + /* q = hi(r) */ + qn = r->top - k; + memmove(q, r->p+k, qn*Dbytes); + + /* r = lo(r) */ + r->top = k; + + do { + /* t = q*c */ + tn = qn + c->top; + memset(t, 0, tn*Dbytes); + mpvecmul(q, qn, c->p, c->top, t); + + /* q = hi(t) */ + qn = tn - k; + if(qn <= 0) qn = 0; + else memmove(q, t+k, qn*Dbytes); + + /* r += lo(t) */ + if(tn > k) + tn = k; + mpvecadd(r->p, k, t, tn, r->p); + + /* if(r >= m) r -= m */ + mpvecsub(r->p, k+1, m->p, k, t), d = t[k]; + for(tn = 0; tn < k; tn++) + r->p[tn] = (r->p[tn] & d) | (t[tn] & ~d); + } while(qn > 0); + + busy = 0; + + if(s != 0) + mpright(r, s, r); + else + mpnorm(r); + goto done; + +hard: + mpdiv(x, n, nil, r); + +done: + if(sign < 0) + mpmagsub(n, r, r); } diff --git a/sys/src/libmp/port/mpmodop.c b/sys/src/libmp/port/mpmodop.c new file mode 100644 index 000000000..8bc7cbb5a --- /dev/null +++ b/sys/src/libmp/port/mpmodop.c @@ -0,0 +1,96 @@ +#include <u.h> +#include <libc.h> +#include <mp.h> + +/* operands need to have m->top+1 digits of space and satisfy 0 ≤ a ≤ m-1 */ +static mpint* +modarg(mpint *a, mpint *m) +{ + if(a->size <= m->top || a->sign < 0 || mpmagcmp(a, m) >= 0){ + a = mpcopy(a); + mpmod(a, m, a); + mpbits(a, Dbits*(m->top+1)); + a->top = m->top; + } else if(a->top < m->top){ + memset(&a->p[a->top], 0, (m->top - a->top)*Dbytes); + } + return a; +} + +void +mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum) +{ + mpint *a, *b; + mpdigit d; + int i, j; + + a = modarg(b1, m); + b = modarg(b2, m); + + sum->flags |= (a->flags | b->flags) & MPtimesafe; + mpbits(sum, Dbits*2*(m->top+1)); + + mpvecadd(a->p, m->top, b->p, m->top, sum->p); + mpvecsub(sum->p, m->top+1, m->p, m->top, sum->p+m->top+1); + + d = sum->p[2*m->top+1]; + for(i = 0, j = m->top+1; i < m->top; i++, j++) + sum->p[i] = (sum->p[i] & d) | (sum->p[j] & ~d); + + sum->top = m->top; + sum->sign = 1; + mpnorm(sum); + + if(a != b1) + mpfree(a); + if(b != b2) + mpfree(b); +} + +void +mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff) +{ + mpint *a, *b; + mpdigit d; + int i, j; + + a = modarg(b1, m); + b = modarg(b2, m); + + diff->flags |= (a->flags | b->flags) & MPtimesafe; + mpbits(diff, Dbits*2*(m->top+1)); + + a->p[m->top] = 0; + mpvecsub(a->p, m->top+1, b->p, m->top, diff->p); + mpvecadd(diff->p, m->top, m->p, m->top, diff->p+m->top+1); + + d = ~diff->p[m->top]; + for(i = 0, j = m->top+1; i < m->top; i++, j++) + diff->p[i] = (diff->p[i] & d) | (diff->p[j] & ~d); + + diff->top = m->top; + diff->sign = 1; + mpnorm(diff); + + if(a != b1) + mpfree(a); + if(b != b2) + mpfree(b); +} + +void +mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod) +{ + mpint *a, *b; + + a = modarg(b1, m); + b = modarg(b2, m); + + mpmul(a, b, prod); + mpmod(prod, m, prod); + + if(a != b1) + mpfree(a); + if(b != b2) + mpfree(b); +} diff --git a/sys/src/libmp/port/mpmul.c b/sys/src/libmp/port/mpmul.c index dedd474a7..777adf307 100644 --- a/sys/src/libmp/port/mpmul.c +++ b/sys/src/libmp/port/mpmul.c @@ -113,10 +113,6 @@ mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) a = b; b = t; } - if(blen == 0){ - memset(p, 0, Dbytes*(alen+blen)); - return; - } if(alen >= KARATSUBAMIN && blen > 1){ // O(n^1.585) @@ -132,24 +128,48 @@ mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) } void +mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) +{ + int i; + mpdigit *t; + + if(alen < blen){ + i = alen; + alen = blen; + blen = i; + t = a; + a = b; + b = t; + } + if(blen == 0) + return; + for(i = 0; i < blen; i++) + mpvecdigmuladd(a, alen, b[i], &p[i]); +} + +void mpmul(mpint *b1, mpint *b2, mpint *prod) { mpint *oprod; - oprod = nil; + oprod = prod; if(prod == b1 || prod == b2){ - oprod = prod; prod = mpnew(0); + prod->flags = oprod->flags; } + prod->flags |= (b1->flags | b2->flags) & MPtimesafe; prod->top = 0; mpbits(prod, (b1->top+b2->top+1)*Dbits); - mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p); + if(prod->flags & MPtimesafe) + mpvectsmul(b1->p, b1->top, b2->p, b2->top, prod->p); + else + mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p); prod->top = b1->top+b2->top+1; prod->sign = b1->sign*b2->sign; mpnorm(prod); - if(oprod != nil){ + if(oprod != prod){ mpassign(prod, oprod); mpfree(prod); } diff --git a/sys/src/libmp/port/mpnrand.c b/sys/src/libmp/port/mpnrand.c index 600283d9d..ebbed5097 100644 --- a/sys/src/libmp/port/mpnrand.c +++ b/sys/src/libmp/port/mpnrand.c @@ -16,8 +16,10 @@ mpnrand(mpint *n, void (*gen)(uchar*, int), mpint *b) mpleft(mpone, bits, m); mpsub(m, mpone, m); - if(b == nil) + if(b == nil){ b = mpnew(bits); + setmalloctag(b, getcallerpc(&n)); + } /* m = m - (m % n) */ mpmod(m, n, b); diff --git a/sys/src/libmp/port/mprand.c b/sys/src/libmp/port/mprand.c index fd288f24e..29433b669 100644 --- a/sys/src/libmp/port/mprand.c +++ b/sys/src/libmp/port/mprand.c @@ -6,37 +6,32 @@ mpint* mprand(int bits, void (*gen)(uchar*, int), mpint *b) { - int n, m; mpdigit mask; + int n, m; uchar *p; n = DIGITS(bits); - if(b == nil) + if(b == nil){ b = mpnew(bits); - else + setmalloctag(b, getcallerpc(&bits)); + }else mpbits(b, bits); p = malloc(n*Dbytes); if(p == nil) - return nil; + sysfatal("mprand: %r"); (*gen)(p, n*Dbytes); betomp(p, n*Dbytes, b); free(p); // make sure we don't give too many bits m = bits%Dbits; - n--; - if(m > 0){ - mask = 1; - mask <<= m; - mask--; - b->p[n] &= mask; - } + if(m == 0) + return b; - for(; n >= 0; n--) - if(b->p[n] != 0) - break; - b->top = n+1; - b->sign = 1; - return b; + mask = 1; + mask <<= m; + mask--; + b->p[n-1] &= mask; + return mpnorm(b); } diff --git a/sys/src/libmp/port/mpright.c b/sys/src/libmp/port/mpright.c index 03039177b..dde7aeace 100644 --- a/sys/src/libmp/port/mpright.c +++ b/sys/src/libmp/port/mpright.c @@ -23,12 +23,16 @@ mpright(mpint *b, int shift, mpint *res) if(res != b) mpbits(res, b->top*Dbits - shift); + else if(shift == 0) + return; + d = shift/Dbits; r = shift - d*Dbits; l = Dbits - r; // shift all the bits out == zero if(d>=b->top){ + res->sign = 1; res->top = 0; return; } @@ -46,9 +50,8 @@ mpright(mpint *b, int shift, mpint *res) } res->p[i++] = last>>r; } - while(i > 0 && res->p[i-1] == 0) - i--; + res->top = i; - if(i==0) - res->sign = 1; + res->flags |= b->flags & MPtimesafe; + mpnorm(res); } diff --git a/sys/src/libmp/port/mpsel.c b/sys/src/libmp/port/mpsel.c new file mode 100644 index 000000000..a145b9d06 --- /dev/null +++ b/sys/src/libmp/port/mpsel.c @@ -0,0 +1,42 @@ +#include "os.h" +#include <mp.h> +#include "dat.h" + +// res = s != 0 ? b1 : b2 +void +mpsel(int s, mpint *b1, mpint *b2, mpint *res) +{ + mpdigit d; + int n, m, i; + + res->flags |= (b1->flags | b2->flags) & MPtimesafe; + if((res->flags & MPtimesafe) == 0){ + mpassign(s ? b1 : b2, res); + return; + } + res->flags &= ~MPnorm; + + n = b1->top; + m = b2->top; + mpbits(res, Dbits*(n >= m ? n : m)); + res->top = n >= m ? n : m; + + s = (-s^s|s)>>(sizeof(s)*8-1); + res->sign = (b1->sign & s) | (b2->sign & ~s); + + d = -((mpdigit)s & 1); + + i = 0; + while(i < n && i < m){ + res->p[i] = (b1->p[i] & d) | (b2->p[i] & ~d); + i++; + } + while(i < n){ + res->p[i] = b1->p[i] & d; + i++; + } + while(i < m){ + res->p[i] = b2->p[i] & ~d; + i++; + } +} diff --git a/sys/src/libmp/port/mpsub.c b/sys/src/libmp/port/mpsub.c index 3fe6ca095..292648f23 100644 --- a/sys/src/libmp/port/mpsub.c +++ b/sys/src/libmp/port/mpsub.c @@ -11,12 +11,15 @@ mpmagsub(mpint *b1, mpint *b2, mpint *diff) // get the sizes right if(mpmagcmp(b1, b2) < 0){ + assert(((b1->flags | b2->flags | diff->flags) & MPtimesafe) == 0); sign = -1; t = b1; b1 = b2; b2 = t; - } else + } else { + diff->flags |= (b1->flags | b2->flags) & MPtimesafe; sign = 1; + } n = b1->top; m = b2->top; if(m == 0){ @@ -39,6 +42,7 @@ mpsub(mpint *b1, mpint *b2, mpint *diff) int sign; if(b1->sign != b2->sign){ + assert(((b1->flags | b2->flags | diff->flags) & MPtimesafe) == 0); sign = b1->sign; mpmagadd(b1, b2, diff); diff->sign = sign; diff --git a/sys/src/libmp/port/mptobe.c b/sys/src/libmp/port/mptobe.c index ed527cc76..9ddea35ed 100644 --- a/sys/src/libmp/port/mptobe.c +++ b/sys/src/libmp/port/mptobe.c @@ -2,57 +2,31 @@ #include <mp.h> #include "dat.h" -// convert an mpint into a big endian byte array (most significant byte first) +// convert an mpint into a big endian byte array (most significant byte first; left adjusted) // return number of bytes converted // if p == nil, allocate and result array int mptobe(mpint *b, uchar *p, uint n, uchar **pp) { - int i, j, suppress; - mpdigit x; - uchar *e, *s, c; + int m; + m = (mpsignif(b)+7)/8; + if(m == 0) + m++; if(p == nil){ - n = (b->top+1)*Dbytes; + n = m; p = malloc(n); + if(p == nil) + sysfatal("mptobe: %r"); setmalloctag(p, getcallerpc(&b)); + } else { + if(n < m) + return -1; + if(n > m) + memset(p+m, 0, n-m); } - if(p == nil) - return -1; if(pp != nil) *pp = p; - memset(p, 0, n); - - // special case 0 - if(b->top == 0){ - if(n < 1) - return -1; - else - return 1; - } - - s = p; - e = s+n; - suppress = 1; - for(i = b->top-1; i >= 0; i--){ - x = b->p[i]; - for(j = Dbits-8; j >= 0; j -= 8){ - c = x>>j; - if(c == 0 && suppress) - continue; - if(p >= e) - return -1; - *p++ = c; - suppress = 0; - } - } - - // guarantee at least one byte - if(s == p){ - if(p >= e) - return -1; - *p++ = 0; - } - - return p - s; + mptober(b, p, m); + return m; } diff --git a/sys/src/libmp/port/mptober.c b/sys/src/libmp/port/mptober.c new file mode 100644 index 000000000..ce63d338d --- /dev/null +++ b/sys/src/libmp/port/mptober.c @@ -0,0 +1,34 @@ +#include "os.h" +#include <mp.h> +#include "dat.h" + +void +mptober(mpint *b, uchar *p, int n) +{ + int i, j, m; + mpdigit x; + + memset(p, 0, n); + + p += n; + m = b->top*Dbytes; + if(m < n) + n = m; + + i = 0; + while(n >= Dbytes){ + n -= Dbytes; + x = b->p[i++]; + for(j = 0; j < Dbytes; j++){ + *--p = x; + x >>= 8; + } + } + if(n > 0){ + x = b->p[i]; + for(j = 0; j < n; j++){ + *--p = x; + x >>= 8; + } + } +} diff --git a/sys/src/libmp/port/mptoi.c b/sys/src/libmp/port/mptoi.c index b3f22b424..6183fa7e5 100644 --- a/sys/src/libmp/port/mptoi.c +++ b/sys/src/libmp/port/mptoi.c @@ -10,17 +10,15 @@ mpint* itomp(int i, mpint *b) { - if(b == nil) + if(b == nil){ b = mpnew(0); - mpassign(mpzero, b); - if(i != 0) - b->top = 1; - if(i < 0){ - b->sign = -1; - *b->p = -i; - } else - *b->p = i; - return b; + setmalloctag(b, getcallerpc(&i)); + } + b->sign = (i >> (sizeof(i)*8 - 1)) | 1; + i *= b->sign; + *b->p = i; + b->top = 1; + return mpnorm(b); } int diff --git a/sys/src/libmp/port/mptole.c b/sys/src/libmp/port/mptole.c index 9421d5f66..3dd892401 100644 --- a/sys/src/libmp/port/mptole.c +++ b/sys/src/libmp/port/mptole.c @@ -3,52 +3,26 @@ #include "dat.h" // convert an mpint into a little endian byte array (least significant byte first) - // return number of bytes converted // if p == nil, allocate and result array int mptole(mpint *b, uchar *p, uint n, uchar **pp) { - int i, j; - mpdigit x; - uchar *e, *s; + int m; + m = (mpsignif(b)+7)/8; + if(m == 0) + m++; if(p == nil){ - n = (b->top+1)*Dbytes; + n = m; p = malloc(n); - } + if(p == nil) + sysfatal("mptole: %r"); + setmalloctag(p, getcallerpc(&b)); + } else if(n < m) + return -1; if(pp != nil) *pp = p; - if(p == nil) - return -1; - memset(p, 0, n); - - // special case 0 - if(b->top == 0){ - if(n < 1) - return -1; - else - return 0; - } - - s = p; - e = s+n; - for(i = 0; i < b->top-1; i++){ - x = b->p[i]; - for(j = 0; j < Dbytes; j++){ - if(p >= e) - return -1; - *p++ = x; - x >>= 8; - } - } - x = b->p[i]; - while(x > 0){ - if(p >= e) - return -1; - *p++ = x; - x >>= 8; - } - - return p - s; + mptolel(b, p, n); + return m; } diff --git a/sys/src/libmp/port/mptolel.c b/sys/src/libmp/port/mptolel.c new file mode 100644 index 000000000..4ee41971f --- /dev/null +++ b/sys/src/libmp/port/mptolel.c @@ -0,0 +1,33 @@ +#include "os.h" +#include <mp.h> +#include "dat.h" + +void +mptolel(mpint *b, uchar *p, int n) +{ + int i, j, m; + mpdigit x; + + memset(p, 0, n); + + m = b->top*Dbytes; + if(m < n) + n = m; + + i = 0; + while(n >= Dbytes){ + n -= Dbytes; + x = b->p[i++]; + for(j = 0; j < Dbytes; j++){ + *p++ = x; + x >>= 8; + } + } + if(n > 0){ + x = b->p[i]; + for(j = 0; j < n; j++){ + *p++ = x; + x >>= 8; + } + } +} diff --git a/sys/src/libmp/port/mptoui.c b/sys/src/libmp/port/mptoui.c index 41c0b0b67..2a963de0c 100644 --- a/sys/src/libmp/port/mptoui.c +++ b/sys/src/libmp/port/mptoui.c @@ -10,13 +10,14 @@ mpint* uitomp(uint i, mpint *b) { - if(b == nil) + if(b == nil){ b = mpnew(0); - mpassign(mpzero, b); - if(i != 0) - b->top = 1; + setmalloctag(b, getcallerpc(&i)); + } *b->p = i; - return b; + b->top = 1; + b->sign = 1; + return mpnorm(b); } uint diff --git a/sys/src/libmp/port/mptouv.c b/sys/src/libmp/port/mptouv.c index b2a7632d1..9e52a357f 100644 --- a/sys/src/libmp/port/mptouv.c +++ b/sys/src/libmp/port/mptouv.c @@ -13,19 +13,18 @@ uvtomp(uvlong v, mpint *b) { int s; - if(b == nil) + if(b == nil){ b = mpnew(VLDIGITS*sizeof(mpdigit)); - else + setmalloctag(b, getcallerpc(&v)); + }else mpbits(b, VLDIGITS*sizeof(mpdigit)); - mpassign(mpzero, b); - if(v == 0) - return b; - for(s = 0; s < VLDIGITS && v != 0; s++){ + b->sign = 1; + for(s = 0; s < VLDIGITS; s++){ b->p[s] = v; v >>= sizeof(mpdigit)*8; } b->top = s; - return b; + return mpnorm(b); } uvlong @@ -37,7 +36,6 @@ mptouv(mpint *b) if(b->top == 0) return 0LL; - mpnorm(b); if(b->top > VLDIGITS) return MAXVLONG; diff --git a/sys/src/libmp/port/mptov.c b/sys/src/libmp/port/mptov.c index b09718ef0..b1b3e93f7 100644 --- a/sys/src/libmp/port/mptov.c +++ b/sys/src/libmp/port/mptov.c @@ -14,24 +14,19 @@ vtomp(vlong v, mpint *b) int s; uvlong uv; - if(b == nil) + if(b == nil){ b = mpnew(VLDIGITS*sizeof(mpdigit)); - else + setmalloctag(b, getcallerpc(&v)); + }else mpbits(b, VLDIGITS*sizeof(mpdigit)); - mpassign(mpzero, b); - if(v == 0) - return b; - if(v < 0){ - b->sign = -1; - uv = -v; - } else - uv = v; - for(s = 0; s < VLDIGITS && uv != 0; s++){ + b->sign = (v >> (sizeof(v)*8 - 1)) | 1; + uv = v * b->sign; + for(s = 0; s < VLDIGITS; s++){ b->p[s] = uv; uv >>= sizeof(mpdigit)*8; } b->top = s; - return b; + return mpnorm(b); } vlong @@ -43,7 +38,6 @@ mptov(mpint *b) if(b->top == 0) return 0LL; - mpnorm(b); if(b->top > VLDIGITS){ if(b->sign > 0) return (vlong)MAXVLONG; diff --git a/sys/src/libmp/port/mpvectscmp.c b/sys/src/libmp/port/mpvectscmp.c new file mode 100644 index 000000000..ccad79b16 --- /dev/null +++ b/sys/src/libmp/port/mpvectscmp.c @@ -0,0 +1,34 @@ +#include "os.h" +#include <mp.h> +#include "dat.h" + +int +mpvectscmp(mpdigit *a, int alen, mpdigit *b, int blen) +{ + mpdigit x, y, z, v; + int m, p; + + if(alen > blen){ + v = 0; + while(alen > blen) + v |= a[--alen]; + m = p = (-v^v|v)>>Dbits-1; + } else if(blen > alen){ + v = 0; + while(blen > alen) + v |= b[--blen]; + m = (-v^v|v)>>Dbits-1; + p = m^1; + } else + m = p = 0; + while(alen-- > 0){ + x = a[alen]; + y = b[alen]; + z = x - y; + x = ~x; + v = ((-z^z|z)>>Dbits-1) & ~m; + p = ((~(x&y|x&z|y&z)>>Dbits-1) & v) | (p & ~v); + m |= v; + } + return (p-m) | m; +} diff --git a/sys/src/libmp/port/strtomp.c b/sys/src/libmp/port/strtomp.c index 2ef8c2109..0a9959692 100644 --- a/sys/src/libmp/port/strtomp.c +++ b/sys/src/libmp/port/strtomp.c @@ -50,7 +50,6 @@ from16(char *a, mpint *b) int i; mpdigit x; - b->top = 0; for(p = a; *p; p++) if(tab.t16[*(uchar*)p] == INVAL) break; @@ -157,8 +156,10 @@ strtomp(char *a, char **pp, int base, mpint *b) int sign; char *e; - if(b == nil) + if(b == nil){ b = mpnew(0); + setmalloctag(b, getcallerpc(&a)); + } if(tab.inited == 0) init(); @@ -196,10 +197,9 @@ strtomp(char *a, char **pp, int base, mpint *b) if(e == a) return nil; - mpnorm(b); - b->sign = sign; if(pp != nil) *pp = e; - return b; + b->sign = sign; + return mpnorm(b); } |