diff --git a/analysis/hf_reanalysis/reanalysis.ipynb b/analysis/hf_reanalysis/reanalysis.ipynb index 609cc6be1..0230536c5 100644 --- a/analysis/hf_reanalysis/reanalysis.ipynb +++ b/analysis/hf_reanalysis/reanalysis.ipynb @@ -522,7 +522,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 19, "id": "f850551f", "metadata": {}, "outputs": [], @@ -534,6 +534,7 @@ " z_key: str = None,\n", " z_type: Literal['log', 'linear'] = 'log',\n", " color_key: str = None,\n", + " color_type: Literal['log', 'log2', 'linear'] = 'linear',\n", " fit_fn = None,\n", " savepath: str = None,\n", "):\n", @@ -582,7 +583,13 @@ " )\n", " hovertemplate=f\"{x_key}:%{{x:.2e}}
{y_key}:%{{y:.2e}}\"\n", " if color_key:\n", - " color_variable = runs[color_key]\n", + " if color_type==\"log\":\n", + " color_variable = np.log(runs[color_key])\n", + " elif color_type==\"log2\":\n", + " color_variable = np.log2(runs[color_key])\n", + " else:\n", + " color_variable = runs[color_key]\n", + " \n", " hovertemplate += f\"
{color_key}:%{{marker.color:.2e}}\"\n", " else:\n", " color_variable = None\n", @@ -618,11 +625,9 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 17, "id": "969c0e98", - "metadata": { - "scrolled": false - }, + "metadata": {}, "outputs": [], "source": [ "model_params = []\n", @@ -683,6 +688,7 @@ "runs = {\n", " 'N': np.array(model_params),\n", " 'D': np.array(unique_tokens),\n", + " 'D_total': np.array(tokens),\n", " 'L': np.array(losses),\n", " 'R': np.array([x/y for x,y in zip(tokens, unique_tokens)])\n", "}\n", @@ -839,14 +845,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|█████████████████████████████████████████████████████████████████| 4500/4500 [10:16<00:00, 7.30it/s]" + "100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4.02it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "x0: [10. 20. 1. 0. 1.]\n", + "x0: [10 20 1 0 1]\n", " fun: 2.3162021954760153e-05\n", " hess_inv: <5x5 LbfgsInvHessProduct with dtype=float64>\n", " jac: array([ 1.21138712e-11, -5.19463000e-11, 5.91681975e-11, -2.29992866e-10,\n", @@ -937,7 +943,9 @@ "cell_type": "code", "execution_count": 13, "id": "b44df3e9", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [], "source": [ "runs_single_epoch = {\n", @@ -954,6 +962,43 @@ " savepath='single-epoch-runs-residuals.html'\n", ")" ] + }, + { + "cell_type": "markdown", + "id": "97430390-c275-432c-b968-8d442772ccd6", + "metadata": {}, + "source": [ + "# Effective Data" + ] + }, + { + "cell_type": "markdown", + "id": "e1e91ab4-ae5e-40ff-982e-713925b46922", + "metadata": {}, + "source": [ + "Visualize how epoch number affects the Chinchilla Scaling Law" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "dd301104-be42-4d08-9028-6a2a24927161", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "scaling_scatter(\n", + " runs, \n", + " x_key='N', \n", + " y_key='D_total', \n", + " z_key='L', \n", + " color_key='R',\n", + " color_type='log',\n", + " fit_fn=single_epoch_fit,\n", + " savepath='single-epoch-runs-fitted-multiepoch-color.html'\n", + ")" + ] } ], "metadata": { diff --git a/analysis/hf_reanalysis/single-epoch-runs-fitted-multiepoch-color.html b/analysis/hf_reanalysis/single-epoch-runs-fitted-multiepoch-color.html new file mode 100644 index 000000000..34733b640 --- /dev/null +++ b/analysis/hf_reanalysis/single-epoch-runs-fitted-multiepoch-color.html @@ -0,0 +1,14 @@ + + + +
+
+ + \ No newline at end of file