Skip to content
Closed
19 changes: 19 additions & 0 deletions include/nbl/builtin/hlsl/bitonic_sort/common.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#ifndef _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_
#define _NBL_BUILTIN_HLSL_BITONIC_SORT_COMMON_INCLUDED_

#include <nbl/builtin/hlsl/cpp_compat.hlsl>
#include <nbl/builtin/hlsl/concepts.hlsl>
#include <nbl/builtin/hlsl/math/intutil.hlsl>

namespace nbl
{
namespace hlsl
{
namespace bitonic_sort
{

}
}
}

#endif
136 changes: 136 additions & 0 deletions include/nbl/builtin/hlsl/subgroup/bitonic_sort.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#ifndef NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
#define NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_basic.hlsl"
#include "nbl/builtin/hlsl/glsl_compat/subgroup_shuffle.hlsl"
#include "nbl/builtin/hlsl/functional.hlsl"
namespace nbl
{
namespace hlsl
{
namespace subgroup
{

template<typename KeyType, typename ValueType, typename Comparator = less<KeyType> >
struct bitonic_sort_config
{
using key_t = KeyType;
using value_t = ValueType;
using comparator_t = Comparator;
};
template<bool Ascending, typename Config, class device_capabilities = void>
struct bitonic_sort;
template<bool Ascending, typename KeyType, typename ValueType, typename Comparator, class device_capabilities>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that Ascending is used because when moving onto workgroup you're going to need to call alternating subgroup sorts. However, as a front-facing API if I wanted a single subgroup shuffle I'd usually want it in the order specified by the Comparator. Maybe push it after the Config and give it a default value of true. Or better yet, since Ascending can be confusing, consider calling it ReverseOrder or something simpler that conveys the intent better

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ascending and later names like takeLarger implicitly assume the comparator is less (lo and hi don't, those are related to the "lane" order in the bitonic sort diagram). That's fine on its own, it makes the code more readable vs naming them with a more generic option. However, there should be comments mentioning that names assume this implicitly so there's no confusion.

struct bitonic_sort<Ascending, bitonic_sort_config<KeyType, ValueType, Comparator>, device_capabilities>
{
using config_t = bitonic_sort_config<KeyType, ValueType, Comparator>;
using key_t = typename config_t::key_t;
using value_t = typename config_t::value_t;
using comparator_t = typename config_t::comparator_t;
// Thread-level compare and swap (operates on lo/hi in registers)
static void compareAndSwap(bool ascending, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
comparator_t comp;
const bool shouldSwap = ascending ? comp(hiKey, loKey) : comp(loKey, hiKey);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The compiler is probably dumb and might not realize the right term is the negation of the left term. Ternaries in SPIR-V usually get compiled to an OpSelect which treats both terms after the ? not as branches to conditionally execute, but as operands whose result must be evaluated before the select operation runs. That is to say, if the compiler is stupid you're going to run two comparisons. If you make the right term the negation of the left one, CSE is likely to kick in and evaluate the comparison only once.

if (shouldSwap)
{
// Swap keys
key_t tempKey = loKey;
loKey = hiKey;
hiKey = tempKey;
// Swap values
value_t tempVal = loVal;
loVal = hiVal;
hiVal = tempVal;
}
Comment on lines +39 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make this branchless like you did the swaps in the subgroup branch

}


static void lastMergeStage(uint32_t stage, uint32_t invocationID, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the end this is just mergeStage with bitonicAscending = true, right? I think you can just have mergeStage and avoid having this function duplicated

NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
[unroll]
for (uint32_t pass = 0; pass <= stage; pass++)
{
const uint32_t stride = 1u << (stage - pass); // Element stride
const uint32_t threadStride = stride >> 1;
if (threadStride == 0)
{
// Local compare and swap for stage 0
compareAndSwap(Ascending, loKey, hiKey, loVal, hiVal);
}
else
{
// Shuffle from partner using XOR
const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loKey, threadStride);
const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiKey, threadStride);
const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loVal, threadStride);
const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiVal, threadStride);
comparator_t comp;
if (comp(loKey, pLoKey)) { loKey = pLoKey; loVal = pLoVal; }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unlike the other method, both threads keep the min elements? Like upperHalf is not being considred here, so I'm inclined to believe this function is going to fail. I'd delete this method and just use mergeStage, since this is just that method but with a forced bitonicAscending = true.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test this code to make sure it's right, but it feels wrong. Either way, just use mergeStage and avoid having this duped.

if (comp(hiKey, pHiKey)) { hiKey = pHiKey; hiVal = pHiVal; }

}

}
}

static void mergeStage(uint32_t stage, bool bitonicAscending, uint32_t invocationID, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
[unroll]
for (uint32_t pass = 0; pass <= stage; pass++)
{
const uint32_t stride = 1u << (stage - pass); // Element stride
const uint32_t threadStride = stride >> 1;
if (threadStride == 0)
{
// Local compare and swap for stage 0
compareAndSwap(bitonicAscending, loKey, hiKey, loVal, hiVal);
}
else
{
// Shuffle from partner using XOR
const key_t pLoKey = glsl::subgroupShuffleXor<key_t>(loKey, threadStride);
const key_t pHiKey = glsl::subgroupShuffleXor<key_t>(hiKey, threadStride);
const value_t pLoVal = glsl::subgroupShuffleXor<value_t>(loVal, threadStride);
const value_t pHiVal = glsl::subgroupShuffleXor<value_t>(hiVal, threadStride);
// Determine if we're upper or lower half
const bool upperHalf = bool(invocationID & threadStride);
const bool takeLarger = upperHalf == bitonicAscending;
comparator_t comp;
if (takeLarger)
{
if (comp(loKey, pLoKey)) { loKey = pLoKey; loVal = pLoVal; }
if (comp(hiKey, pHiKey)) { hiKey = pHiKey; hiVal = pHiVal; }
}
else
{
if (comp(pLoKey, loKey)) { loKey = pLoKey; loVal = pLoVal; }
if (comp(pHiKey, hiKey)) { hiKey = pHiKey; hiVal = pHiVal; }
}
}
}
}

static void __call(NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
const uint32_t invocationID = glsl::gl_SubgroupInvocationID();
const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2();
[unroll]
for (uint32_t stage = 0; stage < subgroupSizeLog2; stage++)
{
const bool bitonicAscending = (stage == subgroupSizeLog2) ? Ascending : !bool(invocationID & (1u << stage));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stage == subgroupSizeLog2is never true in this loop, so just assign the term for the false clause.

mergeStage(stage, bitonicAscending, invocationID, loKey, hiKey, loVal, hiVal);
}
lastMergeStage(subgroupSizeLog2, invocationID, loKey, hiKey, loVal, hiVal);

}
};

}
}
}
#endif
74 changes: 74 additions & 0 deletions include/nbl/builtin/hlsl/workgroup/bitonic_sort.hlsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#ifndef NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
#define NBL_BUILTIN_HLSL_SUBGROUP_BITONIC_SORT_INCLUDED
#include "nbl/builtin/hlsl/bitonic_sort/common.hlsl"
#include "nbl/builtin/hlsl/functional.hlsl"
namespace nbl
{
namespace hlsl
{
namespace workgroup
{
namespace bitonic_sort
{
template<typename KeyType, typename ValueType, typename Comparator = less<KeyType> >
struct bitonic_sort_config
{
using key_t = KeyType;
using value_t = ValueType;
using comparator_t = Comparator;
};

template<bool Ascending, typename Config, class device_capabilities = void>
struct bitonic_sort;
template<bool Ascending, typename KeyType, typename ValueType, typename Comparator, class device_capabilities>
struct bitonic_sort<Ascending, bitonic_sort_config<KeyType, ValueType, Comparator>, device_capabilities>
{
using config_t = bitonic_sort_config<KeyType, ValueType, Comparator>;
using key_t = typename config_t::key_t;
using value_t = typename config_t::value_t;
using comparator_t = typename config_t::comparator_t;

using SortConfig = subgroup::bitonic_sort_config<uint32_t, uint32_t, less<uint32_t> >;


static void mergeWGStage(uint32_t stage, bool bitonicAscending, uint32_t invocationID, NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is already in the workgroup namespace, you can just call it mergeStage

NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
[unroll]
for (uint32_t pass = 0; pass <= stage; pass++)
{
const uint32_t stride = 1u << ((stage - pass) + subgroupSizeLog2); // Element stride shifts to inter-subgroup scale
// Shuffle from partner using WG XOR need to implument
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have a workgroup shuffle:

void shuffleXor(NBL_REF_ARG(T) value, uint32_t mask, NBL_REF_ARG(SharedMemoryAdaptor) sharedmemAdaptor)

You need to template this bitonic sort workgroup struct on a shared memory accessor. You can either do the shuffle in two rounds (shuffle keys -> barrier -> shuffle values -> barrier) or in a single round, by shuffling a pair of both key and value together. Ideally this would be a setting you can choose as well. For now settle on one and leave a comment on how we can also consider the other case down the line (two rounds is probably easier at this stage, and useful for the bigger array sizes)


}
}


static void __call(NBL_REF_ARG(key_t) loKey, NBL_REF_ARG(key_t) hiKey,
NBL_REF_ARG(value_t) loVal, NBL_REF_ARG(value_t) hiVal)
{
const uint32_t invocationID = glsl::gl_SubgroupInvocationID();
const uint32_t subgroupSizeLog2 = glsl::gl_SubgroupSizeLog2();

//first sort all subgroup inside wg
subgroup::bitonic_sort<true, SortConfig>::__call(loKey, hiKey, loVal, hiVal);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ascending should be a parameter you pass to __call and not a template parameter. That way, you can control whether this starting subgroup sort is ascending or descending (notice that whenever you do a bigger-than-subgroup sort, some subgroup sorts are descending and some ascending, depending on the parity of the subgroupID. This is a condition you can't control from compiletime since the subgroupID is only known at runtime

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also true of the workgroup struct btw. When you try to do virtual threading down the line, you will have to do the same thing based on "virtual workgroup ID", which will be known at runtime. So nbl::hlsl::workgroup::bitonic_sort::bitonic_sort should not be templated on Ascensing, it should instead be a parameter you pass to __call

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On that note, nbl::hlsl::workgroup::bitonic_sort::bitonic_sort is an ugly namespace. The struct you use to run a bitonic sort should be nbl::hlsl::workgroup::BitonicSort (and similarly for subgroup). nbl::hlsl::workgroup::bitonic_sort should have structs related to the bitonic sort, but not the functional struct itself.

In the FFT, for example, we have the struct nbl::hlsl::workgroup::FFT to run the FFT, and nbl::hlsl::workgroup::fft is a namespace hat has structs useful for running an FFT, such as the config struct it's templated on

//then we go over first work group shuffle
//we have n = log2(x), where n is how many wgshuffle we have to do on x(subgroup num)

[unroll]
for (uint32_t stage = 1; stage <= n; ++stage)
{
mergeWGStage(stage, Ascending, invocationID, hiKey, loKey, loVal, hiVal);
subgroup::bitonic_sort<true, SortConfig>::lastMergeStage(subgroupSizeLog2, invocationIDloKey, hiKey, loKey,loVal, hiVal);
workgroupExecutionAndMemoryBarrier();
}


}
};

}
}
}
}
#endif