diff --git a/.github/ISSUE_TEMPLATE/bug-report---.md b/.github/ISSUE_TEMPLATE/bug-report---.md new file mode 100644 index 00000000..1aadaf7f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report---.md @@ -0,0 +1,27 @@ +--- +name: "Bug report \U0001F41E" +about: Create a bug report +labels: bug + +--- + +## Describe the bug + +A clear and concise description of what the bug is. + +### Steps to reproduce + +Steps to reproduce the behavior. + +### Expected behavior + +A clear and concise description of what you expected to happen. + +### Environment + +- OS: [e.g. Arch Linux] +- Other details that you think may affect. + +### Additional context + +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/feature-request---.md b/.github/ISSUE_TEMPLATE/feature-request---.md new file mode 100644 index 00000000..355cf62b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request---.md @@ -0,0 +1,18 @@ +--- +name: "Feature request \U0001F680" +about: Suggest an idea +labels: enhancement + +--- + +## Summary + +Brief explanation of the feature. + +### Basic example + +Include a basic example or links here. + +### Motivation + +Why are we doing this? What use cases does it support? What is the expected outcome? diff --git a/.github/workflows/run_units_test.yml b/.github/workflows/run_units_test.yml index 034dd036..1404fdfa 100644 --- a/.github/workflows/run_units_test.yml +++ b/.github/workflows/run_units_test.yml @@ -2,6 +2,7 @@ on: push: branches: - main + - dev jobs: units-test: @@ -19,14 +20,18 @@ jobs: - name: Set up Python uses: actions/setup-python@v2 with: - python-version: 3.10 + python-version: 3.11.2 + + - name: Activate Python virtual environment run: | python -m venv venv source venv/bin/activate + + - name: Upgrade pip and install requirements + run: | python -m pip install --upgrade pip python -m pip install -r requirements.txt - name: Set up Pytest run: | - cd signa2text pytest diff --git a/.gitignore b/.gitignore index 7b0f4fb0..4ce27fed 100644 --- a/.gitignore +++ b/.gitignore @@ -90,10 +90,6 @@ target/ # pytest cache .pytest_cache/ -#mics -.gitpod.yml -poetry.lock - # Data and models data/*/* models/* @@ -120,7 +116,11 @@ yb2audio/data/*/* # Development Enviroment dev.py dev_env.txt +**/development/ # Keys set_environment_variables.sh -model_artifacts \ No newline at end of file + +#miscellaneous +.gitpod.yml +poetry.lock \ No newline at end of file diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index e69de29b..00000000 diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..fa59e118 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 Ladipo Ezekiel Ipadeola + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile index fc07bc95..191e07ed 100644 --- a/Makefile +++ b/Makefile @@ -6,14 +6,8 @@ help: @echo " precommit runs precommit on all files" setup: - @echo "Installing..." - curl -sSL https://install.python-poetry.org | python - - @echo "Activating virtual environment" - poetry shell - poetry install - poetry add pre-commit - python pre-commit install - @echo "Environment setup complete" + @echo "Running setup..." + . ./run_setup.sh precommit: @echo "Running precommit on all files" @@ -25,4 +19,4 @@ export_: run_container: @echo "Running Docker Contain" - + run_container diff --git a/README.md b/README.md index 589c9dab..be5c60a8 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,63 @@ # NSL-2-AUDIO -[![LICENSE](https://img.shields.io/badge/license-MIT-green?style=flat-square)](LICENSE) -[![Python](https://img.shields.io/badge/python-3.10-blue.svg?style=flat-square)](https://www.python.org/) -[![PyTorch](https://img.shields.io/badge/PyTorch-2.7.0-orange)](https://pytorch.org/) - -![image/gif]() - -## Project description + + +[![Contributors][contributors-shield]][contributors-url] +[![Forks][forks-shield]][forks-url] +[![Stargazers][stars-shield]][stars-url] +[![Issues][issues-shield]][issues-url] +[![MIT License][license-shield]][license-url] + + +
+
+ + Logo + + +

NSL-2-AUDIO

+ +

+ NSL-2-AUDIO is an open-source Automatic Sign Language Translation system, specifically designed to translate Nigerian Sign Language (NSL) into one of the Low-Resource Languages (LRLs) spoken in Nigeria." +
+ Explore the docs » +
+
+ View Demo + · + Report Bug + · + Request Feature +

+
+ + +
+ Table of Contents +
    +
  1. + About The Project + +
  2. +
  3. + Getting Started + +
  4. +
  5. Usage
  6. +
  7. Roadmap
  8. +
  9. Contributing
  10. +
  11. License
  12. +
  13. Contact
  14. +
  15. Acknowledgments
  16. +
+
+ + +## About The Project ***Overview:*** \ This project is dedicated to the development of an Automatic Sign Language Translation system, with a specific focus on translating Nigerian Sign Language (NSL) into one of the Low-Resource Languages (LRLs) spoken in Nigeria. The primary objective is to address the communication challenges faced by the NSL community and contribute to inclusivity and employment opportunities in education and the workforce. @@ -28,53 +79,82 @@ Effective communication is a cornerstone of societal cohesion, and this project You can read the project proposal here, [Project Proposal](https://github.com/AISaturdaysLagos/Cohort8-Ransome-Kuti-Ladipo/blob/main/project-proposal.pdf) -## PROJECT TIMELINE +[![DEMO][product-screenshot]](https://example.com) +![sign_lang_gif](images/sign_lang.gif) + +

(back to top)

+ +### Built With -![PROJECT TIMELINE](https://github.com/AISaturdaysLagos/Cohort8-Ransom-Kuti-Ladipo/blob/main/images/Project%20Timeline.png) +- [![Python][Python]][Python-url] +- [![Pytorch][Pytorch]][Pytorch-url] +- [![HuggingFace][HuggingFace]][HuggingFace-url] -### SYSTEM DESIGN -![PLANNED SYSTEM DESIGN]() +

(back to top)

-## Setup + +## Getting Started -## Configuration +The project is structured into three distinct parts, each housed in separate directories as outlined in the project proposal. The initial phase involves translating sign language into English text, followed by the second phase, which focuses on translating the English text into Yoruba text. The final segment entails taking the translated Yoruba text and converting it into generated Yoruba audio. + +The `signa2text` directory is dedicated to the process of translating sign language into English text. Meanwhile, the `linguify_yb` directory serves the purpose of transforming English text into Yoruba text. Finally, the `yb2audio` directory is designated for utilizing the translated audio to generate Yoruba audio. + +To access any of the three directories, adhere to the specified prerequisites below and navigate into the respective directory. + +### Prerequisites ```bash # Clone this repository -$ git clone +$ git clone https://github.com/rileydrizzy/NSL_2_AUDIO # Go into the repository -$ cd +$ cd NSL_2_AUDIO # Install dependencies -$ . ./run_setup.sh - +$ make setup ``` -### Project Roadmap +

(back to top)

-Here's a glimpse of the exciting features we plan to implement in the coming weeks: + +## Usage -| Feature | Description | Status | -| ------------------------- | ---------------------------------------------------------- | ----------- | -| SignText Model | Implement the training of the SignText model | In Progress | -| Deployement of the System| Develop and Deploy the system to Google Cloud. | Planned | -| User Interface | Developing a friendly and functionaly User Interface| Planned | -| Static Transformer | Implementing SOTA model for the translation of Sign to Text| Planned | +

(back to top)

-## How to Contribute + +## Roadmap -We welcome contributions from the community. If you're interested in contributing, please refer to the [Contributing Guidelines](CONTRIBUTING.md). +- [ ] Feature 1 -## Acknowledgments +See the [open issues](https://github.com/rileydrizzy/NSL_2_AUDIO/issues) for a full list of proposed features (and known issues). -I would like to acknowledge the outstanding contributions of : +

(back to top)

