Merge of r85:88 from branch polarssl into trunk.
[umurmur.git] / src / crypt.c
1 /* Copyright (C) 2009-2010, Martin Johansson <martin@fatbob.nu>
2    Copyright (C) 2005-2010, Thorvald Natvig <thorvald@natvig.com>
3
4    All rights reserved.
5
6    Redistribution and use in source and binary forms, with or without
7    modification, are permitted provided that the following conditions
8    are met:
9
10    - Redistributions of source code must retain the above copyright notice,
11      this list of conditions and the following disclaimer.
12    - Redistributions in binary form must reproduce the above copyright notice,
13      this list of conditions and the following disclaimer in the documentation
14      and/or other materials provided with the distribution.
15    - Neither the name of the Developers nor the names of its contributors may
16      be used to endorse or promote products derived from this software without
17      specific prior written permission.
18
19    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
20    ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
21    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
22    A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE FOUNDATION OR
23    CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
24    EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
25    PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
26    PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
27    LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
28    NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
29    SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30 */
31
32 /*
33  * This code implements OCB-AES128.
34  * In the US, OCB is covered by patents. The inventor has given a license
35  * to all programs distributed under the GPL.
36  * uMurmur is BSD (revised) licensed, meaning you can use the code in a
37  * closed-source program. If you do, you'll have to either replace
38  * OCB with something else or get yourself a license.
39  */
40
41 #include <string.h>
42 #include <arpa/inet.h>
43 #include "crypt.h"
44
45 #ifdef USE_POLARSSL
46 #include <polarssl/havege.h>
47 #define RAND_bytes(_dst_, _size_) do { \
48         int i; \
49         for (i = 0; i < _size_; i++) { \
50         _dst_[i] = havege_rand(&hs); \
51         } \
52  } while (0);
53
54 extern havege_state hs;
55 #endif
56
57 static void CryptState_ocb_encrypt(cryptState_t *cs, const unsigned char *plain, unsigned char *encrypted, unsigned int len, const unsigned char *nonce, unsigned char *tag);
58 static void CryptState_ocb_decrypt(cryptState_t *cs, const unsigned char *encrypted, unsigned char *plain, unsigned int len, const unsigned char *nonce, unsigned char *tag);
59
60 void CryptState_init(cryptState_t *cs)
61 {
62         memset(cs->decrypt_history, 0, 0xff);
63         memset(cs->raw_key, 0, AES_BLOCK_SIZE);
64         memset(cs->encrypt_iv, 0, AES_BLOCK_SIZE);
65         memset(cs->decrypt_iv, 0, AES_BLOCK_SIZE);
66         cs->bInit = false;
67         cs->uiGood = cs->uiLate = cs->uiLost = cs->uiResync = 0;
68         cs->uiRemoteGood = cs->uiRemoteLate = cs->uiRemoteLost = cs->uiRemoteResync = 0;
69         Timer_init(&cs->tLastGood);
70         Timer_init(&cs->tLastRequest);
71 }
72
73 bool_t CryptState_isValid(cryptState_t *cs)
74 {
75         return cs->bInit;
76 }
77
78 void CryptState_genKey(cryptState_t *cs) {
79         RAND_bytes(cs->raw_key, AES_BLOCK_SIZE);
80         RAND_bytes(cs->encrypt_iv, AES_BLOCK_SIZE);
81         RAND_bytes(cs->decrypt_iv, AES_BLOCK_SIZE);
82 #ifndef USE_POLARSSL
83         AES_set_encrypt_key(cs->raw_key, 128, &cs->encrypt_key);
84         AES_set_decrypt_key(cs->raw_key, 128, &cs->decrypt_key);
85 #else
86         aes_setkey_enc(&cs->aes_enc, cs->raw_key, 128);
87         aes_setkey_dec(&cs->aes_dec, cs->raw_key, 128);
88 #endif
89         cs->bInit = true;
90 }
91
92 void CryptState_setKey(cryptState_t *cs, const unsigned char *rkey, const unsigned char *eiv, const unsigned char *div)
93 {
94         memcpy(cs->raw_key, rkey, AES_BLOCK_SIZE);
95         memcpy(cs->encrypt_iv, eiv, AES_BLOCK_SIZE);
96         memcpy(cs->decrypt_iv, div, AES_BLOCK_SIZE);
97 #ifndef USE_POLARSSL
98         AES_set_encrypt_key(cs->decrypt_iv, 128, &cs->encrypt_key);
99         AES_set_decrypt_key(cs->raw_key, 128, &cs->decrypt_key);
100 #else
101         aes_setkey_enc(&cs->aes_enc, cs->decrypt_iv, 128);
102         aes_setkey_dec(&cs->aes_dec, cs->raw_key, 128);
103 #endif
104         cs->bInit = true;
105 }
106
107 void CryptState_setDecryptIV(cryptState_t *cs, const unsigned char *iv)
108 {
109         memcpy(cs->decrypt_iv, iv, AES_BLOCK_SIZE);
110 }
111
112 void CryptState_encrypt(cryptState_t *cs, const unsigned char *source, unsigned char *dst, unsigned int plain_length)
113 {
114         unsigned char tag[AES_BLOCK_SIZE];
115         int i;
116         // First, increase our IV.
117         for (i = 0; i < AES_BLOCK_SIZE; i++)
118                 if (++cs->encrypt_iv[i])
119                         break;
120
121         CryptState_ocb_encrypt(cs, source, dst+4, plain_length, cs->encrypt_iv, tag);
122
123         dst[0] = cs->encrypt_iv[0];
124         dst[1] = tag[0];
125         dst[2] = tag[1];
126         dst[3] = tag[2];
127 }
128
129 bool_t CryptState_decrypt(cryptState_t *cs, const unsigned char *source, unsigned char *dst, unsigned int crypted_length)
130 {
131         if (crypted_length < 4)
132                 return false;
133
134         unsigned int plain_length = crypted_length - 4;
135
136         unsigned char saveiv[AES_BLOCK_SIZE];
137         unsigned char ivbyte = source[0];
138         bool_t restore = false;
139         unsigned char tag[AES_BLOCK_SIZE];
140
141         int lost = 0;
142         int late = 0;
143
144         memcpy(saveiv, cs->decrypt_iv, AES_BLOCK_SIZE);
145
146         if (((cs->decrypt_iv[0] + 1) & 0xFF) == ivbyte) {
147                 // In order as expected.
148                 if (ivbyte > cs->decrypt_iv[0]) {
149                         cs->decrypt_iv[0] = ivbyte;
150                 } else if (ivbyte < cs->decrypt_iv[0]) {
151                         int i;
152                         cs->decrypt_iv[0] = ivbyte;
153                         for (i = 1; i < AES_BLOCK_SIZE; i++)
154                                 if (++cs->decrypt_iv[i])
155                                         break;
156                 } else {
157                         return false;
158                 }
159         } else {
160                 // This is either out of order or a repeat.
161
162                 int diff = ivbyte - cs->decrypt_iv[0];
163                 if (diff > 128)
164                         diff = diff-256;
165                 else if (diff < -128)
166                         diff = diff+256;
167
168                 if ((ivbyte < cs->decrypt_iv[0]) && (diff > -30) && (diff < 0)) {
169                         // Late packet, but no wraparound.
170                         late = 1;
171                         lost = -1;
172                         cs->decrypt_iv[0] = ivbyte;
173                         restore = true;
174                 } else if ((ivbyte > cs->decrypt_iv[0]) && (diff > -30) && (diff < 0)) {
175                         int i;
176                         // Last was 0x02, here comes 0xff from last round
177                         late = 1;
178                         lost = -1;
179                         cs->decrypt_iv[0] = ivbyte;
180                         for (i = 1; i < AES_BLOCK_SIZE; i++)
181                                 if (cs->decrypt_iv[i]--)
182                                         break;
183                         restore = true;
184                 } else if ((ivbyte > cs->decrypt_iv[0]) && (diff > 0)) {
185                         // Lost a few packets, but beyond that we're good.
186                         lost = ivbyte - cs->decrypt_iv[0] - 1;
187                         cs->decrypt_iv[0] = ivbyte;
188                 } else if ((ivbyte < cs->decrypt_iv[0]) && (diff > 0)) {
189                         int i;
190                         // Lost a few packets, and wrapped around
191                         lost = 256 - cs->decrypt_iv[0] + ivbyte - 1;
192                         cs->decrypt_iv[0] = ivbyte;
193                         for (i = 1; i < AES_BLOCK_SIZE; i++)
194                                 if (++cs->decrypt_iv[i])
195                                         break;
196                 } else {
197                         return false;
198                 }
199
200                 if (cs->decrypt_history[cs->decrypt_iv[0]] == cs->decrypt_iv[1]) {
201                         memcpy(cs->decrypt_iv, saveiv, AES_BLOCK_SIZE);
202                         return false;
203                 }
204         }
205
206         CryptState_ocb_decrypt(cs, source+4, dst, plain_length, cs->decrypt_iv, tag);
207
208         if (memcmp(tag, source+1, 3) != 0) {
209                 memcpy(cs->decrypt_iv, saveiv, AES_BLOCK_SIZE);         
210                 return false;
211         }
212         cs->decrypt_history[cs->decrypt_iv[0]] = cs->decrypt_iv[1];
213
214         if (restore)
215                 memcpy(cs->decrypt_iv, saveiv, AES_BLOCK_SIZE);
216
217         cs->uiGood++;
218         cs->uiLate += late;
219         cs->uiLost += lost;
220
221         Timer_restart(&cs->tLastGood);
222         return true;
223 }
224
225 #if defined(__LP64__)
226 #define BLOCKSIZE 2
227 #define SHIFTBITS 63
228 typedef uint64_t subblock;
229
230 #if __BYTE_ORDER == __BIG_ENDIAN
231 #define SWAPPED(x) (x)
232 #else
233 #ifdef __x86_64__
234 #define SWAPPED(x) ({register uint64_t __out, __in = (x); __asm__("bswap %q0" : "=r"(__out) : "0"(__in)); __out;})
235 #else
236 #include <byteswap.h>
237 #define SWAPPED(x) bswap_64(x)
238 #endif
239 #endif
240
241 #else
242
243 #define BLOCKSIZE 4
244 #define SHIFTBITS 31
245 typedef uint32_t subblock;
246 #define SWAPPED(x) htonl(x)
247
248 #endif
249
250 #define HIGHBIT (1<<SHIFTBITS);
251
252
253 static void inline XOR(subblock *dst, const subblock *a, const subblock *b) {
254         int i;
255         for (i=0;i<BLOCKSIZE;i++) {
256                 dst[i] = a[i] ^ b[i];
257         }
258 }
259
260 static void inline S2(subblock *block) {
261         subblock carry = SWAPPED(block[0]) >> SHIFTBITS;
262         int i;
263         for (i=0;i<BLOCKSIZE-1;i++)
264                 block[i] = SWAPPED((SWAPPED(block[i]) << 1) | (SWAPPED(block[i+1]) >> SHIFTBITS));
265         block[BLOCKSIZE-1] = SWAPPED((SWAPPED(block[BLOCKSIZE-1]) << 1) ^(carry * 0x87));
266 }
267
268 static void inline S3(subblock *block) {
269         subblock carry = SWAPPED(block[0]) >> SHIFTBITS;
270         int i;
271         for (i=0;i<BLOCKSIZE-1;i++)
272                 block[i] ^= SWAPPED((SWAPPED(block[i]) << 1) | (SWAPPED(block[i+1]) >> SHIFTBITS));
273         block[BLOCKSIZE-1] ^= SWAPPED((SWAPPED(block[BLOCKSIZE-1]) << 1) ^(carry * 0x87));
274 }
275
276 static void inline ZERO(subblock *block) {
277         int i;
278         for (i=0;i<BLOCKSIZE;i++)
279                 block[i]=0;
280 }
281
282 #ifdef USE_POLARSSL
283 #define AESencrypt(src, dst, cryptstate) aes_crypt_ecb(&(cryptstate)->aes_enc, AES_ENCRYPT, (unsigned char *)(src), (unsigned char *)(dst));
284 #define AESdecrypt(src, dst, cryptstate) aes_crypt_ecb(&(cryptstate)->aes_dec, AES_DECRYPT, (unsigned char *)(src), (unsigned char *)(dst));
285 #else
286 #define AESencrypt(src, dst, cryptstate) AES_encrypt((unsigned char *)(src), (unsigned char *)(dst), &(cryptstate)->encrypt_key);
287 #define AESdecrypt(src, dst, cryptstate) AES_decrypt((unsigned char *)(src), (unsigned char *)(dst), &(cryptstate)->decrypt_key);
288 #endif
289
290 void CryptState_ocb_encrypt(cryptState_t *cs, const unsigned char *plain, unsigned char *encrypted, unsigned int len, const unsigned char *nonce, unsigned char *tag) {
291         subblock checksum[BLOCKSIZE], delta[BLOCKSIZE], tmp[BLOCKSIZE], pad[BLOCKSIZE];
292
293         // Initialize
294         AESencrypt(nonce, delta, cs);
295         ZERO(checksum);
296
297         while (len > AES_BLOCK_SIZE) {
298                 S2(delta);
299                 XOR(tmp, delta, (const subblock *)(plain));
300                 AESencrypt(tmp, tmp, cs);
301                 XOR((subblock *)(encrypted), delta, tmp);
302                 XOR(checksum, checksum, (subblock *)(plain));
303                 len -= AES_BLOCK_SIZE;
304                 plain += AES_BLOCK_SIZE;
305                 encrypted += AES_BLOCK_SIZE;
306         }
307
308         S2(delta);
309         ZERO(tmp);
310         tmp[BLOCKSIZE - 1] = SWAPPED(len * 8);
311         XOR(tmp, tmp, delta);
312         AESencrypt(tmp, pad, cs);
313         memcpy(tmp, plain, len);
314         memcpy((unsigned char *)tmp + len, (unsigned char *)pad + len, AES_BLOCK_SIZE - len);
315         XOR(checksum, checksum, tmp);
316         XOR(tmp, pad, tmp);
317         memcpy(encrypted, tmp, len);
318
319         S3(delta);
320         XOR(tmp, delta, checksum);
321         AESencrypt(tmp, tag, cs);
322 }
323
324 void CryptState_ocb_decrypt(cryptState_t *cs, const unsigned char *encrypted, unsigned char *plain, unsigned int len, const unsigned char *nonce, unsigned char *tag) {
325         subblock checksum[BLOCKSIZE], delta[BLOCKSIZE], tmp[BLOCKSIZE], pad[BLOCKSIZE];
326         // Initialize
327         AESencrypt(nonce, delta, cs);
328         ZERO(checksum);
329
330         while (len > AES_BLOCK_SIZE) {
331                 S2(delta);
332                 XOR(tmp, delta, (const subblock *)(encrypted));
333                 AESdecrypt(tmp, tmp, cs);
334                 XOR((subblock *)(plain), delta, tmp);
335                 XOR(checksum, checksum, (const subblock *)(plain));
336                 len -= AES_BLOCK_SIZE;
337                 plain += AES_BLOCK_SIZE;
338                 encrypted += AES_BLOCK_SIZE;
339         }
340
341         S2(delta);
342         ZERO(tmp);
343         tmp[BLOCKSIZE - 1] = SWAPPED(len * 8);
344         XOR(tmp, tmp, delta);
345         AESencrypt(tmp, pad, cs);
346         memset(tmp, 0, AES_BLOCK_SIZE);
347         memcpy(tmp, encrypted, len);
348         XOR(tmp, tmp, pad);
349         XOR(checksum, checksum, tmp);
350         memcpy(plain, tmp, len);
351
352         S3(delta);
353         XOR(tmp, delta, checksum);
354         AESencrypt(tmp, tag, cs);
355 }