Skip to content

Commit

Permalink
Flip order of operands to Screma.
Browse files Browse the repository at this point in the history
Now they resemble the logical execution order: first the map function,
then reduce/scan operations if applicable.
  • Loading branch information
athas committed Dec 4, 2024
1 parent fa90e67 commit 267c866
Show file tree
Hide file tree
Showing 15 changed files with 93 additions and 96 deletions.
6 changes: 3 additions & 3 deletions src/Futhark/AD/Fwd.hs
Original file line number Diff line number Diff line change
Expand Up @@ -264,13 +264,13 @@ zeroFromSubExp (Var v) = do
letExp "zero" $ zeroExp t

fwdSOAC :: Pat Type -> StmAux () -> SOAC SOACS -> ADM ()
fwdSOAC pat aux (Screma size xs (ScremaForm scs reds f)) = do
fwdSOAC pat aux (Screma size xs (ScremaForm f scs reds)) = do
pat' <- bundleNewPat pat
xs' <- bundleTangents xs
f' <- fwdLambda f
scs' <- mapM fwdScan scs
reds' <- mapM fwdRed reds
f' <- fwdLambda f
addStm $ Let pat' aux $ Op $ Screma size xs' $ ScremaForm scs' reds' f'
addStm $ Let pat' aux $ Op $ Screma size xs' $ ScremaForm f' scs' reds'
where
fwdScan :: Scan SOACS -> ADM (Scan SOACS)
fwdScan sc = do
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/AD/Rev/SOAC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ mapOp (Lambda [pa1, pa2] _ lam_body)
cs == mempty,
[map_stm] <- stmsToList (bodyStms lam_body),
(Let (Pat [pe]) _ (Op scrm)) <- map_stm,
(Screma _ [a1, a2] (ScremaForm [] [] map_lam)) <- scrm,
(Screma _ [a1, a2] (ScremaForm map_lam [] [])) <- scrm,
(a1 == paramName pa1 && a2 == paramName pa2) || (a1 == paramName pa2 && a2 == paramName pa1),
r == Var (patElemName pe) =
Just map_lam
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/Analysis/HORep/MapNest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ fromSOAC' ::
[Ident] ->
SOAC rep ->
m (Maybe (MapNest rep))
fromSOAC' bound (SOAC.Screma w inps (SOAC.ScremaForm [] [] lam)) = do
fromSOAC' bound (SOAC.Screma w inps (SOAC.ScremaForm lam [] [])) = do
maybenest <- case ( stmsToList $ bodyStms $ lambdaBody lam,
bodyResult $ lambdaBody lam
) of
Expand Down
6 changes: 3 additions & 3 deletions src/Futhark/Analysis/HORep/SOAC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ newWidth (inp : _) _ = arraySize 0 $ inputType inp
lambda :: SOAC rep -> Lambda rep
lambda (Stream _ _ _ lam) = lam
lambda (Scatter _len _ivs _spec lam) = lam
lambda (Screma _ _ (ScremaForm _ _ lam)) = lam
lambda (Screma _ _ (ScremaForm lam _ _)) = lam
lambda (Hist _ _ _ lam) = lam

-- | Set the lambda used in the SOAC.
Expand All @@ -444,8 +444,8 @@ setLambda lam (Stream w arrs nes _) =
Stream w arrs nes lam
setLambda lam (Scatter len arrs spec _lam) =
Scatter len arrs spec lam
setLambda lam (Screma w arrs (ScremaForm scan red _)) =
Screma w arrs (ScremaForm scan red lam)
setLambda lam (Screma w arrs (ScremaForm _ scan red)) =
Screma w arrs (ScremaForm lam scan red)
setLambda lam (Hist w ops inps _) =
Hist w ops inps lam

