Skip to content

Commit b03a82f

Browse files
committed
refactor to make it dry
1 parent 9986624 commit b03a82f

File tree

3 files changed

+176
-210
lines changed

3 files changed

+176
-210
lines changed

src/serializers/fields.rs

Lines changed: 150 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ use super::extra::Extra;
1717
use super::filter::SchemaFilter;
1818
use super::infer::{infer_json_key, infer_serialize, infer_to_python, SerializeInfer};
1919
use super::shared::PydanticSerializer;
20-
use super::shared::{sort_dict_recursive, CombinedSerializer, TypeSerializer};
20+
use super::shared::{CombinedSerializer, TypeSerializer};
21+
use super::type_serializers::dict::sort_dict_recursive;
2122

2223
/// representation of a field for serialization
2324
#[derive(Debug)]
@@ -167,118 +168,31 @@ impl GeneralFieldsSerializer {
167168
items.sort_by(|(a, _, _), (b, _, _)| a.cmp(b));
168169

169170
for (key_str, key, value) in items {
170-
let op_field = self.fields.get(&key_str);
171-
if extra.exclude_none && value.is_none() {
172-
if let Some(field) = op_field {
173-
if field.required {
174-
used_req_fields += 1;
175-
}
176-
}
177-
continue;
178-
}
179-
let field_extra = Extra {
180-
field_name: Some(&key_str),
181-
..extra
182-
};
183-
if let Some((next_include, next_exclude)) = self.filter.key_filter(&key, include, exclude)? {
184-
if let Some(field) = op_field {
185-
if let Some(ref serializer) = field.serializer {
186-
if !exclude_default(&value, &field_extra, serializer)? {
187-
let value = serializer.to_python(
188-
&value,
189-
next_include.as_ref(),
190-
next_exclude.as_ref(),
191-
&field_extra,
192-
)?;
193-
let output_key = field.get_key_py(output_dict.py(), &field_extra);
194-
output_dict.set_item(output_key, value)?;
195-
}
196-
}
197-
198-
if field.required {
199-
used_req_fields += 1;
200-
}
201-
} else if self.mode == FieldsMode::TypedDictAllow {
202-
let value = match &self.extra_serializer {
203-
Some(serializer) => serializer.to_python(
204-
&value,
205-
next_include.as_ref(),
206-
next_exclude.as_ref(),
207-
&field_extra,
208-
)?,
209-
None => {
210-
infer_to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?
211-
}
212-
};
213-
output_dict.set_item(key, value)?;
214-
} else if field_extra.check == SerCheck::Strict {
215-
return Err(PydanticSerializationUnexpectedValue::new(
216-
Some(format!("Unexpected field `{key}`")),
217-
field_extra.model_type_name().map(|bound| bound.to_string()),
218-
None,
219-
)
220-
.to_py_err());
221-
}
222-
}
171+
self.process_field(
172+
&key_str,
173+
&key,
174+
value,
175+
&output_dict,
176+
include,
177+
exclude,
178+
&extra,
179+
&mut used_req_fields,
180+
)?;
223181
}
224182
} else {
225-
// NOTE! we maintain the order of the input dict assuming that's right
226183
for result in main_iter {
227184
let (key, value) = result?;
228185
let key_str = key_str(&key)?;
229-
let op_field = self.fields.get(key_str);
230-
if extra.exclude_none && value.is_none() {
231-
if let Some(field) = op_field {
232-
if field.required {
233-
used_req_fields += 1;
234-
}
235-
}
236-
continue;
237-
}
238-
let field_extra = Extra {
239-
field_name: Some(key_str),
240-
..extra
241-
};
242-
if let Some((next_include, next_exclude)) = self.filter.key_filter(&key, include, exclude)? {
243-
if let Some(field) = op_field {
244-
if let Some(ref serializer) = field.serializer {
245-
if !exclude_default(&value, &field_extra, serializer)? {
246-
let value = serializer.to_python(
247-
&value,
248-
next_include.as_ref(),
249-
next_exclude.as_ref(),
250-
&field_extra,
251-
)?;
252-
let output_key = field.get_key_py(output_dict.py(), &field_extra);
253-
output_dict.set_item(output_key, value)?;
254-
}
255-
}
256-
257-
if field.required {
258-
used_req_fields += 1;
259-
}
260-
} else if self.mode == FieldsMode::TypedDictAllow {
261-
let value = match &self.extra_serializer {
262-
Some(serializer) => serializer.to_python(
263-
&value,
264-
next_include.as_ref(),
265-
next_exclude.as_ref(),
266-
&field_extra,
267-
)?,
268-
None => {
269-
infer_to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?
270-
}
271-
};
272-
output_dict.set_item(key, value)?;
273-
} else if field_extra.check == SerCheck::Strict {
274-
return Err(PydanticSerializationUnexpectedValue::new(
275-
Some(format!("Unexpected field `{key}`")),
276-
field_extra.model_type_name().map(|bound| bound.to_string()),
277-
None,
278-
)
279-
.to_py_err());
280-
}
281-
}
186+
self.process_field(
187+
key_str,
188+
&key,
189+
value,
190+
&output_dict,
191+
include,
192+
exclude,
193+
&extra,
194+
&mut used_req_fields,
195+
)?;
282196
}
283197
}
284198

