From c5a998460607e168d0ad0d063f075b8ff41c627e Mon Sep 17 00:00:00 2001 From: Valerii Zuev Date: Tue, 29 Oct 2024 14:19:34 +0300 Subject: [PATCH] fix regression.ipynb --- mothernet/notebooks/regression.ipynb | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mothernet/notebooks/regression.ipynb b/mothernet/notebooks/regression.ipynb index 02b16084..1bb08827 100644 --- a/mothernet/notebooks/regression.ipynb +++ b/mothernet/notebooks/regression.ipynb @@ -10,10 +10,11 @@ "outputs": [], "source": [ "from mothernet.prediction.mothernet_additive import MotherNetAdditiveClassifier, MotherNetAdditiveRegressor\n", + "from mothernet.utils import get_mn_model\n", "\n", "#model_path = \"../models_diff/baam_Daverage_maxnumclasses0_nsamples500_numfeatures10_yencoderlinear_04_30_2024_17_02_37_epoch_10.cpkt\"\n", "#model_path = \"../models_diff/baam_Daverage_l1e-05_maxnumclasses0_nsamples500_numfeatures10_yencoderlinear_05_01_2024_15_16_49_epoch_30.cpkt\"\n", - "model_path = \"../models_diff/baam_Daverage_l1e-05_maxnumclasses0_nsamples500_numfeatures10_yencoderlinear_05_08_2024_03_04_01_epoch_40.cpkt\"\n", + "model_path = get_mn_model(\"baam_Daverage_l1e-05_maxnumclasses0_nsamples500_numfeatures10_yencoderlinear_05_08_2024_03_04_01_epoch_40.cpkt\")\n", "reg = MotherNetAdditiveRegressor(device='cpu', path=model_path)" ] }, @@ -98,7 +99,9 @@ } ], "source": [ - "plt.plot(y_test, ss.inverse_transform(y_pred), 'o')" + "import matplotlib.pyplot as plt\n", + "\n", + "plt.plot(y_test, ss.inverse_transform(y_pred.reshape(-1, 1)), 'o')" ] }, {