Skip to main content

kinetic/transforms/aggregate/
sql_rewrite.rs

1//! SQL rewriting utilities for the aggregate transform.
2//!
3//! Handles rewriting SQL queries to replace table names with sealed table names.
4
5use sqlparser::ast::{Statement, TableFactor};
6use sqlparser::dialect::GenericDialect;
7use sqlparser::parser::Parser;
8
9/// SQL rewriter for aggregate transform queries
10pub struct SqlRewrite;
11
12impl SqlRewrite {
13    /// Rewrite a SQL query to use the sealed table name
14    ///
15    /// The user's SQL should reference `active_table` which will be replaced
16    /// with the actual sealed table name (e.g., `sealed_abc123`).
17    pub fn rewrite(sql: &str, sealed_table_name: &str) -> String {
18        let dialect = GenericDialect {};
19        let mut ast = match Parser::parse_sql(&dialect, sql) {
20            Ok(ast) => ast,
21            Err(_) => return sql.replace("active_table", sealed_table_name),
22        };
23
24        for statement in &mut ast {
25            if let Statement::Query(query) = statement
26                && let sqlparser::ast::SetExpr::Select(select) = query.body.as_mut()
27            {
28                for from in &mut select.from {
29                    if let TableFactor::Table { name, .. } = &mut from.relation
30                        && name.to_string() == "active_table"
31                    {
32                        *name = sqlparser::ast::ObjectName(vec![
33                            sqlparser::ast::ObjectNamePart::Identifier(sqlparser::ast::Ident::new(
34                                sealed_table_name,
35                            )),
36                        ]);
37                    }
38                    for join in &mut from.joins {
39                        if let TableFactor::Table { name, .. } = &mut join.relation
40                            && name.to_string() == "active_table"
41                        {
42                            *name = sqlparser::ast::ObjectName(vec![
43                                sqlparser::ast::ObjectNamePart::Identifier(
44                                    sqlparser::ast::Ident::new(sealed_table_name),
45                                ),
46                            ]);
47                        }
48                    }
49                }
50            }
51        }
52
53        ast.iter()
54            .map(|s| s.to_string())
55            .collect::<Vec<_>>()
56            .join("; ")
57    }
58
59    /// Validate that the SQL query is read-only (doesn't modify data)
60    ///
61    /// Returns true if the query appears to be a single SELECT statement
62    pub fn is_read_only(sql: &str) -> bool {
63        let dialect = GenericDialect {};
64        let ast = match Parser::parse_sql(&dialect, sql) {
65            Ok(ast) => ast,
66            Err(_) => return false,
67        };
68
69        if ast.is_empty() {
70            return false;
71        }
72
73        // Must be a single statement (prevent multi-statement attacks implicitly by checking length or strictly examining all)
74        for statement in ast {
75            match statement {
76                Statement::Query(_) | Statement::Explain { .. } => {}
77                _ => return false, // If there's any non-query statement, it's not read-only
78            }
79        }
80
81        true
82    }
83}
84
85#[cfg(test)]
86#[allow(clippy::unwrap_used, clippy::expect_used)]
87mod tests {
88    use super::*;
89
90    #[test]
91    fn test_rewrite_active_table() {
92        let sql = "SELECT * FROM active_table WHERE value > 10";
93        let rewritten = SqlRewrite::rewrite(sql, "sealed_abc123");
94        assert_eq!(rewritten, "SELECT * FROM sealed_abc123 WHERE value > 10");
95    }
96
97    #[test]
98    fn test_rewrite_no_active_table() {
99        let sql = "SELECT * FROM some_other_table";
100        let rewritten = SqlRewrite::rewrite(sql, "sealed_abc123");
101        assert_eq!(rewritten, "SELECT * FROM some_other_table");
102    }
103
104    #[test]
105    fn test_is_read_only() {
106        assert!(SqlRewrite::is_read_only("SELECT * FROM active_table"));
107        assert!(SqlRewrite::is_read_only("  select count(*) from t"));
108        assert!(!SqlRewrite::is_read_only("SELECT * FROM t; DROP TABLE t"));
109        assert!(!SqlRewrite::is_read_only("INSERT INTO t VALUES (1)"));
110        assert!(!SqlRewrite::is_read_only("DELETE FROM t"));
111        assert!(!SqlRewrite::is_read_only("UPDATE t SET x = 1"));
112        assert!(!SqlRewrite::is_read_only(
113            "SELECT * FROM (INSERT INTO t ...)"
114        ));
115        assert!(SqlRewrite::is_read_only("SELECT created_at FROM t"));
116    }
117
118    #[test]
119    fn test_rewrite_multiple_occurrences() {
120        let sql = "SELECT * FROM active_table AS a JOIN active_table AS b ON a.id = b.id";
121        let rewritten = SqlRewrite::rewrite(sql, "sealed_xyz789");
122        assert_eq!(
123            rewritten,
124            "SELECT * FROM sealed_xyz789 AS a JOIN sealed_xyz789 AS b ON a.id = b.id"
125        );
126    }
127}