Skip to content

Commit 7c9f31b

Browse files
committed
✨ version 0.3.2
is_direct/public_message
1 parent 5a2ba13 commit 7c9f31b

6 files changed

Lines changed: 83 additions & 38 deletions

File tree

arclet/entari/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@
4343
from .core import Entari as Entari
4444
from .event import MessageCreatedEvent as MessageCreatedEvent
4545
from .event import MessageEvent as MessageEvent
46+
from .filter import is_direct_message as is_direct_message
47+
from .filter import is_public_message as is_public_message
4648
from .message import MessageChain as MessageChain
4749
from .plugin import Plugin as Plugin
4850
from .plugin import load_plugin as load_plugin

arclet/entari/event.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def parse(cls, account: Account, origin: SatoriEvent):
3333
continue
3434
if attr := getattr(origin, fd.name, None):
3535
attrs[fd.name] = attr
36-
res = cls(**attrs)
36+
res = cls(**attrs) # type: ignore
3737
res._origin = origin
3838
return res
3939

@@ -56,22 +56,32 @@ def validate(self, param: Param):
5656
return param.name == "operator" and super().validate(param)
5757

5858
async def __call__(self, context: Contexts):
59+
if "operator" in context:
60+
return context["operator"]
5961
return context["$origin_event"].operator
6062

6163
class UserProvider(Provider[User]):
6264
async def __call__(self, context: Contexts):
65+
if "user" in context:
66+
return context["user"]
6367
return context["$origin_event"].user
6468

6569
class MessageProvider(Provider[MessageObject]):
6670
async def __call__(self, context: Contexts):
71+
if "$message_origin" in context:
72+
return context["$message_origin"]
6773
return context["$origin_event"].message
6874

6975
class ChannelProvider(Provider[Channel]):
7076
async def __call__(self, context: Contexts):
77+
if "channel" in context:
78+
return context["channel"]
7179
return context["$origin_event"].channel
7280

7381
class GuildProvider(Provider[Guild]):
7482
async def __call__(self, context: Contexts):
83+
if "guild" in context:
84+
return context["guild"]
7585
return context["$origin_event"].guild
7686

7787

arclet/entari/filter.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from typing import Optional
2+
3+
from arclet.letoderea import Contexts, JudgeAuxiliary, Scope
4+
from satori import Channel, ChannelType
5+
6+
7+
class DirectMessageJudger(JudgeAuxiliary):
8+
async def __call__(self, scope: Scope, context: Contexts) -> Optional[bool]:
9+
if "channel" not in context:
10+
return False
11+
channel: Channel = context["channel"]
12+
return channel.type == ChannelType.DIRECT
13+
14+
@property
15+
def scopes(self) -> set[Scope]:
16+
return {Scope.prepare}
17+
18+
19+
is_direct_message = DirectMessageJudger()
20+
21+
22+
class PublicMessageJudger(JudgeAuxiliary):
23+
async def __call__(self, scope: Scope, context: Contexts) -> Optional[bool]:
24+
if "channel" not in context:
25+
return False
26+
channel: Channel = context["channel"]
27+
return channel.type != ChannelType.DIRECT
28+
29+
@property
30+
def scopes(self) -> set[Scope]:
31+
return {Scope.prepare}
32+
33+
34+
is_public_message = PublicMessageJudger()

arclet/entari/message.py

Lines changed: 31 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class MessageChain(list[TE]):
1616
"""消息序列
1717
18-
参数:
18+
Args:
1919
message: 消息内容
2020
"""
2121

@@ -101,54 +101,54 @@ def __iadd__(self, other: str | TE | Iterable[TE]) -> Self:
101101
def __getitem__(self, args: type[TE1]) -> MessageChain[TE1]:
102102
"""获取仅包含指定消息段类型的消息
103103
104-
参数:
104+
Args:
105105
args: 消息段类型
106106
107-
返回:
107+
Return:
108108
所有类型为 `args` 的消息段
109109
"""
110110

111111
@overload
112112
def __getitem__(self, args: tuple[type[TE1], int]) -> TE1:
113113
"""索引指定类型的消息段
114114
115-
参数:
115+
Args:
116116
args: 消息段类型和索引
117117
118-
返回:
118+
Return:
119119
类型为 `args[0]` 的消息段第 `args[1]` 个
120120
"""
121121

122122
@overload
123123
def __getitem__(self, args: tuple[type[TE1], slice]) -> MessageChain[TE1]:
124124
"""切片指定类型的消息段
125125
126-
参数:
126+
Args:
127127
args: 消息段类型和切片
128128
129-
返回:
129+
Return:
130130
类型为 `args[0]` 的消息段切片 `args[1]`
131131
"""
132132

133133
@overload
134134
def __getitem__(self, args: int) -> TE:
135135
"""索引消息段
136136
137-
参数:
137+
Args:
138138
args: 索引
139139
140-
返回:
140+
Return:
141141
第 `args` 个消息段
142142
"""
143143

144144
@overload
145145
def __getitem__(self, args: slice) -> Self:
146146
"""切片消息段
147147
148-
参数:
148+
Args:
149149
args: 切片
150150
151-
返回:
151+
Return:
152152
消息切片 `args`
153153
"""
154154

