diff --git a/bit-wrap/src/codegen.rs b/bit-wrap/src/codegen.rs index b2b136a..ae7e700 100644 --- a/bit-wrap/src/codegen.rs +++ b/bit-wrap/src/codegen.rs @@ -175,7 +175,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { quote! { /// Create a new [`Self`] from a key value #[inline] - #[macro_use] + #[must_use] pub fn new(key: #inner) -> Self { key.into() } @@ -184,7 +184,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { quote! { /// Create a new [`Self`] from a key value #[inline] - #[macro_use] + #[must_use] pub fn new(key: #inner) -> Result { key.try_into() } @@ -205,14 +205,14 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { #new_method /// Convert to hex string (formatted for the bit width) - #[macro_use] + #[must_use] pub fn to_hex(self) -> String { let value = self.0 & Self::MAX; format!("{:0width$X}", value, width = #hex_width) } /// Convert to binary string (full bit width with leading zeros) - #[macro_use] + #[must_use] pub fn to_bin(self) -> String { let value = self.0 & Self::MAX; format!("{:0width$b}", value, width = #bin_width) @@ -220,27 +220,27 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { /// Check if all bits are set #[inline] - #[macro_use] + #[must_use] pub const fn all(self) -> bool { self.0 == Self::MAX } /// Check if any bit is set #[inline] - #[macro_use] + #[must_use] pub const fn any(self) -> bool { self.0 != 0 } /// Count the number of set bits #[inline] - #[macro_use] + #[must_use] pub const fn count_ones(self) -> u32 { self.0.count_ones() } /// Count the number of zero bits within the constrained bit width - #[macro_use] + #[must_use] pub const fn count_zeros(self) -> u32 { let value = self.0 as #inner & Self::MAX; let ones_count = self.count_ones(); @@ -248,7 +248,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { } /// Reverse the bit pattern within the constrained bit width - #[macro_use] + #[must_use] pub const fn reverse_bits(self) -> Self { let value = self.0 as #inner & Self::MAX; let reversed = value.reverse_bits(); @@ -260,7 +260,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { /// Rotate the bits left by `n` positions within the bit width - #[macro_use] + #[must_use] pub const fn rotate_left(self, n: u8) -> Self { let n = n % Self::BIT_WIDTH; if n == 0 { @@ -274,7 +274,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { } /// Rotate the bits right by `n` positions within the bit width - #[macro_use] + #[must_use] pub const fn rotate_right(self, n: u8) -> Self { let n = n % Self::BIT_WIDTH; if n == 0 { @@ -288,7 +288,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { } /// Find the number of leading zero bits within the bit width - #[macro_use] + #[must_use] pub fn leading_zeros(self) -> u32 { let value = self.0 as #inner & Self::MAX; let bit_width = Self::BIT_WIDTH as u32; @@ -301,7 +301,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { } /// Find the number of trailing zero bits within the bit width - #[macro_use] + #[must_use] pub fn trailing_zeros(self) -> u32 { let value = self.0 as #inner & Self::MAX; let reversed = self.reverse_bits(); @@ -309,7 +309,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { } /// Find the number of leading one bits from the MSB within the bit width - #[macro_use] + #[must_use] pub fn leading_ones(self) -> u32 { let value = self.0 as #inner & Self::MAX; let not_value = !value & Self::MAX; @@ -323,7 +323,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { } /// Find the number of trailing one bits from the LSB within the bit width - #[macro_use] + #[must_use] pub fn trailing_ones(self) -> u32 { let value = self.0 as #inner & Self::MAX; let reversed = self.reverse_bits(); @@ -331,13 +331,13 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { } /// Check if the value is zero within the bit width - #[macro_use] + #[must_use] pub const fn is_zero(self) -> bool { (self.0 & Self::MAX) == 0 } /// Set a specific bit to 1 (clamped to bit width) - #[macro_use] + #[must_use] pub const fn bit_is_set(self, bit: u8) -> bool { let clamped = self.clamp_bit_positions(bit); let mask = 1<< clamped; @@ -364,7 +364,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { } /// Check if bits in the given mask are set (within bit width) - #[macro_use] + #[must_use] pub const fn contains_mask(self, mask: Self) -> bool { let self_value = self.0 & Self::MAX; let mask_value = mask.0 & Self::MAX; @@ -384,7 +384,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { } /// Extract a bitfield from the constrained range - #[macro_use] + #[must_use] pub const fn bitfield(self, low: u8, high: u8) -> u64 { let low = self.clamp_bit_positions(low); let high = self.clamp_bit_positions(high); @@ -408,7 +408,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream { self.0 = (self.0 & !mask) | ((value << low) & mask); } - #[macro_use] + #[must_use] const fn clamp_bit_positions(self, bit: u8) -> u8 { if bit >= Self::BIT_WIDTH { return Self::BIT_WIDTH - 1; diff --git a/des-lib/src/blocks/mod.rs b/des-lib/src/blocks/mod.rs index e69de29..916ac45 100644 --- a/des-lib/src/blocks/mod.rs +++ b/des-lib/src/blocks/mod.rs @@ -0,0 +1,5 @@ +use bit_wrap::BitWrapper; + +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, BitWrapper)] +#[bit_width(64)] +pub struct Block(u64); diff --git a/des-lib/src/error.rs b/des-lib/src/error.rs index e69de29..7a5c58d 100644 --- a/des-lib/src/error.rs +++ b/des-lib/src/error.rs @@ -0,0 +1,61 @@ +use thiserror::Error; + +use crate::keys::{key::KeyError, subkey::SubkeyError}; + +#[derive(Debug, Error)] +pub enum DesError { + /// Key value exceeds the maximum allowed value for the bit width + #[error("Key value {value} exceeds maximum {max} for {width}-bit type")] + KeyOutOfRange { + value: u64, // Raw value that was too large + max: u64, // Maximum allowed value (2^bit_width - 1) + width: u8, // Bit width of the key type + }, + + /// Failed to parse a hex or binary string representation + #[error("Failed to parse key string: {0}")] + ParseError(#[from] std::num::ParseIntError), + + /// Failed to parse from a string with invalid format + #[error("Invalid key format: {0}")] + InvalidFormat(String), + + /// Bitfield operation with invalid range (high < low) + #[error("Invalid bitfield range: low={low}, high={high} (must have low <= high)")] + InvalidBitfieldRange { low: u8, high: u8 }, + + /// Attempted to set a bit beyond the valid bit width + #[error("Bit index {bit} out of range for {width}-bit type")] + InvalidBitIndex { bit: u8, width: u8 }, + + #[error("Unknown error: {0}")] + Unknown(String), +} + +impl DesError { + pub fn unknown(input: impl Into) -> Self { + Self::Unknown(input.into()) + } +} + +macro_rules! impl_from_key_error_for_des { + ($error_type:ty) => { + impl From<$error_type> for DesError { + fn from(error: $error_type) -> Self { + type Input = $error_type; + match error { + Input::ValueOutOfRange { value, max, width } => Self::KeyOutOfRange { + value: value as u64, + max: max as u64, + width, + }, + Input::ParseError(err) => Self::ParseError(err), + Input::Unknown(msg) => Self::Unknown(msg), + } + } + } + }; +} + +impl_from_key_error_for_des!(SubkeyError); +impl_from_key_error_for_des!(KeyError); diff --git a/des-lib/src/keys/mod.rs b/des-lib/src/keys/mod.rs index 2d7bd0c..43ab313 100644 --- a/des-lib/src/keys/mod.rs +++ b/des-lib/src/keys/mod.rs @@ -1,5 +1,2 @@ -mod key; -mod subkey; - -pub use key::Key; -pub use subkey::Subkey; +pub mod key; +pub mod subkey; diff --git a/des-lib/src/keys/subkey.rs b/des-lib/src/keys/subkey.rs index 0198fab..00e2776 100644 --- a/des-lib/src/keys/subkey.rs +++ b/des-lib/src/keys/subkey.rs @@ -1,5 +1,47 @@ +use crate::keys::key::Key; use bit_wrap::BitWrapper; #[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, BitWrapper)] #[bit_width(48)] pub struct Subkey(u64); + +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)] +pub struct Subkeys([Subkey; 16]); + +impl From for Subkey { + fn from(_key: Key) -> Self { + todo!("when other functions are moved to type methods, imlmenet this"); + } +} + +impl TryFrom<[u64; 16]> for Subkeys { + type Error = SubkeyError; + fn try_from(keys: [u64; 16]) -> Result { + let mut subkeys = [Subkey::default(); 16]; + + for (idx, &key) in keys.iter().enumerate() { + let subkey = Subkey::try_from(key)?; + subkeys[idx] = subkey; + } + + Ok(Subkeys(subkeys)) + } +} + +impl From for [Subkey; 16] { + fn from(subkeys: Subkeys) -> Self { + subkeys.0 + } +} + +impl AsRef<[Subkey; 16]> for Subkeys { + fn as_ref(&self) -> &[Subkey; 16] { + &self.0 + } +} + +impl PartialEq for Subkey { + fn eq(&self, other: &u64) -> bool { + &self.0 == other + } +} diff --git a/des-lib/src/lib.rs b/des-lib/src/lib.rs index c9e445b..74f59ea 100644 --- a/des-lib/src/lib.rs +++ b/des-lib/src/lib.rs @@ -1,39 +1,43 @@ -mod blocks; -mod constants; -mod error; -mod keys; +pub mod blocks; +pub mod constants; +pub mod error; +pub mod keys; -use crate::constants::{E_BOX, FP, IP, P_BOX, PC1_TABLE, PC2_TABLE, ROUND_ROTATIONS, S_BOXES}; -pub use keys::{Key, Subkey}; +use crate::{ + blocks::Block, + constants::{E_BOX, FP, IP, P_BOX, PC1_TABLE, PC2_TABLE, ROUND_ROTATIONS, S_BOXES}, + error::DesError, + keys::{key::Key, subkey::Subkey}, +}; #[derive(Debug)] pub struct Des { - pub subkeys: [u64; 16], + pub subkeys: [Subkey; 16], } impl Des { /// Create a new DES instance from a 64-bit key (8 bytes). #[must_use] - pub fn new(key: u64) -> Self { - let subkeys = generate_subkeys(key); + pub fn new(key: impl Into) -> Self { + let subkeys = generate_subkeys(key.into()); Self { subkeys } } /// Encrypt a 64-bit block. #[must_use] - pub fn encrypt(&self, block: u64) -> u64 { - self.des(block, true) + pub fn encrypt(&self, block: impl Into) -> u64 { + self.des(block.into(), true) } /// Decrypt a 64-bit block. #[must_use] - pub fn decrypt(&self, block: u64) -> u64 { - self.des(block, false) + pub fn decrypt(&self, block: impl Into) -> u64 { + self.des(block.into(), false) } /// Core DES function: encrypt if forward=true, else decrypt. #[must_use] - fn des(&self, block: u64, forward: bool) -> u64 { + fn des(&self, block: Block, forward: bool) -> u64 { let permutated_block = ip(block); let (left, right) = if forward { @@ -54,16 +58,18 @@ impl Des { /// Accounts for DES specification's big-endian bit numbering (1-64, MSB first) /// versus Rust u64's little-endian bit numbering (0-63, LSB first). #[must_use] -pub fn pc1(key: u64) -> u64 { - permutate(key, 64, 56, &PC1_TABLE) +pub fn pc1(key: impl Into) -> u64 { + permutate(key.into(), 64, 56, &PC1_TABLE) } /// Compression permuation /// Reduces 56-bits to 48-bit key #[must_use] -pub fn pc2(key: u64) -> u64 { +pub fn pc2(key: u64) -> Result { let key_56 = key & 0x00FF_FFFF_FFFF_FFFF; permutate(key_56, 56, 48, &PC2_TABLE) + .try_into() + .map_err(DesError::from) } #[must_use] @@ -108,7 +114,7 @@ fn concatenate_halves(left: u32, right: u32, bit_offset: u8) -> u64 { } /// Generate 16 subkeys from the 64-bit key. -fn generate_subkeys(key: u64) -> [u64; 16] { +fn generate_subkeys(key: Key) -> [Subkey; 16] { let reduced_key = pc1(key); // C_0, D_0 let (mut left, mut right) = split_block(reduced_key); @@ -120,7 +126,8 @@ fn generate_subkeys(key: u64) -> [u64; 16] { let combined = concatenate_halves(left, right, 28); pc2(combined) }) - .collect::>() + .collect::, DesError>>() + .expect("no errors") .try_into() .expect("Exactly 16 subkeys expected") } @@ -133,7 +140,13 @@ fn generate_subkeys(key: u64) -> [u64; 16] { /// - `output_bits` - Number of bits in the output (1-64) /// - `position_table` - 1-based positions (1 to `input_bits`) where each output bit comes from #[must_use] -fn permutate(input: u64, input_bits: u32, output_bits: u32, position_table: &[u8]) -> u64 { +fn permutate( + input: impl Into, + input_bits: u32, + output_bits: u32, + position_table: &[u8], +) -> u64 { + let input = input.into(); position_table .iter() .enumerate() @@ -159,8 +172,8 @@ fn permutate(input: u64, input_bits: u32, output_bits: u32, position_table: &[u8 #[inline] #[must_use] -fn ip(message: u64) -> u64 { - permutate(message, 64, 64, &IP) +fn ip(message: impl Into) -> u64 { + permutate(message.into(), 64, 64, &IP) } /// Expand the right side of the data from 32 bits to 48. @@ -202,7 +215,7 @@ pub fn fp(block: u64) -> u64 { /// Process 16 Feistel rounds for ECB encryption/decryption. #[must_use] -fn process_feistel_rounds(initial_block: u64, subkeys: &[u64]) -> (u32, u32) { +fn process_feistel_rounds(initial_block: u64, subkeys: &[Subkey]) -> (u32, u32) { let (mut left, mut right) = split_block(initial_block); for &subkey in subkeys { (left, right) = feistel(left, right, subkey); @@ -213,7 +226,7 @@ fn process_feistel_rounds(initial_block: u64, subkeys: &[u64]) -> (u32, u32) { /// Feistel function: Expand, XOR with subkey, S-box, permute. /// `R_i` = `L_(i-1)` XOR f(`R_(i-1)`, `K_1`) #[must_use] -fn feistel(left: u32, right: u32, subkey: u64) -> (u32, u32) { +fn feistel(left: u32, right: u32, subkey: Subkey) -> (u32, u32) { let function_output = f_function(right, subkey); let new_right = left ^ function_output; // L_i = R_(i-1) @@ -222,9 +235,9 @@ fn feistel(left: u32, right: u32, subkey: u64) -> (u32, u32) { } #[must_use] -fn f_function(right: u32, subkey: u64) -> u32 { +fn f_function(right: u32, subkey: Subkey) -> u32 { let expanded = expansion_permutation(right); - let xored = expanded ^ subkey; + let xored = expanded ^ subkey.as_ref(); let sboxed = s_box_substitution(xored); p_box_permutation(sboxed) } @@ -232,7 +245,7 @@ fn f_function(right: u32, subkey: u64) -> u32 { #[cfg(test)] mod tests { use super::*; - use claims::assert_ge; + use claims::{assert_ge, assert_ok}; use rstest::rstest; const TEST_KEY: u64 = 0x1334_5779_9BBC_DFF1; @@ -287,14 +300,9 @@ mod tests { #[case(0x00F8_6655_7AAB_33C7, 0xBF91_8D3D_3F0A)] // K_15 #[case(0x00F0_CCAA_F556_678F, 0xCB3D_8B0E_17F5)] // K_16 fn pc2_permutaion_correct(#[case] before: u64, #[case] after: u64) { - let result = pc2(before); + let result = assert_ok!(pc2(before)); - assert_eq!(result, after, "PC2 permutation failed"); - assert_ge!( - result.leading_zeros(), - 16, - "PC2 result should have leading 16 bits as 0" - ); + assert_eq!(*result.as_ref(), after, "PC2 permutation failed"); } #[test] diff --git a/des-lib/tests/des.rs b/des-lib/tests/des.rs index e9fb38a..1fe90bc 100644 --- a/des-lib/tests/des.rs +++ b/des-lib/tests/des.rs @@ -1,7 +1,6 @@ +use des_lib::Des; use rstest::rstest; -use des::Des; - const TEST_KEY: u64 = 0x1334_5779_9BBC_DFF1; const TEST_PLAINTEXT: u64 = 0x0123_4567_89AB_CDEF; const TEST_CIPHERTEXT: u64 = 0x85E8_1354_0F0A_B405; @@ -14,9 +13,9 @@ fn des_instance() -> Des { #[test] fn test_ecb_mode_equivalence() { // If you implement ECB mode, test it matches single block - let key = 0x1334_5779_9BBC_DFF1; + let key = 0x1334_5779_9BBC_DFF1u64; let des = Des::new(key); - let plain = 0x0123_4567_89AB_CDEF; + let plain = 0x0123_4567_89AB_CDEFu64; let _single_block = des.encrypt(plain); // let ecb_result = encrypt_ecb(&[plain]); @@ -52,7 +51,11 @@ fn encrypt_decrypt_roundtrip( #[test] fn weak_keys_rejected() { - let weak_keys = [0x0101010101010101, 0xFEFEFEFEFEFEFEFE, 0xE001E001E001E001]; + let weak_keys = [ + 0x0101010101010101u64, + 0xFEFEFEFEFEFEFEFE, + 0xE001E001E001E001, + ]; for key in weak_keys { let des = Des::new(key); @@ -90,8 +93,8 @@ fn all_one_paintext() { fn different_inputs() { let des = des_instance(); - let plain1 = 1; - let plain2 = 2; + let plain1 = 1u64; + let plain2 = 2u64; let enc1 = des.encrypt(plain1); let enc2 = des.encrypt(plain2); assert_ne!( diff --git a/des-lib/tests/key_schedule.rs b/des-lib/tests/key_schedule.rs index 08daa64..04ba2b6 100644 --- a/des-lib/tests/key_schedule.rs +++ b/des-lib/tests/key_schedule.rs @@ -1,4 +1,4 @@ -use des::Des; +use des_lib::Des; // Full expected subkeys for TEST_KEY (48 bits each, from FIPS spec) const EXPECTED_SUBKEYS: [u64; 16] = [