Skip to content

Commit

Permalink
Dev changes
Browse files Browse the repository at this point in the history
  • Loading branch information
rileydrizzy committed Mar 8, 2024
1 parent c7c37ee commit 240a047
Show file tree
Hide file tree
Showing 11 changed files with 1,124 additions and 448 deletions.
5 changes: 2 additions & 3 deletions signa2text/run_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

# Display a header with script information
echo "=== Running Train Script ==="
python src/main.py

torchrun --standalone --nproc_per_node=1 src/main.py --model_name test_model --epoch 2
#torchrun --standalone --nproc_per_node=1 src/main.py
#--epochs 10 --batch 512
echo "=== Completed Train Script Run ==="
12 changes: 12 additions & 0 deletions signa2text/src/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
"""doc
"""

from pydantic import BaseModel
import hydra


class Arg_type(BaseModel):
save_every: int


@hydra.main(config_name="train", config_path="config", version_base="1.2")
def main(cfg):
A = cfg.wandb_params.model_run_id
print(type(A))


main()
19 changes: 12 additions & 7 deletions signa2text/src/config/train.yaml
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
args:
model_name: "baseline_transformer"
model_name: "asl_baseline_transformer"

params:
save_every: 2
epochs: 2
batch: 64
tpu: False
total_epochs: 2
batch_size: 128
valid_epoch: 2

wandb_params:
resume_checkpoint: False
dev_mode: True
wandb_: False
model_run_id: "None"
model_version: "latest"

dev_mode: True
49 changes: 7 additions & 42 deletions signa2text/src/dataset/dataset_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,21 @@
- get_dataset(file_path): Creates a dataset with a token-to-index mapping.
- prepare_dataloader(dataset, batch_size, num_workers_= 1): Prepares a dataloader with\
distributed sampling.
- prepare_dataloader(dataset, batch_size, num_workers_= 1): Prepares a dataloader
"""


import json
import pandas as pd
import pyarrow.parquet as pq
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from dataset.frames_config import FEATURE_COLUMNS
from dataset.preprocess import preprocess_frames

# File paths for metadata and phrase-to-index mapping
PHRASE_PATH = "/kaggle/input/asl-fingerspelling/character_to_prediction_index.json"
METADATA = "/kaggle/input/asl-fingerspelling/train.csv"
PHRASE_PATH = "kaggle/input/asl-fingerspelling/character_to_prediction_index.json"
METADATA = "kaggle/input/asl-fingerspelling/train.csv"

# Load phrase-to-index mapping
with open(PHRASE_PATH, "r", encoding="utf-8") as f:
Expand Down Expand Up @@ -271,9 +268,9 @@ def get_dataset(file_path):
return dataset


def prepare_dataloader(dataset, batch_size, num_workers_=1):
def prepare_dataloader(dataset, batch_size, num_workers_=0):
"""
Prepare a DataLoader with distributed sampling.
Prepare a DataLoader
Parameters
----------
Expand All @@ -284,51 +281,19 @@ def prepare_dataloader(dataset, batch_size, num_workers_=1):
Number of samples per batch.
num_workers_ : int, optional
Number of workers for data loading, by default 1.
Number of workers for data loading, by default 0.
Returns
-------
DataLoader
A DataLoader instance for the specified dataset.
Notes
Utilize distributed sampling for better training efficiency.
"""
return DataLoader(
dataset,
batch_size=batch_size,
pin_memory=True,
num_workers=num_workers_,
sampler=DistributedSampler(dataset),
)


#! A dataset class for debugging the train pipeline
class TestDataset(Dataset):
def __init__(self, size):
self.size = size
self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)]

def __len__(self):
return self.size

def __getitem__(self, index):
return self.data[index]


#! Function to get a test dataset for debugging train pipeline
def get_test_dataset():
"""_summary_
Parameters
----------
pass_ : _type_
_description_
Returns
-------
_type_
_description_
"""
dataset = TestDataset
return dataset
Loading

0 comments on commit 240a047

Please sign in to comment.