"cells": [
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "%%HTML\n<style type="text/css">\ntable.dataframe td, table.dataframe th {\n border: 1px black solid !important;\n color: black !important;\n}\n</style>\n",
"execution_count": 37,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<IPython.core.display.HTML object>",
"text/html": "<style type="text/css">\ntable.dataframe td, table.dataframe th {\n border: 1px black solid !important;\n color: black !important;\n}\n</style>\n"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "# Predicting heart disease using machine learning\n\nThis notebook will introduce some foundation machine learning and data science concepts by exploring the problem of heart disease classification.\n\nIt is intended to be an end-to-end example of what a data science and machine learning proof of concept might look like.\n\n## What is classification?\n\nClassification involves deciding whether a sample is part of one class or another (single-class classification). If there are multiple class options, it's referred to as multi-class classification.\n\n## What we'll end up with\n\nSince we already have a dataset, we'll approach the problem with the following machine learning modelling framework.\n\n\n\n\n\n6 Step Machine Learning Modelling Framework\n\nMore specifically, we'll look at the following topics.\n\n* Exploratory data analysis (EDA) - the process of going through a dataset and finding out more about it.\n* Model training- create model(s) to learn to predict a target variable based on other variables.\n* Model evaluation - evaluating a models predictions using problem-specific evaluation metrics.\n* Model comparison - comparing several different models to find the best one.\n* Model fine-tuning - once we've found a good model, how can we improve it?\n* Feature importance - since we're predicting the presence of heart disease, are there some things which are more important for prediction?\n* Cross-validation - if we do build a good model, can we be sure it will work on unseen data?\n* Reporting what we've found - if we had to present our work, what would we show someone?\n\nTo work through these topics, we'll use pandas, Matplotlib and NumPy for data anaylsis, as well as, Scikit-Learn for machine learning and modelling tasks.\n\n\n
\nTools which can be used for each step of the machine learning modelling process.\n\nWe'll work through each step and by the end of the notebook, we'll have a handful of models, all which can predict whether or not a person has heart disease based on a number of different parameters at a considerable accuracy.\n\nYou'll also be able to describe which parameters are more indicative than others, for example, sex may be more important than age.\n"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## 1. Problem Definition\n\nIn our case, the problem we will be exploring is binary classification (a sample can only be one of two things).\n\nThis is because we're going to be using a number of differnet features (pieces of information) about a person to predict whether they have heart disease or not.\n\nIn a statement,\n> Given clinical parameters about a patient, can we predict whether or not they have heart disease?\n\n## 2. Data\n\nWhat you'll want to do here is dive into the data your problem definition is based on. This may involve, sourcing, defining different parameters, talking to experts about it and finding out what you should expect.\n\nThe original data came from the cleveland database from the UCI Machine Learning Repository. \n\nTher is also a version of it avalable on Kaggle.\n\nThe original database contains 76 attributes, but here only 14 attributes will be used. Attributes (also called features) are the variables what we'll use to predict our target variable.\n\nAttributes and features are also referred to as independent variables and a target variable can be referred to as a dependent variable.\n\n> We use the independent variables to predict our dependent variable.\n\nOr in our case, the independent variables are a patients different medical attributes and the dependent variable is whether or not they have heart disease.\n\n## 3. Evaluation\n\nThe evaluation metric is something you might define at the start of a project.\n\nSince machine learning is very experimental, you might say something like,\n\n>If we can reach 95% accuracy at predicting whether or not a patient has heart disease during the proof of concept, we'll pursure this project.\n\nThe reason this is helpful is it provides a rough goal for a machine learning engineer or data scientist to work towards.\n\nHowever, due to the nature of experimentation, the evaluation metric may change over time.\n\n## 4. Features\n\nFeatures are different parts of the data. During this step, you'll want to start finding out what you can about the data.\n\nOne of the most common ways to do this, is to create a data dictionary.\n\n### Heart Disease Data Dictionary\n\nA data dictionary describes the data you're dealing with. Not all datasets come with them so this is where you may have to do your research or ask a subject matter expert (someone who knows about the data) for more.\n\nThe following are the features we'll use to predict our target variable (heart disease or no heart disease).\n\n1. age - age in years\n2. sex - 1 = male; 0 = female\n3. cp - chest pain type\n * 0: Typical angina: chest pain related decrease blood supply to the heart\n * 1: Atypical angina: chest pain not related to heart\n * 2: Non-anginal pain: typically esophgael spasms (non heart related)\n * 3: Asymptomatic: chest pain not showing signs of disease\n4. trestbps - resting blood pressure (in mm Hg on admission to the hospital)\n5. chol - serum cholestoral in mg/dl\n * serum = LDL + HDL + .2 * triglycerides\n * above 200 is cause for concern\n6. fbs(fasting blood sugar > 120 mg/dl) (1 = true; 0 = false)\n * '>126' mg/dL signals diabetes\n7. restecg - resting electrocardiographic results\n * 0: Nothing to note\n * 1: ST-T Wave abnormality\n * can range from mild symptoms to severe problems\n * signals non-normal heart beat\n * 2: Possible or definite left venticular hypertrophy\n * Enlarged heart's main pumping chamber\n8. thalach - maximum heart rate achieved\n9. exang - exercise induced angina (1 = yes; 0 = no)\n10. oldpeak - ST depression induced by exercise relative to rest\n * looks at stress of heart during excercise\n * unhealthy heart will stress more\n11. slope - the slope of the peak exercise ST segment\n * 0: Unsloping: better heart rate with excercise (uncommon)\n * 1: Flatsloping: minimal change (typical healthy heart)\n * 2: Downsloping: Signs of unhealthy heart\n12. ca - number of major vessels (0-3) colored by flourosopy\n * colored vessel means the doctor can see the blood passing through\n * the more blood movement the better (no clots)\n13. thal - thalium stress result\n * 1,3 = normal \n * 6 = fixed defect: Used to be defect but ok now \n * 7 = reversable defect: no proper blood movement when excercising\n14. target - have disease or not (1=yes, 0=no)(= the predicted attribute\n\nNote: No personal identifiable information (PPI) can be found in the dataset.\n\nIt's a good idea to save these to a Python dictionary or in an external file, so we can look at them later without coming back here."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Preparing the tools\n\nAt the start of any project, it's custom to see the required libraries imported in a big chunk like you can see below.\n\nHowever, in practice, your projects may import libraries as you go. After you've spent a couple of hours working on your problem, you'll probably want to do some tidying up. This is where you may want to consolidate every library you've used at the top of your notebook (like the cell below).\n\nThe libraries you use will differ from project to project. But there are a few which will you'll likely take advantage of during almost every structured data project.\n\n* pandas for data analysis.\n* NumPy for numerical operations.\n* Matplotlib/seaborn for plotting or data visualization.\n* Scikit-Learn for machine learning modelling and evaluation."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "# Import all the tools we need\n\n# Regular EDA (exploratory data analysis) and plotting libraries\nimport numpy as np\nimport pandas as pd\nimport matplotlib.pyplot as plt\nimport seaborn as sns\n\n# we want our plots to appear inside the notebook\n%matplotlib inline\n\n# Models from Scikit-Learn\nfrom sklearn.linear_model import LogisticRegression\nfrom sklearn.neighbors import KNeighborsClassifier\nfrom sklearn.ensemble import RandomForestClassifier\n\n# Model Evaluations\nfrom sklearn.model_selection import train_test_split, cross_val_score\nfrom sklearn.model_selection import RandomizedSearchCV, GridSearchCV\nfrom sklearn.metrics import confusion_matrix, classification_report\nfrom sklearn.metrics import precision_score, recall_score, f1_score\nfrom sklearn.metrics import plot_roc_curve",
"execution_count": 38,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Load Data\n\nThere are many different kinds of ways to store data. The typical way of storing tabular data, data similar to what you'd see in an Excel file is in .csv format. .csv stands for comma seperated values.\n\nPandas has a built-in function to read .csv files called read_csv() which takes the file pathname of your .csv file. You'll likely use this a lot."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "df = pd.read_csv("heart-disease.csv")\ndf.shape",
"execution_count": 39,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 39,
"data": {
"text/plain": "(303, 14)"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Data Exploration (Exploratory data analysis or EDA)\n\nOnce you've imported a dataset, the next step is to explore. There's no set way of doing this. But what you should be trying to do is become more and more familiar with the dataset.\n\nCompare different columns to each other, compare them to the target variable. Refer back to your data dictionary and remind yourself of what different columns mean.\n\nYour goal is to become a subject matter expert on the dataset you're working with. So if someone asks you a question about it, you can give them an explanation and when you start building models, you can sound check them to make sure they're not performing too well(overfitting) or they might be performing poorly (underfitting).\n\nSince EDA has no real set methodolgy, the following is a short check list you might want to walk through.\n\n1. Ehat question(s) are you trying to solve (or prove wrong)?\n2. What kind of data do you have and how do you treat different types?\n3. What's missing from the data and how do you deal with it?\n4. Where are the outliers and why should you care about them?\n5. How can you add, change or remove features to get more out of your data?\n\nOne of the quickest and easiest ways to check your data is with the
head()
function. Calling it on any dataframe will print the top 5 rows, tail()
calls the bottom 5. You can also pass a number to them like head(10)
to show the top 10 rows."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "df.head()",
"execution_count": 40,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 40,
"data": {
"text/plain": " age sex cp trestbps chol fbs restecg thalach exang oldpeak slope \\n0 63 1 3 145 233 1 0 150 0 2.3 0 \n1 37 1 2 130 250 0 1 187 0 3.5 0 \n2 41 0 1 130 204 0 0 172 0 1.4 2 \n3 56 1 1 120 236 0 1 178 0 0.8 2 \n4 57 0 0 120 354 0 1 163 1 0.6 2 \n\n ca thal target \n0 0 1 1 \n1 0 2 1 \n2 0 2 1 \n3 0 2 1 \n4 0 2 1 ",
"text/html": "
value_counts()
allows you to show how many times each of the values of a categorical column appear."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "# Let's see how many positive (1) and negative (0) samples we have in our dataframe\ndf.target.value_counts()",
"execution_count": 42,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 42,
"data": {
"text/plain": "1 165\n0 138\nName: target, dtype: int64"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Since these two values are close to even, our target column can be considered balanced. An unbalanced target column, meaning some classes have far more samples, can be harder to model than a balanced set, all of your target classes have the same number of samples.\n\nIf you'd prefer these values in percentages, value_counts()
takes a parameter, normalize which can be set to true."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "# Normalized value counts\ndf.target.value_counts(normalize=True)",
"execution_count": 43,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 43,
"data": {
"text/plain": "1 0.544554\n0 0.455446\nName: target, dtype: float64"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We can plot the target column value counts by calling the plot()
function and telling it what kind of plot we'd like, in this case, bar is good"
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "# Plot the value_counts with a bar graph\ndf.target.value_counts().plot(kind="bar", color=["salmon", "lightblue"]);",
"execution_count": 44,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD1CAYAAACrz7WZAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAOMElEQVR4nO3dbYxmZ13H8e/PLq0CMS3stJZ9cBfdgoVgaIZSJRqkCq0StjGQbIO6wSYTtSAIhrbyovqiCfgASFSSla5dkqalqdVuCKJ1bW2MtmXKQ+l2Kd200A67stMU8IGksPD3xZzqeHPPzsx97nuGvfb7eXPf539d55z/i9nfnlxzzpxUFZKktvzAejcgSRo/w12SGmS4S1KDDHdJapDhLkkNMtwlqUEb1rsBgI0bN9a2bdvWuw1JOqncf//9T1bV1LCx74tw37ZtG7Ozs+vdhiSdVJJ8eakxl2UkqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDfq+eIjpZPHtP3jXerfQlGdd+yfr3YLULK/cJalBhrskNWjZcE+yN8mxJA8O1N+W5OEkB5P84aL6NUkOd2Ovm0TTkqQTW8ma+w3AnwEffaaQ5OeAncDLqurpJGd39fOBXcBLgBcA/5jkvKr6zrgblyQtbdkr96q6G3hqoPybwHur6uluzrGuvhO4uaqerqrHgMPAhWPsV5K0AqOuuZ8H/EySe5P8c5JXdPVNwBOL5s11NUnSGhr1VsgNwFnARcArgFuSvBDIkLk17ABJZoAZgK1bt47YhiRpmFGv3OeA22rBfcB3gY1dfcuieZuBI8MOUFV7qmq6qqanpoa+SESSNKJRw/1vgdcAJDkPOB14EtgP7EpyRpLtwA7gvnE0KklauWWXZZLcBLwa2JhkDrgW2Avs7W6P/Bawu6oKOJjkFuAh4DhwpXfKSNLaWzbcq+ryJYZ+ZYn51wHX9WlKktSPT6hKUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhq0bLgn2ZvkWPfWpcGx301SSTZ220nyoSSHkzyQ5IJJNC1JOrGVXLnfAFwyWEyyBfgF4PFF5UtZeG/qDmAG+HD/FiVJq7VsuFfV3cBTQ4Y+ALwbqEW1ncBHa8E9wJlJzh1Lp5KkFRtpzT3JG4CvVNXnBoY2AU8s2p7rapKkNbTsC7IHJXk28B7gtcOGh9RqSI0kMyws3bB169bVtiFJOoFRrtx/DNgOfC7Jl4DNwKeT/AgLV+pbFs3dDBwZdpCq2lNV01U1PTU1NUIbkqSlrPrKvao+D5z9zHYX8NNV9WSS/cBbk9wMvBL4RlUdHVezkoa77WH/mY3TL7/o5P9V4UpuhbwJ+DfgRUnmklxxgumfAB4FDgN/CfzWWLqUJK3KslfuVXX5MuPbFn0v4Mr+bUmS+vAJVUlqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSg1byJqa9SY4leXBR7Y+SfCHJA0n+JsmZi8auSXI4ycNJXjepxiVJS1vJlfsNwCUDtTuAl1bVy4AvAtcAJDkf2AW8pNvnL5KcNrZuJUkrsmy4V9XdwFMDtX+oquPd5j3A5u77TuDmqnq6qh5j4V2qF46xX0nSCoxjzf3Xgb/rvm8Cnlg0NtfVJElrqFe4J3kPcBy48ZnSkGm1xL4zSWaTzM7Pz/dpQ5I0YORwT7IbeD3w5qp6JsDngC2Lpm0Gjgzbv6r2VNV0VU1PTU2N2oYkaYiRwj3JJcBVwBuq6puLhvYDu5KckWQ7sAO4r3+bkqTV2LDchCQ3Aa8GNiaZA65l4e6YM4A7kgDcU1W/UVUHk9wCPMTCcs2VVfWdSTUvSRpu2XCvqsuHlK8/wfzrgOv6NCVJ6scnVCWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDVo23JPsTXIsyYOLas9LckeSR7rPs7p6knwoyeEkDyS5YJLNS5KGW8mV+w3AJQO1q4EDVbUDONBtA1zKwntTdwAzwIfH06YkaTWWDfequht4aqC8E9jXfd8HXLao/tFacA9wZpJzx9WsJGllRl1zP6eqjgJ0n2d39U3AE4vmzXU1SdIaGvcvVDOkVkMnJjNJZpPMzs/Pj7kNSTq1jRruX31muaX7PNbV54Ati+ZtBo4MO0BV7amq6aqanpqaGrENSdIwo4b7fmB39303cPui+q91d81cBHzjmeUbSdLa2bDchCQ3Aa8GNiaZA64F3gvckuQK4HHgTd30TwC/CBwGvgm8ZQI9S5KWsWy4V9XlSwxdPGRuAVf2bUqS1I9PqEpSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGtQr3JP8TpKDSR5MclOSH0yyPcm9SR5J8rEkp4+rWUnSyowc7kk2Ab8NTFfVS4HTgF3A+4APVNUO4GvAFeNoVJK0cn2XZTYAP5RkA/Bs4CjwGuDWbnwfcFnPc0iSVmnkcK+qrwB/zMILso8C3wDuB75eVce7aXPApr5NSpJWp8+yzFnATmA78ALgOcClQ6bWEvvPJJlNMjs/Pz9qG5KkIfosy/w88FhVzVfVt4HbgJ8GzuyWaQA2A0eG7VxVe6pquqqmp6amerQhSRrUJ9wfBy5K8uwkAS4GHgLuBN7YzdkN3N6vRUnSavVZc7+XhV+cfhr4fHesPcBVwDuTHAaeD1w/hj4lSauwYfkpS6uqa4FrB8qPAhf2Oa4kqR+fUJWkBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNahXuCc5M8mtSb6Q5FCSn0ryvCR3JHmk+zxrXM1Kklam75X7nwKfrKoXAz8JHAKuBg5U1Q7gQLctSVpDI4d7kh8GfpbuHalV9a2q+jqwE9jXTdsHXNa3SUnS6vS5cn8hMA/8VZLPJPlIkucA51TVUYDu8+wx9ClJWoU+4b4BuAD4cFW9HPhvVrEEk2QmyWyS2fn5+R5tSJIG9Qn3OWCuqu7ttm9lIey/muRcgO7z2LCdq2pPVU1X1fTU1FSPNiRJg0YO96r6d+CJJC/qShcDDwH7gd1dbTdwe68OJUmrtqHn/m8DbkxyOvAo8BYW/sO4JckVwOPAm3qeQ5K0Sr3Cvao+C0wPGbq4z3ElSf34hKokNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUG9wz3JaUk+k+Tj3fb2JPcmeSTJx7q3NEmS1tA4rtzfDhxatP0+4ANVtQP4GnDFGM4hSVqFXuGeZDPwS8BHuu0ArwFu7absAy7rcw5J0ur1vXL/IPBu4Lvd9vOBr1fV8W57DtjU8xySpFUaOdyTvB44VlX3Ly4PmVpL7D+TZDbJ7Pz8/KhtSJKG6HPl/irgDUm+BNzMwnLMB4Ezk2zo5mwGjgzbuar2VNV0VU1PTU31aEOSNGjkcK+qa6pqc1VtA3YB/1RVbwbuBN7YTdsN3N67S0nSqkziPvergHcmOczCGvz1EziHJOkENiw/ZXlVdRdwV/f9UeDCcRxXkjQan1CVpAYZ7pLUIMNdkhpkuEtSgwx3SWqQ4S5JDTLcJalBhrskNchwl6QGGe6S1CDDXZIaZLhLUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBo0c7km2JLkzyaEkB5O8vas/L8kdSR7pPs8aX7uSpJXoc+V+HHhXVf0EcBFwZZLzgauBA1W1AzjQbUuS1tDI4V5VR6vq0933/wQOAZuAncC+bto+4LK+TUqSVmcsa+5JtgEvB+4Fzqmqo7DwHwBw9hL7zCSZTTI7Pz8/jjYkSZ3e4Z7kucBfA++oqv9Y6X5Vtaeqpqtqempqqm8bkqRFeoV7kmexEOw3VtVtXfmrSc7txs8FjvVrUZK0Wn3ulglwPXCoqt6/aGg/sLv7vhu4ffT2JEmj2NBj31cBvwp8Pslnu9rvAe8FbklyBfA48KZ+LUqSVmvkcK+qfwGyxPDFox5XktSfT6hKUoMMd0lqkOEuSQ0y3CWpQYa7JDXIcJekBhnuktQgw12SGmS4S1KDDHdJapDhLkkNMtwlqUGGuyQ1yHCXpAYZ7pLUIMNdkho0sXBPckmSh5McTnL1pM4jSfpeEwn3JKcBfw5cCpwPXJ7k/EmcS5L0vSZ15X4hcLiqHq2qbwE3AzsndC5J0oA+L8g+kU3AE4u254BXLp6QZAaY6Tb/K8nDE+rlVLQReHK9m1jW779/vTvQ2js5fjZPHj+61MCkwn3Yi7Pr/21U7QH2TOj8p7Qks1U1vd59SIP82Vw7k1qWmQO2LNreDByZ0LkkSQMmFe6fAnYk2Z7kdGAXsH9C55IkDZjIskxVHU/yVuDvgdOAvVV1cBLn0lAud+n7lT+bayRVtfwsSdJJxSdUJalBhrskNchwl6QGTeo+d0kiyYtZeDp9EwvPuhwB9lfVoXVt7BTglXvDkrxlvXvQqSvJVSz86ZEA97Fwi3SAm/xjgpPn3TINS/J4VW1d7z50akryReAlVfXtgfrpwMGq2rE+nZ0aXJY5ySV5YKkh4Jy17EUa8F3gBcCXB+rndmOaIMP95HcO8DrgawP1AP+69u1I/+sdwIEkj/B/f0hwK/DjwFvXratThOF+8vs48Nyq+uzgQJK71r4daUFVfTLJeSz8CfBNLFxwzAGfqqrvrGtzpwDX3CWpQd4tI0kNMtwlqUGGuyQ1yHCXpAYZ7pLUoP8B4304fv57Ts0AAAAASUVORK5CYII=\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "df.info()
shows a quick insight to the number of missing values you have and what type of data your working with.\n\nin our case, there are no missing values and all our columns are numerical in nature."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "df.info()",
"execution_count": 45,
"outputs": [
{
"output_type": "stream",
"text": "<class 'pandas.core.frame.DataFrame'>\nRangeIndex: 303 entries, 0 to 302\nData columns (total 14 columns):\n # Column Non-Null Count Dtype \n--- ------ -------------- ----- \n 0 age 303 non-null int64 \n 1 sex 303 non-null int64 \n 2 cp 303 non-null int64 \n 3 trestbps 303 non-null int64 \n 4 chol 303 non-null int64 \n 5 fbs 303 non-null int64 \n 6 restecg 303 non-null int64 \n 7 thalach 303 non-null int64 \n 8 exang 303 non-null int64 \n 9 oldpeak 303 non-null float64\n 10 slope 303 non-null int64 \n 11 ca 303 non-null int64 \n 12 thal 303 non-null int64 \n 13 target 303 non-null int64 \ndtypes: float64(1), int64(13)\nmemory usage: 33.3 KB\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Another way to get some quick insides on your dataframe is to use df.describe()
.\ndescribe()
shows a range of different metrics about your numerical columns such as mean, max and standard deviation."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "df.describe()",
"execution_count": 46,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 46,
"data": {
"text/plain": " age sex cp trestbps chol fbs \\ncount 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 \nmean 54.366337 0.683168 0.966997 131.623762 246.264026 0.148515 \nstd 9.082101 0.466011 1.032052 17.538143 51.830751 0.356198 \nmin 29.000000 0.000000 0.000000 94.000000 126.000000 0.000000 \n25% 47.500000 0.000000 0.000000 120.000000 211.000000 0.000000 \n50% 55.000000 1.000000 1.000000 130.000000 240.000000 0.000000 \n75% 61.000000 1.000000 2.000000 140.000000 274.500000 0.000000 \nmax 77.000000 1.000000 3.000000 200.000000 564.000000 1.000000 \n\n restecg thalach exang oldpeak slope ca \\ncount 303.000000 303.000000 303.000000 303.000000 303.000000 303.000000 \nmean 0.528053 149.646865 0.326733 1.039604 1.399340 0.729373 \nstd 0.525860 22.905161 0.469794 1.161075 0.616226 1.022606 \nmin 0.000000 71.000000 0.000000 0.000000 0.000000 0.000000 \n25% 0.000000 133.500000 0.000000 0.000000 1.000000 0.000000 \n50% 1.000000 153.000000 0.000000 0.800000 1.000000 0.000000 \n75% 1.000000 166.000000 1.000000 1.600000 2.000000 1.000000 \nmax 2.000000 202.000000 1.000000 6.200000 2.000000 4.000000 \n\n thal target \ncount 303.000000 303.000000 \nmean 2.313531 0.544554 \nstd 0.612277 0.498835 \nmin 0.000000 0.000000 \n25% 2.000000 0.000000 \n50% 2.000000 1.000000 \n75% 3.000000 1.000000 \nmax 3.000000 1.000000 ",
"text/html": "pd.crosstab(column_1, column_2)
.\n\nThis is helpful if you want to start gaining an intuition about hoyour independent variables interact with your dependent variables.\n\nLet's compare our target column with the sex column.\n\nRemember from our data dictionary, for the target column, 1 = heart disease present, 0 = no heart disease. and for sex 1 = male, 0 = female"
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "df.sex.value_counts()",
"execution_count": 48,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 48,
"data": {
"text/plain": "1 207\n0 96\nName: sex, dtype: int64"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "There are 207 males and 96 females in our study."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "# Compare target column with sex column\npd.crosstab(df.target, df.sex)",
"execution_count": 49,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 49,
"data": {
"text/plain": "sex 0 1\ntarget \n0 24 114\n1 72 93",
"text/html": "plot()
function and passing it a few parameters such as, kind
(type of plot you want), figsize=(length, width)
(how big you want it to be) and color=[color_1, color_2]
(the different colors you'd like to use).\n\nDifferent metrics are represented best with different kinds of plots. In our case, a bar graph is great. We'll see examples of more later. And with a bit of practice, you'll gainan intuition of which plot to use with different variables."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "# Create a plot of crosstab\npd.crosstab(df.target, df.sex).plot(kind="bar",\n figsize=(10, 6),\n color=["salmon", "lightblue"]);",
"execution_count": 50,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 720x432 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAlYAAAFvCAYAAACIOIXnAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAATzUlEQVR4nO3df6zd9X3f8debXFteSjIbuCDDhdpRgDYmS5sYFlKNZCXIgZDAMiKBGkYCEZoGBdaQYcaklH+qVurWLUrHZDUsntQCaUZnRhJ3mRuPdSNQu2XBMQUjWOCCBxcvoSkbw3Y/++OeZBbYA+75XJ9zrx8PCd17vj/f5o+rp77f7/3eaq0FAIDhHTXqAQAAFgthBQDQibACAOhEWAEAdCKsAAA6mRj1AEly3HHHtVWrVo16DACA17V9+/YXWmuTB1s3FmG1atWqbNu2bdRjAAC8rqr6/qHWuRUIANCJsAIA6ERYAQB0MhbPWAEAR5a9e/dmeno6L7/88qhHOaRly5ZlamoqS5YsecP7CCsA4LCbnp7O2972tqxatSpVNepxXqO1lj179mR6ejqrV69+w/u5FQgAHHYvv/xyjj322LGMqiSpqhx77LFv+oqasAIARmJco+rH5jKfsAIA6ERYAQB0IqwAADoRVgDAgvDSSy/lox/9aN7znvfkjDPOyF133ZXt27fngx/8YN73vvdl3bp12b17d/bt25czzzwzW7duTZLcfPPNueWWWw7LjF63AAAsCJs3b86JJ56Yr3/960mSF198Meeff342bdqUycnJ3HXXXbnlllty++235ytf+UouueSSfPGLX8zmzZvzwAMPHJYZhRUAsCC8+93vzo033pibbropF154YVasWJEdO3bkvPPOS5Ls378/K1euTJKsWbMml19+eT72sY/l/vvvz9KlSw/LjMIKAFgQTjvttGzfvj3f+MY3cvPNN+e8887LmjVrcv/99x90+4cffjjLly/Pc889d9hmFFYsOHc/unvUI4ydT5y+ctQjAMy7Z599Nsccc0w+9alP5eijj86GDRsyMzOT+++/P2effXb27t2bxx57LGvWrMndd9+dPXv25L777suFF16YBx98MMuXL5/3GYUVALAgPPzww/n85z+fo446KkuWLMltt92WiYmJXHfddXnxxRezb9++3HDDDTnhhBOyfv36bNmyJSeffHKuvfbaXH/99dm4ceO8zyisAIAFYd26dVm3bt1rlt93332vWfbYY4/95PvrrrtuXuc6kNctAAB0IqwAADoRVgAAnQgrAIBOhBUAQCfCCgCgE69bAABGbu+tn+t6vCVf+Kevu83mzZtz/fXXZ//+/fnsZz+b9evXD31eV6wAgCPO/v37c8011+Sb3/xmdu7cmTvuuCM7d+4c+rjCCgA44jz44IN55zvfmXe84x1ZunRpLr300mzatGno4worAOCI88wzz+Tkk0/+yeepqak888wzQx9XWAEAR5zW2muWVdXQxxVWAMARZ2pqKk8//fRPPk9PT+fEE08c+rjCCgA44px55pnZtWtXnnzyybzyyiu588478/GPf3zo43rdAgAwcm/k9Qg9TUxM5Etf+lLWrVuX/fv358orr8yaNWuGP26H2QAAFpwLLrggF1xwQddjuhUIANCJsAIA6ERYAQB0IqwAADoRVgAAnQgrAIBOvG4BABi5ux/d3fV4nzh95etuc+WVV+bee+/N8ccfnx07dnQ5rytWAMAR6dOf/nQ2b97c9ZjCCgA4Ip1zzjk55phjuh5TWAEAdCKsAAA6EVYAAJ0IKwCATrxuAQAYuTfyeoTeLrvssmzdujUvvPBCpqamcuutt+aqq64a6pjCCgA4It1xxx3dj+lWIABAJ8IKAKCT1w2rqrq9qp6vqh0HLDumqr5VVbsGX1cMlldVfbGqHq+q71bVe+dzeABg4WqtjXqE/6+5zPdGrlh9JclHXrVsfZItrbVTk2wZfE6S85OcOvjv6iS3vemJAIBFb9myZdmzZ8/YxlVrLXv27MmyZcve1H6v+/B6a+2+qlr1qsUXJfnQ4PuNSbYmuWmw/N+02f9L36mq5VW1srXW9y8rAgAL2tTUVKanpzMzMzPqUQ5p2bJlmZqaelP7zPW3Ak/4cSy11nZX1fGD5SclefqA7aYHy14TVlV1dWavauWUU06Z4xgAwEK0ZMmSrF69etRjdNf74fU6yLKDXuNrrW1ora1tra2dnJzsPAYAwOE317B6rqpWJsng6/OD5dNJTj5gu6kkz859PACAhWOuYXVPkisG31+RZNMBy//e4LcD35/kRc9XAQBHitd9xqqq7sjsg+rHVdV0ki8k+fUkX62qq5I8leSTg82/keSCJI8n+V9JPjMPMwMAjKU38luBlx1i1bkH2bYluWbYoQAAFiJvXgcA6ERYAQB0IqwAADoRVgAAnQgrAIBOhBUAQCfCCgCgE2EFANCJsAIA6ERYAQB0IqwAADoRVgAAnQgrAIBOJkY9AACMg7sf3T3qEcbOJ05fOeoRFhxXrAAAOhFWAACdCCsAgE6EFQBAJ8IKAKATYQUA0ImwAgDoRFgBAHQirAAAOhFWAACdCCsAgE6EFQBAJ8IKAKATYQUA0ImwAgDoRFgBAHQirAAAOhFWAACdCCsAgE6EFQBAJ8IKAKATYQUA0ImwAgDoRFgBAHQirAAAOhFWAACdCCsAgE6EFQBAJ8IKAKATYQUA0ImwAgDoZKiwqqp/WFXfq6odVXVHVS2rqtVV9UBV7aqqu6pqaa9hAQDG2ZzDqqpOSnJdkrWttTOSvCXJpUl+I8lvtdZOTfKDJFf1GBQAYNwNeytwIslfq6qJJG9NsjvJLyb52mD9xiQXD3kOAIAFYc5h1Vp7JslvJnkqs0H1YpLtSX7YWts32Gw6yUkH27+qrq6qbVW1bWZmZq5jAACMjWFuBa5IclGS1UlOTPJTSc4/yKbtYPu31ja01ta21tZOTk7OdQwAgLExzK3ADyd5srU201rbm+TuJB9IsnxwazBJppI8O+SMAAALwjBh9VSS91fVW6uqkpybZGeSbye5ZLDNFUk2DTciAMDCMMwzVg9k9iH1P03y8OBYG5LclORXqurxJMcm+XKHOQEAxt7E629yaK21LyT5wqsWP5HkrGGOCwCwEHnzOgBAJ8IKAKATYQUA0ImwAgDoRFgBAHQirAAAOhFWAACdCCsAgE6EFQBAJ8IKAKATYQUA0ImwAgDoRFgBAHQirAAAOhFWAACdCCsAgE4mRj0AAIfX3ls/N+oRxtOlN456AhYBV6wAADoRVgAAnQgrAIBOhBUAQCfCCgCgE2EFANCJsAIA6ERYAQB0IqwAADoRVgAAnQgrAIBOhBUAQCfCCgCgE2EFANCJsAIA6ERYAQB0IqwAADoRVgAAnQgrAIBOhBUAQCfCCgCgE2EFANCJsAIA6ERYAQB0IqwAADoRVgAAnQgrAIBOhBUAQCfCCgCgk6HCqqqWV9XXqurPq+qRqjq7qo6pqm9V1a7B1xW9hgUAGGfDXrH6F0k2t9Z+Jsl7kjySZH2SLa21U5NsGXwGAFj05hxWVfX2JOck+XKStNZeaa39MMlFSTYONtuY5OJhhwQAWAiGuWL1jiQzSf51Vf1ZVf1OVf1UkhNaa7uTZPD1+IPtXFVXV9W2qto2MzMzxBgAAONhmLCaSPLeJLe11n4+yUt5E7f9WmsbWmtrW2trJycnhxgDAGA8DBNW00mmW2sPDD5/LbOh9VxVrUySwdfnhxsRAGBhmHNYtdb+R5Knq+r0waJzk+xMck+SKwbLrkiyaagJAQAWiIkh9//lJL9bVUuTPJHkM5mNta9W1VVJnkryySHPAQCwIAwVVq21h5KsPciqc4c5LgDAQuTN6wAAnQgrAIBOhBUAQCfCCgCgE2EFANCJsAIA6ERYAQB0IqwAADoRVgAAnQgrAIBOhBUAQCfCCgCgE2EFANCJsAIA6ERYAQB0IqwAADoRVgAAnQgrAIBOhBUAQCfCCgCgE2EFANCJsAIA6ERYAQB0IqwAADoRVgAAnQgrAIBOhBUAQCfCCgCgE2EFANCJsAIA6ERYAQB0IqwAADoRVgAAnQgrAIBOhBUAQCfCCgCgE2EFANCJsAIA6ERYAQB0IqwAADoRVgAAnQgrAIBOhBUAQCfCCgCgE2EFANDJ0GFVVW+pqj+rqnsHn1dX1QNVtauq7qqqpcOPCQAw/npcsbo+ySMHfP6NJL/VWjs1yQ+SXNXhHAAAY2+osKqqqSQfTfI7g8+V5BeTfG2wycYkFw9zDgCAhWLYK1b/PMk/SvJXg8/HJvlha23f4PN0kpMOtmNVXV1V26pq28zMzJBjAACM3pzDqqouTPJ8a237gYsPsmk72P6ttQ2ttbWttbWTk5NzHQMAYGxMDLHvLyT5eFVdkGRZkrdn9grW8qqaGFy1mkry7PBjAgCMvzlfsWqt3dxam2qtrUpyaZI/aq39UpJvJ7lksNkVSTYNPSUAwAIwH++xuinJr1TV45l95urL83AOAICxM8ytwJ9orW1NsnXw/RNJzupxXACAhcSb1wEAOhFWAACdCCsAgE6EFQBAJ8IKAKATYQUA0ImwAgDoRFgBAHQirAAAOhFWAACdCCsAgE6EFQBAJ8IKAKATYQUA0ImwAgDoRFgBAHQirAAAOhFWAACdCCsAgE6EFQBAJ8IKAKATYQUA0ImwAgDoRFgBAHQirAAAOhFWAACdCCsAgE6EFQBAJ8IKAKATYQUA0ImwAgDoRFgBAHQirAAAOhFWAACdCCsAgE6EFQBAJ8IKAKATYQUA0MnEqAfg0Pbe+rlRjzCeLr1x1BMAwEG5YgUA0ImwAgDoRFgBAHQirAAAOhFWAACdCCsAgE6EFQBAJ3MOq6o6uaq+XVWPVNX3qur6wfJjqupbVbVr8HVFv3EBAMbXMFes9iX5XGvtZ5O8P8k1VfWuJOuTbGmtnZpky+AzAMCiN+ewaq3tbq396eD7HyV5JMlJSS5KsnGw2cYkFw87JADAQtDlGauqWpXk55M8kOSE1truZDa+khx/iH2urqptVbVtZmamxxgAACM1dFhV1dFJ/m2SG1prf/FG92utbWitrW2trZ2cnBx2DACAkRsqrKpqSWaj6ndba3cPFj9XVSsH61cmeX64EQEAFoZhfiuwknw5ySOttX92wKp7klwx+P6KJJvmPh4AwMIxMcS+v5Dk8iQPV9VDg2X/OMmvJ/lqVV2V5KkknxxuRACAhWHOYdVa++MkdYjV5871uAAAC5U3rwMAdCKsAAA6EVYAAJ0IKwCAToQVAEAnwgoAoBNhBQDQibACAOhEWAEAdCKsAAA6EVYAAJ0IKwCAToQVAEAnwgoAoBNhBQDQibACAOhEWAEAdCKsAAA6EVYAAJ0IKwCAToQVAEAnwgoAoBNhBQDQibACAOhEWAEAdCKsAAA6EVYAAJ0IKwCAToQVAEAnwgoAoBNhBQDQibACAOhEWAEAdCKsAAA6EVYAAJ0IKwCAToQVAEAnwgoAoBNhBQDQibACAOhEWAEAdCKsAAA6EVYAAJ0IKwCAToQVAEAn8xJWVfWRqnq0qh6vqvXzcQ4AgHHTPayq6i1JfjvJ+UneleSyqnpX7/MAAIyb+bhidVaSx1trT7TWXklyZ5KL5uE8AABjpVprfQ9YdUmSj7TWPjv4fHmSv9lau/ZV212d5OrBx9OTPNp1EBaz45K8MOohgEXHzxbeqJ9urU0ebMXEPJysDrLsNfXWWtuQZMM8nJ9Frqq2tdbWjnoOYHHxs4Ue5uNW4HSSkw/4PJXk2Xk4DwDAWJmPsPqTJKdW1eqqWprk0iT3zMN5AADGSvdbga21fVV1bZI/TPKWJLe31r7X+zwc0dxCBuaDny0MrfvD6wAARypvXgcA6ERYAQB0IqwAADoRVgAAnczHC0Khm6r6mcz+SaSTMvui2WeT3NNae2SkgwHAQbhixdiqqpsy+7cmK8mDmX1HWiW5o6rWj3I2YPGqqs+MegYWLq9bYGxV1WNJ1rTW9r5q+dIk32utnTqayYDFrKqeaq2dMuo5WJjcCmSc/VWSE5N8/1XLVw7WAcxJVX33UKuSnHA4Z2FxEVaMsxuSbKmqXUmeHiw7Jck7k1w7sqmAxeCEJOuS/OBVyyvJfz3847BYCCvGVmttc1WdluSszD68Xpn9I99/0lrbP9LhgIXu3iRHt9YeevWKqtp6+MdhsfCMFQBAJ34rEACgE2EFANCJsALGWlUtr6p/cBjO86Gq+sB8nwdY3IQVMO6WJ3nDYVWz5vKz7UNJhBUwFA+vA2Otqu7M7J81ejTJt5P8jSQrkixJ8k9aa5uqalWSbw7Wn53k4iQfTnJTZv8M0q4k/6e1dm1VTSb5V5l9dUcy+1qPZ5J8J8n+JDNJfrm19p8Px78PWFyEFTDWBtF0b2vtjKqaSPLW1tpfVNVxmY2hU5P8dJInknygtfadqjoxs+8iem+SHyX5oyT/bRBWv5fkX7bW/riqTknyh621n62qX03yl6213zzc/0Zg8fAeK2AhqSS/VlXnZPbt+yfl/70l+/utte8Mvj8ryX9qrf3PJKmq309y2mDdh5O8q6p+fMy3V9XbDsfwwOInrICF5JeSTCZ5X2ttb1X99yTLButeOmC7evWOBzgqydmttf994MIDQgtgzjy8Doy7HyX58RWlv57k+UFU/e3M3gI8mAeTfLCqVgxuH/7dA9b9hxzwJ5Gq6ucOch6AORFWwFhrre1J8l+qakeSn0uytqq2Zfbq1Z8fYp9nkvxakgeS/MckO5O8OFh93eAY362qnUn+/mD5v0/yd6rqoar6W/P2DwIWNQ+vA4tSVR3dWvvLwRWrP0hye2vtD0Y9F7C4uWIFLFa/WlUPJdmR5Mkk/27E8wBHAFesAAA6ccUKAKATYQUA0ImwAgDoRFgBAHQirAAAOvm/egLRRdRturYAAAAASUVORK5CYII=\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Nice! But our plot is looking pretty bare. Let's add some attributes.\n\nWe'll create the plot again with crosstab()
and plot()
, then add some helpfull labels to it with plt.title(), plt.xlabel()
and more.\n\nTo add the attributes, you call them on plt
within the same cell as where you create the graph."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "# Create a plot of crosstab\npd.crosstab(df.target, df.sex).plot(kind="bar",\n figsize=(10, 6),\n color=["salmon", "lightblue"])\n\nplt.title("Heart Disease Frequency for Sex")\nplt.xlabel("0 = No disease, 1 = disease")\nplt.ylabel("Amount")\nplt.legend(["Female", "Male"])\nplt.xticks(rotation=0)",
"execution_count": 51,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 51,
"data": {
"text/plain": "(array([0, 1]), <a list of 2 Text xticklabel objects>)"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 720x432 with 1 Axes>",
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Age vs. Max Heart Rate for Heart Disease\n\nLet's try combining a couple of independent variables, such as, age and talach (maximum heart rate) and then comparing them to our target variable heart disease\n\nBecause there are so many different values for age and talach, we'll use a scatter plot."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "# Create another figure\nplt.figure(figsize=(10, 6))\n\n# Scatter with possitive examples\nplt.scatter(df.age[df.target == 1], df.thalach[df.target == 1], c="salmon")\n\n# Scatter with negative examples\nplt.scatter(df.age[df.target == 0], df.thalach[df.target == 0], c="lightblue")\n\n# Add some helpful info\nplt.title("Heart disease in function of Age and Max Heart Rate")\nplt.xlabel("Age")\nplt.ylabel("Max Heart Rate")\nplt.legend(["disease", "No Disease"])",
"execution_count": 52,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 52,
"data": {
"text/plain": "<matplotlib.legend.Legend at 0x1f4bf44be80>"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 720x432 with 1 Axes>",
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "What can we infer from this?\n\nIt seems the younger someone is, the higher their max heart rate (dots are higher on the left of the graph) and the older someone is, the more blue dots there are. But this may be because there are more dots all together on the right side of the graph (older participants)\n\nBoth of these are observational of course, but this is what we're trying to do, build an understanding of the data.\n\nLet's check the age distribution"
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "# Check the distribution of the age column with a histogram\ndf.age.plot.hist();",
"execution_count": 53,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX4AAAD4CAYAAADrRI2NAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAQv0lEQVR4nO3dfYxldX3H8ffHXSkPha7IQjesONBsEGLkwSnF0JoKalAoYCJWa83GULcP1GpsY1djqm1qAkkr0rSxbkG7Wh94UIRKq64raJo04PDQoiwGpStuF9lRoYhaKPjtH/eMDLuzu3d299y7M7/3K5nc8/vNOfd8f1n43DO/ex5SVUiS2vGMcRcgSRotg1+SGmPwS1JjDH5JaozBL0mNWTruAoZxxBFH1MTExLjLkKQF5bbbbvteVS3fvn9BBP/ExARTU1PjLkOSFpQk356r36keSWqMwS9JjTH4JakxBr8kNcbgl6TGGPyS1BiDX5IaY/BLUmMMfklqzIK4clfSjibW3jiW/W6+5Jyx7Ff7jkf8ktQYg1+SGmPwS1JjDH5JaozBL0mN6TX4kyxLcm2Se5JsSvKiJIcn2ZDk3u71WX3WIEl6ur6P+C8HPldVzwNOAjYBa4GNVbUK2Ni1JUkj0lvwJzkMeDFwJUBVPV5VDwPnA+u71dYDF/RVgyRpR30e8R8HTAMfTnJHkiuSHAIcVVUPAHSvR/ZYgyRpO31eubsUOBV4c1XdkuRy5jGtk2QNsAbgmGOO6adCaS+N6+pZaW/0ecS/BdhSVbd07WsZfBA8mGQFQPe6ba6Nq2pdVU1W1eTy5Ts8JF6StId6C/6q+i7wnSTHd11nAXcDNwCru77VwPV91SBJ2lHfN2l7M/CxJAcA9wFvZPBhc3WSi4D7gQt7rkGSNEuvwV9VdwKTc/zqrD73K0naOa/claTGGPyS1BiDX5IaY/BLUmMMfklqjMEvSY0x+CWpMQa/JDXG4Jekxhj8ktQYg1+SGmPwS1JjDH5JaozBL0mNMfglqTEGvyQ1xuCXpMYY/JLUGINfkhpj8EtSYwx+SWqMwS9JjTH4JakxBr8kNWZpn2+eZDPwQ+BJ4ImqmkxyOHAVMAFsBl5TVQ/1WYck6SmjOOJ/SVWdXFWTXXstsLGqVgEbu7YkaUTGMdVzPrC+W14PXDCGGiSpWX0HfwFfSHJbkjVd31FV9QBA93rkXBsmWZNkKsnU9PR0z2VKUjt6neMHzqiqrUmOBDYkuWfYDatqHbAOYHJysvoqUJJa0+sRf1Vt7V63AdcBpwEPJlkB0L1u67MGSdLT9Rb8SQ5JcujMMvBy4GvADcDqbrXVwPV91SBJ2lGfUz1HAdclmdnPx6vqc0m+Clyd5CLgfuDCHmuQJG2nt+CvqvuAk+bo/z5wVl/7lSTtmlfuSlJjDH5JaozBL0mNMfglqTEGvyQ1xuCXpMYY/JLUGINfkhpj8EtSYwx+SWqMwS9JjTH4JakxBr8kNcbgl6TGGPyS1BiDX5IaY/BLUmP6fPSipEVoYu2NY9v35kvOGdu+FxOP+CWpMQa/JDXG4Jekxhj8ktQYg1+SGmPwS1Jjeg/+JEuS3JHks1372CS3JLk3yVVJDui7BknSU0ZxxP8WYNOs9qXAZVW1CngIuGgENUiSOkMFf5Ln78mbJ1kJnANc0bUDnAlc262yHrhgT95bkrRnhj3i//sktyb5gyTL5vH+7wfeDvy0az8beLiqnujaW4Cj59owyZokU0mmpqen57FLSdKuDBX8VfWrwOuB5wBTST6e5GW72ibJucC2qrptdvdcb7+Tfa6rqsmqmly+fPkwZUqShjD0vXqq6t4k7wKmgL8BTummbt5ZVZ+eY5MzgPOSvBI4EDiMwV8Ay5Is7Y76VwJb93YQkqThDTvH/4IklzH4kvZM4Deq6oRu+bK5tqmqd1TVyqqaAF4LfKmqXg/cBLy6W201cP3eDUGSNB/DzvH/LXA7cFJVXVxVtwNU1VbgXfPc558Cb0vyTQZz/lfOc3tJ0l4YdqrnlcBPqupJgCTPAA6sqh9X1Ud3t3FV3Qzc3C3fB5y2R9VKkvbasEf8XwQOmtU+uOuTJC0wwx7xH1hVj840qurRJAf3VJMWKB/QIS0Mwx7x/yjJqTONJC8EftJPSZKkPg17xP9W4JokM6dergB+s5+SJEl9Gir4q+qrSZ4HHM/gIqx7qur/eq1MmodxTjNJC818Hrb+y8BEt80pSaiqj/RSlSSpN0MFf5KPAr8E3Ak82XUXYPBL0gIz7BH/JHBiVc15Xx1J0sIx7Fk9XwN+sc9CJEmjMewR/xHA3UluBR6b6ayq83qpSpLUm2GD/z19FiFJGp1hT+f8cpLnAquq6ovdVbtL+i1NktSHYW/L/CYGj0v8YNd1NPCZvoqSJPVn2C93L2bwYJVHYPBQFuDIvoqSJPVn2OB/rKoen2kkWcpOHpkoSdq/DRv8X07yTuCg7lm71wD/3F9ZkqS+DBv8a4Fp4C7gd4F/Yf5P3pIk7QeGPavnp8A/dD/az3nDMkm7Muy9ev6LOeb0q+q4fV6RJKlX87lXz4wDgQuBw/d9OZKkvg01x19V35/1899V9X7gzJ5rkyT1YNipnlNnNZ/B4C+AQ3upSJLUq2Gnev561vITwGbgNfu8GklS74Y9q+clfRciSRqNYad63rar31fV++bY5kDgK8DPdfu5tqreneRY4JMMvhy+HXjD7KuCJUn9GvYCrkng9xncnO1o4PeAExnM8+9srv8x4MyqOgk4GTg7yenApcBlVbUKeAi4aM/LlyTN13wexHJqVf0QIMl7gGuq6nd2tkH3mMZHu+Yzu59icDbQb3X96xnc6/8D8y1ckrRnhj3iPwaYPR3zODCxu42SLElyJ7AN2AB8C3i4qp7oVtnC4C+IubZdk2QqydT09PSQZUqSdmfYI/6PArcmuY7BUfurgI/sbqOqehI4Ocky4DrghLlW28m264B1AJOTk94JVJL2kWHP6nlvkn8Ffq3remNV3THsTqrq4SQ3A6cDy5Is7Y76VwJb51mzJGkvDDvVA3Aw8EhVXQ5s6c7O2akky7sjfZIcBLwU2ATcBLy6W201cP28q5Yk7bFhT+d8N4Mze44HPszgi9p/YvBUrp1ZAaxPsoTBB8zVVfXZJHcDn0zyl8AdwJV7Ub8kaZ6GneN/FXAKg/PuqaqtSXZ5y4aq+s9um+377wNOm2edkqR9ZNipnse70zMLIMkh/ZUkSerTsMF/dZIPMvhi9k3AF/GhLJK0IA17Vs9fdc/afYTBPP+fVdWGXiuTJPVit8HffTn7+ap6KYOLsCRJC9hup3q6i7B+nOQXRlCPJKlnw57V87/AXUk2AD+a6ayqP+qlKklSb4YN/hu7H0nSArfL4E9yTFXdX1XrR1WQJKlfu5vj/8zMQpJP9VyLJGkEdhf8mbV8XJ+FSJJGY3fBXztZliQtULv7cvekJI8wOPI/qFuma1dVHdZrdZKkfW6XwV9VS0ZViCRpNOZzP35J0iJg8EtSYwx+SWqMwS9JjTH4JakxBr8kNcbgl6TGGPyS1BiDX5IaY/BLUmMMfklqTG/Bn+Q5SW5KsinJ15O8pes/PMmGJPd2r8/qqwZJ0o76POJ/AvjjqjoBOB24OMmJwFpgY1WtAjZ2bUnSiPQW/FX1QFXd3i3/ENgEHA2cD8w8ynE9cEFfNUiSdjSSOf4kE8ApwC3AUVX1AAw+HIAjd7LNmiRTSaamp6dHUaYkNaH34E/y88CngLdW1SO7W39GVa2rqsmqmly+fHl/BUpSY3oN/iTPZBD6H6uqT3fdDyZZ0f1+BbCtzxokSU/X51k9Aa4ENlXV+2b96gZgdbe8Gri+rxokSTva3TN398YZwBuAu5Lc2fW9E7gEuDrJRcD9wIU91iBJ2k5vwV9V/8bgoexzOauv/UqSds0rdyWpMQa/JDXG4Jekxhj8ktQYg1+SGmPwS1JjDH5JaozBL0mNMfglqTEGvyQ1xuCXpMYY/JLUGINfkhpj8EtSY/q8H78k7VMTa28cy343X3LOWPbbF4/4JakxBr8kNcbgl6TGGPyS1BiDX5IaY/BLUmMMfklqjMEvSY0x+CWpMb1duZvkQ8C5wLaqen7XdzhwFTABbAZeU1UP9VXDOI3rCkNJ2p0+j/j/ETh7u761wMaqWgVs7NqSpBHqLfir6ivAD7brPh9Y3y2vBy7oa/+SpLmNeo7/qKp6AKB7PXJnKyZZk2QqydT09PTICpSkxW6//XK3qtZV1WRVTS5fvnzc5UjSojHq4H8wyQqA7nXbiPcvSc0bdfDfAKzullcD1494/5LUvN6CP8kngH8Hjk+yJclFwCXAy5LcC7ysa0uSRqi38/ir6nU7+dVZfe1TkrR7++2Xu5Kkfhj8ktQYg1+SGmPwS1JjDH5JaozBL0mNMfglqTEGvyQ1xuCXpMYY/JLUGINfkhpj8EtSYwx+SWqMwS9JjTH4JakxBr8kNcbgl6TG9PYELklaLCbW3jiW/W6+5Jxe3tcjfklqjMEvSY0x+CWpMQa/JDXG4Jekxhj8ktSYsZzOmeRs4HJgCXBFVV3S177GdRqWJO2vRn7En2QJ8HfAK4ATgdclOXHUdUhSq8Yx1XMa8M2quq+qHgc+CZw/hjokqUnjmOo5GvjOrPYW4Fe2XynJGmBN13w0yTf2cR1HAN/bx++5kLQ8/pbHDm2Pf0GNPZfu9Vs8d67OcQR/5uirHTqq1gHreisimaqqyb7ef3/X8vhbHju0Pf6Wxz7bOKZ6tgDPmdVeCWwdQx2S1KRxBP9XgVVJjk1yAPBa4IYx1CFJTRr5VE9VPZHkD4HPMzid80NV9fVR10GP00gLRMvjb3ns0Pb4Wx77z6Rqh+l1SdIi5pW7ktQYg1+SGtNE8Cc5MMmtSf4jydeT/HnXf2ySW5Lcm+Sq7svmRSnJkiR3JPls125p7JuT3JXkziRTXd/hSTZ049+Q5FnjrrMPSZYluTbJPUk2JXlRQ2M/vvs3n/l5JMlbWxn/rjQR/MBjwJlVdRJwMnB2ktOBS4HLqmoV8BBw0Rhr7NtbgE2z2i2NHeAlVXXyrHO41wIbu/Fv7NqL0eXA56rqecBJDP4baGLsVfWN7t/8ZOCFwI+B62hk/LvSRPDXwKNd85ndTwFnAtd2/euBC8ZQXu+SrATOAa7o2qGRse/C+QzGDYt0/EkOA14MXAlQVY9X1cM0MPY5nAV8q6q+TZvjf5omgh9+NtVxJ7AN2AB8C3i4qp7oVtnC4HYSi9H7gbcDP+3az6adscPgQ/4LSW7rbgUCcFRVPQDQvR45tur6cxwwDXy4m+a7IskhtDH27b0W+ES33OL4n6aZ4K+qJ7s/+VYyuFHcCXOtNtqq+pfkXGBbVd02u3uOVRfd2Gc5o6pOZXBH2IuTvHjcBY3IUuBU4ANVdQrwIxqc1ui+vzoPuGbctewvmgn+Gd2fujcDpwPLksxcxLZYbx1xBnBeks0M7oR6JoO/AFoYOwBVtbV73cZgjvc04MEkKwC6123jq7A3W4AtVXVL176WwQdBC2Of7RXA7VX1YNdubfw7aCL4kyxPsqxbPgh4KYMvuW4CXt2tthq4fjwV9qeq3lFVK6tqgsGfu1+qqtfTwNgBkhyS5NCZZeDlwNcY3CZkdbfaohx/VX0X+E6S47uus4C7aWDs23kdT03zQHvj30ETV+4meQGDL3GWMPiwu7qq/iLJcQyOgg8H7gB+u6oeG1+l/Ury68CfVNW5rYy9G+d1XXMp8PGqem+SZwNXA8cA9wMXVtUPxlRmb5KczOBL/QOA+4A30v0/wCIfO0CSgxncBv64qvqfrq+Jf/tdaSL4JUlPaWKqR5L0FINfkhpj8EtSYwx+SWqMwS9JjTH4JakxBr8kNeb/AbFph2EDrMb1AAAAAElFTkSuQmCC\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We can see it's a normal distribution but slightly swaying to the right, which reflects in the scatter plot above.\n\nLet's keep going."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Heart Disease Frequency per Chest Pain Type\n\nLet's try another independent variable. This time, cp(chest pain).\n\nWe'll use the same process as we did before with sex."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "pd.crosstab(df.cp, df.target)",
"execution_count": 54,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 54,
"data": {
"text/plain": "target 0 1\ncp \n0 104 39\n1 9 41\n2 18 69\n3 7 16",
"text/html": "df.corr()
which will create a correlation matrix for us, in other words, a big table of numbers telling us how related each variable is to other."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "# Find the correlation between our independent variables\ncorr_matrix = df.corr()\ncorr_matrix",
"execution_count": 56,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 56,
"data": {
"text/plain": " age sex cp trestbps chol fbs \\nage 1.000000 -0.098447 -0.068653 0.279351 0.213678 0.121308 \nsex -0.098447 1.000000 -0.049353 -0.056769 -0.197912 0.045032 \ncp -0.068653 -0.049353 1.000000 0.047608 -0.076904 0.094444 \ntrestbps 0.279351 -0.056769 0.047608 1.000000 0.123174 0.177531 \nchol 0.213678 -0.197912 -0.076904 0.123174 1.000000 0.013294 \nfbs 0.121308 0.045032 0.094444 0.177531 0.013294 1.000000 \nrestecg -0.116211 -0.058196 0.044421 -0.114103 -0.151040 -0.084189 \nthalach -0.398522 -0.044020 0.295762 -0.046698 -0.009940 -0.008567 \nexang 0.096801 0.141664 -0.394280 0.067616 0.067023 0.025665 \noldpeak 0.210013 0.096093 -0.149230 0.193216 0.053952 0.005747 \nslope -0.168814 -0.030711 0.119717 -0.121475 -0.004038 -0.059894 \nca 0.276326 0.118261 -0.181053 0.101389 0.070511 0.137979 \nthal 0.068001 0.210041 -0.161736 0.062210 0.098803 -0.032019 \ntarget -0.225439 -0.280937 0.433798 -0.144931 -0.085239 -0.028046 \n\n restecg thalach exang oldpeak slope ca \\nage -0.116211 -0.398522 0.096801 0.210013 -0.168814 0.276326 \nsex -0.058196 -0.044020 0.141664 0.096093 -0.030711 0.118261 \ncp 0.044421 0.295762 -0.394280 -0.149230 0.119717 -0.181053 \ntrestbps -0.114103 -0.046698 0.067616 0.193216 -0.121475 0.101389 \nchol -0.151040 -0.009940 0.067023 0.053952 -0.004038 0.070511 \nfbs -0.084189 -0.008567 0.025665 0.005747 -0.059894 0.137979 \nrestecg 1.000000 0.044123 -0.070733 -0.058770 0.093045 -0.072042 \nthalach 0.044123 1.000000 -0.378812 -0.344187 0.386784 -0.213177 \nexang -0.070733 -0.378812 1.000000 0.288223 -0.257748 0.115739 \noldpeak -0.058770 -0.344187 0.288223 1.000000 -0.577537 0.222682 \nslope 0.093045 0.386784 -0.257748 -0.577537 1.000000 -0.080155 \nca -0.072042 -0.213177 0.115739 0.222682 -0.080155 1.000000 \nthal -0.011981 -0.096439 0.206754 0.210244 -0.104764 0.151832 \ntarget 0.137230 0.421741 -0.436757 -0.430696 0.345877 -0.391724 \n\n thal target \nage 0.068001 -0.225439 \nsex 0.210041 -0.280937 \ncp -0.161736 0.433798 \ntrestbps 0.062210 -0.144931 \nchol 0.098803 -0.085239 \nfbs -0.032019 -0.028046 \nrestecg -0.011981 0.137230 \nthalach -0.096439 0.421741 \nexang 0.206754 -0.436757 \noldpeak 0.210244 -0.430696 \nslope -0.104764 0.345877 \nca 0.151832 -0.391724 \nthal 1.000000 -0.344029 \ntarget -0.344029 1.000000 ",
"text/html": "df.column.hist()
)\n * Missing values (df.info()
)\n * Outliers\n \nLet's build some models."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## 5. Modeling"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "We've explored the data, now we'll try to use machine learning to predict our target variable based on the 13 independent variables.\n\nRemember our problem?\n>Given clinical parameters about a patient, can we predict whether or not they have heart disease?\n\nThat's what we'll be trying to answer.\n\nAnd remember our evaluation metric?\n\n>If we can reach 95% accuracy at predicting whether or not a patient has heart disease during the proof of concept, we'll pursure this project.\n\nThat's what we'll be aiming for.\n\nBut before we build a model, we have to get our dataset ready.\n\nLet's look at it again."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "df.head()",
"execution_count": 58,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 58,
"data": {
"text/plain": " age sex cp trestbps chol fbs restecg thalach exang oldpeak slope \\n0 63 1 3 145 233 1 0 150 0 2.3 0 \n1 37 1 2 130 250 0 1 187 0 3.5 0 \n2 41 0 1 130 204 0 0 172 0 1.4 2 \n3 56 1 1 120 236 0 1 178 0 0.8 2 \n4 57 0 0 120 354 0 1 163 1 0.6 2 \n\n ca thal target \n0 0 1 1 \n1 0 2 1 \n2 0 2 1 \n3 0 2 1 \n4 0 2 1 ",
"text/html": "303 rows Ă— 13 columns
\ntest_size
parameter is used to tell the train_test_split()
function how much of our data we want in the test set.\n\nA rule thumb is to use 80% of your data to train on and the other 20% to test on.\n\nFor our problem, a train and test set are enough. But for other problems, you could also use validation(train/validation/test) set or ross-validation (we'll see this in a second).\n\nBut again, each problem will differ. The post, How(and why) to create a good validation set by Rachel Thomas is a good place to go to learn more.\n\nLet's look at our training data."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "X_train.head()",
"execution_count": 63,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 63,
"data": {
"text/plain": " age sex cp trestbps chol fbs restecg thalach exang oldpeak \\n132 42 1 1 120 295 0 1 162 0 0.0 \n202 58 1 0 150 270 0 0 111 1 0.8 \n196 46 1 2 150 231 0 1 147 0 3.6 \n75 55 0 1 135 250 0 0 161 0 1.4 \n176 60 1 0 117 230 1 1 160 1 1.4 \n\n slope ca thal \n132 2 0 2 \n202 2 0 3 \n196 1 0 2 \n75 1 0 2 \n176 2 2 3 ",
"text/html": "model.fit(X_train, y_train)
and for scoring a model model.score(X_test, y_test).score()
returns the ration of predictions (1.0 = 100% correct).\n\nSince the algorithms we've chosen implement the same methods for fitting them to the data as well as evaluating them, let's put them in a dictionary and create a which fits and scores them.",
"attachments": {}
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "# Put models in a dictionary\nmodels = {\n "KNN": KNeighborsClassifier(),\n "Logistic Regression": LogisticRegression(),\n "Random Forest": RandomForestClassifier()\n}\n\n\n# Create function to fit and score models\ndef fit_and_score(models, X_train, X_test, y_train, y_test):\n """\n Fits and evaluates given machine learning models.\n models : a dict of different Scikit-Learn machine learning models\n X_train : training data\n X_test : testing data\n y_train : labels assosciated with training data\n y_test : labels assosciated with test data\n """\n # Random seed for reproducible results\n np.random.seed(42)\n # Make a list to keep model scores\n model_scores = {}\n # Loop through models\n for name, model in models.items():\n # Fit the model to the data\n model.fit(X_train, y_train)\n # Evaluate the model and append its score to model_scores\n model_scores[name] = model.score(X_test, y_test)\n return model_scores",
"execution_count": 66,
"outputs": []
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "model_scores = fit_and_score(models=models,\n X_train=X_train,\n X_test=X_test,\n y_train=y_train,\n y_test=y_test)\n\nmodel_scores",
"execution_count": 67,
"outputs": [
{
"output_type": "stream",
"text": "C:\Users\eyeso\Desktop\ml_course\heart-disease-project\env\lib\site-packages\sklearn\linear_model\logistic.py:938: ConvergenceWarning: lbfgs failed to converge (status=1):\nSTOP: TOTAL NO. of ITERATIONS REACHED LIMIT.\n\nIncrease the number of iterations (max_iter) or scale the data as shown in:\n https://scikit-learn.org/stable/modules/preprocessing.html\nPlease also refer to the documentation for alternative solver options:\n https://scikit-learn.org/stable/modules/linear_model.html#logistic-regression\n n_iter_i = check_optimize_result(\n",
"name": "stderr"
},
{
"output_type": "execute_result",
"execution_count": 67,
"data": {
"text/plain": "{'KNN': 0.6885245901639344,\n 'Logistic Regression': 0.8852459016393442,\n 'Random Forest': 0.8360655737704918}"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Beautiful! Since our models are fitting, let's compare them visually.\n\n### Model Comparison\n\nSince we've saved our models scores to a dictionary, we can plot them by first converting them to a DataFrame."
},
{
"metadata": {
"init_cell": true,
"trusted": true
},
"cell_type": "code",
"source": "model_compare = pd.DataFrame(model_scores, index=["accuracy"])\nmodel_compare.T.plot.bar();",
"execution_count": 68,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Beautiful! We can't really see it from the graph but looking at the dictionary, the LogisticRegression() model performs best.\n\nSince you've found the best model. Let's take it to the boss and show her what we've found.\n\n> You: I've found it!\n Her: Nice one! What did you find?
\n You: The best algorithm for predicting heart disease is a LogisticRegression!
\n Her: Excellent. I'm supriced the hyperparameter tuning is finished by now
\n You: wonders what hyperparameter tuning is
\n You: Ummm yeah, me too, it went pretty quick.
\n Her: I'm very proud, how about you put together a classification report to show the team, and be sure to include a confusion matrix, and the cross-validated precission, recall and F1 scores. I'd also be curious to see what features are most important. Oh and don't forget to include a ** ROC curve**.
\n You: ask self, "what are those?"
\n You: Of course! I'll have to you by tomorrow.\n \nAlright, there were a few words in there which could sound made up to someone who's not a budding data scientist like yourself. But being the budding data scientist you are, you know data scientists make up words all the time.\n\nLet's briefly go through each before we see them in action.\n\n* Hyperparameter tuning - Each model you use has a series of dials you can turn to dictate how they perform. Changing these values may increase or decrease model performance.\n* Feature importance - If there are a large amount of features we're using to make predictions, do some have more importance than others? For example, for predicting heart disease, which is more important, sex or age?\n* Confusion matrix - Compares the predicted values with the true values in a tabular way, if 100% correct, all values in the matrix will be top left to bottom right (diagnol line).\n* Cross-validation - Splits your dataset into multiple parts and train and tests your model on each part and evaluates performance as an average.\n* Precision - Proportion of true positives over total number of samples. Higher precision leads to less false positives.\n* Recall - Proportion of true positives over total number of true positives and false negatives. Higher recall leads to less false negatives.\n* F1 score - Combines precision and recall into one metric. 1 is best, 0 is worst.\n* Classification report - Sklearn has a build-in function called
classification_report()
which returns some of the main classification metrics such as precission, recal and f1-score.\n* ROC Curve - Receiver Operating Characteristic is a plot of true positive rate versus false positive rate.\n* Area Under Curve(AUC) - the area underneath the ROC curve. A perfect model achieves a score of 1.0.\n"
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Hyperparameter tuning and cross-validation\n\nTo cook your favourite dish, you know to set the oven to 180 degrees and turn the grill on. But when your roommate cooks their favourite dish, they set use 200 degrees and the fan-forced mode. Same oven, different settings, different outcomes.\n\nThe same can be done for machine learning algorithms. You can use the same algorithms but change the settings (hyperparameters) and get different results.\n\nBut just like turning the oven up too high can burn your food, the same can happen for machine learning algorithms. You change the settings and it works so well, it overfits (does too well) the data.\n\nWe're looking for the goldilocks model. One which does well on our dataset but also does well on unseen examples.\n\nTo test different hyperparameters, you could use a validation set but since we don't have much data, we'll use cross-validation.\n\nThe most common type of cross-validation is k-fold. It involves splitting your data into k-fold's and then testing a model on each. For example, let's say we had 5 folds (k = 5). This what it might look like.\n\n\n
LogisticRegression
or the RandomForestClassifier
did.\n\nBecause of this, we'll discard KNN and focus on the other two.\n\nWe've tuned KNN by hand but let's see how we can LogisticsRegression
and RandomForestClassifier
using RandomizedSearchCV.\n\nInstead of us having to manually try different hyperparameters by hand, RandomizedSearchCV
tries a number of different combinations, evaluates them and saves the best.\n\n### Tuning models with with RandomizedSearchCV\nReading the Scikit-Learn documentation for LogisticRegression, we find there's a number of different hyperparameters we can tune.\n\nThe same for RandomForestClassifier.\n\nLet's create a hyperparameter grid (a dictionary of different hyperparameters) for each and then test them out."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Create a hyperparameter grid for LogisticRegression\nlog_reg_grid = {"C": np.logspace(-4, 4, 20),\n "solver": ["liblinear"]}\n\n# Create a hyperparameter grid for RandomForestClassifier\nrf_grid = { "n_estimators": np.arange(10, 1000, 50),\n "max_depth": [None, 3, 5, 10],\n "min_samples_split": np.arange(2, 20, 2),\n "min_samples_leaf": np.arange(1, 20, 2)}",
"execution_count": 73,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now let's use RandomizedSearchCV
to try and tune our LogesticRegression
model.\n\nWe'll pass it the different hyperparameters from log_reg_grid
as well as set n_iter = 20
. This means, RandomizedSearchCV
will try 20 different combinations of hyperparameters from log_reg_grid
and save the best ones."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Tune LogisticRegression\n\n# Setup random seed\nnp.random.seed(42)\n\n# Setup random hyperparameter search for LogisticRegression\nrs_log_reg = RandomizedSearchCV(LogisticRegression(),\n param_distributions=log_reg_grid,\n cv=5,\n n_iter=20,\n verbose=True)\n\n# Fit random hyperparameter search model\nrs_log_reg.fit(X_train, y_train);",
"execution_count": 74,
"outputs": [
{
"output_type": "stream",
"text": "Fitting 5 folds for each of 20 candidates, totalling 100 fits\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n[Parallel(n_jobs=1)]: Done 100 out of 100 | elapsed: 0.6s finished\n",
"name": "stderr"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "rs_log_reg.best_params",
"execution_count": 75,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 75,
"data": {
"text/plain": "{'solver': 'liblinear', 'C': 0.23357214690901212}"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "rs_log_reg.score(X_test, y_test)",
"execution_count": 76,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 76,
"data": {
"text/plain": "0.8852459016393442"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now we've tuned LogisticRegression
using RandomizedSearchCV
, we'll do the same for RandomForestClassifier
."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Tune RandomForestClassifier\n\n# Setup random seed\nnp.random.seed()\n\n#Setup random hyperparameter search for RandomForestClassifier\nrs_rf = RandomizedSearchCV(RandomForestClassifier(),\n param_distributions=rf_grid,\n cv=5,\n n_iter=20,\n verbose=True)\n\n# Fit random hyperparameter search model\nrs_rf.fit(X_train, y_train);",
"execution_count": 77,
"outputs": [
{
"output_type": "stream",
"text": "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n",
"name": "stderr"
},
{
"output_type": "stream",
"text": "Fitting 5 folds for each of 20 candidates, totalling 100 fits\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "[Parallel(n_jobs=1)]: Done 100 out of 100 | elapsed: 2.2min finished\n",
"name": "stderr"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "rs_rf.best_params",
"execution_count": 78,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 78,
"data": {
"text/plain": "{'n_estimators': 60,\n 'min_samples_split': 2,\n 'min_samples_leaf': 11,\n 'max_depth': 3}"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "rs_rf.score(X_test, y_test)",
"execution_count": 79,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 79,
"data": {
"text/plain": "0.8852459016393442"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Excellent! Tuning the hyperparameters for each model saw a slight performance boost in both the RandomForestClassifier
and LogisticRegression
.\n\nThis is akin to tuning the settings on your oven and getting it to cook your favourite dish just right.\n\nBut since LogisticRegression
is pulling out in front, we'll try tuning it further with GridSearchCV
."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "### Tuning a model with GridSearchCV\n\nThe difference between RandomizedSearchCV
and GridSearchCV
is where RandomizedSearchCV
searches over a grid of hyperparameters performing n_iter
combinations, GridSearchCV
will test every single possible combination.\n\nIn short:\n* RandomizedSearchCV
- tries n_iter
combinations of hyperparameters and saves the best.\n* GridSearchCV
- tries every single combination of hyperparameters and saves the best.\n\nLet's see it in action."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Different hyperparameters for our LogisticRegression model\nlog_reg_grid = {"C": np.logspace(-4, 4, 30),\n "solver": ["liblinear"]}\n\n# Setup grid hyperparameter search for LogisticRegression\ngs_log_reg = GridSearchCV(LogisticRegression(),\n param_grid=log_reg_grid,\n cv=5,\n verbose=True)\n\n# Fit grid hyperparameter search model\ngs_log_reg.fit(X_train, y_train);",
"execution_count": 80,
"outputs": [
{
"output_type": "stream",
"text": "Fitting 5 folds for each of 30 candidates, totalling 150 fits\n",
"name": "stdout"
},
{
"output_type": "stream",
"text": "[Parallel(n_jobs=1)]: Using backend SequentialBackend with 1 concurrent workers.\n[Parallel(n_jobs=1)]: Done 150 out of 150 | elapsed: 0.9s finished\n",
"name": "stderr"
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Check the best hyperparameters\ngs_log_reg.best_params_",
"execution_count": 81,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 81,
"data": {
"text/plain": "{'C': 0.20433597178569418, 'solver': 'liblinear'}"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Evaluate the grid search LogisticRegression model\ngs_log_reg.score(X_test, y_test)",
"execution_count": 82,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 82,
"data": {
"text/plain": "0.8852459016393442"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "In this case, we get the same results as before since our grid only has a maximum of 20 different hyperparameter combinations.\n\nNote: If there are a large amount of hyperparameters combinations in your grid, GridSearchCV
may take a long time to try them all out. This is why it's a good idea to start with RandomizedSearchCV
, try a certain amount of combinations and then use GridSearchCV
to refine them."
},
{
"metadata": {},
"cell_type": "markdown",
"source": "## Evaluating a classification model, beyond accuracy\n\nNow we've got a tuned model, let's get some of the metrics we discussed before.\n\nWe want:\n* ROC curve and AUC score - plot_roc_curve().\n* Confusion matrix - confusion_matrix().\n* Classification report - classification_report().\n* Precision - precision_score().\n* Recall - recall_score().\n* F1-score - f1_score().\n\nLuckily, Scikit-Learn has these all built-in.\n\nTo access them, we'll have to use our model to make predictions on the test set. You can make predictions by calling predict()
on a trained model and passing it the data you'd like to predict on.\n\nWe'll make predictions on the test data."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Make predictions with tuned model\ny_preds = gs_log_reg.predict(X_test)",
"execution_count": 83,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Let's see them."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "y_preds",
"execution_count": 84,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 84,
"data": {
"text/plain": "array([0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0,\n 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n 1, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0], dtype=int64)"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "They look like our original test data labels, except different where the model has predicted wrong."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "np.array(y_test)",
"execution_count": 85,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 85,
"data": {
"text/plain": "array([0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0,\n 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1,\n 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 0], dtype=int64)"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Since we've got our prediction values we can find the metrics we want.\n\nLet's start with the ROC curve and AUC scores.\n\n### ROC Curve and UAC Scores\n\nWhat's a ROC curve?\n\nIt's a way of understanding how your model is performing by comparing the true positive rate to the false positive rate.\n\nIn our case...\n> To get an appropiate example in a real-world problem, consider a diagnostic test that seeks to determine whether a person has a certain disease. A false positive in this case occurs when the person tests positive, but does not actually have the disease.\nA false negeative, on the other hand, occurs when the person tests negative, suggesting they are healthy, when they actually do have the disease.\n\nScikit-Learn implements a function plot_roc_curve
which can help us create a ROC curve as well as calculate the area under the curve (AUC) metric.\n\nReading the documentation on the plot_roc_curve function we can see it takes (estimator, X, y)
as inputs. Where estimator
is a fitted machine learning model and X and y are the data you'd like to test on.\n\nIn our case, we'll use the GridSearchCV version of our LogesticRegression
estimator, gs_log-reg
as well as the test data, X_test
and y_test
."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Plot ROC curve and calculate AUC metric\nplot_roc_curve(gs_log_reg, X_test, y_test)",
"execution_count": 86,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 86,
"data": {
"text/plain": "<sklearn.metrics.plot.roc_curve.RocCurveDisplay at 0x1f4beaa6ee0>"
},
"metadata": {}
},
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYIAAAEGCAYAAABo25JHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3de3wV1fnv8c8jikAlWkE5aIQEQSFAQIgRbFVA6Q+shXqHeuUnUqm3n5Qe7bGt1F6OpbbeSqFUraLctFUBRTlVQesFBQQRQrWIqIG8MCKIVe48548Z0k2ys7NDMnuTzPf9euXlnpm1Z54V4n72WmtmLXN3REQkvg7KdgAiIpJdSgQiIjGnRCAiEnNKBCIiMadEICIScwdnO4Daat26tefl5WU7DBGRBmXJkiWfuvtRyY41uESQl5fH4sWLsx2GiEiDYmYfVndMXUMiIjGnRCAiEnNKBCIiMadEICISc0oEIiIxF1kiMLMHzewTM1tRzXEzs3vNbLWZLTezXlHFIiIi1YuyRfAQMCjF8cFAp/BnFDAxwlhERKQakT1H4O4vm1leiiJDgSkezIO90MyOMLO27l4WVUwSP9Pe+IhZy9ZlOwyRelFwTA63fadrvZ83m2MExwIfJ2yXhvuqMLNRZrbYzBaXl5dnJDhpHGYtW0dJ2ZZshyFyQMvmk8WWZF/SVXLcfTIwGaCoqEgr6UitFLTNYeb3+2Y7DJEDVjZbBKXAcQnbucD6LMUiIhJb2UwEs4HLw7uH+gCfa3xARCTzIusaMrPpQD+gtZmVArcBhwC4+yRgLnA2sBr4ChgRVSwiIlK9KO8aGl7DcQeujer6cuDI5p07JWVbKGibk5VrizQUerJYIpfNO3cK2uYwtGfSm9FEJNTg1iOQhkl37ogcuNQiEBGJOSUCEZGYUyIQEYk5JQIRkZjTYHED1lAmVNMtnCIHNrUIGrCGMqGabuEUObCpRdDA6bZMEakrtQhERGJOiUBEJOaUCEREYk5jBAeY2twJpLtxRKQ+qEVwgKnNnUC6G0dE6oNaBAcg3QkkIpmkFoGISMwpEYiIxJwSgYhIzCkRiIjEnBKBiEjMKRGIiMScEoGISMwpEYiIxJwSgYhIzCkRiIjEnBKBiEjMKRGIiMScEoGISMwpEYiIxJwSgYhIzCkRiIjEXKSJwMwGmdm7ZrbazG5Jcrydmc03s6VmttzMzo4yHhERqSqyRGBmTYAJwGCgABhuZgWViv0EeMzdTwKGAX+MKh4REUkuyhZBMbDa3de4+w5gBjC0UhkH9q6+fjiwPsJ4REQkiSgTwbHAxwnbpeG+ROOAS82sFJgLXJ/sRGY2yswWm9ni8vLyKGIVEYmtKBOBJdnnlbaHAw+5ey5wNvCImVWJyd0nu3uRuxcdddRREYQqIhJfUSaCUuC4hO1cqnb9XAU8BuDurwPNgNYRxiQiIpVEmQgWAZ3MLN/MmhIMBs+uVOYj4EwAM+tCkAjU9yMikkGRJQJ33wVcB8wDVhHcHbTSzG43syFhsR8CV5vZ28B04Ep3r9x9JCIiETo4ypO7+1yCQeDEfT9LeF0CfCPKGEREJDU9WSwiEnNKBCIiMadEICISc0oEIiIxp0QgIhJzSgQiIjGnRCAiEnNKBCIiMadEICISc2k9WRzOFdTO3VdHHE+jNe2Nj5i1bF2N5UrKtlDQNqfGciIi9aXGFoGZfRt4B/h7uN3TzJ6MOrDGZtaydZSUbamxXEHbHIb2rLxsg4hIdNJpEdwOnALMB3D3ZWbWMdKoGqmCtjnM/H7fbIchIrKPdMYIdrr75kr7NEOoiEgjkU6LYJWZXQQcZGb5wI3AwmjDEhGRTEmnRXAd0BvYAzwBbCNIBiIi0gik0yL4L3e/Gbh57w4zO48gKYiISAOXTovgJ0n23VrfgYiISHZU2yIws/8CBgHHmtnvEw7lEHQTiYhII5Cqa+gTYAXBmMDKhP1fALdEGZSIiGROtYnA3ZcCS81sqrtvy2BMIiKSQekMFh9rZr8CCoBme3e6+wmRRSUiIhmTzmDxQ8BfAAMGA48BMyKMSUREMiidRNDC3ecBuPv77v4ToH+0YYmISKak0zW03cwMeN/MrgHWAUdHG5aIiGRKOongJuAw4AbgV8DhwH9HGZSIiGROjYnA3d8IX34BXAZgZrlRBiUiIpmTcozAzE42s++aWetwu6uZTUGTzomINBrVJgIz+7/AVOAS4Dkzu5VgTYK3Ad06KiLSSKTqGhoK9HD3rWZ2JLA+3H43M6GJiEgmpOoa2ubuWwHc/TPgn0oCIiKNT6oWQQcz2zvVtAF5Cdu4+3k1ndzMBgH3AE2A+939jiRlLgLGEax69ra7fy/98EVEpK5SJYLzK23/oTYnNrMmwARgIFAKLDKz2e5eklCmE/Bj4BvuvsnM9HyCiEiGpZp07oU6nrsYWO3uawDMbAbBuENJQpmrgQnuvim85id1vKaIiNRSOlNM7K9jgY8TtkvDfYlOAE4ws1fNbGHYlVSFmY0ys8Vmtri8vDyicEVE4imdJ4v3lyXZ50mu3wnoB+QC/zCzbu6+eZ83uU8GJgMUFRVVPke9m/bGR8xatq5ez1lStoWCtjn1ek4RkfqQdovAzA6t5blLgeMStnMJbkGtXGaWu+909w+AdwkSQ1bNWraOkrIt9XrOgrY5DO1ZuUEkIpJ9NbYIzKwYeIBgjqF2ZtYDGOnu19fw1kVAJzPLJ5iobhhQ+Y6gp4DhwEPh08snAGtqV4VoFLTNYeb3+2Y7DBGRyKXTIrgXOAfYCODub5PGNNTuvgu4DpgHrAIec/eVZna7mQ0Ji80DNppZCcFTyz9y9421r4aIiOyvdMYIDnL3D4OZqCvsTufk7j4XmFtp388SXjswJvwREZEsSCcRfBx2D3n4bMD1wHvRhiUiIpmSTtfQaIJv7O2ADUCfcJ+IiDQC6bQIdrn7sMgjERGRrEinRbDIzOaa2RVm1jLyiEREJKNqTATufjzwS6A38I6ZPWVmaiGIiDQSaT1Q5u6vufsNQC9gC8GCNSIi0gjUmAjM7DAzu8TM5gBvAuXAqZFHJiIiGZHOYPEKYA4w3t3/EXE8IiKSYekkgg7uvifySEREJCuqTQRm9jt3/yHwNzOrMuNnOiuUiYjIgS9Vi2Bm+N9arUwmIiINS6oVyt4MX3Zx932SgZldB9R1BTMRETkApHP76H8n2XdVfQciIiLZkWqM4GKCNQTyzeyJhEMtgc3J3yUiIg1NqjGCNwnWIMgFJiTs/wJYGmVQIiKSOanGCD4APgCez1w4IiKSaam6hl5y9zPMbBP7LjpvBGvKHBl5dCIiErlUXUN7l6NsnYlAREQkO6q9ayjhaeLjgCbuvhvoC3wf+FoGYhMRkQxI5/bRpwiWqTwemAJ0AaZFGpWIiGRMOolgj7vvBM4D7nb364Fjow1LREQyJZ1EsMvMLgQuA54O9x0SXUgiIpJJ6T5Z3J9gGuo1ZpYPTI82LBERyZQap6F29xVmdgPQ0cw6A6vd/VfRhyYiIplQYyIws9OAR4B1BM8Q/C8zu8zdX406OBERiV46C9PcBZzt7iUAZtaFIDEURRmYiIhkRjpjBE33JgEAd18FNI0uJBERyaR0WgRvmdmfCFoBAJegSedERBqNdBLBNcANwP8mGCN4GbgvyqBERCRzUiYCM+sOHA886e7jMxOSiIhkUrVjBGb2fwiml7gE+LuZJVupTEREGrhUg8WXAIXufiFwMjC6tic3s0Fm9q6ZrTazW1KUu8DM3Mx0J5KISIalSgTb3f1LAHcvr6FsFWbWhGBls8FAATDczAqSlGtJMAbxRm3OLyIi9SPVGEGHhLWKDTg+ce1idz+vhnMXEzyFvAbAzGYAQ4GSSuV+AYwHxtYmcBERqR+pEsH5lbb/UMtzHwt8nLBdCpySWMDMTgKOc/enzazaRGBmo4BRAO3atatlGCIikkqqNYtfqOO5LdlpKw6aHUTw1PKVNZ3I3ScDkwGKioq8huIiIlILter3r6VSgtXN9soF1idstwS6AQvMbC3QB5itAWMRkcyKMhEsAjqZWb6ZNQWGAbP3HnT3z929tbvnuXsesBAY4u6LI4xJREQqSefJYgDM7FB3355ueXffZWbXAfOAJsCD7r7SzG4HFrv77NRnqF/T3viIWcvWpVW2pGwLBW1zIo5IROTAkM401MXAA8DhQDsz6wGMDJesTMnd5wJzK+37WTVl+6UT8P6atWxd2h/wBW1zGNpTq3GKSDyk0yK4FziH4Clj3P1tM+sfaVQRKWibw8zv9812GCIiB5R0xggOcvcPK+3bHUUwIiKSeem0CD4Ou4c8fFr4euC9aMMSEZFMSadFMBoYA7QDNhDc5lnreYdEROTAlM7i9Z8Q3PopIiKNUDp3Df2ZhCeC93L3UZFEJCIiGZXOGMHzCa+bAeey7xxCIiLSgKXTNTQzcdvMHgH+HllEIiKSUfszxUQ+0L6+AxERkexIZ4xgE/8ZIzgI+AyodrUxERFpWGpavN6AHsDeSXr2uLumgRYRaURSdg2FH/pPuvvu8EdJQESkkUlnjOBNM+sVeSQiIpIV1XYNmdnB7r4L+CZwtZm9D3xJsPKYu7uSg4hII5BqjOBNoBfw3QzFIiIiWZAqERiAu7+foVhERCQLUiWCo8xsTHUH3f33EcQjIiIZlioRNAEOI2wZiIhI45QqEZS5++0Zi0RERLIi1e2jagmIiMRAqkRwZsaiEBGRrKk2Ebj7Z5kMREREsmN/Zh8VEZFGRIlARCTmlAhERGJOiUBEJOaUCEREYk6JQEQk5pQIRERiTolARCTmIk0EZjbIzN41s9VmVmXBezMbY2YlZrbczF4ws/ZRxiMiIlVFlgjMrAkwARgMFADDzaygUrGlQJG7FwJ/BcZHFY+IiCQXZYugGFjt7mvcfQcwAxiaWMDd57v7V+HmQiA3wnhERCSJKBPBscDHCdul4b7qXAU8m+yAmY0ys8Vmtri8vLweQxQRkSgTQbJprD1pQbNLgSLgt8mOu/tkdy9y96KjjjqqHkMUEZFUC9PUVSlwXMJ2LrC+ciEzOwu4FTjD3bdHGI+IiCQRZYtgEdDJzPLNrCkwDJidWMDMTgL+BAxx908ijEVERKoRWSJw913AdcA8YBXwmLuvNLPbzWxIWOy3BOsiP25my8xsdjWnExGRiETZNYS7zwXmVtr3s4TXZ0V5fRERqZmeLBYRiTklAhGRmFMiEBGJOSUCEZGYUyIQEYk5JQIRkZhTIhARiTklAhGRmFMiEBGJOSUCEZGYUyIQEYk5JQIRkZhTIhARiTklAhGRmFMiEBGJOSUCEZGYUyIQEYk5JQIRkZhTIhARiTklAhGRmFMiEBGJuYOzHYDIgWznzp2Ulpaybdu2bIcikpZmzZqRm5vLIYcckvZ7lAhEUigtLaVly5bk5eVhZtkORyQld2fjxo2UlpaSn5+f9vvUNSSSwrZt22jVqpWSgDQIZkarVq1q3YJVIhCpgZKANCT78/eqRCAiEnNKBCIHuA0bNvC9732PDh060Lt3b/r27cuTTz6ZtOz69eu54IILkh7r168fixcvBuDBBx+ke/fuFBYW0q1bN2bNmhVZ/AB5eXl8+umnSY89++yzFBUV0aVLFzp37szYsWNZsGABffv23afcrl27aNOmDWVlZVXOcffddzNlypR9yrZu3Zof//jHKeNYsGAB55xzTspY6mrJkiV0796djh07csMNN+DuVcps2rSJc889l8LCQoqLi1mxYgUQdE0WFxfTo0cPunbtym233VbxnmHDhvGvf/2rzvEBweBCQ/rp3bu374+LJr3mF016bb/eK/FVUlKS1evv2bPH+/Tp4xMnTqzYt3btWr/33nurlN25c2fKc51xxhm+aNEi//jjj71Dhw6+efNmd3f/4osvfM2aNXWONdX127dv7+Xl5VX2v/POO96hQwdftWpVxTkmTJjgu3fv9tzcXP/ggw8qyj777LM+YMCApNft3r37Ptd/5pln/NRTT/UOHTr4nj17qo1j/vz5/u1vfztlLHV18skn+2uvveZ79uzxQYMG+dy5c6uUGTt2rI8bN87d3VetWlVRzz179vgXX3zh7u47duzw4uJif/31193dfcGCBT5y5Mik10z2dwss9mo+V3XXkEiafj5nJSXrt9TrOQuOyeG273St9viLL75I06ZNueaaayr2tW/fnuuvvx6Ahx56iGeeeYZt27bx5Zdf8uCDD3LOOeewYsUKtm7dyogRIygpKaFLly5s3boVgE8++YSWLVty2GGHAXDYYYdVvH7//fe59tprKS8vp0WLFvz5z3+mc+fOzJkzh1/+8pfs2LGDVq1aMXXqVNq0acO4ceNYv349a9eupXXr1jzyyCPcfPPNzJs3DzPj6quvroj1vvvuY86cOezcuZPHH3+czp07M378eG699VY6d+4MwMEHH8wPfvADAC688EJmzpzJzTffDMCMGTMYPnx40t9Rr169OPjg/3ycTZ8+nRtvvJGJEyeycOHCKq2LZFLFsr/KysrYsmVLxfUvv/xynnrqKQYPHrxPuZKSkorWS+fOnVm7di0bNmygTZs2Ff82O3fuZOfOnRVjAKeddhpXXnklu3bt2qfu+0NdQyIHsJUrV9KrV6+UZV5//XUefvhhXnzxxX32T5w4kRYtWrB8+XJuvfVWlixZAkCPHj1o06YN+fn5jBgxgjlz5lS8Z9SoUdx3330sWbKEO++8s+KD8Jvf/CYLFy5k6dKlDBs2jPHjx1e8Z8mSJcyaNYtp06YxefJkPvjgA5YuXcry5cu55JJLKsq1bt2at956i9GjR3PnnXcCsGLFCnr37p20XsOHD2fGjBkAbN++nblz53L++edXKffqq6/uc46tW7fywgsvcM455zB8+HCmT5+e8ve3V6pYEs2fP5+ePXtW+Tn11FOrlF23bh25ubkV27m5uaxbt65KuR49evDEE08A8Oabb/Lhhx9SWloKwO7du+nZsydHH300AwcO5JRTTgHgoIMOomPHjrz99ttp1S8VtQhE0pTqm3umXHvttbzyyis0bdqURYsWATBw4ECOPPLIKmVffvllbrjhBgAKCwspLCwEoEmTJjz33HMsWrSIF154gZtuuoklS5YwduxYXnvtNS688MKKc2zfvh0Inqe4+OKLKSsrY8eOHfvcoz5kyBCaN28OwPPPP88111xT8Q01Ma7zzjsPgN69e1d86KVy8skn8+9//5t3332XVatW0adPH77+9a9XKVdWVkaXLl0qtp9++mn69+9PixYtOP/88/nFL37BXXfdRZMmTZLeUVPbu2z69+/PsmXL0irrScYDkl3vlltu4cYbb6Rnz550796dk046qeJ32KRJE5YtW8bmzZs599xzWbFiBd26dQPg6KOPZv369WklsFQiTQRmNgi4B2gC3O/ud1Q6figwBegNbAQudve1UcYk0pB07dqVv/3tbxXbEyZM4NNPP6WoqKhi39e+9rVq31/dh5yZUVxcTHFxMQMHDmTEiBGMGTOGI444IumH3PXXX8+YMWMYMmQICxYsYNy4cUmv7+7VXvPQQw8Fgg+2Xbt2VdRvyZIl9OjRI+l7hg0bxowZM1i1alXSbiGA5s2b73Pf/PTp03n11VfJy8sDYOPGjcyfP5+zzjqLVq1asWnTJlq3bg3AZ599VvG6plj2mj9/PjfddFOV/S1atOC1117bZ19ubm7FN3sIEuoxxxxT5b05OTn85S9/AYLfYX5+fpUHwo444gj69evHc889V5EItm3bVpGE6yKyriEzawJMAAYDBcBwMyuoVOwqYJO7dwTuAn4TVTwiDdGAAQPYtm0bEydOrNj31VdfpfXe008/nalTpwJBt8fy5cuB4M6it956q6LcsmXLaN++PTk5OeTn5/P4448DwQfS3m6Hzz//nGOPPRaAhx9+uNprfutb32LSpEkVH/SfffZZyhh/9KMf8etf/5r33nsPgD179vD73/++4vjw4cN59NFHefHFFxkyZEjSc3Tp0oXVq1cDsGXLFl555RU++ugj1q5dy9q1a5kwYUJF91C/fv145JFHgKDL5dFHH6V///5pxbLX3hZB5Z/KSQCgbdu2tGzZkoULF+LuTJkyhaFDh1Ypt3nzZnbs2AHA/fffz+mnn05OTg7l5eVs3rwZCLq8nn/++YoxDID33nuPrl3r3lKNcoygGFjt7mvcfQcwA6j8GxgK7P2r+itwpunpHZEKZsZTTz3FSy+9RH5+PsXFxVxxxRX85jc1f2caPXo0//73vyksLGT8+PEUFxcDwaDj2LFj6dy5Mz179mTmzJncc889AEydOpUHHnig4nbFvbeVjhs3jgsvvJDTTjut4ht0MiNHjqRdu3YUFhbSo0cPpk2bljLGwsJC7r77boYPH06XLl3o1q3bPreHFhQU0KJFCwYMGFBty2fw4MG8/PLLADzxxBMMGDCgovUBMHToUGbPns327dv56U9/yurVq+nRowcnnXQSHTt25NJLL00rlv01ceJERo4cSceOHTn++OMrBoonTZrEpEmTAFi1ahVdu3alc+fOPPvssxX/HmVlZfTv35/CwkJOPvlkBg4cWHG764YNG2jevDlt27atc4yWrA+rPpjZBcAgdx8Zbl8GnOLu1yWUWRGWKQ233w/LfFrpXKOAUQDt2rXr/eGHH9Y6np/PWQkcGP280nCsWrVqn/5nOTCde+65jB8/nk6dOmU7lIy56667yMnJ4aqrrqpyLNnfrZktcfeiKoWJdowg2Tf7ylknnTK4+2RgMkBRUdF+ZS4lAJHG64477qCsrCxWieCII47gsssuq5dzRZkISoHjErZzgfXVlCk1s4OBw4HUnYoiIpWceOKJnHjiidkOI6NGjBhRb+eKcoxgEdDJzPLNrCkwDJhdqcxs4Irw9QXAix5VX5XIftKfpDQk+/P3GlkicPddwHXAPGAV8Ji7rzSz281s7/D/A0ArM1sNjAFuiSoekf3RrFkzNm7cqGQgDYKH6xE0a9asVu+LbLA4KkVFRb534iyRqGmFMmloqluhLFuDxSIN3iGHHFKrlZ5EGiLNNSQiEnNKBCIiMadEICIScw1usNjMyoHaP1ocaA0kXyap8VKd40F1joe61Lm9ux+V7ECDSwR1YWaLqxs1b6xU53hQneMhqjqra0hEJOaUCEREYi5uiWBytgPIAtU5HlTneIikzrEaIxARkari1iIQEZFKlAhERGKuUSYCMxtkZu+a2WozqzKjqZkdamYzw+NvmFle5qOsX2nUeYyZlZjZcjN7wczaZyPO+lRTnRPKXWBmbmYN/lbDdOpsZheF/9YrzSz1WpENQBp/2+3MbL6ZLQ3/vs/ORpz1xcweNLNPwhUckx03M7s3/H0sN7Nedb6ouzeqH6AJ8D7QAWgKvA0UVCrzA2BS+HoYMDPbcWegzv2BFuHr0XGoc1iuJfAysBAoynbcGfh37gQsBb4ebh+d7bgzUOfJwOjwdQGwNttx17HOpwO9gBXVHD8beJZghcc+wBt1vWZjbBEUA6vdfY277wBmAEMrlRkKPBy+/itwppklWzazoaixzu4+392/CjcXEqwY15Cl8+8M8AtgPNAY5pFOp85XAxPcfROAu3+S4RjrWzp1diAnfH04VVdCbFDc/WVSr9Q4FJjigYXAEWZWpxXsG2MiOBb4OGG7NNyXtIwHC+h8DrTKSHTRSKfOia4i+EbRkNVYZzM7CTjO3Z/OZGARSuff+QTgBDN71cwWmtmgjEUXjXTqPA641MxKgbnA9ZkJLWtq+/97jRrjegTJvtlXvkc2nTINSdr1MbNLgSLgjEgjil7KOpvZQcBdwJWZCigD0vl3Ppige6gfQavvH2bWzd03RxxbVNKp83DgIXf/nZn1BR4J67wn+vCyot4/vxpji6AUOC5hO5eqTcWKMmZ2MEFzMlVT7ECXTp0xs7OAW4Eh7r49Q7FFpaY6twS6AQvMbC1BX+rsBj5gnO7f9ix33+nuHwDvEiSGhiqdOl8FPAbg7q8DzQgmZ2us0vr/vTYaYyJYBHQys3wza0owGDy7UpnZwBXh6wuAFz0chWmgaqxz2E3yJ4Ik0ND7jaGGOrv75+7e2t3z3D2PYFxkiLs35HVO0/nbforgxgDMrDVBV9GajEZZv9Kp80fAmQBm1oUgEZRnNMrMmg1cHt491Af43N3L6nLCRtc15O67zOw6YB7BHQcPuvtKM7sdWOzus4EHCJqPqwlaAsOyF3HdpVnn3wKHAY+H4+IfufuQrAVdR2nWuVFJs87zgG+ZWQmwG/iRu2/MXtR1k2adfwj82cxuIugiubIhf7Ezs+kEXXutw3GP24BDANx9EsE4yNnAauArYESdr9mAf18iIlIPGmPXkIiI1IISgYhIzCkRiIjEnBKBiEjMKRGIiMScEoEccMxst5ktS/jJS1E2r7pZGmt5zQXhDJdvh9MznLgf57jGzC4PX19pZsckHLvfzArqOc5FZtYzjff8j5m1qOu1pfFSIpAD0VZ375nwszZD173E3XsQTEj429q+2d0nufuUcPNK4JiEYyPdvaReovxPnH8kvTj/B1AikGopEUiDEH7z/4eZvRX+nJqkTFczezNsRSw3s07h/ksT9v/JzJrUcLmXgY7he88M57l/J5wn/tBw/x32n/Ud7gz3jTOzsWZ2AcF8TlPDazYPv8kXmdloMxufEPOVZnbffsb5OgmTjZnZRDNbbME6BD8P991AkJDmm9n8cN+3zOz18Pf4uJkdVsN1pJFTIpADUfOEbqEnw32fAAPdvRdwMXBvkvddA9zj7j0JPohLwykHLga+Ee7fDVxSw/W/A7xjZs2Ah4CL3b07wZP4o83sSOBcoKu7FwK/THyzu/8VWEzwzb2nu29NOPxX4LyE7YuBmfsZ5yCCKSX2utXdi4BC4AwzK3T3ewnmoenv7v3DaSd+ApwV/i4XA2NquI40co1uiglpFLaGH4aJDgH+EPaJ7yaYQ6ey14FbzSwXeMLd/2VmZwK9gUXh1BrNCZJKMlPNbCuwlmAq4xOBD9z9vfD4w8C1wB8I1je438yeAdKe5trdy81sTThHzL/Ca7wanrc2cX6NYMqFxNWpLjKzUaASv38AAAGySURBVAT/X7clWKRleaX39gn3vxpepynB701iTIlAGoqbgA1AD4KWbJWFZtx9mpm9AXwbmGdmIwmm7H3Y3X+cxjUuSZyUzsySrlERzn9TTDDR2TDgOmBALeoyE7gI+CfwpLu7BZ/KacdJsFLXHcAE4DwzywfGAie7+yYze4hg8rXKDPi7uw+vRbzSyKlrSBqKw4GycI75ywi+De/DzDoAa8LukNkEXSQvABeY2dFhmSMt/fWa/wnkmVnHcPsy4KWwT/1wd59LMBCb7M6dLwimwk7mCeC7BPPozwz31SpOd99J0MXTJ+xWygG+BD43szbA4GpiWQh8Y2+dzKyFmSVrXUmMKBFIQ/FH4AozW0jQLfRlkjIXAyvMbBnQmWA5vxKCD8z/Z2bLgb8TdJvUyN23Eczs+LiZvQPsASYRfKg+HZ7vJYLWSmUPAZP2DhZXOu8moARo7+5vhvtqHWc49vA7YKy7v02wVvFK4EGC7qa9JgPPmtl8dy8nuKNpenidhQS/K4kxzT4qIhJzahGIiMScEoGISMwpEYiIxJwSgYhIzCkRiIjEnBKBiEjMKRGIiMTc/wcqA8pb5yvsOQAAAABJRU5ErkJggg==\n"
},
"metadata": {
"needs_background": "light"
}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "This is great, our model does far better then guessing which would be a line going from the bottom left corner to the top right corner, AUC = 0.5. But a perfect model would achieve an AUC score of 1.0. So there's still room for improvement.\n\nLet's move onto the next evaluation request, a confusion matrix.\n\n### Confusion matrix\n\nA confusion matrix is a visual way to show where your model made the right predictions and where it made the wrong predictions (or in other words, got confused).\n\nScikit-Learn allows us to create a confusion matrix using confusion_matrix() and passing it the true labels and predicted labels."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Confusion matrix\nprint(confusion_matrix(y_test, y_preds))",
"execution_count": 87,
"outputs": [
{
"output_type": "stream",
"text": "[[25 4]\n [ 3 29]]\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "As you can see, Scikit-Learn's build-in confusion matrix is a bit bland. For a presentation you'd proberbly want to make it visual.\n\nLet's create a function which uses Seaborn's heatmap()
for doing so."
},
{
"metadata": {
"run_control": {
"marked": false
},
"trusted": true
},
"cell_type": "code",
"source": "sns.set(font_scale=1.5)\n\ndef plot_conf_mat(y_test, y_preds):\n """\n Plots a nice looking confusion matrix using seaborn's heatmap()\n """\n fig, ax = plt.subplots(figsize=(3, 3))\n ax = sns.heatmap(confusion_matrix(y_test, y_preds),\n annot=True,\n cbar=False)\n plt.xlabel("True label")\n plt.ylabel("Predicted label")\n \nplot_conf_mat(y_test, y_preds)",
"execution_count": 88,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 216x216 with 1 Axes>",
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Beautiful! That looks much better.\n\nYou can see the model gets confused (predicts the wrong label) relatively the same across both classes. In essence, there are 4 occasions where the model predicted 0 when it should've been 1 (false negative) and 3 occasions where the model predicted 1 instead of 0 (false positive)\n\n### Classification report\n\nWe can make a classification report using classification_report() and passing it the true labels as well as our models predicted labels.\n\nA classification report will also give us information of the precision and recall of our model for each class."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "print(classification_report(y_test, y_preds))",
"execution_count": 89,
"outputs": [
{
"output_type": "stream",
"text": " precision recall f1-score support\n\n 0 0.89 0.86 0.88 29\n 1 0.88 0.91 0.89 32\n\n accuracy 0.89 61\n macro avg 0.89 0.88 0.88 61\nweighted avg 0.89 0.89 0.89 61\n\n",
"name": "stdout"
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "What's going on here?\n\nLet's get a refresh.\n* Precision - Indicates the proportion of positive Identifications (model predicted class 1) which were actually correct. A model which produces no false positives has a precision of 1.0.\n* Recall - Indicates the proportion of actual positives which were correctly classified. A model which produces no false negatives has a recall of 1.0.\n* F1 score - A combination of precision and recall. a perfect model achieves an F1 score of 1.0.\n* Support - The number of samples each metric was calculated on.\n* Accuracy - The accuracy of the model in decimal form. Perfect accuracy is equal to 1.0.\n* Macro avg - Short for macro average, the average precision, recall and F1 score between classes. Macro avg doesn’t class imbalance into effort, so if you do have class imbalances, pay attention to this metric.\n* Weighted avg - Short for weighted average, the weighted average precision, recall and F1 score between classes. Weighted means each metric is calculated with respect to how many samples there are in each class. This metric will favour the majority class (e.g. will give a high value when one class out performs another due to having more samples).\n\nOk, now we've got a few deeper insights on our model. But these where all calculated using a single training and test set.\n\nWhat we'll do to make them more solid is calculate them usingg cross-validation.\n\nHow?\n\nWe'll take the best model along with the best hyperparameters and use cross_val_score() along with various scoring
parameter values.\n\ncross_val_score()
works by taking an estimator (machine learning model) along with data and labels. It then evaluates the machine learning model on the data and labels using cross-validation and a defined scoring
parameter.\n\nLet's remind ourselves of the best hyperparameters and then see them in action."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "#check best hyperparameters\ngs_log_reg.best_params",
"execution_count": 90,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 90,
"data": {
"text/plain": "{'C': 0.20433597178569418, 'solver': 'liblinear'}"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Create a new classifier with best parameters\nclf = LogisticRegression(C=0.20433597178569418,\n solver="liblinear")",
"execution_count": 91,
"outputs": []
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now we've got an instantiated classifier, let's find some cross-validated metrics."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Cross-validates accuracy\ncv_acc = cross_val_score(clf,\n X,\n y,\n cv=5,\n scoring="accuracy")\ncv_acc",
"execution_count": 94,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 94,
"data": {
"text/plain": "array([0.81967213, 0.90163934, 0.8852459 , 0.88333333, 0.75 ])"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Since there are 5 metrics here, we'll take the average."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "cv_acc = np.mean(cv_acc)\ncv_acc",
"execution_count": 95,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 95,
"data": {
"text/plain": "0.8479781420765027"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now we'll do the same for other classification metrics."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Cross-validated precision\ncv_precision = cross_val_score(clf,\n X,\n y,\n cv=5,\n scoring="precision")\ncv_precision = np.mean(cv_precision)\ncv_precision",
"execution_count": 98,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 98,
"data": {
"text/plain": "0.8215873015873015"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Cross-validated recall\ncv_recall = cross_val_score(clf, \n X,\n y,\n cv=5,\n scoring="recall")\ncv_recall = np.mean(cv_recall)\ncv_recall",
"execution_count": 99,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 99,
"data": {
"text/plain": "0.9272727272727274"
},
"metadata": {}
}
]
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Cross_validated F1-score\ncv_f1 = cross_val_score(clf,\n X,\n y,\n cv=5,\n scoring="f1")\ncv_f1 = np.mean(cv_f1)\ncv_f1",
"execution_count": 100,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 100,
"data": {
"text/plain": "0.8705403543192143"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Okay, we've got cross validated metrics, now what?\n\nLet's visualize them."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Visualize cross-validated metrics\ncv_metrics = pd.DataFrame({"Accuracy": cv_acc,\n "Precision": cv_precision,\n "Recall": cv_recall,\n "F1": cv_f1},\n index=[0])\n\ncv_metrics.T.plot.bar(title="Cross-validated classification metrics",\n legend=False);",
"execution_count": 103,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Great! This looks like something we could share. An extension might be adding the metrics on top of each bar so someone can quickly tell they were.\n\nWhat now?\n\nThe final thing of the list of our model evaluation techniques is feature importance.\n\n### Feature importance\n\nFeature importance is another way of asking, "Which features contributing most of the outcomes of the model?"\n\nOr for our problem, trying to predict heart disease using a patient's medical characteristics, which characteristics contribute most to a model predicting whether someone has heart disease or not?\n\nUnlike some of the other functions we've seen, because how each model finds patterns in data is slightly different, how a model judges how important those patterns are is different as well. This means for each model, there's a slightly different way of finding which features were most important.\n\nYou can usually find an example via the Scikit-Learn documentation or via searching for something like "[MODEL TYPE] feature importance", such as, "random forest feature importance".\n\nSince we're using LogisticRegression
, we'll look at one way we can calculate feature importance for it.\n\nTo do so, we'll use the coef_
attribute. Looking at the Scikit-Learn documentation for LogisticRegression, the coef_
attribute is the coefficient of the features in the decision function.\n\nWe can access the coef_
attribute after we've fit an instance of LogisticRegression
."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Fit an instance of LogisticRegression\nclf = LogisticRegression(C=0.20433597178569418,\n solver="liblinear")\n\nclf.fit(X_train, y_train);",
"execution_count": 104,
"outputs": []
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Check coef_\nclf.coef_",
"execution_count": 105,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 105,
"data": {
"text/plain": "array([[ 0.00316728, -0.86044636, 0.66067051, -0.01156993, -0.00166374,\n 0.04386116, 0.31275829, 0.02459361, -0.60413071, -0.56862818,\n 0.45051626, -0.63609888, -0.67663381]])"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Looking at this it might not make much sense. But these values are how much each feature contributes to how a model makes a decision on whether patterns in a sample of patients health data leans more towards having heart disease or not.\n\nEven knowing this, in it's current form, this coef
_ array still doesn't mean much. But it will if we combine it with the columns (features) of our dataframe."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Match coef's of feature to columns\nfeature_dict = dict(zip(df.columns, list(clf.coef_[0])))\nfeature_dict",
"execution_count": 106,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 106,
"data": {
"text/plain": "{'age': 0.003167278990517721,\n 'sex': -0.8604463644626569,\n 'cp': 0.6606705054012705,\n 'trestbps': -0.011569931335912368,\n 'chol': -0.0016637438070300692,\n 'fbs': 0.04386116284216185,\n 'restecg': 0.31275829369964075,\n 'thalach': 0.02459361297137234,\n 'exang': -0.6041307139378419,\n 'oldpeak': -0.5686281825180214,\n 'slope': 0.450516263738603,\n 'ca': -0.6360988840661891,\n 'thal': -0.6766338062111971}"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Now we've match the feature coefficient to different features, let's visualize them."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "# Visualize feature importance\nfeature_df = pd.DataFrame(feature_dict, index=[0])\nfeature_df.T.plot.bar(title="Feature Importance", legend=False);",
"execution_count": 107,
"outputs": [
{
"output_type": "display_data",
"data": {
"text/plain": "<Figure size 432x288 with 1 Axes>",
"image/png": "\n"
},
"metadata": {}
}
]
},
{
"metadata": {},
"cell_type": "markdown",
"source": "You'll notice some are negative and some are positive.\n\nThe larger the value (bigger bar), the more the feature contributes to the models decision.\n\nIf the value is negative, it means there's a negative correlation. And vice versa for positive values.\n\nFor example, the sex
attribute has a negative value of -0.904, which means as the value for sex
increases, the target value decreases.\n\nWe can see this by comparing the sex
column to the target column."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "pd.crosstab(df["sex"], df["target"])",
"execution_count": 108,
"outputs": [
{
"output_type": "execute_result",
"execution_count": 108,
"data": {
"text/plain": "target 0 1\nsex \n0 24 72\n1 114 93",
"text/html": "(pd.crosstab(df[\"slope\"], df[\"target\"])
it is. As slope goes up, so does target.\n\nWhat can you do with this information?\n\nThis is something you might want to talk to a subject matter expert about. They may be interested in seeing where machine learning model is finding the most patterns (highest correlation) as well as where it's not (lowest correlation).\n\nDoing this has a few benefits:\n\n1. Finding out more - If some of the correlations and feature importances are confusing, a subject matter expert may be able to shed some light on the situation and help you figure out more.\n2. Redirecting efforts - If some features offer far more value than others, this may change how you collect data for different problems. See point 3.\n3. Less but better - Similar to above, if some features are offering far more value than others, you could reduce the number of features your model tries to find patterns in as well as improve the ones which offer the most. This could potentially lead to saving on computation, by having a model find patterns across less features, whilst still achieving the same performance levels.\n\n## 6. Experimentation\n\nWell we've completed all the metrics your boss requested. You should be able to put together a great report containing a confusion matrix, a handful of cross-valdated metrics such as precision, recall and F1 as well as which features contribute most to the model making a decision.\n\nBut after all this you might be wondering where step 6 in the framework is, experimentation.\n\nWell the secret here is, as you might've guessed, the whole thing is experimentation.\n\nFrom trying different models, to tuning different models to figuring out which hyperparameters were best.\n\nWhat we've worked through so far has been a series of experiments.\n\nAnd the truth is, we could keep going. But of course, things can't go on forever.\n\nSo by this stage, after trying a few different things, we'd ask ourselves did we meet the evaluation metric?\n\nRemember we defined one in step 3.\n\n> If we can reach 95% accuracy at predicting whether or not a patient has heart disease during the proof of concept, we'll pursure this project.\n\nIn this case, we didn't. The highest accuracy our model achieved was below 90%.\n\nWhat next?\n\nYou might be wondering, what happens when the evaluation metric doesn't get hit?\n\nIs everything we've done wasted?\n\nNo.\n\nIt means we know what doesn't work. In this case, we know the current model we're using (a tuned version of LogisticRegression
) along with our specific data set doesn't hit the target we set ourselves.\n\nThis is where step 6 comes into its own.\n\nA good next step would be to discuss with your team or research on your own different options of going forward.\n\n* Could you collect more data?\n* Could you try a better model? If you're working with structured data, you might want to look into CatBoost or XGBoost.\n* Could you improve the current models (beyond what we've done so far)?\n* If your model is good enough, how would you export it and share it with others? (Hint: check out Scikit-Learn's documentation on model persistance)\n\nThe key here is to remember, your biggest restriction will be time. Hence, why it's paramount to minimise your times between experiments.\n\nThe more you try, the more you figure out what doesn't work, the more you'll start to get a hang of what does."
},
{
"metadata": {
"trusted": true
},
"cell_type": "code",
"source": "",
"execution_count": null,
"outputs": []
}
],
"metadata": {
"_draft": {
"nbviewer_url": "https://gist.github.com/c676d6d237931116001b7280dc8ba5c4"
},
"celltoolbar": "Initialization Cell",
"gist": {
"id": "c676d6d237931116001b7280dc8ba5c4",
"data": {
"description": "end-to-end-heart-disease-classification.ipynb",
"public": true
}
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3",
"language": "python"
},
"language_info": {
"name": "python",
"version": "3.8.2",
"mimetype": "text/x-python",
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"pygments_lexer": "ipython3",