Skip to content

Commit

Permalink
chinchilla fitting working!
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangir-azerbayev committed Oct 30, 2023
1 parent b7bff90 commit d62cf51
Show file tree
Hide file tree
Showing 4 changed files with 138 additions and 73 deletions.
2 changes: 1 addition & 1 deletion analysis/hf_reanalysis/data-constrained-scaling-raw.html

Large diffs are not rendered by default.

181 changes: 109 additions & 72 deletions analysis/hf_reanalysis/reanalysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import List, Dict\n",
"from typing import List, Dict, Literal\n",
"from functools import partial, reduce\n",
"import operator\n",
"import itertools\n",
Expand Down Expand Up @@ -532,7 +532,9 @@
" x_key: str, \n",
" y_key: str, \n",
" z_key: str = None,\n",
" z_type: Literal['log', 'linear'] = 'log',\n",
" color_key: str = None,\n",
" fit_fn = None,\n",
" savepath: str = None,\n",
"):\n",
" x = runs[x_key]\n",
Expand All @@ -545,10 +547,30 @@
" scene=dict(\n",
" xaxis=dict(type='log', title=x_key),\n",
" yaxis=dict(type='log', title=y_key),\n",
" zaxis=dict(type='log', title=z_key),\n",
" zaxis=dict(type=z_type, title=z_key),\n",
" )\n",
" )\n",
" hovertemplate=f\"<b>{x_key}:%{{x:.2e}}</b><br><b>{y_key}:%{{y:.2e}}</b><br><b>{z_key}:%{{z:.2e}}</b>\"\n",
" \n",
" if fit_fn:\n",
" x_grid = np.logspace(np.log10(min(x)), np.log10(max(x)), 50)\n",
" y_grid = np.logspace(np.log10(min(y)), np.log10(max(y)), 50)\n",
" \n",
" x_grid, y_grid = np.meshgrid(x_grid, y_grid)\n",
" z_surface = fit_fn(x_grid, y_grid)\n",
" \n",
" \n",
" surface = go.Surface(\n",
" z=z_surface, \n",
" x=x_grid, \n",
" y=y_grid, \n",
" name='Surface', \n",
" opacity=0.6,\n",
" contours={\"z\": {\"show\": True}},\n",
" colorscale='Viridis'\n",
" )\n",
" else:\n",
" surface = None\n",
"\n",
" else:\n",
" z = None\n",
Expand All @@ -566,21 +588,26 @@
" color_variable = None\n",
" hovertemplate += \"<extra></extra>\"\n",
" \n",
" data = [\n",
" scatter(\n",
" **scatter_kwargs,\n",
" mode='markers',\n",
" marker = dict(\n",
" size=8,\n",
" color=color_variable,\n",
" colorscale='Viridis',\n",
" colorbar=dict(title=color_key),\n",
" opacity=0.8,\n",
" ),\n",
" hovertemplate=hovertemplate\n",
" ), \n",
" ]\n",
" \n",
" if fit_fn:\n",
" data.append(surface)\n",
" \n",
" fig = go.Figure(\n",
" data=[\n",
" scatter(\n",
" **scatter_kwargs,\n",
" mode='markers',\n",
" marker = dict(\n",
" size=8,\n",
" color=color_variable,\n",
" colorscale='Viridis',\n",
" colorbar=dict(title=color_key),\n",
" opacity=0.8,\n",
" ),\n",
" hovertemplate=hovertemplate\n",
" )\n",
" ]\n",
" data=data\n",
" )\n",
" \n",
" \n",
Expand Down Expand Up @@ -749,12 +776,11 @@
" arg3 = e + np.zeros(arg2.shape)\n",
" \n",
" huber_input = logsumexp(np.stack((arg1, arg2, arg3), axis=0), axis=0) - np.log(L)\n",
" \n",
" return np.sum(huber(huber_input))\n",
" return np.mean(huber(huber_input))\n",
"\n",
"def compute_loss(params, N, D, L):\n",
"def compute_loss(params, N, D):\n",
" a, b, e, alpha, beta = params\n",
" return e + np.exp(a)/N**alpha + np.exp(b)/D**beta\n",
" return float(np.exp(e)) + float(np.exp(a))/N**alpha + float(np.exp(b))/D**beta\n",
"\n",
"loss_grad = grad(parametric_loss, argnum=0)"
]
Expand All @@ -768,7 +794,7 @@
{
"data": {
"text/plain": [
"0.0007581840318605708"
"0.0001509896728510271"
]
},
"execution_count": 8,
Expand All @@ -778,7 +804,7 @@
],
"source": [
"x0 = np.array([6, 6, 0.8, 0.3, 0.3])\n",
"parametric_loss(x0, np.array([6e10]), np.array([1e12]), np.array([1.2]))"
"parametric_loss(x0, runs_single_epoch['N'], runs_single_epoch['D'], runs_single_epoch['L'])"
]
},
{
Expand Down Expand Up @@ -806,71 +832,32 @@
{
"cell_type": "code",
"execution_count": 10,
"id": "7f8b6357",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
" fun: 0.0009536697844852684\n",
" hess_inv: <5x5 LbfgsInvHessProduct with dtype=float64>\n",
" jac: array([ 1.29267203e-05, -1.10545529e-06, 6.83015639e-06, -1.56147861e-05,\n",
" -4.19487472e-05])\n",
" message: 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'\n",
" nfev: 119\n",
" nit: 73\n",
" njev: 119\n",
" status: 0\n",
" success: True\n",
" x: array([6.65288417, 8.6388557 , 0.66973247, 0.37006023, 0.38972565])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"result = minimize(\n",
" parametric_loss, \n",
" x0,\n",
" args=(runs_single_epoch['N'], runs_single_epoch['D'], runs_single_epoch['L']),\n",
" method='L-BFGS-B',\n",
" jac=loss_grad\n",
")\n",
"\n",
"result"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "5a5d7753",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|███████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00, 4.26it/s]"
"100%|█████████████████████████████████████████████████████████████████| 4500/4500 [10:16<00:00, 7.30it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"x0: [ 0. 20. 1. 0.5 0. ]\n",
" fun: 0.0009496429007248201\n",
"x0: [10. 20. 1. 0. 1.]\n",
" fun: 2.3162021954760153e-05\n",
" hess_inv: <5x5 LbfgsInvHessProduct with dtype=float64>\n",
" jac: array([ 1.71088136e-07, 7.68931261e-07, 5.78982831e-07, -3.31521127e-06,\n",
" -1.46216045e-05])\n",
" jac: array([ 1.21138712e-11, -5.19463000e-11, 5.91681975e-11, -2.29992866e-10,\n",
" 9.12117586e-10])\n",
" message: 'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'\n",
" nfev: 213\n",
" nit: 133\n",
" njev: 213\n",
" nfev: 204\n",
" nit: 102\n",
" njev: 204\n",
" status: 0\n",
" success: True\n",
" x: array([6.11568046, 8.59511225, 0.64766353, 0.33900831, 0.38730405])\n"
" x: array([6.11564517, 8.59509349, 0.64766146, 0.33900636, 0.38730309])\n"
]
},
{
Expand All @@ -895,7 +882,7 @@
")\n",
"\n",
"# hardcode best init, comment out to do grid search\n",
"init_grid = dict(a=[0], b=[20], e=[1], alpha=[0.5], beta=[0])\n",
"init_grid = dict(a=[10], b=[20], e=[1], alpha=[0], beta=[1])\n",
"\n",
"num_inits = reduce(operator.mul, [len(v) for k,v in init_grid.items()], 1)\n",
"\n",
Expand All @@ -906,7 +893,8 @@
" x0,\n",
" args=(runs_single_epoch['N'], runs_single_epoch['D'], runs_single_epoch['L']),\n",
" method='L-BFGS-B',\n",
" jac=loss_grad\n",
" jac=loss_grad,\n",
" options={'gtol': 1e-12, 'ftol': 1e-12}\n",
" )\n",
"\n",
" if result.fun < best_loss:\n",
Expand All @@ -917,6 +905,55 @@
"print(\"x0: \", best_init)\n",
"print(best_result)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "b8158223",
"metadata": {},
"outputs": [],
"source": [
"single_epoch_fit = lambda N, D: compute_loss(best_result.x, N, D)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cf7a0b36",
"metadata": {},
"outputs": [],
"source": [
"scaling_scatter(\n",
" runs_single_epoch, \n",
" x_key='N', \n",
" y_key='D', \n",
" z_key='L', \n",
" fit_fn=single_epoch_fit,\n",
" savepath='single-epoch-runs-fitted.html'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "b44df3e9",
"metadata": {},
"outputs": [],
"source": [
"runs_single_epoch = {\n",
" **runs_single_epoch, \n",
" 'residuals': runs_single_epoch['L'] - single_epoch_fit(runs_single_epoch['N'], runs_single_epoch['D'])\n",
"}\n",
"\n",
"scaling_scatter(\n",
" runs_single_epoch,\n",
" x_key='N',\n",
" y_key='D',\n",
" z_key='residuals',\n",
" z_type='linear',\n",
" savepath='single-epoch-runs-residuals.html'\n",
")"
]
}
],
"metadata": {
Expand Down
14 changes: 14 additions & 0 deletions analysis/hf_reanalysis/single-epoch-runs-fitted.html

Large diffs are not rendered by default.

14 changes: 14 additions & 0 deletions analysis/hf_reanalysis/single-epoch-runs-residuals.html

Large diffs are not rendered by default.

0 comments on commit d62cf51

Please sign in to comment.