@@ -7,19 +7,26 @@ use pyo3::types::{PyDict, PyList, PyString};
77
88use ahash:: AHashMap ;
99
10- use crate :: build_tools:: { is_strict, schema_or_config, SchemaDict } ;
11- use crate :: errors:: { ErrorKind , PydanticCustomError , ValError , ValLineError , ValResult } ;
10+ use crate :: build_tools:: { is_strict, py_error , schema_or_config, SchemaDict } ;
11+ use crate :: errors:: { ErrorKind , PydanticCustomError , PydanticKindError , ValError , ValLineError , ValResult } ;
1212use crate :: input:: { GenericMapping , Input } ;
1313use crate :: lookup_key:: LookupKey ;
1414use crate :: questions:: Question ;
1515use crate :: recursion_guard:: RecursionGuard ;
1616
1717use super :: { build_validator, BuildContext , BuildValidator , CombinedValidator , Extra , Validator } ;
1818
19+ #[ derive( Debug , Clone ) ]
20+ enum CustomError {
21+ Custom ( PydanticCustomError ) ,
22+ Kind ( PydanticKindError ) ,
23+ None ,
24+ }
25+
1926#[ derive( Debug , Clone ) ]
2027pub struct UnionValidator {
2128 choices : Vec < CombinedValidator > ,
22- custom_error : Option < PydanticCustomError > ,
29+ custom_error : CustomError ,
2330 strict : bool ,
2431 name : String ,
2532}
@@ -51,15 +58,29 @@ impl BuildValidator for UnionValidator {
5158 }
5259}
5360
54- fn get_custom_error ( py : Python , schema : & PyDict ) -> PyResult < Option < PydanticCustomError > > {
55- match schema. get_as :: < & PyDict > ( intern ! ( py, "custom_error" ) ) ? {
56- Some ( custom_error) => Ok ( Some ( PydanticCustomError :: py_new (
61+ fn get_custom_error ( py : Python , schema : & PyDict ) -> PyResult < CustomError > {
62+ let custom_error: & PyDict = match schema. get_as ( intern ! ( py, "custom_error" ) ) ? {
63+ Some ( ce) => ce,
64+ None => return Ok ( CustomError :: None ) ,
65+ } ;
66+ let kind: String = custom_error. get_as_req ( intern ! ( py, "kind" ) ) ?;
67+ let context: Option < & PyDict > = custom_error. get_as ( intern ! ( py, "context" ) ) ?;
68+
69+ if ErrorKind :: valid_kind ( py, & kind) {
70+ if custom_error. contains ( intern ! ( py, "message" ) ) ? {
71+ py_error ! ( "custom_error.message should not be provided if kind matches a known error" )
72+ } else {
73+ let error = PydanticKindError :: py_new ( py, & kind, context) ?;
74+ Ok ( CustomError :: Kind ( error) )
75+ }
76+ } else {
77+ let error = PydanticCustomError :: py_new (
5778 py,
58- custom_error . get_as_req :: < String > ( intern ! ( py , " kind" ) ) ? ,
79+ kind,
5980 custom_error. get_as_req :: < String > ( intern ! ( py, "message" ) ) ?,
60- None ,
61- ) ) ) ,
62- None => Ok ( None ) ,
81+ context ,
82+ ) ;
83+ Ok ( CustomError :: Custom ( error ) )
6384 }
6485}
6586
@@ -72,8 +93,11 @@ impl UnionValidator {
7293 if let Some ( errors) = errors {
7394 ValError :: LineErrors ( errors)
7495 } else {
75- let value_error = self . custom_error . as_ref ( ) . unwrap ( ) ;
76- value_error. clone ( ) . into_val_error ( input)
96+ match self . custom_error {
97+ CustomError :: Kind ( ref kind_error) => kind_error. clone ( ) . into_val_error ( input) ,
98+ CustomError :: Custom ( ref custom_error) => custom_error. clone ( ) . into_val_error ( input) ,
99+ CustomError :: None => unreachable ! ( ) ,
100+ }
77101 }
78102 }
79103}
@@ -89,8 +113,8 @@ impl Validator for UnionValidator {
89113 ) -> ValResult < ' data , PyObject > {
90114 if extra. strict . unwrap_or ( self . strict ) {
91115 let mut errors: Option < Vec < ValLineError > > = match self . custom_error {
92- Some ( _ ) => None ,
93- None => Some ( Vec :: with_capacity ( self . choices . len ( ) ) ) ,
116+ CustomError :: None => Some ( Vec :: with_capacity ( self . choices . len ( ) ) ) ,
117+ _ => None ,
94118 } ;
95119 let strict_extra = extra. as_strict ( ) ;
96120
@@ -124,8 +148,8 @@ impl Validator for UnionValidator {
124148 }
125149
126150 let mut errors: Option < Vec < ValLineError > > = match self . custom_error {
127- Some ( _ ) => None ,
128- None => Some ( Vec :: with_capacity ( self . choices . len ( ) ) ) ,
151+ CustomError :: None => Some ( Vec :: with_capacity ( self . choices . len ( ) ) ) ,
152+ _ => None ,
129153 } ;
130154
131155 // 2nd pass: check if the value can be coerced into one of the Union types, e.g. use validate
@@ -200,7 +224,7 @@ pub struct TaggedUnionValidator {
200224 discriminator : Discriminator ,
201225 from_attributes : bool ,
202226 strict : bool ,
203- custom_error : Option < PydanticCustomError > ,
227+ custom_error : CustomError ,
204228 tags_repr : String ,
205229 discriminator_repr : String ,
206230 name : String ,
@@ -386,30 +410,32 @@ impl TaggedUnionValidator {
386410 Ok ( res) => Ok ( res) ,
387411 Err ( err) => Err ( err. with_outer_location ( tag. as_ref ( ) . into ( ) ) ) ,
388412 }
389- } else if let Some ( ref custom_error) = self . custom_error {
390- Err ( custom_error. clone ( ) . into_val_error ( input) )
391413 } else {
392- Err ( ValError :: new (
393- ErrorKind :: UnionTagInvalid {
394- discriminator : self . discriminator_repr . clone ( ) ,
395- tag : tag. to_string ( ) ,
396- expected_tags : self . tags_repr . clone ( ) ,
397- } ,
398- input,
399- ) )
414+ match self . custom_error {
415+ CustomError :: Kind ( ref kind_error) => Err ( kind_error. clone ( ) . into_val_error ( input) ) ,
416+ CustomError :: Custom ( ref custom_error) => Err ( custom_error. clone ( ) . into_val_error ( input) ) ,
417+ CustomError :: None => Err ( ValError :: new (
418+ ErrorKind :: UnionTagInvalid {
419+ discriminator : self . discriminator_repr . clone ( ) ,
420+ tag : tag. to_string ( ) ,
421+ expected_tags : self . tags_repr . clone ( ) ,
422+ } ,
423+ input,
424+ ) ) ,
425+ }
400426 }
401427 }
402428
403429 fn tag_not_found < ' s , ' data > ( & ' s self , input : & ' data impl Input < ' data > ) -> ValError < ' data > {
404- if let Some ( ref custom_error ) = self . custom_error {
405- custom_error . clone ( ) . into_val_error ( input)
406- } else {
407- ValError :: new (
430+ match self . custom_error {
431+ CustomError :: Kind ( ref kind_error ) => kind_error . clone ( ) . into_val_error ( input) ,
432+ CustomError :: Custom ( ref custom_error ) => custom_error . clone ( ) . into_val_error ( input ) ,
433+ CustomError :: None => ValError :: new (
408434 ErrorKind :: UnionTagNotFound {
409435 discriminator : self . discriminator_repr . clone ( ) ,
410436 } ,
411437 input,
412- )
438+ ) ,
413439 }
414440 }
415441}
0 commit comments