Skip to main content

kinetic/sources/aws/
sqs.rs

1//! AWS SQS Source.
2
3use arrow_array::{RecordBatch, StringArray};
4use arrow_schema::{DataType, Field, Schema};
5use aws_common::{client::create_sqs_client, config::SqsConfig};
6use kinetic_buffers::BufferSender;
7use kinetic_core::{ComponentId, EventBatch, EventMetadata, ShutdownSignal};
8use metrics::{Label, counter};
9use std::sync::Arc;
10use tokio::time::{Duration, sleep};
11use tracing::{error, info, trace};
12
13pub struct SqsSourceTask {
14    config: SqsConfig,
15    component_id: String,
16    pipeline_id: String,
17    sender: BufferSender,
18    shutdown: ShutdownSignal,
19    labels: Arc<[Label]>,
20}
21
22impl SqsSourceTask {
23    pub fn new(
24        config: SqsConfig,
25        component_id: String,
26        pipeline_id: String,
27        sender: BufferSender,
28        shutdown: ShutdownSignal,
29    ) -> Self {
30        let labels: Arc<[Label]> = Arc::new([
31            Label::new("component_id", component_id.clone()),
32            Label::new("component_type", "source"),
33            Label::new("component_kind", "aws_sqs"),
34        ]);
35
36        Self {
37            config,
38            component_id,
39            pipeline_id,
40            sender,
41            shutdown,
42            labels,
43        }
44    }
45
46    pub async fn run(mut self) {
47        info!(
48            "Starting SQS source '{}' for queue: {}",
49            self.component_id, self.config.queue_url
50        );
51
52        let aws = self.config.auth.as_ref().cloned().unwrap_or_default();
53        let client = create_sqs_client(&aws).await;
54
55        let metadata = Arc::new(EventMetadata::new(
56            self.pipeline_id.clone(),
57            ComponentId(self.component_id.clone()),
58        ));
59
60        let schema = Arc::new(Schema::new(vec![
61            Field::new("message_id", DataType::Utf8, false),
62            Field::new("body", DataType::Utf8, true),
63        ]));
64
65        loop {
66            tokio::select! {
67                _ = self.shutdown.recv() => {
68                    info!("SQS source '{}' shutting down", self.component_id);
69                    break;
70                }
71                messages = client
72                    .receive_message()
73                    .queue_url(&self.config.queue_url)
74                    .max_number_of_messages(self.config.max_number_of_messages)
75                    .wait_time_seconds(i32::try_from(self.config.poll_secs).unwrap_or(20).min(20))
76                    .send() => {
77                        match messages {
78                            Ok(response) => {
79                                let messages = response.messages.unwrap_or_default();
80                                if messages.is_empty() {
81                                    trace!("SQS {} no messages, polling again", self.component_id);
82                                    continue;
83                                }
84
85                                info!(
86                                    "SQS {} received {} messages",
87                                    self.component_id,
88                                    messages.len()
89                                );
90
91                                let mut message_ids = Vec::with_capacity(messages.len());
92                                let mut bodies = Vec::with_capacity(messages.len());
93                                let mut receipt_handles = Vec::with_capacity(messages.len());
94
95                                for msg in messages {
96                                    let body = msg.body.unwrap_or_default();
97                                    counter!("component_received_network_bytes_total", self.labels.iter()).increment(body.len() as u64);
98                                    message_ids.push(msg.message_id.unwrap_or_default());
99                                    bodies.push(Some(body));
100                                    if let Some(handle) = msg.receipt_handle {
101                                        receipt_handles.push(handle);
102                                    }
103                                }
104
105                                // For now, treat body as string
106                                let id_array = StringArray::from(message_ids);
107                                let body_array = StringArray::from(bodies);
108
109                                if let Ok(rb) = RecordBatch::try_new(
110                                    schema.clone(),
111                                    vec![Arc::new(id_array), Arc::new(body_array)],
112                                ) {
113                                    match EventBatch::new_with_xid(rb, metadata.clone()) {
114                                        Ok(batch) => {
115                                            let row_count = batch.num_rows();
116                                            let byte_size = batch.estimated_size();
117                                            counter!("component_received_events_total", self.labels.iter()).increment(row_count as u64);
118                                            counter!("component_received_event_bytes_total", self.labels.iter()).increment(byte_size as u64);
119
120                                            if let Err(e) = self.sender.send(batch).await {
121                                                counter!("component_errors_total", self.labels.iter()).increment(1);
122                                                error!("Failed to send SQS messages downstream: {:?}", e);
123                                            } else {
124                                                counter!("component_sent_events_total", self.labels.iter()).increment(row_count as u64);
125                                                // Delete messages after successful send
126                                                if self.config.delete_message {
127                                                    for handle in receipt_handles {
128                                                        if let Err(e) = client
129                                                            .delete_message()
130                                                            .queue_url(&self.config.queue_url)
131                                                            .receipt_handle(handle)
132                                                            .send()
133                                                            .await
134                                                        {
135                                                            error!(
136                                                                "Failed to delete SQS message after processing in '{}': {}",
137                                                                self.component_id, e
138                                                            );
139                                                        }
140                                                    }
141                                                }
142                                            }
143                                        }
144                                        Err(e) => {
145                                            counter!("component_errors_total", self.labels.iter()).increment(1);
146                                            error!("Failed to create EventBatch from SQS: {}", e);
147                                        }
148                                    }
149                                }
150                            }
151                            Err(e) => {
152                                error!(
153                                    "SQS receive error in component '{}': {}",
154                                    self.component_id, e
155                                );
156                                sleep(Duration::from_secs(self.config.poll_secs)).await;
157                            }
158                        }
159                    }
160            }
161        }
162    }
163}