Skip to content

Commit 3a6dc23

Browse files
authored
Merge pull request #6 from AnswerDotAI/fix-transform-sigs
Fix transform signatures
2 parents c1cc529 + 154aacf commit 3a6dc23

File tree

3 files changed

+223
-17
lines changed

3 files changed

+223
-17
lines changed

fasttransform/transform.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
# %% ../nbs/01_transform.ipynb 1
1010
from typing import Any
11+
import inspect
1112

1213
from fastcore.imports import *
1314
from fastcore.foundation import *
@@ -63,7 +64,11 @@ def __call__(cls, *args, **kwargs):
6364
if not hasattr(cls, nm): setattr(cls, nm, Function(f).dispatch(f))
6465
else: getattr(cls,nm).dispatch(f)
6566
return cls
66-
return super().__call__(*args, **kwargs)
67+
obj = super().__call__(*args, **kwargs)
68+
# _TfmMeta.__new__ replaces cls.__signature__ which breaks the signature of a callable
69+
# instances of cls, fix it
70+
if hasattr(obj, '__call__'): obj.__signature__ = inspect.signature(obj.__call__)
71+
return obj
6772

6873

6974
def __new__(cls, name, bases, namespace):
@@ -73,6 +78,8 @@ def __new__(cls, name, bases, namespace):
7378
funcs = [getattr(new_cls, nm)] + [getattr(b, nm,None) for b in bases]
7479
funcs = [f for f in funcs if f]
7580
if funcs: setattr(new_cls, nm, _merge_funcs(*funcs))
81+
# _TfmMeta.__call__ shadows the signature of inheriting classes, set it back
82+
new_cls.__signature__ = inspect.signature(new_cls.__init__)
7683
return new_cls
7784

7885
# %% ../nbs/01_transform.ipynb 14
@@ -130,21 +137,21 @@ def _do_call(self, nm, *args, **kwargs):
130137

131138
add_docs(Transform, decode="Delegate to decodes to undo transform", setup="Delegate to setups to set up transform")
132139

133-
# %% ../nbs/01_transform.ipynb 155
140+
# %% ../nbs/01_transform.ipynb 156
134141
class InplaceTransform(Transform):
135142
"A `Transform` that modifies in-place and just returns whatever it's passed"
136143
def _call(self, fn, *args, split_idx=None, **kwargs):
137144
super()._call(fn,*args, split_idx=split_idx, **kwargs)
138145
return args[0]
139146

140-
# %% ../nbs/01_transform.ipynb 159
147+
# %% ../nbs/01_transform.ipynb 160
141148
class DisplayedTransform(Transform):
142149
"A transform with a `__repr__` that shows its attrs"
143150

144151
@property
145152
def name(self): return f"{super().name} -- {getattr(self,'__stored_args__',{})}\n"
146153

147-
# %% ../nbs/01_transform.ipynb 165
154+
# %% ../nbs/01_transform.ipynb 166
148155
class ItemTransform(Transform):
149156
"A transform that always take tuples as items"
150157
_retain = True
@@ -158,21 +165,21 @@ def _call1(self, x, name, **kwargs):
158165
return retain_type(y, x, Any)
159166

160167

161-
# %% ../nbs/01_transform.ipynb 174
168+
# %% ../nbs/01_transform.ipynb 175
162169
def get_func(t, name, *args, **kwargs):
163170
"Get the `t.name` (potentially partial-ized with `args` and `kwargs`) or `noop` if not defined"
164171
f = nested_callable(t, name)
165172
return f if not (args or kwargs) else partial(f, *args, **kwargs)
166173

167-
# %% ../nbs/01_transform.ipynb 178
174+
# %% ../nbs/01_transform.ipynb 179
168175
class Func():
169176
"Basic wrapper around a `name` with `args` and `kwargs` to call on a given type"
170177
def __init__(self, name, *args, **kwargs): self.name,self.args,self.kwargs = name,args,kwargs
171178
def __repr__(self): return f'sig: {self.name}({self.args}, {self.kwargs})'
172179
def _get(self, t): return get_func(t, self.name, *self.args, **self.kwargs)
173180
def __call__(self,t): return mapped(self._get, t)
174181

175-
# %% ../nbs/01_transform.ipynb 181
182+
# %% ../nbs/01_transform.ipynb 182
176183
class _Sig():
177184
def __getattr__(self,k):
178185
def _inner(*args, **kwargs): return Func(k, *args, **kwargs)
@@ -181,7 +188,7 @@ def _inner(*args, **kwargs): return Func(k, *args, **kwargs)
181188
Sig = _Sig()
182189

