diff --git a/run_train.py b/run_train.py index b33231f4..70305b1c 100644 --- a/run_train.py +++ b/run_train.py @@ -8,6 +8,7 @@ ``` """ import argparse +import subprocess from typing import Dict, cast import numpy as np @@ -41,6 +42,16 @@ logger = logging.get_logger(__name__) +def run_comms_benchmark(): + result = subprocess.run( + ["mpirun", "-np", "2", "python", "scripts/comms_benchmark.py"], capture_output=True, text=True + ) + output_lines = result.stdout.strip().split("\n") + # Print only the first two lines of the output + for line in output_lines[:2]: + print(line) + + def get_dataloader_from_data_stage( trainer: DistributedTrainer, data: DataArgs, @@ -229,6 +240,8 @@ def get_args(): args = get_args() config_file = args.config_file + run_comms_benchmark() + # Load trainer and data trainer = DistributedTrainer(config_file) dataloader = get_dataloader(trainer) diff --git a/scripts/comms_benchmark.py b/scripts/comms_benchmark.py new file mode 100644 index 00000000..b31e5d8f --- /dev/null +++ b/scripts/comms_benchmark.py @@ -0,0 +1,90 @@ +import time + +import numpy as np +from mpi4py import MPI + + +def bandwith_test(size_in_mb=100, trials=10, warmup_trials=2): + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + size = comm.Get_size() + msg_size = size_in_mb * 1024 * 1024 # To convert into bytes + tot_time = 0 + + if size < 2: + raise ValueError("The benchmark requires 2 MPI processes.") + + comm.Barrier() # Synchronize before warmup + + for _ in range(warmup_trials): + if rank == 0: + data = np.ones(msg_size, dtype="b") # Array of bytes + comm.Send([data, MPI.BYTE], dest=1, tag=0) + comm.Recv([data, MPI.BYTE], source=1, tag=1) + elif rank == 1: + data = np.empty(msg_size, dtype="b") + comm.Recv([data, MPI.BYTE], source=0, tag=0) + comm.Send([data, MPI.BYTE], dest=0, tag=1) + + comm.Barrier() # Synchronize the processes before benchmark + + for _ in range(0, trials): + if rank == 0: + data = np.ones(msg_size, dtype="b") # array of bytes + start_time = time.time() + comm.Send([data, MPI.BYTE], dest=1, tag=0) + comm.Recv([data, MPI.BYTE], source=1, tag=1) + end_time = time.time() + elapsed_time = end_time - start_time + tot_time += elapsed_time + elif rank == 1: + data = np.empty(msg_size, dtype="b") + comm.Recv([data, MPI.BYTE], source=0, tag=0) + comm.Send([data, MPI.BYTE], dest=0, tag=1) + + comm.Barrier() # Synchronize the processes after benchmark + + if rank == 0: + avg_time = tot_time / trials + bandwidth = 2 * size_in_mb / avg_time # Round trip + print(f"Average Bandwidth: {bandwidth:.2f} MB/s") + + +def latency_test(num_trials=1000): + comm = MPI.COMM_WORLD + rank = comm.Get_rank() + size = comm.Get_size() + msg_size = 1 # 1 byte + tot_time = 0 + + if size < 2: + raise ValueError("The benchmark requires at least 2 MPI processes.") + + # Synchronize all processes before starting the benchmark + comm.Barrier() + + for _ in range(num_trials): + if rank == 0: + data = np.ones(msg_size, dtype="b") + start_time = time.perf_counter() + comm.Send([data, MPI.BYTE], dest=1, tag=0) + comm.Recv([data, MPI.BYTE], source=1, tag=1) + end_time = time.perf_counter() + elapsed_time = end_time - start_time + tot_time += elapsed_time + elif rank == 1: + data = np.empty(msg_size, dtype="b") + comm.Recv([data, MPI.BYTE], source=0, tag=0) + comm.Send([data, MPI.BYTE], dest=0, tag=1) + + # Synchronize all processes after finishing the benchmark + comm.Barrier() + + if rank == 0: + avg_latency = (tot_time / num_trials) * 1e6 # Convert to microseconds + print(f"Average Latency: {avg_latency:.2f} µs") + + +if __name__ == "__main__": + bandwith_test() + latency_test()