diff --git a/crates/query/src/scan/array_predicate.rs b/crates/query/src/scan/array_predicate.rs index f9097b5..ca4cd50 100644 --- a/crates/query/src/scan/array_predicate.rs +++ b/crates/query/src/scan/array_predicate.rs @@ -4,7 +4,6 @@ use arrow::array::{Array, ArrayRef, AsArray, BooleanArray, Datum, PrimitiveArray use arrow::buffer::{BooleanBuffer, Buffer}; use arrow::compute::{cast_with_options, CastOptions}; use arrow::datatypes::{ArrowNativeType, ArrowNativeTypeOp, ArrowPrimitiveType, DataType, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type}; -use std::collections::HashSet; use std::hash::Hash; use std::ops::BitAnd; use std::sync::Arc; @@ -532,14 +531,11 @@ impl ArrayPredicate for BloomFilter { pub struct PrimitiveListContainsAny { - values: HashSet, + values: Vec, } -impl PrimitiveListContainsAny -where - T::Native: std::cmp::Eq + Hash, -{ +impl PrimitiveListContainsAny { pub fn new(values: &[T::Native]) -> Self { Self { values: values.iter().copied().collect(), @@ -548,10 +544,7 @@ where } -impl ArrayPredicate for PrimitiveListContainsAny -where - T::Native: std::cmp::Eq + Hash, -{ +impl ArrayPredicate for PrimitiveListContainsAny { fn evaluate(&self, arr: &dyn Array) -> anyhow::Result { let list_array = arr.as_list_opt::().ok_or_else(|| { anyhow!("expected List array, but got {}", arr.data_type()) @@ -585,7 +578,7 @@ where pub struct StringListContainsAny { - values: HashSet, + values: Vec, } @@ -619,7 +612,7 @@ impl ArrayPredicate for StringListContainsAny { let start = offsets[i].as_usize(); let end = offsets[i + 1].as_usize(); for j in start..end { - if self.values.contains(values_array.value(j)) { + if self.values.iter().any(|v| v.as_str() == values_array.value(j)) { return true; } } @@ -633,9 +626,10 @@ impl ArrayPredicate for StringListContainsAny { #[cfg(feature = "_bench")] mod bench { - use crate::scan::array_predicate::{ArrayPredicate, BloomFilter}; - use arrow::array::FixedSizeBinaryArray; + use crate::scan::array_predicate::{ArrayPredicate, BloomFilter, PrimitiveListContainsAny, StringListContainsAny}; + use arrow::array::{FixedSizeBinaryArray, ListArray, ListBuilder, StringBuilder}; use arrow::buffer::MutableBuffer; + use arrow::datatypes::UInt32Type; #[divan::bench] @@ -650,4 +644,48 @@ mod bench { pred.evaluate(&array).unwrap() }) } + + + #[divan::bench] + fn primitive_list_contains(bench: divan::Bencher) { + let pred: PrimitiveListContainsAny = PrimitiveListContainsAny::new(&vec![0]); + let array = ListArray::from_iter_primitive::(vec![ + Some(vec![Some(1), Some(2), Some(3)]), + Some(vec![Some(0), Some(1)]), + Some(vec![Some(100), Some(200), Some(300), Some(400), Some(500)]), + ]); + bench.bench(|| { + pred.evaluate(&array).unwrap() + }) + } + + + #[divan::bench] + fn string_list_contains(bench: divan::Bencher) { + let pred = StringListContainsAny::new(&vec!["0x00000000000000000000000000000000"]); + + let mut list_builder = ListBuilder::new(StringBuilder::new()); + + list_builder.values().append_value("0x11111111111111111111111111111111"); + list_builder.values().append_value("0x22222222222222222222222222222222"); + list_builder.values().append_value("0x33333333333333333333333333333333"); + list_builder.append(true); + + list_builder.values().append_value("0x00000000000000000000000000000000"); + list_builder.values().append_value("0x11111111111111111111111111111111"); + list_builder.append(true); + + list_builder.values().append_value("0x11111111111111111111111111111111"); + list_builder.values().append_value("0x22222222222222222222222222222222"); + list_builder.values().append_value("0x33333333333333333333333333333333"); + list_builder.values().append_value("0x44444444444444444444444444444444"); + list_builder.values().append_value("0x55555555555555555555555555555555"); + list_builder.append(true); + + let array = list_builder.finish(); + + bench.bench(|| { + pred.evaluate(&array).unwrap() + }) + } } \ No newline at end of file