22from copy import deepcopy
33from dataclasses import dataclass , field
44import inspect
5- from typing import Any , Callable , Generic , Optional , TypeVar , Union , cast , get_args , overload
5+ from typing import Any , Callable , Generic , Literal , Optional , TypeVar , Union , cast , get_args , overload
66
77from arclet .alconna import (
88 Alconna ,
3434)
3535from arclet .letoderea .handler import depend_handler
3636from arclet .letoderea .provider import ProviderFactory
37- from nepattern import DirectPattern , main
37+ from nepattern import DirectPattern
38+ from nepattern .util import CUnionType
3839from pygtrie import CharTrie
3940from satori .client import Account
4041from satori .element import At , Text
@@ -174,7 +175,7 @@ async def __call__(self, scope: Scope, context: Contexts) -> Optional[Union[bool
174175 context ["alc_result" ] = CommandResult (self .cmd , _res , may_help_text )
175176 return context
176177 elif may_help_text :
177- await account .send (context ["$event" ], may_help_text )
178+ await account .send (context ["$event" ], MessageChain ( may_help_text ) )
178179 return False
179180
180181 @property
@@ -183,9 +184,9 @@ def scopes(self) -> set[Scope]:
183184
184185
185186class AlconnaProvider (Provider [Any ]):
186- def __init__ (self , type : str , extra : Optional [dict ] = None ):
187+ def __init__ (self , type_ : str , extra : Optional [dict ] = None ):
187188 super ().__init__ ()
188- self .type = type
189+ self .type = type_
189190 self .extra = extra or {}
190191
191192 async def __call__ (self , context : Contexts ):
@@ -221,7 +222,7 @@ async def __call__(self, context: Contexts):
221222class AlconnaProviderFactory (ProviderFactory ):
222223 def validate (self , param : Param ):
223224 annotation = get_origin (param .annotation )
224- if annotation in main . _Contents :
225+ if annotation in ( Union , CUnionType , Literal ) :
225226 annotation = get_origin (get_args (param .annotation )[0 ])
226227 if annotation is CommandResult :
227228 return AlconnaProvider ("result" )
@@ -254,7 +255,11 @@ def __init__(self, need_tome: bool = False, remove_tome: bool = False):
254255 plugins ["~command.EntariCommands" ] = self .publisher
255256 self .need_tome = need_tome
256257 self .remove_tome = remove_tome
257- config .namespaces ["Entari" ] = Namespace (self .__namespace__ )
258+ config .namespaces ["Entari" ] = Namespace (
259+ self .__namespace__ ,
260+ to_text = lambda x : x .text if x .__class__ is Text else None ,
261+ converter = lambda x : MessageChain (x ),
262+ )
258263
259264 @self .publisher .register (auxiliaries = [MessageJudger ()])
260265 async def listener (event : MessageEvent ):
@@ -340,8 +345,7 @@ def on(
340345 remove_tome : bool = False ,
341346 auxiliaries : Optional [list [BaseAuxiliary ]] = None ,
342347 providers : Optional [list [Provider , type [Provider ], ProviderFactory , type [ProviderFactory ]]] = None ,
343- ) -> Callable [[TCallable ], TCallable ]:
344- ...
348+ ) -> Callable [[TCallable ], TCallable ]: ...
345349
346350 @overload
347351 def on (
@@ -354,8 +358,7 @@ def on(
354358 * ,
355359 args : Optional [dict [str , Union [TAValue , Args , Arg ]]] = None ,
356360 meta : Optional [CommandMeta ] = None ,
357- ) -> Callable [[TCallable ], TCallable ]:
358- ...
361+ ) -> Callable [[TCallable ], TCallable ]: ...
359362
360363 def on (
361364 self ,
0 commit comments