|
1 |
| -"""Test the python functions from ./src/nn_iris.""" |
| 1 | +"""Test the python functions from ./src/nn_iris_solution.""" |
2 | 2 |
|
3 | 3 | import sys
|
| 4 | + |
4 | 5 | import numpy as np
|
5 |
| -from sklearn.model_selection import train_test_split |
6 | 6 | from sklearn.datasets import make_classification
|
7 |
| -from sklearn.model_selection import GridSearchCV |
| 7 | +from sklearn.model_selection import GridSearchCV, train_test_split |
8 | 8 |
|
9 | 9 | sys.path.insert(0, "./src/")
|
10 | 10 |
|
11 |
| -from src.nn_iris import compute_accuracy, cv_knearest_classifier |
| 11 | +from src.nn_iris_solution import compute_accuracy, cv_knearest_classifier |
12 | 12 |
|
13 | 13 |
|
14 | 14 | def test_compute_accuracy():
|
| 15 | + """Test the compute_accuracy function from the iris solution.""" |
15 | 16 | y = np.array([0, 1, 0, 1, 0])
|
16 | 17 | y_pred = np.array([0, 1, 1, 1, 0])
|
17 | 18 | acc = compute_accuracy(y, y_pred)
|
18 | 19 | assert np.allclose(acc, 0.8)
|
19 | 20 |
|
| 21 | + |
20 | 22 | def test_cv_knearest_classifier():
|
| 23 | + """Test cv_knearest_classifier.""" |
21 | 24 | # Create a dummy dataset
|
22 |
| - X, y = make_classification(n_samples=100, n_features=20, random_state=42) |
23 |
| - xtrain, xtest, ytrain, ytest = train_test_split(X, y, train_size=.75, random_state=29) |
| 25 | + in_x, y = make_classification(n_samples=100, n_features=20, random_state=42) |
| 26 | + xtrain, xtest, ytrain, ytest = train_test_split( |
| 27 | + in_x, y, train_size=0.75, random_state=29 |
| 28 | + ) |
24 | 29 |
|
25 | 30 | # Call the function
|
26 |
| - knn_cv = cv_knearest_classifier(ytrain, xtrain) |
| 31 | + knn_cv = cv_knearest_classifier(xtrain, ytrain) |
27 | 32 |
|
28 | 33 | # Check if the returned object is of the expected type
|
29 | 34 | assert isinstance(knn_cv, GridSearchCV)
|
30 | 35 |
|
31 | 36 | # Get the best score
|
32 | 37 | best_score = knn_cv.best_score_
|
33 | 38 | # Get the best score
|
34 |
| - best_params = knn_cv.best_params_['n_neighbors'] |
| 39 | + best_params = knn_cv.best_params_["n_neighbors"] |
35 | 40 |
|
36 | 41 | # Perform assertions on the best results
|
37 | 42 | assert np.allclose(best_score, 0.893333333)
|
38 | 43 | assert np.allclose(best_params, 4)
|
39 |
| - |
|
0 commit comments