Skip to content

Commit 32a68a1

Browse files
committed
rewrite library type annotations and converters to use/support UnionType '|' symbol
1 parent 3db2ff4 commit 32a68a1

29 files changed

+558
-615
lines changed

pyproject.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[tool.black]
66
line-length = 88
7-
target_version = ['py39', 'py310']
7+
target_version = ['py310']
88
include = '\.pyi?$'
99
exclude = '''
1010
@@ -38,7 +38,7 @@ include = [
3838
exclude = [
3939
"**/__pycache__"
4040
]
41-
pythonVersion = "3.9"
41+
pythonVersion = "3.10"
4242

4343
[tool.pytest.ini_options]
4444
asyncio_mode = "strict"

snakecore/__init__.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030

3131

3232
async def init(
33-
global_client: Optional[discord.Client] = None,
33+
global_client: discord.Client | None = None,
3434
*,
3535
raise_module_exceptions: bool = False,
3636
) -> tuple[int, int]:
@@ -47,7 +47,7 @@ async def init(
4747
4848
Parameters
4949
----------
50-
global_client : Optional[discord.Client], optional
50+
global_client : discord.Client | None, optional
5151
The global `discord.Client` object to set for all modules to use.
5252
By default None.
5353
raise_module_exceptions : bool, optional
@@ -77,7 +77,7 @@ async def init(
7777

7878

7979
def init_sync(
80-
global_client: Optional[discord.Client] = None,
80+
global_client: discord.Client | None = None,
8181
*,
8282
raise_module_exceptions: bool = False,
8383
) -> tuple[int, int]:
@@ -88,7 +88,7 @@ def init_sync(
8888
8989
Parameters
9090
----------
91-
global_client : Optional[discord.Client], optional
91+
global_client : discord.Client | None, optional
9292
The global `discord.Client` object to set for all modules to use.
9393
Defaults to None.
9494
raise_module_exceptions : bool, optional
@@ -134,7 +134,7 @@ def init_sync(
134134

135135

136136
async def init_async(
137-
global_client: Optional[discord.Client] = None,
137+
global_client: discord.Client | None = None,
138138
*,
139139
raise_module_exceptions: bool = False,
140140
):
@@ -145,7 +145,7 @@ async def init_async(
145145
146146
Parameters
147147
----------
148-
global_client : Optional[discord.Client], optional
148+
global_client : discord.Client | None, optional
149149
The global `discord.Client` object to set for all modules to use.
150150
Defaults to None.
151151
raise_module_exceptions : bool, optional

snakecore/commands/__init__.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
from .bot import *
1515

1616

17-
def init(global_client: Optional[discord.Client] = None) -> None:
17+
def init(global_client: discord.Client | None = None) -> None:
1818
"""Initialize this module.
1919
2020
Parameters
2121
----------
22-
global_client : Optional[discord.Client], optional
22+
global_client : discord.Client | None, optional
2323
The global `discord.Client` object to set for all submodules to use.
2424
Defaults to None.
2525
"""

snakecore/commands/bot.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def delete_extension_config(self, qualified_name: str):
9999

100100
def get_extension_config(
101101
self, qualified_name: str, default: Any = UNSET, /
102-
) -> Union[Mapping[str, Any], Any]:
102+
) -> Mapping[str, Any] | Any:
103103
"""Get the configuration mapping that an extension should be loaded with, under
104104
the given qualified name.
105105
@@ -118,7 +118,7 @@ def get_extension_config(
118118
119119
Returns
120120
-------
121-
Union[Mapping[str, Any], Any]
121+
Mapping[str, Any] | Any
122122
The mapping or a default value.
123123
"""
124124
config = self._extension_configs.get(qualified_name, None)
@@ -137,8 +137,8 @@ async def load_extension_with_config(
137137
self,
138138
name: str,
139139
*,
140-
package: Optional[str] = None,
141-
config: Optional[Mapping[str, Any]] = None,
140+
package: str | None = None,
141+
config: Mapping[str, Any] | None = None,
142142
) -> None:
143143
"""A shorthand for calling `set_extension_config` followed by `load_extension`."""
144144
name = self._resolve_name(name, package)

snakecore/commands/converters.py

+25-13
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
from .bot import AutoShardedBot, Bot
4242

4343
_T = TypeVar("_T")
44-
BotT = Union[Bot, AutoShardedBot]
45-
DECBotT = Union[commands.Bot, commands.AutoShardedBot]
44+
BotT = Bot | AutoShardedBot
45+
DECBotT = commands.Bot | commands.AutoShardedBot
4646

4747
ellipsis = type(Ellipsis)
4848

@@ -432,7 +432,7 @@ class CodeBlock:
432432
_INLINE_PATTERN = re.compile(regex_patterns.INLINE_CODE_BLOCK)
433433

434434
def __init__(
435-
self, code: str, language: Optional[str] = None, inline: Optional[bool] = None
435+
self, code: str, language: str | None = None, inline: bool | None = None
436436
) -> None:
437437
self.code = code
438438
self.language = language
@@ -560,6 +560,12 @@ async def convert(self, ctx: commands.Context[DECBotT], argument: str) -> str:
560560
def __call__(self, *args: Any, **kwds: Any) -> Any:
561561
pass
562562

563+
def __or__(self, other): # Add support for UnionType
564+
return self.__class__.__class__.__or__(self.__class__, other) # type: ignore
565+
566+
def __ror__(self, other): # Add support for UnionType
567+
return self.__class__.__class__.__ror__(self.__class__, other) # type: ignore
568+
563569
@staticmethod
564570
def escape(string: str) -> str:
565571
"""Convert a "raw" string to one where characters are escaped."""
@@ -629,7 +635,7 @@ def __init__(self, size: Any = None) -> None:
629635
super().__init__()
630636
self.size: tuple = (..., ...) if size is None else size
631637

632-
def __class_getitem__(cls, size: Union[StringParams, StringParamsTuple]) -> Self:
638+
def __class_getitem__(cls, size: StringParams | StringParamsTuple) -> Self:
633639
size_tuple = (..., ...)
634640

635641
if getattr(size, "__origin__", None) is Literal:
@@ -775,7 +781,7 @@ def __init__(self, regex: str, examples: tuple[str, ...]) -> None:
775781
self.regex_pattern = re.compile(regex)
776782
self.examples = examples
777783

778-
def __class_getitem__(cls, regex_and_examples: Union[str, tuple[str, ...]]) -> Self:
784+
def __class_getitem__(cls, regex_and_examples: str | tuple[str, ...]) -> Self:
779785
regex = None
780786
examples = ()
781787
if isinstance(regex_and_examples, tuple) and regex_and_examples:
@@ -891,7 +897,7 @@ def __class_getitem__(cls, params: Any) -> Self:
891897

892898
args = getattr(converter, "__args__", ())
893899
if sys.version_info >= (3, 10) and converter.__class__ is types.UnionType: # type: ignore
894-
converter = Union[args] # type: ignore
900+
converter = args # type: ignore
895901

896902
origin = getattr(converter, "__origin__", None)
897903

@@ -1138,10 +1144,10 @@ async def convert(
11381144
ctx.command
11391145
and getattr(fake_parameter.annotation, "__origin__", None) is Union
11401146
and type(None)
1141-
in fake_parameter.annotation.__args__ # check for Optional[...] # type: ignore
1147+
in fake_parameter.annotation.__args__ # check for ... | None # type: ignore
11421148
and transformed is None
11431149
):
1144-
view.index = previous_index # view.undo() does not revert properly for Optional[...]
1150+
view.index = previous_index # view.undo() does not revert properly for ... | None
11451151
view.previous = previous_previous
11461152

11471153
ctx.current_parameter = original_parameter
@@ -1150,6 +1156,12 @@ async def convert(
11501156
def __call__(self, *args: Any, **kwds: Any) -> Any:
11511157
pass
11521158

1159+
def __or__(self, other): # Add support for UnionType
1160+
return self.__class__.__class__.__or__(self.__class__, other) # type: ignore
1161+
1162+
def __ror__(self, other): # Add support for UnionType
1163+
return self.__class__.__class__.__ror__(self.__class__, other) # type: ignore
1164+
11531165
def __repr__(self) -> str:
11541166
return f"{self.__class__.__name__}[{', '.join(self._repr_converter(conv) for conv in self.converters)}]"
11551167

@@ -1296,7 +1308,7 @@ class String(str): # type: ignore
12961308
- `'"ab\\"c"'` -> `'ab"c'`
12971309
"""
12981310

1299-
def __class_getitem__(cls, size: Union[StringParams, StringParamsTuple]):
1311+
def __class_getitem__(cls, size: StringParams | StringParamsTuple):
13001312
...
13011313

13021314
class StringExpr(str): # type: ignore
@@ -1312,15 +1324,15 @@ class StringExpr(str): # type: ignore
13121324
- `'"ab\\"c"'` -> `'ab"c'`
13131325
"""
13141326

1315-
def __class_getitem__(cls, regex_and_examples: Union[str, tuple[str, ...]]):
1327+
def __class_getitem__(cls, regex_and_examples: str | tuple[str, ...]):
13161328
...
13171329

13181330
class StringExprMatch(re.Match): # type: ignore
13191331
"""A subclass of the `StringExpr` converter, that converts inputs into
13201332
`re.Match` objects instead of strings.
13211333
"""
13221334

1323-
def __class_getitem__(cls, regex_and_examples: Union[str, tuple[str, ...]]):
1335+
def __class_getitem__(cls, regex_and_examples: str | tuple[str, ...]):
13241336
...
13251337

13261338
else:
@@ -1420,8 +1432,8 @@ async def convert_flag(
14201432
annotation = annotation.__args__[0]
14211433
return await convert_flag(ctx, argument, flag, annotation)
14221434
elif origin is Union and type(None) in annotation.__args__:
1423-
# typing.Optional[x]
1424-
annotation = Union[tuple(arg for arg in annotation.__args__ if arg is not type(None))] # type: ignore
1435+
# typing.x | None
1436+
annotation = tuple(arg for arg in annotation.__args__ if arg is not type(None)) # type: ignore
14251437
return await commands.run_converters(ctx, annotation, argument, param)
14261438
elif origin is dict:
14271439
# typing.Dict[K, V] -> typing.Tuple[K, V]

snakecore/commands/decorators.py

+15-13
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from discord.ext.commands.flags import FlagsMeta, Flag
1919
from snakecore.commands.bot import ExtAutoShardedBot, ExtBot
2020

21-
from snakecore.commands.converters import FlagConverter as _FlagConverter
21+
from snakecore.commands.converters import FlagConverter as _FlagConverter, Parens
2222
from snakecore.commands.parser import parse_command_str
2323
from ._types import AnyCommandType
2424

@@ -36,7 +36,7 @@
3636

3737
def flagconverter_kwargs(
3838
*,
39-
prefix: Optional[str] = "",
39+
prefix: str | None = "",
4040
delimiter: str = ":",
4141
cls: type[commands.FlagConverter] = _FlagConverter,
4242
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
@@ -65,9 +65,9 @@ def flagconverter_kwargs(
6565
6666
Parameters
6767
----------
68-
prefix : Optional[str], optional
68+
prefix : str | None, optional
6969
The prefix to pass to the `FlagConverter` subclass. Defaults to `""`.
70-
delimiter (Optional[str], optional): The delimiter to pass to the `FlagConverter`
70+
delimiter (str | None, optional): The delimiter to pass to the `FlagConverter`
7171
subclass. Defaults to `":"`.
7272
cls (type[commands.FlagConverter], optional): The class to use as a base class for
7373
the resulting FlagConverter class to generate. Useful for implementing custom flag
@@ -111,14 +111,16 @@ def flagconverter_kwargs_inner_deco(func: Callable[_P, _T]) -> Callable[_P, _T]:
111111
commands.Greedy,
112112
):
113113
if (
114-
isinstance(evaluated_anno, UnionGenericAlias)
114+
isinstance(
115+
evaluated_anno, (UnionGenericAlias, types.UnionType)
116+
)
115117
and type(None) in evaluated_anno.__args__
116118
or evaluated_anno in (None, type(None), str)
117119
):
118120
raise TypeError(
119-
"Cannot have None, NoneType or typing.Optional as an "
120-
f"annotation for '*{param.name}' when using "
121-
"flagconverter decorator"
121+
"Cannot have None, NoneType or typing.Optional/or "
122+
"UnionType with None as an annotation for "
123+
f"'*{param.name}' when using flagconverter decorator"
122124
)
123125
else:
124126
new_annotation = f"commands.Greedy[{param.annotation}]"
@@ -128,14 +130,14 @@ def flagconverter_kwargs_inner_deco(func: Callable[_P, _T]) -> Callable[_P, _T]:
128130
or isinstance(new_annotation, commands.Greedy)
129131
):
130132
if (
131-
isinstance(new_annotation, UnionGenericAlias)
133+
isinstance(new_annotation, (UnionGenericAlias, types.UnionType))
132134
and type(None) in new_annotation.__args__
133135
or new_annotation in (None, type(None), str)
134136
):
135137
raise TypeError(
136-
"Cannot have None, NoneType or typing.Optional as an "
137-
f"annotation for '*{param.name}' when using "
138-
"flagconverter decorator"
138+
"Cannot have None, NoneType or typing.Optional/or "
139+
"UnionType with None as an annotation for "
140+
f"'*{param.name}' when using flagconverter decorator"
139141
)
140142
new_annotation = commands.Greedy(converter=param.annotation)
141143

@@ -339,7 +341,7 @@ def with_config_kwargs(setup: Callable[_P, Any]) -> Callable[[commands.Bot], Any
339341
340342
Parameters
341343
----------
342-
func : Callable[[Union[ExtBot, ExtAutoShardedBot], ...], None]
344+
func : Callable[[ExtBot | ExtAutoShardedBot, ...], None]
343345
The `setup()` function.
344346
"""
345347

snakecore/commands/parser.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class CodeBlock:
109109
protocol and can be used as a converter.
110110
"""
111111

112-
def __init__(self, text: str, lang: Optional[str] = None) -> None:
112+
def __init__(self, text: str, lang: str | None = None) -> None:
113113
"""Initialise codeblock object. The text argument here is the contents of
114114
the codeblock. If the optional argument lang is specified, it has to be
115115
the language type of the codeblock, if not provided, it is determined
@@ -244,7 +244,7 @@ def split_anno(anno: str):
244244
def strip_optional_anno(anno: str) -> str:
245245
"""Helper to strip "Optional" anno"""
246246
anno = anno.strip()
247-
if anno.startswith("Optional[") and anno.endswith("]"):
247+
if anno.startswith("") and anno.endswith(" | None"):
248248
# call recursively to split "Optional" chains
249249
return strip_optional_anno(anno[9:-1])
250250

@@ -254,7 +254,7 @@ def strip_optional_anno(anno: str) -> str:
254254
def split_union_anno(anno: str):
255255
"""Helper to split a 'Union' annotation. Returns a generator of strings."""
256256
anno = strip_optional_anno(anno)
257-
if anno.startswith("Union[") and anno.endswith("]"):
257+
if anno.startswith("") and anno.endswith(""):
258258
for anno in split_anno(anno[6:-1]):
259259
# use recursive splits to "flatten" unions
260260
yield from split_union_anno(anno)
@@ -392,7 +392,7 @@ def parse_args(cmd_str: str):
392392
"""
393393
args: list[Any] = []
394394
kwargs: dict[str, Any] = {}
395-
temp_list: Optional[list[Any]] = None # used to store the temporary tuple
395+
temp_list: list[Any] | None = None # used to store the temporary tuple
396396

397397
kwstart = False # used to make sure that keyword args come after args
398398
prevkey = None # temporarily store previous key name
@@ -522,13 +522,13 @@ def append_arg(arg: Any):
522522

523523

524524
async def cast_basic_arg(
525-
ctx: commands.Context[Union[commands.Bot, commands.AutoShardedBot]],
525+
ctx: commands.Context[commands.Bot | commands.AutoShardedBot],
526526
anno: str,
527527
arg: Any,
528528
) -> Any:
529529
"""Helper to cast an argument to the type mentioned by the parameter
530530
annotation. This casts an argument in its "basic" form, where both argument
531-
and typehint are "simple", that does not contain stuff like Union[...],
531+
and typehint are "simple", that does not contain stuff like ...,
532532
tuple[...], etc.
533533
Raises ValueError on failure to cast arguments
534534
"""
@@ -796,9 +796,9 @@ async def cast_basic_arg(
796796

797797
async def cast_arg(
798798
ctx: commands.Context,
799-
param: Union[inspect.Parameter, str],
799+
param: inspect.Parameter | str,
800800
arg: Any,
801-
key: Optional[str] = None,
801+
key: str | None = None,
802802
convert_error: bool = True,
803803
) -> Any:
804804
"""Cast an argument to the type mentioned by the paramenter annotation"""

snakecore/config.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -349,10 +349,10 @@ class CoreConfig(ConfigurationBase):
349349
"""Configuration variables used by `snakecore` itself."""
350350

351351
global_client = Field(
352-
var_type=Union[discord.Client, discord.AutoShardedClient], write_once=True
352+
var_type=discord.Client | discord.AutoShardedClient, write_once=True
353353
)
354354
init_mods = Field(init_constr=dict, var_type=dict[ModuleName, bool])
355-
storage_channel = Field(init_val=None, var_type=Optional[discord.abc.Messageable])
355+
storage_channel = Field(init_val=None, var_type=discord.abc.Messageable | None)
356356

357357

358358
conf = CoreConfig()

0 commit comments

Comments
 (0)