feat: impl Subkey and Key types

This commit is contained in:
Kristofers Solo 2025-10-04 19:55:35 +03:00
parent ab1c0e4ad9
commit f0b9acbc9d
Signed by: kristoferssolo
GPG Key ID: 8687F2D3EEE6F0ED
8 changed files with 183 additions and 67 deletions

View File

@ -175,7 +175,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
quote! {
/// Create a new [`Self`] from a key value
#[inline]
#[macro_use]
#[must_use]
pub fn new(key: #inner) -> Self {
key.into()
}
@ -184,7 +184,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
quote! {
/// Create a new [`Self`] from a key value
#[inline]
#[macro_use]
#[must_use]
pub fn new(key: #inner) -> Result<Self, #error_type> {
key.try_into()
}
@ -205,14 +205,14 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
#new_method
/// Convert to hex string (formatted for the bit width)
#[macro_use]
#[must_use]
pub fn to_hex(self) -> String {
let value = self.0 & Self::MAX;
format!("{:0width$X}", value, width = #hex_width)
}
/// Convert to binary string (full bit width with leading zeros)
#[macro_use]
#[must_use]
pub fn to_bin(self) -> String {
let value = self.0 & Self::MAX;
format!("{:0width$b}", value, width = #bin_width)
@ -220,27 +220,27 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
/// Check if all bits are set
#[inline]
#[macro_use]
#[must_use]
pub const fn all(self) -> bool {
self.0 == Self::MAX
}
/// Check if any bit is set
#[inline]
#[macro_use]
#[must_use]
pub const fn any(self) -> bool {
self.0 != 0
}
/// Count the number of set bits
#[inline]
#[macro_use]
#[must_use]
pub const fn count_ones(self) -> u32 {
self.0.count_ones()
}
/// Count the number of zero bits within the constrained bit width
#[macro_use]
#[must_use]
pub const fn count_zeros(self) -> u32 {
let value = self.0 as #inner & Self::MAX;
let ones_count = self.count_ones();
@ -248,7 +248,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
}
/// Reverse the bit pattern within the constrained bit width
#[macro_use]
#[must_use]
pub const fn reverse_bits(self) -> Self {
let value = self.0 as #inner & Self::MAX;
let reversed = value.reverse_bits();
@ -260,7 +260,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
/// Rotate the bits left by `n` positions within the bit width
#[macro_use]
#[must_use]
pub const fn rotate_left(self, n: u8) -> Self {
let n = n % Self::BIT_WIDTH;
if n == 0 {
@ -274,7 +274,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
}
/// Rotate the bits right by `n` positions within the bit width
#[macro_use]
#[must_use]
pub const fn rotate_right(self, n: u8) -> Self {
let n = n % Self::BIT_WIDTH;
if n == 0 {
@ -288,7 +288,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
}
/// Find the number of leading zero bits within the bit width
#[macro_use]
#[must_use]
pub fn leading_zeros(self) -> u32 {
let value = self.0 as #inner & Self::MAX;
let bit_width = Self::BIT_WIDTH as u32;
@ -301,7 +301,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
}
/// Find the number of trailing zero bits within the bit width
#[macro_use]
#[must_use]
pub fn trailing_zeros(self) -> u32 {
let value = self.0 as #inner & Self::MAX;
let reversed = self.reverse_bits();
@ -309,7 +309,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
}
/// Find the number of leading one bits from the MSB within the bit width
#[macro_use]
#[must_use]
pub fn leading_ones(self) -> u32 {
let value = self.0 as #inner & Self::MAX;
let not_value = !value & Self::MAX;
@ -323,7 +323,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
}
/// Find the number of trailing one bits from the LSB within the bit width
#[macro_use]
#[must_use]
pub fn trailing_ones(self) -> u32 {
let value = self.0 as #inner & Self::MAX;
let reversed = self.reverse_bits();
@ -331,13 +331,13 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
}
/// Check if the value is zero within the bit width
#[macro_use]
#[must_use]
pub const fn is_zero(self) -> bool {
(self.0 & Self::MAX) == 0
}
/// Set a specific bit to 1 (clamped to bit width)
#[macro_use]
#[must_use]
pub const fn bit_is_set(self, bit: u8) -> bool {
let clamped = self.clamp_bit_positions(bit);
let mask = 1<< clamped;
@ -364,7 +364,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
}
/// Check if bits in the given mask are set (within bit width)
#[macro_use]
#[must_use]
pub const fn contains_mask(self, mask: Self) -> bool {
let self_value = self.0 & Self::MAX;
let mask_value = mask.0 & Self::MAX;
@ -384,7 +384,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
}
/// Extract a bitfield from the constrained range
#[macro_use]
#[must_use]
pub const fn bitfield(self, low: u8, high: u8) -> u64 {
let low = self.clamp_bit_positions(low);
let high = self.clamp_bit_positions(high);
@ -408,7 +408,7 @@ fn generate_bitwise_ops(info: &Struct) -> TokenStream {
self.0 = (self.0 & !mask) | ((value << low) & mask);
}
#[macro_use]
#[must_use]
const fn clamp_bit_positions(self, bit: u8) -> u8 {
if bit >= Self::BIT_WIDTH {
return Self::BIT_WIDTH - 1;

View File

@ -0,0 +1,5 @@
use bit_wrap::BitWrapper;
#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, BitWrapper)]
#[bit_width(64)]
pub struct Block(u64);

View File

@ -0,0 +1,61 @@
use thiserror::Error;
use crate::keys::{key::KeyError, subkey::SubkeyError};
#[derive(Debug, Error)]
pub enum DesError {
/// Key value exceeds the maximum allowed value for the bit width
#[error("Key value {value} exceeds maximum {max} for {width}-bit type")]
KeyOutOfRange {
value: u64, // Raw value that was too large
max: u64, // Maximum allowed value (2^bit_width - 1)
width: u8, // Bit width of the key type
},
/// Failed to parse a hex or binary string representation
#[error("Failed to parse key string: {0}")]
ParseError(#[from] std::num::ParseIntError),
/// Failed to parse from a string with invalid format
#[error("Invalid key format: {0}")]
InvalidFormat(String),
/// Bitfield operation with invalid range (high < low)
#[error("Invalid bitfield range: low={low}, high={high} (must have low <= high)")]
InvalidBitfieldRange { low: u8, high: u8 },
/// Attempted to set a bit beyond the valid bit width
#[error("Bit index {bit} out of range for {width}-bit type")]
InvalidBitIndex { bit: u8, width: u8 },
#[error("Unknown error: {0}")]
Unknown(String),
}
impl DesError {
pub fn unknown(input: impl Into<String>) -> Self {
Self::Unknown(input.into())
}
}
macro_rules! impl_from_key_error_for_des {
($error_type:ty) => {
impl From<$error_type> for DesError {
fn from(error: $error_type) -> Self {
type Input = $error_type;
match error {
Input::ValueOutOfRange { value, max, width } => Self::KeyOutOfRange {
value: value as u64,
max: max as u64,
width,
},
Input::ParseError(err) => Self::ParseError(err),
Input::Unknown(msg) => Self::Unknown(msg),
}
}
}
};
}
impl_from_key_error_for_des!(SubkeyError);
impl_from_key_error_for_des!(KeyError);

View File

@ -1,5 +1,2 @@
mod key;
mod subkey;
pub use key::Key;
pub use subkey::Subkey;
pub mod key;
pub mod subkey;

View File

@ -1,5 +1,47 @@
use crate::keys::key::Key;
use bit_wrap::BitWrapper;
#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, BitWrapper)]
#[bit_width(48)]
pub struct Subkey(u64);
#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord)]
pub struct Subkeys([Subkey; 16]);
impl From<Key> for Subkey {
fn from(_key: Key) -> Self {
todo!("when other functions are moved to type methods, imlmenet this");
}
}
impl TryFrom<[u64; 16]> for Subkeys {
type Error = SubkeyError;
fn try_from(keys: [u64; 16]) -> Result<Self, Self::Error> {
let mut subkeys = [Subkey::default(); 16];
for (idx, &key) in keys.iter().enumerate() {
let subkey = Subkey::try_from(key)?;
subkeys[idx] = subkey;
}
Ok(Subkeys(subkeys))
}
}
impl From<Subkeys> for [Subkey; 16] {
fn from(subkeys: Subkeys) -> Self {
subkeys.0
}
}
impl AsRef<[Subkey; 16]> for Subkeys {
fn as_ref(&self) -> &[Subkey; 16] {
&self.0
}
}
impl PartialEq<u64> for Subkey {
fn eq(&self, other: &u64) -> bool {
&self.0 == other
}
}

View File

@ -1,39 +1,43 @@
mod blocks;
mod constants;
mod error;
mod keys;
pub mod blocks;
pub mod constants;
pub mod error;
pub mod keys;
use crate::constants::{E_BOX, FP, IP, P_BOX, PC1_TABLE, PC2_TABLE, ROUND_ROTATIONS, S_BOXES};
pub use keys::{Key, Subkey};
use crate::{
blocks::Block,
constants::{E_BOX, FP, IP, P_BOX, PC1_TABLE, PC2_TABLE, ROUND_ROTATIONS, S_BOXES},
error::DesError,
keys::{key::Key, subkey::Subkey},
};
#[derive(Debug)]
pub struct Des {
pub subkeys: [u64; 16],
pub subkeys: [Subkey; 16],
}
impl Des {
/// Create a new DES instance from a 64-bit key (8 bytes).
#[must_use]
pub fn new(key: u64) -> Self {
let subkeys = generate_subkeys(key);
pub fn new(key: impl Into<Key>) -> Self {
let subkeys = generate_subkeys(key.into());
Self { subkeys }
}
/// Encrypt a 64-bit block.
#[must_use]
pub fn encrypt(&self, block: u64) -> u64 {
self.des(block, true)
pub fn encrypt(&self, block: impl Into<Block>) -> u64 {
self.des(block.into(), true)
}
/// Decrypt a 64-bit block.
#[must_use]
pub fn decrypt(&self, block: u64) -> u64 {
self.des(block, false)
pub fn decrypt(&self, block: impl Into<Block>) -> u64 {
self.des(block.into(), false)
}
/// Core DES function: encrypt if forward=true, else decrypt.
#[must_use]
fn des(&self, block: u64, forward: bool) -> u64 {
fn des(&self, block: Block, forward: bool) -> u64 {
let permutated_block = ip(block);
let (left, right) = if forward {
@ -54,16 +58,18 @@ impl Des {
/// Accounts for DES specification's big-endian bit numbering (1-64, MSB first)
/// versus Rust u64's little-endian bit numbering (0-63, LSB first).
#[must_use]
pub fn pc1(key: u64) -> u64 {
permutate(key, 64, 56, &PC1_TABLE)
pub fn pc1(key: impl Into<Key>) -> u64 {
permutate(key.into(), 64, 56, &PC1_TABLE)
}
/// Compression permuation
/// Reduces 56-bits to 48-bit key
#[must_use]
pub fn pc2(key: u64) -> u64 {
pub fn pc2(key: u64) -> Result<Subkey, DesError> {
let key_56 = key & 0x00FF_FFFF_FFFF_FFFF;
permutate(key_56, 56, 48, &PC2_TABLE)
.try_into()
.map_err(DesError::from)
}
#[must_use]
@ -108,7 +114,7 @@ fn concatenate_halves(left: u32, right: u32, bit_offset: u8) -> u64 {
}
/// Generate 16 subkeys from the 64-bit key.
fn generate_subkeys(key: u64) -> [u64; 16] {
fn generate_subkeys(key: Key) -> [Subkey; 16] {
let reduced_key = pc1(key); // C_0, D_0
let (mut left, mut right) = split_block(reduced_key);
@ -120,7 +126,8 @@ fn generate_subkeys(key: u64) -> [u64; 16] {
let combined = concatenate_halves(left, right, 28);
pc2(combined)
})
.collect::<Vec<_>>()
.collect::<Result<Vec<_>, DesError>>()
.expect("no errors")
.try_into()
.expect("Exactly 16 subkeys expected")
}
@ -133,7 +140,13 @@ fn generate_subkeys(key: u64) -> [u64; 16] {
/// - `output_bits` - Number of bits in the output (1-64)
/// - `position_table` - 1-based positions (1 to `input_bits`) where each output bit comes from
#[must_use]
fn permutate(input: u64, input_bits: u32, output_bits: u32, position_table: &[u8]) -> u64 {
fn permutate(
input: impl Into<u64>,
input_bits: u32,
output_bits: u32,
position_table: &[u8],
) -> u64 {
let input = input.into();
position_table
.iter()
.enumerate()
@ -159,8 +172,8 @@ fn permutate(input: u64, input_bits: u32, output_bits: u32, position_table: &[u8
#[inline]
#[must_use]
fn ip(message: u64) -> u64 {
permutate(message, 64, 64, &IP)
fn ip(message: impl Into<Block>) -> u64 {
permutate(message.into(), 64, 64, &IP)
}
/// Expand the right side of the data from 32 bits to 48.
@ -202,7 +215,7 @@ pub fn fp(block: u64) -> u64 {
/// Process 16 Feistel rounds for ECB encryption/decryption.
#[must_use]
fn process_feistel_rounds(initial_block: u64, subkeys: &[u64]) -> (u32, u32) {
fn process_feistel_rounds(initial_block: u64, subkeys: &[Subkey]) -> (u32, u32) {
let (mut left, mut right) = split_block(initial_block);
for &subkey in subkeys {
(left, right) = feistel(left, right, subkey);
@ -213,7 +226,7 @@ fn process_feistel_rounds(initial_block: u64, subkeys: &[u64]) -> (u32, u32) {
/// Feistel function: Expand, XOR with subkey, S-box, permute.
/// `R_i` = `L_(i-1)` XOR f(`R_(i-1)`, `K_1`)
#[must_use]
fn feistel(left: u32, right: u32, subkey: u64) -> (u32, u32) {
fn feistel(left: u32, right: u32, subkey: Subkey) -> (u32, u32) {
let function_output = f_function(right, subkey);
let new_right = left ^ function_output;
// L_i = R_(i-1)
@ -222,9 +235,9 @@ fn feistel(left: u32, right: u32, subkey: u64) -> (u32, u32) {
}
#[must_use]
fn f_function(right: u32, subkey: u64) -> u32 {
fn f_function(right: u32, subkey: Subkey) -> u32 {
let expanded = expansion_permutation(right);
let xored = expanded ^ subkey;
let xored = expanded ^ subkey.as_ref();
let sboxed = s_box_substitution(xored);
p_box_permutation(sboxed)
}
@ -232,7 +245,7 @@ fn f_function(right: u32, subkey: u64) -> u32 {
#[cfg(test)]
mod tests {
use super::*;
use claims::assert_ge;
use claims::{assert_ge, assert_ok};
use rstest::rstest;
const TEST_KEY: u64 = 0x1334_5779_9BBC_DFF1;
@ -287,14 +300,9 @@ mod tests {
#[case(0x00F8_6655_7AAB_33C7, 0xBF91_8D3D_3F0A)] // K_15
#[case(0x00F0_CCAA_F556_678F, 0xCB3D_8B0E_17F5)] // K_16
fn pc2_permutaion_correct(#[case] before: u64, #[case] after: u64) {
let result = pc2(before);
let result = assert_ok!(pc2(before));
assert_eq!(result, after, "PC2 permutation failed");
assert_ge!(
result.leading_zeros(),
16,
"PC2 result should have leading 16 bits as 0"
);
assert_eq!(*result.as_ref(), after, "PC2 permutation failed");
}
#[test]

View File

@ -1,7 +1,6 @@
use des_lib::Des;
use rstest::rstest;
use des::Des;
const TEST_KEY: u64 = 0x1334_5779_9BBC_DFF1;
const TEST_PLAINTEXT: u64 = 0x0123_4567_89AB_CDEF;
const TEST_CIPHERTEXT: u64 = 0x85E8_1354_0F0A_B405;
@ -14,9 +13,9 @@ fn des_instance() -> Des {
#[test]
fn test_ecb_mode_equivalence() {
// If you implement ECB mode, test it matches single block
let key = 0x1334_5779_9BBC_DFF1;
let key = 0x1334_5779_9BBC_DFF1u64;
let des = Des::new(key);
let plain = 0x0123_4567_89AB_CDEF;
let plain = 0x0123_4567_89AB_CDEFu64;
let _single_block = des.encrypt(plain);
// let ecb_result = encrypt_ecb(&[plain]);
@ -52,7 +51,11 @@ fn encrypt_decrypt_roundtrip(
#[test]
fn weak_keys_rejected() {
let weak_keys = [0x0101010101010101, 0xFEFEFEFEFEFEFEFE, 0xE001E001E001E001];
let weak_keys = [
0x0101010101010101u64,
0xFEFEFEFEFEFEFEFE,
0xE001E001E001E001,
];
for key in weak_keys {
let des = Des::new(key);
@ -90,8 +93,8 @@ fn all_one_paintext() {
fn different_inputs() {
let des = des_instance();
let plain1 = 1;
let plain2 = 2;
let plain1 = 1u64;
let plain2 = 2u64;
let enc1 = des.encrypt(plain1);
let enc2 = des.encrypt(plain2);
assert_ne!(

View File

@ -1,4 +1,4 @@
use des::Des;
use des_lib::Des;
// Full expected subkeys for TEST_KEY (48 bits each, from FIPS spec)
const EXPECTED_SUBKEYS: [u64; 16] = [