From f98b2522309cf25216e631cc0fe527e23219aa57 Mon Sep 17 00:00:00 2001 From: Wataru Ishida Date: Fri, 15 Mar 2024 08:24:57 +0000 Subject: [PATCH] feat: initial bf16 support no optimization yet Signed-off-by: Wataru Ishida --- Dockerfile | 2 +- reduction_server/src/client.rs | 77 ++++++++++++++++++++++++---------- reduction_server/src/reduce.rs | 17 +++++++- reduction_server/src/ring.rs | 8 ++-- reduction_server/src/server.rs | 6 ++- reduction_server/src/utils.rs | 4 +- 6 files changed, 84 insertions(+), 30 deletions(-) diff --git a/Dockerfile b/Dockerfile index 473282d..7544194 100644 --- a/Dockerfile +++ b/Dockerfile @@ -48,7 +48,7 @@ FROM optcast AS unittest ENV RUST_LOG=info ENV NCCL_SOCKET_IFNAME=lo -RUN cd reduction_server && cargo test --all -- --nocapture +RUN cd reduction_server && cargo test --all -- --nocapture --test-threads=1 FROM nvcr.io/nvidia/cuda:12.3.1-devel-ubuntu22.04 AS final diff --git a/reduction_server/src/client.rs b/reduction_server/src/client.rs index 015d4dc..c24b22d 100644 --- a/reduction_server/src/client.rs +++ b/reduction_server/src/client.rs @@ -8,7 +8,7 @@ use std::io::{Read, Write}; use std::net::{TcpListener, TcpStream}; use std::sync::Arc; -use half::f16; +use half::{bf16, f16}; use log::{info, trace}; use crate::utils::*; @@ -216,6 +216,8 @@ pub(crate) fn client(args: Args) { do_client::(args.as_ref(), comm); } else if args.data_type == DataType::F16 { do_client::(args.as_ref(), comm); + } else if args.data_type == DataType::BF16 { + do_client::(args.as_ref(), comm); } }) }) @@ -286,6 +288,8 @@ pub(crate) fn bench(args: Args) { do_client::(args.as_ref(), comm); } else if args.data_type == DataType::F16 { do_client::(args.as_ref(), comm); + } else if args.data_type == DataType::BF16 { + do_client::(args.as_ref(), comm); } }) }) @@ -301,29 +305,58 @@ mod tests { use crate::utils::tests::initialize; use clap::Parser; - #[test] - fn test_bench() { + fn do_bench(dt: &str) { initialize(); - let b = std::thread::spawn(|| { - let count = format!("{}", 1024 * 1024); - let args = Args::parse_from([ - "--bench", - "--address", - "127.0.0.1", - "--port", - "8080", - "--count", - &count, - ]); - bench(args); - }); - let c = std::thread::spawn(|| { - let count = format!("{}", 1024 * 1024); - let args = - Args::parse_from(["--client", "--address", "127.0.0.1:8080", "--count", &count]); - client(args); - }); + let b = { + let dt = dt.to_string(); + std::thread::spawn(move || { + let count = format!("{}", 1024 * 1024); + let args = Args::parse_from([ + "--bench", + "--address", + "127.0.0.1", + "--port", + "8080", + "--count", + &count, + "--data-type", + &dt, + ]); + bench(args); + }) + }; + let c = { + let dt = dt.to_string(); + std::thread::spawn(move || { + let count = format!("{}", 1024 * 1024); + let args = Args::parse_from([ + "--client", + "--address", + "127.0.0.1:8080", + "--count", + &count, + "--data-type", + &dt, + ]); + client(args); + }) + }; b.join().unwrap(); c.join().unwrap(); } + + #[test] + fn test_bench_f32() { + do_bench("f32"); + } + + #[test] + fn test_bench_f16() { + do_bench("f16"); + } + + #[test] + fn test_bench_bf16() { + do_bench("bf16"); + } } diff --git a/reduction_server/src/reduce.rs b/reduction_server/src/reduce.rs index 491293f..3260743 100644 --- a/reduction_server/src/reduce.rs +++ b/reduction_server/src/reduce.rs @@ -5,7 +5,7 @@ */ use aligned_box::AlignedBox; -use half::f16; +use half::{f16, bf16}; use crate::utils::{alignment, Float}; @@ -115,6 +115,21 @@ impl Reduce for [f32] { } } +impl Reduce for [bf16] { + fn reduce(&mut self, recv_bufs: &Vec<&[bf16]>, _: Option<&mut WorkingMemory>) -> Result<(), ()> { + for (i, recv) in recv_bufs.iter().enumerate() { + if i == 0 { + self.copy_from_slice(&recv.as_ref()); + } else { + for j in 0..self.len() { + self[j] += recv[j]; + } + } + } + Ok(()) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/reduction_server/src/ring.rs b/reduction_server/src/ring.rs index fbf8d02..296dbd7 100644 --- a/reduction_server/src/ring.rs +++ b/reduction_server/src/ring.rs @@ -11,11 +11,11 @@ use std::sync::atomic::AtomicUsize; use std::sync::{Arc, Mutex}; use aligned_box::AlignedBox; -use half::f16; +use half::{bf16, f16}; use log::{info, trace}; -use crate::utils::*; use crate::reduce::{Reduce, WorkingMemory}; +use crate::utils::*; use crate::nccl_net; use crate::nccl_net::Comm; @@ -581,8 +581,10 @@ pub(crate) fn ring(args: Args) { let (recvs, sends) = comm.into_iter().unzip(); if args.data_type == DataType::F32 { do_ring::(args, ch, recvs, sends); - } else { + } else if args.data_type == DataType::F16 { do_ring::(args, ch, recvs, sends); + } else if args.data_type == DataType::BF16 { + do_ring::(args, ch, recvs, sends); } }) }) diff --git a/reduction_server/src/server.rs b/reduction_server/src/server.rs index aae13ee..05d211e 100644 --- a/reduction_server/src/server.rs +++ b/reduction_server/src/server.rs @@ -11,11 +11,11 @@ use std::net::TcpListener; use std::sync::atomic::AtomicUsize; use std::sync::Arc; +use half::{bf16, f16}; use log::{info, trace, warn}; -use half::f16; -use crate::utils::*; use crate::reduce::{Reduce, WorkingMemory}; +use crate::utils::*; use crate::nccl_net; use crate::nccl_net::Comm; @@ -579,5 +579,7 @@ pub(crate) fn server(args: Args) { do_server::(args); } else if args.data_type == DataType::F16 { do_server::(args); + } else if args.data_type == DataType::BF16 { + do_server::(args); } } diff --git a/reduction_server/src/utils.rs b/reduction_server/src/utils.rs index 5f36c8c..829c4a8 100644 --- a/reduction_server/src/utils.rs +++ b/reduction_server/src/utils.rs @@ -8,7 +8,7 @@ use std::fmt::Debug; use std::time::Duration; use clap::{Parser, ValueEnum}; -use half::f16; +use half::{f16, bf16}; use log::info; use num_traits::FromPrimitive; @@ -30,6 +30,7 @@ pub(crate) fn transpose(v: Vec>) -> Vec> { pub(crate) enum DataType { F32, F16, + BF16, } #[derive(Parser, Debug, Clone)] @@ -90,6 +91,7 @@ pub(crate) trait Float: impl Float for f32 {} impl Float for f16 {} +impl Float for bf16 {} pub(crate) fn alignment(size: usize) -> usize { let page = unsafe { libc::sysconf(libc::_SC_PAGESIZE) as usize };