colibri_imx6: fix video stdout in default environment
[oweals/u-boot.git] / lib / aes.c
index 05c97cd740982f38a57a069507494372adf28e96..c998aecb3c74503a8fb70c50f4234f9d57621b3e 100644 (file)
--- a/lib/aes.c
+++ b/lib/aes.c
@@ -1,8 +1,7 @@
+// SPDX-License-Identifier: GPL-2.0+
 /*
  * Copyright (c) 2011 The Chromium OS Authors.
  * (C) Copyright 2011 NVIDIA Corporation www.nvidia.com
- *
- * SPDX-License-Identifier:    GPL-2.0+
  */
 
 /*
 
 #ifndef USE_HOSTCC
 #include <common.h>
+#include <log.h>
 #else
 #include <string.h>
 #endif
-#include "aes.h"
+#include "uboot_aes.h"
 
 /* forward s-box */
 static const u8 sbox[256] = {
@@ -509,50 +509,79 @@ static u8 rcon[11] = {
        0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80, 0x1b, 0x36
 };
 
+static u32 aes_get_rounds(u32 key_len)
+{
+       u32 rounds = AES128_ROUNDS;
+
+       if (key_len == AES192_KEY_LENGTH)
+               rounds = AES192_ROUNDS;
+       else if (key_len == AES256_KEY_LENGTH)
+               rounds = AES256_ROUNDS;
+
+       return rounds;
+}
+
+static u32 aes_get_keycols(u32 key_len)
+{
+       u32 keycols = AES128_KEYCOLS;
+
+       if (key_len == AES192_KEY_LENGTH)
+               keycols = AES192_KEYCOLS;
+       else if (key_len == AES256_KEY_LENGTH)
+               keycols = AES256_KEYCOLS;
+
+       return keycols;
+}
+
 /* produce AES_STATECOLS bytes for each round */
-void aes_expand_key(u8 *key, u8 *expkey)
+void aes_expand_key(u8 *key, u32 key_len, u8 *expkey)
 {
        u8 tmp0, tmp1, tmp2, tmp3, tmp4;
-       u32 idx;
+       u32 idx, aes_rounds, aes_keycols;
+
+       aes_rounds = aes_get_rounds(key_len);
+       aes_keycols = aes_get_keycols(key_len);
 
-       memcpy(expkey, key, AES_KEYCOLS * 4);
+       memcpy(expkey, key, key_len);
 
-       for (idx = AES_KEYCOLS; idx < AES_STATECOLS * (AES_ROUNDS + 1); idx++) {
+       for (idx = aes_keycols; idx < AES_STATECOLS * (aes_rounds + 1); idx++) {
                tmp0 = expkey[4*idx - 4];
                tmp1 = expkey[4*idx - 3];
                tmp2 = expkey[4*idx - 2];
                tmp3 = expkey[4*idx - 1];
-               if (!(idx % AES_KEYCOLS)) {
+               if (!(idx % aes_keycols)) {
                        tmp4 = tmp3;
                        tmp3 = sbox[tmp0];
-                       tmp0 = sbox[tmp1] ^ rcon[idx / AES_KEYCOLS];
+                       tmp0 = sbox[tmp1] ^ rcon[idx / aes_keycols];
                        tmp1 = sbox[tmp2];
                        tmp2 = sbox[tmp4];
-               } else if ((AES_KEYCOLS > 6) && (idx % AES_KEYCOLS == 4)) {
+               } else if ((aes_keycols > 6) && (idx % aes_keycols == 4)) {
                        tmp0 = sbox[tmp0];
                        tmp1 = sbox[tmp1];
                        tmp2 = sbox[tmp2];
                        tmp3 = sbox[tmp3];
                }
 
-               expkey[4*idx+0] = expkey[4*idx - 4*AES_KEYCOLS + 0] ^ tmp0;
-               expkey[4*idx+1] = expkey[4*idx - 4*AES_KEYCOLS + 1] ^ tmp1;
-               expkey[4*idx+2] = expkey[4*idx - 4*AES_KEYCOLS + 2] ^ tmp2;
-               expkey[4*idx+3] = expkey[4*idx - 4*AES_KEYCOLS + 3] ^ tmp3;
+               expkey[4*idx+0] = expkey[4*idx - 4*aes_keycols + 0] ^ tmp0;
+               expkey[4*idx+1] = expkey[4*idx - 4*aes_keycols + 1] ^ tmp1;
+               expkey[4*idx+2] = expkey[4*idx - 4*aes_keycols + 2] ^ tmp2;
+               expkey[4*idx+3] = expkey[4*idx - 4*aes_keycols + 3] ^ tmp3;
        }
 }
 
 /* encrypt one 128 bit block */
-void aes_encrypt(u8 *in, u8 *expkey, u8 *out)
+void aes_encrypt(u32 key_len, u8 *in, u8 *expkey, u8 *out)
 {
        u8 state[AES_STATECOLS * 4];
-       u32 round;
+       u32 round, aes_rounds;
+
+       aes_rounds = aes_get_rounds(key_len);
 
        memcpy(state, in, AES_STATECOLS * 4);
        add_round_key((u32 *)state, (u32 *)expkey);
 
-       for (round = 1; round < AES_ROUNDS + 1; round++) {
-               if (round < AES_ROUNDS)
+       for (round = 1; round < aes_rounds + 1; round++) {
+               if (round < aes_rounds)
                        mix_sub_columns(state);
                else
                        shift_rows(state);
@@ -564,18 +593,20 @@ void aes_encrypt(u8 *in, u8 *expkey, u8 *out)
        memcpy(out, state, sizeof(state));
 }
 
-void aes_decrypt(u8 *in, u8 *expkey, u8 *out)
+void aes_decrypt(u32 key_len, u8 *in, u8 *expkey, u8 *out)
 {
        u8 state[AES_STATECOLS * 4];
-       int round;
+       int round, aes_rounds;
+
+       aes_rounds = aes_get_rounds(key_len);
 
        memcpy(state, in, sizeof(state));
 
        add_round_key((u32 *)state,
-                     (u32 *)expkey + AES_ROUNDS * AES_STATECOLS);
+                     (u32 *)expkey + aes_rounds * AES_STATECOLS);
        inv_shift_rows(state);
 
-       for (round = AES_ROUNDS; round--; ) {
+       for (round = aes_rounds; round--; ) {
                add_round_key((u32 *)state,
                              (u32 *)expkey + round * AES_STATECOLS);
                if (round)
@@ -593,74 +624,66 @@ static void debug_print_vector(char *name, u32 num_bytes, u8 *data)
 #endif
 }
 
-/**
- * Apply chain data to the destination using EOR
- *
- * Each array is of length AES_KEY_LENGTH.
- *
- * @cbc_chain_data     Chain data
- * @src                        Source data
- * @dst                        Destination data, which is modified here
- */
-static void apply_cbc_chain_data(u8 *cbc_chain_data, u8 *src, u8 *dst)
+void aes_apply_cbc_chain_data(u8 *cbc_chain_data, u8 *src, u8 *dst)
 {
        int i;
 
-       for (i = 0; i < AES_KEY_LENGTH; i++)
+       for (i = 0; i < AES_BLOCK_LENGTH; i++)
                *dst++ = *src++ ^ *cbc_chain_data++;
 }
 
-void aes_cbc_encrypt_blocks(u8 *key_exp, u8 *src, u8 *dst, u32 num_aes_blocks)
+void aes_cbc_encrypt_blocks(u32 key_len, u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
+                           u32 num_aes_blocks)
 {
-       u8 zero_key[AES_KEY_LENGTH] = { 0 };
-       u8 tmp_data[AES_KEY_LENGTH];
-       /* Convenient array of 0's for IV */
-       u8 *cbc_chain_data = zero_key;
+       u8 tmp_data[AES_BLOCK_LENGTH];
+       u8 *cbc_chain_data = iv;
        u32 i;
 
        for (i = 0; i < num_aes_blocks; i++) {
                debug("encrypt_object: block %d of %d\n", i, num_aes_blocks);
-               debug_print_vector("AES Src", AES_KEY_LENGTH, src);
+               debug_print_vector("AES Src", AES_BLOCK_LENGTH, src);
 
                /* Apply the chain data */
-               apply_cbc_chain_data(cbc_chain_data, src, tmp_data);
-               debug_print_vector("AES Xor", AES_KEY_LENGTH, tmp_data);
+               aes_apply_cbc_chain_data(cbc_chain_data, src, tmp_data);
+               debug_print_vector("AES Xor", AES_BLOCK_LENGTH, tmp_data);
 
                /* Encrypt the AES block */
-               aes_encrypt(tmp_data, key_exp, dst);
-               debug_print_vector("AES Dst", AES_KEY_LENGTH, dst);
+               aes_encrypt(key_len, tmp_data, key_exp, dst);
+               debug_print_vector("AES Dst", AES_BLOCK_LENGTH, dst);
 
                /* Update pointers for next loop. */
                cbc_chain_data = dst;
-               src += AES_KEY_LENGTH;
-               dst += AES_KEY_LENGTH;
+               src += AES_BLOCK_LENGTH;
+               dst += AES_BLOCK_LENGTH;
        }
 }
 
-void aes_cbc_decrypt_blocks(u8 *key_exp, u8 *src, u8 *dst, u32 num_aes_blocks)
+void aes_cbc_decrypt_blocks(u32 key_len, u8 *key_exp, u8 *iv, u8 *src, u8 *dst,
+                           u32 num_aes_blocks)
 {
-       u8 tmp_data[AES_KEY_LENGTH], tmp_block[AES_KEY_LENGTH];
+       u8 tmp_data[AES_BLOCK_LENGTH], tmp_block[AES_BLOCK_LENGTH];
        /* Convenient array of 0's for IV */
-       u8 cbc_chain_data[AES_KEY_LENGTH] = { 0 };
+       u8 cbc_chain_data[AES_BLOCK_LENGTH];
        u32 i;
 
+       memcpy(cbc_chain_data, iv, AES_BLOCK_LENGTH);
        for (i = 0; i < num_aes_blocks; i++) {
                debug("encrypt_object: block %d of %d\n", i, num_aes_blocks);
-               debug_print_vector("AES Src", AES_KEY_LENGTH, src);
+               debug_print_vector("AES Src", AES_BLOCK_LENGTH, src);
 
-               memcpy(tmp_block, src, AES_KEY_LENGTH);
+               memcpy(tmp_block, src, AES_BLOCK_LENGTH);
 
                /* Decrypt the AES block */
-               aes_decrypt(src, key_exp, tmp_data);
-               debug_print_vector("AES Xor", AES_KEY_LENGTH, tmp_data);
+               aes_decrypt(key_len, src, key_exp, tmp_data);
+               debug_print_vector("AES Xor", AES_BLOCK_LENGTH, tmp_data);
 
                /* Apply the chain data */
-               apply_cbc_chain_data(cbc_chain_data, tmp_data, dst);
-               debug_print_vector("AES Dst", AES_KEY_LENGTH, dst);
+               aes_apply_cbc_chain_data(cbc_chain_data, tmp_data, dst);
+               debug_print_vector("AES Dst", AES_BLOCK_LENGTH, dst);
 
                /* Update pointers for next loop. */
-               memcpy(cbc_chain_data, tmp_block, AES_KEY_LENGTH);
-               src += AES_KEY_LENGTH;
-               dst += AES_KEY_LENGTH;
+               memcpy(cbc_chain_data, tmp_block, AES_BLOCK_LENGTH);
+               src += AES_BLOCK_LENGTH;
+               dst += AES_BLOCK_LENGTH;
        }
 }