diff --git a/array/bitstring.mbt b/array/bitstring.mbt index 48305d310..ae04a91d2 100644 --- a/array/bitstring.mbt +++ b/array/bitstring.mbt @@ -40,13 +40,9 @@ pub fn ArrayView::unsafe_extract_bit( _len : Int, ) -> UInt { let byte_index = offset >> 3 - let bit_mask = (1 << (7 - (offset & 7))).to_byte() - // TODO: branchless for performance - if (bs.unsafe_get(byte_index) & bit_mask) != 0 { - 1 - } else { - 0 - } + let bit_shift = 7 - (offset & 7) + let byte_val = bs.unsafe_get(byte_index).to_uint() + (byte_val >> bit_shift) & 1U } ///| @@ -59,13 +55,10 @@ pub fn ArrayView::unsafe_extract_bit_signed( _len : Int, ) -> Int { let byte_index = offset >> 3 - let bit_mask = (1 << (7 - (offset & 7))).to_byte() - // TODO: branchless for performance - if (bs.unsafe_get(byte_index) & bit_mask) != 0 { - -1 - } else { - 0 - } + let bit_shift = 7 - (offset & 7) + let byte_val = bs.unsafe_get(byte_index).to_int() + // Extract bit and convert to signed: 0 -> 0, 1 -> -1 + (((byte_val >> bit_shift) & 1) * -1) | 0 } ///| diff --git a/builtin/array.mbt b/builtin/array.mbt index d09201eee..835731cdc 100644 --- a/builtin/array.mbt +++ b/builtin/array.mbt @@ -735,7 +735,6 @@ pub fn[T] Array::rev(self : Array[T]) -> Array[T] { /// assert_eq(v1, [3]) /// assert_eq(v2, [4, 5]) /// ``` -/// TODO: perf could be optimized pub fn[T] Array::split_at(self : Array[T], index : Int) -> (Array[T], Array[T]) { if index < 0 || index > self.length() { let len = self.length() @@ -743,18 +742,11 @@ pub fn[T] Array::split_at(self : Array[T], index : Int) -> (Array[T], Array[T]) "index out of bounds: the len is from 0 to \{len} but the index is \{index}", ) } + let len2 = self.length() - index let v1 = Array::make_uninit(index) - let v2 = Array::make_uninit(self.length() - index) + let v2 = Array::make_uninit(len2) UninitializedArray::unsafe_blit(v1.buffer(), 0, self.buffer(), 0, index) - if index != self.length() { - UninitializedArray::unsafe_blit( - v2.buffer(), - 0, - self.buffer(), - index, - self.length() - index, - ) - } + UninitializedArray::unsafe_blit(v2.buffer(), 0, self.buffer(), index, len2) (v1, v2) } diff --git a/bytes/bitstring.mbt b/bytes/bitstring.mbt index c13c86355..c299a7ccc 100644 --- a/bytes/bitstring.mbt +++ b/bytes/bitstring.mbt @@ -40,13 +40,9 @@ pub fn BytesView::unsafe_extract_bit( _len : Int, ) -> UInt { let byte_index = offset >> 3 - let bit_mask = (1 << (7 - (offset & 7))).to_byte() - // TODO: branchless for performance - if (bs.unsafe_get(byte_index) & bit_mask) != 0 { - 1 - } else { - 0 - } + let bit_shift = 7 - (offset & 7) + let byte_val = bs.unsafe_get(byte_index).to_uint() + (byte_val >> bit_shift) & 1U } ///| @@ -59,13 +55,10 @@ pub fn BytesView::unsafe_extract_bit_signed( _len : Int, ) -> Int { let byte_index = offset >> 3 - let bit_mask = (1 << (7 - (offset & 7))).to_byte() - // TODO: branchless for performance - if (bs.unsafe_get(byte_index) & bit_mask) != 0 { - -1 - } else { - 0 - } + let bit_shift = 7 - (offset & 7) + let byte_val = bs.unsafe_get(byte_index).to_int() + // Extract bit and convert to signed: 0 -> 0, 1 -> -1 + (((byte_val >> bit_shift) & 1) * -1) | 0 } ///| diff --git a/sorted_set/set.mbt b/sorted_set/set.mbt index 9f53548db..d9b2a1375 100644 --- a/sorted_set/set.mbt +++ b/sorted_set/set.mbt @@ -141,15 +141,17 @@ pub fn[V : Compare] SortedSet::union( self : SortedSet[V], src : SortedSet[V], ) -> SortedSet[V] { - fn aux(a : Node[V]?, b : Node[V]?) -> Node[V]? { + fn aux(a : Node[V]?, b : Node[V]?) -> (Node[V]?, Int) { match (a, b) { - (Some(_), None) => a - (None, Some(_)) => b + (Some(_), None) => (a, count_nodes(a)) + (None, Some(_)) => (b, count_nodes(b)) (Some({ value: va, left: la, right: ra, .. }), Some(_)) => { let (l, r) = split(b, va) - Some(join(aux(la, l), va, aux(ra, r))) + let (left_tree, left_size) = aux(la, l) + let (right_tree, right_size) = aux(ra, r) + (Some(join(left_tree, va, right_tree)), left_size + 1 + right_size) } - (None, None) => None + (None, None) => (None, 0) } } @@ -157,13 +159,8 @@ pub fn[V : Compare] SortedSet::union( (Some(_), Some(_)) => { let t1 = copy_tree(self.root) let t2 = copy_tree(src.root) - let t = aux(t1, t2) - let mut ct = 0 - let ret = { root: t, size: 0 } - // TODO: optimize this. Avoid counting the size of the set. - ret.each(_x => ct = ct + 1) - ret.size = ct - ret + let (t, size) = aux(t1, t2) + { root: t, size } } (Some(_), None) => { root: copy_tree(self.root), size: self.size } (None, Some(_)) => { root: copy_tree(src.root), size: src.size } @@ -171,6 +168,14 @@ pub fn[V : Compare] SortedSet::union( } } +///| +fn[V] count_nodes(node : Node[V]?) -> Int { + match node { + None => 0 + Some({ left, right, .. }) => 1 + count_nodes(left) + count_nodes(right) + } +} + ///| fn[V : Compare] split(root : Node[V]?, value : V) -> (Node[V]?, Node[V]?) { match root { @@ -289,10 +294,10 @@ pub fn[V : Compare] SortedSet::symmetric_difference( self : SortedSet[V], other : SortedSet[V], ) -> SortedSet[V] { - // TODO: Optimize this function to avoid creating two intermediate sets. - let set1 = self.difference(other) - let set2 = other.difference(self) - set1.union(set2) + let ret = new() + self.each(x => if !other.contains(x) { ret.add(x) }) + other.each(x => if !self.contains(x) { ret.add(x) }) + ret } ///|