-**Name:** Afonja Tejumade ***(```Mentor```)*** -**Email:** -**GitHub:** [@tejuafonja](https://github.com/tejuafonja) + +## Contributing + +Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are **greatly appreciated**. + +If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". +Don't forget to give the project a star! Thanks again! + +1. Fork the Project +2. Create your Feature Branch (`git checkout -b feature/AmazingFeature`) +3. Commit your Changes (`git commit -m 'Add some AmazingFeature'`) +4. Push to the Branch (`git push origin feature/AmazingFeature`) +5. Open a Pull Request +

(back to top)

+ + +## License + +Distributed under the MIT License. See `LICENSE` for more information. + +

(back to top)

+ + ## Support and Contact If you have questions or need assistance, feel free to reach out to: @@ -84,4 +164,39 @@ If you have questions or need assistance, feel free to reach out to: **GitHub:** [@rileydrizzy](https://github.com/rileydrizzy) **Linkdeln:** [Ipadeola Ladipo](https://www.linkedin.com/in/ladipo-ipadeola/) +

(back to top)

+ + +## Acknowledgments + +I would like to acknowledge the outstanding contributions of : + +**Name:** Afonja Tejumade ***(```Mentor```)*** +**Email:** +**GitHub:** [@tejuafonja](https://github.com/tejuafonja) + +

(back to top)

+ --- + + + +[contributors-shield]: https://img.shields.io/github/contributors/rileydrizzy/NSL_2_AUDIO.svg?style=for-the-badge +[contributors-url]: https://github.com/rileydrizzy/NSL_2_AUDIO/graphs/contributors +[forks-shield]: https://img.shields.io/github/forks/rileydrizzy/NSL_2_AUDIO.svg?style=for-the-badge +[forks-url]: https://github.com/rileydrizzy/NSL_2_AUDIO/network/members +[stars-shield]: https://img.shields.io/github/stars/rileydrizzy/NSL_2_AUDIO.svg?style=for-the-badge +[stars-url]: https://github.com/rileydrizzy/NSL_2_AUDIO/stargazers +[issues-shield]: https://img.shields.io/github/issues/rileydrizzy/NSL_2_AUDIO.svg?style=for-the-badge +[issues-url]: https://github.com/rileydrizzy/NSL_2_AUDIO/issues +[license-shield]: https://img.shields.io/github/license/rileydrizzy/NSL_2_AUDIO.svg?style=for-the-badge +[license-url]: https://github.com/rileydrizzy/NSL_2_AUDIO/blob/master/LICENSE.txt +[product-screenshot]: images/screenshot.png +[Python-url]: +[Python]: +[Pytorch-url]: +[Pytorch]: +[HuggingFace-url]: +[HuggingFace]: +[GCP-url]: +[GCP]: <> diff --git a/app.py b/app.py deleted file mode 100644 index e69de29b..00000000 diff --git a/images/Project Timeline.png b/images/Project Timeline.png deleted file mode 100644 index 7a0fdbda..00000000 Binary files a/images/Project Timeline.png and /dev/null differ diff --git a/images/logo.png b/images/logo.png new file mode 100644 index 00000000..e7e886e7 Binary files /dev/null and b/images/logo.png differ diff --git a/images/sign lang.gif b/images/sign_lang.gif similarity index 100% rename from images/sign lang.gif rename to images/sign_lang.gif diff --git a/inference.py b/inference.py deleted file mode 100644 index c0383c3a..00000000 --- a/inference.py +++ /dev/null @@ -1,171 +0,0 @@ -"""doc -""" -import os -import cv2 - -import numpy as np -import pandas as pd -import mediapipe as mp -import torch -from IPython.display import Audio -from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, VitsModel -from linguify_yb.src.models import baseline_transfomer - -nllb_model_name = "facebook/nllb-200-distilled-600M" -mms_model_name = "facebook/mms-tts-yor" -youruba_lang = "yor_Latn" - - -# NLLB Model -nllb_tokenizer = AutoTokenizer.from_pretrained(nllb_model_name) -nllb_model = AutoModelForSeq2SeqLM.from_pretrained(nllb_model_name) - -# MMS Model -mms_model = VitsModel.from_pretrained(mms_model_name) -mms_tokenizer = AutoTokenizer.from_pretrained(mms_model_name) - - -def NLLB_infer(eng_text): - inputs = nllb_tokenizer(eng_text, return_tensors="pt") - translate_token = nllb_model.generate( - **inputs, - forced_bos_token_id=nllb_tokenizer.lang_code_to_id[youruba_lang], - max_length=50 - ) - outputs = nllb_tokenizer.batch_decode(translate_token, skip_special_tokens=True)[0] - return outputs - - -def MMS_model_infer(text): - inputs = mms_tokenizer(text, return_tensors="pt") - with torch.no_grad(): - output = mms_model(**inputs).waveform - return Audio(output, rate=mms_model.config.sampling_rate) - - -# TODO Debug -def extract_landmarks(path, start_frame=0): - mp_holistic = mp.solutions.holistic - # Initialize variables - frame_number = 0 - frame = [] - type_ = [] - index = [] - x = [] - y = [] - z = [] - - # Open the video file - cap = cv2.VideoCapture(path) - - # Get the total number of frames in the video - end_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) - - # Get the frames per second (FPS) of the video - # fps = cap.get(cv2.CAP_PROP_FPS) - # cap.set(cv2.CAP_PROP_FPS, fps) - - # Initialize holistic model for landmark detection - with mp_holistic.Holistic( - min_detection_confidence=0.5, min_tracking_confidence=0.5 - ) as holistic: - while cap.isOpened(): - success, image = cap.read() - - # Break if video is finished - if not success: - break - - # Increment frame number - frame_number += 1 - - # Skip frames until the start_frame is reached - if frame_number < start_frame: - continue - - # Break if end_frame is reached - if end_frames != -1 and frame_number > end_frames: - break - - # Prepare image for landmark detection - image.flags.writeable = False - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - results = holistic.process(image) - - # Process face landmarks - if results.face_landmarks is None: - for i in range(478): - frame.append(frame_number) - type_.append("face") - index.append(i) - x.append(0) - y.append(0) - z.append(0) - else: - for ind, val in enumerate(results.face_landmarks.landmark): - frame.append(frame_number) - type_.append("face") - index.append(ind) - x.append(val.x) - y.append(val.y) - z.append(val.z) - - # Process pose landmarks - if results.pose_landmarks is None: - for i in range(32): - frame.append(frame_number) - type_.append("pose") - index.append(i) - x.append(0) - y.append(0) - z.append(0) - else: - for ind, val in enumerate(results.pose_landmarks.landmark): - frame.append(frame_number) - type_.append("pose") - index.append(ind) - x.append(val.x) - y.append(val.y) - z.append(val.z) - - # Process left hand landmarks - if results.left_hand_landmarks is None: - for i in range(20): - frame.append(frame_number) - type_.append("left_hand") - index.append(i) - x.append(0) - y.append(0) - z.append(0) - else: - for ind, val in enumerate(results.left_hand_landmarks.landmark): - frame.append(frame_number) - type_.append("left_hand") - index.append(ind) - x.append(val.x) - y.append(val.y) - z.append(val.z) - - # Process right hand landmarks - if results.right_hand_landmarks is None: - for i in range(20): - frame.append(frame_number) - type_.append("right_hand") - index.append(i) - x.append(0) - y.append(0) - z.append(0) - else: - for ind, val in enumerate(results.right_hand_landmarks.landmark): - frame.append(frame_number) - type_.append("right_hand") - index.append(ind) - x.append(val.x) - y.append(val.y) - z.append(val.z) - # TODO rearrange dataframe to account for just the frames in sequential manner - # TODO consider to use numpy instead of a dataframe - # Create a DataFrame from the collected data - return pd.DataFrame( - {"frame": frame, "type": type_, "landmark_index": index, "x": x, "y": y, "z": z} - ) diff --git a/linguify_yb/README.md b/linguify_yb/README.md index d427164d..e69de29b 100644 --- a/linguify_yb/README.md +++ b/linguify_yb/README.md @@ -1,11 +0,0 @@ -# Signa-Text - -[![LICENSE](https://img.shields.io/badge/license-MIT-green?style=flat-square)](LICENSE) -[![Python](https://img.shields.io/badge/python-3.6-blue.svg?style=flat-square)](https://www.python.org/) -[![PyTorch](https://img.shields.io/badge/PyTorch-1.7.0-orange)](https://pytorch.org/) - -![image/gif]() - -## Project description - -***Overview:*** \ diff --git a/run_setup.sh b/run_setup.sh index 2744c0b9..111fcab1 100644 --- a/run_setup.sh +++ b/run_setup.sh @@ -3,5 +3,4 @@ curl -sSL https://install.python-poetry.org | python - echo "Activating virtual environment" poetry install poetry shell -python pre-commit install echo "Environment setup complete" diff --git a/set_environment_variables_template.sh b/set_environment_variables_template.sh index 488d881a..30087585 100644 --- a/set_environment_variables_template.sh +++ b/set_environment_variables_template.sh @@ -1,9 +1,9 @@ +#!/bin/bash # replace placeholders and rename this file to `set_environment_variables.sh` + #replace with Kaggle username and key -#!/bin/bash export KAGGLE_USERNAME=username export KAGGLE_KEY=xxxxxxxxxxxxxx + #replace with WANDB key export WANDB_API_KEY=xxxxxxxxxxxxxx -# -export GOOGLE= diff --git a/signa2text/README.md b/signa2text/README.md index 35049fc2..858fd819 100644 --- a/signa2text/README.md +++ b/signa2text/README.md @@ -1,30 +1,68 @@ # Signa2Text [![LICENSE](https://img.shields.io/badge/license-MIT-green?style=flat-square)](LICENSE) -[![Python](https://img.shields.io/badge/python-3.6-blue.svg?style=flat-square)](https://www.python.org/) -[![PyTorch](https://img.shields.io/badge/PyTorch-1.7.0-orange)](https://pytorch.org/) +[![Python](https://img.shields.io/badge/python-3.11.2-blue.svg?style=flat-square)](https://www.python.org/) +[![PyTorch](https://img.shields.io/badge/PyTorch-2.1.0-orange)](https://pytorch.org/) -![image/gif](https://github.com/rileydrizzy/Cohort8-Ransom-Kuti-Ladipo/blob/main/images/sign%20lang.gif) + +
+ Table of Contents +
    +
  1. Project description
  2. +
  3. + Getting Started + +
  4. +
  5. Contact
  6. +
  7. Acknowledgments
  8. +
+
## Project description -***Overview:*** +h +## Getting Started -## Project Roadmap +h +### Prerequisites -## How to Contribute +```bash +# Clone the main repository +$ git clone https://github.com/rileydrizzy/NSL_2_AUDIO -We welcome contributions from the community. If you're interested in contributing, please refer to the [Contributing Guidelines](CONTRIBUTING.md). +# Go into the main repository +$ cd NSL_2_AUDIO -## Support and Contact +# Install dependencies +$ make setup -If you have questions or need assistance, feel free to reach out to: +# To go into this directory +$ cd signa2Text +``` + +### Data + +```bash +# Tun to download the dataset +$ python src/download_dev_data.py +``` + +## Contact **Name:** **Ipadeola Ezekiel Ladipo** **Email:** **GitHub:** [@rileydrizzy](https://github.com/rileydrizzy) **Linkdeln:** [Ipadeola Ladipo](https://www.linkedin.com/in/ladipo-ipadeola/) ---- +## Acknowledgments + +I would like to acknowledge the outstanding contributions of : + +**Name:** Afonja Tejumade ***(```Mentor```)*** +**Email:** +**GitHub:** [@tejuafonja](https://github.com/tejuafonja) diff --git a/signa2text/src/__init__.py b/signa2text/src/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/signa2text/src/benchmark.py b/signa2text/src/benchmark.py deleted file mode 100644 index 02a70d52..00000000 --- a/signa2text/src/benchmark.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -Module for benchmarking a PyTorch model. - -This module provides a `BenchMarker` class for analyzing model metrics such as -Multiply-Accumulates(MACs), sparsity, the number of parameters, and model size. - -Classes: -- BenchMarker: A class for benchmarking a PyTorch model. - -Functions: -- get_model_macs: Calculate the MACs (Multiply-Accumulates) of a model. -- get_model_sparsity: Calculate the sparsity of a model. -- get_num_parameters: Calculate the total number of parameters of a model. -- get_model_size: Calculate the size of a model in bits. - - -""" -from torchprofile import profile_macs -from torch import nn - -Byte = 8 -KiB = 1024 * Byte -MiB = 1024 * KiB -GiB = 1024 * MiB - - -class BenchMarker: - """ - Benchmarking class to analyze model metrics such as MACs, - sparsity, number of parameters, and model size. - """ - - def __init__(self) -> None: - pass - - def get_model_macs(self, model, inputs=None) -> int: - """ - Calculate the Multiply-Accumulates (MACs) of a model. - - Parameters: - - model: The PyTorch model. - - inputs: The input tensor to the model. - - Returns: - - int: The number of MACs. - """ - return profile_macs(model, inputs) - - def get_model_sparsity(self, model: nn.Module) -> float: - """ - Calculate the sparsity of the given model. - - Sparsity is defined as 1 - (number of non-zeros / total number of elements). - - Parameters: - - model: The PyTorch model. - - Returns: - - float: The sparsity of the model. - """ - num_nonzeros, num_elements = 0, 0 - for param in model.parameters(): - num_nonzeros += param.count_nonzero() - num_elements += param.numel() - return 1 - float(num_nonzeros) / num_elements - - def get_num_parameters(self, model: nn.Module, count_nonzero_only=False) -> int: - """ - Calculate the total number of parameters of the model. - - Parameters: - - model: The PyTorch model. - - count_nonzero_only: If True, count only nonzero weights. - - Returns: - - int: The total number of parameters. - """ - num_counted_elements = 0 - for param in model.parameters(): - if count_nonzero_only: - num_counted_elements += param.count_nonzero() - else: - num_counted_elements += param.numel() - return num_counted_elements - - def get_model_size( - self, model: nn.Module, data_width=32, count_nonzero_only=False - ) -> int: - """ - Calculate the model size in bits. - - Parameters: - - model: The PyTorch model. - - data_width: Number of bits per element. - - count_nonzero_only: If True, count only nonzero weights. - - Returns: - - int: The model size in bits. - """ - return self.get_num_parameters(model, count_nonzero_only) * data_width - - def runner(self, model): - """ - Run the benchmark on the given model. - - Parameters: - - model: The PyTorch model. - - Returns: - - tuple: A tuple containing the model metrics - """ - model_macs = self.get_model_macs(model) - model_sparsity = self.get_model_sparsity(model) - model_num_params = self.get_num_parameters(model) - model_size = self.get_model_size(model) - - return model_macs, model_sparsity, model_num_params, model_size diff --git a/signa2text/src/dataset/dataset_loader.py b/signa2text/src/dataset/dataset_loader.py index 69577216..c53db1d6 100644 --- a/signa2text/src/dataset/dataset_loader.py +++ b/signa2text/src/dataset/dataset_loader.py @@ -1,15 +1,23 @@ """ -Module to define datasets and dataloaders for ASL Fingerspelling project. +ASL Fingerspelling Dataset Module + +This module defines classes and functions for handling datasets and dataloaders for the +ASL Fingerspelling project. Classes: -- TokenHashTable: A class for handling token-to-index and index-to-token mappings. -- LandmarkDataset: A dataset class for ASL Fingerspelling frames,\ - including methods for processing and cleaning frames. +- TokenHashTable: Handles token-to-index and index-to-token mappings. + +- LandmarkDataset: Dataset class for ASL Fingerspelling frames.\ + Includes methods for processing and cleaning frames and phrase. Functions: -- read_file: Read data from file based on file_id_list and landmarks_metadata_path. -- get_dataset: Create a dataset with token-to-index mapping. -- prepare_dataloader: Prepare a dataloader with distributed sampling. +- read_file(file_id_list, landmarks_metadata_path): Reads data from a file based on file IDs\ + and landmarks metadata path. + +- get_dataset(file_path): Creates a dataset with a token-to-index mapping. + +- prepare_dataloader(dataset, batch_size, num_workers_= 1): Prepares a dataloader with\ + distributed sampling. """ @@ -21,7 +29,7 @@ from torch.utils.data import DataLoader, Dataset from torch.utils.data.distributed import DistributedSampler from dataset.frames_config import FEATURE_COLUMNS -from dataset.preprocess import clean_frames_process +from dataset.preprocess import preprocess_frames # File paths for metadata and phrase-to-index mapping PHRASE_PATH = "/kaggle/input/asl-fingerspelling/character_to_prediction_index.json" @@ -49,53 +57,72 @@ class TokenHashTable: + """ + TokenHashTable handles token-to-index and index-to-token mappings for sequence data. + + This class is designed to facilitate the conversion between sequences of tokens + and their corresponding indices, providing methods for transforming sentences + to tensors and vice versa. + """ + def __init__( self, word2index_mapping=character_to_num, index2word_mapping=num_to_character ): - """ - Initialize a TokenHashTable to handle token-to-index and index-to-token mapping. - - Parameters: - word2index_mapping (dict): Mapping from word to index. - index2word_mapping (dict): Mapping from index to word. + """Initialize a TokenHashTable. + + Parameters + ---------- + word2index_mapping : dict, optional + Mapping from word to index, by default character_to_num. + index2word_mapping : dict, optional + Mapping from index to word, by default num_to_character. """ self.word2index = word2index_mapping self.index2word = index2word_mapping - def _indexesfromsentence(self, sentence): - """ - Convert a sentence into a list of corresponding indices. + def _indexes_from_sentence(self, sentence): + """Convert a sentence into a list of corresponding indices. - Parameters: - sentence (list): List of words in a sentence. + Parameters + ---------- + sentence : list + List of words in a sentence. - Returns: - list: List of indices corresponding to words in the sentence. + Returns + ------- + list + List of indices corresponding to words in the sentence. """ return [self.word2index[word] for word in sentence] - def tensorfromsentence(self, sentence): - """ - Convert a sentence into a tensor of indices. + def tensor_from_sentence(self, sentence): + """Convert a sentence into a tensor of indices. - Parameters: - sentence (list): List of words in a sentence. + Parameters + ---------- + sentence : list + List of words in a sentence. - Returns: - torch.Tensor: Tensor of indices. + Returns + ------- + torch.Tensor + Tensor of indices. """ - indexes = self._indexesfromsentence(sentence) + indexes = self._indexes_from_sentence(sentence) return torch.tensor(indexes, dtype=torch.long) def indexes_to_sentence(self, indexes_list): - """ - Convert a list of indices into a list of corresponding words. + """Convert a list of indices into a list of corresponding words. - Parameters: - indexes_list (list or torch.Tensor): List or tensor of indices. + Parameters + ---------- + indexes_list : list or torch.Tensor + List or tensor of indices. - Returns: - list: List of words corresponding to the indices. + Returns + ------- + list + List of words corresponding to the indices. """ if torch.is_tensor(indexes_list): indexes_list = indexes_list.tolist() @@ -105,80 +132,121 @@ def indexes_to_sentence(self, indexes_list): def read_file(file_id_list, landmarks_metadata_path): """ - Read data from file based on file_id_list and landmarks_metadata_path. - - Parameters: - file_id_list (list): List of tuples containing file paths and corresponding file_ids. - landmarks_metadata_path (str): Path to the metadata file. - - Returns: - tuple: A tuple containing lists of frames and phrases. + Read data from files based on file IDs and landmarks metadata. + + Parameters + ---------- + file_id_list : list + List of tuples containing file paths and corresponding file IDs. + landmarks_metadata_path : str + Path to the metadata file. + + Returns + ------- + tuple + A tuple containing lists of frames and phrases. """ phrase_list = [] frames_list = [] - for file, file_id in file_id_list: + + for file_path, file_id in file_id_list: metadata_train_dataframe = pd.read_csv(landmarks_metadata_path) file_id_df = metadata_train_dataframe.loc[ metadata_train_dataframe["file_id"] == file_id ] + saved_parquet_df = pq.read_table( - file, columns=["sequence_id"] + FEATURE_COLUMNS + file_path, columns=["sequence_id"] + FEATURE_COLUMNS ).to_pandas() + for seq_id, phrase in zip(file_id_df.sequence_id, file_id_df.phrase): frames = saved_parquet_df[saved_parquet_df.index == seq_id].to_numpy() - # Handle NaN values frames_list.append(torch.tensor(frames)) phrase_list.append(phrase) + return frames_list, phrase_list class LandmarkDataset(Dataset): - def __init__(self, file_path, table, transform=True): + """ + LandmarkDataset represents a dataset of landmarks for sequence processing tasks. + """ + + def __init__(self, file_path, token_table, transform_=True): """ Initialize a LandmarkDataset. - Parameters: - - file_path (str, pathr): _description_ - - table (object): _description_ - - transform (bool, optional): _description_, by default True + Parameters + ---------- + file_path : str or path-like + Path to the dataset file. + token_table : object + An object representing a token table for phrase preprocessing. + transform_ : bool, optional + Indicates whether to apply transformations, by default True. """ self.landmarks_metadata_path = METADATA self.frames, self.labels = read_file(file_path, self.landmarks_metadata_path) - self.trans = transform - self.table = table + self.transform = transform_ + self.token_lookup_table = token_table - def _label_pre(self, label_sample): + def _phrase_preprocess(self, phrase_): """ - Preprocess label samples. + Tokenizes the input phrase. - Parameters: - - label_sample (_type_): _description_ + Parameters + ---------- + phrase_ : str + The original phrase - Returns: - - _type_: _description_ + Returns + ------- + List[int] + A list containing ints representing strings in the tokenized phrase. """ - sample = START_TOKEN + label_sample + END_TOKEN - new_phrase = self.table.tensorfromsentence(list(sample)) - ans = F.pad( - input=new_phrase, - pad=[0, 64 - new_phrase.shape[0]], + phrase = START_TOKEN + phrase_ + END_TOKEN + tokenize_phrase = self.token_lookup_table.tensor_from_sentence(list(phrase)) + tokenzie_phrase = F.pad( + input=tokenize_phrase, + pad=[0, 64 - tokenize_phrase.shape[0]], mode="constant", value=PAD_TOKEN_IDX, ) - return ans + return tokenzie_phrase def __len__(self): + """ + Returns the length of the dataset. + + Returns + ------- + int + Length of the dataset. + """ return len(self.labels) def __getitem__(self, idx): + """ + Returns a tuple containing frames and corresponding preprocessed phrase for a given index. + + Parameters + ---------- + idx : int or slice + Index or slice to retrieve from the dataset. + + Returns + ------- + tuple + A tuple containing frames and preprocessed labels. + """ if torch.is_tensor(idx): idx = idx.tolist() phrase = self.labels[idx] frames = self.frames[idx] - if self.trans: - phrase = self._label_pre(phrase) - frames = clean_frames_process(frames) + if self.transform: + phrase = self._phrase_preprocess(phrase) + frames = preprocess_frames(frames) return frames, phrase @@ -186,28 +254,45 @@ def get_dataset(file_path): """ Create a dataset with token-to-index mapping. - Parameters: - - file_path (_type_): _description_ + Parameters + ---------- + file_path : str or path-like + Path to the file containing the dataset. - Returns: - - _type_: _description_ + Returns + ------- + dataset : LandmarkDataset + An instance of LandmarkDataset with token-to-index mapping and frames. """ + lookup_table = TokenHashTable(character_to_num, num_to_character) - dataset = LandmarkDataset(file_path, lookup_table, transform=True) + dataset = LandmarkDataset(file_path, lookup_table, transform_=True) + return dataset def prepare_dataloader(dataset: Dataset, batch_size: int, num_workers_: int = 1): """ - Prepare a dataloader with distributed sampling. + Prepare a DataLoader with distributed sampling. + + Parameters + ---------- + dataset : Dataset + The dataset to load. + + batch_size : int + Number of samples per batch. - Parameters: - dataset (Dataset): The dataset to load. - batch_size (int): Number of samples per batch. - num_workers_ (int, optional): Number of workers for data loading, by default 1. + num_workers_ : int, optional + Number of workers for data loading, by default 1. - Returns: - DataLoader: A DataLoader instance for the specified dataset. + Returns + ------- + DataLoader + A DataLoader instance for the specified dataset. + + Notes + Utilize distributed sampling for better training efficiency. """ return DataLoader( dataset, @@ -233,5 +318,17 @@ def __getitem__(self, index): #! Function to get a test dataset for debugging train pipeline def get_test_dataset(): + """_summary_ + + Parameters + ---------- + pass_ : _type_ + _description_ + + Returns + ------- + _type_ + _description_ + """ dataset = TestDataset return dataset diff --git a/signa2text/src/dataset/dataset_paths.py b/signa2text/src/dataset/dataset_paths.py index 3ea6fb0f..8ff2b2be 100644 --- a/signa2text/src/dataset/dataset_paths.py +++ b/signa2text/src/dataset/dataset_paths.py @@ -1,46 +1,64 @@ -"""doc """ +Dataset Paths Module + +This module provides functions to retrieve file paths for training and validation datasets. + +Functions: +- get_dataset_paths(dev_mode=True): Retrieves paths for either development mode or the full dataset. + +""" + import os import json -from utils.logger_util import logger -def get_dataset_paths(): - """_summary_ +def get_dataset_paths(dev_mode=True): + """Get paths for training and validation datasets. + + Parameters + ---------- + dev_mode : bool, optional + If True, returns paths for development mode, else for full data. Returns ------- - _type_ - _description_ + list of tuple + List of tuples containing file paths and corresponding file IDs for training + and validation datasets. + + Raises + ------ + AssertionError + If the number of files retrieved does not match the expected count. """ - try: - # On kaggle replace with "data/dataset_paths.json" to train on full data + if dev_mode: dataset_paths = "data/dev_samples.json" - with open(dataset_paths, "r", encoding="utf-8") as json_file: - dataset_paths_dict = json.load(json_file) - - # Training dataset - train_dataset_dict = dataset_paths_dict["train_files"] - train_file_ids = [os.path.basename(file) for file in train_dataset_dict] - train_file_ids = [ - int(file_name.replace(".parquet", "")) for file_name in train_file_ids - ] - assert len(train_dataset_dict) == len( - train_file_ids - ), "Failed getting Train files path" - train_ds_files = list(zip(train_dataset_dict, train_file_ids)) - - # Validation dataset - valid_dataset_dict = dataset_paths_dict["valid_files"] - valid_file_ids = [os.path.basename(file) for file in valid_dataset_dict] - valid_file_ids = [ - int(file_name.replace(".parquet", "")) for file_name in valid_file_ids - ] - assert len(train_dataset_dict) == len( - train_file_ids - ), "Failed getting of Valid Files path" - valid_ds_files = list(zip(valid_dataset_dict, valid_file_ids)) - - return train_ds_files, valid_ds_files - except AssertionError as asset_error: - logger.exception(f"Failed due to {asset_error}") + else: + dataset_paths = "data/dataset_paths.json" + + with open(dataset_paths, "r", encoding="utf-8") as json_file: + dataset_paths_dict = json.load(json_file) + + # Training dataset + train_dataset_dict = dataset_paths_dict["train_files"] + train_file_ids = [os.path.basename(file_path) for file_path in train_dataset_dict] + train_file_ids = [ + int(file_name.replace(".parquet", "")) for file_name in train_file_ids + ] + assert len(train_dataset_dict) == len( + train_file_ids + ), "Failed getting Train files path" + train_ds_files = list(zip(train_dataset_dict, train_file_ids)) + + # Validation dataset + valid_dataset_dict = dataset_paths_dict["valid_files"] + valid_file_ids = [os.path.basename(file_path) for file_path in valid_dataset_dict] + valid_file_ids = [ + int(file_name.replace(".parquet", "")) for file_name in valid_file_ids + ] + assert len(valid_dataset_dict) == len( + valid_file_ids + ), "Failed getting Valid Files path" + valid_ds_files = list(zip(valid_dataset_dict, valid_file_ids)) + + return train_ds_files, valid_ds_files diff --git a/signa2text/src/dataset/frames_config.py b/signa2text/src/dataset/frames_config.py index a60ef371..45216387 100644 --- a/signa2text/src/dataset/frames_config.py +++ b/signa2text/src/dataset/frames_config.py @@ -1,9 +1,12 @@ """ -Module to define constants and lists related to ASL Fingerspelling frame features. +ASL Fingerspelling Frame Features Module + +This module defines constants and lists related to ASL Fingerspelling frame features. + Variables: -- FRAME_LEN: -- LIP: -- FEATURE_COLOUMNS: +- FRAME_LEN: Length of each frame. +- LIP: Indices corresponding to lip features. +- FEATURE_COLUMNS: Combined list of feature columns, including frame, hand, pose, and face features. """ # Length of each frame diff --git a/signa2text/src/dataset/preprocess.py b/signa2text/src/dataset/preprocess.py index 647ac534..e065caec 100644 --- a/signa2text/src/dataset/preprocess.py +++ b/signa2text/src/dataset/preprocess.py @@ -1,15 +1,17 @@ """ -Module to define a function for cleaning and processing ASL Fingerspelling frames. -Functions: -- clean_frames_process: +ASL Fingerspelling Frame Cleaning and Processing Module -""" +This module defines a function for cleaning and processing ASL Fingerspelling frames. +Functions: +- preprocess_frames(frames, max_frame_len, n_hand_landmarks, n_pose_landmarks, n_face_landmarks): + Cleans and processes ASL Fingerspelling frames. +""" import torch from torch.nn import functional as F -def clean_frames_process( +def preprocess_frames( frames, max_frame_len=128, n_hand_landmarks=21, @@ -20,21 +22,21 @@ def clean_frames_process( Parameters ---------- - frames : (torch.Tensor) + frames : torch.Tensor Input tensor containing frames. max_frame_len : int, optional - Maximum length of frames, by default 128 + Maximum length of frames, by default 128. n_hand_landmarks : int, optional - Number of hand landmarks, by default 21 + Number of hand landmarks, by default 21. n_pose_landmarks : int, optional - Number of pose landmarks, by default 33 + Number of pose landmarks, by default 33. n_face_landmarks : int, optional - Number of face landmarks, by default 40 + Number of face landmarks, by default 40. Returns ------- - frames - torch.Tensor: Cleaned and processed frames tensor. + torch.Tensor + Cleaned and processed frames tensor. """ # Clip frames to the maximum length frames = frames[:max_frame_len] diff --git a/signa2text/src/dev_data.py b/signa2text/src/dev_data.py deleted file mode 100644 index 6f15731a..00000000 --- a/signa2text/src/dev_data.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Dataset Download Module - -This module provides functions to download the a subsample of Google ASL dataset. - -Functions: -- download_dataset(url: str, destination: str, path): - Downloads a dataset from the given URL to the specified destination directory. -- main - the main function to run the script -""" - - -import os -import shutil -import subprocess -import zipfile - -from utils.logger_util import logger - -DATA_DIR = "kaggle/input/asl-fingerspelling/" -data_files = ["train.csv", "character_to_prediction_index.json"] -train_landmarks = ["1019715464.parquet", "1021040628.parquet", "105143404.parquet"] -TRAIN_LANDMARKS_DIR = "train_landmarks/" - -COMMAND = [ - "kaggle", - "competitions", - "download", - "-c", - "asl-fingerspelling", - "-f", - "FILE", - "-p", - f"{DATA_DIR}", -] - - -def check_storage(project_dir=os.getcwd()): - """check and return availabe storage space - - Parameters - ---------- - directory_path : str, Path - current working directory/directory path - - Returns - ------- - int - the size of available storage space (GB) - - Raises - ------ - StorageFullError - exception for when storage is full. - """ - total, used, free = shutil.disk_usage(project_dir) - total_size_gb = round(total / (2**30), 2) - used_size_gb = round(used / (2**30), 2) - free_size_gb = round(free / (2**30), 2) - if used_size_gb / total_size_gb >= 0.8: - raise StorageFullError - return free_size_gb - - -class StorageFullError(Exception): - """Custom exception for when storage is full.""" - - pass - - -def downlaod_file(cmd, unzipped_file_path, data_dir): - """Download file using kaggle API - - Parameters - ---------- - cmd : list - Kaggle API Commands - unzipped_file : str, Path - path of the unzipped file - data_dir : str, Path - the directory where the data should be downloaded into - """ - subprocess.run(cmd, check=True, text=True) - if ( - os.path.exists(unzipped_file_path) - and os.path.splitext(unzipped_file_path)[1].lower() == ".zip" - ): - # Unzipping and delete the zipped file to free storage - with zipfile.ZipFile(unzipped_file_path, "r") as zip_ref: - zip_ref.extractall(data_dir) - os.remove(unzipped_file_path) - else: - pass - - -def main(): - """the main function to run the script""" - logger.info("Commencing downloading the dataset") - try: - logger.info(f"Current Available space {check_storage()}GB") - for file in data_files: - logger.info(f"Downloading {file} in {DATA_DIR}") - COMMAND[6] = file - unzipfile_path = DATA_DIR + file + ".zip" - downlaod_file(COMMAND, unzipfile_path, DATA_DIR) - logger.info(f" {file} downloaded succesful") - # Downloading the LANDMARKS files - for parquet_file in train_landmarks: - logger.info(f"Current Available space {check_storage()}GB") - file_path = TRAIN_LANDMARKS_DIR + parquet_file - COMMAND[6] = file_path - COMMAND[8] = DATA_DIR + TRAIN_LANDMARKS_DIR - unzipfile_path = DATA_DIR + file_path + ".zip" - downlaod_file(COMMAND, unzipfile_path, DATA_DIR + TRAIN_LANDMARKS_DIR) - logger.info(f"{parquet_file} downloaded succesfully") - - logger.success("All files downloaded succesfully") - - except Exception as error: - logger.exception(f"Data unloading was unsuccesfully due to {error}") - - -if __name__ == "__main__": - main() diff --git a/signa2text/src/download_dev_data.py b/signa2text/src/download_dev_data.py new file mode 100644 index 00000000..07dc0169 --- /dev/null +++ b/signa2text/src/download_dev_data.py @@ -0,0 +1,165 @@ +"""Dataset Download Module + +This module provides functionality to download a subsample of the Google ASL dataset through +the Kaggle API. It includes functions to download specific files, check available storage space, +and a main script to orchestrate the downloading process. + +Functions: +- download_dataset(url: str, destination: str, path: str): + Downloads a dataset from the given URL to the specified destination directory. + +- check_storage(project_dir: str = os.getcwd()) -> float: + Checks and returns the available storage space in gigabytes (GB) for a specified directory. + +- download_file(cmd: List[str], unzipped_file_path: str, data_dir: str): + Downloads a file using Kaggle API commands and unzips it. + +- main(): + The main function to execute the script. It orchestrates the download of various dataset files\ + and logs progress. + +Constants: +- DATA_DIR: Default directory for storing downloaded data. +- data_files: List of files to download from the Kaggle competition. +- train_landmarks: List of additional files related to landmarks for training data. +- TRAIN_LANDMARKS_DIR: Directory for storing downloaded landmark files. + +COMMAND: Kaggle API command template used for downloading files. + +Custom Exception: +- StorageFullError: Raised when the available storage space is insufficient. + +Usage: +- Ensure the Kaggle API is configured. +- Execute the script to download the specified dataset files. +""" + + +import os +import shutil +import subprocess +import zipfile +from utils.logging import logger + +DATA_DIR = "kaggle/input/asl-fingerspelling/" +data_files = ["train.csv", "character_to_prediction_index.json"] +train_landmarks = ["1019715464.parquet", "1021040628.parquet", "105143404.parquet"] +TRAIN_LANDMARKS_DIR = "train_landmarks/" + +COMMAND = [ + "kaggle", + "competitions", + "download", + "-c", + "asl-fingerspelling", + "-f", + "FILE_NAME", + "-p", + f"{DATA_DIR}", +] + + +class StorageFullError(Exception): + """Custom exception for when storage is full.""" + + pass + + +def check_storage(project_dir=os.getcwd()): + """Check and return available storage space. + + Parameters + ---------- + project_dir : str, Path + Current working directory or directory path. + + Returns + ------- + int + The size of available storage space (GB). + + Raises + ------ + StorageFullError + Exception for when storage is full. + """ + total, used, free = shutil.disk_usage(project_dir) + total_size_gb = round(total / (2**30), 2) + used_size_gb = round(used / (2**30), 2) + free_size_gb = round(free / (2**30), 2) + + if used_size_gb / total_size_gb >= 0.8: + raise StorageFullError("Storage is full. Cannot perform the operation.") + return free_size_gb + + +def download_file(cmd, unzipped_file_path, data_dir): + """Download file using Kaggle API. + + Parameters + ---------- + cmd : list + Kaggle API Commands. + unzipped_file_path : str, Path + Path of the unzipped file. + data_dir : str, Path + The directory where the data should be downloaded into. + """ + subprocess.run(cmd, check=True, text=True) + if ( + os.path.exists(unzipped_file_path) + and os.path.splitext(unzipped_file_path)[1].lower() == ".zip" + ): + # Unzipping and delete the zipped file to free storage + with zipfile.ZipFile(unzipped_file_path, "r") as zip_ref: + zip_ref.extractall(data_dir) + os.remove(unzipped_file_path) + + +def main(): + """ + Orchestrates the dataset download using the Kaggle API. + + Downloads specified dataset and landmark files, logging progress and checking storage space. + + Raises + ------ + Exception + If an error occurs during the download process. + """ + logger.info("Commencing downloading the dataset") + try: + logger.info(f"Current Available space {check_storage()}GB") + + # Downloading the metadata files + for file in data_files: + logger.info(f"Downloading {file} in {DATA_DIR}") + + # Swtiching "FILE_NAME" in cmd list with the actual file name in kaggle + COMMAND[6] = file + unzipfile_path = DATA_DIR + file + ".zip" + download_file(COMMAND, unzipfile_path, DATA_DIR) + logger.info(f"{file} downloaded successfully") + + # Swtiching the directotry to download the landmarks into + COMMAND[8] = DATA_DIR + TRAIN_LANDMARKS_DIR + + # Downloading the LANDMARKS files + for parquet_file in train_landmarks: + logger.info(f"Current Available space {check_storage()}GB") + file_path = TRAIN_LANDMARKS_DIR + parquet_file + + # Swtiching "FILE_NAME" in cmd list with the actual file name in kaggle + COMMAND[6] = file_path + unzipfile_path = DATA_DIR + file_path + ".zip" + download_file(COMMAND, unzipfile_path, DATA_DIR + TRAIN_LANDMARKS_DIR) + logger.info(f"{parquet_file} downloaded successfully") + + logger.success("All files downloaded successfully") + + except Exception as error: + logger.exception(f"Data unloading was unsuccessful due to {error}") + + +if __name__ == "__main__": + main() diff --git a/signa2text/src/main.py b/signa2text/src/main.py index bdfeb63b..67a770dd 100644 --- a/signa2text/src/main.py +++ b/signa2text/src/main.py @@ -5,14 +5,14 @@ # TODO cleanup and complete documentation # TODO Complete and refactor code for distributed training -# TODO remove test model and test data +# TODO remove test model and test data\ +# TODO add wandb for monitoring and saving model state import torch -from torch import nn -from utils.util import parse_args, set_seed -from utils.logger_util import logger +from utils.tools import parse_args, set_seed +from utils.logging import logger from models.model_loader import ModelLoader from dataset.dataset_loader import get_dataset, prepare_dataloader, get_test_dataset from dataset.dataset_paths import get_dataset_paths @@ -20,7 +20,7 @@ from torch.distributed import destroy_process_group -def load_train_objs(model_name, files=None): +def load_train_objs(model_name, files_paths): """ Load training objects, including the model, optimizer, dataset, and criterion. @@ -39,9 +39,9 @@ def load_train_objs(model_name, files=None): # Optimizes given model/function using TorchDynamo and specified backend torch.compile(model) optimizer_ = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) - criterion = nn.CrossEntropyLoss(label_smoothing=0.1) - dataset = get_test_dataset() # get_dataset(files) - return model, optimizer_, dataset, criterion + criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1) + dataset = get_test_dataset() # get_dataset(files_paths) + return model, optimizer_, criterion, dataset def main(model_name: str, save_every: int, total_epochs: int, batch_size: int): @@ -55,7 +55,9 @@ def main(model_name: str, save_every: int, total_epochs: int, batch_size: int): - batch_size (int): Batch size for training. """ logger.info(f"Starting training on {model_name}, epoch -> {total_epochs}") - logger.info(f"Batch Size -> {batch_size}, model saved every -> {save_every} epoch") + logger.info( + f"Batch Size -> {batch_size}, model to be saved every -> {save_every} epoch" + ) # To ensure reproducibility of the training process set_seed() @@ -63,7 +65,11 @@ def main(model_name: str, save_every: int, total_epochs: int, batch_size: int): try: # train, valid = get_dataset_paths() ddp_setup() - dataset, model, optimizer, criterion = load_train_objs(model_name) + model, optimizer, criterion, dataset = load_train_objs( + model_name, files_paths=None + ) + + #! DEBUG and add validation process train_dataset = prepare_dataloader( dataset, batch_size, @@ -79,14 +85,15 @@ def main(model_name: str, save_every: int, total_epochs: int, batch_size: int): trainer.train(total_epochs) destroy_process_group() - logger.success(f"Training completed: {total_epochs} epochs on.") + logger.success(f"Training completed: {total_epochs} epochs .") + except Exception as error: - logger.exception(f"Training failed due to {error}.") + logger.exception(f"Training failed due to -> {error}.") if __name__ == "__main__": arg = parse_args() - logger.info(f"{arg.model_name}") + logger.info(f"Model to be trained is: {arg.model_name}") main( model_name=arg.model_name, save_every=arg.save_every, diff --git a/signa2text/src/models/model_loader.py b/signa2text/src/models/model_loader.py index 529690f8..13f45455 100644 --- a/signa2text/src/models/model_loader.py +++ b/signa2text/src/models/model_loader.py @@ -1,11 +1,19 @@ -"""doc - """ +Model Loader Module -from models.baseline_transformer import ASLTransformer -import torch +This module defines a class for loading different models. +Classes: +- ModelLoader: Loads various models. +Methods: +- get_model(model_name): Builds and retrieves a specific model instance. + +""" +import torch +from models.baseline_transformer import ASLTransformer + +# TODO remove test model def test_model(): model = torch.nn.Sequential( [torch.nn.Linear(20, 100), torch.nn.Linear(100, 10), torch.nn.Linear(10, 5)] @@ -17,22 +25,21 @@ class ModelLoader: """Model Loader""" def __init__(self): - self.models = {"asl_transfomer": ASLTransformer(), "test_model": test_model()} + self.models = {"asl_transformer": ASLTransformer(), "test_model": test_model()} def get_model(self, model_name): - """build and retrieve the model instance + """Build and retrieve the model instance. Parameters ---------- model_name : str - model name + Name of the model. Returns ------- object - return built model instance + Built model instance. """ - if model_name in self.models: return self.models[model_name] else: diff --git a/signa2text/src/tests/test_data_ingestion.py b/signa2text/src/tests/test_data_ingestion.py index ab9d97f5..b5adacf9 100644 --- a/signa2text/src/tests/test_data_ingestion.py +++ b/signa2text/src/tests/test_data_ingestion.py @@ -1,16 +1,20 @@ -"doc" +""" +Test Module for Dataset Ingestion and Preprocessing -import pytest +This module contains test functions for dataset ingestion, preprocessing and tokenization. -import torch -from src.dataset.frames_config import FRAME_LEN -from src.dataset.preprocess import clean_frames_process -from src.dataset.dataset_loader import TokenHashTable +Functions: +- test_frames_preprocess(frames): Tests the preprocessing of frames. -# TODO test for frames in right shapes, in tensor, frames are normalize -# TODO test for frames dont contain NAN +- test_token_hash_table(): Tests the TokenHashTable for sentence tokenization. +""" + +import pytest -# TODO test for labels are tokensize +import torch +from dataset.frames_config import FRAME_LEN +from dataset.preprocess import preprocess_frames +from dataset.dataset_loader import TokenHashTable @pytest.mark.parametrize( @@ -18,19 +22,28 @@ [torch.randn(num_frames, 345) for num_frames in [10, 108, 128, 156, 750, 420]], ) def test_frames_preprocess(frames): - """doc""" - frames = clean_frames_process(frames) + """Tests the preprocessing of frames. + + Parameters + ---------- + frames : torch.Tensor + Input tensor containing frames for preprocessing. + """ + frames = preprocess_frames(frames) expected_output_shape = (128, 345) assert ( expected_output_shape == frames.shape - ), f"frames shape should be {expected_output_shape}" + ), f"Frames shape should be {expected_output_shape}, got {frames.shape}" def test_token_hash_table(): - token_table = TokenHashTable() + """ + Tests the TokenHashTable for sentence tokenization. + """ + token_lookup_table = TokenHashTable() sample_sentence = "this is a test run" sample_sentence_len = len(sample_sentence) - sample_sentence_token = [ + sample_sentence_tokens = [ 51, 39, 40, @@ -50,14 +63,20 @@ def test_token_hash_table(): 52, 45, ] - sample_sentence_token = torch.tensor(sample_sentence_token, dtype=torch.long) - tokenize_result = token_table.sentence_to_tensor(sample_sentence) + sample_sentence_tokens = torch.tensor(sample_sentence_tokens, dtype=torch.long) + tokenize_result = token_lookup_table.sentence_to_tensor(sample_sentence) + + # Assert the length of tokenize text + assert sample_sentence_len == len( + tokenize_result + ), f"Expexted length of tokenize text to be {sample_sentence_len}, got {len(tokenize_result)}" is_same = all( torch.equal(idx1, idx2) - for idx1, idx2 in zip(sample_sentence_token, tokenize_result) + for idx1, idx2 in zip(sample_sentence_tokens, tokenize_result) ) - assert sample_sentence_len == len(tokenize_result) - assert is_same == True + # Assert tokens match the expected value + assert is_same == True, "Tokens do not match the expected value" + # Assert that clean_frames is a PyTorch tensor - assert torch.is_tensor(tokenize_result), "is not PyTorch tensor" + assert torch.is_tensor(tokenize_result), "Tokens are not PyTorch tensor" diff --git a/signa2text/src/utils/logger_util.py b/signa2text/src/utils/logging.py similarity index 90% rename from signa2text/src/utils/logger_util.py rename to signa2text/src/utils/logging.py index f81d88f9..e895af67 100644 --- a/signa2text/src/utils/logger_util.py +++ b/signa2text/src/utils/logging.py @@ -26,8 +26,4 @@ log_filepath = Path(LOG_DIR, "running_logs.log") Path.mkdir(LOG_DIR, exist_ok=True) -logger.add( - log_filepath, - format=FORMAT_STYLE, - level="INFO", -) +logger.add(log_filepath, format=FORMAT_STYLE, level="INFO", retention="3 days") diff --git a/signa2text/src/utils/util.py b/signa2text/src/utils/tools.py similarity index 55% rename from signa2text/src/utils/util.py rename to signa2text/src/utils/tools.py index cd27952f..06eb5cbc 100644 --- a/signa2text/src/utils/util.py +++ b/signa2text/src/utils/tools.py @@ -1,3 +1,17 @@ +""" +Utility Module for Training + +This module provides utility functions for training, including setting random seeds, +device strategy, and argument parsing. + +Functions: +- set_seed(seed: int = 42) -> None: Sets random seeds for reproducibility. +- get_device_strategy(tpu: bool = False): Returns the device strategy based \ + on CPU/GPU/TPU availability. +- parse_args(): Parses arguments for the training script. + +""" + import os import random import argparse @@ -5,9 +19,17 @@ import numpy as np import torch -#import torch_xla.core.xla_model as xm +# import torch_xla.core.xla_model as xm + def set_seed(seed: int = 42) -> None: + """Sets random seeds for reproducibility. + + Parameters + ---------- + seed : int, optional + Seed value for random number generation, by default 42. + """ np.random.seed(seed) random.seed(seed) torch.manual_seed(seed) @@ -19,20 +41,32 @@ def set_seed(seed: int = 42) -> None: os.environ["PYTHONHASHSEED"] = str(seed) -def get_device_strategy(tpu=False): +def get_device_strategy(tpu: bool = False) -> torch.device: + """Returns the device strategy based on CPU/GPU/TPU availability. + + Parameters + ---------- + tpu : bool, optional + Flag indicating whether to train on TPU, by default False. + + Returns + ------- + torch.device + Device (CPU or GPU) for training. + """ if tpu: - device = None #xm.xla_device() + device = None # xm.xla_device() else: - device = torch.device("cuda" if torch.cuda.is_availabe() else "cpu") + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") return device -def parse_args(): +def parse_args() -> argparse.Namespace: """ Parse arguments given to the script. Returns: - The parsed argument object. + argparse.Namespace: Parsed argument object. """ parser = argparse.ArgumentParser( description="Run distributed data-parallel training and log with wandb." @@ -40,7 +74,7 @@ def parse_args(): parser.add_argument( "--model_name", - default="baseline_transfomer", + default="baseline_transformer", type=str, metavar="N", help="name of model to train", @@ -74,7 +108,7 @@ def parse_args(): ) parser.add_argument( "--save_every", - default= 2, + default=2, type=int, help="", ) diff --git a/yb2audio/README.md b/yb2audio/README.md index 63bbd98f..5f2795ae 100644 --- a/yb2audio/README.md +++ b/yb2audio/README.md @@ -1,9 +1,26 @@ -# Yoruba to Audio-Loom (***yb2audio***) +# Yb2Audio [![LICENSE](https://img.shields.io/badge/license-MIT-green?style=flat-square)](LICENSE) -[![Python](https://img.shields.io/badge/python-3.10-blue.svg?style=flat-square)](https://www.python.org/) +[![Python](https://img.shields.io/badge/python-3.11.2-blue.svg?style=flat-square)](https://www.python.org/) [![PyTorch](https://img.shields.io/badge/PyTorch-2.1.0-orange)](https://pytorch.org/) + +
+ Table of Contents +
    +
  1. Project description
  2. +
  3. + Getting Started + +
  4. +
  5. Contact
  6. +
  7. Acknowledgments
  8. +
+
+ ## Project description ***Overview:*** \ @@ -28,21 +45,21 @@ In our ongoing pursuit of advancing communication accessibility for Nigerian Sig **Future Implications:** The successful implementation of this generative AI approach holds promising implications for various applications not only to our main project, including but not limited to improved accessibility in education, enhanced communication tools for diverse linguistic communities, and a more inclusive digital landscape. -## Setup +## Getting Started -### Configuration +### Prerequisites ```bash -# Clone this repository -$ git clone +# Clone the main repository +$ git clone https://github.com/rileydrizzy/NSL_2_AUDIO + +# Go into the main repository +$ cd NSL_2_AUDIO # Install dependencies $ make setup -# activate virtual enviroment -$ source $(poetry env info --path)/bin/activate - -# Go into the repository +# To go into this directory $ cd yb2audio ``` @@ -51,8 +68,16 @@ $ cd yb2audio ```bash # Tun to download the dataset $ python src/dataset/download_dataset.py + ``` +## Contact + +**Name:** **Ipadeola Ezekiel Ladipo** +**Email:** +**GitHub:** [@rileydrizzy](https://github.com/rileydrizzy) +**Linkdeln:** [Ipadeola Ladipo](https://www.linkedin.com/in/ladipo-ipadeola/) + ## Acknowledgments I would like to acknowledge the outstanding contributions of : @@ -60,10 +85,3 @@ I would like to acknowledge the outstanding contributions of : **Name:** Afonja Tejumade ***(```Mentor```)*** **Email:** **GitHub:** [@tejuafonja](https://github.com/tejuafonja) - -## Contact - -**Name:** **Ipadeola Ezekiel Ladipo** -**Email:** -**GitHub:** [@rileydrizzy](https://github.com/rileydrizzy) -**Linkdeln:** [Ipadeola Ladipo](https://www.linkedin.com/in/ladipo-ipadeola/) diff --git a/yb2audio/src/__init__.py b/yb2audio/src/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/yb2audio/src/download_dataset.py b/yb2audio/src/download_dataset.py index 2c88f0c0..63852ea0 100644 --- a/yb2audio/src/download_dataset.py +++ b/yb2audio/src/download_dataset.py @@ -3,9 +3,16 @@ This module provides functions to download datasets the IroyinSpeech dataset. Functions: -- download_dataset(url: str, destination: str) -> bool: +- download_dataset(url: str, destination: str): Downloads a dataset from the given URL to the specified destination directory. - main - the main function to run the script + +Usage: + To download the dataset, run the script directly. + +Example: + $ python src/download_dataset.py + """ import os @@ -34,13 +41,15 @@ def download_dataset_(url, destination_dir): def main(): - """main function to run the script""" + """ + main function to run the script + """ logger.info(f"Commencing downloading the dataset into {DATA_DIR}") try: download_dataset_(url=URL_, destination_dir=DATA_DIR) logger.success(f"Dataset downloaded to {DATA_DIR} successfully.") except Exception as error: - logger.error(f"Dataset download failed due to: {error}") + logger.exception(f"Dataset download failed due to: {error}") if __name__ == "__main__": diff --git a/yb2audio/src/utils/__init__.py b/yb2audio/src/utils/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/yb2audio/src/utils/benchmark.py b/yb2audio/src/utils/benchmark.py deleted file mode 100644 index e69de29b..00000000