Add support for PolarSSL 1.1.1
[umurmur.git] / src / crypt.c
1 /* Copyright (C) 2009-2011, Martin Johansson <martin@fatbob.nu>
2    Copyright (C) 2005-2011, Thorvald Natvig <thorvald@natvig.com>
3
4    All rights reserved.
5
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9
10    - Redistributions of source code must retain the above copyright notice,
11      this list of conditions and the following disclaimer.
12    - Redistributions in binary form must reproduce the above copyright notice,
13      this list of conditions and the following disclaimer in the documentation
14      and/or other materials provided with the distribution.
15    - Neither the name of the Developers nor the names of its contributors may
16      be used to endorse or promote products derived from this software without
17      specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /*
33  * This code implements OCB-AES128.
34  * In the US, OCB is covered by patents. The inventor has given a license
35  * to all programs distributed under the GPL.
36  * uMurmur is BSD (revised) licensed, meaning you can use the code in a
37  * closed-source program. If you do, you'll have to either replace
38  * OCB with something else or get yourself a license.
39  */
40
41 #include <string.h>
42 #include <arpa/inet.h>
43 #include "crypt.h"
44 #include "ssl.h"
45
46 #ifdef USE_POLARSSL
47 #include <polarssl/havege.h>
48 extern havege_state hs;
49 #endif
50
51 static void CryptState_ocb_encrypt(cryptState_t *cs, const unsigned char *plain, unsigned char *encrypted, unsigned int len, const unsigned char *nonce, unsigned char *tag);
52 static void CryptState_ocb_decrypt(cryptState_t *cs, const unsigned char *encrypted, unsigned char *plain, unsigned int len, const unsigned char *nonce, unsigned char *tag);
53
54 void CryptState_init(cryptState_t *cs)
55 {
56         memset(cs->decrypt_history, 0, 0xff);
57         memset(cs->raw_key, 0, AES_BLOCK_SIZE);
58         memset(cs->encrypt_iv, 0, AES_BLOCK_SIZE);
59         memset(cs->decrypt_iv, 0, AES_BLOCK_SIZE);
60         cs->bInit = false;
61         cs->uiGood = cs->uiLate = cs->uiLost = cs->uiResync = 0;
62         cs->uiRemoteGood = cs->uiRemoteLate = cs->uiRemoteLost = cs->uiRemoteResync = 0;
63         Timer_init(&cs->tLastGood);
64         Timer_init(&cs->tLastRequest);
65 }
66
67 bool_t CryptState_isValid(cryptState_t *cs)
68 {
69         return cs->bInit;
70 }
71
72 void CryptState_genKey(cryptState_t *cs) {
73         RAND_bytes(cs->raw_key, AES_BLOCK_SIZE);
74         RAND_bytes(cs->encrypt_iv, AES_BLOCK_SIZE);
75         RAND_bytes(cs->decrypt_iv, AES_BLOCK_SIZE);
76 #ifndef USE_POLARSSL
77         AES_set_encrypt_key(cs->raw_key, 128, &cs->encrypt_key);
78         AES_set_decrypt_key(cs->raw_key, 128, &cs->decrypt_key);
79 #else
80         aes_setkey_enc(&cs->aes_enc, cs->raw_key, 128);
81         aes_setkey_dec(&cs->aes_dec, cs->raw_key, 128);
82 #endif
83         cs->bInit = true;
84 }
85
86 void CryptState_setKey(cryptState_t *cs, const unsigned char *rkey, const unsigned char *eiv, const unsigned char *div)
87 {
88         memcpy(cs->raw_key, rkey, AES_BLOCK_SIZE);
89         memcpy(cs->encrypt_iv, eiv, AES_BLOCK_SIZE);
90         memcpy(cs->decrypt_iv, div, AES_BLOCK_SIZE);
91 #ifndef USE_POLARSSL
92         AES_set_encrypt_key(cs->decrypt_iv, 128, &cs->encrypt_key);
93         AES_set_decrypt_key(cs->raw_key, 128, &cs->decrypt_key);
94 #else
95         aes_setkey_enc(&cs->aes_enc, cs->decrypt_iv, 128);
96         aes_setkey_dec(&cs->aes_dec, cs->raw_key, 128);
97 #endif
98         cs->bInit = true;
99 }
100
101 void CryptState_setDecryptIV(cryptState_t *cs, const unsigned char *iv)
102 {
103         memcpy(cs->decrypt_iv, iv, AES_BLOCK_SIZE);
104 }
105
106 void CryptState_encrypt(cryptState_t *cs, const unsigned char *source, unsigned char *dst, unsigned int plain_length)
107 {
108         unsigned char tag[AES_BLOCK_SIZE];
109         int i;
110         // First, increase our IV.
111         for (i = 0; i < AES_BLOCK_SIZE; i++)
112                 if (++cs->encrypt_iv[i])
113                         break;
114
115         CryptState_ocb_encrypt(cs, source, dst+4, plain_length, cs->encrypt_iv, tag);
116
117         dst[0] = cs->encrypt_iv[0];
118         dst[1] = tag[0];
119         dst[2] = tag[1];
120         dst[3] = tag[2];
121 }
122
123 bool_t CryptState_decrypt(cryptState_t *cs, const unsigned char *source, unsigned char *dst, unsigned int crypted_length)
124 {
125         if (crypted_length < 4)
126                 return false;
127
128         unsigned int plain_length = crypted_length - 4;
129
130         unsigned char saveiv[AES_BLOCK_SIZE];
131         unsigned char ivbyte = source[0];
132         bool_t restore = false;
133         unsigned char tag[AES_BLOCK_SIZE];
134
135         int lost = 0;
136         int late = 0;
137
138         memcpy(saveiv, cs->decrypt_iv, AES_BLOCK_SIZE);
139
140         if (((cs->decrypt_iv[0] + 1) & 0xFF) == ivbyte) {
141                 // In order as expected.
142                 if (ivbyte > cs->decrypt_iv[0]) {
143                         cs->decrypt_iv[0] = ivbyte;
144                 } else if (ivbyte < cs->decrypt_iv[0]) {
145                         int i;
146                         cs->decrypt_iv[0] = ivbyte;
147                         for (i = 1; i < AES_BLOCK_SIZE; i++)
148                                 if (++cs->decrypt_iv[i])
149                                         break;
150                 } else {
151                         return false;
152                 }
153         } else {
154                 // This is either out of order or a repeat.
155
156                 int diff = ivbyte - cs->decrypt_iv[0];
157                 if (diff > 128)
158                         diff = diff-256;
159                 else if (diff < -128)
160                         diff = diff+256;
161
162                 if ((ivbyte < cs->decrypt_iv[0]) && (diff > -30) && (diff < 0)) {
163                         // Late packet, but no wraparound.
164                         late = 1;
165                         lost = -1;
166                         cs->decrypt_iv[0] = ivbyte;
167                         restore = true;
168                 } else if ((ivbyte > cs->decrypt_iv[0]) && (diff > -30) && (diff < 0)) {
169                         int i;
170                         // Last was 0x02, here comes 0xff from last round
171                         late = 1;
172                         lost = -1;
173                         cs->decrypt_iv[0] = ivbyte;
174                         for (i = 1; i < AES_BLOCK_SIZE; i++)
175                                 if (cs->decrypt_iv[i]--)
176                                         break;
177                         restore = true;
178                 } else if ((ivbyte > cs->decrypt_iv[0]) && (diff > 0)) {
179                         // Lost a few packets, but beyond that we're good.
180                         lost = ivbyte - cs->decrypt_iv[0] - 1;
181                         cs->decrypt_iv[0] = ivbyte;
182                 } else if ((ivbyte < cs->decrypt_iv[0]) && (diff > 0)) {
183                         int i;
184                         // Lost a few packets, and wrapped around
185                         lost = 256 - cs->decrypt_iv[0] + ivbyte - 1;
186                         cs->decrypt_iv[0] = ivbyte;
187                         for (i = 1; i < AES_BLOCK_SIZE; i++)
188                                 if (++cs->decrypt_iv[i])
189                                         break;
190                 } else {
191                         return false;
192                 }
193
194                 if (cs->decrypt_history[cs->decrypt_iv[0]] == cs->decrypt_iv[1]) {
195                         memcpy(cs->decrypt_iv, saveiv, AES_BLOCK_SIZE);
196                         return false;
197                 }
198         }
199
200         CryptState_ocb_decrypt(cs, source+4, dst, plain_length, cs->decrypt_iv, tag);
201
202         if (memcmp(tag, source+1, 3) != 0) {
203                 memcpy(cs->decrypt_iv, saveiv, AES_BLOCK_SIZE);         
204                 return false;
205         }
206         cs->decrypt_history[cs->decrypt_iv[0]] = cs->decrypt_iv[1];
207
208         if (restore)
209                 memcpy(cs->decrypt_iv, saveiv, AES_BLOCK_SIZE);
210
211         cs->uiGood++;
212         cs->uiLate += late;
213         cs->uiLost += lost;
214
215         Timer_restart(&cs->tLastGood);
216         return true;
217 }
218
219 #if defined(__LP64__)
220 #define BLOCKSIZE 2
221 #define SHIFTBITS 63
222 typedef uint64_t subblock;
223
224 #if __BYTE_ORDER == __BIG_ENDIAN
225 #define SWAPPED(x) (x)
226 #else
227 #ifdef __x86_64__
228 #define SWAPPED(x) ({register uint64_t __out, __in = (x); __asm__("bswap %q0" : "=r"(__out) : "0"(__in)); __out;})
229 #else
230 #include <byteswap.h>
231 #define SWAPPED(x) bswap_64(x)
232 #endif
233 #endif
234
235 #else
236
237 #define BLOCKSIZE 4
238 #define SHIFTBITS 31
239 typedef uint32_t subblock;
240 #define SWAPPED(x) htonl(x)
241
242 #endif
243
244 #define HIGHBIT (1<<SHIFTBITS);
245
246
247 static void inline XOR(subblock *dst, const subblock *a, const subblock *b) {
248         int i;
249         for (i=0;i<BLOCKSIZE;i++) {
250                 dst[i] = a[i] ^ b[i];
251         }
252 }
253
254 static void inline S2(subblock *block) {
255         subblock carry = SWAPPED(block[0]) >> SHIFTBITS;
256         int i;
257         for (i=0;i<BLOCKSIZE-1;i++)
258                 block[i] = SWAPPED((SWAPPED(block[i]) << 1) | (SWAPPED(block[i+1]) >> SHIFTBITS));
259         block[BLOCKSIZE-1] = SWAPPED((SWAPPED(block[BLOCKSIZE-1]) << 1) ^(carry * 0x87));
260 }
261
262 static void inline S3(subblock *block) {
263         subblock carry = SWAPPED(block[0]) >> SHIFTBITS;
264         int i;
265         for (i=0;i<BLOCKSIZE-1;i++)
266                 block[i] ^= SWAPPED((SWAPPED(block[i]) << 1) | (SWAPPED(block[i+1]) >> SHIFTBITS));
267         block[BLOCKSIZE-1] ^= SWAPPED((SWAPPED(block[BLOCKSIZE-1]) << 1) ^(carry * 0x87));
268 }
269
270 static void inline ZERO(subblock *block) {
271         int i;
272         for (i=0;i<BLOCKSIZE;i++)
273                 block[i]=0;
274 }
275
276 #ifdef USE_POLARSSL
277 #define AESencrypt(src, dst, cryptstate) aes_crypt_ecb(&(cryptstate)->aes_enc, AES_ENCRYPT, (unsigned char *)(src), (unsigned char *)(dst));
278 #define AESdecrypt(src, dst, cryptstate) aes_crypt_ecb(&(cryptstate)->aes_dec, AES_DECRYPT, (unsigned char *)(src), (unsigned char *)(dst));
279 #else
280 #define AESencrypt(src, dst, cryptstate) AES_encrypt((unsigned char *)(src), (unsigned char *)(dst), &(cryptstate)->encrypt_key);
281 #define AESdecrypt(src, dst, cryptstate) AES_decrypt((unsigned char *)(src), (unsigned char *)(dst), &(cryptstate)->decrypt_key);
282 #endif
283
284 void CryptState_ocb_encrypt(cryptState_t *cs, const unsigned char *plain, unsigned char *encrypted, unsigned int len, const unsigned char *nonce, unsigned char *tag) {
285         subblock checksum[BLOCKSIZE], delta[BLOCKSIZE], tmp[BLOCKSIZE], pad[BLOCKSIZE];
286
287         // Initialize
288         AESencrypt(nonce, delta, cs);
289         ZERO(checksum);
290
291         while (len > AES_BLOCK_SIZE) {
292                 S2(delta);
293                 XOR(tmp, delta, (const subblock *)(plain));
294                 AESencrypt(tmp, tmp, cs);
295                 XOR((subblock *)(encrypted), delta, tmp);
296                 XOR(checksum, checksum, (subblock *)(plain));
297                 len -= AES_BLOCK_SIZE;
298                 plain += AES_BLOCK_SIZE;
299                 encrypted += AES_BLOCK_SIZE;
300         }
301
302         S2(delta);
303         ZERO(tmp);
304         tmp[BLOCKSIZE - 1] = SWAPPED(len * 8);
305         XOR(tmp, tmp, delta);
306         AESencrypt(tmp, pad, cs);
307         memcpy(tmp, plain, len);
308         memcpy((unsigned char *)tmp + len, (unsigned char *)pad + len, AES_BLOCK_SIZE - len);
309         XOR(checksum, checksum, tmp);
310         XOR(tmp, pad, tmp);
311         memcpy(encrypted, tmp, len);
312
313         S3(delta);
314         XOR(tmp, delta, checksum);
315         AESencrypt(tmp, tag, cs);
316 }
317
318 void CryptState_ocb_decrypt(cryptState_t *cs, const unsigned char *encrypted, unsigned char *plain, unsigned int len, const unsigned char *nonce, unsigned char *tag) {
319         subblock checksum[BLOCKSIZE], delta[BLOCKSIZE], tmp[BLOCKSIZE], pad[BLOCKSIZE];
320         // Initialize
321         AESencrypt(nonce, delta, cs);
322         ZERO(checksum);
323
324         while (len > AES_BLOCK_SIZE) {
325                 S2(delta);
326                 XOR(tmp, delta, (const subblock *)(encrypted));
327                 AESdecrypt(tmp, tmp, cs);
328                 XOR((subblock *)(plain), delta, tmp);
329                 XOR(checksum, checksum, (const subblock *)(plain));
330                 len -= AES_BLOCK_SIZE;
331                 plain += AES_BLOCK_SIZE;
332                 encrypted += AES_BLOCK_SIZE;
333         }
334
335         S2(delta);
336         ZERO(tmp);
337         tmp[BLOCKSIZE - 1] = SWAPPED(len * 8);
338         XOR(tmp, tmp, delta);
339         AESencrypt(tmp, pad, cs);
340         memset(tmp, 0, AES_BLOCK_SIZE);
341         memcpy(tmp, encrypted, len);
342         XOR(tmp, tmp, pad);
343         XOR(checksum, checksum, tmp);
344         memcpy(plain, tmp, len);
345
346         S3(delta);
347         XOR(tmp, delta, checksum);
348         AESencrypt(tmp, tag, cs);
349 }