Skip to content

Commit

Permalink
fix: in PQ post processing, revert sort columns just before propagati…
Browse files Browse the repository at this point in the history
…ng down pipeline (#5098)
  • Loading branch information
kgutwin authored Jan 23, 2025
1 parent a54b815 commit 1c404a8
Show file tree
Hide file tree
Showing 10 changed files with 761 additions and 35 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ repos:
hooks:
- id: actionlint
- repo: https://github.com/tcort/markdown-link-check
rev: v3.13.6
rev: v3.12.2
hooks:
- id: markdown-link-check
name: markdown-link-check-local
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

- Sort steps in sub-pipelines no longer cause a column lookup error
(@lukapeschke, #5066)
- Dereferencing of sort columns when rendering SQL now done in context of main
pipeline (@kgutwin, #5098)

**Documentation**:

Expand Down
31 changes: 30 additions & 1 deletion prqlc/prqlc/src/sql/pq/anchor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -643,15 +643,44 @@ impl<'a> CidRedirector<'a> {
ctx: &'a mut AnchorContext,
) -> Vec<ColumnSort<CId>> {
let cid_redirects = ctx.relation_instances[riid].cid_redirects.clone();
log::debug!("redirect sorts {sorts:?} {riid:?} cid_redirects {cid_redirects:?}");
let mut redirector = CidRedirector { ctx, cid_redirects };

fold_column_sorts(&mut redirector, sorts).unwrap()
}

// revert sort columns back to their original pre-split columns
pub fn revert_sorts(
sorts: Vec<ColumnSort<CId>>,
ctx: &'a mut AnchorContext,
) -> Vec<ColumnSort<CId>> {
sorts
.into_iter()
.map(|sort| {
let decl = ctx.column_decls.get(&sort.column).unwrap();
if let ColumnDecl::RelationColumn(riid, cid, _) = decl {
let cid_redirects = &ctx.relation_instances[riid].cid_redirects;
for (source, target) in cid_redirects.iter() {
if target == cid {
log::debug!("reverting {target:?} back to {source:?}");
return ColumnSort {
direction: sort.direction,
column: *source,
};
}
}
}
sort
})
.collect()
}
}

impl RqFold for CidRedirector<'_> {
fn fold_cid(&mut self, cid: CId) -> Result<CId> {
Ok(self.cid_redirects.get(&cid).cloned().unwrap_or(cid))
let v = self.cid_redirects.get(&cid).cloned().unwrap_or(cid);
log::debug!("mapping {cid:?} via {0:?} to {v:?}", self.cid_redirects);
Ok(v)
}

fn fold_transform(&mut self, transform: Transform) -> Result<Transform> {
Expand Down
68 changes: 35 additions & 33 deletions prqlc/prqlc/src/sql/pq/postprocess.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ fn infer_sorts(query: SqlQuery, ctx: &mut Context) -> SqlQuery {
let mut s = SortingInference {
last_sorting: Vec::new(),
ctes_sorting: HashMap::new(),
main_relation: false,
ctx,
};

Expand All @@ -36,12 +37,12 @@ fn infer_sorts(query: SqlQuery, ctx: &mut Context) -> SqlQuery {
struct SortingInference<'a> {
last_sorting: Sorting,
ctes_sorting: HashMap<TId, CteSorting>,
main_relation: bool,
ctx: &'a mut Context,
}

struct CteSorting {
sorting: Sorting,
has_been_used: bool,
}

impl RqFold for SortingInference<'_> {}
Expand All @@ -50,51 +51,29 @@ impl PqFold for SortingInference<'_> {
fn fold_sql_query(&mut self, query: SqlQuery) -> Result<SqlQuery> {
let mut ctes = Vec::with_capacity(query.ctes.len());
for cte in query.ctes {
log::debug!("infer_sorts: {0:?}", cte.tid);
let cte = self.fold_cte(cte)?;

// store sorting to be used later in From references
let sorting = self.last_sorting.drain(..).collect();
let sorting = CteSorting {
sorting,
has_been_used: false,
};
log::debug!("--- sorting {sorting:?}");
let sorting = CteSorting { sorting };
self.ctes_sorting.insert(cte.tid, sorting);

ctes.push(cte);
}

// fold main_relation using a made-up tid
// fold main_relation
log::debug!("infer_sorts: main relation");
self.main_relation = true;
let mut main_relation = self.fold_sql_relation(query.main_relation)?;
log::debug!("--== last_sorting {0:?}", self.last_sorting);

// push a sort at the back of the main pipeline
if let SqlRelation::AtomicPipeline(pipeline) = &mut main_relation {
pipeline.push(SqlTransform::Sort(self.last_sorting.drain(..).collect()));
}

// make sure that all CTEs whose sorting was used actually SELECT it
for cte in &mut ctes {
let sorting = self.ctes_sorting.get(&cte.tid).unwrap();
if !sorting.has_been_used {
continue;
}

let CteKind::Normal(sql_relation) = &mut cte.kind else {
continue;
};
let Some(pipeline) = sql_relation.as_atomic_pipeline_mut() else {
continue;
};
let select = pipeline.iter_mut().find_map(|x| x.as_select_mut()).unwrap();

for column_sort in &sorting.sorting {
let cid = column_sort.column;
let is_selected = select.contains(&cid);
if !is_selected {
select.push(cid);
}
}
}

Ok(SqlQuery {
ctes,
main_relation,
Expand All @@ -116,6 +95,7 @@ impl PqMapper<RelationExpr, RelationExpr, (), ()> for SortingInference<'_> {
transforms: Vec<SqlTransform<RelationExpr, ()>>,
) -> Result<Vec<SqlTransform<RelationExpr, ()>>> {
let mut sorting = Vec::new();
let mut has_sort_transform = false;

let mut result = Vec::with_capacity(transforms.len() + 1);

Expand All @@ -126,7 +106,6 @@ impl PqMapper<RelationExpr, RelationExpr, (), ()> for SortingInference<'_> {
RelationExprKind::Ref(ref tid) => {
// infer sorting from referenced pipeline
if let Some(cte_sorting) = self.ctes_sorting.get_mut(tid) {
cte_sorting.has_been_used = true;
sorting.clone_from(&cte_sorting.sorting);
} else {
sorting = Vec::new();
Expand All @@ -147,8 +126,9 @@ impl PqMapper<RelationExpr, RelationExpr, (), ()> for SortingInference<'_> {
}

// just store sorting and don't emit Sort
SqlTransform::Sort(s) => {
sorting.clone_from(&s);
SqlTransform::Sort(expr) => {
sorting.clone_from(&expr);
has_sort_transform = true;
continue;
}

Expand All @@ -166,6 +146,28 @@ impl PqMapper<RelationExpr, RelationExpr, (), ()> for SortingInference<'_> {
result.push(transform)
}

if !self.main_relation {
// if this is a CTE, make sure that its SELECT includes the
// columns from the sort
let select = result.iter_mut().find_map(|x| x.as_select_mut()).unwrap();
for column_sort in &sorting {
let cid = column_sort.column;
let is_selected = select.contains(&cid);
if !is_selected {
log::debug!("adding {cid:?} to {select:?}");
select.push(cid);
}
}

if has_sort_transform {
// now revert the sort columns so that the output
// sorting reflects the input column cids, needed to
// ensure proper column reference lookup in the final
// steps
sorting = CidRedirector::revert_sorts(sorting, &mut self.ctx.anchor);
}
}

// remember sorting for this pipeline
self.last_sorting = sorting;

Expand Down
5 changes: 5 additions & 0 deletions prqlc/prqlc/tests/integration/queries/sort_2.prql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from albums
select { AA=album_id, artist_id }
sort AA
filter AA >= 25
join artists (==artist_id)
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
---
source: prqlc/prqlc/tests/integration/queries.rs
expression: "from albums\nselect { AA=album_id, artist_id }\nsort AA\nfilter AA >= 25\njoin artists (==artist_id)\n"
input_file: prqlc/prqlc/tests/integration/queries/sort_2.prql
---
WITH table_1 AS (
SELECT
album_id AS "AA",
artist_id
FROM
albums
),
table_0 AS (
SELECT
"AA",
artist_id
FROM
table_1
WHERE
"AA" >= 25
)
SELECT
table_0."AA",
table_0.artist_id,
artists.*
FROM
table_0
JOIN artists ON table_0.artist_id = artists.artist_id
ORDER BY
table_0."AA"
Loading

0 comments on commit 1c404a8

Please sign in to comment.