|
432 | 432 | "test_eq(retain_type(FS(1.), None, Any), FS(1.))"
|
433 | 433 | ]
|
434 | 434 | },
|
| 435 | + { |
| 436 | + "cell_type": "markdown", |
| 437 | + "id": "fe4ee917", |
| 438 | + "metadata": {}, |
| 439 | + "source": [ |
| 440 | + "## Retain types" |
| 441 | + ] |
| 442 | + }, |
| 443 | + { |
| 444 | + "cell_type": "markdown", |
| 445 | + "id": "17849108", |
| 446 | + "metadata": {}, |
| 447 | + "source": [ |
| 448 | + "Copied from fastcore.dispatch, Used in fastai." |
| 449 | + ] |
| 450 | + }, |
| 451 | + { |
| 452 | + "cell_type": "code", |
| 453 | + "execution_count": null, |
| 454 | + "id": "f3cd995c", |
| 455 | + "metadata": {}, |
| 456 | + "outputs": [], |
| 457 | + "source": [ |
| 458 | + "#|export\n", |
| 459 | + "def retain_types(new, old=None, typs=None):\n", |
| 460 | + " \"Cast each item of `new` to type of matching item in `old` if it's a superclass\"\n", |
| 461 | + " if not is_listy(new): \n", |
| 462 | + " typs = Any if typs is None else typs # make fasttransform.utils.retain_type compatible\n", |
| 463 | + " return retain_type(new, old,typs)\n", |
| 464 | + " if typs is not None:\n", |
| 465 | + " if isinstance(typs, dict):\n", |
| 466 | + " t = first(typs.keys())\n", |
| 467 | + " typs = typs[t]\n", |
| 468 | + " else: t,typs = typs,None\n", |
| 469 | + " else: t = type(old) if old is not None and isinstance(old,type(new)) else type(new)\n", |
| 470 | + " return t(L(new, old, typs).map_zip(retain_types, cycled=True))\n", |
| 471 | + " " |
| 472 | + ] |
| 473 | + }, |
| 474 | + { |
| 475 | + "cell_type": "code", |
| 476 | + "execution_count": null, |
| 477 | + "id": "37c1d02c", |
| 478 | + "metadata": {}, |
| 479 | + "outputs": [], |
| 480 | + "source": [ |
| 481 | + "class T(tuple): pass\n", |
| 482 | + "\n", |
| 483 | + "t1,t2 = retain_types((1,(1,(1,1))), (2,T((2,T((3,4))))))\n", |
| 484 | + "test_eq_type(t1, 1)\n", |
| 485 | + "test_eq_type(t2, T((1,T((1,1)))))\n", |
| 486 | + "\n", |
| 487 | + "t1,t2 = retain_types((1,(1,(1,1))), typs = {tuple: [int, {T: [int, {T: [int,int]}]}]})\n", |
| 488 | + "test_eq_type(t1, 1)\n", |
| 489 | + "test_eq_type(t2, T((1,T((1,1)))))" |
| 490 | + ] |
| 491 | + }, |
435 | 492 | {
|
436 | 493 | "cell_type": "markdown",
|
437 | 494 | "id": "b712c700",
|
|
0 commit comments