kinetic/sources/aws/
sqs.rs1use arrow_array::{RecordBatch, StringArray};
4use arrow_schema::{DataType, Field, Schema};
5use aws_common::{client::create_sqs_client, config::SqsConfig};
6use kinetic_buffers::BufferSender;
7use kinetic_core::{ComponentId, EventBatch, EventMetadata, ShutdownSignal};
8use metrics::{Label, counter};
9use std::sync::Arc;
10use tokio::time::{Duration, sleep};
11use tracing::{error, info, trace};
12
13pub struct SqsSourceTask {
14 config: SqsConfig,
15 component_id: String,
16 pipeline_id: String,
17 sender: BufferSender,
18 shutdown: ShutdownSignal,
19 labels: Arc<[Label]>,
20}
21
22impl SqsSourceTask {
23 pub fn new(
24 config: SqsConfig,
25 component_id: String,
26 pipeline_id: String,
27 sender: BufferSender,
28 shutdown: ShutdownSignal,
29 ) -> Self {
30 let labels: Arc<[Label]> = Arc::new([
31 Label::new("component_id", component_id.clone()),
32 Label::new("component_type", "source"),
33 Label::new("component_kind", "aws_sqs"),
34 ]);
35
36 Self {
37 config,
38 component_id,
39 pipeline_id,
40 sender,
41 shutdown,
42 labels,
43 }
44 }
45
46 pub async fn run(mut self) {
47 info!(
48 "Starting SQS source '{}' for queue: {}",
49 self.component_id, self.config.queue_url
50 );
51
52 let aws = self.config.auth.as_ref().cloned().unwrap_or_default();
53 let client = create_sqs_client(&aws).await;
54
55 let metadata = Arc::new(EventMetadata::new(
56 self.pipeline_id.clone(),
57 ComponentId(self.component_id.clone()),
58 ));
59
60 let schema = Arc::new(Schema::new(vec![
61 Field::new("message_id", DataType::Utf8, false),
62 Field::new("body", DataType::Utf8, true),
63 ]));
64
65 loop {
66 tokio::select! {
67 _ = self.shutdown.recv() => {
68 info!("SQS source '{}' shutting down", self.component_id);
69 break;
70 }
71 messages = client
72 .receive_message()
73 .queue_url(&self.config.queue_url)
74 .max_number_of_messages(self.config.max_number_of_messages)
75 .wait_time_seconds(i32::try_from(self.config.poll_secs).unwrap_or(20).min(20))
76 .send() => {
77 match messages {
78 Ok(response) => {
79 let messages = response.messages.unwrap_or_default();
80 if messages.is_empty() {
81 trace!("SQS {} no messages, polling again", self.component_id);
82 continue;
83 }
84
85 info!(
86 "SQS {} received {} messages",
87 self.component_id,
88 messages.len()
89 );
90
91 let mut message_ids = Vec::with_capacity(messages.len());
92 let mut bodies = Vec::with_capacity(messages.len());
93 let mut receipt_handles = Vec::with_capacity(messages.len());
94
95 for msg in messages {
96 let body = msg.body.unwrap_or_default();
97 counter!("component_received_network_bytes_total", self.labels.iter()).increment(body.len() as u64);
98 message_ids.push(msg.message_id.unwrap_or_default());
99 bodies.push(Some(body));
100 if let Some(handle) = msg.receipt_handle {
101 receipt_handles.push(handle);
102 }
103 }
104
105 let id_array = StringArray::from(message_ids);
107 let body_array = StringArray::from(bodies);
108
109 if let Ok(rb) = RecordBatch::try_new(
110 schema.clone(),
111 vec![Arc::new(id_array), Arc::new(body_array)],
112 ) {
113 match EventBatch::new_with_xid(rb, metadata.clone()) {
114 Ok(batch) => {
115 let row_count = batch.num_rows();
116 let byte_size = batch.estimated_size();
117 counter!("component_received_events_total", self.labels.iter()).increment(row_count as u64);
118 counter!("component_received_event_bytes_total", self.labels.iter()).increment(byte_size as u64);
119
120 if let Err(e) = self.sender.send(batch).await {
121 counter!("component_errors_total", self.labels.iter()).increment(1);
122 error!("Failed to send SQS messages downstream: {:?}", e);
123 } else {
124 counter!("component_sent_events_total", self.labels.iter()).increment(row_count as u64);
125 if self.config.delete_message {
127 for handle in receipt_handles {
128 if let Err(e) = client
129 .delete_message()
130 .queue_url(&self.config.queue_url)
131 .receipt_handle(handle)
132 .send()
133 .await
134 {
135 error!(
136 "Failed to delete SQS message after processing in '{}': {}",
137 self.component_id, e
138 );
139 }
140 }
141 }
142 }
143 }
144 Err(e) => {
145 counter!("component_errors_total", self.labels.iter()).increment(1);
146 error!("Failed to create EventBatch from SQS: {}", e);
147 }
148 }
149 }
150 }
151 Err(e) => {
152 error!(
153 "SQS receive error in component '{}': {}",
154 self.component_id, e
155 );
156 sleep(Duration::from_secs(self.config.poll_secs)).await;
157 }
158 }
159 }
160 }
161 }
162 }
163}