Expand Down
21 changes: 11 additions & 10 deletions src/Futhark/IR/Parse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -727,24 +727,25 @@ pSOAC pr =
<*> p
pScremaForm =
SOAC.ScremaForm
<$> braces (pScan pr `sepBy` pComma)
<$> pLambda pr
<* pComma
<*> braces (pReduce pr `sepBy` pComma)
<*> braces (pScan pr `sepBy` pComma)
<* pComma
<*> pLambda pr
<*> braces (pReduce pr `sepBy` pComma)
pRedomapForm =
SOAC.ScremaForm mempty
<$> braces (pReduce pr `sepBy` pComma)
SOAC.ScremaForm
<$> pLambda pr
<*> pure []
<* pComma
<*> pLambda pr
<*> braces (pReduce pr `sepBy` pComma)
pScanomapForm =
SOAC.ScremaForm
<$> braces (pScan pr `sepBy` pComma)
<$> pLambda pr
<* pComma
<*> pure mempty
<*> pLambda pr
<*> braces (pScan pr `sepBy` pComma)
<*> pure []
pMapForm =
SOAC.ScremaForm mempty mempty <$> pLambda pr
SOAC.ScremaForm <$> pLambda pr <*> pure mempty <*> pure mempty
pScatter =
keyword "scatter"
*> parens
Expand Down
2 changes: 1 addition & 1 deletion src/Futhark/IR/SOACS.hs
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ usesAD prog = any stmUsesAD (progConsts prog) || any funUsesAD (progFuns prog)
expUsesAD (Op JVP {}) = True
expUsesAD (Op VJP {}) = True
expUsesAD (Op (Stream _ _ _ lam)) = lamUsesAD lam
expUsesAD (Op (Screma _ _ (ScremaForm scans reds lam))) =
expUsesAD (Op (Screma _ _ (ScremaForm lam scans reds))) =
lamUsesAD lam
|| any (lamUsesAD . scanLambda) scans
|| any (lamUsesAD . redLambda) reds
Expand Down
81 changes: 37 additions & 44 deletions src/Futhark/IR/SOACS/SOAC.hs
Original file line number Diff line number Diff line change
Expand Up @@ -148,14 +148,14 @@ data HistOp rep = HistOp
-- | The essential parts of a 'Screma' factored out (everything
-- except the input arrays).
data ScremaForm rep = ScremaForm
{ scremaScans :: [Scan rep],
scremaReduces :: [Reduce rep],
-- | The "main" lambda of the Screma. For a map, this is
{ -- | The "main" lambda of the Screma. For a map, this is
-- equivalent to 'isMapSOAC'. Note that the meaning of the return
-- value of this lambda depends crucially on exactly which Screma
-- this is. The parameters will correspond exactly to elements of
-- the input arrays, however.
scremaLambda :: Lambda rep
scremaLambda :: Lambda rep,
scremaScans :: [Scan rep],
scremaReduces :: [Reduce rep]
}
deriving (Eq, Ord, Show)

Expand Down Expand Up @@ -221,7 +221,7 @@ singleReduce reds =
-- | The types produced by a single 'Screma', given the size of the
-- input array.
scremaType :: SubExp -> ScremaForm rep -> [Type]
scremaType w (ScremaForm scans reds map_lam) =
scremaType w (ScremaForm map_lam scans reds) =
scan_tps ++ red_tps ++ map (`arrayOfRow` w) map_tps
where
scan_tps =
Expand Down Expand Up @@ -258,12 +258,12 @@ nilFn = Lambda mempty mempty (mkBody mempty mempty)
-- | Construct a Screma with possibly multiple scans, and
-- the given map function.
scanomapSOAC :: [Scan rep] -> Lambda rep -> ScremaForm rep
scanomapSOAC scans = ScremaForm scans []
scanomapSOAC scans lam = ScremaForm lam scans []

-- | Construct a Screma with possibly multiple reductions, and
-- the given map function.
redomapSOAC :: [Reduce rep] -> Lambda rep -> ScremaForm rep
redomapSOAC = ScremaForm []
redomapSOAC reds lam = ScremaForm lam [] reds

-- | Construct a Screma with possibly multiple scans, and identity map
-- function.
Expand All @@ -287,11 +287,11 @@ reduceSOAC reds = redomapSOAC reds <$> mkIdentityLambda ts

-- | Construct a Screma corresponding to a map.
mapSOAC :: Lambda rep -> ScremaForm rep
mapSOAC = ScremaForm [] []
mapSOAC lam = ScremaForm lam [] []

-- | Does this Screma correspond to a scan-map composition?
isScanomapSOAC :: ScremaForm rep -> Maybe ([Scan rep], Lambda rep)
isScanomapSOAC (ScremaForm scans reds map_lam) = do
isScanomapSOAC (ScremaForm map_lam scans reds) = do
guard $ null reds
guard $ not $ null scans
pure (scans, map_lam)
Expand All @@ -305,7 +305,7 @@ isScanSOAC form = do

-- | Does this Screma correspond to a reduce-map composition?
isRedomapSOAC :: ScremaForm rep -> Maybe ([Reduce rep], Lambda rep)
isRedomapSOAC (ScremaForm scans reds map_lam) = do
isRedomapSOAC (ScremaForm map_lam scans reds) = do
guard $ null scans
guard $ not $ null reds
pure (reds, map_lam)
Expand All @@ -320,7 +320,7 @@ isReduceSOAC form = do
-- | Does this Screma correspond to a simple map, without any
-- reduction or scan results?
isMapSOAC :: ScremaForm rep -> Maybe (Lambda rep)
isMapSOAC (ScremaForm scans reds map_lam) = do
isMapSOAC (ScremaForm map_lam scans reds) = do
guard $ null scans
guard $ null reds
pure map_lam
Expand Down Expand Up @@ -443,12 +443,13 @@ mapSOACM tv (Hist w arrs ops bucket_fun) =
)
ops
<*> mapOnSOACLambda tv bucket_fun
mapSOACM tv (Screma w arrs (ScremaForm scans reds map_lam)) =
mapSOACM tv (Screma w arrs (ScremaForm map_lam scans reds)) =
Screma
<$> mapOnSOACSubExp tv w
<*> mapM (mapOnSOACVName tv) arrs
<*> ( ScremaForm
<$> forM
<$> mapOnSOACLambda tv map_lam
<*> forM
scans
( \(Scan red_lam red_nes) ->
Scan
Expand All @@ -462,7 +463,6 @@ mapSOACM tv (Screma w arrs (ScremaForm scans reds map_lam)) =
<$> mapOnSOACLambda tv red_lam
<*> mapM (mapOnSOACSubExp tv) red_nes
)
<*> mapOnSOACLambda tv map_lam
)

