-
Notifications
You must be signed in to change notification settings - Fork 27
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
418 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,261 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as plt\n", | ||
"import numpy as np\n", | ||
"from sklearn import datasets, linear_model\n", | ||
"from sklearn.metrics import mean_squared_error, r2_score\n", | ||
"import torch.optim as optim" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 51, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Helper functions\n", | ||
"def predict(X, W):\n", | ||
" YPred = X.dot(W)\n", | ||
" return YPred" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 97, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"num_iterations = 5\n", | ||
"N = 10000\n", | ||
"D = 10\n", | ||
"\n", | ||
"X = np.random.randn(N, D)\n", | ||
"Y = np.random.randn(N, 1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 33, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load the diabetes dataset\n", | ||
"diabetes = datasets.load_diabetes()\n", | ||
"\n", | ||
"# Use only one feature\n", | ||
"diabetes_X = diabetes.data[:, np.newaxis, 2]\n", | ||
"\n", | ||
"# Split the data into training/testing sets\n", | ||
"diabetes_X_train = diabetes_X[:-20]\n", | ||
"diabetes_X_test = diabetes_X[-20:]\n", | ||
"\n", | ||
"# Split the targets into training/testing sets\n", | ||
"diabetes_y_train = diabetes.target[:-20]\n", | ||
"diabetes_y_test = diabetes.target[-20:]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 34, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"#print(diabetes)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 29, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"diabetes_X_train = X[:-20]\n", | ||
"diabetes_X_test = X[-20:]\n", | ||
"diabetes_y_train = Y[:-20]\n", | ||
"diabetes_y_test = Y[-20:]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 40, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"('Coefficients: \\n', array([938.23786125]))\n", | ||
"Mean squared error: 2548.07\n", | ||
"Variance score: 0.47\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAADuCAYAAAAOR30qAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEGRJREFUeJzt3W+MXFX9x/HPnf7RHaC1UFBjmXuRWKlFEFir8RcV/+H/JwY1cawx/pkHBEIkoUYm0WgyxOojIfgzQ41R9z5RiSZiTEqtxJhodCskFmEJkblbNJi2gm0zXfpnrw+Os9t2d+be2+6de+6571fSB52ebb6bhU++/Z5zz/XiOBYAoHi1ogsAABgEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgQwAliCQAcASq7Ms3rhxYxwEQU6lAICb9u3bdyiO48uT1mUK5CAIND09ff5VAUAFeZ4XpVnHyAIALEEgA4AlCGQAsASBDACWIJABwBIEMgCnhWGoIAhUq9UUBIHCMCy6pKEyHXsDgDIJw1CtVkv9fl+SFEWRWq2WJKnZbBZZ2rLokAE4q91uL4TxQL/fV7vdLqii0QhkAM6anZ3N9HnRCGQAzmo0Gpk+LxqBDMBZnU5H9Xr9rM/q9bo6nU5BFY1GIANwVrPZVLfble/78jxPvu+r2+1auaEnSV4cx6kXT05OxlwuBADZeJ63L47jyaR1dMgAYAkCGQAsQSADgCUIZACwBIEMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsASBDACWIJABwBIEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgQwAliCQAcASBDIAWIJABgBLEMgAYAkCGQAsQSADgCUIZACwBIEMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsASBDACWIJABwBIEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgQwAliCQAcASBDIAWIJABgBLEMgAYAkCGYCznn9euuEGyfOka66RpqeLrmg0AhmAlcIwVBAEqtVqCoJAYRim/tpf/tKE8KtfLT3+uPlsZkb60Y9yKnaFrC66AAA4VxiGarVa6vf7kqQoitRqtSRJzWZz2a85cUK67Tbp+98f/vcO+VJreHEcp148OTkZT9ve8wMovSAIFEXRks9931ev1zvrs6eekt72NumFF4b/fVdfLe3dKzUaK1xoSp7n7YvjeDJpHSMLANaZnZ1N/Px73zNjiS1bhofx3XdLp05JzzxTXBhnwcgCgHUajcayHfKmTVt1yy3SI4+M/vpHH5Xe+c58assTHTIA63Q6HdXr9TM++T9JsQ4c+OvQMH73u02nHMflDGOJDhmAhZrNpubnPX3hC1t14sT1I9fef790++1jKixnBDIAqzz5pPSGN0jSp4auWbdO+sMfBuvcwcgCgBW+/nWzSTcqZD/7WWluTvrPf9wLY4kOGUCBjh2TNm6UXnpp9LpvflP68pfHU1OR6JABC13IU2pl8JvfmG74kktGh/HMjNmkq0IYSwQyYJ3BU2pRFCmO44Wn1MoeynEsfeITJojf+97h697xDun0abN+8+bx1WcDntQDLJPlKbUy+Mc/pE2bktf99KfSrbfmX08ReFIPKKk0T6mVwa5dphtOCuNDh0w37GoYZ0EgA5ZpDHnGd9jnNjl50lxz6XnSF784fN1tt5kQjmPpssvGV5/tCGTAMkufUpPq9bo6nU5BFSV77DETwmvXmo24Yf74RxPCDzwwvtrKhEAGLNNsNtXtduX7vjzPk+/76na7Q6+dLNLdd5sgvvHG4WsaDXN2OI6lt7xlfLWVEZt6ADJ58UVpw4bkdffdJ91xR/71lEHaTT0eDAGQysMPSx/9aPK6Z5+VgiD3cpzEyALAUHEsffCDZiwxKow//GFpft6sJ4zPHx0ygCV6Pemqq5LXPfywCWOsDDpkAAvuu890w0lh/OKLphsmjFcWgQxU3LFjJoQ9T7rzzuHrduxYPDu8fv346qsSAhmoqB//ePGCn1Eee8yE8M6d46mrypghAxWzZo158ecoW7eaIF6zZjw1waBDBirg2WcXxxKjwnjXLtMN799PGBeBQAYcdtddJoRf+9rR6/bvN0H8+c+Ppy4sj5EF4JhTp9J3t/PzJrBhBzpkwBGPPmrCNSmMv/OdxdMShLFd6JCBktu2Tfrzn5PXHTrEVZe2I5CBEnrhBenSS5PXXX+99Pjj+deDlcHIAiiR737XjBmSwnjPHjOSIIzLhQ4ZsFwcS7WUrdPJk9Jq/q8uLTpkwFJPPmm64aQwvuOOxU06wrjc+PEBlrnqKnPbWpJnnpGuvjr3cjBGBDJggePHpXNeozdUhpf8oGQYWQAFGmzSJYXxD36wOJaAu+iQgQKkfSDj8OF0x9vgBjrkc4RhqCAIVKvVFASBwjAsuiQ4otdbvOAnyaAbJoyrhUA+QxiGarVaiqJIcRwriiK1Wi1CGRfkk59M9xaOX/yCsUTVeXGGn/7k5GQ8PT2dYznFCoJAURQt+dz3ffXSbHsD/5Pl7PCpU9KqVfnWg2J5nrcvjuPJpHV0yGeYnZ3N9DncdCFjq927050d/sAHFrthwhgDbOqdodFoLNshNxqNAqpBEQZjq36/L0kLYytJajabQ79uYkKam0v++2dmpM2bV6RUOIgO+QydTkf1c84f1et1dTqdgirCuLXb7YUwHuj3+2q320vWHjmyuEmXFMaDbpgwxigE8hmazaa63a5835fnefJ9X91ud2RnBLekGVvde68J4aQ3L+/cySYdsiGQz9FsNtXr9TQ/P69er0cYV8yw8VSj0Vjohpdpls9y9KgJ4R07cigwBxz1tAeBDJxh6djqGkmxoqg38ute8YrFbvjii/OscGVx1NMuHHsDzhGGoT73uS06ceLGxLV790rvetcYisoJRz3HI+2xN05ZAP+z+HLQ5DGVKy8H5ainXRhZoPIeeCDdy0G3b3fv5aCjZuYYPzpkVFbaUJ2dla68Mt9aitLpdM46dy1x1LNIdMiolH/+M/sFP66GscRRT9sQyKiEj3zEhPBrXjN63Ve/Wr2zwxz1tAcjCzgt7Vii3zePPwNFokOGc37+8+xjCcIYNqBDhjPSdsO7d0vve1++tQDng0BGqfX70kUXpVtbpbkwyomRBUqp1TIdcVIY+371NulQXnTIKJW0Y4m//z35lUmAbeiQYb0nnsi+SUcYo4wIZFhrEMLXXjt63Ve+wlgCbiCQC8Q9tEsN7olI0w2/9JJZf++9+dcFjAOBXBDuoT3bt76V7uWg0mI3vHZt/nUB48R9yAXhHloj7Sbdnj3Se96Tby1AXrgP2XJVvof24EHpiivSrWUujCphZFGQKt5D+8Y3mo44KYxf+Uo26VBNBHJBlr67zd17aAebdPv3j1733HMmhJ9/fjx1AbYhkAvi+j20e/ZkPzucdDUm4Do29bCi0m7S3XOP5OA/BoBlsamHsVl8OWi6tatW5VsPUFaMLHDe7ror3ctBpcWxBGEMDEeHjMzSjiV+9zvp7W/PtxbAJQQyUun10l/Yw3E14PwwssBIN9xgOuKkMN62jbPDwIWiQ8ay0o4l/v1vacOGfGsBqoIOGQt+/evsZ4cJY2DlEMhYCOEPfShp5Xb5fqCpqWreSAfkjZFFRc3NSRMT6dZOTFyk48f7kqQoklqtliQ581QhYAs65Ir50pdMN5wUxhs2mJGE7wcLYTzQ7/fVbrdzrBKoJjrkiki7STczI23evPj7Kl8TCowbHbLDnn46+ybdmWEsVfOaUKAoBLKDLrvMhPDrXz963Z13Jp8drtI1oUDRGFk4Io7TvY9Oko4fl17+8nRrBxt37XZbs7OzajQa6nQ6bOgBOeD6zZKbmpK2b0+3lqfogGJw/abj0m7S/epXac4XA7ABM+SSCMNQjca1mTfpCGOgPAjkEnjrWyN9+tNNHTgw+qV0113HBT9AmTGysNhiJ+yPXHfggLRpU+7lAMgZHbJl9u1Lf3bY82qKY8IYcAWBbIlBCE8m7sPeI8mT5PFwBuAYRhYFmp9P/465iYl1On786MLveTgDcA8dcgF27zbdcJowHmzSPfjg/8v3fXmeJ9/31e12eTgDcAyBPEYve5kJ4ve/f/S63/9+6WmJZrOpXq+n+fl59Xo9wjhBGIYKgkC1Wk1BECgMucMZ9mNkkbMjR6T169Ot5bjaygjDUK1WS/3+4A7niDucUQp0yDnpdEw3nBTG3/42Z4dXWrvdXgjjAe5wRhnQIa+wtI80Hz0qXXxxvrVUFXc4o6zokFfA3/6W7uzwpZcudsOEcX64wxllRSBfgJtvNiG8devodXv3mhA+fHgsZa24sm2QcYczyoqRRUanTklr1qRbOz+ffoRhqzJukHGHM8qK+5BT+tnPpI9/PHndZz4j/fCH+dczLkEQKIqiJZ/7vq9erzf+goAS4j7kFZK2w3X1gh82yIDxYYa8jIMHs78c1MUwltggA8aJQD7Dgw+aEL7iitHrdu2qztlhNsiA8WFkofRjibk58/hzlbBBBoxPZTf1/vUv6VWvSl63ZYs5ZwwA5yvtpl7lRhZTU6YjTgrjmRkzkrAtjMt2JhhAepUYWZw+LW3bJv3lL8lrbZ4Ll/FMMID0nO6Qn3jCdMOrV48O46mpYjfp0na9XJoDuM3JDvlrX5O+8Y3RazZulGZnpYmJ8dQ0TJaulzPBgNuc6ZCPHZPWrjUd8agw3rnTdMIHDxYfxlK2rpczwYDbSh/IjzxiQviSS6STJ4eve/ppE8Q7doyvtjSydL2cCQbcVspAjmPp1ltNEN9yy/B1N99sNvTiWHrd68ZWXiZZut5ms6lut8u79QBHlSqQn3vOhHCtJj300PB1Dz1kQvi3vzVrbZa16+XdeoC7LI8ro9s1QXzllaPXHT5sgvhjHxtPXSuBrhfAgNVP6s3NJW+83X67dP/946kHAM6HE9dv/uQnw//sT3+S3vzm8dUCAHmzOpDf9CZp3TrpyBHz+yCQnnqqehf8AKgGqwP5uuvMwxsnTkiXX150NQCQL6sDWZLWry+6AgAYj1KcsgCAKiCQAcASlQ5k7hYGYBPrZ8h54W5hALapbIfM3cIAbFPZQOZuYQC2qWwgc7dweTH7h6sqG8iu3C1ctXAazP6jKFIcxwuzf9e/b1REHMepf910002xS6ampmLf92PP82Lf9+OpqamiS8pkamoqrtfrsaSFX/V6feT3Ufbv2ff9s77fwS/f94suDRhK0nScImOtvu0NowVBoCiKlnzu+756vd6Sz889WSKZfxWU6brPWq2m5f6b9TxP8/PzBVQEJEt721tlRxYuyLox6cLJEmb/cBmBXGJZw8mFkyWuzP6B5RDIJZY1nFzoLnnDClxGIJdY1nBypbvkvYJwVSkCuWpHu7LIEk50l4DdrD9l4cLJAADV5swpCxdOBgBAGtYHsgsnAwAgDesD2YWTAQCQhvWB7MrJAABIYnUgh2G4MENetWqVJHEyoCI4WYMqsvaNIeeerjh9+vRCZ0wYu423uaCqrD32lvXiHLiDnz1cU/pjb5yuqC5+9qgqawOZ0xXVxc8eVWVtIHO6orr42aOqrA1k7l2oLn72qCprN/UAwBWl39QDgKohkAHAEgQyAFiCQAYASxDIAGCJTKcsPM87KGnpM60AgFH8OI4vT1qUKZABAPlhZAEAliCQAcASBDIAWIJABgBLEMgAYAkCGQAsQSADgCUIZACwBIEMAJb4L/4/ciktfwZ6AAAAAElFTkSuQmCC\n", | ||
"text/plain": [ | ||
"<Figure size 432x288 with 1 Axes>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"# Create linear regression object\n", | ||
"regr = linear_model.LinearRegression()\n", | ||
"\n", | ||
"# Train the model using the training sets\n", | ||
"regr.fit(diabetes_X_train, diabetes_y_train)\n", | ||
"\n", | ||
"# Make predictions using the testing set\n", | ||
"diabetes_y_pred = regr.predict(diabetes_X_test)\n", | ||
"\n", | ||
"# The coefficients\n", | ||
"print('Coefficients: \\n', regr.coef_)\n", | ||
"# The mean squared error\n", | ||
"print(\"Mean squared error: %.2f\"\n", | ||
" % mean_squared_error(diabetes_y_test, diabetes_y_pred))\n", | ||
"# Explained variance score: 1 is perfect prediction\n", | ||
"print('Variance score: %.2f' % r2_score(diabetes_y_test, diabetes_y_pred))\n", | ||
"\n", | ||
"# Plot outputs\n", | ||
"\n", | ||
"plt.scatter(diabetes_X_test, diabetes_y_test, color='black')\n", | ||
"plt.plot(diabetes_X_test, diabetes_y_pred, color='blue', linewidth=3)\n", | ||
"\n", | ||
"plt.xticks(())\n", | ||
"plt.yticks(())\n", | ||
"\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 47, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"2623942383.2329097\n", | ||
"('Objective Value:', 2663156030.210076)\n", | ||
"('Objective Value:', 2718233711.865551)\n", | ||
"('Objective Value:', 2755549908.0834365)\n", | ||
"('Objective Value:', 2777025493.9161234)\n", | ||
"('Objective Value:', 2788632001.1960063)\n", | ||
"('Weights: ', array([[936.25300989]]))\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"admm = optim.ADMM([diabetes_X_train, diabetes_y_train, 0.01], \"Lasso\") #, parallel = True)\n", | ||
"\n", | ||
"print(admm.getLoss())\n", | ||
"for i in range(0, 5):\n", | ||
" print('Objective Value:', admm.step())\n", | ||
"\n", | ||
"print('Weights: ',admm.getWeights())\n", | ||
"W = admm.getWeights()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 52, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Comparison with Scikit - LR\n", | ||
"admm_predict = predict(diabetes_X_test, W)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 53, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Mean squared error: 20747.88\n", | ||
"Variance score: -3.29\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"print(\"Mean squared error: %.2f\"\n", | ||
" % mean_squared_error(diabetes_y_test, admm_predict))\n", | ||
"# Explained variance score: 1 is perfect prediction\n", | ||
"print('Variance score: %.2f' % r2_score(diabetes_y_test, admm_predict))\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 54, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"data": { | ||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAADuCAYAAAAOR30qAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADdZJREFUeJzt3V+IVOUfx/HPWU1rpH9sglbsmYgUU8Fy7aKSjOjGiIwuUiYkyBa8KUzIiyXKaMCrIkqMjaLQSW+iC7OIoGjFm22tJJPIhJmF1PyJqNj4Z22e38Wwnp3VZs85O+ec55zzfsFeNM2Br45+/M73PM95HGOMAADJ60q6AABAE4EMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsMT0IG++7bbbTLFYjKgUAMim/fv3nzTGzJ7sfYECuVgsanh4OHxVAJBDjuPU/LyPkQUAWIJABgBLEMgAYAkCGQAsQSADgCUIZACZVqlUVCwW1dXVpWKxqEqlknRJ/ynQsjcASJNKpaK+vj7V63VJUq1WU19fnySpVColWdo10SEDyKz+/v4rYTymXq+rv78/oYraI5ABZNbIyEig15NGIAPIrJ6enkCvJ41ABpBZ5XJZhUKh5bVCoaByuZxQRe0RyAAyq1QqaWBgQK7rynEcua6rgYEBK2/oSZJjjPH95t7eXsPDhQAgGMdx9htjeid7Hx0yAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsASBDACWIJABwBIEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgQwAliCQAcASBDIAWIJABgBLEMgAYAkCGQAsQSADgCUIZACwBIEMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsASBDACWIJABwBIEMgArVSoVFYtFdXV1qVgsqlKpJF1S5KYnXQAATFSpVNTX16d6vS5JqtVq6uvrkySVSqUkS4sUHTIA6/T3918J4zH1el39/f0JVRQPAhmAdUZGRgK9nhUEMgDr9PT0BHo9KwhkANYpl8sqFAotrxUKBZXL5YQqigeBDMA6pVJJAwMDcl1XjuPIdV0NDAxk+oaeRCADVsrjkq+JSqWSqtWqGo2GqtVq5sNYYtkbYJ28LvkCHTJgnbwu+QKBDFgnr0u+QCAD1snrki8QyIB18rrkCwQyYJ28LvmC5BhjfL+5t7fXDA8PR1gOAGSP4zj7jTG9k72PDhkALEEgA4AlCGQAsASBDACWIJABwBIEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgTwBR+cASApHOI3D0TkAkkSHPA5H50DiWxKSQ4c8DkfngG9JSBId8jgcnQO+JSFJBPI4HJ2DPH5LYkRjDwJ5HI7OQd6+JY2NaGq1mowxV0Y0hHIyOMIJGGfiDFlqfkvK6j/MxWJRtVrtqtdd11W1Wo2/oIziCCcghLx9S8rjiMZmdMhAjtEhx4MOGcCkuJFtFwIZyLG8jWhsx8gCACLGyAIAUoZABgBLEMgAYAkCOUFsWQUwHk97SwhPFQMwER1yQniqGICJCOSEsGUVwEQEckLy9lQxAJMjkBPCllUAExHICWHLKoCJ2DoNABFj6zQApAyBDACWIJAxKXYUAvFgpx7aYkchEB86ZLTFjkIgPgQy2mJHIRAfAhltsaMQiA+BjLbYUQjEh0BGW+woBOJDIKdEkkvPSqWSqtWqGo2GqtUqYQxEhGVvKcDSMyAf6JBTgKVnQD4QyCnA0jMgHwjkFGDpWXBs90YaEcgpwNKzYMZm7rVaTcaYKzN3Qhm2I5BTgKVnwTBzR1rxgHpkTldXl67159pxHDUajQQqQt7xgHp0TNrmsczckVYEMtpK4zyWmTvSikBGW2mcxzJzR1oxQ0ZbzGOBqWOGjI5gHgvEh0BOmbhvsDGPBeJDIKdIEjfYmMcC8WGGbIFKpaL+/n6NjIyop6dH5XL5moFXLBZVq9Wuet11XVWr1RgqBRCG3xkyj99MWJBHa/KQISDbGFkkLMiyMm6wAdlGICcsSNfLDTYg2wjkhAXpernBBmQbgZywoF0v59sBwZw/L33wgbRvnxRgDUMiCOSE0fUCnXfggLRiheQ4UqEgrV8vPfyw9P77SVfWHqssLFAqlQhgYApGR6Vt26SXX27/vqNH46knLAIZQCodPixt3Cjt3u3v/XPmSK+9Fm1NU8XIAkAqNBrSxx83RxCOI82bN3kYd3dLO3c2rz12rHmtzXIdyGl78DqQN7/9Js2c2QzgadOkF15o3qRr59lnpWq1eQPv5Elp9erm9WmQ20BO44PXgawzRnrjjWaAOo60aJF06VL7a2bOlD78ULp8uXn9rl2S68ZSbsflNpDT+OB1NPHNJluqVen225sB3NUlbd48+TVPPCH9/nszgC9ckNata3bQaZfbQM7KcyHyFk58s8mG997zuuC77mrOdyfz9tvSxYvNEP7yS2n+/OjrjJ0xxvfP0qVLTVa4rmskXfXjum7Spfm2Y8cOUygUWuovFApmx44dba9xXdc4jmNc1237Xhtl4XPLo+PHjVm40JhmnPr/+eqrpCvvDEnDxkfG5jaQw4SZbYKGUxZ+zY7jXPPX7DhO0qVhgu3bgwfwY48Zc+ZM0pV3HoHsQ9q7xaDhlIXuMgu/hqw6fdqYRx4JHsI7dyZdefT8BnIqZshRzUnT/lyIoI/jzMLcnCfe2WXPHm8WfMst0g8/TH7NkiXSiRNeJK9eHX2dqeEntcd+kuiQs/A1OypBf2+y0l2m/ZtNmtXrxjz9dPAueOvWpCtPlrIysshKiEQlSDjxjxvCGBwMHsA9PcbUaklXbo/MBDI3cTqL7hKTuXTJmHXrgofwW28Z02gkXb2d/Aay9YeccrAnEL39+6XeSY/gbDVrlvTjj9KCBdHUlCV+Dzm1+qZepVLRuXPnrnqdmzjZl7cNL3H791/pjju8G3J+w3jDBm+L8rlzhHHH+Wmjx37iHFlca94pyXR3d/M1O+OYdUdj797gYwjJmKGhpCtPP6V9ZMGoIr/47DvDGGn58ubRRUGsXdt8WM+MGdHUlUd+RxbWPqA+C2tmEQ6ffXgHD0qLFwe/7t13pZde6nw9CMbaGXLQTQ/IDj77YEolbxYcJIxPnvQGE4SxHawNZHZk5ReffXsjI14AO4702Wf+rtu0qXU63N0dbZ0IztpA5jTm/OKzv9qmTV4AB3n4eq3mBfCWLdHVh86w9qYekGdHjzaXpQX13HPS9u2drwdTk4l1yECevPii1wUHCeODB70umDBON2tXWQBZd/asdPPNwa9bvrz5VLW0HNwJ/+iQgRiVy14XHCSM9+71uuDBQcI4q+iQgQhdutQ8FTmM0VFpOn9Dc4UOGeiwTz/1uuAgYfzOO63L0gjj/OEjB6ao0Qh/BP0//0gTllwjx+iQgRC++cbrgoOE8SuvtHbBhDHGo0MGfJo5szkTDurECWn27M7Xg+yhQwb+w08/tW5R9hvGq1a1dsGEMfyiQwbGCbuc7M8/pbvv7mwtyB86ZOTaoUOtXbBfCxa0dsGEMTqBQEbu3HOPF8ALF/q/bmjIC+BDh6KrD/nFyAKZd/y4NHduuGsDPHsLmDI6ZGTSggVeFxwkjD/5pHUUAcSJDhmZEPZBPVLzBOYuWhNYgD+GSK01a8I9qGfz5tYumDCGLeiQkRqXL0vXXRfu2npduuGGztYDdBq9Aaz25pteFxwkjJcta+2CCWOkAR0yrDKVEcKxY9KcOZ2tB4gTHTISt3Wr1wUHCeNp01q7YMIYaUeHjESE3aL866/SokWdrQWwBR0yYvH11+G2KEutXTBhjCwjkBGZ8QG8cqX/6z7/nM0ZyCdGFuiYgwelxYvDXUvwAnTImKLxXXCQMN6yhS4YmIgOGYH8/Xf41QyNBsfXA+3QIWNSjz/udcFBwrhUau2CCWOgPTpkXOXChfA7286fl66/vrP1AHlBhwxJ0oYNXhccJIzvvbe1CyaMgfDokHNqKluUOUUZiAYdco5s2xZui/LcuZyiDMSBDjnjOEUZSA865Izp1BZlwhiIHx1yBoTtgvftkx58sLO1AAiPQE6hAwekJUvCXcuuOMBejCxSYvwYIkgY79rFFmUgLeiQLfXXX9Kdd4a7luAF0okO2SL33ed1wUHCuFymCwaygA45QVPZojw6Kk3n0wMyhQ45ZgMD4bYor13b2gUTxkD28Nc6Yo1G8zDOMM6elW68sbP1ALAXHXIEdu/2uuAgYfzUU61dMGEM5AsdcoeE3Zxx6pR0662drQVAOtEhhzQ0FG6L8po1rV0wYQxgDB1yAPPnS3/8Efy6alVy3Y6XAyBj6JDbOHy4tQv2G8bLlrV2wYQxAD8I5Alef90L4Hnz/F/3yy9eAA8NRVcfgOzK/cji9Olwc9xZs6Rz5zpfD4D8ymWHPDjodcFBwvjbb70umDAG0Gm56JBHR6X166WPPgp+baPB8fUA4pHZDvnnn70ueMYM/2H83XetN+QIYwBxyUwgNxrSq696IXz//f6u27ix2UGPBfCjj0ZbJwD8l1SPLA4fbi4xO3Mm2HXDw9LSpdHUBABhpapDNkbasqV1WZqfMH7+eeniRa8LJowB2Mj6DvnUKemBB6QjR4Jd9/330ooVkZQEAJGwukM+ckTq7vYXxk8+2VyKNtYFE8YA0sbqDnlwsP3//+ILadWqeGoBgKhZ3SE/84y0cqX33w891BxhjHXBhDGALLG6Q77pJmnPnqSrAIB4WN0hA0CeEMgAYAkCGQAsQSADgCUIZACwBIEMAJYgkAHAEo4xxv+bHed/kmrRlQMAmeQaY2ZP9qZAgQwAiA4jCwCwBIEMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsMT/AZ6ZQ8zfgirvAAAAAElFTkSuQmCC\n", | ||
"text/plain": [ | ||
"<Figure size 432x288 with 1 Axes>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"plt.scatter(diabetes_X_test, diabetes_y_test, color='black')\n", | ||
"plt.plot(diabetes_X_test, admm_predict, color='blue', linewidth=3)\n", | ||
"\n", | ||
"plt.xticks(())\n", | ||
"plt.yticks(())\n", | ||
"\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import torch.optim as optim" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"num_iterations = 5\n", | ||
"N = 1000\n", | ||
"D = 20\n", | ||
"\n", | ||
"A = np.random.randn(N, D)\n", | ||
"b = np.random.randn(N, 1)\n", | ||
"Y = np.random.rand(N, 1)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"13218.725722259838\n", | ||
"('O Val:', 503.7030770044544)\n", | ||
"('O Val:', 503.70307384103387)\n", | ||
"('O Val:', 503.7030786237145)\n", | ||
"('O Val:', 503.7030712532673)\n", | ||
"('O Val:', 503.7030778652381)\n", | ||
"('Weights: ', array([[-0.02279339],\n", | ||
" [-0.03123076],\n", | ||
" [ 0.03235469],\n", | ||
" [-0.00769174],\n", | ||
" [-0.03006917],\n", | ||
" [-0.01935186],\n", | ||
" [ 0.0200083 ],\n", | ||
" [-0.0228672 ],\n", | ||
" [ 0.03491097],\n", | ||
" [-0.0216022 ],\n", | ||
" [ 0.0118515 ],\n", | ||
" [-0.03075605],\n", | ||
" [-0.019314 ],\n", | ||
" [ 0.05811015],\n", | ||
" [ 0.02316836],\n", | ||
" [ 0.02389774],\n", | ||
" [-0.06124286],\n", | ||
" [-0.02646492],\n", | ||
" [-0.01167727],\n", | ||
" [-0.03814794]]))\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"admm = optim.ADMM([A, b, 0.01], \"Lasso\") #, parallel = True)\n", | ||
"\n", | ||
"print(admm.getLoss())\n", | ||
"for i in range(0, num_iterations):\n", | ||
" print('O Val:', admm.step())\n", | ||
"\n", | ||
"print('Weights: ',admm.getWeights())\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 2 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython2", | ||
"version": "2.7.14" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.