Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP support for training loop #110

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cloud/google/workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@ def GenerateWorkerStartupScript(context, hostname_nfs_server, env_variables, cmd
startup_script = f'''
#!/bin/bash
set -e

echo "net.ipv6.conf.all.disable_ipv6 = 1" >> /etc/sysctl.conf
echo "net.ipv6.conf.default.disable_ipv6 = 1" >> /etc/sysctl.conf
echo "net.ipv6.conf.lo.disable_ipv6 = 1" >> /etc/sysctl.conf

sysctl -p

mount -t tmpfs -o size=80%,noatime tmpfs /tmp
mkdir -p /var/log/airflow/logs
chmod 777 /var/log/airflow/logs
Expand Down
14 changes: 14 additions & 0 deletions common/docker_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ def health_check_info(image_name):
return False


def has_custom_entrypoint(image_name):
image = pull_image(image_name)
try:
entrypoint = image.attrs.get("Config", {}).get("Entrypoint", None)

if entrypoint:
return True
else:
return False
except Exception as e:
print(f"Error: {e}")
return True


def pull_image(image_name):
import docker
import traceback
Expand Down
56 changes: 47 additions & 9 deletions dags/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@

import json
import os
import uuid
from datetime import datetime

from airflow import DAG
from airflow.utils.weight_rule import WeightRule
from airflow.operators.python import PythonOperator
from airflow.models import Variable, BaseOperator as Operator
from airflow.hooks.base_hook import BaseHook

from worker_op import worker_op
from helper_ops import scale_up_cluster_op, scale_down_cluster_op, collect_metrics_op
Expand All @@ -19,8 +21,22 @@

PARAM = Variable.get("training_param", {}, deserialize_json=True)
DEEPEM_IMAGE = PARAM.get("deepem_image", "zettaai/deepem")
SKIP_EXPORT = PARAM.pop("skip_export", False)
cluster_info = json.loads(BaseHook.get_connection("InstanceGroups").extra)
training_cluster = "deepem-gpu"

max_trainers = sum(c['max_size'] for c in cluster_info[training_cluster])

if "rdzv_id" not in PARAM:
PARAM["rdzv_id"] = str(uuid.uuid4())
Variable.set("training_param", PARAM, serialize_json=True)

if "gpu_ids" not in PARAM:
num_gpus = cluster_info[training_cluster][0]['gpuWorkerAcceleratorCount']
PARAM["gpu_ids"] = list(range(num_gpus))
Variable.set("training_param", PARAM, serialize_json=True)


SKIP_EXPORT = PARAM.pop("skip_export", False)

default_args = dict(
owner="seuronbot",
Expand All @@ -31,6 +47,13 @@
)


def reset_rdzv_id(context):
from airflow.models import Variable
param = Variable.get("training_param", {}, deserialize_json=True)
param["rdzv_id"] = str(uuid.uuid4())
Variable.set("training_param", param, serialize_json=True)


def prep_parameters() -> dict:
"""Modify the user-supplied parameters to be used as a command for DeepEM."""
param = PARAM.copy()
Expand Down Expand Up @@ -59,7 +82,16 @@ def prep_parameters() -> dict:
return param


def make_argstr(param: dict) -> str:
def make_argstr(param: dict, num_trainers: int, rank: int, rdzv_id: str) -> str:

torchrun_launcher = param.pop("TORCHRUN_LAUNCHER", None)
if torchrun_launcher:
launch_command = ["torchrun", f"--nproc_per_node={len(param['gpu_ids'])}",
f"--nnodes={num_trainers}", f"--node_rank={rank}", f"--rdzv_id={rdzv_id}",
"--rdzv_backend=etcd-v2", f"--rdzv_endpoint={os.environ['REDIS_SERVER']}:2379",
"/DeepEM/deepem/train/run.py"]
else:
launch_command = []

def format_arg(item) -> str:
k, v = item
Expand All @@ -72,15 +104,17 @@ def format_arg(item) -> str:
else:
return f"--{k} {v}"

return " ".join(map(format_arg, param.items()))
return " ".join(launch_command + list(map(format_arg, param.items())))


def training_op(dag: DAG, queue="deepem-gpu") -> Operator:
def training_op(dag: DAG, rank=0, queue=training_cluster) -> Operator:
param = prep_parameters()

wandb_api_key = param.pop("WANDB_API_KEY", None)
environment = {"WANDB_API_KEY": wandb_api_key} if wandb_api_key else None

num_trainers = min(param.pop("NUM_TRAINERS", 1), max_trainers)
rdzv_id = param.pop("rdzv_id", None)
# these variables will be mounted in the containers
mount_secrets = param.pop("MOUNT_SECRETS", [])
variables = []
Expand All @@ -90,11 +124,12 @@ def training_op(dag: DAG, queue="deepem-gpu") -> Operator:
return worker_op(
variables=variables,
mount_point=param.pop("MOUNT_PATH", default_mount_path),
task_id="training",
command=make_argstr(param),
task_id=f"training_{rank}",
command=make_argstr(param, num_trainers, rank, rdzv_id),
use_gpus=True,
environment=environment,
force_pull=True,
on_retry_callback=reset_rdzv_id if rank == 0 else None,
on_failure_callback=task_failure_alert,
on_success_callback=task_done_alert,
image=DEEPEM_IMAGE,
Expand All @@ -104,6 +139,7 @@ def training_op(dag: DAG, queue="deepem-gpu") -> Operator:
dag=dag,
qos=False,
shm_size=4 * (2 ** 30), # 4 GB
network_mode="host",
)


Expand Down Expand Up @@ -155,11 +191,13 @@ def report_model() -> None:
)

collect_metrics = collect_metrics_op(training_dag)
scale_up = scale_up_cluster_op(training_dag, "training", "deepem-gpu", 1, 1, "cluster")
num_trainers = min(PARAM.get("NUM_TRAINERS", 1), max_trainers)

scale_up = scale_up_cluster_op(training_dag, "training", training_cluster, num_trainers, num_trainers, "cluster")
scale_down = scale_down_cluster_op(
training_dag, "training", "deepem-gpu", 0, "cluster", trigger_rule="all_done"
training_dag, "training", training_cluster, 0, "cluster", trigger_rule="all_done"
)
training = training_op(training_dag)
training = [training_op(training_dag, i) for i in range(num_trainers)]
report_training = PythonOperator(
task_id="report_model",
python_callable=report_model,
Expand Down
1 change: 1 addition & 0 deletions dags/worker_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,5 @@ def worker_op(**kwargs):
retry_delay=kwargs.get("retry_delay", default_args.get("retry_delay", 60)),
retry_exponential_backoff=kwargs.get("retry_exponential_backoff", default_args.get("retry_exponential_backoff", False)),
shm_size=kwargs.get("shm_size", None),
network_mode=kwargs.get("network_mode", None),
)
12 changes: 12 additions & 0 deletions deploy/docker-compose-CeleryExecutor.yml
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,18 @@ services:
<<: *airflow-common-env
command: python utils/memory_monitor.py amqp://rabbitmq worker-client-queue

etcd:
image: quay.io/coreos/etcd:v3.5.17
ports:
- "2379:2379/tcp"
command: >
/usr/local/bin/etcd
--data-dir /var/lib/etcd
--enable-v2
--listen-client-urls http://0.0.0.0:2379
--advertise-client-urls http://0.0.0.0:2379
--initial-cluster-state new

proxy:
image: nginx:1.23.0-alpine
environment:
Expand Down
2 changes: 2 additions & 0 deletions pipeline/init_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ def parse_metadata():
worker_setting['workerConcurrencies'] = c['workerConcurrencies']
else:
worker_setting['concurrency'] = c.get('concurrency', 1)
if c['type'] == 'deepem-gpu':
worker_setting['gpuWorkerAcceleratorCount'] = c.get('gpuWorkerAcceleratorCount', 1)
instance_groups[c['type']].append(worker_setting)
elif item["key"] == "easyseg-worker":
worker = json.loads(item["value"])
Expand Down
11 changes: 11 additions & 0 deletions slackbot/training_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from airflow_api import run_dag
from bot_utils import replyto, download_json
from airflow_api import get_variable, set_variable
from common import docker_helper


@SeuronBot.on_message("update training parameters",
Expand All @@ -25,6 +26,16 @@ def update_training_parameters(msg: dict) -> None:
except Exception as e:
replyto(msg, f"Error parsing parameters: {e}")

replyto(msg, "Download deepem image and check for custom entrypoint")
deepem_image = json_obj.get("deepem_image", "zettaai/deepem")
if docker_helper.has_custom_entrypoint(deepem_image):
replyto(msg, ":disappointed:Custom entrypoint found, disable DDP")
json_obj["TORCHRUN_LAUNCHER"] = False
json_obj["NUM_TRAINERS"] = 1
else:
replyto(msg, ":cool:Launch training script with torchrun")
json_obj["TORCHRUN_LAUNCHER"] = True

set_variable("training_param", json_obj, serialize_json=True)
replyto(msg, "Parameters successfully updated")

Expand Down