feat(macros): add BitWrapper derive macro

This commit is contained in:
Kristofers Solo 2025-10-04 16:37:35 +03:00
parent 405112f0ad
commit d362a7df3a
Signed by: kristoferssolo
GPG Key ID: 8687F2D3EEE6F0ED
6 changed files with 548 additions and 0 deletions

14
bit-wrap/Cargo.toml Normal file
View File

@ -0,0 +1,14 @@
[package]
name = "bit-wrap"
version = "0.1.0"
edition = "2024"
[lib]
proc-macro = true
[dependencies]
quote = "1.0"
unsynn = "0.1"
[lints]
workspace = true

36
bit-wrap/src/ast.rs Normal file
View File

@ -0,0 +1,36 @@
use crate::grammar;
use unsynn::*;
pub struct Struct {
pub attr: Attribute,
pub name: Ident,
pub body: Ident,
}
impl Struct {
pub fn bit_width(&self) -> u128 {
self.attr.bit_width
}
}
pub struct Attribute {
pub bit_width: u128,
}
impl From<grammar::Attribute> for Attribute {
fn from(value: grammar::Attribute) -> Self {
Self {
bit_width: value.bit_width.content.bit_width.content.value(),
}
}
}
impl From<grammar::StructDef> for Struct {
fn from(value: grammar::StructDef) -> Self {
Self {
attr: value.attr.into(),
name: value.name,
body: value.body.content,
}
}
}

View File

@ -0,0 +1,7 @@
use crate::{codegen::generate_impl, grammar::StructDef};
use unsynn::*;
pub fn impl_bit_wrapper(input: &TokenStream) -> TokenStream {
let parsed = input.to_token_iter().parse::<StructDef>().unwrap();
generate_impl(&parsed.into())
}

431
bit-wrap/src/codegen.rs Normal file
View File

