@@ -488,6 +488,10 @@ namespace internal {
488488// parameter packs. We need this to be able to pattern match FFI handler
489489// signature at compile time.
490490
491+ // A type tag for decoding argument.
492+ template <typename T>
493+ struct ArgTag {};
494+
491495// A type tag for decoding optional argument.
492496template <typename T>
493497struct OptionalArgTag {};
@@ -524,23 +528,23 @@ template <typename T>
524528struct CtxTag {};
525529
526530// ----------------------------------------------------------------------------//
527- // A template for counting tagged arguments in the Ts pack (i.e. attributes) .
531+ // A template for counting tagged arguments in the Ts pack.
528532// ----------------------------------------------------------------------------//
529533
530- template <template <typename > class Tag , typename ... Ts>
534+ template <template <typename > typename Tag, typename ... Ts>
531535struct NumTagged ;
532536
533- template <template <typename > class Tag >
537+ template <template <typename > typename Tag>
534538struct NumTagged <Tag> {
535539 static constexpr int64_t value = 0 ;
536540};
537541
538- template <template <typename > class Tag , typename T, typename ... Ts>
542+ template <template <typename > typename Tag, typename T, typename ... Ts>
539543struct NumTagged <Tag, Tag<T>, Ts...> {
540544 static constexpr int64_t value = 1 + NumTagged<Tag, Ts...>::value;
541545};
542546
543- template <template <typename > class Tag , typename T, typename ... Ts>
547+ template <template <typename > typename Tag, typename T, typename ... Ts>
544548struct NumTagged <Tag, T, Ts...> {
545549 static constexpr int64_t value = 0 + NumTagged<Tag, Ts...>::value;
546550};
@@ -622,7 +626,7 @@ template <ExecutionStage stage, typename... Ts>
622626class Binding {
623627 public:
624628 template <typename T>
625- Binding<stage, Ts..., T > Arg () && {
629+ Binding<stage, Ts..., internal::ArgTag<T> > Arg () && {
626630 static_assert (!internal::HasOptionalArgTag<Ts...>::value,
627631 " argument can't be passed after optional argument" );
628632 static_assert (!internal::HasRemainingArgsTag<Ts...>::value,
@@ -1159,7 +1163,10 @@ struct DecodingContext {
11591163};
11601164
11611165template <typename T>
1162- struct Decode {
1166+ struct Decode ;
1167+
1168+ template <typename T>
1169+ struct Decode <ArgTag<T>> {
11631170 XLA_FFI_ATTRIBUTE_ALWAYS_INLINE
11641171 static std::optional<T> call (DecodingOffsets& offsets, DecodingContext& ctx,
11651172 DiagnosticEngine& diagnostic) {
@@ -1178,7 +1185,7 @@ struct Decode<OptionalArgTag<T>> {
11781185 if (XLA_FFI_PREDICT_FALSE (offsets.args >= ctx.call_frame ->args .size )) {
11791186 return std::optional<T>(std::nullopt );
11801187 }
1181- return Decode<T >::call (offsets, ctx, diagnostic);
1188+ return Decode<ArgTag<T> >::call (offsets, ctx, diagnostic);
11821189 }
11831190};
11841191
@@ -1464,7 +1471,10 @@ class RemainingRets;
14641471namespace internal {
14651472// A helper struct to extract the type of the handler argument.
14661473template <typename T>
1467- struct FnArgType {
1474+ struct FnArgType ;
1475+
1476+ template <typename T>
1477+ struct FnArgType <internal::ArgTag<T>> {
14681478 using Type = T;
14691479};
14701480
@@ -1508,44 +1518,6 @@ struct FnArgType<internal::CtxTag<T>> {
15081518 using Type = typename CtxDecoding<T>::Type;
15091519};
15101520
1511- // A template for checking if type in a parameter pack is a tagged one and has
1512- // a special decoding rule defined by template specialization.
1513- template <typename >
1514- struct IsTagged : std::false_type {};
1515-
1516- template <typename T>
1517- struct IsTagged <OptionalArgTag<T>> : std::true_type {};
1518- template <typename T>
1519- struct IsTagged <RetTag<T>> : std::true_type {};
1520- template <typename T>
1521- struct IsTagged <OptionalRetTag<T>> : std::true_type {};
1522- template <typename T>
1523- struct IsTagged <AttrTag<T>> : std::true_type {};
1524- template <typename T>
1525- struct IsTagged <AttrsTag<T>> : std::true_type {};
1526- template <typename T>
1527- struct IsTagged <CtxTag<T>> : std::true_type {};
1528-
1529- template <>
1530- struct IsTagged <RemainingArgsTag> : std::true_type {};
1531- template <>
1532- struct IsTagged <RemainingRetsTag> : std::true_type {};
1533-
1534- // A template for counting regular arguments in the Ts pack (arguments that are
1535- // not wrapped into a special tag).
1536- template <typename ... Ts>
1537- struct NumArgs ;
1538-
1539- template <>
1540- struct NumArgs <> {
1541- static constexpr int64_t value = 0 ;
1542- };
1543-
1544- template <typename T, typename ... Ts>
1545- struct NumArgs <T, Ts...> {
1546- static constexpr int64_t value = !IsTagged<T>::value + NumArgs<Ts...>::value;
1547- };
1548-
15491521// A template to detect result encodings that are state constructors. We use
15501522// this to report back the TypeId of the state as a part of the metadata.
15511523template <typename ResultEnconding, typename = void >
@@ -1574,7 +1546,8 @@ template <ExecutionStage stage, typename Fn, typename... Ts>
15741546class Handler : public Ffi {
15751547 static constexpr int64_t kSize = sizeof ...(Ts);
15761548
1577- static constexpr int64_t kNumArgs = internal::NumArgs<Ts...>::value;
1549+ static constexpr int64_t kNumArgs =
1550+ internal::NumTagged<internal::ArgTag, Ts...>::value;
15781551
15791552 static constexpr int64_t kNumOptionalArgs =
15801553 internal::NumTagged<internal::OptionalArgTag, Ts...>::value;
0 commit comments