diff --git a/Diabetes_Test.ipynb b/Diabetes_Test.ipynb new file mode 100644 index 0000000..07ad0d6 --- /dev/null +++ b/Diabetes_Test.ipynb @@ -0,0 +1,261 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "from sklearn import datasets, linear_model\n", + "from sklearn.metrics import mean_squared_error, r2_score\n", + "import torch.optim as optim" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [], + "source": [ + "# Helper functions\n", + "def predict(X, W):\n", + " YPred = X.dot(W)\n", + " return YPred" + ] + }, + { + "cell_type": "code", + "execution_count": 97, + "metadata": {}, + "outputs": [], + "source": [ + "num_iterations = 5\n", + "N = 10000\n", + "D = 10\n", + "\n", + "X = np.random.randn(N, D)\n", + "Y = np.random.randn(N, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "# Load the diabetes dataset\n", + "diabetes = datasets.load_diabetes()\n", + "\n", + "# Use only one feature\n", + "diabetes_X = diabetes.data[:, np.newaxis, 2]\n", + "\n", + "# Split the data into training/testing sets\n", + "diabetes_X_train = diabetes_X[:-20]\n", + "diabetes_X_test = diabetes_X[-20:]\n", + "\n", + "# Split the targets into training/testing sets\n", + "diabetes_y_train = diabetes.target[:-20]\n", + "diabetes_y_test = diabetes.target[-20:]" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "#print(diabetes)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "diabetes_X_train = X[:-20]\n", + "diabetes_X_test = X[-20:]\n", + "diabetes_y_train = Y[:-20]\n", + "diabetes_y_test = Y[-20:]" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "('Coefficients: \\n', array([938.23786125]))\n", + "Mean squared error: 2548.07\n", + "Variance score: 0.47\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAADuCAYAAAAOR30qAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEGRJREFUeJzt3W+MXFX9x/HPnf7RHaC1UFBjmXuRWKlFEFir8RcV/+H/JwY1cawx/pkHBEIkoUYm0WgyxOojIfgzQ41R9z5RiSZiTEqtxJhodCskFmEJkblbNJi2gm0zXfpnrw+Os9t2d+be2+6de+6571fSB52ebb6bhU++/Z5zz/XiOBYAoHi1ogsAABgEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgQwAliCQAcASq7Ms3rhxYxwEQU6lAICb9u3bdyiO48uT1mUK5CAIND09ff5VAUAFeZ4XpVnHyAIALEEgA4AlCGQAsASBDACWIJABwBIEMgCnhWGoIAhUq9UUBIHCMCy6pKEyHXsDgDIJw1CtVkv9fl+SFEWRWq2WJKnZbBZZ2rLokAE4q91uL4TxQL/fV7vdLqii0QhkAM6anZ3N9HnRCGQAzmo0Gpk+LxqBDMBZnU5H9Xr9rM/q9bo6nU5BFY1GIANwVrPZVLfble/78jxPvu+r2+1auaEnSV4cx6kXT05OxlwuBADZeJ63L47jyaR1dMgAYAkCGQAsQSADgCUIZACwBIEMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsASBDACWIJABwBIEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgQwAliCQAcASBDIAWIJABgBLEMgAYAkCGQAsQSADgCUIZACwBIEMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsASBDACWIJABwBIEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgQwAliCQAcASBDIAWIJABgBLEMgAYAkCGYCznn9euuEGyfOka66RpqeLrmg0AhmAlcIwVBAEqtVqCoJAYRim/tpf/tKE8KtfLT3+uPlsZkb60Y9yKnaFrC66AAA4VxiGarVa6vf7kqQoitRqtSRJzWZz2a85cUK67Tbp+98f/vcO+VJreHEcp148OTkZT9ve8wMovSAIFEXRks9931ev1zvrs6eekt72NumFF4b/fVdfLe3dKzUaK1xoSp7n7YvjeDJpHSMLANaZnZ1N/Px73zNjiS1bhofx3XdLp05JzzxTXBhnwcgCgHUajcayHfKmTVt1yy3SI4+M/vpHH5Xe+c58assTHTIA63Q6HdXr9TM++T9JsQ4c+OvQMH73u02nHMflDGOJDhmAhZrNpubnPX3hC1t14sT1I9fef790++1jKixnBDIAqzz5pPSGN0jSp4auWbdO+sMfBuvcwcgCgBW+/nWzSTcqZD/7WWluTvrPf9wLY4kOGUCBjh2TNm6UXnpp9LpvflP68pfHU1OR6JABC13IU2pl8JvfmG74kktGh/HMjNmkq0IYSwQyYJ3BU2pRFCmO44Wn1MoeynEsfeITJojf+97h697xDun0abN+8+bx1WcDntQDLJPlKbUy+Mc/pE2bktf99KfSrbfmX08ReFIPKKk0T6mVwa5dphtOCuNDh0w37GoYZ0EgA5ZpDHnGd9jnNjl50lxz6XnSF784fN1tt5kQjmPpssvGV5/tCGTAMkufUpPq9bo6nU5BFSV77DETwmvXmo24Yf74RxPCDzwwvtrKhEAGLNNsNtXtduX7vjzPk+/76na7Q6+dLNLdd5sgvvHG4WsaDXN2OI6lt7xlfLWVEZt6ADJ58UVpw4bkdffdJ91xR/71lEHaTT0eDAGQysMPSx/9aPK6Z5+VgiD3cpzEyALAUHEsffCDZiwxKow//GFpft6sJ4zPHx0ygCV6Pemqq5LXPfywCWOsDDpkAAvuu890w0lh/OKLphsmjFcWgQxU3LFjJoQ9T7rzzuHrduxYPDu8fv346qsSAhmoqB//ePGCn1Eee8yE8M6d46mrypghAxWzZo158ecoW7eaIF6zZjw1waBDBirg2WcXxxKjwnjXLtMN799PGBeBQAYcdtddJoRf+9rR6/bvN0H8+c+Ppy4sj5EF4JhTp9J3t/PzJrBhBzpkwBGPPmrCNSmMv/OdxdMShLFd6JCBktu2Tfrzn5PXHTrEVZe2I5CBEnrhBenSS5PXXX+99Pjj+deDlcHIAiiR737XjBmSwnjPHjOSIIzLhQ4ZsFwcS7WUrdPJk9Jq/q8uLTpkwFJPPmm64aQwvuOOxU06wrjc+PEBlrnqKnPbWpJnnpGuvjr3cjBGBDJggePHpXNeozdUhpf8oGQYWQAFGmzSJYXxD36wOJaAu+iQgQKkfSDj8OF0x9vgBjrkc4RhqCAIVKvVFASBwjAsuiQ4otdbvOAnyaAbJoyrhUA+QxiGarVaiqJIcRwriiK1Wi1CGRfkk59M9xaOX/yCsUTVeXGGn/7k5GQ8PT2dYznFCoJAURQt+dz3ffXSbHsD/5Pl7PCpU9KqVfnWg2J5nrcvjuPJpHV0yGeYnZ3N9DncdCFjq927050d/sAHFrthwhgDbOqdodFoLNshNxqNAqpBEQZjq36/L0kLYytJajabQ79uYkKam0v++2dmpM2bV6RUOIgO+QydTkf1c84f1et1dTqdgirCuLXb7YUwHuj3+2q320vWHjmyuEmXFMaDbpgwxigE8hmazaa63a5835fnefJ9X91ud2RnBLekGVvde68J4aQ3L+/cySYdsiGQz9FsNtXr9TQ/P69er0cYV8yw8VSj0Vjohpdpls9y9KgJ4R07cigwBxz1tAeBDJxh6djqGkmxoqg38ute8YrFbvjii/OscGVx1NMuHHsDzhGGoT73uS06ceLGxLV790rvetcYisoJRz3HI+2xN05ZAP+z+HLQ5DGVKy8H5ainXRhZoPIeeCDdy0G3b3fv5aCjZuYYPzpkVFbaUJ2dla68Mt9aitLpdM46dy1x1LNIdMiolH/+M/sFP66GscRRT9sQyKiEj3zEhPBrXjN63Ve/Wr2zwxz1tAcjCzgt7Vii3zePPwNFokOGc37+8+xjCcIYNqBDhjPSdsO7d0vve1++tQDng0BGqfX70kUXpVtbpbkwyomRBUqp1TIdcVIY+371NulQXnTIKJW0Y4m//z35lUmAbeiQYb0nnsi+SUcYo4wIZFhrEMLXXjt63Ve+wlgCbiCQC8Q9tEsN7olI0w2/9JJZf++9+dcFjAOBXBDuoT3bt76V7uWg0mI3vHZt/nUB48R9yAXhHloj7Sbdnj3Se96Tby1AXrgP2XJVvof24EHpiivSrWUujCphZFGQKt5D+8Y3mo44KYxf+Uo26VBNBHJBlr67zd17aAebdPv3j1733HMmhJ9/fjx1AbYhkAvi+j20e/ZkPzucdDUm4Do29bCi0m7S3XOP5OA/BoBlsamHsVl8OWi6tatW5VsPUFaMLHDe7ror3ctBpcWxBGEMDEeHjMzSjiV+9zvp7W/PtxbAJQQyUun10l/Yw3E14PwwssBIN9xgOuKkMN62jbPDwIWiQ8ay0o4l/v1vacOGfGsBqoIOGQt+/evsZ4cJY2DlEMhYCOEPfShp5Xb5fqCpqWreSAfkjZFFRc3NSRMT6dZOTFyk48f7kqQoklqtliQ581QhYAs65Ir50pdMN5wUxhs2mJGE7wcLYTzQ7/fVbrdzrBKoJjrkiki7STczI23evPj7Kl8TCowbHbLDnn46+ybdmWEsVfOaUKAoBLKDLrvMhPDrXz963Z13Jp8drtI1oUDRGFk4Io7TvY9Oko4fl17+8nRrBxt37XZbs7OzajQa6nQ6bOgBOeD6zZKbmpK2b0+3lqfogGJw/abj0m7S/epXac4XA7ABM+SSCMNQjca1mTfpCGOgPAjkEnjrWyN9+tNNHTgw+qV0113HBT9AmTGysNhiJ+yPXHfggLRpU+7lAMgZHbJl9u1Lf3bY82qKY8IYcAWBbIlBCE8m7sPeI8mT5PFwBuAYRhYFmp9P/465iYl1On786MLveTgDcA8dcgF27zbdcJowHmzSPfjg/8v3fXmeJ9/31e12eTgDcAyBPEYve5kJ4ve/f/S63/9+6WmJZrOpXq+n+fl59Xo9wjhBGIYKgkC1Wk1BECgMucMZ9mNkkbMjR6T169Ot5bjaygjDUK1WS/3+4A7niDucUQp0yDnpdEw3nBTG3/42Z4dXWrvdXgjjAe5wRhnQIa+wtI80Hz0qXXxxvrVUFXc4o6zokFfA3/6W7uzwpZcudsOEcX64wxllRSBfgJtvNiG8devodXv3mhA+fHgsZa24sm2QcYczyoqRRUanTklr1qRbOz+ffoRhqzJukHGHM8qK+5BT+tnPpI9/PHndZz4j/fCH+dczLkEQKIqiJZ/7vq9erzf+goAS4j7kFZK2w3X1gh82yIDxYYa8jIMHs78c1MUwltggA8aJQD7Dgw+aEL7iitHrdu2qztlhNsiA8WFkofRjibk58/hzlbBBBoxPZTf1/vUv6VWvSl63ZYs5ZwwA5yvtpl7lRhZTU6YjTgrjmRkzkrAtjMt2JhhAepUYWZw+LW3bJv3lL8lrbZ4Ll/FMMID0nO6Qn3jCdMOrV48O46mpYjfp0na9XJoDuM3JDvlrX5O+8Y3RazZulGZnpYmJ8dQ0TJaulzPBgNuc6ZCPHZPWrjUd8agw3rnTdMIHDxYfxlK2rpczwYDbSh/IjzxiQviSS6STJ4eve/ppE8Q7doyvtjSydL2cCQbcVspAjmPp1ltNEN9yy/B1N99sNvTiWHrd68ZWXiZZut5ms6lut8u79QBHlSqQn3vOhHCtJj300PB1Dz1kQvi3vzVrbZa16+XdeoC7LI8ro9s1QXzllaPXHT5sgvhjHxtPXSuBrhfAgNVP6s3NJW+83X67dP/946kHAM6HE9dv/uQnw//sT3+S3vzm8dUCAHmzOpDf9CZp3TrpyBHz+yCQnnqqehf8AKgGqwP5uuvMwxsnTkiXX150NQCQL6sDWZLWry+6AgAYj1KcsgCAKiCQAcASlQ5k7hYGYBPrZ8h54W5hALapbIfM3cIAbFPZQOZuYQC2qWwgc7dweTH7h6sqG8iu3C1ctXAazP6jKFIcxwuzf9e/b1REHMepf910002xS6ampmLf92PP82Lf9+OpqamiS8pkamoqrtfrsaSFX/V6feT3Ufbv2ff9s77fwS/f94suDRhK0nScImOtvu0NowVBoCiKlnzu+756vd6Sz889WSKZfxWU6brPWq2m5f6b9TxP8/PzBVQEJEt721tlRxYuyLox6cLJEmb/cBmBXGJZw8mFkyWuzP6B5RDIJZY1nFzoLnnDClxGIJdY1nBypbvkvYJwVSkCuWpHu7LIEk50l4DdrD9l4cLJAADV5swpCxdOBgBAGtYHsgsnAwAgDesD2YWTAQCQhvWB7MrJAABIYnUgh2G4MENetWqVJHEyoCI4WYMqsvaNIeeerjh9+vRCZ0wYu423uaCqrD32lvXiHLiDnz1cU/pjb5yuqC5+9qgqawOZ0xXVxc8eVWVtIHO6orr42aOqrA1k7l2oLn72qCprN/UAwBWl39QDgKohkAHAEgQyAFiCQAYASxDIAGCJTKcsPM87KGnpM60AgFH8OI4vT1qUKZABAPlhZAEAliCQAcASBDIAWIJABgBLEMgAYAkCGQAsQSADgCUIZACwBIEMAJb4L/4/ciktfwZ6AAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Create linear regression object\n", + "regr = linear_model.LinearRegression()\n", + "\n", + "# Train the model using the training sets\n", + "regr.fit(diabetes_X_train, diabetes_y_train)\n", + "\n", + "# Make predictions using the testing set\n", + "diabetes_y_pred = regr.predict(diabetes_X_test)\n", + "\n", + "# The coefficients\n", + "print('Coefficients: \\n', regr.coef_)\n", + "# The mean squared error\n", + "print(\"Mean squared error: %.2f\"\n", + " % mean_squared_error(diabetes_y_test, diabetes_y_pred))\n", + "# Explained variance score: 1 is perfect prediction\n", + "print('Variance score: %.2f' % r2_score(diabetes_y_test, diabetes_y_pred))\n", + "\n", + "# Plot outputs\n", + "\n", + "plt.scatter(diabetes_X_test, diabetes_y_test, color='black')\n", + "plt.plot(diabetes_X_test, diabetes_y_pred, color='blue', linewidth=3)\n", + "\n", + "plt.xticks(())\n", + "plt.yticks(())\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "2623942383.2329097\n", + "('Objective Value:', 2663156030.210076)\n", + "('Objective Value:', 2718233711.865551)\n", + "('Objective Value:', 2755549908.0834365)\n", + "('Objective Value:', 2777025493.9161234)\n", + "('Objective Value:', 2788632001.1960063)\n", + "('Weights: ', array([[936.25300989]]))\n" + ] + } + ], + "source": [ + "admm = optim.ADMM([diabetes_X_train, diabetes_y_train, 0.01], \"Lasso\") #, parallel = True)\n", + "\n", + "print(admm.getLoss())\n", + "for i in range(0, 5):\n", + " print('Objective Value:', admm.step())\n", + "\n", + "print('Weights: ',admm.getWeights())\n", + "W = admm.getWeights()" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "# Comparison with Scikit - LR\n", + "admm_predict = predict(diabetes_X_test, W)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Mean squared error: 20747.88\n", + "Variance score: -3.29\n" + ] + } + ], + "source": [ + "print(\"Mean squared error: %.2f\"\n", + " % mean_squared_error(diabetes_y_test, admm_predict))\n", + "# Explained variance score: 1 is perfect prediction\n", + "print('Variance score: %.2f' % r2_score(diabetes_y_test, admm_predict))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAWQAAADuCAYAAAAOR30qAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAADdZJREFUeJzt3V+IVOUfx/HPWU1rpH9sglbsmYgUU8Fy7aKSjOjGiIwuUiYkyBa8KUzIiyXKaMCrIkqMjaLQSW+iC7OIoGjFm22tJJPIhJmF1PyJqNj4Z22e38Wwnp3VZs85O+ec55zzfsFeNM2Br45+/M73PM95HGOMAADJ60q6AABAE4EMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsMT0IG++7bbbTLFYjKgUAMim/fv3nzTGzJ7sfYECuVgsanh4OHxVAJBDjuPU/LyPkQUAWIJABgBLEMgAYAkCGQAsQSADgCUIZACZVqlUVCwW1dXVpWKxqEqlknRJ/ynQsjcASJNKpaK+vj7V63VJUq1WU19fnySpVColWdo10SEDyKz+/v4rYTymXq+rv78/oYraI5ABZNbIyEig15NGIAPIrJ6enkCvJ41ABpBZ5XJZhUKh5bVCoaByuZxQRe0RyAAyq1QqaWBgQK7rynEcua6rgYEBK2/oSZJjjPH95t7eXsPDhQAgGMdx9htjeid7Hx0yAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsASBDACWIJABwBIEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgQwAliCQAcASBDIAWIJABgBLEMgAYAkCGQAsQSADgCUIZACwBIEMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsASBDACWIJABwBIEMgArVSoVFYtFdXV1qVgsqlKpJF1S5KYnXQAATFSpVNTX16d6vS5JqtVq6uvrkySVSqUkS4sUHTIA6/T3918J4zH1el39/f0JVRQPAhmAdUZGRgK9nhUEMgDr9PT0BHo9KwhkANYpl8sqFAotrxUKBZXL5YQqigeBDMA6pVJJAwMDcl1XjuPIdV0NDAxk+oaeRCADVsrjkq+JSqWSqtWqGo2GqtVq5sNYYtkbYJ28LvkCHTJgnbwu+QKBDFgnr0u+QCAD1snrki8QyIB18rrkCwQyYJ28LvmC5BhjfL+5t7fXDA8PR1gOAGSP4zj7jTG9k72PDhkALEEgA4AlCGQAsASBDACWIJABwBIEMgBYgkAGAEsQyABgCQIZACxBIAOAJQhkALAEgTwBR+cASApHOI3D0TkAkkSHPA5H50DiWxKSQ4c8DkfngG9JSBId8jgcnQO+JSFJBPI4HJ2DPH5LYkRjDwJ5HI7OQd6+JY2NaGq1mowxV0Y0hHIyOMIJGGfiDFlqfkvK6j/MxWJRtVrtqtdd11W1Wo2/oIziCCcghLx9S8rjiMZmdMhAjtEhx4MOGcCkuJFtFwIZyLG8jWhsx8gCACLGyAIAUoZABgBLEMgAYAkCOUFsWQUwHk97SwhPFQMwER1yQniqGICJCOSEsGUVwEQEckLy9lQxAJMjkBPCllUAExHICWHLKoCJ2DoNABFj6zQApAyBDACWIJAxKXYUAvFgpx7aYkchEB86ZLTFjkIgPgQy2mJHIRAfAhltsaMQiA+BjLbYUQjEh0BGW+woBOJDIKdEkkvPSqWSqtWqGo2GqtUqYQxEhGVvKcDSMyAf6JBTgKVnQD4QyCnA0jMgHwjkFGDpWXBs90YaEcgpwNKzYMZm7rVaTcaYKzN3Qhm2I5BTgKVnwTBzR1rxgHpkTldXl67159pxHDUajQQqQt7xgHp0TNrmsczckVYEMtpK4zyWmTvSikBGW2mcxzJzR1oxQ0ZbzGOBqWOGjI5gHgvEh0BOmbhvsDGPBeJDIKdIEjfYmMcC8WGGbIFKpaL+/n6NjIyop6dH5XL5moFXLBZVq9Wuet11XVWr1RgqBRCG3xkyj99MWJBHa/KQISDbGFkkLMiyMm6wAdlGICcsSNfLDTYg2wjkhAXpernBBmQbgZywoF0v59sBwZw/L33wgbRvnxRgDUMiCOSE0fUCnXfggLRiheQ4UqEgrV8vPfyw9P77SVfWHqssLFAqlQhgYApGR6Vt26SXX27/vqNH46knLAIZQCodPixt3Cjt3u3v/XPmSK+9Fm1NU8XIAkAqNBrSxx83RxCOI82bN3kYd3dLO3c2rz12rHmtzXIdyGl78DqQN7/9Js2c2QzgadOkF15o3qRr59lnpWq1eQPv5Elp9erm9WmQ20BO44PXgawzRnrjjWaAOo60aJF06VL7a2bOlD78ULp8uXn9rl2S68ZSbsflNpDT+OB1NPHNJluqVen225sB3NUlbd48+TVPPCH9/nszgC9ckNata3bQaZfbQM7KcyHyFk58s8mG997zuuC77mrOdyfz9tvSxYvNEP7yS2n+/OjrjJ0xxvfP0qVLTVa4rmskXfXjum7Spfm2Y8cOUygUWuovFApmx44dba9xXdc4jmNc1237Xhtl4XPLo+PHjVm40JhmnPr/+eqrpCvvDEnDxkfG5jaQw4SZbYKGUxZ+zY7jXPPX7DhO0qVhgu3bgwfwY48Zc+ZM0pV3HoHsQ9q7xaDhlIXuMgu/hqw6fdqYRx4JHsI7dyZdefT8BnIqZshRzUnT/lyIoI/jzMLcnCfe2WXPHm8WfMst0g8/TH7NkiXSiRNeJK9eHX2dqeEntcd+kuiQs/A1OypBf2+y0l2m/ZtNmtXrxjz9dPAueOvWpCtPlrIysshKiEQlSDjxjxvCGBwMHsA9PcbUaklXbo/MBDI3cTqL7hKTuXTJmHXrgofwW28Z02gkXb2d/Aay9YeccrAnEL39+6XeSY/gbDVrlvTjj9KCBdHUlCV+Dzm1+qZepVLRuXPnrnqdmzjZl7cNL3H791/pjju8G3J+w3jDBm+L8rlzhHHH+Wmjx37iHFlca94pyXR3d/M1O+OYdUdj797gYwjJmKGhpCtPP6V9ZMGoIr/47DvDGGn58ubRRUGsXdt8WM+MGdHUlUd+RxbWPqA+C2tmEQ6ffXgHD0qLFwe/7t13pZde6nw9CMbaGXLQTQ/IDj77YEolbxYcJIxPnvQGE4SxHawNZHZk5ReffXsjI14AO4702Wf+rtu0qXU63N0dbZ0IztpA5jTm/OKzv9qmTV4AB3n4eq3mBfCWLdHVh86w9qYekGdHjzaXpQX13HPS9u2drwdTk4l1yECevPii1wUHCeODB70umDBON2tXWQBZd/asdPPNwa9bvrz5VLW0HNwJ/+iQgRiVy14XHCSM9+71uuDBQcI4q+iQgQhdutQ8FTmM0VFpOn9Dc4UOGeiwTz/1uuAgYfzOO63L0gjj/OEjB6ao0Qh/BP0//0gTllwjx+iQgRC++cbrgoOE8SuvtHbBhDHGo0MGfJo5szkTDurECWn27M7Xg+yhQwb+w08/tW5R9hvGq1a1dsGEMfyiQwbGCbuc7M8/pbvv7mwtyB86ZOTaoUOtXbBfCxa0dsGEMTqBQEbu3HOPF8ALF/q/bmjIC+BDh6KrD/nFyAKZd/y4NHduuGsDPHsLmDI6ZGTSggVeFxwkjD/5pHUUAcSJDhmZEPZBPVLzBOYuWhNYgD+GSK01a8I9qGfz5tYumDCGLeiQkRqXL0vXXRfu2npduuGGztYDdBq9Aaz25pteFxwkjJcta+2CCWOkAR0yrDKVEcKxY9KcOZ2tB4gTHTISt3Wr1wUHCeNp01q7YMIYaUeHjESE3aL866/SokWdrQWwBR0yYvH11+G2KEutXTBhjCwjkBGZ8QG8cqX/6z7/nM0ZyCdGFuiYgwelxYvDXUvwAnTImKLxXXCQMN6yhS4YmIgOGYH8/Xf41QyNBsfXA+3QIWNSjz/udcFBwrhUau2CCWOgPTpkXOXChfA7286fl66/vrP1AHlBhwxJ0oYNXhccJIzvvbe1CyaMgfDokHNqKluUOUUZiAYdco5s2xZui/LcuZyiDMSBDjnjOEUZSA865Izp1BZlwhiIHx1yBoTtgvftkx58sLO1AAiPQE6hAwekJUvCXcuuOMBejCxSYvwYIkgY79rFFmUgLeiQLfXXX9Kdd4a7luAF0okO2SL33ed1wUHCuFymCwaygA45QVPZojw6Kk3n0wMyhQ45ZgMD4bYor13b2gUTxkD28Nc6Yo1G8zDOMM6elW68sbP1ALAXHXIEdu/2uuAgYfzUU61dMGEM5AsdcoeE3Zxx6pR0662drQVAOtEhhzQ0FG6L8po1rV0wYQxgDB1yAPPnS3/8Efy6alVy3Y6XAyBj6JDbOHy4tQv2G8bLlrV2wYQxAD8I5Alef90L4Hnz/F/3yy9eAA8NRVcfgOzK/cji9Olwc9xZs6Rz5zpfD4D8ymWHPDjodcFBwvjbb70umDAG0Gm56JBHR6X166WPPgp+baPB8fUA4pHZDvnnn70ueMYM/2H83XetN+QIYwBxyUwgNxrSq696IXz//f6u27ix2UGPBfCjj0ZbJwD8l1SPLA4fbi4xO3Mm2HXDw9LSpdHUBABhpapDNkbasqV1WZqfMH7+eeniRa8LJowB2Mj6DvnUKemBB6QjR4Jd9/330ooVkZQEAJGwukM+ckTq7vYXxk8+2VyKNtYFE8YA0sbqDnlwsP3//+ILadWqeGoBgKhZ3SE/84y0cqX33w891BxhjHXBhDGALLG6Q77pJmnPnqSrAIB4WN0hA0CeEMgAYAkCGQAsQSADgCUIZACwBIEMAJYgkAHAEo4xxv+bHed/kmrRlQMAmeQaY2ZP9qZAgQwAiA4jCwCwBIEMAJYgkAHAEgQyAFiCQAYASxDIAGAJAhkALEEgA4AlCGQAsMT/AZ6ZQ8zfgirvAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "plt.scatter(diabetes_X_test, diabetes_y_test, color='black')\n", + "plt.plot(diabetes_X_test, admm_predict, color='blue', linewidth=3)\n", + "\n", + "plt.xticks(())\n", + "plt.yticks(())\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/Test_ADMM.ipynb b/Test_ADMM.ipynb new file mode 100644 index 0000000..c010dcb --- /dev/null +++ b/Test_ADMM.ipynb @@ -0,0 +1,105 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch.optim as optim" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "num_iterations = 5\n", + "N = 1000\n", + "D = 20\n", + "\n", + "A = np.random.randn(N, D)\n", + "b = np.random.randn(N, 1)\n", + "Y = np.random.rand(N, 1)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "13218.725722259838\n", + "('O Val:', 503.7030770044544)\n", + "('O Val:', 503.70307384103387)\n", + "('O Val:', 503.7030786237145)\n", + "('O Val:', 503.7030712532673)\n", + "('O Val:', 503.7030778652381)\n", + "('Weights: ', array([[-0.02279339],\n", + " [-0.03123076],\n", + " [ 0.03235469],\n", + " [-0.00769174],\n", + " [-0.03006917],\n", + " [-0.01935186],\n", + " [ 0.0200083 ],\n", + " [-0.0228672 ],\n", + " [ 0.03491097],\n", + " [-0.0216022 ],\n", + " [ 0.0118515 ],\n", + " [-0.03075605],\n", + " [-0.019314 ],\n", + " [ 0.05811015],\n", + " [ 0.02316836],\n", + " [ 0.02389774],\n", + " [-0.06124286],\n", + " [-0.02646492],\n", + " [-0.01167727],\n", + " [-0.03814794]]))\n" + ] + } + ], + "source": [ + "admm = optim.ADMM([A, b, 0.01], \"Lasso\") #, parallel = True)\n", + "\n", + "print(admm.getLoss())\n", + "for i in range(0, num_iterations):\n", + " print('O Val:', admm.step())\n", + "\n", + "print('Weights: ',admm.getWeights())\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/admm.py b/admm.py index 994f872..19eb5fb 100644 --- a/admm.py +++ b/admm.py @@ -2,7 +2,7 @@ from numpy.linalg import inv from numpy.linalg import norm from joblib import Parallel, delayed -from multiprocessing import Process, Manager, cpu_count +from multiprocessing import Process, Manager, cpu_count, Pool class SolveIndividual: @@ -42,12 +42,12 @@ def step(self): # Solve for X_t+1 self.X = inv(self.A.T.dot(self.A) + self.rho).dot(self.A.T.dot(self.b) + self.rho * self.Z - self.nu) - + # Solve for Z_t+1 self.Z = self.X + self.nu / self.rho - (self.alpha / self.rho) * np.sign(self.Z) # Combine self.nu = self.nu + self.rho * (self.X - self.Z) - + def solveIndividual(self, i): solve = SolveIndividual() return solve.solve(self.A[i], np.asscalar(self.b[i]), self.nuBar[i].reshape(-1, 1), self.rho, self.Z) @@ -58,30 +58,47 @@ def combineSolution(self, i): def step_parallel(self): # Solve for X_t+1 - Parallel(n_jobs = self.numberOfThreads, backend = "threading")( - delayed(self.solveIndividual)(i) for i in range(0, self.N-1)) - - self.X = np.average(self.XBar, axis = 0) + #Parallel(n_jobs = self.numberOfThreads, backend = "threading")( + # delayed(self.solveIndividual)(i) for i in range(0, self.N-1)) + process = [] + for i in range(0, self.N-1): + p = Process(target = self.solveIndividual, args= (i,)) + p.start() + process.append(p) + + for p in process: + p.join() + + self.X = np.average(self.XBar, axis = 0) self.nu = np.average(self.nuBar, axis = 0) - + self.X = self.X.reshape(-1, 1) self.nu = self.nu.reshape(-1, 1) - + # Solve for Z_t+1 self.Z = self.X + self.nu / self.rho - (self.alpha / self.rho) * np.sign(self.Z) # Combine - Parallel(n_jobs = self.numberOfThreads, backend = "threading")( - delayed(self.combineSolution)(i) for i in range(0, self.N-1)) + #Parallel(n_jobs = self.numberOfThreads, backend = "threading")( + # delayed(self.combineSolution)(i) for i in range(0, self.N-1)) + + process = [] + for i in range(0, self.N-1): + p = Process(target = self.combineSolution, args= (i,)) + p.start() + process.append(p) + + for p in process: + p.join() def step_iterative(self): # Solve for X_t+1 for i in range(0, self.N-1): t = self.solveIndividual(i) self.XBar[i] = t.T - - self.X = np.average(self.XBar, axis = 0) + + self.X = np.average(self.XBar, axis = 0) self.nu = np.average(self.nuBar, axis = 0) - + self.X = self.X.reshape(-1, 1) self.nu = self.nu.reshape(-1, 1) diff --git a/pytorch_test.py b/pytorch_test.py new file mode 100644 index 0000000..f9b2667 --- /dev/null +++ b/pytorch_test.py @@ -0,0 +1,18 @@ +import numpy as np +import torch.optim as optim + +num_iterations = 5 +N = 1000 +D = 20 + +A = np.random.randn(N, D) +b = np.random.randn(N, 1) +# Y = np.random.rand(N, 1) + +admm = optim.ADMM([A, b, 0.01], "Lasso") #, parallel = True) + +print(admm.getLoss()) +for i in range(0, num_iterations): + print('O Val:', admm.step()) + +print('Weights: ',admm.getWeights()) diff --git a/test.py b/test.py index e76227e..af2b057 100644 --- a/test.py +++ b/test.py @@ -3,14 +3,14 @@ num_iterations = 20 -N = 100000 -D = 200 +N = 100 +D = 20 A = np.random.randn(N, D) b = np.random.randn(N, 1) # Y = np.random.rand(N, 1) -admm = ADMM(A, b, parallel = True) +admm = ADMM(A, b, parallel = True) print(admm.LassoObjective()) for i in range(0, num_iterations):