@@ -78,6 +78,7 @@ module Data.HashMap.Internal
78
78
, intersection
79
79
, intersectionWith
80
80
, intersectionWithKey
81
+ , intersectionWithKey #
81
82
82
83
-- * Folds
83
84
, foldr'
@@ -150,9 +151,9 @@ import Data.Data (Constr, Data (..), DataType)
150
151
import Data.Functor.Classes (Eq1 (.. ), Eq2 (.. ), Ord1 (.. ), Ord2 (.. ),
151
152
Read1 (.. ), Show1 (.. ), Show2 (.. ))
152
153
import Data.Functor.Identity (Identity (.. ))
153
- import Data.HashMap.Internal.List (isPermutationBy , unorderedCompare )
154
154
import Data.Hashable (Hashable )
155
155
import Data.Hashable.Lifted (Hashable1 , Hashable2 )
156
+ import Data.HashMap.Internal.List (isPermutationBy , unorderedCompare )
156
157
import Data.Semigroup (Semigroup (.. ), stimesIdempotentMonoid )
157
158
import GHC.Exts (Int (.. ), Int #, TYPE , (==#) )
158
159
import GHC.Stack (HasCallStack )
@@ -163,9 +164,9 @@ import Text.Read hiding (step)
163
164
import qualified Data.Data as Data
164
165
import qualified Data.Foldable as Foldable
165
166
import qualified Data.Functor.Classes as FC
166
- import qualified Data.HashMap.Internal.Array as A
167
167
import qualified Data.Hashable as H
168
168
import qualified Data.Hashable.Lifted as H
169
+ import qualified Data.HashMap.Internal.Array as A
169
170
import qualified Data.List as List
170
171
import qualified GHC.Exts as Exts
171
172
import qualified Language.Haskell.TH.Syntax as TH
@@ -1627,7 +1628,7 @@ unionArrayBy f !b1 !b2 !ary1 !ary2 = A.run $ do
1627
1628
A. write mary i =<< A. indexM ary2 i2
1628
1629
go (i+ 1 ) i1 (i2+ 1 ) b'
1629
1630
where
1630
- m = 1 `unsafeShiftL` ( countTrailingZeros b)
1631
+ m = 1 `unsafeShiftL` countTrailingZeros b
1631
1632
testBit x = x .&. m /= 0
1632
1633
b' = b .&. complement m
1633
1634
go 0 0 0 bCombined
@@ -1759,37 +1760,161 @@ differenceWith f a b = foldlWithKey' go empty a
1759
1760
-- | \(O(n \log m)\) Intersection of two maps. Return elements of the first
1760
1761
-- map for keys existing in the second.
1761
1762
intersection :: (Eq k , Hashable k ) => HashMap k v -> HashMap k w -> HashMap k v
1762
- intersection a b = foldlWithKey' go empty a
1763
- where
1764
- go m k v = case lookup k b of
1765
- Just _ -> unsafeInsert k v m
1766
- _ -> m
1763
+ intersection = Exts. inline intersectionWith const
1767
1764
{-# INLINABLE intersection #-}
1768
1765
1769
1766
-- | \(O(n \log m)\) Intersection of two maps. If a key occurs in both maps
1770
1767
-- the provided function is used to combine the values from the two
1771
1768
-- maps.
1772
- intersectionWith :: (Eq k , Hashable k ) => (v1 -> v2 -> v3 ) -> HashMap k v1
1773
- -> HashMap k v2 -> HashMap k v3
1774
- intersectionWith f a b = foldlWithKey' go empty a
1775
- where
1776
- go m k v = case lookup k b of
1777
- Just w -> unsafeInsert k (f v w) m
1778
- _ -> m
1769
+ intersectionWith :: (Eq k , Hashable k ) => (v1 -> v2 -> v3 ) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
1770
+ intersectionWith f = Exts. inline intersectionWithKey $ const f
1779
1771
{-# INLINABLE intersectionWith #-}
1780
1772
1781
1773
-- | \(O(n \log m)\) Intersection of two maps. If a key occurs in both maps
1782
1774
-- the provided function is used to combine the values from the two
1783
1775
-- maps.
1784
- intersectionWithKey :: (Eq k , Hashable k ) => (k -> v1 -> v2 -> v3 )
1785
- -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
1786
- intersectionWithKey f a b = foldlWithKey' go empty a
1787
- where
1788
- go m k v = case lookup k b of
1789
- Just w -> unsafeInsert k (f k v w) m
1790
- _ -> m
1776
+ intersectionWithKey :: (Eq k , Hashable k ) => (k -> v1 -> v2 -> v3 ) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
1777
+ intersectionWithKey f = intersectionWithKey# $ \ k v1 v2 -> (# f k v1 v2 # )
1791
1778
{-# INLINABLE intersectionWithKey #-}
1792
1779
1780
+ intersectionWithKey# :: Eq k => (k -> v1 -> v2 -> (# v3 # )) -> HashMap k v1 -> HashMap k v2 -> HashMap k v3
1781
+ intersectionWithKey# f = go 0
1782
+ where
1783
+ -- empty vs. anything
1784
+ go ! _ _ Empty = Empty
1785
+ go _ Empty _ = Empty
1786
+ -- leaf vs. anything
1787
+ go s (Leaf h1 (L k1 v1)) t2 =
1788
+ lookupCont
1789
+ (\ _ -> Empty )
1790
+ (\ v _ -> case f k1 v1 v of (# v' # ) -> Leaf h1 $ L k1 v')
1791
+ h1 k1 s t2
1792
+ go s t1 (Leaf h2 (L k2 v2)) =
1793
+ lookupCont
1794
+ (\ _ -> Empty )
1795
+ (\ v _ -> case f k2 v v2 of (# v' # ) -> Leaf h2 $ L k2 v')
1796
+ h2 k2 s t1
1797
+ -- collision vs. collision
1798
+ go _ (Collision h1 ls1) (Collision h2 ls2) = intersectionCollisions f h1 h2 ls1 ls2
1799
+ -- branch vs. branch
1800
+ go s (BitmapIndexed b1 ary1) (BitmapIndexed b2 ary2) =
1801
+ intersectionArrayBy (go (s + bitsPerSubkey)) b1 b2 ary1 ary2
1802
+ go s (BitmapIndexed b1 ary1) (Full ary2) =
1803
+ intersectionArrayBy (go (s + bitsPerSubkey)) b1 fullNodeMask ary1 ary2
1804
+ go s (Full ary1) (BitmapIndexed b2 ary2) =
1805
+ intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask b2 ary1 ary2
1806
+ go s (Full ary1) (Full ary2) =
1807
+ intersectionArrayBy (go (s + bitsPerSubkey)) fullNodeMask fullNodeMask ary1 ary2
1808
+ -- collision vs. branch
1809
+ go s (BitmapIndexed b1 ary1) t2@ (Collision h2 _ls2)
1810
+ | b1 .&. m2 == 0 = Empty
1811
+ | otherwise = go (s + bitsPerSubkey) (A. index ary1 i) t2
1812
+ where
1813
+ m2 = mask h2 s
1814
+ i = sparseIndex b1 m2
1815
+ go s t1@ (Collision h1 _ls1) (BitmapIndexed b2 ary2)
1816
+ | b2 .&. m1 == 0 = Empty
1817
+ | otherwise = go (s + bitsPerSubkey) t1 (A. index ary2 i)
1818
+ where
1819
+ m1 = mask h1 s
1820
+ i = sparseIndex b2 m1
1821
+ go s (Full ary1) t2@ (Collision h2 _ls2) = go (s + bitsPerSubkey) (A. index ary1 i) t2
1822
+ where
1823
+ i = index h2 s
1824
+ go s t1@ (Collision h1 _ls1) (Full ary2) = go (s + bitsPerSubkey) t1 (A. index ary2 i)
1825
+ where
1826
+ i = index h1 s
1827
+ {-# INLINE intersectionWithKey# #-}
1828
+
1829
+ intersectionArrayBy ::
1830
+ ( HashMap k v1 ->
1831
+ HashMap k v2 ->
1832
+ HashMap k v3
1833
+ ) ->
1834
+ Bitmap ->
1835
+ Bitmap ->
1836
+ A. Array (HashMap k v1 ) ->
1837
+ A. Array (HashMap k v2 ) ->
1838
+ HashMap k v3
1839
+ intersectionArrayBy f ! b1 ! b2 ! ary1 ! ary2
1840
+ | b1 .&. b2 == 0 = Empty
1841
+ | otherwise = runST $ do
1842
+ mary <- A. new_ $ popCount bIntersect
1843
+ -- iterate over nonzero bits of b1 .|. b2
1844
+ let go ! i ! i1 ! i2 ! b ! bFinal
1845
+ | b == 0 = pure (i, bFinal)
1846
+ | testBit $ b1 .&. b2 = do
1847
+ x1 <- A. indexM ary1 i1
1848
+ x2 <- A. indexM ary2 i2
1849
+ case f x1 x2 of
1850
+ Empty -> go i (i1 + 1 ) (i2 + 1 ) b' (bFinal .&. complement m)
1851
+ _ -> do
1852
+ A. write mary i $! f x1 x2
1853
+ go (i + 1 ) (i1 + 1 ) (i2 + 1 ) b' bFinal
1854
+ | testBit b1 = go i (i1 + 1 ) i2 b' bFinal
1855
+ | otherwise = go i i1 (i2 + 1 ) b' bFinal
1856
+ where
1857
+ m = 1 `unsafeShiftL` countTrailingZeros b
1858
+ testBit x = x .&. m /= 0
1859
+ b' = b .&. complement m
1860
+ (len, bFinal) <- go 0 0 0 bCombined bIntersect
1861
+ case len of
1862
+ 0 -> pure Empty
1863
+ 1 -> A. read mary 0
1864
+ _ -> bitmapIndexedOrFull bFinal <$> (A. unsafeFreeze =<< A. shrink mary len)
1865
+ where
1866
+ bCombined = b1 .|. b2
1867
+ bIntersect = b1 .&. b2
1868
+ {-# INLINE intersectionArrayBy #-}
1869
+
1870
+ intersectionCollisions :: Eq k => (k -> v1 -> v2 -> (# v3 # )) -> Hash -> Hash -> A. Array (Leaf k v1 ) -> A. Array (Leaf k v2 ) -> HashMap k v3
1871
+ intersectionCollisions f h1 h2 ary1 ary2
1872
+ | h1 == h2 = runST $ do
1873
+ mary2 <- A. thaw ary2 0 $ A. length ary2
1874
+ mary <- A. new_ $ min (A. length ary1) (A. length ary2)
1875
+ let go i j
1876
+ | i >= A. length ary1 || j >= A. lengthM mary2 = pure j
1877
+ | otherwise = do
1878
+ L k1 v1 <- A. indexM ary1 i
1879
+ searchSwap k1 j mary2 >>= \ case
1880
+ Just (L _k2 v2) -> do
1881
+ let ! (# v3 # ) = f k1 v1 v2
1882
+ A. write mary j $ L k1 v3
1883
+ go (i + 1 ) (j + 1 )
1884
+ Nothing -> do
1885
+ go (i + 1 ) j
1886
+ len <- go 0 0
1887
+ case len of
1888
+ 0 -> pure Empty
1889
+ 1 -> Leaf h1 <$> A. read mary 0
1890
+ _ -> Collision h1 <$> (A. unsafeFreeze =<< A. shrink mary len)
1891
+ | otherwise = Empty
1892
+ {-# INLINE intersectionCollisions #-}
1893
+
1894
+ -- | Say we have
1895
+ -- @
1896
+ -- 1 2 3 4
1897
+ -- @
1898
+ -- and we search for @3@. Then we can mutate the array to
1899
+ -- @
1900
+ -- undefined 2 1 4
1901
+ -- @
1902
+ -- We don't actually need to write undefined, we just have to make sure that the next search starts 1 after the current one.
1903
+ searchSwap :: Eq k => k -> Int -> A. MArray s (Leaf k v ) -> ST s (Maybe (Leaf k v ))
1904
+ searchSwap toFind start = go start toFind start
1905
+ where
1906
+ go i0 k i mary
1907
+ | i >= A. lengthM mary = pure Nothing
1908
+ | otherwise = do
1909
+ l@ (L k' _v) <- A. read mary i
1910
+ if k == k'
1911
+ then do
1912
+ A. write mary i =<< A. read mary i0
1913
+ pure $ Just l
1914
+ else go i0 k (i + 1 ) mary
1915
+ {-# INLINE searchSwap #-}
1916
+
1917
+
1793
1918
------------------------------------------------------------------------
1794
1919
-- * Folds
1795
1920
0 commit comments