@@ -8,7 +8,7 @@ use pyo3::types::{PyDict, PyList, PyString};
8
8
use ahash:: AHashMap ;
9
9
10
10
use crate :: build_tools:: { is_strict, schema_or_config, SchemaDict } ;
11
- use crate :: errors:: { ErrorKind , ValError , ValLineError , ValResult } ;
11
+ use crate :: errors:: { ErrorKind , PydanticValueError , ValError , ValLineError , ValResult } ;
12
12
use crate :: input:: { GenericMapping , Input } ;
13
13
use crate :: lookup_key:: LookupKey ;
14
14
use crate :: questions:: Question ;
@@ -19,6 +19,7 @@ use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Ex
19
19
#[ derive( Debug , Clone ) ]
20
20
pub struct UnionValidator {
21
21
choices : Vec < CombinedValidator > ,
22
+ custom_error : Option < PydanticValueError > ,
22
23
strict : bool ,
23
24
name : String ,
24
25
}
@@ -31,8 +32,9 @@ impl BuildValidator for UnionValidator {
31
32
config : Option < & PyDict > ,
32
33
build_context : & mut BuildContext ,
33
34
) -> PyResult < CombinedValidator > {
35
+ let py = schema. py ( ) ;
34
36
let choices: Vec < CombinedValidator > = schema
35
- . get_as_req :: < & PyList > ( intern ! ( schema . py ( ) , "choices" ) ) ?
37
+ . get_as_req :: < & PyList > ( intern ! ( py , "choices" ) ) ?
36
38
. iter ( )
37
39
. map ( |choice| build_validator ( choice, config, build_context) )
38
40
. collect :: < PyResult < Vec < CombinedValidator > > > ( ) ?;
@@ -41,13 +43,41 @@ impl BuildValidator for UnionValidator {
41
43
42
44
Ok ( Self {
43
45
choices,
46
+ custom_error : get_custom_error ( py, schema) ?,
44
47
strict : is_strict ( schema, config) ?,
45
48
name : format ! ( "{}[{}]" , Self :: EXPECTED_TYPE , descr) ,
46
49
}
47
50
. into ( ) )
48
51
}
49
52
}
50
53
54
+ fn get_custom_error ( py : Python , schema : & PyDict ) -> PyResult < Option < PydanticValueError > > {
55
+ match schema. get_as :: < & PyDict > ( intern ! ( py, "custom_error" ) ) ? {
56
+ Some ( custom_error) => Ok ( Some ( PydanticValueError :: py_new (
57
+ py,
58
+ custom_error. get_as_req :: < String > ( intern ! ( py, "kind" ) ) ?,
59
+ custom_error. get_as_req :: < String > ( intern ! ( py, "message" ) ) ?,
60
+ None ,
61
+ ) ) ) ,
62
+ None => Ok ( None ) ,
63
+ }
64
+ }
65
+
66
+ impl UnionValidator {
67
+ fn or_custom_error < ' s , ' data > (
68
+ & ' s self ,
69
+ errors : Option < Vec < ValLineError < ' data > > > ,
70
+ input : & ' data impl Input < ' data > ,
71
+ ) -> ValError < ' data > {
72
+ if let Some ( errors) = errors {
73
+ ValError :: LineErrors ( errors)
74
+ } else {
75
+ let value_error = self . custom_error . as_ref ( ) . unwrap ( ) ;
76
+ value_error. clone ( ) . into_val_error ( input)
77
+ }
78
+ }
79
+ }
80
+
51
81
impl Validator for UnionValidator {
52
82
fn validate < ' s , ' data > (
53
83
& ' s self ,
@@ -58,7 +88,10 @@ impl Validator for UnionValidator {
58
88
recursion_guard : & ' s mut RecursionGuard ,
59
89
) -> ValResult < ' data , PyObject > {
60
90
if extra. strict . unwrap_or ( self . strict ) {
61
- let mut errors: Vec < ValLineError > = Vec :: with_capacity ( self . choices . len ( ) ) ;
91
+ let mut errors: Option < Vec < ValLineError > > = match self . custom_error {
92
+ Some ( _) => None ,
93
+ None => Some ( Vec :: with_capacity ( self . choices . len ( ) ) ) ,
94
+ } ;
62
95
let strict_extra = extra. as_strict ( ) ;
63
96
64
97
for validator in & self . choices {
@@ -67,14 +100,16 @@ impl Validator for UnionValidator {
67
100
otherwise => return otherwise,
68
101
} ;
69
102
70
- errors. extend (
71
- line_errors
72
- . into_iter ( )
73
- . map ( |err| err. with_outer_location ( validator. get_name ( ) . into ( ) ) ) ,
74
- ) ;
103
+ if let Some ( ref mut errors) = errors {
104
+ errors. extend (
105
+ line_errors
106
+ . into_iter ( )
107
+ . map ( |err| err. with_outer_location ( validator. get_name ( ) . into ( ) ) ) ,
108
+ ) ;
109
+ }
75
110
}
76
111
77
- Err ( ValError :: LineErrors ( errors) )
112
+ Err ( self . or_custom_error ( errors, input ) )
78
113
} else {
79
114
// 1st pass: check if the value is an exact instance of one of the Union types,
80
115
// e.g. use validate in strict mode
@@ -88,7 +123,10 @@ impl Validator for UnionValidator {
88
123
return res;
89
124
}
90
125
91
- let mut errors: Vec < ValLineError > = Vec :: with_capacity ( self . choices . len ( ) ) ;
126
+ let mut errors: Option < Vec < ValLineError > > = match self . custom_error {
127
+ Some ( _) => None ,
128
+ None => Some ( Vec :: with_capacity ( self . choices . len ( ) ) ) ,
129
+ } ;
92
130
93
131
// 2nd pass: check if the value can be coerced into one of the Union types, e.g. use validate
94
132
for validator in & self . choices {
@@ -97,14 +135,16 @@ impl Validator for UnionValidator {
97
135
success => return success,
98
136
} ;
99
137
100
- errors. extend (
101
- line_errors
102
- . into_iter ( )
103
- . map ( |err| err. with_outer_location ( validator. get_name ( ) . into ( ) ) ) ,
104
- ) ;
138
+ if let Some ( ref mut errors) = errors {
139
+ errors. extend (
140
+ line_errors
141
+ . into_iter ( )
142
+ . map ( |err| err. with_outer_location ( validator. get_name ( ) . into ( ) ) ) ,
143
+ ) ;
144
+ }
105
145
}
106
146
107
- Err ( ValError :: LineErrors ( errors) )
147
+ Err ( self . or_custom_error ( errors, input ) )
108
148
}
109
149
}
110
150
@@ -160,6 +200,7 @@ pub struct TaggedUnionValidator {
160
200
discriminator : Discriminator ,
161
201
from_attributes : bool ,
162
202
strict : bool ,
203
+ custom_error : Option < PydanticValueError > ,
163
204
tags_repr : String ,
164
205
discriminator_repr : String ,
165
206
name : String ,
@@ -206,6 +247,7 @@ impl BuildValidator for TaggedUnionValidator {
206
247
discriminator,
207
248
from_attributes,
208
249
strict : is_strict ( schema, config) ?,
250
+ custom_error : get_custom_error ( py, schema) ?,
209
251
tags_repr,
210
252
discriminator_repr,
211
253
name : format ! ( "{}[{}]" , Self :: EXPECTED_TYPE , descr) ,
@@ -341,6 +383,8 @@ impl TaggedUnionValidator {
341
383
Ok ( res) => Ok ( res) ,
342
384
Err ( err) => Err ( err. with_outer_location ( tag. as_ref ( ) . into ( ) ) ) ,
343
385
}
386
+ } else if let Some ( ref custom_error) = self . custom_error {
387
+ Err ( custom_error. clone ( ) . into_val_error ( input) )
344
388
} else {
345
389
Err ( ValError :: new (
346
390
ErrorKind :: UnionTagInvalid {
@@ -354,11 +398,15 @@ impl TaggedUnionValidator {
354
398
}
355
399
356
400
fn tag_not_found < ' s , ' data > ( & ' s self , input : & ' data impl Input < ' data > ) -> ValError < ' data > {
357
- ValError :: new (
358
- ErrorKind :: UnionTagNotFound {
359
- discriminator : self . discriminator_repr . clone ( ) ,
360
- } ,
361
- input,
362
- )
401
+ if let Some ( ref custom_error) = self . custom_error {
402
+ custom_error. clone ( ) . into_val_error ( input)
403
+ } else {
404
+ ValError :: new (
405
+ ErrorKind :: UnionTagNotFound {
406
+ discriminator : self . discriminator_repr . clone ( ) ,
407
+ } ,
408
+ input,
409
+ )
410
+ }
363
411
}
364
412
}
0 commit comments