2
2
3
3
if t .TYPE_CHECKING :
4
4
from .symbol import Symbol
5
+ from ..typing import ATTR_TYPES
5
6
6
7
T1 = t .TypeVar ("T1" )
7
8
T2 = t .TypeVar ("T2" )
10
11
11
12
ARGS = t .ParamSpec ("ARGS" )
12
13
14
+ class HashableList (t .Generic [T1 ]):
15
+ def __init__ (self , lst :'list[T1]' ):
16
+ self .lst = lst
17
+
18
+ def __hash__ (self ):
19
+ # Convert list to tuple for hashing
20
+ return hash (tuple (self .lst ))
21
+
22
+ def __eq__ (self , other ):
23
+ if not isinstance (other , HashableList ):
24
+ return False
25
+ return self .lst == other .lst
26
+
27
+ def __repr__ (self ):
28
+ return f"HashableList({ self .lst } )"
29
+
30
+
31
+ class HashableDict (t .Generic [T1 , T2 ]):
32
+ def __init__ (self , dct :'dict[T1, T2]' ):
33
+ self .dct = dct
34
+
35
+ def __hash__ (self ):
36
+ # Convert list to tuple for hashing
37
+ return hash (tuple (self .dct .items ()))
38
+
39
+ def __eq__ (self , other ):
40
+ if not isinstance (other , HashableDict ):
41
+ return False
42
+ return self .dct == other .dct
43
+
44
+ def __repr__ (self ):
45
+ return f"HashableList({ self .dct } )"
46
+
47
+
13
48
class GetProtocol (t .Protocol , t .Generic [T1 , T2 ]):
14
49
def get (self , key : 'T1' , ) -> 'T2' : ...
15
50
@@ -20,18 +55,21 @@ def copy(self) -> 'T1': ...
20
55
class Copy :
21
56
def __init__ (self , data ):
22
57
self .data = data
23
-
58
+
24
59
def copy (self ):
25
60
return self .data
26
61
27
62
T5 = t .TypeVar ("T5" , bound = CopyProtocol )
63
+ HASHABLE_ATTRS = str | bool | int | float | HashableList ['HASHABLE_ATTRS' ] | HashableDict [str , 'HASHABLE_ATTRS' ]
28
64
29
65
class Fetcher (t .Generic [T1 , T2 , T5 ]):
30
- def __init__ (self , data : 'GetProtocol[T1, T2]' , default :'T5' = Copy (None )):
66
+ def __init__ (self , data : 't.Union[ GetProtocol[T1, T2], dict[T1, T2] ]' , default :'T5' = Copy (None )):
31
67
self .data = data
32
68
self .default = default .copy () if isinstance (default , CopyProtocol ) else default
33
69
34
70
def __getitem__ (self , name :'T1' ) -> 'T2|T5' :
71
+ if isinstance (self .data , dict ):
72
+ return self .data .get (name , self .default )
35
73
return self .data .get (name , self .default )
36
74
class InnerHTML :
37
75
def __init__ (self , inner ):
@@ -40,35 +78,54 @@ def __init__(self, inner):
40
78
self .ids : 'dict[str|None, list[Symbol]]' = {}
41
79
self .classes : 'dict[str, list[Symbol]]' = {}
42
80
self .tags : 'dict[type[Symbol], list[Symbol]]' = {}
81
+ self .attrs : 'dict[str, dict[HASHABLE_ATTRS, list[Symbol]]]' = {}
82
+ self .text : 'dict[str, list[Symbol]]' = {}
43
83
44
84
self .children_ids : 'dict[str|None, list[Symbol]]' = {}
45
85
self .children_classes : 'dict[str, list[Symbol]]' = {}
46
86
self .children_tags : 'dict[type[Symbol], list[Symbol]]' = {}
87
+ self .children_attrs : 'dict[str, dict[str, list[Symbol]]]' = {}
88
+ self .children_text : 'dict[str, list[Symbol]]' = {}
47
89
48
- def add_elm (self , elm :'Symbol' ):
49
- """
50
- Add an element to the children indexes and merge the element's own indexes
51
- recursively into aggregate indexes.
90
+ def add_elm (self , elm : 'Symbol' ):
91
+ def make_hashable (v ):
92
+ if isinstance (v , list ):
93
+ return HashableList (v )
94
+ elif isinstance (v , dict ):
95
+ return HashableDict (v )
96
+ return v
52
97
53
- Args:
54
- elm: Symbol element to add to the indexes.
55
- """
56
98
self .children_ids .setdefault (elm .get_prop ("id" , None ), []).append (elm )
57
99
[self .children_classes .setdefault (c , []).append (elm ) for c in elm .classes ]
58
100
self .children_tags .setdefault (type (elm ), []).append (elm )
59
101
60
- def concat (d1 : 'dict[T1|T3, list[T2|T4]]' , * d2 : 'dict[T3, list[T4]]' , ** kwargs ):
61
- ret = {** kwargs }
102
+ # Normalize keys when adding to children_attrs
103
+ for prop , value in elm .props .items ():
104
+ key = make_hashable (value )
105
+ self .children_attrs .setdefault (prop , {}).setdefault (key , []).append (elm )
62
106
63
- for dict in list (d2 ) + [d1 ]:
64
- for k , v in dict .items ():
107
+ self .children_text .setdefault (elm .text , []).append (elm )
108
+
109
+ def concat (d1 : 'dict' , * d2 : 'dict' ):
110
+ ret = {}
111
+
112
+ for dict_ in list (d2 ) + [d1 ]:
113
+ for k , v in dict_ .items ():
65
114
ret .setdefault (k , []).extend (v )
66
115
67
116
return ret
68
117
118
+ # Normalize keys in elm.props for attrs merging
119
+ normalized_props = {
120
+ prop : {make_hashable (value ): [elm ] for value in values }
121
+ for prop , values in elm .props .items ()
122
+ }
123
+
69
124
self .ids = concat (self .ids , elm .inner_html .ids , {elm .get_prop ("id" , None ): [elm ]})
70
125
self .classes = concat (self .classes , elm .inner_html .classes , {c : [elm ] for c in elm .classes })
71
126
self .tags = concat (self .tags , elm .inner_html .tags , {type (elm ): [elm ]})
127
+ self .attrs = concat (self .attrs , elm .inner_html .attrs , normalized_props )
128
+ self .text = concat (self .text , elm .inner_html .text , {elm .text : [elm ]})
72
129
73
130
def get_elements_by_id (self , id : 'str' ):
74
131
return self .ids .get (id , [])
@@ -77,7 +134,45 @@ def get_elements_by_class_name(self, class_name: 'str'):
77
134
return self .classes .get (class_name , [])
78
135
79
136
def get_elements_by_tag_name (self , tag : 'str' ):
80
- return self .tags .get (tag , [])
137
+ # Find the tag class by name
138
+ for tag_class , elements in self .tags .items ():
139
+ if tag_class .__name__ .lower () == tag .lower ():
140
+ return elements
141
+ return []
142
+
143
+ def find (self , key :'str' ):
144
+ if key .startswith ("#" ):
145
+ return self .get_elements_by_id (key [1 :])
146
+ elif key .startswith ("." ):
147
+ return self .get_elements_by_class_name (key [1 :])
148
+ else :
149
+ return self .get_elements_by_tag_name (key )
150
+
151
+ def get_by_text (self , text :'str' ):
152
+ return self .text .get (text , [])
153
+
154
+ def get_by_attr (self , attr :'str' , value :'str' ):
155
+ return self .attrs .get (attr , {}).get (value , [])
156
+
157
+ def advanced_find (self , tag :'str' , attrs :'dict[t.Literal["text"] | str, str | bool | int | float | tuple[str, str | bool | int | float] | list[str | bool | int | float | tuple[str, str | bool | int | float]]]' = {}):
158
+ def check_attr (e :'Symbol' , k :'str' , v :'str | bool | int | float | tuple[str, str | bool | int | float]' ):
159
+ prop = e .get_prop (k )
160
+ if isinstance (prop , list ):
161
+ return v in prop
162
+
163
+ if isinstance (prop , dict ):
164
+ return v in list (prop .items ())
165
+
166
+ return prop == v
167
+
168
+ tags = self .find (tag )
169
+ if "text" in attrs :
170
+ text = attrs .pop ("text" )
171
+ tags = filter (lambda e : e .text == text , tags )
172
+
173
+ for k , v in attrs .items ():
174
+ tags = filter (lambda e : check_attr (e , k , v ) if not isinstance (v , list ) else all ([check_attr (e , k , i ) for i in v ]), tags )
175
+ return list (tags )
81
176
82
177
@property
83
178
def id (self ):
0 commit comments