@ -0,0 +1,431 @@
use crate::ast::Struct;
use quote::quote;
use unsynn::*;
#[ignore = "too_many_lines"]
pub fn generate_impl(info: &Struct) -> TokenStream {
let name = &info.name;
let inner = &info.body;
let ops = generate_bitwise_ops(info);
let fmt = generate_bitwise_fmt(info);
let wrapper = generate_bitwise_wrapper(info);
quote! {
#ops
#fmt
#wrapper
impl std::convert::From<#name> for #inner {
fn from(value: #name) -> Self {
value.0
}
}
impl AsRef<#inner> for #name {
fn as_ref(&self) -> &#inner {
&self.0
}
}
impl std::ops::Deref for #name {
type Target = #inner;
fn deref(&self) -> &Self::Target {
&self.0
}
}
}
}
fn generate_bitwise_fmt(info: &Struct) -> TokenStream {
let name = &info.name;
quote! {
impl std::fmt::LowerHex for #name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::LowerHex::fmt(&self.0, f)
}
}
impl std::fmt::UpperHex for #name {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
std::fmt::UpperHex::fmt(&self.0, f)
}
}
impl std::fmt::Debug for Subkey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Subkey(0x{:012X})", self.0)
}
}
impl std::fmt::Display for Subkey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "0x{:012X}", self.0)
}
}
}
}
fn generate_bitwise_ops(info: &Struct) -> TokenStream {
let name = &info.name;
let inner = &info.body;
let bit_width = u8::try_from(info.bit_width()).expect("8-bit value");
let hex_width = usize::from(bit_width.div_ceil(4));
let bin_width = usize::from(bit_width);
let max_value = (1u64 << (bit_width)) - 1;
quote! {
impl #name {
/// The bit width of this type
pub const BIT_WIDTH: u8 = #bit_width;
/// Minimum value for this bit width
pub const MIN: #inner = 0;
/// Maximum value for this bit width
pub const MAX: #inner = #max_value;
/// Create a new [`Self`] from a key value
#[inline]
#[macro_use]
pub fn new(key: #inner) -> Result<Self, String> {
key.try_into()
}
/// Convert to hex string (formatted for the bit width)
#[macro_use]
pub fn to_hex(self) -> String {
let value = self.0 & Self::MAX;
format!("{:0width$X}", value, width = #hex_width)
}
/// Convert to binary string (full bit width with leading zeros)
#[macro_use]
pub fn to_bin(self) -> String {
let value = self.0 & Self::MAX;
format!("{:0width$b}", value, width = #bin_width)
}
/// Check if all bits are set
#[inline]
#[macro_use]
pub const fn all(self) -> bool {
self.0 == Self::MAX
}
/// Check if any bit is set
#[inline]
#[macro_use]
pub const fn any(self) -> bool {
self.0 != 0
}
/// Count the number of set bits
#[inline]
#[macro_use]
pub const fn count_ones(self) -> u32 {
self.0.count_ones()
}
/// Count the number of zero bits within the constrained bit width
#[macro_use]
pub const fn count_zeros(self) -> u32 {
let value = self.0 as #inner & Self::MAX;
let ones_count = self.count_ones();
(Self::BIT_WIDTH as u32) - ones_count
}
/// Reverse the bit pattern within the constrained bit width
#[macro_use]
pub const fn reverse_bits(self) -> Self {
let value = self.0 as #inner & Self::MAX;
let reversed = value.reverse_bits();
// Mask and shift to keep only the relevant bits
let result = ((reversed >> (64 - Self::BIT_WIDTH)) & Self::MAX).reverse_bits()
>> (64 - Self::BIT_WIDTH);
Self((result & Self::MAX) as #inner)
}
/// Rotate the bits left by `n` positions within the bit width
#[macro_use]
pub const fn rotate_left(self, n: u8) -> Self {
let n = n % Self::BIT_WIDTH;
if n == 0 {
return self;
}
let masked = self.0 & Self::MAX;
let rotated = ((masked << n) | (masked >> Self::BIT_WIDTH.saturating_sub(n))) & Self::MAX;
#name(rotated)
}
/// Rotate the bits right by `n` positions within the bit width
#[macro_use]
pub const fn rotate_right(self, n: u8) -> Self {
let n = n % Self::BIT_WIDTH;
if n == 0 {
return self;
}
let masked = self.0 & Self::MAX;
let rotated = ((masked >> n) | (masked << Self::BIT_WIDTH.saturating_sub(n))) & Self::MAX;
#name(rotated)
}
/// Find the number of leading zero bits within the bit width
#[macro_use]
pub fn leading_zeros(self) -> u32 {
let value = self.0 as #inner & Self::MAX;
let bit_width = Self::BIT_WIDTH as u32;
if value == 0 {
return bit_width;
}
let full_leading_zeros = value.leading_zeros();
let adjustment = 64u32.saturating_sub(bit_width);
full_leading_zeros.saturating_sub(adjustment).min(bit_width)
}
/// Find the number of trailing zero bits within the bit width
#[macro_use]
pub fn trailing_zeros(self) -> u32 {
let value = self.0 as #inner & Self::MAX;
let reversed = self.reverse_bits();
reversed.leading_zeros()
}
/// Find the number of leading one bits from the MSB within the bit width
#[macro_use]
pub fn leading_ones(self) -> u32 {
let value = self.0 as #inner & Self::MAX;
let not_value = !value & Self::MAX;
let bit_width = Self::BIT_WIDTH as u32;
if not_value == 0 {
return bit_width;
}
bit_width - not_value.trailing_zeros().min(bit_width)
}
/// Find the number of trailing one bits from the LSB within the bit width
#[macro_use]
pub fn trailing_ones(self) -> u32 {
let value = self.0 as #inner & Self::MAX;
let reversed = self.reverse_bits();
reversed.leading_ones()
}
/// Check if the value is zero within the bit width
#[macro_use]
pub const fn is_zero(self) -> bool {
(self.0 & Self::MAX) == 0
}
/// Set a specific bit to 1 (clamped to bit width)
#[macro_use]
pub const fn bit_is_set(self, bit: u8) -> bool {
let clamped = self.clamp_bit_positions(bit);
let mask = 1<< clamped;
(self.0 & mask) == mask
}
/// Set a specific bit to 1 (clamped to bit width)
pub const fn set_bit(&mut self, bit: u8) {
let clamped = self.clamp_bit_positions(bit);
self.0 |= 1 << clamped;
}
/// Clear a specific bit to 0 (clamped to bit width)
pub const fn clear_bit(&mut self, bit: u8) {
let clamped = self.clamp_bit_positions(bit);
let mask = !(1 << clamped);
self.0 &= (mask & Self::MAX);
}
/// Toggle a specific bit (clamped to bit width)
pub const fn toggle_bit(&mut self, bit: u8) {
let clamped = self.clamp_bit_positions(bit);
self.0 ^= 1 << clamped;
}
/// Check if bits in the given mask are set (within bit width)
#[macro_use]
pub const fn contains_mask(self, mask: Self) -> bool {
let self_value = self.0 & Self::MAX;
let mask_value = mask.0 & Self::MAX;
(self_value & mask_value) == mask.0
}
/// Create from a hex string with bit width validation
pub fn from_hex(hex: &str) -> Result<Self, String> {
let value = #inner::from_str_radix(hex, 16).map_err(|e| format!("Invalid hex string: {e}"))?;
let masked = value & Self::MAX;
if value != masked {
return Err(
format!(
"Hex value 0x{value:X} exceeds {}-bit limit (masked to 0x{masked:0width$X})",
Self::BIT_WIDTH, width = #hex_width
)
)
}
Ok(Self(value))
}
/// Extract a bitfield from the constrained range
#[macro_use]
pub const fn bitfield(self, low: u8, high: u8) -> u64 {
let low = self.clamp_bit_positions(low);
let high = self.clamp_bit_positions(high);
let start= if low <= high { low }else{ high };
let end = if low <= high { high }else{ low };
let width = high.wrapping_sub(low).saturating_sub(1);
if width == 0 {
return 0;
}
let value = self.0 & Self::MAX;
(value >> low) & ((1 << width) - 1)
}
/// Set bits from a constrained range [low, high]
pub const fn set_bitfield(&mut self, value: u64, low: u32, high: u32) {
let mask = ((1 << (high - low + 1)) - 1) << low;
self.0 = (self.0 & !mask) | ((value << low) & mask);
}
#[macro_use]
const fn clamp_bit_positions(self, bit: u8) -> u8 {
if bit >= Self::BIT_WIDTH {
return Self::BIT_WIDTH - 1;
}
bit
}
}
}
}
fn generate_bitwise_wrapper(_info: &Struct) -> TokenStream {
quote! {
impl std::ops::Not for Subkey {
type Output = Self;
fn not(self) -> Self::Output {
Self(!self.0)
}
}
impl std::ops::BitAnd for Subkey {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
Self(self.0 & rhs.0)
}
}
impl std::ops::BitAnd<&Self> for Subkey {
type Output = Self;
fn bitand(self, rhs: &Self) -> Self::Output {
Self(self.0 & rhs.0)
}
}
impl std::ops::BitAndAssign for Subkey {
fn bitand_assign(&mut self, rhs: Self) {
self.0 &= rhs.0;
}
}
impl std::ops::BitAndAssign<&Self> for Subkey {
fn bitand_assign(&mut self, rhs: &Self) {
self.0 &= rhs.0;
}
}
impl std::ops::BitOr for Subkey {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
Self(self.0 | rhs.0)
}
}
impl std::ops::BitOr<&Self> for Subkey {
type Output = Self;
fn bitor(self, rhs: &Self) -> Self::Output {
Self(self.0 | rhs.0)
}
}
impl std::ops::BitOrAssign for Subkey {
fn bitor_assign(&mut self, rhs: Self) {
self.0 |= rhs.0;
}
}
impl std::ops::BitOrAssign<&Self> for Subkey {
fn bitor_assign(&mut self, rhs: &Self) {
self.0 |= rhs.0;
}
}
impl std::ops::BitXor for Subkey {
type Output = Self;
fn bitxor(self, rhs: Self) -> Self {
Self(self.0 ^ rhs.0)
}
}
impl std::ops::BitXor<&Self> for Subkey {
type Output = Self;
fn bitxor(self, rhs: &Self) -> Self {
Self(self.0 ^ rhs.0)
}
}
impl std::ops::BitXorAssign for Subkey {
fn bitxor_assign(&mut self, rhs: Self) {
self.0 ^= rhs.0;
}
}
impl std::ops::BitXorAssign<&Self> for Subkey {
fn bitxor_assign(&mut self, rhs: &Self) {
self.0 ^= rhs.0;
}
}
impl std::ops::Shl<u32> for Subkey {
type Output = Self;
fn shl(self, rhs: u32) -> Self {
Self(self.0 << rhs)
}
}
impl std::ops::ShlAssign<u32> for Subkey {
fn shl_assign(&mut self, rhs: u32) {
self.0 <<= rhs;
}
}
impl std::ops::Shr<u32> for Subkey {
type Output = Self;
fn shr(self, rhs: u32) -> Self::Output {
Self(self.0 >> rhs)
}
}
impl std::ops::ShrAssign<u32> for Subkey {
fn shr_assign(&mut self, rhs: u32) {
self.0 >>= rhs;
}
}
}
}

