1use arrow_array::{ArrayRef, RecordBatch, StringArray};
2use arrow_schema::{DataType, Field, Schema};
3use std::sync::Arc;
4use xid;
5
6pub fn ensure_xid_column(batch: RecordBatch) -> Result<RecordBatch, arrow_schema::ArrowError> {
7 let schema = batch.schema();
8 if schema.column_with_name("_xid").is_some() {
9 return Ok(batch);
10 }
11
12 let num_rows = batch.num_rows();
13 let mut xids = Vec::with_capacity(num_rows);
14 for _ in 0..num_rows {
15 xids.push(xid::new().to_string());
16 }
17 let xid_array = StringArray::from(xids);
18
19 let mut new_fields = schema.fields().to_vec();
20 new_fields.push(Arc::new(Field::new("_xid", DataType::Utf8, false)));
21
22 let new_schema = Arc::new(Schema::new(new_fields));
23
24 let mut new_columns: Vec<ArrayRef> = batch.columns().to_vec();
25 new_columns.push(Arc::new(xid_array) as ArrayRef);
26
27 RecordBatch::try_new(new_schema, new_columns)
28}
29
30pub fn batch_xids(batch: &arrow_array::RecordBatch) -> Vec<String> {
32 if let Some(column) = batch.column_by_name("_xid")
33 && let Some(array) = column.as_any().downcast_ref::<StringArray>()
34 {
35 return array
36 .iter()
37 .map(|value| value.unwrap_or("unknown_xid").to_string())
38 .collect();
39 }
40
41 (0..batch.num_rows())
42 .map(|_| "unknown_xid".to_string())
43 .collect()
44}
45
46#[cfg(test)]
47#[allow(clippy::unwrap_used)]
48mod tests {
49 use super::*;
50 use arrow_array::Int32Array;
51
52 fn base_batch() -> RecordBatch {
53 let schema = Arc::new(Schema::new(vec![Field::new(
54 "value",
55 DataType::Int32,
56 false,
57 )]));
58 let col = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
59 RecordBatch::try_new(schema, vec![col]).expect("test batch should build")
60 }
61
62 #[test]
63 fn ensure_xid_adds_column() {
64 let input = base_batch();
65 let output = ensure_xid_column(input).expect("xid column append should succeed");
66
67 assert_eq!(output.num_columns(), 2);
68 assert!(output.schema().column_with_name("_xid").is_some());
69 assert_eq!(output.num_rows(), 3);
70 }
71
72 #[test]
73 fn ensure_xid_is_idempotent_when_present() {
74 let once = ensure_xid_column(base_batch()).expect("first append should succeed");
75 let twice = ensure_xid_column(once.clone()).expect("second append should succeed");
76
77 assert_eq!(twice.num_columns(), once.num_columns());
78 assert!(twice.schema().column_with_name("_xid").is_some());
79 }
80
81 #[test]
82 fn test_batch_xids() {
83 let batch = ensure_xid_column(base_batch()).unwrap();
84 let xids = batch_xids(&batch);
85 assert_eq!(xids.len(), 3);
86 assert!(!xids[0].is_empty());
87 }
88}