Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

single pass scan group virtualization #2041

Merged
merged 4 commits into from
Nov 17, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
single pass scan group virtualization
sortraev committed Nov 3, 2023
commit 6c280fd6463667f06ee531c1d9a1afe70daef84a
52 changes: 21 additions & 31 deletions src/Futhark/CodeGen/ImpGen/GPU/SegScan.hs
Original file line number Diff line number Diff line change
@@ -14,8 +14,8 @@ import Futhark.IR.GPUMem

-- The single-pass scan does not support multiple operators, so jam
-- them together here.
combineScans :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScans ops =
combineScanOps :: [SegBinOp GPUMem] -> SegBinOp GPUMem
combineScanOps ops =
SegBinOp
{ segBinOpComm = mconcat (map segBinOpComm ops),
segBinOpLambda = lam',
@@ -48,30 +48,18 @@ bodyHas f = any (f' . stmExp) . bodyStms
{ walkOnBody = const $ guard . not . bodyHas f
}

canBeSinglePass :: [SegBinOp GPUMem] -> KernelBody GPUMem -> Maybe (SegBinOp GPUMem)
canBeSinglePass ops kbody
| all ok ops,
not $ bodyHas freshArray (Body () (kernelBodyStms kbody) []) =
Just $ combineScans ops
| otherwise =
Nothing
canBeSinglePass :: [SegBinOp GPUMem] -> Maybe (SegBinOp GPUMem)
canBeSinglePass scan_ops =
if all ok scan_ops
then Just $ combineScanOps scan_ops
else Nothing
where
ok op =
segBinOpShape op == mempty
&& all primType (lambdaReturnType (segBinOpLambda op))
&& not (bodyHas isAssert (lambdaBody (segBinOpLambda op)))
isAssert (BasicOp Assert {}) = True
isAssert _ = False
-- XXX: Currently single pass scans cannot handle construction of
-- arrays in the kernel body (#2013), because of insufficient
-- memory expansion. This can in principle be fixed.
freshArray (BasicOp Manifest {}) = True
freshArray (BasicOp Iota {}) = True
freshArray (BasicOp Replicate {}) = True
freshArray (BasicOp Scratch {}) = True
freshArray (BasicOp Concat {}) = True
freshArray (BasicOp ArrayLit {}) = True
freshArray _ = False

-- | Compile 'SegScan' instance to host-level code with calls to
-- various kernels.
@@ -82,17 +70,19 @@ compileSegScan ::
[SegBinOp GPUMem] ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegScan pat lvl space scans kbody = sWhen (0 .<. n) $ do
emit $ Imp.DebugPrint "\n# SegScan" Nothing
target <- hostTarget <$> askEnv
case target of
CUDA
| Just scan' <- canBeSinglePass scans kbody ->
SinglePass.compileSegScan pat lvl space scan' kbody
HIP
| Just scan' <- canBeSinglePass scans kbody ->
SinglePass.compileSegScan pat lvl space scan' kbody
_ -> TwoPass.compileSegScan pat lvl space scans kbody
emit $ Imp.DebugPrint "" Nothing
compileSegScan pat lvl space scan_ops map_kbody =
sWhen (0 .<. n) $ do
emit $ Imp.DebugPrint "\n# SegScan" Nothing
target <- hostTarget <$> askEnv

case (targetSupportsSinglePass target, canBeSinglePass scan_ops) of
(True, Just scan_ops') ->
SinglePass.compileSegScan pat lvl space scan_ops' map_kbody
_ ->
TwoPass.compileSegScan pat lvl space scan_ops map_kbody
where
n = product $ map pe64 $ segSpaceDims space
targetSupportsSinglePass HIP = True
targetSupportsSinglePass CUDA = True
targetSupportsSinglePass _ = False

820 changes: 422 additions & 398 deletions src/Futhark/CodeGen/ImpGen/GPU/SegScan/SinglePass.hs
Original file line number Diff line number Diff line change
@@ -6,7 +6,7 @@
module Futhark.CodeGen.ImpGen.GPU.SegScan.SinglePass (compileSegScan) where

import Control.Monad
import Data.List (zip4)
import Data.List (zip4, zip7)
import Data.Maybe
import Futhark.CodeGen.ImpCode.GPU qualified as Imp
import Futhark.CodeGen.ImpGen
@@ -18,6 +18,7 @@ import Futhark.Util (mapAccumLM, takeLast)
import Futhark.Util.IntegralExp (IntegralExp (mod, rem), divUp, quot)
import Prelude hiding (mod, quot, rem)


xParams, yParams :: SegBinOp GPUMem -> [LParam GPUMem]
xParams scan =
take (length (segBinOpNeutral scan)) (lambdaParams (segBinOpLambda scan))
@@ -32,9 +33,9 @@ createLocalArrays ::
SubExp ->
[PrimType] ->
InKernelGen (VName, [VName], [VName], VName, [VName])
createLocalArrays (Count groupSize) m types = do
createLocalArrays (Count groupSize) chunk types = do
let groupSizeE = pe64 groupSize
workSize = pe64 m * groupSizeE
workSize = pe64 chunk * groupSizeE
prefixArraysSize =
foldl (\acc tySize -> alignTo acc tySize + tySize * groupSizeE) 0 $
map primByteSize types
@@ -67,7 +68,7 @@ createLocalArrays (Count groupSize) m types = do
warpSize
$ map primByteSize types

sComment "Allocate reused shared memeory" $ pure ()
sComment "Allocate reusable shared memory" $ pure ()

localMem <- sAlloc "local_mem" size (Space "local")
transposeArrayLength <- dPrimV "trans_arr_len" workSize
@@ -105,6 +106,12 @@ createLocalArrays (Count groupSize) m types = do

pure (sharedId, transposedArrays, prefixArrays, warpscan, warpExchanges)

statusX, statusA, statusP :: Num a => a
statusX = 0
statusA = 1
statusP = 2


inBlockScanLookback ::
KernelConstants ->
Imp.TExp Int64 ->
@@ -118,8 +125,8 @@ inBlockScanLookback constants arrs_full_size flag_arr arrs scan_lam = everything
let flg_param_x = Param mempty (tvVar flg_x) (MemPrim p_int8)
flg_param_y = Param mempty (tvVar flg_y) (MemPrim p_int8)
flg_y_exp = tvExp flg_y
statusP = (2 :: Imp.TExp Int8)
statusX = (0 :: Imp.TExp Int8)
statusP_e = statusP :: Imp.TExp Int8
statusX_e = statusX :: Imp.TExp Int8

dLParams (lambdaParams scan_lam)

@@ -149,7 +156,7 @@ inBlockScanLookback constants arrs_full_size flag_arr arrs scan_lam = everything

let op_to_x = do
sIf
(flg_y_exp .==. statusP .||. flg_y_exp .==. statusX)
(flg_y_exp .==. statusP_e .||. flg_y_exp .==. statusX_e)
( do
y_to_x_flg
y_to_x
@@ -194,22 +201,23 @@ inBlockScanLookback constants arrs_full_size flag_arr arrs scan_lam = everything

readInitial p arr
| primType $ paramType p =
copyDWIM (paramName p) [] (Var arr) [DimFix ltid]
copyDWIMFix (paramName p) [] (Var arr) [ltid]
| otherwise =
copyDWIM (paramName p) [] (Var arr) [DimFix gtid]
copyDWIMFix (paramName p) [] (Var arr) [gtid]
readParam behind p arr
| primType $ paramType p =
copyDWIM (paramName p) [] (Var arr) [DimFix $ ltid - behind]
copyDWIMFix (paramName p) [] (Var arr) [ltid - behind]
| otherwise =
copyDWIM (paramName p) [] (Var arr) [DimFix $ gtid - behind + arrs_full_size]
copyDWIMFix (paramName p) [] (Var arr) [gtid - behind + arrs_full_size]

writeResult x y arr
| primType $ paramType x = do
copyDWIM arr [DimFix ltid] (Var $ paramName x) []
copyDWIMFix arr [ltid] (Var $ paramName x) []
copyDWIM (paramName y) [] (Var $ paramName x) []
| otherwise =
copyDWIM (paramName y) [] (Var $ paramName x) []


-- | Compile 'SegScan' instance to host-level code with calls to a
-- single-pass kernel.
compileSegScan ::
@@ -219,431 +227,447 @@ compileSegScan ::
SegBinOp GPUMem ->
KernelBody GPUMem ->
CallKernelGen ()
compileSegScan pat lvl space scanOp kbody = do
compileSegScan pat lvl space scan_op map_kbody = do
attrs <- lvlKernelAttrs lvl
let Pat all_pes = pat
scanOpNe = segBinOpNeutral scanOp
tys = map (\(Prim pt) -> pt) $ lambdaReturnType $ segBinOpLambda scanOp

scanop_nes = segBinOpNeutral scan_op

n = product $ map pe64 $ segSpaceDims space
sumT :: Integer
maxT :: Integer
sumT = foldl (\bytes typ -> bytes + primByteSize typ) 0 tys
primByteSize' = max 4 . primByteSize
sumT' = foldl (\bytes typ -> bytes + primByteSize' typ) 0 tys `div` 4
maxT = maximum (map primByteSize tys)
m :: (Num a) => a
m = fromIntegral $ max 1 $ min mem_constraint reg_constraint

tys = map (\(Prim pt) -> pt) $ lambdaReturnType $ segBinOpLambda scan_op
tys_sizes = map primByteSize tys

sumT, maxT :: Integer
sumT = sum tys_sizes
sumT' = (sum $ map (max 4 . primByteSize) tys) `div` 4
maxT = maximum tys_sizes

-- TODO: Make these constants dynamic by querying device
k_reg = 64
k_mem = 95

mem_constraint = max k_mem sumT `div` maxT
reg_constraint = (k_reg - 1 - sumT') `div` (2 * sumT')

group_size = kAttrGroupSize attrs
group_size' = pe64 $ unCount group_size
chunk :: (Num a) => a
chunk = fromIntegral $ max 1 $ min mem_constraint reg_constraint

num_groups <-
Count . tvSize <$> dPrimV "num_groups" (n `divUp` (group_size' * m))
let num_groups' = pe64 (unCount num_groups)
group_size_e = pe64 $ unCount $ kAttrGroupSize attrs
num_physgroups_e = pe64 $ unCount $ kAttrNumGroups attrs

num_threads <-
dPrimVE "num_threads" $ num_groups' * group_size'
num_virtgroups <-
tvSize <$> dPrimV "num_virtgroups" (n `divUp` (group_size_e * chunk))
let num_virtgroups_e = pe64 num_virtgroups

-- TODO: what is num_threads, and should it be dependent on number of physical
-- or virtual groups?
num_threads <- dPrimVE "num_threads" $ num_virtgroups_e * group_size_e

let (gtids, dims) = unzip $ unSegSpace space
dims' = map pe64 dims
segmented = length dims' > 1
not_segmented_e = if segmented then false else true
not_segmented_e = fromBool $ not segmented
segment_size = last dims'

statusX, statusA, statusP :: (Num a) => a
statusX = 0
statusA = 1
statusP = 2

emit $ Imp.DebugPrint "Sequential elements per thread (m) " $ Just $ untyped (m :: Imp.TExp Int32)
emit $ Imp.DebugPrint "Memory constraint" $ Just $ untyped (fromIntegral mem_constraint :: Imp.TExp Int32)
emit $ Imp.DebugPrint "Register constraint" $ Just $ untyped (fromIntegral reg_constraint :: Imp.TExp Int32)
emit $ Imp.DebugPrint "sumT'" $ Just $ untyped (fromIntegral sumT' :: Imp.TExp Int32)
let debug_ s v = emit $ Imp.DebugPrint s $ Just $ untyped (v :: Imp.TExp Int32)
debug_ "Sequential elements per thread (chunk) " chunk
debug_ "Memory constraint" $ fromIntegral mem_constraint
debug_ "Register constraint" $ fromIntegral reg_constraint
debug_ "sumT'" $ fromIntegral sumT'

globalId <- genZeroes "id_counter" 1
statusFlags <- sAllocArray "status_flags" int8 (Shape [unCount num_groups]) (Space "device")
statusFlags <- sAllocArray "status_flags" int8 (Shape [num_virtgroups]) (Space "device")
(aggregateArrays, incprefixArrays) <-
fmap unzip $
forM tys $ \ty ->
(,)
<$> sAllocArray "aggregates" ty (Shape [unCount num_groups]) (Space "device")
<*> sAllocArray "incprefixes" ty (Shape [unCount num_groups]) (Space "device")
<$> sAllocArray "aggregates" ty (Shape [num_virtgroups]) (Space "device")
<*> sAllocArray "incprefixes" ty (Shape [num_virtgroups]) (Space "device")

sReplicate statusFlags $ intConst Int8 statusX
global_id <- genZeroes "global_dynid" 1

sKernelThread "segscan" (segFlat space) attrs $ do

sKernelThread "segscan" (segFlat space) (defKernelAttrs num_groups group_size) $ do
constants <- kernelConstants <$> askEnv

(sharedId, transposedArrays, prefixArrays, warpscan, exchanges) <-
createLocalArrays (kAttrGroupSize attrs) (intConst Int64 m) tys

dynamicId <- dPrim "dynamic_id" int32
sWhen (kernelLocalThreadId constants .==. 0) $ do
(globalIdMem, _, globalIdOff) <- fullyIndexArray globalId [0]
sOp $
Imp.Atomic DefaultSpace $
Imp.AtomicAdd
Int32
(tvVar dynamicId)
globalIdMem
(Count $ unCount globalIdOff)
(untyped (1 :: Imp.TExp Int32))
copyDWIMFix sharedId [0] (tvSize dynamicId) []

let localBarrier = Imp.Barrier Imp.FenceLocal
localFence = Imp.MemFence Imp.FenceLocal
globalFence = Imp.MemFence Imp.FenceGlobal

sOp localBarrier
copyDWIMFix (tvVar dynamicId) [] (Var sharedId) [0]
sOp localBarrier

blockOff <-
dPrimV "blockOff" $
sExt64 (tvExp dynamicId) * m * kernelGroupSize constants
sgmIdx <- dPrimVE "sgm_idx" $ tvExp blockOff `mod` segment_size
boundary <-
dPrimVE "boundary" $
sExt32 $
sMin64 (m * group_size') (segment_size - sgmIdx)
segsize_compact <-
dPrimVE "segsize_compact" $
sExt32 $
sMin64 (m * group_size') segment_size
privateArrays <-
forM tys $ \ty ->
sAllocArray
"private"
ty
(Shape [intConst Int64 m])
(ScalarSpace [intConst Int64 m] ty)

sComment "Load and map" $
sFor "i" m $ \i -> do
-- The map's input index
phys_tid <-
dPrimVE "phys_tid" $
tvExp blockOff
+ sExt64 (kernelLocalThreadId constants)
+ i * kernelGroupSize constants
dIndexSpace (zip gtids dims') phys_tid
-- Perform the map
let in_bounds =
compileStms mempty (kernelBodyStms kbody) $ do
let (all_scan_res, map_res) = splitAt (segBinOpResults [scanOp]) $ kernelBodyResult kbody

-- Write map results to their global memory destinations
forM_ (zip (takeLast (length map_res) all_pes) map_res) $ \(dest, src) ->
copyDWIMFix (patElemName dest) (map Imp.le64 gtids) (kernelResultSubExp src) []

-- Write to-scan results to private memory.
forM_ (zip privateArrays $ map kernelResultSubExp all_scan_res) $ \(dest, src) ->
copyDWIMFix dest [i] src []

out_of_bounds =
forM_ (zip privateArrays scanOpNe) $ \(dest, ne) ->
copyDWIMFix dest [i] ne []

sIf (phys_tid .<. n) in_bounds out_of_bounds

sOp $ Imp.ErrorSync Imp.FenceLocal
sComment "Transpose scan inputs" $ do
forM_ (zip transposedArrays privateArrays) $ \(trans, priv) -> do
sFor "i" m $ \i -> do
sharedIdx <-
dPrimVE "sharedIdx" $
sExt64 (kernelLocalThreadId constants)
-- TODO: we would use virtualiseGroups instead of the below couple of lines,
-- but it adds a redundant barrier. why?
physgroup_id <- dPrim "physgroup_id" int32
sOp $ Imp.GetGroupId (tvVar physgroup_id) 0
iters <- dPrimVE "virtloop_bound" $ (num_virtgroups_e - tvExp physgroup_id)
`divUp` num_physgroups_e
sFor "virtloop_i" iters $ const $ do

(sharedId, transposedArrays, prefixArrays, warpscan, exchanges) <-
createLocalArrays (kAttrGroupSize attrs) (intConst Int64 chunk) tys

dyn_id <- dPrim "dynamic_id" int32
sComment "First thread in block fetches this block's dynamic_id" $ do
sWhen (kernelLocalThreadId constants .==. 0) $ do
(globalIdMem, _, globalIdOff) <- fullyIndexArray global_id [0]
sOp $ Imp.Atomic DefaultSpace $
Imp.AtomicAdd
Int32
(tvVar dyn_id)
globalIdMem
(Count $ unCount globalIdOff)
(untyped (1 :: Imp.TExp Int32))
sComment "Set dynamic id and reset status flag for this block" $ do
copyDWIMFix sharedId [0] (tvSize dyn_id) []
copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusX) []

sComment "First thread in last (virtual) block resets global dynamic_id" $ do
sWhen (tvExp dyn_id .==. num_virtgroups_e - 1) $
copyDWIMFix global_id [0] (intConst Int32 0) []

let local_barrier = Imp.Barrier Imp.FenceLocal
local_fence = Imp.MemFence Imp.FenceLocal
global_fence = Imp.MemFence Imp.FenceGlobal

sOp local_barrier
copyDWIMFix (tvVar dyn_id) [] (Var sharedId) [0]
sOp local_barrier -- TODO: necessary? don't think so, but ignore it for now.

blockOff <-
dPrimV "blockOff" $
sExt64 (tvExp dyn_id) * chunk * group_size_e -- kernelGroupSize constants
sgmIdx <- dPrimVE "sgm_idx" $ tvExp blockOff `mod` segment_size
boundary <-
dPrimVE "boundary" $
sExt32 $
sMin64 (chunk * group_size_e) (segment_size - sgmIdx)
segsize_compact <-
dPrimVE "segsize_compact" $
sExt32 $
sMin64 (chunk * group_size_e) segment_size
privateArrays <-
forM tys $ \ty ->
sAllocArray
"private"
ty
(Shape [intConst Int64 chunk])
(ScalarSpace [intConst Int64 chunk] ty)

sComment "Load and map" $
sFor "i" chunk $ \i -> do
-- The map's input index
phys_tid <-
dPrimVE "phys_tid" $
tvExp blockOff
+ sExt64 (kernelLocalThreadId constants)
+ i * kernelGroupSize constants
copyDWIMFix trans [sharedIdx] (Var priv) [i]
sOp localBarrier
sFor "i" m $ \i -> do
sharedIdx <- dPrimV "sharedIdx" $ kernelLocalThreadId constants * m + i
copyDWIMFix priv [sExt64 i] (Var trans) [sExt64 $ tvExp sharedIdx]
sOp localBarrier

sComment "Per thread scan" $ do
-- We don't need to touch the first element, so only m-1
-- iterations here.
globalIdx <-
dPrimVE "gidx" $
(kernelLocalThreadId constants * m) + 1
sFor "i" (m - 1) $ \i -> do
let xs = map paramName $ xParams scanOp
ys = map paramName $ yParams scanOp
-- determine if start of segment
new_sgm <-
if segmented
then dPrimVE "new_sgm" $ (globalIdx + sExt32 i - boundary) `mod` segsize_compact .==. 0
else pure false
-- skip scan of first element in segment
sUnless new_sgm $ do
forM_ (zip privateArrays $ zip3 xs ys tys) $ \(src, (x, y, ty)) -> do
dPrim_ x ty
dPrim_ y ty
copyDWIMFix x [] (Var src) [i]
copyDWIMFix y [] (Var src) [i + 1]

compileStms mempty (bodyStms $ lambdaBody $ segBinOpLambda scanOp) $
forM_ (zip privateArrays $ map resSubExp $ bodyResult $ lambdaBody $ segBinOpLambda scanOp) $ \(dest, res) ->
copyDWIMFix dest [i + 1] res []

sComment "Publish results in shared memory" $ do
forM_ (zip prefixArrays privateArrays) $ \(dest, src) ->
copyDWIMFix dest [sExt64 $ kernelLocalThreadId constants] (Var src) [m - 1]
sOp localBarrier

let crossesSegment = do
guard segmented
Just $ \from to ->
let from' = (from + 1) * m - 1
to' = (to + 1) * m - 1
in (to' - from') .>. (to' + segsize_compact - boundary) `mod` segsize_compact

scanOp' <- renameLambda $ segBinOpLambda scanOp

accs <- mapM (dPrim "acc") tys
sComment "Scan results (with warp scan)" $ do
groupScan
crossesSegment
num_threads
(kernelGroupSize constants)
scanOp'
prefixArrays
dIndexSpace (zip gtids dims') phys_tid
-- Perform the map
let in_bounds =
compileStms mempty (kernelBodyStms map_kbody) $ do
let (all_scan_res, map_res) =
splitAt (segBinOpResults [scan_op]) $ kernelBodyResult map_kbody

-- Write map results to their global memory destinations
forM_ (zip (takeLast (length map_res) all_pes) map_res) $ \(dest, src) ->
copyDWIMFix (patElemName dest) (map Imp.le64 gtids) (kernelResultSubExp src) []

-- Write to-scan results to private memory.
forM_ (zip privateArrays $ map kernelResultSubExp all_scan_res) $ \(dest, src) ->
copyDWIMFix dest [i] src []

out_of_bounds =
forM_ (zip privateArrays scanop_nes) $ \(dest, ne) ->
copyDWIMFix dest [i] ne []

sIf (phys_tid .<. n) in_bounds out_of_bounds

sOp $ Imp.ErrorSync Imp.FenceLocal
sComment "Transpose scan inputs" $ do
forM_ (zip transposedArrays privateArrays) $ \(trans, priv) -> do
sFor "i" chunk $ \i -> do
sharedIdx <-
dPrimVE "sharedIdx" $
sExt64 (kernelLocalThreadId constants)
+ i * kernelGroupSize constants
copyDWIMFix trans [sharedIdx] (Var priv) [i]
sOp local_barrier
sFor "i" chunk $ \i -> do
sharedIdx <- dPrimV "sharedIdx" $ kernelLocalThreadId constants * chunk + i
copyDWIMFix priv [sExt64 i] (Var trans) [sExt64 $ tvExp sharedIdx]
sOp local_barrier

sComment "Per thread scan" $ do
-- We don't need to touch the first element, so only m-1
-- iterations here.
globalIdx <-
dPrimVE "gidx" $
(kernelLocalThreadId constants * chunk) + 1
sFor "i" (chunk - 1) $ \i -> do
let xs = map paramName $ xParams scan_op
ys = map paramName $ yParams scan_op
-- determine if start of segment
new_sgm <-
if segmented
then dPrimVE "new_sgm" $ (globalIdx + sExt32 i - boundary) `mod` segsize_compact .==. 0
else pure false
-- skip scan of first element in segment
sUnless new_sgm $ do
forM_ (zip4 privateArrays xs ys tys) $ \(src, x, y, ty) -> do
dPrim_ x ty
dPrim_ y ty
copyDWIMFix x [] (Var src) [i]
copyDWIMFix y [] (Var src) [i + 1]

compileStms mempty (bodyStms $ lambdaBody $ segBinOpLambda scan_op) $
forM_ (zip privateArrays $ map resSubExp $ bodyResult $ lambdaBody $ segBinOpLambda scan_op) $ \(dest, res) ->
copyDWIMFix dest [i + 1] res []


sComment "Publish results in shared memory" $ do
forM_ (zip prefixArrays privateArrays) $ \(dest, src) ->
copyDWIMFix dest [sExt64 $ kernelLocalThreadId constants] (Var src) [chunk - 1]
sOp local_barrier

let crossesSegment = do
guard segmented
Just $ \from to ->
let from' = (from + 1) * chunk - 1
to' = (to + 1) * chunk - 1
in (to' - from') .>. (to' + segsize_compact - boundary) `mod` segsize_compact

scan_op1 <- renameLambda $ segBinOpLambda scan_op

accs <- mapM (dPrim "acc") tys
sComment "Scan results (with warp scan)" $ do
groupScan
crossesSegment
num_threads
(kernelGroupSize constants)
scan_op1
prefixArrays

sOp $ Imp.ErrorSync Imp.FenceLocal

let firstThread acc prefixes =
copyDWIMFix (tvVar acc) [] (Var prefixes) [sExt64 (kernelGroupSize constants) - 1]
notFirstThread acc prefixes =
copyDWIMFix (tvVar acc) [] (Var prefixes) [sExt64 (kernelLocalThreadId constants) - 1]
sIf
(kernelLocalThreadId constants .==. 0)
(zipWithM_ firstThread accs prefixArrays)
(zipWithM_ notFirstThread accs prefixArrays)

let firstThread acc prefixes =
copyDWIMFix (tvVar acc) [] (Var prefixes) [sExt64 (kernelGroupSize constants) - 1]
notFirstThread acc prefixes =
copyDWIMFix (tvVar acc) [] (Var prefixes) [sExt64 (kernelLocalThreadId constants) - 1]
sIf
(kernelLocalThreadId constants .==. 0)
(zipWithM_ firstThread accs prefixArrays)
(zipWithM_ notFirstThread accs prefixArrays)

sOp localBarrier

prefixes <- forM (zip scanOpNe tys) $ \(ne, ty) ->
dPrimV "prefix" $ TPrimExp $ toExp' ty ne
blockNewSgm <- dPrimVE "block_new_sgm" $ sgmIdx .==. 0
sComment "Perform lookback" $ do
sWhen (blockNewSgm .&&. kernelLocalThreadId constants .==. 0) $ do
everythingVolatile $
forM_ (zip accs incprefixArrays) $ \(acc, incprefixArray) ->
copyDWIMFix incprefixArray [tvExp dynamicId] (tvSize acc) []
sOp globalFence
everythingVolatile $
copyDWIMFix statusFlags [tvExp dynamicId] (intConst Int8 statusP) []
forM_ (zip scanOpNe accs) $ \(ne, acc) ->
copyDWIMFix (tvVar acc) [] ne []
-- end sWhen
sOp local_barrier

prefixes <- forM (zip scanop_nes tys) $ \(ne, ty) ->
dPrimV "prefix" $ TPrimExp $ toExp' ty ne
blockNewSgm <- dPrimVE "block_new_sgm" $ sgmIdx .==. 0
sComment "Perform lookback" $ do
sWhen (blockNewSgm .&&. kernelLocalThreadId constants .==. 0) $ do
everythingVolatile $
forM_ (zip accs incprefixArrays) $ \(acc, incprefixArray) ->
copyDWIMFix incprefixArray [tvExp dyn_id] (tvSize acc) []
sOp global_fence
everythingVolatile $
copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusP) []
forM_ (zip scanop_nes accs) $ \(ne, acc) ->
copyDWIMFix (tvVar acc) [] ne []
-- end sWhen

let warpSize = kernelWaveSize constants
sWhen (bNot blockNewSgm .&&. kernelLocalThreadId constants .<. warpSize) $ do
sWhen (kernelLocalThreadId constants .==. 0) $ do
sIf
(not_segmented_e .||. boundary .==. sExt32 (group_size_e * chunk))
( do
everythingVolatile $
forM_ (zip aggregateArrays accs) $ \(aggregateArray, acc) ->
copyDWIMFix aggregateArray [tvExp dyn_id] (tvSize acc) []
sOp global_fence
everythingVolatile $
copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusA) []
)
( do
everythingVolatile $
forM_ (zip incprefixArrays accs) $ \(incprefixArray, acc) ->
copyDWIMFix incprefixArray [tvExp dyn_id] (tvSize acc) []
sOp global_fence
everythingVolatile $
copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusP) []
)
everythingVolatile $
copyDWIMFix warpscan [0] (Var statusFlags) [tvExp dyn_id - 1]
-- sWhen
sOp local_fence

status <- dPrim "status" int8 :: InKernelGen (TV Int8)
copyDWIMFix (tvVar status) [] (Var warpscan) [0]

let warpSize = kernelWaveSize constants
sWhen (bNot blockNewSgm .&&. kernelLocalThreadId constants .<. warpSize) $ do
sWhen (kernelLocalThreadId constants .==. 0) $ do
sIf
(not_segmented_e .||. boundary .==. sExt32 (group_size' * m))
( do
everythingVolatile $
forM_ (zip aggregateArrays accs) $ \(aggregateArray, acc) ->
copyDWIMFix aggregateArray [tvExp dynamicId] (tvSize acc) []
sOp globalFence
(tvExp status .==. statusP)
( sWhen (kernelLocalThreadId constants .==. 0) $
everythingVolatile $
copyDWIMFix statusFlags [tvExp dynamicId] (intConst Int8 statusA) []
forM_ (zip prefixes incprefixArrays) $ \(prefix, incprefixArray) ->
copyDWIMFix (tvVar prefix) [] (Var incprefixArray) [tvExp dyn_id - 1]
)
( do
everythingVolatile $
forM_ (zip incprefixArrays accs) $ \(incprefixArray, acc) ->
copyDWIMFix incprefixArray [tvExp dynamicId] (tvSize acc) []
sOp globalFence
everythingVolatile $
copyDWIMFix statusFlags [tvExp dynamicId] (intConst Int8 statusP) []
)
everythingVolatile $
copyDWIMFix warpscan [0] (Var statusFlags) [tvExp dynamicId - 1]
-- sWhen
sOp localFence

status <- dPrim "status" int8 :: InKernelGen (TV Int8)
copyDWIMFix (tvVar status) [] (Var warpscan) [0]

sIf
(tvExp status .==. statusP)
( sWhen (kernelLocalThreadId constants .==. 0) $
everythingVolatile $
forM_ (zip prefixes incprefixArrays) $ \(prefix, incprefixArray) ->
copyDWIMFix (tvVar prefix) [] (Var incprefixArray) [tvExp dynamicId - 1]
)
( do
readOffset <-
dPrimV "readOffset" $
sExt32 $
tvExp dynamicId - sExt64 (kernelWaveSize constants)
let loopStop = warpSize * (-1)
sameSegment readIdx
| segmented =
let startIdx = sExt64 (tvExp readIdx + 1) * kernelGroupSize constants * m - 1
in tvExp blockOff - startIdx .<=. sgmIdx
| otherwise = true
sWhile (tvExp readOffset .>. loopStop) $ do
readI <- dPrimV "read_i" $ tvExp readOffset + kernelLocalThreadId constants
aggrs <- forM (zip scanOpNe tys) $ \(ne, ty) ->
dPrimV "aggr" $ TPrimExp $ toExp' ty ne
flag <- dPrimV "flag" (statusX :: Imp.TExp Int8)
everythingVolatile . sWhen (tvExp readI .>=. 0) $ do
readOffset <-
dPrimV "readOffset" $
sExt32 $
tvExp dyn_id - sExt64 (kernelWaveSize constants)
let loopStop = warpSize * (-1)
sameSegment readIdx
| segmented =
let startIdx = sExt64 (tvExp readIdx + 1) * kernelGroupSize constants * chunk - 1
in tvExp blockOff - startIdx .<=. sgmIdx
| otherwise = true
sWhile (tvExp readOffset .>. loopStop) $ do
readI <- dPrimV "read_i" $ tvExp readOffset + kernelLocalThreadId constants
aggrs <- forM (zip scanop_nes tys) $ \(ne, ty) ->
dPrimV "aggr" $ TPrimExp $ toExp' ty ne
flag <- dPrimV "flag" (statusX :: Imp.TExp Int8)
everythingVolatile . sWhen (tvExp readI .>=. 0) $ do
sIf
(sameSegment readI)
( do
copyDWIMFix (tvVar flag) [] (Var statusFlags) [sExt64 $ tvExp readI]
sIf
(tvExp flag .==. statusP)
( forM_ (zip incprefixArrays aggrs) $ \(incprefix, aggr) ->
copyDWIMFix (tvVar aggr) [] (Var incprefix) [sExt64 $ tvExp readI]
)
( sWhen (tvExp flag .==. statusA) $ do
forM_ (zip aggrs aggregateArrays) $ \(aggr, aggregate) ->
copyDWIMFix (tvVar aggr) [] (Var aggregate) [sExt64 $ tvExp readI]
)
)
(copyDWIMFix (tvVar flag) [] (intConst Int8 statusP) [])
-- end sIf
-- end sWhen

forM_ (zip exchanges aggrs) $ \(exchange, aggr) ->
copyDWIMFix exchange [sExt64 $ kernelLocalThreadId constants] (tvSize aggr) []
copyDWIMFix warpscan [sExt64 $ kernelLocalThreadId constants] (tvSize flag) []

-- execute warp-parallel reduction but only if the last read flag in not STATUS_P
copyDWIMFix (tvVar flag) [] (Var warpscan) [sExt64 warpSize - 1]
sWhen (tvExp flag .<. (2 :: Imp.TExp Int8)) $ do
lam' <- renameLambda scan_op1
inBlockScanLookback
constants
num_threads
warpscan
exchanges
lam'

-- all threads of the warp read the result of reduction
copyDWIMFix (tvVar flag) [] (Var warpscan) [sExt64 warpSize - 1]
forM_ (zip aggrs exchanges) $ \(aggr, exchange) ->
copyDWIMFix (tvVar aggr) [] (Var exchange) [sExt64 warpSize - 1]
-- update read offset
sIf
(sameSegment readI)
( do
copyDWIMFix (tvVar flag) [] (Var statusFlags) [sExt64 $ tvExp readI]
sIf
(tvExp flag .==. statusP)
( forM_ (zip incprefixArrays aggrs) $ \(incprefix, aggr) ->
copyDWIMFix (tvVar aggr) [] (Var incprefix) [sExt64 $ tvExp readI]
)
( sWhen (tvExp flag .==. statusA) $ do
forM_ (zip aggrs aggregateArrays) $ \(aggr, aggregate) ->
copyDWIMFix (tvVar aggr) [] (Var aggregate) [sExt64 $ tvExp readI]
)
(tvExp flag .==. statusP)
(readOffset <-- loopStop)
( sWhen (tvExp flag .==. statusA) $ do
readOffset <-- tvExp readOffset - zExt32 warpSize
)
(copyDWIMFix (tvVar flag) [] (intConst Int8 statusP) [])
-- end sIf
-- end sWhen

forM_ (zip exchanges aggrs) $ \(exchange, aggr) ->
copyDWIMFix exchange [sExt64 $ kernelLocalThreadId constants] (tvSize aggr) []
copyDWIMFix warpscan [sExt64 $ kernelLocalThreadId constants] (tvSize flag) []

-- execute warp-parallel reduction but only if the last read flag in not STATUS_P
copyDWIMFix (tvVar flag) [] (Var warpscan) [sExt64 warpSize - 1]
sWhen (tvExp flag .<. (2 :: Imp.TExp Int8)) $ do
lam' <- renameLambda scanOp'
inBlockScanLookback
constants
num_threads
warpscan
exchanges
lam'

-- all threads of the warp read the result of reduction
copyDWIMFix (tvVar flag) [] (Var warpscan) [sExt64 warpSize - 1]
forM_ (zip aggrs exchanges) $ \(aggr, exchange) ->
copyDWIMFix (tvVar aggr) [] (Var exchange) [sExt64 warpSize - 1]
-- update read offset
sIf
(tvExp flag .==. statusP)
(readOffset <-- loopStop)
( sWhen (tvExp flag .==. statusA) $ do
readOffset <-- tvExp readOffset - zExt32 warpSize
)

-- update prefix if flag different than STATUS_X:
sWhen (tvExp flag .>. (statusX :: Imp.TExp Int8)) $ do
lam <- renameLambda scanOp'
let (xs, ys) = splitAt (length tys) $ map paramName $ lambdaParams lam
forM_ (zip xs aggrs) $ \(x, aggr) -> dPrimV_ x (tvExp aggr)
forM_ (zip ys prefixes) $ \(y, prefix) -> dPrimV_ y (tvExp prefix)
compileStms mempty (bodyStms $ lambdaBody lam) $
forM_ (zip3 prefixes tys $ map resSubExp $ bodyResult $ lambdaBody lam) $
\(prefix, ty, res) -> prefix <-- TPrimExp (toExp' ty res)
sOp localFence
)

-- end sWhile
-- end sIf
sWhen (kernelLocalThreadId constants .==. 0) $ do
scanOp'''' <- renameLambda scanOp'
let xs = map paramName $ take (length tys) $ lambdaParams scanOp''''
ys = map paramName $ drop (length tys) $ lambdaParams scanOp''''
sWhen (boundary .==. sExt32 (group_size' * m)) $ do
forM_ (zip xs prefixes) $ \(x, prefix) -> dPrimV_ x $ tvExp prefix
forM_ (zip ys accs) $ \(y, acc) -> dPrimV_ y $ tvExp acc
compileStms mempty (bodyStms $ lambdaBody scanOp'''') $
everythingVolatile $
forM_ (zip incprefixArrays $ map resSubExp $ bodyResult $ lambdaBody scanOp'''') $
\(incprefixArray, res) -> copyDWIMFix incprefixArray [tvExp dynamicId] res []
sOp globalFence
everythingVolatile $ copyDWIMFix statusFlags [tvExp dynamicId] (intConst Int8 statusP) []
-- update prefix if flag different than STATUS_X:
sWhen (tvExp flag .>. (statusX :: Imp.TExp Int8)) $ do
lam <- renameLambda scan_op1
let (xs, ys) = splitAt (length tys) $ map paramName $ lambdaParams lam
forM_ (zip xs aggrs) $ \(x, aggr) -> dPrimV_ x (tvExp aggr)
forM_ (zip ys prefixes) $ \(y, prefix) -> dPrimV_ y (tvExp prefix)
compileStms mempty (bodyStms $ lambdaBody lam) $
forM_ (zip3 prefixes tys $ map resSubExp $ bodyResult $ lambdaBody lam) $
\(prefix, ty, res) -> prefix <-- TPrimExp (toExp' ty res)
sOp local_fence
)

-- end sWhile
-- end sIf
sWhen (kernelLocalThreadId constants .==. 0) $ do
scan_op2 <- renameLambda scan_op1
let xs = map paramName $ take (length tys) $ lambdaParams scan_op2
ys = map paramName $ drop (length tys) $ lambdaParams scan_op2
sWhen (boundary .==. sExt32 (group_size_e * chunk)) $ do
forM_ (zip xs prefixes) $ \(x, prefix) -> dPrimV_ x $ tvExp prefix
forM_ (zip ys accs) $ \(y, acc) -> dPrimV_ y $ tvExp acc
compileStms mempty (bodyStms $ lambdaBody scan_op2) $
everythingVolatile $
forM_ (zip incprefixArrays $ map resSubExp $ bodyResult $ lambdaBody scan_op2) $
\(incprefixArray, res) -> copyDWIMFix incprefixArray [tvExp dyn_id] res []
sOp global_fence
everythingVolatile $ copyDWIMFix statusFlags [tvExp dyn_id] (intConst Int8 statusP) []
forM_ (zip exchanges prefixes) $ \(exchange, prefix) ->
copyDWIMFix exchange [0] (tvSize prefix) []
forM_ (zip3 accs tys scanop_nes) $ \(acc, ty, ne) ->
tvVar acc <~~ toExp' ty ne
-- end sWhen
-- end sWhen

sWhen (bNot $ tvExp dyn_id .==. 0) $ do
sOp local_barrier
forM_ (zip exchanges prefixes) $ \(exchange, prefix) ->
copyDWIMFix exchange [0] (tvSize prefix) []
forM_ (zip3 accs tys scanOpNe) $ \(acc, ty, ne) ->
tvVar acc <~~ toExp' ty ne
-- end sWhen
copyDWIMFix (tvVar prefix) [] (Var exchange) [0]
sOp local_barrier
-- end sWhen
-- end sComment

scan_op3 <- renameLambda scan_op1
scan_op4 <- renameLambda scan_op1

sComment "Distribute results" $ do
let (xs, ys) = splitAt (length tys) $ map paramName $ lambdaParams scan_op3
(xs', ys') = splitAt (length tys) $ map paramName $ lambdaParams scan_op4

forM_ (zip7 prefixes accs xs xs' ys ys' tys) $
\(prefix, acc, x, x', y, y', ty) -> do
dPrim_ x ty
dPrim_ y ty
dPrimV_ x' $ tvExp prefix
dPrimV_ y' $ tvExp acc

sIf
(kernelLocalThreadId constants * chunk .<. boundary .&&. bNot blockNewSgm)
( compileStms mempty (bodyStms $ lambdaBody scan_op4) $
forM_ (zip3 xs tys $ map resSubExp $ bodyResult $ lambdaBody scan_op4) $
\(x, ty, res) -> x <~~ toExp' ty res
)
(forM_ (zip xs accs) $ \(x, acc) -> copyDWIMFix x [] (Var $ tvVar acc) [])
-- calculate where previous thread stopped, to determine number of
-- elements left before new segment.
stop <-
dPrimVE "stopping_point" $
segsize_compact - (kernelLocalThreadId constants * chunk - 1 + segsize_compact - boundary) `rem` segsize_compact
sFor "i" chunk $ \i -> do
sWhen (sExt32 i .<. stop - 1) $ do
forM_ (zip privateArrays ys) $ \(src, y) ->
-- only include prefix for the first segment part per thread
copyDWIMFix y [] (Var src) [i]
compileStms mempty (bodyStms $ lambdaBody scan_op3) $
forM_ (zip privateArrays $ map resSubExp $ bodyResult $ lambdaBody scan_op3) $
\(dest, res) ->
copyDWIMFix dest [i] res []

sComment "Transpose scan output and Write it to global memory in coalesced fashion" $ do
forM_ (zip3 transposedArrays privateArrays $ map patElemName all_pes) $ \(locmem, priv, dest) -> do
-- sOp local_barrier
sFor "i" chunk $ \i -> do
sharedIdx <-
dPrimV "sharedIdx" $
sExt64 (kernelLocalThreadId constants * chunk) + i
copyDWIMFix locmem [tvExp sharedIdx] (Var priv) [i]
sOp local_barrier
sFor "i" chunk $ \i -> do
flat_idx <-
dPrimVE "flat_idx" $
tvExp blockOff
+ kernelGroupSize constants * i
+ sExt64 (kernelLocalThreadId constants)
dIndexSpace (zip gtids dims') flat_idx
sWhen (flat_idx .<. n) $ do
copyDWIMFix
dest
(map Imp.le64 gtids)
(Var locmem)
[sExt64 $ flat_idx - tvExp blockOff]
sOp local_barrier

sWhen (bNot $ tvExp dynamicId .==. 0) $ do
sOp localBarrier
forM_ (zip exchanges prefixes) $ \(exchange, prefix) ->
copyDWIMFix (tvVar prefix) [] (Var exchange) [0]
sOp localBarrier
-- end sWhen
-- end sComment

scanOp''''' <- renameLambda scanOp'
scanOp'''''' <- renameLambda scanOp'

sComment "Distribute results" $ do
let (xs, ys) = splitAt (length tys) $ map paramName $ lambdaParams scanOp'''''
(xs', ys') = splitAt (length tys) $ map paramName $ lambdaParams scanOp''''''

forM_ (zip4 (zip prefixes accs) (zip xs xs') (zip ys ys') tys) $
\((prefix, acc), (x, x'), (y, y'), ty) -> do
dPrim_ x ty
dPrim_ y ty
dPrimV_ x' $ tvExp prefix
dPrimV_ y' $ tvExp acc

sIf
(kernelLocalThreadId constants * m .<. boundary .&&. bNot blockNewSgm)
( compileStms mempty (bodyStms $ lambdaBody scanOp'''''') $
forM_ (zip3 xs tys $ map resSubExp $ bodyResult $ lambdaBody scanOp'''''') $
\(x, ty, res) -> x <~~ toExp' ty res
)
(forM_ (zip xs accs) $ \(x, acc) -> copyDWIMFix x [] (Var $ tvVar acc) [])
-- calculate where previous thread stopped, to determine number of
-- elements left before new segment.
stop <-
dPrimVE "stopping_point" $
segsize_compact - (kernelLocalThreadId constants * m - 1 + segsize_compact - boundary) `rem` segsize_compact
sFor "i" m $ \i -> do
sWhen (sExt32 i .<. stop - 1) $ do
forM_ (zip privateArrays ys) $ \(src, y) ->
-- only include prefix for the first segment part per thread
copyDWIMFix y [] (Var src) [i]
compileStms mempty (bodyStms $ lambdaBody scanOp''''') $
forM_ (zip privateArrays $ map resSubExp $ bodyResult $ lambdaBody scanOp''''') $
\(dest, res) ->
copyDWIMFix dest [i] res []

sComment "Transpose scan output and Write it to global memory in coalesced fashion" $ do
forM_ (zip3 transposedArrays privateArrays $ map patElemName all_pes) $ \(locmem, priv, dest) -> do
-- sOp localBarrier
sFor "i" m $ \i -> do
sharedIdx <-
dPrimV "sharedIdx" $
sExt64 (kernelLocalThreadId constants * m) + i
copyDWIMFix locmem [tvExp sharedIdx] (Var priv) [i]
sOp localBarrier
sFor "i" m $ \i -> do
flat_idx <-
dPrimVE "flat_idx" $
tvExp blockOff
+ kernelGroupSize constants * i
+ sExt64 (kernelLocalThreadId constants)
dIndexSpace (zip gtids dims') flat_idx
sWhen (flat_idx .<. n) $ do
copyDWIMFix
dest
(map Imp.le64 gtids)
(Var locmem)
[sExt64 $ flat_idx - tvExp blockOff]
sOp localBarrier

sComment "If this is the last block, reset the dynamicId" $
sWhen (tvExp dynamicId .==. num_groups' - 1) $
copyDWIMFix globalId [0] (constant (0 :: Int32)) []
{-# NOINLINE compileSegScan #-}