kinetic/transforms/aggregate/
task.rs1use crate::transforms::Transform;
4use crate::transforms::aggregate::sql_rewrite::SqlRewrite;
5use crate::transforms::util::TransformErrorHandler;
6use async_trait::async_trait;
7use duckdb_engine::instance::Config as DuckDbConfig;
8use duckdb_engine::{DuckDbInstance, GenerationalSwap, Result, WindowTimer};
9use kinetic_buffers::{BufferReceiver, BufferSender};
10use kinetic_config::AggregateConfig;
11use kinetic_config::model::ErrorPolicy;
12use kinetic_core::{ArcEventMetadata, ComponentId, EventBatch, EventMetadata};
13use metrics::{Label, counter};
14use std::collections::HashMap;
15use std::sync::Arc;
16use std::time::Duration;
17use tracing::{debug, error, info, warn};
18
19use kinetic_common::register;
20use kinetic_common::telemetry::EventDuration;
21use std::time::Instant;
22
23pub struct AggregateTask {
24 receiver: BufferReceiver,
25 senders: HashMap<String, BufferSender>,
26 config: AggregateConfig,
27 component_id: String,
28 pipeline_id: String,
29 window_duration: Duration,
30 error_handler: TransformErrorHandler,
31 labels: Arc<[Label]>,
32 event_duration: EventDuration,
33}
34
35#[async_trait]
36impl Transform for AggregateTask {
37 async fn run(self: Box<Self>) {
38 self.run_task().await;
39 }
40}
41
42impl AggregateTask {
43 pub fn new(
44 component_id: String,
45 pipeline_id: String,
46 receiver: BufferReceiver,
47 senders: HashMap<String, BufferSender>,
48 config: AggregateConfig,
49 error_policy: ErrorPolicy,
50 ) -> anyhow::Result<Self> {
51 let window_duration = parse_duration(&config.window_duration).unwrap_or_else(|e| {
52 warn!(
53 "Invalid window duration '{}', defaulting to 60s: {}",
54 config.window_duration, e
55 );
56 Duration::from_secs(60)
57 });
58
59 if !SqlRewrite::is_read_only(&config.sql) {
60 anyhow::bail!(
61 "Aggregate transform '{}' configured with non-SELECT SQL. Only read-only SELECT queries are allowed.",
62 component_id
63 );
64 }
65
66 let error_handler = TransformErrorHandler::new(component_id.clone(), error_policy);
67 let labels: Arc<[Label]> = Arc::new([
68 Label::new("component_id", component_id.clone()),
69 Label::new("component_type", "transform"),
70 Label::new("component_kind", "aggregate"),
71 ]);
72 let event_duration = register!(EventDuration::new(component_id.clone(), "transform"));
73
74 Ok(Self {
75 receiver,
76 senders,
77 config,
78 component_id,
79 pipeline_id,
80 window_duration,
81 error_handler,
82 labels,
83 event_duration,
84 })
85 }
86
87 pub async fn run_task(mut self) {
88 info!(
89 "Starting Aggregate transform task: {} with SQL: {}",
90 self.component_id, self.config.sql
91 );
92
93 let duckdb_config = DuckDbConfig {
94 memory_limit: self
95 .config
96 .memory_limit
97 .clone()
98 .unwrap_or_else(|| "8GB".to_string()),
99 temp_directory: self.config.temp_directory.clone(),
100 enable_external_access: false,
101 };
102
103 let mut instance = match DuckDbInstance::new(duckdb_config) {
104 Ok(inst) => inst,
105 Err(e) => {
106 error!(
107 "Failed to create DuckDB instance for {}: {}",
108 self.component_id, e
109 );
110 return;
111 }
112 };
113
114 let mut generational_swap = GenerationalSwap::new();
115 let mut window_timer = WindowTimer::from_duration(self.window_duration);
116
117 loop {
118 tokio::select! {
119 biased;
120
121 _ = window_timer.tick() => {
122 let start = Instant::now();
123 if let Err(e) = self.seal_and_aggregate(&mut instance, &mut generational_swap).await
124 && !self.error_handler.handle_error(&self.senders, format!("seal_and_aggregate failed: {}", e), None).await
125 {
126 break;
127 }
128 self.event_duration.emit(start.elapsed());
129 }
130
131 maybe_batch = self.receiver.recv() => {
132 match maybe_batch {
133 Some(batch) => {
134 let start = Instant::now();
135 counter!("component_received_events_total", self.labels.iter()).increment(batch.num_rows() as u64);
136 counter!("component_received_event_bytes_total", self.labels.iter()).increment(batch.estimated_size() as u64);
137
138 if let Some(raw_sender) = self.senders.get("raw")
140 && let Err(e) = raw_sender.send(batch.clone()).await
141 {
142 error!(
143 "Aggregate {} failed to send raw batch: {:?}",
144 self.component_id, e
145 );
146 }
147
148 if let Err(e) = self.process_batch(&mut instance, &mut generational_swap, &batch).await
149 && !self
150 .error_handler
151 .handle_error(&self.senders, format!("process_batch failed: {}", e), Some(&batch))
152 .await
153 {
154 break;
155 }
156
157 if let Some(token) = batch.ack_token {
158 token.ack();
159 }
160 self.event_duration.emit(start.elapsed());
161 }
162 None => {
163 info!("Aggregate transform task {} input channel closed", self.component_id);
164 if let Err(e) = self.seal_and_aggregate(&mut instance, &mut generational_swap).await {
166 let _ = self
167 .error_handler
168 .handle_error(
169 &self.senders,
170 format!("seal_and_aggregate on shutdown failed: {}", e),
171 None,
172 )
173 .await;
174 }
175 break;
176 }
177 }
178 }
179 }
180 }
181
182 info!(
183 "Aggregate transform task {} shutting down",
184 self.component_id
185 );
186 }
187
188 async fn process_batch(
189 &self,
190 instance: &mut DuckDbInstance,
191 swap: &mut GenerationalSwap,
192 batch: &EventBatch,
193 ) -> Result<()> {
194 debug!(
195 "Aggregate {} received batch of {} rows",
196 self.component_id,
197 batch.num_rows()
198 );
199
200 if !swap.is_initialized() {
201 swap.create_active_table_from_batch(instance.conn_mut(), &batch.payload)?;
202 }
203
204 swap.append_batch(instance.conn_mut(), &batch.payload)?;
205
206 Ok(())
207 }
208
209 async fn seal_and_aggregate(
210 &mut self,
211 instance: &mut DuckDbInstance,
212 swap: &mut GenerationalSwap,
213 ) -> Result<()> {
214 if !swap.has_data() {
215 debug!(
216 "Aggregate {} window closed but no data to aggregate",
217 self.component_id
218 );
219 return Ok(());
220 }
221
222 info!(
223 "Aggregate {} sealing window and running SQL",
224 self.component_id
225 );
226
227 let sealed_uuid = swap.seal(instance.conn_mut())?;
228 let sealed_table = format!("sealed_{}", sealed_uuid.simple());
229
230 let sql = SqlRewrite::rewrite(&self.config.sql, &sealed_table);
231
232 debug!("Aggregate {} executing SQL: {}", self.component_id, sql);
233
234 let batches = instance.query_arrow(&sql)?;
235
236 let sender = self
237 .senders
238 .get("aggregated")
239 .or_else(|| self.senders.get("default"));
240
241 if let Some(sender) = sender {
242 for batch in batches {
243 if batch.num_rows() == 0 {
244 continue;
245 }
246
247 let metadata = EventMetadata::new(
248 self.pipeline_id.clone(),
249 ComponentId(self.component_id.clone()),
250 );
251
252 match EventBatch::new(batch, ArcEventMetadata::new(metadata)) {
253 Ok(event_batch) => {
254 let rows = event_batch.num_rows();
255 let bytes = event_batch.estimated_size();
256 if let Err(e) = sender.send(event_batch).await {
257 counter!("component_errors_total", self.labels.iter()).increment(1);
258 error!(
259 "Aggregate {} failed to send aggregated batch: {:?}",
260 self.component_id, e
261 );
262 } else {
263 counter!("component_sent_events_total", self.labels.iter())
264 .increment(rows as u64);
265 counter!("component_sent_event_bytes_total", self.labels.iter())
266 .increment(bytes as u64);
267 }
268 }
269 Err(e) => {
270 counter!("component_errors_total", self.labels.iter()).increment(1);
271 error!("Failed to create EventBatch in aggregate: {}", e);
272 }
273 }
274 }
275 }
276 duckdb_engine::generational::cleanup_sealed(instance.conn_mut(), &sealed_uuid)?;
277
278 info!(
279 "Aggregate {} completed window aggregation",
280 self.component_id
281 );
282
283 Ok(())
284 }
285}
286
287fn parse_duration(s: &str) -> std::result::Result<Duration, String> {
288 kinetic_common::parse_duration(s).map_err(|e| e.to_string())
289}
290
291#[cfg(test)]
292#[allow(clippy::unwrap_used, clippy::expect_used)]
293mod tests {
294 use super::*;
295
296 #[test]
297 fn test_parse_duration() {
298 assert_eq!(parse_duration("15s").unwrap(), Duration::from_secs(15));
299 assert_eq!(parse_duration("5m").unwrap(), Duration::from_secs(300));
300 assert_eq!(parse_duration("100ms").unwrap(), Duration::from_millis(100));
301 }
302
303 #[test]
304 fn test_parse_duration_errors() {
305 assert!(parse_duration("").is_err());
306 assert!(parse_duration("s").is_err());
307 assert!(parse_duration("5x").is_err());
308 }
309}