|
| 1 | +use pyo3::exceptions::PyTypeError; |
| 2 | +use pyo3::intern; |
| 3 | +use pyo3::prelude::*; |
| 4 | +use pyo3::types::{PyDict, PyTuple}; |
| 5 | + |
| 6 | +use crate::build_tools::SchemaDict; |
| 7 | +use crate::errors::ValResult; |
| 8 | +use crate::input::Input; |
| 9 | +use crate::recursion_guard::RecursionGuard; |
| 10 | + |
| 11 | +use super::{build_validator, BuildContext, BuildValidator, CombinedValidator, Extra, Validator}; |
| 12 | + |
| 13 | +#[derive(Debug, Clone)] |
| 14 | +pub struct CallValidator { |
| 15 | + function: PyObject, |
| 16 | + arguments_validator: Box<CombinedValidator>, |
| 17 | + return_validator: Option<Box<CombinedValidator>>, |
| 18 | + name: String, |
| 19 | +} |
| 20 | + |
| 21 | +impl BuildValidator for CallValidator { |
| 22 | + const EXPECTED_TYPE: &'static str = "call"; |
| 23 | + |
| 24 | + fn build( |
| 25 | + schema: &PyDict, |
| 26 | + config: Option<&PyDict>, |
| 27 | + build_context: &mut BuildContext, |
| 28 | + ) -> PyResult<CombinedValidator> { |
| 29 | + let py = schema.py(); |
| 30 | + |
| 31 | + let arguments_schema: &PyAny = schema.get_as_req(intern!(py, "arguments_schema"))?; |
| 32 | + let arguments_validator = Box::new(build_validator(arguments_schema, config, build_context)?); |
| 33 | + |
| 34 | + let return_schema = schema.get_item(intern!(py, "return_schema")); |
| 35 | + let return_validator = match return_schema { |
| 36 | + Some(return_schema) => Some(Box::new(build_validator(return_schema, config, build_context)?)), |
| 37 | + None => None, |
| 38 | + }; |
| 39 | + let function: &PyAny = schema.get_as_req(intern!(py, "function"))?; |
| 40 | + let function_name: &str = function.getattr(intern!(py, "__name__"))?.extract()?; |
| 41 | + let name = format!("{}[{}]", Self::EXPECTED_TYPE, function_name); |
| 42 | + |
| 43 | + Ok(Self { |
| 44 | + function: function.to_object(py), |
| 45 | + arguments_validator, |
| 46 | + return_validator, |
| 47 | + name, |
| 48 | + } |
| 49 | + .into()) |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +impl Validator for CallValidator { |
| 54 | + fn validate<'s, 'data>( |
| 55 | + &'s self, |
| 56 | + py: Python<'data>, |
| 57 | + input: &'data impl Input<'data>, |
| 58 | + extra: &Extra, |
| 59 | + slots: &'data [CombinedValidator], |
| 60 | + recursion_guard: &'s mut RecursionGuard, |
| 61 | + ) -> ValResult<'data, PyObject> { |
| 62 | + let args = self |
| 63 | + .arguments_validator |
| 64 | + .validate(py, input, extra, slots, recursion_guard) |
| 65 | + .map_err(|e| e.with_outer_location("arguments".into()))?; |
| 66 | + |
| 67 | + let return_value = if let Ok((args, kwargs)) = args.extract::<(&PyTuple, &PyDict)>(py) { |
| 68 | + self.function.call(py, args, Some(kwargs))? |
| 69 | + } else if let Ok(kwargs) = args.cast_as::<PyDict>(py) { |
| 70 | + self.function.call(py, (), Some(kwargs))? |
| 71 | + } else { |
| 72 | + let msg = "Arguments validator should return a tuple of (args, kwargs) or a dict of kwargs"; |
| 73 | + return Err(PyTypeError::new_err(msg).into()); |
| 74 | + }; |
| 75 | + |
| 76 | + if let Some(return_validator) = &self.return_validator { |
| 77 | + return_validator |
| 78 | + .validate(py, return_value.into_ref(py), extra, slots, recursion_guard) |
| 79 | + .map_err(|e| e.with_outer_location("return-value".into())) |
| 80 | + } else { |
| 81 | + Ok(return_value.to_object(py)) |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + fn get_name(&self) -> &str { |
| 86 | + &self.name |
| 87 | + } |
| 88 | +} |
0 commit comments