Skip to content

Commit

Permalink
Preserve null dictionary values in interleave and concat kernels (#…
Browse files Browse the repository at this point in the history
…7144)

* fix(select): preserve null values in `merge_dictionary_values`

This function internally computes value masks describing which values
from input dictionaries should remain in the output. Values never
referenced by keys are considered redundant. Null values were
considered redundant, but they are now preserved as of this commit.

This change is necessary because keys can reference null values. Before
this commit, the entries of `MergedDictionaries::key_mappings`
corresponding to null values were left unset. This caused `concat` and
`interleave` to remap all elements referencing them to whatever value
at index 0, producing an erroneous result.

* test(select): add test case `concat::test_string_dictionary_array_nulls_in_values`

This test case passes dictionary arrays containing null values (but no
null keys) to `concat`.

* test(select): add test case `interleave::test_interleave_dictionary_nulls`

This test case passes two dictionary arrays each containing null values
or keys to `interleave`.

* refactor(select): add type alias for `Interner` bucket

Addresses `clippy::type-complexity`.
  • Loading branch information
kawadakk authored Feb 19, 2025
1 parent d0a2301 commit 46afdd8
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 12 deletions.
18 changes: 18 additions & 0 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,24 @@ mod tests {
)
}

#[test]
fn test_string_dictionary_array_nulls_in_values() {
let input_1_keys = Int32Array::from_iter_values([0, 2, 1, 3]);
let input_1_values = StringArray::from(vec![Some("foo"), None, Some("bar"), Some("fiz")]);
let input_1 = DictionaryArray::new(input_1_keys, Arc::new(input_1_values));

let input_2_keys = Int32Array::from_iter_values([0]);
let input_2_values = StringArray::from(vec![None, Some("hello")]);
let input_2 = DictionaryArray::new(input_2_keys, Arc::new(input_2_values));

let expected = vec![Some("foo"), Some("bar"), None, Some("fiz"), None];

let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
let dictionary = concat.as_dictionary::<Int32Type>();
let actual = collect_string_dictionary(dictionary);
assert_eq!(actual, expected);
}

#[test]
fn test_string_dictionary_merge() {
let mut builder = StringDictionaryBuilder::<Int32Type>::new();
Expand Down
41 changes: 29 additions & 12 deletions arrow-select/src/dictionary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ use arrow_schema::{ArrowError, DataType};
/// Hash collisions will result in replacement
struct Interner<'a, V> {
state: RandomState,
buckets: Vec<Option<(&'a [u8], V)>>,
buckets: Vec<Option<InternerBucket<'a, V>>>,
shift: u32,
}

/// A single bucket in [`Interner`].
type InternerBucket<'a, V> = (Option<&'a [u8]>, V);

impl<'a, V> Interner<'a, V> {
/// Capacity controls the number of unique buckets allocated within the Interner
///
Expand All @@ -54,7 +57,11 @@ impl<'a, V> Interner<'a, V> {
}
}

