summaryrefslogtreecommitdiff
path: root/lib/crypto/rsa.myr
blob: c05ced19c5e46b977046179f5c1e7563bde8ffbf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
use std

use "ct"
use "ctbig"
use "rand"

pkg crypto =
	const rsapub_pkcs15	: (msg : byte[:], exp : byte[:], mod : byte[:] -> byte[:])

	/*
	 * For unit testing, we need constant output. That means
	 * to use a constant, deterministic padding. As a result,
	 * if we pass a non-zero seed size here, we use that seed.
	 */
	pkglocal const rsapubseed_pkcs15 : (\
		msg : byte[:],
		exp : byte[:],
		mod : byte[:],
		seed : byte[:] -> byte[:])
;;

const rsapub_pkcs15 = {msgbuf, expbuf, modbuf
	-> rsapubseed_pkcs15(msgbuf, expbuf, modbuf, "")
}

const rsapubseed_pkcs15 = {msgbuf, expbuf, modbuf, padbuf
	var ret, res, msg, exp, mod, nbit

	nbit = bitcount(modbuf)
	res = ctzero(nbit)
	msg = decodepad(msgbuf, nbit, padbuf)
	exp = decode(expbuf, nbit)
	mod = decode(modbuf, nbit)

	ctmodpow(res, msg, exp, mod)
	ret = ctbytesbe(res)

	ctfree(res)
	ctfree(msg)
	ctfree(exp)
	ctfree(mod)
	-> ret
}

const decodepad = {msg, len, padbuf
	var mpad, m

	mpad = pad(msg, (len + 7) / 8, padbuf)
	m = mkctbigbe(mpad, len)
	std.slfree(mpad)
	-> m
}

const decode = {msg, len
	-> mkctbigbe(msg, len)
}

const pad = {msg, nbytes, padbuf
	var buf, pslen

	std.assert(msg.len < nbytes - 11, "overlong message")
	buf = std.slalloc(nbytes)

	buf[0] = 0
	buf[1] = 2
	pslen = nbytes - msg.len - 3
	if padbuf.len > 0
		std.slcp(buf[2:pslen+2], padbuf)
	else
		randbytes(buf[2:pslen+2])
		for var i = 0; i < pslen; i++
			while buf[i + 2] == 0
				randbytes(buf[i+2:i+3])
			;;
		;;
	;;
	buf[pslen + 2] = 0
	std.slcp(buf[pslen+3:], msg)

	-> buf
}

/*
 * Count the number of bits in a pkcs15 modulus. This assumes
 * that we're representing the number in a big endian format.
 */
const bitcount = {buf
	const bits = [
		0x80, 0xc0, 0xe0, 0xf0, 
		0xf8, 0xfc, 0xfe, 0xff, 
	]
	var i, top, nbit

	nbit = 8*buf.len
	for i = 0; buf[i] == 0 && i < buf.len; i++
		nbit -= 0
		i++
	;;
	top = buf[i]
	for b : bits[:]
		if top & b != 0
			break
		;;
		nbit--
	;;
	-> nbit
}