@@ -301,6 +215,65 @@ impl GeneralFieldsSerializer {
301215
}
302216
}
303217

218+
#[allow(clippy::too_many_arguments)]
219+
fn process_field<'py>(
220+
&self,
221+
key_str: &str,
222+
key: &Bound<'py, PyAny>,
223+
value: Bound<'py, PyAny>,
224+
output_dict: &Bound<'py, PyDict>,
225+
include: Option<&Bound<'py, PyAny>>,
226+
exclude: Option<&Bound<'py, PyAny>>,
227+
extra: &Extra,
228+
used_req_fields: &mut usize,
229+
) -> PyResult<()> {
230+
let op_field = self.fields.get(key_str);
231+
if extra.exclude_none && value.is_none() {
232+
if let Some(field) = op_field {
233+
if field.required {
234+
*used_req_fields += 1;
235+
}
236+
}
237+
return Ok(());
238+
}
239+
let field_extra = Extra {
240+
field_name: Some(key_str),
241+
..*extra
242+
};
243+
if let Some((next_include, next_exclude)) = self.filter.key_filter(key, include, exclude)? {
244+
if let Some(field) = op_field {
245+
if let Some(ref serializer) = field.serializer {
246+
if !exclude_default(&value, &field_extra, serializer)? {
247+
let value =
248+
serializer.to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?;
249+
let output_key = field.get_key_py(output_dict.py(), &field_extra);
250+
output_dict.set_item(output_key, value)?;
251+
}
252+
}
253+
254+
if field.required {
255+
*used_req_fields += 1;
256+
}
257+
} else if self.mode == FieldsMode::TypedDictAllow {
258+
let value = match &self.extra_serializer {
259+
Some(serializer) => {
260+
serializer.to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?
261+
}
262+
None => infer_to_python(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra)?,
263+
};
264+
output_dict.set_item(key, value)?;
265+
} else if field_extra.check == SerCheck::Strict {
266+
return Err(PydanticSerializationUnexpectedValue::new(
267+
Some(format!("Unexpected field `{key}`")),
268+
field_extra.model_type_name().map(|bound| bound.to_string()),
269+
None,
270+
)
271+
.to_py_err());
272+
}
273+
}
274+
Ok(())
275+
}
276+
304277
pub(crate) fn main_serde_serialize<'py, S: serde::ser::Serializer>(
305278
&self,
306279
main_iter: impl Iterator<Item = PyResult<(Bound<'py, PyAny>, Bound<'py, PyAny>)>>,
@@ -324,97 +297,90 @@ impl GeneralFieldsSerializer {
324297
.map_err(py_err_se_err)?;
325298
items.sort_by(|(a, _, _), (b, _, _)| a.cmp(b));
326299
for (key_str, key, value) in items {
327-
let field_extra = Extra {
328-
field_name: Some(&key_str),
329-
..extra
330-
};
300+
self.process_serde_field::<S>(&key_str, &key, &value, &mut map, include, exclude, &extra)?;
301+
}
302+
} else {
303+
for result in main_iter {
304+
let (key, value) = result.map_err(py_err_se_err)?;
305+
if extra.exclude_none && value.is_none() {
306+
continue;
307+
}
308+
let key_str = key_str(&key).map_err(py_err_se_err)?;
309+
self.process_serde_field::<S>(key_str, &key, &value, &mut map, include, exclude, &extra)?;
310+
}
311+
}
312+
Ok(map)
313+
}
331314

