Skip to main content

kinetic_encoder_ocsf/
lib.rs

1#![deny(clippy::unwrap_used)]
2
3pub mod generated {
4    include!(concat!(env!("OUT_DIR"), "/generated_schema.rs"));
5}
6
7use arrow_array::builder::*;
8use arrow_schema::{DataType, Schema};
9use bytes::Bytes;
10use kinetic_core::encode::{
11    Decoder, DecoderConfig, Encoder, EncoderConfig, Error as EncodeError, Result as EncodeResult,
12};
13use kinetic_core::event::EventBatch;
14use kinetic_core::metadata::ArcEventMetadata;
15use kinetic_doc_derive::ComponentDoc;
16use serde::{Deserialize, Serialize};
17use snafu::Snafu;
18use sonic_rs::{JsonContainerTrait, JsonValueMutTrait, JsonValueTrait};
19use std::sync::Arc;
20
21#[derive(Debug, Snafu)]
22pub enum Error {
23    #[snafu(display("Failed to parse OCSF JSON: {}", source))]
24    SonicParse { source: sonic_rs::Error },
25
26    #[snafu(display("Arrow IPC encoding error: {}", source))]
27    ArrowEncode { source: arrow_schema::ArrowError },
28
29    #[snafu(display("Kinetic core error: {}", source))]
30    KineticCore { source: kinetic_core::Error },
31
32    #[snafu(display("Unsupported OCSF version: {}", version))]
33    UnsupportedVersion { version: String },
34
35    #[snafu(display("Schema mismatch: {}", message))]
36    SchemaMismatch { message: String },
37}
38
39impl From<Error> for EncodeError {
40    fn from(err: Error) -> Self {
41        match err {
42            Error::SonicParse { source } => EncodeError::Decode {
43                source: Box::new(source),
44            },
45            Error::ArrowEncode { source } => EncodeError::Encode {
46                source: Box::new(source),
47            },
48            Error::KineticCore { source } => EncodeError::Decode {
49                source: Box::new(source),
50            },
51            Error::UnsupportedVersion { version } => EncodeError::Decode {
52                source: format!("Unsupported OCSF version: {}", version).into(),
53            },
54            Error::SchemaMismatch { message } => EncodeError::Decode {
55                source: message.into(),
56            },
57        }
58    }
59}
60
61pub type Result<T, E = Error> = std::result::Result<T, E>;
62
63/// Configuration for the OCSF Decoder.
64#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, ComponentDoc)]
65#[component(type = "decoder", name = "ocsf")]
66pub struct OcsfDecoderConfig {
67    /// OCSF schema version to use (e.g., "1.0").
68    pub version: Option<String>,
69    /// Maximum size of a single message in bytes.
70    pub max_size: Option<usize>,
71}
72
73impl Default for OcsfDecoderConfig {
74    fn default() -> Self {
75        Self {
76            version: None,
77            max_size: Some(10 * 1024 * 1024), // 10MB default
78        }
79    }
80}
81
82impl DecoderConfig for OcsfDecoderConfig {
83    fn build(&self, _schema: Arc<Schema>) -> EncodeResult<Arc<dyn Decoder>> {
84        // Note: For OCSF we use the internally generated schemas.
85        // We can use the version provided in config or detect it per-payload.
86        Ok(Arc::new(OcsfDecoder {
87            config: self.clone(),
88        }))
89    }
90}
91
92/// The OCSF Decoder.
93pub struct OcsfDecoder {
94    pub config: OcsfDecoderConfig,
95}
96
97impl Decoder for OcsfDecoder {
98    fn decode(&self, data: &[u8], metadata: ArcEventMetadata) -> EncodeResult<EventBatch> {
99        self.decode(data, metadata)
100    }
101}
102
103fn append_value(
104    builder: &mut dyn arrow_array::builder::ArrayBuilder,
105    data_type: &DataType,
106    value: Option<&sonic_rs::Value>,
107) -> Result<()> {
108    match data_type {
109        DataType::Utf8 => {
110            let b = builder
111                .as_any_mut()
112                .downcast_mut::<StringBuilder>()
113                .ok_or_else(|| Error::SchemaMismatch {
114                    message: "Expected StringBuilder".to_string(),
115                })?;
116            match value {
117                Some(v) => {
118                    if let Some(s) = v.as_str() {
119                        b.append_value(s);
120                    } else if !v.is_null() {
121                        if let Ok(s) = sonic_rs::to_string(v) {
122                            b.append_value(s);
123                        } else {
124                            b.append_null();
125                        }
126                    } else {
127                        b.append_null();
128                    }
129                }
130                None => b.append_null(),
131            }
132        }
133        DataType::Int64 => {
134            let b = builder
135                .as_any_mut()
136                .downcast_mut::<Int64Builder>()
137                .ok_or_else(|| Error::SchemaMismatch {
138                    message: "Expected Int64Builder".to_string(),
139                })?;
140            match value {
141                Some(v) => {
142                    if let Some(i) = v.as_i64() {
143                        b.append_value(i);
144                    } else if let Some(u) = v.as_u64() {
145                        b.append_value(u as i64);
146                    } else {
147                        b.append_null();
148                    }
149                }
150                None => b.append_null(),
151            }
152        }
153        DataType::Float64 => {
154            let b = builder
155                .as_any_mut()
156                .downcast_mut::<Float64Builder>()
157                .ok_or_else(|| Error::SchemaMismatch {
158                    message: "Expected Float64Builder".to_string(),
159                })?;
160            match value {
161                Some(v) => {
162                    if let Some(f) = v.as_f64() {
163                        b.append_value(f);
164                    } else if let Some(i) = v.as_i64() {
165                        b.append_value(i as f64);
166                    } else if let Some(u) = v.as_u64() {
167                        b.append_value(u as f64);
168                    } else {
169                        b.append_null();
170                    }
171                }
172                None => b.append_null(),
173            }
174        }
175        DataType::Boolean => {
176            let b = builder
177                .as_any_mut()
178                .downcast_mut::<BooleanBuilder>()
179                .ok_or_else(|| Error::SchemaMismatch {
180                    message: "Expected BooleanBuilder".to_string(),
181                })?;
182            match value {
183                Some(v) => {
184                    if let Some(b_val) = v.as_bool() {
185                        b.append_value(b_val);
186                    } else {
187                        b.append_null();
188                    }
189                }
190                None => b.append_null(),
191            }
192        }
193        DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _) => {
194            let b = builder
195                .as_any_mut()
196                .downcast_mut::<TimestampMillisecondBuilder>()
197                .ok_or_else(|| Error::SchemaMismatch {
198                    message: "Expected TimestampMillisecondBuilder".to_string(),
199                })?;
200            match value {
201                Some(v) => {
202                    if let Some(i) = v.as_i64() {
203                        b.append_value(i);
204                    } else if let Some(u) = v.as_u64() {
205                        b.append_value(u as i64);
206                    } else {
207                        b.append_null();
208                    }
209                }
210                None => b.append_null(),
211            }
212        }
213        DataType::List(field) => {
214            if let Some(b) = builder
215                .as_any_mut()
216                .downcast_mut::<ListBuilder<Box<dyn ArrayBuilder>>>()
217            {
218                if let Some(v) = value {
219                    if let Some(arr) = v.as_array() {
220                        for item in arr {
221                            append_value(b.values(), field.data_type(), Some(item))?;
222                        }
223                        b.append(true);
224                    } else {
225                        b.append(false);
226                    }
227                } else {
228                    b.append(false);
229                }
230            }
231        }
232        DataType::Struct(fields) => {
233            if let Some(b) = builder.as_any_mut().downcast_mut::<StructBuilder>() {
234                if let Some(v) = value.and_then(|v| v.as_object()) {
235                    for (i, field) in fields.iter().enumerate() {
236                        let field_val = v.get(field.name());
237                        let dt = field.data_type();
238
239                        match dt {
240                            DataType::Utf8 => append_value(
241                                b.field_builder::<StringBuilder>(i).ok_or_else(|| {
242                                    Error::SchemaMismatch {
243                                        message: format!("Field {} builder missing", field.name()),
244                                    }
245                                })?,
246                                dt,
247                                field_val,
248                            )?,
249                            DataType::Int64 => append_value(
250                                b.field_builder::<Int64Builder>(i).ok_or_else(|| {
251                                    Error::SchemaMismatch {
252                                        message: format!("Field {} builder missing", field.name()),
253                                    }
254                                })?,
255                                dt,
256                                field_val,
257                            )?,
258                            DataType::Float64 => append_value(
259                                b.field_builder::<Float64Builder>(i).ok_or_else(|| {
260                                    Error::SchemaMismatch {
261                                        message: format!("Field {} builder missing", field.name()),
262                                    }
263                                })?,
264                                dt,
265                                field_val,
266                            )?,
267                            DataType::Boolean => append_value(
268                                b.field_builder::<BooleanBuilder>(i).ok_or_else(|| {
269                                    Error::SchemaMismatch {
270                                        message: format!("Field {} builder missing", field.name()),
271                                    }
272                                })?,
273                                dt,
274                                field_val,
275                            )?,
276                            DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _) => {
277                                append_value(
278                                    b.field_builder::<TimestampMillisecondBuilder>(i)
279                                        .ok_or_else(|| Error::SchemaMismatch {
280                                            message: format!(
281                                                "Field {} builder missing",
282                                                field.name()
283                                            ),
284                                        })?,
285                                    dt,
286                                    field_val,
287                                )?
288                            }
289                            DataType::List(_) => append_value(
290                                b.field_builder::<ListBuilder<Box<dyn ArrayBuilder>>>(i)
291                                    .ok_or_else(|| Error::SchemaMismatch {
292                                        message: format!("Field {} builder missing", field.name()),
293                                    })?,
294                                dt,
295                                field_val,
296                            )?,
297                            DataType::Struct(inner_fields) => append_value(
298                                b.field_builder::<StructBuilder>(i).ok_or_else(|| {
299                                    Error::SchemaMismatch {
300                                        message: format!("Field {} builder missing", field.name()),
301                                    }
302                                })?,
303                                &DataType::Struct(inner_fields.clone()),
304                                field_val,
305                            )?,
306                            _ => {}
307                        }
308                    }
309                    b.append(true);
310                } else {
311                    // Recursive null append for all fields
312                    for (i, field) in fields.iter().enumerate() {
313                        let dt = field.data_type();
314                        match dt {
315                            DataType::Utf8 => append_value(
316                                b.field_builder::<StringBuilder>(i).ok_or_else(|| {
317                                    Error::SchemaMismatch {
318                                        message: format!("Field {} builder missing", field.name()),
319                                    }
320                                })?,
321                                dt,
322                                None,
323                            )?,
324                            DataType::Int64 => append_value(
325                                b.field_builder::<Int64Builder>(i).ok_or_else(|| {
326                                    Error::SchemaMismatch {
327                                        message: format!("Field {} builder missing", field.name()),
328                                    }
329                                })?,
330                                dt,
331                                None,
332                            )?,
333                            DataType::Float64 => append_value(
334                                b.field_builder::<Float64Builder>(i).ok_or_else(|| {
335                                    Error::SchemaMismatch {
336                                        message: format!("Field {} builder missing", field.name()),
337                                    }
338                                })?,
339                                dt,
340                                None,
341                            )?,
342                            DataType::Boolean => append_value(
343                                b.field_builder::<BooleanBuilder>(i).ok_or_else(|| {
344                                    Error::SchemaMismatch {
345                                        message: format!("Field {} builder missing", field.name()),
346                                    }
347                                })?,
348                                dt,
349                                None,
350                            )?,
351                            DataType::Timestamp(arrow_schema::TimeUnit::Millisecond, _) => {
352                                append_value(
353                                    b.field_builder::<TimestampMillisecondBuilder>(i)
354                                        .ok_or_else(|| Error::SchemaMismatch {
355                                            message: format!(
356                                                "Field {} builder missing",
357                                                field.name()
358                                            ),
359                                        })?,
360                                    dt,
361                                    None,
362                                )?
363                            }
364                            DataType::List(_) => append_value(
365                                b.field_builder::<ListBuilder<Box<dyn ArrayBuilder>>>(i)
366                                    .ok_or_else(|| Error::SchemaMismatch {
367                                        message: format!("Field {} builder missing", field.name()),
368                                    })?,
369                                dt,
370                                None,
371                            )?,
372                            DataType::Struct(inner_fields) => append_value(
373                                b.field_builder::<StructBuilder>(i).ok_or_else(|| {
374                                    Error::SchemaMismatch {
375                                        message: format!("Field {} builder missing", field.name()),
376                                    }
377                                })?,
378                                &DataType::Struct(inner_fields.clone()),
379                                None,
380                            )?,
381                            _ => {}
382                        }
383                    }
384                    b.append(false);
385                }
386            }
387        }
388        _ => {
389            // Ignored types
390        }
391    }
392    Ok(())
393}
394
395impl OcsfDecoder {
396    fn decode(&self, data: &[u8], metadata: ArcEventMetadata) -> EncodeResult<EventBatch> {
397        if let Some(limit) = self.config.max_size
398            && data.len() > limit
399        {
400            return Err(EncodeError::MessageTooLarge {
401                size: data.len(),
402                limit,
403            });
404        }
405
406        let document = sonic_rs::from_slice::<sonic_rs::Value>(data)
407            .map_err(|e| Error::SonicParse { source: e })?;
408
409        // Determine which schema to use
410        let schema = if let Some(config_version) = &self.config.version {
411            generated::get_schema_by_version(config_version).ok_or_else(|| {
412                Error::UnsupportedVersion {
413                    version: config_version.clone(),
414                }
415            })?
416        } else {
417            // Inference: try to get version from metadata.version in payload
418            let inferred_version = document
419                .get("metadata")
420                .and_then(|m| m.get("version"))
421                .and_then(|v| v.as_str())
422                .unwrap_or("v1.8.0"); // Default to latest if not found
423
424            generated::get_schema_by_version(inferred_version)
425                .unwrap_or_else(generated::get_latest_schema)
426        };
427
428        let mut builders: Vec<Box<dyn ArrayBuilder>> = schema
429            .fields()
430            .iter()
431            .map(|f| arrow_array::builder::make_builder(f.data_type(), 1))
432            .collect();
433
434        let mut unmapped = sonic_rs::json!({});
435
436        if let Some(obj) = document.as_object() {
437            for (i, field) in schema.fields().iter().enumerate() {
438                let name = field.name();
439                if name == "_unmapped" {
440                    continue;
441                }
442
443                let value = obj.get(name);
444                append_value(&mut builders[i], field.data_type(), value)?;
445            }
446
447            if let Some(unmapped_obj) = unmapped.as_object_mut() {
448                for (k, v) in obj.iter() {
449                    if schema.column_with_name(k).is_none() {
450                        unmapped_obj.insert(k, v.clone());
451                    }
452                }
453            }
454        } else {
455            for (i, field) in schema.fields().iter().enumerate() {
456                if field.name() != "_unmapped" {
457                    append_value(&mut builders[i], field.data_type(), None)?;
458                }
459            }
460        }
461
462        if let Ok(unmapped_idx) = schema.index_of("_unmapped") {
463            let unmapped_str = sonic_rs::to_string(&unmapped).unwrap_or_else(|_| "{}".to_string());
464            let b = builders[unmapped_idx]
465                .as_any_mut()
466                .downcast_mut::<StringBuilder>()
467                .ok_or_else(|| Error::SchemaMismatch {
468                    message: "Expected StringBuilder for _unmapped field".to_string(),
469                })?;
470            b.append_value(unmapped_str);
471        }
472
473        let arrays: Vec<std::sync::Arc<dyn arrow_array::Array>> =
474            builders.into_iter().map(|mut b| b.finish()).collect();
475
476        let record_batch = arrow_array::RecordBatch::try_new(schema, arrays)
477            .map_err(|e| Error::ArrowEncode { source: e })?;
478        let event_batch = EventBatch::new_with_xid(record_batch, metadata)
479            .map_err(|e| Error::KineticCore { source: e })?;
480
481        Ok(event_batch)
482    }
483}
484
485/// Configuration for the OCSF Encoder.
486#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default, ComponentDoc)]
487#[component(type = "encoder", name = "ocsf")]
488pub struct OcsfEncoderConfig {}
489
490impl EncoderConfig for OcsfEncoderConfig {
491    fn build(&self) -> EncodeResult<Arc<dyn Encoder>> {
492        Ok(Arc::new(OcsfEncoder::default()))
493    }
494}
495
496/// The OCSF Encoder.
497#[derive(Default)]
498pub struct OcsfEncoder {}
499
500impl Encoder for OcsfEncoder {
501    fn encode(&self, batch: &EventBatch) -> EncodeResult<Bytes> {
502        self.encode_inner(batch).map_err(Into::into)
503    }
504}
505
506impl OcsfEncoder {
507    fn encode_inner(&self, batch: &EventBatch) -> Result<Bytes> {
508        // Output Arrow IPC stream for performance.
509        let mut buf = Vec::new();
510        {
511            let mut writer =
512                arrow_ipc::writer::StreamWriter::try_new(&mut buf, batch.payload.schema().as_ref())
513                    .map_err(|e| Error::ArrowEncode { source: e })?;
514
515            writer
516                .write(&batch.payload)
517                .map_err(|e| Error::ArrowEncode { source: e })?;
518
519            writer
520                .finish()
521                .map_err(|e| Error::ArrowEncode { source: e })?;
522        }
523
524        Ok(Bytes::from(buf))
525    }
526}
527
528#[cfg(test)]
529#[allow(clippy::unwrap_used)]
530mod tests {
531    use super::*;
532    use kinetic_core::metadata::{ComponentId, EventMetadata};
533    use std::sync::Arc;
534
535    #[test]
536    fn test_schema_generation_v1_8_0() {
537        let schema = generated::get_schema_v1_8_0();
538        assert!(schema.column_with_name("message").is_some());
539        assert!(schema.column_with_name("_unmapped").is_some());
540    }
541
542    #[test]
543    fn test_schema_lookup() {
544        assert!(generated::get_schema_by_version("v1.3.0").is_some());
545        assert!(generated::get_schema_by_version("1.8.0").is_some());
546        assert!(generated::get_schema_by_version("v1.8.0").is_some());
547        assert!(generated::get_schema_by_version("nonexistent").is_none());
548    }
549
550    #[test]
551    fn test_ocsf_decoder_inference() {
552        let config = OcsfDecoderConfig {
553            version: None,
554            max_size: None,
555        };
556        let decoder = config.build(generated::get_latest_schema()).unwrap();
557
558        let metadata = Arc::new(EventMetadata::new(
559            "test_pipeline",
560            ComponentId::from("test_source"),
561        ));
562        let result = decoder.decode(
563            b"{\"message\": \"test event\", \"metadata\": {\"version\": \"1.3.0\"}, \"unknown_field\": 123}",
564            metadata,
565        );
566        assert!(result.is_ok());
567        let batch = result.unwrap();
568
569        let msg_col = batch
570            .payload
571            .column_by_name("message")
572            .ok_or_else(|| {
573                format!(
574                    "Column 'message' not found in batch. Available columns: {:?}",
575                    batch
576                        .payload
577                        .schema()
578                        .fields()
579                        .iter()
580                        .map(|f| f.name())
581                        .collect::<Vec<_>>()
582                )
583            })
584            .unwrap();
585        let msg_arr = msg_col
586            .as_any()
587            .downcast_ref::<arrow_array::StringArray>()
588            .unwrap();
589        assert_eq!(msg_arr.value(0), "test event");
590
591        let unmapped_col = batch.payload.column_by_name("_unmapped").unwrap();
592        let unmapped_arr = unmapped_col
593            .as_any()
594            .downcast_ref::<arrow_array::StringArray>()
595            .unwrap();
596        assert!(unmapped_arr.value(0).contains("\"unknown_field\":123"));
597    }
598
599    #[test]
600    fn test_ocsf_encoder() {
601        let config = OcsfEncoderConfig {};
602        let encoder = config.build().unwrap();
603
604        let schema = generated::get_latest_schema();
605        let empty_record_batch = arrow_array::RecordBatch::new_empty(schema);
606        let metadata = Arc::new(EventMetadata::new(
607            "test_pipeline",
608            ComponentId::from("test_source"),
609        ));
610        let event_batch = EventBatch::new(empty_record_batch, metadata).unwrap();
611
612        let result = encoder.encode(&event_batch);
613        assert!(result.is_ok());
614        let bytes = result.unwrap();
615        assert!(!bytes.is_empty());
616    }
617}