diff --git a/mothernet/notebooks/regression.ipynb b/mothernet/notebooks/regression.ipynb index 02b16084..1bb08827 100644 --- a/mothernet/notebooks/regression.ipynb +++ b/mothernet/notebooks/regression.ipynb @@ -10,10 +10,11 @@ "outputs": [], "source": [ "from mothernet.prediction.mothernet_additive import MotherNetAdditiveClassifier, MotherNetAdditiveRegressor\n", + "from mothernet.utils import get_mn_model\n", "\n", "#model_path = \"../models_diff/baam_Daverage_maxnumclasses0_nsamples500_numfeatures10_yencoderlinear_04_30_2024_17_02_37_epoch_10.cpkt\"\n", "#model_path = \"../models_diff/baam_Daverage_l1e-05_maxnumclasses0_nsamples500_numfeatures10_yencoderlinear_05_01_2024_15_16_49_epoch_30.cpkt\"\n", - "model_path = \"../models_diff/baam_Daverage_l1e-05_maxnumclasses0_nsamples500_numfeatures10_yencoderlinear_05_08_2024_03_04_01_epoch_40.cpkt\"\n", + "model_path = get_mn_model(\"baam_Daverage_l1e-05_maxnumclasses0_nsamples500_numfeatures10_yencoderlinear_05_08_2024_03_04_01_epoch_40.cpkt\")\n", "reg = MotherNetAdditiveRegressor(device='cpu', path=model_path)" ] }, @@ -98,7 +99,9 @@ } ], "source": [ - "plt.plot(y_test, ss.inverse_transform(y_pred), 'o')" + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(y_test, ss.inverse_transform(y_pred.reshape(-1, 1)), 'o')" ] }, {