Skip to content

Commit 20e415d

Browse files
author
Rens Dimmendaal
committed
add retain_types from fastcore.dispatch to fasttransform.utils
1 parent 2659f09 commit 20e415d

File tree

3 files changed

+75
-2
lines changed

3 files changed

+75
-2
lines changed

fasttransform/_modidx.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -100,4 +100,5 @@
100100
'fasttransform.utils.get_name': ('utils.html#get_name', 'fasttransform/utils.py'),
101101
'fasttransform.utils.is_tuple': ('utils.html#is_tuple', 'fasttransform/utils.py'),
102102
'fasttransform.utils.retain_meta': ('utils.html#retain_meta', 'fasttransform/utils.py'),
103-
'fasttransform.utils.retain_type': ('utils.html#retain_type', 'fasttransform/utils.py')}}}
103+
'fasttransform.utils.retain_type': ('utils.html#retain_type', 'fasttransform/utils.py'),
104+
'fasttransform.utils.retain_types': ('utils.html#retain_types', 'fasttransform/utils.py')}}}

fasttransform/utils.py

+16-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/00_utils.ipynb.
22

33
# %% auto 0
4-
__all__ = ['get_name', 'is_tuple', 'retain_meta', 'default_set_meta', 'cast', 'retain_type']
4+
__all__ = ['get_name', 'is_tuple', 'retain_meta', 'default_set_meta', 'cast', 'retain_type', 'retain_types']
55

66
# %% ../nbs/00_utils.ipynb 1
77
from typing import Any
@@ -60,3 +60,18 @@ def retain_type(new, old, ret_type=Any,as_copy=False):
6060
ret_type = old if isinstance(old,type) else type(old)
6161
if ret_type is NoneType or isinstance(new,ret_type): return new
6262
return retain_meta(old, cast(new, ret_type), as_copy=as_copy)
63+
64+
# %% ../nbs/00_utils.ipynb 40
65+
def retain_types(new, old=None, typs=None):
66+
"Cast each item of `new` to type of matching item in `old` if it's a superclass"
67+
if not is_listy(new):
68+
typs = Any if typs is None else typs # make fasttransform.utils.retain_type compatible
69+
return retain_type(new, old,typs)
70+
if typs is not None:
71+
if isinstance(typs, dict):
72+
t = first(typs.keys())
73+
typs = typs[t]
74+
else: t,typs = typs,None
75+
else: t = type(old) if old is not None and isinstance(old,type(new)) else type(new)
76+
return t(L(new, old, typs).map_zip(retain_types, cycled=True))
77+

nbs/00_utils.ipynb

+57
Original file line numberDiff line numberDiff line change
@@ -432,6 +432,63 @@
432432
"test_eq(retain_type(FS(1.), None, Any), FS(1.))"
433433
]
434434
},
435+
{
436+
"cell_type": "markdown",
437+
"id": "fe4ee917",
438+
"metadata": {},
439+
"source": [
440+
"## Retain types"
441+
]
442+
},
443+
{
444+
"cell_type": "markdown",
445+
"id": "17849108",
446+
"metadata": {},
447+
"source": [
448+
"Copied from fastcore.dispatch, Used in fastai."
449+
]
450+
},
451+
{
452+
"cell_type": "code",
453+
"execution_count": null,
454+
"id": "f3cd995c",
455+
"metadata": {},
456+
"outputs": [],
457+
"source": [
458+
"#|export\n",
459+
"def retain_types(new, old=None, typs=None):\n",
460+
" \"Cast each item of `new` to type of matching item in `old` if it's a superclass\"\n",
461+
" if not is_listy(new): \n",
462+
" typs = Any if typs is None else typs # make fasttransform.utils.retain_type compatible\n",
463+
" return retain_type(new, old,typs)\n",
464+
" if typs is not None:\n",
465+
" if isinstance(typs, dict):\n",
466+
" t = first(typs.keys())\n",
467+
" typs = typs[t]\n",
468+
" else: t,typs = typs,None\n",
469+
" else: t = type(old) if old is not None and isinstance(old,type(new)) else type(new)\n",
470+
" return t(L(new, old, typs).map_zip(retain_types, cycled=True))\n",
471+
" "
472+
]
473+
},
474+
{
475+
"cell_type": "code",
476+
"execution_count": null,
477+
"id": "37c1d02c",
478+
"metadata": {},
479+
"outputs": [],
480+
"source": [
481+
"class T(tuple): pass\n",
482+
"\n",
483+
"t1,t2 = retain_types((1,(1,(1,1))), (2,T((2,T((3,4))))))\n",
484+
"test_eq_type(t1, 1)\n",
485+
"test_eq_type(t2, T((1,T((1,1)))))\n",
486+
"\n",
487+
"t1,t2 = retain_types((1,(1,(1,1))), typs = {tuple: [int, {T: [int, {T: [int,int]}]}]})\n",
488+
"test_eq_type(t1, 1)\n",
489+
"test_eq_type(t2, T((1,T((1,1)))))"
490+
]
491+
},
435492
{
436493
"cell_type": "markdown",
437494
"id": "b712c700",

0 commit comments

Comments
 (0)