Skip to main content

kinetic/sources/gcp/
gcs.rs

1//! Google Cloud Storage (GCS) Source.
2
3use gcloud_pubsub::client::Client as PubSubClient;
4use gcloud_pubsub::subscriber::ReceivedMessage;
5use gcloud_storage::client::Client as StorageClient;
6use gcloud_storage::http::objects::download::Range;
7use gcloud_storage::http::objects::get::GetObjectRequest;
8use gcloud_storage::http::objects::list::ListObjectsRequest;
9use gcp_common::{
10    client::{create_pubsub_client, create_storage_client},
11    config::{GcpConfig, GcsSourceConfig, PubSubConfig},
12};
13use kinetic_buffers::BufferSender;
14use kinetic_common::config::SourceContext;
15use kinetic_core::encode::Decoder;
16use kinetic_core::healthcheck::Healthcheck;
17use kinetic_core::state::StateStore;
18use kinetic_core::{ArcEventMetadata, ComponentId, EventMetadata, ShutdownSignal};
19use metrics::{Label, counter};
20use serde::Deserialize;
21use std::sync::Arc;
22use tokio::time::{Duration, interval, sleep};
23use tracing::{debug, error, info, trace, warn};
24
25pub struct GcsSourceTask {
26    config: GcsSourceConfig,
27    component_id: String,
28    pipeline_id: String,
29    sender: BufferSender,
30    #[allow(dead_code)]
31    error_sender: BufferSender,
32    shutdown: ShutdownSignal,
33    decoder: Arc<dyn Decoder>,
34    labels: Arc<[Label]>,
35    state_store: Option<StateStore>,
36}
37
38#[derive(Debug, Deserialize)]
39#[serde(rename_all = "camelCase")]
40#[allow(dead_code)]
41struct GcsNotification {
42    kind: String,
43    id: String,
44    self_link: String,
45    name: String,
46    bucket: String,
47    generation: String,
48    metageneration: String,
49    content_type: Option<String>,
50    time_created: String,
51    updated: String,
52    size: String,
53    md5_hash: String,
54    media_link: String,
55    #[serde(default)]
56    metadata: std::collections::HashMap<String, String>,
57}
58
59impl GcsSourceTask {
60    pub fn new(config: GcsSourceConfig, cx: SourceContext, decoder: Arc<dyn Decoder>) -> Self {
61        let labels: Arc<[Label]> = Arc::new([
62            Label::new("component_id", cx.id.0.clone()),
63            Label::new("component_type", "source"),
64            Label::new("component_kind", "gcp_cloud_storage"),
65        ]);
66
67        let state_store = cx.data_dir.map(StateStore::new);
68
69        Self {
70            config,
71            component_id: cx.id.0,
72            pipeline_id: cx.pipeline_id,
73            sender: cx.out,
74            error_sender: cx.error_out,
75            shutdown: cx.shutdown,
76            decoder,
77            labels,
78            state_store,
79        }
80    }
81
82    pub async fn run(self) {
83        match &self.config {
84            GcsSourceConfig::List {
85                auth,
86                bucket,
87                prefix,
88                interval_secs,
89                delete_after_read,
90            } => {
91                let gcp: GcpConfig = auth.as_ref().cloned().unwrap_or_default();
92                self.run_list_mode(&gcp, bucket, prefix, *interval_secs, *delete_after_read)
93                    .await;
94            }
95            GcsSourceConfig::EventStream {
96                pubsub,
97                bucket,
98                include_metadata_updates,
99            } => {
100                self.run_event_stream_mode(pubsub, bucket, *include_metadata_updates)
101                    .await;
102            }
103        }
104    }
105
106    async fn run_list_mode(
107        &self,
108        gcp: &GcpConfig,
109        bucket: &str,
110        prefix: &Option<String>,
111        interval_secs: u64,
112        delete_after_read: bool,
113    ) {
114        info!(
115            "Starting GCS source '{}' in LIST mode for bucket: {}",
116            self.component_id, bucket
117        );
118
119        let storage_client: StorageClient = match create_storage_client(gcp).await {
120            Ok(c) => c,
121            Err(e) => {
122                error!("Failed to create GCS client: {}", e);
123                return;
124            }
125        };
126
127        let mut interval = interval(Duration::from_secs(interval_secs));
128        let metadata = Arc::new(EventMetadata::new(
129            self.pipeline_id.clone(),
130            ComponentId(self.component_id.clone()),
131        ));
132
133        let mut shutdown = self.shutdown.clone();
134        let mut processed_objects: std::collections::HashSet<String> = self
135            .state_store
136            .as_ref()
137            .and_then(|s| s.load().ok().flatten())
138            .unwrap_or_default();
139
140        loop {
141            tokio::select! {
142                _ = shutdown.recv() => {
143                    info!("GCS source '{}' received shutdown signal", self.component_id);
144                    if let Some(s) = &self.state_store {
145                        let _ = s.save(&processed_objects);
146                    }
147                    break;
148                }
149                _ = interval.tick() => {
150                    debug!("GCS source '{}' listing bucket: {}", self.component_id, bucket);
151                    let req = ListObjectsRequest {
152                        bucket: bucket.to_string(),
153                        prefix: prefix.clone(),
154                        ..Default::default()
155                    };
156
157                    match storage_client.list_objects(&req).await {
158                        Ok(resp) => {
159                            let mut any_new = false;
160                            if let Some(items) = resp.items {
161                                for obj in items {
162                                    let key = obj.name.clone();
163                                    if !processed_objects.contains(&key) {
164                                        if let Err(e) = self
165                                            .process_object(&storage_client, bucket, &key, &metadata)
166                                            .await
167                                        {
168                                            error!("Failed to process GCS object {}: {}", key, e);
169                                        } else if delete_after_read {
170                                            debug!("GCS source '{}' deleting processed object: {}/{}", self.component_id, bucket, key);
171                                            let del_req = gcloud_storage::http::objects::delete::DeleteObjectRequest {
172                                                bucket: bucket.to_string(),
173                                                object: key.to_string(),
174                                                ..Default::default()
175                                            };
176                                            if let Err(e) = storage_client.delete_object(&del_req).await {
177                                                error!("Failed to delete GCS object {} after processing: {}", key, e);
178                                                // If deletion fails, we insert into processed_objects to avoid immediate re-processing
179                                                // in the next interval if it's still there.
180                                                processed_objects.insert(key);
181                                                any_new = true;
182                                            }
183                                        } else {
184                                            processed_objects.insert(key);
185                                            any_new = true;
186                                        }
187                                    }
188                                }
189                            }
190
191                            if any_new && let Some(s) = &self.state_store {
192                                let _ = s.save(&processed_objects);
193                            }
194                        }
195                        Err(e) => {
196                            error!("Failed to list GCS objects in bucket {}: {}", bucket, e);
197                        }
198                    }
199                }
200            }
201        }
202    }
203
204    async fn run_event_stream_mode(
205        &self,
206        pubsub_config: &PubSubConfig,
207        bucket: &str,
208        include_metadata_updates: bool,
209    ) {
210        let sub_name: String = match &pubsub_config.subscription {
211            Some(s) => s.clone(),
212            None => {
213                error!(
214                    "Subscription name is required for GCS source EventStream mode in '{}'",
215                    self.component_id
216                );
217                return;
218            }
219        };
220
221        info!(
222            "Starting GCS source '{}' in EVENT mode for bucket: {} using Pub/Sub sub: {}",
223            self.component_id, bucket, sub_name
224        );
225
226        let gcp: GcpConfig = pubsub_config.auth.as_ref().cloned().unwrap_or_default();
227        let storage_client: StorageClient = match create_storage_client(&gcp).await {
228            Ok(c) => c,
229            Err(e) => {
230                error!("Failed to create GCS client: {}", e);
231                return;
232            }
233        };
234        let pubsub_client: PubSubClient = match create_pubsub_client(&gcp).await {
235            Ok(c) => c,
236            Err(e) => {
237                error!("Failed to create Pub/Sub client: {}", e);
238                return;
239            }
240        };
241
242        let metadata = Arc::new(EventMetadata::new(
243            self.pipeline_id.clone(),
244            ComponentId(self.component_id.clone()),
245        ));
246
247        let subscription = pubsub_client.subscription(&sub_name);
248
249        let mut shutdown = self.shutdown.clone();
250
251        loop {
252            tokio::select! {
253                _ = shutdown.recv() => {
254                    info!("GCS source '{}' received shutdown signal", self.component_id);
255                    break;
256                }
257                messages = subscription.pull(10, None) => {
258                    match messages {
259                        Ok(msgs) => {
260                            let msgs: Vec<ReceivedMessage> = msgs;
261                            for msg in msgs {
262                                let mut ack = true;
263                                if let Some(event_type) = msg.message.attributes.get("eventType").map(|s| s.as_str()) {
264                                    let should_process = match event_type {
265                                        "OBJECT_FINALIZE" => true,
266                                        "OBJECT_METADATA_UPDATE" if include_metadata_updates => true,
267                                        _ => false,
268                                    };
269
270                                    if should_process {
271                                        if let Ok(notif) = serde_json::from_slice::<GcsNotification>(&msg.message.data) {
272                                            if notif.bucket == bucket {
273                                                if let Err(e) = self.process_object(&storage_client, bucket, &notif.name, &metadata).await {
274                                                    error!("Failed to process object {} from GCS notification: {}", notif.name, e);
275                                                    ack = false;
276                                                }
277                                            } else {
278                                                warn!("Received GCS notification for bucket {}, but configured for {}. Skipping.", notif.bucket, bucket);
279                                            }
280                                        } else {
281                                            error!("Failed to parse GCS notification payload");
282                                        }
283                                    }
284                                }
285
286                                if ack {
287                                    let _ = msg.ack().await;
288                                } else {
289                                    let _ = msg.nack().await;
290                                }
291                            }
292                        }
293                        Err(e) => {
294                            error!("Pub/Sub pull error in '{}': {}", self.component_id, e);
295                            sleep(Duration::from_secs(5)).await;
296                        }
297                    }
298                }
299            }
300        }
301    }
302
303    async fn process_object(
304        &self,
305        client: &StorageClient,
306        bucket: &str,
307        key: &str,
308        metadata: &ArcEventMetadata,
309    ) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
310        trace!(
311            "GCS source {} processing object: {}/{}",
312            self.component_id, bucket, key
313        );
314
315        let req = GetObjectRequest {
316            bucket: bucket.to_string(),
317            object: key.to_string(),
318            ..Default::default()
319        };
320
321        let data: Vec<u8> = client.download_object(&req, &Range::default()).await?;
322        let data_len = data.len();
323
324        counter!("component_received_network_bytes_total", self.labels.iter())
325            .increment(data_len as u64);
326
327        let batch = match self.decoder.decode(&data, metadata.clone()) {
328            Ok(b) => b,
329            Err(e) => {
330                counter!("component_errors_total", self.labels.iter()).increment(1);
331                return Err(e.into());
332            }
333        };
334
335        let row_count = batch.num_rows();
336        let byte_size = batch.estimated_size();
337
338        counter!("component_received_events_total", self.labels.iter()).increment(row_count as u64);
339        counter!("component_received_event_bytes_total", self.labels.iter())
340            .increment(byte_size as u64);
341
342        if let Err(e) = self.sender.send(batch).await {
343            counter!("component_errors_total", self.labels.iter()).increment(1);
344            return Err(format!("Send error: {:?}", e).into());
345        }
346
347        counter!("component_sent_events_total", self.labels.iter()).increment(row_count as u64);
348
349        Ok(())
350    }
351}
352
353#[async_trait::async_trait]
354impl Healthcheck for GcsSourceTask {
355    async fn check(&self) -> anyhow::Result<()> {
356        let (auth, bucket) = match &self.config {
357            GcsSourceConfig::List { auth, bucket, .. } => (auth, bucket),
358            GcsSourceConfig::EventStream { pubsub, bucket, .. } => (&pubsub.auth, bucket),
359        };
360
361        let gcp: GcpConfig = auth.as_ref().cloned().unwrap_or_default();
362        let client: StorageClient = create_storage_client(&gcp).await?;
363
364        let req = ListObjectsRequest {
365            bucket: bucket.to_string(),
366            max_results: Some(1),
367            ..Default::default()
368        };
369
370        client.list_objects(&req).await.map(|_| ()).map_err(|e| {
371            anyhow::anyhow!(
372                "GCS healthcheck failed for bucket '{}' in component '{}': {}",
373                bucket,
374                self.component_id,
375                e
376            )
377        })
378    }
379}