|
2572 | 2572 | "id": "GxLAr-7p0ATX"
|
2573 | 2573 | },
|
2574 | 2574 | "source": [
|
2575 |
| - "To use `tf.data.experimental.sample_from_datasets` pass the datasets, and the weight for each:" |
| 2575 | + "To use `tf.data.Dataset.sample_from_datasets` pass the datasets, and the weight for each:" |
2576 | 2576 | ]
|
2577 | 2577 | },
|
2578 | 2578 | {
|
|
2583 | 2583 | },
|
2584 | 2584 | "outputs": [],
|
2585 | 2585 | "source": [
|
2586 |
| - "balanced_ds = tf.data.experimental.sample_from_datasets(\n", |
| 2586 | + "balanced_ds = tf.data.Dataset.sample_from_datasets(\n", |
2587 | 2587 | " [negative_ds, positive_ds], [0.5, 0.5]).batch(10)"
|
2588 | 2588 | ]
|
2589 | 2589 | },
|
|
2623 | 2623 | "id": "kZ9ezkK6irMD"
|
2624 | 2624 | },
|
2625 | 2625 | "source": [
|
2626 |
| - "One problem with the above `experimental.sample_from_datasets` approach is that\n", |
2627 |
| - "it needs a separate `tf.data.Dataset` per class. Using `Dataset.filter`\n", |
2628 |
| - "works, but results in all the data being loaded twice.\n", |
| 2626 | + "One problem with the above `Dataset.sample_from_datasets` approach is that\n", |
| 2627 | + "it needs a separate `tf.data.Dataset` per class. You could use `Dataset.filter`\n", |
| 2628 | + "to create those two datasets, but that results in all the data being loaded twice.\n", |
2629 | 2629 | "\n",
|
2630 |
| - "The `data.experimental.rejection_resample` function can be applied to a dataset to rebalance it, while only loading it once. Elements will be dropped from the dataset to achieve balance.\n", |
| 2630 | + "The `data.Dataset.rejection_resample` method can be applied to a dataset to rebalance it, while only loading it once. Elements will be dropped from the dataset to achieve balance.\n", |
2631 | 2631 | "\n",
|
2632 |
| - "`data.experimental.rejection_resample` takes a `class_func` argument. This `class_func` is applied to each dataset element, and is used to determine which class an example belongs to for the purposes of balancing.\n", |
| 2632 | + "`data.Dataset.rejection_resample` takes a `class_func` argument. This `class_func` is applied to each dataset element, and is used to determine which class an example belongs to for the purposes of balancing.\n", |
2633 | 2633 | "\n",
|
2634 |
| - "The elements of `creditcard_ds` are already `(features, label)` pairs. So the `class_func` just needs to return those labels:" |
| 2634 | + "The goal here is to balance the lable distribution, and the elements of `creditcard_ds` are already `(features, label)` pairs. So the `class_func` just needs to return those labels:" |
2635 | 2635 | ]
|
2636 | 2636 | },
|
2637 | 2637 | {
|
|
2646 | 2646 | " return label"
|
2647 | 2647 | ]
|
2648 | 2648 | },
|
2649 |
| - { |
2650 |
| - "cell_type": "markdown", |
2651 |
| - "metadata": { |
2652 |
| - "id": "DdKmE8Jumlp0" |
2653 |
| - }, |
2654 |
| - "source": [ |
2655 |
| - "The resampler also needs a target distribution, and optionally an initial distribution estimate:" |
2656 |
| - ] |
2657 |
| - }, |
2658 |
| - { |
2659 |
| - "cell_type": "code", |
2660 |
| - "execution_count": null, |
2661 |
| - "metadata": { |
2662 |
| - "id": "9tv0tWNxmkzM" |
2663 |
| - }, |
2664 |
| - "outputs": [], |
2665 |
| - "source": [ |
2666 |
| - "resampler = tf.data.experimental.rejection_resample(\n", |
2667 |
| - " class_func, target_dist=[0.5, 0.5], initial_dist=fractions)" |
2668 |
| - ] |
2669 |
| - }, |
2670 | 2649 | {
|
2671 | 2650 | "cell_type": "markdown",
|
2672 | 2651 | "metadata": {
|
2673 | 2652 | "id": "YxJrOZVToGuE"
|
2674 | 2653 | },
|
2675 | 2654 | "source": [
|
2676 |
| - "The resampler deals with individual examples, so you must `unbatch` the dataset before applying the resampler:" |
| 2655 | + "The resampling method deals with individual examples, so in this case you must `unbatch` the dataset before applying that method.\n", |
| 2656 | + "\n", |
| 2657 | + "The method needs a target distribution, and optionally an initial distribution estimate as inputs." |
2677 | 2658 | ]
|
2678 | 2659 | },
|
2679 | 2660 | {
|
|
2684 | 2665 | },
|
2685 | 2666 | "outputs": [],
|
2686 | 2667 | "source": [
|
2687 |
| - "resample_ds = creditcard_ds.unbatch().apply(resampler).batch(10)" |
| 2668 | + "resample_ds = (\n", |
| 2669 | + " creditcard_ds\n", |
| 2670 | + " .unbatch()\n", |
| 2671 | + " .rejection_resample(class_func, target_dist=[0.5,0.5],\n", |
| 2672 | + " initial_dist=fractions)\n", |
| 2673 | + " .batch(10))" |
2688 | 2674 | ]
|
2689 | 2675 | },
|
2690 | 2676 | {
|
|
2693 | 2679 | "id": "L-HnC1s8idqV"
|
2694 | 2680 | },
|
2695 | 2681 | "source": [
|
2696 |
| - "The resampler returns creates `(class, example)` pairs from the output of the `class_func`. In this case, the `example` was already a `(feature, label)` pair, so use `map` to drop the extra copy of the labels:" |
| 2682 | + "The `rejection_resample` method returns `(class, example)` pairs where the `class` is the output of the `class_func`. In this case, the `example` was already a `(feature, label)` pair, so use `map` to drop the extra copy of the labels:" |
2697 | 2683 | ]
|
2698 | 2684 | },
|
2699 | 2685 | {
|
|
0 commit comments