diff --git a/include/dlaf/eigensolver/reduction_to_band/impl.h b/include/dlaf/eigensolver/reduction_to_band/impl.h index 66d6e110a0..efdb55ee55 100644 --- a/include/dlaf/eigensolver/reduction_to_band/impl.h +++ b/include/dlaf/eigensolver/reduction_to_band/impl.h @@ -31,6 +31,7 @@ #include "dlaf/lapack/tile.h" #include "dlaf/matrix/copy_tile.h" #include "dlaf/matrix/distribution.h" +#include "dlaf/matrix/extra_buffers.h" #include "dlaf/matrix/index.h" #include "dlaf/matrix/matrix.h" #include "dlaf/matrix/panel.h" @@ -455,20 +456,24 @@ void gemmComputeW2(matrix::Matrix& w2, matrix::Panel(thread_priority::high))); + ExtraBuffers buffers(w2.blockSize(), 6); + + //// Note: + //// Not all ranks in the column always hold at least a tile in the panel Ai, but all ranks in + //// the column are going to participate to the reduce. For them, it is important to set the + //// partial result W2 to zero. using namespace blas; // GEMM W2 = W* . X - for (const auto& index_tile : w.iteratorLocal()) + for (const auto& index_tile : w.iteratorLocal()) { ex::start_detached(dlaf::internal::whenAllLift(Op::ConjTrans, Op::NoTrans, T(1), w.read_sender(index_tile), x.read_sender(index_tile), - T(1), w2.readwrite_sender(LocalTileIndex(0, 0))) | + T(1), buffers.readwrite_sender(index_tile.row())) | tile::gemm(dlaf::internal::Policy(thread_priority::high))); + } + + ex::start_detached(tile::set0(dlaf::internal::Policy(), w2.readwrite_sender(LocalTileIndex(0, 0)))); + ex::start_detached(buffers.reduce(w2.readwrite_sender(LocalTileIndex(0, 0)))); } template @@ -959,7 +964,7 @@ common::internal::vector>> Reduc const LocalTileIndex t_idx(0, 0); // TODO used just by the column, maybe we can re-use a panel tile? // TODO probably the first one in any panel is ok? - Matrix t({nrefls_block, nrefls_block}, dist.blockSize()); + Matrix t({nrefls_block, nrefls_block}, {nrefls_block, nrefls_block}); computeTFactor(v, taus.back(), t.readwrite_sender(t_idx)); @@ -1107,7 +1112,7 @@ common::internal::vector>> Reduc const LocalTileIndex t_idx(0, 0); // TODO used just by the column, maybe we can re-use a panel tile? // TODO or we can keep just the sh_future and allocate just inside if (is_panel_rank_col) - matrix::Matrix t({nrefls_block, nrefls_block}, dist.blockSize()); + matrix::Matrix t({nrefls_block, nrefls_block}, {nrefls_block, nrefls_block}); // PANEL const matrix::SubPanelView panel_view(dist, ij_offset, band_size); diff --git a/include/dlaf/matrix/extra_buffers.h b/include/dlaf/matrix/extra_buffers.h new file mode 100644 index 0000000000..f2169da399 --- /dev/null +++ b/include/dlaf/matrix/extra_buffers.h @@ -0,0 +1,66 @@ +// +// Distributed Linear Algebra with Future (DLAF) +// +// Copyright (c) 2018-2023, ETH Zurich +// All rights reserved. +// +// Please, refer to the LICENSE file in the root directory. +// SPDX-License-Identifier: BSD-3-Clause +// + +#pragma once + +#include + +#include "dlaf/blas/tile_extensions.h" +#include "dlaf/matrix/matrix.h" +#include "dlaf/types.h" + +namespace dlaf { + +template +struct ExtraBuffers : protected Matrix { + ExtraBuffers(const TileElementSize bs, const SizeType size) + : Matrix{{bs.rows() * size, bs.cols()}, bs}, nbuffers_(size) { + namespace ex = pika::execution::experimental; + for (const auto& i : common::iterate_range2d(Matrix::distribution().localNrTiles())) + ex::start_detached(Matrix::readwrite_sender(i) | + tile::set0(dlaf::internal::Policy>( + pika::execution::thread_priority::high))); + } + + auto read_sender(SizeType index) { + return Matrix::read_sender(internalIndex(index)); + } + + auto readwrite_sender(SizeType index) { + return Matrix::readwrite_sender(internalIndex(index)); + } + + template + [[nodiscard]] auto reduce(TileSender tile) { + namespace di = dlaf::internal; + namespace ex = pika::execution::experimental; + + std::vector>>> buffers; + for (SizeType index = 0; index < nbuffers_; ++index) + buffers.emplace_back(read_sender(index)); + + return ex::when_all(std::move(tile), ex::when_all_vector(std::move(buffers))) | + di::transform(di::Policy>(), + [](const matrix::Tile& tile, + const std::vector>>& buffers, + auto&&... ts) { + for (const auto& buffer : buffers) + dlaf::tile::internal::add(T(1), buffer.get(), tile, ts...); + }); + } + +protected: + LocalTileIndex internalIndex(SizeType index) const noexcept { + return LocalTileIndex{index % nbuffers_, 0}; + } + + SizeType nbuffers_; +}; +} diff --git a/test/unit/matrix/CMakeLists.txt b/test/unit/matrix/CMakeLists.txt index 0899724581..1d3d70d449 100644 --- a/test/unit/matrix/CMakeLists.txt +++ b/test/unit/matrix/CMakeLists.txt @@ -99,3 +99,10 @@ DLAF_addTest( USE_MAIN MPIPIKA MPIRANKS 6 ) + +DLAF_addTest( + test_extra_buffers + SOURCES test_extra_buffers.cpp + LIBRARIES dlaf.core + USE_MAIN PIKA +) diff --git a/test/unit/matrix/test_extra_buffers.cpp b/test/unit/matrix/test_extra_buffers.cpp new file mode 100644 index 0000000000..a35f85377d --- /dev/null +++ b/test/unit/matrix/test_extra_buffers.cpp @@ -0,0 +1,42 @@ +// +// Distributed Linear Algebra with Future (DLAF) +// +// Copyright (c) 2018-2023, ETH Zurich +// All rights reserved. +// +// Please, refer to the LICENSE file in the root directory. +// SPDX-License-Identifier: BSD-3-Clause +// + +#include "dlaf/matrix/extra_buffers.h" + +#include + +#include "dlaf/common/range2d.h" +#include "dlaf/matrix/print_numpy.h" + +#include "dlaf_test/matrix/util_tile.h" + +using namespace dlaf; + +TEST(ExtraBuffersTest, Basic) { + using T = float; + constexpr auto D = Device::CPU; + + namespace ex = pika::execution::experimental; + namespace tt = pika::this_thread::experimental; + + TileElementSize tile_size(2, 2); + Matrix tile({tile_size.rows(), tile_size.cols()}, tile_size); + constexpr SizeType nbuffers = 10; + ExtraBuffers buffers(tile_size, nbuffers); + + for (SizeType i = 0; i < nbuffers; ++i) { + tt::sync_wait(ex::when_all(buffers.readwrite_sender(i), ex::just(T(1))) | + ex::then([](const auto& tile, const T value) { matrix::test::set(tile, value); })); + } + + ex::start_detached(buffers.reduce(tile.readwrite_sender(LocalTileIndex{0, 0}))); + + print(format::numpy{}, tile.read(LocalTileIndex(0, 0)).get()); +}