332-
let filter = self.filter.key_filter(&key, include, exclude).map_err(py_err_se_err)?;
333-
if let Some((next_include, next_exclude)) = filter {
334-
if let Some(field) = self.fields.get(&key_str) {
335-
if let Some(ref serializer) = field.serializer {
336-
if !exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
337-
if extra.sort_keys {
338-
let sorted_dict = sort_dict_recursive(value.py(), &value).map_err(py_err_se_err)?;
339-
let s = PydanticSerializer::new(
340-
sorted_dict.as_ref(),
341-
serializer,
342-
next_include.as_ref(),
343-
next_exclude.as_ref(),
344-
&field_extra,
345-
);
346-
let output_key = field.get_key_json(&key_str, &field_extra);
347-
map.serialize_entry(&output_key, &s)?;
348-
} else {
349-
let s = PydanticSerializer::new(
350-
&value,
351-
serializer,
352-
next_include.as_ref(),
353-
next_exclude.as_ref(),
354-
&field_extra,
355-
);
356-
let output_key = field.get_key_json(&key_str, &field_extra);
357-
map.serialize_entry(&output_key, &s)?;
358-
}
359-
}
360-
}
361-
} else if self.mode == FieldsMode::TypedDictAllow {
362-
let output_key = infer_json_key(&key, &field_extra).map_err(py_err_se_err)?;
315+
#[allow(clippy::too_many_arguments)]
316+
fn process_serde_field<'py, S: serde::ser::Serializer>(
317+
&self,
318+
key_str: &str,
319+
key: &Bound<'py, PyAny>,
320+
value: &Bound<'py, PyAny>,
321+
map: &mut S::SerializeMap,
322+
include: Option<&Bound<'py, PyAny>>,
323+
exclude: Option<&Bound<'py, PyAny>>,
324+
extra: &Extra,
325+
) -> Result<(), S::Error> {
326+
if extra.exclude_none && value.is_none() {
327+
return Ok(());
328+
}
329+
330+
let field_extra = Extra {
331+
field_name: Some(key_str),
332+
..*extra
333+
};
334+
335+
let filter = self.filter.key_filter(key, include, exclude).map_err(py_err_se_err)?;
336+
if let Some((next_include, next_exclude)) = filter {
337+
if let Some(field) = self.fields.get(key_str) {
338+
if let Some(ref serializer) = field.serializer {
339+
if !exclude_default(value, &field_extra, serializer).map_err(py_err_se_err)? {
340+
// Get potentially sorted value
363341
if extra.sort_keys {
364-
let sorted_dict = sort_dict_recursive(value.py(), &value).map_err(py_err_se_err)?;
365-
let s = SerializeInfer::new(
342+
let sorted_dict = sort_dict_recursive(value.py(), value).map_err(py_err_se_err)?;
343+
let s = PydanticSerializer::new(
366344
sorted_dict.as_ref(),
345+
serializer,
367346
next_include.as_ref(),
368347
next_exclude.as_ref(),
369348
&field_extra,
370349
);
350+
let output_key = field.get_key_json(key_str, &field_extra);
371351
map.serialize_entry(&output_key, &s)?;
372352
} else {
373-
let s =
374-
SerializeInfer::new(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra);
353+
let s = PydanticSerializer::new(
354+
value,
355+
serializer,
356+
next_include.as_ref(),
357+
next_exclude.as_ref(),
358+
&field_extra,
359+
);
360+
let output_key = field.get_key_json(key_str, &field_extra);
375361
map.serialize_entry(&output_key, &s)?;
376-
}
362+
};
377363
}
378364
}
379-
}
380-
} else {
381-
for result in main_iter {
382-
let (key, value) = result.map_err(py_err_se_err)?;
383-
if extra.exclude_none && value.is_none() {
384-
continue;
385-
}
386-
let key_str = key_str(&key).map_err(py_err_se_err)?;
387-
let field_extra = Extra {
388-
field_name: Some(key_str),
389-
..extra
365+
} else if self.mode == FieldsMode::TypedDictAllow {
366+
let output_key = infer_json_key(key, &field_extra).map_err(py_err_se_err)?;
367+
// Get potentially sorted value
368+
if extra.sort_keys {
369+
let sorted_dict = sort_dict_recursive(value.py(), value).map_err(py_err_se_err)?;
370+
let s = SerializeInfer::new(
371+
sorted_dict.as_ref(),
372+
next_include.as_ref(),
373+
next_exclude.as_ref(),
374+
&field_extra,
375+
);
376+
map.serialize_entry(&output_key, &s)?;
377+
} else {
378+
let s = SerializeInfer::new(value, next_include.as_ref(), next_exclude.as_ref(), &field_extra);
379+
map.serialize_entry(&output_key, &s)?;
390380
};
391-
392-
let filter = self.filter.key_filter(&key, include, exclude).map_err(py_err_se_err)?;
393-
if let Some((next_include, next_exclude)) = filter {
394-
if let Some(field) = self.fields.get(key_str) {
395-
if let Some(ref serializer) = field.serializer {
396-
if !exclude_default(&value, &field_extra, serializer).map_err(py_err_se_err)? {
397-
let s = PydanticSerializer::new(
398-
&value,
399-
serializer,
400-
next_include.as_ref(),
401-
next_exclude.as_ref(),
402-
&field_extra,
403-
);
404-
let output_key = field.get_key_json(key_str, &field_extra);
405-
map.serialize_entry(&output_key, &s)?;
406-
}
407-
}
408-
} else if self.mode == FieldsMode::TypedDictAllow {
409-
let output_key = infer_json_key(&key, &field_extra).map_err(py_err_se_err)?;
410-
let s = SerializeInfer::new(&value, next_include.as_ref(), next_exclude.as_ref(), &field_extra);
411-
map.serialize_entry(&output_key, &s)?;
412-
}
413-
// no error case here since unions (which need the error case) use `to_python(..., mode='json')`
414-
}
415381
}
416382
}
417-
Ok(map)
383+
Ok(())
418384
}
419385

420386
pub(crate) fn add_computed_fields_python(

0 commit comments

Comments
 (0)