Skip to main content

kinetic/transforms/aggregate/
task.rs

1//! Generic SQL aggregate transform task.
2
3use crate::transforms::Transform;
4use crate::transforms::aggregate::sql_rewrite::SqlRewrite;
5use crate::transforms::util::TransformErrorHandler;
6use async_trait::async_trait;
7use duckdb_engine::instance::Config as DuckDbConfig;
8use duckdb_engine::{DuckDbInstance, GenerationalSwap, Result, WindowTimer};
9use kinetic_buffers::{BufferReceiver, BufferSender};
10use kinetic_config::AggregateConfig;
11use kinetic_config::model::ErrorPolicy;
12use kinetic_core::{ArcEventMetadata, ComponentId, EventBatch, EventMetadata};
13use metrics::{Label, counter};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::time::Duration;
17use tracing::{debug, error, info, warn};
18
19use kinetic_common::register;
20use kinetic_common::telemetry::EventDuration;
21use std::time::Instant;
22
23pub struct AggregateTask {
24    receiver: BufferReceiver,
25    senders: HashMap<String, BufferSender>,
26    config: AggregateConfig,
27    component_id: String,
28    pipeline_id: String,
29    window_duration: Duration,
30    error_handler: TransformErrorHandler,
31    labels: Arc<[Label]>,
32    event_duration: EventDuration,
33}
34
35#[async_trait]
36impl Transform for AggregateTask {
37    async fn run(self: Box<Self>) {
38        self.run_task().await;
39    }
40}
41
42impl AggregateTask {
43    pub fn new(
44        component_id: String,
45        pipeline_id: String,
46        receiver: BufferReceiver,
47        senders: HashMap<String, BufferSender>,
48        config: AggregateConfig,
49        error_policy: ErrorPolicy,
50    ) -> anyhow::Result<Self> {
51        let window_duration = parse_duration(&config.window_duration).unwrap_or_else(|e| {
52            warn!(
53                "Invalid window duration '{}', defaulting to 60s: {}",
54                config.window_duration, e
55            );
56            Duration::from_secs(60)
57        });
58
59        if !SqlRewrite::is_read_only(&config.sql) {
60            anyhow::bail!(
61                "Aggregate transform '{}' configured with non-SELECT SQL. Only read-only SELECT queries are allowed.",
62                component_id
63            );
64        }
65
66        let error_handler = TransformErrorHandler::new(component_id.clone(), error_policy);
67        let labels: Arc<[Label]> = Arc::new([
68            Label::new("component_id", component_id.clone()),
69            Label::new("component_type", "transform"),
70            Label::new("component_kind", "aggregate"),
71        ]);
72        let event_duration = register!(EventDuration::new(component_id.clone(), "transform"));
73
74        Ok(Self {
75            receiver,
76            senders,
77            config,
78            component_id,
79            pipeline_id,
80            window_duration,
81            error_handler,
82            labels,
83            event_duration,
84        })
85    }
86
87    pub async fn run_task(mut self) {
88        info!(
89            "Starting Aggregate transform task: {} with SQL: {}",
90            self.component_id, self.config.sql
91        );
92
93        let duckdb_config = DuckDbConfig {
94            memory_limit: self
95                .config
96                .memory_limit
97                .clone()
98                .unwrap_or_else(|| "8GB".to_string()),
99            temp_directory: self.config.temp_directory.clone(),
100            enable_external_access: false,
101        };
102
103        let mut instance = match DuckDbInstance::new(duckdb_config) {
104            Ok(inst) => inst,
105            Err(e) => {
106                error!(
107                    "Failed to create DuckDB instance for {}: {}",
108                    self.component_id, e
109                );
110                return;
111            }
112        };
113
114        let mut generational_swap = GenerationalSwap::new();
115        let mut window_timer = WindowTimer::from_duration(self.window_duration);
116
117        loop {
118            tokio::select! {
119                biased;
120
121                _ = window_timer.tick() => {
122                    let start = Instant::now();
123                    if let Err(e) = self.seal_and_aggregate(&mut instance, &mut generational_swap).await
124                        && !self.error_handler.handle_error(&self.senders, format!("seal_and_aggregate failed: {}", e), None).await
125                    {
126                        break;
127                    }
128                    self.event_duration.emit(start.elapsed());
129                }
130
131                maybe_batch = self.receiver.recv() => {
132                    match maybe_batch {
133                        Some(batch) => {
134                            let start = Instant::now();
135                            counter!("component_received_events_total", self.labels.iter()).increment(batch.num_rows() as u64);
136                            counter!("component_received_event_bytes_total", self.labels.iter()).increment(batch.estimated_size() as u64);
137
138                            // Fork mode: send raw data to 'raw' output if it exists
139                            if let Some(raw_sender) = self.senders.get("raw")
140                                && let Err(e) = raw_sender.send(batch.clone()).await
141                            {
142                                error!(
143                                    "Aggregate {} failed to send raw batch: {:?}",
144                                    self.component_id, e
145                                );
146                            }
147
148                            if let Err(e) = self.process_batch(&mut instance, &mut generational_swap, &batch).await
149                                && !self
150                                    .error_handler
151                                    .handle_error(&self.senders, format!("process_batch failed: {}", e), Some(&batch))
152                                    .await
153                            {
154                                break;
155                            }
156
157                            if let Some(token) = batch.ack_token {
158                                token.ack();
159                            }
160                            self.event_duration.emit(start.elapsed());
161                        }
162                        None => {
163                            info!("Aggregate transform task {} input channel closed", self.component_id);
164                            // Final aggregation before shutting down
165                            if let Err(e) = self.seal_and_aggregate(&mut instance, &mut generational_swap).await {
166                                let _ = self
167                                    .error_handler
168                                    .handle_error(
169                                        &self.senders,
170                                        format!("seal_and_aggregate on shutdown failed: {}", e),
171                                        None,
172                                    )
173                                    .await;
174                            }
175                            break;
176                        }
177                    }
178                }
179            }
180        }
181
182        info!(
183            "Aggregate transform task {} shutting down",
184            self.component_id
185        );
186    }
187
188    async fn process_batch(
189        &self,
190        instance: &mut DuckDbInstance,
191        swap: &mut GenerationalSwap,
192        batch: &EventBatch,
193    ) -> Result<()> {
194        debug!(
195            "Aggregate {} received batch of {} rows",
196            self.component_id,
197            batch.num_rows()
198        );
199
200        if !swap.is_initialized() {
201            swap.create_active_table_from_batch(instance.conn_mut(), &batch.payload)?;
202        }
203
204        swap.append_batch(instance.conn_mut(), &batch.payload)?;
205
206        Ok(())
207    }
208
209    async fn seal_and_aggregate(
210        &mut self,
211        instance: &mut DuckDbInstance,
212        swap: &mut GenerationalSwap,
213    ) -> Result<()> {
214        if !swap.has_data() {
215            debug!(
216                "Aggregate {} window closed but no data to aggregate",
217                self.component_id
218            );
219            return Ok(());
220        }
221
222        info!(
223            "Aggregate {} sealing window and running SQL",
224            self.component_id
225        );
226
227        let sealed_uuid = swap.seal(instance.conn_mut())?;
228        let sealed_table = format!("sealed_{}", sealed_uuid.simple());
229
230        let sql = SqlRewrite::rewrite(&self.config.sql, &sealed_table);
231
232        debug!("Aggregate {} executing SQL: {}", self.component_id, sql);
233
234        let batches = instance.query_arrow(&sql)?;
235
236        let sender = self
237            .senders
238            .get("aggregated")
239            .or_else(|| self.senders.get("default"));
240
241        if let Some(sender) = sender {
242            for batch in batches {
243                if batch.num_rows() == 0 {
244                    continue;
245                }
246
247                let metadata = EventMetadata::new(
248                    self.pipeline_id.clone(),
249                    ComponentId(self.component_id.clone()),
250                );
251
252                match EventBatch::new(batch, ArcEventMetadata::new(metadata)) {
253                    Ok(event_batch) => {
254                        let rows = event_batch.num_rows();
255                        let bytes = event_batch.estimated_size();
256                        if let Err(e) = sender.send(event_batch).await {
257                            counter!("component_errors_total", self.labels.iter()).increment(1);
258                            error!(
259                                "Aggregate {} failed to send aggregated batch: {:?}",
260                                self.component_id, e
261                            );
262                        } else {
263                            counter!("component_sent_events_total", self.labels.iter())
264                                .increment(rows as u64);
265                            counter!("component_sent_event_bytes_total", self.labels.iter())
266                                .increment(bytes as u64);
267                        }
268                    }
269                    Err(e) => {
270                        counter!("component_errors_total", self.labels.iter()).increment(1);
271                        error!("Failed to create EventBatch in aggregate: {}", e);
272                    }
273                }
274            }
275        }
276        duckdb_engine::generational::cleanup_sealed(instance.conn_mut(), &sealed_uuid)?;
277
278        info!(
279            "Aggregate {} completed window aggregation",
280            self.component_id
281        );
282
283        Ok(())
284    }
285}
286
287fn parse_duration(s: &str) -> std::result::Result<Duration, String> {
288    kinetic_common::parse_duration(s).map_err(|e| e.to_string())
289}
290
291#[cfg(test)]
292#[allow(clippy::unwrap_used, clippy::expect_used)]
293mod tests {
294    use super::*;
295
296    #[test]
297    fn test_parse_duration() {
298        assert_eq!(parse_duration("15s").unwrap(), Duration::from_secs(15));
299        assert_eq!(parse_duration("5m").unwrap(), Duration::from_secs(300));
300        assert_eq!(parse_duration("100ms").unwrap(), Duration::from_millis(100));
301    }
302
303    #[test]
304    fn test_parse_duration_errors() {
305        assert!(parse_duration("").is_err());
306        assert!(parse_duration("s").is_err());
307        assert!(parse_duration("5x").is_err());
308    }
309}