Skip to content

Commit 18b2aec

Browse files
committed
refactor:
- Add helper functions for ProcessGroup management - Update main.cc to support ProcessGroup-based training workflow - Move NCCL communicator from Device to ProcessGroup; maintain rank information in Device - Update communication operators to call ProcessGroup member functions; pg parameter is optional (uses default group if null) - Introduce a generic Rank class for rank abstraction - Add global class to manage and access environment variables
1 parent 5ada1d7 commit 18b2aec

File tree

23 files changed

+637
-543
lines changed

23 files changed

+637
-543
lines changed

example/common/utils.cc

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
#include "utils.h"
2+
13
#include <cstdint>
24
#include <cstring>
35

4-
#include "utils.h"
6+
#include "infini_train/include/nn/parallel/global.h"
57

68
namespace infini_train {
79

@@ -12,4 +14,42 @@ float ConvertBF16ToFloat(void *ptr) {
1214
std::memcpy(&f, &f32_bits, sizeof(f));
1315
return f;
1416
}
17+
18+
std::vector<int> GetDataParallelGroupRanks(int rank) {
19+
std::vector<int> ranks;
20+
21+
int world_size = nn::parallel::global::GetWorldSize();
22+
int tp_size = nn::parallel::global::GetTensorParallelSize();
23+
int dp_size = nn::parallel::global::GetDataParallelSize();
24+
25+
ranks.reserve(dp_size);
26+
int dp_group_id = rank % tp_size;
27+
28+
for (int r = 0; r < world_size; ++r) {
29+
if (r % tp_size == dp_group_id) {
30+
ranks.push_back(r);
31+
}
32+
}
33+
34+
return ranks;
35+
}
36+
37+
std::vector<int> GetTensorParallelGroupRanks(int rank) {
38+
std::vector<int> ranks;
39+
40+
int world_size = nn::parallel::global::GetWorldSize();
41+
int tp_size = nn::parallel::global::GetTensorParallelSize();
42+
43+
ranks.reserve(tp_size);
44+
int tp_group_id = rank / tp_size;
45+
46+
for (int r = 0; r < world_size; ++r) {
47+
if (r / tp_size == tp_group_id) {
48+
ranks.push_back(r);
49+
}
50+
}
51+
52+
return ranks;
53+
}
54+
1555
} // namespace infini_train

example/common/utils.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
#pragma once
22

3+
#include <vector>
4+
35
namespace infini_train {
46

57
float ConvertBF16ToFloat(void *ptr);
8+
9+
std::vector<int> GetDataParallelGroupRanks(int rank);
10+
11+
std::vector<int> GetTensorParallelGroupRanks(int rank);
12+
613
} // namespace infini_train

example/gpt2/main.cc

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,16 @@
1414
#include "infini_train/include/nn/modules/loss.h"
1515
#include "infini_train/include/nn/modules/module.h"
1616
#include "infini_train/include/nn/parallel/distributed_data_parallel.h"
17+
#include "infini_train/include/nn/parallel/global.h"
1718
#include "infini_train/include/nn/parallel/parallel_functional.h"
19+
#include "infini_train/include/nn/parallel/rank.h"
1820
#include "infini_train/include/nn/parallel/reduce_op_type.h"
1921
#include "infini_train/include/optimizer.h"
2022
#ifdef PROFILE_MODE
2123
#include "infini_train/include/profiler.h"
2224
#endif
25+
#include "infini_train/include/nn/parallel/utils.h"
26+
2327
#include "example/common/tiny_shakespeare_dataset.h"
2428
#include "example/common/tokenizer.h"
2529
#include "example/common/utils.h"
@@ -49,12 +53,13 @@ DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
4953
// debugging
5054
DEFINE_bool(overfit_single_batch, true, "overfit just one batch of data");
5155
// memory management
52-
DEFINE_string(device, "cuda", "device type (cpu/cuda), useless if data_parallel=true");
56+
DEFINE_string(device, "cuda", "device type (cpu/cuda), useless if using parallel training mode");
5357
// parallel
5458
DEFINE_int32(
55-
data_parallel, 1,
56-
"Number of GPUs to use for data parallel training. "
59+
nthread_per_process, 1,
60+
"Number of threads to use for each process. "
5761
"When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices.");
62+
DEFINE_int32(tensor_parallel, 1, "");
5863
// precision
5964
DEFINE_string(dtype, "float32", "precision used in training (float32/bfloat16)");
6065

