Skip to content

Commit 1fcef01

Browse files
committed
add memo trait
This commit adds a `Memo` trait and a first draft of an implementation of the `Memo` trait via the backed ORM-mapped database.
1 parent b96ee5a commit 1fcef01

18 files changed

+806
-31
lines changed

optd-cost-model/src/cost/agg.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

optd-persistent/src/cost_model/orm.rs

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -238,10 +238,13 @@ impl CostModelStorageLayer for BackendManager {
238238
match res {
239239
Ok(insert_res) => insert_res.last_insert_id,
240240
Err(_) => {
241-
return Err(BackendError::BackendError(format!(
242-
"failed to insert statistic {:?} into statistic table",
243-
stat
244-
)))
241+
return Err(BackendError::CostModel(
242+
format!(
243+
"failed to insert statistic {:?} into statistic table",
244+
stat
245+
)
246+
.into(),
247+
))
245248
}
246249
}
247250
}
@@ -450,10 +453,13 @@ impl CostModelStorageLayer for BackendManager {
450453
.collect::<Vec<_>>();
451454

452455
if attr_ids.len() != attr_base_indices.len() {
453-
return Err(BackendError::BackendError(format!(
454-
"Not all attributes found for table_id {} and base indices {:?}",
455-
table_id, attr_base_indices
456-
)));
456+
return Err(BackendError::CostModel(
457+
format!(
458+
"Not all attributes found for table_id {} and base indices {:?}",
459+
table_id, attr_base_indices
460+
)
461+
.into(),
462+
));
457463
}
458464

459465
self.get_stats_for_attr(attr_ids, stat_type, epoch_id).await
@@ -505,10 +511,13 @@ impl CostModelStorageLayer for BackendManager {
505511
.one(&self.db)
506512
.await?;
507513
if expr_exists.is_none() {
508-
return Err(BackendError::BackendError(format!(
509-
"physical expression id {} not found when storing cost",
510-
physical_expression_id
511-
)));
514+
return Err(BackendError::CostModel(
515+
format!(
516+
"physical expression id {} not found when storing cost",
517+
physical_expression_id
518+
)
519+
.into(),
520+
));
512521
}
513522

514523
// Check if epoch_id exists in Event table
@@ -518,10 +527,9 @@ impl CostModelStorageLayer for BackendManager {
518527
.await
519528
.unwrap();
520529
if epoch_exists.is_none() {
521-
return Err(BackendError::BackendError(format!(
522-
"epoch id {} not found when storing cost",
523-
epoch_id
524-
)));
530+
return Err(BackendError::CostModel(
531+
format!("epoch id {} not found when storing cost", epoch_id).into(),
532+
));
525533
}
526534

527535
let new_cost = plan_cost::ActiveModel {

optd-persistent/src/lib.rs

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@ mod migrator;
1313
pub mod cost_model;
1414
pub use cost_model::interface::CostModelStorageLayer;
1515

16+
mod memo;
17+
pub use memo::interface::Memo;
18+
1619
/// The filename of the SQLite database for migration.
1720
pub const DATABASE_FILENAME: &str = "sqlite.db";
1821
/// The URL of the SQLite database for migration.
@@ -39,17 +42,48 @@ fn get_sqlite_url(file: &str) -> String {
3942
format!("sqlite:{}?mode=rwc", file)
4043
}
4144

42-
pub type StorageResult<T> = Result<T, BackendError>;
45+
#[derive(Debug)]
46+
pub enum CostModelError {
47+
// TODO: Add more error types
48+
UnknownStatisticType,
49+
VersionedStatisticNotFound,
50+
CustomError(String),
51+
}
4352

53+
/// TODO convert this to `thiserror`
54+
#[derive(Debug)]
55+
/// The different kinds of errors that might occur while running operations on a memo table.
56+
pub enum MemoError {
57+
UnknownGroup,
58+
UnknownLogicalExpression,
59+
UnknownPhysicalExpression,
60+
InvalidExpression,
61+
}
62+
63+
/// TODO convert this to `thiserror`
4464
#[derive(Debug)]
4565
pub enum BackendError {
66+
Memo(MemoError),
4667
DatabaseError(DbErr),
68+
CostModel(CostModelError),
4769
BackendError(String),
4870
}
4971

50-
impl From<String> for BackendError {
72+
impl From<String> for CostModelError {
5173
fn from(value: String) -> Self {
52-
BackendError::BackendError(value)
74+
CostModelError::CustomError(value)
75+
}
76+
}
77+
78+
impl From<CostModelError> for BackendError {
79+
fn from(value: CostModelError) -> Self {
80+
BackendError::CostModel(value)
81+
}
82+
}
83+
84+
impl From<MemoError> for BackendError {
85+
fn from(value: MemoError) -> Self {
86+
BackendError::Memo(value)
5387
}
5488
}
5589

@@ -59,6 +93,9 @@ impl From<DbErr> for BackendError {
5993
}
6094
}
6195

96+
/// A type alias for a result with [`BackendError`] as the error type.
97+
pub type StorageResult<T> = Result<T, BackendError>;
98+
6299
pub struct BackendManager {
63100
db: DatabaseConnection,
64101
}

optd-persistent/src/main.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ use optd_persistent::DATABASE_URL;
1717

1818
#[tokio::main]
1919
async fn main() {
20+
basic_demo().await;
21+
memo_demo().await;
22+
}
23+
24+
async fn memo_demo() {
25+
let _db = Database::connect(DATABASE_URL).await.unwrap();
26+
27+
todo!()
28+
}
29+
30+
async fn basic_demo() {
2031
let db = Database::connect(DATABASE_URL).await.unwrap();
2132

2233
// Create a new `CascadesGroup`.
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
use crate::entities::*;
2+
use std::hash::{DefaultHasher, Hash, Hasher};
3+
4+
/// All of the different types of fixed logical operators.
5+
///
6+
/// Note that there could be more operators that the memo table must support that are not enumerated
7+
/// in this enum, as there can be up to `2^16` different types of operators.
8+
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
9+
#[non_exhaustive]
10+
#[repr(i16)]
11+
pub enum LogicalOperator {
12+
Scan,
13+
Join,
14+
}
15+
16+
/// All of the different types of fixed physical operators.
17+
///
18+
/// Note that there could be more operators that the memo table must support that are not enumerated
19+
/// in this enum, as there can be up to `2^16` different types of operators.
20+
#[derive(Debug, Clone, Copy, Hash, Eq, PartialEq)]
21+
#[non_exhaustive]
22+
#[repr(i16)]
23+
pub enum PhysicalOperator {
24+
TableScan,
25+
IndexScan,
26+
NestedLoopJoin,
27+
HashJoin,
28+
}
29+
30+
/// A method to generate a fingerprint used to efficiently check if two
31+
/// expressions are equivalent.
32+
///
33+
/// TODO actually make efficient.
34+
fn fingerprint(variant_tag: i16, data: &serde_json::Value) -> i64 {
35+
let mut hasher = DefaultHasher::new();
36+
37+
variant_tag.hash(&mut hasher);
38+
data.hash(&mut hasher);
39+
40+
hasher.finish() as i64
41+
}
42+
43+
impl logical_expression::Model {
44+
/// Creates a new logical expression with an unset `id` and `group_id`.
45+
pub fn new(variant_tag: LogicalOperator, data: serde_json::Value) -> Self {
46+
let tag = variant_tag as i16;
47+
let fingerprint = fingerprint(tag, &data);
48+
49+
Self {
50+
id: 0,
51+
group_id: 0,
52+
fingerprint,
53+
variant_tag: tag,
54+
data,
55+
}
56+
}
57+
}
58+
59+
impl physical_expression::Model {
60+
/// Creates a new physical expression with an unset `id` and `group_id`.
61+
pub fn new(variant_tag: PhysicalOperator, data: serde_json::Value) -> Self {
62+
let tag = variant_tag as i16;
63+
let fingerprint = fingerprint(tag, &data);
64+
65+
Self {
66+
id: 0,
67+
group_id: 0,
68+
fingerprint,
69+
variant_tag: tag,
70+
data,
71+
}
72+
}
73+
}

optd-persistent/src/memo/interface.rs

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,141 @@
1+
use crate::StorageResult;
2+
3+
/// A trait representing an implementation of a memoization table.
4+
///
5+
/// Note that we use [`trait_variant`] here in order to add bounds on every method.
6+
/// See this [blog post](
7+
/// https://blog.rust-lang.org/2023/12/21/async-fn-rpit-in-traits.html#async-fn-in-public-traits)
8+
/// for more information.
9+
///
10+
/// TODO Figure out for each when to get the ID of a record or the entire record itself.
11+
#[trait_variant::make(Send)]
12+
pub trait Memo {
13+
/// A type representing a group in the Cascades framework.
14+
type Group;
15+
/// A type representing a unique identifier for a group.
16+
type GroupId;
17+
/// A type representing a logical expression.
18+
type LogicalExpression;
19+
/// A type representing a unique identifier for a logical expression.
20+
type LogicalExpressionId;
21+
/// A type representing a physical expression.
22+
type PhysicalExpression;
23+
/// A type representing a unique identifier for a physical expression.
24+
type PhysicalExpressionId;
25+
26+
/// Retrieves a [`Self::Group`] given a [`Self::GroupId`].
27+
///
28+
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
29+
async fn get_group(&self, group_id: Self::GroupId) -> StorageResult<Self::Group>;
30+
31+
/// Retrieves all group IDs that are stored in the memo table.
32+
async fn get_all_groups(&self) -> StorageResult<Vec<Self::Group>>;
33+
34+
/// Retrieves a [`Self::LogicalExpression`] given a [`Self::LogicalExpressionId`].
35+
///
36+
/// If the logical expression does not exist, returns a [`MemoError::UnknownLogicalExpression`]
37+
/// error.
38+
async fn get_logical_expression(
39+
&self,
40+
logical_expression_id: Self::LogicalExpressionId,
41+
) -> StorageResult<Self::LogicalExpression>;
42+
43+
/// Retrieves a [`Self::PhysicalExpression`] given a [`Self::PhysicalExpressionId`].
44+
///
45+
/// If the physical expression does not exist, returns a
46+
/// [`MemoError::UnknownPhysicalExpression`] error.
47+
async fn get_physical_expression(
48+
&self,
49+
physical_expression_id: Self::PhysicalExpressionId,
50+
) -> StorageResult<Self::PhysicalExpression>;
51+
52+
/// Retrieves the parent group ID of a logical expression given its expression ID.
53+
///
54+
/// If the logical expression does not exist, returns a [`MemoError::UnknownLogicalExpression`]
55+
/// error.
56+
async fn get_group_from_logical_expression(
57+
&self,
58+
logical_expression_id: Self::LogicalExpressionId,
59+
) -> StorageResult<Self::GroupId>;
60+
61+
/// Retrieves the parent group ID of a logical expression given its expression ID.
62+
///
63+
/// If the physical expression does not exist, returns a
64+
/// [`MemoError::UnknownPhysicalExpression`] error.
65+
async fn get_group_from_physical_expression(
66+
&self,
67+
physical_expression_id: Self::PhysicalExpressionId,
68+
) -> StorageResult<Self::GroupId>;
69+
70+
/// Retrieves all of the logical expression "children" of a group.
71+
///
72+
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
73+
async fn get_group_logical_expressions(
74+
&self,
75+
group_id: Self::GroupId,
76+
) -> StorageResult<Vec<Self::LogicalExpression>>;
77+
78+
/// Retrieves all of the physical expression "children" of a group.
79+
///
80+
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
81+
async fn get_group_physical_expressions(
82+
&self,
83+
group_id: Self::GroupId,
84+
) -> StorageResult<Vec<Self::PhysicalExpression>>;
85+
86+
/// Retrieves the best physical query plan (winner) for a given group.
87+
///
88+
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
89+
async fn get_winner(
90+
&self,
91+
group_id: Self::GroupId,
92+
) -> StorageResult<Option<Self::PhysicalExpressionId>>;
93+
94+
/// Updates / replaces a group's best physical plan (winner). Optionally returns the previous
95+
/// winner's physical expression ID.
96+
///
97+
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
98+
async fn update_group_winner(
99+
&self,
100+
group_id: Self::GroupId,
101+
physical_expression_id: Self::PhysicalExpressionId,
102+
) -> StorageResult<Option<Self::PhysicalExpressionId>>;
103+
104+
/// Adds a logical expression to an existing group via its [`Self::GroupId`].
105+
///
106+
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
107+
async fn add_logical_expression_to_group(
108+
&self,
109+
group_id: Self::GroupId,
110+
logical_expression: Self::LogicalExpression,
111+
children: Vec<Self::LogicalExpressionId>,
112+
) -> StorageResult<()>;
113+
114+
/// Adds a physical expression to an existing group via its [`Self::GroupId`].
115+
///
116+
/// If the group does not exist, returns a [`MemoError::UnknownGroup`] error.
117+
async fn add_physical_expression_to_group(
118+
&self,
119+
group_id: Self::GroupId,
120+
physical_expression: Self::PhysicalExpression,
121+
children: Vec<Self::LogicalExpressionId>,
122+
) -> StorageResult<()>;
123+
124+
/// Adds a new logical expression into the memo table, creating a new group if the expression
125+
/// does not already exist.
126+
///
127+
/// The [`Self::LogicalExpression`] type should have some sort of mechanism for checking if
128+
/// the expression has been seen before, and if it has already been created, then the parent
129+
/// group ID should also be retrievable.
130+
///
131+
/// If the expression already exists, then this function will return the [`Self::GroupId`] of
132+
/// the parent group and the corresponding (already existing) [`Self::LogicalExpressionId`].
133+
///
134+
/// If the expression does not exist, this function will create a new group and a new
135+
/// expression, returning brand new IDs for both.
136+
async fn add_logical_expression(
137+
&self,
138+
expression: Self::LogicalExpression,
139+
children: Vec<Self::LogicalExpressionId>,
140+
) -> StorageResult<(Self::GroupId, Self::LogicalExpressionId)>;
141+
}

optd-persistent/src/memo/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
mod expression;
2+
3+
pub mod interface;
4+
pub mod orm;

0 commit comments

Comments
 (0)