Skip to content

Commit e495252

Browse files
FFFrogpytorchmergebot
authored andcommitted
Make TraceUtils.h to be device-agnostic (pytorch#126969)
Some features of third-party devices depend on TraceUtils.h, so some of the CUDA code was removed and split into NCCLUtils files. In addition, some common functions still remain in TraceUtils.h since I'm not sure if other devices will use them later. Pull Request resolved: pytorch#126969 Approved by: https://github.com/c-p-i-o
1 parent 7fac03a commit e495252

File tree

3 files changed

+494
-498
lines changed

3 files changed

+494
-498
lines changed

torch/csrc/distributed/c10d/NCCLUtils.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,56 @@ control_plane::RegisterHandler dumpHandler{
289289
"application/octet-stream");
290290
}};
291291

292+
void DebugInfoWriter::write(const std::string& ncclTrace) {
293+
// Open a file for writing. The ios::binary flag is used to write data as
294+
// binary.
295+
std::ofstream file(filename_, std::ios::binary);
296+
297+
// Check if the file was opened successfully.
298+
if (!file.is_open()) {
299+
LOG(ERROR) << "Error opening file for writing NCCLPG debug info: "
300+
<< filename_;
301+
return;
302+
}
303+
304+
file.write(ncclTrace.data(), ncclTrace.size());
305+
LOG(INFO) << "Finished writing NCCLPG debug info to " << filename_;
306+
}
307+
308+
DebugInfoWriter& DebugInfoWriter::getWriter(int rank) {
309+
if (writer_ == nullptr) {
310+
std::string fileNamePrefix = getCvarString(
311+
{"TORCH_NCCL_DEBUG_INFO_TEMP_FILE"}, "/tmp/nccl_trace_rank_");
312+
// Using std::unique_ptr here to auto-delete the writer object
313+
// when the pointer itself is destroyed.
314+
std::unique_ptr<DebugInfoWriter> writerPtr(
315+
new DebugInfoWriter(fileNamePrefix, rank));
316+
DebugInfoWriter::registerWriter(std::move(writerPtr));
317+
}
318+
return *writer_;
319+
}
320+
321+
void DebugInfoWriter::registerWriter(std::unique_ptr<DebugInfoWriter> writer) {
322+
TORCH_CHECK_WITH(
323+
DistBackendError,
324+
hasWriterRegistered_.load() == false,
325+
"debugInfoWriter already registered");
326+
hasWriterRegistered_.store(true);
327+
writer_ = std::move(writer);
328+
}
329+
330+
std::unique_ptr<DebugInfoWriter> DebugInfoWriter::writer_ = nullptr;
331+
std::atomic<bool> DebugInfoWriter::hasWriterRegistered_(false);
332+
333+
float getDurationFromEvent(
334+
at::cuda::CUDAEvent& ncclStartEvent,
335+
at::cuda::CUDAEvent& ncclEndEvent) {
336+
TORCH_CHECK(
337+
ncclEndEvent.query(),
338+
"getDuration can only be called after work is succeeded.")
339+
return ncclStartEvent.elapsed_time(ncclEndEvent);
340+
}
341+
292342
} // namespace c10d
293343

294344
#endif // USE_C10D_NCCL

0 commit comments

Comments
 (0)