@@ -83,25 +88,42 @@ const std::unordered_map<std::string, GPT2::ModelType> kStrToModelType = {
8388
{"gpt2-xl", GPT2::ModelType::kGPT2XL},
8489
};
8590

86-
std::string GetDataParallelFactoryName(const nn::parallel::DistributedDataParallel::Rank &rank) { return "DDP"; }
8791
} // namespace
8892

8993
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
9094
DEFINE_validator(device,
9195
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
9296

93-
void Train(const nn::parallel::DistributedDataParallel::Rank &rank) {
97+
void Train(const nn::parallel::Rank &rank) {
98+
using namespace nn::parallel;
99+
94100
// select the device
95101
const Device *device;
96-
if (rank.IsDDP()) {
102+
103+
int ddp_world_size = global::GetDataParallelSize();
104+
int ddp_rank = 0;
105+
const ProcessGroup *ddp_pg = nullptr;
106+
107+
if (rank.IsParallel()) {
97108
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());
109+
110+
if (ddp_world_size > 1) {
111+
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()),
112+
GetDataParallelGroupRanks(rank.thread_rank()));
113+
ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank());
114+
}
115+
116+
if (global::GetTensorParallelSize() > 1) {
117+
ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()),
118+
GetTensorParallelGroupRanks(rank.thread_rank()));
119+
}
98120
} else {
99121
device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice()
100122
: DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0);
101123
}
102124

103125
// calculate gradient accumulation from the desired total batch size and the current run configuration
104-
const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * rank.WorldSize();
126+
const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size;
105127
CHECK_EQ(FLAGS_total_batch_size % tokens_per_fwdbwd, 0);
106128
const auto grad_accum_steps = FLAGS_total_batch_size / tokens_per_fwdbwd;
107129
LOG(INFO) << "total desired batch size: " << FLAGS_total_batch_size
@@ -140,18 +162,17 @@ void Train(const nn::parallel::DistributedDataParallel::Rank &rank) {
140162
// before wrapping the model with DistributedDataParallel (DDP).
141163
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
142164
// are created during the conversion.
143-
if (rank.IsDDP()) {
144-
model = std::make_shared<nn::parallel::DistributedDataParallel>(
145-
nn::parallel::DistributedDataParallel(model, rank.thread_rank()));
165+
if (ddp_world_size > 1) {
166+
model = std::make_shared<DistributedDataParallel>(DistributedDataParallel(model, rank.thread_rank()));
146167
}
147168

148169
DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
149-
FLAGS_batch_size, rank.thread_rank(), rank.WorldSize());
170+
FLAGS_batch_size, ddp_rank, ddp_world_size);
150171
std::optional<DistributedDataLoader> val_loader = std::nullopt;
151172
if (!FLAGS_input_val_bin.empty()) {
152173
val_loader = DistributedDataLoader(
153174
std::make_shared<TinyShakespeareDataset>(FLAGS_input_val_bin, FLAGS_sequence_length), FLAGS_batch_size,
154-
rank.thread_rank(), rank.WorldSize());
175+
ddp_rank, ddp_world_size);
155176
}
156177

