Skip to content

Commit 73891ee

Browse files
committed
Add #[recursive]
1 parent 2bb8144 commit 73891ee

File tree

5 files changed

+41
-1
lines changed

5 files changed

+41
-1
lines changed

Cargo.toml

+3-1
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,16 @@ path = "src/lib.rs"
3838

3939
[features]
4040
default = ["std"]
41-
std = []
41+
std = ["recursive"]
4242
# Enable JSON output in the `cli` example:
4343
json_example = ["serde_json", "serde"]
4444
visitor = ["sqlparser_derive"]
4545

4646
[dependencies]
4747
bigdecimal = { version = "0.4.1", features = ["serde"], optional = true }
4848
log = "0.4"
49+
recursive = { version = "0.1.1", optional = true}
50+
4951
serde = { version = "1.0", features = ["derive"], optional = true }
5052
# serde_json is only used in examples/cli, but we have to put it outside
5153
# of dev-dependencies because of

derive/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> proc_
7878
let expanded = quote! {
7979
// The generated impl.
8080
impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause {
81+
#[cfg_attr(feature = "std", recursive::recursive)]
8182
fn visit<V: sqlparser::ast::#visitor_trait>(
8283
&#modifier self,
8384
visitor: &mut V

src/ast/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1188,6 +1188,7 @@ impl fmt::Display for CastFormat {
11881188
}
11891189

11901190
impl fmt::Display for Expr {
1191+
#[cfg_attr(feature = "std", recursive::recursive)]
11911192
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
11921193
match self {
11931194
Expr::Identifier(s) => write!(f, "{s}"),

src/ast/visitor.rs

+26
Original file line numberDiff line numberDiff line change
@@ -884,4 +884,30 @@ mod tests {
884884
assert_eq!(actual, expected)
885885
}
886886
}
887+
888+
889+
struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes
890+
891+
impl Visitor for QuickVisitor {
892+
type Break = ();
893+
}
894+
895+
#[test]
896+
fn overflow() {
897+
let cond = (0..1000)
898+
.map(|n| format!("X = {}", n))
899+
.collect::<Vec<_>>()
900+
.join(" OR ");
901+
let sql = format!("SELECT x where {0}", cond);
902+
903+
let dialect = GenericDialect {};
904+
let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap();
905+
let s = Parser::new(&dialect)
906+
.with_tokens(tokens)
907+
.parse_statement()
908+
.unwrap();
909+
910+
let mut visitor = QuickVisitor {} ;
911+
s.visit(&mut visitor);
912+
}
887913
}

tests/sqlparser_common.rs

+10
Original file line numberDiff line numberDiff line change
@@ -11748,3 +11748,13 @@ fn parse_create_table_select() {
1174811748
);
1174911749
}
1175011750
}
11751+
11752+
#[test]
11753+
fn overflow() {
11754+
let expr = std::iter::repeat("1").take(1000).collect::<Vec<_>>().join(" + ");
11755+
let sql = format!("SELECT {}", expr);
11756+
11757+
let mut statements = Parser::parse_sql(&GenericDialect {}, sql.as_str()).unwrap();
11758+
let statement = statements.pop().unwrap();
11759+
assert_eq!(statement.to_string(), sql);
11760+
}

0 commit comments

Comments
 (0)