kinetic/transforms/aggregate/
sql_rewrite.rs1use sqlparser::ast::{Statement, TableFactor};
6use sqlparser::dialect::GenericDialect;
7use sqlparser::parser::Parser;
8
9pub struct SqlRewrite;
11
12impl SqlRewrite {
13 pub fn rewrite(sql: &str, sealed_table_name: &str) -> String {
18 let dialect = GenericDialect {};
19 let mut ast = match Parser::parse_sql(&dialect, sql) {
20 Ok(ast) => ast,
21 Err(_) => return sql.replace("active_table", sealed_table_name),
22 };
23
24 for statement in &mut ast {
25 if let Statement::Query(query) = statement
26 && let sqlparser::ast::SetExpr::Select(select) = query.body.as_mut()
27 {
28 for from in &mut select.from {
29 if let TableFactor::Table { name, .. } = &mut from.relation
30 && name.to_string() == "active_table"
31 {
32 *name = sqlparser::ast::ObjectName(vec![
33 sqlparser::ast::ObjectNamePart::Identifier(sqlparser::ast::Ident::new(
34 sealed_table_name,
35 )),
36 ]);
37 }
38 for join in &mut from.joins {
39 if let TableFactor::Table { name, .. } = &mut join.relation
40 && name.to_string() == "active_table"
41 {
42 *name = sqlparser::ast::ObjectName(vec![
43 sqlparser::ast::ObjectNamePart::Identifier(
44 sqlparser::ast::Ident::new(sealed_table_name),
45 ),
46 ]);
47 }
48 }
49 }
50 }
51 }
52
53 ast.iter()
54 .map(|s| s.to_string())
55 .collect::<Vec<_>>()
56 .join("; ")
57 }
58
59 pub fn is_read_only(sql: &str) -> bool {
63 let dialect = GenericDialect {};
64 let ast = match Parser::parse_sql(&dialect, sql) {
65 Ok(ast) => ast,
66 Err(_) => return false,
67 };
68
69 if ast.is_empty() {
70 return false;
71 }
72
73 for statement in ast {
75 match statement {
76 Statement::Query(_) | Statement::Explain { .. } => {}
77 _ => return false, }
79 }
80
81 true
82 }
83}
84
85#[cfg(test)]
86#[allow(clippy::unwrap_used, clippy::expect_used)]
87mod tests {
88 use super::*;
89
90 #[test]
91 fn test_rewrite_active_table() {
92 let sql = "SELECT * FROM active_table WHERE value > 10";
93 let rewritten = SqlRewrite::rewrite(sql, "sealed_abc123");
94 assert_eq!(rewritten, "SELECT * FROM sealed_abc123 WHERE value > 10");
95 }
96
97 #[test]
98 fn test_rewrite_no_active_table() {
99 let sql = "SELECT * FROM some_other_table";
100 let rewritten = SqlRewrite::rewrite(sql, "sealed_abc123");
101 assert_eq!(rewritten, "SELECT * FROM some_other_table");
102 }
103
104 #[test]
105 fn test_is_read_only() {
106 assert!(SqlRewrite::is_read_only("SELECT * FROM active_table"));
107 assert!(SqlRewrite::is_read_only(" select count(*) from t"));
108 assert!(!SqlRewrite::is_read_only("SELECT * FROM t; DROP TABLE t"));
109 assert!(!SqlRewrite::is_read_only("INSERT INTO t VALUES (1)"));
110 assert!(!SqlRewrite::is_read_only("DELETE FROM t"));
111 assert!(!SqlRewrite::is_read_only("UPDATE t SET x = 1"));
112 assert!(!SqlRewrite::is_read_only(
113 "SELECT * FROM (INSERT INTO t ...)"
114 ));
115 assert!(SqlRewrite::is_read_only("SELECT created_at FROM t"));
116 }
117
118 #[test]
119 fn test_rewrite_multiple_occurrences() {
120 let sql = "SELECT * FROM active_table AS a JOIN active_table AS b ON a.id = b.id";
121 let rewritten = SqlRewrite::rewrite(sql, "sealed_xyz789");
122 assert_eq!(
123 rewritten,
124 "SELECT * FROM sealed_xyz789 AS a JOIN sealed_xyz789 AS b ON a.id = b.id"
125 );
126 }
127}