157178
//
@@ -218,10 +239,8 @@ void Train(const nn::parallel::DistributedDataParallel::Rank &rank) {
218239
auto loss = loss_fn.Forward({logits, y})[0];
219240
loss = loss / grad_accum_steps;
220241
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward";
221-
if (rank.IsDDP()) {
222-
auto pg = infini_train::nn::parallel::ProcessGroupFactory::Instance()->Get(
223-
GetDataParallelFactoryName(rank));
224-
nn::parallel::function::AllReduce(loss, nn::parallel::function::ReduceOpType::kAvg, pg);
242+
if (ddp_world_size > 1) {
243+
function::AllReduce(loss, function::ReduceOpType::kAvg, ddp_pg);
225244
}
226245
auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice());
227246
if (FLAGS_dtype == kDtypeFP32) {
@@ -262,11 +281,14 @@ int main(int argc, char *argv[]) {
262281
gflags::ParseCommandLineFlags(&argc, &argv, true);
263282
google::InitGoogleLogging(argv[0]);
264283

284+
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel);
285+
265286
// NOTE(dcj): currently we only support single process
266-
if (FLAGS_data_parallel > 1) {
287+
if (FLAGS_nthread_per_process > 1) {
267288
std::vector<std::thread> threads;
268-
for (int idx = 0; idx < FLAGS_data_parallel; ++idx) {
269-
nn::parallel::DistributedDataParallel::Rank rank(0, idx, 1, FLAGS_data_parallel);
289+
for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) {
290+
nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx,
291+
nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process);
270292
threads.emplace_back(Train, rank);
271293
}
272294

example/llama3/main.cc

Lines changed: 35 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,15 @@
1313
#include "infini_train/include/nn/modules/module.h"
1414
#include "infini_train/include/nn/parallel/distributed_data_parallel.h"
1515
#include "infini_train/include/nn/parallel/parallel_functional.h"
16+
#include "infini_train/include/nn/parallel/rank.h"
1617
#include "infini_train/include/nn/parallel/reduce_op_type.h"
1718
#include "infini_train/include/optimizer.h"
1819
#ifdef PROFILE_MODE
1920
#include "infini_train/include/profiler.h"
2021
#endif
2122
#include "infini_train/include/nn/parallel/global.h"
2223
#include "infini_train/include/nn/parallel/process_group.h"
24+
#include "infini_train/include/nn/parallel/utils.h"
2325

2426
#include "example/common/tiny_shakespeare_dataset.h"
2527
#include "example/common/tokenizer.h"
@@ -50,11 +52,11 @@ DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
5052
// debugging
5153
DEFINE_bool(overfit_single_batch, true, "overfit just one batch of data");
5254
// memory management
53-
DEFINE_string(device, "cuda", "device type (cpu/cuda), useless if data_parallel=true");
55+
DEFINE_string(device, "cuda", "device type (cpu/cuda), useless if using parallel training mode");
5456
// parallel
5557
DEFINE_int32(
56-
data_parallel, 1,
57-
"Number of GPUs to use for data parallel training. "
58+
nthread_per_process, 1,
59+
"Number of threads to use for each process. "
5860
"When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices.");
5961
DEFINE_int32(tensor_parallel, 1, "");
6062
// precision
@@ -69,38 +71,42 @@ constexpr char kDeviceCPU[] = "cpu";
6971
constexpr char kDeviceCUDA[] = "cuda";
7072
constexpr char kDtypeFP32[] = "float32";
7173
constexpr char kDtypeBF16[] = "bfloat16";
72-
73-
bool IsTensorParallelMainRank(int tp_size, int rank) {
74-
// tp size: 2, world size: 8, rank: #
75-
return rank % tp_size == 0;
76-
}
7774
} // namespace
7875

7976
DEFINE_validator(model, [](const char *, const std::string &value) { return kSupportedModels.contains(value); });
8077
DEFINE_validator(device,
8178
[](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA; });
8279

83-
void Train(const nn::parallel::DistributedDataParallel::Rank &rank) {
80+
void Train(const nn::parallel::Rank &rank) {
81+
using namespace nn::parallel;
82+
8483
// select the device
8584
const Device *device;
86-
if (rank.IsDDP()) {
85+
86+
int ddp_world_size = global::GetDataParallelSize();
87+
int ddp_rank = 0;
88+
const ProcessGroup *ddp_pg = nullptr;
89+
90+
if (rank.IsParallel()) {
8791
device = DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, rank.thread_rank());
8892

89-
if (FLAGS_tensor_parallel > 1) {
90-
// tensor parallel enabled
91-
if (IsTensorParallelMainRank(FLAGS_tensor_parallel, rank.thread_rank())) {
92-
infini_train::nn::parallel::ProcessGroupFactory::Instance()->Create(
93-
infini_train::nn::parallel::GetTensorParallelProcessFactoryName(rank, FLAGS_tensor_parallel),
94-
FLAGS_tensor_parallel);
95-
}
93+
if (ddp_world_size > 1) {
94+
ddp_pg = ProcessGroupFactory::Instance()->GetOrCreate(GetDataParallelProcessGroupName(rank.thread_rank()),
95+
GetDataParallelGroupRanks(rank.thread_rank()));
96+
ddp_rank = ddp_pg->GetGroupRank(rank.thread_rank());
97+
}
98+
99+
if (global::GetTensorParallelSize() > 1) {
100+
ProcessGroupFactory::Instance()->GetOrCreate(GetTensorParallelProcessGroupName(rank.thread_rank()),
101+
GetTensorParallelGroupRanks(rank.thread_rank()));
96102
}
97103
} else {
98104
device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance()->GetDefaultDevice()
99105
: DeviceManager::Instance()->GetDevice(DeviceType::kCUDA, 0);
100106
}
101107

102108
// calculate gradient accumulation from the desired total batch size and the current run configuration
103-
const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * rank.WorldSize();
109+
const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size;
104110
CHECK_EQ(FLAGS_total_batch_size % tokens_per_fwdbwd, 0);
105111
const auto grad_accum_steps = FLAGS_total_batch_size / tokens_per_fwdbwd;
106112
if (rank.IsMainRank()) {
@@ -138,18 +144,17 @@ void Train(const nn::parallel::DistributedDataParallel::Rank &rank) {
138144
// before wrapping the model with DistributedDataParallel (DDP).
139145
// Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
140146
// are created during the conversion.
141-
if (rank.IsDDP()) {
142-
model = std::make_shared<nn::parallel::DistributedDataParallel>(
143-
nn::parallel::DistributedDataParallel(model, rank.thread_rank()));
147+
if (ddp_world_size > 1) {
148+
model = std::make_shared<DistributedDataParallel>(DistributedDataParallel(model, rank.thread_rank()));
144149
}
145150

146151
DistributedDataLoader train_loader(std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
147-
FLAGS_batch_size, rank.thread_rank(), rank.WorldSize());
152+
FLAGS_batch_size, ddp_rank, ddp_world_size);
148153
std::optional<DistributedDataLoader> val_loader = std::nullopt;
149154
if (!FLAGS_input_val_bin.empty()) {
150155
val_loader = DistributedDataLoader(
151156
std::make_shared<TinyShakespeareDataset>(FLAGS_input_val_bin, FLAGS_sequence_length), FLAGS_batch_size,
152-
rank.thread_rank(), rank.WorldSize());
157+
ddp_rank, ddp_world_size);
153158
}
154159

155160
//
@@ -215,8 +220,8 @@ void Train(const nn::parallel::DistributedDataParallel::Rank &rank) {
215220
auto loss = loss_fn.Forward({logits, y})[0];
216221
loss = loss / grad_accum_steps;
217222
LOG(INFO) << "Rank " << rank.thread_rank() << ": finish loss forward";
218-
if (rank.IsDDP()) {
219-
nn::parallel::function::AllReduce(loss, nn::parallel::function::ReduceOpType::kAvg);
223+
if (ddp_world_size > 1) {
224+
function::AllReduce(loss, function::ReduceOpType::kAvg, ddp_pg);
220225
}
221226
auto loss_cpu = loss->To(DeviceManager::Instance()->GetDefaultDevice());
222227
if (FLAGS_dtype == kDtypeFP32) {
@@ -258,13 +263,14 @@ int main(int argc, char *argv[]) {
258263
gflags::ParseCommandLineFlags(&argc, &argv, true);
259264
google::InitGoogleLogging(argv[0]);
260265

261-
infini_train::global::InitAllEnv(FLAGS_data_parallel, FLAGS_tensor_parallel);
266+
nn::parallel::global::InitAllEnv(FLAGS_nthread_per_process, FLAGS_tensor_parallel);
262267

263268
// NOTE(dcj): currently we only support single process
264-
if (FLAGS_data_parallel > 1) {
269+
if (FLAGS_nthread_per_process > 1) {
265270
std::vector<std::thread> threads;
266-
for (int idx = 0; idx < FLAGS_data_parallel; ++idx) {
267-
nn::parallel::DistributedDataParallel::Rank rank(0, idx, 1, FLAGS_data_parallel);
271+
for (int idx = 0; idx < FLAGS_nthread_per_process; ++idx) {
272+
nn::parallel::Rank rank(nn::parallel::global::GetLocalProcRank(), idx,
273+
nn::parallel::global::GetNprocPerNode(), FLAGS_nthread_per_process);
268274
threads.emplace_back(Train, rank);
269275
}
270276

0 commit comments

Comments
 (0)