Linux-libre 5.3.12-gnu
[librecmc/linux-libre.git] / arch / arm64 / crypto / aes-glue.c
1 // SPDX-License-Identifier: GPL-2.0-only
2 /*
3  * linux/arch/arm64/crypto/aes-glue.c - wrapper code for ARMv8 AES
4  *
5  * Copyright (C) 2013 - 2017 Linaro Ltd <ard.biesheuvel@linaro.org>
6  */
7
8 #include <asm/neon.h>
9 #include <asm/hwcap.h>
10 #include <asm/simd.h>
11 #include <crypto/aes.h>
12 #include <crypto/internal/hash.h>
13 #include <crypto/internal/simd.h>
14 #include <crypto/internal/skcipher.h>
15 #include <crypto/scatterwalk.h>
16 #include <linux/module.h>
17 #include <linux/cpufeature.h>
18 #include <crypto/xts.h>
19
20 #include "aes-ce-setkey.h"
21 #include "aes-ctr-fallback.h"
22
23 #ifdef USE_V8_CRYPTO_EXTENSIONS
24 #define MODE                    "ce"
25 #define PRIO                    300
26 #define aes_setkey              ce_aes_setkey
27 #define aes_expandkey           ce_aes_expandkey
28 #define aes_ecb_encrypt         ce_aes_ecb_encrypt
29 #define aes_ecb_decrypt         ce_aes_ecb_decrypt
30 #define aes_cbc_encrypt         ce_aes_cbc_encrypt
31 #define aes_cbc_decrypt         ce_aes_cbc_decrypt
32 #define aes_cbc_cts_encrypt     ce_aes_cbc_cts_encrypt
33 #define aes_cbc_cts_decrypt     ce_aes_cbc_cts_decrypt
34 #define aes_ctr_encrypt         ce_aes_ctr_encrypt
35 #define aes_xts_encrypt         ce_aes_xts_encrypt
36 #define aes_xts_decrypt         ce_aes_xts_decrypt
37 #define aes_mac_update          ce_aes_mac_update
38 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
39 #else
40 #define MODE                    "neon"
41 #define PRIO                    200
42 #define aes_setkey              crypto_aes_set_key
43 #define aes_expandkey           crypto_aes_expand_key
44 #define aes_ecb_encrypt         neon_aes_ecb_encrypt
45 #define aes_ecb_decrypt         neon_aes_ecb_decrypt
46 #define aes_cbc_encrypt         neon_aes_cbc_encrypt
47 #define aes_cbc_decrypt         neon_aes_cbc_decrypt
48 #define aes_cbc_cts_encrypt     neon_aes_cbc_cts_encrypt
49 #define aes_cbc_cts_decrypt     neon_aes_cbc_cts_decrypt
50 #define aes_ctr_encrypt         neon_aes_ctr_encrypt
51 #define aes_xts_encrypt         neon_aes_xts_encrypt
52 #define aes_xts_decrypt         neon_aes_xts_decrypt
53 #define aes_mac_update          neon_aes_mac_update
54 MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 NEON");
55 MODULE_ALIAS_CRYPTO("ecb(aes)");
56 MODULE_ALIAS_CRYPTO("cbc(aes)");
57 MODULE_ALIAS_CRYPTO("ctr(aes)");
58 MODULE_ALIAS_CRYPTO("xts(aes)");
59 MODULE_ALIAS_CRYPTO("cmac(aes)");
60 MODULE_ALIAS_CRYPTO("xcbc(aes)");
61 MODULE_ALIAS_CRYPTO("cbcmac(aes)");
62 #endif
63
64 MODULE_AUTHOR("Ard Biesheuvel <ard.biesheuvel@linaro.org>");
65 MODULE_LICENSE("GPL v2");
66
67 /* defined in aes-modes.S */
68 asmlinkage void aes_ecb_encrypt(u8 out[], u8 const in[], u32 const rk[],
69                                 int rounds, int blocks);
70 asmlinkage void aes_ecb_decrypt(u8 out[], u8 const in[], u32 const rk[],
71                                 int rounds, int blocks);
72
73 asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
74                                 int rounds, int blocks, u8 iv[]);
75 asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
76                                 int rounds, int blocks, u8 iv[]);
77
78 asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
79                                 int rounds, int bytes, u8 const iv[]);
80 asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
81                                 int rounds, int bytes, u8 const iv[]);
82
83 asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
84                                 int rounds, int blocks, u8 ctr[]);
85
86 asmlinkage void aes_xts_encrypt(u8 out[], u8 const in[], u32 const rk1[],
87                                 int rounds, int blocks, u32 const rk2[], u8 iv[],
88                                 int first);
89 asmlinkage void aes_xts_decrypt(u8 out[], u8 const in[], u32 const rk1[],
90                                 int rounds, int blocks, u32 const rk2[], u8 iv[],
91                                 int first);
92
93 asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
94                                int blocks, u8 dg[], int enc_before,
95                                int enc_after);
96
97 struct cts_cbc_req_ctx {
98         struct scatterlist sg_src[2];
99         struct scatterlist sg_dst[2];
100         struct skcipher_request subreq;
101 };
102
103 struct crypto_aes_xts_ctx {
104         struct crypto_aes_ctx key1;
105         struct crypto_aes_ctx __aligned(8) key2;
106 };
107
108 struct mac_tfm_ctx {
109         struct crypto_aes_ctx key;
110         u8 __aligned(8) consts[];
111 };
112
113 struct mac_desc_ctx {
114         unsigned int len;
115         u8 dg[AES_BLOCK_SIZE];
116 };
117
118 static int skcipher_aes_setkey(struct crypto_skcipher *tfm, const u8 *in_key,
119                                unsigned int key_len)
120 {
121         return aes_setkey(crypto_skcipher_tfm(tfm), in_key, key_len);
122 }
123
124 static int xts_set_key(struct crypto_skcipher *tfm, const u8 *in_key,
125                        unsigned int key_len)
126 {
127         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
128         int ret;
129
130         ret = xts_verify_key(tfm, in_key, key_len);
131         if (ret)
132                 return ret;
133
134         ret = aes_expandkey(&ctx->key1, in_key, key_len / 2);
135         if (!ret)
136                 ret = aes_expandkey(&ctx->key2, &in_key[key_len / 2],
137                                     key_len / 2);
138         if (!ret)
139                 return 0;
140
141         crypto_skcipher_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
142         return -EINVAL;
143 }
144
145 static int ecb_encrypt(struct skcipher_request *req)
146 {
147         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
148         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
149         int err, rounds = 6 + ctx->key_length / 4;
150         struct skcipher_walk walk;
151         unsigned int blocks;
152
153         err = skcipher_walk_virt(&walk, req, false);
154
155         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
156                 kernel_neon_begin();
157                 aes_ecb_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
158                                 ctx->key_enc, rounds, blocks);
159                 kernel_neon_end();
160                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
161         }
162         return err;
163 }
164
165 static int ecb_decrypt(struct skcipher_request *req)
166 {
167         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
168         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
169         int err, rounds = 6 + ctx->key_length / 4;
170         struct skcipher_walk walk;
171         unsigned int blocks;
172
173         err = skcipher_walk_virt(&walk, req, false);
174
175         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
176                 kernel_neon_begin();
177                 aes_ecb_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
178                                 ctx->key_dec, rounds, blocks);
179                 kernel_neon_end();
180                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
181         }
182         return err;
183 }
184
185 static int cbc_encrypt(struct skcipher_request *req)
186 {
187         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
188         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
189         int err, rounds = 6 + ctx->key_length / 4;
190         struct skcipher_walk walk;
191         unsigned int blocks;
192
193         err = skcipher_walk_virt(&walk, req, false);
194
195         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
196                 kernel_neon_begin();
197                 aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
198                                 ctx->key_enc, rounds, blocks, walk.iv);
199                 kernel_neon_end();
200                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
201         }
202         return err;
203 }
204
205 static int cbc_decrypt(struct skcipher_request *req)
206 {
207         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
208         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
209         int err, rounds = 6 + ctx->key_length / 4;
210         struct skcipher_walk walk;
211         unsigned int blocks;
212
213         err = skcipher_walk_virt(&walk, req, false);
214
215         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
216                 kernel_neon_begin();
217                 aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
218                                 ctx->key_dec, rounds, blocks, walk.iv);
219                 kernel_neon_end();
220                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
221         }
222         return err;
223 }
224
225 static int cts_cbc_init_tfm(struct crypto_skcipher *tfm)
226 {
227         crypto_skcipher_set_reqsize(tfm, sizeof(struct cts_cbc_req_ctx));
228         return 0;
229 }
230
231 static int cts_cbc_encrypt(struct skcipher_request *req)
232 {
233         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
234         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
235         struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
236         int err, rounds = 6 + ctx->key_length / 4;
237         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
238         struct scatterlist *src = req->src, *dst = req->dst;
239         struct skcipher_walk walk;
240
241         skcipher_request_set_tfm(&rctx->subreq, tfm);
242
243         if (req->cryptlen <= AES_BLOCK_SIZE) {
244                 if (req->cryptlen < AES_BLOCK_SIZE)
245                         return -EINVAL;
246                 cbc_blocks = 1;
247         }
248
249         if (cbc_blocks > 0) {
250                 unsigned int blocks;
251
252                 skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
253                                            cbc_blocks * AES_BLOCK_SIZE,
254                                            req->iv);
255
256                 err = skcipher_walk_virt(&walk, &rctx->subreq, false);
257
258                 while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
259                         kernel_neon_begin();
260                         aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
261                                         ctx->key_enc, rounds, blocks, walk.iv);
262                         kernel_neon_end();
263                         err = skcipher_walk_done(&walk,
264                                                  walk.nbytes % AES_BLOCK_SIZE);
265                 }
266                 if (err)
267                         return err;
268
269                 if (req->cryptlen == AES_BLOCK_SIZE)
270                         return 0;
271
272                 dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
273                                              rctx->subreq.cryptlen);
274                 if (req->dst != req->src)
275                         dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
276                                                rctx->subreq.cryptlen);
277         }
278
279         /* handle ciphertext stealing */
280         skcipher_request_set_crypt(&rctx->subreq, src, dst,
281                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
282                                    req->iv);
283
284         err = skcipher_walk_virt(&walk, &rctx->subreq, false);
285         if (err)
286                 return err;
287
288         kernel_neon_begin();
289         aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
290                             ctx->key_enc, rounds, walk.nbytes, walk.iv);
291         kernel_neon_end();
292
293         return skcipher_walk_done(&walk, 0);
294 }
295
296 static int cts_cbc_decrypt(struct skcipher_request *req)
297 {
298         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
299         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
300         struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
301         int err, rounds = 6 + ctx->key_length / 4;
302         int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
303         struct scatterlist *src = req->src, *dst = req->dst;
304         struct skcipher_walk walk;
305
306         skcipher_request_set_tfm(&rctx->subreq, tfm);
307
308         if (req->cryptlen <= AES_BLOCK_SIZE) {
309                 if (req->cryptlen < AES_BLOCK_SIZE)
310                         return -EINVAL;
311                 cbc_blocks = 1;
312         }
313
314         if (cbc_blocks > 0) {
315                 unsigned int blocks;
316
317                 skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
318                                            cbc_blocks * AES_BLOCK_SIZE,
319                                            req->iv);
320
321                 err = skcipher_walk_virt(&walk, &rctx->subreq, false);
322
323                 while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
324                         kernel_neon_begin();
325                         aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
326                                         ctx->key_dec, rounds, blocks, walk.iv);
327                         kernel_neon_end();
328                         err = skcipher_walk_done(&walk,
329                                                  walk.nbytes % AES_BLOCK_SIZE);
330                 }
331                 if (err)
332                         return err;
333
334                 if (req->cryptlen == AES_BLOCK_SIZE)
335                         return 0;
336
337                 dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
338                                              rctx->subreq.cryptlen);
339                 if (req->dst != req->src)
340                         dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
341                                                rctx->subreq.cryptlen);
342         }
343
344         /* handle ciphertext stealing */
345         skcipher_request_set_crypt(&rctx->subreq, src, dst,
346                                    req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
347                                    req->iv);
348
349         err = skcipher_walk_virt(&walk, &rctx->subreq, false);
350         if (err)
351                 return err;
352
353         kernel_neon_begin();
354         aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
355                             ctx->key_dec, rounds, walk.nbytes, walk.iv);
356         kernel_neon_end();
357
358         return skcipher_walk_done(&walk, 0);
359 }
360
361 static int ctr_encrypt(struct skcipher_request *req)
362 {
363         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
364         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
365         int err, rounds = 6 + ctx->key_length / 4;
366         struct skcipher_walk walk;
367         int blocks;
368
369         err = skcipher_walk_virt(&walk, req, false);
370
371         while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
372                 kernel_neon_begin();
373                 aes_ctr_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
374                                 ctx->key_enc, rounds, blocks, walk.iv);
375                 kernel_neon_end();
376                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
377         }
378         if (walk.nbytes) {
379                 u8 __aligned(8) tail[AES_BLOCK_SIZE];
380                 unsigned int nbytes = walk.nbytes;
381                 u8 *tdst = walk.dst.virt.addr;
382                 u8 *tsrc = walk.src.virt.addr;
383
384                 /*
385                  * Tell aes_ctr_encrypt() to process a tail block.
386                  */
387                 blocks = -1;
388
389                 kernel_neon_begin();
390                 aes_ctr_encrypt(tail, NULL, ctx->key_enc, rounds,
391                                 blocks, walk.iv);
392                 kernel_neon_end();
393                 crypto_xor_cpy(tdst, tsrc, tail, nbytes);
394                 err = skcipher_walk_done(&walk, 0);
395         }
396
397         return err;
398 }
399
400 static int ctr_encrypt_sync(struct skcipher_request *req)
401 {
402         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
403         struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
404
405         if (!crypto_simd_usable())
406                 return aes_ctr_encrypt_fallback(ctx, req);
407
408         return ctr_encrypt(req);
409 }
410
411 static int xts_encrypt(struct skcipher_request *req)
412 {
413         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
414         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
415         int err, first, rounds = 6 + ctx->key1.key_length / 4;
416         struct skcipher_walk walk;
417         unsigned int blocks;
418
419         err = skcipher_walk_virt(&walk, req, false);
420
421         for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
422                 kernel_neon_begin();
423                 aes_xts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
424                                 ctx->key1.key_enc, rounds, blocks,
425                                 ctx->key2.key_enc, walk.iv, first);
426                 kernel_neon_end();
427                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
428         }
429
430         return err;
431 }
432
433 static int xts_decrypt(struct skcipher_request *req)
434 {
435         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
436         struct crypto_aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
437         int err, first, rounds = 6 + ctx->key1.key_length / 4;
438         struct skcipher_walk walk;
439         unsigned int blocks;
440
441         err = skcipher_walk_virt(&walk, req, false);
442
443         for (first = 1; (blocks = (walk.nbytes / AES_BLOCK_SIZE)); first = 0) {
444                 kernel_neon_begin();
445                 aes_xts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
446                                 ctx->key1.key_dec, rounds, blocks,
447                                 ctx->key2.key_enc, walk.iv, first);
448                 kernel_neon_end();
449                 err = skcipher_walk_done(&walk, walk.nbytes % AES_BLOCK_SIZE);
450         }
451
452         return err;
453 }
454
455 static struct skcipher_alg aes_algs[] = { {
456         .base = {
457                 .cra_name               = "__ecb(aes)",
458                 .cra_driver_name        = "__ecb-aes-" MODE,
459                 .cra_priority           = PRIO,
460                 .cra_flags              = CRYPTO_ALG_INTERNAL,
461                 .cra_blocksize          = AES_BLOCK_SIZE,
462                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
463                 .cra_module             = THIS_MODULE,
464         },
465         .min_keysize    = AES_MIN_KEY_SIZE,
466         .max_keysize    = AES_MAX_KEY_SIZE,
467         .setkey         = skcipher_aes_setkey,
468         .encrypt        = ecb_encrypt,
469         .decrypt        = ecb_decrypt,
470 }, {
471         .base = {
472                 .cra_name               = "__cbc(aes)",
473                 .cra_driver_name        = "__cbc-aes-" MODE,
474                 .cra_priority           = PRIO,
475                 .cra_flags              = CRYPTO_ALG_INTERNAL,
476                 .cra_blocksize          = AES_BLOCK_SIZE,
477                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
478                 .cra_module             = THIS_MODULE,
479         },
480         .min_keysize    = AES_MIN_KEY_SIZE,
481         .max_keysize    = AES_MAX_KEY_SIZE,
482         .ivsize         = AES_BLOCK_SIZE,
483         .setkey         = skcipher_aes_setkey,
484         .encrypt        = cbc_encrypt,
485         .decrypt        = cbc_decrypt,
486 }, {
487         .base = {
488                 .cra_name               = "__cts(cbc(aes))",
489                 .cra_driver_name        = "__cts-cbc-aes-" MODE,
490                 .cra_priority           = PRIO,
491                 .cra_flags              = CRYPTO_ALG_INTERNAL,
492                 .cra_blocksize          = AES_BLOCK_SIZE,
493                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
494                 .cra_module             = THIS_MODULE,
495         },
496         .min_keysize    = AES_MIN_KEY_SIZE,
497         .max_keysize    = AES_MAX_KEY_SIZE,
498         .ivsize         = AES_BLOCK_SIZE,
499         .walksize       = 2 * AES_BLOCK_SIZE,
500         .setkey         = skcipher_aes_setkey,
501         .encrypt        = cts_cbc_encrypt,
502         .decrypt        = cts_cbc_decrypt,
503         .init           = cts_cbc_init_tfm,
504 }, {
505         .base = {
506                 .cra_name               = "__ctr(aes)",
507                 .cra_driver_name        = "__ctr-aes-" MODE,
508                 .cra_priority           = PRIO,
509                 .cra_flags              = CRYPTO_ALG_INTERNAL,
510                 .cra_blocksize          = 1,
511                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
512                 .cra_module             = THIS_MODULE,
513         },
514         .min_keysize    = AES_MIN_KEY_SIZE,
515         .max_keysize    = AES_MAX_KEY_SIZE,
516         .ivsize         = AES_BLOCK_SIZE,
517         .chunksize      = AES_BLOCK_SIZE,
518         .setkey         = skcipher_aes_setkey,
519         .encrypt        = ctr_encrypt,
520         .decrypt        = ctr_encrypt,
521 }, {
522         .base = {
523                 .cra_name               = "ctr(aes)",
524                 .cra_driver_name        = "ctr-aes-" MODE,
525                 .cra_priority           = PRIO - 1,
526                 .cra_blocksize          = 1,
527                 .cra_ctxsize            = sizeof(struct crypto_aes_ctx),
528                 .cra_module             = THIS_MODULE,
529         },
530         .min_keysize    = AES_MIN_KEY_SIZE,
531         .max_keysize    = AES_MAX_KEY_SIZE,
532         .ivsize         = AES_BLOCK_SIZE,
533         .chunksize      = AES_BLOCK_SIZE,
534         .setkey         = skcipher_aes_setkey,
535         .encrypt        = ctr_encrypt_sync,
536         .decrypt        = ctr_encrypt_sync,
537 }, {
538         .base = {
539                 .cra_name               = "__xts(aes)",
540                 .cra_driver_name        = "__xts-aes-" MODE,
541                 .cra_priority           = PRIO,
542                 .cra_flags              = CRYPTO_ALG_INTERNAL,
543                 .cra_blocksize          = AES_BLOCK_SIZE,
544                 .cra_ctxsize            = sizeof(struct crypto_aes_xts_ctx),
545                 .cra_module             = THIS_MODULE,
546         },
547         .min_keysize    = 2 * AES_MIN_KEY_SIZE,
548         .max_keysize    = 2 * AES_MAX_KEY_SIZE,
549         .ivsize         = AES_BLOCK_SIZE,
550         .setkey         = xts_set_key,
551         .encrypt        = xts_encrypt,
552         .decrypt        = xts_decrypt,
553 } };
554
555 static int cbcmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
556                          unsigned int key_len)
557 {
558         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
559         int err;
560
561         err = aes_expandkey(&ctx->key, in_key, key_len);
562         if (err)
563                 crypto_shash_set_flags(tfm, CRYPTO_TFM_RES_BAD_KEY_LEN);
564
565         return err;
566 }
567
568 static void cmac_gf128_mul_by_x(be128 *y, const be128 *x)
569 {
570         u64 a = be64_to_cpu(x->a);
571         u64 b = be64_to_cpu(x->b);
572
573         y->a = cpu_to_be64((a << 1) | (b >> 63));
574         y->b = cpu_to_be64((b << 1) ^ ((a >> 63) ? 0x87 : 0));
575 }
576
577 static int cmac_setkey(struct crypto_shash *tfm, const u8 *in_key,
578                        unsigned int key_len)
579 {
580         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
581         be128 *consts = (be128 *)ctx->consts;
582         int rounds = 6 + key_len / 4;
583         int err;
584
585         err = cbcmac_setkey(tfm, in_key, key_len);
586         if (err)
587                 return err;
588
589         /* encrypt the zero vector */
590         kernel_neon_begin();
591         aes_ecb_encrypt(ctx->consts, (u8[AES_BLOCK_SIZE]){}, ctx->key.key_enc,
592                         rounds, 1);
593         kernel_neon_end();
594
595         cmac_gf128_mul_by_x(consts, consts);
596         cmac_gf128_mul_by_x(consts + 1, consts);
597
598         return 0;
599 }
600
601 static int xcbc_setkey(struct crypto_shash *tfm, const u8 *in_key,
602                        unsigned int key_len)
603 {
604         static u8 const ks[3][AES_BLOCK_SIZE] = {
605                 { [0 ... AES_BLOCK_SIZE - 1] = 0x1 },
606                 { [0 ... AES_BLOCK_SIZE - 1] = 0x2 },
607                 { [0 ... AES_BLOCK_SIZE - 1] = 0x3 },
608         };
609
610         struct mac_tfm_ctx *ctx = crypto_shash_ctx(tfm);
611         int rounds = 6 + key_len / 4;
612         u8 key[AES_BLOCK_SIZE];
613         int err;
614
615         err = cbcmac_setkey(tfm, in_key, key_len);
616         if (err)
617                 return err;
618
619         kernel_neon_begin();
620         aes_ecb_encrypt(key, ks[0], ctx->key.key_enc, rounds, 1);
621         aes_ecb_encrypt(ctx->consts, ks[1], ctx->key.key_enc, rounds, 2);
622         kernel_neon_end();
623
624         return cbcmac_setkey(tfm, key, sizeof(key));
625 }
626
627 static int mac_init(struct shash_desc *desc)
628 {
629         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
630
631         memset(ctx->dg, 0, AES_BLOCK_SIZE);
632         ctx->len = 0;
633
634         return 0;
635 }
636
637 static void mac_do_update(struct crypto_aes_ctx *ctx, u8 const in[], int blocks,
638                           u8 dg[], int enc_before, int enc_after)
639 {
640         int rounds = 6 + ctx->key_length / 4;
641
642         if (crypto_simd_usable()) {
643                 kernel_neon_begin();
644                 aes_mac_update(in, ctx->key_enc, rounds, blocks, dg, enc_before,
645                                enc_after);
646                 kernel_neon_end();
647         } else {
648                 if (enc_before)
649                         __aes_arm64_encrypt(ctx->key_enc, dg, dg, rounds);
650
651                 while (blocks--) {
652                         crypto_xor(dg, in, AES_BLOCK_SIZE);
653                         in += AES_BLOCK_SIZE;
654
655                         if (blocks || enc_after)
656                                 __aes_arm64_encrypt(ctx->key_enc, dg, dg,
657                                                     rounds);
658                 }
659         }
660 }
661
662 static int mac_update(struct shash_desc *desc, const u8 *p, unsigned int len)
663 {
664         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
665         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
666
667         while (len > 0) {
668                 unsigned int l;
669
670                 if ((ctx->len % AES_BLOCK_SIZE) == 0 &&
671                     (ctx->len + len) > AES_BLOCK_SIZE) {
672
673                         int blocks = len / AES_BLOCK_SIZE;
674
675                         len %= AES_BLOCK_SIZE;
676
677                         mac_do_update(&tctx->key, p, blocks, ctx->dg,
678                                       (ctx->len != 0), (len != 0));
679
680                         p += blocks * AES_BLOCK_SIZE;
681
682                         if (!len) {
683                                 ctx->len = AES_BLOCK_SIZE;
684                                 break;
685                         }
686                         ctx->len = 0;
687                 }
688
689                 l = min(len, AES_BLOCK_SIZE - ctx->len);
690
691                 if (l <= AES_BLOCK_SIZE) {
692                         crypto_xor(ctx->dg + ctx->len, p, l);
693                         ctx->len += l;
694                         len -= l;
695                         p += l;
696                 }
697         }
698
699         return 0;
700 }
701
702 static int cbcmac_final(struct shash_desc *desc, u8 *out)
703 {
704         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
705         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
706
707         mac_do_update(&tctx->key, NULL, 0, ctx->dg, (ctx->len != 0), 0);
708
709         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
710
711         return 0;
712 }
713
714 static int cmac_final(struct shash_desc *desc, u8 *out)
715 {
716         struct mac_tfm_ctx *tctx = crypto_shash_ctx(desc->tfm);
717         struct mac_desc_ctx *ctx = shash_desc_ctx(desc);
718         u8 *consts = tctx->consts;
719
720         if (ctx->len != AES_BLOCK_SIZE) {
721                 ctx->dg[ctx->len] ^= 0x80;
722                 consts += AES_BLOCK_SIZE;
723         }
724
725         mac_do_update(&tctx->key, consts, 1, ctx->dg, 0, 1);
726
727         memcpy(out, ctx->dg, AES_BLOCK_SIZE);
728
729         return 0;
730 }
731
732 static struct shash_alg mac_algs[] = { {
733         .base.cra_name          = "cmac(aes)",
734         .base.cra_driver_name   = "cmac-aes-" MODE,
735         .base.cra_priority      = PRIO,
736         .base.cra_blocksize     = AES_BLOCK_SIZE,
737         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
738                                   2 * AES_BLOCK_SIZE,
739         .base.cra_module        = THIS_MODULE,
740
741         .digestsize             = AES_BLOCK_SIZE,
742         .init                   = mac_init,
743         .update                 = mac_update,
744         .final                  = cmac_final,
745         .setkey                 = cmac_setkey,
746         .descsize               = sizeof(struct mac_desc_ctx),
747 }, {
748         .base.cra_name          = "xcbc(aes)",
749         .base.cra_driver_name   = "xcbc-aes-" MODE,
750         .base.cra_priority      = PRIO,
751         .base.cra_blocksize     = AES_BLOCK_SIZE,
752         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx) +
753                                   2 * AES_BLOCK_SIZE,
754         .base.cra_module        = THIS_MODULE,
755
756         .digestsize             = AES_BLOCK_SIZE,
757         .init                   = mac_init,
758         .update                 = mac_update,
759         .final                  = cmac_final,
760         .setkey                 = xcbc_setkey,
761         .descsize               = sizeof(struct mac_desc_ctx),
762 }, {
763         .base.cra_name          = "cbcmac(aes)",
764         .base.cra_driver_name   = "cbcmac-aes-" MODE,
765         .base.cra_priority      = PRIO,
766         .base.cra_blocksize     = 1,
767         .base.cra_ctxsize       = sizeof(struct mac_tfm_ctx),
768         .base.cra_module        = THIS_MODULE,
769
770         .digestsize             = AES_BLOCK_SIZE,
771         .init                   = mac_init,
772         .update                 = mac_update,
773         .final                  = cbcmac_final,
774         .setkey                 = cbcmac_setkey,
775         .descsize               = sizeof(struct mac_desc_ctx),
776 } };
777
778 static struct simd_skcipher_alg *aes_simd_algs[ARRAY_SIZE(aes_algs)];
779
780 static void aes_exit(void)
781 {
782         int i;
783
784         for (i = 0; i < ARRAY_SIZE(aes_simd_algs); i++)
785                 if (aes_simd_algs[i])
786                         simd_skcipher_free(aes_simd_algs[i]);
787
788         crypto_unregister_shashes(mac_algs, ARRAY_SIZE(mac_algs));
789         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
790 }
791
792 static int __init aes_init(void)
793 {
794         struct simd_skcipher_alg *simd;
795         const char *basename;
796         const char *algname;
797         const char *drvname;
798         int err;
799         int i;
800
801         err = crypto_register_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
802         if (err)
803                 return err;
804
805         err = crypto_register_shashes(mac_algs, ARRAY_SIZE(mac_algs));
806         if (err)
807                 goto unregister_ciphers;
808
809         for (i = 0; i < ARRAY_SIZE(aes_algs); i++) {
810                 if (!(aes_algs[i].base.cra_flags & CRYPTO_ALG_INTERNAL))
811                         continue;
812
813                 algname = aes_algs[i].base.cra_name + 2;
814                 drvname = aes_algs[i].base.cra_driver_name + 2;
815                 basename = aes_algs[i].base.cra_driver_name;
816                 simd = simd_skcipher_create_compat(algname, drvname, basename);
817                 err = PTR_ERR(simd);
818                 if (IS_ERR(simd))
819                         goto unregister_simds;
820
821                 aes_simd_algs[i] = simd;
822         }
823
824         return 0;
825
826 unregister_simds:
827         aes_exit();
828         return err;
829 unregister_ciphers:
830         crypto_unregister_skciphers(aes_algs, ARRAY_SIZE(aes_algs));
831         return err;
832 }
833
834 #ifdef USE_V8_CRYPTO_EXTENSIONS
835 module_cpu_feature_match(AES, aes_init);
836 #else
837 module_init(aes_init);
838 EXPORT_SYMBOL(neon_aes_ecb_encrypt);
839 EXPORT_SYMBOL(neon_aes_cbc_encrypt);
840 #endif
841 module_exit(aes_exit);