Skip to content

Commit c597b82

Browse files
committed
fix: batch transform
1 parent df89640 commit c597b82

4 files changed

Lines changed: 33 additions & 3 deletions

File tree

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ A comprehensive, production-ready PyTorch project template with modular architec
1010
- **📊 Experiment Tracking**: MLflow and Weights & Biases integration with auto-visualization
1111
- **🔧 Modern Tooling**: uv package management, pre-commit hooks, Docker support
1212
- **💾 Resume Training**: Automatic checkpoint saving and loading with state preservation
13-
- **🌐 Cross-Platform**: Development support on macOS and Linux with optimized builds
13+
- **🌐 Cross-Platform**: Development support on macOS (Apple Silicon MPS), Linux with optimized builds
1414
- **🐳 Development Environment**: Devcontainer and Jupyter Lab integration
1515
- **⚡ Performance Optimization**: RAM caching, mixed precision, torch.compile support
1616
- **📚 Auto Documentation**: Sphinx-based API docs with live reloading

src/transform/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .build import (
66
BATCHED_TRANSFORM_REGISTRY,
77
TRANSFORM_REGISTRY,
8+
BatchedTransformCompose,
89
build_batched_transform,
910
build_transform,
1011
build_transforms,

src/transform/batch_compose.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
from ..dataloaders import DatasetOutput
5+
6+
7+
class BatchedTransformCompose:
8+
def __init__(self, transforms: list[nn.Module]) -> None:
9+
self.transforms = transforms
10+
11+
def to(self, device: torch.device) -> "BatchedTransformCompose":
12+
for t in self.transforms:
13+
if isinstance(t, torch.nn.Module):
14+
t.to(device)
15+
return self
16+
17+
def __call__(self, data: DatasetOutput) -> DatasetOutput:
18+
for t in self.transforms:
19+
data = t(data)
20+
return data
21+
22+
def __repr__(self) -> str:
23+
format_string = self.__class__.__name__ + "("
24+
for t in self.transforms:
25+
format_string += "\n"
26+
format_string += f" {t}"
27+
format_string += "\n)"
28+
return format_string

src/transform/build.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from ..config import TransformConfig
44
from ..utils import Registry
55
from .base import BaseTransform
6+
from .batch_compose import BatchedTransformCompose
67

78
TRANSFORM_REGISTRY = Registry("TRANSFORM")
89
"""Registry for data transformation classes."""
@@ -54,7 +55,7 @@ def build_transforms(cfg: list[TransformConfig]) -> T.Compose:
5455
return T.Compose(transforms)
5556

5657

57-
def build_batched_transform(cfg: list[TransformConfig]) -> T.Compose:
58+
def build_batched_transform(cfg: list[TransformConfig]) -> BatchedTransformCompose:
5859
"""Build a composition of batch-level transformations.
5960
6061
Args:
@@ -66,4 +67,4 @@ def build_batched_transform(cfg: list[TransformConfig]) -> T.Compose:
6667
batched_transforms = []
6768
for cfg_transform in cfg:
6869
batched_transforms.append(build_batch_transform(cfg_transform))
69-
return T.Compose(batched_transforms)
70+
return BatchedTransformCompose(batched_transforms)

0 commit comments

Comments
 (0)