Skip to content

Commit d1feaa3

Browse files
authored
Allow Sequence and tuple to is_instance validator (#299)
* allow Sequence and Tuple to is_instance validator * rename to input_is_instance * repr tests and type hints * add cls_repr argument to isinstance validator
1 parent a8a1f1a commit d1feaa3

File tree

8 files changed

+84
-31
lines changed

8 files changed

+84
-31
lines changed

pydantic_core/core_schema.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -350,23 +350,31 @@ def literal_schema(*expected: Any, ref: str | None = None, extra: Any = None) ->
350350

351351
class IsInstanceSchema(TypedDict, total=False):
352352
type: Required[Literal['is-instance']]
353-
cls: Required[Type[Any]]
353+
cls: Required[Any]
354+
cls_repr: str
354355
json_types: Set[JsonType]
355356
json_function: Callable[[Any], Any]
356357
ref: str
357358
extra: Any
358359

359360

360361
def is_instance_schema(
361-
cls: Type[Any],
362+
cls: Any,
362363
*,
363364
json_types: Set[JsonType] | None = None,
364365
json_function: Callable[[Any], Any] | None = None,
366+
cls_repr: str | None = None,
365367
ref: str | None = None,
366368
extra: Any = None,
367369
) -> IsInstanceSchema:
368370
return dict_not_none(
369-
type='is-instance', cls=cls, json_types=json_types, json_function=json_function, ref=ref, extra=extra
371+
type='is-instance',
372+
cls=cls,
373+
json_types=json_types,
374+
json_function=json_function,
375+
cls_repr=cls_repr,
376+
ref=ref,
377+
extra=extra,
370378
)
371379

372380

src/input/input_abstract.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ pub trait Input<'a>: fmt::Debug + ToPyObject {
4747
None
4848
}
4949

50-
fn is_instance(&self, class: &PyType, json_mask: u8) -> PyResult<bool>;
50+
// input_ prefix to differentiate from the function on PyAny
51+
fn input_is_instance(&self, class: &PyAny, json_mask: u8) -> PyResult<bool>;
5152

