Skip to content

Commit ae1c5d7

Browse files
Fix #8
1 parent 8004180 commit ae1c5d7

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

extra_keras_datasets/iris.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,10 @@ def load_data(path="iris.npz", test_split=0.2):
8080
testing_data = samples[:num_test_samples]
8181

8282
# Split into inputs and targets
83-
input_train = [i[0:4] for i in training_data]
84-
input_test = [i[0:4] for i in testing_data]
85-
target_train = [i[4] for i in training_data]
86-
target_test = [i[4] for i in testing_data]
83+
input_train = np.array([i[0:4] for i in training_data])
84+
input_test = np.array([i[0:4] for i in testing_data])
85+
target_train = np.array([i[4] for i in training_data])
86+
target_test = np.array([i[4] for i in testing_data])
8787

8888
# Warn about citation
8989
warn_citation()

0 commit comments

Comments
 (0)