1313#include " infini_train/include/nn/modules/module.h"
1414#include " infini_train/include/nn/parallel/distributed_data_parallel.h"
1515#include " infini_train/include/nn/parallel/parallel_functional.h"
16+ #include " infini_train/include/nn/parallel/rank.h"
1617#include " infini_train/include/nn/parallel/reduce_op_type.h"
1718#include " infini_train/include/optimizer.h"
1819#ifdef PROFILE_MODE
1920#include " infini_train/include/profiler.h"
2021#endif
2122#include " infini_train/include/nn/parallel/global.h"
2223#include " infini_train/include/nn/parallel/process_group.h"
24+ #include " infini_train/include/nn/parallel/utils.h"
2325
2426#include " example/common/tiny_shakespeare_dataset.h"
2527#include " example/common/tokenizer.h"
@@ -50,11 +52,11 @@ DEFINE_uint32(sample_every, 0, "how often to sample from the model?");
5052// debugging
5153DEFINE_bool (overfit_single_batch, true , " overfit just one batch of data" );
5254// memory management
53- DEFINE_string (device, " cuda" , " device type (cpu/cuda), useless if data_parallel=true " );
55+ DEFINE_string (device, " cuda" , " device type (cpu/cuda), useless if using parallel training mode " );
5456// parallel
5557DEFINE_int32 (
56- data_parallel , 1 ,
57- " Number of GPUs to use for data parallel training . "
58+ nthread_per_process , 1 ,
59+ " Number of threads to use for each process . "
5860 " When set > 1, enables data parallelism with device=cuda on the specified number of visible CUDA devices." );
5961DEFINE_int32 (tensor_parallel, 1 , " " );
6062// precision
@@ -69,38 +71,42 @@ constexpr char kDeviceCPU[] = "cpu";
6971constexpr char kDeviceCUDA [] = " cuda" ;
7072constexpr char kDtypeFP32 [] = " float32" ;
7173constexpr char kDtypeBF16 [] = " bfloat16" ;
72-
73- bool IsTensorParallelMainRank (int tp_size, int rank) {
74- // tp size: 2, world size: 8, rank: #
75- return rank % tp_size == 0 ;
76- }
7774} // namespace
7875
7976DEFINE_validator (model, [](const char *, const std::string &value) { return kSupportedModels .contains (value); });
8077DEFINE_validator (device,
8178 [](const char *, const std::string &value) { return value == kDeviceCPU || value == kDeviceCUDA ; });
8279
83- void Train (const nn::parallel::DistributedDataParallel::Rank &rank) {
80+ void Train (const nn::parallel::Rank &rank) {
81+ using namespace nn ::parallel;
82+
8483 // select the device
8584 const Device *device;
86- if (rank.IsDDP ()) {
85+
86+ int ddp_world_size = global::GetDataParallelSize ();
87+ int ddp_rank = 0 ;
88+ const ProcessGroup *ddp_pg = nullptr ;
89+
90+ if (rank.IsParallel ()) {
8791 device = DeviceManager::Instance ()->GetDevice (DeviceType::kCUDA , rank.thread_rank ());
8892
89- if (FLAGS_tensor_parallel > 1 ) {
90- // tensor parallel enabled
91- if (IsTensorParallelMainRank (FLAGS_tensor_parallel, rank.thread_rank ())) {
92- infini_train::nn::parallel::ProcessGroupFactory::Instance ()->Create (
93- infini_train::nn::parallel::GetTensorParallelProcessFactoryName (rank, FLAGS_tensor_parallel),
94- FLAGS_tensor_parallel);
95- }
93+ if (ddp_world_size > 1 ) {
94+ ddp_pg = ProcessGroupFactory::Instance ()->GetOrCreate (GetDataParallelProcessGroupName (rank.thread_rank ()),
95+ GetDataParallelGroupRanks (rank.thread_rank ()));
96+ ddp_rank = ddp_pg->GetGroupRank (rank.thread_rank ());
97+ }
98+
99+ if (global::GetTensorParallelSize () > 1 ) {
100+ ProcessGroupFactory::Instance ()->GetOrCreate (GetTensorParallelProcessGroupName (rank.thread_rank ()),
101+ GetTensorParallelGroupRanks (rank.thread_rank ()));
96102 }
97103 } else {
98104 device = FLAGS_device == kDeviceCPU ? DeviceManager::Instance ()->GetDefaultDevice ()
99105 : DeviceManager::Instance ()->GetDevice (DeviceType::kCUDA , 0 );
100106 }
101107
102108 // calculate gradient accumulation from the desired total batch size and the current run configuration
103- const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * rank. WorldSize () ;
109+ const auto tokens_per_fwdbwd = FLAGS_batch_size * FLAGS_sequence_length * ddp_world_size ;
104110 CHECK_EQ (FLAGS_total_batch_size % tokens_per_fwdbwd, 0 );
105111 const auto grad_accum_steps = FLAGS_total_batch_size / tokens_per_fwdbwd;
106112 if (rank.IsMainRank ()) {
@@ -138,18 +144,17 @@ void Train(const nn::parallel::DistributedDataParallel::Rank &rank) {
138144 // before wrapping the model with DistributedDataParallel (DDP).
139145 // Otherwise, DDP’s gradient hooks may be lost because new parameter tensors
140146 // are created during the conversion.
141- if (rank.IsDDP ()) {
142- model = std::make_shared<nn::parallel::DistributedDataParallel>(
143- nn::parallel::DistributedDataParallel (model, rank.thread_rank ()));
147+ if (ddp_world_size > 1 ) {
148+ model = std::make_shared<DistributedDataParallel>(DistributedDataParallel (model, rank.thread_rank ()));
144149 }
145150
146151 DistributedDataLoader train_loader (std::make_shared<TinyShakespeareDataset>(FLAGS_input_bin, FLAGS_sequence_length),
147- FLAGS_batch_size, rank. thread_rank (), rank. WorldSize () );
152+ FLAGS_batch_size, ddp_rank, ddp_world_size );
148153 std::optional<DistributedDataLoader> val_loader = std::nullopt ;
149154 if (!FLAGS_input_val_bin.empty ()) {
150155 val_loader = DistributedDataLoader (
151156 std::make_shared<TinyShakespeareDataset>(FLAGS_input_val_bin, FLAGS_sequence_length), FLAGS_batch_size,
152- rank. thread_rank (), rank. WorldSize () );
157+ ddp_rank, ddp_world_size );
153158 }
154159
155160 //
@@ -215,8 +220,8 @@ void Train(const nn::parallel::DistributedDataParallel::Rank &rank) {
215220 auto loss = loss_fn.Forward ({logits, y})[0 ];
216221 loss = loss / grad_accum_steps;
217222 LOG (INFO) << " Rank " << rank.thread_rank () << " : finish loss forward" ;
218- if (rank. IsDDP () ) {
219- nn::parallel:: function::AllReduce (loss, nn::parallel:: function::ReduceOpType::kAvg );
223+ if (ddp_world_size > 1 ) {
224+ function::AllReduce (loss, function::ReduceOpType::kAvg , ddp_pg );
220225 }
221226 auto loss_cpu = loss->To (DeviceManager::Instance ()->GetDefaultDevice ());
222227 if (FLAGS_dtype == kDtypeFP32 ) {
@@ -258,13 +263,14 @@ int main(int argc, char *argv[]) {
258263 gflags::ParseCommandLineFlags (&argc, &argv, true );
259264 google::InitGoogleLogging (argv[0 ]);
260265
261- infini_train:: global::InitAllEnv (FLAGS_data_parallel , FLAGS_tensor_parallel);
266+ nn::parallel:: global::InitAllEnv (FLAGS_nthread_per_process , FLAGS_tensor_parallel);
262267
263268 // NOTE(dcj): currently we only support single process
264- if (FLAGS_data_parallel > 1 ) {
269+ if (FLAGS_nthread_per_process > 1 ) {
265270 std::vector<std::thread> threads;
266- for (int idx = 0 ; idx < FLAGS_data_parallel; ++idx) {
267- nn::parallel::DistributedDataParallel::Rank rank (0 , idx, 1 , FLAGS_data_parallel);
271+ for (int idx = 0 ; idx < FLAGS_nthread_per_process; ++idx) {
272+ nn::parallel::Rank rank (nn::parallel::global::GetLocalProcRank (), idx,
273+ nn::parallel::global::GetNprocPerNode (), FLAGS_nthread_per_process);
268274 threads.emplace_back (Train, rank);
269275 }
270276
0 commit comments