Skip to content

Commit 008bd7a

Browse files
author
daihao
committed
single controller: add train controller
1 parent 6138e3a commit 008bd7a

File tree

6 files changed

+340
-25
lines changed

6 files changed

+340
-25
lines changed

areal/api/controller_api.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def set_version(self, version: int):
315315
"""
316316
raise NotImplementedError()
317317

318-
def get_version(self) -> int:
318+
def get_version(self) -> List[int]:
319319
"""Get the current weight version in the training engine.
320320
321321
Returns
@@ -359,7 +359,7 @@ def train_batch(
359359
input_: DistributedBatch,
360360
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
361361
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
362-
) -> Dict[str, float]:
362+
) -> List[Dict[str, float]]:
363363
"""Update the model with a batch of data and a loss function.
364364
365365
Note
@@ -382,7 +382,7 @@ def train_batch(
382382
383383
Returns
384384
-------
385-
Dict[str, float]
385+
List[Dict[str, float]]
386386
Scalar statistics after training, e.g., the current learning rate,
387387
gradient norm, etc.
388388
"""
@@ -394,7 +394,7 @@ def eval_batch(
394394
input_: DistributedBatch,
395395
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
396396
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
397-
) -> torch.Tensor | None:
397+
) -> List[torch.Tensor]:
398398
"""Evaluate the model using the forward pass and loss function.
399399
400400
Note

areal/api/engine_api.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@ class Scheduling:
2525
cpu: int
2626
gpu: int
2727
mem: int
28+
port_count: int
29+
cmd: str | None = None
2830
nodelist: str | None = None
2931
exclude: str | None = None
3032
partition: str | None = None
@@ -138,7 +140,7 @@ def parallelism_group(self) -> dist.ProcessGroup:
138140
"""
139141
raise NotImplementedError()
140142

