Skip to content

Commit

Permalink
Scikit- LR vs ADMM Lasso
Browse files Browse the repository at this point in the history
  • Loading branch information
bhushan23 committed May 1, 2018
1 parent c42f398 commit e0c30b0
Show file tree
Hide file tree
Showing 5 changed files with 418 additions and 17 deletions.
261 changes: 261 additions & 0 deletions Diabetes_Test.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,261 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from sklearn import datasets, linear_model\n",
"from sklearn.metrics import mean_squared_error, r2_score\n",
"import torch.optim as optim"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [],
"source": [
"# Helper functions\n",
"def predict(X, W):\n",
" YPred = X.dot(W)\n",
" return YPred"
]
},
{
"cell_type": "code",
"execution_count": 97,
"metadata": {},
"outputs": [],
"source": [
"num_iterations = 5\n",
"N = 10000\n",
"D = 10\n",
"\n",
"X = np.random.randn(N, D)\n",
"Y = np.random.randn(N, 1)"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"# Load the diabetes dataset\n",
"diabetes = datasets.load_diabetes()\n",
"\n",
"# Use only one feature\n",
"diabetes_X = diabetes.data[:, np.newaxis, 2]\n",
"\n",
"# Split the data into training/testing sets\n",
"diabetes_X_train = diabetes_X[:-20]\n",
"diabetes_X_test = diabetes_X[-20:]\n",
"\n",
"# Split the targets into training/testing sets\n",
"diabetes_y_train = diabetes.target[:-20]\n",
"diabetes_y_test = diabetes.target[-20:]"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [],
"source": [
"#print(diabetes)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"diabetes_X_train = X[:-20]\n",
"diabetes_X_test = X[-20:]\n",
"diabetes_y_train = Y[:-20]\n",
"diabetes_y_test = Y[-20:]"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"('Coefficients: \\n', array([938.23786125]))\n",
"Mean squared error: 2548.07\n",
"Variance score: 0.47\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAADuCAYAAAAOR30qAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEGRJREFUeJzt3W+MXFX9x/HPnf7RHaC1UFBjmXuRWKlFEFir8RcV/+H/JwY1cawx/pkHBEIkoUYm0WgyxOojIfgzQ41R9z5RiSZiTEqtxJhodCskFmEJkblbNJi2gm0zXfpnrw+Os9t2d+be2+6de+6571fSB52ebb6bhU++/Z5zz/XiOBYAoHi1ogsAABgEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgQwAliCQAcASq7Ms3rhxYxwEQU6lAICb9u3bdyiO48uT1mUK5CAIND09ff5VAUAFeZ4XpVnHyAIALEEgA4AlCGQAsASBDACWIJABwBIEMgCnhWGoIAhUq9UUBIHCMCy6pKEyHXsDgDIJw1CtVkv9fl+SFEWRWq2WJKnZbBZZ2rLokAE4q91uL4TxQL/fV7vdLqii0QhkAM6anZ3N9HnRCGQAzmo0Gpk+LxqBDMBZnU5H9Xr9rM/q9bo6nU5BFY1GIANwVrPZVLfble/78jxPvu+r2+1auaEnSV4cx6kXT05OxlwuBADZeJ63L47jyaR1dMgAYAkCGQAsQSADgCUIZACwBIEMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsASBDACWIJABwBIEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgQwAliCQAcASBDIAWIJABgBLEMgAYAkCGQAsQSADgCUIZACwBIEMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsASBDACWIJABwBIEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgQwAliCQAcASBDIAWIJABgBLEMgAYAkCGYCznn9euuEGyfOka66RpqeLrmg0AhmAlcIwVBAEqtVqCoJAYRim/tpf/tKE8KtfLT3+uPlsZkb60Y9yKnaFrC66AAA4VxiGarVa6vf7kqQoitRqtSRJzWZz2a85cUK67Tbp+98f/vcO+VJreHEcp148OTkZT9ve8wMovSAIFEXRks9931ev1zvrs6eekt72NumFF4b/fVdfLe3dKzUaK1xoSp7n7YvjeDJpHSMLANaZnZ1N/Px73zNjiS1bhofx3XdLp05JzzxTXBhnwcgCgHUajcayHfKmTVt1yy3SI4+M/vpHH5Xe+c58assTHTIA63Q6HdXr9TM++T9JsQ4c+OvQMH73u02nHMflDGOJDhmAhZrNpubnPX3hC1t14sT1I9fef790++1jKixnBDIAqzz5pPSGN0jSp4auWbdO+sMfBuvcwcgCgBW+/nWzSTcqZD/7WWluTvrPf9wLY4kOGUCBjh2TNm6UXnpp9LpvflP68pfHU1OR6JABC13IU2pl8JvfmG74kktGh/HMjNmkq0IYSwQyYJ3BU2pRFCmO44Wn1MoeynEsfeITJojf+97h697xDun0abN+8+bx1WcDntQDLJPlKbUy+Mc/pE2bktf99KfSrbfmX08ReFIPKKk0T6mVwa5dphtOCuNDh0w37GoYZ0EgA5ZpDHnGd9jnNjl50lxz6XnSF784fN1tt5kQjmPpssvGV5/tCGTAMkufUpPq9bo6nU5BFSV77DETwmvXmo24Yf74RxPCDzwwvtrKhEAGLNNsNtXtduX7vjzPk+/76na7Q6+dLNLdd5sgvvHG4WsaDXN2OI6lt7xlfLWVEZt6ADJ58UVpw4bkdffdJ91xR/71lEHaTT0eDAGQysMPSx/9aPK6Z5+VgiD3cpzEyALAUHEsffCDZiwxKow//GFpft6sJ4zPHx0ygCV6Pemqq5LXPfywCWOsDDpkAAvuu890w0lh/OKLphsmjFcWgQxU3LFjJoQ9T7rzzuHrduxYPDu8fv346qsSAhmoqB//ePGCn1Eee8yE8M6d46mrypghAxWzZo158ecoW7eaIF6zZjw1waBDBirg2WcXxxKjwnjXLtMN799PGBeBQAYcdtddJoRf+9rR6/bvN0H8+c+Ppy4sj5EF4JhTp9J3t/PzJrBhBzpkwBGPPmrCNSmMv/OdxdMShLFd6JCBktu2Tfrzn5PXHTrEVZe2I5CBEnrhBenSS5PXXX+99Pjj+deDlcHIAiiR737XjBmSwnjPHjOSIIzLhQ4ZsFwcS7WUrdPJk9Jq/q8uLTpkwFJPPmm64aQwvuOOxU06wrjc+PEBlrnqKnPbWpJnnpGuvjr3cjBGBDJggePHpXNeozdUhpf8oGQYWQAFGmzSJYXxD36wOJaAu+iQgQKkfSDj8OF0x9vgBjrkc4RhqCAIVKvVFASBwjAsuiQ4otdbvOAnyaAbJoyrhUA+QxiGarVaiqJIcRwriiK1Wi1CGRfkk59M9xaOX/yCsUTVeXGGn/7k5GQ8PT2dYznFCoJAURQt+dz3ffXSbHsD/5Pl7PCpU9KqVfnWg2J5nrcvjuPJpHV0yGeYnZ3N9DncdCFjq927050d/sAHFrthwhgDbOqdodFoLNshNxqNAqpBEQZjq36/L0kLYytJajabQ79uYkKam0v++2dmpM2bV6RUOIgO+QydTkf1c84f1et1dTqdgirCuLXb7YUwHuj3+2q320vWHjmyuEmXFMaDbpgwxigE8hmazaa63a5835fnefJ9X91ud2RnBLekGVvde68J4aQ3L+/cySYdsiGQz9FsNtXr9TQ/P69er0cYV8yw8VSj0Vjohpdpls9y9KgJ4R07cigwBxz1tAeBDJxh6djqGkmxoqg38ute8YrFbvjii/OscGVx1NMuHHsDzhGGoT73uS06ceLGxLV790rvetcYisoJRz3HI+2xN05ZAP+z+HLQ5DGVKy8H5ainXRhZoPIeeCDdy0G3b3fv5aCjZuYYPzpkVFbaUJ2dla68Mt9aitLpdM46dy1x1LNIdMiolH/+M/sFP66GscRRT9sQyKiEj3zEhPBrXjN63Ve/Wr2zwxz1tAcjCzgt7Vii3zePPwNFokOGc37+8+xjCcIYNqBDhjPSdsO7d0vve1++tQDng0BGqfX70kUXpVtbpbkwyomRBUqp1TIdcVIY+371NulQXnTIKJW0Y4m//z35lUmAbeiQYb0nnsi+SUcYo4wIZFhrEMLXXjt63Ve+wlgCbiCQC8Q9tEsN7olI0w2/9JJZf++9+dcFjAOBXBDuoT3bt76V7uWg0mI3vHZt/nUB48R9yAXhHloj7Sbdnj3Se96Tby1AXrgP2XJVvof24EHpiivSrWUujCphZFGQKt5D+8Y3mo44KYxf+Uo26VBNBHJBlr67zd17aAebdPv3j1733HMmhJ9/fjx1AbYhkAvi+j20e/ZkPzucdDUm4Do29bCi0m7S3XOP5OA/BoBlsamHsVl8OWi6tatW5VsPUFaMLHDe7ror3ctBpcWxBGEMDEeHjMzSjiV+9zvp7W/PtxbAJQQyUun10l/Yw3E14PwwssBIN9xgOuKkMN62jbPDwIWiQ8ay0o4l/v1vacOGfGsBqoIOGQt+/evsZ4cJY2DlEMhYCOEPfShp5Xb5fqCpqWreSAfkjZFFRc3NSRMT6dZOTFyk48f7kqQoklqtliQ581QhYAs65Ir50pdMN5wUxhs2mJGE7wcLYTzQ7/fVbrdzrBKoJjrkiki7STczI23evPj7Kl8TCowbHbLDnn46+ybdmWEsVfOaUKAoBLKDLrvMhPDrXz963Z13Jp8drtI1oUDRGFk4Io7TvY9Oko4fl17+8nRrBxt37XZbs7OzajQa6nQ6bOgBOeD6zZKbmpK2b0+3lqfogGJw/abj0m7S/epXac4XA7ABM+SSCMNQjca1mTfpCGOgPAjkEnjrWyN9+tNNHTgw+qV0113HBT9AmTGysNhiJ+yPXHfggLRpU+7lAMgZHbJl9u1Lf3bY82qKY8IYcAWBbIlBCE8m7sPeI8mT5PFwBuAYRhYFmp9P/465iYl1On786MLveTgDcA8dcgF27zbdcJowHmzSPfjg/8v3fXmeJ9/31e12eTgDcAyBPEYve5kJ4ve/f/S63/9+6WmJZrOpXq+n+fl59Xo9wjhBGIYKgkC1Wk1BECgMucMZ9mNkkbMjR6T169Ot5bjaygjDUK1WS/3+4A7niDucUQp0yDnpdEw3nBTG3/42Z4dXWrvdXgjjAe5wRhnQIa+wtI80Hz0qXXxxvrVUFXc4o6zokFfA3/6W7uzwpZcudsOEcX64wxllRSBfgJtvNiG8devodXv3mhA+fHgsZa24sm2QcYczyoqRRUanTklr1qRbOz+ffoRhqzJukHGHM8qK+5BT+tnPpI9/PHndZz4j/fCH+dczLkEQKIqiJZ/7vq9erzf+goAS4j7kFZK2w3X1gh82yIDxYYa8jIMHs78c1MUwltggA8aJQD7Dgw+aEL7iitHrdu2qztlhNsiA8WFkofRjibk58/hzlbBBBoxPZTf1/vUv6VWvSl63ZYs5ZwwA5yvtpl7lRhZTU6YjTgrjmRkzkrAtjMt2JhhAepUYWZw+LW3bJv3lL8lrbZ4Ll/FMMID0nO6Qn3jCdMOrV48O46mpYjfp0na9XJoDuM3JDvlrX5O+8Y3RazZulGZnpYmJ8dQ0TJaulzPBgNuc6ZCPHZPWrjUd8agw3rnTdMIHDxYfxlK2rpczwYDbSh/IjzxiQviSS6STJ4eve/ppE8Q7doyvtjSydL2cCQbcVspAjmPp1ltNEN9yy/B1N99sNvTiWHrd68ZWXiZZut5ms6lut8u79QBHlSqQn3vOhHCtJj300PB1Dz1kQvi3vzVrbZa16+XdeoC7LI8ro9s1QXzllaPXHT5sgvhjHxtPXSuBrhfAgNVP6s3NJW+83X67dP/946kHAM6HE9dv/uQnw//sT3+S3vzm8dUCAHmzOpDf9CZp3TrpyBHz+yCQnnqqehf8AKgGqwP5uuvMwxsnTkiXX150NQCQL6sDWZLWry+6AgAYj1KcsgCAKiCQAcASlQ5k7hYGYBPrZ8h54W5hALapbIfM3cIAbFPZQOZuYQC2qWwgc7dweTH7h6sqG8iu3C1ctXAazP6jKFIcxwuzf9e/b1REHMepf910002xS6ampmLf92PP82Lf9+OpqamiS8pkamoqrtfrsaSFX/V6feT3Ufbv2ff9s77fwS/f94suDRhK0nScImOtvu0NowVBoCiKlnzu+756vd6Sz889WSKZfxWU6brPWq2m5f6b9TxP8/PzBVQEJEt721tlRxYuyLox6cLJEmb/cBmBXGJZw8mFkyWuzP6B5RDIJZY1nFzoLnnDClxGIJdY1nBypbvkvYJwVSkCuWpHu7LIEk50l4DdrD9l4cLJAADV5swpCxdOBgBAGtYHsgsnAwAgDesD2YWTAQCQhvWB7MrJAABIYnUgh2G4MENetWqVJHEyoCI4WYMqsvaNIeeerjh9+vRCZ0wYu423uaCqrD32lvXiHLiDnz1cU/pjb5yuqC5+9qgqawOZ0xXVxc8eVWVtIHO6orr42aOqrA1k7l2oLn72qCprN/UAwBWl39QDgKohkAHAEgQyAFiCQAYASxDIAGCJTKcsPM87KGnpM60AgFH8OI4vT1qUKZABAPlhZAEAliCQAcASBDIAWIJABgBLEMgAYAkCGQAsQSADgCUIZACwBIEMAJb4L/4/ciktfwZ6AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Create linear regression object\n",
"regr = linear_model.LinearRegression()\n",
"\n",
"# Train the model using the training sets\n",
"regr.fit(diabetes_X_train, diabetes_y_train)\n",
"\n",
"# Make predictions using the testing set\n",
"diabetes_y_pred = regr.predict(diabetes_X_test)\n",
"\n",
"# The coefficients\n",
"print('Coefficients: \\n', regr.coef_)\n",
"# The mean squared error\n",
"print(\"Mean squared error: %.2f\"\n",
" % mean_squared_error(diabetes_y_test, diabetes_y_pred))\n",
"# Explained variance score: 1 is perfect prediction\n",
"print('Variance score: %.2f' % r2_score(diabetes_y_test, diabetes_y_pred))\n",
"\n",
"# Plot outputs\n",
"\n",
"plt.scatter(diabetes_X_test, diabetes_y_test, color='black')\n",
"plt.plot(diabetes_X_test, diabetes_y_pred, color='blue', linewidth=3)\n",
"\n",
"plt.xticks(())\n",
"plt.yticks(())\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2623942383.2329097\n",
"('Objective Value:', 2663156030.210076)\n",
"('Objective Value:', 2718233711.865551)\n",
"('Objective Value:', 2755549908.0834365)\n",
"('Objective Value:', 2777025493.9161234)\n",
"('Objective Value:', 2788632001.1960063)\n",
"('Weights: ', array([[936.25300989]]))\n"
]
}
],
"source": [
"admm = optim.ADMM([diabetes_X_train, diabetes_y_train, 0.01], \"Lasso\") #, parallel = True)\n",
"\n",
"print(admm.getLoss())\n",
"for i in range(0, 5):\n",
" print('Objective Value:', admm.step())\n",
"\n",
"print('Weights: ',admm.getWeights())\n",
"W = admm.getWeights()"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"# Comparison with Scikit - LR\n",
"admm_predict = predict(diabetes_X_test, W)"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Mean squared error: 20747.88\n",
"Variance score: -3.29\n"
]
}
],
"source": [
"print(\"Mean squared error: %.2f\"\n",
" % mean_squared_error(diabetes_y_test, admm_predict))\n",
"# Explained variance score: 1 is perfect prediction\n",
"print('Variance score: %.2f' % r2_score(diabetes_y_test, admm_predict))\n"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAADuCAYAAAAOR30qAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADdZJREFUeJzt3V+IVOUfx/HPWU1rpH9sglbsmYgUU8Fy7aKSjOjGiIwuUiYkyBa8KUzIiyXKaMCrIkqMjaLQSW+iC7OIoGjFm22tJJPIhJmF1PyJqNj4Z22e38Wwnp3VZs85O+ec55zzfsFeNM2Br45+/M73PM95HGOMAADJ60q6AABAE4EMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsMT0IG++7bbbTLFYjKgUAMim/fv3nzTGzJ7sfYECuVgsanh4OHxVAJBDjuPU/LyPkQUAWIJABgBLEMgAYAkCGQAsQSADgCUIZACZVqlUVCwW1dXVpWKxqEqlknRJ/ynQsjcASJNKpaK+vj7V63VJUq1WU19fnySpVColWdo10SEDyKz+/v4rYTymXq+rv78/oYraI5ABZNbIyEig15NGIAPIrJ6enkCvJ41ABpBZ5XJZhUKh5bVCoaByuZxQRe0RyAAyq1QqaWBgQK7rynEcua6rgYEBK2/oSZJjjPH95t7eXsPDhQAgGMdx9htjeid7Hx0yAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsASBDACWIJABwBIEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgQwAliCQAcASBDIAWIJABgBLEMgAYAkCGQAsQSADgCUIZACwBIEMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsASBDACWIJABwBIEMgArVSoVFYtFdXV1qVgsqlKpJF1S5KYnXQAATFSpVNTX16d6vS5JqtVq6uvrkySVSqUkS4sUHTIA6/T3918J4zH1el39/f0JVRQPAhmAdUZGRgK9nhUEMgDr9PT0BHo9KwhkANYpl8sqFAotrxUKBZXL5YQqigeBDMA6pVJJAwMDcl1XjuPIdV0NDAxk+oaeRCADVsrjkq+JSqWSqtWqGo2GqtVq5sNYYtkbYJ28LvkCHTJgnbwu+QKBDFgnr0u+QCAD1snrki8QyIB18rrkCwQyYJ28LvmC5BhjfL+5t7fXDA8PR1gOAGSP4zj7jTG9k72PDhkALEEgA4AlCGQAsASBDACWIJABwBIEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgTwBR+cASApHOI3D0TkAkkSHPA5H50DiWxKSQ4c8DkfngG9JSBId8jgcnQO+JSFJBPI4HJ2DPH5LYkRjDwJ5HI7OQd6+JY2NaGq1mowxV0Y0hHIyOMIJGGfiDFlqfkvK6j/MxWJRtVrtqtdd11W1Wo2/oIziCCcghLx9S8rjiMZmdMhAjtEhx4MOGcCkuJFtFwIZyLG8jWhsx8gCACLGyAIAUoZABgBLEMgAYAkCOUFsWQUwHk97SwhPFQMwER1yQniqGICJCOSEsGUVwEQEckLy9lQxAJMjkBPCllUAExHICWHLKoCJ2DoNABFj6zQApAyBDACWIJAxKXYUAvFgpx7aYkchEB86ZLTFjkIgPgQy2mJHIRAfAhltsaMQiA+BjLbYUQjEh0BGW+woBOJDIKdEkkvPSqWSqtWqGo2GqtUqYQxEhGVvKcDSMyAf6JBTgKVnQD4QyCnA0jMgHwjkFGDpWXBs90YaEcgpwNKzYMZm7rVaTcaYKzN3Qhm2I5BTgKVnwTBzR1rxgHpkTldXl67159pxHDUajQQqQt7xgHp0TNrmsczckVYEMtpK4zyWmTvSikBGW2mcxzJzR1oxQ0ZbzGOBqWOGjI5gHgvEh0BOmbhvsDGPBeJDIKdIEjfYmMcC8WGGbIFKpaL+/n6NjIyop6dH5XL5moFXLBZVq9Wuet11XVWr1RgqBRCG3xkyj99MWJBHa/KQISDbGFkkLMiyMm6wAdlGICcsSNfLDTYg2wjkhAXpernBBmQbgZywoF0v59sBwZw/L33wgbRvnxRgDUMiCOSE0fUCnXfggLRiheQ4UqEgrV8vPfyw9P77SVfWHqssLFAqlQhgYApGR6Vt26SXX27/vqNH46knLAIZQCodPixt3Cjt3u3v/XPmSK+9Fm1NU8XIAkAqNBrSxx83RxCOI82bN3kYd3dLO3c2rz12rHmtzXIdyGl78DqQN7/9Js2c2QzgadOkF15o3qRr59lnpWq1eQPv5Elp9erm9WmQ20BO44PXgawzRnrjjWaAOo60aJF06VL7a2bOlD78ULp8uXn9rl2S68ZSbsflNpDT+OB1NPHNJluqVen225sB3NUlbd48+TVPPCH9/nszgC9ckNata3bQaZfbQM7KcyHyFk58s8mG997zuuC77mrOdyfz9tvSxYvNEP7yS2n+/OjrjJ0xxvfP0qVLTVa4rmskXfXjum7Spfm2Y8cOUygUWuovFApmx44dba9xXdc4jmNc1237Xhtl4XPLo+PHjVm40JhmnPr/+eqrpCvvDEnDxkfG5jaQw4SZbYKGUxZ+zY7jXPPX7DhO0qVhgu3bgwfwY48Zc+ZM0pV3HoHsQ9q7xaDhlIXuMgu/hqw6fdqYRx4JHsI7dyZdefT8BnIqZshRzUnT/lyIoI/jzMLcnCfe2WXPHm8WfMst0g8/TH7NkiXSiRNeJK9eHX2dqeEntcd+kuiQs/A1OypBf2+y0l2m/ZtNmtXrxjz9dPAueOvWpCtPlrIysshKiEQlSDjxjxvCGBwMHsA9PcbUaklXbo/MBDI3cTqL7hKTuXTJmHXrgofwW28Z02gkXb2d/Aay9YeccrAnEL39+6XeSY/gbDVrlvTjj9KCBdHUlCV+Dzm1+qZepVLRuXPnrnqdmzjZl7cNL3H791/pjju8G3J+w3jDBm+L8rlzhHHH+Wmjx37iHFlca94pyXR3d/M1O+OYdUdj797gYwjJmKGhpCtPP6V9ZMGoIr/47DvDGGn58ubRRUGsXdt8WM+MGdHUlUd+RxbWPqA+C2tmEQ6ffXgHD0qLFwe/7t13pZde6nw9CMbaGXLQTQ/IDj77YEolbxYcJIxPnvQGE4SxHawNZHZk5ReffXsjI14AO4702Wf+rtu0qXU63N0dbZ0IztpA5jTm/OKzv9qmTV4AB3n4eq3mBfCWLdHVh86w9qYekGdHjzaXpQX13HPS9u2drwdTk4l1yECevPii1wUHCeODB70umDBON2tXWQBZd/asdPPNwa9bvrz5VLW0HNwJ/+iQgRiVy14XHCSM9+71uuDBQcI4q+iQgQhdutQ8FTmM0VFpOn9Dc4UOGeiwTz/1uuAgYfzOO63L0gjj/OEjB6ao0Qh/BP0//0gTllwjx+iQgRC++cbrgoOE8SuvtHbBhDHGo0MGfJo5szkTDurECWn27M7Xg+yhQwb+w08/tW5R9hvGq1a1dsGEMfyiQwbGCbuc7M8/pbvv7mwtyB86ZOTaoUOtXbBfCxa0dsGEMTqBQEbu3HOPF8ALF/q/bmjIC+BDh6KrD/nFyAKZd/y4NHduuGsDPHsLmDI6ZGTSggVeFxwkjD/5pHUUAcSJDhmZEPZBPVLzBOYuWhNYgD+GSK01a8I9qGfz5tYumDCGLeiQkRqXL0vXXRfu2npduuGGztYDdBq9Aaz25pteFxwkjJcta+2CCWOkAR0yrDKVEcKxY9KcOZ2tB4gTHTISt3Wr1wUHCeNp01q7YMIYaUeHjESE3aL866/SokWdrQWwBR0yYvH11+G2KEutXTBhjCwjkBGZ8QG8cqX/6z7/nM0ZyCdGFuiYgwelxYvDXUvwAnTImKLxXXCQMN6yhS4YmIgOGYH8/Xf41QyNBsfXA+3QIWNSjz/udcFBwrhUau2CCWOgPTpkXOXChfA7286fl66/vrP1AHlBhwxJ0oYNXhccJIzvvbe1CyaMgfDokHNqKluUOUUZiAYdco5s2xZui/LcuZyiDMSBDjnjOEUZSA865Izp1BZlwhiIHx1yBoTtgvftkx58sLO1AAiPQE6hAwekJUvCXcuuOMBejCxSYvwYIkgY79rFFmUgLeiQLfXXX9Kdd4a7luAF0okO2SL33ed1wUHCuFymCwaygA45QVPZojw6Kk3n0wMyhQ45ZgMD4bYor13b2gUTxkD28Nc6Yo1G8zDOMM6elW68sbP1ALAXHXIEdu/2uuAgYfzUU61dMGEM5AsdcoeE3Zxx6pR0662drQVAOtEhhzQ0FG6L8po1rV0wYQxgDB1yAPPnS3/8Efy6alVy3Y6XAyBj6JDbOHy4tQv2G8bLlrV2wYQxAD8I5Alef90L4Hnz/F/3yy9eAA8NRVcfgOzK/cji9Olwc9xZs6Rz5zpfD4D8ymWHPDjodcFBwvjbb70umDAG0Gm56JBHR6X166WPPgp+baPB8fUA4pHZDvnnn70ueMYM/2H83XetN+QIYwBxyUwgNxrSq696IXz//f6u27ix2UGPBfCjj0ZbJwD8l1SPLA4fbi4xO3Mm2HXDw9LSpdHUBABhpapDNkbasqV1WZqfMH7+eeniRa8LJowB2Mj6DvnUKemBB6QjR4Jd9/330ooVkZQEAJGwukM+ckTq7vYXxk8+2VyKNtYFE8YA0sbqDnlwsP3//+ILadWqeGoBgKhZ3SE/84y0cqX33w891BxhjHXBhDGALLG6Q77pJmnPnqSrAIB4WN0hA0CeEMgAYAkCGQAsQSADgCUIZACwBIEMAJYgkAHAEo4xxv+bHed/kmrRlQMAmeQaY2ZP9qZAgQwAiA4jCwCwBIEMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsMT/AZ6ZQ8zfgirvAAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(diabetes_X_test, diabetes_y_test, color='black')\n",
"plt.plot(diabetes_X_test, admm_predict, color='blue', linewidth=3)\n",
"\n",
"plt.xticks(())\n",
"plt.yticks(())\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
105 changes: 105 additions & 0 deletions Test_ADMM.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import torch.optim as optim"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"num_iterations = 5\n",
"N = 1000\n",
"D = 20\n",
"\n",
"A = np.random.randn(N, D)\n",
"b = np.random.randn(N, 1)\n",
"Y = np.random.rand(N, 1)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"13218.725722259838\n",
"('O Val:', 503.7030770044544)\n",
"('O Val:', 503.70307384103387)\n",
"('O Val:', 503.7030786237145)\n",
"('O Val:', 503.7030712532673)\n",
"('O Val:', 503.7030778652381)\n",
"('Weights: ', array([[-0.02279339],\n",
" [-0.03123076],\n",
" [ 0.03235469],\n",
" [-0.00769174],\n",
" [-0.03006917],\n",
" [-0.01935186],\n",
" [ 0.0200083 ],\n",
" [-0.0228672 ],\n",
" [ 0.03491097],\n",
" [-0.0216022 ],\n",
" [ 0.0118515 ],\n",
" [-0.03075605],\n",
" [-0.019314 ],\n",
" [ 0.05811015],\n",
" [ 0.02316836],\n",
" [ 0.02389774],\n",
" [-0.06124286],\n",
" [-0.02646492],\n",
" [-0.01167727],\n",
" [-0.03814794]]))\n"
]
}
],
"source": [
"admm = optim.ADMM([A, b, 0.01], \"Lasso\") #, parallel = True)\n",
"\n",
"print(admm.getLoss())\n",
"for i in range(0, num_iterations):\n",
" print('O Val:', admm.step())\n",
"\n",
"print('Weights: ',admm.getWeights())\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.14"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading

0 comments on commit e0c30b0

Please sign in to comment.