fix: nested structure

This commit is contained in:
2025-07-15 16:11:28 +03:00
parent db1dab2aa1
commit 60488d364e
4 changed files with 171 additions and 119 deletions

View File

@@ -2,7 +2,7 @@ use proc_macro2::TokenStream;
use quote::{format_ident, quote};
use syn::{
Attribute, Data, DeriveInput, Error, Expr, Field, Fields, FieldsNamed, GenericParam, Generics,
Lit, Meta, MetaList, Result, Type, TypePath, WhereClause, WherePredicate, parse_quote, parse2,
Ident, Lit, Meta, MetaList, Result, Type, parse_quote,
};
const WITH_MERGE: bool = cfg!(feature = "merge");
@@ -17,9 +17,8 @@ pub fn impl_from_file(input: &DeriveInput) -> Result<TokenStream> {
let file_ident = format_ident!("{name}File");
let fields = extract_named_fields(input)?;
let (field_assignments, file_fields, default_bounds) = process_fields(fields)?;
let (field_assignments, file_fields) = process_fields(fields)?;
let where_clause = build_where_clause(where_clause.cloned(), default_bounds)?;
let derive_clause = build_derive_clause();
Ok(quote! {
@@ -71,13 +70,15 @@ fn extract_named_fields(input: &DeriveInput) -> Result<&FieldsNamed> {
}
/// Build the shadow field + assignment for one original field
fn build_file_field(field: &Field) -> Result<(TokenStream, TokenStream, Option<TokenStream>)> {
fn build_file_field(field: &Field) -> Result<(TokenStream, TokenStream)> {
let ident = field
.ident
.as_ref()
.ok_or_else(|| Error::new_spanned(field, "Expected named fields"))?;
let ty = &field.ty;
let default_override = parse_from_file_default_attr(&field.attrs)?;
let field_attrs = if WITH_MERGE {
quote! { #[merge(strategy = merge::option::overwrite_none)] }
} else {
@@ -85,59 +86,41 @@ fn build_file_field(field: &Field) -> Result<(TokenStream, TokenStream, Option<T
};
// Nested struct -> delegate to its own `FromFile` impl
let shadow_ty = quote! { <#ty as filecaster::FromFile>::Shadow };
let field_decl = quote! {
#field_attrs
pub #ident: Option<#ty>
};
let assign = quote! {
#ident: <#ty as filecaster::FromFile>::from_file(file.#ident)
pub #ident: Option<#shadow_ty>
};
let default_bound = Some(quote! { #ty: Default });
let assign = build_file_assing(ident, ty, default_override);
Ok((field_decl, assign, default_bound))
Ok((field_decl, assign))
}
fn build_file_assing(ident: &Ident, ty: &Type, default_override: Option<Expr>) -> TokenStream {
if let Some(expr) = default_override {
return quote! {
#ident: file.#ident.map(|inner| <#ty as filecaster::FromFile>::from_file(Some(inner))).unwrap_or(#expr)
};
}
quote! {
#ident: <#ty as filecaster::FromFile>::from_file(file.#ident)
}
}
/// Process all fields
fn process_fields(
fields: &FieldsNamed,
) -> Result<(Vec<TokenStream>, Vec<TokenStream>, Vec<TokenStream>)> {
fn process_fields(fields: &FieldsNamed) -> Result<(Vec<TokenStream>, Vec<TokenStream>)> {
fields.named.iter().try_fold(
(Vec::new(), Vec::new(), Vec::new()),
|(mut assignments, mut file_fields, mut defaults), field| {
let (file_field, assignment, default_value) = build_file_field(field)?;
(Vec::new(), Vec::new()),
|(mut assignments, mut file_fields), field| {
let (file_field, assignment) = build_file_field(field)?;
file_fields.push(file_field);
assignments.push(assignment);
if let Some(value) = default_value {
defaults.push(value);
}
Ok((assignments, file_fields, defaults))
Ok((assignments, file_fields))
},
)
}
/// Where-clause helpers
fn build_where_clause(
where_clause: Option<WhereClause>,
default_bounds: Vec<TokenStream>,
) -> Result<Option<WhereClause>> {
if default_bounds.is_empty() {
return Ok(where_clause);
}
let mut where_clause = where_clause;
if let Some(wc) = &mut where_clause {
for bound in default_bounds {
let predicate = parse2::<WherePredicate>(bound.clone())
.map_err(|_| Error::new_spanned(&bound, "Failed to parse where predicate"))?;
wc.predicates.push(predicate);
}
} else {
where_clause = Some(parse_quote!(where #(#default_bounds),*));
}
Ok(where_clause)
}
/// Derive clause for the shadow struct
fn build_derive_clause() -> TokenStream {
quote! {
@@ -150,8 +133,8 @@ fn build_derive_clause() -> TokenStream {
/// Add Default bound to every generic parameter
fn add_trait_bounds(mut generics: Generics) -> Generics {
for param in &mut generics.params {
if let GenericParam::Type(type_param) = param {
type_param.bounds.push(parse_quote!(Default));
if let GenericParam::Type(ty) = param {
ty.bounds.push(parse_quote!(Default));
}
}
generics
@@ -250,63 +233,10 @@ mod tests {
b: String,
}
};
let (assign, file_fields, bounds) = process_fields(&fields).unwrap();
let (assign, file_fields) = process_fields(&fields).unwrap();
// two fields
assert_eq!(assign.len(), 2);
assert_eq!(file_fields.len(), 2);
// a uses unwrap_or_else
assert!(
assign[0]
.to_string()
.contains("a : file . a . unwrap_or_else")
);
// b uses unwrap_or_default
assert!(
assign[1]
.to_string()
.contains("b : file . b . unwrap_or_default")
);
// default-bound should only mention String
assert_eq!(bounds.len(), 1);
assert!(bounds[0].to_string().contains("String : Default"));
}
#[test]
fn build_where_clause_to_new() {
let bounds = vec![quote! { A: Default }, quote! { B: Default }];
let wc = build_where_clause(None, bounds).unwrap().unwrap();
let s = wc.to_token_stream().to_string();
assert!(s.contains("where A : Default , B : Default"));
}
#[test]
fn build_where_clause_append_existing() {
let orig: WhereClause = parse_quote!(where X: Clone);
let bounds = vec![quote! { Y: Default }];
let wc = build_where_clause(Some(orig.clone()), bounds)
.unwrap()
.unwrap();
let preds: Vec<_> = wc
.predicates
.iter()
.map(|p| p.to_token_stream().to_string())
.collect();
assert!(preds.contains(&"X : Clone".to_string()));
assert!(preds.contains(&"Y : Default".to_string()));
}
#[test]
fn build_where_clause_no_bounds_keeps_original() {
let orig: WhereClause = parse_quote!(where Z: Eq);
let wc = build_where_clause(Some(orig.clone()), vec![])
.unwrap()
.unwrap();
let preds: Vec<_> = wc
.predicates
.iter()
.map(|p| p.to_token_stream().to_string())
.collect();
assert_eq!(preds, vec!["Z : Eq".to_string()]);
}
#[test]

View File

@@ -83,7 +83,7 @@ use proc_macro::TokenStream;
use proc_macro_error2::proc_macro_error;
use syn::{DeriveInput, parse_macro_input};
/// Implements the `FromFile` derive macro.
/// Implements the [`FromFile`] trait.
///
/// This macro processes the `#[from_file]` attribute on structs to generate
/// code for loading data from files.