fn intern<F: FnOnce() -> Result<V, E>, E>(&mut self, new: &'a [u8], f: F) -> Result<&V, E> {
fn intern<F: FnOnce() -> Result<V, E>, E>(
&mut self,
new: Option<&'a [u8]>,
f: F,
) -> Result<&V, E> {
let hash = self.state.hash_one(new);
let bucket_idx = hash >> self.shift;
Ok(match &mut self.buckets[bucket_idx as usize] {
Expand Down Expand Up @@ -151,15 +158,19 @@ pub fn merge_dictionary_values<K: ArrowDictionaryKeyType>(

for (idx, dictionary) in dictionaries.iter().enumerate() {
let mask = masks.and_then(|m| m.get(idx));
let key_mask = match (dictionary.logical_nulls(), mask) {
(Some(n), None) => Some(n.into_inner()),
(None, Some(n)) => Some(n.clone()),
(Some(n), Some(m)) => Some(n.inner() & m),
let key_mask_owned;
let key_mask = match (dictionary.nulls(), mask) {
(Some(n), None) => Some(n.inner()),
(None, Some(n)) => Some(n),
(Some(n), Some(m)) => {
key_mask_owned = n.inner() & m;
Some(&key_mask_owned)
}
(None, None) => None,
};
let keys = dictionary.keys().values();
let values = dictionary.values().as_ref();
let values_mask = compute_values_mask(keys, key_mask.as_ref(), values.len());
let values_mask = compute_values_mask(keys, key_mask, values.len());

let masked_values = get_masked_values(values, &values_mask);
num_values += masked_values.len();
Expand Down Expand Up @@ -223,7 +234,10 @@ fn compute_values_mask<K: ArrowNativeType>(
}

/// Return a Vec containing for each set index in `mask`, the index and byte value of that index
fn get_masked_values<'a>(array: &'a dyn Array, mask: &BooleanBuffer) -> Vec<(usize, &'a [u8])> {
fn get_masked_values<'a>(
array: &'a dyn Array,
mask: &BooleanBuffer,
) -> Vec<(usize, Option<&'a [u8]>)> {
match array.data_type() {
DataType::Utf8 => masked_bytes(array.as_string::<i32>(), mask),
DataType::LargeUtf8 => masked_bytes(array.as_string::<i64>(), mask),
Expand All @@ -239,10 +253,13 @@ fn get_masked_values<'a>(array: &'a dyn Array, mask: &BooleanBuffer) -> Vec<(usi
fn masked_bytes<'a, T: ByteArrayType>(
array: &'a GenericByteArray<T>,
mask: &BooleanBuffer,
) -> Vec<(usize, &'a [u8])> {
) -> Vec<(usize, Option<&'a [u8]>)> {
let mut out = Vec::with_capacity(mask.count_set_bits());
for idx in mask.set_indices() {
out.push((idx, array.value(idx).as_ref()))
out.push((
idx,
array.is_valid(idx).then_some(array.value(idx).as_ref()),
))
}
out
}
Expand Down Expand Up @@ -311,10 +328,10 @@ mod tests {
let b = DictionaryArray::new(Int32Array::new_null(10), Arc::new(StringArray::new_null(0)));

let merged = merge_dictionary_values(&[&a, &b], None).unwrap();
let expected = StringArray::from(vec!["bingo", "hello"]);
let expected = StringArray::from(vec![None, Some("bingo"), Some("hello")]);
assert_eq!(merged.values.as_ref(), &expected);
assert_eq!(merged.key_mappings.len(), 2);
assert_eq!(&merged.key_mappings[0], &[0, 0, 0, 1, 0]);
assert_eq!(&merged.key_mappings[0], &[0, 0, 1, 2, 0]);
assert_eq!(&merged.key_mappings[1], &[] as &[i32; 0]);
}

Expand Down
24 changes: 24 additions & 0 deletions arrow-select/src/interleave.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,30 @@ mod tests {
assert_eq!(&collected, &["c", "c", "c"]);
}

#[test]
fn test_interleave_dictionary_nulls() {
let input_1_keys = Int32Array::from_iter_values([0, 2, 1, 3]);
let input_1_values = StringArray::from(vec![Some("foo"), None, Some("bar"), Some("fiz")]);
let input_1 = DictionaryArray::new(input_1_keys, Arc::new(input_1_values));
let input_2: DictionaryArray<Int32Type> = vec![None].into_iter().collect();

let expected = vec![Some("fiz"), None, None, Some("foo")];

let values = interleave(
&[&input_1 as _, &input_2 as _],
&[(0, 3), (0, 2), (1, 0), (0, 0)],
)
.unwrap();
let dictionary = values.as_dictionary::<Int32Type>();
let actual: Vec<Option<&str>> = dictionary
.downcast_dict::<StringArray>()
.unwrap()
.into_iter()
.collect();

assert_eq!(actual, expected);
}

#[test]
fn test_lists() {
// [[1, 2], null, [3]]
Expand Down

0 comments on commit 46afdd8

Please sign in to comment.