diff --git a/.gitignore b/.gitignore index d2b2701..129b0e2 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ /code/data/models/* /code/data/full/* /code/data/reduced/* +/code/data/notebook/* /papers/* \ No newline at end of file diff --git a/code/.ipynb_checkpoints/HAR_system-checkpoint.ipynb b/code/.ipynb_checkpoints/HAR_system-checkpoint.ipynb new file mode 100644 index 0000000..9f6e426 --- /dev/null +++ b/code/.ipynb_checkpoints/HAR_system-checkpoint.ipynb @@ -0,0 +1,1056 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# HAR system - Lincetto Riccardo, Drago Matteo\n", + "This notebook runs:\n", + "- Classification with null class (One Shot classification);\n", + "- Binary classification for activity detection (Two Steps - detection);\n", + "- Classification without null class (Two Steps - classification);\n", + "- Cascade of the last to methods.\n", + "\n", + "The operations performed here are very similar to those execute in 'main.py', with the exception that here the program is executed for specified user and model.\n", + "\n", + "## Notebook setup\n", + "This first cell contains the parameters that can be tuned for code execution:\n", + "- subject: select the subject on which to test the model, between [1,4];\n", + "- task: choose \"A\" for locomotion classification or \"B\" for gesture recognition;\n", + "- model_name: choose between \"Convolutional\", \"Convolutional1DRecurrent\", \"Convolutional2DRecurrent\" and \"ConvolutionalDeepRecurrent\";\n", + "- data_folder: directory name where '.mat' files are stored;\n", + "- window_size: parameter that sets the length of temporal windows on which to perform the convolution;\n", + "- stride: step length to chose the next window;\n", + "- GPU: boolean flag indicatin wheter GPU is present on the machine that executes the code;\n", + "- epochs: number of complete sweeps of the data signals during training;\n", + "- batch_size: number of forward propagations in the networks between consecutives backpropagations." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "subject = 1\n", + "task = \"A\"\n", + "model_name = \"Convolutional\"\n", + "data_folder = \"./data/full/\"\n", + "window_size = 15\n", + "stride = 5\n", + "GPU = True\n", + "epochs = 10\n", + "batch_size = 32" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here the useful functions are imported." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import preprocessing\n", + "import models\n", + "import utils\n", + "import os\n", + "import numpy as np\n", + "from sklearn.metrics import classification_report, f1_score, confusion_matrix\n", + "from keras.models import load_model\n", + "from keras.optimizers import Adam\n", + "from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau\n", + "from keras.utils import to_categorical\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Differently from 'main.py', all results saved from this notebook are going to be stored in a dedicated folder: './data/notebook/'. For proper execution, this folder needs first to be created." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "if not(os.path.exists(\"./data\")):\n", + " os.mkdir(\"./data\")\n", + "if not(os.path.exists(\"./data/notebook\")):\n", + " os.mkdir(\"./data/notebook\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If task A is selected, calssifications in the following notebook are based on the labels of column 0; if instead it's task B, column 6 labels are used." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Task A uses labels column 0\n" + ] + } + ], + "source": [ + "if task == \"A\":\n", + " label = 0\n", + "elif task == \"B\":\n", + " label = 6\n", + "else:\n", + " print(\"Error: invalid task.\")\n", + "print(\"Task\", task, \"uses labels column\", label)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Classification with null class: One Shot classification\n", + "Here classification is performed considering inactivity as a class, alongside with the others. In the case of locomotion classification (task A), this becomes a 5-class problem, while in the case of gesture recognition (task B) the classes become 18.\n", + "### Preprocessing" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: '../data/full/S1-ADL1.mat'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\scipy\\io\\matlab\\mio.py\u001b[0m in \u001b[0;36m_open_file\u001b[1;34m(file_like, appendmat)\u001b[0m\n\u001b[0;32m 32\u001b[0m \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 33\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfile_like\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'rb'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 34\u001b[0m \u001b[1;32mexcept\u001b[0m \u001b[0mIOError\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../data/full/S1-ADL1'", + "\nDuring handling of the above exception, another exception occurred:\n", + "\u001b[1;31mFileNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 6\u001b[0m \u001b[0mmake_binary\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mFalse\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 7\u001b[0m \u001b[0mnull_class\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;32mTrue\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 8\u001b[1;33m print_info=True)\n\u001b[0m", + "\u001b[1;32m~\\Desktop\\HDA-Project\\code\\preprocessing.py\u001b[0m in \u001b[0;36mloadData\u001b[1;34m(subject, label, folder, window_size, stride, make_binary, null_class, print_info)\u001b[0m\n\u001b[0;32m 100\u001b[0m data1, data2, data3, data4, data5, data6 = readData(subject=subject,\n\u001b[0;32m 101\u001b[0m \u001b[0mfolder\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mfolder\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 102\u001b[1;33m print_info=print_info)\n\u001b[0m\u001b[0;32m 103\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 104\u001b[0m \u001b[1;31m# create training set and test set\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\Desktop\\HDA-Project\\code\\preprocessing.py\u001b[0m in \u001b[0;36mreadData\u001b[1;34m(subject, folder, print_info)\u001b[0m\n\u001b[0;32m 20\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 21\u001b[0m \u001b[1;31m# load into dictionaries of numpy arrays\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 22\u001b[1;33m \u001b[0mdata1\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mloadmat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfilename_1\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmdict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m'features_interp'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;34m'features'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'labels_cut'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;34m'labels'\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 23\u001b[0m \u001b[0mdata2\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mloadmat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfilename_2\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmdict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m'features_interp'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;34m'features'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'labels_cut'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;34m'labels'\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 24\u001b[0m \u001b[0mdata3\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mloadmat\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfilename_3\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmdict\u001b[0m\u001b[1;33m=\u001b[0m\u001b[1;33m{\u001b[0m\u001b[1;34m'features_interp'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;34m'features'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'labels_cut'\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;34m'labels'\u001b[0m\u001b[1;33m}\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\scipy\\io\\matlab\\mio.py\u001b[0m in \u001b[0;36mloadmat\u001b[1;34m(file_name, mdict, appendmat, **kwargs)\u001b[0m\n\u001b[0;32m 139\u001b[0m \"\"\"\n\u001b[0;32m 140\u001b[0m \u001b[0mvariable_names\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'variable_names'\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 141\u001b[1;33m \u001b[0mMR\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfile_opened\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmat_reader_factory\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfile_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mappendmat\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;33m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 142\u001b[0m \u001b[0mmatfile_dict\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mMR\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mget_variables\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mvariable_names\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 143\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mmdict\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\scipy\\io\\matlab\\mio.py\u001b[0m in \u001b[0;36mmat_reader_factory\u001b[1;34m(file_name, appendmat, **kwargs)\u001b[0m\n\u001b[0;32m 62\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 63\u001b[0m \"\"\"\n\u001b[1;32m---> 64\u001b[1;33m \u001b[0mbyte_stream\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfile_opened\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_open_file\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfile_name\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mappendmat\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 65\u001b[0m \u001b[0mmjv\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmnv\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mget_matfile_version\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbyte_stream\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 66\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mmjv\u001b[0m \u001b[1;33m==\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;32m~\\Anaconda3\\lib\\site-packages\\scipy\\io\\matlab\\mio.py\u001b[0m in \u001b[0;36m_open_file\u001b[1;34m(file_like, appendmat)\u001b[0m\n\u001b[0;32m 37\u001b[0m \u001b[1;32mif\u001b[0m \u001b[0mappendmat\u001b[0m \u001b[1;32mand\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mfile_like\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mendswith\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'.mat'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 38\u001b[0m \u001b[0mfile_like\u001b[0m \u001b[1;33m+=\u001b[0m \u001b[1;34m'.mat'\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 39\u001b[1;33m \u001b[1;32mreturn\u001b[0m \u001b[0mopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mfile_like\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'rb'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m 40\u001b[0m \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 41\u001b[0m \u001b[1;32mraise\u001b[0m \u001b[0mIOError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'Reader needs file name or open file-like object'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n", + "\u001b[1;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '../data/full/S1-ADL1.mat'" + ] + } + ], + "source": [ + "X_train, Y_train, X_test, Y_test, n_features, n_classes, class_weights = preprocessing.loadData(subject=subject,\n", + " label=label,\n", + " folder=data_folder,\n", + " window_size=window_size,\n", + " stride=stride,\n", + " make_binary=False,\n", + " null_class=True,\n", + " print_info=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Y_train and Y_test contain the correct labels for each signals window. Y_test in particular will be used to evaluate predictions for both this (one-shot) and the two-steps models. For this reason it is here saved with a different name, to avoid having it being overwritten later." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "Y_test_true = Y_test" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "oneshot_model = models.ConvolutionalRecurrent((window_size, n_features), n_classes, print_info=False)\n", + "\n", + "oneshot_model.compile(optimizer = Adam(lr=0.001),\n", + " loss = \"categorical_crossentropy\", \n", + " metrics = [\"accuracy\"])\n", + "\n", + "checkpointer = ModelCheckpoint(filepath='./data/model_AOS_1.hdf5', verbose=1, save_best_only=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 31423 samples, validate on 11505 samples\n", + "Epoch 1/15\n", + "31423/31423 [==============================] - 10s 328us/step - loss: 0.4909 - acc: 0.8142 - val_loss: 0.3507 - val_acc: 0.8696\n", + "\n", + "Epoch 00001: val_loss improved from inf to 0.35067, saving model to ./data/model_AOS_1.hdf5\n", + "Epoch 2/15\n", + "31423/31423 [==============================] - 5s 172us/step - loss: 0.3290 - acc: 0.8692 - val_loss: 0.3179 - val_acc: 0.8860\n", + "\n", + "Epoch 00002: val_loss improved from 0.35067 to 0.31791, saving model to ./data/model_AOS_1.hdf5\n", + "Epoch 3/15\n", + "31423/31423 [==============================] - 5s 173us/step - loss: 0.2827 - acc: 0.8885 - val_loss: 0.3022 - val_acc: 0.9007\n", + "\n", + "Epoch 00003: val_loss improved from 0.31791 to 0.30218, saving model to ./data/model_AOS_1.hdf5\n", + "Epoch 4/15\n", + "31423/31423 [==============================] - 5s 171us/step - loss: 0.2584 - acc: 0.8982 - val_loss: 0.2984 - val_acc: 0.9000\n", + "\n", + "Epoch 00004: val_loss improved from 0.30218 to 0.29843, saving model to ./data/model_AOS_1.hdf5\n", + "Epoch 5/15\n", + "31423/31423 [==============================] - 5s 172us/step - loss: 0.2386 - acc: 0.9040 - val_loss: 0.3364 - val_acc: 0.8952\n", + "\n", + "Epoch 00005: val_loss did not improve\n", + "Epoch 6/15\n", + "31423/31423 [==============================] - 5s 173us/step - loss: 0.2284 - acc: 0.9105 - val_loss: 0.3045 - val_acc: 0.9030\n", + "\n", + "Epoch 00006: val_loss did not improve\n", + "Epoch 7/15\n", + "31423/31423 [==============================] - 5s 172us/step - loss: 0.2195 - acc: 0.9146 - val_loss: 0.2983 - val_acc: 0.9076\n", + "\n", + "Epoch 00007: val_loss improved from 0.29843 to 0.29826, saving model to ./data/model_AOS_1.hdf5\n", + "Epoch 8/15\n", + "31423/31423 [==============================] - 5s 173us/step - loss: 0.2056 - acc: 0.9196 - val_loss: 0.3087 - val_acc: 0.9009\n", + "\n", + "Epoch 00008: val_loss did not improve\n", + "Epoch 9/15\n", + "31423/31423 [==============================] - 6s 176us/step - loss: 0.1992 - acc: 0.9238 - val_loss: 0.3035 - val_acc: 0.9046\n", + "\n", + "Epoch 00009: val_loss did not improve\n", + "Epoch 10/15\n", + "31423/31423 [==============================] - 6s 175us/step - loss: 0.1883 - acc: 0.9265 - val_loss: 0.3125 - val_acc: 0.9079\n", + "\n", + "Epoch 00010: val_loss did not improve\n", + "Epoch 11/15\n", + "31423/31423 [==============================] - 6s 175us/step - loss: 0.1838 - acc: 0.9284 - val_loss: 0.3212 - val_acc: 0.9033\n", + "\n", + "Epoch 00011: val_loss did not improve\n", + "Epoch 12/15\n", + "31423/31423 [==============================] - 5s 175us/step - loss: 0.1804 - acc: 0.9292 - val_loss: 0.3223 - val_acc: 0.9045\n", + "\n", + "Epoch 00012: val_loss did not improve\n", + "Epoch 13/15\n", + "31423/31423 [==============================] - 6s 175us/step - loss: 0.1730 - acc: 0.9325 - val_loss: 0.3092 - val_acc: 0.9081\n", + "\n", + "Epoch 00013: val_loss did not improve\n", + "Epoch 14/15\n", + "31423/31423 [==============================] - 5s 175us/step - loss: 0.1700 - acc: 0.9323 - val_loss: 0.3072 - val_acc: 0.9094\n", + "\n", + "Epoch 00014: val_loss did not improve\n", + "Epoch 15/15\n", + "31423/31423 [==============================] - 6s 175us/step - loss: 0.1635 - acc: 0.9357 - val_loss: 0.3402 - val_acc: 0.8964\n", + "\n", + "Epoch 00015: val_loss did not improve\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "oneshot_model.fit(x = X_train, \n", + " y = to_categorical(Y_train),\n", + " epochs = 15,\n", + " batch_size = 128,\n", + " verbose = 1,\n", + " callbacks=[checkpointer],\n", + " validation_data=(X_test, to_categorical(Y_test)),\n", + " class_weight=class_weights)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluation - passare class_weights a class report" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 0.85 0.89 0.87 2039\n", + " 1 0.92 0.90 0.91 3958\n", + " 2 0.79 0.80 0.79 2333\n", + " 3 0.99 0.99 0.99 2733\n", + " 4 0.94 0.84 0.89 442\n", + "\n", + "avg / total 0.90 0.90 0.90 11505\n", + "\n", + "Weighted f1-score: 0.8965743574822785\n" + ] + } + ], + "source": [ + "Y_pred = oneshot_model.predict_classes(X_test)\n", + "print(classification_report(Y_test, Y_pred))\n", + "print(\"Weighted f1-score:\", f1_score(Y_test, Y_pred, average='weighted'))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 0.88 0.88 0.88 2039\n", + " 1 0.92 0.91 0.92 3958\n", + " 2 0.83 0.83 0.83 2333\n", + " 3 0.97 1.00 0.98 2733\n", + " 4 0.91 0.84 0.88 442\n", + "\n", + "avg / total 0.91 0.91 0.91 11505\n", + "\n", + "Weighted f1-score: 0.9073576036830062\n" + ] + } + ], + "source": [ + "oneshot_model_best = load_model('./data/model_AOS_1.hdf5')\n", + "\n", + "Y_pred = oneshot_model_best.predict_classes(X_test)\n", + "print(classification_report(Y_test, Y_pred))\n", + "print(\"Weighted f1-score:\", f1_score(Y_test, Y_pred, average='weighted'))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAU0AAAElCAYAAABgV7DzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3XdYlfX/x/HnOQcQZYg401y4ytQcuHGLuEHCxIGappmpoaYoCg5MTA0rydmwHGnkyMyvpTgQBxbmwtI0xZkLB0MZ59y/P/h5CgUPR8CbY++H17ku78Pnvu/3fc7Ni8+9NYqiKAghhMgVrdoFCCGEJZHQFEIIM0hoCiGEGSQ0hRDCDBKaQghhBglNIYQwg4SmEEKYwSJDU6/X8+WXX+Lt7Y2npyddu3Zl3rx5pKWl5Wmab7/9Nh4eHqxatcrs8Y8fP86YMWOeev75LTExkYEDB+b4c09PT+7du1dg8x82bBhnzpwBYMiQISQkJADQvn17jh8/nm/ziYiIYPXq1fk2vfwWExND9+7dAfj444/ZtGnTU03n0e+zoL8/kTMrtQt4GtOnT+fu3bt89dVXODg4kJKSwnvvvceUKVOYN2/eU03z2rVrREdHc+TIEXQ6ndnj161bl08++eSp5l0Q7t69+8Rw+v777wt0/suXLzf+f9++fQU2n9jYWGrUqFFg089P77777lOP++j3WdDfn8iZxfU0L126xA8//MDs2bNxcHAAoFixYsyYMYOOHTsCmX+V33vvPbp3706PHj2YO3cuGRkZQGa4LVy4EF9fX9q3b8+aNWtISkrizTffJCMjA29vby5cuECtWrWMvSPAOJycnMyYMWPw9PSkV69eTJ06FYPBkKVHYe78s1O3bl3CwsLw8fGha9eubN26lTFjxtC5c2cGDhxISkoKAN999x29e/fGy8uLdu3aGac3efJkHjx4gKenJ3q9njp16vDuu+/i4eHB8ePHjcsTHh6Or68ver2eGzdu4ObmxsGDB3P8/G/fvk2DBg2M8w8ODmbAgAHGn3fq1ImzZ88ae5STJ08GYNCgQVy9ehWAdevW4e3tTdu2bVmwYIFx3HXr1tG9e3d69uzJkCFDOHfuHACTJk3i888/N7Z7OLx9+3Z27tzJihUrHuttXrp0iY4dOxISEoKPjw+dOnVi+/btAKSnpxMSEkLXrl3p0aMHU6ZMISkpCcjsCfv7+9OlSxe2b99O+/btCQsL4/XXX8fDw4OIiAgmT55Mz5498fb25tq1awDs2rULX19f43J99NFHj312D+s+evQonp6exleTJk3o16+fWd/nv9fPTz/91LgsY8aM4caNGwD4+fnx4Ycf0r9/f9q3b8+UKVMwGAw5frcilxQLs23bNuW11157YpuJEycqISEhisFgUFJTU5UhQ4YoS5cuVRRFUWrWrKmsXLlSURRFOX78uFKnTh3lwYMHysWLF5X69esbp1GzZk3l1q1bjw1v3LhRGTJkiKIoipKRkaFMmTJFOX/+vHLw4EGlW7duTz3/R9WsWVP56quvFEVRlKVLlyoNGjRQ/v77b0Wv1yu9evVSNm/erCQlJSmvv/66kpCQoCiKovz222/GZchueTZu3PjY8mRkZCj9+/dXli5dqgwePFhZvHixye/Az89P2blzp6IoitKpUyelRYsWSlJSkvLnn38qXbp0URRFUdq1a6ccO3bssc+yXbt2ysyZMxVFUZTr168rderUUa5cuaLs379f6dixo7Hd+vXrlS5duigGg0EJCAhQPvvsM+P8/z386M8eunjxolKzZk1jndu2bVPatm2rKIqifPzxx8qoUaOUtLQ0Ra/XK5MmTVKCgoKM9YWHhxun065dO2X27NmKoijKjz/+qLz00kvK77//riiKoowcOVJZvHixYjAYlAEDBijnzp1TFEVR/v77b+Xll19Wbt26lWW9yK7Wo0ePKm3atFHOnDlj9vd569Yt5bvvvlP69OmjJCcnK4qiKJ988olx/RwwYIAyZswYRa/XK4mJiYqbm5ty4MCBJ321Ihcsrqep1WpN/rWMiopiwIABaDQabGxs8PX1JSoqyvjzDh06APDKK6+QlpZm7DXlRqNGjThz5gx+fn4sW7aMQYMGUbly5QKZv4eHBwCVKlWiZs2alC1bFq1Wy4svvsjdu3exs7NjyZIl7Nmzh48++oglS5Y8cVlcXV0fe0+n0zF//nyWL1+Ooii89dZbJj8Dd3d3oqKiOHv2LGXLlqVhw4b88ssvREZG0qlTJ5PjP+yRly5dmlKlSnHr1i327t1L165dcXZ2BjD24i5dumRyejmxtramTZs2ANSuXZs7d+4Amd+Pr68v1tbWaLVa/Pz82Lt3r3G8Rz+nh8tUsWJFSpUqxUsvvQRkfi93795Fo9GwZMkS4uLiCA8PZ86cOSiKwv37959YX3x8PKNHj2bu3LlUq1bN7O/z4bJ4e3tTrFgxAAYOHMjBgweN+/fbtWuHVqvF3t6eypUrc/fu3dx+fCIHFhea9erV46+//jJuTj107do1hg8fzoMHDzAYDGg0GuPPDAaDcfMYoEiRIgDGNoqJe5b8+wBTxYoV2b59O8OHDycpKYk33niDnTt3ZmmfX/O3trbO9v8P/f3333h5eXH58mUaNWqEv7//E5fj4S/Woy5fvkyRIkW4cOFCrn6pHoZmdHQ0LVu2pEWLFkRHR7Nz5046d+5scnwrq392pWs0GhRFyfYPoaIoZGRkGNs8lJ6ebnIegDEUH87noey+n39P89HPycbGJss0H5WSkkKvXr2Ii4ujdu3aTJw4ESsrqyeuV7du3WLYsGGMHz+eJk2aAOZ/nzkty7/XNVtbW+P/H/0cxdOxuNAsW7YsPXr0IDAw0BicSUlJTJ8+HScnJ2xtbXFzc2PVqlUoikJaWhrffvstLVq0MGs+zs7Oxh3vW7ZsMb6/Zs0aJk+ejJubGxMmTMDNzY2TJ09mGTc/5p8bJ06cwNnZmZEjR+Lm5sauXbuAzDMBrKys0Ov1Jn9J7t27x4QJE5gzZw7du3dnypQpJudbrlw5SpQowdq1a2nZsiVubm78/PPP3Llzx9gL+zedTpflFzk7rVq1YuvWrcb9dOvXr8fJyYnKlStTokQJTpw4AWT+cTx06JBZ085uXt988w3p6ekYDAZWr15Ny5YtzZrGv8XHx5OUlIS/vz/t27cnJiaGtLS0HLeIkpOTGT58OD4+PvTs2dP4/tN8n61atWL9+vXGHunKlStp3LhxlqAX+cviQhNg2rRpVK9eHV9fXzw9PenduzfVq1dn1qxZAEydOpWEhAR69OhBjx49qFq1KiNGjDBrHlOnTmXmzJn06tWLs2fPUrp0aQC8vLzQ6/V07doVb29vEhMT8fPze2zcvM4/N1q2bEnZsmXp3LkzXbp04erVqzg7OxMfH0/p0qWpV68e3bp14/bt209czrZt2+Lm5saoUaO4ePGi8aCKp6dnjkfg3d3dSUhIoHbt2lSsWBFbW1vjgbhHde7cGT8/P06fPv3EZRk8eDCDBg2iW7dubNq0iaVLlxo3n2/cuIGHhweBgYE0a9bMOF7r1q1Zu3YtS5cuzc1HBsDbb79NqVKl8PLyokuXLmRkZOTqj0VOatWqRdu2benSpQtdunRh165dVK9enfj4+Gzbr1q1ilOnTrF9+3a8vLyMB4Se5vv08fGhefPm9O7dmy5dunDy5Enmz5//1MsiTNMo0l8XOViwYAE9e/akWrVqapciRKFhkT1NUfAURaFChQoSmEI8QnqaQghhBulpCiGEGVS5jHLUxt/VmG2+mdvt8SPEliJdb9lXhFjpLPvvvMZ0k0KtmE3+LkHRBqPMan//t/B8nf/TsOw1UAghnjGLvGGHEOI5obG8fpuEphBCPRrL22EhoSmEUI/0NIUQwgzS0xRCCDNozb/h97/p9XqmTp3KuXPn0Ol0hIaGkpiYyIgRI6hSpQoAffv2pWvXroSHh7N7926srKwIDAykXr16xMfHM2nSJDQaDTVq1GDatGnGm7zkREJTCKGePG6eP7ypydq1a4mJiSE0NJT27dvzxhtvMGTIEGO7uLg4Dh06REREBFevXmX06NGsX7+e0NBQ/P39adq0KcHBwURGRuLu7v7EeUpoCiHUk8fN844dO9K2bVsArly5QqlSpThx4gTnzp0jMjKSypUrExgYSGxsLG5ubmg0GsqXL49erychIYG4uDjjrflat27Nvn37JDSFEIVYPhwIsrKyIiAggO3bt/PJJ59w7do1evfuTZ06dVi8eDGffvopDg4OODk5Gcexs7MjMTERRVGM9yN9+J4plnfoSgjx/NBozHvl4IMPPuCnn34iKCgINzc36tSpA2TewvDkyZPY29uTnJxsbJ+cnIyDg0OW/ZfJyck4OjqaLFlCUwihHo3WvNcjHt53FaBo0aJoNBpGjRrFsWPHADhw4ACvvPIKDRs2JDo6GoPBwJUrVzAYDDg7O1O7dm1iYmKAzEeHZPdImEfJ5rkQQj153KfZqVMnJk+eTP/+/cnIyCAwMJAXXniBkJAQrK2tKVWqFCEhIdjb2+Pq6kqfPn0wGAwEBwcDEBAQQFBQEGFhYbi4uBify/XEktW4NZzcsEM9csMOdVneWYlZ5fsNO1pPN6v9/Sjz2hcE6WkKIdQjVwQJIYQZtJbX95bQFEKoR3qaQghhBgu89tziYv7v00fZGDQIgBvnfue7AF82BA4gMnwKyv8/ZzpuewTfTujNdwG+nP91NwDJt2+wadobbJgygG3zx5Keel+tRchCr9czYvgQOrR1o1OHNvx19izXr1/n9de86NShDR3auvHX2bNql/mYXw/F0M2jfZb3Jk8cxxfL/3mU7sKPPqRNiya0c2vGD99vetYl5tr169epVa0Sp/74gyO/HaZG1Rfp7N6Ozu7t+C5indrl5Si7dWfQgL7G2l+uWZVBA/qqXeaT5fGUIzVYVE/z8MbPObVnM9ZFigLwy7pFuL7+NlUateHnBRM4H7uHMtXrcOzHVbw+L4KMtFQ2TBlAxVdbcHjDZ7zU1ouX2nlyaG04cT9/S/0eg1ReIti65QcAIndHE7VnN5MmjsepRAn69O3Haz6vs2f3Lk6f+gOXQvRUyI/D5rHum9UUK1YMgJs3bjDizcGcOfMnNfxrAXDnzh2WLgrn8IlTpCQn49asET08vdQsO1vp6emMeWcERW0z16kjvx1m9JixjBk7XuXKTMtu3fl2feYfp9u3b9OlU3vmzAtTs0TTpKdZsIqXq0iXiR8bh0u7vExq0l0URSH9fgpanRXX/zzOCy81QGdtQxE7B4qXq8TN86dwGzKJWm16oBgMJN36m2LFS6q4JP/o4elF+KJlAFy4EE+ZMmU4uH8fly9doltnd9atXU2rNm3VLfIRVVyqsfKbCONwUnISk6YE06dvf+N7dnZ2VKxUiZTkZJKTk03eOUYtgQHv8eawtyhXvjwAvx2OZdu2rXTq0IaRbw3N1WV1aslu3Xno/ZnTeHvkKF544QW1yssdrc68VyFQONfkHFRr3gmtlbVxuPgLldn7+WzWjOlOyt2bVKjThLT7SdgUczC2sSlqR1pKIhqNBoNBzzf+Pbl84hDlXmqgxiJky8rKimFDB/Pe2DF4efsQH38epxIl+HHbdipWrETY/A/ULjELTy9vrKz/+R6qVKmKa5Omj7Wr8GJFmjasS5uWjXlrpHkP0HoWVn29glKlS9Ox0z8nNLs2bsL7oXP5OXIPVaq6EDprhooVmvbougOZuxt279rJgIGDVa0tVyxw87xAqnh4xn2fPn3w8/MjPj6+IGZD9Oeh9Jq1kv4Lf+Sltp7sWzEXm6L2pN//5xrTtPvJFLHLvJ5UZ2VNv0+20HbEdCI/mVwgNT2t5Z+v4MiJU4waORwnJye6de8JQJduPTgcG6tydebb/tM2/v77Kkd/P8OJU+f48Yfvif3lkNplZfH1V1+yM3IHnd3bcfzoEYYPHUQnjy40aNgIgB6evTh65IjKVZr273UnOTmZTRu+43Xfvuh0haNn9kT5dO35s1Qgobljxw7S0tJYt24d48ePZ86cOQUxG4o4FMemmD0AdiXKkJp8jzI16nLl91gy0lJJTU7k9uW/cK5Ugz1LZ3LpeOY1ptZF7dAUks3FNatXMm9uKADFihVDq9Hi1qoNP23bCsC+vVG8XLu2miU+FacSThS1LUqRIkWwtbWleHEn7t69o3ZZWfwcuYefduxm2/Zd1H21Pss+/4o+Pl78+v/hvntnJA0aNlS5ypxlt+7odDp27YzE3aOLytXlkgX2NAvkQFBsbCytWrUCoH79+pw4caIgZkO7kTP5+cP30Op0aK2saTdyJnYlSlOv2wA2TvFDUQw06/cuVjZFqNdtALuXzuDXiMWg0dB6eFCB1GQuTy9vRgwbQqcObUhPT2fu/AXUe7U+I0cM47NlS3B0LM6XX69Wu0yztWjZit07I+nYpgVarZZmLVrSrsOT71NYGHy0cBHj/UdjbWND2bLlWLhoqemRVJLdumNra8ufp09RtaqL2uXlTiHpPZqjQK49nzJlCp06daJNmzYAtG3blh07dmBllZnRcu25euTac3VZXkRkle/XnncPN6v9/S3q7xsvkJ7mo/euMxgMxsAUQgijQrLJbY4Cqbhhw4ZERUUBcOTIEWrWrFkQsxFCWDoLPBBUIN0/d3d39u3bh6+vL4qiMHv27IKYjRDC0llgT7NAQlOr1TJz5syCmLQQ4nlSSHqP5pAdjUII9UhPUwghzCA9TSGEyD2NhKYQQuSehKYQQphBI4+7EEKI3JOephBCmEFCUwghzCChKYQQ5rC8zJTQFEKoJ689Tb1ez9SpUzl37hw6nY7Q0FAURWHSpEloNBpq1KjBtGnT0Gq1hIeHs3v3bqysrAgMDKRevXrEx8dn2/ZJLO90fCHEc0Oj0Zj1etSuXbsAWLt2LWPGjCE0NJTQ0FD8/f1Zs2YNiqIQGRlJXFwchw4dIiIigrCwMGbMyHyMSXZtTZHQFEKoJq+h2bFjR0JCQgC4cuUKpUqVIi4ujiZNmgDQunVr9u/fT2xsLG5ubmg0GsqXL49erychISHbtqZIaAohVJPX0ITMh8sFBAQQEhKCh4cHiqIY29rZ2ZGYmEhSUhL29vbGcR6+n11bUyQ0hRDq0Zj5ysEHH3zATz/9RFBQEKmpqcb3k5OTcXR0fOzG6MnJyTg4OGTZf/mwrSkSmkII1eS1p7lp0yaWLs18jlPRokXRaDTUqVOHmJjMhyhGRUXh6upKw4YNiY6OxmAwcOXKFQwGA87OztSuXfuxtqbI0XMhhGryevS8U6dOTJ48mf79+5ORkUFgYCDVqlUjKCiIsLAwXFxc8PDwQKfT4erqSp8+fYyPGAcICAh4rK3JmgviwWqmyIPV1CMPVlOXBZ6WmEV+P1itzJBvzWp//YvX83X+T0N6mkII1cgNO3LJkntqACWbjla7hKd2fs8CtUvIEwcL72lqLTAkCpJcRimEEGaQ0BRCCDNIaAohhDksLzMlNIUQ6pGephBCmEFCUwghzCChKYQQ5rC8zJTQFEKoR3qaQghhBglNIYQwg4SmEEKYQUJTCCHMYXmZKaEphFCPqSc/FkYSmkII1Vjg1rmEphBCPbJPUwghzGCBmSmhKYRQj/Q0VaDX63nn7WH8efo0Op2OJcu+4O69u/T27kn16jUAeHP4CHx691G1Tq1Ww6KgftSsUga9QWH4tFUkJT/g0+B+lHAshk6rYWjQSs5duglAqRL27FoxDtfes0lNy6CEYzG+eH8Qjna23LqbzDsz13DjdpIqy5Kens7Yd4Zx8UI8qampjJ0wmQ0Ra7l+/RoAFy/E08i1CUu/XE3ozCCidu9Eo9Ewa+4CGjZqrErN2clu3XGpVg2Aie+NpWbNWrw5fITKVeaOwWDg3VEjOXbsKEWKFGHx0s+oVr262mWZZIGZafmhuXXLDwBE7o4mas9uJk0cT9du3Rk9Zizvjh2vcnX/6Na6LgDt31hAq0Y1+GC8N3fupbBu6y+s3/4brV1rUKtKWc5duknH5i8TMqYnZZwdjONPHOrB/t/OMu+Ln2nXtBYzRvdk5Mw1qizLd+vWUMK5JOHLVpCQcAv3Vk2IjTsLwJ3bt3mthzszQ+dz/OhvxP5yiK2R0Vy8EM/gfq+xc1+sKjVnJ7t159Mlyxk2ZBBn/jxNzXG1VK4w9zZ/v4kHDx6wJ/oAMQcPMmnieCI2fK92WSZZ4uM/LO94/yN6eHoRvmgZABcuxFOmTBl+OxzLtm1b6dShDW+/NZTExESVq4Qfdh/jnVnfAFCpvDPXbyXSvL4LFcqW4Mclo/Dt2pioX/8EwGBQ6DYinNv3Uozjv+RSjp/3nQTgwJG/aFHf5dkvxP/r6fUaAVOmG4d1un/+9s4LncnQ4e9QttwL1H21AWs3/ohGo+HSxQuULl1WhWpzlt26k5yUxJSgafTtP0Dl6syzf1807h6dAWjarBmxsb+qXFHuaDTmvQqDAgvNo0eP4ufnV1CTz8LKyophQwfz3tgxeHn70KhxE2aHzuXnyD1UrerC7Fkznkkdpuj1BpbP9CNsog8bd/xG5RdKcvteCt1GhHPx7wTGv+EOwM6YP0i4m5xl3GOnLtGtTWZvtXubuhSztXnm9T9kZ2+PvYMDSYmJvDnQl0lB0wG4ceM6e/fspE//gca2VlZWhM4Mwq+PF716+6pUcc4eXXeqVK1K4yZN1S7LbIn37lG8eHHjsE6nIyMjQ8WKckej0Zj1KgwKJDSXL1/O1KlTSU1NLYjJZz/Pz1dw5MQpRo0cTseOnWjQsBEAPTx7cfTIkWdWhynDgldSz2smi4L7cScphR/3HAdg654TNKxdKcfx5n3xM5XLl+THJaN4sWwJLl27/axKztblSxfx7u6OT5/+ePfuC8CWTRvw9vFFp9NlaTs5OISjf8Sz6JMPOf/XWTXKfaJ/rzvJycmmRyiEHBwds2xRGQwGrKwK/9436Wn+v0qVKrFw4cKCmPRj1qxeyby5oQAUK1YMrUZL3z6v8esvhwDYvTOSBg0bPpNanqRvt8a8N6QTACkP0jEYDETHnsHDrTYAbg2r8/vZqzmO79awOqu3xNBtRDjnr9ziwJG/nknd2blx/Rq+vboSNGM2/fwGG9/fuzuS9u6djcPRe3YxafwYAIrY2mJtbY2mEF0Bkt2682jgW4rmLVry0/+2AhBz8CB16tRVuaLcyWtPMz09nQkTJtCvXz98fHyIjIwkLi6OVq1a4efnh5+fH1u3Zn4u4eHh+Pj44Ovry7FjxwCIj4+nb9++9OvXj2nTpmEwGEzWXCB/ijw8PLh06VJBTPoxnl7ejBg2hE4d2pCens7c+QuoULEi4/1HY2NjQ9my5Vi4aOkzqeVJvo88yrIZA9j+uT/WVjomzF/PsVOXWBTcn+G9W3E36T6DJ6/IcfzT8df4PCRzs/fK9TuMmKHOQSCAjz/8gDt37hA2bzZh82YDsOa7Hzhz5jSVq1Q1tmvu1prNm9bTo1Mb9Ho9b7z5dpafqy27dcfW1lbtsp6Kp1cvdu7YTttWLVAUhWWffal2SbmS103uzZs34+TkxLx587h9+za9evXinXfe4Y033mDIkCHGdnFxcRw6dIiIiAiuXr3K6NGjWb9+PaGhofj7+9O0aVOCg4OJjIzE3d39yTUriqLkqeocXLp0iXHjxvHtt98+9rOUtAKZ5TNTsulotUt4auf3LFC7hDxxsC38m5xPYolHi/8tvz/++tMjzWp/ZHqHLMPJyckoioK9vT23b9/Gx8cHNzc3zp07h16vp3LlygQGBrJhwwYePHjA8OHDAfDy8uKLL77A09OTqKgoNBoNO3bsYN++fUybNu2JNVj2GiiEsGh57Wna2dkBkJSUxJgxY/D39yctLY3evXtTp04dFi9ezKeffoqDgwNOTk5ZxktMTERRFGMND98zpfDsYBJC/OdotRqzXtm5evUqAwcOxNPTkx49euDu7k6dOnUAcHd35+TJk9jb22c5yJecnIyDg0OWuywlJyfj6OhouuY8LnOOXnzxxWw3zYUQ4qG8Hj2/efMmQ4YMYcKECfj4+AAwdOhQ44GeAwcO8Morr9CwYUOio6MxGAxcuXIFg8GAs7MztWvXJiYmBoCoqChcXV1N1iyb50II1eR183zJkiXcu3ePRYsWsWjRIgAmTZrE7Nmzsba2plSpUoSEhGBvb4+rqyt9+vTBYDAQHBwMQEBAAEFBQYSFheHi4oKHh4fpmgvqQNCTyIEg9ciBIHXJgaCsmszebVb7Q4Ft87eAp2DZa6AQwqIVlqt8zCGhKYRQjQVmpoSmEEI90tMUQggzWGBmSmgKIdQjPU0hhDCDBWamhKYQQj3S0xRCCDNIaAohhBksMDMlNIUQ6rHEK6QkNIUQqpHNcyGEMIMFZqaEphBCPVoLTE0JTSGEaiwwMyU0hRDqkX2aQghhBgs8eC6hKYRQj/Q0cyktw/QD2QszS777+aQff1e7hDxZ3Lue2iWIfGSBmSk9TSGEejRYXmpKaAohVCP7NIUQwgzP1T7NPn36PLZAiqKg0WhYu3ZtgRcmhHj+WWBm5hyaYWFhz7IOIcR/0HN1RVCFChUAuHbtGvPmzeP27dt4eHhQq1Yt48+EECIvLPEuR1pTDYKCgnjttddIS0vD1dWV999//1nUJYT4D9BozHsVBiYPBKWmptK8eXMWL16Mi4sLRYoUeRZ1CSH+A/K6eZ6enk5gYCCXL18mLS2Nt99+m+rVqzNp0iQ0Gg01atRg2rRpaLVawsPD2b17N1ZWVgQGBlKvXj3i4+OzbfvEmk0VZWNjw969ezEYDBw5cgQbG5s8LaQQQjykMfP1qM2bN+Pk5MSaNWtYvnw5ISEhhIaG4u/vz5o1a1AUhcjISOLi4jh06BARERGEhYUxY8YMgGzbmmIyNENCQtiwYQO3b9/miy++YPr06bn6MIQQwhSNRmPW61GdO3fm3XffNQ7rdDri4uJo0qQJAK1bt2b//v3Exsbi5uaGRqOhfPny6PV6EhISsm1risnN83LlyvHWW29x/vx5atSoQcWKFXP9gQghxJPk9TiQnZ0dAElJSYwZMwZ/f38++OADY8B0yekPAAAerElEQVTa2dmRmJhIUlISTk5OWcZLTEw0nkb57/dM1myqwaJFi5gxYwaHDx9mypQprFix4mmWTQghHpPXnibA1atXGThwIJ6envTo0SPLPsnk5GQcHR2xt7cnOTk5y/sODg7ZtjXFZGhGRUWxevVqAgMDWbVqFVu3bjU5USGEyI28Hj2/efMmQ4YMYcKECfj4+ABQu3ZtYmJigMz8cnV1pWHDhkRHR2MwGLhy5QoGgwFnZ+ds25picvPc2dmZ+/fvY2dnR3p6Os7OzuZ8JkIIkaO8Xka5ZMkS7t27x6JFi1i0aBEAU6ZMYdasWYSFheHi4oKHhwc6nQ5XV1f69OmDwWAgODgYgICAAIKCgrK0NVmzoihKdj94eBnlrVu3uH//PrVq1eLs2bM4OTmxadOmPC3onRR9nsZXW6oF39pObg0n8sI2n+9WMfibY2a1X9FX/e9fLqMUQqjGEm/YkeM+zQoVKlChQgUyMjLYsmULGzduZOPGjSxduvRZ1vdEv/4SQ/fO7bO8F7HuGzq1a2kc/ujDubRq1oiundqy7X9bnnWJ2UpPT2fU8MF4dm5H53Yt+GnrDxw78hv1X6pCr24d6dWtI5vWfwvA2tVf06V9Szq1aUbYXPWvxnIpWZSA9i4AVC5RlKBO1ZncoRr9G5U3nkf3ev0XmOJejeBO1WldLXN3jp2Njk+8axPQ3oWA9i641yyp0hJkT6/X89abQ2jXuiUd27Xmr7Nn1S7JLAaDgdEjR9DGrTmdOrTl7JkzapeUK3k9T1MNJjvbAQEBtGvXjsOHD1OmTBlSUlKeRV0mfRw2j3XfrMbOrpjxveNHj7Dqqy94uMch7sRxvvt2LTv2ZJ575dG+Fa3btKdYsWLZTvNZ+W7dGko4lyR82QoSEm7h3qoJ4yZO4a133uXt0WON7c7/dZavPl/Khh93UKRIEebNnkF6ejrW1taq1N3l5dK0qOJk3D0xqHEF1hy+wpmbKXjXLUuzKk7cTkmnjL0N728/i5VWw6yuNfn1wl0qlyhKTPwdVsdeUaV2U37c8gMAu6L2EbVnNwETxhGx4XuVq8q9zd9v4sGDB+yJPkDMwYNMmjjeIuq3xBt2mDx6bmtry1tvvUXZsmWZM2cON2/efBZ1mVTVpRorv4kwDifcusWM4EBmz/1nt8LpU3/g1qoNtra22NraUq16deJOmLcPpSD09HqNgCnTjcM6nRVHjxxmx0//w6tLe8a+M5ykxESidu/k1QaNGDNiCL26dqBxsxaqBSbA9cRUwvfGG4edi1lz5mbmH9E/b6ZQo5QdZ26m8EXMJQAUMn8p9IpCFeeiVC5RlIAOLoxsWYni+b1zLI96enrx6ZJlAFyIj6dMmbIqV2Se/fuicffoDEDTZs2Ijf1V5YpyxxKvPTcZmoqicOPGDVJSUkhJSeHu3bvPoi6Tenp5GwNEr9czeuQw3v/gQxwcHIxtar9Sh/379pKYmEjCrVvEHDxAyr/O1VKLnb099g4OJCUm8uZAXyYFTadBo8YEh8xh0/92UrlKVebPmUXCrZsc3L+XBeHL+GzlOqZMGMvdO3dUqzv20j0y/nXc8EZSGrVKZ55cXL+CA0WstGQYFFLS9eg08Gaziuw5e4vUDANX76Wy6fg1Poj8i8OX7tG/UXm1FiNHVlZWvPnGIMb5j6bXaz5ql2OWxHv3KF68uHFYp9ORkZGhYkW5o9VqzHoVBiZDc9SoUWzfvp2ePXvSoUMHWrdu/SzqMsuR32L56+wZxvu/w9BB/Tj1x+9MnjCOWi+9zLARI+ndqxtTJ0/AtXETnEuWUrtcAC5fuoh3d3d8+vTHu3dfunb35NUGDQHo0sOTE8eOUMK5JC3c2mDv4EDp0mWo+dJLnD37p8qV/+PzmEt0e6UM/q2rcO+BnsTUzF/SYtY6xrWtypW7D/jx5A0Afr+WxO/XkwA4fClzc70w+uzLrzh28jQjRwzLcjJ0Yefg6JjlahaDwYCVVeHqzWdHq9GY9SoMTH6qjRs3pnHjxgB06NChwAt6Go1cm3Dg18zN7gvx5xk6qB+h88K4eeMGt27eZNuOKO7evctrPTtT+5U6KlcLN65fw7dXV2bP+5hWbTMPZPl6d+P9eR/RsFFj9u7ZRb36DWjSrDlffraYBw8eoNfrOf3HH1StWk3l6v9Rr7wDX8Rc5M79DPo3Ks/xK4lY6zRMbO/Ctj9ucDD+n17xG01e5NeLd/nl4l1eLmvP+dv3Vaz8cWtWreTy5UtMCJhMsWLF0Gq16HQ6tcvKteYtWrJ1yw/49H6dmIMHqVOnrtol5UohyUGz5Biabm5uOY4UHR1dIMXkt5KlSnH+/Dnat2qGtY01M97/oFD8Inz84QfcuXOHsHmzCZs3G4AZs+cRPGk81jY2lClTlvkfL8bB0ZF+fm/Qs1MbFEVh7MTJlChEFxdcS0xlbJuqpGUY+P16EseuJtKpVilK29vQppozbf7/yPnnMReJOHqVIU0r0r5GSVIzDHx56JLK1Wfl2cub4W++Qcd2rUlPT2fehx9ha2urdlm55unVi507ttO2VQsURWHZZ1+qXVKuWOIpRzme3F6Q5OR29cjJ7SIv8vv43eiN5q2PC3u9nL8FPIXCv9NDCPHcssSepoSmEEI1heSAuFlMHj2HzHvVnTp1qtCc2C6EeD5oNea9CgOTPc1t27axZMkS9Ho9nTt3RqPRMHLkyGdRmxDiOWeJm+cme5orVqzg22+/xcnJiZEjR7Jjx45nUZcQ4j/guexparVabGxsjHdOLlq0cJ6ULISwPBbY0TQdmq6urowbN45r164RHBxM3bqWcdKsEKLwKyxX+ZjDZGiOGzeOqKgoateuTbVq1WjXrt2zqEsI8R+QqyPRhYzJmjdt2kRCQgKlSpXi7t27eb5ruxBCPKTTasx6FQYme5pn//9mrIqi8Pvvv+Pk5ISXl1eBFyaEeP5Z4Na56dAcP3688f+KovDWW28VaEFCiP+OQtJ5NIvJ0ExLSzP+/8aNG1y6VLhutCCEsFzP5YGghye0K4qCra0tQ4cOfRZ1CSH+AywwM02H5rvvvounp+ezqEUI8R9jiZvnJo+eR0REmGoihBBPRWPmv8IgV/s0vby8qFq1KlptZsZ++OGHBV6YEOL5Z4k9TZOh+d577z2LOoQQ/0H5FZpHjx5l/vz5rFy5kri4OEaMGEGVKlUA6Nu3L127diU8PJzdu3djZWVFYGAg9erVIz4+nkmTJqHRaKhRowbTpk0zdg5zkmNo+vv789FHH9GkSZP8WSohhHhEftzlaPny5WzevNl4X4yTJ0/yxhtvMGTIEGObuLg4Dh06REREBFevXmX06NGsX7+e0NBQ/P39adq0KcHBwURGRuLu7v7E+eUYqQkJCXleGCGEeJL8uMtRpUqVWLhwoXH4xIkT7N69m/79+xMYGEhSUhKxsbG4ubmh0WgoX748er2ehIQE4uLijB3D1q1bs3//fpM159jTvHjxImFhYdn+bNy4cSYnLIQQpuTHKUceHh5Zzh+vV68evXv3pk6dOixevJhPP/0UBwcHnJycjG3s7OxITExEURRjb/fhe6bkGJq2trZUrVo1L8sihBBPVBAnt7u7u+Po6Gj8f0hICB06dMjyHPvk5GQcHByy7L9MTk42jvckOYZmqVKl6NWrV15qF0KIJyqIo+dDhw4lKCiIevXqceDAAV555RUaNmzIvHnzGDp0KH///TcGgwFnZ2dq165NTEwMTZs2JSoqimbNmpmcfo6hWadOnXxdkH+ztVH/2eN5Ycn1W/ojcF+e8KPaJeTJ7/O6qV1CoaIrgJ7m9OnTCQkJwdramlKlShESEoK9vT2urq706dMHg8FAcHAwAAEBAQQFBREWFoaLiwseHh4mp6/Kc88fZDzrOYrnhYSmuvL7ueeL9p83q/3IFlXyt4CnII/wFUKo5rk8uV0IIQrKc3mXIyGEKCgWmJkSmkII9UhPUwghzGCBmSmhKYRQjyU+jVJCUwihmvy4YcezJqEphFCN5UWmhKYQQkVyIEgIIcxgeZEpoSmEUJEFdjQlNIUQ6pEDQUIIYYaCuMtRQZPQFEKoxvIiU0JTCKEi2TwXQggzyBVBKjMYDLw7aiTHjh2lSJEiLF76GdWqV1e7rFyz1PoPxcQwNTCAnyN3G9+bMH4sNWvWYthbI9Qr7P9ZaTXM7VuPF52LYaPTEr79T3o2rEBpxyIAvOhclN/O32HMyt+Y3OMlXF2c0Wk1rD1wgbUHL1LeyZYP+r6KlVaDBgj89jh/3Uh+8kyfMUtdd6SnqbLN32/iwYMH7Ik+QMzBg0yaOJ6IDd+rXVauWWL9H86fyzerVlLMzg6AGzdu8OYbA/nzz9PUHDdB5eoyeblW4HZyOuNWH8CpmDVb3muF28ydADgWteKbd5oR8v1JmlUvSeVSdrz28X5sdFp+CmjN1qNXGde1Fl/vPc/2E9doXasUE7q/xNtfxqq8VFlZ4roDsk9Tdfv3RePu0RmAps2aERv7q8oVmccS63dxqcbaiA0MGewHQHJSElOCpvPzT/9TubJ/bD1ylf8dvWoc1hv+ecLL2M41+WpvPDfupXI3JZ2Tl+8CoKCg02rI0Cu8//1JEu9nPqNFp9OSmq5/tguQC5a47oBlnqdpibsUcpR47x7Fixc3Dut0OjIyLOeBRJZYfy/v17C2tjYOV6lalSZNm6pY0eNS0vQkp+qxK6Jj0eBGfLj1FAAl7W1oUbMU3x26CEBahoF79zOw0mqY3+9VvjlwgZQ0PbeT08kwKLiUtiOw58t88tOfai5Otixx3QHQojHrVRg8V6Hp4OiY5WHvBoMBKyvL6Uxbev2F2QtOtqx5pxkbf73E5sNXAOjy6gtsjr3CvzqeOBa1YsVbTTjzdxKLI88a329WvSRLh7oybvWRQrc/Eyx33dFozHsVBs9VaDZv0ZKf/rcVgJiDB6lTp67KFZnH0usvrErZ2/D1iKZ88MMfRBy6ZHy/Zc1S7P7junG4iLWW1SObERFzkYXbzxjfb1a9JNN61Wbw0kMcv3j3mdaeW5a67mjM/FcYFP4/RWbw9OrFzh3baduqBYqisOyzL9UuySyWXn9hNdK9OsWLWjG6Uw1Gd6oBwOBlh3ApY8eFmynGdv1bVKZSyWL4Nq+Eb/NKAEz45ijBXrWx1mmZ3+9VAP66nsSUiBPPfkGewFLXncLSezSHPPdcWBR57rm68vu559vibpjVvvMrpfO3gKfwXPU0hRCWxRJ7ms/VPk0hhGXRajRmvXJy9OhR/PwyT3uLj4+nb9++9OvXj2nTpmEwGAAIDw/Hx8cHX19fjh079sS2T6w5H5ZbCCGeilZj3is7y5cvZ+rUqaSmpgIQGhqKv78/a9asQVEUIiMjiYuL49ChQ0RERBAWFsaMGTNybGuy5nxbeiGEMFN+HD2vVKkSCxcuNA7HxcXRpEkTAFq3bs3+/fuJjY3Fzc0NjUZD+fLl0ev1JCQkZNvWFAlNIYRq8uM8TQ8PjyznpCqKYrym3c7OjsTERJKSkrC3tze2efh+dm1NkQNBQgjVFMS5l1rtP33B5ORkHB0dsbe3Jzk5Ocv7Dg4O2bY1Of38LVcIIXIvP/ZpPqp27drExMQAEBUVhaurKw0bNiQ6OhqDwcCVK1cwGAw4Oztn29YU6WkKIVRTED3NgIAAgoKCCAsLw8XFBQ8PD3Q6Ha6urvTp0weDwUBwcHCObU3WLCe3C0siJ7erK79Pbo/+87ZZ7d1qlMjfAp6C9DSFEKqxwHPbJTSFEOp50gnrhZWEphBCNZYXmRKaQgg1WWBqSmgKIVRTWO6RaQ4JTSGEaixwl6aEphBCPRKaQghhBtk8F0IIM0hPUwghzGCBmSmhKYRQkQWmpoSmEEI1sk9TCCHMYIn7NFW5y1FSqumHFxVmVjq5Dal4OjfupapdQp5UdC6Sr9M7esH0ndL/7dVKDvk6/6chPU0hhHossKcpoSmEUI3s0xRCCDNY4j5NCU0hhGosMDMlNIUQKrLA1JTQFEKoRvZpCiGEGXL7WN7CREJTCKEeCU0hhMg92TwXQggzyClHQghhBgvMTAlNIYSK8iE1vby8cHDIvCb9xRdfpE+fPrz//vvodDrc3NwYNWoUBoOB6dOnc+rUKWxsbJg1axaVK1d+qvlJaAohVJPXfZqpqZk3QFm5cqXxPU9PTxYuXEjFihUZPnw4cXFxXL58mbS0NNatW8eRI0eYM2cOixcvfqp5SmgKIVST132af/zxB/fv32fIkCFkZGQwevRo0tLSqFSpEgBubm4cOHCAGzdu0KpVKwDq16/PiRMnnnqeEppCCNXkdevc1taWoUOH0rt3b86fP8+wYcNwdHQ0/tzOzo6LFy+SlJSEvb298X2dTkdGRgZWVuZHoISmEEI9eUzNqlWrUrlyZTQaDVWrVsXBwYE7d+4Yf56cnIyjoyMPHjwgOTnZ+L7BYHiqwASQu+kKIVSjMfPfo7777jvmzJkDwLVr17h//z7FihXjwoULKIpCdHQ0rq6uNGzYkKioKACOHDlCzZo1n7rm56KnOX/uHLb++APpaWm8OXwEXbv1YNTI4dy5fQe9Xs+yz1fgUq2a2mU+UXp6Om+9OYT4+POkpqYyKXAq3Xv0VLsssx2KiWFqYAA/R+5WuxSzGAwG3h01kmPHjlKkSBEWL/2MatWrq13WY/R6PQH+b3P2zGl0Oh3zFy5j3vvTuXH9bwAuXYingWtTPv1sJXNnBRO9ZxcajYYZoR9Sv1Fjlat/XF73afr4+DB58mT69u2LRqNh9uzZaLVa3nvvPfR6PW5ubrz66qvUrVuXffv24evri6IozJ49+6nnafGhuXfPbmIO7mfHrr2kpKTwyYIPmRoYQB/ffnj7vE7U7l2cPv1HoQ/Nb1avwrlkSb74aiW3bt2iWeMGFheaH86fyzerVlLMzk7tUsy2+ftNPHjwgD3RB4g5eJBJE8cTseF7tct6zI5tPwKw8X+7ORC9h5CpAXy++jsA7ty5ja+nB8Gz5nLi2BEO/3qI73+O4tLFeN4c0Jufon5Rs/Rs5XWfpo2NDR9++OFj73/77bdZhrVaLTNnzszj3P5/WvkyFRXt2P4zr7xSl76ve/O6tyedu3Yj5sB+Ll++TI8unVi3dg2tWrdVu0yTvH16M21GiHH4afe3qMnFpRprIzaoXcZT2b8vGnePzgA0bdaM2NhfVa4oex7dejJnwSIALl28QKkyZYw/C5sTwuBhb1O23AvUqVefVd9tQaPRcPniBUqVLpPTJNWlMfNVCFh8aN66dZPfDv/KyjXf8lH4IoYO9iM+/jxOTk788L+fqVixEgvmz1W7TJPs7e1xcHAgMTGRfn18mDZjltolma2X92tYW1urXcZTSbx3j+LFixuHHx5dLYysrKwYO3Io0yaNo1vPXgDcvHGdfVG76N13YJZ2c2cF80Zfb7xe66NWuU+k1WjMehUGFh+aziVL0sHdAxsbG2rWrIWtrS16vZ6u3TM3bbt0685vhwtnr+FRFy9epHPHdvTr74dv335ql/Of4uDoSGLiP09GzMvR1WdhwaLP2X3oOAH+I0lJTmbr5o14vdYHnU6Xpd3EqTP5Je4cS8IXcP7cWZWqzZkFdjQtPzSbt2jJjp+3oSgKV69cISU5mW49evLztq0A7Ivey0u1X1G5StOuXbtGj66dmBX6AYPeGKJ2Of85zVu05Kf/Za4zMQcPUqdOXZUryt76dasJX5C55VS0aDG0Wi1anY7oPTtp29HD2G5f1C6mTngXgCK2tlhbW6PVFr5fd43GvFdhUHj/lOZSl67d2Re9l7ZuzTAYDHz48UJq1nyJUW8P47PlSynu6MjnX61Wu0yT5s6ZzZ3btwl9P4TQ9zP3bX6/5X8ULVpU5cr+Gzy9erFzx3batmqBoigs++xLtUvKVpfuXowfPRyf7h3ISM9g2vvzsbW15eyZ01SqUtXYrlnL1vz4/QZ6dWmLQW9g4NC3qFS56hOmrJZCkoRm0CiKojzrmSalGp71LPOVla7w/cUWluHGvVS1S8iTis5F8nV6l++kmdW+gpNNvs7/aVh8T1MIYbksr58poSmEUFFh2U9pDglNIYRq5HEXQghhDsvLTAlNIYR6LDAzJTSFEOqRfZpCCGEG2acphBDmsLzMlNAUQqhHK6EphBC5J5vnQghhBks8ECQXUQshhBmkpymEUI0l9jQlNIUQqpF9mkIIYQbpaQohhBksMDMlNIUQKrLA1JTQFEKoRvZpCiGEGWSfphBCmCGvmWkwGJg+fTqnTp3CxsaGWbNmUbly5XypLSdycrsQQj15fPD5jh07SEtLY926dYwfP545c+YUeMnS0xRCqCav+zRjY2Np1aoVAPXr1+fEiRP5UdYTqRKa9kWkgyv+m/L7EbiWrqh13sZPSkrC3t7eOKzT6cjIyMDKquCiTdJLCGGx7O3tSU5ONg4bDIYCDUyQ0BRCWLCGDRsSFRUFwJEjR6hZs2aBz1OjKIpS4HMRQogC8PDo+enTp1EUhdmzZ1OtWrUCnaeEphBCmOG52DxXFIVTp05x+vRptUsRQjznLP6UI0VRePvttylRogQJCQlUqFCB4OBgtcsy21dffcWgQYPULkMIYYLF9zS//fZbSpYsSWhoKJ988gknT55kxowZapdlluTkZNasWUNYWJjapQghTLD40KxWrRoajYZr165RpEgRvv76a06ePGlRAXT8+HGcnZ25fPkygYGBapcjhHiC5yI0ixYtytGjR0lISMDGxoZPPvmE+/fvq11arlWtWpV+/foxZ84cUlNTCQoKUrskIUQOLD40S5QoQZ8+fdizZw979+7l8uXLHD58mLNnz5Kamqp2eblStmxZOnbsiLW1NVOmTCEjI4Nx48apXZYQIhvPzSlH586dY8uWLZw+fZoHDx4wceJEatSooXZZTyUhIYGPPvqIUaNGUaZMGbXLEUL8y3MTmgAZGRncu3cPAGdnZ5WryRuDwYBWa/EbAkI8d56r0BRCiIImXRkhhDCDhKYQQphBQlMIIcwgoSmEEGaQ0BRCCDNIaD5HYmJiaN68OX5+fvj5+fH666+zcuXKp5rW/Pnz2bBhA7///jvh4eE5ttu+fTvXrl3L1TSjoqKYNGnSYzWPHTs2x3E2bNjA/PnzczV9c9oK8bQs/i5HIqtmzZqxYMECANLS0ujcuTOenp44Ojo+1fRefvllXn755Rx//vXXXzN9+nTKli37VNMXwtJIaD7HkpKS0Gq16HQ6/Pz8KFGiBPfu3WPZsmVMnz6d+Ph4DAYD/v7+NG3alJ9++onFixfj7OxMeno6Li4uxMTEsHbtWhYsWEBERATffPMNBoOBDh06ULduXX7//XcCAgJYs2YN69atY8uWLWg0Grp27crAgQM5e/YsgYGBFC1alKJFi1K8ePEc6121ahU///wzGRkZODg4sHDhQiDzMQaDBg0iKSmJ0aNH07ZtWw4dOsSCBQvQ6XRUrFiRmTNnPquPVfzHSWg+Zw4ePIifnx8ajQZra2uCgoKws7MDoEePHri7u7NmzRpKlCjB7NmzuX37NgMGDODHH39k3rx5RERE4OTkxPDhw7NM99atWyxfvpzNmzdjY2PDnDlzaNy4MS+//DLTp0/nwoULbN26lTVr1qDRaBg8eDBubm58/PHHjBkzhpYtW7Js2TL++uuvbOs2GAzcuXOHFStWoNVqGTp0KMePHwegaNGiLFu2jISEBHr37k2rVq0ICgpizZo1lCxZko8++oiNGzcW+AO1hAAJzefOvzfPH1W1alUATp8+TWxsLMeOHQMyLz+9efMm9vb2lChRAoAGDRpkGffixYvUqFEDW1tbgMduYXf69GmuXLnC4MGDAbh79y4XLlzgzz//pF69ekDmQ7ByCk2tVou1tTXjxo2jWLFi/P3332RkZADQqFEjNBoNJUuWxMHBgdu3b3P9+nX8/f0BePDgAS1btqRSpUpmfVZCPA0Jzf8QjUYDgIuLC+XKlWPEiBE8ePCAxYsX4+joSGJiIgkJCTg7O3P8+HHKlStnHLdSpUr89ddfpKWlYWNjw5gxY5gyZQoajQZFUXBxcaF69ep89tlnaDQaVqxYQc2aNXFxceG3336jdevWnDhxIsfa/vjjD3bs2EFERAT379/H29ubh1f4Puxx3rhxg5SUFEqUKEG5cuVYtGgRDg4OREZGUqxYMa5evVqAn54QmSQ0/4N8fX2ZOnUqAwYMICkpiX79+mFjY0NoaChDhw6lePHij23qOjs7M2zYMAYMGIBGo6Fdu3aULVuWBg0aMHHiRL744guaN29O3759SUtLo169epQtW5Zp06YxduxYPv/8c5ydnSlSpEi2NVWuXJmiRYvi7e2NjY0NpUuX5vr160BmT3LgwIGkpKQwc+ZMdDodU6ZMYfjw4SiKgp2dHXPnzpXQFM+E3LBDCCHMIOdpCiGEGSQ0hRDCDBKaQghhBglNIYQwg4SmEEKYQUJTCCHMIKEphBBm+D+X6vG48CQUKgAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cnf_matrix = confusion_matrix(Y_test, Y_pred)\n", + "np.set_printoptions(precision=2)\n", + "\n", + "sns.set_style(\"dark\")\n", + "plt.figure()\n", + "utils.plot_confusion_matrix(cnf_matrix, classes=[0,1],\n", + " title='Confusion matrix, without normalization')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Two-steps classification\n", + "## Activity detection\n", + "This model performs a binary classification.\n", + "### Preprocessing" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "X_train, Y_train, X_test, Y_test, n_features, n_classes, class_weights = preprocessing.loadData(subject=subject,\n", + " label=label,\n", + " folder=folder,\n", + " window_size=window_size,\n", + " stride=stride,\n", + " make_binary=True,\n", + " null_class=True,\n", + " print_info=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "detection_model = models.ConvolutionalRecurrent((window_size, n_features), n_classes, print_info=False)\n", + "\n", + "detection_model.compile(optimizer = Adam(lr=0.001),\n", + " loss = \"categorical_crossentropy\", \n", + " metrics = [\"accuracy\"])\n", + "\n", + "checkpointer = ModelCheckpoint(filepath='./data/model_ATSD_1.hdf5', verbose=1, save_best_only=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 31423 samples, validate on 11505 samples\n", + "Epoch 1/15\n", + "31423/31423 [==============================] - 9s 272us/step - loss: 0.1381 - acc: 0.9541 - val_loss: 0.1605 - val_acc: 0.9485\n", + "\n", + "Epoch 00001: val_loss improved from inf to 0.16046, saving model to ./data/model_ATSD_1.hdf5\n", + "Epoch 2/15\n", + "31423/31423 [==============================] - 6s 177us/step - loss: 0.0763 - acc: 0.9738 - val_loss: 0.1470 - val_acc: 0.9538\n", + "\n", + "Epoch 00002: val_loss improved from 0.16046 to 0.14704, saving model to ./data/model_ATSD_1.hdf5\n", + "Epoch 3/15\n", + "31423/31423 [==============================] - 6s 177us/step - loss: 0.0624 - acc: 0.9780 - val_loss: 0.1562 - val_acc: 0.9537\n", + "\n", + "Epoch 00003: val_loss did not improve\n", + "Epoch 4/15\n", + "31423/31423 [==============================] - 6s 177us/step - loss: 0.0528 - acc: 0.9809 - val_loss: 0.1403 - val_acc: 0.9553\n", + "\n", + "Epoch 00004: val_loss improved from 0.14704 to 0.14033, saving model to ./data/model_ATSD_1.hdf5\n", + "Epoch 5/15\n", + "31423/31423 [==============================] - 6s 178us/step - loss: 0.0454 - acc: 0.9838 - val_loss: 0.1408 - val_acc: 0.9524\n", + "\n", + "Epoch 00005: val_loss did not improve\n", + "Epoch 6/15\n", + "31423/31423 [==============================] - 6s 178us/step - loss: 0.0416 - acc: 0.9851 - val_loss: 0.1236 - val_acc: 0.9581\n", + "\n", + "Epoch 00006: val_loss improved from 0.14033 to 0.12356, saving model to ./data/model_ATSD_1.hdf5\n", + "Epoch 7/15\n", + "31423/31423 [==============================] - 6s 179us/step - loss: 0.0397 - acc: 0.9859 - val_loss: 0.1558 - val_acc: 0.9491\n", + "\n", + "Epoch 00007: val_loss did not improve\n", + "Epoch 8/15\n", + "31423/31423 [==============================] - 6s 178us/step - loss: 0.0330 - acc: 0.9881 - val_loss: 0.1353 - val_acc: 0.9582\n", + "\n", + "Epoch 00008: val_loss did not improve\n", + "Epoch 9/15\n", + "31423/31423 [==============================] - 6s 179us/step - loss: 0.0311 - acc: 0.9889 - val_loss: 0.1495 - val_acc: 0.9558\n", + "\n", + "Epoch 00009: val_loss did not improve\n", + "Epoch 10/15\n", + "31423/31423 [==============================] - 6s 180us/step - loss: 0.0298 - acc: 0.9896 - val_loss: 0.1625 - val_acc: 0.9558\n", + "\n", + "Epoch 00010: val_loss did not improve\n", + "Epoch 11/15\n", + "31423/31423 [==============================] - 6s 181us/step - loss: 0.0264 - acc: 0.9910 - val_loss: 0.1253 - val_acc: 0.9631\n", + "\n", + "Epoch 00011: val_loss did not improve\n", + "Epoch 12/15\n", + "31423/31423 [==============================] - 6s 186us/step - loss: 0.0251 - acc: 0.9907 - val_loss: 0.1934 - val_acc: 0.9534\n", + "\n", + "Epoch 00012: val_loss did not improve\n", + "Epoch 13/15\n", + "31423/31423 [==============================] - 6s 180us/step - loss: 0.0220 - acc: 0.9921 - val_loss: 0.1624 - val_acc: 0.9631\n", + "\n", + "Epoch 00013: val_loss did not improve\n", + "Epoch 14/15\n", + "31423/31423 [==============================] - 6s 176us/step - loss: 0.0240 - acc: 0.9915 - val_loss: 0.1413 - val_acc: 0.9558\n", + "\n", + "Epoch 00014: val_loss did not improve\n", + "Epoch 15/15\n", + "31423/31423 [==============================] - 6s 177us/step - loss: 0.0216 - acc: 0.9922 - val_loss: 0.1587 - val_acc: 0.9540\n", + "\n", + "Epoch 00015: val_loss did not improve\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "detection_model.fit(x = X_train, \n", + " y = to_categorical(Y_train), \n", + " epochs = 15, \n", + " batch_size = 128,\n", + " verbose = 1,\n", + " callbacks=[checkpointer],\n", + " validation_data=(X_test, to_categorical(Y_test)),\n", + " class_weight=class_weights)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 0.92 0.81 0.86 2039\n", + " 1 0.96 0.98 0.97 9466\n", + "\n", + "avg / total 0.95 0.95 0.95 11505\n", + "\n", + "Weighted f1-score: 0.9528960217945623\n" + ] + } + ], + "source": [ + "Y_pred = detection_model.predict_classes(X_test)\n", + "print(classification_report(Y_test, Y_pred))\n", + "print(\"Weighted f1-score:\", f1_score(Y_test, Y_pred, average='weighted'))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 0.95 0.80 0.87 2039\n", + " 1 0.96 0.99 0.97 9466\n", + "\n", + "avg / total 0.96 0.96 0.96 11505\n", + "\n", + "Weighted f1-score: 0.9566994574560295\n" + ] + } + ], + "source": [ + "detection_model_best = load_model('./data/model_ATSD_1.hdf5')\n", + "\n", + "Y_pred = detection_model_best.predict_classes(X_test)\n", + "print(classification_report(Y_test, Y_pred))\n", + "print(\"Weighted f1-score:\", f1_score(Y_test, Y_pred, average='weighted'))" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAU0AAAElCAYAAABgV7DzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3XdYFOcaNvB7liJIETYa/DTqEYUoUU5UNBYkdsCGEgtRUWNONHZiRRCsR6zYSGxJjokGC7EcjzEmlihiwXxGjaInxoYFRQVEFlTY3ff7w8N+UcFlhGVYuH9ee13u7jszz+zCzTMzO7OSEEKAiIiKRKV0AURE5oShSUQkA0OTiEgGhiYRkQwMTSIiGRiaREQyMDSJiGQwy9DU6XT417/+hcDAQAQEBKBr165YtGgRcnNzizXPkSNHwtfXFxs3bpQ9/blz5zBu3LjXXn5Jy8rKwuDBgwt9PiAgAI8ePTLZ8j/55BNcvnwZADBs2DCkp6cDADp06IBz586V2HLi4uLw3Xffldj8SlpiYiK6d+8OAFi+fDl27tz5WvN58f009ftHhbNUuoDXMXPmTGRmZuKbb76Bg4MDcnJyMGnSJISHh2PRokWvNc/U1FQkJCTgzJkzsLCwkD1948aNsWLFitdatilkZma+Mpz+/e9/m3T569atM/z/6NGjJlvOqVOn4ObmZrL5l6Tx48e/9rQvvp+mfv+ocGbXad66dQv/+c9/MG/ePDg4OAAAKleujFmzZqFTp04Anv1VnjRpErp3744ePXpg4cKF0Gq1AJ6F28qVKxEUFIQOHTogNjYWGo0G//jHP6DVahEYGIgbN27g7bffNnRHAAz3s7OzMW7cOAQEBKB3796YPn069Hr9cx2F3OUXpHHjxoiOjkafPn3QtWtX7NmzB+PGjYOfnx8GDx6MnJwcAMD333+Pvn37olevXmjfvr1hftOmTcOTJ08QEBAAnU6HRo0aYfz48fD19cW5c+cM6xMTE4OgoCDodDrcv38f3t7eOHHiRKGvf0ZGBpo0aWJYfmRkJAYNGmR4vkuXLrhy5Yqho5w2bRoAYMiQIbhz5w4AYMuWLQgMDES7du2wdOlSw7RbtmxB9+7d0bNnTwwbNgzXrl0DAISGhuKrr74yjMu/v2/fPhw8eBDr169/qdu8desWOnXqhDlz5qBPnz7o0qUL9u3bBwDIy8vDnDlz0LVrV/To0QPh4eHQaDQAnnXCISEh8Pf3x759+9ChQwdER0ejX79+8PX1RVxcHKZNm4aePXsiMDAQqampAIBffvkFQUFBhvVatmzZS69dft1nz55FQECA4daiRQsMGDBA1vv515/Pzz//3LAu48aNw/379wEAwcHBWLJkCQYOHIgOHTogPDwcer2+0PeWikiYmb1794oPPvjglWOmTJki5syZI/R6vXj69KkYNmyYWLNmjRBCCHd3d7FhwwYhhBDnzp0TjRo1Ek+ePBE3b94U7777rmEe7u7uIi0t7aX7O3bsEMOGDRNCCKHVakV4eLi4fv26OHHihOjWrdtrL/9F7u7u4ptvvhFCCLFmzRrRpEkTcffuXaHT6UTv3r3Frl27hEajEf369RPp6elCCCFOnz5tWIeC1mfHjh0vrY9WqxUDBw4Ua9asEUOHDhWrVq0y+h4EBweLgwcPCiGE6NKli2jdurXQaDTizz//FP7+/kIIIdq3by9+//33l17L9u3bi9mzZwshhLh3755o1KiRSElJEceOHROdOnUyjNu2bZvw9/cXer1eTJ06VXz55ZeG5f/1/ovP5bt586Zwd3c31Ll3717Rrl07IYQQy5cvF2PGjBG5ublCp9OJ0NBQERERYagvJibGMJ/27duLefPmCSGE+OGHH0SDBg3ExYsXhRBCjBo1SqxatUro9XoxaNAgce3aNSGEEHfv3hUNGzYUaWlpz/1cFFTr2bNnxfvvvy8uX74s+/1MS0sT33//vejfv7/Izs4WQgixYsUKw8/noEGDxLhx44ROpxNZWVnC29tbHD9+/FVvLRWB2XWaKpXK6F/L+Ph4DBo0CJIkwdraGkFBQYiPjzc837FjRwDAO++8g9zcXEPXVBTNmjXD5cuXERwcjLVr12LIkCGoU6eOSZbv6+sLAKhduzbc3d3h4uIClUqFt956C5mZmbCzs8Pq1atx+PBhLFu2DKtXr37lunh5eb30mIWFBRYvXox169ZBCIERI0YYfQ06d+6M+Ph4XLlyBS4uLmjatCl+/fVXHDhwAF26dDE6fX5HXq1aNVStWhVpaWk4cuQIunbtCrVaDQCGLu7WrVtG51cYKysrvP/++wAADw8PPHz4EMCz9ycoKAhWVlZQqVQIDg7GkSNHDNO9+Drlr1OtWrVQtWpVNGjQAMCz9yUzMxOSJGH16tVISkpCTEwM5s+fDyEEHj9+/Mr6kpOTMXbsWCxcuBD16tWT/X7mr0tgYCAqV64MABg8eDBOnDhh2L/fvn17qFQq2Nvbo06dOsjMzCzqy0eFMLvQ9PT0xNWrVw2bU/lSU1MxfPhwPHnyBHq9HpIkGZ7T6/WGzWMAqFSpEgAYxggj1yz56wGmWrVqYd++fRg+fDg0Gg0++ugjHDx48LnxJbV8KyurAv+f7+7du+jVqxdu376NZs2aISQk5JXrkf+L9aLbt2+jUqVKuHHjRpF+qfJDMyEhAW3atEHr1q2RkJCAgwcPws/Pz+j0lpb/f1e6JEkQQhT4h1AIAa1WaxiTLy8vz+gyABhCMX85+Qp6f/46zxdfJ2tr6+fm+aKcnBz07t0bSUlJ8PDwwJQpU2BpafnKn6u0tDR88sknmDhxIlq0aAFA/vtZ2Lr89WfNxsbG8P8XX0d6PWYXmi4uLujRowfCwsIMwanRaDBz5kw4OTnBxsYG3t7e2LhxI4QQyM3NxdatW9G6dWtZy1Gr1YYd77t37zY8Hhsbi2nTpsHb2xuTJ0+Gt7c3Lly48Ny0JbH8ojh//jzUajVGjRoFb29v/PLLLwCefRLA0tISOp3O6C/Jo0ePMHnyZMyfPx/du3dHeHi40eVWr14dzs7O2Lx5M9q0aQNvb2/8/PPPePjwoaEL+ysLC4vnfpEL0rZtW+zZs8ewn27btm1wcnJCnTp14OzsjPPnzwN49sfx5MmTsuZd0LI2bdqEvLw86PV6fPfdd2jTpo2sefxVcnIyNBoNQkJC0KFDByQmJiI3N7fQLaLs7GwMHz4cffr0Qc+ePQ2Pv8772bZtW2zbts3QkW7YsAHNmzd/LuipZJldaALAjBkzUL9+fQQFBSEgIAB9+/ZF/fr1MXfuXADA9OnTkZ6ejh49eqBHjx6oW7cuPv30U1nLmD59OmbPno3evXvjypUrqFatGgCgV69e0Ol06Nq1KwIDA5GVlYXg4OCXpi3u8ouiTZs2cHFxgZ+fH/z9/XHnzh2o1WokJyejWrVq8PT0RLdu3ZCRkfHK9WzXrh28vb0xZswY3Lx503BQJSAgoNAj8J07d0Z6ejo8PDxQq1Yt2NjYGA7EvcjPzw/BwcG4dOnSK9dl6NChGDJkCLp164adO3dizZo1hs3n+/fvw9fXF2FhYWjZsqVhOh8fH2zevBlr1qwpyksGABg5ciSqVq2KXr16wd/fH1qttkh/LArz9ttvo127dvD394e/vz9++eUX1K9fH8nJyQWO37hxI/744w/s27cPvXr1MhwQep33s0+fPmjVqhX69u0Lf39/XLhwAYsXL37tdSHjJMF+nQqxdOlS9OzZE/Xq1VO6FKIywyw7TTI9IQRq1qzJwCR6ATtNIiIZ2GkSEcmgyGmUF1OylVgslYIazjbGB5HZqmIr/xTjV7FtMkbW+MenY0p0+a+DnSYRkQxmecEOIionJPPr2xiaRKScv5zNZC4YmkSkHHaaREQysNMkIpJBVbJH40sDQ5OIlMPNcyIiGbh5TkQkAztNIiIZ2GkSEcnATpOISAZ2mkREMrDTJCKSgaFJRCSDipvnRERFx06TiEgGHggiIpKBnSYRkQzsNImIZOBVjoiIZODmORGRDNw8JyKSgZ0mEZEM7DSJiGRgp0lEJANDk4hIBm6eExHJwE6TiEgGdppERDKw0yQikoGdJhFR0UkMTSKioituaObl5SE0NBS3b9+GSqXCnDlzYGlpidDQUEiSBDc3N8yYMQMqlQoxMTE4dOgQLC0tERYWBk9PTyQnJxc49lXMb4cCEZUbkkqSdXvR4cOHodVqsXnzZowePRrLli1DVFQUQkJCEBsbCyEEDhw4gKSkJJw8eRJxcXGIjo7GrFmzAKDAscYwNIlIMZIkybq9qG7dutDpdNDr9dBoNLC0tERSUhJatGgBAPDx8cGxY8dw6tQpeHt7Q5Ik1KhRAzqdDunp6QWONYab50SkmOJunleuXBm3b9+Gv78/MjIysHr1avz666+G+drZ2SErKwsajQZOTk6G6fIfF0K8NNYYhiYRKaa4obl+/Xp4e3tj4sSJuHPnDoYMGYK8vDzD89nZ2XB0dIS9vT2ys7Ofe9zBweG5/Zf5Y43h5jkRKUeSeXuBo6MjHBwcAABVqlSBVquFh4cHEhMTAQDx8fHw8vJC06ZNkZCQAL1ej5SUFOj1eqjV6gLHGi1ZCCGKudqyXUzJNj6IzFINZxulSyATqmJbsl9P4TRwo6zxD78b9Nz97OxshIWF4f79+8jLy8PgwYPRqFEjREREIC8vD66urpg7dy4sLCywcuVKxMfHQ6/XY9q0afDy8sK1a9cKHPsqDE0qUQzN8q2kQ9N50HeyxmdsHFiiy38d3KdJRIrhh9uJiGRgaBIRyWF+mcnQJCLlsNMkIpKBoUlEJANDk4hIhoIuwlHWMTSJSDHsNImIZGBoEhHJwNAkIpLD/DKToUlEymGnSUQkgzmGJq+nWUxnf/sVQ/r4AQDSHtzD6I/6IziwCwYGdMKN61cN4/R6PYYP6o3N33753PT7f9yFyaM/KtWaqeh0Oh1Gj/gHfDu0hX/ndrh29QrOnP4NHdq2hH+n9zF5wnjo9XoAQNjUiejQtiU6vd8aJ44fVbhy81Dcr7tQAjvNYvjqi6XYtW0TbG3tAABL5kage+9+8O/5ARKPHsa1y5dQ+2+uAIDlC2Yj82HGc9PPi5yMo4f2o8E7nqVeOxXNjz/sBgD8dPAIjsQfQtjUSbiTkoIFS5bivZatMXdmBOK2bILHO41w8sRxHIg/jqtXLmPY4IE4fOykwtWbgbKRg7Kw0yyGWnXqYvm6WMP93349gdQ7KRjWvzt279iK5q3bAgB+2r0DKpWEtu07Pzd9k2bvITJqWanWTPJ07xmA5Z+vBgDcvHEDb77pgpTbt/Bey9YAgPdatcaJY0dRo0ZN2FaujKdPnyLr0SNYWbEfKQpz7DRNEpp6vR6RkZHo378/goODkZycbIrFKK5Lt16wsrIy3E+5lQxHJyd8vWU3/k/Nt/DV59H4879J+GHnVoydHPHS9P4BfcrMDwIVztLSEp9+8hGmThyPgN6B+Fvdukg4chgAsHfPD8jOyYaFpSVUKhWav/sOArr7YmzIRIWrNg/mGJom+XO4f/9+5ObmYsuWLThz5gzmz5+PVatWmWJRZUoVZzU6dO4KAGjXuSuWL5iFp0+fIvXuHXzUrytu37wBKytr1KxV56Wuk8q21ev+hdQ5Uej4fmtsituBmdOnYUX0YjRp5gXrStbYHLsBLi7VsX3Xj8jKyoJ/p/fRvEVL1KhZU+nSy7SyEoRymCQ0T506hbZtn22avvvuuzh//rwpFlPmNGveCvEHf0bPPh/i1IkE1HdviEnT5xqej1nyT1St5sLANCObYzci5fYtTJgcCtvKlaFSqfDz3j2IWf0l/k+NGpg8YTw6d/FFeno67OzsYWFhAQcHB1hbV0J2tkbp8ss8hub/aDQa2NvbG+5bWFhAq9XC0rJ87+eZPCMKkZNGY/O3X8Le0RGLYr5WuiQqph4BvTF6xMfw79wO2rw8RC2MhkoloW/vHqhc2RZtfdqhi19X6HQ6JB4/hi7tvaHT6dA36EO4ub+tdPlln/llpmm+WC0qKgp///vf0bXrs01VHx8fxMfHG57nF6uVX/xitfKtpL9Yrd7EH2WNv7LEv0SX/zpMciCoadOmhpA8c+YM3N3dTbEYIjJzkiTvVhaYZHu5c+fOOHr0KIKCgiCEwLx580yxGCIyc9yn+T8qlQqzZ882xayJqBwxw8zkGUFEpBx2mkREMphhZjI0iUg5Kn5HEBFR0bHTJCKSgfs0iYhkMMPMZGgSkXLYaRIRycDQJCKSwQwzk6FJRMphp0lEJAM/p0lEJIMZNpoMTSJSDjfPiYhkMMPMZGgSkXLYaRIRyWCGmcnQJCLlsNMkIpLBDDOToUlEymGnSUQkQ0lk5po1a3Dw4EHk5eXhww8/RIsWLRAaGgpJkuDm5oYZM2ZApVIhJiYGhw4dgqWlJcLCwuDp6Ynk5OQCx76KSb7Cl4ioKCRJknV7UWJiIk6fPo1NmzZhw4YNuHv3LqKiohASEoLY2FgIIXDgwAEkJSXh5MmTiIuLQ3R0NGbNmgUABY41hqFJRIopbmgmJCTA3d0do0ePxqeffop27dohKSkJLVq0AAD4+Pjg2LFjOHXqFLy9vSFJEmrUqAGdTof09PQCxxrDzXMiUkxxN88zMjKQkpKC1atX49atWxg5ciSEEIaAtbOzQ1ZWFjQaDZycnAzT5T9e0FhjGJpEpJjiXrDDyckJrq6usLa2hqurKypVqoS7d+8ans/OzoajoyPs7e2RnZ393OMODg7P7b/MH2u05mJVTERUDMXdPG/WrBmOHDkCIQRSU1Px+PFjtGrVComJiQCA+Ph4eHl5oWnTpkhISIBer0dKSgr0ej3UajU8PDxeGmsMO00iUkxxN8/bt2+PX3/9FX369IEQApGRkXjrrbcQERGB6OhouLq6wtfXFxYWFvDy8kL//v2h1+sRGRkJAJg6depLY43WLIQQxStbvosp2cYHkVmq4WyjdAlkQlVsLUp0fp1jTsgav29MyxJd/utgp0lEijHDz7YzNIlIOTwjiIhIBjP8tguGJhEph50mEZEMZpiZDE0iUo4E80tNhiYRKYb7NImIZChX+zT79+//0grln9y+efNmkxdGROWfGWZm4aEZHR1dmnUQUQWkMsPULDQ0a9asCQBITU3FokWLkJGRAV9fX7z99tuG54iIiqO4VzlSgtGrHEVEROCDDz5Abm4uvLy88M9//rM06iKiCkCS5N3KAqOh+fTpU7Rq1QqSJBmuV0dEVBJUkiTrVhYYPXpubW2NI0eOQK/X48yZM7C2ti6NuoioAigbMSiP0U5zzpw52L59OzIyMvD1119j5syZpVAWEVUExb0IsRKMdprVq1fHiBEjcP36dbi5uaFWrVqlURcRVQBmeBzIeGh+8cUXOHLkCBo3boz169fDz88PQ4cOLYXSiKi8KyvdoxxGQzM+Ph6xsbFQqVTQarUYMGAAQ5OISoQZZqbxfZpqtRqPHz8GAOTl5UGtVpu8KCKqGMrVPs380yjT0tIMH2q/cuXKc98dTERUHOVqnyZPoyQiUysr3aMcRk+jTE5Oxt69e5GXlwcAuHfvHmbPnl061RFRuWZ+kVmEfZpTp04FAPz222+4desWHj58aPKiiKhiMMczgoyGpo2NDUaMGAEXFxfMnz8fDx48KI26iKgCMMdzz41+5EgIgfv37yMnJwc5OTnIzMwsjbqIqAIol1c5GjNmDPbt24eePXuiY8eO8PHxKY26iKgCMMfNc6OdZvPmzdG8eXMAQMeOHU1eEBFVHGUkB2UpNDS9vb0LnSghIaFYC637pl2xpqeyy7n5GKVLIBN6fDqmROdXrj5yVNxgJCIyxuj+wTKI30ZJRIopV50mEZGpmeHB86J1xxqNBn/88QdycnJMXQ8RVSAqSd6tLDDaae7duxerV6+GTqeDn58fJEnCqFGjSqM2IirnzHHz3GinuX79emzduhVOTk4YNWoU9u/fXxp1EVEFUC47TZVKBWtra8P17GxtbUujLiKqAMyw0TQeml5eXpgwYQJSU1MRGRmJxo0bl0ZdRFQBlJWzfOQwGpoTJkxAfHw8PDw8UK9ePbRv37406iKiCsAcP6dptOadO3ciPT0dVatWRWZmJnbu3FkadRFRBWChkmTdygKjneaVK1cAPLva0cWLF+Hk5IRevXqZvDAiKv/McOvceGhOnDjR8H8hBEaMGGHSgoio4igjzaMsRkMzNzfX8P/79+/j1q1bJi2IiCqOcnkgKP8D7UII2NjY4OOPPy6NuoioAjDDzDQemuPHj0dAQEBp1EJEFYw5bp4bPXoeFxdXGnUQUQUkyfxXFhRpn2avXr1Qt25dqFTPMnbJkiUmL4yIyr+S6jTT0tIQGBiIr7/+GpaWlggNDYUkSXBzc8OMGTOgUqkQExODQ4cOwdLSEmFhYfD09ERycnKBY1/FaGhOmjSpZNaKiOgFJRGaeXl5iIyMhI2NDQAgKioKISEheO+99xAZGYkDBw6gRo0aOHnyJOLi4nDnzh2MHTsW27ZtK3Bs586dX11zYU+EhIQAAFq0aPHSjYioJORf06Kot4IsWLAAQUFBePPNNwEASUlJhpzy8fHBsWPHcOrUKXh7e0OSJNSoUQM6nQ7p6ekFjjWm0NBMT0+X/QIQEclR3Kscbd++HWq1Gm3btjU8JoQwBKydnR2ysrKg0Whgb29vGJP/eEFjjSl08/zmzZuIjo4u8LkJEyYYnTERkTHF/cjRtm3bIEkSjh8/josXL2Lq1KnPNXzZ2dlwdHSEvb09srOzn3vcwcHhuf2X+WONKbTTtLGxQd26dQu8ERGVhOJ+7/l3332HjRs3YsOGDWjYsCEWLFgAHx8fJCYmAgDi4+Ph5eWFpk2bIiEhAXq9HikpKdDr9VCr1fDw8HhprDGFdppVq1ZF7969X/e1ICIyyhSf05w6dSoiIiIQHR0NV1dX+Pr6wsLCAl5eXujfvz/0ej0iIyMLHWuMJIQQBT2xYMECTJ06tWTX5n+eaE0yWyoD+L3n5VtJf+/550evyxo/us3fSnT5r6PQTtNUgUlElK9cnkZJRGQq5ngaJUOTiBRTLq9yRERkKmaYmQxNIlIOO00iIhnMMDMZmkSkHHP8NkqGJhEpprCLcJRlDE0iUoz5RSZDk4gUxANBREQymF9kMjSJSEFm2GgyNIlIOTwQREQkgwVDk4io6MwvMhmaRKQgbp4TEcnAM4KIiGRgp0lEJIP5RSZDk4gUZIaNJkOTiJSjMsNek6FJRIphp0lEJIPETpOIqOjYaRIRycB9mkREMrDTJCKSgRchJiKSQWV+mcnQJCLl8Og5IS8vD/8YNgTJ16/DwsICX6xeh7cbNAAAbN4Ui1Wfr8ThhOMKV0mvYm1libWzBqFuzTfwKPsJQuZvRfU3HBH1WW8IIbD3aBKi1u4FACyYGIjW77pCrxcIjd6B42evolZ1Z6yeORCWFhaQJGD0nE34M/mewmtVNpnh1jlDs6Tt/XEPtFotDh05hgP792FGZDg2b92Gs2fO4Jt/fQUhhNIlkhHDAltDk/MU7w9ZArc6b2Lp1H54w8kOAyZ/heSUNOxdOw57Dp+HXgi09KyLtsGLUa92NXwb9RHaDFyIyFHdsXpzPP5z6Hd0atUQc8b2RNCkL5VerTLJHDtNc7wyU5nm5uYOrVYLvV6PR48ewcrKCmlpaYgID8WiJcuULo+KoIFrdfx8NAkA8GfyPTSo6wKfwYuRnJIGO1trONrbIj0zGyn3HiLnSS4qWVvC0c4GWq0OABAavR0/JpwHAFhaqPAkV6vYupR1KknerSxgp1nC7OztceP6dfy9UQOkPXiA73f8B58O/xgLFy+Fra2t0uVREfz+x234+zTCrl9+R4vGf0ONN50ghECLxn/Dt/M/wsWrd3E/Q4NKVpbQC4GzOyLgaG+L0XNiAQBpD7MBAG513kTUZ73Rb8JaJVenTGOn+Rdnz55FcHCwqWZfZq1cvhSduvji3IVLSDx1Fh3beSPp/DmMGzMSwQOD8N+LFzBpQojSZdIrfPPv48jSPMFP68ajq09jnL54A3q9wMlz19Gg2wycuXgTkz7qjIE9WiD1wSN49JiJht1mIHxEV9SoVgUA4OPlhq3Rw/FxxLfcn/kKkiTvVhaYpNNct24ddu3aVSE7K2dnZ1haWQEA1Go1atepg19/+x12dnZIvn4dwQODsDiam+llmdc7dXDszFVMWbIdTT1qo16tqtj/VQj6hKzBw6zH0OQ8QSVrK2Q8yoHmcS70eoGsnCd4mquFXeVK8PFyw+LJfRAw5nPcuJOh9OqUaWUkB2UxSWjWrl0bK1euxJQpU0wx+zJt7PjPMOKTYejYri1yc3Mxa8482NnZKV0WyXD5xj1EjuqGkMEd8TArByNnxaLZO7Xx75hReJqnxd0HjzBy1nd4/DQPrd6th1/WT4CFSoUtP/5f/Jl8DxsXDoO1lQXWzR4MALh0PRVj/7lZ4bUqm8zxw+2SMNHh3Fu3bmHChAnYunXrS8894X7xcsu5+RilSyATenw6pkTnd+LyQ1njW9Z3KtHlvw4eCCIi5Zhfo8nQJCLlmOPRc4YmESnGDHdpmi4033rrrQL3ZxIR5WNoEhHJwM1zIiIZ2GkSEclQ3MzMy8tDWFgYbt++jdzcXIwcORL169dHaGgoJEmCm5sbZsyYAZVKhZiYGBw6dAiWlpYICwuDp6cnkpOTCxz7KrxgBxEpR5J5e8GuXbvg5OSE2NhYrFu3DnPmzEFUVBRCQkIQGxsLIQQOHDiApKQknDx5EnFxcYiOjsasWbMAoMCxxjA0iUgxksx/L/Lz88P48eMN9y0sLJCUlIQWLVoAAHx8fHDs2DGcOnUK3t7ekCQJNWrUgE6nQ3p6eoFjjWFoEpFiinvBDjs7O9jb20Oj0WDcuHEICQmBEALS/wbb2dkhKysLGo0G9vb2z02XlZVV4FhjGJpEpJhibp0DAO7cuYPBgwcjICAAPXr0eG6fZHZ2NhwdHWFvb4/s7OznHndwcChwrDEMTSJSTjFT88GDBxg2bBgmT56MPn36AAA8PDyQmJgIAIiPj4eXlxeaNm2KhIQE6PV6pKSkQK/XQ61WFziYKKEzAAAHj0lEQVTWaMmmumDHq/CCHeUXL9hRvpX0BTt+v6mRNd6zlv1z9+fOnYsff/wRrq6uhsfCw8Mxd+5c5OXlwdXVFXPnzoWFhQVWrlyJ+Ph46PV6TJs2DV5eXrh27RoiIiJeGvsqDE0qUQzN8q2kQ/PcLXmh2fgte+ODTIyf0yQixZjhZ9sZmkSkIDNMTYYmESmG554TEclQVr6WVw6GJhEph6FJRFR03DwnIpKBl4YjIpLBDDOToUlECjLD1GRoEpFiuE+TiEgG7tMkIpLBDDOToUlECjLD1GRoEpFiuE+TiEgG7tMkIpLBDDOToUlECjLD1GRoEpFiVGa4fc7QJCLFmF9kMjSJSEFm2GgyNIlISeaXmgxNIlIMO00iIhnMMDMZmkSkHHaaREQy8DRKIiI5zC8zGZpEpBwzzEyGJhEph/s0iYhk4D5NIiI5zC8zGZpEpBwVQ5OIqOi4eU5EJIM5HghSKV0AEZE5YadJRIoxx06ToUlEiuE+TSIiGdhpEhHJYIaZydAkIgWZYWoyNIlIMdynSUQkA/dpEhHJYIaZydAkIgWZYWoyNIlIMea4T1MSQgiliyAiMhc895yISAaGJhGRDAxNIiIZGJpERDIwNE1ACIE//vgDly5dUroUIiph/MhRCRNCYOTIkXB2dkZ6ejpq1qyJyMhIpcuiEvTNN99gyJAhSpdBCmGnWcK2bt2KN954A1FRUVixYgUuXLiAWbNmKV0WlZDs7GzExsYiOjpa6VJIIQzNElavXj1IkoTU1FRUqlQJ3377LS5cuMBfsnLi3LlzUKvVuH37NsLCwpQuhxTA0Cxh9erVg62tLc6ePYv09HRYW1tjxYoVePz4sdKlUQmoW7cuBgwYgPnz5+Pp06eIiIhQuiQqZQzNEubs7Iz+/fvj8OHDOHLkCG7fvo3ffvsNV65cwdOnT5Uuj4rJxcUFnTp1gpWVFcLDw6HVajFhwgSly6JSxNMoTeTatWvYvXs3Ll26hCdPnmDKlClwc3NTuiwqYenp6Vi2bBnGjBmDN998U+lyqBQwNE1Iq9Xi0aNHAAC1Wq1wNWQqer0eKhU32ioKhiYRkQz880hEJANDk4hIBoYmEZEMDE0iIhkYmkREMjA0y5HExES0atUKwcHBCA4ORr9+/bBhw4bXmtfixYuxfft2XLx4ETExMYWO27dvH1JTU4s0z/j4eISGhr5U82effVboNNu3b8fixYuLNH85Y4leF69yVM60bNkSS5cuBQDk5ubCz88PAQEBcHR0fK35NWzYEA0bNiz0+W+//RYzZ86Ei4vLa82fyNwwNMsxjUYDlUoFCwsLBAcHw9nZGY8ePcLatWsxc+ZMJCcnQ6/XIyQkBO+99x5++uknrFq1Cmq1Gnl5eXB1dUViYiI2b96MpUuXIi4uDps2bYJer0fHjh3RuHFjXLx4EVOnTkVsbCy2bNmC3bt3Q5IkdO3aFYMHD8aVK1cQFhYGW1tb2NraokqVKoXWu3HjRvz888/QarVwcHDAypUrAQBnzpzBkCFDoNFoMHbsWLRr1w4nT57E0qVLYWFhgVq1amH27Nml9bJSBcfQLGdOnDiB4OBgSJIEKysrREREwM7ODgDQo0cPdO7cGbGxsXB2dsa8efOQkZGBQYMG4YcffsCiRYsQFxcHJycnDB8+/Ln5pqWlYd26ddi1axesra0xf/58NG/eHA0bNsTMmTNx48YN7NmzB7GxsZAkCUOHDoW3tzeWL1+OcePGoU2bNli7di2uXr1aYN16vR4PHz7E+vXroVKp8PHHH+PcuXMAAFtbW6xduxbp6eno27cv2rZti4iICMTGxuKNN97AsmXLsGPHDlha8seZTI8/ZeXMXzfPX1S3bl0AwKVLl3Dq1Cn8/vvvAJ6d7vngwQPY29vD2dkZANCkSZPnpr158ybc3NxgY2MDAC9dFu3SpUtISUnB0KFDAQCZmZm4ceMG/vzzT3h6egIAmjZtWmhoqlQqWFlZYcKECahcuTLu3r0LrVYLAGjWrBkkScIbb7wBBwcHZGRk4N69ewgJCQEAPHnyBG3atEHt2rVlvVZEr4OhWYFIkgQAcHV1RfXq1fHpp5/iyZMnWLVqFRwdHZGVlYX09HSo1WqcO3cO1atXN0xbu3ZtXL16Fbm5ubC2tsa4ceMQHh4OSZIghICrqyvq16+PL7/8EpIkYf369XB3d4erqytOnz4NHx8fnD9/vtDa/vvf/2L//v2Ii4vD48ePERgYiPwzfPM7zvv37yMnJwfOzs6oXr06vvjiCzg4OODAgQOoXLky7ty5Y8JXj+gZhmYFFBQUhOnTp2PQoEHQaDQYMGAArK2tERUVhY8//hhVqlR5aVNXrVbjk08+waBBgyBJEtq3bw8XFxc0adIEU6ZMwddff41WrVrhww8/RG5uLjw9PeHi4oIZM2bgs88+w1dffQW1Wo1KlSoVWFOdOnVga2uLwMBAWFtbo1q1arh37x6AZ53k4MGDkZOTg9mzZ8PCwgLh4eEYPnw4hBCws7PDwoULGZpUKnjBDiIiGfg5TSIiGRiaREQyMDSJiGRgaBIRycDQJCKSgaFJRCQDQ5OISIb/BwKiKogYa/KpAAAAAElFTkSuQmCC\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cnf_matrix = confusion_matrix(Y_test, Y_pred)\n", + "np.set_printoptions(precision=2)\n", + "\n", + "sns.set_style(\"dark\")\n", + "plt.figure()\n", + "utils.plot_confusion_matrix(cnf_matrix, classes=[0,1],\n", + " title='Confusion matrix, without normalization')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "Y_pred_d = Y_pred" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Activity classification" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "X_train, Y_train, X_test, Y_test, n_features, n_classes, class_weights = preprocessing.loadData(subject=subject,\n", + " label=label,\n", + " folder=folder,\n", + " window_size=window_size,\n", + " stride=stride,\n", + " make_binary=False,\n", + " null_class=False,\n", + " print_info=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Model" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "classification_model = models.ConvolutionalRecurrent((window_size, n_features), n_classes, print_info=False)\n", + "\n", + "classification_model.compile(optimizer = Adam(lr=0.001),\n", + " loss = \"categorical_crossentropy\", \n", + " metrics = [\"accuracy\"])\n", + "\n", + "checkpointer = ModelCheckpoint(filepath='./data/model_ATSC_1.hdf5', verbose=1, save_best_only=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Training" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train on 27949 samples, validate on 9466 samples\n", + "Epoch 1/15\n", + "27949/27949 [==============================] - 9s 309us/step - loss: 0.3618 - acc: 0.8527 - val_loss: 0.2699 - val_acc: 0.9145\n", + "\n", + "Epoch 00001: val_loss improved from inf to 0.26991, saving model to ./data/model_ATSC_1.hdf5\n", + "Epoch 2/15\n", + "27949/27949 [==============================] - 5s 182us/step - loss: 0.2441 - acc: 0.8977 - val_loss: 0.2743 - val_acc: 0.9184\n", + "\n", + "Epoch 00002: val_loss did not improve\n", + "Epoch 3/15\n", + "27949/27949 [==============================] - 5s 176us/step - loss: 0.2131 - acc: 0.9097 - val_loss: 0.2686 - val_acc: 0.9201\n", + "\n", + "Epoch 00003: val_loss improved from 0.26991 to 0.26859, saving model to ./data/model_ATSC_1.hdf5\n", + "Epoch 4/15\n", + "27949/27949 [==============================] - 5s 175us/step - loss: 0.1971 - acc: 0.9184 - val_loss: 0.2504 - val_acc: 0.9224\n", + "\n", + "Epoch 00004: val_loss improved from 0.26859 to 0.25044, saving model to ./data/model_ATSC_1.hdf5\n", + "Epoch 5/15\n", + "27949/27949 [==============================] - 5s 175us/step - loss: 0.1819 - acc: 0.9252 - val_loss: 0.2252 - val_acc: 0.9259\n", + "\n", + "Epoch 00005: val_loss improved from 0.25044 to 0.22523, saving model to ./data/model_ATSC_1.hdf5\n", + "Epoch 6/15\n", + "27949/27949 [==============================] - 5s 175us/step - loss: 0.1695 - acc: 0.9304 - val_loss: 0.2747 - val_acc: 0.9256\n", + "\n", + "Epoch 00006: val_loss did not improve\n", + "Epoch 7/15\n", + "27949/27949 [==============================] - 5s 176us/step - loss: 0.1651 - acc: 0.9326 - val_loss: 0.2376 - val_acc: 0.9225\n", + "\n", + "Epoch 00007: val_loss did not improve\n", + "Epoch 8/15\n", + "27949/27949 [==============================] - 5s 176us/step - loss: 0.1580 - acc: 0.9346 - val_loss: 0.2594 - val_acc: 0.9169\n", + "\n", + "Epoch 00008: val_loss did not improve\n", + "Epoch 9/15\n", + "27949/27949 [==============================] - 5s 176us/step - loss: 0.1531 - acc: 0.9375 - val_loss: 0.2125 - val_acc: 0.9304\n", + "\n", + "Epoch 00009: val_loss improved from 0.22523 to 0.21254, saving model to ./data/model_ATSC_1.hdf5\n", + "Epoch 10/15\n", + "27949/27949 [==============================] - 5s 177us/step - loss: 0.1467 - acc: 0.9414 - val_loss: 0.2676 - val_acc: 0.9220\n", + "\n", + "Epoch 00010: val_loss did not improve\n", + "Epoch 11/15\n", + "27949/27949 [==============================] - 5s 177us/step - loss: 0.1428 - acc: 0.9426 - val_loss: 0.2377 - val_acc: 0.9246\n", + "\n", + "Epoch 00011: val_loss did not improve\n", + "Epoch 12/15\n", + "27949/27949 [==============================] - 5s 177us/step - loss: 0.1353 - acc: 0.9449 - val_loss: 0.2008 - val_acc: 0.9237\n", + "\n", + "Epoch 00012: val_loss improved from 0.21254 to 0.20084, saving model to ./data/model_ATSC_1.hdf5\n", + "Epoch 13/15\n", + "27949/27949 [==============================] - 5s 177us/step - loss: 0.1343 - acc: 0.9452 - val_loss: 0.1970 - val_acc: 0.9305\n", + "\n", + "Epoch 00013: val_loss improved from 0.20084 to 0.19697, saving model to ./data/model_ATSC_1.hdf5\n", + "Epoch 14/15\n", + "27949/27949 [==============================] - 5s 179us/step - loss: 0.1287 - acc: 0.9479 - val_loss: 0.2205 - val_acc: 0.9274\n", + "\n", + "Epoch 00014: val_loss did not improve\n", + "Epoch 15/15\n", + "27949/27949 [==============================] - 5s 181us/step - loss: 0.1232 - acc: 0.9504 - val_loss: 0.2918 - val_acc: 0.9263\n", + "\n", + "Epoch 00015: val_loss did not improve\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "classification_model.fit(x = X_train,\n", + " y = to_categorical(Y_train), \n", + " epochs = 15, \n", + " batch_size = 128,\n", + " verbose = 1,\n", + " callbacks=[checkpointer],\n", + " validation_data=(X_test, to_categorical(Y_test)),\n", + " class_weight=class_weights)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 0.94 0.90 0.92 3958\n", + " 1 0.83 0.90 0.86 2333\n", + " 2 0.99 1.00 0.99 2733\n", + " 3 1.00 0.85 0.92 442\n", + "\n", + "avg / total 0.93 0.93 0.93 9466\n", + "\n", + "Weighted f1-score: 0.9267518071621716\n" + ] + } + ], + "source": [ + "Y_pred = classification_model.predict_classes(X_test)\n", + "print(classification_report(Y_test, Y_pred))\n", + "print(\"Weighted f1-score:\", f1_score(Y_test, Y_pred, average='weighted'))" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 0.92 0.93 0.93 3958\n", + " 1 0.86 0.87 0.87 2333\n", + " 2 0.99 1.00 0.99 2733\n", + " 3 1.00 0.86 0.92 442\n", + "\n", + "avg / total 0.93 0.93 0.93 9466\n", + "\n", + "Weighted f1-score: 0.9305368499675762\n" + ] + } + ], + "source": [ + "classification_model_best = load_model('./data/model_ATSC_1.hdf5')\n", + "\n", + "Y_pred = classification_model_best.predict_classes(X_test)\n", + "print(classification_report(Y_test, Y_pred))\n", + "print(\"Weighted f1-score:\", f1_score(Y_test, Y_pred, average='weighted'))" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAU0AAAElCAYAAABgV7DzAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAIABJREFUeJzt3XlYVOX///HnzLDKoiKGWW64laG/NNzR3BA1FSUNXDDDNPWjhkuiKGhigktYSe6VlZmKllmZpZYhKliWG5qWKS4okahsss2c3x9+ndJUHGA4DrwfXnNdzMx9znmfmfE19zlnzn00iqIoCCGEeCBatQsQQghLIqEphBAmkNAUQggTSGgKIYQJJDSFEMIEEppCCGECCU0hhDCBRYamXq/ngw8+wM/PD19fX3r16sXChQvJz88v0TzHjBmDj48Pa9euNXn6o0ePMmHChGIvv7RlZmYybNiwez7v6+tLRkaG2ZY/cuRI/vjjDwCCgoJIT08HoEuXLhw9erTUlhMbG8snn3xSavMrbYmJifTu3RuAt99+my1bthRrPne+n+Z+/8S9WaldQHHMnj2b69ev8+GHH+Lk5EROTg5TpkxhxowZLFy4sFjzTE1NJT4+nkOHDqHT6UyevmnTprzzzjvFWrY5XL9+/b7h9MUXX5h1+atWrTL+vXfvXrMt5+DBgzRs2NBs8y9Nr776arGnvfP9NPf7J+7N4nqaFy5c4Msvv2TevHk4OTkBUKlSJV5//XW6desG3PxWnjJlCr1796ZPnz4sWLCAwsJC4Ga4LVmyhICAALp06cK6devIysri5ZdfprCwED8/P86dO0fjxo2NvSPAeD87O5sJEybg6+tL//79mTlzJgaD4bYehanLv5umTZsSHR3NgAED6NWrF9u2bWPChAn06NGDYcOGkZOTA8CmTZsYOHAg/fr1o3Pnzsb5TZ8+ndzcXHx9fdHr9Xh4ePDqq6/i4+PD0aNHjesTExNDQEAAer2etLQ0vLy8SEhIuOfrf/XqVZo3b25cfnh4OEOHDjU+3717d06fPm3sUU6fPh2AF198kUuXLgGwYcMG/Pz86NSpE4sXLzZOu2HDBnr37k3fvn0JCgrizJkzAEybNo333nvP2O7W/R07dvD999+zZs2a//Q2L1y4QLdu3YiIiGDAgAF0796dHTt2AFBQUEBERAS9evWiT58+zJgxg6ysLOBmTzg4OJiePXuyY8cOunTpQnR0NC+88AI+Pj7ExsYyffp0+vbti5+fH6mpqQD88MMPBAQEGNfrrbfe+s9rd6vuw4cP4+vra7y1atWKwYMHm/R+/vvz+e677xrXZcKECaSlpQEQGBjIm2++yZAhQ+jSpQszZszAYDDc870VD0ixMNu3b1eef/75+7aZOnWqEhERoRgMBiUvL08JCgpSVqxYoSiKojRq1Ej5+OOPFUVRlKNHjyoeHh5Kbm6ucv78eeXpp582zqNRo0bKlStX/nP/888/V4KCghRFUZTCwkJlxowZytmzZ5WEhATlueeeK/by79SoUSPlww8/VBRFUVasWKE0b95cuXz5sqLX65X+/fsrW7duVbKyspQXXnhBSU9PVxRFUX799VfjOtxtfT7//PP/rE9hYaEyZMgQZcWKFcrw4cOVZcuWFfkeBAYGKt9//72iKIrSvXt3pV27dkpWVpby+++/Kz179lQURVE6d+6sHDly5D+vZefOnZU5c+YoiqIof/31l+Lh4aGkpKQo+/btU7p162Zst3nzZqVnz56KwWBQQkJClNWrVxuX/+/7dz53y/nz55VGjRoZ69y+fbvSqVMnRVEU5e2331bGjRun5OfnK3q9Xpk2bZoSFhZmrC8mJsY4n86dOyvz5s1TFEVRvv76a+WJJ55QTpw4oSiKoowdO1ZZtmyZYjAYlKFDhypnzpxRFEVRLl++rDz55JPKlStXbvtc3K3Ww4cPK88++6zyxx9/mPx+XrlyRdm0aZPi7++vZGdnK4qiKO+8847x8zl06FBlwoQJil6vVzIzMxUvLy9l//7993trxQOwuJ6mVqst8tsyLi6OoUOHotFosLGxISAggLi4OOPzXbt2BeCpp54iPz/f2Gt6EM888wx//PEHgYGBrFy5khdffJE6deqYZfk+Pj4A1K5dm0aNGuHm5oZWq+Xxxx/n+vXrODg4sHz5cn788Ufeeustli9fft918fT0/M9jOp2ORYsWsWrVKhRF4ZVXXinyNfD29iYuLo7Tp0/j5uZGixYt+Omnn9i1axfdu3cvcvpbPfLq1avj6urKlStX2LNnD7169cLFxQXA2Iu7cOFCkfO7F2tra5599lkAmjRpwrVr14Cb709AQADW1tZotVoCAwPZs2ePcbo7X6db61SrVi1cXV154okngJvvy/Xr19FoNCxfvpykpCRiYmKIiopCURRu3Lhx3/qSk5MZP348CxYsoH79+ia/n7fWxc/Pj0qVKgEwbNgwEhISjPv3O3fujFarxdHRkTp16nD9+vUHffnEPVhcaDZr1ow///zTuDl1S2pqKqNGjSI3NxeDwYBGozE+ZzAYjJvHALa2tgDGNkoRY5b8+wBTrVq12LFjB6NGjSIrK4uXXnqJ77///rb2pbV8a2vru/59y+XLl+nXrx8XL17kmWeeITg4+L7rces/1p0uXryIra0t586de6D/VLdCMz4+nvbt29OuXTvi4+P5/vvv6dGjR5HTW1n9sytdo9GgKMpdvwgVRaGwsNDY5paCgoIilwEYQ/HWcm652/vz73ne+TrZ2NjcNs875eTk0L9/f5KSkmjSpAlTp07Fysrqvp+rK1euMHLkSCZPnkyrVq0A09/Pe63Lvz9rdnZ2xr/vfB1F8VhcaLq5udGnTx9CQ0ONwZmVlcXs2bOpUqUKdnZ2eHl5sXbtWhRFIT8/n40bN9KuXTuTluPi4mLc8f7VV18ZH1+3bh3Tp0/Hy8uL1157DS8vL44fP37btKWx/Adx7NgxXFxcGDt2LF5eXvzwww/AzV8CWFlZodfri/xPkpGRwWuvvUZUVBS9e/dmxowZRS63Ro0aVK1alfXr19O+fXu8vLz47rvvuHbtmrEX9m86ne62/8h306FDB7Zt22bcT7d582aqVKlCnTp1qFq1KseOHQNufjkeOHDApHnfbVmffvopBQUFGAwGPvnkE9q3b2/SPP4tOTmZrKwsgoOD6dKlC4mJieTn599ziyg7O5tRo0YxYMAA+vbta3y8OO9nhw4d2Lx5s7FH+vHHH9OyZcvbgl6ULosLTYBZs2bRoEEDAgIC8PX1ZeDAgTRo0IC5c+cCMHPmTNLT0+nTpw99+vShXr16jB492qRlzJw5kzlz5tC/f39Onz5N9erVAejXrx96vZ5evXrh5+dHZmYmgYGB/5m2pMt/EO3bt8fNzY0ePXrQs2dPLl26hIuLC8nJyVSvXp1mzZrx3HPPcfXq1fuuZ6dOnfDy8mLcuHGcP3/eeFDF19f3nkfgvb29SU9Pp0mTJtSqVQs7Ozvjgbg79ejRg8DAQE6dOnXfdRk+fDgvvvgizz33HFu2bGHFihXGzee0tDR8fHwIDQ2lTZs2xuk6duzI+vXrWbFixYO8ZACMGTMGV1dX+vXrR8+ePSksLHygL4t7ady4MZ06daJnz5707NmTH374gQYNGpCcnHzX9mvXruXkyZPs2LGDfv36GQ8IFef9HDBgAG3btmXgwIH07NmT48ePs2jRomKviyiaRpH+uriHxYsX07dvX+rXr692KUI8NCyypynMT1EUHnvsMQlMIe4gPU0hhDCB9DSFEMIEqpxGad98nBqLVc3ZHxcX3aiccLa3yDNzi+3fP/epCOxK+e01NQtu/BpTugUUg/Q0hRDCBBWrWyCEeLhoLK/fJqEphFCPBe7ekNAUQqhHeppCCGEC6WkKIYQJtKYP+K02CU0hhHpk81wIIUwgm+dCCGGCEvY09Xo9M2fO5MyZM+h0OiIjI8nMzGT06NHUrVsXgEGDBtGrVy9iYmLYvXs3VlZWhIaG0qxZM5KTk5k2bRoajYaGDRsya9Ys4xis9yKhKYRQTwl7mrfGHF2/fj2JiYlERkbSpUsXXnrpJYKCgoztkpKSOHDgALGxsVy6dInx48ezefNmIiMjCQ4OpnXr1oSHh7Nr1y68vb3vu0wJTSGEekrY0+zWrRudOnUCICUlBVdXV44dO8aZM2fYtWsXderUITQ0lIMHD+Ll5YVGo6FmzZro9XrS09NJSkoyjpzfsWNH9u7dK6EphHiIlcI+TSsrK0JCQtixYwfvvPMOqampDBw4EA8PD5YtW8a7776Lk5MTVapUMU7j4OBAZmYmiqIYxw+49VhRLO/QlRCi/NBoTbvdw/z58/n2228JCwvDy8sLDw8P4OYVBo4fP46joyPZ2dnG9tnZ2Tg5Od22/zI7OxtnZ+ciS5bQFEKop4SheeuyKAD29vZoNBrGjRvHkSNHANi/fz9PPfUULVq0ID4+HoPBQEpKCgaDARcXF5o0aUJiYiJw88qed7ti651k81wIoR5tyTbPu3fvzvTp0xkyZAiFhYWEhoby6KOPEhERgbW1Na6urkRERODo6Iinpyf+/v4YDAbCw8MBCAkJISwsjOjoaNzd3Y2Xzb4fVUZul/E0yy8ZT7N8K/XxNLu8YVL7G98X/wJ4paVifcKFEA8XC/zSkdAUQqhHTqMUQggTSE9TCCFMIKMcCSGECWTzXAghTCCb50IIYQLpaZY9rVbD0rDBNKr7CHqDwqhZa8nKzuXd8MFUda6ETqthRNjHOFWyZeFrA4zTtWpalxcmreTnY8kc2RLO8dOXANj6/WHe/XS3SmtjmoKCAib+byTnzyWTl5fHxNem06Jla6ZMGM31a9fQ6/UsWf4+dd3rA/D332n08X6WH/b/gp2dncrVF19BQQGjR44gOfkseXl5hEyfQa1atZk8cQI6nQ5bW1tWvf8hbm5uapda6gwGA6+OG8uRI4extbVl2YrV1G/QQO2yik96mmXvuY5NAejy0mI6PNOQ+ZP9uJaRw4ZtP7F5x6909GxI47pubI9Pwmfk2wD4dWvOpbTr7Nh3gs6tGxP77UEmzY9VczWKZdOGdVR1qUbMyjWkp1/Bu0Mr2nfshN/AQfj6DSQ+bje//36Suu71+WHnd7wxewZpaalql11in65bi0s1F95b8xFXrlyhbasW1K1bjzcXv8P/e/ppVq9aQfSi+cxfGK12qaVu6xdbyM3N5cf4/SQmJDBt6mRiP/tC7bKKT3qaZe/L3UfYtucYALVruvDXlUyebdmQY7+n8PXycSSnpDNlwSZj+0p2Nswc0wvvEW8B0OLJWjz9xON8t/pV0tIzmbxgE5f/zlBlXUzVt9/z9PH1M97X6az4KWE/TZ5qysC+PahVuw4R828Gh1arZePW7XR/to1a5ZYav+cH0t/vn60GKysrPlz7KY8++igAhYWF2Npabk/6fvbtjcfbpwcArdu04eDBn1WuqIQsMDQtr+K70OsNrJoTSPTUAXy+81fqPFqNqxk5PDc6hvOX05n80j/j4w3v35bPdvzKlWs3Rzw5eTaViOXb6P7y23z5wxGiQwaqtRomc3B0xNHJiazMTF4eFsC0sNmcP3eWylWqErt1O4/VqkXMWwsBeLZLN1xcqqlccelwdHTEycmJzMxMhgQMZNbsCGNgJuzfx4ql7zL+1YkqV2kemRkZVK5c2Xhfp9NRWFioYkUlpNGYdnsImCU0b50Q7+/vT2BgIMnJyeZYzG1Ghn9Ms35zWBo+mGtZOXz941EAtv14jBZNahvbBfRsyZrP9xnv7z5wih9/OgXAFz8c5v81ftzstZamixfO49fbmwH+Q/AbOIiqLtXw6dUbgO49nuPwr7+oXKF5XDh/nh7eXRg0ZCj+gwYDsGnjBib8bwyfffEV1atXV7lC83Bydr5tzEeDwYCVlQVvMJbS0HBlySxV7Ny5k/z8fDZs2MDkyZOJiooyx2IAGPRcS6YEdQcgJ7cAg8FA/ME/8PFqAoBXiwac+L+DPM6OdtjYWHEh9Zpx+mXhg+nf9WkAOrdqzK8nzput1tKW9lcqAf17Efb6PAYHDgegVZt27PruGwAS9sXT+IkmKlZoHqmpqfTp5cPceVG8OPzmJQ0+/WQty5e9y/adP1DP3V3lCs2nbbv2fPvNNgASExLw8GiqckUlZIE9TbOMchQZGUmzZs147rnnAOjQoQN79uwxPl+aoxxVsrNh5etDcXN1xtpKx6IPvuPIyQssDR+Cg70N17NuMHz6Gq5l3uCZJrUJedmHFyatMk5fp2Y1VswegkajIftGHmPnrCv1fZrmGuVoZsgkvvgslgaNGhsfe2fZe0weP5qcnGycnCuzbPVHVKla1fi8Z9OGxP901GxHz8tilKMpk15lU+xGGjd+Arh5ca3jSceoVbuOcXRurw4dCZv1utlrKetRjm4dPT969AiKorBy9Qc0fuKJMlt+qY9y1H+1Se1vfP5y6RZQDGYJzRkzZtC9e3eeffZZADp16sTOnTuNmxEyNFz5JUPDlW+lHpp+75nU/sZnI0q3gGIwyyf8zqHlLX6/ixDCLCzxS8cs+zRbtGhBXFwcAIcOHaJRo0bmWIwQwsJpNBqTbg8Ds3T/vL292bt3LwEBASiKwrx588yxGCGEhdOU8HIXajBLaGq1WubMmWOOWQshypGHpfdoCtnRKIRQjYSmEEKYQEJTCCFMYXmZKaEphFCP9DSFEMIEEppCCGECCU0hhDBBSUNTr9czc+ZMzpw5g06nIzIyEkVRmDZtGhqNhoYNGzJr1iy0Wi0xMTHs3r0bKysrQkNDadasGcnJyXdtez8Px1hLQoiKSWPi7Q4//PADAOvXr2fChAlERkYSGRlJcHAw69atQ1EUdu3aRVJSEgcOHCA2Npbo6Ghef/3mYC53a1sUCU0hhGpKehplt27diIiIACAlJQVXV1eSkpJo1aoVAB07dmTfvn0cPHgQLy8vNBoNNWvWRK/Xk56efte2RZHQFEKopjTOPbeysiIkJISIiAh8fHxQFMXY1sHBgczMTLKysnB0dDROc+vxu7UtiuzTFEKoprQOBM2fP58pU6bwwgsvkJeXZ3w8OzsbZ2fn/4y8lp2djZOT0237L2+1LYr0NIUQqtFoNSbd7rRlyxZWrFgBgL29PRqNBg8PDxITEwGIi4vD09OTFi1aEB8fj8FgICUlBYPBgIuLC02aNPlP26JIT1MIoZqS9jS7d+/O9OnTGTJkCIWFhYSGhlK/fn3CwsKIjo7G3d0dHx8fdDodnp6e+Pv7G69hBhASEvKftkXWbI6R24siI7eXXzJye/lW2iO3Pzpqs0ntL618vnQLKIaK9QkXQjxULPFLR0JTCKEey8tMCU0hhHqkpymEECaQ0BRCCBNIaAohhCksLzMlNIUQ6pGephBCmEBCUwghTCChKYQQJpDQfEBndkersVjVTPwiSe0Sysz7g55WuwRhSSwvM6WnKYRQT1GXlngYSWgKIVRjgVvnEppCCPXIPk0hhDCBBWamhKYQQj3S0xRCCBNYYGZKaAoh1KO9y3V/HnYSmkII1UhPUwghTCD7NIUQwgQWmJkSmkII9UhPUwghTCChKYQQJrDAzJTQFEKop6Q9zYKCAkJDQ7l48SL5+fmMGTOGGjVqMHr0aOrWrQvAoEGD6NWrFzExMezevRsrKytCQ0Np1qwZycnJTJs2DY1GQ8OGDZk1a1aRg4hIaAohVFPS32lu3bqVKlWqsHDhQq5evUr//v353//+x0svvURQUJCxXVJSEgcOHCA2NpZLly4xfvx4Nm/eTGRkJMHBwbRu3Zrw8HB27dqFt7f3fZcpoSmEUE1JN8979OiBj4+P8b5Op+PYsWOcOXOGXbt2UadOHUJDQzl48CBeXl5oNBpq1qyJXq8nPT2dpKQkWrVqBUDHjh3Zu3evhKYQ4uFV0s1zBwcHALKyspgwYQLBwcHk5+czcOBAPDw8WLZsGe+++y5OTk5UqVLltukyMzNRFMVYw63HimJ5I4AKIcoNjca0291cunSJYcOG4evrS58+ffD29sbDwwMAb29vjh8/jqOjI9nZ2cZpsrOzcXJyum3/ZXZ2Ns7OzkXWLKEphFCNRqMx6Xanv//+m6CgIF577TUGDBgAwIgRIzhy5AgA+/fv56mnnqJFixbEx8djMBhISUnBYDDg4uJCkyZNSExMBCAuLg5PT88ia5bNcyGEakq6T3P58uVkZGSwdOlSli5dCsC0adOYN28e1tbWuLq6EhERgaOjI56envj7+2MwGAgPDwcgJCSEsLAwoqOjcXd3v23/6D1rVhRFKVnZprt8Pb+sF6mqSVuPq11CmZELq5VvdqXczWo7P86k9vtDOpZuAcUgPU0hhGrkx+1CCGECOY1SZQUFBUz83yjOn08mPy+P4CnTeLTm44RMGoeNrS0eHs2ImB+NVqtlVuhrHEjYd/PvufNp1aad2uU/EJ0GRrWrTXUHG6x1Gj4/msrF67mMblcbRYEL13L54MAFFOCFp2vg8agTigIf/XSR01dyqGxnxf+86mCl1XDtRgHL950jX1/me2hKTUFBAa+8HERy8lny8vKYFjqT3n36ql2W2RgMBl4dN5YjRw5ja2vLshWrqd+ggdplFZsFZmb5Cs3NG9ZR1cWFmJUfkJ5+Be+OrXF1rc7c+dG0bN2WqLmz+Cx2PU828eDnAwls2xXPmT//YHRQIN/9mKB2+Q/Ey92FrDw9y/b+gaONjnm9G5OcfoONhy5zIjWLoNaP80ytyqRl5dPA1YHwb37H1cGGyZ3qMf3rk/T1cGPPn+ns+fMqzzerQddGrnxzIk3t1Sq2Tz9Zi0u1arz/4cdcuXKFNi2bl+vQ3PrFFnJzc/kxfj+JCQlMmzqZ2M++ULusYpOepsr69Hue3r5+xvtWOh2XUi7SsnVbAFq2bsu3276kc7fu2NtXIi8vj8zMTKytrdUq2WQJyddITL5mvG8wKNSrZs+J1CwADl/MoGlNJ9YcuEjUrtMAuDpYcz23AICPf76IBtAALg7WXLqUW9arUKr8Bgyk//MDjPetrMrVR/o/9u2Nx9unBwCt27Th4MGfVa6oZCwxNM32O83Dhw8TGBhortnflYOjI45OTmRlZjJy2CBCZr5Onbr12Bd/8wjdju1fk5Odg5XOCq1WQ4eWzfD37cmY8RPLtM6SyCs0kFtowM5Ky6vP1mXjoUto+OeDd6PQQCVrHQAG5eYm+mtd3Nl35p+g1Wpgfp8neMrNkVNp2f9ZhiVxdHTEycmJzMxMBvsPYNbrc9UuyawyMzKoXLmy8b5Op6OwsFDFikqmNH7cXtbMEpqrVq1i5syZ5OXlmWP293Xxwnme79OdAf6D8RsYwOJ3V7Jk8QKGvuCLq+sjuFSrRuz6tVR3q0HCoRMkHj7JoqgILqVcLPNai8ulkjUzuzcg/s+r7Dt7DcO/fjVmb6UlJ19vvL/x0GX+tymJ3k9V5xFHGwD0Ckz98jdWJ5xnTPs6ZV5/aTt//jw9unVm8JBAAgYNVrscs3Jydr7tVD+DwWDRvWutVmPS7WFgltCsXbs2S5YsMces7yvtr1QC/J5j5utvMChwOAA7v/2GxTErWbvxC66mX+HZzl2pXKUqDg4O6HQ6HJ2csLGxJTs7q8zrLQ5nOyumd6vPp7+k8OPpdACSr97gSTdHAP7fY8789lc2TWo4MrzVYwAU6A0UGkABXmr1OE3+r+2NQsNtgWuJUlNT6dOrO3Mj5/PiS0FFT2Dh2rZrz7ffbAMgMSEBD4+mKldUMiU9I0gNZvmK8vHx4cKFC+aY9X29/eZ8rl+7xuIFkSxeEAnA6HGvMnSgL/aVKtGuw7N07d4TvV7PT4n76NP9WfR6PX4DA2jQsHGZ11sc/TzccLDR0b9pDfr/3/+Xj36+wIstH8dKq+Hi9VwSz93cFG9TuwqzfBqg1WjYcTKNtKx8vv0tjaDWteiPgqLAB4ll/z6VpgVR87h29SqRb0QQ+UYEAF989Q329vYqV2Yevv368/3OHXTq0A5FUVi5+gO1SyqRhyQHTWK2M4IuXLjApEmT2Lhx43+ekzOCyi85I6h8K+0zgrxjTPvVyo5xbUq3gGKw3J0hQgiLZ4k9TQlNIYRqHpb9lKYwW2g+/vjjd900F0KIWx6SA+ImkZ6mEEI10tMUQggTWGBmSmgKIdTz77PZLIWEphBCNbJPUwghTFCu9mn6+/v/Z4VuXe5y/fr1Zi9MCFH+WWBm3js0o6Ojy7IOIUQFpLXA1LxnaD722M3BHlJTU1m4cCFXr17Fx8eHxo0bG58TQoiSeFhGLjJFkaMchYWF8fzzz5Ofn4+npydvvPFGWdQlhKgAyuV4mnl5ebRt2xaNRoO7uzu2trZlUZcQogLQajQm3R4GRR49t7GxYc+ePRgMBg4dOoSNjU1Z1CWEqAAejhg0TZGhGRERwfz587l69Srvv/8+s2fPLoOyhBAVQUl/clRQUEBoaCgXL14kPz+fMWPG0KBBA6ZNm4ZGo6Fhw4bMmjULrVZLTEwMu3fvxsrKitDQUJo1a0ZycvJd295PkaFZo0YNXnnlFc6ePUvDhg2pVatWiVZSCCFuKelxoK1bt1KlShXjwer+/fvzxBNPEBwcTOvWrQkPD2fXrl3UrFmTAwcOEBsby6VLlxg/fjybN28mMjLyP229vb3vu8wiQ3Pp0qXs2bOHpk2bsmbNGnr06MHw4cNLtqZCCEHJe5o9evTAx8fHeF+n05GUlESrVq0A6NixI3v37qVevXp4eXmh0WioWbMmer2e9PT0u7YtKjSLPBAUFxfHJ598QmhoKGvXrmXbtm0lWUchhDAq6dFzBwcHHB0dycrKYsKECQQHBxtPwrn1fGZmJllZWTg6Ot42XWZm5l3bFqXI0HRxceHGjRvAzf0HLi4uD/RiCCFEUUrjwmqXLl1i2LBh+Pr60qdPn9v2SWZnZ+Ps7IyjoyPZ2dm3Pe7k5HTXtkW5Z2j6+/sTEBDA77//jo+PDyNGjKBnz55cvnz5gV4MIYQoilZj2u1Of//9N0FBQbz22msMGDAAgCZNmpCYmAjc3FL29PSkRYsWxMfHYzAYSElJwWAw4OLicte2RZHTKIUDprsjAAAXv0lEQVQQqinpPs3ly5eTkZHB0qVLWbp0KQAzZsxg7ty5REdH4+7ujo+PDzqdDk9PT/z9/TEYDISHhwMQEhJCWFjYbW2LrLmoq1EmJyezfft2CgoKAPjrr7+YM2dOiVZUrkZZfsnVKMu30r4aZdD6oya1fz9A/eu8F7lPMyQkBIBffvmFCxcucO3aNbMXJYSoGCzxjKAiQ9POzo5XXnkFNzc3oqKi+Pvvv8uiLiFEBWCJ554X2dlWFIW0tDRycnLIycnh+vXrZVGXEKICKJejHI0bN44dO3bQt29funbtSseOHcuiLiFEBWCJm+dF9jRbtmxJy5YtAejatavZCxJCVBwPSQ6a5J6h6eXldc+J4uPjS7TQKg4Va6SkinRE2WPaN2qXUKaORfVUuwSLVq6uEVTSYBRCiKIUuX/wISRXoxRCqKZc9TSFEMLcLPDg+YP1jrOysjh58iQ5OTnmrkcIUYGU9NxzNRTZ09y+fTvLly9Hr9fTo0cPNBoNY8eOLYvahBDlnCVunhfZ01yzZg0bN26kSpUqjB07lp07d5ZFXUKICqBc9jS1Wi02NjbG8ezs7e3Loi4hRAVggR3NokPT09OTSZMmkZqaSnh4OE2bqj/KiBCifHhYzvIxRZGhOWnSJOLi4mjSpAn169enc+fOZVGXEKICsMTfaRZZ85YtW0hPT8fV1ZXr16+zZcuWsqhLCFEB6LQak24PgyJ7mqdPnwZujnZ04sQJqlSpQr9+/cxemBCi/LPArfOiQ3Py5MnGvxVF4ZVXXjFrQUKIiuMh6TyapMjQzM//59IUaWlpXLhwwawFCSEqjnJ5IOjWD9oVRcHOzo4RI0aURV1CiArAAjOz6NB89dVX8fX1LYtahBAVjCVunhd59Dw2NrYs6hBCVEAaE/89DB5on2a/fv2oV68eWu3NjH3zzTfNXpgQovyzxJ5mkaE5ZcqUsqhDCFEBlavQDA4O5q233qJVq1ZlWY8QogIpV6Mcpaenl2UdQogKqLRGOTp8+DCBgYEAJCUl0aFDBwIDAwkMDGTbtm0AxMTEMGDAAAICAjhy5AgAycnJDBo0iMGDBzNr1iwMBkORNd+zp3n+/Hmio6Pv+tykSZOKnLEQQhSlNDqaq1atYuvWrcYR2I4fP85LL71EUFCQsU1SUhIHDhwgNjaWS5cuMX78eDZv3kxkZCTBwcG0bt2a8PBwdu3ahbe3932Xd8/QtLOzo169eiVfIyGEuIfS+HF77dq1WbJkCVOnTgXg2LFjnDlzhl27dlGnTh1CQ0M5ePAgXl5eaDQaatasiV6vJz09naSkJOMuyI4dO7J3797ih6arqyv9+/cv8QoJIcS9lMaBIB8fn9vOVGzWrBkDBw7Ew8ODZcuW8e677+Lk5ESVKlWMbRwcHMjMzERRFON+1VuPFeWeoenh4VGS9XioGAwGXh03liNHDmNra8uyFaup36CB2mWVugOJicwMDeG7XbuNj702eSKNGjVm5Cuj1SusBKy0GqL8m/JYVXtsrLQs3XmaPs0fpbqzLQCPVbXnUPI1gj85TEjvxnjWrYpOp2FDwnk2JF7g0Sp2RL3QFCvdzd/5zdh0jDNp2SqvVfGVt8+yzgwHgry9vXF2djb+HRERQdeuXcnO/ud9z87OxsnJyfgzyluP3Zrufu55ICgkJKQkdT9Utn6xhdzcXH6M30/EG1FMmzq56IkszJuLFjD2lZfJzc0Fbo4T4Nu7J19/tVXlykrG95maXM0uYNDSREas/plZ/ZsQ/Mlhhiw7wJg1v5Bxo5A3tv5Gm/ou1KlWiYExCQTEJDKqszvO9lZM7NGQj/cmM2TZAZbtOs2UXo3UXqUSKW+fZY3GtNuDGDFihPFAz/79+3nqqado0aIF8fHxGAwGUlJSMBgMuLi40KRJExITEwGIi4vD09OzyPlXiEv47tsbj7dPDwBat2nDwYM/q1xR6XN3r8/62M8IGn7zCGJ2VhYzwmbz3bffqFxZyXxz+DLbj1w23i80KMa/X+1+MxDTMvO4fqOA4ykZACgoaLUaCvUKkVt/IzO3EACdTkNeQdFHRx9m5e2zbI7fac6ePZuIiAisra1xdXUlIiICR0dHPD098ff3x2AwEB4eDtzsHIaFhREdHY27uzs+Pj5Fzr9ChGZmRgaVK1c23tfpdBQWFmJlVX5Wv7/f8ySfPWu8X7dePerWq2fxoZmTrwfAwVZHzLDmLN5+CgAXRxvaNazGG1tPAJBfaCC/0ICVVsOCgGZsSDhPTr7eOH296g5M6/0EY9b8os6KlJLy9lkurVGOHn/8cTZu3AjAU089xfr16//TZvz48YwfP/62x+rVq8fatWtNWpYljjZvMidn59t28BoMBov9kFVEj1a2Y+3o1mw5mMKXv14CoGezGmz99RL/6njibG/F+yM9+SM1i+Xf/2l8vE19F5YNb8GUTw9b9P5MKH+fZXNsnptbhQjNtu3a8+03N3/gmpiQgIeHXBzOUlRztOGDUS1Z8PVJNv30zxHSdg2rEfdbmvG+rZWWj19pxaYDF3l352nj423quzDT90mCVv/EsQsZZVq7OZS3z7JWozHp9jCw3K8oE/j268/3O3fQqUM7FEVh5eoP1C5JPKAxXetT2d6acd71GeddH4CgVT/jXt2Bc1dyjO0Gt61NrWqV8G/zOP5tHgcgZMNRZvg+ibWVloUBzQD4869swjYnlf2KlJLy9ll+SHLQJBpFUZSim5Wu/9svL8ohj2mWvQ/VVMeieqpdQpmyK+Vu1pqfzpnUfnjL2qVbQDFUiJ6mEOLhZIkDdkhoCiFUY3mRKaEphFDRw3JwxxQSmkII1VheZEpoCiFUZIEdTQlNIYR65ECQEEKYwByjHJmbhKYQQjWWF5kSmkIIFcnmuRBCmMASB7+Q0BRCqEZ6mkIIYQLLi0wJTSGEiiywoymhKYRQj9YC+5oSmkII1UhPUwghTKCRnqYQQjw46WkKIYQJZJ+mEEKYQHqaQghhAkschNgSz2ISQpQTWo1pt3s5fPgwgYGBACQnJzNo0CAGDx7MrFmzMBgMAMTExDBgwAACAgI4cuTIfdvet+aSr7YQQhSPxsR/d7Nq1SpmzpxJXl4eAJGRkQQHB7Nu3ToURWHXrl0kJSVx4MABYmNjiY6O5vXXX79n26JIaAohVKPRmHa7m9q1a7NkyRLj/aSkJFq1agVAx44d2bdvHwcPHsTLywuNRkPNmjXR6/Wkp6fftW1RJDSFEKopjZ6mj48PVlb/HJ5RFMU4EIiDgwOZmZlkZWXh6OhobHPr8bu1LYocCBJCqOZ++ymLPU/tP33B7OxsnJ2dcXR0JDs7+7bHnZyc7tq2yPmXbrlCCPHgSqOneacmTZqQmJgIQFxcHJ6enrRo0YL4+HgMBgMpKSkYDAZcXFzu2rYo0tMUQqjGHL84CgkJISwsjOjoaNzd3fHx8UGn0+Hp6Ym/vz8Gg4Hw8PB7ti2yZkVRlNIv+/5yC8t6iaKseEz7Ru0SytSxqJ5ql1Cm7Eq5m7X396smtW/fsGrpFlAM0tMsA4X6on/7VV5UtBBJy8hTu4QyVcvFtlTnZ4k/bpfQFEKoxvIiU0JTCKEmC0xNCU0hhGpkPE0hhDCBBe7SlNAUQqhHQlMIIUwgm+dCCGEC6WkKIYQJLDAzJTSFECqywNSU0BRCqEb2aQohhAlkn6YQQpjAAjNTQlMIoSILTE0JTSGEamSfphBCmED2aQohhAksMDMlNIUQKrLA1JTQFEKoRvZpCiGECcxxCV9zk9AUQqhHQlMIIR6cbJ4LIYQJ5CdHQghhAgvMzIoRmgUFBbzychDJyWfJy8tjWuhMevfpq3ZZpW7Rgii2ff0lBfn5vDxqNM1bPMOUia+i0+mwtbVh5Xsf8oibm9plliqDwcCr48Zy5MhhbG1tWbZiNfUbNFC7rBLT6/WEBI/h9B+n0Ol0LFqykoVvzCbtr8sAXDiXTHPP1ry7+mPmzJzKTwn70Gq1zIyIomXrdipXbwILTM0KEZqffrIWl2rVeP/Dj7ly5QptWjYvd6G558fdJCbsY+cPe8jJyeGdxW/y6SdrWbT4bZr9v6d5f9UKot9cQNSCN9UutVRt/WILubm5/Bi/n8SEBKZNnUzsZ1+oXVaJ7dz+NQCff7Ob/fE/EjEzhPc+2QTAtWtXCfD1IXzuAo4fO8LBAwls3bGHs3+e5n8jA9n2/X41SzdJaezT7NevH05OTgA8/vjj+Pv788Ybb6DT6fDy8mLcuHEYDAZmz57NyZMnsbGxYe7cudSpU6dYy6sQoek3YCD9nx9gvG9lVf5We+eO73jqqaYMesGPzIxM5kbOJ+jlUdR49FEACvV67GztVK6y9O3bG4+3Tw8AWrdpw8GDP6tcUenwea4vXX16AXDh/DlcH3nE+Fx0VATDR47BrcajWFtbY29vT15eHpmZGVhbWatVcrGUdJ9mXl4eAB9//LHxMV9fX5YsWUKtWrUYNWoUSUlJXLx4kfz8fDZs2MChQ4eIiopi2bJlxVpm+UuPu3B0dAQgMzOTwf4DmPX6XJUrKn1XrvzN+XPJxH7+JWfPnsH/+X78cuQ4AAn797Fi2bts37lb3SLNIDMjg8qVKxvv63Q6CgsLy8UXo5WVFRPHjuDbr7eyfM06AP5O+4u9cT8w642FAOisrNBotXRp04yMjAzmv7VUzZJNVtJ+5m+//caNGzcICgqisLCQ8ePHk5+fT+3atQHw8vJi//79pKWl0aFDBwCefvppjh07VuxlWv4n6wGdP3+egAH9GTV6LAGDBqtdTqlzqVaNRo2fwMbGhkaNGmNnZ8ffaWnE/fgDC+dHsunzL6levbraZZY6J2dnMjMzjfcNBkO5CMxbFi99j79SL+PbvQO79h1i29bP6fe8PzqdDoDN6z/hkUfcWLvpK7KyMnm+Vxee8WxNjZqPqVz5AyphatrZ2TFixAgGDhzI2bNnGTlyJM7OzsbnHRwcOH/+PFlZWcbOE5Tsy1VbspItQ2pqKn16dWdu5HxefClI7XLMom279uz8bjuKonApJYWc7Gy++/YbVixbyrbvvqeeu7vaJZpF23bt+fabbQAkJiTg4dFU5YpKx+YNnxCzeAEA9vaV0Gq1aHU64n/8nk7dfIztKlepQiVHR3Q6HY6OTtjY2pKdna1W2SbTmPjvTvXq1aNv375oNBrq1auHk5MT165dMz6fnZ2Ns7Mzjo6Ot70uJflyrRChuSBqHteuXiXyjQi6d+1E966duHHjhtpllaqevXrT7OnmdPJqwwvP+/Lm20uYPnUyWVmZDPEfQE/vLrwxZ7baZZY63379sbOzo1OHdkydMpEFixarXVKp6Nm7H0lHDzOgd1cCB/Zh1huLsLOz4/Qfp6hdt56xXb8BAQD079Hp5m1AAPUbNlKrbJNpNKbd7rRp0yaioqKAm52jGzduUKlSJc6dO4eiKMTHx+Pp6UmLFi2Ii4sD4NChQzRqVPzXSKMoilLsqYspt7Csl6iuQr1B7RLKjJWuQnwPG6Vl5KldQpmq5WJbqvM7+3euSe3rut5+MDM/P5/p06eTkpKCRqNhypQpaLVa5s2bh16vx8vLi4kTJxqPnp86dQpFUZg3bx7169cvVs0SmmVAQrP8ktAsmbNXTAzNaur/AqT87DEXQlgcrQWeRymhKYRQjeVFpoSmEEJFFtjRlNAUQqjJ8lJTQlMIoRrpaQohhAksMDMlNIUQ6pGephBCmEAudyGEEKawvMyU0BRCqMcCM1NCUwihHtmnKYQQJpB9mkIIYQrLy0wJTSGEerQSmkII8eBk81wIIUxgiQeCKtaIsUIIUULS0xRCqMYSe5oSmkII1cg+TSGEMIH0NIUQwgQWmJkSmkIIFVlgakpoCiFUI/s0hRDCBLJPUwghTFDSzDQYDMyePZuTJ09iY2PD3LlzqVOnTqnUdi/y43YhhHo0Jt7usHPnTvLz89mwYQOTJ08mKirK7CVLT1MIoZqS7tM8ePAgHTp0AODpp5/m2LFjpVHWfakSmnYVLaqtpENfXtVysVW7BItmb12y6bOysnB0dDTe1+l0FBYWYmVlvpCR/81CCIvl6OhIdna28b7BYDBrYIKEphDCgrVo0YK4uDgADh06RKNGjcy+TI2iKIrZlyKEEGZw6+j5qVOnUBSFefPmUb9+fbMuU0JTCCFMUK43zxVF4eTJk5w6dUrtUoQQ5US5PY6tKApjxoyhatWqpKen89hjjxEeHq52WWb34Ycf8uKLL6pdhhDlVrntaW7cuJFq1aoRGRnJO++8w/Hjx3n99dfVLsussrOzWbduHdHR0WqXIkS5VW5Ds379+mg0GlJTU7G1teWjjz7i+PHj5TpQjh49iouLCxcvXiQ0NFTtcoQol8p1aNrb23P48GHS09OxsbHhnXfe4caNG2qXZjb16tVj8ODBREVFkZeXR1hYmNolCVHulNvQrFq1Kv7+/vz444/s2bOHixcv8ssvv3D69Gny8vLULs8s3Nzc6NatG9bW1syYMYPCwkImTZqkdllClCvl/idHZ86c4auvvuLUqVPk5uYydepUGjZsqHZZZSI9PZ233nqLcePG8cgjj6hdjhDlQrkPTYDCwkIyMjIAcHFxUbmasmUwGNBqy+0GhRBlrkKEphBClBbpggghhAkkNIUQwgQSmkIIYQIJTSGEMIGEphBCmEBCsxxJTEykbdu2BAYGEhgYyAsvvMDHH39crHktWrSIzz77jBMnThATE3PPdjt27CA1NfWB5hkXF8e0adP+U/PEiRPvOc1nn33GokWLHmj+prQVorjK7ShHFVWbNm1YvHgxAPn5+fTo0QNfX1+cnZ2LNb8nn3ySJ5988p7Pf/TRR8yePRs3N7dizV8ISyOhWY5lZWWh1WrR6XQEBgZStWpVMjIyWLlyJbNnzyY5ORmDwUBwcDCtW7fm22+/ZdmyZbi4uFBQUIC7uzuJiYmsX7+exYsXExsby6efforBYKBr1640bdqUEydOEBISwrp169iwYQNfffUVGo2GXr16MWzYME6fPk1oaCj29vbY29tTuXLle9a7du1avvvuOwoLC3FycmLJkiXAzcsYvPjii2RlZTF+/Hg6derEgQMHWLx4MTqdjlq1ajFnzpyyellFBSehWc4kJCQQGBiIRqPB2tqasLAwHBwcAOjTpw/e3t6sW7eOqlWrMm/ePK5evcrQoUP5+uuvWbhwIbGxsVSpUoVRo0bdNt8rV66watUqtm7dio2NDVFRUbRs2ZInn3yS2bNnc+7cObZt28a6devQaDQMHz4cLy8v3n77bSZMmED79u1ZuXIlf/75513rNhgMXLt2jTVr1qDVahkxYgRHjx4FwN7enpUrV5Kens7AgQPp0KEDYWFhrFu3jmrVqvHWW2/x+eefm/2CWkKAhGa58+/N8zvVq1cPgFOnTnHw4EGOHDkC3DzN9O+//8bR0ZGqVasC0Lx589umPX/+PA0bNsTOzg7gP0PPnTp1ipSUFIYPHw7A9evXOXfuHL///jvNmjUDbl4E616hqdVqsba2ZtKkSVSqVInLly9TWFgIwDPPPINGo6FatWo4OTlx9epV/vrrL4KDgwHIzc2lffv21K5d26TXSojikNCsQDQaDQDu7u7UqFGD0aNHk5uby7Jly3B2diYzM5P09HRcXFw4evQoNWrUME5bu3Zt/vzzT/Lz87GxsWHChAnMmDEDjUaDoii4u7vToEEDVq9ejUajYc2aNTRq1Ah3d3d+/fVXOnbsyLFjx+5Z22+//cbOnTuJjY3lxo0b+Pn5cesM31s9zrS0NHJycqhatSo1atRg6dKlODk5sWvXLipVqsSlS5fM+OoJcZOEZgUUEBDAzJkzGTp0KFlZWQwePBgbGxsiIyMZMWIElStX/s+mrouLCyNHjmTo0KFoNBo6d+6Mm5sbzZs3Z+rUqbz//vu0bduWQYMGkZ+fT7NmzXBzc2PWrFlMnDiR9957DxcXF2xtbe9aU506dbC3t8fPzw8bGxuqV6/OX3/9BdzsSQ4bNoycnBzmzJmDTqdjxowZjBo1CkVRcHBwYMGCBRKaokzIgB1CCGEC+Z2mEEKYQEJTCCFMIKEphBAmkNAUQggTSGgKIYQJJDSFEMIEEppCCGGC/w+yKg6RERQHowAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "cnf_matrix = confusion_matrix(Y_test, Y_pred)\n", + "np.set_printoptions(precision=2)\n", + "\n", + "sns.set_style(\"dark\")\n", + "plt.figure()\n", + "utils.plot_confusion_matrix(cnf_matrix, classes=[0,1],\n", + " title='Confusion matrix, without normalization')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Cascade of detection and classification\n", + "The labels that have to be used for assessment are saved in Y_test_true. The labels predicted by the detection_model are saved instead in Y_pred_d." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(11505,) (11505,)\n" + ] + } + ], + "source": [ + "print(Y_test_true.shape, Y_pred_d.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "X_train, Y_train, X_test, Y_test, n_features, n_classes, class_weights = preprocessing.loadData(subject=subject,\n", + " label=label,\n", + " folder=folder,\n", + " window_size=window_size,\n", + " stride=stride,\n", + " make_binary=True,\n", + " null_class=True,\n", + " print_info=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "mask = (Y_pred_d == 1)\n", + "X_detected = X_test[mask, :, :]\n", + "Y_pred_c = classification_model_best.predict_classes(X_detected)\n", + "Y_pred_d[mask] = Y_pred_c" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 0.30 0.84 0.44 2039\n", + " 1 0.10 0.06 0.08 3958\n", + " 2 0.00 0.00 0.00 2333\n", + " 3 0.00 0.00 0.00 2733\n", + " 4 0.00 0.00 0.00 442\n", + "\n", + "avg / total 0.09 0.17 0.11 11505\n", + "\n", + "Weighted f1-score: 0.1054545307463817\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "C:\\Users\\Riccardo\\Anaconda3\\lib\\site-packages\\sklearn\\metrics\\classification.py:1135: UndefinedMetricWarning: Precision and F-score are ill-defined and being set to 0.0 in labels with no predicted samples.\n", + " 'precision', 'predicted', average, warn_for)\n", + "C:\\Users\\Riccardo\\Anaconda3\\lib\\site-packages\\sklearn\\metrics\\classification.py:1135: UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples.\n", + " 'precision', 'predicted', average, warn_for)\n" + ] + } + ], + "source": [ + "print(classification_report(Y_test_true, Y_pred_d))\n", + "print(\"Weighted f1-score:\", f1_score(Y_test_true, Y_pred_d, average='weighted'))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "One-shot classification instead had:" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "X_train, Y_train, X_test, Y_test, n_features, n_classes, class_weights = preprocessing.loadData(subject=subject,\n", + " label=label,\n", + " folder=folder,\n", + " window_size=window_size,\n", + " stride=stride,\n", + " make_binary=False,\n", + " null_class=True,\n", + " print_info=False)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " precision recall f1-score support\n", + "\n", + " 0 0.88 0.88 0.88 2039\n", + " 1 0.92 0.91 0.92 3958\n", + " 2 0.83 0.83 0.83 2333\n", + " 3 0.97 1.00 0.98 2733\n", + " 4 0.91 0.84 0.88 442\n", + "\n", + "avg / total 0.91 0.91 0.91 11505\n", + "\n", + "Weighted f1-score: 0.9073576036830062\n" + ] + } + ], + "source": [ + "oneshot_model_best = load_model('./data/model_AOS_1.hdf5')\n", + "\n", + "Y_pred = oneshot_model_best.predict_classes(X_test)\n", + "print(classification_report(Y_test, Y_pred))\n", + "print(\"Weighted f1-score:\", f1_score(Y_test, Y_pred, average='weighted'))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# end" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.4" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/code/TaskB.ipynb b/code/HAR_system.ipynb similarity index 94% rename from code/TaskB.ipynb rename to code/HAR_system.ipynb index 6673ead..ee24e6c 100644 --- a/code/TaskB.ipynb +++ b/code/HAR_system.ipynb @@ -4,15 +4,26 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Task B - single subject - model ?\n", + "# HAR system - Lincetto Riccardo, Drago Matteo\n", + "This notebook runs:\n", + "- Classification with null class (One Shot classification);\n", + "- Binary classification for activity detection (Two Steps - detection);\n", + "- Classification without null class (Two Steps - classification);\n", + "- Cascade of the last to methods.\n", + "\n", + "The operations performed here are very similar to those execute in 'main.py', with the exception that here the program is executed for specified user and model.\n", "\n", "## Notebook setup\n", "This first cell contains the parameters that can be tuned for code execution:\n", "- subject: select the subject on which to test the model, between [1,4];\n", - "- label: index of feature column to be selected to perform activity detection, between [0,6]. The default value for task B is 6;\n", - "- folder: directory name where '.mat' files are stored;\n", + "- task: choose \"A\" for locomotion classification or \"B\" for gesture recognition;\n", + "- model_name: choose between \"Convolutional\", \"Convolutional1DRecurrent\", \"Convolutional2DRecurrent\" and \"ConvolutionalDeepRecurrent\";\n", + "- data_folder: directory name where '.mat' files are stored;\n", "- window_size: parameter that sets the length of temporal windows on which to perform the convolution;\n", - "- stride: step length to chose the next window." + "- stride: step length to chose the next window;\n", + "- GPU: boolean flag indicatin wheter GPU is present on the machine that executes the code;\n", + "- epochs: number of complete sweeps of the data signals during training;\n", + "- batch_size: number of forward propagations in the networks between consecutives backpropagations." ] }, { @@ -22,10 +33,21 @@ "outputs": [], "source": [ "subject = 1\n", - "label = 0\n", - "folder = \"../data/full/\"\n", + "task = \"A\"\n", + "model_name = \"Convolutional\"\n", + "data_folder = \"./data/full/\"\n", "window_size = 15\n", - "stride = 5" + "stride = 5\n", + "GPU = True\n", + "epochs = 10\n", + "batch_size = 32" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here the useful functions are imported." ] }, { @@ -47,21 +69,22 @@ "import preprocessing\n", "import models\n", "import utils\n", + "import os\n", "import numpy as np\n", - "import seaborn as sns\n", - "import matplotlib.pyplot as plt\n", "from sklearn.metrics import classification_report, f1_score, confusion_matrix\n", "from keras.models import load_model\n", "from keras.optimizers import Adam\n", - "from keras.callbacks import ModelCheckpoint\n", - "from keras.utils import to_categorical" + "from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau\n", + "from keras.utils import to_categorical\n", + "import seaborn as sns\n", + "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "In the following cell, we make use of some functions of Keras which have been removed, but of which the code is still available at https://github.com/keras-team/keras/commit/a56b1a55182acf061b1eb2e2c86b48193a0e88f7. These are used to evaulate the f1 score during training on batches of data: this is only an approximation though, which is the reason why they have been removed." + "Differently from 'main.py', all results saved from this notebook are going to be stored in a dedicated folder: './data/notebook/'. For proper execution, this folder needs first to be created." ] }, { @@ -70,50 +93,104 @@ "metadata": {}, "outputs": [], "source": [ - "import os\n", "if not(os.path.exists(\"./data\")):\n", - " os.mkdir(\"./data\")" + " os.mkdir(\"./data\")\n", + "if not(os.path.exists(\"./data/notebook\")):\n", + " os.mkdir(\"./data/notebook\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "# One-shot classification\n", - "Here classification is performed with null class.\n", - "### Preprocessing" + "If task A is selected, calssifications in the following notebook are based on the labels of column 0; if instead it's task B, column 6 labels are used." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Task A uses labels column 0\n" + ] + } + ], "source": [ - "X_train, Y_train, X_test, Y_test, n_features, n_classes, class_weights = preprocessing.loadData(subject=subject,\n", - " label=label,\n", - " folder=folder,\n", - " window_size=window_size,\n", - " stride=stride,\n", - " make_binary=False,\n", - " null_class=True,\n", - " print_info=False)" + "if task == \"A\":\n", + " label = 0\n", + "elif task == \"B\":\n", + " label = 6\n", + "else:\n", + " print(\"Error: invalid task.\")\n", + "print(\"Task\", task, \"uses labels column\", label)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Y_train and Y_test contain the correct labels for each signals window. Y_test in particular will be used to evaluate predictions for both this (one-shot) and the two-steps models. For this reason it is here saved with a different name, to avoid having it being overwritten later." + "## Classification with null class: One Shot classification\n", + "Here classification is performed considering inactivity as a class, alongside with the others. In the case of locomotion classification (task A), this becomes a 5-class problem, while in the case of gesture recognition (task B) the classes become 18. In the following cell are perfomed in order:\n", + "- preprocessing;\n", + "- model selection;\n", + "- model compilation;\n", + "- training.\n", + "\n", + "Note that in case \"Convolutional2DRecurrent\" is the model selected, then the preprocessed data need to be reshaped, adding one dimension; this is automatically done by the code." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "metadata": {}, "outputs": [], "source": [ - "Y_test_true = Y_test" + "# preprocessing\n", + "X_train, Y_train, X_test, Y_test, n_features, n_classes, class_weights = preprocessing.loadData(subject=subject,\n", + " label=label,\n", + " folder=data_folder,\n", + " window_size=window_size,\n", + " stride=stride,\n", + " make_binary=False,\n", + " null_class=True,\n", + " print_info=False)\n", + "\n", + "# model selection\n", + "if model_name == \"Convolutional\":\n", + " model = models.Convolutional((window_size, n_features), n_classes, print_info=False)\n", + "elif model_name == \"Convolutional1DRecurrent\":\n", + " model = models.Convolutional1DRecurrent((window_size, n_features), n_classes, GPU=GPU, print_info=False)\n", + "elif model_name == \"Convolutional2DRecurrent\":\n", + " model = models.Convolutional2DRecurrent((window_size, n_features, 1), n_classes, GPU=GPU, print_info=False)\n", + " # reshaping for 2D convolutional model\n", + " X_train = X_train.reshape(X_train.shape[0], window_size, n_features, 1)\n", + " X_test = X_test.reshape(X_test.shape[0], window_size, n_features, 1)\n", + "elif model_name == \"ConvolutionalDeepRecurrent\":\n", + " model = models.ConvolutionalDeepRecurrent((window_size, n_features), n_classes, GPU=GPU, print_info=False)\n", + "else:\n", + " print(\"Model not found.\")\n", + " break\n", + "\n", + "# model compilation\n", + "model.compile(optimizer = Adam(lr=0.001), loss = \"categorical_crossentropy\", metrics = [\"accuracy\"])\n", + "save_model_name = task + \"_\" + model_name + \"_OS_\" + str(s)\n", + "filepath = './data/notebook/'+save_model_name+'.hdf5'\n", + "print(\"Model:\", save_model_name, \"\\nLocation:\", filepath, \"\\n\")\n", + "\n", + "# training\n", + "checkpointer = ModelCheckpoint(filepath=filepath, verbose=1, save_best_only=True)\n", + "lr_reducer = ReduceLROnPlateau(factor=0.1, patience=5, min_lr=0.00001, verbose=1)\n", + "model.fit(x = X_train, \n", + " y = to_categorical(Y_train), \n", + " epochs = epochs, \n", + " batch_size = batch_size,\n", + " verbose = 1,\n", + " validation_data=(X_test, to_categorical(Y_test)),\n", + " callbacks=[checkpointer, lr_reducer])" ] }, { diff --git a/code/__pycache__/models.cpython-36.pyc b/code/__pycache__/models.cpython-36.pyc index 8741df1..92cee87 100644 Binary files a/code/__pycache__/models.cpython-36.pyc and b/code/__pycache__/models.cpython-36.pyc differ diff --git a/code/__pycache__/preprocessing.cpython-36.pyc b/code/__pycache__/preprocessing.cpython-36.pyc index 066d3f2..33f05bd 100644 Binary files a/code/__pycache__/preprocessing.cpython-36.pyc and b/code/__pycache__/preprocessing.cpython-36.pyc differ diff --git a/code/__pycache__/utils.cpython-36.pyc b/code/__pycache__/utils.cpython-36.pyc index c6df091..c9b450a 100644 Binary files a/code/__pycache__/utils.cpython-36.pyc and b/code/__pycache__/utils.cpython-36.pyc differ diff --git a/code/main.py b/code/main.py index 7df11ad..2b4cbe4 100644 --- a/code/main.py +++ b/code/main.py @@ -13,7 +13,7 @@ subject = [1,2,3,4] task = "A" # choose between "A" or "B" model_names = ["Convolutional", "Convolutional1DRecurrent", "Convolutional2DRecurrent", "ConvolutionalDeepRecurrent"] -data_folder = "../data/full/" +data_folder = "./data/full/" window_size = 15 stride = 5 GPU = True diff --git a/code/main_multiuser.py b/code/main_multiuser.py index d799f58..a930f5f 100644 --- a/code/main_multiuser.py +++ b/code/main_multiuser.py @@ -13,7 +13,7 @@ subject = [23] task = "A" # choose between "A" or "B" model_names = ["Convolutional", "Convolutional1DRecurrent", "Convolutional2DRecurrent", "ConvolutionalDeepRecurrent"] -data_folder = "../data/full/" +data_folder = "./data/full/" window_size = 15 stride = 5 GPU = True