-- | A helper for defining 'TraverseOpStms'.
Expand Down Expand Up @@ -547,7 +547,7 @@ instance AliasedOp SOAC where
consumedInOp VJP {} = mempty
-- Only map functions can consume anything. The operands to scan
-- and reduce functions are always considered "fresh".
consumedInOp (Screma _ arrs (ScremaForm _ _ map_lam)) =
consumedInOp (Screma _ arrs (ScremaForm map_lam _ _)) =
mapNames consumedArray $ consumedByLambda map_lam
where
consumedArray v = fromMaybe v $ lookup v params_to_arrs
Expand Down Expand Up @@ -586,12 +586,12 @@ instance CanBeAliased SOAC where
arrs
(map (mapHistOp (Alias.analyseLambda aliases)) ops)
(Alias.analyseLambda aliases bucket_fun)
addOpAliases aliases (Screma w arrs (ScremaForm scans reds map_lam)) =
addOpAliases aliases (Screma w arrs (ScremaForm map_lam scans reds)) =
Screma w arrs $
ScremaForm
(Alias.analyseLambda aliases map_lam)
(map onScan scans)
(map onRed reds)
(Alias.analyseLambda aliases map_lam)
where
onRed red = red {redLambda = Alias.analyseLambda aliases $ redLambda red}
onScan scan = scan {scanLambda = Alias.analyseLambda aliases $ scanLambda scan}
Expand Down Expand Up @@ -642,7 +642,7 @@ instance IsOp SOAC where
lam
(zipWith (<>) (map depsOf' args) (map depsOf' vec))
<> map (const $ freeIn args <> freeIn lam) (lambdaParams lam)
opDependencies (Screma w arrs (ScremaForm scans reds map_lam)) =
opDependencies (Screma w arrs (ScremaForm map_lam scans reds)) =
let (scans_in, reds_in, map_deps) =
splitAt3 (scanResults scans) (redResults reds) $
lambdaDependencies mempty map_lam (depsOfArrays w arrs)
Expand Down Expand Up @@ -682,7 +682,7 @@ instance (RepTypes rep) => ST.IndexOp (SOAC rep) where
SubExpRes _ (Var v) -> uncurry (flip ST.Indexed) <$> M.lookup v arr_indexes'
_ -> Nothing
where
lambdaAndSubExp (Screma _ arrs (ScremaForm scans reds map_lam)) =
lambdaAndSubExp (Screma _ arrs (ScremaForm map_lam scans reds)) =
nthMapOut (scanResults scans + redResults reds) map_lam arrs
lambdaAndSubExp _ =
Nothing
Expand Down Expand Up @@ -849,7 +849,7 @@ typeCheckSOAC (Hist w arrs ops bucket_fun) = do
<> prettyTuple (lambdaReturnType bucket_fun)
<> " but should have type "
<> prettyTuple bucket_ret_t
typeCheckSOAC (Screma w arrs (ScremaForm scans reds map_lam)) = do
typeCheckSOAC (Screma w arrs (ScremaForm map_lam scans reds)) = do
TC.require [Prim int64] w
arrs' <- TC.checkSOACArrayArgs w arrs
TC.checkLambda map_lam arrs'
Expand Down Expand Up @@ -906,12 +906,12 @@ instance RephraseOp SOAC where
where
onOp (HistOp dest_shape rf dests nes op) =
HistOp dest_shape rf dests nes <$> rephraseLambda r op
rephraseInOp r (Screma w arrs (ScremaForm scans red lam)) =
rephraseInOp r (Screma w arrs (ScremaForm lam scans red)) =
Screma w arrs
<$> ( ScremaForm
<$> mapM onScan scans
<$> rephraseLambda r lam
<*> mapM onScan scans
<*> mapM onRed red
<*> rephraseLambda r lam
)
where
onScan (Scan op nes) = Scan <$> rephraseLambda r op <*> pure nes
Expand All @@ -928,11 +928,11 @@ instance (OpMetrics (Op rep)) => OpMetrics (SOAC rep) where
inside "Scatter" $ lambdaMetrics lam
opMetrics (Hist _ _ ops bucket_fun) =
inside "Hist" $ mapM_ (lambdaMetrics . histOp) ops >> lambdaMetrics bucket_fun
opMetrics (Screma _ _ (ScremaForm scans reds map_lam)) =
opMetrics (Screma _ _ (ScremaForm map_lam scans reds)) =
inside "Screma" $ do
lambdaMetrics map_lam
mapM_ (lambdaMetrics . scanLambda) scans
mapM_ (lambdaMetrics . redLambda) reds
lambdaMetrics map_lam

