Skip to content

Commit

Permalink
updates
Browse files Browse the repository at this point in the history
  • Loading branch information
rileydrizzy committed Apr 22, 2024
1 parent 850ef81 commit d74c478
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 72 deletions.
27 changes: 0 additions & 27 deletions Makefile

This file was deleted.

8 changes: 4 additions & 4 deletions signa2text/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ help:
@echo " setup Set up the environment with the required dependencies and environment variables"
@echo " freeze_reqs Save the dependencies onto the requirements text file"
@echo " precommit Runs precommit on all files"
@echo " download_data Download "
@echo " run_train Training"
@echo " download_data Download necessary data for training "
@echo " run_train Execute the training process"

setup:
@echo "Installing and setting up dependencies..."
. ./run_setup.sh

. ./run_setup.sh no-venv
@echo "Setting Enviroment Variables"
. ./set_environment_variables.sh

Expand Down
16 changes: 16 additions & 0 deletions signa2text/run_setup.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#!/usr/bin/env bash

# Exit immediately if any command exits with a non-zero status
set -e

venv="$1"

if [ "$venv" == "no-venv" ]; then
echo "Installing without creating a virtual environment."
pip install --no-cache-dir -r requirements.txt
else
echo "Installing with virtual environment."
python -m venv env
source env/bin/activate
pip install --no-cache-dir -r requirements.txt
fi
20 changes: 11 additions & 9 deletions signa2text/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,16 @@
import hydra

from omegaconf import DictConfig
from utils.tools import resume_training, set_seed
from utils.logging import logger
from models.model_loader import ModelLoader
from dataset.dataset_loader import get_dataset, prepare_dataloader # get_test_dataset
from dataset.dataset_paths import get_dataset_paths
from lightning import Trainer
from lightning.pytorch.loggers import WandbLogger
from dataset.dataset_loader import get_dataset, prepare_dataloader
from dataset.dataset_paths import get_dataset_paths
from models.model_loader import ModelLoader
from metrics import NormalizedLevenshteinDistance
from trainer import LitModule, profiler
from utils.logging import logger
from utils.tools import resume_training, set_seed
from config import PROJECT_NAME
from lightning.pytorch.loggers import WandbLogger


MAX_TRAIN_TIME = "00:06:00:00"
Expand Down Expand Up @@ -57,8 +58,9 @@ def load_train_objs(model_name):
# Optimizes given model/function using TorchDynamo and specified backend
torch.compile(model)
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
acc_metric = None

return model, criterion
return model, criterion, acc_metric


@hydra.main(config_name="train", config_path="config", version_base="1.2")
Expand All @@ -80,7 +82,7 @@ def main(cfg: DictConfig):

train_data_paths, valid_data_paths = get_dataset_paths(dev_mode=cfg.dev_mode)

model, criterion = load_train_objs(cfg.model_name)
model, criterion, _ = load_train_objs(cfg.model_name)

logger.info("Initializing WANDB")

Expand Down Expand Up @@ -110,7 +112,7 @@ def main(cfg: DictConfig):
model_name=cfg.model_name,
model=model,
loss_criterion=criterion,
metric=None,
acc_metric=NormalizedLevenshteinDistance,
save_ckpt_every=cfg.params.save_every,
)
trainer = Trainer(
Expand Down
35 changes: 29 additions & 6 deletions signa2text/src/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,45 @@
"""

# TODO add loss/criterion for training
# TODO add Leve distances metric for eval
# TODO add Lv distances metric for eval
# TODO add CTC Loss


import torch
from torchmetrics.text import EditDistance
from torchmetrics import Metric


class NormalizedLevenshteinDistance(torch.nn.Module):
# impl loss
"""
In summary, "normalized total Levenshtein distance" adjusts the raw Levenshtein distance
to a standardized scale,
typically between 0 and 1, to facilitate comparison across different pairs of strings.
The edit distance is the number of characters that need to be substituted, inserted,
or deleted, to transform the predicted text into the reference text. The lower the distance,
the more accurate the model is considered to be.
"""


"""class NormalizedLevenshteinDistance(torch.nn.Module):
def __init__(self):
super().__init__()
self.levenshte_indistance = EditDistance(reduction="sum")
self.levenshtein_distance = EditDistance(reduction="sum")
def forward(self, predictions, targets):
total_chars = sum(len(char) for char in targets)
total_distance = self.levenshte_indistance(predictions, targets)
total_chars = sum(len(label) for label in targets)
total_distance = self.levenshtein_distance(predictions, targets)
result = (total_chars - total_distance) / total_chars
return result
"""


class NormalizedLevenshteinDistance(Metric):
def __init__(self, **kwargs: torch.Any):
super().__init__()
text = None

def update(self):
pass

def compute(self, predictions, targets):
total_chars = sum(len(label) for label in targets)
30 changes: 16 additions & 14 deletions signa2text/src/models/baseline_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,23 +41,25 @@ def forward(self, x):
Parameters
----------
x : tensors
input tensor with shape (batch_size, sequence_length)
input tensor with shape (batch_size, sequence_length, )
Returns
-------
tensors
embedded tensor with shape (batch_size, sequence_length, embedding_dim)
"""
batch_size, maxlen = x.size()
if len(x.size()) > 1:
max_len = x.size(1)
else:
max_len = x.size(0)

print(max_len)
# Token embedding
x = self.token_embed_layer(x)

# Positional encoding
positions = torch.arange(0, maxlen).to(x.device)
positions = (
self.position_embed_layer(positions).unsqueeze(0).expand(batch_size, -1, -1)
)
positions = torch.arange(0, max_len).unsqueeze(0)
positions = self.position_embed_layer(positions) # expand(batch_size, -1, -1)

return x + positions

Expand All @@ -78,23 +80,25 @@ def __init__(self, embedding_dim):
in_channels=64, out_channels=128, kernel_size=11, stride=2, padding=padding
)
self.conv3_layer = nn.Conv1d(
in_channels=128, out_channels=256, kernel_size=11, stride=2, padding=padding
in_channels=128, out_channels=64, kernel_size=11, stride=2, padding=padding
)

