Skip to content

Commit dc39fc9

Browse files
fix transform signatures just like fastcore did
1 parent c1cc529 commit dc39fc9

File tree

2 files changed

+154
-76
lines changed

2 files changed

+154
-76
lines changed

fasttransform/transform.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
# %% ../nbs/01_transform.ipynb 1
1010
from typing import Any
11+
import inspect
1112

1213
from fastcore.imports import *
1314
from fastcore.foundation import *
@@ -51,7 +52,7 @@ def __setitem__(self, k, v):
5152
if k not in self: super().__setitem__(k, Function(v))
5253
self[k].dispatch(v)
5354

54-
# %% ../nbs/01_transform.ipynb 13
55+
# %% ../nbs/01_transform.ipynb 14
5556
class _TfmMeta(type):
5657
@classmethod
5758
def __prepare__(cls, name, bases): return _TfmDict()
@@ -63,7 +64,11 @@ def __call__(cls, *args, **kwargs):
6364
if not hasattr(cls, nm): setattr(cls, nm, Function(f).dispatch(f))
6465
else: getattr(cls,nm).dispatch(f)
6566
return cls
66-
return super().__call__(*args, **kwargs)
67+
obj = super().__call__(*args, **kwargs)
68+
# _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
69+
# instances of cls, fix it
70+
if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)
71+
return obj
6772

6873

6974
def __new__(cls, name, bases, namespace):
@@ -73,9 +78,11 @@ def __new__(cls, name, bases, namespace):
7378
funcs = [getattr(new_cls, nm)] + [getattr(b, nm,None) for b in bases]
7479
funcs = [f for f in funcs if f]
7580
if funcs: setattr(new_cls, nm, _merge_funcs(*funcs))
81+
# _TfmMeta.__call__ shadows the signature of inheriting classes, set it back
82+
new_cls.__signature__ = inspect.signature(new_cls.__init__)
7683
return new_cls
7784

78-
# %% ../nbs/01_transform.ipynb 14
85+
# %% ../nbs/01_transform.ipynb 15
7986
class Transform(metaclass=_TfmMeta):
8087
"Delegates (`__call__`,`decode`,`setup`) to (<code>encodes</code>,<code>decodes</code>,<code>setups</code>) if `split_idx` matches"
8188
split_idx,init_enc,order,train_setup = None,None,0,None
@@ -130,21 +137,21 @@ def _do_call(self, nm, *args, **kwargs):
130137

131138
add_docs(Transform, decode="Delegate to decodes to undo transform", setup="Delegate to setups to set up transform")
132139

133-
# %% ../nbs/01_transform.ipynb 155
140+
# %% ../nbs/01_transform.ipynb 157
134141
class InplaceTransform(Transform):
135142
"A `Transform` that modifies in-place and just returns whatever it's passed"
136143
def _call(self, fn, *args, split_idx=None, **kwargs):
137144
super()._call(fn,*args, split_idx=split_idx, **kwargs)
138145
return args[0]
139146

140-
# %% ../nbs/01_transform.ipynb 159
147+
# %% ../nbs/01_transform.ipynb 161
141148
class DisplayedTransform(Transform):
142149
"A transform with a `__repr__` that shows its attrs"
143150

144151
@property
145152
def name(self): return f"{super().name} -- {getattr(self,'__stored_args__',{})}\n"
146153

147-
# %% ../nbs/01_transform.ipynb 165
154+
# %% ../nbs/01_transform.ipynb 167
148155
class ItemTransform(Transform):
149156
"A transform that always take tuples as items"
150157
_retain = True
@@ -158,21 +165,21 @@ def _call1(self, x, name, **kwargs):
158165
return retain_type(y, x, Any)
159166

160167

161-
# %% ../nbs/01_transform.ipynb 174
168+
# %% ../nbs/01_transform.ipynb 176
162169
def get_func(t, name, *args, **kwargs):
163170
"Get the `t.name` (potentially partial-ized with `args` and `kwargs`) or `noop` if not defined"
164171
f = nested_callable(t, name)
165172
return f if not (args or kwargs) else partial(f, *args, **kwargs)
166173

167-
# %% ../nbs/01_transform.ipynb 178
174+
# %% ../nbs/01_transform.ipynb 180
168175
class Func():
169176
"Basic wrapper around a `name` with `args` and `kwargs` to call on a given type"
170177
def __init__(self, name, *args, **kwargs): self.name,self.args,self.kwargs = name,args,kwargs
171178
def __repr__(self): return f'sig: {self.name}({self.args}, {self.kwargs})'
172179
def _get(self, t): return get_func(t, self.name, *self.args, **self.kwargs)
173180
def __call__(self,t): return mapped(self._get, t)
174181

175-
# %% ../nbs/01_transform.ipynb 181
182+
# %% ../nbs/01_transform.ipynb 183
176183
class _Sig():
177184
def __getattr__(self,k):
178185
def _inner(*args, **kwargs): return Func(k, *args, **kwargs)
@@ -181,7 +188,7 @@ def _inner(*args, **kwargs): return Func(k, *args, **kwargs)
181188
Sig = _Sig()
182189

183190

184-
# %% ../nbs/01_transform.ipynb 187
191+
# %% ../nbs/01_transform.ipynb 189
185192
def compose_tfms(x, tfms, is_enc=True, reverse=False, **kwargs):
186193
"Apply all `func_nm` attribute of `tfms` on `x`, maybe in `reverse` order"
187194
if reverse: tfms = reversed(tfms)
@@ -191,13 +198,13 @@ def compose_tfms(x, tfms, is_enc=True, reverse=False, **kwargs):
191198
return x
192199

193200

194-
# %% ../nbs/01_transform.ipynb 192
201+
# %% ../nbs/01_transform.ipynb 194
195202
def mk_transform(f):
196203
"Convert function `f` to `Transform` if it isn't already one"
197204
f = instantiate(f)
198205
return f if isinstance(f,(Transform,Pipeline)) else Transform(f)
199206

200-
# %% ../nbs/01_transform.ipynb 193
207+
# %% ../nbs/01_transform.ipynb 195
201208
def gather_attrs(o, k, nm):
202209
"Used in __getattr__ to collect all attrs `k` from `self.{nm}`"
203210
if k.startswith('_') or k==nm: raise AttributeError(k)
@@ -206,12 +213,12 @@ def gather_attrs(o, k, nm):
206213
if not res: raise AttributeError(k)
207214
return res[0] if len(res)==1 else L(res)
208215

209-
# %% ../nbs/01_transform.ipynb 194
216+
# %% ../nbs/01_transform.ipynb 196
210217
def gather_attr_names(o, nm):
211218
"Used in __dir__ to collect all attrs `k` from `self.{nm}`"
212219
return L(getattr(o,nm)).map(dir).concat().unique()
213220

214-
# %% ../nbs/01_transform.ipynb 195
221+
# %% ../nbs/01_transform.ipynb 197
215222
class Pipeline:
216223
"A pipeline of composed (for encode/decode) transforms, setup with types"
217224
def __init__(self, funcs=None, split_idx=None):

nbs/01_transform.ipynb

+133-62
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)