Skip to content

Commit be7b07e

Browse files
authored
Change how arguments are defined (#304)
* change how arguments are defined * tests for missing and extra dict keys with __args__/__kwargs__
1 parent fc42f4f commit be7b07e

File tree

11 files changed

+353
-239
lines changed

11 files changed

+353
-239
lines changed

pydantic_core/_pydantic_core.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,14 @@ else:
1111

1212
__all__ = (
1313
'__version__',
14+
'build_profile',
1415
'SchemaValidator',
1516
'SchemaError',
1617
'ValidationError',
1718
'PydanticCustomError',
1819
'PydanticKindError',
1920
'PydanticOmit',
21+
'list_all_errors',
2022
)
2123
__version__: str
2224
build_profile: str

pydantic_core/core_schema.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1138,6 +1138,8 @@ def json_schema(schema: CoreSchema | None = None, *, ref: str | None = None, ext
11381138
'union_tag_invalid',
11391139
'union_tag_not_found',
11401140
'arguments_type',
1141+
'positional_arguments_type',
1142+
'keyword_arguments_type',
11411143
'unexpected_keyword_argument',
11421144
'missing_keyword_argument',
11431145
'unexpected_positional_argument',

src/errors/kinds.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,8 +301,12 @@ pub enum ErrorKind {
301301
},
302302
// ---------------------
303303
// argument errors
304-
#[strum(message = "Arguments must be a tuple of (positional arguments, keyword arguments) or a plain dict")]
304+
#[strum(message = "Arguments must be a tuple, list or a dictionary")]
305305
ArgumentsType,
306+
#[strum(message = "Positional arguments must be a list or tuple")]
307+
PositionalArgumentsType,
308+
#[strum(message = "Keyword arguments must be a dictionary")]
309+
KeywordArgumentsType,
306310
#[strum(message = "Unexpected keyword argument")]
307311
UnexpectedKeywordArgument,
308312
#[strum(message = "Missing required keyword argument")]

src/input/input_json.rs

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use pyo3::prelude::*;
22

3-
use crate::errors::{ErrorKind, InputValue, LocItem, ValError, ValResult};
3+
use crate::errors::{ErrorKind, InputValue, LocItem, ValError, ValLineError, ValResult};
44

55
use super::datetime::{
66
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, float_as_datetime, float_as_duration,
@@ -56,24 +56,44 @@ impl<'a> Input<'a> for JsonInput {
5656

5757
fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
5858
match self {
59-
JsonInput::Object(kwargs) => Ok(JsonArgs::new(None, Some(kwargs)).into()),
60-
JsonInput::Array(array) => {
61-
if array.len() != 2 {
62-
Err(ValError::new(ErrorKind::ArgumentsType, self))
63-
} else {
64-
let args = match unsafe { array.get_unchecked(0) } {
65-
JsonInput::Null => None,
66-
JsonInput::Array(args) => Some(args.as_slice()),
67-
_ => return Err(ValError::new(ErrorKind::ArgumentsType, self)),
68-
};
69-
let kwargs = match unsafe { array.get_unchecked(1) } {
70-
JsonInput::Null => None,
71-
JsonInput::Object(kwargs) => Some(kwargs),
72-
_ => return Err(ValError::new(ErrorKind::ArgumentsType, self)),
73-
};
74-
Ok(JsonArgs::new(args, kwargs).into())
59+
JsonInput::Object(object) => {
60+
if let Some(args) = object.get("__args__") {
61+
if let Some(kwargs) = object.get("__kwargs__") {
62+
// we only try this logic if there are only these two items in the dict
63+
if object.len() == 2 {
64+
let args = match args {
65+
JsonInput::Null => Ok(None),
66+
JsonInput::Array(args) => Ok(Some(args.as_slice())),
67+
_ => Err(ValLineError::new_with_loc(
68+
ErrorKind::PositionalArgumentsType,
69+
args,
70+
"__args__",
71+
)),
72+
};
73+
let kwargs = match kwargs {
74+
JsonInput::Null => Ok(None),
75+
JsonInput::Object(kwargs) => Ok(Some(kwargs)),
76+
_ => Err(ValLineError::new_with_loc(
77+
ErrorKind::KeywordArgumentsType,
78+
kwargs,
79+
"__kwargs__",
80+
)),
81+
};
82+
83+
return match (args, kwargs) {
84+
(Ok(args), Ok(kwargs)) => Ok(JsonArgs::new(args, kwargs).into()),
85+
(Err(args_error), Err(kwargs_error)) => {
86+
return Err(ValError::LineErrors(vec![args_error, kwargs_error]))
87+
}
88+
(Err(error), _) => Err(ValError::LineErrors(vec![error])),
89+
(_, Err(error)) => Err(ValError::LineErrors(vec![error])),
90+
};
91+
}
92+
}
7593
}
94+
Ok(JsonArgs::new(None, Some(object)).into())
7695
}
96+
JsonInput::Array(array) => Ok(JsonArgs::new(Some(array), None).into()),
7797
_ => Err(ValError::new(ErrorKind::ArgumentsType, self)),
7898
}
7999
}

src/input/input_python.rs

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use pyo3::types::{
1212
use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues};
1313
use pyo3::{ffi, intern, AsPyPointer, PyTypeInfo};
1414

15-
use crate::errors::{py_err_string, ErrorKind, InputValue, LocItem, ValError, ValResult};
15+
use crate::errors::{py_err_string, ErrorKind, InputValue, LocItem, ValError, ValLineError, ValResult};
1616

1717
use super::datetime::{
1818
bytes_as_date, bytes_as_datetime, bytes_as_time, bytes_as_timedelta, date_as_datetime, float_as_datetime,
@@ -114,26 +114,54 @@ impl<'a> Input<'a> for PyAny {
114114
}
115115

116116
fn validate_args(&'a self) -> ValResult<'a, GenericArguments<'a>> {
117-
if let Ok(kwargs) = self.cast_as::<PyDict>() {
118-
Ok(PyArgs::new(None, Some(kwargs)).into())
119-
} else if let Ok((args, kwargs)) = self.extract::<(&PyAny, &PyAny)>() {
120-
let args = if let Ok(tuple) = args.cast_as::<PyTuple>() {
121-
Some(tuple)
122-
} else if args.is_none() {
123-
None
124-
} else if let Ok(list) = args.cast_as::<PyList>() {
125-
Some(PyTuple::new(self.py(), list.iter().collect::<Vec<_>>()))
126-
} else {
127-
return Err(ValError::new(ErrorKind::ArgumentsType, self));
128-
};
129-
let kwargs = if let Ok(dict) = kwargs.cast_as::<PyDict>() {
130-
Some(dict)
131-
} else if kwargs.is_none() {
132-
None
133-
} else {
134-
return Err(ValError::new(ErrorKind::ArgumentsType, self));
135-
};
136-
Ok(PyArgs::new(args, kwargs).into())
117+
if let Ok(dict) = self.cast_as::<PyDict>() {
118+
if let Some(args) = dict.get_item("__args__") {
119+
if let Some(kwargs) = dict.get_item("__kwargs__") {
120+
// we only try this logic if there are only these two items in the dict
121+
if dict.len() == 2 {
122+
let args = if let Ok(tuple) = args.cast_as::<PyTuple>() {
123+
Ok(Some(tuple))
124+
} else if args.is_none() {
125+
Ok(None)
126+
} else if let Ok(list) = args.cast_as::<PyList>() {
127+
Ok(Some(PyTuple::new(self.py(), list.iter().collect::<Vec<_>>())))
128+
} else {
129+
Err(ValLineError::new_with_loc(
130+
ErrorKind::PositionalArgumentsType,
131+
args,
132+
"__args__",
133+
))
134+
};
135+
136+
let kwargs = if let Ok(dict) = kwargs.cast_as::<PyDict>() {
137+
Ok(Some(dict))
138+
} else if kwargs.is_none() {
139+
Ok(None)
140+
} else {
141+
Err(ValLineError::new_with_loc(
142+
ErrorKind::KeywordArgumentsType,
143+
kwargs,
144+
"__kwargs__",
145+
))
146+
};
147+
148+
return match (args, kwargs) {
149+
(Ok(args), Ok(kwargs)) => Ok(PyArgs::new(args, kwargs).into()),
150+
(Err(args_error), Err(kwargs_error)) => {
151+
Err(ValError::LineErrors(vec![args_error, kwargs_error]))
152+
}
153+
(Err(error), _) => Err(ValError::LineErrors(vec![error])),
154+
(_, Err(error)) => Err(ValError::LineErrors(vec![error])),
155+
};
156+
}
157+
}
158+
}
159+
Ok(PyArgs::new(None, Some(dict)).into())
160+
} else if let Ok(tuple) = self.cast_as::<PyTuple>() {
161+
Ok(PyArgs::new(Some(tuple), None).into())
162+
} else if let Ok(list) = self.cast_as::<PyList>() {
163+
let tuple = PyTuple::new(self.py(), list.iter().collect::<Vec<_>>());
164+
Ok(PyArgs::new(Some(tuple), None).into())
137165
} else {
138166
Err(ValError::new(ErrorKind::ArgumentsType, self))
139167
}

tests/benchmarks/test_micro_benchmarks.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -889,9 +889,12 @@ def test_arguments(benchmark):
889889
],
890890
}
891891
)
892-
assert v.validate_python(((1, 'a', 'true'), {'b': 'bb', 'c': 3})) == ((1, 'a', True), {'b': 'bb', 'c': 3})
892+
assert v.validate_python({'__args__': (1, 'a', 'true'), '__kwargs__': {'b': 'bb', 'c': 3}}) == (
893+
(1, 'a', True),
894+
{'b': 'bb', 'c': 3},
895+
)
893896

894-
benchmark(v.validate_python, ((1, 'a', 'true'), {'b': 'bb', 'c': 3}))
897+
benchmark(v.validate_python, {'__args__': (1, 'a', 'true'), '__kwargs__': {'b': 'bb', 'c': 3}})
895898

896899

897900
@pytest.mark.benchmark(group='defaults')

0 commit comments

Comments
 (0)