Skip to main content

kinetic/transforms/filter/
task.rs

1//! Filter transform task.
2
3use crate::transforms::Transform;
4use crate::transforms::util::TransformErrorHandler;
5use arrow_array::{BooleanArray, RecordBatch, StringArray};
6use arrow_select::filter::filter_record_batch;
7use async_trait::async_trait;
8use kinetic_buffers::{BufferReceiver, BufferSender};
9use kinetic_config::model::ErrorPolicy;
10use kinetic_config::{FilterCondition, FilterType};
11use kinetic_core::EventBatch;
12use std::collections::HashMap;
13use tracing::{debug, error, info};
14
15use metrics::{Label, counter};
16use std::sync::Arc;
17
18use kinetic_common::register;
19use kinetic_common::telemetry::EventDuration;
20use std::time::Instant;
21
22pub struct FilterTask {
23    receiver: BufferReceiver,
24    senders: HashMap<String, BufferSender>,
25    condition: FilterCondition,
26    component_id: String,
27    labels: Arc<[Label]>,
28    error_handler: TransformErrorHandler,
29    event_duration: EventDuration,
30}
31
32#[async_trait]
33impl Transform for FilterTask {
34    async fn run(self: Box<Self>) {
35        self.run_task().await;
36    }
37}
38
39impl FilterTask {
40    pub fn new(
41        component_id: String,
42        _pipeline_id: String,
43        receiver: BufferReceiver,
44        senders: HashMap<String, BufferSender>,
45        condition: FilterCondition,
46        error_policy: ErrorPolicy,
47    ) -> Self {
48        let labels: Arc<[Label]> = Arc::new([
49            Label::new("component_id", component_id.clone()),
50            Label::new("component_type", "transform"),
51            Label::new("component_kind", "filter"),
52        ]);
53
54        let error_handler = TransformErrorHandler::new(component_id.clone(), error_policy);
55        let event_duration = register!(EventDuration::new(component_id.clone(), "transform"));
56
57        Self {
58            receiver,
59            senders,
60            condition,
61            component_id,
62            labels,
63            error_handler,
64            event_duration,
65        }
66    }
67
68    pub async fn run_task(mut self) {
69        info!("Starting Filter transform task: {}", self.component_id);
70
71        while let Some(mut batch) = self.receiver.recv().await {
72            let start = Instant::now();
73            let received_rows = batch.num_rows();
74            counter!("component_received_events_total", self.labels.iter())
75                .increment(received_rows as u64);
76            counter!("component_received_event_bytes_total", self.labels.iter())
77                .increment(batch.estimated_size() as u64);
78
79            debug!(
80                "Filter {} received batch of {} rows",
81                self.component_id, received_rows
82            );
83
84            let filtered_batch = match self.apply_filter(&batch.payload) {
85                Ok(Some(b)) => b,
86                Ok(None) => {
87                    // All rows filtered out
88                    counter!("component_discarded_events_total", self.labels.iter())
89                        .increment(received_rows as u64);
90                    if let Some(token) = batch.ack_token.take() {
91                        token.ack();
92                    }
93                    self.event_duration.emit(start.elapsed());
94                    continue;
95                }
96                Err(e) => {
97                    counter!("component_errors_total", self.labels.iter()).increment(1);
98                    if !self
99                        .error_handler
100                        .handle_error(
101                            &self.senders,
102                            format!("Failed to apply filter to batch: {}", e),
103                            Some(&batch),
104                        )
105                        .await
106                    {
107                        break;
108                    }
109                    self.event_duration.emit(start.elapsed());
110                    continue;
111                }
112            };
113
114            let sent_rows = filtered_batch.num_rows();
115            if sent_rows == 0 {
116                counter!("component_discarded_events_total", self.labels.iter())
117                    .increment(received_rows as u64);
118                if let Some(token) = batch.ack_token.take() {
119                    token.ack();
120                }
121                self.event_duration.emit(start.elapsed());
122                continue;
123            }
124
125            if received_rows > sent_rows {
126                counter!("component_discarded_events_total", self.labels.iter())
127                    .increment((received_rows - sent_rows) as u64);
128            }
129
130            let mut new_batch = match EventBatch::new(filtered_batch, batch.metadata.clone()) {
131                Ok(b) => b,
132                Err(e) => {
133                    counter!("component_errors_total", self.labels.iter()).increment(1);
134                    if !self
135                        .error_handler
136                        .handle_error(
137                            &self.senders,
138                            format!("Failed to create EventBatch: {}", e),
139                            Some(&batch),
140                        )
141                        .await
142                    {
143                        break;
144                    }
145                    self.event_duration.emit(start.elapsed());
146                    continue;
147                }
148            };
149            new_batch.ack_token = batch.ack_token.take();
150
151            if let Some(sender) = self.senders.get("default") {
152                let bytes_sent = new_batch.estimated_size();
153                match sender.send(new_batch).await {
154                    Ok(_) => {
155                        counter!("component_sent_events_total", self.labels.iter())
156                            .increment(sent_rows as u64);
157                        counter!("component_sent_event_bytes_total", self.labels.iter())
158                            .increment(bytes_sent as u64);
159                    }
160                    Err(e) => {
161                        counter!("component_errors_total", self.labels.iter()).increment(1);
162                        error!(
163                            "Filter {} failed to send batch to default output: {:?}",
164                            self.component_id, e
165                        );
166                    }
167                }
168            } else {
169                // No default sender, events are effectively dropped
170                counter!("component_discarded_events_total", self.labels.iter())
171                    .increment(sent_rows as u64);
172                if let Some(token) = new_batch.ack_token.take() {
173                    token.ack();
174                }
175            }
176            self.event_duration.emit(start.elapsed());
177        }
178
179        info!("Filter transform task {} shutting down", self.component_id);
180    }
181
182    pub fn apply_filter(
183        &self,
184        batch: &RecordBatch,
185    ) -> std::result::Result<Option<RecordBatch>, String> {
186        let filter_array = self.evaluate_condition(batch)?;
187        let filtered_batch = filter_record_batch(batch, &filter_array)
188            .map_err(|e| format!("Failed to filter record batch: {}", e))?;
189
190        if filtered_batch.num_rows() > 0 {
191            Ok(Some(filtered_batch))
192        } else {
193            Ok(None)
194        }
195    }
196
197    pub fn evaluate_condition(
198        &self,
199        batch: &RecordBatch,
200    ) -> std::result::Result<BooleanArray, String> {
201        let column = batch
202            .column_by_name(&self.condition.source)
203            .ok_or_else(|| format!("Column '{}' not found", self.condition.source))?;
204
205        let array = column
206            .as_any()
207            .downcast_ref::<StringArray>()
208            .ok_or_else(|| "Filter condition only supports StringArray for now".to_string())?;
209
210        let pattern = &self.condition.pattern;
211        let filter = array
212            .iter()
213            .map(|val| match self.condition.kind {
214                FilterType::Pass => val.map(|s| s.contains(pattern)).unwrap_or(false),
215                FilterType::Drop => val.map(|s| !s.contains(pattern)).unwrap_or(true),
216            })
217            .collect::<BooleanArray>();
218
219        Ok(filter)
220    }
221}