1
1
use pyo3:: exceptions:: { PyTypeError , PyValueError } ;
2
2
use pyo3:: intern;
3
3
use pyo3:: sync:: GILOnceCell ;
4
- use pyo3:: types:: { IntoPyDict , PyDict , PyTuple , PyType } ;
4
+ use pyo3:: types:: { IntoPyDict , PyDict , PyString , PyTuple , PyType } ;
5
5
use pyo3:: { prelude:: * , PyTypeInfo } ;
6
6
7
7
use crate :: build_tools:: { is_strict, schema_or_config_same} ;
@@ -28,6 +28,18 @@ pub fn get_decimal_type(py: Python) -> &Bound<'_, PyType> {
28
28
. bind ( py)
29
29
}
30
30
31
+ fn validate_as_decimal ( py : Python , schema : & Bound < ' _ , PyDict > , key : & str ) -> PyResult < Option < Py < PyAny > > > {
32
+ match schema. get_as :: < Bound < ' _ , PyAny > > ( & PyString :: new ( py, key) ) ? {
33
+ Some ( value) => match value. validate_decimal ( false , py) {
34
+ Ok ( v) => Ok ( Some ( v. into_inner ( ) . unbind ( ) ) ) ,
35
+ Err ( _) => Err ( PyValueError :: new_err ( format ! (
36
+ "'{key}' must be coercible to a Decimal instance" ,
37
+ ) ) ) ,
38
+ } ,
39
+ None => Ok ( None ) ,
40
+ }
41
+ }
42
+
31
43
#[ derive( Debug , Clone ) ]
32
44
pub struct DecimalValidator {
33
45
strict : bool ,
@@ -50,6 +62,7 @@ impl BuildValidator for DecimalValidator {
50
62
_definitions : & mut DefinitionsBuilder < CombinedValidator > ,
51
63
) -> PyResult < CombinedValidator > {
52
64
let py = schema. py ( ) ;
65
+
53
66
let allow_inf_nan = schema_or_config_same ( schema, config, intern ! ( py, "allow_inf_nan" ) ) ?. unwrap_or ( false ) ;
54
67
let decimal_places = schema. get_as ( intern ! ( py, "decimal_places" ) ) ?;
55
68
let max_digits = schema. get_as ( intern ! ( py, "max_digits" ) ) ?;
@@ -58,16 +71,17 @@ impl BuildValidator for DecimalValidator {
58
71
"allow_inf_nan=True cannot be used with max_digits or decimal_places" ,
59
72
) ) ;
60
73
}
74
+
61
75
Ok ( Self {
62
76
strict : is_strict ( schema, config) ?,
63
77
allow_inf_nan,
64
78
check_digits : decimal_places. is_some ( ) || max_digits. is_some ( ) ,
65
79
decimal_places,
66
- multiple_of : schema . get_as ( intern ! ( py, "multiple_of" ) ) ?,
67
- le : schema . get_as ( intern ! ( py, "le" ) ) ?,
68
- lt : schema . get_as ( intern ! ( py, "lt" ) ) ?,
69
- ge : schema . get_as ( intern ! ( py, "ge" ) ) ?,
70
- gt : schema . get_as ( intern ! ( py, "gt" ) ) ?,
80
+ multiple_of : validate_as_decimal ( py, schema , "multiple_of" ) ?,
81
+ le : validate_as_decimal ( py, schema , "le" ) ?,
82
+ lt : validate_as_decimal ( py, schema , "lt" ) ?,
83
+ ge : validate_as_decimal ( py, schema , "ge" ) ?,
84
+ gt : validate_as_decimal ( py, schema , "gt" ) ?,
71
85
max_digits,
72
86
}
73
87
. into ( ) )
0 commit comments