Skip to content

Commit b3ea33e

Browse files
ezhulenevGoogle-ML-Automation
authored andcommitted
[xla:ffi] NFC: Use AttrTag<T> for tagging regular arguments
PiperOrigin-RevId: 825769996
1 parent 3d0e631 commit b3ea33e

File tree

2 files changed

+30
-48
lines changed

2 files changed

+30
-48
lines changed

xla/ffi/api/api.h

Lines changed: 21 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,10 @@ namespace internal {
488488
// parameter packs. We need this to be able to pattern match FFI handler
489489
// signature at compile time.
490490

491+
// A type tag for decoding argument.
492+
template <typename T>
493+
struct ArgTag {};
494+
491495
// A type tag for decoding optional argument.
492496
template <typename T>
493497
struct OptionalArgTag {};
@@ -524,23 +528,23 @@ template <typename T>
524528
struct CtxTag {};
525529

526530
//----------------------------------------------------------------------------//
527-
// A template for counting tagged arguments in the Ts pack (i.e. attributes).
531+
// A template for counting tagged arguments in the Ts pack.
528532
//----------------------------------------------------------------------------//
529533

530-
template <template <typename> class Tag, typename... Ts>
534+
template <template <typename> typename Tag, typename... Ts>
531535
struct NumTagged;
532536

533-
template <template <typename> class Tag>
537+
template <template <typename> typename Tag>
534538
struct NumTagged<Tag> {
535539
static constexpr int64_t value = 0;
536540
};
537541

538-
template <template <typename> class Tag, typename T, typename... Ts>
542+
template <template <typename> typename Tag, typename T, typename... Ts>
539543
struct NumTagged<Tag, Tag<T>, Ts...> {
540544
static constexpr int64_t value = 1 + NumTagged<Tag, Ts...>::value;
541545
};
542546

543-
template <template <typename> class Tag, typename T, typename... Ts>
547+
template <template <typename> typename Tag, typename T, typename... Ts>
544548
struct NumTagged<Tag, T, Ts...> {
545549
static constexpr int64_t value = 0 + NumTagged<Tag, Ts...>::value;
546550
};
@@ -622,7 +626,7 @@ template <ExecutionStage stage, typename... Ts>
622626
class Binding {
623627
public:
624628
template <typename T>
625-
Binding<stage, Ts..., T> Arg() && {
629+
Binding<stage, Ts..., internal::ArgTag<T>> Arg() && {
626630
static_assert(!internal::HasOptionalArgTag<Ts...>::value,
627631
"argument can't be passed after optional argument");
628632
static_assert(!internal::HasRemainingArgsTag<Ts...>::value,
@@ -1159,7 +1163,10 @@ struct DecodingContext {
11591163
};
11601164

11611165
template <typename T>
1162-
struct Decode {
1166+
struct Decode;
1167+
1168+
template <typename T>
1169+
struct Decode<ArgTag<T>> {
11631170
XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
11641171
static std::optional<T> call(DecodingOffsets& offsets, DecodingContext& ctx,
11651172
DiagnosticEngine& diagnostic) {
@@ -1178,7 +1185,7 @@ struct Decode<OptionalArgTag<T>> {
11781185
if (XLA_FFI_PREDICT_FALSE(offsets.args >= ctx.call_frame->args.size)) {
11791186
return std::optional<T>(std::nullopt);
11801187
}
1181-
return Decode<T>::call(offsets, ctx, diagnostic);
1188+
return Decode<ArgTag<T>>::call(offsets, ctx, diagnostic);
11821189
}
11831190
};
11841191

@@ -1464,7 +1471,10 @@ class RemainingRets;
14641471
namespace internal {
14651472
// A helper struct to extract the type of the handler argument.
14661473
template <typename T>
1467-
struct FnArgType {
1474+
struct FnArgType;
1475+
1476+
template <typename T>
1477+
struct FnArgType<internal::ArgTag<T>> {
14681478
using Type = T;
14691479
};
14701480

@@ -1508,44 +1518,6 @@ struct FnArgType<internal::CtxTag<T>> {
15081518
using Type = typename CtxDecoding<T>::Type;
15091519
};
15101520

1511-
// A template for checking if type in a parameter pack is a tagged one and has
1512-
// a special decoding rule defined by template specialization.
1513-
template <typename>
1514-
struct IsTagged : std::false_type {};
1515-
1516-
template <typename T>
1517-
struct IsTagged<OptionalArgTag<T>> : std::true_type {};
1518-
template <typename T>
1519-
struct IsTagged<RetTag<T>> : std::true_type {};
1520-
template <typename T>
1521-
struct IsTagged<OptionalRetTag<T>> : std::true_type {};
1522-
template <typename T>
1523-
struct IsTagged<AttrTag<T>> : std::true_type {};
1524-
template <typename T>
1525-
struct IsTagged<AttrsTag<T>> : std::true_type {};
1526-
template <typename T>
1527-
struct IsTagged<CtxTag<T>> : std::true_type {};
1528-
1529-
template <>
1530-
struct IsTagged<RemainingArgsTag> : std::true_type {};
1531-
template <>
1532-
struct IsTagged<RemainingRetsTag> : std::true_type {};
1533-
1534-
// A template for counting regular arguments in the Ts pack (arguments that are
1535-
// not wrapped into a special tag).
1536-
template <typename... Ts>
1537-
struct NumArgs;
1538-
1539-
template <>
1540-
struct NumArgs<> {
1541-
static constexpr int64_t value = 0;
1542-
};
1543-
1544-
template <typename T, typename... Ts>
1545-
struct NumArgs<T, Ts...> {
1546-
static constexpr int64_t value = !IsTagged<T>::value + NumArgs<Ts...>::value;
1547-
};
1548-
15491521
// A template to detect result encodings that are state constructors. We use
15501522
// this to report back the TypeId of the state as a part of the metadata.
15511523
template <typename ResultEnconding, typename = void>
@@ -1574,7 +1546,8 @@ template <ExecutionStage stage, typename Fn, typename... Ts>
15741546
class Handler : public Ffi {
15751547
static constexpr int64_t kSize = sizeof...(Ts);
15761548

1577-
static constexpr int64_t kNumArgs = internal::NumArgs<Ts...>::value;
1549+
static constexpr int64_t kNumArgs =
1550+
internal::NumTagged<internal::ArgTag, Ts...>::value;
15781551

15791552
static constexpr int64_t kNumOptionalArgs =
15801553
internal::NumTagged<internal::OptionalArgTag, Ts...>::value;

xla/ffi/api/ffi_test.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ limitations under the License.
6161

6262
namespace xla::ffi {
6363

64+
using xla::ffi::internal::ArgTag;
65+
using xla::ffi::internal::NumTagged;
66+
using xla::ffi::internal::RetTag;
67+
68+
// Compile-time test for the template metaprogramming for counting tags.
69+
static_assert(NumTagged<ArgTag, RetTag<int32_t>>::value == 0);
70+
static_assert(NumTagged<ArgTag, ArgTag<int32_t>>::value == 1);
71+
static_assert(NumTagged<ArgTag, ArgTag<int32_t>, RetTag<int32_t>>::value == 1);
72+
6473
enum class Int32BasedEnum : int32_t {
6574
kOne = 1,
6675
kTwo = 2,

0 commit comments

Comments
 (0)