Skip to content

Commit 837a6ce

Browse files
authored
Core schema improvements (#265)
* improvements to a few core_schema.py functions * few more tiny tweaks * uprev * fix default=None in with_default_schema * fix with_default_schema and union custom_error * overload union_schema and tagged_union_schema * fix pyright
1 parent 97cdbcf commit 837a6ce

File tree

5 files changed

+178
-60
lines changed

5 files changed

+178
-60
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "pydantic-core"
3-
version = "0.3.0"
3+
version = "0.3.1"
44
edition = "2021"
55
license = "MIT"
66
homepage = "https://github.com/pydantic/pydantic-core"

pydantic_core/core_schema.py

Lines changed: 111 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import sys
44
from datetime import date, datetime, time, timedelta
5-
from typing import Any, Callable, Dict, List, Optional, Type, Union
5+
from typing import Any, Callable, Dict, List, Optional, Type, Union, overload
66

77
if sys.version_info < (3, 11):
88
from typing_extensions import NotRequired, Required
@@ -275,7 +275,7 @@ class LiteralSchema(TypedDict):
275275

276276

277277
def literal_schema(*expected: Any, ref: str | None = None) -> LiteralSchema:
278-
return dict_not_none(type='literal', expected=list(expected), ref=ref)
278+
return dict_not_none(type='literal', expected=expected, ref=ref)
279279

280280

281281
class IsInstanceSchema(TypedDict):
@@ -327,8 +327,7 @@ class TuplePositionalSchema(TypedDict, total=False):
327327

328328

329329
def tuple_positional_schema(
330-
items_schema: List[CoreSchema],
331-
*,
330+
*items_schema: CoreSchema,
332331
extra_schema: CoreSchema | None = None,
333332
strict: bool | None = None,
334333
ref: str | None = None,
@@ -349,7 +348,7 @@ class TupleVariableSchema(TypedDict, total=False):
349348

350349

351350
def tuple_variable_schema(
352-
items_schema: CoreSchema,
351+
items_schema: CoreSchema | None = None,
353352
*,
354353
min_length: int | None = None,
355354
max_length: int | None = None,
@@ -427,9 +426,9 @@ class DictSchema(TypedDict, total=False):
427426

428427

429428
def dict_schema(
430-
*,
431429
keys_schema: CoreSchema | None = None,
432430
values_schema: CoreSchema | None = None,
431+
*,
433432
min_length: int | None = None,
434433
max_length: int | None = None,
435434
strict: bool | None = None,
@@ -456,16 +455,27 @@ class FunctionSchema(TypedDict):
456455
ref: NotRequired[str]
457456

458457

459-
def function_schema(
460-
mode: Literal['before', 'after', 'wrap'],
461-
function: Callable[..., Any],
462-
schema: CoreSchema,
463-
*,
464-
validator_instance: Any | None = None,
465-
ref: str | None = None,
458+
def function_before_schema(
459+
function: Callable[..., Any], schema: CoreSchema, *, validator_instance: Any | None = None, ref: str | None = None
466460
) -> FunctionSchema:
467461
return dict_not_none(
468-
type='function', mode=mode, function=function, schema=schema, validator_instance=validator_instance, ref=ref
462+
type='function', mode='before', function=function, schema=schema, validator_instance=validator_instance, ref=ref
463+
)
464+
465+
466+
def function_after_schema(
467+
function: Callable[..., Any], schema: CoreSchema, *, validator_instance: Any | None = None, ref: str | None = None
468+
) -> FunctionSchema:
469+
return dict_not_none(
470+
type='function', mode='after', function=function, schema=schema, validator_instance=validator_instance, ref=ref
471+
)
472+
473+
474+
def function_wrap_schema(
475+
function: Callable[..., Any], schema: CoreSchema, *, validator_instance: Any | None = None, ref: str | None = None
476+
) -> FunctionSchema:
477+
return dict_not_none(
478+
type='function', mode='wrap', function=function, schema=schema, validator_instance=validator_instance, ref=ref
469479
)
470480

471481

@@ -496,24 +506,24 @@ class WithDefaultSchema(TypedDict, total=False):
496506
ref: str
497507

498508

509+
Omitted = object()
510+
511+
499512
def with_default_schema(
500513
schema: CoreSchema,
501514
*,
502-
default: Any | None = None,
515+
default: Any = Omitted,
503516
default_factory: Callable[[], Any] | None = None,
504517
on_error: Literal['raise', 'omit', 'default'] | None = None,
505518
strict: bool | None = None,
506519
ref: str | None = None,
507520
) -> WithDefaultSchema:
508-
return dict_not_none(
509-
type='default',
510-
schema=schema,
511-
default=default,
512-
default_factory=default_factory,
513-
on_error=on_error,
514-
strict=strict,
515-
ref=ref,
521+
s = dict_not_none(
522+
type='default', schema=schema, default_factory=default_factory, on_error=on_error, strict=strict, ref=ref
516523
)
524+
if default is not Omitted:
525+
s['default'] = default
526+
return s
517527

518528

519529
class NullableSchema(TypedDict, total=False):
@@ -532,6 +542,14 @@ class CustomError(TypedDict):
532542
message: str
533543

534544

545+
def _custom_error(kind: str | None, message: str | None) -> CustomError | None:
546+
if kind is None and message is None:
547+
return None
548+
else:
549+
# let schema validation raise the error
550+
return CustomError(kind=kind, message=message) # type: ignore
551+
552+
535553
class UnionSchema(TypedDict, total=False):
536554
type: Required[Literal['union']]
537555
choices: Required[List[CoreSchema]]
@@ -540,8 +558,36 @@ class UnionSchema(TypedDict, total=False):
540558
ref: str
541559

542560

543-
def union_schema(choices: List[CoreSchema], *, strict: bool | None = None, ref: str | None = None) -> UnionSchema:
544-
return dict_not_none(type='union', choices=choices, strict=strict, ref=ref)
561+
@overload
562+
def union_schema(
563+
*choices: CoreSchema,
564+
custom_error_kind: str,
565+
custom_error_message: str,
566+
strict: bool | None = None,
567+
ref: str | None = None,
568+
) -> UnionSchema:
569+
...
570+
571+
572+
@overload
573+
def union_schema(*choices: CoreSchema, strict: bool | None = None, ref: str | None = None) -> UnionSchema:
574+
...
575+
576+
577+
def union_schema(
578+
*choices: CoreSchema,
579+
custom_error_kind: str | None = None,
580+
custom_error_message: str | None = None,
581+
strict: bool | None = None,
582+
ref: str | None = None,
583+
) -> UnionSchema:
584+
return dict_not_none(
585+
type='union',
586+
choices=choices,
587+
custom_error=_custom_error(custom_error_kind, custom_error_message),
588+
strict=strict,
589+
ref=ref,
590+
)
545591

546592

547593
class TaggedUnionSchema(TypedDict):
@@ -553,14 +599,47 @@ class TaggedUnionSchema(TypedDict):
553599
ref: NotRequired[str]
554600

555601

602+
@overload
603+
def tagged_union_schema(
604+
choices: Dict[str, CoreSchema],
605+
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], str | None],
606+
*,
607+
custom_error_kind: str,
608+
custom_error_message: str,
609+
strict: bool | None = None,
610+
ref: str | None = None,
611+
) -> TaggedUnionSchema:
612+
...
613+
614+
615+
@overload
556616
def tagged_union_schema(
557617
choices: Dict[str, CoreSchema],
558618
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], str | None],
559619
*,
560620
strict: bool | None = None,
561621
ref: str | None = None,
562622
) -> TaggedUnionSchema:
563-
return dict_not_none(type='tagged-union', choices=choices, discriminator=discriminator, strict=strict, ref=ref)
623+
...
624+
625+
626+
def tagged_union_schema(
627+
choices: Dict[str, CoreSchema],
628+
discriminator: str | list[str | int] | list[list[str | int]] | Callable[[Any], str | None],
629+
*,
630+
custom_error_kind: str | None = None,
631+
custom_error_message: str | None = None,
632+
strict: bool | None = None,
633+
ref: str | None = None,
634+
) -> TaggedUnionSchema:
635+
return dict_not_none(
636+
type='tagged-union',
637+
choices=choices,
638+
discriminator=discriminator,
639+
custom_error=_custom_error(custom_error_kind, custom_error_message),
640+
strict=strict,
641+
ref=ref,
642+
)
564643

565644

566645
class ChainSchema(TypedDict):
@@ -569,7 +648,7 @@ class ChainSchema(TypedDict):
569648
ref: NotRequired[str]
570649

571650

572-
def chain_schema(steps: List[CoreSchema], *, ref: str | None = None) -> ChainSchema:
651+
def chain_schema(*steps: CoreSchema, ref: str | None = None) -> ChainSchema:
573652
return dict_not_none(type='chain', steps=steps, ref=ref)
574653

575654

@@ -681,16 +760,15 @@ class ArgumentsSchema(TypedDict, total=False):
681760

682761

683762
def arguments_schema(
684-
arguments_schema: List[ArgumentsParameter],
685-
*,
763+
*arguments: ArgumentsParameter,
686764
populate_by_name: bool | None = None,
687765
var_args_schema: CoreSchema | None = None,
688766
var_kwargs_schema: CoreSchema | None = None,
689767
ref: str | None = None,
690768
) -> ArgumentsSchema:
691769
return dict_not_none(
692770
type='arguments',
693-
arguments_schema=arguments_schema,
771+
arguments_schema=arguments,
694772
populate_by_name=populate_by_name,
695773
var_args_schema=var_args_schema,
696774
var_kwargs_schema=var_kwargs_schema,
@@ -700,21 +778,21 @@ def arguments_schema(
700778

701779
class CallSchema(TypedDict):
702780
type: Literal['call']
703-
function: Callable[..., Any]
704781
arguments_schema: CoreSchema
782+
function: Callable[..., Any]
705783
return_schema: NotRequired[CoreSchema]
706784
ref: NotRequired[str]
707785

708786

709787
def call_schema(
788+
arguments: CoreSchema,
710789
function: Callable[..., Any],
711-
arguments_schema: CoreSchema,
712790
*,
713791
return_schema: CoreSchema | None = None,
714792
ref: str | None = None,
715793
) -> CallSchema:
716794
return dict_not_none(
717-
type='call', function=function, arguments_schema=arguments_schema, return_schema=return_schema, ref=ref
795+
type='call', arguments_schema=arguments, function=function, return_schema=return_schema, ref=ref
718796
)
719797

720798

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@ precision = 2
7070
exclude_lines = [
7171
'pragma: no cover',
7272
'raise NotImplementedError',
73-
'raise NotImplemented',
7473
'if TYPE_CHECKING:',
7574
'@overload',
7675
]

0 commit comments

Comments
 (0)