Skip to content

Commit b73381e

Browse files
oberblastmeistertreeowlsjakobi
authored
Make intersections much faster (#406)
Fixes #225. Co-authored-by: David Feuer <[email protected]> Co-authored-by: Simon Jakobi <[email protected]>
1 parent 807e3a4 commit b73381e

File tree

5 files changed

+178
-37
lines changed

5 files changed

+178
-37
lines changed

Data/HashMap/Internal.hs

+147-22
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ module Data.HashMap.Internal
7878
, intersection
7979
, intersectionWith
8080
, intersectionWithKey
81+
, intersectionWithKey#
8182

8283
-- * Folds
8384
, foldr'
@@ -150,9 +151,9 @@ import Data.Data (Constr, Data (..), DataType)
150151
import Data.Functor.Classes (Eq1 (..), Eq2 (..), Ord1 (..), Ord2 (..),
151152
Read1 (..), Show1 (..), Show2 (..))
152153
import Data.Functor.Identity (Identity (..))
153-
import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare)
154154
import Data.Hashable (Hashable)
155155
import Data.Hashable.Lifted (Hashable1, Hashable2)
156+
import Data.HashMap.Internal.List (isPermutationBy, unorderedCompare)
156157
import Data.Semigroup (Semigroup (..), stimesIdempotentMonoid)
157158
import GHC.Exts (Int (..), Int#, TYPE, (==#))
158159
import GHC.Stack (HasCallStack)
@@ -163,9 +164,9 @@ import Text.Read hiding (step)
163164
import qualified Data.Data as Data
164165
import qualified Data.Foldable as Foldable
165166
import qualified Data.Functor.Classes as FC
166-
import qualified Data.HashMap.Internal.Array as A
167167
import qualified Data.Hashable as H
168168
import qualified Data.Hashable.Lifted as H
169+
import qualified Data.HashMap.Internal.Array as A
169170
import qualified Data.List as List
170171
import qualified GHC.Exts as Exts
171172
import qualified Language.Haskell.TH.Syntax as TH
@@ -1627,7 +1628,7 @@ unionArrayBy f !b1 !b2 !ary1 !ary2 = A.run $ do
16271628
A.write mary i =<< A.indexM ary2 i2
16281629
go (i+1) i1 (i2+1) b'
16291630
where
1630-
m = 1 `unsafeShiftL` (countTrailingZeros b)
1631+
m = 1 `unsafeShiftL` countTrailingZeros b
16311632
testBit x = x .&. m /= 0
16321633
b' = b .&. complement m
16331634
go 0 0 0 bCombined
@@ -1759,37 +1760,161 @@ differenceWith f a b = foldlWithKey' go empty a
17591760
-- | \(O(n \log m)\) Intersection of two maps. Return elements of the first
17601761
-- map for keys existing in the second.
17611762
intersection :: (Eq k, Hashable k) => HashMap k v -> HashMap k w -> HashMap k v
1762-
intersection a b = foldlWithKey' go empty a
1763-
where
1764-
go m k v = case lookup k b of
1765-
Just _ -> unsafeInsert k v m
1766-
_ -> m
1763+
intersection = Exts.inline intersectionWith const
17671764
{-# INLINABLE intersection #-}
17681765

17691766
-- | \(O(n \log m)\) Intersection of two maps. If a key occurs in both maps
17701767
-- the provided function is used to combine the values from the two
17711768
-- maps.
1772-
intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1
1773-
-> HashMap k v2 -> HashMap k v3
1774-
intersectionWith f a b = foldlWithKey' go empty a
1775-
where
1776-
go m k v = case lookup k b of
1777-
Just w -> unsafeInsert k (f v w) m
1778-
_ -> m
1769+
intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
1770+
intersectionWith f = Exts.inline intersectionWithKey $ const f
17791771
{-# INLINABLE intersectionWith #-}
17801772

17811773
-- | \(O(n \log m)\) Intersection of two maps. If a key occurs in both maps
17821774
-- the provided function is used to combine the values from the two
17831775
-- maps.
1784-
intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3)
1785-
-> HashMap k v1 -> HashMap k v2 -> HashMap k v3
1786-
intersectionWithKey f a b = foldlWithKey' go empty a
1787-
where
1788-
go m k v = case lookup k b of
1789-
Just w -> unsafeInsert k (f k v w) m
1790-
_ -> m
1776+
intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
1777+
intersectionWithKey f = intersectionWithKey# $ \k v1 v2 -> (# f k v1 v2 #)
17911778
{-# INLINABLE intersectionWithKey #-}
17921779

1780+
intersectionWithKey# :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
1781+
intersectionWithKey# f = go 0
1782+
where
1783+
-- empty vs. anything
1784+
go !_ _ Empty = Empty
1785+
go _ Empty _ = Empty
1786+
-- leaf vs. anything
1787+
go s (Leaf h1 (L k1 v1)) t2 =
1788+
lookupCont
1789+
(\_ -> Empty)
1790+
(\v _ -> case f k1 v1 v of (# v' #) -> Leaf h1 $ L k1 v')
1791+
h1 k1 s t2
1792+
go s t1 (Leaf h2 (L k2 v2)) =
1793+
lookupCont
1794+
(\_ -> Empty)
1795+
(\v _ -> case f k2 v v2 of (# v' #) -> Leaf h2 $ L k2 v')
1796+
h2 k2 s t1
1797+
-- collision vs. collision
1798+
go _ (Collision h1 ls1) (Collision h2 ls2) = intersectionCollisions f h1 h2 ls1 ls2
1799+
-- branch vs. branch
1800+
go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) =
1801+
intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2
1802+
go s (BitmapIndexed b1 ary1) (Full ary2) =
1803+
intersectionArrayBy (go (s + bitsPerSubkey)) b1 fullNodeMask ary1 ary2
1804+
go s (Full ary1) (BitmapIndexed b2 ary2) =
1805+
intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask b2 ary1 ary2
1806+
go s (Full ary1) (Full ary2) =
1807+
intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask fullNodeMask ary1 ary2
1808+
-- collision vs. branch
1809+
go s (BitmapIndexed b1 ary1) t2@(Collision h2 _ls2)
1810+
| b1 .&. m2 == 0 = Empty
1811+
| otherwise = go (s + bitsPerSubkey) (A.index ary1 i) t2
1812+
where
1813+
m2 = mask h2 s
1814+
i = sparseIndex b1 m2
1815+
go s t1@(Collision h1 _ls1) (BitmapIndexed b2 ary2)
1816+
| b2 .&. m1 == 0 = Empty
1817+
| otherwise = go (s + bitsPerSubkey) t1 (A.index ary2 i)
1818+
where
1819+
m1 = mask h1 s
1820+
i = sparseIndex b2 m1
1821+
go s (Full ary1) t2@(Collision h2 _ls2) = go (s + bitsPerSubkey) (A.index ary1 i) t2
1822+
where
1823+
i = index h2 s
1824+
go s t1@(Collision h1 _ls1) (Full ary2) = go (s + bitsPerSubkey) t1 (A.index ary2 i)
1825+
where
1826+
i = index h1 s
1827+
{-# INLINE intersectionWithKey# #-}
1828+
1829+
intersectionArrayBy ::
1830+
( HashMap k v1 ->
1831+
HashMap k v2 ->
1832+
HashMap k v3
1833+
) ->
1834+
Bitmap ->
1835+
Bitmap ->
1836+
A.Array (HashMap k v1) ->
1837+
A.Array (HashMap k v2) ->
1838+
HashMap k v3
1839+
intersectionArrayBy f !b1 !b2 !ary1 !ary2
1840+
| b1 .&. b2 == 0 = Empty
1841+
| otherwise = runST $ do
1842+
mary <- A.new_ $ popCount bIntersect
1843+
-- iterate over nonzero bits of b1 .|. b2
1844+
let go !i !i1 !i2 !b !bFinal
1845+
| b == 0 = pure (i, bFinal)
1846+
| testBit $ b1 .&. b2 = do
1847+
x1 <- A.indexM ary1 i1
1848+
x2 <- A.indexM ary2 i2
1849+
case f x1 x2 of
1850+
Empty -> go i (i1 + 1) (i2 + 1) b' (bFinal .&. complement m)
1851+
_ -> do
1852+
A.write mary i $! f x1 x2
1853+
go (i + 1) (i1 + 1) (i2 + 1) b' bFinal
1854+
| testBit b1 = go i (i1 + 1) i2 b' bFinal
1855+
| otherwise = go i i1 (i2 + 1) b' bFinal
1856+
where
1857+
m = 1 `unsafeShiftL` countTrailingZeros b
1858+
testBit x = x .&. m /= 0
1859+
b' = b .&. complement m
1860+
(len, bFinal) <- go 0 0 0 bCombined bIntersect
1861+
case len of
1862+
0 -> pure Empty
1863+
1 -> A.read mary 0
1864+
_ -> bitmapIndexedOrFull bFinal <$> (A.unsafeFreeze =<< A.shrink mary len)
1865+
where
1866+
bCombined = b1 .|. b2
1867+
bIntersect = b1 .&. b2
1868+
{-# INLINE intersectionArrayBy #-}
1869+
1870+
intersectionCollisions :: Eq k => (k -> v1 -> v2 -> (# v3 #)) -> Hash -> Hash -> A.Array (Leaf k v1) -> A.Array (Leaf k v2) -> HashMap k v3
1871+
intersectionCollisions f h1 h2 ary1 ary2
1872+
| h1 == h2 = runST $ do
1873+
mary2 <- A.thaw ary2 0 $ A.length ary2
1874+
mary <- A.new_ $ min (A.length ary1) (A.length ary2)
1875+
let go i j
1876+
| i >= A.length ary1 || j >= A.lengthM mary2 = pure j
1877+
| otherwise = do
1878+
L k1 v1 <- A.indexM ary1 i
1879+
searchSwap k1 j mary2 >>= \case
1880+
Just (L _k2 v2) -> do
1881+
let !(# v3 #) = f k1 v1 v2
1882+
A.write mary j $ L k1 v3
1883+
go (i + 1) (j + 1)
1884+
Nothing -> do
1885+
go (i + 1) j
1886+
len <- go 0 0
1887+
case len of
1888+
0 -> pure Empty
1889+
1 -> Leaf h1 <$> A.read mary 0
1890+
_ -> Collision h1 <$> (A.unsafeFreeze =<< A.shrink mary len)
1891+
| otherwise = Empty
1892+
{-# INLINE intersectionCollisions #-}
1893+
1894+
-- | Say we have
1895+
-- @
1896+
-- 1 2 3 4
1897+
-- @
1898+
-- and we search for @3@. Then we can mutate the array to
1899+
-- @
1900+
-- undefined 2 1 4
1901+
-- @
1902+
-- We don't actually need to write undefined, we just have to make sure that the next search starts 1 after the current one.
1903+
searchSwap :: Eq k => k -> Int -> A.MArray s (Leaf k v) -> ST s (Maybe (Leaf k v))
1904+
searchSwap toFind start = go start toFind start
1905+
where
1906+
go i0 k i mary
1907+
| i >= A.lengthM mary = pure Nothing
1908+
| otherwise = do
1909+
l@(L k' _v) <- A.read mary i
1910+
if k == k'
1911+
then do
1912+
A.write mary i =<< A.read mary i0
1913+
pure $ Just l
1914+
else go i0 k (i + 1) mary
1915+
{-# INLINE searchSwap #-}
1916+
1917+
17931918
------------------------------------------------------------------------
17941919
-- * Folds
17951920

Data/HashMap/Internal/Array.hs

+16
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ module Data.HashMap.Internal.Array
7777
, toList
7878
, fromList
7979
, fromList'
80+
, shrink
8081
) where
8182

8283
import Control.Applicative (liftA2)
@@ -92,6 +93,7 @@ import GHC.Exts (Int (..), SmallArray#, SmallMutableArray#,
9293
thawSmallArray#, unsafeCoerce#,
9394
unsafeFreezeSmallArray#, unsafeThawSmallArray#,
9495
writeSmallArray#)
96+
import qualified GHC.Exts as Exts
9597
import GHC.ST (ST (..))
9698
import Prelude hiding (all, filter, foldMap, foldl, foldr, length,
9799
map, read, traverse)
@@ -205,6 +207,20 @@ new _n@(I# n#) b =
205207
new_ :: Int -> ST s (MArray s a)
206208
new_ n = new n undefinedElem
207209

210+
-- | When 'Exts.shrinkSmallMutableArray#' is available, the returned array is the same as the array given, as it is shrunk in place.
211+
-- Otherwise a copy is made.
212+
shrink :: MArray s a -> Int -> ST s (MArray s a)
213+
#if __GLASGOW_HASKELL__ >= 810
214+
shrink mary _n@(I# n#) =
215+
CHECK_GT("shrink", _n, (0 :: Int))
216+
CHECK_LE("shrink", _n, (lengthM mary))
217+
ST $ \s -> case Exts.shrinkSmallMutableArray# (unMArray mary) n# s of
218+
s' -> (# s', mary #)
219+
#else
220+
shrink mary n = cloneM mary 0 n
221+
#endif
222+
{-# INLINE shrink #-}
223+
208224
singleton :: a -> Array a
209225
singleton x = runST (singletonM x)
210226
{-# INLINE singleton #-}

Data/HashMap/Internal/Strict.hs

+4-11
Original file line numberDiff line numberDiff line change
@@ -128,16 +128,17 @@ import Data.Bits ((.&.), (.|.))
128128
import Data.Coerce (coerce)
129129
import Data.Functor.Identity (Identity (..))
130130
-- See Note [Imports from Data.HashMap.Internal]
131+
import Data.Hashable (Hashable)
131132
import Data.HashMap.Internal (Hash, HashMap (..), Leaf (..), LookupRes (..),
132133
bitsPerSubkey, fullNodeMask, hash, index, mask,
133134
ptrEq, sparseIndex)
134-
import Data.Hashable (Hashable)
135135
import Prelude hiding (lookup, map)
136136

137137
-- See Note [Imports from Data.HashMap.Internal]
138138
import qualified Data.HashMap.Internal as HM
139139
import qualified Data.HashMap.Internal.Array as A
140140
import qualified Data.List as List
141+
import qualified GHC.Exts as Exts
141142

142143
{-
143144
Note [Imports from Data.HashMap.Internal]
@@ -616,23 +617,15 @@ differenceWith f a b = HM.foldlWithKey' go HM.empty a
616617
-- maps.
617618
intersectionWith :: (Eq k, Hashable k) => (v1 -> v2 -> v3) -> HashMap k v1
618619
-> HashMap k v2 -> HashMap k v3
619-
intersectionWith f a b = HM.foldlWithKey' go HM.empty a
620-
where
621-
go m k v = case HM.lookup k b of
622-
Just w -> let !x = f v w in HM.unsafeInsert k x m
623-
_ -> m
620+
intersectionWith f = Exts.inline intersectionWithKey $ const f
624621
{-# INLINABLE intersectionWith #-}
625622

626623
-- | \(O(n+m)\) Intersection of two maps. If a key occurs in both maps
627624
-- the provided function is used to combine the values from the two
628625
-- maps.
629626
intersectionWithKey :: (Eq k, Hashable k) => (k -> v1 -> v2 -> v3)
630627
-> HashMap k v1 -> HashMap k v2 -> HashMap k v3
631-
intersectionWithKey f a b = HM.foldlWithKey' go HM.empty a
632-
where
633-
go m k v = case HM.lookup k b of
634-
Just w -> let !x = f k v w in HM.unsafeInsert k x m
635-
_ -> m
628+
intersectionWithKey f = HM.intersectionWithKey# $ \k v1 v2 -> let !v3 = f k v1 v2 in (# v3 #)
636629
{-# INLINABLE intersectionWithKey #-}
637630

638631
------------------------------------------------------------------------

benchmarks/Benchmarks.hs

+5-1
Original file line numberDiff line numberDiff line change
@@ -318,13 +318,17 @@ main = do
318318
[ bench "Int" $ whnf (HM.union hmi) hmi2
319319
, bench "ByteString" $ whnf (HM.union hmbs) hmbsSubset
320320
]
321+
322+
, bgroup "intersection"
323+
[ bench "Int" $ whnf (HM.intersection hmi) hmi2
324+
, bench "ByteString" $ whnf (HM.intersection hmbs) hmbsSubset
325+
]
321326

322327
-- Transformations
323328
, bench "map" $ whnf (HM.map (\ v -> v + 1)) hmi
324329

325330
-- * Difference and intersection
326331
, bench "difference" $ whnf (HM.difference hmi) hmi2
327-
, bench "intersection" $ whnf (HM.intersection hmi) hmi2
328332

329333
-- Folds
330334
, bench "foldl'" $ whnf (HM.foldl' (+) 0) hmi

tests/Properties/HashMapLazy.hs

+6-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
{-# LANGUAGE CPP #-}
22
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
33
{-# OPTIONS_GHC -fno-warn-orphans #-} -- because of Arbitrary (HashMap k v)
4+
{-# LANGUAGE BangPatterns #-}
45

56
-- | Tests for the 'Data.HashMap.Lazy' module. We test functions by
67
-- comparing them to @Map@ from @containers@.
@@ -42,7 +43,7 @@ import qualified Data.Map.Lazy as M
4243

4344
-- Key type that generates more hash collisions.
4445
newtype Key = K { unK :: Int }
45-
deriving (Arbitrary, Eq, Ord, Read, Show)
46+
deriving (Arbitrary, Eq, Ord, Read, Show, Num)
4647

4748
instance Hashable Key where
4849
hashWithSalt salt k = hashWithSalt salt (unK k) `mod` 20
@@ -318,8 +319,10 @@ pDifferenceWith xs ys = M.differenceWith f (M.fromList xs) `eq_`
318319
f x y = if x == 0 then Nothing else Just (x - y)
319320

320321
pIntersection :: [(Key, Int)] -> [(Key, Int)] -> Bool
321-
pIntersection xs ys = M.intersection (M.fromList xs) `eq_`
322-
HM.intersection (HM.fromList xs) $ ys
322+
pIntersection xs ys =
323+
M.intersection (M.fromList xs)
324+
`eq_` HM.intersection (HM.fromList xs)
325+
$ ys
323326

324327
pIntersectionWith :: [(Key, Int)] -> [(Key, Int)] -> Bool
325328
pIntersectionWith xs ys = M.intersectionWith (-) (M.fromList xs) `eq_`

0 commit comments

Comments
 (0)