8
8
9
9
# %% ../nbs/01_transform.ipynb 1
10
10
from typing import Any
11
+ import inspect
11
12
12
13
from fastcore .imports import *
13
14
from fastcore .foundation import *
@@ -51,7 +52,7 @@ def __setitem__(self, k, v):
51
52
if k not in self : super ().__setitem__ (k , Function (v ))
52
53
self [k ].dispatch (v )
53
54
54
- # %% ../nbs/01_transform.ipynb 13
55
+ # %% ../nbs/01_transform.ipynb 14
55
56
class _TfmMeta (type ):
56
57
@classmethod
57
58
def __prepare__ (cls , name , bases ): return _TfmDict ()
@@ -63,7 +64,11 @@ def __call__(cls, *args, **kwargs):
63
64
if not hasattr (cls , nm ): setattr (cls , nm , Function (f ).dispatch (f ))
64
65
else : getattr (cls ,nm ).dispatch (f )
65
66
return cls
66
- return super ().__call__ (* args , ** kwargs )
67
+ obj = super ().__call__ (* args , ** kwargs )
68
+ # _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
69
+ # instances of cls, fix it
70
+ if hasattr (obj , '__call__' ): obj .__signature__ = inspect .signature (obj .__call__ )
71
+ return obj
67
72
68
73
69
74
def __new__ (cls , name , bases , namespace ):
@@ -73,9 +78,11 @@ def __new__(cls, name, bases, namespace):
73
78
funcs = [getattr (new_cls , nm )] + [getattr (b , nm ,None ) for b in bases ]
74
79
funcs = [f for f in funcs if f ]
75
80
if funcs : setattr (new_cls , nm , _merge_funcs (* funcs ))
81
+ # _TfmMeta.__call__ shadows the signature of inheriting classes, set it back
82
+ new_cls .__signature__ = inspect .signature (new_cls .__init__ )
76
83
return new_cls
77
84
78
- # %% ../nbs/01_transform.ipynb 14
85
+ # %% ../nbs/01_transform.ipynb 15
79
86
class Transform (metaclass = _TfmMeta ):
80
87
"Delegates (`__call__`,`decode`,`setup`) to (<code>encodes</code>,<code>decodes</code>,<code>setups</code>) if `split_idx` matches"
81
88
split_idx ,init_enc ,order ,train_setup = None ,None ,0 ,None
@@ -130,21 +137,21 @@ def _do_call(self, nm, *args, **kwargs):
130
137
131
138
add_docs (Transform , decode = "Delegate to decodes to undo transform" , setup = "Delegate to setups to set up transform" )
132
139
133
- # %% ../nbs/01_transform.ipynb 155
140
+ # %% ../nbs/01_transform.ipynb 157
134
141
class InplaceTransform (Transform ):
135
142
"A `Transform` that modifies in-place and just returns whatever it's passed"
136
143
def _call (self , fn , * args , split_idx = None , ** kwargs ):
137
144
super ()._call (fn ,* args , split_idx = split_idx , ** kwargs )
138
145
return args [0 ]
139
146
140
- # %% ../nbs/01_transform.ipynb 159
147
+ # %% ../nbs/01_transform.ipynb 161
141
148
class DisplayedTransform (Transform ):
142
149
"A transform with a `__repr__` that shows its attrs"
143
150
144
151
@property
145
152
def name (self ): return f"{ super ().name } -- { getattr (self ,'__stored_args__' ,{})} \n "
146
153
147
- # %% ../nbs/01_transform.ipynb 165
154
+ # %% ../nbs/01_transform.ipynb 167
148
155
class ItemTransform (Transform ):
149
156
"A transform that always take tuples as items"
150
157
_retain = True
@@ -158,21 +165,21 @@ def _call1(self, x, name, **kwargs):
158
165
return retain_type (y , x , Any )
159
166
160
167
161
- # %% ../nbs/01_transform.ipynb 174
168
+ # %% ../nbs/01_transform.ipynb 176
162
169
def get_func (t , name , * args , ** kwargs ):
163
170
"Get the `t.name` (potentially partial-ized with `args` and `kwargs`) or `noop` if not defined"
164
171
f = nested_callable (t , name )
165
172
return f if not (args or kwargs ) else partial (f , * args , ** kwargs )
166
173
167
- # %% ../nbs/01_transform.ipynb 178
174
+ # %% ../nbs/01_transform.ipynb 180
168
175
class Func ():
169
176
"Basic wrapper around a `name` with `args` and `kwargs` to call on a given type"
170
177
def __init__ (self , name , * args , ** kwargs ): self .name ,self .args ,self .kwargs = name ,args ,kwargs
171
178
def __repr__ (self ): return f'sig: { self .name } ({ self .args } , { self .kwargs } )'
172
179
def _get (self , t ): return get_func (t , self .name , * self .args , ** self .kwargs )
173
180
def __call__ (self ,t ): return mapped (self ._get , t )
174
181
175
- # %% ../nbs/01_transform.ipynb 181
182
+ # %% ../nbs/01_transform.ipynb 183
176
183
class _Sig ():
177
184
def __getattr__ (self ,k ):
178
185
def _inner (* args , ** kwargs ): return Func (k , * args , ** kwargs )
@@ -181,7 +188,7 @@ def _inner(*args, **kwargs): return Func(k, *args, **kwargs)
181
188
Sig = _Sig ()
182
189
183
190
184
- # %% ../nbs/01_transform.ipynb 187
191
+ # %% ../nbs/01_transform.ipynb 189
185
192
def compose_tfms (x , tfms , is_enc = True , reverse = False , ** kwargs ):
186
193
"Apply all `func_nm` attribute of `tfms` on `x`, maybe in `reverse` order"
187
194
if reverse : tfms = reversed (tfms )
@@ -191,13 +198,13 @@ def compose_tfms(x, tfms, is_enc=True, reverse=False, **kwargs):
191
198
return x
192
199
193
200
194
- # %% ../nbs/01_transform.ipynb 192
201
+ # %% ../nbs/01_transform.ipynb 194
195
202
def mk_transform (f ):
196
203
"Convert function `f` to `Transform` if it isn't already one"
197
204
f = instantiate (f )
198
205
return f if isinstance (f ,(Transform ,Pipeline )) else Transform (f )
199
206
200
- # %% ../nbs/01_transform.ipynb 193
207
+ # %% ../nbs/01_transform.ipynb 195
201
208
def gather_attrs (o , k , nm ):
202
209
"Used in __getattr__ to collect all attrs `k` from `self.{nm}`"
203
210
if k .startswith ('_' ) or k == nm : raise AttributeError (k )
@@ -206,12 +213,12 @@ def gather_attrs(o, k, nm):
206
213
if not res : raise AttributeError (k )
207
214
return res [0 ] if len (res )== 1 else L (res )
208
215
209
- # %% ../nbs/01_transform.ipynb 194
216
+ # %% ../nbs/01_transform.ipynb 196
210
217
def gather_attr_names (o , nm ):
211
218
"Used in __dir__ to collect all attrs `k` from `self.{nm}`"
212
219
return L (getattr (o ,nm )).map (dir ).concat ().unique ()
213
220
214
- # %% ../nbs/01_transform.ipynb 195
221
+ # %% ../nbs/01_transform.ipynb 197
215
222
class Pipeline :
216
223
"A pipeline of composed (for encode/decode) transforms, setup with types"
217
224
def __init__ (self , funcs = None , split_idx = None ):
0 commit comments