@@ -174,9 +174,9 @@ def __getitem__(
174174
def __contains__(self, value: str | Element | type[Element]) -> bool:
175175
"""检查消息段是否存在
176176
177-
参数:
177+
Args:
178178
value: 消息段或消息段类型
179-
返回:
179+
Return:
180180
消息内是否存在给定消息段或给定类型的消息段
181181
"""
182182
if isinstance(value, type):
@@ -186,20 +186,19 @@ def __contains__(self, value: str | Element | type[Element]) -> bool:
186186
return super().__contains__(value)
187187

188188
def has(self, value: str | Element | type[Element]) -> bool:
189-
"""与 {ref}``__contains__` <nonebot.adapters.Message.__contains__>` 相同"""
190189
return value in self
191190

192191
def index(self, value: str | Element | type[Element], *args: SupportsIndex) -> int:
193192
"""索引消息段
194193
195-
参数:
194+
Args:
196195
value: 消息段或者消息段类型
197-
arg: start 与 end
196+
args: start 与 end
198197
199-
返回:
198+
Return:
200199
索引 index
201200
202-
异常:
201+
Raise:
203202
ValueError: 消息段不存在
204203
"""
205204
if isinstance(value, type):
@@ -214,11 +213,11 @@ def index(self, value: str | Element | type[Element], *args: SupportsIndex) -> i
214213
def get(self, type_: type[TE], count: int | None = None) -> MessageChain[TE]:
215214
"""获取指定类型的消息段
216215
217-
参数:
216+
Args:
218217
type_: 消息段类型
219218
count: 获取个数
220219
221-
返回:
220+
Return:
222221
构建的新消息
223222
"""
224223
if count is None:
@@ -235,10 +234,10 @@ def get(self, type_: type[TE], count: int | None = None) -> MessageChain[TE]:
235234
def count(self, value: type[Element] | str | Element) -> int:
236235
"""计算指定消息段的个数
237236
238-
参数:
237+
Args:
239238
value: 消息段或消息段类型
240239
241-
返回:
240+
Return:
242241
个数
243242
"""
244243
if isinstance(value, str):
@@ -252,10 +251,10 @@ def count(self, value: type[Element] | str | Element) -> int:
252251
def only(self, value: type[Element] | str | Element) -> bool:
253252
"""检查消息中是否仅包含指定消息段
254253
255-
参数:
254+
Args:
256255
value: 指定消息段或消息段类型
257256
258-
返回:
257+
Return:
259258
是否仅包含指定消息段
260259
"""
261260
if isinstance(value, type):
@@ -267,10 +266,10 @@ def only(self, value: type[Element] | str | Element) -> bool:
267266
def join(self, iterable: Iterable[TE1 | MessageChain[TE1]]) -> MessageChain[TE | TE1]:
268267
"""将多个消息连接并将自身作为分割
269268
270-
参数:
269+
Args:
271270
iterable: 要连接的消息
272271
273-
返回:
272+
Return:
274273
连接后的消息
275274
"""
276275
ret = MessageChain()
@@ -287,24 +286,24 @@ def copy(self) -> MessageChain[TE]:
287286
"""深拷贝消息"""
288287
return deepcopy(self)
289288

290-
def include(self, *types: type[Element]) -> Self:
289+
def include(self, *types: type[Element]) -> MessageChain:
291290
"""过滤消息
292291
293-
参数:
292+
Args:
294293
types: 包含的消息段类型
295294
296-
返回:
295+
Return:
297296
新构造的消息
298297
"""
299298
return MessageChain(seg for seg in self if seg.__class__ in types)
300299

301-
def exclude(self, *types: type[Element]) -> Self:
300+
def exclude(self, *types: type[Element]) -> MessageChain:
302301
"""过滤消息
303302
304-
参数:
303+
Args:
305304
types: 不包含的消息段类型
306305
307-
返回:
306+
Return:
308307
新构造的消息
309308
"""
310309
return MessageChain(seg for seg in self if seg.__class__ not in types)

example_plugin.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
1-
from arclet.entari import MessageCreatedEvent, Plugin, EntariCommands, ContextSession, AlconnaDispatcher
2-
from arclet.alconna import Alconna
1+
from arclet.entari import MessageCreatedEvent, Plugin, EntariCommands, ContextSession, AlconnaDispatcher, is_direct_message
2+
from arclet.alconna import Alconna, Args, AllParam
33

44
plug = Plugin()
55

66
disp_message = plug.dispatch(MessageCreatedEvent)
77

88

9-
@disp_message.on()
9+
@disp_message.on(auxiliaries=[is_direct_message])
1010
async def _(event: MessageCreatedEvent):
1111
print(event.content)
1212

1313

14-
on_alconna = plug.mount(AlconnaDispatcher(Alconna("test")))
14+
on_alconna = plug.mount(AlconnaDispatcher(Alconna("chat", Args["content", AllParam])))
1515

1616

1717
@on_alconna.on()

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "arclet-entari"
3-
version = "0.3.1"
3+
version = "0.3.2"
44
description = "Simple IM Framework based on satori-python"
55
authors = [
66
{name = "RF-Tar-Railt",email = "rf_tar_railt@qq.com"},

0 commit comments

Comments
 (0)