141-
def get_scheduling_config(self) -> Scheduling:
143+
def get_scheduling_config(self) -> List[Scheduling]:
142144
"""Get the scheduling configuration for the engine.
143145
144146
This includes configuration such as container image, CPU/GPU/memory size.

areal/api/scheduler_api.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,40 @@
11
import abc
22
from dataclasses import dataclass, field
3-
from typing import Dict, List
3+
from typing import List, Literal
4+
5+
from areal.api.engine_api import Scheduling
46

57

68
@dataclass
79
class Worker:
810
id: str
911
ip: str
10-
ports: List[str] = field(default_factory=list)
11-
12-
13-
@dataclass
14-
class ContainerSpec:
15-
cpu: int = 0
16-
gpu: int = 0
17-
mem: int = 0
18-
container_image: str = ""
19-
cmd: str = ""
20-
env_vars: Dict[str, str] = field(default_factory=dict)
21-
port_count: int = 2
12+
serve_port: str
13+
extra_ports: List[str] = field(default_factory=list)
2214

2315

2416
@dataclass
2517
class ScheduleStrategy:
26-
type: str = ""
18+
type: Literal["colocation", "separation", ""] = ""
2719
uid: str = ""
2820

2921

3022
@dataclass
31-
class SchedulingConfig:
23+
class Job:
3224
replicas: int = 0
33-
specs: List[ContainerSpec] = field(default_factory=list)
25+
tasks: List[Scheduling] = field(default_factory=list)
3426
schedule_strategy: ScheduleStrategy | None = None
3527
role: str = ""
3628

3729

3830
class Scheduler(abc.ABC):
39-
def create_workers(self, worker_key, scheduler_config, *args, **kwargs) -> str:
31+
def create_workers(self, job: Job, *args, **kwargs):
4032
"""
41-
Start workers, return job id
33+
Start workers
4234
"""
35+
raise NotImplementedError()
4336

44-
def get_workers(self, worker_key, timeout=None) -> List[Worker]:
37+
def get_workers(self, role: str, timeout=None) -> List[Worker]:
4538
"""
4639
Wait and return worker list, including scheduling results such as ip and engine ports
4740
(worker id, ip, ports)
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
from concurrent.futures import ThreadPoolExecutor
2+
from functools import partial
3+
from typing import Any, Callable, Dict, List
4+
5+
import torch
6+
7+
from areal.api.alloc_mode import ParallelStrategy
8+
from areal.api.cli_args import TrainEngineConfig
9+
from areal.api.controller_api import DistributedBatch, TrainController
10+
from areal.api.engine_api import TrainEngine
11+
from areal.api.io_struct import (
12+
AllocationMode,
13+
FinetuneSpec,
14+
ParamSpec,
15+
SaveLoadMeta,
16+
WeightUpdateMeta,
17+
)
18+
from areal.api.scheduler_api import Job, Scheduler, ScheduleStrategy, Worker
19+
from areal.controller.utils import create_engine_with_retry, rpc_call
20+
from areal.utils import logging
21+
from areal.utils.http import wait_future_ordered
22+
23+
logger = logging.getLogger("DistributedTrainController")
24+
25+
26+
class DistributedTrainController(TrainController):
27+
def __init__(
28+
self, train_engine: TrainEngine, config: TrainEngineConfig, scheduler: Scheduler
29+
):
30+
super().__init__(train_engine, config, scheduler)
31+
32+
self.role: str = "train"
33+
self.group_size: int
34+
self.alloc_mode: AllocationMode
35+
self.workers: List[Worker]
36+
self.engine_dp_ranks: List[int]
37+
38+
def create_process_group(self, parallel_strategy: ParallelStrategy | None = None):
39+
assert self.workers is not None, "Workers are not created"
40+
self.custom_function_call("create_process_group", parallel_strategy)
41+
42+
def initialize(
43+
self,
44+
alloc_mode_str: str,
45+
ft_spec: FinetuneSpec,
46+
schedule_strategy: ScheduleStrategy,
47+
group_size: int = 1,
48+
):
49+
"""Initialize environments for distributed training and load models."""
50+
self.alloc_mode = AllocationMode.from_str(alloc_mode_str)
51+
self.ft_spec = ft_spec
52+
self.group_size = group_size
53+
54+
job = Job(
55+
replicas=self.alloc_mode.train.world_size,
56+
tasks=self.train_engine.get_scheduling_config(),
57+
schedule_strategy=schedule_strategy,
58+
role=self.role,
59+
)
60+
logger.info(f"Start to create job: {job}")
61+
self.scheduler.create_workers(job)
62+
# after get workers, all rpc server is ready
63+
self.workers = self.scheduler.get_workers(self.role, timeout=1800)
64+
65+
logger.info(f"Start to create process group")
66+
self.create_process_group(self.alloc_mode.train)
67+
68+
logger.info(f"Start to initialize engine")
69+
with ThreadPoolExecutor(max_workers=len(self.workers)) as executor:
70+
futures = [
71+
executor.submit(
72+
partial(
73+
create_engine_with_retry,
74+
self.scheduler.create_engine,
75+
worker.id,
76+
self.train_engine,
77+
None,
78+
self.ft_spec,
79+
)
80+
)
81+
for worker in self.workers
82+
]
83+
84+
wait_future_ordered(futures, exit_on_exception=True)
85+
86+
logger.info(f"Start to get rank info from engine")
87+
self.engine_dp_ranks = rpc_call(
88+
self.scheduler, self.workers, "data_parallel_rank"
89+
)
90+
logger.info(f"Initialize train engines succeeded!")
91+
92+
def destroy(self):
93+
self.scheduler.delete_workers()
94+
95+
def train(self, mode: bool = True):
96+
self.custom_function_call("train", mode)
97+
98+
def upload_weights(self, meta: WeightUpdateMeta):
99+
self.custom_function_call("upload_weights", meta)
100+
101+
def get_param_specs(
102+
self, weight_chunked_mem_mb: int = 1024
103+
) -> List[List[ParamSpec]]:
104+
ret: List[List[List[ParamSpec]]] = self.custom_function_call(
105+
"get_param_specs", weight_chunked_mem_mb
106+
)
107+
flattened = [inner for outer in ret for inner in outer]
108+
return flattened
109+
110+
def set_version(self, version: int):
111+
return self.custom_function_call("set_version", version)
112+
113+
def get_version(self) -> List[int]:
114+
return self.custom_function_call("get_version")
115+
116+
def save(self, meta: SaveLoadMeta):
117+
self.custom_function_call("save", meta)
118+
119+
def load(self, meta: SaveLoadMeta):
120+
self.custom_function_call("load", meta)
121+
122+
def step_lr_scheduler(self):
123+
self.custom_function_call("step_lr_scheduler")
124+
125+
def custom_function_call(self, method: str, *args, **kwargs):
126+
return rpc_call(self.scheduler, self.workers, method, None, args, kwargs)
127+
128+
def _align_batches_with_dp(
129+
self, input_: DistributedBatch, rebalance=True
130+
) -> List[DistributedBatch]:
131+
if rebalance:
132+
inputs = input_.chunk_by_ffd(self.group_size, self.alloc_mode.train.dp_size)
133+
else:
134+
inputs = input_.chunk(self.alloc_mode.train.dp_size)
135+
136+
batches = []
137+
for dp_rank in self.engine_dp_ranks:
138+
batches.append(inputs[dp_rank])
139+
140+
return batches
141+
142+
def train_batch(
143+
self,
144+
input_: DistributedBatch,
145+
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
146+
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
147+
) -> List[Dict[str, float]]:
148+
149+
batches = self._align_batches_with_dp(input_, True)
150+
train_stats = rpc_call(
151+
self.scheduler,
152+
self.workers,
153+
"train_batch",
154+
batches,
155+
loss_fn,
156+
loss_weight_fn,
157+
)
158+
159+
return train_stats
160+
161+
def eval_batch(
162+
self,
163+
input_: DistributedBatch,
164+
loss_fn: Callable[[torch.Tensor, Dict[str, Any]], torch.Tensor],
165+
loss_weight_fn: Callable[[Dict[str, Any]], torch.Tensor],
166+
) -> List[torch.Tensor]:
167+
168+
batches = self._align_batches_with_dp(input_, True)
169+
eval_stats = rpc_call(
170+
self.scheduler, self.workers, "eval_batch", batches, loss_fn, loss_weight_fn
171+
)
172+
173+
return eval_stats
174+
175+
def forward(
176+
self,
177+
input_: DistributedBatch,
178+
output_seqlens: List[int] | None = None,
179+
post_hook: Callable[[torch.Tensor, Dict[str, Any]], Any] | None = None,
180+
aggregate_fn: Callable[[List[Any]], Any] = torch.cat,
181+
) -> List[Any]:
182+
batches = self._align_batches_with_dp(input_, False)
183+
forward_stats = rpc_call(
184+
self.scheduler,
185+
self.workers,
186+
"forward",
187+
batches,
188+
output_seqlens,
189+
post_hook,
190+
aggregate_fn,
191+
)
192+
193+
return forward_stats

0 commit comments

Comments
 (0)