1use aws_common::{
4 client::create_s3_client,
5 client::create_sqs_client,
6 config::{AwsConfig, S3SourceConfig, SqsConfig},
7};
8use aws_sdk_s3::Client as S3Client;
9use kinetic_buffers::BufferSender;
10use kinetic_common::config::SourceContext;
11use kinetic_core::encode::Decoder;
12use kinetic_core::healthcheck::Healthcheck;
13use kinetic_core::state::StateStore;
14use kinetic_core::{ArcEventMetadata, ComponentId, EventMetadata, ShutdownSignal};
15use metrics::{Label, counter};
16use std::sync::Arc;
17use tokio::time::{Duration, interval, sleep};
18use tracing::{debug, error, info, trace, warn};
19
20pub struct S3SourceTask {
21 config: S3SourceConfig,
22 component_id: String,
23 pipeline_id: String,
24 sender: BufferSender,
25 #[allow(dead_code)]
26 error_sender: BufferSender,
27 shutdown: ShutdownSignal,
28 decoder: Arc<dyn Decoder>,
29 labels: Arc<[Label]>,
30 state_store: Option<StateStore>,
31}
32
33use serde::Deserialize;
34
35#[derive(Debug, Deserialize)]
36struct S3EventNotification {
37 #[serde(rename = "Records")]
38 records: Vec<S3EventRecord>,
39}
40
41#[derive(Debug, Deserialize)]
42struct S3EventRecord {
43 s3: S3Entity,
44}
45
46#[derive(Debug, Deserialize)]
47struct S3Entity {
48 bucket: S3Bucket,
49 object: S3Object,
50}
51
52#[derive(Debug, Deserialize)]
53struct S3Bucket {
54 name: String,
55}
56
57#[derive(Debug, Deserialize)]
58struct S3Object {
59 key: String,
60}
61
62impl S3SourceTask {
63 pub fn new(config: S3SourceConfig, cx: SourceContext, decoder: Arc<dyn Decoder>) -> Self {
64 let labels: Arc<[Label]> = Arc::new([
65 Label::new("component_id", cx.id.0.clone()),
66 Label::new("component_type", "source"),
67 Label::new("component_kind", "aws_s3"),
68 ]);
69
70 let state_store = cx.data_dir.map(StateStore::new);
71
72 Self {
73 config,
74 component_id: cx.id.0,
75 pipeline_id: cx.pipeline_id,
76 sender: cx.out,
77 error_sender: cx.error_out,
78 shutdown: cx.shutdown,
79 decoder,
80 labels,
81 state_store,
82 }
83 }
84
85 pub async fn run(self) {
86 match &self.config {
87 S3SourceConfig::List {
88 auth,
89 bucket,
90 prefix,
91 interval_secs,
92 delete_after_read,
93 } => {
94 let aws = auth.as_ref().cloned().unwrap_or_default();
95 self.run_list_mode(&aws, bucket, prefix, *interval_secs, *delete_after_read)
96 .await;
97 }
98 S3SourceConfig::EventStream { sqs, bucket } => {
99 self.run_event_stream_mode(sqs, bucket).await;
100 }
101 }
102 }
103
104 async fn run_list_mode(
105 &self,
106 aws: &AwsConfig,
107 bucket: &str,
108 prefix: &Option<String>,
109 interval_secs: u64,
110 delete_after_read: bool,
111 ) {
112 info!(
113 "Starting S3 source '{}' in LIST mode for bucket: {}",
114 self.component_id, bucket
115 );
116
117 let s3_client = create_s3_client(aws).await;
118 let mut interval = interval(Duration::from_secs(interval_secs));
119 let metadata = Arc::new(EventMetadata::new(
120 self.pipeline_id.clone(),
121 ComponentId(self.component_id.clone()),
122 ));
123
124 let mut shutdown = self.shutdown.clone();
125 let mut processed_keys: std::collections::HashSet<String> = if delete_after_read {
126 std::collections::HashSet::new()
127 } else {
128 self.state_store
129 .as_ref()
130 .and_then(|s| s.load().ok().flatten())
131 .unwrap_or_default()
132 };
133
134 let mut continuation_token: Option<String> = None;
135
136 loop {
137 tokio::select! {
138 _ = shutdown.recv() => {
139 info!("S3 source '{}' received shutdown signal", self.component_id);
140 if let Some(s) = &self.state_store {
141 let _ = s.save(&processed_keys);
142 }
143 break;
144 }
145 _ = interval.tick() => {
146 loop {
147 debug!("S3 source '{}' listing bucket: {}", self.component_id, bucket);
148 let mut list_objects = s3_client.list_objects_v2().bucket(bucket);
149 if let Some(p) = prefix {
150 list_objects = list_objects.prefix(p);
151 }
152 if let Some(token) = &continuation_token {
153 list_objects = list_objects.continuation_token(token);
154 }
155
156 match list_objects.send().await {
157 Ok(resp) => {
158 let mut any_new = false;
159 for obj in resp.contents.unwrap_or_default() {
160 if let Some(key) = obj.key
161 && !processed_keys.contains(&key)
162 {
163 if let Err(e) = self
164 .process_object(&s3_client, bucket, &key, &metadata)
165 .await
166 {
167 error!("Failed to process S3 object {}: {}", key, e);
168 } else {
169 if delete_after_read {
170 if let Err(e) = s3_client.delete_object().bucket(bucket).key(&key).send().await {
171 error!("Failed to delete processed S3 object {}: {}", key, e);
172 }
173 } else {
174 processed_keys.insert(key);
175 }
176 any_new = true;
177 }
178 }
179 }
180
181 if any_new
182 && !delete_after_read
183 && let Some(s) = &self.state_store
184 {
185 let _ = s.save(&processed_keys);
186 }
187
188 continuation_token = resp.next_continuation_token;
189 if continuation_token.is_none() {
190 break;
191 }
192 }
193 Err(e) => {
194 error!("Failed to list S3 objects in bucket {}: {}", bucket, e);
195 break;
196 }
197 }
198 }
199 }
200 }
201 }
202 }
203
204 async fn run_event_stream_mode(&self, sqs_config: &SqsConfig, bucket: &str) {
205 info!(
206 "Starting S3 source '{}' in EVENT mode for bucket: {} using SQS: {}",
207 self.component_id, bucket, sqs_config.queue_url
208 );
209
210 let aws = sqs_config.auth.as_ref().cloned().unwrap_or_default();
211 let s3_client = create_s3_client(&aws).await;
212 let sqs_client = create_sqs_client(&aws).await;
213 let metadata = Arc::new(EventMetadata::new(
214 self.pipeline_id.clone(),
215 ComponentId(self.component_id.clone()),
216 ));
217
218 let mut shutdown = self.shutdown.clone();
219
220 loop {
221 tokio::select! {
222 _ = shutdown.recv() => {
223 info!("S3 source '{}' received shutdown signal", self.component_id);
224 break;
225 }
226 _ = async {
227 loop {
228 match sqs_client
229 .receive_message()
230 .queue_url(&sqs_config.queue_url)
231 .max_number_of_messages(sqs_config.max_number_of_messages)
232 .wait_time_seconds(20)
233 .send()
234 .await
235 {
236 Ok(resp) => {
237 for msg in resp.messages.unwrap_or_default() {
238 let mut all_success = true;
239 if let Some(body) = msg.body {
240 match serde_json::from_str::<S3EventNotification>(&body) {
242 Ok(notif) => {
243 for record in notif.records {
244 let event_bucket = &record.s3.bucket.name;
245 let key = &record.s3.object.key;
246
247 let decoded_key = match urlencoding::decode(key) {
249 Ok(k) => k.into_owned(),
250 Err(_) => key.clone(),
251 };
252
253 if event_bucket == bucket {
254 if let Err(e) = self.process_object(&s3_client, bucket, &decoded_key, &metadata).await {
255 error!("Failed to process object {} from event: {}", decoded_key, e);
256 all_success = false;
257 }
258 } else {
259 warn!("Received event for bucket {}, but configured for {}. Skipping.", event_bucket, bucket);
260 }
261 }
262 }
263 Err(e) => {
264 #[derive(Deserialize)]
266 struct SnsMessage {
267 #[serde(rename = "Message")]
268 message: String,
269 }
270
271 match serde_json::from_str::<SnsMessage>(&body) {
272 Ok(sns) => {
273 match serde_json::from_str::<S3EventNotification>(&sns.message) {
274 Ok(notif) => {
275 for record in notif.records {
276 let event_bucket = &record.s3.bucket.name;
277 let key = &record.s3.object.key;
278 let decoded_key = match urlencoding::decode(key) {
279 Ok(k) => k.into_owned(),
280 Err(_) => key.clone(),
281 };
282 if event_bucket == bucket
283 && let Err(e2) = self.process_object(&s3_client, bucket, &decoded_key, &metadata).await
284 {
285 error!("Failed to process object {} from SNS event: {}", decoded_key, e2);
286 all_success = false;
287 }
288 }
289 }
290 Err(e2) => {
291 error!("Failed to parse S3 event from SNS message: {}. Inner error: {}", e, e2);
292 all_success = false;
293 }
294 }
295 }
296 Err(_) => {
297 error!("Failed to parse SQS message body as S3 event or SNS message: {}", e);
298 all_success = false;
299 }
300 }
301 }
302 }
303 }
304
305 if all_success
306 && let Some(handle) = msg.receipt_handle
307 && let Err(e) = sqs_client
308 .delete_message()
309 .queue_url(&sqs_config.queue_url)
310 .receipt_handle(handle)
311 .send()
312 .await
313 {
314 warn!("Failed to delete SQS message: {}", e);
315 }
316 }
317 }
318 Err(e) => {
319 error!("SQS receive error in '{}': {}", self.component_id, e);
320 sleep(Duration::from_secs(5)).await;
321 }
322 }
323 }
324 } => {}
325 }
326 }
327 }
328
329 async fn process_object(
330 &self,
331 client: &S3Client,
332 bucket: &str,
333 key: &str,
334 metadata: &ArcEventMetadata,
335 ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
336 trace!(
337 "S3 source {} processing object: {}/{}",
338 self.component_id, bucket, key
339 );
340 let resp = client.get_object().bucket(bucket).key(key).send().await?;
341 let data = resp.body.collect().await?.into_bytes();
342 let data_len = data.len();
343
344 counter!("component_received_network_bytes_total", self.labels.iter())
345 .increment(data_len as u64);
346
347 let batch = match self.decoder.decode(&data, metadata.clone()) {
348 Ok(b) => b,
349 Err(e) => {
350 counter!("component_errors_total", self.labels.iter()).increment(1);
351 return Err(e.into());
352 }
353 };
354
355 let row_count = batch.num_rows();
356 let byte_size = batch.estimated_size();
357
358 counter!("component_received_events_total", self.labels.iter()).increment(row_count as u64);
359 counter!("component_received_event_bytes_total", self.labels.iter())
360 .increment(byte_size as u64);
361
362 if let Err(e) = self.sender.send(batch).await {
363 counter!("component_errors_total", self.labels.iter()).increment(1);
364 return Err(format!("Send error: {:?}", e).into());
365 }
366
367 counter!("component_sent_events_total", self.labels.iter()).increment(row_count as u64);
368
369 Ok(())
370 }
371}
372
373#[async_trait::async_trait]
374impl Healthcheck for S3SourceTask {
375 async fn check(&self) -> anyhow::Result<()> {
376 let (aws, bucket) = match &self.config {
377 S3SourceConfig::List { auth, bucket, .. } => (auth, bucket),
378 S3SourceConfig::EventStream { sqs, bucket } => (&sqs.auth, bucket),
379 };
380
381 let aws = aws.as_ref().cloned().unwrap_or_default();
382 let s3_client = create_s3_client(&aws).await;
383 s3_client
384 .list_objects_v2()
385 .bucket(bucket)
386 .max_keys(1)
387 .send()
388 .await
389 .map(|_| ())
390 .map_err(|e| {
391 anyhow::anyhow!(
392 "S3 healthcheck failed for bucket '{}' in component '{}': {}",
393 bucket,
394 self.component_id,
395 e
396 )
397 })
398 }
399}