Skip to content

8. Create a machine learning model with the Iris dataset

Katie House edited this page Sep 15, 2020 · 3 revisions

In this tutorial, we are going to write a simple machine learning model and deploy it in our Django web app.

We are using the Iris dataset and Decision Trees for a multiclass classifier.

To create the model, first install Scikit-learn:

pip3 install -U scikit-learn

Add this script to a new folder in the root directory ml_model/iris_model.py. Note: this is not a machine learning tutorial and the model is super basic. The model should be revised with things like train/test splits, cross validation, hyperparameter tuning, etc. to be more robust.

from sklearn import datasets
from sklearn.tree import DecisionTreeClassifier
import pickle

# Load the Iris dataset
iris = datasets.load_iris()
X = iris.data  
y = iris.target

# Train a Decision Tree Classifier
clf = DecisionTreeClassifier(random_state=0)
clf.fit(X, y)

# Save the model as a pkl file
filename = 'ml_model/iris_model.pkl'
pickle.dump(clf, open(filename, 'wb'))

Run the script with:

python3 ml_model/iris_model.py

and it should create an ml_model/iris_model.pkl file. This is a serialized version of the trained classifier.