1use aws_common::client::create_s3_client;
4use aws_common::config::AwsConfig;
5use kinetic_buffers::BufferSender;
6use kinetic_core::encode::Decoder;
7use kinetic_core::{ComponentId, EventMetadata, ShutdownSignal};
8use kinetic_encoder_parquet::ParquetDecoder;
9use metrics::{Label, counter};
10use std::sync::Arc;
11use tracing::{error, info};
12
13#[allow(clippy::large_enum_variant)]
14pub enum ReplayFragment {
15 File(String),
16 S3 {
17 bucket: String,
18 key: String,
19 aws: Option<AwsConfig>,
20 },
21}
22
23pub struct ReplaySource {
24 component_id: String,
25 pipeline_id: String,
26 out: BufferSender,
27 shutdown: ShutdownSignal,
28 fragments: Vec<ReplayFragment>,
29 labels: Arc<[Label]>,
30 decoder: ParquetDecoder,
31}
32
33impl ReplaySource {
34 pub fn new(
35 component_id: String,
36 pipeline_id: String,
37 out: BufferSender,
38 shutdown: ShutdownSignal,
39 fragments: Vec<ReplayFragment>,
40 ) -> Self {
41 let labels: Arc<[Label]> = Arc::new([
42 Label::new("component_id", component_id.clone()),
43 Label::new("component_type", "source"),
44 Label::new("component_kind", "replay"),
45 ]);
46
47 Self {
48 component_id,
49 pipeline_id,
50 out,
51 shutdown,
52 fragments,
53 labels,
54 decoder: ParquetDecoder::default(),
55 }
56 }
57
58 pub async fn run(self) {
59 info!(
60 "Starting Replay source '{}' for {} fragments",
61 self.component_id,
62 self.fragments.len()
63 );
64
65 for fragment in &self.fragments {
66 if self.shutdown.is_shutdown() {
67 break;
68 }
69
70 let data = match fragment {
71 ReplayFragment::File(path) => {
72 info!("Replaying local fragment: {}", path);
73 match tokio::fs::read(path).await {
74 Ok(d) => d,
75 Err(e) => {
76 error!("Failed to read local fragment {}: {}", path, e);
77 continue;
78 }
79 }
80 }
81 ReplayFragment::S3 { bucket, key, aws } => {
82 info!("Replaying S3 fragment: s3://{}/{}", bucket, key);
83 let aws_config = aws.as_ref().cloned().unwrap_or_default();
84 let client = create_s3_client(&aws_config).await;
85
86 match client.get_object().bucket(bucket).key(key).send().await {
87 Ok(output) => match output.body.collect().await {
88 Ok(bytes) => bytes.into_bytes().to_vec(),
89 Err(e) => {
90 error!(
91 "Failed to collect S3 body for s3://{}/{}: {}",
92 bucket, key, e
93 );
94 continue;
95 }
96 },
97 Err(e) => {
98 error!("Failed to get S3 object s3://{}/{}: {}", bucket, key, e);
99 continue;
100 }
101 }
102 }
103 };
104
105 let metadata = Arc::new(EventMetadata::new(
106 self.pipeline_id.clone(),
107 ComponentId(self.component_id.clone()),
108 ));
109
110 match self.decoder.decode(&data, metadata) {
113 Ok(batch) => {
114 let rows = batch.num_rows();
115 if let Err(e) = self.out.send(batch).await {
116 error!("Replay source failed to send batch: {:?}", e);
117 break;
118 }
119 counter!("component_sent_events_total", self.labels.iter())
120 .increment(rows as u64);
121 }
122 Err(e) => {
123 error!("Failed to decode Parquet fragment: {}", e);
124 }
125 }
126 }
127
128 info!("Replay source '{}' completed", self.component_id);
129 }
130}
131
132#[cfg(test)]
133#[allow(clippy::unwrap_used, clippy::expect_used)]
134mod tests {
135 use super::*;
136 use kinetic_buffers::channel;
137 use kinetic_core::EventBatch;
138 use kinetic_core::encode::Encoder;
139 use kinetic_encoder_parquet::ParquetEncoder;
140 use tempfile::tempdir;
141
142 #[tokio::test]
143 async fn test_replay_source_emits_batches() {
144 let tmp = tempdir().unwrap();
145 let file_path = tmp.path().join("test.parquet");
146
147 let schema = Arc::new(arrow_schema::Schema::new(vec![arrow_schema::Field::new(
149 "f1",
150 arrow_schema::DataType::Int32,
151 false,
152 )]));
153 let rb = arrow_array::RecordBatch::try_new(
154 schema,
155 vec![Arc::new(arrow_array::Int32Array::from(vec![1, 2, 3]))],
156 )
157 .unwrap();
158 let metadata = Arc::new(EventMetadata::new("test", ComponentId::from("source")));
159 let batch = EventBatch::new(rb, metadata).unwrap();
160
161 let encoder = ParquetEncoder::default();
162 let encoded = encoder.encode(&batch).unwrap();
163 tokio::fs::write(&file_path, encoded.as_ref())
164 .await
165 .unwrap();
166
167 let (tx, mut rx) = channel(10, kinetic_buffers::WhenFull::Block, "test".to_string());
168 let coord = kinetic_core::ShutdownCoordinator::new();
169 let fragments = vec![ReplayFragment::File(
170 file_path.to_str().unwrap().to_string(),
171 )];
172
173 let source = ReplaySource::new(
174 "replay-test".to_string(),
175 "test-pipeline".to_string(),
176 tx,
177 coord.register(),
178 fragments,
179 );
180
181 tokio::spawn(source.run());
182
183 let batch_out = rx.recv().await;
185 assert!(batch_out.is_some());
186 assert_eq!(batch_out.unwrap().num_rows(), 3);
187 }
188
189 #[tokio::test]
190 async fn test_replay_source_shutdown() {
191 let (tx, _rx) = channel(10, kinetic_buffers::WhenFull::Block, "test".to_string());
192 let coord = kinetic_core::ShutdownCoordinator::new();
193 let fragments = vec![ReplayFragment::File("nonexistent.parquet".to_string())];
194
195 let source = ReplaySource::new(
196 "replay-test".to_string(),
197 "test-pipeline".to_string(),
198 tx,
199 coord.register(),
200 fragments,
201 );
202
203 coord.broadcast(kinetic_core::ShutdownReason::Graceful);
205
206 source.run().await;
207 }
209}