48
bit-wrap/src/grammar.rs Normal file
View File

@ -0,0 +1,48 @@
use unsynn::*;
keyword! {
pub KwStruct = "struct";
pub KwPub = "pub";
pub KwBitWidth = "bit_width";
}
unsynn! {
pub struct BitWidth{
pub kw_bit_width: KwBitWidth,
pub bit_width: ParenthesisGroupContaining<LiteralInteger>,
}
pub struct Attribute {
pub pound: Pound,
pub bit_width: BracketGroupContaining<BitWidth>,
}
pub struct StructDef {
pub attr: Attribute,
pub vis: Option<KwPub>,
pub kw_struct: KwStruct,
pub name: Ident,
pub body: ParenthesisGroupContaining<Ident>,
pub semi: Semicolon,
}
}
#[cfg(test)]
mod tests {
use super::*;
const SAMPLE: &str = r#"
#[bit_width(48)]
pub struct Subkey(u64);
"#;
#[test]
fn parse_attribute() {
let mut iter = SAMPLE.to_token_iter();
let sdef = iter
.parse::<StructDef>()
.expect("failed to parse StructDef");
dbg!(sdef);
assert!(false);
}
}

12
bit-wrap/src/lib.rs Normal file
View File

@ -0,0 +1,12 @@
mod ast;
mod bit_wrapper;
mod codegen;
mod grammar;
use crate::bit_wrapper::impl_bit_wrapper;
use proc_macro::TokenStream;
#[proc_macro_derive(BitWrapper, attributes(bit_width))]
pub fn derive_bit_wrapper(input: TokenStream) -> TokenStream {
impl_bit_wrapper(&input.into()).into()
}