# Output embedding layer
self.embedding_layer = nn.Linear(256, embedding_dim)
self.embedding_layer = nn.Linear(44, embedding_dim)

def forward(self, x):
# Input x should have shape (batch_size, input_size, input_dim)
# x = x.unsqueeze(1) # Add a channel dimension for 1D convolution

# Apply convolutional layers with ReLU activation and stride 2
x = torch.relu(self.conv1_layer(x))
x = torch.relu(self.conv2_layer(x))
x = torch.relu(self.conv3_layer(x))

# print(x.size())

# Global average pooling to reduce spatial dimensions
x = torch.mean(x, dim=2)
# x = torch.mean(x, dim=2)

# print(x.size())

# Apply the linear embedding layer
x = self.embedding_layer(x)
Expand Down Expand Up @@ -322,9 +326,7 @@ def generate(self, source, target_start_token_idx=60):
_description_
"""
encoder_out = self.encoder(source)
decoder_input = (
torch.ones((1), dtype=torch.long).to(source.device) * target_start_token_idx
)
decoder_input = torch.ones((1), dtype=torch.long) * target_start_token_idx
dec_logits = []
for _ in range(self.target_maxlen - 1):
decoder_out = self._decoder_run(encoder_out, decoder_input)
Expand Down
24 changes: 12 additions & 12 deletions signa2text/src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
"""

# TODO implement loss and metrics

import torch
import lightning as L
from lightning.pytorch.callbacks import (
Expand All @@ -17,7 +15,6 @@
from lightning.pytorch.profilers import SimpleProfiler

# from utils.logging import logger
# from metrics import NormalizedLevenshteinDistance

# Checkpoint Filename Template
FILENAME_TEMPLATE = "NSL-2-AUDIO-{epoch}-{val_loss:.2f}"
Expand All @@ -29,31 +26,34 @@ class LitModule(L.LightningModule):
"""_summary_"""

def __init__(
self, model, loss_criterion, metric, save_ckpt_every=5, model_name="test"
self, model_name, model, loss_criterion, acc_metric, save_ckpt_every=5
):
super().__init__()
self.model = model
self.loss_criterion = loss_criterion
self.metric = metric
self.accuracy_metric = acc_metric
self.save_ckpt_every = save_ckpt_every
self.checkpoint_dir = f"artifacts/{model_name}/"
self.save_hyperparameters()

def _get_preds_loss_accuracy(self, batch):
sources, targets = batch
preds = self.model(sources, targets)
loss = self.loss_criterion(preds, sources)
# Levenshtein_dis = self.metric(oi)
return preds, loss
# loss = self.loss_criterion(preds, targets)

acc_loss = self.accuracy_metric()
return loss, acc_loss, preds

def training_step(self, batch, batch_idx):
preds, loss = self._get_preds_loss_accuracy(batch)
self.log("loss", loss)
loss, acc_loss, preds = self._get_preds_loss_accuracy(batch)

self.log("loss", loss, on_epoch=True, on_step=False)
return loss

def validation_step(self, batch, batch_idx):
preds, val_loss = self._get_preds_loss_accuracy(batch)
self.log("val_loss", val_loss)
val_loss, _, preds = self._get_preds_loss_accuracy(batch)

self.log("val_loss", val_loss, on_epoch=True, on_step=False)
return preds

def configure_optimizers(self):
Expand Down

0 comments on commit d74c478

Please sign in to comment.