Skip to content

Commit 3cbde4c

Browse files
authored
Merge pull request #1334 from google-research/primitive-name-map-e
Redefine NameMapE as the primitive type
2 parents c7373b2 + c372cba commit 3cbde4c

File tree

7 files changed

+78
-98
lines changed

7 files changed

+78
-98
lines changed

src/lib/CheckType.hs

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,13 @@ liftTyperM cont =
5656
affineUsed :: AtomName r o -> TyperM r i o ()
5757
affineUsed name = TyperM $ do
5858
affines <- get
59-
case lookupNameMap name affines of
60-
Just n -> if n > 0 then
61-
throw TypeErr $ "Affine name " ++ pprint name ++ " used " ++ show (n + 1) ++ " times."
62-
else
63-
put $ insertNameMap name (n + 1) affines
64-
Nothing -> put $ insertNameMap name 1 affines
59+
case lookupNameMapE name affines of
60+
Just (LiftE n) ->
61+
if n > 0 then
62+
throw TypeErr $ "Affine name " ++ pprint name ++ " used " ++ show (n + 1) ++ " times."
63+
else
64+
put $ insertNameMapE name (LiftE $ n + 1) affines
65+
Nothing -> put $ insertNameMapE name (LiftE 1) affines
6566

6667
parallelAffines :: [TyperM r i o a] -> TyperM r i o [a]
6768
parallelAffines actions = TyperM $ do
@@ -77,7 +78,7 @@ parallelAffines actions = TyperM $ do
7778
result <- runTyperT' act
7879
(result,) <$> get
7980
put affines
80-
forM_ (toListNameMap $ unionsWithNameMap max isolateds) \(name, ct) ->
81+
forM_ (toListNameMapE $ unionsWithNameMapE max isolateds) \(name, (LiftE ct)) ->
8182
case ct of
8283
0 -> return ()
8384
1 -> runTyperT' $ affineUsed name

