Skip to main content

kinetic/sources/gcp/
pubsub.rs

1//! Google Cloud Pub/Sub Source.
2
3use arrow_array::{RecordBatch, StringArray};
4use arrow_schema::{DataType, Field, Schema};
5use gcloud_pubsub::client::Client as PubSubClient;
6use gcloud_pubsub::subscriber::ReceivedMessage;
7use gcp_common::{
8    client::create_pubsub_client,
9    config::{GcpConfig, PubSubSourceConfig},
10};
11use kinetic_buffers::BufferSender;
12use kinetic_core::healthcheck::Healthcheck;
13use kinetic_core::{ComponentId, EventBatch, EventMetadata, ShutdownSignal};
14use metrics::{Label, counter};
15use std::sync::Arc;
16use tokio::time::{Duration, sleep};
17use tracing::{error, info, trace};
18
19pub struct PubSubSourceTask {
20    config: PubSubSourceConfig,
21    component_id: String,
22    pipeline_id: String,
23    sender: BufferSender,
24    shutdown: ShutdownSignal,
25    labels: Arc<[Label]>,
26}
27
28impl PubSubSourceTask {
29    pub fn new(
30        config: PubSubSourceConfig,
31        component_id: String,
32        pipeline_id: String,
33        sender: BufferSender,
34        shutdown: ShutdownSignal,
35    ) -> Self {
36        let labels: Arc<[Label]> = Arc::new([
37            Label::new("component_id", component_id.clone()),
38            Label::new("component_type", "source"),
39            Label::new("component_kind", "gcp_pubsub"),
40        ]);
41
42        Self {
43            config,
44            component_id,
45            pipeline_id,
46            sender,
47            shutdown,
48            labels,
49        }
50    }
51
52    pub async fn run(self) {
53        let sub_name: String = match &self.config.config.subscription {
54            Some(s) => s.clone(),
55            None => {
56                error!(
57                    "Subscription name is required for Pub/Sub source in '{}'",
58                    self.component_id
59                );
60                return;
61            }
62        };
63
64        info!(
65            "Starting Pub/Sub source '{}' for subscription: {}",
66            self.component_id, sub_name
67        );
68
69        let gcp: GcpConfig = self
70            .config
71            .config
72            .auth
73            .as_ref()
74            .cloned()
75            .unwrap_or_default();
76        let client: PubSubClient = match create_pubsub_client(&gcp).await {
77            Ok(c) => c,
78            Err(e) => {
79                error!("Failed to create Pub/Sub client: {}", e);
80                return;
81            }
82        };
83
84        let metadata = Arc::new(EventMetadata::new(
85            self.pipeline_id.clone(),
86            ComponentId(self.component_id.clone()),
87        ));
88
89        let schema = Arc::new(Schema::new(vec![
90            Field::new("message_id", DataType::Utf8, false),
91            Field::new("body", DataType::Utf8, true),
92            Field::new("publish_time", DataType::Utf8, true),
93        ]));
94
95        let subscription = client.subscription(&sub_name);
96
97        let mut shutdown = self.shutdown.clone();
98
99        loop {
100            tokio::select! {
101                _ = shutdown.recv() => {
102                    info!("Pub/Sub source '{}' shutting down", self.component_id);
103                    break;
104                }
105                messages = subscription.pull(10, None) => {
106                    match messages {
107                        Ok(msgs) => {
108                            let msgs: Vec<ReceivedMessage> = msgs;
109                            if msgs.is_empty() {
110                                trace!("Pub/Sub {} no messages, polling again", self.component_id);
111                                continue;
112                            }
113
114                            info!(
115                                "Pub/Sub {} received {} messages",
116                                self.component_id,
117                                msgs.len()
118                            );
119
120                            let mut message_ids = Vec::with_capacity(msgs.len());
121                            let mut bodies = Vec::with_capacity(msgs.len());
122                            let mut publish_times = Vec::with_capacity(msgs.len());
123                            let mut ack_messages = Vec::with_capacity(msgs.len());
124
125                            for msg in msgs {
126                                let body = String::from_utf8_lossy(&msg.message.data).into_owned();
127                                counter!("component_received_network_bytes_total", self.labels.iter()).increment(body.len() as u64);
128                                message_ids.push(msg.message.message_id.clone());
129                                bodies.push(Some(body));
130                                publish_times.push(msg.message.publish_time.as_ref().map(|t| t.to_string()));
131                                ack_messages.push(msg);
132                            }
133
134                            let id_array = StringArray::from(message_ids);
135                            let body_array = StringArray::from(bodies);
136                            let time_array = StringArray::from(publish_times);
137
138                            if let Ok(rb) = RecordBatch::try_new(
139                                schema.clone(),
140                                vec![Arc::new(id_array), Arc::new(body_array), Arc::new(time_array)],
141                            ) {
142                                match EventBatch::new_with_xid(rb, metadata.clone()) {
143                                    Ok(batch) => {
144                                        let row_count = batch.num_rows();
145                                        let byte_size = batch.estimated_size();
146                                        counter!("component_received_events_total", self.labels.iter()).increment(row_count as u64);
147                                        counter!("component_received_event_bytes_total", self.labels.iter()).increment(byte_size as u64);
148
149                                        if let Err(e) = self.sender.send(batch).await {
150                                            counter!("component_errors_total", self.labels.iter()).increment(1);
151                                            error!("Failed to send Pub/Sub messages downstream: {:?}", e);
152                                            for msg in ack_messages {
153                                                let _ = msg.nack().await;
154                                            }
155                                        } else {
156                                            counter!("component_sent_events_total", self.labels.iter()).increment(row_count as u64);
157                                            for msg in ack_messages {
158                                                let _ = msg.ack().await;
159                                            }
160                                        }
161                                    }
162                                    Err(e) => {
163                                        counter!("component_errors_total", self.labels.iter()).increment(1);
164                                        error!("Failed to create EventBatch from Pub/Sub: {}", e);
165                                        for msg in ack_messages {
166                                            let _ = msg.nack().await;
167                                        }
168                                    }
169                                }
170                            }
171                        }
172                        Err(e) => {
173                            error!(
174                                "Pub/Sub receive error in component '{}': {}",
175                                self.component_id, e
176                            );
177                            sleep(Duration::from_secs(5)).await;
178                        }
179                    }
180                }
181            }
182        }
183    }
184}
185
186#[async_trait::async_trait]
187impl Healthcheck for PubSubSourceTask {
188    async fn check(&self) -> anyhow::Result<()> {
189        let gcp: GcpConfig = self
190            .config
191            .config
192            .auth
193            .as_ref()
194            .cloned()
195            .unwrap_or_default();
196        let client: PubSubClient = create_pubsub_client(&gcp).await?;
197
198        let sub_name = self
199            .config
200            .config
201            .subscription
202            .as_ref()
203            .ok_or_else(|| anyhow::anyhow!("Subscription name missing"))?;
204        let subscription = client.subscription(sub_name);
205
206        if subscription.exists(None).await? {
207            Ok(())
208        } else {
209            anyhow::bail!("Pub/Sub subscription '{}' does not exist", sub_name)
210        }
211    }
212}