Skip to main content

kinetic_encoder_protobuf/
lib.rs

1//! Protobuf encoder and decoder for Kinetic using prost-reflect and arrow-json.
2
3use arrow_json::ReaderBuilder;
4use arrow_schema::Schema;
5use bytes::Bytes;
6use kinetic_core::encode::{
7    Decoder, DecoderConfig, Encoder, EncoderConfig, Error as EncodeError, Result as EncodeResult,
8};
9use kinetic_core::{ArcEventMetadata, EventBatch};
10use prost::Message;
11use prost_reflect::{DescriptorPool, DynamicMessage, MessageDescriptor};
12use serde::{Deserialize, Serialize};
13use snafu::{OptionExt, Snafu};
14use std::fs;
15use std::sync::Arc;
16
17#[derive(Debug, Snafu)]
18pub enum Error {
19    #[snafu(display("Protobuf descriptor error: {}", message))]
20    Descriptor { message: String },
21
22    #[snafu(display("Protobuf encoding error: {}", message))]
23    ProtobufEncode { message: String },
24
25    #[snafu(display("Protobuf decoding error: {}", message))]
26    ProtobufDecode { message: String },
27
28    #[snafu(display("Arrow error: {}", source))]
29    Arrow { source: arrow_schema::ArrowError },
30
31    #[snafu(display("Kinetic core error: {}", source))]
32    KineticCore { source: kinetic_core::Error },
33
34    #[snafu(display("IO error: {}", source))]
35    Io { source: std::io::Error },
36
37    #[snafu(display("JSON error: {}", source))]
38    Json { source: serde_json::Error },
39}
40
41impl From<Error> for EncodeError {
42    fn from(err: Error) -> Self {
43        match err {
44            Error::Descriptor { message } => EncodeError::Encode {
45                source: message.into(),
46            },
47            Error::ProtobufEncode { message } => EncodeError::Encode {
48                source: message.into(),
49            },
50            Error::ProtobufDecode { message } => EncodeError::Decode {
51                source: message.into(),
52            },
53            Error::Arrow { source } => EncodeError::Encode {
54                source: Box::new(source),
55            },
56            Error::KineticCore { source } => EncodeError::Encode {
57                source: Box::new(source),
58            },
59            Error::Io { source } => EncodeError::Encode {
60                source: Box::new(source),
61            },
62            Error::Json { source } => EncodeError::Encode {
63                source: Box::new(source),
64            },
65        }
66    }
67}
68
69pub type Result<T, E = Error> = std::result::Result<T, E>;
70
71/// Options for configuring Protobuf encoding.
72#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
73pub struct ProtobufEncoderOptions {
74    /// Path to the Protobuf descriptor file.
75    pub descriptor_path: String,
76    /// Fully qualified name of the Protobuf message.
77    pub message_name: String,
78}
79
80impl EncoderConfig for ProtobufEncoderOptions {
81    fn build(&self) -> EncodeResult<Arc<dyn Encoder>> {
82        let descriptor_bytes =
83            fs::read(&self.descriptor_path).map_err(|e| Error::Io { source: e })?;
84        let pool = DescriptorPool::decode(Bytes::from(descriptor_bytes)).map_err(|e| {
85            Error::Descriptor {
86                message: e.to_string(),
87            }
88        })?;
89        let message_descriptor =
90            pool.get_message_by_name(&self.message_name)
91                .context(DescriptorSnafu {
92                    message: format!("Message {} not found in descriptor", self.message_name),
93                })?;
94
95        Ok(Arc::new(ProtobufEncoder::new(
96            self.clone(),
97            message_descriptor,
98        )))
99    }
100}
101
102/// Protobuf Encoder
103pub struct ProtobufEncoder {
104    _options: ProtobufEncoderOptions,
105    message_descriptor: MessageDescriptor,
106}
107
108impl ProtobufEncoder {
109    pub fn new(options: ProtobufEncoderOptions, message_descriptor: MessageDescriptor) -> Self {
110        Self {
111            _options: options,
112            message_descriptor,
113        }
114    }
115}
116
117impl Encoder for ProtobufEncoder {
118    fn encode(&self, batch: &EventBatch) -> EncodeResult<Bytes> {
119        // Step 1: RecordBatch -> JSON
120        let mut buffer = Vec::new();
121        {
122            let mut writer = arrow_json::LineDelimitedWriter::new(&mut buffer);
123            writer
124                .write(&batch.payload)
125                .map_err(|e| Error::Arrow { source: e })?;
126            writer.finish().map_err(|e| Error::Arrow { source: e })?;
127        }
128
129        // Step 2: JSON -> Protobuf binary
130        let mut final_buffer = Vec::new();
131        let reader = std::io::BufReader::new(&buffer[..]);
132        use std::io::BufRead;
133        for line in reader.lines() {
134            let line = line.map_err(|e| Error::Io { source: e })?;
135            if line.trim().is_empty() {
136                continue;
137            }
138            let mut deserializer = serde_json::Deserializer::from_str(&line);
139            let dynamic_message =
140                DynamicMessage::deserialize(self.message_descriptor.clone(), &mut deserializer)
141                    .map_err(|e| Error::ProtobufEncode {
142                        message: e.to_string(),
143                    })?;
144
145            let proto_bytes = dynamic_message.encode_to_vec();
146
147            // Length prefixing
148            let mut len_buf = Vec::new();
149            prost::encode_length_delimiter(proto_bytes.len(), &mut len_buf).map_err(|e| {
150                Error::ProtobufEncode {
151                    message: e.to_string(),
152                }
153            })?;
154            final_buffer.extend_from_slice(&len_buf);
155            final_buffer.extend_from_slice(&proto_bytes);
156        }
157
158        Ok(Bytes::from(final_buffer))
159    }
160}
161
162/// Options for configuring Protobuf decoding.
163#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
164pub struct ProtobufDecoderOptions {
165    /// Path to the Protobuf descriptor file.
166    pub descriptor_path: String,
167    /// Fully qualified name of the Protobuf message.
168    pub message_name: String,
169    /// Maximum size of a single message in bytes.
170    pub max_size: Option<usize>,
171}
172
173impl Default for ProtobufDecoderOptions {
174    fn default() -> Self {
175        Self {
176            descriptor_path: String::new(),
177            message_name: String::new(),
178            max_size: Some(10 * 1024 * 1024), // 10MB default
179        }
180    }
181}
182
183impl DecoderConfig for ProtobufDecoderOptions {
184    fn build(&self, schema: Arc<Schema>) -> EncodeResult<Arc<dyn Decoder>> {
185        let descriptor_bytes =
186            fs::read(&self.descriptor_path).map_err(|e| Error::Io { source: e })?;
187        let pool = DescriptorPool::decode(Bytes::from(descriptor_bytes)).map_err(|e| {
188            Error::Descriptor {
189                message: e.to_string(),
190            }
191        })?;
192        let message_descriptor =
193            pool.get_message_by_name(&self.message_name)
194                .context(DescriptorSnafu {
195                    message: format!("Message {} not found in descriptor", self.message_name),
196                })?;
197
198        Ok(Arc::new(ProtobufDecoder::new(
199            self.clone(),
200            message_descriptor,
201            schema,
202        )))
203    }
204}
205
206/// Protobuf Decoder
207pub struct ProtobufDecoder {
208    options: ProtobufDecoderOptions,
209    message_descriptor: MessageDescriptor,
210    schema: Arc<Schema>,
211}
212
213impl ProtobufDecoder {
214    pub fn new(
215        options: ProtobufDecoderOptions,
216        message_descriptor: MessageDescriptor,
217        schema: Arc<Schema>,
218    ) -> Self {
219        Self {
220            options,
221            message_descriptor,
222            schema,
223        }
224    }
225}
226impl Decoder for ProtobufDecoder {
227    fn decode(&self, data: &[u8], metadata: ArcEventMetadata) -> EncodeResult<EventBatch> {
228        if let Some(limit) = self.options.max_size
229            && data.len() > limit
230        {
231            return Err(kinetic_core::encode::Error::MessageTooLarge {
232                size: data.len(),
233                limit,
234            });
235        }
236
237        let mut cursor = std::io::Cursor::new(data);
238        let mut rows = Vec::new();
239
240        while cursor.position() < data.len() as u64 {
241            let len_result = prost::decode_length_delimiter(&mut cursor);
242            match len_result {
243                Ok(len) => {
244                    let start = cursor.position() as usize;
245                    let end = start + len;
246                    if end > data.len() {
247                        return Err(Error::ProtobufDecode {
248                            message: "Truncated protobuf message".to_string(),
249                        }
250                        .into());
251                    }
252                    let message_bytes = &data[start..end];
253                    cursor.set_position(end as u64);
254
255                    let dynamic_message =
256                        DynamicMessage::decode(self.message_descriptor.clone(), message_bytes)
257                            .map_err(|e| Error::ProtobufDecode {
258                                message: e.to_string(),
259                            })?;
260
261                    rows.push(dynamic_message);
262                }
263                Err(e) => {
264                    // If decoding length delimiter fails, it might be a single message without prefix
265                    if cursor.position() == 0 {
266                        let dynamic_message =
267                            DynamicMessage::decode(self.message_descriptor.clone(), data).map_err(
268                                |_| Error::ProtobufDecode {
269                                    message: e.to_string(),
270                                },
271                            )?;
272                        rows.push(dynamic_message);
273                        break;
274                    } else {
275                        return Err(Error::ProtobufDecode {
276                            message: e.to_string(),
277                        }
278                        .into());
279                    }
280                }
281            }
282        }
283
284        if rows.is_empty() {
285            return Err(EncodeError::Decode {
286                source: "No messages found in Protobuf data".into(),
287            });
288        }
289
290        let mut decoder = ReaderBuilder::new(self.schema.clone())
291            .build_decoder()
292            .map_err(|e| Error::Arrow { source: e })?;
293
294        decoder
295            .serialize(&rows)
296            .map_err(|e| Error::Arrow { source: e })?;
297
298        let record_batch = decoder
299            .flush()
300            .map_err(|e| Error::Arrow { source: e })?
301            .context(ProtobufDecodeSnafu {
302                message: "Failed to flush Arrow decoder",
303            })?;
304
305        EventBatch::new_with_xid(record_batch, metadata)
306            .map_err(|e| Error::KineticCore { source: e })
307            .map_err(Into::<EncodeError>::into)
308    }
309}
310
311#[cfg(test)]
312mod tests {
313    // Functional tests skipped due to complexity of generating descriptors.
314}