diff --git a/immut/hashmap/HAMT.mbt b/immut/hashmap/HAMT.mbt index 4a685aa8c..6f1b9039b 100644 --- a/immut/hashmap/HAMT.mbt +++ b/immut/hashmap/HAMT.mbt @@ -329,6 +329,133 @@ pub fn union[K : Eq + Hash, V](self : T[K, V], other : T[K, V]) -> T[K, V] { } } +///| +/// Union two hashmaps with a function +pub fn union_with[K : Eq + Hash, V]( + self : T[K, V], + other : T[K, V], + f : (K, V, V) -> V +) -> T[K, V] { + match (self, other) { + (_, Empty) => self + (Empty, _) => other + (_, Leaf(k, v)) => + match self.get(k) { + Some(v1) => self.add(k, f(k, v1, v)) + None => self.add(k, v) + } + (Leaf(k, v), _) => + match other.get(k) { + Some(v2) => other.add(k, f(k, v, v2)) + None => other.add(k, v) + } + (Branch(sa1), Branch(sa2)) => + Branch(sa1.union(sa2, fn(m1, m2) { m1.union_with(m2, f) })) + (_, _) => + self + .iter() + .fold(init=other, fn(m, kv) { + match m.get(kv.0) { + Some(v2) => m.add(kv.0, f(kv.0, kv.1, v2)) + None => m.add(kv.0, kv.1) + } + }) + } +} + +///| +/// Intersect two hashmaps +pub fn intersection[K : Eq + Hash, V]( + self : T[K, V], + other : T[K, V] +) -> T[K, V] { + match (self, other) { + (_, Empty) => Empty + (Empty, _) => Empty + (Leaf(k, v), _) => + match other.get(k) { + Some(_) => Leaf(k, v) + None => Empty + } + (_, Leaf(k, _)) => + match self.get(k) { + Some(v) => Leaf(k, v) + None => Empty + } + (Branch(sa1), Branch(sa2)) => + Branch(sa1.intersection(sa2, fn(m1, m2) { m1.intersection(m2) })) + (_, _) => + self + .iter() + .fold(init=Empty, fn(m, kv) { + if other.get(kv.0) is Some(_) { + m.add(kv.0, kv.1) + } else { + m + } + }) + } +} + +///| +/// Intersection two hashmaps with a function +pub fn intersection_with[K : Eq + Hash, V]( + self : T[K, V], + other : T[K, V], + f : (K, V, V) -> V +) -> T[K, V] { + match (self, other) { + (_, Empty) => Empty + (Empty, _) => Empty + (Leaf(k, v), _) => + match other.get(k) { + Some(v2) => Leaf(k, f(k, v, v2)) + None => Empty + } + (_, Leaf(k, v2)) => + match self.get(k) { + Some(v1) => Leaf(k, f(k, v1, v2)) + None => Empty + } + (Branch(sa1), Branch(sa2)) => + Branch(sa1.intersection(sa2, fn(m1, m2) { m1.intersection_with(m2, f) })) + (_, _) => + self + .iter() + .fold(init=Empty, fn(m, kv) { + match other.get(kv.0) { + Some(v2) => m.add(kv.0, f(kv.0, kv.1, v2)) + None => m + } + }) + } +} + +///| +/// Difference of two hashmaps: elements in `self` but not in `other` +pub fn difference[K : Eq + Hash, V](self : T[K, V], other : T[K, V]) -> T[K, V] { + match (self, other) { + (Empty, _) => Empty + (_, Empty) => self + (Leaf(k, v), _) => + match other.get(k) { + Some(_) => Empty + None => Leaf(k, v) + } + (Branch(sa1), Branch(sa2)) => Branch(sa1.difference(sa2)) + (_, _) => + self + .iter() + .fold(init=Empty, fn(m, kv) { + if other.get(kv.0) is None { + m.add(kv.0, kv.1) + } else { + m + } + }) + } +} + ///| /// Iterate through the elements in a hash map pub fn each[K, V](self : T[K, V], f : (K, V) -> Unit) -> Unit { diff --git a/immut/hashmap/HAMT_test.mbt b/immut/hashmap/HAMT_test.mbt index 63e88b6cc..e6752d1a2 100644 --- a/immut/hashmap/HAMT_test.mbt +++ b/immut/hashmap/HAMT_test.mbt @@ -580,6 +580,99 @@ test "HAMT::union leaf to non-overlapping map" { assert_eq!(u.get(2), Some(2)) } +///| +test "HAMT::intersection with empty" { + let m1 = @hashmap.of([(1, 1)]) + let m2 = @hashmap.new() + assert_eq!(m1.intersection(m2).size(), 0) + assert_eq!(m2.intersection(m1).size(), 0) +} + +///| +test "HAMT::intersection with leaf" { + let m1 = @hashmap.of([(1, 1), (2, 2)]) + let m2 = @hashmap.singleton(2, 2) + let m3 = @hashmap.singleton(3, 3) + assert_eq!(m1.intersection(m2).get(2), Some(2)) + assert_eq!(m1.intersection(m3).get(3), None) + assert_eq!(m2.intersection(m1).get(2), Some(2)) +} + +///| +test "HAMT::intersection with branch" { + let m1 = @hashmap.of([(1, 1), (2, 2), (3, 3)]) + let m2 = @hashmap.of([(2, 2), (3, 30), (4, 4)]) + let inter = m1.intersection(m2) + assert_eq!(inter.get(1), None) + assert_eq!(inter.get(2), Some(2)) + assert_eq!(inter.get(3), Some(3)) + assert_eq!(inter.get(4), None) +} + +///| +test "HAMT::intersection_with with empty" { + let m1 = @hashmap.of([(1, 1)]) + let m2 = @hashmap.new() + assert_eq!(m1.intersection_with(m2, fn(_k, v1, v2) { v1 + v2 }).size(), 0) + assert_eq!(m2.intersection_with(m1, fn(_k, v1, v2) { v1 + v2 }).size(), 0) +} + +///| +test "HAMT::intersection_with with leaf" { + let m1 = @hashmap.of([(1, 1), (2, 2)]) + let m2 = @hashmap.singleton(2, 20) + let m3 = @hashmap.singleton(3, 30) + assert_eq!( + m1.intersection_with(m2, fn(_k, v1, v2) { v1 + v2 }).get(2), + Some(22), + ) + assert_eq!(m1.intersection_with(m3, fn(_k, v1, v2) { v1 + v2 }).get(3), None) + assert_eq!( + m2.intersection_with(m1, fn(_k, v1, v2) { v1 + v2 }).get(2), + Some(22), + ) +} + +///| +test "HAMT::intersection_with with branch" { + let m1 = @hashmap.of([(1, 1), (2, 2), (3, 3)]) + let m2 = @hashmap.of([(2, 20), (3, 30), (4, 4)]) + let inter = m1.intersection_with(m2, fn(_k, v1, v2) { v1 * v2 }) + assert_eq!(inter.get(1), None) + assert_eq!(inter.get(2), Some(40)) + assert_eq!(inter.get(3), Some(90)) + assert_eq!(inter.get(4), None) +} + +///| +test "HAMT::difference with empty" { + let m1 = @hashmap.of([(1, 1)]) + let m2 = @hashmap.new() + assert_eq!(m1.difference(m2), m1) + assert_eq!(m2.difference(m1).size(), 0) +} + +///| +test "HAMT::difference with leaf" { + let m1 = @hashmap.of([(1, 1), (2, 2)]) + let m2 = @hashmap.singleton(2, 2) + let m3 = @hashmap.singleton(3, 3) + assert_eq!(m1.difference(m2).get(2), None) + assert_eq!(m1.difference(m3).get(1), Some(1)) + assert_eq!(m2.difference(m1).size(), 0) +} + +///| +test "HAMT::difference with branch" { + let m1 = @hashmap.of([(1, 1), (2, 2), (3, 3)]) + let m2 = @hashmap.of([(2, 2), (3, 30), (4, 4)]) + let diff = m1.difference(m2) + assert_eq!(diff.get(1), Some(1)) + assert_eq!(diff.get(2), None) + assert_eq!(diff.get(3), None) + assert_eq!(diff.get(4), None) +} + ///| test "HAMT::each" { let empty = @hashmap.new() diff --git a/immut/hashmap/hashmap.mbti b/immut/hashmap/hashmap.mbti index 90af0df6d..3199f8965 100644 --- a/immut/hashmap/hashmap.mbti +++ b/immut/hashmap/hashmap.mbti @@ -9,6 +9,8 @@ fn add[K : Eq + Hash, V](T[K, V], K, V) -> T[K, V] fn contains[K : Eq + Hash, V](T[K, V], K) -> Bool +fn difference[K : Eq + Hash, V](T[K, V], T[K, V]) -> T[K, V] + fn each[K, V](T[K, V], (K, V) -> Unit) -> Unit fn elems[K, V](T[K, V]) -> Iter[V] @@ -28,6 +30,10 @@ fn from_iter[K : Eq + Hash, V](Iter[(K, V)]) -> T[K, V] fn get[K : Eq + Hash, V](T[K, V], K) -> V? +fn intersection[K : Eq + Hash, V](T[K, V], T[K, V]) -> T[K, V] + +fn intersection_with[K : Eq + Hash, V](T[K, V], T[K, V], (K, V, V) -> V) -> T[K, V] + fn iter[K, V](T[K, V]) -> Iter[(K, V)] fn iter2[K, V](T[K, V]) -> Iter2[K, V] @@ -55,11 +61,14 @@ fn to_array[K, V](T[K, V]) -> Array[(K, V)] fn union[K : Eq + Hash, V](T[K, V], T[K, V]) -> T[K, V] +fn union_with[K : Eq + Hash, V](T[K, V], T[K, V], (K, V, V) -> V) -> T[K, V] + // Types and methods type T[K, V] impl T { add[K : Eq + Hash, V](Self[K, V], K, V) -> Self[K, V] contains[K : Eq + Hash, V](Self[K, V], K) -> Bool + difference[K : Eq + Hash, V](Self[K, V], Self[K, V]) -> Self[K, V] each[K, V](Self[K, V], (K, V) -> Unit) -> Unit elems[K, V](Self[K, V]) -> Iter[V] filter[K : Eq + Hash, V](Self[K, V], (V) -> Bool) -> Self[K, V] @@ -68,6 +77,8 @@ impl T { fold[K, V, A](Self[K, V], init~ : A, (A, V) -> A) -> A fold_with_key[K, V, A](Self[K, V], init~ : A, (A, K, V) -> A) -> A get[K : Eq + Hash, V](Self[K, V], K) -> V? + intersection[K : Eq + Hash, V](Self[K, V], Self[K, V]) -> Self[K, V] + intersection_with[K : Eq + Hash, V](Self[K, V], Self[K, V], (K, V, V) -> V) -> Self[K, V] iter[K, V](Self[K, V]) -> Iter[(K, V)] iter2[K, V](Self[K, V]) -> Iter2[K, V] keys[K, V](Self[K, V]) -> Iter[K] @@ -79,6 +90,7 @@ impl T { size[K, V](Self[K, V]) -> Int to_array[K, V](Self[K, V]) -> Array[(K, V)] union[K : Eq + Hash, V](Self[K, V], Self[K, V]) -> Self[K, V] + union_with[K : Eq + Hash, V](Self[K, V], Self[K, V], (K, V, V) -> V) -> Self[K, V] } impl[K : Eq + Hash, V : Eq] Eq for T[K, V] impl[K : Hash, V : Hash] Hash for T[K, V] diff --git a/immut/hashset/HAMT.mbt b/immut/hashset/HAMT.mbt index 1f68d73fa..dbeb95a49 100644 --- a/immut/hashset/HAMT.mbt +++ b/immut/hashset/HAMT.mbt @@ -192,6 +192,56 @@ pub fn union[K : Eq + Hash](self : T[K], other : T[K]) -> T[K] { } } +///| +/// Intersect two hashsets +pub fn intersection[K : Eq + Hash](self : T[K], other : T[K]) -> T[K] { + match (self, other) { + (_, Empty) => Empty + (Empty, _) => Empty + (Leaf(k), _) => if other.contains(k) { Leaf(k) } else { Empty } + (_, Leaf(k)) => if self.contains(k) { Leaf(k) } else { Empty } + (Branch(sa1), Branch(sa2)) => { + let res = sa1.intersection(sa2, fn(m1, m2) { m1.intersection(m2) }) + if res.size() == 0 { + Empty + } else { + Branch(res) + } + } + (_, _) => + self + .iter() + .fold(init=Empty, fn(m, k) { + if other.contains(k) { + m.add(k) + } else { + m + } + }) + } +} + +///| +/// Difference of two hashsets: elements in `self` but not in `other` +pub fn difference[K : Eq + Hash](self : T[K], other : T[K]) -> T[K] { + match (self, other) { + (Empty, _) => Empty + (_, Empty) => self + (Leaf(k), _) => if other.contains(k) { Empty } else { Leaf(k) } + (Branch(sa1), Branch(sa2)) => Branch(sa1.difference(sa2)) + (_, _) => + self + .iter() + .fold(init=Empty, fn(m, k) { + if other.contains(k) { + m + } else { + m.add(k) + } + }) + } +} + ///| /// Returns true if the hash set is empty. pub fn is_empty[A](self : T[A]) -> Bool { diff --git a/immut/hashset/HAMT_test.mbt b/immut/hashset/HAMT_test.mbt index 5e7f3ab38..7e5bc5ede 100644 --- a/immut/hashset/HAMT_test.mbt +++ b/immut/hashset/HAMT_test.mbt @@ -217,3 +217,43 @@ test "union 2 hashsets" { let set4 = @hashset.of([1, 2, 3, 4, 5]) inspect!(set3 == set4, content="true") } + +///| +test "@hashset.intersection with overlapping sets" { + let set1 = @hashset.of([1, 2, 3, 4]) + let set2 = @hashset.of([3, 4, 5, 6]) + let result = set1.intersection(set2) + inspect!(result, content="@immut/hashset.of([3, 4])") +} + +///| +test "@hashset.intersection with disjoint sets" { + let set1 = @hashset.of([1, 2]) + let set2 = @hashset.of([3, 4]) + let result = set1.intersection(set2) + inspect!(result.is_empty(), content="true") +} + +///| +test "@hashset.intersection with one empty set" { + let set1 = @hashset.of([1, 2, 3]) + let set2 = @hashset.new() + let result = set1.intersection(set2) + inspect!(result.is_empty(), content="true") +} + +///| +test "@hashset.intersection with identical sets" { + let set1 = @hashset.of([1, 2, 3]) + let set2 = @hashset.of([1, 2, 3]) + let result = set1.intersection(set2) + inspect!(result == set1, content="true") +} + +///| +test "@hashset.intersection with subset" { + let set1 = @hashset.of([1, 2, 3, 4]) + let set2 = @hashset.of([2, 3]) + let result = set1.intersection(set2) + inspect!(result == @hashset.of([2, 3]), content="true") +} diff --git a/immut/hashset/hashset.mbti b/immut/hashset/hashset.mbti index 84204d6cd..466152207 100644 --- a/immut/hashset/hashset.mbti +++ b/immut/hashset/hashset.mbti @@ -9,12 +9,16 @@ fn add[A : Eq + Hash](T[A], A) -> T[A] fn contains[A : Eq + Hash](T[A], A) -> Bool +fn difference[K : Eq + Hash](T[K], T[K]) -> T[K] + fn each[A](T[A], (A) -> Unit) -> Unit fn from_array[A : Eq + Hash](Array[A]) -> T[A] fn from_iter[A : Eq + Hash](Iter[A]) -> T[A] +fn intersection[K : Eq + Hash](T[K], T[K]) -> T[K] + fn is_empty[A](T[A]) -> Bool fn iter[A](T[A]) -> Iter[A] @@ -34,7 +38,9 @@ type T[A] impl T { add[A : Eq + Hash](Self[A], A) -> Self[A] contains[A : Eq + Hash](Self[A], A) -> Bool + difference[K : Eq + Hash](Self[K], Self[K]) -> Self[K] each[A](Self[A], (A) -> Unit) -> Unit + intersection[K : Eq + Hash](Self[K], Self[K]) -> Self[K] is_empty[A](Self[A]) -> Bool iter[A](Self[A]) -> Iter[A] remove[A : Eq + Hash](Self[A], A) -> Self[A] diff --git a/immut/internal/sparse_array/sparse_array.mbt b/immut/internal/sparse_array/sparse_array.mbt index 107650c28..3ca044909 100644 --- a/immut/internal/sparse_array/sparse_array.mbt +++ b/immut/internal/sparse_array/sparse_array.mbt @@ -102,6 +102,54 @@ pub fn union[X]( } } +///| +/// `intersection(self: SparseArray[X], other: SparseArray[X], f: (X, X) -> X) -> SparseArray[X]` +/// +/// Only keep indices present in both sparse arrays, and merge values with f. +pub fn intersection[X]( + self : SparseArray[X], + other : SparseArray[X], + f : (X, X) -> X +) -> SparseArray[X] { + let inter = self.elem_info.intersection(other.elem_info) + let new_len = inter.size() + if new_len == 0 { + empty() + } else { + let init = self.data[0] // Both sides have elements, pick either for init + let new_data = FixedArray::make(new_len, init) + for i in inter.iter() { + new_data[inter.index_of(i)] = f( + self.data[self.elem_info.index_of(i)], + other.data[other.elem_info.index_of(i)], + ) + } + { elem_info: inter, data: new_data } + } +} + +///| +/// `difference(self: SparseArray[X], other: SparseArray[X]) -> SparseArray[X]` +/// +/// Keep indices and values only present in self but not in other. +pub fn difference[X]( + self : SparseArray[X], + other : SparseArray[X] +) -> SparseArray[X] { + let diff = self.elem_info.difference(other.elem_info) + let new_len = diff.size() + if new_len == 0 { + empty() + } else { + let init = self.data[0] + let new_data = FixedArray::make(new_len, init) + for i in diff.iter() { + new_data[diff.index_of(i)] = self.data[self.elem_info.index_of(i)] + } + { elem_info: diff, data: new_data } + } +} + ///| /// `replace(self: SparseArray[X], idx: Int, value: X)` /// diff --git a/immut/internal/sparse_array/sparse_array.mbti b/immut/internal/sparse_array/sparse_array.mbti index 9cfa9a248..de05586ca 100644 --- a/immut/internal/sparse_array/sparse_array.mbti +++ b/immut/internal/sparse_array/sparse_array.mbti @@ -3,12 +3,16 @@ package "moonbitlang/core/immut/internal/sparse_array" // Values fn add[X](SparseArray[X], Int, X) -> SparseArray[X] +fn difference[X](SparseArray[X], SparseArray[X]) -> SparseArray[X] + fn each[X](SparseArray[X], (X) -> Unit) -> Unit fn empty[X]() -> SparseArray[X] fn has[X](SparseArray[X], Int) -> Bool +fn intersection[X](SparseArray[X], SparseArray[X], (X, X) -> X) -> SparseArray[X] + fn op_get[X](SparseArray[X], Int) -> X? fn replace[X](SparseArray[X], Int, X) -> SparseArray[X] @@ -40,8 +44,10 @@ pub(all) struct SparseArray[X] { } impl SparseArray { add[X](Self[X], Int, X) -> Self[X] + difference[X](Self[X], Self[X]) -> Self[X] each[X](Self[X], (X) -> Unit) -> Unit has[X](Self[X], Int) -> Bool + intersection[X](Self[X], Self[X], (X, X) -> X) -> Self[X] op_get[X](Self[X], Int) -> X? replace[X](Self[X], Int, X) -> Self[X] size[X](Self[X]) -> Int