From 70a0e183b5b548926f4b7a99fd67a50bacb3233d Mon Sep 17 00:00:00 2001 From: Kristofers Solo Date: Sun, 23 Nov 2025 18:24:12 +0200 Subject: [PATCH] feat(aes): Implement ShiftRows and MixColumns transformations - Implements `shift_rows` performing cyclic shifts on the state matrix rows. - Implements `mix_columns` using Galois Field matrix multiplication. - Adds `gmul` and `xtime` const helpers for GF(2^8) arithmetic. - Adds unit tests verifying transformations against FIPS-197 vectors. --- aes/src/block/block128.rs | 135 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 135 insertions(+) diff --git a/aes/src/block/block128.rs b/aes/src/block/block128.rs index 33474a8..9e739d4 100644 --- a/aes/src/block/block128.rs +++ b/aes/src/block/block128.rs @@ -62,6 +62,91 @@ impl Block128 { Block32::from_u32(val as u32), ] } + #[must_use] + pub const fn shift_rows(self) -> Self { + let b = self.to_be_bytes(); + let mut out = [0u8; 16]; + + // Row 0: No shift (Indices 0, 4, 8, 12) + out[0] = b[0]; + out[4] = b[4]; + out[8] = b[8]; + out[12] = b[12]; + + // Row 1: Shift left 1 (Indices 1, 5, 9, 13 -> 5, 9, 13, 1) + out[1] = b[5]; + out[5] = b[9]; + out[9] = b[13]; + out[13] = b[1]; + + // Row 2: Shift left 2 (Indices 2, 6, 10, 14 -> 10, 14, 2, 6) + out[2] = b[10]; + out[6] = b[14]; + out[10] = b[2]; + out[14] = b[6]; + + // Row 3: Shift left 3 (Indices 3, 7, 11, 15 -> 15, 3, 7, 11) + out[3] = b[15]; + out[7] = b[3]; + out[11] = b[7]; + out[15] = b[11]; + + Self::from_be_bytes(out) + } + + #[must_use] + pub fn mix_columns(self) -> Self { + let mut bytes = self.to_be_bytes(); + + for col in 0..4 { + let offset = col * 4; + + let c0 = bytes[offset]; + let c1 = bytes[offset + 1]; + let c2 = bytes[offset + 2]; + let c3 = bytes[offset + 3]; + + // Matrix multiplication over GF(2^8): + // [d0] [2 3 1 1] [c0] + // [d1] = [1 2 3 1] [c1] + // [d2] [1 1 2 3] [c2] + // [d3] [3 1 1 2] [c3] + + bytes[offset] = gmul(c0, 2) ^ gmul(c1, 3) ^ c2 ^ c3; + bytes[offset + 1] = c0 ^ gmul(c1, 2) ^ gmul(c2, 3) ^ c3; + bytes[offset + 2] = c0 ^ c1 ^ gmul(c2, 2) ^ gmul(c3, 3); + bytes[offset + 3] = gmul(c0, 3) ^ c1 ^ c2 ^ gmul(c3, 2); + } + + Self::from_be_bytes(bytes) + } +} + +/// Galois Field multiplication by 2 (xtime). +/// If the high bit is set, XOR with the irreducible polynomial 0x1B. +const fn xtime(x: u8) -> u8 { + if x & 0x80 != 0 { + return (x << 1) ^ 0x1b; + } + x << 1 +} + +/// General Galois Field multiplication. +/// Implemented using "peasant's algorithm" (shift and add). +const fn gmul(mut a: u8, mut b: u8) -> u8 { + let mut p = 0; + let mut i = 0; + + // Unrolled loop for const context + while i < 8 { + if (b & 1) != 0 { + p ^= a; + } + a = xtime(a); + b >>= 1; + i += 1; + } + p } impl FromStr for Block128 { @@ -147,3 +232,53 @@ impl From<&Block128> for Vec { value.to_be_bytes().to_vec() } } + +#[cfg(test)] +mod tests { + use super::*; + use rstest::rstest; + + #[rstest] + #[case( + 0x63CA_B704_0953_D051_CD60_E0E7_BA70_E18C, + 0x6353_E08C_0960_E104_CD70_B751_BACA_D0E7 + )] + fn shift_rows(#[case] input: u128, #[case] expected: u128) { + let block = Block128::new(input); + let result = block.shift_rows().as_u128(); + assert_eq!( + result, expected, + "Shift Rows failed. Expected 0x{expected:032X}, got 0x{result:032X}", + ); + } + + #[rstest] + #[case( + 0x6353_E08C_0960_E104_CD70_B751_BACA_D0E7, + 0x5F72_6415_57F5_BC92_F7BE_3B29_1DB9_F91A + )] + #[case( + 0xD4BF_5D30_D4BF_5D30_D4BF_5D30_D4BF_5D30, + 0x0466_81E5_0466_81E5_0466_81E5_0466_81E5 + )] + fn mix_columns(#[case] input: u128, #[case] expected: u128) { + let block = Block128::new(input); + let result = block.mix_columns().as_u128(); + assert_eq!( + result, expected, + "Mix Columns failed. Expected 0x{expected:032X}, got 0x{result:032X}", + ); + } + + #[rstest] + #[case(0x57, 0x13, 0xFE)] // Example from FIPS-197 4.2.1 + #[case(0x57, 0x01, 0x57)] // Identity + #[case(0x57, 0x02, 0xAE)] // x2 (xtime) + #[case(0x57, 0x04, 0x47)] // x4 + #[case(0x57, 0x08, 0x8E)] // x8 + #[case(0x57, 0x10, 0x07)] // x16 + fn galois_multiplication(#[case] a: u8, #[case] b: u8, #[case] expected: u8) { + let res = gmul(a, b); + assert_eq!(res, expected, "gmul({a:02x}, {b:02x}) failed"); + } +}