src/lib/Core.hs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -482,5 +482,3 @@ freshNameM hint = do
482482
Distinct <- getDistinct
483483
return $ withFresh hint scope \b -> Abs b (binderName b)
484484
{-# INLINE freshNameM #-}
485-
486-
type AtomNameMap r = NameMap (AtomNameC r)

src/lib/Lower.hs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,17 +217,18 @@ lowerCase maybeDest scrut alts resultTy = do
217217
-- so that it never allocates scratch space for its result, but will put it directly in
218218
-- the corresponding slice of the full 2D buffer.
219219

220-
type DestAssignment (i'::S) (o::S) = AtomNameMap SimpIR (ProjDest o) i'
220+
type DestAssignment (i'::S) (o::S) = NameMap (AtomNameC SimpIR) (ProjDest o) i'
221221

222222
data ProjDest o
223223
= FullDest (Dest SimpIR o)
224224
| ProjDest (NE.NonEmpty Projection) (Dest SimpIR o) -- dest corresponds to the projection applied to name
225+
deriving (Show)
225226

226227
instance SinkableE ProjDest where
227228
sinkingProofE = todoSinkableProof
228229

229230
lookupDest :: DestAssignment i' o -> SAtomName i' -> Maybe (ProjDest o)
230-
lookupDest = flip lookupNameMap
231+
lookupDest dests = fmap fromLiftE . flip lookupNameMapE dests
231232

232233
-- Matches up the free variables of the atom, with the given dest. For example, if the
233234
-- atom is a pair of two variables, the dest might be split into per-component dests,
@@ -238,10 +239,10 @@ lookupDest = flip lookupNameMap
238239
-- XXX: When adding more cases, be careful about potentially repeated vars in the output!
239240
decomposeDest :: Emits o => Dest SimpIR o -> SAtom i' -> LowerM i o (Maybe (DestAssignment i' o))
240241
decomposeDest dest = \case
241-
Var v -> return $ Just $ singletonNameMap (atomVarName v) $ FullDest dest
242+
Var v -> return $ Just $ singletonNameMapE (atomVarName v) $ LiftE $ FullDest dest
242243
ProjectElt _ p x -> do
243244
(ps, v) <- return $ asNaryProj p x
244-
return $ Just $ singletonNameMap (atomVarName v) $ ProjDest ps dest
245+
return $ Just $ singletonNameMapE (atomVarName v) $ LiftE $ ProjDest ps dest
245246
_ -> return Nothing
246247

247248
lowerBlockWithDest :: Emits o => Dest SimpIR o -> SBlock i -> LowerM i o (SAtom o)
@@ -258,7 +259,7 @@ lowerBlockWithDest dest (Abs decls ans) = do
258259
Just DistinctBetween -> do
259260
s' <- traverseDeclNestWithDestS destMap s decls
260261
-- But we have to emit explicit writes, for all the vars that are not defined in decls!
261-
forM_ (toListNameMap $ hoistFilterNameMap decls destMap) \(n, d) -> do
262+
forM_ (toListNameMapE $ hoistNameMap decls destMap) \(n, (LiftE d)) -> do
262263
x <- case s ! n of
263264
Rename v -> Var <$> toAtomVar v
264265
SubstVal a -> return a

src/lib/MTL1.hs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@ instance HoistableState UnitE where
223223
hoistState _ _ UnitE = UnitE
224224
{-# INLINE hoistState #-}
225225

226-
instance HoistableState (NameMap c a) where
227-
hoistState _ b m = hoistFilterNameMap b m
226+
instance Show a => HoistableState (NameMap c a) where
227+
hoistState _ b m = hoistNameMap b m
228228
{-# INLINE hoistState #-}
229229

230230
-------------------- ScopedT1 --------------------

src/lib/Name.hs

Lines changed: 55 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -481,7 +481,7 @@ newtype NonEmptyListE (e::E) (n::S) = NonEmptyListE { fromNonEmptyListE :: NonEm
481481
deriving (Show, Eq, Generic)
482482

483483
newtype LiftE (a:: *) (n::S) = LiftE { fromLiftE :: a }
484-
deriving (Show, Eq, Generic, Monoid, Semigroup)
484+
deriving (Show, Eq, Ord, Generic, Monoid, Semigroup)
485485

486486
newtype ComposeE (f :: * -> *) (e::E) (n::S) =
487487
ComposeE { fromComposeE :: (f (e n)) }
@@ -3256,113 +3256,85 @@ instance HoistableB b => HoistableB (WithAttrB a b) where
32563256

32573257
-- === extra data structures ===
32583258

3259-
-- A map from names in some scope to values that do not contain names. This is
3260-
-- not trying to enforce completeness -- a name in the scope can fail to be in
3261-
-- the map.
3262-
3263-
-- Hoisting the map removes entries that are no longer in scope.
3264-
3265-
newtype NameMap (c::C) (a:: *) (n::S) = UnsafeNameMap (RawNameMap a)
3266-
deriving (Eq, Semigroup, Monoid, Store)
3267-
3268-
hoistFilterNameMap :: BindsNames b => b n l -> NameMap c a l -> NameMap c a n
3269-
hoistFilterNameMap b (UnsafeNameMap raw) =
3270-
UnsafeNameMap $ raw `R.difference` frag
3271-
where UnsafeMakeScopeFrag frag = toScopeFrag b
3272-
{-# INLINE hoistFilterNameMap #-}
3273-
3274-
insertNameMap :: Name c n -> a -> NameMap c a n -> NameMap c a n
3275-
insertNameMap (UnsafeMakeName n) x (UnsafeNameMap raw) = UnsafeNameMap $ R.insert n x raw
3276-
{-# INLINE insertNameMap #-}
3277-
3278-
lookupNameMap :: Name c n -> NameMap c a n -> Maybe a
3279-
lookupNameMap (UnsafeMakeName n) (UnsafeNameMap raw) = R.lookup n raw
3280-
{-# INLINE lookupNameMap #-}
3281-
3282-
singletonNameMap :: Name c n -> a -> NameMap c a n
3283-
singletonNameMap (UnsafeMakeName n) x = UnsafeNameMap $ R.singleton n x
3284-
{-# INLINE singletonNameMap #-}
3285-
3286-
toListNameMap :: NameMap c a n -> [(Name c n, a)]
3287-
toListNameMap (UnsafeNameMap raw) = R.toList raw <&> \(r, x) -> (UnsafeMakeName r, x)
3288-
{-# INLINE toListNameMap #-}
3289-
3290-
unionWithNameMap :: (a -> a -> a) -> NameMap c a n -> NameMap c a n -> NameMap c a n
3291-
unionWithNameMap f (UnsafeNameMap raw1) (UnsafeNameMap raw2) =
3292-
UnsafeNameMap $ R.unionWith f raw1 raw2
3293-
{-# INLINE unionWithNameMap #-}
3294-
3295-
unionsWithNameMap :: (Foldable f) => (a -> a -> a) -> f (NameMap c a n) -> NameMap c a n
3296-
unionsWithNameMap func maps =
3297-
foldl' (unionWithNameMap func) mempty maps
3298-
{-# INLINE unionsWithNameMap #-}
3299-
3300-
traverseNameMap :: (Applicative f) => (a -> f b)
3301-
-> NameMap c a n -> f (NameMap c b n)
3302-
traverseNameMap f (UnsafeNameMap raw) = UnsafeNameMap <$> traverse f raw
3303-
{-# INLINE traverseNameMap #-}
3304-
3305-
mapNameMap :: (a -> b) -> NameMap c a n -> (NameMap c b n)
3306-
mapNameMap f (UnsafeNameMap raw) = UnsafeNameMap $ fmap f raw
3307-
{-# INLINE mapNameMap #-}
3308-
3309-
keysNameMap :: NameMap c a n -> [Name c n]
3310-
keysNameMap = map fst . toListNameMap
3311-
{-# INLINE keysNameMap #-}
3312-
3313-
keySetNameMap :: (Color c) => NameMap c a n -> NameSet n
3314-
keySetNameMap nmap = freeVarsE $ ListE $ keysNameMap nmap
3315-
3316-
instance SinkableE (NameMap c a) where
3317-
sinkingProofE = undefined
3259+
-- A map from names in some scope to values that may contain names
3260+
-- from the same scope. This is not trying to enforce completeness --
3261+
-- a name in the scope can fail to be in the map.
3262+
3263+
-- Hoisting the map removes entries for names that are no longer in
3264+
-- scope, and then attempts to hoist the remaining values.
3265+
3266+
-- This structure is useful for bottom-up code traversals. Once one
3267+
-- has traversed some term in scope n, one may be carrying information
3268+
-- associated with (some of) the free variables of the term. These
3269+
-- free variables are necessarily in the scope n, though they need by
3270+
-- no means be all the names in the scope n (that's what a Subst is
3271+
-- for). But, if the traversal is alpha-invariant, it cannot be
3272+
-- carrying any information about names bound within the term, only
3273+
-- the free ones.
3274+
--
3275+
-- Further, if the information being carried is E-kinded, the names
3276+
-- therein should be resolvable in the same scope n, since those are
3277+
-- the only names that are given meaning by the context of the term
3278+
-- being traversed.
33183279

3319-
newtype NameMapE (c::C) (e:: E) (n::S) = NameMapE (NameMap c (e n) n)
3280+
newtype NameMapE (c::C) (e:: E) (n::S) = UnsafeNameMapE (RawNameMap (e n))
33203281
deriving (Eq, Semigroup, Monoid, Store)
33213282

33223283
-- Filters out the entry(ies) for the binder being hoisted above,
33233284
-- and hoists the values of the remaining entries.
33243285
hoistNameMapE :: (BindsNames b, HoistableE e, ShowE e)
33253286
=> b n l -> NameMapE c e l -> HoistExcept (NameMapE c e n)
3326-
hoistNameMapE b (NameMapE nmap) =
3327-
NameMapE <$> (traverseNameMap (hoist b) $ hoistFilterNameMap b nmap) where
3287+
hoistNameMapE b (UnsafeNameMapE raw) =
3288+
UnsafeNameMapE <$> traverse (hoist b) diff
3289+
where
3290+
diff = raw `R.difference` frag
3291+
UnsafeMakeScopeFrag frag = toScopeFrag b
33283292
{-# INLINE hoistNameMapE #-}
33293293

33303294
insertNameMapE :: Name c n -> e n -> NameMapE c e n -> NameMapE c e n
3331-
insertNameMapE n x (NameMapE nmap) = NameMapE $ insertNameMap n x nmap
3295+
insertNameMapE (UnsafeMakeName n) x (UnsafeNameMapE raw)
3296+
= UnsafeNameMapE $ R.insert n x raw
33323297
{-# INLINE insertNameMapE #-}
33333298

33343299
lookupNameMapE :: Name c n -> NameMapE c e n -> Maybe (e n)
3335-
lookupNameMapE n (NameMapE nmap) = lookupNameMap n nmap
3300+
lookupNameMapE (UnsafeMakeName n) (UnsafeNameMapE raw) = R.lookup n raw
33363301
{-# INLINE lookupNameMapE #-}
33373302

33383303
singletonNameMapE :: Name c n -> e n -> NameMapE c e n
3339-
singletonNameMapE n x = NameMapE $ singletonNameMap n x
3304+
singletonNameMapE (UnsafeMakeName n) x = UnsafeNameMapE $ R.singleton n x
33403305
{-# INLINE singletonNameMapE #-}
33413306

33423307
toListNameMapE :: NameMapE c e n -> [(Name c n, (e n))]
3343-
toListNameMapE (NameMapE nmap) = toListNameMap nmap
3308+
toListNameMapE (UnsafeNameMapE raw) =
3309+
R.toList raw <&> \(r, x) -> (UnsafeMakeName r, x)
33443310
{-# INLINE toListNameMapE #-}
33453311

33463312
unionWithNameMapE :: (e n -> e n -> e n) -> NameMapE c e n -> NameMapE c e n -> NameMapE c e n
3347-
unionWithNameMapE f (NameMapE nmap1) (NameMapE nmap2) =
3348-
NameMapE $ unionWithNameMap f nmap1 nmap2
3313+
unionWithNameMapE f (UnsafeNameMapE raw1) (UnsafeNameMapE raw2) =
3314+
UnsafeNameMapE $ R.unionWith f raw1 raw2
33493315
{-# INLINE unionWithNameMapE #-}
33503316

3317+
unionsWithNameMapE :: (Foldable f) => (e n -> e n -> e n) -> f (NameMapE c e n) -> NameMapE c e n
3318+
unionsWithNameMapE func maps =
3319+
foldl' (unionWithNameMapE func) mempty maps
3320+
{-# INLINE unionsWithNameMapE #-}
3321+
33513322
traverseNameMapE :: (Applicative f) => (e1 n -> f (e2 n))
33523323
-> NameMapE c e1 n -> f (NameMapE c e2 n)
3353-
traverseNameMapE f (NameMapE nmap) = NameMapE <$> traverseNameMap f nmap
3324+
traverseNameMapE f (UnsafeNameMapE raw) = UnsafeNameMapE <$> traverse f raw
33543325
{-# INLINE traverseNameMapE #-}
33553326

33563327
mapNameMapE :: (e1 n -> e2 n)
33573328
-> NameMapE c e1 n -> NameMapE c e2 n
3358-
mapNameMapE f (NameMapE nmap) = NameMapE $ mapNameMap f nmap
3329+
mapNameMapE f (UnsafeNameMapE raw) = UnsafeNameMapE $ fmap f raw
33593330
{-# INLINE mapNameMapE #-}
33603331

33613332
keysNameMapE :: NameMapE c e n -> [Name c n]
3362-
keysNameMapE (NameMapE nmap) = keysNameMap nmap
3333+
keysNameMapE = map fst . toListNameMapE
3334+
{-# INLINE keysNameMapE #-}
33633335

33643336
keySetNameMapE :: (Color c) => NameMapE c e n -> NameSet n
3365-
keySetNameMapE (NameMapE nmap) = keySetNameMap nmap
3337+
keySetNameMapE nmap = freeVarsE $ ListE $ keysNameMapE nmap
33663338

33673339
instance SinkableE e => SinkableE (NameMapE c e) where
33683340
sinkingProofE = undefined
@@ -3373,6 +3345,16 @@ instance RenameE e => RenameE (NameMapE c e) where
33733345
instance HoistableE e => HoistableE (NameMapE c e) where
33743346
freeVarsE = undefined
33753347

3348+
-- A small short-cut: When the information in a NameMapE does not, in
3349+
-- fact, reference any names, hoisting the entries cannot fail.
3350+
3351+
type NameMap (c::C) (a:: *) = NameMapE c (LiftE a)
3352+
3353+
hoistNameMap :: (BindsNames b, Show a)
3354+
=> b n l -> NameMap c a l -> (NameMap c a n)
3355+
hoistNameMap b = ignoreHoistFailure . hoistNameMapE b
3356+
{-# INLINE hoistNameMap #-}
3357+
33763358
-- === E-kinded IR coercions ===
33773359

33783360
-- XXX: the intention is that we won't have to use this much

src/lib/Occurrence.hs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,10 @@ class MaxPlus a where
8888
max :: a -> a -> a
8989
plus :: a -> a -> a
9090

91-
instance (MaxPlus a) => MaxPlus (NameMap c a n) where
91+
instance (MaxPlus (e n)) => MaxPlus (NameMapE c e n) where
9292
zero = mempty
93-
max = unionWithNameMap max
94-
plus = unionWithNameMap plus
95-
96-
deriving instance (MaxPlus (e n)) => MaxPlus (NameMapE c e n)
93+
max = unionWithNameMapE max
94+
plus = unionWithNameMapE plus
9795

9896
-- === Access ===
9997

src/lib/Vectorize.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ askVectorByteWidth :: TopVectorizeM i o Word32
131131
askVectorByteWidth = TopVectorizeM $ SubstReaderT $ lift $ lift11 (fromLiftE <$> ask)
132132

133133
extendCommuteMap :: AtomName SimpIR o -> MonoidCommutes -> TopVectorizeM i o a -> TopVectorizeM i o a
134-
extendCommuteMap name commutativity = local $ insertNameMap name commutativity
134+
extendCommuteMap name commutativity = local $ insertNameMapE name $ LiftE commutativity
135135

136136
vectorizeLoopsDestBlock :: DestBlock i
137137
-> TopVectorizeM i o (DestBlock o)
@@ -309,9 +309,9 @@ vectorSafeEffect (EffectRow effs NoTail) = allM safe $ eSetToList effs where
309309
safe (RWSEffect Writer (Var h)) = do
310310
h' <- renameM $ atomVarName h
311311
commuteMap <- ask
312-
case lookupNameMap h' commuteMap of
313-
Just Commutes -> return True
314-
Just DoesNotCommute -> return False
312+
case lookupNameMapE h' commuteMap of
313+
Just (LiftE Commutes) -> return True
314+
Just (LiftE DoesNotCommute) -> return False
315315
Nothing -> error $ "Handle " ++ pprint h ++ " not present in commute map?"
316316
safe _ = return False
317317

0 commit comments

Comments
 (0)