summaryrefslogtreecommitdiff
path: root/sys/include/mp.h
blob: 0de6db39c8e67f90c4c02cb8a8dc580ec407ee02 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
#pragma	src	"/sys/src/libmp"
#pragma	lib	"libmp.a"

#define _MPINT 1

/*
 * the code assumes mpdigit to be at least an int
 * mpdigit must be an atomic type.  mpdigit is defined
 * in the architecture specific u.h
 */
typedef struct mpint mpint;

struct mpint
{
	int	sign;	/* +1 or -1 */
	int	size;	/* allocated digits */
	int	top;	/* significant digits */
	mpdigit	*p;
	char	flags;
};

enum
{
	MPstatic=	0x01,	/* static constant */
	MPnorm=		0x02,	/* normalization status */
	MPtimesafe=	0x04,	/* request time invariant computation */
	MPfield=	0x08,	/* this mpint is a field modulus */

	Dbytes=		sizeof(mpdigit),	/* bytes per digit */
	Dbits=		Dbytes*8		/* bits per digit */
};

/* allocation */
void	mpsetminbits(int n);	/* newly created mpint's get at least n bits */
mpint*	mpnew(int n);		/* create a new mpint with at least n bits */
void	mpfree(mpint *b);
void	mpbits(mpint *b, int n);	/* ensure that b has at least n bits */
mpint*	mpnorm(mpint *b);		/* dump leading zeros */
mpint*	mpcopy(mpint *b);
void	mpassign(mpint *old, mpint *new);

/* random bits */
mpint*	mprand(int bits, void (*gen)(uchar*, int), mpint *b);
/* return uniform random [0..n-1] */
mpint*	mpnrand(mpint *n, void (*gen)(uchar*, int), mpint *b);

/* conversion */
mpint*	strtomp(char*, char**, int, mpint*);	/* ascii */
int	mpfmt(Fmt*);
char*	mptoa(mpint*, int, char*, int);
mpint*	letomp(uchar*, uint, mpint*);	/* byte array, little-endian */
int	mptole(mpint*, uchar*, uint, uchar**);
void	mptolel(mpint *b, uchar *p, int n);
mpint*	betomp(uchar*, uint, mpint*);	/* byte array, big-endian */
int	mptobe(mpint*, uchar*, uint, uchar**);
void	mptober(mpint *b, uchar *p, int n);
uint	mptoui(mpint*);			/* unsigned int */
mpint*	uitomp(uint, mpint*);
int	mptoi(mpint*);			/* int */
mpint*	itomp(int, mpint*);
uvlong	mptouv(mpint*);			/* unsigned vlong */
mpint*	uvtomp(uvlong, mpint*);
vlong	mptov(mpint*);			/* vlong */
mpint*	vtomp(vlong, mpint*);
double	mptod(mpint*);			/* double */
mpint*	dtomp(double, mpint*);

/* divide 2 digits by one */
void	mpdigdiv(mpdigit *dividend, mpdigit divisor, mpdigit *quotient);

/* in the following, the result mpint may be */
/* the same as one of the inputs. */
void	mpadd(mpint *b1, mpint *b2, mpint *sum);	/* sum = b1+b2 */
void	mpsub(mpint *b1, mpint *b2, mpint *diff);	/* diff = b1-b2 */
void	mpleft(mpint *b, int shift, mpint *res);	/* res = b<<shift */
void	mpright(mpint *b, int shift, mpint *res);	/* res = b>>shift */
void	mpmul(mpint *b1, mpint *b2, mpint *prod);	/* prod = b1*b2 */
void	mpexp(mpint *b, mpint *e, mpint *m, mpint *res);	/* res = b**e mod m */
void	mpmod(mpint *b, mpint *m, mpint *remainder);	/* remainder = b mod m */

/* logical operations */
void	mpand(mpint *b1, mpint *b2, mpint *res);
void	mpbic(mpint *b1, mpint *b2, mpint *res);
void	mpor(mpint *b1, mpint *b2, mpint *res);
void	mpnot(mpint *b, mpint *res);
void	mpxor(mpint *b1, mpint *b2, mpint *res);
void	mptrunc(mpint *b, int n, mpint *res);
void	mpxtend(mpint *b, int n, mpint *res);
void	mpasr(mpint *b, int shift, mpint *res);

