Skip to content

Commit

Permalink
feat: initial bf16 support
Browse files Browse the repository at this point in the history
no optimization yet

Signed-off-by: Wataru Ishida <[email protected]>
  • Loading branch information
ishidawataru committed Mar 15, 2024
1 parent 17e0a47 commit f98b252
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 30 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ FROM optcast AS unittest

ENV RUST_LOG=info
ENV NCCL_SOCKET_IFNAME=lo
RUN cd reduction_server && cargo test --all -- --nocapture
RUN cd reduction_server && cargo test --all -- --nocapture --test-threads=1

FROM nvcr.io/nvidia/cuda:12.3.1-devel-ubuntu22.04 AS final

Expand Down
77 changes: 55 additions & 22 deletions reduction_server/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::io::{Read, Write};
use std::net::{TcpListener, TcpStream};
use std::sync::Arc;

use half::f16;
use half::{bf16, f16};
use log::{info, trace};

use crate::utils::*;
Expand Down Expand Up @@ -216,6 +216,8 @@ pub(crate) fn client(args: Args) {
do_client::<f32>(args.as_ref(), comm);
} else if args.data_type == DataType::F16 {
do_client::<f16>(args.as_ref(), comm);
} else if args.data_type == DataType::BF16 {
do_client::<bf16>(args.as_ref(), comm);
}
})
})
Expand Down Expand Up @@ -286,6 +288,8 @@ pub(crate) fn bench(args: Args) {
do_client::<f32>(args.as_ref(), comm);
} else if args.data_type == DataType::F16 {
do_client::<f16>(args.as_ref(), comm);
} else if args.data_type == DataType::BF16 {
do_client::<bf16>(args.as_ref(), comm);
}
})
})
Expand All @@ -301,29 +305,58 @@ mod tests {
use crate::utils::tests::initialize;
use clap::Parser;

#[test]
fn test_bench() {
fn do_bench(dt: &str) {
initialize();
let b = std::thread::spawn(|| {
let count = format!("{}", 1024 * 1024);
let args = Args::parse_from([
"--bench",
"--address",
"127.0.0.1",
"--port",
"8080",
"--count",
&count,
]);
bench(args);
});
let c = std::thread::spawn(|| {
let count = format!("{}", 1024 * 1024);
let args =
Args::parse_from(["--client", "--address", "127.0.0.1:8080", "--count", &count]);
client(args);
});
let b = {
let dt = dt.to_string();
std::thread::spawn(move || {
let count = format!("{}", 1024 * 1024);
let args = Args::parse_from([
"--bench",
"--address",
"127.0.0.1",
"--port",
"8080",
"--count",
&count,
"--data-type",
&dt,
]);
bench(args);
})
};
let c = {
let dt = dt.to_string();
std::thread::spawn(move || {
let count = format!("{}", 1024 * 1024);
let args = Args::parse_from([
"--client",
"--address",
"127.0.0.1:8080",
"--count",
&count,
"--data-type",
&dt,
]);
client(args);
})
};
b.join().unwrap();
c.join().unwrap();
}

#[test]
fn test_bench_f32() {
do_bench("f32");
}

#[test]
fn test_bench_f16() {
do_bench("f16");
}

#[test]
fn test_bench_bf16() {
do_bench("bf16");
}
}
17 changes: 16 additions & 1 deletion reduction_server/src/reduce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
*/

use aligned_box::AlignedBox;
use half::f16;
use half::{f16, bf16};

use crate::utils::{alignment, Float};

Expand Down Expand Up @@ -115,6 +115,21 @@ impl Reduce<f32> for [f32] {
}
}

impl Reduce<bf16> for [bf16] {
fn reduce(&mut self, recv_bufs: &Vec<&[bf16]>, _: Option<&mut WorkingMemory>) -> Result<(), ()> {
for (i, recv) in recv_bufs.iter().enumerate() {
if i == 0 {
self.copy_from_slice(&recv.as_ref());
} else {
for j in 0..self.len() {
self[j] += recv[j];
}
}
}
Ok(())
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
8 changes: 5 additions & 3 deletions reduction_server/src/ring.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ use std::sync::atomic::AtomicUsize;
use std::sync::{Arc, Mutex};

use aligned_box::AlignedBox;
use half::f16;
use half::{bf16, f16};
use log::{info, trace};

use crate::utils::*;
use crate::reduce::{Reduce, WorkingMemory};
use crate::utils::*;

use crate::nccl_net;
use crate::nccl_net::Comm;
Expand Down Expand Up @@ -581,8 +581,10 @@ pub(crate) fn ring(args: Args) {
let (recvs, sends) = comm.into_iter().unzip();
if args.data_type == DataType::F32 {
do_ring::<f32>(args, ch, recvs, sends);
} else {
} else if args.data_type == DataType::F16 {
do_ring::<f16>(args, ch, recvs, sends);
} else if args.data_type == DataType::BF16 {
do_ring::<bf16>(args, ch, recvs, sends);
}
})
})
Expand Down
6 changes: 4 additions & 2 deletions reduction_server/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@ use std::net::TcpListener;
use std::sync::atomic::AtomicUsize;
use std::sync::Arc;

use half::{bf16, f16};
use log::{info, trace, warn};
use half::f16;

use crate::utils::*;
use crate::reduce::{Reduce, WorkingMemory};
use crate::utils::*;

use crate::nccl_net;
use crate::nccl_net::Comm;
Expand Down Expand Up @@ -579,5 +579,7 @@ pub(crate) fn server(args: Args) {
do_server::<f32>(args);
} else if args.data_type == DataType::F16 {
do_server::<f16>(args);
} else if args.data_type == DataType::BF16 {
do_server::<bf16>(args);
}
}
4 changes: 3 additions & 1 deletion reduction_server/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use std::fmt::Debug;
use std::time::Duration;

use clap::{Parser, ValueEnum};
use half::f16;
use half::{f16, bf16};
use log::info;
use num_traits::FromPrimitive;

Expand All @@ -30,6 +30,7 @@ pub(crate) fn transpose<T>(v: Vec<Vec<T>>) -> Vec<Vec<T>> {
pub(crate) enum DataType {
F32,
F16,
BF16,
}

#[derive(Parser, Debug, Clone)]
Expand Down Expand Up @@ -90,6 +91,7 @@ pub(crate) trait Float:

impl Float for f32 {}
impl Float for f16 {}
impl Float for bf16 {}

pub(crate) fn alignment(size: usize) -> usize {
let page = unsafe { libc::sysconf(libc::_SC_PAGESIZE) as usize };
Expand Down

0 comments on commit f98b252

Please sign in to comment.