diff options
author | cinap_lenrek <cinap_lenrek@felloff.net> | 2015-11-21 09:39:59 +0100 |
---|---|---|
committer | cinap_lenrek <cinap_lenrek@felloff.net> | 2015-11-21 09:39:59 +0100 |
commit | 38e1e5272fc9c66a00d702246813135452819ffe (patch) | |
tree | b2d56b8f5e66a17daeb63693fc4dbd15c7308275 | |
parent | b677ab0c5909942bf8946e9e9bd148dea7dae718 (diff) |
libmp: initial attempt at constant time code, faster reductions for special primes (for ecc)
introduce MPtimesafe flag to request time invariant computation
disables normalization so significant digits are not leaked.
32 files changed, 660 insertions, 229 deletions
diff --git a/sys/include/mp.h b/sys/include/mp.h index 14061adc7..b17df619c 100644 --- a/sys/include/mp.h +++ b/sys/include/mp.h @@ -22,7 +22,10 @@ struct mpint enum { - MPstatic= 0x01, + MPstatic= 0x01, /* static constant */ + MPnorm= 0x02, /* normalization status */ + MPtimesafe= 0x04, /* request time invariant computation */ + Dbytes= sizeof(mpdigit), /* bytes per digit */ Dbits= Dbytes*8 /* bits per digit */ }; @@ -32,7 +35,7 @@ void mpsetminbits(int n); /* newly created mpint's get at least n bits */ mpint* mpnew(int n); /* create a new mpint with at least n bits */ void mpfree(mpint *b); void mpbits(mpint *b, int n); /* ensure that b has at least n bits */ -void mpnorm(mpint *b); /* dump leading zeros */ +mpint* mpnorm(mpint *b); /* dump leading zeros */ mpint* mpcopy(mpint *b); void mpassign(mpint *old, mpint *new); @@ -47,8 +50,10 @@ int mpfmt(Fmt*); char* mptoa(mpint*, int, char*, int); mpint* letomp(uchar*, uint, mpint*); /* byte array, little-endian */ int mptole(mpint*, uchar*, uint, uchar**); +void mptolel(mpint *b, uchar *p, int n); mpint* betomp(uchar*, uint, mpint*); /* byte array, big-endian */ int mptobe(mpint*, uchar*, uint, uchar**); +void mptober(mpint *b, uchar *p, int n); uint mptoui(mpint*); /* unsigned int */ mpint* uitomp(uint, mpint*); int mptoi(mpint*); /* int */ @@ -71,12 +76,20 @@ void mpmul(mpint *b1, mpint *b2, mpint *prod); /* prod = b1*b2 */ void mpexp(mpint *b, mpint *e, mpint *m, mpint *res); /* res = b**e mod m */ void mpmod(mpint *b, mpint *m, mpint *remainder); /* remainder = b mod m */ +/* modular arithmetic, time invariant when 0≤b1≤m-1 and 0≤b2≤m-1 */ +void mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum); /* sum = b1+b2 % m */ +void mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff); /* diff = b1-b2 % m */ +void mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod); /* prod = b1*b2 % m */ + /* quotient = dividend/divisor, remainder = dividend % divisor */ void mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder); /* return neg, 0, pos as b1-b2 is neg, 0, pos */ int mpcmp(mpint *b1, mpint *b2); +/* res = s != 0 ? b1 : b2 */ +void mpsel(int s, mpint *b1, mpint *b2, mpint *res); + /* extended gcd return d, x, and y, s.t. d = gcd(a,b) and ax+by = d */ void mpextendedgcd(mpint *a, mpint *b, mpint *d, mpint *x, mpint *y); @@ -106,12 +119,14 @@ void mpvecdigmuladd(mpdigit *b, int n, mpdigit m, mpdigit *p); /* prereq: p has room for n+1 digits */ int mpvecdigmulsub(mpdigit *b, int n, mpdigit m, mpdigit *p); -/* p[0:alen*blen-1] = a[0:alen-1] * b[0:blen-1] */ +/* p[0:alen+blen-1] = a[0:alen-1] * b[0:blen-1] */ /* prereq: alen >= blen, p has room for m*n digits */ void mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p); +void mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p); /* sign of a - b or zero if the same */ int mpveccmp(mpdigit *a, int alen, mpdigit *b, int blen); +int mpvectscmp(mpdigit *a, int alen, mpdigit *b, int blen); /* divide the 2 digit dividend by the one digit divisor and stick in quotient */ /* we assume that the result is one digit - overflow is all 1's */ diff --git a/sys/man/2/mp b/sys/man/2/mp index 5be4246c8..c562ccab4 100644 --- a/sys/man/2/mp +++ b/sys/man/2/mp @@ -1,6 +1,6 @@ .TH MP 2 .SH NAME -mpsetminbits, mpnew, mpfree, mpbits, mpnorm, mpcopy, mpassign, mprand, mpnrand, strtomp, mpfmt,mptoa, betomp, mptobe, letomp, mptole, mptoui, uitomp, mptoi, itomp, uvtomp, mptouv, vtomp, mptov, mpdigdiv, mpadd, mpsub, mpleft, mpright, mpmul, mpexp, mpmod, mpdiv, mpcmp, mpextendedgcd, mpinvert, mpsignif, mplowbits0, mpvecdigmuladd, mpvecdigmulsub, mpvecadd, mpvecsub, mpveccmp, mpvecmul, mpmagcmp, mpmagadd, mpmagsub, crtpre, crtin, crtout, crtprefree, crtresfree \- extended precision arithmetic +mpsetminbits, mpnew, mpfree, mpbits, mpnorm, mpcopy, mpassign, mprand, mpnrand, strtomp, mpfmt,mptoa, betomp, mptobe, mptober, letomp, mptole, mptolel, mptoui, uitomp, mptoi, itomp, uvtomp, mptouv, vtomp, mptov, mpdigdiv, mpadd, mpsub, mpleft, mpright, mpmul, mpexp, mpmod, mpmodadd, mpmodsub, mpmodmul, mpdiv, mpcmp, mpsel, mpextendedgcd, mpinvert, mpsignif, mplowbits0, mpvecdigmuladd, mpvecdigmulsub, mpvecadd, mpvecsub, mpveccmp, mpvecmul, mpmagcmp, mpmagadd, mpmagsub, crtpre, crtin, crtout, crtprefree, crtresfree \- extended precision arithmetic .SH SYNOPSIS .B #include <u.h> .br @@ -22,7 +22,7 @@ void mpsetminbits(int n) void mpbits(mpint *b, int n) .PP .B -void mpnorm(mpint *b) +mpint* mpnorm(mpint *b) .PP .B mpint* mpcopy(mpint *b) @@ -52,12 +52,18 @@ mpint* betomp(uchar *buf, uint blen, mpint *b) int mptobe(mpint *b, uchar *buf, uint blen, uchar **bufp) .PP .B +void mptober(mpint *b, uchar *buf, int blen) +.PP +.B mpint* letomp(uchar *buf, uint blen, mpint *b) .PP .B int mptole(mpint *b, uchar *buf, uint blen, uchar **bufp) .PP .B +void mptolel(mpint *b, uchar *buf, int blen) +.PP +.B uint mptoui(mpint*) .PP .B @@ -115,12 +121,24 @@ void mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder) .PP .B +void mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum) +.PP +.B +void mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff) +.PP +.B +void mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod) +.PP +.B int mpcmp(mpint *b1, mpint *b2) .PP .B int mpmagcmp(mpint *b1, mpint *b2) .PP .B +void mpsel(int s, mpint *b1, mpint *b2, mpint *res) +.PP +.B void mpextendedgcd(mpint *a, mpint *b, mpint *d, mpint *x, .br .B @@ -383,6 +401,24 @@ deposited in the location pointed to by Sign is ignored in these conversions, i.e., the byte array version is always positive. .PP +.I Mptober +and +.I mptolel +fill +.I blen +lower bytes of an +.I mpint +into a fixed length byte array. +.I Mptober +fills the bytes right adjusted in big endian order so that the least +significant byte is at +.I buf[blen-1] +while +.I mptolel +fills in little endian order; left adjusted; so that the least +significat byte is filled into +.IR buf[0] . +.PP .IR Betomp , and .I letomp @@ -486,6 +522,31 @@ is less than, equal to, or greater than the same as .I mpcmp but ignores the sign and just compares magnitudes. +.TP +.I mpsel +assigns +.I b1 +to +.I res +when +.I s +is not zero, otherwise +.I b2 +is assigned to +.IR res . +.PD +.PP +Modular arithmetic: +.TF mpmodmul_ +.TP +.I mpmodadd +.BR "sum = b1+b2 mod m" . +.TP +.I mpmodsub +.BR "diff = b1-b2 mod m" . +.TP +.I mpmodmul +.BR "prod = b1*b2 mod m" . .PD .PP .I Mpextendedgcd @@ -564,8 +625,8 @@ We assume p has room for n+1 digits. It returns +1 is the result is positive an -1 if negative. .TP .I mpvecmul -.BR "p[0:alen*blen] = a[0:alen-1] * b[0:blen-1]" . -We assume that p has room for alen*blen+1 digits. +.BR "p[0:alen+blen] = a[0:alen-1] * b[0:blen-1]" . +We assume that p has room for alen+blen+1 digits. .TP .I mpveccmp This returns -1, 0, or +1 as a - b is negative, 0, or positive. @@ -576,6 +637,17 @@ This returns -1, 0, or +1 as a - b is negative, 0, or positive. and .I mpzero are the constants 2, 1 and 0. These cannot be freed. +.SS "Time invariant computation" +.PP +In the field of cryptography, it is sometimes neccesary to implement +algorithms such that the runtime of the algorithm is not depdenent on +the input data. This library provides partial support for time +invariant computation with the +.I MPtimesafe +flag that can be set on input or destination operands to request timing +safe operation. The result of a timing safe operation will also have the +.I MPtimesafe +flag set and is not normalized. .SS "Chinese remainder theorem .PP When computing in a non-prime modulus, diff --git a/sys/src/libmp/port/betomp.c b/sys/src/libmp/port/betomp.c index 9197f3a14..0830704ef 100644 --- a/sys/src/libmp/port/betomp.c +++ b/sys/src/libmp/port/betomp.c @@ -13,19 +13,12 @@ betomp(uchar *p, uint n, mpint *b) b = mpnew(0); setmalloctag(b, getcallerpc(&p)); } - - // dump leading zeros - while(*p == 0 && n > 1){ - p++; - n--; - } - - // get the space mpbits(b, n*8); - b->top = DIGITS(n*8); - m = b->top-1; - // first digit might not be Dbytes long + m = DIGITS(n*8); + b->top = m--; + b->sign = 1; + s = ((n-1)*8)%Dbits; x = 0; for(; n > 0; n--){ @@ -37,6 +30,5 @@ betomp(uchar *p, uint n, mpint *b) x = 0; } } - - return b; + return mpnorm(b); } diff --git a/sys/src/libmp/port/letomp.c b/sys/src/libmp/port/letomp.c index e23fed21e..d5cca241b 100644 --- a/sys/src/libmp/port/letomp.c +++ b/sys/src/libmp/port/letomp.c @@ -9,8 +9,10 @@ letomp(uchar *s, uint n, mpint *b) int i=0, m = 0; mpdigit x=0; - if(b == nil) + if(b == nil){ b = mpnew(0); + setmalloctag(b, getcallerpc(&s)); + } mpbits(b, 8*n); for(; n > 0; n--){ x |= ((mpdigit)(*s++)) << i; @@ -24,5 +26,6 @@ letomp(uchar *s, uint n, mpint *b) if(i > 0) b->p[m++] = x; b->top = m; - return b; + b->sign = 1; + return mpnorm(b); } diff --git a/sys/src/libmp/port/mkfile b/sys/src/libmp/port/mkfile index 76fa25dd7..b0bdf67d5 100644 --- a/sys/src/libmp/port/mkfile +++ b/sys/src/libmp/port/mkfile @@ -6,12 +6,15 @@ FILES=\ mpfmt\ strtomp\ mptobe\ + mptober\ mptole\ + mptolel\ betomp\ letomp\ mpadd\ mpsub\ mpcmp\ + mpsel\ mpfactorial\ mpmul\ mpleft\ @@ -20,10 +23,12 @@ FILES=\ mpvecsub\ mpvecdigmuladd\ mpveccmp\ + mpvectscmp\ mpdigdiv\ mpdiv\ mpexp\ mpmod\ + mpmodop\ mpextendedgcd\ mpinvert\ mprand\ diff --git a/sys/src/libmp/port/mpadd.c b/sys/src/libmp/port/mpadd.c index 6022a64ef..9a1ccde66 100644 --- a/sys/src/libmp/port/mpadd.c +++ b/sys/src/libmp/port/mpadd.c @@ -9,6 +9,8 @@ mpmagadd(mpint *b1, mpint *b2, mpint *sum) int m, n; mpint *t; + sum->flags |= (b1->flags | b2->flags) & MPtimesafe; + // get the sizes right if(b2->top > b1->top){ t = b1; @@ -41,6 +43,7 @@ mpadd(mpint *b1, mpint *b2, mpint *sum) int sign; if(b1->sign != b2->sign){ + assert(((b1->flags | b2->flags | sum->flags) & MPtimesafe) == 0); if(b1->sign < 0) mpmagsub(b2, b1, sum); else diff --git a/sys/src/libmp/port/mpaux.c b/sys/src/libmp/port/mpaux.c index 66f1524f0..eb70a9364 100644 --- a/sys/src/libmp/port/mpaux.c +++ b/sys/src/libmp/port/mpaux.c @@ -5,33 +5,27 @@ static mpdigit _mptwodata[1] = { 2 }; static mpint _mptwo = { - 1, - 1, - 1, + 1, 1, 1, _mptwodata, - MPstatic + MPstatic|MPnorm }; mpint *mptwo = &_mptwo; static mpdigit _mponedata[1] = { 1 }; static mpint _mpone = { - 1, - 1, - 1, + 1, 1, 1, _mponedata, - MPstatic + MPstatic|MPnorm }; mpint *mpone = &_mpone; static mpdigit _mpzerodata[1] = { 0 }; static mpint _mpzero = { - 1, - 1, - 0, + 1, 1, 0, _mpzerodata, - MPstatic + MPstatic|MPnorm }; mpint *mpzero = &_mpzero; @@ -57,18 +51,17 @@ mpnew(int n) if(n < 0) sysfatal("mpsetminbits: n < 0"); - b = mallocz(sizeof(mpint), 1); - setmalloctag(b, getcallerpc(&n)); - if(b == nil) - sysfatal("mpnew: %r"); n = DIGITS(n); if(n < mpmindigits) n = mpmindigits; - b->p = (mpdigit*)mallocz(n*Dbytes, 1); - if(b->p == nil) + b = mallocz(sizeof(mpint) + n*Dbytes, 1); + if(b == nil) sysfatal("mpnew: %r"); + setmalloctag(b, getcallerpc(&n)); + b->p = (mpdigit*)&b[1]; b->size = n; b->sign = 1; + b->flags = MPnorm; return b; } @@ -83,16 +76,23 @@ mpbits(mpint *b, int m) if(b->size >= n){ if(b->top >= n) return; - memset(&b->p[b->top], 0, Dbytes*(n - b->top)); - b->top = n; - return; + } else { + if(b->p == (mpdigit*)&b[1]){ + b->p = (mpdigit*)mallocz(n*Dbytes, 0); + if(b->p == nil) + sysfatal("mpbits: %r"); + memmove(b->p, &b[1], Dbytes*b->top); + memset(&b[1], 0, Dbytes*b->size); + } else { + b->p = (mpdigit*)realloc(b->p, n*Dbytes); + if(b->p == nil) + sysfatal("mpbits: %r"); + } + b->size = n; } - b->p = (mpdigit*)realloc(b->p, n*Dbytes); - if(b->p == nil) - sysfatal("mpbits: %r"); memset(&b->p[b->top], 0, Dbytes*(n - b->top)); - b->size = n; b->top = n; + b->flags &= ~MPnorm; } void @@ -102,22 +102,30 @@ mpfree(mpint *b) return; if(b->flags & MPstatic) sysfatal("freeing mp constant"); - memset(b->p, 0, b->size*Dbytes); // information hiding - free(b->p); + memset(b->p, 0, b->size*Dbytes); + if(b->p != (mpdigit*)&b[1]) + free(b->p); free(b); } -void +mpint* mpnorm(mpint *b) { int i; + if(b->flags & MPtimesafe){ + assert(b->sign == 1); + b->flags &= ~MPnorm; + return b; + } for(i = b->top-1; i >= 0; i--) if(b->p[i] != 0) break; b->top = i+1; if(b->top == 0) b->sign = 1; + b->flags |= MPnorm; + return b; } mpint* @@ -126,8 +134,10 @@ mpcopy(mpint *old) mpint *new; new = mpnew(Dbits*old->size); - new->top = old->top; + setmalloctag(new, getcallerpc(&old)); new->sign = old->sign; + new->top = old->top; + new->flags = old->flags & ~MPstatic; memmove(new->p, old->p, Dbytes*old->top); return new; } @@ -135,9 +145,14 @@ mpcopy(mpint *old) void mpassign(mpint *old, mpint *new) { + if(new == nil || old == new) + return; + new->top = 0; mpbits(new, Dbits*old->top); new->sign = old->sign; new->top = old->top; + new->flags &= ~MPnorm; + new->flags |= old->flags & ~MPstatic; memmove(new->p, old->p, Dbytes*old->top); } @@ -167,6 +182,7 @@ mplowbits0(mpint *n) int k, bit, digit; mpdigit d; + assert(n->flags & MPnorm); if(n->top==0) return 0; k = 0; @@ -187,4 +203,3 @@ mplowbits0(mpint *n) } return k; } - diff --git a/sys/src/libmp/port/mpcmp.c b/sys/src/libmp/port/mpcmp.c index a2e3cf724..7ab5a16b6 100644 --- a/sys/src/libmp/port/mpcmp.c +++ b/sys/src/libmp/port/mpcmp.c @@ -8,10 +8,14 @@ mpmagcmp(mpint *b1, mpint *b2) { int i; - i = b1->top - b2->top; - if(i) - return i; - + i = b1->flags | b2->flags; + if(i & MPtimesafe) + return mpvectscmp(b1->p, b1->top, b2->p, b2->top); + if(i & MPnorm){ + i = b1->top - b2->top; + if(i) + return i; + } return mpveccmp(b1->p, b1->top, b2->p, b2->top); } @@ -19,10 +23,8 @@ mpmagcmp(mpint *b1, mpint *b2) int mpcmp(mpint *b1, mpint *b2) { - if(b1->sign != b2->sign) - return b1->sign - b2->sign; - if(b1->sign < 0) - return mpmagcmp(b2, b1); - else - return mpmagcmp(b1, b2); + int sign; + + sign = (b1->sign - b2->sign) >> 1; // -1, 0, 1 + return sign | (sign&1)-1 & mpmagcmp(b1, b2)*b1->sign; } diff --git a/sys/src/libmp/port/mpdiv.c b/sys/src/libmp/port/mpdiv.c index 92aee03f4..54b943862 100644 --- a/sys/src/libmp/port/mpdiv.c +++ b/sys/src/libmp/port/mpdiv.c @@ -13,10 +13,29 @@ mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder) mpdigit qd, *up, *vp, *qp; mpint *u, *v, *t; + assert(quotient != remainder); + assert(divisor->flags & MPnorm); + // divide bv zero if(divisor->top == 0) abort(); + // division by one or small powers of two + if(divisor->top == 1 && (divisor->p[0] & divisor->p[0]-1) == 0){ + vlong r = (vlong)dividend->sign * (dividend->p[0] & divisor->p[0]-1); + if(quotient != nil){ + for(s = 0; ((divisor->p[0] >> s) & 1) == 0; s++) + ; + mpright(dividend, s, quotient); + } + if(remainder != nil){ + remainder->flags |= dividend->flags & MPtimesafe; + vtomp(r, remainder); + } + return; + } + assert((dividend->flags & MPtimesafe) == 0); + // quick check if(mpmagcmp(dividend, divisor) < 0){ if(remainder != nil) @@ -95,12 +114,14 @@ mpdiv(mpint *dividend, mpint *divisor, mpint *quotient, mpint *remainder) *up-- = 0; } if(qp != nil){ + assert((quotient->flags & MPtimesafe) == 0); mpnorm(quotient); if(dividend->sign != divisor->sign) quotient->sign = -1; } if(remainder != nil){ + assert((remainder->flags & MPtimesafe) == 0); mpright(u, s, remainder); // u is the remainder shifted remainder->sign = dividend->sign; } diff --git a/sys/src/libmp/port/mpeuclid.c b/sys/src/libmp/port/mpeuclid.c index 80b5983bf..586b9cc22 100644 --- a/sys/src/libmp/port/mpeuclid.c +++ b/sys/src/libmp/port/mpeuclid.c @@ -13,6 +13,9 @@ mpeuclid(mpint *a, mpint *b, mpint *d, mpint *x, mpint *y) { mpint *tmp, *x0, *x1, *x2, *y0, *y1, *y2, *q, *r; + assert((a->flags&b->flags) & MPnorm); + assert(((a->flags|b->flags|d->flags|x->flags|y->flags) & MPtimesafe) == 0); + if(a->sign<0 || b->sign<0) sysfatal("mpeuclid: negative arg"); diff --git a/sys/src/libmp/port/mpexp.c b/sys/src/libmp/port/mpexp.c index 9ec067cb9..1ebabba93 100644 --- a/sys/src/libmp/port/mpexp.c +++ b/sys/src/libmp/port/mpexp.c @@ -22,6 +22,10 @@ mpexp(mpint *b, mpint *e, mpint *m, mpint *res) mpdigit d, bit; int i, j; + assert(m->flags & MPnorm); + assert((e->flags & MPtimesafe) == 0); + res->flags |= b->flags & MPtimesafe; + i = mpcmp(e,mpzero); if(i==0){ mpassign(mpone, res); diff --git a/sys/src/libmp/port/mpextendedgcd.c b/sys/src/libmp/port/mpextendedgcd.c index 413a05c2a..72e49bce1 100644 --- a/sys/src/libmp/port/mpextendedgcd.c +++ b/sys/src/libmp/port/mpextendedgcd.c @@ -5,7 +5,7 @@ // extended binary gcd // -// For a anv b it solves, v = gcd(a,b) and finds x and y s.t. +// For a and b it solves, v = gcd(a,b) and finds x and y s.t. // ax + by = v // // Handbook of Applied Cryptography, Menezes et al, 1997, pg 608. @@ -15,6 +15,9 @@ mpextendedgcd(mpint *a, mpint *b, mpint *v, mpint *x, mpint *y) mpint *u, *A, *B, *C, *D; int g; + assert((a->flags&b->flags) & MPnorm); + assert(((a->flags|b->flags|v->flags|x->flags|y->flags) & MPtimesafe) == 0); + if(a->sign < 0 || b->sign < 0){ mpassign(mpzero, v); mpassign(mpzero, y); diff --git a/sys/src/libmp/port/mpfmt.c b/sys/src/libmp/port/mpfmt.c index f7c42a7bc..676b64be0 100644 --- a/sys/src/libmp/port/mpfmt.c +++ b/sys/src/libmp/port/mpfmt.c @@ -102,6 +102,7 @@ to10(mpint *b, char *buf, int len) return -1; d = mpcopy(b); + mpnorm(d); r = mpnew(0); billion = uitomp(1000000000, nil); out = buf+len; @@ -128,15 +129,20 @@ int mpfmt(Fmt *fmt) { mpint *b; - char *p; + char *p, f; b = va_arg(fmt->args, mpint*); if(b == nil) return fmtstrcpy(fmt, "*"); - + + f = b->flags; + b->flags &= ~MPtimesafe; + p = mptoa(b, fmt->prec, nil, 0); fmt->flags &= ~FmtPrec; + b->flags = f; + if(p == nil) return fmtstrcpy(fmt, "*"); else{ diff --git a/sys/src/libmp/port/mpleft.c b/sys/src/libmp/port/mpleft.c index cdcdff740..38929b82e 100644 --- a/sys/src/libmp/port/mpleft.c +++ b/sys/src/libmp/port/mpleft.c @@ -15,8 +15,8 @@ mpleft(mpint *b, int shift, mpint *res) return; } - // a negative left shift is a right shift - if(shift < 0){ + // a zero or negative left shift is a right shift + if(shift <= 0){ mpright(b, -shift, res); return; } @@ -46,7 +46,6 @@ mpleft(mpint *b, int shift, mpint *res) for(i = 0; i < d; i++) res->p[i] = 0; - // normalize - while(res->top > 0 && res->p[res->top-1] == 0) - res->top--; + res->flags |= b->flags & MPtimesafe; + mpnorm(res); } diff --git a/sys/src/libmp/port/mpmod.c b/sys/src/libmp/port/mpmod.c index 91bebfa27..c053f5b7f 100644 --- a/sys/src/libmp/port/mpmod.c +++ b/sys/src/libmp/port/mpmod.c @@ -2,14 +2,100 @@ #include <mp.h> #include "dat.h" -// remainder = b mod m -// -// knuth, vol 2, pp 398-400 - void -mpmod(mpint *b, mpint *m, mpint *remainder) +mpmod(mpint *x, mpint *n, mpint *r) { - mpdiv(b, m, nil, remainder); - if(remainder->sign < 0) - mpadd(m, remainder, remainder); + static int busy; + static mpint *p, *m, *c, *v; + mpdigit q[32], t[64], d; + int sign, k, s, qn, tn; + + sign = x->sign; + + assert(n->flags & MPnorm); + if(n->top < 2 || n->top > nelem(q) || (x->top-n->top) > nelem(q)) + goto hard; + + /* + * check if n = 2**k - c where c has few power of two factors + * above the lowest digit. + */ + for(k = n->top-1; k > 0; k--){ + d = n->p[k] >> 1; + if((d+1 & d) != 0) + goto hard; + } + + d = n->p[n->top-1]; + for(s = 0; (d & (mpdigit)1<<Dbits-1) == 0; s++) + d <<= 1; + + /* lo(x) = x[0:k-1], hi(x) = x[k:xn-1] */ + k = n->top; + + while(_tas(&busy)) + ; + + if(p == nil || mpmagcmp(n, p) != 0){ + if(m == nil){ + m = mpnew(0); + c = mpnew(0); + p = mpnew(0); + } + mpassign(n, p); + + mpleft(n, s, m); + mpleft(mpone, k*Dbits, c); + mpsub(c, m, c); + } + + mpleft(x, s, r); + if(r->top <= k){ + mpbits(r, (k+1)*Dbits); + r->top = k+1; + } + + /* q = hi(r) */ + qn = r->top - k; + memmove(q, r->p+k, qn*Dbytes); + + /* r = lo(r) */ + r->top = k; + + do { + /* t = q*c */ + tn = qn + c->top; + memset(t, 0, tn*Dbytes); + mpvecmul(q, qn, c->p, c->top, t); + + /* q = hi(t) */ + qn = tn - k; + if(qn <= 0) qn = 0; + else memmove(q, t+k, qn*Dbytes); + + /* r += lo(t) */ + if(tn > k) + tn = k; + mpvecadd(r->p, k, t, tn, r->p); + + /* if(r >= m) r -= m */ + mpvecsub(r->p, k+1, m->p, k, t), d = t[k]; + for(tn = 0; tn < k; tn++) + r->p[tn] = (r->p[tn] & d) | (t[tn] & ~d); + } while(qn > 0); + + busy = 0; + + if(s != 0) + mpright(r, s, r); + else + mpnorm(r); + goto done; + +hard: + mpdiv(x, n, nil, r); + +done: + if(sign < 0) + mpmagsub(n, r, r); } diff --git a/sys/src/libmp/port/mpmodop.c b/sys/src/libmp/port/mpmodop.c new file mode 100644 index 000000000..8bc7cbb5a --- /dev/null +++ b/sys/src/libmp/port/mpmodop.c @@ -0,0 +1,96 @@ +#include <u.h> +#include <libc.h> +#include <mp.h> + +/* operands need to have m->top+1 digits of space and satisfy 0 ≤ a ≤ m-1 */ +static mpint* +modarg(mpint *a, mpint *m) +{ + if(a->size <= m->top || a->sign < 0 || mpmagcmp(a, m) >= 0){ + a = mpcopy(a); + mpmod(a, m, a); + mpbits(a, Dbits*(m->top+1)); + a->top = m->top; + } else if(a->top < m->top){ + memset(&a->p[a->top], 0, (m->top - a->top)*Dbytes); + } + return a; +} + +void +mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum) +{ + mpint *a, *b; + mpdigit d; + int i, j; + + a = modarg(b1, m); + b = modarg(b2, m); + + sum->flags |= (a->flags | b->flags) & MPtimesafe; + mpbits(sum, Dbits*2*(m->top+1)); + + mpvecadd(a->p, m->top, b->p, m->top, sum->p); + mpvecsub(sum->p, m->top+1, m->p, m->top, sum->p+m->top+1); + + d = sum->p[2*m->top+1]; + for(i = 0, j = m->top+1; i < m->top; i++, j++) + sum->p[i] = (sum->p[i] & d) | (sum->p[j] & ~d); + + sum->top = m->top; + sum->sign = 1; + mpnorm(sum); + + if(a != b1) + mpfree(a); + if(b != b2) + mpfree(b); +} + +void +mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff) +{ + mpint *a, *b; + mpdigit d; + int i, j; + + a = modarg(b1, m); + b = modarg(b2, m); + + diff->flags |= (a->flags | b->flags) & MPtimesafe; + mpbits(diff, Dbits*2*(m->top+1)); + + a->p[m->top] = 0; + mpvecsub(a->p, m->top+1, b->p, m->top, diff->p); + mpvecadd(diff->p, m->top, m->p, m->top, diff->p+m->top+1); + + d = ~diff->p[m->top]; + for(i = 0, j = m->top+1; i < m->top; i++, j++) + diff->p[i] = (diff->p[i] & d) | (diff->p[j] & ~d); + + diff->top = m->top; + diff->sign = 1; + mpnorm(diff); + + if(a != b1) + mpfree(a); + if(b != b2) + mpfree(b); +} + +void +mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod) +{ + mpint *a, *b; + + a = modarg(b1, m); + b = modarg(b2, m); + + mpmul(a, b, prod); + mpmod(prod, m, prod); + + if(a != b1) + mpfree(a); + if(b != b2) + mpfree(b); +} diff --git a/sys/src/libmp/port/mpmul.c b/sys/src/libmp/port/mpmul.c index dedd474a7..777adf307 100644 --- a/sys/src/libmp/port/mpmul.c +++ b/sys/src/libmp/port/mpmul.c @@ -113,10 +113,6 @@ mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) a = b; b = t; } - if(blen == 0){ - memset(p, 0, Dbytes*(alen+blen)); - return; - } if(alen >= KARATSUBAMIN && blen > 1){ // O(n^1.585) @@ -132,24 +128,48 @@ mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) } void +mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p) +{ + int i; + mpdigit *t; + + if(alen < blen){ + i = alen; + alen = blen; + blen = i; + t = a; + a = b; + b = t; + } + if(blen == 0) + return; + for(i = 0; i < blen; i++) + mpvecdigmuladd(a, alen, b[i], &p[i]); +} + +void mpmul(mpint *b1, mpint *b2, mpint *prod) { mpint *oprod; - oprod = nil; + oprod = prod; if(prod == b1 || prod == b2){ - oprod = prod; prod = mpnew(0); + prod->flags = oprod->flags; } + prod->flags |= (b1->flags | b2->flags) & MPtimesafe; prod->top = 0; mpbits(prod, (b1->top+b2->top+1)*Dbits); - mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p); + if(prod->flags & MPtimesafe) + mpvectsmul(b1->p, b1->top, b2->p, b2->top, prod->p); + else + mpvecmul(b1->p, b1->top, b2->p, b2->top, prod->p); prod->top = b1->top+b2->top+1; prod->sign = b1->sign*b2->sign; mpnorm(prod); - if(oprod != nil){ + if(oprod != prod){ mpassign(prod, oprod); mpfree(prod); } diff --git a/sys/src/libmp/port/mpnrand.c b/sys/src/libmp/port/mpnrand.c index 600283d9d..ebbed5097 100644 --- a/sys/src/libmp/port/mpnrand.c +++ b/sys/src/libmp/port/mpnrand.c @@ -16,8 +16,10 @@ mpnrand(mpint *n, void (*gen)(uchar*, int), mpint *b) mpleft(mpone, bits, m); mpsub(m, mpone, m); - if(b == nil) + if(b == nil){ b = mpnew(bits); + setmalloctag(b, getcallerpc(&n)); + } /* m = m - (m % n) */ mpmod(m, n, b); diff --git a/sys/src/libmp/port/mprand.c b/sys/src/libmp/port/mprand.c index fd288f24e..29433b669 100644 --- a/sys/src/libmp/port/mprand.c +++ b/sys/src/libmp/port/mprand.c @@ -6,37 +6,32 @@ mpint* mprand(int bits, void (*gen)(uchar*, int), mpint *b) { - int n, m; mpdigit mask; + int n, m; uchar *p; n = DIGITS(bits); - if(b == nil) + if(b == nil){ b = mpnew(bits); - else + setmalloctag(b, getcallerpc(&bits)); + }else mpbits(b, bits); p = malloc(n*Dbytes); if(p == nil) - return nil; + sysfatal("mprand: %r"); (*gen)(p, n*Dbytes); betomp(p, n*Dbytes, b); free(p); // make sure we don't give too many bits m = bits%Dbits; - n--; - if(m > 0){ - mask = 1; - mask <<= m; - mask--; - b->p[n] &= mask; - } + if(m == 0) + return b; - for(; n >= 0; n--) - if(b->p[n] != 0) - break; - b->top = n+1; - b->sign = 1; - return b; + mask = 1; + mask <<= m; + mask--; + b->p[n-1] &= mask; + return mpnorm(b); } diff --git a/sys/src/libmp/port/mpright.c b/sys/src/libmp/port/mpright.c index 03039177b..dde7aeace 100644 --- a/sys/src/libmp/port/mpright.c +++ b/sys/src/libmp/port/mpright.c @@ -23,12 +23,16 @@ mpright(mpint *b, int shift, mpint *res) if(res != b) mpbits(res, b->top*Dbits - shift); + else if(shift == 0) + return; + d = shift/Dbits; r = shift - d*Dbits; l = Dbits - r; // shift all the bits out == zero if(d>=b->top){ + res->sign = 1; res->top = 0; return; } @@ -46,9 +50,8 @@ mpright(mpint *b, int shift, mpint *res) } res->p[i++] = last>>r; } - while(i > 0 && res->p[i-1] == 0) - i--; + res->top = i; - if(i==0) - res->sign = 1; + res->flags |= b->flags & MPtimesafe; + mpnorm(res); } diff --git a/sys/src/libmp/port/mpsel.c b/sys/src/libmp/port/mpsel.c new file mode 100644 index 000000000..a145b9d06 --- /dev/null +++ b/sys/src/libmp/port/mpsel.c @@ -0,0 +1,42 @@ +#include "os.h" +#include <mp.h> +#include "dat.h" + +// res = s != 0 ? b1 : b2 +void +mpsel(int s, mpint *b1, mpint *b2, mpint *res) +{ + mpdigit d; + int n, m, i; + + res->flags |= (b1->flags | b2->flags) & MPtimesafe; + if((res->flags & MPtimesafe) == 0){ + mpassign(s ? b1 : b2, res); + return; + } + res->flags &= ~MPnorm; + + n = b1->top; + m = b2->top; + mpbits(res, Dbits*(n >= m ? n : m)); + res->top = n >= m ? n : m; + + s = (-s^s|s)>>(sizeof(s)*8-1); + res->sign = (b1->sign & s) | (b2->sign & ~s); + + d = -((mpdigit)s & 1); + + i = 0; + while(i < n && i < m){ + res->p[i] = (b1->p[i] & d) | (b2->p[i] & ~d); + i++; + } + while(i < n){ + res->p[i] = b1->p[i] & d; + i++; + } + while(i < m){ + res->p[i] = b2->p[i] & ~d; + i++; + } +} diff --git a/sys/src/libmp/port/mpsub.c b/sys/src/libmp/port/mpsub.c index 3fe6ca095..292648f23 100644 --- a/sys/src/libmp/port/mpsub.c +++ b/sys/src/libmp/port/mpsub.c @@ -11,12 +11,15 @@ mpmagsub(mpint *b1, mpint *b2, mpint *diff) // get the sizes right if(mpmagcmp(b1, b2) < 0){ + assert(((b1->flags | b2->flags | diff->flags) & MPtimesafe) == 0); sign = -1; t = b1; b1 = b2; b2 = t; - } else + } else { + diff->flags |= (b1->flags | b2->flags) & MPtimesafe; sign = 1; + } n = b1->top; m = b2->top; if(m == 0){ @@ -39,6 +42,7 @@ mpsub(mpint *b1, mpint *b2, mpint *diff) int sign; if(b1->sign != b2->sign){ + assert(((b1->flags | b2->flags | diff->flags) & MPtimesafe) == 0); sign = b1->sign; mpmagadd(b1, b2, diff); diff->sign = sign; diff --git a/sys/src/libmp/port/mptobe.c b/sys/src/libmp/port/mptobe.c index ed527cc76..9ddea35ed 100644 --- a/sys/src/libmp/port/mptobe.c +++ b/sys/src/libmp/port/mptobe.c @@ -2,57 +2,31 @@ #include <mp.h> #include "dat.h" -// convert an mpint into a big endian byte array (most significant byte first) +// convert an mpint into a big endian byte array (most significant byte first; left adjusted) // return number of bytes converted // if p == nil, allocate and result array int mptobe(mpint *b, uchar *p, uint n, uchar **pp) { - int i, j, suppress; - mpdigit x; - uchar *e, *s, c; + int m; + m = (mpsignif(b)+7)/8; + if(m == 0) + m++; if(p == nil){ - n = (b->top+1)*Dbytes; + n = m; p = malloc(n); + if(p == nil) + sysfatal("mptobe: %r"); setmalloctag(p, getcallerpc(&b)); + } else { + if(n < m) + return -1; + if(n > m) + memset(p+m, 0, n-m); } - if(p == nil) - return -1; if(pp != nil) *pp = p; - memset(p, 0, n); - - // special case 0 - if(b->top == 0){ - if(n < 1) - return -1; - else - return 1; - } - - s = p; - e = s+n; - suppress = 1; - for(i = b->top-1; i >= 0; i--){ - x = b->p[i]; - for(j = Dbits-8; j >= 0; j -= 8){ - c = x>>j; - if(c == 0 && suppress) - continue; - if(p >= e) - return -1; - *p++ = c; - suppress = 0; - } - } - - // guarantee at least one byte - if(s == p){ - if(p >= e) - return -1; - *p++ = 0; - } - - return p - s; + mptober(b, p, m); + return m; } diff --git a/sys/src/libmp/port/mptober.c b/sys/src/libmp/port/mptober.c new file mode 100644 index 000000000..ce63d338d --- /dev/null +++ b/sys/src/libmp/port/mptober.c @@ -0,0 +1,34 @@ +#include "os.h" +#include <mp.h> +#include "dat.h" + +void +mptober(mpint *b, uchar *p, int n) +{ + int i, j, m; + mpdigit x; + + memset(p, 0, n); + + p += n; + m = b->top*Dbytes; + if(m < n) + n = m; + + i = 0; + while(n >= Dbytes){ + n -= Dbytes; + x = b->p[i++]; + for(j = 0; j < Dbytes; j++){ + *--p = x; + x >>= 8; + } + } + if(n > 0){ + x = b->p[i]; + for(j = 0; j < n; j++){ + *--p = x; + x >>= 8; + } + } +} diff --git a/sys/src/libmp/port/mptoi.c b/sys/src/libmp/port/mptoi.c index b3f22b424..6183fa7e5 100644 --- a/sys/src/libmp/port/mptoi.c +++ b/sys/src/libmp/port/mptoi.c @@ -10,17 +10,15 @@ mpint* itomp(int i, mpint *b) { - if(b == nil) + if(b == nil){ b = mpnew(0); - mpassign(mpzero, b); - if(i != 0) - b->top = 1; - if(i < 0){ - b->sign = -1; - *b->p = -i; - } else - *b->p = i; - return b; + setmalloctag(b, getcallerpc(&i)); + } + b->sign = (i >> (sizeof(i)*8 - 1)) | 1; + i *= b->sign; + *b->p = i; + b->top = 1; + return mpnorm(b); } int diff --git a/sys/src/libmp/port/mptole.c b/sys/src/libmp/port/mptole.c index 9421d5f66..3dd892401 100644 --- a/sys/src/libmp/port/mptole.c +++ b/sys/src/libmp/port/mptole.c @@ -3,52 +3,26 @@ #include "dat.h" // convert an mpint into a little endian byte array (least significant byte first) - // return number of bytes converted // if p == nil, allocate and result array int mptole(mpint *b, uchar *p, uint n, uchar **pp) { - int i, j; - mpdigit x; - uchar *e, *s; + int m; + m = (mpsignif(b)+7)/8; + if(m == 0) + m++; if(p == nil){ - n = (b->top+1)*Dbytes; + n = m; p = malloc(n); - } + if(p == nil) + sysfatal("mptole: %r"); + setmalloctag(p, getcallerpc(&b)); + } else if(n < m) + return -1; if(pp != nil) *pp = p; - if(p == nil) - return -1; - memset(p, 0, n); - - // special case 0 - if(b->top == 0){ - if(n < 1) - return -1; - else - return 0; - } - - s = p; - e = s+n; - for(i = 0; i < b->top-1; i++){ - x = b->p[i]; - for(j = 0; j < Dbytes; j++){ - if(p >= e) - return -1; - *p++ = x; - x >>= 8; - } - } - x = b->p[i]; - while(x > 0){ - if(p >= e) - return -1; - *p++ = x; - x >>= 8; - } - - return p - s; + mptolel(b, p, n); + return m; } diff --git a/sys/src/libmp/port/mptolel.c b/sys/src/libmp/port/mptolel.c new file mode 100644 index 000000000..4ee41971f --- /dev/null +++ b/sys/src/libmp/port/mptolel.c @@ -0,0 +1,33 @@ +#include "os.h" +#include <mp.h> +#include "dat.h" + +void +mptolel(mpint *b, uchar *p, int n) +{ + int i, j, m; + mpdigit x; + + memset(p, 0, n); + + m = b->top*Dbytes; + if(m < n) + n = m; + + i = 0; + while(n >= Dbytes){ + n -= Dbytes; + x = b->p[i++]; + for(j = 0; j < Dbytes; j++){ + *p++ = x; + x >>= 8; + } + } + if(n > 0){ + x = b->p[i]; + for(j = 0; j < n; j++){ + *p++ = x; + x >>= 8; + } + } +} diff --git a/sys/src/libmp/port/mptoui.c b/sys/src/libmp/port/mptoui.c index 41c0b0b67..2a963de0c 100644 --- a/sys/src/libmp/port/mptoui.c +++ b/sys/src/libmp/port/mptoui.c @@ -10,13 +10,14 @@ mpint* uitomp(uint i, mpint *b) { - if(b == nil) + if(b == nil){ b = mpnew(0); - mpassign(mpzero, b); - if(i != 0) - b->top = 1; + setmalloctag(b, getcallerpc(&i)); + } *b->p = i; - return b; + b->top = 1; + b->sign = 1; + return mpnorm(b); } uint diff --git a/sys/src/libmp/port/mptouv.c b/sys/src/libmp/port/mptouv.c index b2a7632d1..9e52a357f 100644 --- a/sys/src/libmp/port/mptouv.c +++ b/sys/src/libmp/port/mptouv.c @@ -13,19 +13,18 @@ uvtomp(uvlong v, mpint *b) { int s; - if(b == nil) + if(b == nil){ b = mpnew(VLDIGITS*sizeof(mpdigit)); - else + setmalloctag(b, getcallerpc(&v)); + }else mpbits(b, VLDIGITS*sizeof(mpdigit)); - mpassign(mpzero, b); - if(v == 0) - return b; - for(s = 0; s < VLDIGITS && v != 0; s++){ + b->sign = 1; + for(s = 0; s < VLDIGITS; s++){ b->p[s] = v; v >>= sizeof(mpdigit)*8; } b->top = s; - return b; + return mpnorm(b); } uvlong @@ -37,7 +36,6 @@ mptouv(mpint *b) if(b->top == 0) return 0LL; - mpnorm(b); if(b->top > VLDIGITS) return MAXVLONG; diff --git a/sys/src/libmp/port/mptov.c b/sys/src/libmp/port/mptov.c index b09718ef0..b1b3e93f7 100644 --- a/sys/src/libmp/port/mptov.c +++ b/sys/src/libmp/port/mptov.c @@ -14,24 +14,19 @@ vtomp(vlong v, mpint *b) int s; uvlong uv; - if(b == nil) + if(b == nil){ b = mpnew(VLDIGITS*sizeof(mpdigit)); - else + setmalloctag(b, getcallerpc(&v)); + }else mpbits(b, VLDIGITS*sizeof(mpdigit)); - mpassign(mpzero, b); - if(v == 0) - return b; - if(v < 0){ - b->sign = -1; - uv = -v; - } else - uv = v; - for(s = 0; s < VLDIGITS && uv != 0; s++){ + b->sign = (v >> (sizeof(v)*8 - 1)) | 1; + uv = v * b->sign; + for(s = 0; s < VLDIGITS; s++){ b->p[s] = uv; uv >>= sizeof(mpdigit)*8; } b->top = s; - return b; + return mpnorm(b); } vlong @@ -43,7 +38,6 @@ mptov(mpint *b) if(b->top == 0) return 0LL; - mpnorm(b); if(b->top > VLDIGITS){ if(b->sign > 0) return (vlong)MAXVLONG; diff --git a/sys/src/libmp/port/mpvectscmp.c b/sys/src/libmp/port/mpvectscmp.c new file mode 100644 index 000000000..ccad79b16 --- /dev/null +++ b/sys/src/libmp/port/mpvectscmp.c @@ -0,0 +1,34 @@ +#include "os.h" +#include <mp.h> +#include "dat.h" + +int +mpvectscmp(mpdigit *a, int alen, mpdigit *b, int blen) +{ + mpdigit x, y, z, v; + int m, p; + + if(alen > blen){ + v = 0; + while(alen > blen) + v |= a[--alen]; + m = p = (-v^v|v)>>Dbits-1; + } else if(blen > alen){ + v = 0; + while(blen > alen) + v |= b[--blen]; + m = (-v^v|v)>>Dbits-1; + p = m^1; + } else + m = p = 0; + while(alen-- > 0){ + x = a[alen]; + y = b[alen]; + z = x - y; + x = ~x; + v = ((-z^z|z)>>Dbits-1) & ~m; + p = ((~(x&y|x&z|y&z)>>Dbits-1) & v) | (p & ~v); + m |= v; + } + return (p-m) | m; +} diff --git a/sys/src/libmp/port/strtomp.c b/sys/src/libmp/port/strtomp.c index 2ef8c2109..0a9959692 100644 --- a/sys/src/libmp/port/strtomp.c +++ b/sys/src/libmp/port/strtomp.c @@ -50,7 +50,6 @@ from16(char *a, mpint *b) int i; mpdigit x; - b->top = 0; for(p = a; *p; p++) if(tab.t16[*(uchar*)p] == INVAL) break; @@ -157,8 +156,10 @@ strtomp(char *a, char **pp, int base, mpint *b) int sign; char *e; - if(b == nil) + if(b == nil){ b = mpnew(0); + setmalloctag(b, getcallerpc(&a)); + } if(tab.inited == 0) init(); @@ -196,10 +197,9 @@ strtomp(char *a, char **pp, int base, mpint *b) if(e == a) return nil; - mpnorm(b); - b->sign = sign; if(pp != nil) *pp = e; - return b; + b->sign = sign; + return mpnorm(b); } |