Skip to content

Commit af07999

Browse files
authored
Implement frozen and extra_behavior for dataclasses (#505)
1 parent 8b2be89 commit af07999

File tree

8 files changed

+488
-166
lines changed

8 files changed

+488
-166
lines changed

pydantic_core/core_schema.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@ def dict_not_none(**kwargs: Any) -> Any:
1919
return {k: v for k, v in kwargs.items() if v is not None}
2020

2121

22+
ExtraBehavior = Literal['allow', 'forbid', 'ignore']
23+
24+
2225
class CoreConfig(TypedDict, total=False):
2326
title: str
2427
strict: bool
@@ -27,7 +30,7 @@ class CoreConfig(TypedDict, total=False):
2730
# if configs are merged, which should take precedence, default 0, default means child takes precedence
2831
config_merge_priority: int
2932
# settings related to typed_dicts only
30-
typed_dict_extra_behavior: Literal['allow', 'forbid', 'ignore']
33+
extra_fields_behavior: ExtraBehavior
3134
typed_dict_total: bool # default: True
3235
# used on typed-dicts and tagged union keys
3336
from_attributes: bool
@@ -2494,7 +2497,7 @@ class TypedDictSchema(TypedDict, total=False):
24942497
extra_validator: CoreSchema
24952498
return_fields_set: bool
24962499
# all these values can be set via config, equivalent fields have `typed_dict_` prefix
2497-
extra_behavior: Literal['allow', 'forbid', 'ignore']
2500+
extra_behavior: ExtraBehavior
24982501
total: bool # default: True
24992502
populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1
25002503
from_attributes: bool
@@ -2509,7 +2512,7 @@ def typed_dict_schema(
25092512
strict: bool | None = None,
25102513
extra_validator: CoreSchema | None = None,
25112514
return_fields_set: bool | None = None,
2512-
extra_behavior: Literal['allow', 'forbid', 'ignore'] | None = None,
2515+
extra_behavior: ExtraBehavior | None = None,
25132516
total: bool | None = None,
25142517
populate_by_name: bool | None = None,
25152518
from_attributes: bool | None = None,
@@ -2645,6 +2648,7 @@ class DataclassField(TypedDict, total=False):
26452648
schema: Required[CoreSchema]
26462649
kw_only: bool # default: True
26472650
init_only: bool # default: False
2651+
frozen: bool # default: False
26482652
validation_alias: Union[str, List[Union[str, int]], List[List[Union[str, int]]]]
26492653
serialization_alias: str
26502654
serialization_exclude: bool # default: False
@@ -2661,6 +2665,7 @@ def dataclass_field(
26612665
serialization_alias: str | None = None,
26622666
serialization_exclude: bool | None = None,
26632667
metadata: Any = None,
2668+
frozen: bool | None = None,
26642669
) -> DataclassField:
26652670
"""
26662671
Returns a schema for a dataclass field, e.g.:
@@ -2696,6 +2701,7 @@ def dataclass_field(
26962701
serialization_alias=serialization_alias,
26972702
serialization_exclude=serialization_exclude,
26982703
metadata=metadata,
2704+
frozen=frozen,
26992705
)
27002706

27012707

@@ -2708,6 +2714,7 @@ class DataclassArgsSchema(TypedDict, total=False):
27082714
ref: str
27092715
metadata: Any
27102716
serialization: SerSchema
2717+
extra_behavior: ExtraBehavior
27112718

27122719

27132720
def dataclass_args_schema(
@@ -2718,6 +2725,7 @@ def dataclass_args_schema(
27182725
ref: str | None = None,
27192726
metadata: Any = None,
27202727
serialization: SerSchema | None = None,
2728+
extra_behavior: ExtraBehavior | None = None,
27212729
) -> DataclassArgsSchema:
27222730
"""
27232731
Returns a schema for validating dataclass arguments, e.g.:
@@ -2754,6 +2762,7 @@ def dataclass_args_schema(
27542762
ref=ref,
27552763
metadata=metadata,
27562764
serialization=serialization,
2765+
extra_behavior=extra_behavior,
27572766
)
27582767

27592768

@@ -2764,6 +2773,7 @@ class DataclassSchema(TypedDict, total=False):
27642773
post_init: bool # default: False
27652774
revalidate_instances: Literal['always', 'never', 'subclass-instances'] # default: 'never'
27662775
strict: bool # default: False
2776+
frozen: bool # default False
27672777
ref: str
27682778
metadata: Any
27692779
serialization: SerSchema
@@ -2779,6 +2789,7 @@ def dataclass_schema(
27792789
ref: str | None = None,
27802790
metadata: Any = None,
27812791
serialization: SerSchema | None = None,
2792+
frozen: bool | None = None,
27822793
) -> DataclassSchema:
27832794
"""
27842795
Returns a schema for a dataclass. As with `ModelSchema`, this schema can only be used as a field within
@@ -2805,6 +2816,7 @@ def dataclass_schema(
28052816
ref=ref,
28062817
metadata=metadata,
28072818
serialization=serialization,
2819+
frozen=frozen,
28082820
)
28092821

28102822

src/build_tools.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,3 +248,35 @@ pub fn safe_repr(v: &PyAny) -> Cow<str> {
248248
"<unprintable object>".into()
249249
}
250250
}
251+
252+
#[derive(Debug, Clone)]
253+
pub(crate) enum ExtraBehavior {
254+
Allow,
255+
Forbid,
256+
Ignore,
257+
}
258+
259+
impl ExtraBehavior {
260+
pub fn from_schema_or_config(
261+
py: Python,
262+
schema: &PyDict,
263+
config: Option<&PyDict>,
264+
default: Self,
265+
) -> PyResult<Self> {
266+
let extra_behavior = schema_or_config::<Option<&str>>(
267+
schema,
268+
config,
269+
intern!(py, "extra_behavior"),
270+
intern!(py, "extra_fields_behavior"),
271+
)?
272+
.flatten();
273+
let res = match extra_behavior {
274+
Some("allow") => Self::Allow,
275+
Some("ignore") => Self::Ignore,
276+
Some("forbid") => Self::Forbid,
277+
Some(v) => return py_err!("Invalid extra_behavior: `{}`", v),
278+
None => default,
279+
};
280+
Ok(res)
281+
}
282+
}

src/serializers/type_serializers/typed_dict.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use ahash::{AHashMap, AHashSet};
88
use serde::ser::SerializeMap;
99

1010
use crate::build_context::BuildContext;
11-
use crate::build_tools::{py_error_type, schema_or_config, SchemaDict};
11+
use crate::build_tools::{py_error_type, schema_or_config, ExtraBehavior, SchemaDict};
1212
use crate::PydanticSerializationUnexpectedValue;
1313

1414
use super::{
@@ -80,16 +80,13 @@ impl BuildSerializer for TypedDictSerializer {
8080
) -> PyResult<CombinedSerializer> {
8181
let py = schema.py();
8282

83-
let extra_behavior = schema_or_config::<&str>(
84-
schema,
85-
config,
86-
intern!(py, "extra_behavior"),
87-
intern!(py, "typed_dict_extra_behavior"),
88-
)?;
8983
let total =
9084
schema_or_config(schema, config, intern!(py, "total"), intern!(py, "typed_dict_total"))?.unwrap_or(true);
9185

92-
let include_extra = extra_behavior == Some("allow");
86+
let include_extra = matches!(
87+
ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?,
88+
ExtraBehavior::Allow
89+
);
9390

9491
let fields_dict: &PyDict = schema.get_as_req(intern!(py, "fields"))?;
9592
let mut fields: AHashMap<String, TypedDictField> = AHashMap::with_capacity(fields_dict.len());

src/validators/dataclass.rs

Lines changed: 54 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use pyo3::types::{PyDict, PyList, PyString, PyTuple, PyType};
55

66
use ahash::AHashSet;
77

8-
use crate::build_tools::{is_strict, py_err, schema_or_config_same, SchemaDict};
8+
use crate::build_tools::{is_strict, py_err, schema_or_config_same, ExtraBehavior, SchemaDict};
99
use crate::errors::{ErrorType, ValError, ValLineError, ValResult};
1010
use crate::input::{GenericArguments, Input};
1111
use crate::lookup_key::LookupKey;
@@ -24,6 +24,7 @@ struct Field {
2424
init_only: bool,
2525
lookup_key: LookupKey,
2626
validator: CombinedValidator,
27+
frozen: bool,
2728
}
2829

2930
#[derive(Debug, Clone)]
@@ -33,6 +34,7 @@ pub struct DataclassArgsValidator {
3334
init_only_count: Option<usize>,
3435
dataclass_name: String,
3536
validator_name: String,
37+
extra_behavior: ExtraBehavior,
3638
}
3739

3840
impl BuildValidator for DataclassArgsValidator {
@@ -47,6 +49,8 @@ impl BuildValidator for DataclassArgsValidator {
4749

4850
let populate_by_name = schema_or_config_same(schema, config, intern!(py, "populate_by_name"))?.unwrap_or(false);
4951

52+
let extra_behavior = ExtraBehavior::from_schema_or_config(py, schema, config, ExtraBehavior::Ignore)?;
53+
5054
let fields_schema: &PyList = schema.get_as_req(intern!(py, "fields"))?;
5155
let mut fields: Vec<Field> = Vec::with_capacity(fields_schema.len());
5256

@@ -91,6 +95,7 @@ impl BuildValidator for DataclassArgsValidator {
9195
lookup_key,
9296
validator,
9397
init_only: field.get_as(intern!(py, "init_only"))?.unwrap_or(false),
98+
frozen: field.get_as::<bool>(intern!(py, "frozen"))?.unwrap_or(false),
9499
});
95100
}
96101

@@ -108,6 +113,7 @@ impl BuildValidator for DataclassArgsValidator {
108113
init_only_count,
109114
dataclass_name,
110115
validator_name,
116+
extra_behavior,
111117
}
112118
.into())
113119
}
@@ -254,11 +260,20 @@ impl Validator for DataclassArgsValidator {
254260
match raw_key.strict_str() {
255261
Ok(either_str) => {
256262
if !used_keys.contains(either_str.as_cow()?.as_ref()) {
257-
errors.push(ValLineError::new_with_loc(
258-
ErrorType::UnexpectedKeywordArgument,
259-
value,
260-
raw_key.as_loc_item(),
261-
));
263+
// Unknown / extra field
264+
match self.extra_behavior {
265+
ExtraBehavior::Forbid => {
266+
errors.push(ValLineError::new_with_loc(
267+
ErrorType::UnexpectedKeywordArgument,
268+
value,
269+
raw_key.as_loc_item(),
270+
));
271+
}
272+
ExtraBehavior::Ignore => {}
273+
ExtraBehavior::Allow => {
274+
output_dict.set_item(either_str.as_py_string(py), value)?
275+
}
276+
}
262277
}
263278
}
264279
Err(ValError::LineErrors(line_errors)) => {
@@ -303,7 +318,19 @@ impl Validator for DataclassArgsValidator {
303318
) -> ValResult<'data, PyObject> {
304319
let dict: &PyDict = obj.downcast()?;
305320

321+
let ok = |output: PyObject| {
322+
dict.set_item(field_name, output)?;
323+
Ok(dict.to_object(py))
324+
};
325+
306326
if let Some(field) = self.fields.iter().find(|f| f.name == field_name) {
327+
if field.frozen {
328+
return Err(ValError::new_with_loc(
329+
ErrorType::FrozenField,
330+
field_value,
331+
field.name.to_string(),
332+
));
333+
}
307334
// by using dict but removing the field in question, we match V1 behaviour
308335
let data_dict = dict.copy()?;
309336
if let Err(err) = data_dict.del_item(field_name) {
@@ -321,10 +348,7 @@ impl Validator for DataclassArgsValidator {
321348
.validator
322349
.validate(py, field_value, &next_extra, slots, recursion_guard)
323350
{
324-
Ok(output) => {
325-
dict.set_item(field_name, output)?;
326-
Ok(dict.to_object(py))
327-
}
351+
Ok(output) => ok(output),
328352
Err(ValError::LineErrors(line_errors)) => {
329353
let errors = line_errors
330354
.into_iter()
@@ -335,13 +359,21 @@ impl Validator for DataclassArgsValidator {
335359
Err(err) => Err(err),
336360
}
337361
} else {
338-
Err(ValError::new_with_loc(
339-
ErrorType::NoSuchAttribute {
340-
attribute: field_name.to_string(),
341-
},
342-
field_value,
343-
field_name.to_string(),
344-
))
362+
// Handle extra (unknown) field
363+
// We partially use the extra_behavior for initialization / validation
364+
// to determine how to handle assignment
365+
match self.extra_behavior {
366+
// For dataclasses we allow assigning unknown fields
367+
// to match stdlib dataclass behavior
368+
ExtraBehavior::Allow => ok(field_value.to_object(py)),
369+
_ => Err(ValError::new_with_loc(
370+
ErrorType::NoSuchAttribute {
371+
attribute: field_name.to_string(),
372+
},
373+
field_value,
374+
field_name.to_string(),
375+
)),
376+
}
345377
}
346378
}
347379

@@ -364,6 +396,7 @@ pub struct DataclassValidator {
364396
post_init: Option<Py<PyString>>,
365397
revalidate: Revalidate,
366398
name: String,
399+
frozen: bool,
367400
}
368401

369402
impl BuildValidator for DataclassValidator {
@@ -399,6 +432,7 @@ impl BuildValidator for DataclassValidator {
399432
// as with model, get the class's `__name__`, not using `class.name()` since it uses `__qualname__`
400433
// which is not what we want here
401434
name: class.getattr(intern!(py, "__name__"))?.extract()?,
435+
frozen: schema.get_as(intern!(py, "frozen"))?.unwrap_or(false),
402436
}
403437
.into())
404438
}
@@ -455,6 +489,9 @@ impl Validator for DataclassValidator {
455489
slots: &'data [CombinedValidator],
456490
recursion_guard: &'s mut RecursionGuard,
457491
) -> ValResult<'data, PyObject> {
492+
if self.frozen {
493+
return Err(ValError::new(ErrorType::FrozenInstance, field_value));
494+
}
458495
let dict_py_str = intern!(py, "__dict__");
459496
let dict: &PyDict = obj.getattr(dict_py_str)?.downcast()?;
460497

0 commit comments

Comments
 (0)