Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(torii-sqlite): support enum upgrade of variants #2930

Merged
merged 16 commits into from
Feb 3, 2025
53 changes: 37 additions & 16 deletions crates/dojo/types/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,13 +242,19 @@ impl Ty {
let diff_children: Vec<Member> = s1
.children
.iter()
.filter(|m1| {
s2.children
.iter()
.find(|m2| m2.name == m1.name)
.map_or(true, |m2| *m1 != m2)
.filter_map(|m1| {
if let Some(m2) = s2.children.iter().find(|m2| m2.name == m1.name) {
// Member exists in both - check if types are different
m1.ty.diff(&m2.ty).map(|diff_ty| Member {
name: m1.name.clone(),
ty: diff_ty,
key: m1.key,
})
} else {
// Member doesn't exist in s2
Some(m1.clone())
}
Comment on lines +245 to +256
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Ohayo sensei! Consider diffing changes to the key property.

Currently, this logic only compares and diffs member types. If a member changes from non-key to key (or vice versa), it won’t be reflected. This might lead to schema discrepancies where primary keys are misaligned.

Here's a small illustrative snippet to demonstrate checking for key changes:

 if let Some(m2) = s2.children.iter().find(|m2| m2.name == m1.name) {
     // Member exists in both - check for differences
     // ...
-    m1.ty.diff(&m2.ty).map(|diff_ty| Member {
+    let diff_ty = m1.ty.diff(&m2.ty);
+    let diff_key = (m1.key != m2.key);
+    if diff_ty.is_some() || diff_key {
         // Return a new member capturing both potential type & key changes
         Some(Member {
             name: m1.name.clone(),
             ty: diff_ty.unwrap_or_else(|| m1.ty.clone()),
             key: m1.key,
         })
     } else {
         None
     }
 } else {
     Some(m1.clone())
 }

Committable suggestion skipped: line range outside the PR's diff.

})
.cloned()
.collect();

if diff_children.is_empty() {
Expand All @@ -262,10 +268,17 @@ impl Ty {
let diff_options: Vec<EnumOption> = e1
.options
.iter()
.filter(|o1| {
e2.options.iter().find(|o2| o2.name == o1.name).map_or(true, |o2| *o1 != o2)
.filter_map(|o1| {
if let Some(o2) = e2.options.iter().find(|o2| o2.name == o1.name) {
// Option exists in both - check if types are different
o1.ty
.diff(&o2.ty)
.map(|diff_ty| EnumOption { name: o1.name.clone(), ty: diff_ty })
} else {
// Option doesn't exist in e2
Some(o1.clone())
}
})
.cloned()
.collect();

if diff_options.is_empty() {
Expand All @@ -278,18 +291,26 @@ impl Ty {
}))
}
}
(Ty::Array(a1), Ty::Array(a2)) => {
if a1 == a2 {
None
(Ty::Tuple(t1), Ty::Tuple(t2)) => {
if t1.len() != t2.len() {
Some(Ty::Tuple(
t1.iter()
.filter_map(|ty| if !t2.contains(ty) { Some(ty.clone()) } else { None })
.collect(),
))
} else {
Some(Ty::Array(a1.clone()))
// Compare each tuple element recursively
let diff_elements: Vec<Ty> =
t1.iter().zip(t2.iter()).filter_map(|(ty1, ty2)| ty1.diff(ty2)).collect();

if diff_elements.is_empty() { None } else { Some(Ty::Tuple(diff_elements)) }
}
}
(Ty::Tuple(t1), Ty::Tuple(t2)) => {
if t1 == t2 {
(Ty::Array(a1), Ty::Array(a2)) => {
if a1 == a2 {
None
} else {
Some(Ty::Tuple(t1.clone()))
Some(Ty::Array(a1.clone()))
}
}
(Ty::ByteArray(b1), Ty::ByteArray(b2)) => {
Expand Down
101 changes: 84 additions & 17 deletions crates/torii/sqlite/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -872,20 +872,44 @@ fn add_columns_recursive(
));
};

let modify_column =
|alter_table_queries: &mut Vec<String>, name: &str, sql_type: &str, sql_value: &str| {
// SQLite doesn't support ALTER COLUMN directly, so we need to:
// 1. Create a temporary table to store the current values
// 2. Drop the old column & index
// 3. Create new column with new type/constraint
// 4. Copy values back & create new index
alter_table_queries.push(format!(
"CREATE TEMPORARY TABLE tmp_values_{name} AS SELECT internal_id, [{name}] FROM \
[{table_id}]"
));
alter_table_queries.push(format!("DROP INDEX IF EXISTS [idx_{table_id}_{name}]"));
alter_table_queries.push(format!("ALTER TABLE [{table_id}] DROP COLUMN [{name}]"));
alter_table_queries
.push(format!("ALTER TABLE [{table_id}] ADD COLUMN [{name}] {sql_type}"));
alter_table_queries.push(format!("UPDATE [{table_id}] SET [{name}] = {sql_value}"));
alter_table_queries.push(format!("DROP TABLE tmp_values_{name}"));
alter_table_queries.push(format!(
"CREATE INDEX IF NOT EXISTS [idx_{table_id}_{name}] ON [{table_id}] ([{name}]);"
));
};
Comment on lines +875 to +895
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Ohayo sensei! Fix potential data loss in column modification.

The modify_column closure creates a temporary table but doesn't copy the values back to the new column. The UPDATE statement on line 865 sets a static sql_value instead of restoring the original values.

Apply this diff to fix the data loss:

             alter_table_queries.push(format!(
                 "CREATE TEMPORARY TABLE tmp_values_{name} AS SELECT internal_id, [{name}] FROM \
                  [{table_id}]"
             ));
             alter_table_queries.push(format!("DROP INDEX IF EXISTS [idx_{table_id}_{name}]"));
             alter_table_queries.push(format!("ALTER TABLE [{table_id}] DROP COLUMN [{name}]"));
             alter_table_queries
                 .push(format!("ALTER TABLE [{table_id}] ADD COLUMN [{name}] {sql_type}"));
-            alter_table_queries.push(format!("UPDATE [{table_id}] SET [{name}] = {sql_value}"));
+            alter_table_queries.push(format!(
+                "UPDATE [{table_id}] SET [{name}] = (SELECT [{name}] FROM tmp_values_{name} \
+                 WHERE tmp_values_{name}.internal_id = [{table_id}].internal_id)"
+            ));
             alter_table_queries.push(format!("DROP TABLE tmp_values_{name}"));
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
let modify_column =
|alter_table_queries: &mut Vec<String>, name: &str, sql_type: &str, sql_value: &str| {
// SQLite doesn't support ALTER COLUMN directly, so we need to:
// 1. Create a temporary table to store the current values
// 2. Drop the old column & index
// 3. Create new column with new type/constraint
// 4. Copy values back & create new index
alter_table_queries.push(format!(
"CREATE TEMPORARY TABLE tmp_values_{name} AS SELECT internal_id, [{name}] FROM \
[{table_id}]"
));
alter_table_queries.push(format!("DROP INDEX IF EXISTS [idx_{table_id}_{name}]"));
alter_table_queries.push(format!("ALTER TABLE [{table_id}] DROP COLUMN [{name}]"));
alter_table_queries
.push(format!("ALTER TABLE [{table_id}] ADD COLUMN [{name}] {sql_type}"));
alter_table_queries.push(format!("UPDATE [{table_id}] SET [{name}] = {sql_value}"));
alter_table_queries.push(format!("DROP TABLE tmp_values_{name}"));
alter_table_queries.push(format!(
"CREATE INDEX IF NOT EXISTS [idx_{table_id}_{name}] ON [{table_id}] ([{name}]);"
));
};
let modify_column =
|alter_table_queries: &mut Vec<String>, name: &str, sql_type: &str, sql_value: &str| {
// SQLite doesn't support ALTER COLUMN directly, so we need to:
// 1. Create a temporary table to store the current values
// 2. Drop the old column & index
// 3. Create new column with new type/constraint
// 4. Copy values back & create new index
alter_table_queries.push(format!(
"CREATE TEMPORARY TABLE tmp_values_{name} AS SELECT internal_id, [{name}] FROM \
[{table_id}]"
));
alter_table_queries.push(format!("DROP INDEX IF EXISTS [idx_{table_id}_{name}]"));
alter_table_queries.push(format!("ALTER TABLE [{table_id}] DROP COLUMN [{name}]"));
alter_table_queries
.push(format!("ALTER TABLE [{table_id}] ADD COLUMN [{name}] {sql_type}"));
alter_table_queries.push(format!(
"UPDATE [{table_id}] SET [{name}] = (SELECT [{name}] FROM tmp_values_{name} \
WHERE tmp_values_{name}.internal_id = [{table_id}].internal_id)"
));
alter_table_queries.push(format!("DROP TABLE tmp_values_{name}"));
alter_table_queries.push(format!(
"CREATE INDEX IF NOT EXISTS [idx_{table_id}_{name}] ON [{table_id}] ([{name}]);"
));
};


match ty {
Ty::Struct(s) => {
let struct_diff =
if let Some(upgrade_diff) = upgrade_diff { upgrade_diff.as_struct() } else { None };

for member in &s.children {
if let Some(upgrade_diff) = upgrade_diff {
if !upgrade_diff
.as_struct()
.unwrap()
.children
.iter()
.any(|m| m.name == member.name)
{
let member_diff = if let Some(diff) = struct_diff {
if let Some(m) = diff.children.iter().find(|m| m.name == member.name) {
Some(&m.ty)
} else {
// If the member is not in the diff, skip it
continue;
}
}
} else {
None
};

let mut new_path = path.to_vec();
new_path.push(member.name.clone());
Expand All @@ -897,23 +921,38 @@ fn add_columns_recursive(
alter_table_queries,
indices,
table_id,
None,
member_diff,
)?;
}
}
Ty::Tuple(tuple) => {
for (idx, member) in tuple.iter().enumerate() {
let elements_to_process = if let Some(diff) = upgrade_diff.and_then(|d| d.as_tuple()) {
// Only process elements from the diff
diff.iter()
.filter_map(|m| {
tuple.iter().position(|member| member == m).map(|idx| (idx, m, Some(m)))
})
.collect()
} else {
// Process all elements
tuple
.iter()
.enumerate()
.map(|(idx, member)| (idx, member, None))
.collect::<Vec<_>>()
};

for (idx, member, member_diff) in elements_to_process {
let mut new_path = path.to_vec();
new_path.push(idx.to_string());

add_columns_recursive(
&new_path,
member,
columns,
alter_table_queries,
indices,
table_id,
None,
member_diff,
)?;
}
}
Expand All @@ -924,17 +963,45 @@ fn add_columns_recursive(
add_column(&column_name, "TEXT");
}
Ty::Enum(e) => {
// The variant of the enum
let enum_diff =
if let Some(upgrade_diff) = upgrade_diff { upgrade_diff.as_enum() } else { None };

let column_name =
if column_prefix.is_empty() { "option".to_string() } else { column_prefix };

let all_options =
e.options.iter().map(|c| format!("'{}'", c.name)).collect::<Vec<_>>().join(", ");

let sql_type = format!("TEXT CHECK([{column_name}] IN ({all_options}))");
add_column(&column_name, &sql_type);
let sql_type = format!(
"TEXT CONSTRAINT [{column_name}_check] CHECK([{column_name}] IN ({all_options}))"
);
if enum_diff.is_some_and(|diff| diff != e) {
// For upgrades, modify the existing option column to add the new options to the
// CHECK constraint We need to drop the old column and create a new
// one with the new CHECK constraint
modify_column(
alter_table_queries,
&column_name,
&sql_type,
&format!("[{column_name}]"),
);
} else {
// For new tables, create the column directly
add_column(&column_name, &sql_type);
}

for child in &e.options {
// If we have a diff, only process new variants that aren't in the original enum
let variant_diff = if let Some(diff) = enum_diff {
if let Some(v) = diff.options.iter().find(|v| v.name == child.name) {
Some(&v.ty)
} else {
continue;
}
} else {
None
};

if let Ty::Tuple(tuple) = &child.ty {
if tuple.is_empty() {
continue;
Expand All @@ -951,7 +1018,7 @@ fn add_columns_recursive(
alter_table_queries,
indices,
table_id,
None,
variant_diff,
)?;
}
}
Expand Down
Loading