From d482b2382e78983ffef75f9465e11c7ab8b02fad Mon Sep 17 00:00:00 2001 From: Wataru Ishida Date: Mon, 25 Mar 2024 03:56:31 +0000 Subject: [PATCH 1/3] fix: recv/send_thread default value handling in do_server() Signed-off-by: Wataru Ishida --- reduction_server/src/main.rs | 10 +--------- reduction_server/src/server.rs | 10 ++++++++++ 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/reduction_server/src/main.rs b/reduction_server/src/main.rs index 6731597..1ac7b4e 100644 --- a/reduction_server/src/main.rs +++ b/reduction_server/src/main.rs @@ -32,15 +32,7 @@ fn main() { .init(); nccl_net::init(); - let mut args = Args::parse(); - - if args.recv_threads == 0 { - args.recv_threads = args.nrank - } - - if args.send_threads == 0 { - args.send_threads = args.nrank - } + let args = Args::parse(); if args.client { client(args); diff --git a/reduction_server/src/server.rs b/reduction_server/src/server.rs index 05d211e..5f80dbf 100644 --- a/reduction_server/src/server.rs +++ b/reduction_server/src/server.rs @@ -447,6 +447,16 @@ fn recv_loop( } fn do_server(args: Args) { + let mut args = args; + + if args.recv_threads == 0 { + args.recv_threads = args.nrank + } + + if args.send_threads == 0 { + args.send_threads = args.nrank + } + let listener = TcpListener::bind(format!("{}:{}", args.address, args.port)).expect("failed to bind"); From 88ccde3dadb4230f61a42e6d27e6d357a72c22aa Mon Sep 17 00:00:00 2001 From: Wataru Ishida Date: Mon, 25 Mar 2024 03:58:02 +0000 Subject: [PATCH 2/3] fix(server): fix working memory size Signed-off-by: Wataru Ishida --- reduction_server/src/server.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/reduction_server/src/server.rs b/reduction_server/src/server.rs index 5f80dbf..e246acd 100644 --- a/reduction_server/src/server.rs +++ b/reduction_server/src/server.rs @@ -94,7 +94,7 @@ fn reduce_loop( info!("reduce thread({}) all ranks get connected!", i); let mut mems = (0..jobs.len()) - .map(|_| WorkingMemory::new(args.count, args.recv_threads)) + .map(|_| WorkingMemory::new(args.count / args.reduce_threads, args.recv_threads)) .collect::>(); loop { From b8f9e0d340bfc1bd13fe496648ad2cdcdee081a4 Mon Sep 17 00:00:00 2001 From: Wataru Ishida Date: Mon, 25 Mar 2024 03:58:44 +0000 Subject: [PATCH 3/3] chore(test): test server() Signed-off-by: Wataru Ishida --- reduction_server/src/server.rs | 68 ++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/reduction_server/src/server.rs b/reduction_server/src/server.rs index e246acd..0cc8f14 100644 --- a/reduction_server/src/server.rs +++ b/reduction_server/src/server.rs @@ -593,3 +593,71 @@ pub(crate) fn server(args: Args) { do_server::(args); } } + +// test +#[cfg(test)] +mod tests { + use super::*; + use crate::client::client; + use crate::utils::tests::initialize; + use clap::Parser; + + fn do_test(dt: &str) { + initialize(); + let nrank = 4; + let server = { + let dt = dt.to_string(); + std::thread::spawn(move || { + let nrank = format!("{}", nrank); + let args = Args::parse_from([ + "--verbose", // doesn't work without specifying a flag that doesn't take an argument + "--port", + "8080", + "--data-type", + &dt, + "--nrank", + &nrank, + "--nreq", + "1", // when using socket plugin, concurrent recv/send requests doesn't work + ]); + server(args); + }) + }; + (0..nrank) + .map(|_| { + let dt = dt.to_string(); + std::thread::spawn(move || { + std::thread::sleep(std::time::Duration::from_millis(100)); + let args = Args::parse_from([ + "--client", + "--address", + "127.0.0.1:8080", + "--data-type", + &dt, + "--nreq", + "1", // when using socket plugin, concurrent recv/send requests doesn't work + ]); + client(args); + }) + }) + .collect::>() + .into_iter() + .for_each(|h| h.join().unwrap()); + server.join().unwrap(); + } + + #[test] + fn test_server_f32() { + do_test("f32"); + } + + #[test] + fn test_server_f16() { + do_test("f16"); + } + + #[test] + fn test_server_bf16() { + do_test("bf16"); + } +}