/* modular arithmetic, time invariant when 0≤b1≤m-1 and 0≤b2≤m-1 */
void	mpmodadd(mpint *b1, mpint *b2, mpint *m, mpint *sum);	/* sum = b1+b2 % m */
void	mpmodsub(mpint *b1, mpint *b2, mpint *m, mpint *diff);	/* diff = b1-b2 % m */
void	mpmodmul(mpint *b1, mpint *b2, mpint *m, mpint *prod);	/* prod = b1*b2 % m */

/* quotient = dividend/divisor, remainder = dividend % divisor */
void	mpdiv(mpint *dividend, mpint *divisor,  mpint *quotient, mpint *remainder);

/* return neg, 0, pos as b1-b2 is neg, 0, pos */
int	mpcmp(mpint *b1, mpint *b2);

/* res = s != 0 ? b1 : b2 */
void	mpsel(int s, mpint *b1, mpint *b2, mpint *res);

/* return n! */
mpint*	mpfactorial(ulong n);

/* extended gcd return d, x, and y, s.t. d = gcd(a,b) and ax+by = d */
void	mpextendedgcd(mpint *a, mpint *b, mpint *d, mpint *x, mpint *y);

/* res = b**-1 mod m */
void	mpinvert(mpint *b, mpint *m, mpint *res);

/* bit counting */
int	mpsignif(mpint*);	/* number of sigificant bits in mantissa */
int	mplowbits0(mpint*);	/* k, where n = 2**k * q for odd q */

/* well known constants */
extern mpint	*mpzero, *mpone, *mptwo;

/* sum[0:alen] = a[0:alen-1] + b[0:blen-1] */
/* prereq: alen >= blen, sum has room for alen+1 digits */
void	mpvecadd(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *sum);

/* diff[0:alen-1] = a[0:alen-1] - b[0:blen-1] */
/* prereq: alen >= blen, diff has room for alen digits */
void	mpvecsub(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *diff);

/* p[0:n] += m * b[0:n-1] */
/* prereq: p has room for n+1 digits */
void	mpvecdigmuladd(mpdigit *b, int n, mpdigit m, mpdigit *p);

/* p[0:n] -= m * b[0:n-1] */
/* prereq: p has room for n+1 digits */
int	mpvecdigmulsub(mpdigit *b, int n, mpdigit m, mpdigit *p);

/* p[0:alen+blen-1] = a[0:alen-1] * b[0:blen-1] */
/* prereq: alen >= blen, p has room for m*n digits */
void	mpvecmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p);
void	mpvectsmul(mpdigit *a, int alen, mpdigit *b, int blen, mpdigit *p);

/* sign of a - b or zero if the same */
int	mpveccmp(mpdigit *a, int alen, mpdigit *b, int blen);
int	mpvectscmp(mpdigit *a, int alen, mpdigit *b, int blen);

/* divide the 2 digit dividend by the one digit divisor and stick in quotient */
/* we assume that the result is one digit - overflow is all 1's */
void	mpdigdiv(mpdigit *dividend, mpdigit divisor, mpdigit *quotient);

/* playing with magnitudes */
int	mpmagcmp(mpint *b1, mpint *b2);
void	mpmagadd(mpint *b1, mpint *b2, mpint *sum);	/* sum = b1+b2 */
void	mpmagsub(mpint *b1, mpint *b2, mpint *sum);	/* sum = b1+b2 */

/* chinese remainder theorem */
typedef struct CRTpre	CRTpre;		/* precomputed values for converting */
					/*  twixt residues and mpint */
typedef struct CRTres	CRTres;		/* residue form of an mpint */

#pragma incomplete CRTpre

struct CRTres
{
	int	n;		/* number of residues */
	mpint	*r[1];		/* residues */
};

CRTpre*	crtpre(int, mpint**);			/* precompute conversion values */
CRTres*	crtin(CRTpre*, mpint*);			/* convert mpint to residues */
void	crtout(CRTpre*, CRTres*, mpint*);	/* convert residues to mpint */
void	crtprefree(CRTpre*);
void	crtresfree(CRTres*);

/* fast field arithmetic */
typedef struct Mfield	Mfield;

struct Mfield
{
	mpint;
	int	(*reduce)(Mfield*, mpint*, mpint*);
};

mpint *mpfield(mpint*);

Mfield *gmfield(mpint*);
Mfield *cnfield(mpint*);

#pragma	varargck	type	"B"	mpint*