2
2
3
3
import sys
4
4
from datetime import date , datetime , time , timedelta
5
- from typing import Any , Callable , Dict , List , Optional , Type , Union
5
+ from typing import Any , Callable , Dict , List , Optional , Type , Union , overload
6
6
7
7
if sys .version_info < (3 , 11 ):
8
8
from typing_extensions import NotRequired , Required
@@ -275,7 +275,7 @@ class LiteralSchema(TypedDict):
275
275
276
276
277
277
def literal_schema (* expected : Any , ref : str | None = None ) -> LiteralSchema :
278
- return dict_not_none (type = 'literal' , expected = list ( expected ) , ref = ref )
278
+ return dict_not_none (type = 'literal' , expected = expected , ref = ref )
279
279
280
280
281
281
class IsInstanceSchema (TypedDict ):
@@ -327,8 +327,7 @@ class TuplePositionalSchema(TypedDict, total=False):
327
327
328
328
329
329
def tuple_positional_schema (
330
- items_schema : List [CoreSchema ],
331
- * ,
330
+ * items_schema : CoreSchema ,
332
331
extra_schema : CoreSchema | None = None ,
333
332
strict : bool | None = None ,
334
333
ref : str | None = None ,
@@ -349,7 +348,7 @@ class TupleVariableSchema(TypedDict, total=False):
349
348
350
349
351
350
def tuple_variable_schema (
352
- items_schema : CoreSchema ,
351
+ items_schema : CoreSchema | None = None ,
353
352
* ,
354
353
min_length : int | None = None ,
355
354
max_length : int | None = None ,
@@ -427,9 +426,9 @@ class DictSchema(TypedDict, total=False):
427
426
428
427
429
428
def dict_schema (
430
- * ,
431
429
keys_schema : CoreSchema | None = None ,
432
430
values_schema : CoreSchema | None = None ,
431
+ * ,
433
432
min_length : int | None = None ,
434
433
max_length : int | None = None ,
435
434
strict : bool | None = None ,
@@ -456,16 +455,27 @@ class FunctionSchema(TypedDict):
456
455
ref : NotRequired [str ]
457
456
458
457
459
- def function_schema (
460
- mode : Literal ['before' , 'after' , 'wrap' ],
461
- function : Callable [..., Any ],
462
- schema : CoreSchema ,
463
- * ,
464
- validator_instance : Any | None = None ,
465
- ref : str | None = None ,
458
+ def function_before_schema (
459
+ function : Callable [..., Any ], schema : CoreSchema , * , validator_instance : Any | None = None , ref : str | None = None
466
460
) -> FunctionSchema :
467
461
return dict_not_none (
468
- type = 'function' , mode = mode , function = function , schema = schema , validator_instance = validator_instance , ref = ref
462
+ type = 'function' , mode = 'before' , function = function , schema = schema , validator_instance = validator_instance , ref = ref
463
+ )
464
+
465
+
466
+ def function_after_schema (
467
+ function : Callable [..., Any ], schema : CoreSchema , * , validator_instance : Any | None = None , ref : str | None = None
468
+ ) -> FunctionSchema :
469
+ return dict_not_none (
470
+ type = 'function' , mode = 'after' , function = function , schema = schema , validator_instance = validator_instance , ref = ref
471
+ )
472
+
473
+
474
+ def function_wrap_schema (
475
+ function : Callable [..., Any ], schema : CoreSchema , * , validator_instance : Any | None = None , ref : str | None = None
476
+ ) -> FunctionSchema :
477
+ return dict_not_none (
478
+ type = 'function' , mode = 'wrap' , function = function , schema = schema , validator_instance = validator_instance , ref = ref
469
479
)
470
480
471
481
@@ -496,24 +506,24 @@ class WithDefaultSchema(TypedDict, total=False):
496
506
ref : str
497
507
498
508
509
+ Omitted = object ()
510
+
511
+
499
512
def with_default_schema (
500
513
schema : CoreSchema ,
501
514
* ,
502
- default : Any | None = None ,
515
+ default : Any = Omitted ,
503
516
default_factory : Callable [[], Any ] | None = None ,
504
517
on_error : Literal ['raise' , 'omit' , 'default' ] | None = None ,
505
518
strict : bool | None = None ,
506
519
ref : str | None = None ,
507
520
) -> WithDefaultSchema :
508
- return dict_not_none (
509
- type = 'default' ,
510
- schema = schema ,
511
- default = default ,
512
- default_factory = default_factory ,
513
- on_error = on_error ,
514
- strict = strict ,
515
- ref = ref ,
521
+ s = dict_not_none (
522
+ type = 'default' , schema = schema , default_factory = default_factory , on_error = on_error , strict = strict , ref = ref
516
523
)
524
+ if default is not Omitted :
525
+ s ['default' ] = default
526
+ return s
517
527
518
528
519
529
class NullableSchema (TypedDict , total = False ):
@@ -532,6 +542,14 @@ class CustomError(TypedDict):
532
542
message : str
533
543
534
544
545
+ def _custom_error (kind : str | None , message : str | None ) -> CustomError | None :
546
+ if kind is None and message is None :
547
+ return None
548
+ else :
549
+ # let schema validation raise the error
550
+ return CustomError (kind = kind , message = message ) # type: ignore
551
+
552
+
535
553
class UnionSchema (TypedDict , total = False ):
536
554
type : Required [Literal ['union' ]]
537
555
choices : Required [List [CoreSchema ]]
@@ -540,8 +558,36 @@ class UnionSchema(TypedDict, total=False):
540
558
ref : str
541
559
542
560
543
- def union_schema (choices : List [CoreSchema ], * , strict : bool | None = None , ref : str | None = None ) -> UnionSchema :
544
- return dict_not_none (type = 'union' , choices = choices , strict = strict , ref = ref )
561
+ @overload
562
+ def union_schema (
563
+ * choices : CoreSchema ,
564
+ custom_error_kind : str ,
565
+ custom_error_message : str ,
566
+ strict : bool | None = None ,
567
+ ref : str | None = None ,
568
+ ) -> UnionSchema :
569
+ ...
570
+
571
+
572
+ @overload
573
+ def union_schema (* choices : CoreSchema , strict : bool | None = None , ref : str | None = None ) -> UnionSchema :
574
+ ...
575
+
576
+
577
+ def union_schema (
578
+ * choices : CoreSchema ,
579
+ custom_error_kind : str | None = None ,
580
+ custom_error_message : str | None = None ,
581
+ strict : bool | None = None ,
582
+ ref : str | None = None ,
583
+ ) -> UnionSchema :
584
+ return dict_not_none (
585
+ type = 'union' ,
586
+ choices = choices ,
587
+ custom_error = _custom_error (custom_error_kind , custom_error_message ),
588
+ strict = strict ,
589
+ ref = ref ,
590
+ )
545
591
546
592
547
593
class TaggedUnionSchema (TypedDict ):
@@ -553,14 +599,47 @@ class TaggedUnionSchema(TypedDict):
553
599
ref : NotRequired [str ]
554
600
555
601
602
+ @overload
603
+ def tagged_union_schema (
604
+ choices : Dict [str , CoreSchema ],
605
+ discriminator : str | list [str | int ] | list [list [str | int ]] | Callable [[Any ], str | None ],
606
+ * ,
607
+ custom_error_kind : str ,
608
+ custom_error_message : str ,
609
+ strict : bool | None = None ,
610
+ ref : str | None = None ,
611
+ ) -> TaggedUnionSchema :
612
+ ...
613
+
614
+
615
+ @overload
556
616
def tagged_union_schema (
557
617
choices : Dict [str , CoreSchema ],
558
618
discriminator : str | list [str | int ] | list [list [str | int ]] | Callable [[Any ], str | None ],
559
619
* ,
560
620
strict : bool | None = None ,
561
621
ref : str | None = None ,
562
622
) -> TaggedUnionSchema :
563
- return dict_not_none (type = 'tagged-union' , choices = choices , discriminator = discriminator , strict = strict , ref = ref )
623
+ ...
624
+
625
+
626
+ def tagged_union_schema (
627
+ choices : Dict [str , CoreSchema ],
628
+ discriminator : str | list [str | int ] | list [list [str | int ]] | Callable [[Any ], str | None ],
629
+ * ,
630
+ custom_error_kind : str | None = None ,
631
+ custom_error_message : str | None = None ,
632
+ strict : bool | None = None ,
633
+ ref : str | None = None ,
634
+ ) -> TaggedUnionSchema :
635
+ return dict_not_none (
636
+ type = 'tagged-union' ,
637
+ choices = choices ,
638
+ discriminator = discriminator ,
639
+ custom_error = _custom_error (custom_error_kind , custom_error_message ),
640
+ strict = strict ,
641
+ ref = ref ,
642
+ )
564
643
565
644
566
645
class ChainSchema (TypedDict ):
@@ -569,7 +648,7 @@ class ChainSchema(TypedDict):
569
648
ref : NotRequired [str ]
570
649
571
650
572
- def chain_schema (steps : List [ CoreSchema ], * , ref : str | None = None ) -> ChainSchema :
651
+ def chain_schema (* steps : CoreSchema , ref : str | None = None ) -> ChainSchema :
573
652
return dict_not_none (type = 'chain' , steps = steps , ref = ref )
574
653
575
654
@@ -681,16 +760,15 @@ class ArgumentsSchema(TypedDict, total=False):
681
760
682
761
683
762
def arguments_schema (
684
- arguments_schema : List [ArgumentsParameter ],
685
- * ,
763
+ * arguments : ArgumentsParameter ,
686
764
populate_by_name : bool | None = None ,
687
765
var_args_schema : CoreSchema | None = None ,
688
766
var_kwargs_schema : CoreSchema | None = None ,
689
767
ref : str | None = None ,
690
768
) -> ArgumentsSchema :
691
769
return dict_not_none (
692
770
type = 'arguments' ,
693
- arguments_schema = arguments_schema ,
771
+ arguments_schema = arguments ,
694
772
populate_by_name = populate_by_name ,
695
773
var_args_schema = var_args_schema ,
696
774
var_kwargs_schema = var_kwargs_schema ,
@@ -700,21 +778,21 @@ def arguments_schema(
700
778
701
779
class CallSchema (TypedDict ):
702
780
type : Literal ['call' ]
703
- function : Callable [..., Any ]
704
781
arguments_schema : CoreSchema
782
+ function : Callable [..., Any ]
705
783
return_schema : NotRequired [CoreSchema ]
706
784
ref : NotRequired [str ]
707
785
708
786
709
787
def call_schema (
788
+ arguments : CoreSchema ,
710
789
function : Callable [..., Any ],
711
- arguments_schema : CoreSchema ,
712
790
* ,
713
791
return_schema : CoreSchema | None = None ,
714
792
ref : str | None = None ,
715
793
) -> CallSchema :
716
794
return dict_not_none (
717
- type = 'call' , function = function , arguments_schema = arguments_schema , return_schema = return_schema , ref = ref
795
+ type = 'call' , arguments_schema = arguments , function = function , return_schema = return_schema , ref = ref
718
796
)
719
797
720
798
0 commit comments