Skip to content

Commit 09427c7

Browse files
authored
Custom error kind (#284)
* allow PydanticKindError with unions * improve coverage * fix conflicting tests
1 parent dbe128a commit 09427c7

File tree

13 files changed

+185
-58
lines changed

13 files changed

+185
-58
lines changed

pydantic_core/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from ._pydantic_core import (
22
PydanticCustomError,
3-
PydanticErrorKind,
3+
PydanticKindError,
44
PydanticOmit,
55
SchemaError,
66
SchemaValidator,
@@ -17,6 +17,6 @@
1717
'SchemaError',
1818
'ValidationError',
1919
'PydanticCustomError',
20-
'PydanticErrorKind',
20+
'PydanticKindError',
2121
'PydanticOmit',
2222
)

pydantic_core/_pydantic_core.pyi

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ __all__ = (
1414
'SchemaError',
1515
'ValidationError',
1616
'PydanticCustomError',
17-
'PydanticErrorKind',
17+
'PydanticKindError',
1818
'PydanticOmit',
1919
)
2020
__version__: str
@@ -59,7 +59,7 @@ class PydanticCustomError(ValueError):
5959
def __init__(self, kind: str, message_template: str, context: 'dict[str, str | int] | None' = None) -> None: ...
6060
def message(self) -> str: ...
6161

62-
class PydanticErrorKind(ValueError):
62+
class PydanticKindError(ValueError):
6363
kind: str
6464
message_template: str
6565
context: 'dict[str, str | int] | None'

pydantic_core/core_schema.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -618,17 +618,20 @@ def nullable_schema(schema: CoreSchema, *, strict: bool | None = None, ref: str
618618
return dict_not_none(type='nullable', schema=schema, strict=strict, ref=ref)
619619

620620

621-
class CustomError(TypedDict):
622-
kind: str
621+
class CustomError(TypedDict, total=False):
622+
kind: Required[str]
623623
message: str
624+
context: Dict[str, Union[str, int]]
624625

625626

626-
def _custom_error(kind: str | None, message: str | None) -> CustomError | None:
627-
if kind is None and message is None:
627+
def _custom_error(
628+
kind: str | None, message: str | None, context: dict[str, str | int] | None = None
629+
) -> CustomError | None:
630+
if kind is None and message is None and context is None:
628631
return None
629632
else:
630633
# let schema validation raise the error
631-
return CustomError(kind=kind, message=message) # type: ignore
634+
return dict_not_none(kind=kind, message=message, context=context)
632635

633636

634637
class UnionSchema(TypedDict, total=False):
@@ -659,13 +662,14 @@ def union_schema(
659662
*choices: CoreSchema,
660663
custom_error_kind: str | None = None,
661664
custom_error_message: str | None = None,
665+
custom_error_context: dict[str, str | int] | None = None,
662666
strict: bool | None = None,
663667
ref: str | None = None,
664668
) -> UnionSchema:
665669
return dict_not_none(
666670
type='union',
667671
choices=choices,
668-
custom_error=_custom_error(custom_error_kind, custom_error_message),
672+
custom_error=_custom_error(custom_error_kind, custom_error_message, custom_error_context),
669673
strict=strict,
670674
ref=ref,
671675
)
@@ -710,14 +714,15 @@ def tagged_union_schema(
710714
*,
711715
custom_error_kind: str | None = None,
712716
custom_error_message: str | None = None,
717+
custom_error_context: dict[str, int | str] | None = None,
713718
strict: bool | None = None,
714719
ref: str | None = None,
715720
) -> TaggedUnionSchema:
716721
return dict_not_none(
717722
type='tagged-union',
718723
choices=choices,
719724
discriminator=discriminator,
720-
custom_error=_custom_error(custom_error_kind, custom_error_message),
725+
custom_error=_custom_error(custom_error_kind, custom_error_message, custom_error_context),
721726
strict=strict,
722727
ref=ref,
723728
)

src/errors/kinds.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,11 @@ impl ErrorKind {
441441
}
442442
}
443443

444+
pub fn valid_kind(py: Python, kind: &str) -> bool {
445+
let lookup = ERROR_KIND_LOOKUP.get_or_init(py, Self::build_lookup);
446+
lookup.contains_key(kind)
447+
}
448+
444449
fn build_lookup() -> AHashMap<String, Self> {
445450
let mut lookup = AHashMap::new();
446451
for error_kind in Self::iter() {

src/errors/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ pub use self::kinds::{list_all_errors, ErrorKind};
1010
pub use self::line_error::{pretty_line_errors, InputValue, ValError, ValLineError, ValResult};
1111
pub use self::location::LocItem;
1212
pub use self::validation_exception::ValidationError;
13-
pub use self::value_exception::{PydanticCustomError, PydanticErrorKind, PydanticOmit};
13+
pub use self::value_exception::{PydanticCustomError, PydanticKindError, PydanticOmit};
1414

1515
pub fn py_err_string(py: Python, err: PyErr) -> String {
1616
let value = err.value(py);

src/errors/value_exception.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,12 @@ impl PydanticCustomError {
100100

101101
#[pyclass(extends=PyValueError, module="pydantic_core._pydantic_core")]
102102
#[derive(Debug, Clone)]
103-
pub struct PydanticErrorKind {
103+
pub struct PydanticKindError {
104104
kind: ErrorKind,
105105
}
106106

107107
#[pymethods]
108-
impl PydanticErrorKind {
108+
impl PydanticKindError {
109109
#[new]
110110
pub fn py_new(py: Python, kind: &str, context: Option<&PyDict>) -> PyResult<Self> {
111111
let kind = ErrorKind::new(py, kind, context).map_err(PyTypeError::new_err)?;
@@ -144,7 +144,7 @@ impl PydanticErrorKind {
144144
}
145145
}
146146

147-
impl PydanticErrorKind {
147+
impl PydanticKindError {
148148
pub fn into_val_error<'a>(self, input: &'a impl Input<'a>) -> ValError<'a> {
149149
ValError::new(self.kind, input)
150150
}

src/lib.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ mod validators;
1818

1919
// required for benchmarks
2020
pub use build_tools::SchemaError;
21-
pub use errors::{list_all_errors, PydanticCustomError, PydanticErrorKind, PydanticOmit, ValidationError};
21+
pub use errors::{list_all_errors, PydanticCustomError, PydanticKindError, PydanticOmit, ValidationError};
2222
pub use validators::SchemaValidator;
2323

2424
pub fn get_version() -> String {
@@ -39,7 +39,7 @@ fn _pydantic_core(_py: Python, m: &PyModule) -> PyResult<()> {
3939
m.add_class::<ValidationError>()?;
4040
m.add_class::<SchemaError>()?;
4141
m.add_class::<PydanticCustomError>()?;
42-
m.add_class::<PydanticErrorKind>()?;
42+
m.add_class::<PydanticKindError>()?;
4343
m.add_class::<PydanticOmit>()?;
4444
m.add_function(wrap_pyfunction!(list_all_errors, m)?)?;
4545
Ok(())

src/validators/function.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ use pyo3::types::{PyAny, PyDict};
55

66
use crate::build_tools::{py_error, SchemaDict};
77
use crate::errors::{
8-
ErrorKind, LocItem, PydanticCustomError, PydanticErrorKind, PydanticOmit, ValError, ValResult, ValidationError,
8+
ErrorKind, LocItem, PydanticCustomError, PydanticKindError, PydanticOmit, ValError, ValResult, ValidationError,
99
};
1010
use crate::input::Input;
1111
use crate::questions::Question;
@@ -285,7 +285,7 @@ pub fn convert_err<'a>(py: Python<'a>, err: PyErr, input: &'a impl Input<'a>) ->
285285
if err.is_instance_of::<PyValueError>(py) {
286286
if let Ok(pydantic_value_error) = err.value(py).extract::<PydanticCustomError>() {
287287
pydantic_value_error.into_val_error(input)
288-
} else if let Ok(pydantic_error_kind) = err.value(py).extract::<PydanticErrorKind>() {
288+
} else if let Ok(pydantic_error_kind) = err.value(py).extract::<PydanticKindError>() {
289289
pydantic_error_kind.into_val_error(input)
290290
} else if let Ok(validation_error) = err.value(py).extract::<ValidationError>() {
291291
validation_error.into_py(py)

src/validators/union.rs

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,26 @@ use pyo3::types::{PyDict, PyList, PyString};
77

88
use ahash::AHashMap;
99

10-
use crate::build_tools::{is_strict, schema_or_config, SchemaDict};
11-
use crate::errors::{ErrorKind, PydanticCustomError, ValError, ValLineError, ValResult};
10+
use crate::build_tools::{is_strict, py_error, schema_or_config, SchemaDict};
11+
use crate::errors::{ErrorKind, PydanticCustomError, PydanticKindError, ValError, ValLineError, ValResult};
1212
use crate::input::{GenericMapping, Input};
1313
use crate::lookup_key::LookupKey;
1414
use crate::questions::Question;
1515
use crate::recursion_guard::RecursionGuard;
1616

1717
use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
1818

19+
#[derive(Debug, Clone)]
20+
enum CustomError {
21+
Custom(PydanticCustomError),
22+
Kind(PydanticKindError),
23+
None,
24+
}
25+
1926
#[derive(Debug, Clone)]
2027
pub struct UnionValidator {
2128
choices: Vec<CombinedValidator>,
22-
custom_error: Option<PydanticCustomError>,
29+
custom_error: CustomError,
2330
strict: bool,
2431
name: String,
2532
}
@@ -51,15 +58,29 @@ impl BuildValidator for UnionValidator {
5158
}
5259
}
5360

54-
fn get_custom_error(py: Python, schema: &PyDict) -> PyResult<Option<PydanticCustomError>> {
55-
match schema.get_as::<&PyDict>(intern!(py, "custom_error"))? {
56-
Some(custom_error) => Ok(Some(PydanticCustomError::py_new(
61+
fn get_custom_error(py: Python, schema: &PyDict) -> PyResult<CustomError> {
62+
let custom_error: &PyDict = match schema.get_as(intern!(py, "custom_error"))? {
63+
Some(ce) => ce,
64+
None => return Ok(CustomError::None),
65+
};
66+
let kind: String = custom_error.get_as_req(intern!(py, "kind"))?;
67+
let context: Option<&PyDict> = custom_error.get_as(intern!(py, "context"))?;
68+
69+
if ErrorKind::valid_kind(py, &kind) {
70+
if custom_error.contains(intern!(py, "message"))? {
71+
py_error!("custom_error.message should not be provided if kind matches a known error")
72+
} else {
73+
let error = PydanticKindError::py_new(py, &kind, context)?;
74+
Ok(CustomError::Kind(error))
75+
}
76+
} else {
77+
let error = PydanticCustomError::py_new(
5778
py,
58-
custom_error.get_as_req::<String>(intern!(py, "kind"))?,
79+
kind,
5980
custom_error.get_as_req::<String>(intern!(py, "message"))?,
60-
None,
61-
))),
62-
None => Ok(None),
81+
context,
82+
);
83+
Ok(CustomError::Custom(error))
6384
}
6485
}
6586

@@ -72,8 +93,11 @@ impl UnionValidator {
7293
if let Some(errors) = errors {
7394
ValError::LineErrors(errors)
7495
} else {
75-
let value_error = self.custom_error.as_ref().unwrap();
76-
value_error.clone().into_val_error(input)
96+
match self.custom_error {
97+
CustomError::Kind(ref kind_error) => kind_error.clone().into_val_error(input),
98+
CustomError::Custom(ref custom_error) => custom_error.clone().into_val_error(input),
99+
CustomError::None => unreachable!(),
100+
}
77101
}
78102
}
79103
}
@@ -89,8 +113,8 @@ impl Validator for UnionValidator {
89113
) -> ValResult<'data, PyObject> {
90114
if extra.strict.unwrap_or(self.strict) {
91115
let mut errors: Option<Vec<ValLineError>> = match self.custom_error {
92-
Some(_) => None,
93-
None => Some(Vec::with_capacity(self.choices.len())),
116+
CustomError::None => Some(Vec::with_capacity(self.choices.len())),
117+
_ => None,
94118
};
95119
let strict_extra = extra.as_strict();
96120

@@ -124,8 +148,8 @@ impl Validator for UnionValidator {
124148
}
125149

126150
let mut errors: Option<Vec<ValLineError>> = match self.custom_error {
127-
Some(_) => None,
128-
None => Some(Vec::with_capacity(self.choices.len())),
151+
CustomError::None => Some(Vec::with_capacity(self.choices.len())),
152+
_ => None,
129153
};
130154

131155
// 2nd pass: check if the value can be coerced into one of the Union types, e.g. use validate
@@ -200,7 +224,7 @@ pub struct TaggedUnionValidator {
200224
discriminator: Discriminator,
201225
from_attributes: bool,
202226
strict: bool,
203-
custom_error: Option<PydanticCustomError>,
227+
custom_error: CustomError,
204228
tags_repr: String,
205229
discriminator_repr: String,
206230
name: String,
@@ -386,30 +410,32 @@ impl TaggedUnionValidator {
386410
Ok(res) => Ok(res),
387411
Err(err) => Err(err.with_outer_location(tag.as_ref().into())),
388412
}
389-
} else if let Some(ref custom_error) = self.custom_error {
390-
Err(custom_error.clone().into_val_error(input))
391413
} else {
392-
Err(ValError::new(
393-
ErrorKind::UnionTagInvalid {
394-
discriminator: self.discriminator_repr.clone(),
395-
tag: tag.to_string(),
396-
expected_tags: self.tags_repr.clone(),
397-
},
398-
input,
399-
))
414+
match self.custom_error {
415+
CustomError::Kind(ref kind_error) => Err(kind_error.clone().into_val_error(input)),
416+
CustomError::Custom(ref custom_error) => Err(custom_error.clone().into_val_error(input)),
417+
CustomError::None => Err(ValError::new(
418+
ErrorKind::UnionTagInvalid {
419+
discriminator: self.discriminator_repr.clone(),
420+
tag: tag.to_string(),
421+
expected_tags: self.tags_repr.clone(),
422+
},
423+
input,
424+
)),
425+
}
400426
}
401427
}
402428

403429
fn tag_not_found<'s, 'data>(&'s self, input: &'data impl Input<'data>) -> ValError<'data> {
404-
if let Some(ref custom_error) = self.custom_error {
405-
custom_error.clone().into_val_error(input)
406-
} else {
407-
ValError::new(
430+
match self.custom_error {
431+
CustomError::Kind(ref kind_error) => kind_error.clone().into_val_error(input),
432+
CustomError::Custom(ref custom_error) => custom_error.clone().into_val_error(input),
433+
CustomError::None => ValError::new(
408434
ErrorKind::UnionTagNotFound {
409435
discriminator: self.discriminator_repr.clone(),
410436
},
411437
input,
412-
)
438+
),
413439
}
414440
}
415441
}

tests/test_schema_functions.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,14 @@ def test_schema_functions(function, args_kwargs, expected_schema):
213213

214214

215215
def test_invalid_custom_error():
216-
s = core_schema.union_schema({'type': 'int'}, {'type': 'str'}, custom_error_kind='foobar')
217-
with pytest.raises(SchemaError, match=r'custom_error \-> message\s+Input should be a valid string'):
216+
s = core_schema.union_schema({'type': 'int'}, {'type': 'str'}, custom_error_message='foobar')
217+
with pytest.raises(SchemaError, match=r'custom_error \-> kind\s+Field required'):
218+
SchemaValidator(s)
219+
220+
221+
def test_invalid_custom_error_kind():
222+
s = core_schema.union_schema(
223+
{'type': 'int'}, {'type': 'str'}, custom_error_kind='finite_number', custom_error_message='x'
224+
)
225+
with pytest.raises(SchemaError, match='custom_error.message should not be provided if kind matches a known error'):
218226
SchemaValidator(s)

0 commit comments

Comments
 (0)