Skip to content

Commit f968289

Browse files
committed
ck-builder: test system initial prototype
1 parent 381929b commit f968289

File tree

5 files changed

+274
-5
lines changed

5 files changed

+274
-5
lines changed

experimental/builder/include/ck_tile/builder/testing/conv_args.hpp

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#pragma once
55

66
#include "ck_tile/builder/conv_signature_concepts.hpp"
7+
#include "ck/library/utility/convolution_parameter.hpp"
78

89
namespace ck_tile::builder::test {
910

@@ -12,8 +13,53 @@ struct FilterExtent
1213
ck::index_t width = 1;
1314
ck::index_t height = 1;
1415
ck::index_t depth = 1;
16+
17+
template <int SPATIAL_DIM>
18+
std::vector<ck::index_t> to_vector() const
19+
{
20+
if constexpr(SPATIAL_DIM == 1)
21+
{
22+
return {std::initializer_list<ck::index_t>{this->width}};
23+
}
24+
else if constexpr(SPATIAL_DIM == 2)
25+
{
26+
return {{this->height, this->width}};
27+
}
28+
else if constexpr(SPATIAL_DIM == 3)
29+
{
30+
return {{this->depth, this->height, this->width}};
31+
}
32+
}
1533
};
1634

35+
template <int SPATIAL_DIM>
36+
std::array<ck::index_t, SPATIAL_DIM + 3> to_ck_lengths(const std::array<ck::index_t, 3>& gnc,
37+
const FilterExtent& whd)
38+
{
39+
std::array<ck::index_t, SPATIAL_DIM + 3> result = {0};
40+
result[0] = gnc[0];
41+
result[1] = gnc[1];
42+
result[2] = gnc[2];
43+
44+
if constexpr(SPATIAL_DIM == 1)
45+
{
46+
result[3] = whd.width;
47+
}
48+
else if constexpr(SPATIAL_DIM == 2)
49+
{
50+
result[3] = whd.height;
51+
result[4] = whd.width;
52+
}
53+
else if constexpr(SPATIAL_DIM == 3)
54+
{
55+
result[3] = whd.depth;
56+
result[4] = whd.height;
57+
result[5] = whd.width;
58+
}
59+
60+
return result;
61+
}
62+
1763
struct TensorExtent
1864
{
1965
ck::index_t batch_size = 1; // N
@@ -39,6 +85,21 @@ struct ConvArgs
3985
FilterExtent filter_dilation;
4086
FilterExtent input_left_pad;
4187
FilterExtent input_right_pad;
88+
89+
ck::utils::conv::ConvParam to_conv_param() const
90+
{
91+
return ck::utils::conv::ConvParam(SPATIAL_DIM,
92+
this->lengths.groups,
93+
this->lengths.batch_size,
94+
this->lengths.output_channels,
95+
this->lengths.input_channels,
96+
this->lengths.filter.to_vector<SPATIAL_DIM>(),
97+
this->lengths.image.to_vector<SPATIAL_DIM>(),
98+
this->filter_strides.to_vector<SPATIAL_DIM>(),
99+
this->filter_dilation.to_vector<SPATIAL_DIM>(),
100+
this->input_left_pad.to_vector<SPATIAL_DIM>(),
101+
this->input_right_pad.to_vector<SPATIAL_DIM>(), );
102+
}
42103
};
43104

44105
} // namespace ck_tile::builder::test

experimental/builder/include/ck_tile/builder/testing/tensor_memory_manager.hpp

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,13 @@
44
#pragma once
55

66
#include <memory>
7+
#include <numeric>
78
#include <hip/hip_runtime.h>
89
#include "ck_tile/builder/conv_signature_concepts.hpp"
10+
#include "ck_tile/builder/testing/conv_args.hpp"
11+
#include "ck_tile/builder/testing/type_traits.hpp"
12+
#include "ck_tile/builder/conv_factory.hpp"
13+
#include "ck/library/utility/convolution_host_tensor_descriptor_helper.hpp"
914

