Skip to content

Commit 0deea1b

Browse files
committed
Use constant-time (faster) padding decoding also for OAEP
1 parent 519e7ae commit 0deea1b

File tree

5 files changed

+145
-66
lines changed

5 files changed

+145
-66
lines changed

lib/Crypto/Cipher/PKCS1_OAEP.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,13 @@
2323
from Crypto.Signature.pss import MGF1
2424
import Crypto.Hash.SHA1
2525

26-
from Crypto.Util.py3compat import bord, _copy_bytes
26+
from Crypto.Util.py3compat import _copy_bytes
2727
import Crypto.Util.number
28-
from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
29-
from Crypto.Util.strxor import strxor
28+
from Crypto.Util.number import ceil_div, bytes_to_long, long_to_bytes
29+
from Crypto.Util.strxor import strxor
3030
from Crypto import Random
31+
from ._pkcs1_oaep_decode import oaep_decode
32+
3133

3234
class PKCS1OAEP_Cipher:
3335
"""Cipher object for PKCS#1 v1.5 OAEP.
@@ -68,7 +70,7 @@ def __init__(self, key, hashAlgo, mgfunc, label, randfunc):
6870
if mgfunc:
6971
self._mgf = mgfunc
7072
else:
71-
self._mgf = lambda x,y: MGF1(x,y,self._hashObj)
73+
self._mgf = lambda x, y: MGF1(x, y, self._hashObj)
7274

7375
self._label = _copy_bytes(None, None, label)
7476
self._randfunc = randfunc
@@ -105,7 +107,7 @@ def encrypt(self, message):
105107

106108
# See 7.1.1 in RFC3447
107109
modBits = Crypto.Util.number.size(self._key.n)
108-
k = ceil_div(modBits, 8) # Convert from bits to bytes
110+
k = ceil_div(modBits, 8) # Convert from bits to bytes
109111
hLen = self._hashObj.digest_size
110112
mLen = len(message)
111113

@@ -159,20 +161,18 @@ def decrypt(self, ciphertext):
159161

160162
# See 7.1.2 in RFC3447
161163
modBits = Crypto.Util.number.size(self._key.n)
162-
k = ceil_div(modBits,8) # Convert from bits to bytes
164+
k = ceil_div(modBits, 8) # Convert from bits to bytes
163165
hLen = self._hashObj.digest_size
164166

165167
# Step 1b and 1c
166-
if len(ciphertext) != k or k<hLen+2:
168+
if len(ciphertext) != k or k < hLen+2:
167169
raise ValueError("Ciphertext with incorrect length.")
168170
# Step 2a (O2SIP)
169171
ct_int = bytes_to_long(ciphertext)
170172
# Step 2b (RSADP) and step 2c (I2OSP)
171173
em = self._key._decrypt_to_bytes(ct_int)
172174
# Step 3a
173175
lHash = self._hashObj.new(self._label).digest()
174-
# Step 3b
175-
y = em[0]
176176
# y must be 0, but we MUST NOT check it here in order not to
177177
# allow attacks like Manger's (http://dl.acm.org/citation.cfm?id=704143)
178178
maskedSeed = em[1:hLen+1]
@@ -185,22 +185,17 @@ def decrypt(self, ciphertext):
185185
dbMask = self._mgf(seed, k-hLen-1)
186186
# Step 3f
187187
db = strxor(maskedDB, dbMask)
188-
# Step 3g
189-
one_pos = hLen + db[hLen:].find(b'\x01')
190-
lHash1 = db[:hLen]
191-
invalid = bord(y) | int(one_pos < hLen)
192-
hash_compare = strxor(lHash1, lHash)
193-
for x in hash_compare:
194-
invalid |= bord(x)
195-
for x in db[hLen:one_pos]:
196-
invalid |= bord(x)
197-
if invalid != 0:
188+
# Step 3b + 3g
189+
res = oaep_decode(em, lHash, db)
190+
if res <= 0:
198191
raise ValueError("Incorrect decryption.")
199192
# Step 4
200-
return db[one_pos + 1:]
193+
return db[res:]
194+
201195

202196
def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None):
203-
"""Return a cipher object :class:`PKCS1OAEP_Cipher` that can be used to perform PKCS#1 OAEP encryption or decryption.
197+
"""Return a cipher object :class:`PKCS1OAEP_Cipher`
198+
that can be used to perform PKCS#1 OAEP encryption or decryption.
204199
205200
:param key:
206201
The key object to use to encrypt or decrypt the message.
@@ -234,4 +229,3 @@ def new(key, hashAlgo=None, mgfunc=None, label=b'', randfunc=None):
234229
if randfunc is None:
235230
randfunc = Random.get_random_bytes
236231
return PKCS1OAEP_Cipher(key, hashAlgo, mgfunc, label, randfunc)
237-

lib/Crypto/Cipher/PKCS1_v1_5.py

Lines changed: 3 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,31 +25,7 @@
2525
from Crypto import Random
2626
from Crypto.Util.number import bytes_to_long, long_to_bytes
2727
from Crypto.Util.py3compat import bord, is_bytes, _copy_bytes
28-
29-
from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, c_size_t,
30-
c_uint8_ptr)
31-
32-
33-
_raw_pkcs1_decode = load_pycryptodome_raw_lib("Crypto.Cipher._pkcs1_decode",
34-
"""
35-
int pkcs1_decode(const uint8_t *em, size_t len_em,
36-
const uint8_t *sentinel, size_t len_sentinel,
37-
size_t expected_pt_len,
38-
uint8_t *output);
39-
""")
40-
41-
42-
def _pkcs1_decode(em, sentinel, expected_pt_len, output):
43-
if len(em) != len(output):
44-
raise ValueError("Incorrect output length")
45-
46-
ret = _raw_pkcs1_decode.pkcs1_decode(c_uint8_ptr(em),
47-
c_size_t(len(em)),
48-
c_uint8_ptr(sentinel),
49-
c_size_t(len(sentinel)),
50-
c_size_t(expected_pt_len),
51-
c_uint8_ptr(output))
52-
return ret
28+
from ._pkcs1_oaep_decode import pkcs1_decode
5329

5430

5531
class PKCS115_Cipher:
@@ -113,7 +89,6 @@ def encrypt(self, message):
11389
continue
11490
ps.append(new_byte)
11591
ps = b"".join(ps)
116-
assert(len(ps) == k - mLen - 3)
11792
# Step 2b
11893
em = b'\x00\x02' + ps + b'\x00' + _copy_bytes(None, None, message)
11994
# Step 3a (OS2IP)
@@ -182,14 +157,14 @@ def decrypt(self, ciphertext, sentinel, expected_pt_len=0):
182157
# Step 3 (not constant time when the sentinel is not a byte string)
183158
output = bytes(bytearray(k))
184159
if not is_bytes(sentinel) or len(sentinel) > k:
185-
size = _pkcs1_decode(em, b'', expected_pt_len, output)
160+
size = pkcs1_decode(em, b'', expected_pt_len, output)
186161
if size < 0:
187162
return sentinel
188163
else:
189164
return output[size:]
190165

191166
# Step 3 (somewhat constant time)
192-
size = _pkcs1_decode(em, sentinel, expected_pt_len, output)
167+
size = pkcs1_decode(em, sentinel, expected_pt_len, output)
193168
return output[size:]
194169

195170

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from Crypto.Util._raw_api import (load_pycryptodome_raw_lib, c_size_t,
2+
c_uint8_ptr)
3+
4+
5+
_raw_pkcs1_decode = load_pycryptodome_raw_lib("Crypto.Cipher._pkcs1_decode",
6+
"""
7+
int pkcs1_decode(const uint8_t *em, size_t len_em,
8+
const uint8_t *sentinel, size_t len_sentinel,
9+
size_t expected_pt_len,
10+
uint8_t *output);
11+
12+
int oaep_decode(const uint8_t *em,
13+
size_t em_len,
14+
const uint8_t *lHash,
15+
size_t hLen,
16+
const uint8_t *db,
17+
size_t db_len);
18+
""")
19+
20+
21+
def pkcs1_decode(em, sentinel, expected_pt_len, output):
22+
if len(em) != len(output):
23+
raise ValueError("Incorrect output length")
24+
25+
ret = _raw_pkcs1_decode.pkcs1_decode(c_uint8_ptr(em),
26+
c_size_t(len(em)),
27+
c_uint8_ptr(sentinel),
28+
c_size_t(len(sentinel)),
29+
c_size_t(expected_pt_len),
30+
c_uint8_ptr(output))
31+
return ret
32+
33+
34+
def oaep_decode(em, lHash, db):
35+
ret = _raw_pkcs1_decode.oaep_decode(c_uint8_ptr(em),
36+
c_size_t(len(em)),
37+
c_uint8_ptr(lHash),
38+
c_size_t(len(lHash)),
39+
c_uint8_ptr(db),
40+
c_size_t(len(db)))
41+
return ret

src/pkcs1_decode.c

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ STATIC size_t safe_select_idx(size_t in1, size_t in2, uint8_t choice)
130130
* - in1[] is NOT equal to in2[] where neq_mask[] is 0xFF.
131131
* Return non-zero otherwise.
132132
*/
133-
STATIC uint8_t safe_cmp(const uint8_t *in1, const uint8_t *in2,
133+
STATIC uint8_t safe_cmp_masks(const uint8_t *in1, const uint8_t *in2,
134134
const uint8_t *eq_mask, const uint8_t *neq_mask,
135135
size_t len)
136136
{
@@ -187,7 +187,7 @@ STATIC size_t safe_search(const uint8_t *in1, uint8_t c, size_t len)
187187
return result;
188188
}
189189

190-
#define EM_PREFIX_LEN 10
190+
#define PKCS1_PREFIX_LEN 10
191191

192192
/*
193193
* Decode and verify the PKCS#1 padding, then put either the plaintext
@@ -222,13 +222,13 @@ EXPORT_SYM int pkcs1_decode(const uint8_t *em, size_t len_em_output,
222222
if (NULL == em || NULL == output || NULL == sentinel) {
223223
return -1;
224224
}
225-
if (len_em_output < (EM_PREFIX_LEN + 2)) {
225+
if (len_em_output < (PKCS1_PREFIX_LEN + 2)) {
226226
return -1;
227227
}
228228
if (len_sentinel > len_em_output) {
229229
return -1;
230230
}
231-
if (expected_pt_len > 0 && expected_pt_len > (len_em_output - EM_PREFIX_LEN - 1)) {
231+
if (expected_pt_len > 0 && expected_pt_len > (len_em_output - PKCS1_PREFIX_LEN - 1)) {
232232
return -1;
233233
}
234234

@@ -240,7 +240,7 @@ EXPORT_SYM int pkcs1_decode(const uint8_t *em, size_t len_em_output,
240240
memcpy(padded_sentinel + (len_em_output - len_sentinel), sentinel, len_sentinel);
241241

242242
/** The first 10 bytes must follow the pattern **/
243-
match = safe_cmp(em,
243+
match = safe_cmp_masks(em,
244244
(const uint8_t*)"\x00\x02" "\x00\x00\x00\x00\x00\x00\x00\x00",
245245
(const uint8_t*)"\xFF\xFF" "\x00\x00\x00\x00\x00\x00\x00\x00",
246246
(const uint8_t*)"\x00\x00" "\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF",
@@ -283,3 +283,72 @@ EXPORT_SYM int pkcs1_decode(const uint8_t *em, size_t len_em_output,
283283
free(padded_sentinel);
284284
return result;
285285
}
286+
287+
/*
288+
* Decode and verify the OAEP padding in constant time.
289+
*
290+
* The function returns the number of bytes to ignore at the beginning
291+
* of db (the rest is the plaintext), or -1 in case of problems.
292+
*/
293+
294+
EXPORT_SYM int oaep_decode(const uint8_t *em,
295+
size_t em_len,
296+
const uint8_t *lHash,
297+
size_t hLen,
298+
const uint8_t *db,
299+
size_t db_len) /* em_len - 1 - hLen */
300+
{
301+
int result;
302+
size_t one_pos, search_len, i;
303+
uint8_t wrong_padding;
304+
uint8_t *eq_mask = NULL;
305+
uint8_t *neq_mask = NULL;
306+
uint8_t *target_db = NULL;
307+
308+
if (NULL == em || NULL == lHash || NULL == db) {
309+
return -1;
310+
}
311+
312+
if (em_len < 2*hLen+2 || db_len != em_len-1-hLen) {
313+
return -1;
314+
}
315+
316+
/* Allocate */
317+
eq_mask = (uint8_t*) calloc(1, db_len);
318+
neq_mask = (uint8_t*) calloc(1, db_len);
319+
target_db = (uint8_t*) calloc(1, db_len);
320+
if (NULL == eq_mask || NULL == neq_mask || NULL == target_db) {
321+
result = -1;
322+
goto cleanup;
323+
}
324+
325+
/* Step 3g */
326+
search_len = db_len - hLen;
327+
328+
one_pos = safe_search(db + hLen, 0x01, search_len);
329+
if (SIZE_T_MAX == one_pos) {
330+
result = -1;
331+
goto cleanup;
332+
}
333+
334+
memset(eq_mask, 0xAA, db_len);
335+
memcpy(target_db, lHash, hLen);
336+
memset(eq_mask, 0xFF, hLen);
337+
338+
for (i=0; i<search_len; i++) {
339+
eq_mask[hLen + i] = propagate_ones(i < one_pos);
340+
}
341+
342+
wrong_padding = em[0];
343+
wrong_padding |= safe_cmp_masks(db, target_db, eq_mask, neq_mask, db_len);
344+
set_if_match(&wrong_padding, one_pos, search_len);
345+
346+
result = wrong_padding ? -1 : (int)(hLen + 1 + one_pos);
347+
348+
cleanup:
349+
free(eq_mask);
350+
free(neq_mask);
351+
free(target_db);
352+
353+
return result;
354+
}

src/test/test_pkcs1.c

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ void set_if_match(uint8_t *flag, size_t term1, size_t term2);
55
void set_if_no_match(uint8_t *flag, size_t term1, size_t term2);
66
void safe_select(const uint8_t *in1, const uint8_t *in2, uint8_t *out, uint8_t choice, size_t len);
77
size_t safe_select_idx(size_t in1, size_t in2, uint8_t choice);
8-
uint8_t safe_cmp(const uint8_t *in1, const uint8_t *in2,
8+
uint8_t safe_cmp_masks(const uint8_t *in1, const uint8_t *in2,
99
const uint8_t *eq_mask, const uint8_t *neq_mask,
1010
size_t len);
1111
size_t safe_search(const uint8_t *in1, uint8_t c, size_t len);
@@ -80,57 +80,57 @@ void test_safe_select_idx(void)
8080
assert(safe_select_idx(0x100004, 0x223344, 1) == 0x223344);
8181
}
8282

83-
void test_safe_cmp(void)
83+
void test_safe_cmp_masks(void)
8484
{
8585
uint8_t res;
8686

87-
res = safe_cmp(onezero, onezero,
87+
res = safe_cmp_masks(onezero, onezero,
8888
(uint8_t*)"\xFF\xFF",
8989
(uint8_t*)"\x00\x00",
9090
2);
9191
assert(res == 0);
9292

93-
res = safe_cmp(onezero, zerozero,
93+
res = safe_cmp_masks(onezero, zerozero,
9494
(uint8_t*)"\xFF\xFF",
9595
(uint8_t*)"\x00\x00",
9696
2);
9797
assert(res != 0);
9898

99-
res = safe_cmp(onezero, oneone,
99+
res = safe_cmp_masks(onezero, oneone,
100100
(uint8_t*)"\xFF\xFF",
101101
(uint8_t*)"\x00\x00",
102102
2);
103103
assert(res != 0);
104104

105-
res = safe_cmp(onezero, oneone,
105+
res = safe_cmp_masks(onezero, oneone,
106106
(uint8_t*)"\xFF\x00",
107107
(uint8_t*)"\x00\x00",
108108
2);
109109
assert(res == 0);
110110

111111
/** -- **/
112112

113-
res = safe_cmp(onezero, onezero,
113+
res = safe_cmp_masks(onezero, onezero,
114114
(uint8_t*)"\x00\x00",
115115
(uint8_t*)"\xFF\xFF",
116116
2);
117117
assert(res != 0);
118118

119-
res = safe_cmp(oneone, zerozero,
119+
res = safe_cmp_masks(oneone, zerozero,
120120
(uint8_t*)"\x00\x00",
121121
(uint8_t*)"\xFF\xFF",
122122
2);
123123
assert(res == 0);
124124

125-
res = safe_cmp(onezero, oneone,
125+
res = safe_cmp_masks(onezero, oneone,
126126
(uint8_t*)"\x00\x00",
127127
(uint8_t*)"\x00\xFF",
128128
2);
129129
assert(res == 0);
130130

131131
/** -- **/
132132

133-
res = safe_cmp(onezero, oneone,
133+
res = safe_cmp_masks(onezero, oneone,
134134
(uint8_t*)"\xFF\x00",
135135
(uint8_t*)"\x00\xFF",
136136
2);
@@ -158,7 +158,7 @@ int main(void)
158158
test_set_if_no_match();
159159
test_safe_select();
160160
test_safe_select_idx();
161-
test_safe_cmp();
161+
test_safe_cmp_masks();
162162
test_safe_search();
163163
return 0;
164164
}

0 commit comments

Comments
 (0)