instance (PrettyRep rep) => PP.Pretty (SOAC rep) where
pretty (VJP lam args vec) =
Expand Down Expand Up @@ -961,56 +961,49 @@ instance (PrettyRep rep) => PP.Pretty (SOAC rep) where
ppScatter w arrs dests lam
pretty (Hist w arrs ops bucket_fun) =
ppHist w arrs ops bucket_fun
pretty (Screma w arrs (ScremaForm scans reds map_lam))
pretty (Screma w arrs (ScremaForm map_lam scans reds))
| null scans,
null reds =
"map"
<> (parens . align)
( pretty w
<> comma
</> ppTuple' (map pretty arrs)
<> comma
</> pretty map_lam
<> comma </> ppTuple' (map pretty arrs)
<> comma </> pretty map_lam
)
| null scans =
"redomap"
<> (parens . align)
( pretty w
<> comma
</> ppTuple' (map pretty arrs)
<> comma </> ppTuple' (map pretty arrs)
<> comma </> pretty map_lam
<> comma
</> PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty reds)
<> comma
</> pretty map_lam
)
| null reds =
"scanomap"
<> (parens . align)
( pretty w
<> comma </> ppTuple' (map pretty arrs)
<> comma </> pretty map_lam
<> comma
</> ppTuple' (map pretty arrs)
<> comma
</> PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty scans)
<> comma
</> pretty map_lam
</> PP.braces
(mconcat $ intersperse (comma <> PP.line) $ map pretty scans)
)
pretty (Screma w arrs form) = ppScrema w arrs form

-- | Prettyprint the given Screma.
ppScrema ::
(PrettyRep rep, Pretty inp) => SubExp -> [inp] -> ScremaForm rep -> Doc ann
ppScrema w arrs (ScremaForm scans reds map_lam) =
ppScrema w arrs (ScremaForm map_lam scans reds) =
"screma"
<> (parens . align)
( pretty w
<> comma
</> ppTuple' (map pretty arrs)
<> comma </> ppTuple' (map pretty arrs)
<> comma </> pretty map_lam
<> comma
</> PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty scans)
<> comma
</> PP.braces (mconcat $ intersperse (comma <> PP.line) $ map pretty reds)
<> comma
</> pretty map_lam
)

-- | Prettyprint the given Stream.
Expand Down
Loading

0 comments on commit 267c866

Please sign in to comment.