183190

184-
# %% ../nbs/01_transform.ipynb 187
191+
# %% ../nbs/01_transform.ipynb 188
185192
def compose_tfms(x, tfms, is_enc=True, reverse=False, **kwargs):
186193
"Apply all `func_nm` attribute of `tfms` on `x`, maybe in `reverse` order"
187194
if reverse: tfms = reversed(tfms)
@@ -191,13 +198,13 @@ def compose_tfms(x, tfms, is_enc=True, reverse=False, **kwargs):
191198
return x
192199

193200

194-
# %% ../nbs/01_transform.ipynb 192
201+
# %% ../nbs/01_transform.ipynb 193
195202
def mk_transform(f):
196203
"Convert function `f` to `Transform` if it isn't already one"
197204
f = instantiate(f)
198205
return f if isinstance(f,(Transform,Pipeline)) else Transform(f)
199206

200-
# %% ../nbs/01_transform.ipynb 193
207+
# %% ../nbs/01_transform.ipynb 194
201208
def gather_attrs(o, k, nm):
202209
"Used in __getattr__ to collect all attrs `k` from `self.{nm}`"
203210
if k.startswith('_') or k==nm: raise AttributeError(k)
@@ -206,12 +213,12 @@ def gather_attrs(o, k, nm):
206213
if not res: raise AttributeError(k)
207214
return res[0] if len(res)==1 else L(res)
208215

209-
# %% ../nbs/01_transform.ipynb 194
216+
# %% ../nbs/01_transform.ipynb 195
210217
def gather_attr_names(o, nm):
211218
"Used in __dir__ to collect all attrs `k` from `self.{nm}`"
212219
return L(getattr(o,nm)).map(dir).concat().unique()
213220

214-
# %% ../nbs/01_transform.ipynb 195
221+
# %% ../nbs/01_transform.ipynb 196
215222
class Pipeline:
216223
"A pipeline of composed (for encode/decode) transforms, setup with types"
217224
def __init__(self, funcs=None, split_idx=None):

nbs/01_transform.ipynb

Lines changed: 186 additions & 3 deletions
Large diffs are not rendered by default.

nbs/fastcore_migration_guide.ipynb

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@
7070
"metadata": {},
7171
"outputs": [],
7272
"source": [
73+
"#|eval: false\n",
7374
"from fastcore.dispatch import typedispatch\n",
7475
"\n",
7576
"@typedispatch \n",
@@ -186,6 +187,7 @@
186187
}
187188
],
188189
"source": [
190+
"#|eval: false\n",
189191
"from fastcore.dispatch import TypeDispatch\n",
190192
"t_fc = TypeDispatch(fs)\n",
191193
"t_fc"
@@ -265,6 +267,7 @@
265267
}
266268
],
267269
"source": [
270+
"#|eval: false\n",
268271
"t_fc.add(lambda x: x**2)\n",
269272
"t_fc"
270273
]
@@ -373,6 +376,16 @@
373376
"Before:"
374377
]
375378
},
379+
{
380+
"cell_type": "code",
381+
"execution_count": null,
382+
"id": "ff9599da",
383+
"metadata": {},
384+
"outputs": [],
385+
"source": [
386+
"def f_str(x:str): return x+'1'"
387+
]
388+
},
376389
{
377390
"cell_type": "code",
378391
"execution_count": null,
@@ -399,8 +412,7 @@
399412
}
400413
],
401414
"source": [
402-
"def f_str(x:str): return x+'1'\n",
403-
"\n",
415+
"#|eval: false\n",
404416
"t_fc2 = TypeDispatch(f_str, bases=t_fc)\n",
405417
"t_fc2"
406418
]
@@ -539,6 +551,7 @@
539551
}
540552
],
541553
"source": [
554+
"#|eval: false\n",
542555
"t_fc[int]"
543556
]
544557
},
@@ -559,6 +572,7 @@
559572
"metadata": {},
560573
"outputs": [],
561574
"source": [
575+
"#|eval: false\n",
562576
"t_fc.returns(5)"
563577
]
564578
},
@@ -661,6 +675,7 @@
661675
}
662676
],
663677
"source": [
678+
"#|eval: false\n",
664679
"@typedispatch\n",
665680
"def f2_fc(x:int|float): return x+2\n",
666681
"@typedispatch\n",
@@ -776,6 +791,7 @@
776791
}
777792
],
778793
"source": [
794+
"#|eval: false\n",
779795
"# Before (subclassing required)\n",
780796
"from fastcore.transform import Transform as FCTransform\n",
781797
"\n",

0 commit comments

Comments
 (0)