Skip to content

Commit bb13fda

Browse files
authored
custom errors on union and tagged-union (#262)
* custom errors on union and tagged-union * use PydanticValueError directly * formatting tweak
1 parent 2e43d6c commit bb13fda

File tree

5 files changed

+142
-25
lines changed

5 files changed

+142
-25
lines changed

pydantic_core/core_schema.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,9 +527,15 @@ def nullable_schema(schema: CoreSchema, *, strict: bool | None = None, ref: str
527527
return dict_not_none(type='nullable', schema=schema, strict=strict, ref=ref)
528528

529529

530+
class CustomError(TypedDict):
531+
kind: str
532+
message: str
533+
534+
530535
class UnionSchema(TypedDict, total=False):
531536
type: Required[Literal['union']]
532537
choices: Required[List[CoreSchema]]
538+
custom_error: CustomError
533539
strict: bool
534540
ref: str
535541

@@ -542,6 +548,7 @@ class TaggedUnionSchema(TypedDict):
542548
type: Literal['tagged-union']
543549
choices: Dict[str, CoreSchema]
544550
discriminator: Union[str, List[Union[str, int]], List[List[Union[str, int]]], Callable[[Any], Optional[str]]]
551+
custom_error: NotRequired[CustomError]
545552
strict: NotRequired[bool]
546553
ref: NotRequired[str]
547554

src/errors/value_exception.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@ use crate::input::Input;
77
use super::{ErrorKind, ValError};
88

99
#[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")]
10-
#[derive(Clone)]
11-
#[cfg_attr(debug_assertions, derive(Debug))]
10+
#[derive(Debug, Clone)]
1211
pub struct PydanticValueError {
1312
kind: String,
1413
message_template: String,
@@ -18,7 +17,7 @@ pub struct PydanticValueError {
1817
#[pymethods]
1918
impl PydanticValueError {
2019
#[new]
21-
fn py_new(py: Python, kind: String, message_template: String, context: Option<&PyDict>) -> Self {
20+
pub fn py_new(py: Python, kind: String, message_template: String, context: Option<&PyDict>) -> Self {
2221
Self {
2322
kind,
2423
message_template,

src/validators/union.rs

Lines changed: 70 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use pyo3::types::{PyDict, PyList, PyString};
88
use ahash::AHashMap;
99

1010
use crate::build_tools::{is_strict, schema_or_config, SchemaDict};
11-
use crate::errors::{ErrorKind, ValError, ValLineError, ValResult};
11+
use crate::errors::{ErrorKind, PydanticValueError, ValError, ValLineError, ValResult};
1212
use crate::input::{GenericMapping, Input};
1313
use crate::lookup_key::LookupKey;
1414
use crate::questions::Question;
@@ -19,6 +19,7 @@ use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Ex
1919
#[derive(Debug, Clone)]
2020
pub struct UnionValidator {
2121
choices: Vec<CombinedValidator>,
22+
custom_error: Option<PydanticValueError>,
2223
strict: bool,
2324
name: String,
2425
}
@@ -31,8 +32,9 @@ impl BuildValidator for UnionValidator {
3132
config: Option<&PyDict>,
3233
build_context: &mut BuildContext,
3334
) -> PyResult<CombinedValidator> {
35+
let py = schema.py();
3436
let choices: Vec<CombinedValidator> = schema
35-
.get_as_req::<&PyList>(intern!(schema.py(), "choices"))?
37+
.get_as_req::<&PyList>(intern!(py, "choices"))?
3638
.iter()
3739
.map(|choice| build_validator(choice, config, build_context))
3840
.collect::<PyResult<Vec<CombinedValidator>>>()?;
@@ -41,13 +43,41 @@ impl BuildValidator for UnionValidator {
4143

4244
Ok(Self {
4345
choices,
46+
custom_error: get_custom_error(py, schema)?,
4447
strict: is_strict(schema, config)?,
4548
name: format!("{}[{}]", Self::EXPECTED_TYPE, descr),
4649
}
4750
.into())
4851
}
4952
}
5053

54+
fn get_custom_error(py: Python, schema: &PyDict) -> PyResult<Option<PydanticValueError>> {
55+
match schema.get_as::<&PyDict>(intern!(py, "custom_error"))? {
56+
Some(custom_error) => Ok(Some(PydanticValueError::py_new(
57+
py,
58+
custom_error.get_as_req::<String>(intern!(py, "kind"))?,
59+
custom_error.get_as_req::<String>(intern!(py, "message"))?,
60+
None,
61+
))),
62+
None => Ok(None),
63+
}
64+
}
65+
66+
impl UnionValidator {
67+
fn or_custom_error<'s, 'data>(
68+
&'s self,
69+
errors: Option<Vec<ValLineError<'data>>>,
70+
input: &'data impl Input<'data>,
71+
) -> ValError<'data> {
72+
if let Some(errors) = errors {
73+
ValError::LineErrors(errors)
74+
} else {
75+
let value_error = self.custom_error.as_ref().unwrap();
76+
value_error.clone().into_val_error(input)
77+
}
78+
}
79+
}
80+
5181
impl Validator for UnionValidator {
5282
fn validate<'s, 'data>(
5383
&'s self,
@@ -58,7 +88,10 @@ impl Validator for UnionValidator {
5888
recursion_guard: &'s mut RecursionGuard,
5989
) -> ValResult<'data, PyObject> {
6090
if extra.strict.unwrap_or(self.strict) {
61-
let mut errors: Vec<ValLineError> = Vec::with_capacity(self.choices.len());
91+
let mut errors: Option<Vec<ValLineError>> = match self.custom_error {
92+
Some(_) => None,
93+
None => Some(Vec::with_capacity(self.choices.len())),
94+
};
6295
let strict_extra = extra.as_strict();
6396

6497
for validator in &self.choices {
@@ -67,14 +100,16 @@ impl Validator for UnionValidator {
67100
otherwise => return otherwise,
68101
};
69102

70-
errors.extend(
71-
line_errors
72-
.into_iter()
73-
.map(|err| err.with_outer_location(validator.get_name().into())),
74-
);
103+
if let Some(ref mut errors) = errors {
104+
errors.extend(
105+
line_errors
106+
.into_iter()
107+
.map(|err| err.with_outer_location(validator.get_name().into())),
108+
);
109+
}
75110
}
76111

77-
Err(ValError::LineErrors(errors))
112+
Err(self.or_custom_error(errors, input))
78113
} else {
79114
// 1st pass: check if the value is an exact instance of one of the Union types,
80115
// e.g. use validate in strict mode
@@ -88,7 +123,10 @@ impl Validator for UnionValidator {
88123
return res;
89124
}
90125

91-
let mut errors: Vec<ValLineError> = Vec::with_capacity(self.choices.len());
126+
let mut errors: Option<Vec<ValLineError>> = match self.custom_error {
127+
Some(_) => None,
128+
None => Some(Vec::with_capacity(self.choices.len())),
129+
};
92130

93131
// 2nd pass: check if the value can be coerced into one of the Union types, e.g. use validate
94132
for validator in &self.choices {
@@ -97,14 +135,16 @@ impl Validator for UnionValidator {
97135
success => return success,
98136
};
99137

100-
errors.extend(
101-
line_errors
102-
.into_iter()
103-
.map(|err| err.with_outer_location(validator.get_name().into())),
104-
);
138+
if let Some(ref mut errors) = errors {
139+
errors.extend(
140+
line_errors
141+
.into_iter()
142+
.map(|err| err.with_outer_location(validator.get_name().into())),
143+
);
144+
}
105145
}
106146

107-
Err(ValError::LineErrors(errors))
147+
Err(self.or_custom_error(errors, input))
108148
}
109149
}
110150

@@ -160,6 +200,7 @@ pub struct TaggedUnionValidator {
160200
discriminator: Discriminator,
161201
from_attributes: bool,
162202
strict: bool,
203+
custom_error: Option<PydanticValueError>,
163204
tags_repr: String,
164205
discriminator_repr: String,
165206
name: String,
@@ -206,6 +247,7 @@ impl BuildValidator for TaggedUnionValidator {
206247
discriminator,
207248
from_attributes,
208249
strict: is_strict(schema, config)?,
250+
custom_error: get_custom_error(py, schema)?,
209251
tags_repr,
210252
discriminator_repr,
211253
name: format!("{}[{}]", Self::EXPECTED_TYPE, descr),
@@ -341,6 +383,8 @@ impl TaggedUnionValidator {
341383
Ok(res) => Ok(res),
342384
Err(err) => Err(err.with_outer_location(tag.as_ref().into())),
343385
}
386+
} else if let Some(ref custom_error) = self.custom_error {
387+
Err(custom_error.clone().into_val_error(input))
344388
} else {
345389
Err(ValError::new(
346390
ErrorKind::UnionTagInvalid {
@@ -354,11 +398,15 @@ impl TaggedUnionValidator {
354398
}
355399

356400
fn tag_not_found<'s, 'data>(&'s self, input: &'data impl Input<'data>) -> ValError<'data> {
357-
ValError::new(
358-
ErrorKind::UnionTagNotFound {
359-
discriminator: self.discriminator_repr.clone(),
360-
},
361-
input,
362-
)
401+
if let Some(ref custom_error) = self.custom_error {
402+
custom_error.clone().into_val_error(input)
403+
} else {
404+
ValError::new(
405+
ErrorKind::UnionTagNotFound {
406+
discriminator: self.discriminator_repr.clone(),
407+
},
408+
input,
409+
)
410+
}
363411
}
364412
}

tests/validators/test_tagged_union.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,3 +287,48 @@ def test_downcast_error():
287287
v = SchemaValidator({'type': 'tagged-union', 'discriminator': lambda x: 123, 'choices': {'str': {'type': 'str'}}})
288288
with pytest.raises(TypeError, match="'int' object cannot be converted to 'PyString'"):
289289
v.validate_python('x')
290+
291+
292+
def test_custom_error():
293+
v = SchemaValidator(
294+
{
295+
'type': 'tagged-union',
296+
'discriminator': 'foo',
297+
'custom_error': {'kind': 'snap', 'message': 'Input should be a foo or bar'},
298+
'choices': {
299+
'apple': {
300+
'type': 'typed-dict',
301+
'fields': {'foo': {'schema': {'type': 'str'}}, 'bar': {'schema': {'type': 'int'}}},
302+
},
303+
'banana': {
304+
'type': 'typed-dict',
305+
'fields': {
306+
'foo': {'schema': {'type': 'str'}},
307+
'spam': {'schema': {'type': 'list', 'items_schema': {'type': 'int'}}},
308+
},
309+
},
310+
},
311+
}
312+
)
313+
assert v.validate_python({'foo': 'apple', 'bar': '123'}) == {'foo': 'apple', 'bar': 123}
314+
with pytest.raises(ValidationError) as exc_info:
315+
v.validate_python({'spam': 'apple', 'bar': 'Bar'})
316+
# insert_assert(exc_info.value.errors())
317+
assert exc_info.value.errors() == [
318+
{
319+
'kind': 'snap',
320+
'loc': [],
321+
'message': 'Input should be a foo or bar',
322+
'input_value': {'spam': 'apple', 'bar': 'Bar'},
323+
}
324+
]
325+
with pytest.raises(ValidationError) as exc_info:
326+
v.validate_python({'foo': 'other', 'bar': 'Bar'})
327+
assert exc_info.value.errors() == [
328+
{
329+
'kind': 'snap',
330+
'loc': [],
331+
'message': 'Input should be a foo or bar',
332+
'input_value': {'foo': 'other', 'bar': 'Bar'},
333+
}
334+
]

tests/validators/test_union.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,3 +244,21 @@ def test_strict_union():
244244
{'kind': 'bool_type', 'loc': ['bool'], 'message': 'Input should be a valid boolean', 'input_value': '123'},
245245
{'kind': 'int_type', 'loc': ['int'], 'message': 'Input should be a valid integer', 'input_value': '123'},
246246
]
247+
248+
249+
def test_custom_error():
250+
v = SchemaValidator(
251+
{
252+
'type': 'union',
253+
'choices': [{'type': 'str'}, {'type': 'bytes'}],
254+
'custom_error': {'kind': 'my_error', 'message': 'Input should be a string or bytes'},
255+
}
256+
)
257+
assert v.validate_python('hello') == 'hello'
258+
assert v.validate_python(b'hello') == b'hello'
259+
with pytest.raises(ValidationError) as exc_info:
260+
v.validate_python(123)
261+
# insert_assert(exc_info.value.errors())
262+
assert exc_info.value.errors() == [
263+
{'kind': 'my_error', 'loc': [], 'message': 'Input should be a string or bytes', 'input_value': 123}
264+
]

0 commit comments

Comments
 (0)