Skip to main content

kinetic/transforms/sql_map/
task.rs

1use crate::transforms::Transform;
2use crate::transforms::util::TransformErrorHandler;
3use async_trait::async_trait;
4use duckdb_engine::{
5    DuckDbInstance,
6    generational::{create_table_from_batch_schema, insert_batch},
7};
8use kinetic_buffers::{BufferReceiver, BufferSender};
9use kinetic_config::model::ErrorPolicy;
10use kinetic_config::model::SqlMapConfig;
11use kinetic_core::{ArcEventMetadata, ComponentId, EventBatch, EventMetadata};
12use std::collections::HashMap;
13use tracing::{debug, error, info, trace};
14
15use metrics::{Label, counter};
16use std::sync::Arc;
17
18use kinetic_common::register;
19use kinetic_common::telemetry::EventDuration;
20use std::time::Instant;
21
22pub struct SqlMapTask {
23    receiver: BufferReceiver,
24    senders: HashMap<String, BufferSender>,
25    config: SqlMapConfig,
26    pipeline_id: String,
27    component_id: String,
28    labels: Arc<[Label]>,
29    error_handler: TransformErrorHandler,
30    event_duration: EventDuration,
31}
32
33#[async_trait]
34impl Transform for SqlMapTask {
35    async fn run(self: Box<Self>) {
36        self.run_task().await;
37    }
38}
39
40impl SqlMapTask {
41    pub fn new(
42        component_id: String,
43        pipeline_id: String,
44        receiver: BufferReceiver,
45        senders: HashMap<String, BufferSender>,
46        config: SqlMapConfig,
47        error_policy: ErrorPolicy,
48    ) -> Self {
49        let labels: Arc<[Label]> = Arc::new([
50            Label::new("component_id", component_id.clone()),
51            Label::new("component_type", "transform"),
52            Label::new("component_kind", "sql_map"),
53        ]);
54
55        let error_handler = TransformErrorHandler::new(component_id.clone(), error_policy);
56        let event_duration = register!(EventDuration::new(component_id.clone(), "transform"));
57
58        Self {
59            receiver,
60            senders,
61            config,
62            pipeline_id,
63            component_id,
64            labels,
65            error_handler,
66            event_duration,
67        }
68    }
69
70    pub async fn run_task(mut self) {
71        info!("Starting SqlMap transform task: {}", self.component_id);
72
73        // Security: Ensure external access is disabled in DuckDB
74        let duckdb_config = duckdb_engine::instance::Config {
75            enable_external_access: false,
76            ..duckdb_engine::instance::Config::default()
77        };
78
79        let mut instance = match DuckDbInstance::new(duckdb_config) {
80            Ok(i) => i,
81            Err(e) => {
82                error!(
83                    "Failed to initialize DuckDB instance for SqlMap transform {}: {}",
84                    self.component_id, e
85                );
86                return;
87            }
88        };
89
90        // Cache metadata to avoid redundant allocations in the inner loop
91        let metadata = ArcEventMetadata::new(EventMetadata::new(
92            self.pipeline_id.clone(),
93            ComponentId(self.component_id.clone()),
94        ));
95
96        while let Some(mut batch) = self.receiver.recv().await {
97            let start = Instant::now();
98            let received_rows = batch.num_rows();
99            counter!("component_received_events_total", self.labels.iter())
100                .increment(received_rows as u64);
101            counter!("component_received_event_bytes_total", self.labels.iter())
102                .increment(batch.estimated_size() as u64);
103
104            debug!(
105                "SqlMap {} received batch of {} rows",
106                self.component_id, received_rows
107            );
108
109            // Execute SQL logic against the incoming batch payload using DuckDB
110            if let Err(e) = self
111                .process_and_send(&mut instance, &mut batch, metadata.clone())
112                .await
113            {
114                counter!("component_errors_total", self.labels.iter()).increment(1);
115                error!(
116                    "SqlMap {} failed to process batch: {:?}",
117                    self.component_id, e
118                );
119                if !self
120                    .error_handler
121                    .handle_error(
122                        &self.senders,
123                        format!("Failed to project batch: {:?}", e),
124                        Some(&batch),
125                    )
126                    .await
127                {
128                    self.event_duration.emit(start.elapsed());
129                    break;
130                }
131            }
132            self.event_duration.emit(start.elapsed());
133        }
134
135        info!("SqlMap transform task {} shutting down", self.component_id);
136    }
137
138    async fn process_and_send(
139        &mut self,
140        instance: &mut DuckDbInstance,
141        batch: &mut EventBatch,
142        metadata: ArcEventMetadata,
143    ) -> anyhow::Result<()> {
144        let table_name = "payload";
145        let conn = instance.conn_mut();
146
147        // Truncate the table if it already exists, otherwise create it.
148        // Doing this per batch is very efficient for in-memory embedded DBs like DuckDB.
149        // SAFETY: table_name is hardcoded to "payload" above and is not derived from user input,
150        // preventing SQL injection through string formatting.
151        let _ = conn.execute(&format!("DROP TABLE IF EXISTS {}", table_name), []);
152        create_table_from_batch_schema(conn, table_name, &batch.payload)?;
153        insert_batch(conn, table_name, &batch.payload)?;
154
155        // Execute user query over the 'payload' table
156        let batches = instance.query_arrow(&self.config.query)?;
157
158        let sender = self.senders.get("default");
159
160        if let Some(sender) = sender {
161            let num_batches = batches.len();
162
163            // Cascade AckToken to all outgoing batches if available
164            let mut ack_tokens = if let Some(token) = batch.ack_token.take() {
165                token.split(num_batches)
166            } else {
167                Vec::new()
168            };
169
170            for new_batch in batches {
171                if new_batch.num_rows() == 0 {
172                    // If we have a split token for an empty batch, we must ack it now
173                    if !ack_tokens.is_empty() {
174                        ack_tokens.remove(0).ack();
175                    }
176                    continue;
177                }
178
179                let rows = new_batch.num_rows();
180                let mut event_batch = EventBatch::new(new_batch, metadata.clone())?;
181
182                // Assign the split token to this sub-batch
183                if !ack_tokens.is_empty() {
184                    event_batch.ack_token = Some(ack_tokens.remove(0));
185                }
186
187                let bytes = event_batch.estimated_size();
188
189                match sender.send(event_batch).await {
190                    Ok(_) => {
191                        counter!("component_sent_events_total", self.labels.iter())
192                            .increment(rows as u64);
193                        counter!("component_sent_event_bytes_total", self.labels.iter())
194                            .increment(bytes as u64);
195                    }
196                    Err(e) => {
197                        error!(
198                            "SqlMap {} failed to send projected batch to default output: {:?}",
199                            self.component_id, e
200                        );
201                        // The AckToken will be dropped with the event_batch,
202                        // eventually causing a nack/replay (intended fallback).
203                        return Err(anyhow::anyhow!("Failed to send to downstream: {:?}", e));
204                    }
205                }
206            }
207        } else {
208            // No default sender, events are dropped
209            let num_rows = batch.num_rows();
210            counter!("component_discarded_events_total", self.labels.iter())
211                .increment(num_rows as u64);
212
213            if let Some(token) = batch.ack_token.take() {
214                token.ack();
215            }
216        }
217
218        trace!("SqlMap cleaning up payload table");
219        // SAFETY: table_name is hardcoded to "payload" as defined in process_and_send.
220        let _ = instance.execute(&format!("DROP TABLE IF EXISTS {}", table_name));
221
222        Ok(())
223    }
224}