5253
fn callable(&self) -> bool {
5354
false

src/input/input_json.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use pyo3::prelude::*;
2-
use pyo3::types::PyType;
32

43
use crate::errors::{ErrorKind, InputValue, LocItem, ValError, ValResult};
54

@@ -38,7 +37,7 @@ impl<'a> Input<'a> for JsonInput {
3837
matches!(self, JsonInput::Null)
3938
}
4039

41-
fn is_instance(&self, _class: &PyType, json_mask: u8) -> PyResult<bool> {
40+
fn input_is_instance(&self, _class: &PyAny, json_mask: u8) -> PyResult<bool> {
4241
if json_mask == 0 {
4342
Ok(false)
4443
} else {
@@ -320,7 +319,7 @@ impl<'a> Input<'a> for String {
320319
false
321320
}
322321

323-
fn is_instance(&self, _class: &PyType, json_mask: u8) -> PyResult<bool> {
322+
fn input_is_instance(&self, _class: &PyAny, json_mask: u8) -> PyResult<bool> {
324323
if json_mask == 0 {
325324
Ok(false)
326325
} else {

src/input/input_python.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use pyo3::types::{
1010
};
1111
#[cfg(not(PyPy))]
1212
use pyo3::types::{PyDictItems, PyDictKeys, PyDictValues};
13-
use pyo3::{intern, AsPyPointer, PyTypeInfo};
13+
use pyo3::{ffi, intern, AsPyPointer, PyTypeInfo};
1414

1515
use crate::errors::{py_err_string, ErrorKind, InputValue, LocItem, ValError, ValResult};
1616

@@ -22,8 +22,8 @@ use super::datetime::{
2222
use super::input_abstract::InputType;
2323
use super::shared::{float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_int};
2424
use super::{
25-
py_string_str, repr_string, EitherBytes, EitherString, EitherTimedelta, GenericArguments, GenericCollection,
26-
GenericIterator, GenericMapping, Input, JsonInput, PyArgs,
25+
py_error_on_minusone, py_string_str, repr_string, EitherBytes, EitherString, EitherTimedelta, GenericArguments,
26+
GenericCollection, GenericIterator, GenericMapping, Input, JsonInput, PyArgs,
2727
};
2828

2929
/// Extract generators and deques into a `GenericCollection`
@@ -93,8 +93,12 @@ impl<'a> Input<'a> for PyAny {
9393
self.getattr(name).ok()
9494
}
9595

96-
fn is_instance(&self, class: &PyType, _json_mask: u8) -> PyResult<bool> {
97-
self.is_instance(class)
96+
fn input_is_instance(&self, class: &PyAny, _json_mask: u8) -> PyResult<bool> {
97+
// See PyO3/pyo3#2694 - we can't use `is_instance` here since it requires PyType,
98+
// and some check objects are not types, this logic is lifted from `is_instance` in PyO3
99+
let result = unsafe { ffi::PyObject_IsInstance(self.as_ptr(), class.as_ptr()) };
100+
py_error_on_minusone(self.py(), result)?;
101+
Ok(result == 1)
98102
}
99103

100104
fn callable(&self) -> bool {

src/input/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::os::raw::c_int;
2+
13
use pyo3::prelude::*;
24

35
mod datetime;
@@ -19,3 +21,12 @@ pub use return_enums::{
1921
pub fn repr_string(v: &PyAny) -> PyResult<String> {
2022
v.repr()?.extract()
2123
}
24+
25+
// Defined here as it's not exported by pyo3
26+
pub fn py_error_on_minusone(py: Python<'_>, result: c_int) -> PyResult<()> {
27+
if result != -1 {
28+
Ok(())
29+
} else {
30+
Err(PyErr::fetch(py))
31+
}
32+
}

src/validators/is_instance.rs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use pyo3::intern;
22
use pyo3::prelude::*;
33
use pyo3::types::{PyDict, PySet, PyType};
44

5-
use crate::build_tools::SchemaDict;
5+
use crate::build_tools::{py_err, SchemaDict};
66
use crate::errors::{ErrorKind, ValError, ValResult};
77
use crate::input::{Input, JsonType};
88
use crate::recursion_guard::RecursionGuard;
@@ -12,7 +12,7 @@ use super::{BuildContext, BuildValidator, CombinedValidator, Extra, Validator};
1212

1313
#[derive(Debug, Clone)]
1414
pub struct IsInstanceValidator {
15-
class: Py<PyType>,
15+
class: PyObject,
1616
json_types: u8,
1717
json_function: Option<PyObject>,
1818
class_repr: String,
@@ -28,8 +28,23 @@ impl BuildValidator for IsInstanceValidator {
2828
_build_context: &mut BuildContext,
2929
) -> PyResult<CombinedValidator> {
3030
let py = schema.py();
31-
let class: &PyType = schema.get_as_req(intern!(py, "cls"))?;
32-
let class_repr = class.name()?.to_string();
31+
let cls_key = intern!(py, "cls");
32+
let class: &PyAny = schema.get_as_req(cls_key)?;
33+
34+
// test that class works with isinstance to avoid errors at call time, reuse cls_key since it doesn't
35+
// matter what object is being checked
36+
let test_value: &PyAny = cls_key.as_ref();
37+
if test_value.input_is_instance(class, 0).is_err() {
38+
return py_err!("'cls' must be valid as the first argument to 'isinstance'");
39+
}
40+
41+
let class_repr = match schema.get_as(intern!(py, "cls_repr"))? {
42+
Some(s) => s,
43+
None => match class.extract::<&PyType>() {
44+
Ok(t) => t.name()?.to_string(),
45+
Err(_) => class.repr()?.extract()?,
46+
},
47+
};
3348
let name = format!("{}[{}]", Self::EXPECTED_TYPE, class_repr);
3449
let json_types = match schema.get_as::<&PySet>(intern!(py, "json_types"))? {
3550
Some(s) => JsonType::combine(s)?,
@@ -55,7 +70,7 @@ impl Validator for IsInstanceValidator {
5570
_slots: &'data [CombinedValidator],
5671
_recursion_guard: &'s mut RecursionGuard,
5772
) -> ValResult<'data, PyObject> {
58-
match input.is_instance(self.class.as_ref(py), self.json_types)? {
73+
match input.input_is_instance(self.class.as_ref(py), self.json_types)? {
5974
true => {
6075
if input.get_type().is_json() {
6176
if let Some(ref json_function) = self.json_function {

src/validators/new_class.rs

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use std::cmp::Ordering;
2-
use std::os::raw::c_int;
32
use std::ptr::null_mut;
43

54
use pyo3::conversion::AsPyPointer;
@@ -10,7 +9,7 @@ use pyo3::{ffi, intern};
109

1110
use crate::build_tools::{py_err, SchemaDict};
1211
use crate::errors::{ErrorKind, ValError, ValResult};
13-
use crate::input::Input;
12+
use crate::input::{py_error_on_minusone, Input};
1413
use crate::questions::Question;
1514
use crate::recursion_guard::RecursionGuard;
1615

@@ -159,23 +158,13 @@ where
159158
let attr_name = attr_name.to_object(py);
160159
let value = value.to_object(py);
161160
unsafe {
162-
error_on_minusone(
161+
py_error_on_minusone(
163162
py,
164163
ffi::PyObject_GenericSetAttr(obj.as_ptr(), attr_name.as_ptr(), value.as_ptr()),
165164
)
166165
}
167166
}
168167

169-
// Defined here as it's not exported by pyo3
170-
#[inline]
171-
fn error_on_minusone(py: Python<'_>, result: c_int) -> PyResult<()> {
172-
if result != -1 {
173-
Ok(())
174-
} else {
175-
Err(PyErr::fetch(py))
176-
}
177-
}
178-
179168
fn build_config<'a>(
180169
py: Python<'a>,
181170
schema: &'a PyDict,

tests/validators/test_is_instance.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import typing
12
from collections import deque
23

34
import pytest
@@ -69,7 +70,7 @@ def test_is_instance_cases(schema_class, input_val, value):
6970

7071
@pytest.mark.parametrize('input_cls', [123, 'foo', Foo(), [], {1: 2}])
7172
def test_is_instance_invalid(input_cls):
72-
with pytest.raises(SchemaError, match="object cannot be converted to 'PyType'"):
73+
with pytest.raises(SchemaError, match="SchemaError: 'cls' must be valid as the first argument to 'isinstance'"):
7374
SchemaValidator({'type': 'is-instance', 'cls': input_cls})
7475

7576

@@ -197,3 +198,28 @@ def test_json_function():
197198
v.validate_python([1, 2, 3])
198199
with pytest.raises(ValidationError, match=r'Input should be an instance of deque \[kind=is_instance_of,'):
199200
v.validate_json('{"1": 2}')
201+
202+
203+
def test_is_instance_sequence():
204+
v = SchemaValidator(core_schema.is_instance_schema(typing.Sequence))
205+
assert v.isinstance_python(1) is False
206+
assert v.isinstance_python([1]) is True
207+
208+
with pytest.raises(ValidationError, match=r'Input should be an instance of typing.Sequence \[kind=is_instance_of,'):
209+
v.validate_python(1)
210+
211+
212+
def test_is_instance_tuple():
213+
v = SchemaValidator(core_schema.is_instance_schema((int, str)))
214+
assert v.isinstance_python(1) is True
215+
assert v.isinstance_python('foobar') is True
216+
assert v.isinstance_python([1]) is False
217+
with pytest.raises(ValidationError, match=r"Input should be an instance of \(<class 'int'>, <class 'str'>\)"):
218+
v.validate_python([1])
219+
220+
221+
def test_class_repr():
222+
v = SchemaValidator(core_schema.is_instance_schema(int, cls_repr='Foobar'))
223+
assert v.validate_python(1) == 1
224+
with pytest.raises(ValidationError, match=r'Input should be an instance of Foobar \[kind=is_instance_of,'):
225+
v.validate_python('1')

0 commit comments

Comments
 (0)