diff --git a/3.test_cases/pytorch/cpu-ddp/README.md b/3.test_cases/pytorch/cpu-ddp/README.md deleted file mode 100644 index e0a6b8d0e..000000000 --- a/3.test_cases/pytorch/cpu-ddp/README.md +++ /dev/null @@ -1,9 +0,0 @@ -# PyTorch DDP on CPU - -Isolated environments are crucial for reproducible machine learning because they encapsulate specific software versions and dependencies, ensuring models are consistently retrainable, shareable, and deployable without compatibility issues. - -[Anaconda](https://www.anaconda.com/) leverages conda environments to create distinct spaces for projects, allowing different Python versions and libraries to coexist without conflicts by isolating updates to their respective environments. [Docker](https://www.docker.com/), a containerization platform, packages applications and their dependencies into containers, ensuring they run seamlessly across any Linux server by providing OS-level virtualization and encapsulating the entire runtime environment. - -This example showcases CPU [PyTorch DDP](https://pytorch.org/tutorials/beginner/ddp_series_theory.html) environment setup utilizing these approaches for efficient environment management. - - We provide guides for both Slurm and Kubernetes. However, please note that the Conda example is only compatible with Slurm. For detailed instructions, proceed to the [slurm](slurm) or [kubernetes](kubernetes) subdirectory. diff --git a/3.test_cases/pytorch/cpu-ddp/ddp.py b/3.test_cases/pytorch/cpu-ddp/ddp.py deleted file mode 100644 index 819e1739c..000000000 --- a/3.test_cases/pytorch/cpu-ddp/ddp.py +++ /dev/null @@ -1,124 +0,0 @@ -# Modified version of https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu_torchrun.py - -import torch -import torch.nn.functional as F -from torch.utils.data import Dataset, DataLoader - -import torch.multiprocessing as mp -from torch.utils.data.distributed import DistributedSampler -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.distributed import init_process_group, destroy_process_group -import os - -import torch -from torch.utils.data import Dataset - -class MyTrainDataset(Dataset): - def __init__(self, size): - self.size = size - self.data = [(torch.rand(20), torch.rand(1)) for _ in range(size)] - - def __len__(self): - return self.size - - def __getitem__(self, index): - return self.data[index] - -def ddp_setup(): - init_process_group(backend="gloo") - -class Trainer: - def __init__( - self, - model: torch.nn.Module, - train_data: DataLoader, - optimizer: torch.optim.Optimizer, - save_every: int, - snapshot_path: str, - ) -> None: - self.model = model - self.rank = os.environ["RANK"] - self.train_data = train_data - self.optimizer = optimizer - self.save_every = save_every - self.epochs_run = 0 - self.snapshot_path = snapshot_path - if os.path.exists(snapshot_path): - print("Loading snapshot") - self._load_snapshot(snapshot_path) - - self.model = DDP(self.model) - - def _load_snapshot(self, snapshot_path): - snapshot = torch.load(snapshot_path) - self.model.load_state_dict(snapshot["MODEL_STATE"]) - self.epochs_run = snapshot["EPOCHS_RUN"] - print(f"Resuming training from snapshot at Epoch {self.epochs_run}") - - def _run_batch(self, source, targets): - self.optimizer.zero_grad() - output = self.model(source) - loss = F.cross_entropy(output, targets) - loss.backward() - self.optimizer.step() - - def _run_epoch(self, epoch): - b_sz = len(next(iter(self.train_data))[0]) - print(f"[RANK {self.rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)}") - self.train_data.sampler.set_epoch(epoch) - for source, targets in self.train_data: - source = source - targets = targets - self._run_batch(source, targets) - - def _save_snapshot(self, epoch): - snapshot = { - "MODEL_STATE": self.model.module.state_dict(), - "EPOCHS_RUN": epoch, - } - torch.save(snapshot, self.snapshot_path) - print(f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}") - - def train(self, max_epochs: int): - for epoch in range(self.epochs_run, max_epochs): - self._run_epoch(epoch) - if epoch % self.save_every == 0: - self._save_snapshot(epoch) - - -def load_train_objs(): - train_set = MyTrainDataset(2048) # load your dataset - model = torch.nn.Linear(20, 1) # load your model - optimizer = torch.optim.SGD(model.parameters(), lr=1e-3) - return train_set, model, optimizer - - -def prepare_dataloader(dataset: Dataset, batch_size: int): - return DataLoader( - dataset, - batch_size=batch_size, - pin_memory=True, - shuffle=False, - sampler=DistributedSampler(dataset) - ) - - -def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str): - ddp_setup() - dataset, model, optimizer = load_train_objs() - train_data = prepare_dataloader(dataset, batch_size) - trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path) - trainer.train(total_epochs) - destroy_process_group() - - -if __name__ == "__main__": - import argparse - parser = argparse.ArgumentParser(description='simple distributed training job') - parser.add_argument('total_epochs', type=int, help='Total epochs to train the model') - parser.add_argument('save_every', type=int, help='How often to save a snapshot') - parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') - parser.add_argument('--checkpoint_path', default="./snapshot.pt", type=str, help='Full path to checkpoint file') - args = parser.parse_args() - - main(args.save_every, args.total_epochs, args.batch_size, args.checkpoint_path) diff --git a/3.test_cases/pytorch/cpu-ddp/kubernetes/fsdp-simple.yaml b/3.test_cases/pytorch/cpu-ddp/kubernetes/fsdp-simple.yaml deleted file mode 100644 index 80c85dcd5..000000000 --- a/3.test_cases/pytorch/cpu-ddp/kubernetes/fsdp-simple.yaml +++ /dev/null @@ -1,118 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: etcd -spec: - ports: - - name: etcd-client-port - port: 2379 - protocol: TCP - targetPort: 2379 - selector: - app: etcd - ---- -apiVersion: apps/v1 -kind: Deployment -metadata: - labels: - app: etcd - name: etcd -spec: - replicas: 1 - selector: - matchLabels: - app: etcd - template: - metadata: - labels: - app: etcd - spec: - containers: - - name: etcd - command: ["/usr/local/bin/etcd"] - args: - - "--data-dir" - - "/var/lib/etcd" - - "--enable-v2" - - "--listen-client-urls" - - "http://0.0.0.0:2379" - - "--advertise-client-urls" - - "http://0.0.0.0:2379" - - "--initial-cluster-state" - - "new" - image: quay.io/coreos/etcd:v3.5.19 - ports: - - containerPort: 2379 - name: client - protocol: TCP - - containerPort: 2380 - name: server - protocol: TCP - restartPolicy: Always ---- -apiVersion: "kubeflow.org/v1" -kind: PyTorchJob -metadata: - name: fsdp -spec: - elasticPolicy: - rdzvBackend: etcd - rdzvHost: etcd - rdzvPort: 2379 - minReplicas: 1 - maxReplicas: 64 - maxRestarts: 100 - metrics: - - type: Resource - resource: - name: cpu - target: - type: Utilization - averageUtilization: 90 - pytorchReplicaSpecs: - Worker: - replicas: 2 - restartPolicy: OnFailure - template: - metadata: - labels: - app: fsdp - spec: - volumes: - - name: shmem - hostPath: - path: /dev/shm - - name: local - hostPath: - path: /mnt/k8s-disks/0 - #nodeSelector: - # node.kubernetes.io/instance-type: "${INSTANCE_TYPE}" - containers: - - name: pytorch - image: pytorch/pytorch:latest - imagePullPolicy: Always - command: - - /opt/conda/bin/torchrun - - --nproc_per_node=4 - - --nnodes=2 - - /local/ddp.py - - "5000" - - "10" - - --batch_size=32 - - --checkpoint_path=/local/snapshot.pt - - volumeMounts: - - name: shmem - mountPath: /dev/shm - - name: local - mountPath: /local - initContainers: - - name: script-downloader - image: public.ecr.aws/hpc-cloud/nccl-tests:latest - command: ["curl", "https://raw.githubusercontent.com/aws-samples/awsome-distributed-training/refs/heads/main/3.test_cases/16.pytorch-cpu-ddp/ddp.py", "-o", "/local/ddp.py"] - volumeMounts: - - name: shmem - mountPath: /dev/shm - - name: local - mountPath: /local diff --git a/3.test_cases/pytorch/cpu-ddp/slurm/0.create-conda-env.sh b/3.test_cases/pytorch/cpu-ddp/slurm/0.create-conda-env.sh deleted file mode 100644 index 280990b1f..000000000 --- a/3.test_cases/pytorch/cpu-ddp/slurm/0.create-conda-env.sh +++ /dev/null @@ -1,17 +0,0 @@ -#!/usr/bin/env bash -set -ex - -# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. -# SPDX-License-Identifier: MIT-0 - -wget https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -chmod +x Miniconda3-latest-Linux-x86_64.sh -./Miniconda3-latest-Linux-x86_64.sh -b -f -p ./miniconda3 - -source ./miniconda3/bin/activate - -conda create -y -p ./pt_cpu python=3.10 pytorch=2.0.1 -c pytorch -c nvidia -c conda-forge - -source activate ./pt_cpu/ - -rm Miniconda3-latest-Linux-x86_64.sh* diff --git a/3.test_cases/pytorch/cpu-ddp/.gitignore b/3.test_cases/pytorch/ddp/.gitignore similarity index 64% rename from 3.test_cases/pytorch/cpu-ddp/.gitignore rename to 3.test_cases/pytorch/ddp/.gitignore index 13919ec80..0df471a80 100644 --- a/3.test_cases/pytorch/cpu-ddp/.gitignore +++ b/3.test_cases/pytorch/ddp/.gitignore @@ -1,4 +1,7 @@ Miniconda3-latest* miniconda3 -pt_cpu +pt *.yaml +data +*.pt +mlruns diff --git a/3.test_cases/pytorch/cpu-ddp/Dockerfile b/3.test_cases/pytorch/ddp/Dockerfile similarity index 66% rename from 3.test_cases/pytorch/cpu-ddp/Dockerfile rename to 3.test_cases/pytorch/ddp/Dockerfile index 4d6b47ffc..dcd0dbfea 100644 --- a/3.test_cases/pytorch/cpu-ddp/Dockerfile +++ b/3.test_cases/pytorch/ddp/Dockerfile @@ -2,6 +2,6 @@ FROM pytorch/pytorch:latest RUN apt update && apt upgrade -y +RUN mlflow==2.13.2 sagemaker-mlflow==0.1.0 COPY ddp.py /workspace - diff --git a/3.test_cases/pytorch/ddp/README.md b/3.test_cases/pytorch/ddp/README.md new file mode 100644 index 000000000..2edd9012d --- /dev/null +++ b/3.test_cases/pytorch/ddp/README.md @@ -0,0 +1,71 @@ +# PyTorch DDP + +Isolated environments are crucial for reproducible machine learning because they encapsulate specific software versions and dependencies, ensuring models are consistently retrainable, shareable, and deployable without compatibility issues. + +[Anaconda](https://www.anaconda.com/) leverages conda environments to create distinct spaces for projects, allowing different Python versions and libraries to coexist without conflicts by isolating updates to their respective environments. [Docker](https://www.docker.com/), a containerization platform, packages applications and their dependencies into containers, ensuring they run seamlessly across any Linux server by providing OS-level virtualization and encapsulating the entire runtime environment. + +This example showcases [PyTorch DDP](https://pytorch.org/tutorials/beginner/ddp_series_theory.html) environment setup utilizing these approaches for efficient environment management. The implementation supports both CPU and GPU computation: + +- **CPU Training**: Uses the GLOO backend for distributed training on CPU nodes +- **GPU Training**: Automatically switches to NCCL backend when GPUs are available, providing optimized multi-GPU training + +## Training + +### Basic Usage + +To run the training with GPUs, use `torchrun` with the appropriate number of GPUs: +```bash +torchrun --nproc_per_node=N ddp.py --total_epochs=10 --save_every=1 --batch_size=32 +``` +where N is the number of GPUs you want to use. + +## MLFlow Integration + +This implementation includes [MLFlow](https://mlflow.org/) integration for experiment tracking and model management. MLFlow helps you track metrics, parameters, and artifacts during training, making it easier to compare different runs and manage model versions. + +### Setup + +1. Install MLFlow: +```bash +pip install mlflow +``` + +2. Start the MLFlow tracking server: +```bash +mlflow ui +``` + +### Usage + +To enable MLFlow logging, add the `--use_mlflow` flag when running the training script: +```bash +torchrun --nproc_per_node=N ddp.py --total_epochs=10 --save_every=1 --batch_size=32 --use_mlflow +``` + +By default, MLFlow will connect to `http://localhost:5000`. To use a different tracking server, specify the `--tracking_uri`: +```bash +torchrun --nproc_per_node=N ddp.py --total_epochs=10 --save_every=1 --batch_size=32 --use_mlflow --tracking_uri=http://localhost:5000 +``` + +### What's Tracked + +MLFlow will track: +- Training metrics (loss per epoch) +- Model hyperparameters +- Model checkpoints +- Training configuration + +### Viewing Results + +1. Open your browser and navigate to `http://localhost:5000` (or your specified tracking URI) + +The MLFlow UI provides: +- Experiment comparison +- Metric visualization +- Parameter tracking +- Model artifact management +- Run history + +## Deployment + +We provide guides for both Slurm and Kubernetes. However, please note that the Conda example is only compatible with Slurm. For detailed instructions, proceed to the [slurm](slurm) or [kubernetes](kubernetes) subdirectory. diff --git a/3.test_cases/pytorch/ddp/ddp.py b/3.test_cases/pytorch/ddp/ddp.py new file mode 100644 index 000000000..3b52637f1 --- /dev/null +++ b/3.test_cases/pytorch/ddp/ddp.py @@ -0,0 +1,197 @@ +# Modified version of https://github.com/pytorch/examples/blob/main/distributed/ddp-tutorial-series/multigpu_torchrun.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision import datasets, transforms +import mlflow +import mlflow.pytorch + +import torch.multiprocessing as mp +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.distributed import init_process_group, destroy_process_group +import os + +class MLP(nn.Module): + def __init__(self): + super().__init__() + self.flatten = nn.Flatten() + self.linear_relu_stack = nn.Sequential( + nn.Linear(28*28, 512), + nn.ReLU(), + nn.Linear(512, 512), + nn.ReLU(), + nn.Linear(512, 10) + ) + + def forward(self, x): + x = self.flatten(x) + logits = self.linear_relu_stack(x) + return logits + +def ddp_setup(): + # Use NCCL backend for GPU training, fallback to GLOO for CPU + if torch.cuda.is_available(): + print("Using NCCL backend for GPU training") + init_process_group(backend="nccl") + else: + print("Using GLOO backend for CPU training") + init_process_group(backend="gloo") + +class Trainer: + def __init__( + self, + model: torch.nn.Module, + train_data: DataLoader, + optimizer: torch.optim.Optimizer, + save_every: int, + snapshot_path: str, + use_mlflow: bool = False, + tracking_uri: str = None + ) -> None: + self.model = model + self.rank = int(os.environ["RANK"]) + self.train_data = train_data + self.optimizer = optimizer + self.save_every = save_every + self.epochs_run = 0 + self.snapshot_path = snapshot_path + self.use_mlflow = use_mlflow + self.tracking_uri = tracking_uri if tracking_uri else f"file://{os.environ['HOME']}/mlruns" + # Set device + self.device = torch.device(f"cuda:{os.environ['LOCAL_RANK']}" if torch.cuda.is_available() else "cpu") + self.model = self.model.to(self.device) + + if os.path.exists(snapshot_path): + print("Loading snapshot") + self._load_snapshot(snapshot_path) + + self.model = DDP(self.model, device_ids=[self.device.index] if torch.cuda.is_available() else None) + + def _load_snapshot(self, snapshot_path): + snapshot = torch.load(snapshot_path, map_location=self.device) + self.model.load_state_dict(snapshot["MODEL_STATE"]) + self.epochs_run = snapshot["EPOCHS_RUN"] + print(f"Resuming training from snapshot at Epoch {self.epochs_run}") + + def _run_batch(self, source, targets): + source = source.to(self.device) + targets = targets.to(self.device) + self.optimizer.zero_grad() + output = self.model(source) + loss = F.cross_entropy(output, targets) + loss.backward() + self.optimizer.step() + return loss.item() + + def _run_epoch(self, epoch): + b_sz = len(next(iter(self.train_data))[0]) + self.train_data.sampler.set_epoch(epoch) + total_loss = 0 + for source, targets in self.train_data: + loss = self._run_batch(source, targets) + total_loss += loss + + avg_loss = total_loss / len(self.train_data) + if self.use_mlflow and self.rank == 0: # Only log from rank 0 + mlflow.log_metric("train_loss", avg_loss, step=epoch) + print(f"[RANK {self.rank}] Epoch {epoch} | Batchsize: {b_sz} | Steps: {len(self.train_data)} | Loss: {avg_loss}") + return avg_loss + + def _save_snapshot(self, epoch): + snapshot = { + "MODEL_STATE": self.model.module.state_dict(), + "EPOCHS_RUN": epoch, + } + torch.save(snapshot, self.snapshot_path) + print(f"Epoch {epoch} | Training snapshot saved at {self.snapshot_path}") + + def train(self, max_epochs: int): + if self.use_mlflow and self.rank == 0: + print(f"Setting tracking URI to {self.tracking_uri}") + # Set tracking URI first + if self.tracking_uri: + mlflow.set_tracking_uri(self.tracking_uri) + + # Create or get experiment + experiment = mlflow.get_experiment_by_name("mnist_ddp") + if experiment is None: + experiment_id = mlflow.create_experiment("mnist_ddp") + else: + experiment_id = experiment.experiment_id + + # Set the experiment + mlflow.set_experiment(experiment_id=experiment_id) + + with mlflow.start_run(): + mlflow.log_params({ + "model": "MLP", + "optimizer": "Adam", + "learning_rate": self.optimizer.param_groups[0]['lr'], + "batch_size": len(next(iter(self.train_data))[0]), + "epochs": max_epochs, + "device": str(self.device) + }) + mlflow.pytorch.log_model(self.model.module, "model") + else: + print("MLFlow is disabled") + + for epoch in range(self.epochs_run, max_epochs): + avg_loss = self._run_epoch(epoch) + if epoch % self.save_every == 0: + self._save_snapshot(epoch) + if self.use_mlflow and self.rank == 0: + mlflow.pytorch.log_model(self.model.module, f"model_epoch_{epoch}") + mlflow.log_metric("train_loss", avg_loss, step=epoch) + +def load_train_objs(): + # Define data transforms + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + # Load MNIST dataset + train_set = datasets.MNIST( + root='./data', + train=True, + download=True, + transform=transform + ) + + # Create model and optimizer + model = MLP() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + return train_set, model, optimizer + +def prepare_dataloader(dataset: datasets.MNIST, batch_size: int): + return DataLoader( + dataset, + batch_size=batch_size, + pin_memory=True, + shuffle=False, + sampler=DistributedSampler(dataset) + ) + +def main(save_every: int, total_epochs: int, batch_size: int, snapshot_path: str, use_mlflow: bool = False, tracking_uri: str = None): + ddp_setup() + dataset, model, optimizer = load_train_objs() + train_data = prepare_dataloader(dataset, batch_size) + trainer = Trainer(model, train_data, optimizer, save_every, snapshot_path, use_mlflow, tracking_uri) + trainer.train(total_epochs) + destroy_process_group() + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description='simple distributed training job') + parser.add_argument('--total_epochs', type=int, help='Total epochs to train the model') + parser.add_argument('--save_every', type=int, help='How often to save a snapshot') + parser.add_argument('--batch_size', default=32, type=int, help='Input batch size on each device (default: 32)') + parser.add_argument('--checkpoint_path', default="./snapshot.pt", type=str, help='Full path to checkpoint file') + parser.add_argument('--use_mlflow', action='store_true', help='Enable MLFlow logging') + parser.add_argument('--tracking_uri', type=str, help='MLflow tracking URI', default=None) + args = parser.parse_args() + main(args.save_every, args.total_epochs, args.batch_size, args.checkpoint_path, args.use_mlflow, args.tracking_uri) diff --git a/3.test_cases/pytorch/cpu-ddp/kubernetes/README.md b/3.test_cases/pytorch/ddp/kubernetes/README.md similarity index 100% rename from 3.test_cases/pytorch/cpu-ddp/kubernetes/README.md rename to 3.test_cases/pytorch/ddp/kubernetes/README.md diff --git a/3.test_cases/pytorch/cpu-ddp/kubernetes/fsdp.yaml-template b/3.test_cases/pytorch/ddp/kubernetes/ddp-custom-container.yaml-template similarity index 100% rename from 3.test_cases/pytorch/cpu-ddp/kubernetes/fsdp.yaml-template rename to 3.test_cases/pytorch/ddp/kubernetes/ddp-custom-container.yaml-template diff --git a/3.test_cases/pytorch/ddp/slurm/0.create-venv.sh b/3.test_cases/pytorch/ddp/slurm/0.create-venv.sh new file mode 100644 index 000000000..a6660b361 --- /dev/null +++ b/3.test_cases/pytorch/ddp/slurm/0.create-venv.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +set -ex + +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: MIT-0 + +PYTHON_V=python3 +OS_VERSION=$(cat /etc/os-release | grep VERSION_ID | awk -F '=' '{print $2}') +OS_VERSION=${OS_VERSION//\"/} + +if [ $OS_VERSION = "20.04" ]; then + PYTHON_VERSION=$(python3.9 --version | awk '{print $2}' | awk -F'.' '{print $1"."$2}') + PYTHON_V=python3.9 +else + PYTHON_VERSION=$(python3 --version | awk '{print $2}' | awk -F'.' '{print $1"."$2}') +fi + +sudo apt install -y python$PYTHON_VERSION-venv + +# Create and activate Python virtual environment +$PYTHON_V -m venv pt +source ./pt/bin/activate + +# Install required packages +pip install torch==2.1.1 torchvision==0.16.1 numpy==1.* mlflow==2.13.2 sagemaker-mlflow==0.1.0 diff --git a/3.test_cases/pytorch/cpu-ddp/slurm/1.conda-train.sbatch b/3.test_cases/pytorch/ddp/slurm/1.venv-train.sbatch similarity index 59% rename from 3.test_cases/pytorch/cpu-ddp/slurm/1.conda-train.sbatch rename to 3.test_cases/pytorch/ddp/slurm/1.venv-train.sbatch index d1fb414da..e47000957 100644 --- a/3.test_cases/pytorch/cpu-ddp/slurm/1.conda-train.sbatch +++ b/3.test_cases/pytorch/ddp/slurm/1.venv-train.sbatch @@ -1,5 +1,5 @@ #!/bin/bash -#SBATCH --job-name=cpu-ddp-conda +#SBATCH --job-name=ddp-venv #SBATCH --exclusive #SBATCH --wait-all-nodes=1 #SBATCH --nodes 2 @@ -8,19 +8,27 @@ export LOGLEVEL=INFO declare -a TORCHRUN_ARGS=( - --nproc_per_node=4 + --nproc_per_node=1 # For GPU: Set this to number of GPUs per node --nnodes=$SLURM_JOB_NUM_NODES --rdzv_id=$SLURM_JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$(hostname) ) +declare -a TRAIN_ARGS=( + --total_epochs 500 + --save_every 1 + --batch_size 32 + --checkpoint_path ./snapshot.pt + --use_mlflow +) + AUTO_RESUME="" if [ -d "/opt/sagemaker_cluster" ]; then echo "Detected Hyperpod cluster.. enabling --auto-resume=1" AUTO_RESUME="--auto-resume=1" fi -srun ${AUTO_RESUME} ./pt_cpu/bin/torchrun \ +srun ${AUTO_RESUME} ./pt/bin/torchrun \ "${TORCHRUN_ARGS[@]}" \ - $(dirname "$PWD")/ddp.py 5000000 10 + $(dirname "$PWD")/ddp.py ${TRAIN_ARGS[@]} diff --git a/3.test_cases/pytorch/cpu-ddp/slurm/2.create-enroot-image.sh b/3.test_cases/pytorch/ddp/slurm/2.create-enroot-image.sh similarity index 100% rename from 3.test_cases/pytorch/cpu-ddp/slurm/2.create-enroot-image.sh rename to 3.test_cases/pytorch/ddp/slurm/2.create-enroot-image.sh diff --git a/3.test_cases/pytorch/cpu-ddp/slurm/3.container-train.sbatch b/3.test_cases/pytorch/ddp/slurm/3.container-train.sbatch similarity index 71% rename from 3.test_cases/pytorch/cpu-ddp/slurm/3.container-train.sbatch rename to 3.test_cases/pytorch/ddp/slurm/3.container-train.sbatch index 29206372d..9070fa29d 100644 --- a/3.test_cases/pytorch/cpu-ddp/slurm/3.container-train.sbatch +++ b/3.test_cases/pytorch/ddp/slurm/3.container-train.sbatch @@ -3,7 +3,7 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: MIT-0 -#SBATCH --job-name=cpu-ddp-container +#SBATCH --job-name=ddp-container #SBATCH --exclusive #SBATCH --wait-all-nodes=1 #SBATCH --nodes 2 @@ -18,11 +18,21 @@ declare -a ARGS=( ) declare -a TORCHRUN_ARGS=( - --nproc_per_node=4 + --nproc_per_node=4 # For GPU: Set this to number of GPUs per node --nnodes=$SLURM_JOB_NUM_NODES --rdzv_id=$SLURM_JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$(hostname) + --use-mlflow +) + + +declare -a TRAIN_ARGS=( + --total_epochs 500 + --save_every 1 + --batch_size 32 + --checkpoint_path ./snapshot.pt + --use-mlflow ) AUTO_RESUME="" @@ -33,4 +43,4 @@ fi srun ${AUTO_RESUME} -l "${ARGS[@]}" torchrun \ "${TORCHRUN_ARGS[@]}" \ - $(dirname "$PWD")/ddp.py 5000000 10 + $(dirname "$PWD")/ddp.py ${TRAIN_ARGS[@]} diff --git a/3.test_cases/pytorch/cpu-ddp/slurm/README.md b/3.test_cases/pytorch/ddp/slurm/README.md similarity index 99% rename from 3.test_cases/pytorch/cpu-ddp/slurm/README.md rename to 3.test_cases/pytorch/ddp/slurm/README.md index 1efa98d17..83ae814a4 100644 --- a/3.test_cases/pytorch/cpu-ddp/slurm/README.md +++ b/3.test_cases/pytorch/ddp/slurm/README.md @@ -19,7 +19,7 @@ using a container. bash 0.create-conda-env.sh ``` -It will prepare `miniconda3` and `pt_cpu` `pt_cpu` includes `torchrun` +It will prepare `miniconda3` and `pt` `pt` includes `torchrun` Submit DDP training job with: