Skip to content

Commit 5d87bad

Browse files
authoredNov 30, 2023
Add files via upload
0 parents  commit 5d87bad

2 files changed

+2
-0
lines changed
 
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"collapsed_sections":["UB0c4YxQYVPt","dLhyjh72YX6U","iDqr47LNRxLP","eZvVXvXHcKDa","4r1jLTe5cMCw","Ri5xZKofs-0u","2dImq6M2s9ll"]},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"},"accelerator":"TPU"},"cells":[{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"8BT-Y06IXRsd","executionInfo":{"status":"ok","timestamp":1701062057049,"user_tz":-330,"elapsed":20088,"user":{"displayName":"Avanti Kulkarni","userId":"13157108548794740668"}},"outputId":"99c773f4-2dc6-402e-ed4c-8935b6581391"},"outputs":[{"output_type":"stream","name":"stdout","text":["Mounted at /content/drive\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"markdown","source":["# Library"],"metadata":{"id":"UB0c4YxQYVPt"}},{"cell_type":"code","source":["import pandas as pd\n","import os\n","import numpy as np\n","import matplotlib.pyplot as plt\n","from scipy import signal\n","from scipy.stats import pearsonr\n","from scipy.signal import find_peaks\n","import seaborn as sns\n","%matplotlib inline\n","\n","font_size = 16\n","# Specify the path to save containing data CSV files\n","path = '/content/drive/MyDrive/Gait_phase_analysis_DH803/healthy gait/Moin'\n","fs = 104\n","cutoff_frequency = 2 # Hz\n","filter_types = ['butterworth']\n","# Specify the path to save the result CSV file\n","result_path = '/content/drive/MyDrive/Gait_phase_analysis_DH803/healthy gait/Result'"],"metadata":{"id":"ITPKilsoX2VG"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Functions"],"metadata":{"id":"dLhyjh72YX6U"}},{"cell_type":"code","source":["def apply_lowpass_filters(data, sampling_frequency, cutoff_frequency, filter_types, filter_order=4):\n"," \"\"\"\n"," Apply multiple lowpass filters (including moving average) to the input data and calculate correlation coefficients.\n","\n"," Parameters:\n"," - data (array-like): The input data to filter.\n"," - sampling_frequency (float): The sampling frequency of the input data (in Hz).\n"," - cutoff_frequency (float): The cutoff frequency of the lowpass filters (in Hz).\n"," - filter_types (list of str): A list of filter types to use ('butterworth', 'chebyshev1', 'chebyshev2', 'moving_average').\n"," - filter_order (int): The order of the filters (only applicable to certain filter types).\n","\n"," Returns:\n"," - results (dict): A dictionary containing filtered data and correlation coefficients for each filter type.\n"," \"\"\"\n"," nyquist_frequency = 0.5 * sampling_frequency\n"," results = {}\n","\n"," for filter_type in filter_types:\n"," if filter_type == 'butterworth':\n"," b, a = signal.butter(filter_order, cutoff_frequency / nyquist_frequency, btype='low')\n"," filtered_data = signal.filtfilt(b, a, data)\n"," elif filter_type == 'chebyshev1':\n"," b, a = signal.cheby1(filter_order, 1, cutoff_frequency / nyquist_frequency, btype='low')\n"," filtered_data = signal.filtfilt(b, a, data)\n"," elif filter_type == 'chebyshev2':\n"," b, a = signal.cheby2(filter_order, 30, cutoff_frequency / nyquist_frequency, btype='low')\n"," filtered_data = signal.filtfilt(b, a, data)\n"," elif filter_type == 'moving_average':\n"," filter_window = int(sampling_frequency / cutoff_frequency) # Window size for moving average\n"," filtered_data = np.convolve(data, np.ones(filter_window) / filter_window, mode='same')\n"," else:\n"," raise ValueError(\"Invalid filter_type. Supported types: 'butterworth', 'chebyshev1', 'chebyshev2', 'moving_average'\")\n","\n"," correlation_coefficient, _ = pearsonr(data, filtered_data)\n","\n"," results[filter_type] = {\n"," 'filtered_data': filtered_data,\n"," 'correlation_coefficient': correlation_coefficient\n"," }\n","\n"," return results"],"metadata":{"id":"FQV0Xn1eYZh9"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Data Read"],"metadata":{"id":"yHRknbpCYZwF"}},{"cell_type":"code","source":["csv_files = [file for file in os.listdir(path) if file.endswith('.csv')]\n","dataframes = []\n","# Loop through each CSV file and read it into a separate dataframe\n","for csv_file in csv_files:\n"," # print(csv_file)\n"," # Construct the full file path\n"," file_path = os.path.join(path, csv_file)\n","\n"," # Read the CSV file into a dataframe\n"," data = pd.read_csv(file_path)\n"," # # Store the dataframe in the dictionary with the file name as the key\n"," # dataframes[csv_file] = data\n","\n"," HeelADCRaw = data['Heel ADC0']\n"," MidFootADCRaw = data['MidFoot ADC1']\n"," ToeADCRaw = data['Toe ADC2']\n","\n"," GyroXRaw = data[' GyroX (deg/s)']\n"," GyroXRaw = GyroXRaw - np.min(GyroXRaw)\n"," GyroYRaw = data[' GyroY (deg/s)']\n"," GyroYRaw = GyroYRaw - np.min(GyroYRaw)\n"," GyroZRaw = data[' GyroZ (deg/s)']\n"," GyroZRaw = GyroZRaw - np.min(GyroZRaw)\n","\n"," AccXRaw = data[' AccX (g)']\n"," AccYRaw = data[' AccY (g)']\n"," AccZRaw = data[' AccZ (g)']\n","\n"," start, end = 0, len(GyroZRaw)\n"," GyroXRawSeg = GyroXRaw[start: end]\n"," GyroYRawSeg = GyroYRaw[start: end]\n"," GyroZRawSeg = GyroZRaw[start: end]\n","\n"," results1 = apply_lowpass_filters(AccXRaw, fs, cutoff_frequency, filter_types, 2)\n"," for filter_type, result in results1.items():\n"," AccXFiltSeg = result['filtered_data']\n"," results1 = apply_lowpass_filters(AccYRaw, fs, cutoff_frequency, filter_types, 2)\n"," for filter_type, result in results1.items():\n"," AccYFiltSeg = result['filtered_data']\n"," results1 = apply_lowpass_filters(AccZRaw, fs, cutoff_frequency, filter_types, 2)\n"," for filter_type, result in results1.items():\n"," AccZFiltSeg = result['filtered_data']\n","\n"," results1 = apply_lowpass_filters(GyroXRawSeg, fs, cutoff_frequency, filter_types, 2)\n"," for filter_type, result in results1.items():\n"," GyroXFiltSeg = result['filtered_data']\n"," results1 = apply_lowpass_filters(GyroYRawSeg, fs, cutoff_frequency, filter_types, 2)\n"," for filter_type, result in results1.items():\n"," GyroYFiltSeg = result['filtered_data']\n"," results1 = apply_lowpass_filters(GyroZRawSeg, fs, cutoff_frequency, filter_types, 2)\n"," for filter_type, result in results1.items():\n"," GyroZFiltSeg = result['filtered_data']\n","\n","\n","\n"," ## Find max of the given window\n"," max_window = max(GyroZFiltSeg)\n"," # print(max_window)\n","\n"," ## Find the lower threshold based on max found\n"," thr = 0.32\n"," low_thr = int(max_window - (0.32 * max_window))\n"," # print(max_window, low_thr)\n","\n"," ## Begin MSw Finding based on Max peak detetion and locallized peak detection\n"," ## Find the local peaks based on lower threshold\n"," GyroZPksPos, _ = find_peaks(GyroZFiltSeg, height=low_thr, distance=fs//1)\n"," GyroZPksVal = GyroZFiltSeg[GyroZPksPos]\n","\n","\n"," ## Begin finding the TO (towards Left of MSw) and IC (towards Right of MSw)\n"," GyroZPksPosLeftMin, GyroZPksPosRightMin = [], []\n"," for i in range(len(GyroZPksPos)):\n"," if (i == 0):\n"," low, high = (GyroZPksPos[i]), (GyroZPksPos[i]+50)\n"," b = GyroZPksPos[i] + np.argmin(GyroZFiltSeg[low : high])\n"," #print(b, GyroZPksPos[i])\n"," GyroZPksPosRightMin.append(b)\n"," else:\n"," #print(i)\n"," low, high = (GyroZPksPos[i]-50), (GyroZPksPos[i])\n"," a = GyroZPksPos[i] - 50 + np.argmin(GyroZFiltSeg[low : high])\n"," #print(a, GyroZPksPos[i])\n"," GyroZPksPosLeftMin.append(a)\n","\n"," if ((GyroZPksPos[i]+50) > len(GyroZFiltSeg)):\n"," continue\n"," else:\n"," low, high = (GyroZPksPos[i]), (GyroZPksPos[i]+50)\n"," b = GyroZPksPos[i] + np.argmin(GyroZFiltSeg[low : high])\n"," #print(b, GyroZPksPos[i])\n"," GyroZPksPosRightMin.append(b)\n","\n"," ## Begin Finding MSt (Between IC and TO)\n"," if (GyroZPksPosLeftMin > GyroZPksPosRightMin):\n"," loop_len = len(GyroZPksPosLeftMin)\n"," else:\n"," loop_len = len(GyroZPksPosRightMin)\n","\n"," GyroZPksPosMSt = []\n"," for i in range(loop_len):\n"," low, high = (GyroZPksPosRightMin[i]), (GyroZPksPosLeftMin[i])\n"," a = GyroZPksPosRightMin[i] + np.argmax(GyroZFiltSeg[low : high])\n"," #print(a, GyroZPksPos[i])\n"," GyroZPksPosMSt.append(a)\n","\n","\n","\n"," label = []\n"," for n in range(0, GyroZPksPos[0]):\n"," label.append(4)\n","\n"," for idx in range(len(GyroZPksPos)-1):\n"," for m in range(GyroZPksPos[idx], GyroZPksPosRightMin[idx]):\n"," label.append(3)\n"," for j in range(GyroZPksPosRightMin[idx], GyroZPksPosMSt[idx]):\n"," label.append(0)\n"," for k in range(GyroZPksPosMSt[idx], GyroZPksPosLeftMin[idx]):\n"," label.append(1)\n"," for i in range(GyroZPksPosLeftMin[idx], GyroZPksPos[idx+1]):\n"," label.append(2)\n"," # print(GyroZPksPos[idx], GyroZPksPosRightMin[idx], GyroZPksPosMSt[idx], GyroZPksPosLeftMin[idx])\n","\n"," for o in range(GyroZPksPos[len(GyroZPksPos)-1], len(GyroZFiltSeg)):\n"," label.append(4)\n","\n"," # for i in range(len(label)):\n"," # label[i] = label[i]*100 + int(np.mean(GyroZFiltSeg))\n"," fig, ax1 = plt.subplots(figsize=(20, 5))\n"," plt.title('Peak Finding [Major-Swing (MSw), Toe-Off (TO), Initial Contact (IC), Major-Stance (MSt)]', fontsize=font_size)\n"," # Plotting on the first y-axis\n"," ax1.plot(GyroZFiltSeg, label='Filtered Data', alpha=0.6, linewidth=2)\n"," ax1.plot(GyroZPksPos, GyroZFiltSeg[GyroZPksPos], \"x\", label='MSw')\n"," ax1.plot(GyroZPksPosLeftMin, GyroZFiltSeg[GyroZPksPosLeftMin], \"x\", label='TO')\n"," ax1.plot(GyroZPksPosRightMin, GyroZFiltSeg[GyroZPksPosRightMin], \"x\", label='IC')\n"," ax1.plot(GyroZPksPosMSt, GyroZFiltSeg[GyroZPksPosMSt], \"x\", label='MSt')\n"," ax1.set_xlabel('Samples', fontsize=font_size)\n"," ax1.set_ylabel('Amplitude', fontsize=font_size)\n"," ax1.legend(loc='upper left', fontsize=font_size)\n"," # Creating a second y-axis\n"," ax2 = ax1.twinx()\n"," ax2.plot(label, \"--\", label='Label', color='orange')\n"," ax2.set_ylabel('Label', fontsize=font_size)\n"," # Manually setting tick labels\n"," tick_labels = ['Heel Strike', 'Mid-Stance', 'Toe-Off', 'Mid Swing', 'N/A']\n"," ax2.set_yticks([0, 1, 2, 3, 4])\n"," ax2.set_yticklabels(tick_labels, fontsize=font_size)\n"," ax2.legend(loc='upper right', fontsize=font_size)\n"," plt.grid(True)\n"," plt.tight_layout()\n"," plt.show()\n","\n"," # Create a dataframe with the specified columns\n"," df = pd.DataFrame({\n"," 'AccX (g)': AccXFiltSeg,\n"," 'AccY (g)': AccYFiltSeg,\n"," 'AccZ (g)': AccZFiltSeg,\n"," 'GyroX (deg/s)': GyroXFiltSeg,\n"," 'GyroY (deg/s)': GyroYFiltSeg,\n"," 'GyroZ (deg/s)': GyroZFiltSeg,\n"," 'Heel ADC0': HeelADCRaw,\n"," 'MidFoot ADC1': MidFootADCRaw,\n"," 'Toe ADC2': ToeADCRaw,\n"," 'LABEL': label\n"," })\n"," df['LABEL'].value_counts().plot(kind='bar') #Plot freuqncy of class\n","\n"," dataframes.append(df)\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000,"output_embedded_package_id":"1K-IIdToWCyDFw6xlbLukx7JRXSKVzdFH"},"id":"X3_QkfvpXSnO","executionInfo":{"status":"ok","timestamp":1701062224814,"user_tz":-330,"elapsed":21415,"user":{"displayName":"Avanti Kulkarni","userId":"13157108548794740668"}},"outputId":"ff971ddd-ba09-42e2-bb85-7decf64c5d13"},"execution_count":null,"outputs":[{"output_type":"display_data","data":{"text/plain":"Output hidden; open in https://colab.research.google.com to view."},"metadata":{}}]},{"cell_type":"markdown","source":["# Save Labelled data into CSV file"],"metadata":{"id":"iDqr47LNRxLP"}},{"cell_type":"code","source":["# # This code will store the result in same columns.\n","# # Ensure the result directory exists\n","# os.makedirs(result_path, exist_ok=True)\n","\n","# # Concatenate all dataframes in the list\n","# result_df = pd.concat(dataframes, ignore_index=True)\n","\n","# # Save the concatenated dataframe to a CSV file\n","# result_file_path = os.path.join(result_path, 'result.csv')\n","# result_df.to_csv(result_file_path, index=False)\n","\n","# print(f\"Result CSV file saved at: {result_file_path}\")"],"metadata":{"id":"aBcGDsIscZ7s"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# # This code will store the result in different columns.\n","# # Set the path to save the result CSV file\n","# result_file = 'result.csv'\n","\n","# # Create the result directory if it doesn't exist\n","# if not os.path.exists(result_path):\n","# os.makedirs(result_path)\n","\n","# # Assuming 'dataframes' is the list of dataframes from the previous code\n","# # Create a dictionary to store dataframes as columns\n","# df_dict = {'df{}'.format(i): df for i, df in enumerate(dataframes)}\n","\n","# # Create a single dataframe with columns as dataframes\n","# result_df = pd.concat(df_dict, axis=1)\n","\n","# # Save the result dataframe to a CSV file\n","# result_df.to_csv(os.path.join(result_path, result_file), index=False)"],"metadata":{"id":"NfZTik6Rc1rs"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# This code will store the result in different CSV files.\n","\n","# Create the result directory if it doesn't exist\n","if not os.path.exists(result_path):\n"," os.makedirs(result_path)\n","else:\n"," # Delete all files in the result_path directory\n"," files_to_delete = os.listdir(result_path)\n"," for file_to_delete in files_to_delete:\n"," file_path = os.path.join(result_path, file_to_delete)\n"," os.remove(file_path)\n","\n","# Assuming 'dataframes' is the list of dataframes from the previous code\n","# Save each dataframe to a separate CSV file\n","for i, df in enumerate(dataframes):\n"," csv_file = f'df_{i}.csv'\n"," csv_path = os.path.join(result_path, csv_file)\n"," df.to_csv(csv_path, index=False)\n"],"metadata":{"id":"_fM3qmvCTcEq"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Machine Learning Section"],"metadata":{"id":"ptvEBc_QR3uQ"}},{"cell_type":"code","source":["import tensorflow as tf\n","print(tf.version.VERSION)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"OAW7NxJvRtnY","executionInfo":{"status":"ok","timestamp":1701062275152,"user_tz":-330,"elapsed":7138,"user":{"displayName":"Avanti Kulkarni","userId":"13157108548794740668"}},"outputId":"c33c9e2b-922f-4964-d7ef-43bcb506d079"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["2.12.0\n"]}]},{"cell_type":"markdown","source":["## Loading the dataset"],"metadata":{"id":"eZvVXvXHcKDa"}},{"cell_type":"code","source":["new_csv_files = [file for file in os.listdir(result_path) if file.endswith('.csv')]\n","print(new_csv_files)\n","\n","recordings = []\n","for csv_file in new_csv_files:\n"," file_path = os.path.join(result_path, csv_file)\n"," df_data = pd.read_csv(file_path)\n"," recordings.append(df_data.iloc[:].values)\n","\n","recordings = np.array(recordings).reshape(len(recordings), -1, 10)\n","print(\"recordings shape:\", recordings.shape)"],"metadata":{"id":"mkrmtMsITIkZ","colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"status":"ok","timestamp":1701065751766,"user_tz":-330,"elapsed":1375,"user":{"displayName":"Avanti Kulkarni","userId":"13157108548794740668"}},"outputId":"bc0ec8d2-3783-4486-b194-7d0b8e57f2f5"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["['df_0.csv', 'df_1.csv', 'df_2.csv', 'df_3.csv', 'df_4.csv', 'df_5.csv', 'df_6.csv', 'df_7.csv', 'df_8.csv', 'df_9.csv', 'df_10.csv', 'df_11.csv', 'df_12.csv', 'df_13.csv', 'df_14.csv']\n","recordings shape: (15, 5000, 10)\n"]}]},{"cell_type":"markdown","source":["## Frame data"],"metadata":{"id":"4r1jLTe5cMCw"}},{"cell_type":"code","source":["def frame(x, frame_len, hop_len):\n"," ''' Slice a 3D data array into (overlapping) frames. '''\n","\n"," assert(x.shape == (len(x), 10))\n"," assert(x.shape[0] >= frame_len)\n"," assert(hop_len >= 1)\n","\n"," n_frames = 1 + (x.shape[0] - frame_len) // hop_len\n"," shape = (n_frames, frame_len, x.shape[1])\n"," strides_x = ((hop_len * x.strides[0],) + x.strides)\n","\n"," return np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides_x)\n","\n","all_frames = []\n","for i in range(recordings.shape[0]):\n"," # frames = frame(recordings[i], 26, 26) # no overlap\n"," frames = frame(recordings[i], 24, 12) # 50% overlap\n"," all_frames.append(frames)\n","\n","print(np.array(all_frames).shape)\n","all_frames = np.concatenate(all_frames)\n","print(all_frames.shape)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"xQKavsmQWnla","executionInfo":{"status":"ok","timestamp":1701062303662,"user_tz":-330,"elapsed":447,"user":{"displayName":"Avanti Kulkarni","userId":"13157108548794740668"}},"outputId":"ccf4db82-51b4-4e2d-a8f2-925043eac642"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["(15, 415, 24, 10)\n","(6225, 24, 10)\n"]}]},{"cell_type":"code","source":["# # Split the last dimension of 'all_frames' into 'x_frames' and 'y_frames'\n","# x_frames = all_frames[:, :, :9]\n","# y_frames = all_frames[:, :, 9:]\n","\n","# # Print the shapes of the resulting arrays\n","# print(\"x_frames shape:\", x_frames.shape)\n","# print(\"y_frames shape:\", y_frames.shape)\n","\n","# Split the 'all_frames' array into 'x_frames' and 'y_frames'\n","x_frames = recordings[:, :, :-1] # Take the first 9 columns for each frame\n","y_frames = recordings[:, :, -1] # Take the last column for each frame\n","# Reshape x_recording to (75000, 9)\n","x_reshaped = np.reshape(x_frames, (75000, 9))\n","\n","# Reshape y_recording to (75000, 1)\n","y_reshaped = np.reshape(y_frames, (75000, 1))\n","print(\"x_recording shape:\", x_frames.shape)\n","print(\"y_recording shape:\", y_frames.shape)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"uDsH5Vfbcf0o","executionInfo":{"status":"ok","timestamp":1701066038116,"user_tz":-330,"elapsed":392,"user":{"displayName":"Avanti Kulkarni","userId":"13157108548794740668"}},"outputId":"3f08900f-f5e4-42c6-bf8b-fea05ee53298"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["x_recording shape: (15, 5000, 9)\n","y_recording shape: (15, 5000)\n"]}]},{"cell_type":"markdown","source":["## Preprocessing the dataset"],"metadata":{"id":"Ri5xZKofs-0u"}},{"cell_type":"code","source":["# Normalize x_frames between [-1;1]\n","min_val = np.min(x_reshaped)\n","max_val = np.max(x_reshaped)\n","x_frames_normed = -1 + 2 * (x_reshaped - min_val) / (max_val - min_val)\n","\n","print(\"Normalized x_frame shape:\", x_frames_normed.shape)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"SyNr2HL8tCMF","executionInfo":{"status":"ok","timestamp":1701066179286,"user_tz":-330,"elapsed":11,"user":{"displayName":"Avanti Kulkarni","userId":"13157108548794740668"}},"outputId":"0b8a0358-4efc-4a6b-8f8d-50b57ed441e1"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Normalized x_frame shape: (75000, 9)\n"]}]},{"cell_type":"markdown","source":["## Preparing the dataset (train, test, split)"],"metadata":{"id":"2dImq6M2s9ll"}},{"cell_type":"code","source":["from sklearn.model_selection import train_test_split\n","import tensorflow as tf\n","from sklearn.feature_extraction.text import TfidfVectorizer\n","# initializing TfidfVectorizer\n","x_train, x_test, y_train, y_test = train_test_split(x_frames_normed, y_reshaped, test_size=0.25, shuffle=True)\n","\n","\n","\n","print(\"X-Training samples:\", x_train.shape)\n","print(\"X-Testing samples:\", x_test.shape)\n","print(\"Y-Training samples :\", y_train.shape)\n","print(\"Y-Testing samples:\", y_test.shape)\n","\n","# Filter the training set\n","indices_to_remove = [id for id in range(y_train.shape[0]) if y_train[id] == 4]\n","x_train_filtered = np.delete(x_train, indices_to_remove, axis=0)\n","y_train_filtered = np.delete(y_train, indices_to_remove)\n","\n","# Filter the test set\n","indices_to_remove = [id for id in range(y_test.shape[0]) if y_test[id] == 4]\n","x_test_filtered = np.delete(x_test, indices_to_remove, axis=0)\n","y_test_filtered = np.delete(y_test, indices_to_remove)\n","print(\"X-Training samples(filtered):\", x_train_filtered.shape)\n","print(\"X-Testing samples(filtered):\", x_test_filtered.shape)\n","print(\"Y-Training samples(filtered) :\", y_train_filtered.shape)\n","print(\"Y-Testing samples(filtered):\", y_test_filtered.shape)"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"TulnT_imsPUG","executionInfo":{"status":"ok","timestamp":1701070160950,"user_tz":-330,"elapsed":28,"user":{"displayName":"Avanti Kulkarni","userId":"13157108548794740668"}},"outputId":"cdf7f61c-9a74-4813-97be-096e26133e23"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["X-Training samples: (56250, 9)\n","X-Testing samples: (18750, 9)\n","Y-Training samples : (56250, 1)\n","Y-Testing samples: (18750, 1)\n","X-Training samples(filtered): (52972, 9)\n","X-Testing samples(filtered): (17648, 9)\n","Y-Training samples(filtered) : (52972,)\n","Y-Testing samples(filtered): (17648,)\n"]}]},{"cell_type":"markdown","metadata":{"id":"wVJATu_oIBjS"},"source":["## Creating the model"]},{"cell_type":"code","source":["import tensorflow as tf\n","from tensorflow.keras.optimizers import Adam\n","from tensorflow.keras.losses import SparseCategoricalCrossentropy\n","\n","def make_model(input_shape):\n"," # Reshape input to add a time dimension\n"," input_layer = tf.keras.layers.Input(input_shape)\n"," reshape = tf.keras.layers.Reshape((input_shape[0], 1))(input_layer)\n","\n"," conv1 = tf.keras.layers.Conv1D(filters=64, kernel_size=3, padding=\"same\")(reshape)\n"," conv1 = tf.keras.layers.BatchNormalization()(conv1)\n"," conv1 = tf.keras.layers.ReLU()(conv1)\n","\n"," conv2 = tf.keras.layers.Conv1D(filters=64, kernel_size=3, padding=\"same\")(conv1)\n"," conv2 = tf.keras.layers.BatchNormalization()(conv2)\n"," conv2 = tf.keras.layers.ReLU()(conv2)\n","\n"," # Apply Batch Normalization after the last convolutional layer\n"," conv2 = tf.keras.layers.BatchNormalization()(conv2)\n","\n"," gap = tf.keras.layers.GlobalAveragePooling1D()(conv2)\n"," num_classes = 4\n"," output_layer = tf.keras.layers.Dense(num_classes, activation=\"softmax\")(gap)\n","\n"," return tf.keras.models.Model(inputs=input_layer, outputs=output_layer)\n","\n","# Assuming x_train_filtered, y_train_filtered, x_test_filtered, y_test_filtered are defined\n","input_shape = x_train_filtered.shape[1:]\n","\n","# Adjust learning rate\n","optimizer = Adam(learning_rate=0.001)\n","\n","model = make_model(input_shape)\n","\n","# Compile the model with the optimizer\n","model.compile(optimizer=optimizer, loss=SparseCategoricalCrossentropy(), metrics=['accuracy'])\n","\n","# Reshape data to have a third dimension\n","x_train_reshaped = x_train_filtered[:, :, tf.newaxis]\n","x_test_reshaped = x_test_filtered[:, :, tf.newaxis]\n","\n","# Custom data augmentation function for time series\n","def time_series_augmentation(x):\n"," # Apply random flip horizontally to each time series independently\n"," flipped_x = tf.image.random_flip_left_right(x)\n"," return flipped_x\n","\n","# Create a data augmentation layer using the custom function\n","data_augmentation = tf.keras.layers.Lambda(time_series_augmentation)\n","\n","# Combine data augmentation with the model\n","augmented_model = tf.keras.Sequential([\n"," data_augmentation,\n"," model\n","])\n","\n","# Compile augmented model before training\n","augmented_model.compile(optimizer=optimizer, loss=SparseCategoricalCrossentropy(), metrics=['accuracy'])\n","\n","# Train the model with augmented data\n","history = augmented_model.fit(x_train_reshaped, y_train_filtered, epochs=30, batch_size=128, validation_data=(x_test_reshaped, y_test_filtered))\n","\n","# Evaluate the model on the original test set\n","test_loss, test_accuracy = model.evaluate(x_test_reshaped, y_test_filtered)\n","\n","print(f'Test Loss: {test_loss:.4f}')\n","print(f'Test Accuracy: {test_accuracy:.4f}')\n"],"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"id":"y5sOrSJ5dSOc","executionInfo":{"status":"ok","timestamp":1701070437657,"user_tz":-330,"elapsed":271391,"user":{"displayName":"Avanti Kulkarni","userId":"13157108548794740668"}},"outputId":"af703864-279e-4d8f-90e8-395a9c08facc"},"execution_count":null,"outputs":[{"output_type":"stream","name":"stdout","text":["Epoch 1/30\n","414/414 [==============================] - 15s 22ms/step - loss: 1.0811 - accuracy: 0.5062 - val_loss: 5.9558 - val_accuracy: 0.2301\n","Epoch 2/30\n","414/414 [==============================] - 6s 14ms/step - loss: 1.0022 - accuracy: 0.5582 - val_loss: 1.5939 - val_accuracy: 0.4618\n","Epoch 3/30\n","414/414 [==============================] - 8s 19ms/step - loss: 0.9392 - accuracy: 0.5934 - val_loss: 2.4014 - val_accuracy: 0.2454\n","Epoch 4/30\n","414/414 [==============================] - 8s 19ms/step - loss: 0.9022 - accuracy: 0.6116 - val_loss: 1.6009 - val_accuracy: 0.4005\n","Epoch 5/30\n","414/414 [==============================] - 6s 14ms/step - loss: 0.8597 - accuracy: 0.6316 - val_loss: 1.2356 - val_accuracy: 0.4950\n","Epoch 6/30\n","414/414 [==============================] - 8s 19ms/step - loss: 0.8160 - accuracy: 0.6555 - val_loss: 2.4190 - val_accuracy: 0.2916\n","Epoch 7/30\n","414/414 [==============================] - 8s 20ms/step - loss: 0.7896 - accuracy: 0.6685 - val_loss: 1.0599 - val_accuracy: 0.5615\n","Epoch 8/30\n","414/414 [==============================] - 6s 13ms/step - loss: 0.7724 - accuracy: 0.6762 - val_loss: 1.1161 - val_accuracy: 0.4854\n","Epoch 9/30\n","414/414 [==============================] - 8s 19ms/step - loss: 0.7573 - accuracy: 0.6840 - val_loss: 1.2476 - val_accuracy: 0.4683\n","Epoch 10/30\n","414/414 [==============================] - 8s 20ms/step - loss: 0.7296 - accuracy: 0.6955 - val_loss: 5.3159 - val_accuracy: 0.2368\n","Epoch 11/30\n","414/414 [==============================] - 6s 14ms/step - loss: 0.7268 - accuracy: 0.6935 - val_loss: 1.3398 - val_accuracy: 0.4539\n","Epoch 12/30\n","414/414 [==============================] - 7s 16ms/step - loss: 0.7026 - accuracy: 0.7075 - val_loss: 1.9300 - val_accuracy: 0.4803\n","Epoch 13/30\n","414/414 [==============================] - 9s 21ms/step - loss: 0.6823 - accuracy: 0.7176 - val_loss: 8.5063 - val_accuracy: 0.2493\n","Epoch 14/30\n","414/414 [==============================] - 6s 13ms/step - loss: 0.6591 - accuracy: 0.7292 - val_loss: 4.8412 - val_accuracy: 0.2189\n","Epoch 15/30\n","414/414 [==============================] - 7s 16ms/step - loss: 0.6223 - accuracy: 0.7479 - val_loss: 2.1135 - val_accuracy: 0.2985\n","Epoch 16/30\n","414/414 [==============================] - 9s 23ms/step - loss: 0.6044 - accuracy: 0.7581 - val_loss: 5.6274 - val_accuracy: 0.2208\n","Epoch 17/30\n","414/414 [==============================] - 6s 15ms/step - loss: 0.5508 - accuracy: 0.7813 - val_loss: 59.5307 - val_accuracy: 0.2070\n","Epoch 18/30\n","414/414 [==============================] - 8s 20ms/step - loss: 0.5374 - accuracy: 0.7857 - val_loss: 26.7358 - val_accuracy: 0.2070\n","Epoch 19/30\n","414/414 [==============================] - 8s 20ms/step - loss: 0.5058 - accuracy: 0.7956 - val_loss: 6.5008 - val_accuracy: 0.2359\n","Epoch 20/30\n","414/414 [==============================] - 6s 14ms/step - loss: 0.5324 - accuracy: 0.7876 - val_loss: 58.4683 - val_accuracy: 0.2070\n","Epoch 21/30\n","414/414 [==============================] - 8s 18ms/step - loss: 0.5013 - accuracy: 0.7981 - val_loss: 52.4306 - val_accuracy: 0.2070\n","Epoch 22/30\n","414/414 [==============================] - 8s 20ms/step - loss: 0.4670 - accuracy: 0.8152 - val_loss: 59.7063 - val_accuracy: 0.2070\n","Epoch 23/30\n","414/414 [==============================] - 6s 15ms/step - loss: 0.5411 - accuracy: 0.7779 - val_loss: 2.7812 - val_accuracy: 0.4062\n","Epoch 24/30\n","414/414 [==============================] - 7s 18ms/step - loss: 0.5211 - accuracy: 0.7923 - val_loss: 3.8692 - val_accuracy: 0.3328\n","Epoch 25/30\n","414/414 [==============================] - 8s 20ms/step - loss: 0.4109 - accuracy: 0.8340 - val_loss: 15.7089 - val_accuracy: 0.2094\n","Epoch 26/30\n","414/414 [==============================] - 5s 13ms/step - loss: 0.3990 - accuracy: 0.8417 - val_loss: 77.2156 - val_accuracy: 0.2070\n","Epoch 27/30\n","414/414 [==============================] - 7s 16ms/step - loss: 0.3934 - accuracy: 0.8437 - val_loss: 5.3983 - val_accuracy: 0.2607\n","Epoch 28/30\n","414/414 [==============================] - 8s 20ms/step - loss: 0.3956 - accuracy: 0.8427 - val_loss: 3.0488 - val_accuracy: 0.3485\n","Epoch 29/30\n","414/414 [==============================] - 6s 15ms/step - loss: 0.3522 - accuracy: 0.8592 - val_loss: 154.8326 - val_accuracy: 0.2070\n","Epoch 30/30\n","414/414 [==============================] - 7s 16ms/step - loss: 0.3655 - accuracy: 0.8543 - val_loss: 7.1014 - val_accuracy: 0.3247\n","552/552 [==============================] - 2s 3ms/step - loss: 7.9734 - accuracy: 0.2774\n","Test Loss: 7.9734\n","Test Accuracy: 0.2774\n"]}]},{"cell_type":"code","source":["test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)\n","\n","print(\"Test loss:\", test_loss)\n","print(\"Test acc:\", test_acc)\n","model.summary()"],"metadata":{"id":"3___g5CdvPiW","colab":{"base_uri":"https://localhost:8080/","height":651},"executionInfo":{"status":"error","timestamp":1701066638412,"user_tz":-330,"elapsed":20,"user":{"displayName":"Avanti Kulkarni","userId":"13157108548794740668"}},"outputId":"1e01a357-461f-4955-f9ba-6582dd1200c1"},"execution_count":null,"outputs":[{"output_type":"error","ename":"ValueError","evalue":"ignored","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)","\u001b[0;32m<ipython-input-97-5748aed9adc6>\u001b[0m in \u001b[0;36m<cell line: 1>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mtest_loss\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_acc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevaluate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_test\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 2\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Test loss:\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_loss\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Test acc:\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtest_acc\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msummary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py\u001b[0m in \u001b[0;36merror_handler\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 68\u001b[0m \u001b[0;31m# To get the full stack trace, call:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 69\u001b[0m \u001b[0;31m# `tf.debugging.disable_traceback_filtering()`\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 70\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0me\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwith_traceback\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfiltered_tb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 71\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0;32mdel\u001b[0m \u001b[0mfiltered_tb\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;32m/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\u001b[0m in \u001b[0;36mtf__test_function\u001b[0;34m(iterator)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mdo_return\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0mretval_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mag__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconverted_call\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mag__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mld\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstep_function\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mag__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mld\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mag__\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mld\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfscope\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;32mexcept\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mdo_return\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n","\u001b[0;31mValueError\u001b[0m: in user code:\n\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1852, in test_function *\n return step_function(self, iterator)\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1836, in step_function **\n outputs = model.distribute_strategy.run(run_step, args=(data,))\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1824, in run_step **\n outputs = model.test_step(data)\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/training.py\", line 1788, in test_step\n y_pred = self(x, training=False)\n File \"/usr/local/lib/python3.10/dist-packages/keras/utils/traceback_utils.py\", line 70, in error_handler\n raise e.with_traceback(filtered_tb) from None\n File \"/usr/local/lib/python3.10/dist-packages/keras/engine/input_spec.py\", line 298, in assert_input_compatibility\n raise ValueError(\n\n ValueError: Input 0 of layer \"model\" is incompatible with the layer: expected shape=(None, 24, 9), found shape=(None, 9)\n"]}]},{"cell_type":"code","source":["Y_pred = model.predict(x_test)\n","y_pred = np.argmax(Y_pred, axis=1)\n","confusion_matrix = tf.math.confusion_matrix(y_test, y_pred)\n","\n","plt.figure()\n","sns.heatmap(confusion_matrix,\n"," annot=True,\n"," xticklabels=labels,\n"," yticklabels=labels,\n"," cmap=plt.cm.Blues,\n"," fmt='d', cbar=False)\n","plt.tight_layout()\n","plt.ylabel('True label')\n","plt.xlabel('Predicted label')\n","plt.show()"],"metadata":{"id":"DA0aTrYhuW-d"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Save the model into an HDF5 file ‘model.h5’\n","model.save('model.h5')"],"metadata":{"id":"XnrBR1OEu76f"},"execution_count":null,"outputs":[]}]}

‎Gait_phase_analysis.ipynb

+1
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)
Please sign in to comment.