diff --git a/diabetes_regression/training/train.py b/diabetes_regression/training/train.py index 22258042..e13c23db 100644 --- a/diabetes_regression/training/train.py +++ b/diabetes_regression/training/train.py @@ -53,7 +53,7 @@ def train_model(data, ridge_args): # Evaluate the metrics for the model def get_model_metrics(model, data): preds = model.predict(data["test"]["X"]) - mse = mean_squared_error(preds, data["test"]["y"]) + mse = mean_squared_error(data["test"]["y"],preds) metrics = {"mse": mse} return metrics