Skip to main content

kinetic/sources/aws/
s3.rs

1//! AWS S3 Source.
2
3use aws_common::{
4    client::create_s3_client,
5    client::create_sqs_client,
6    config::{AwsConfig, S3SourceConfig, SqsConfig},
7};
8use aws_sdk_s3::Client as S3Client;
9use kinetic_buffers::BufferSender;
10use kinetic_common::config::SourceContext;
11use kinetic_core::encode::Decoder;
12use kinetic_core::healthcheck::Healthcheck;
13use kinetic_core::state::StateStore;
14use kinetic_core::{ArcEventMetadata, ComponentId, EventMetadata, ShutdownSignal};
15use metrics::{Label, counter};
16use std::sync::Arc;
17use tokio::time::{Duration, interval, sleep};
18use tracing::{debug, error, info, trace, warn};
19
20pub struct S3SourceTask {
21    config: S3SourceConfig,
22    component_id: String,
23    pipeline_id: String,
24    sender: BufferSender,
25    #[allow(dead_code)]
26    error_sender: BufferSender,
27    shutdown: ShutdownSignal,
28    decoder: Arc<dyn Decoder>,
29    labels: Arc<[Label]>,
30    state_store: Option<StateStore>,
31}
32
33use serde::Deserialize;
34
35#[derive(Debug, Deserialize)]
36struct S3EventNotification {
37    #[serde(rename = "Records")]
38    records: Vec<S3EventRecord>,
39}
40
41#[derive(Debug, Deserialize)]
42struct S3EventRecord {
43    s3: S3Entity,
44}
45
46#[derive(Debug, Deserialize)]
47struct S3Entity {
48    bucket: S3Bucket,
49    object: S3Object,
50}
51
52#[derive(Debug, Deserialize)]
53struct S3Bucket {
54    name: String,
55}
56
57#[derive(Debug, Deserialize)]
58struct S3Object {
59    key: String,
60}
61
62impl S3SourceTask {
63    pub fn new(config: S3SourceConfig, cx: SourceContext, decoder: Arc<dyn Decoder>) -> Self {
64        let labels: Arc<[Label]> = Arc::new([
65            Label::new("component_id", cx.id.0.clone()),
66            Label::new("component_type", "source"),
67            Label::new("component_kind", "aws_s3"),
68        ]);
69
70        let state_store = cx.data_dir.map(StateStore::new);
71
72        Self {
73            config,
74            component_id: cx.id.0,
75            pipeline_id: cx.pipeline_id,
76            sender: cx.out,
77            error_sender: cx.error_out,
78            shutdown: cx.shutdown,
79            decoder,
80            labels,
81            state_store,
82        }
83    }
84
85    pub async fn run(self) {
86        match &self.config {
87            S3SourceConfig::List {
88                auth,
89                bucket,
90                prefix,
91                interval_secs,
92                delete_after_read,
93            } => {
94                let aws = auth.as_ref().cloned().unwrap_or_default();
95                self.run_list_mode(&aws, bucket, prefix, *interval_secs, *delete_after_read)
96                    .await;
97            }
98            S3SourceConfig::EventStream { sqs, bucket } => {
99                self.run_event_stream_mode(sqs, bucket).await;
100            }
101        }
102    }
103
104    async fn run_list_mode(
105        &self,
106        aws: &AwsConfig,
107        bucket: &str,
108        prefix: &Option<String>,
109        interval_secs: u64,
110        delete_after_read: bool,
111    ) {
112        info!(
113            "Starting S3 source '{}' in LIST mode for bucket: {}",
114            self.component_id, bucket
115        );
116
117        let s3_client = create_s3_client(aws).await;
118        let mut interval = interval(Duration::from_secs(interval_secs));
119        let metadata = Arc::new(EventMetadata::new(
120            self.pipeline_id.clone(),
121            ComponentId(self.component_id.clone()),
122        ));
123
124        let mut shutdown = self.shutdown.clone();
125        let mut processed_keys: std::collections::HashSet<String> = if delete_after_read {
126            std::collections::HashSet::new()
127        } else {
128            self.state_store
129                .as_ref()
130                .and_then(|s| s.load().ok().flatten())
131                .unwrap_or_default()
132        };
133
134        let mut continuation_token: Option<String> = None;
135
136        loop {
137            tokio::select! {
138                _ = shutdown.recv() => {
139                    info!("S3 source '{}' received shutdown signal", self.component_id);
140                    if let Some(s) = &self.state_store {
141                        let _ = s.save(&processed_keys);
142                    }
143                    break;
144                }
145                _ = interval.tick() => {
146                    loop {
147                        debug!("S3 source '{}' listing bucket: {}", self.component_id, bucket);
148                        let mut list_objects = s3_client.list_objects_v2().bucket(bucket);
149                        if let Some(p) = prefix {
150                            list_objects = list_objects.prefix(p);
151                        }
152                        if let Some(token) = &continuation_token {
153                            list_objects = list_objects.continuation_token(token);
154                        }
155
156                        match list_objects.send().await {
157                            Ok(resp) => {
158                                let mut any_new = false;
159                                for obj in resp.contents.unwrap_or_default() {
160                                    if let Some(key) = obj.key
161                                        && !processed_keys.contains(&key)
162                                    {
163                                        if let Err(e) = self
164                                            .process_object(&s3_client, bucket, &key, &metadata)
165                                            .await
166                                        {
167                                            error!("Failed to process S3 object {}: {}", key, e);
168                                        } else {
169                                            if delete_after_read {
170                                                if let Err(e) = s3_client.delete_object().bucket(bucket).key(&key).send().await {
171                                                    error!("Failed to delete processed S3 object {}: {}", key, e);
172                                                }
173                                            } else {
174                                                processed_keys.insert(key);
175                                            }
176                                            any_new = true;
177                                        }
178                                    }
179                                }
180
181                                if any_new
182                                    && !delete_after_read
183                                    && let Some(s) = &self.state_store
184                                {
185                                    let _ = s.save(&processed_keys);
186                                }
187
188                                continuation_token = resp.next_continuation_token;
189                                if continuation_token.is_none() {
190                                    break;
191                                }
192                            }
193                            Err(e) => {
194                                error!("Failed to list S3 objects in bucket {}: {}", bucket, e);
195                                break;
196                            }
197                        }
198                    }
199                }
200            }
201        }
202    }
203
204    async fn run_event_stream_mode(&self, sqs_config: &SqsConfig, bucket: &str) {
205        info!(
206            "Starting S3 source '{}' in EVENT mode for bucket: {} using SQS: {}",
207            self.component_id, bucket, sqs_config.queue_url
208        );
209
210        let aws = sqs_config.auth.as_ref().cloned().unwrap_or_default();
211        let s3_client = create_s3_client(&aws).await;
212        let sqs_client = create_sqs_client(&aws).await;
213        let metadata = Arc::new(EventMetadata::new(
214            self.pipeline_id.clone(),
215            ComponentId(self.component_id.clone()),
216        ));
217
218        let mut shutdown = self.shutdown.clone();
219
220        loop {
221            tokio::select! {
222                _ = shutdown.recv() => {
223                    info!("S3 source '{}' received shutdown signal", self.component_id);
224                    break;
225                }
226                _ = async {
227                        loop {
228                            match sqs_client
229                                .receive_message()
230                                .queue_url(&sqs_config.queue_url)
231                                .max_number_of_messages(sqs_config.max_number_of_messages)
232                                .wait_time_seconds(20)
233                                .send()
234                                .await
235                            {
236                                Ok(resp) => {
237                                    for msg in resp.messages.unwrap_or_default() {
238                                        let mut all_success = true;
239                                        if let Some(body) = msg.body {
240                                            // Decode S3 event notification
241                                            match serde_json::from_str::<S3EventNotification>(&body) {
242                                                Ok(notif) => {
243                                                    for record in notif.records {
244                                                        let event_bucket = &record.s3.bucket.name;
245                                                        let key = &record.s3.object.key;
246
247                                                        // URL decode key if necessary (S3 events are URL encoded)
248                                                        let decoded_key = match urlencoding::decode(key) {
249                                                            Ok(k) => k.into_owned(),
250                                                            Err(_) => key.clone(),
251                                                        };
252
253                                                        if event_bucket == bucket {
254                                                            if let Err(e) = self.process_object(&s3_client, bucket, &decoded_key, &metadata).await {
255                                                                error!("Failed to process object {} from event: {}", decoded_key, e);
256                                                                all_success = false;
257                                                            }
258                                                        } else {
259                                                            warn!("Received event for bucket {}, but configured for {}. Skipping.", event_bucket, bucket);
260                                                        }
261                                                    }
262                                                }
263                                                Err(e) => {
264                                                    // Try parsing as SNS-wrapped SQS message if simple parse fails
265                                                    #[derive(Deserialize)]
266                                                    struct SnsMessage {
267                                                        #[serde(rename = "Message")]
268                                                        message: String,
269                                                    }
270
271                                                    match serde_json::from_str::<SnsMessage>(&body) {
272                                                        Ok(sns) => {
273                                                            match serde_json::from_str::<S3EventNotification>(&sns.message) {
274                                                                Ok(notif) => {
275                                                                    for record in notif.records {
276                                                                        let event_bucket = &record.s3.bucket.name;
277                                                                        let key = &record.s3.object.key;
278                                                                        let decoded_key = match urlencoding::decode(key) {
279                                                                            Ok(k) => k.into_owned(),
280                                                                            Err(_) => key.clone(),
281                                                                        };
282                                                                        if event_bucket == bucket
283                                                                            && let Err(e2) = self.process_object(&s3_client, bucket, &decoded_key, &metadata).await
284                                                                        {
285                                                                            error!("Failed to process object {} from SNS event: {}", decoded_key, e2);
286                                                                            all_success = false;
287                                                                        }
288                                                                    }
289                                                                }
290                                                                Err(e2) => {
291                                                                    error!("Failed to parse S3 event from SNS message: {}. Inner error: {}", e, e2);
292                                                                    all_success = false;
293                                                                }
294                                                            }
295                                                        }
296                                                        Err(_) => {
297                                                            error!("Failed to parse SQS message body as S3 event or SNS message: {}", e);
298                                                            all_success = false;
299                                                        }
300                                                    }
301                                                }
302                                            }
303                                        }
304
305                                        if all_success
306                                            && let Some(handle) = msg.receipt_handle
307                                            && let Err(e) = sqs_client
308                                                .delete_message()
309                                                .queue_url(&sqs_config.queue_url)
310                                                .receipt_handle(handle)
311                                                .send()
312                                                .await
313                                        {
314                                            warn!("Failed to delete SQS message: {}", e);
315                                        }
316                                    }
317                                }
318                                Err(e) => {
319                                    error!("SQS receive error in '{}': {}", self.component_id, e);
320                                    sleep(Duration::from_secs(5)).await;
321                                }
322                            }
323                        }
324                    } => {}
325            }
326        }
327    }
328
329    async fn process_object(
330        &self,
331        client: &S3Client,
332        bucket: &str,
333        key: &str,
334        metadata: &ArcEventMetadata,
335    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
336        trace!(
337            "S3 source {} processing object: {}/{}",
338            self.component_id, bucket, key
339        );
340        let resp = client.get_object().bucket(bucket).key(key).send().await?;
341        let data = resp.body.collect().await?.into_bytes();
342        let data_len = data.len();
343
344        counter!("component_received_network_bytes_total", self.labels.iter())
345            .increment(data_len as u64);
346
347        let batch = match self.decoder.decode(&data, metadata.clone()) {
348            Ok(b) => b,
349            Err(e) => {
350                counter!("component_errors_total", self.labels.iter()).increment(1);
351                return Err(e.into());
352            }
353        };
354
355        let row_count = batch.num_rows();
356        let byte_size = batch.estimated_size();
357
358        counter!("component_received_events_total", self.labels.iter()).increment(row_count as u64);
359        counter!("component_received_event_bytes_total", self.labels.iter())
360            .increment(byte_size as u64);
361
362        if let Err(e) = self.sender.send(batch).await {
363            counter!("component_errors_total", self.labels.iter()).increment(1);
364            return Err(format!("Send error: {:?}", e).into());
365        }
366
367        counter!("component_sent_events_total", self.labels.iter()).increment(row_count as u64);
368
369        Ok(())
370    }
371}
372
373#[async_trait::async_trait]
374impl Healthcheck for S3SourceTask {
375    async fn check(&self) -> anyhow::Result<()> {
376        let (aws, bucket) = match &self.config {
377            S3SourceConfig::List { auth, bucket, .. } => (auth, bucket),
378            S3SourceConfig::EventStream { sqs, bucket } => (&sqs.auth, bucket),
379        };
380
381        let aws = aws.as_ref().cloned().unwrap_or_default();
382        let s3_client = create_s3_client(&aws).await;
383        s3_client
384            .list_objects_v2()
385            .bucket(bucket)
386            .max_keys(1)
387            .send()
388            .await
389            .map(|_| ())
390            .map_err(|e| {
391                anyhow::anyhow!(
392                    "S3 healthcheck failed for bucket '{}' in component '{}': {}",
393                    bucket,
394                    self.component_id,
395                    e
396                )
397            })
398    }
399}