Skip to content

Commit 84e82e6

Browse files
blagininiffyio
andauthored
Add #[recursive] (#1522)
Co-authored-by: Ifeanyi Ubah <[email protected]>
1 parent c973df3 commit 84e82e6

File tree

8 files changed

+93
-2
lines changed

8 files changed

+93
-2
lines changed

Cargo.toml

+4-1
Original file line numberDiff line numberDiff line change
@@ -37,15 +37,18 @@ name = "sqlparser"
3737
path = "src/lib.rs"
3838

3939
[features]
40-
default = ["std"]
40+
default = ["std", "recursive-protection"]
4141
std = []
42+
recursive-protection = ["std", "recursive"]
4243
# Enable JSON output in the `cli` example:
4344
json_example = ["serde_json", "serde"]
4445
visitor = ["sqlparser_derive"]
4546

4647
[dependencies]
4748
bigdecimal = { version = "0.4.1", features = ["serde"], optional = true }
4849
log = "0.4"
50+
recursive = { version = "0.1.1", optional = true}
51+
4952
serde = { version = "1.0", features = ["derive"], optional = true }
5053
# serde_json is only used in examples/cli, but we have to put it outside
5154
# of dev-dependencies because of

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ The following optional [crate features](https://doc.rust-lang.org/cargo/referen
6363

6464
* `serde`: Adds [Serde](https://serde.rs/) support by implementing `Serialize` and `Deserialize` for all AST nodes.
6565
* `visitor`: Adds a `Visitor` capable of recursively walking the AST tree.
66-
66+
* `recursive-protection` (enabled by default), uses [recursive](https://docs.rs/recursive/latest/recursive/) for stack overflow protection.
6767

6868
## Syntax vs Semantics
6969

derive/src/lib.rs

+3
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,10 @@ fn derive_visit(input: proc_macro::TokenStream, visit_type: &VisitType) -> proc_
7878

7979
let expanded = quote! {
8080
// The generated impl.
81+
// Note that it uses [`recursive::recursive`] to protect from stack overflow.
82+
// See tests in https://github.com/apache/datafusion-sqlparser-rs/pull/1522/ for more info.
8183
impl #impl_generics sqlparser::ast::#visit_trait for #name #ty_generics #where_clause {
84+
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
8285
fn visit<V: sqlparser::ast::#visitor_trait>(
8386
&#modifier self,
8487
visitor: &mut V

sqlparser_bench/benches/sqlparser_bench.rs

+40
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,46 @@ fn basic_queries(c: &mut Criterion) {
4242
group.bench_function("sqlparser::with_select", |b| {
4343
b.iter(|| Parser::parse_sql(&dialect, with_query).unwrap());
4444
});
45+
46+
let large_statement = {
47+
let expressions = (0..1000)
48+
.map(|n| format!("FN_{}(COL_{})", n, n))
49+
.collect::<Vec<_>>()
50+
.join(", ");
51+
let tables = (0..1000)
52+
.map(|n| format!("TABLE_{}", n))
53+
.collect::<Vec<_>>()
54+
.join(" JOIN ");
55+
let where_condition = (0..1000)
56+
.map(|n| format!("COL_{} = {}", n, n))
57+
.collect::<Vec<_>>()
58+
.join(" OR ");
59+
let order_condition = (0..1000)
60+
.map(|n| format!("COL_{} DESC", n))
61+
.collect::<Vec<_>>()
62+
.join(", ");
63+
64+
format!(
65+
"SELECT {} FROM {} WHERE {} ORDER BY {}",
66+
expressions, tables, where_condition, order_condition
67+
)
68+
};
69+
70+
group.bench_function("parse_large_statement", |b| {
71+
b.iter(|| Parser::parse_sql(&dialect, criterion::black_box(large_statement.as_str())));
72+
});
73+
74+
let large_statement = Parser::parse_sql(&dialect, large_statement.as_str())
75+
.unwrap()
76+
.pop()
77+
.unwrap();
78+
79+
group.bench_function("format_large_statement", |b| {
80+
b.iter(|| {
81+
let formatted_query = large_statement.to_string();
82+
assert_eq!(formatted_query, large_statement);
83+
});
84+
});
4585
}
4686

4787
criterion_group!(benches, basic_queries);

src/ast/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,7 @@ impl fmt::Display for CastFormat {
12911291
}
12921292

12931293
impl fmt::Display for Expr {
1294+
#[cfg_attr(feature = "recursive-protection", recursive::recursive)]
12941295
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
12951296
match self {
12961297
Expr::Identifier(s) => write!(f, "{s}"),

src/ast/visitor.rs

+25
Original file line numberDiff line numberDiff line change
@@ -894,4 +894,29 @@ mod tests {
894894
assert_eq!(actual, expected)
895895
}
896896
}
897+
898+
struct QuickVisitor; // [`TestVisitor`] is too slow to iterate over thousands of nodes
899+
900+
impl Visitor for QuickVisitor {
901+
type Break = ();
902+
}
903+
904+
#[test]
905+
fn overflow() {
906+
let cond = (0..1000)
907+
.map(|n| format!("X = {}", n))
908+
.collect::<Vec<_>>()
909+
.join(" OR ");
910+
let sql = format!("SELECT x where {0}", cond);
911+
912+
let dialect = GenericDialect {};
913+
let tokens = Tokenizer::new(&dialect, sql.as_str()).tokenize().unwrap();
914+
let s = Parser::new(&dialect)
915+
.with_tokens(tokens)
916+
.parse_statement()
917+
.unwrap();
918+
919+
let mut visitor = QuickVisitor {};
920+
s.visit(&mut visitor);
921+
}
897922
}

src/parser/mod.rs

+6
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,9 @@ mod recursion {
7373
/// Note: Uses an [`std::rc::Rc`] and [`std::cell::Cell`] in order to satisfy the Rust
7474
/// borrow checker so the automatic [`DepthGuard`] decrement a
7575
/// reference to the counter.
76+
///
77+
/// Note: when "recursive-protection" feature is enabled, this crate uses additional stack overflow protection
78+
/// for some of its recursive methods. See [`recursive::recursive`] for more information.
7679
pub(crate) struct RecursionCounter {
7780
remaining_depth: Rc<Cell<usize>>,
7881
}
@@ -326,6 +329,9 @@ impl<'a> Parser<'a> {
326329
/// # Ok(())
327330
/// # }
328331
/// ```
332+
///
333+
/// Note: when "recursive-protection" feature is enabled, this crate uses additional stack overflow protection
334+
// for some of its recursive methods. See [`recursive::recursive`] for more information.
329335
pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self {
330336
self.recursion_counter = RecursionCounter::new(recursion_limit);
331337
self

tests/sqlparser_common.rs

+13
Original file line numberDiff line numberDiff line change
@@ -12433,3 +12433,16 @@ fn test_table_sample() {
1243312433
dialects.verified_stmt("SELECT * FROM tbl AS t TABLESAMPLE SYSTEM (50)");
1243412434
dialects.verified_stmt("SELECT * FROM tbl AS t TABLESAMPLE SYSTEM (50) REPEATABLE (10)");
1243512435
}
12436+
12437+
#[test]
12438+
fn overflow() {
12439+
let expr = std::iter::repeat("1")
12440+
.take(1000)
12441+
.collect::<Vec<_>>()
12442+
.join(" + ");
12443+
let sql = format!("SELECT {}", expr);
12444+
12445+
let mut statements = Parser::parse_sql(&GenericDialect {}, sql.as_str()).unwrap();
12446+
let statement = statements.pop().unwrap();
12447+
assert_eq!(statement.to_string(), sql);
12448+
}

0 commit comments

Comments
 (0)