1015
namespace ck_tile::builder::test {
1116

@@ -20,17 +25,58 @@ struct DeviceMemoryDeleter
2025

2126
using DeviceBuffer = std::unique_ptr<std::byte[], DeviceMemoryDeleter>;
2227

28+
template <DataType DT>
29+
DeviceBuffer alloc_tensor(ck::HostTensorDescriptor descriptor)
30+
{
31+
const auto total_elements = descriptor.GetElementSpaceSize();
32+
const auto total_size = total_elements * sizeof(typename DataTypeTraits<DT>::Type);
33+
34+
std::byte* d_buf = nullptr;
35+
const auto status = hipMalloc(&d_buf, total_size);
36+
// TODO(Robin): How to check error without relying on google test?
37+
// Ideally we get some sort of trace here, but thats not possible until c++23.
38+
// For now just throw a runtime error.
39+
if(status != hipSuccess)
40+
{
41+
throw std::runtime_error("failed to load hip memory");
42+
}
43+
return DeviceBuffer(d_buf);
44+
}
45+
2346
template <auto SIGNATURE>
2447
requires ValidConvSignature<SIGNATURE>
2548
struct TensorMemoryManager
2649
{
2750
// Type aliases for tensor data types
2851
// For now, all tensors use the same data type from the signature
29-
using InputDataType = decltype(SIGNATURE.data_type);
30-
using WeightDataType = decltype(SIGNATURE.data_type);
31-
using OutputDataType = decltype(SIGNATURE.data_type);
52+
using InputDataType = DataTypeTraits<SIGNATURE.data_type>::Type;
53+
using WeightDataType = DataTypeTraits<SIGNATURE.data_type>::Type;
54+
using OutputDataType = DataTypeTraits<SIGNATURE.data_type>::Type;
55+
56+
using Layouts =
57+
decltype(ck_tile::builder::factory_internal::GetTensorLayout<SIGNATURE.layout,
58+
SIGNATURE.spatial_dim,
59+
ConvDirection::FORWARD>());
60+
61+
TensorMemoryManager(const ConvArgs<SIGNATURE>& args)
62+
: param(args.to_conv_param()),
63+
input_descriptor(ck::utils::conv::make_input_host_tensor_descriptor_g_n_c_wis_packed<
64+
typename Layouts::ALayout>(this->param)),
65+
weight_descriptor(ck::utils::conv::make_weight_host_tensor_descriptor_g_k_c_xs_packed<
66+
typename Layouts::BLayout>(this->param)),
67+
output_descriptor(ck::utils::conv::make_output_host_tensor_descriptor_g_n_k_wos_packed<
68+
typename Layouts::ELayout>(this->param)),
69+
input_buf(alloc_tensor<SIGNATURE.data_type>(this->input_descriptor)),
70+
weight_buf(alloc_tensor<SIGNATURE.data_type>(this->weight_descriptor)),
71+
output_buf(alloc_tensor<SIGNATURE.data_type>(this->output_descriptor))
72+
{
73+
}
74+
75+
ck::utils::conv::ConvParam param;
3276

33-
TensorMemoryManager() = default;
77+
ck::HostTensorDescriptor input_descriptor;
78+
ck::HostTensorDescriptor weight_descriptor;
79+
ck::HostTensorDescriptor output_descriptor;
3480

3581
// Device memory buffers
3682
DeviceBuffer input_buf = nullptr;
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// Copyright (C) Advanced Micro Devices, Inc. All rights reserved.
2+
// SPDX-License-Identifier: MIT
3+
4+
#pragma once
5+
6+
#include <cstddef>
7+
#include "ck_tile/builder/types.hpp"
8+
#include "ck_tile/ops/common/tensor_layout.hpp"
9+
10+
// TODO(Robin): Test to check that all DataType variants are covered?
11+
// TODO(Robin): Put this file somewhere else?
12+
13+
namespace ck_tile::builder::test {
14+
15+
/// This structure contains some useful traits for CK-Builder's DataType
16+
/// type. Its main usecase is to convert a CK-Builder DataType into an
17+
/// equivalent C++ type.
18+
template <DataType DT>
19+
struct DataTypeTraits;
20+
21+
template <>
22+
struct DataTypeTraits<DataType::FP32>
23+
{
24+
using Type = float;
25+
};
26+
27+
template <>
28+
struct DataTypeTraits<DataType::FP16>
29+
{
30+
using Type = ck::half_t;
31+
};
32+
33+
template <>
34+
struct DataTypeTraits<DataType::BF16>
35+
{
36+
using Type = ck::bhalf_t;
37+
};
38+
39+
template <>
40+
struct DataTypeTraits<DataType::FP8>
41+
{
42+
using Type = ck::f8_t;
43+
};
44+
45+
template <>
46+
struct DataTypeTraits<DataType::I8>
47+
{
48+
using Type = int8_t;
49+
};
50+
51+
template <>
52+
struct DataTypeTraits<DataType::U8>
53+
{
54+
using Type = uint8_t;
55+
};
56+
57+
} // namespace ck_tile::builder::test

experimental/builder/test/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ add_ck_builder_test(test_ckb_conv_builder
2323
test_bwd_weight_instance_traits.cpp
2424
test_instance_traits_util.cpp
2525
test_memory_manager.cpp)
26-
2726
add_ck_builder_test(test_ckb_inline_diff test_inline_diff.cpp)
2827

2928
# Testing the virtual GetInstanceString methods requires kernel compilation.
@@ -51,6 +50,7 @@ add_ck_builder_test(test_ckb_build_fwd_instances
5150
conv/test_ckb_conv_fwd_3d_fp16.cpp
5251
conv/test_ckb_conv_fwd_3d_fp32.cpp
5352
)
53+
target_link_libraries(test_ckb_build_fwd_instances PRIVATE utility)
5454

5555
function(add_ck_factory_test test_name)
5656
add_ck_builder_test(${test_name} ${ARGN})

experimental/builder/test/conv/test_ckb_conv_fwd_2d_dl_fp16.cpp

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,16 @@
33

44
#include "utils/ckb_conv_test_configs.hpp"
55
#include "utils/ckb_conv_test_utils.hpp"
6+
#include "ck_tile/builder/testing/tensor_memory_manager.hpp"
7+
#include "ck_tile/builder/testing/conv_args.hpp"
68

79
namespace {
810

911
using namespace ck_tile::builder::test_utils;
1012

13+
namespace ckb = ck_tile::builder;
14+
namespace ckt = ck_tile::builder::test;
15+
1116
TEST(FwdConvInstances, Create_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK_Instance_2D_FP16_GNHWC)
1217
{
1318
constexpr ConvSignature FwdConvSignature{.spatial_dim = 2,
@@ -54,4 +59,104 @@ TEST(FwdConvInstances,
5459
{"DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK", "256, 128, 128, 16", "Filter1x1Pad0"});
5560
}
5661

62+
TEST(FwdConvInstances, Fp16_2D_DL_EndToEndBasic)
63+
{
64+
constexpr ConvSignature signature{.spatial_dim = 2,
65+
.direction = ConvDirection::FORWARD,
66+
.layout = GroupConvLayout2D::GNHWC_GKYXC_GNHWK,
67+
.data_type = DataType::FP16,
68+
.elementwise_operation = ElementwiseOperation::PASS_THROUGH};
69+
70+
constexpr auto algorithm =
71+
ConvAlgorithm_DeviceGroupedConvFwdDlMultipleD_NHWC_KYXC_NHWK{}
72+
.with_thread_block(FwdThreadBlock_256_128x128x16)
73+
.with_specializations(ConvFwdSpecialization::DEFAULT, GemmSpecialization::MNKPadding)
74+
.with_dl_thread_config(DlThreadConfig_16x2x4x4x1)
75+
.with_dl_thread_cluster(DlThreadCluster_8x2)
76+
.with_dl_transfer(DlFwdTransfer);
77+
78+
auto args = ckt::ConvArgs<signature>{
79+
.lengths =
80+
{
81+
.batch_size = 16,
82+
.groups = 1,
83+
.input_channels = 32,
84+
.output_channels = 16,
85+
.image =
86+
{
87+
.width = 56,
88+
.height = 56,
89+
},
90+
.filter =
91+
{
92+
.width = 3,
93+
.height = 3,
94+
},
95+
},
96+
.filter_strides = {.width = 1, .height = 1},
97+
.filter_dilation = {.width = 1, .height = 1},
98+
.input_left_pad = {.width = 0, .height = 0},
99+
.input_right_pad = {.width = 0, .height = 0},
100+
};
101+
102+
auto tmm = ckt::TensorMemoryManager<signature>(args);
103+
104+
auto conv = ConvBuilder<signature, algorithm>::Instance{};
105+
106+
auto invoker = conv.MakeInvoker();
107+
108+
const auto input_desc = tmm.input_descriptor;
109+
const auto weight_desc = tmm.weight_descriptor;
110+
const auto output_desc = tmm.output_descriptor;
111+
112+
std::array<ck::index_t, 2 + 3> input_lengths;
113+
std::array<ck::index_t, 2 + 3> input_strides;
114+
std::array<ck::index_t, 2 + 3> weight_lengths;
115+
std::array<ck::index_t, 2 + 3> weight_strides;
116+
std::array<ck::index_t, 2 + 3> output_lengths;
117+
std::array<ck::index_t, 2 + 3> output_strides;
118+
std::array<ck::index_t, 2> conv_filter_strides;
119+
std::array<ck::index_t, 2> conv_filter_dilations;
120+
std::array<ck::index_t, 2> input_left_pads;
121+
std::array<ck::index_t, 2> input_right_pads;
122+
123+
auto copy = [](auto& src, auto& dst) { std::copy(src.begin(), src.end(), dst.begin()); };
124+
125+
copy(input_desc.GetLengths(), input_lengths);
126+
copy(input_desc.GetStrides(), input_strides);
127+
copy(weight_desc.GetLengths(), weight_lengths);
128+
copy(weight_desc.GetStrides(), weight_strides);
129+
copy(output_desc.GetLengths(), output_lengths);
130+
copy(output_desc.GetStrides(), output_strides);
131+
132+
copy(tmm.param.conv_filter_strides_, conv_filter_strides);
133+
copy(tmm.param.conv_filter_dilations_, conv_filter_dilations);
134+
copy(tmm.param.input_left_pads_, input_left_pads);
135+
copy(tmm.param.input_right_pads_, input_right_pads);
136+
137+
auto argument = conv.MakeArgument(tmm.input_buf.get(),
138+
tmm.weight_buf.get(),
139+
{},
140+
tmm.output_buf.get(),
141+
input_lengths,
142+
input_strides,
143+
weight_lengths,
144+
weight_strides,
145+
{},
146+
{},
147+
output_lengths,
148+
output_strides,
149+
conv_filter_strides,
150+
conv_filter_dilations,
151+
input_left_pads,
152+
input_right_pads,
153+
ck::tensor_operation::element_wise::PassThrough{},
154+
ck::tensor_operation::element_wise::PassThrough{},
155+
ck::tensor_operation::element_wise::PassThrough{});
156+
157+
ASSERT_THAT(conv.IsSupportedArgument(argument), testing::IsTrue());
158+
159+
invoker.Run(argument, {});
160+
}
161+
57162
} // namespace

0 commit comments

Comments
 (0)