From 9863a97c77906427a4627e9632469e66740fbc6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Ma=C4=87kowski?= Date: Mon, 2 Dec 2024 07:17:25 +0100 Subject: [PATCH 1/9] feat(orm): add foreign key support --- Cargo.lock | 2 + flareon-cli/Cargo.toml | 2 +- flareon-cli/src/migration_generator.rs | 558 +++--------------- flareon-cli/tests/migration_generator.rs | 65 +- .../tests/migration_generator/create_model.rs | 3 +- flareon-codegen/Cargo.toml | 5 + flareon-codegen/src/expr.rs | 280 +++++++-- flareon-codegen/src/lib.rs | 16 + flareon-codegen/src/model.rs | 203 ++++++- flareon-codegen/src/symbol_resolver.rs | 479 +++++++++++++++ flareon-macros/src/model.rs | 61 +- flareon-macros/src/query.rs | 38 +- flareon-macros/tests/compile_tests.rs | 3 + .../tests/ui/attr_model_multiple_pks.rs | 11 + .../tests/ui/attr_model_multiple_pks.stderr | 5 + flareon-macros/tests/ui/attr_model_no_pk.rs | 8 + .../tests/ui/attr_model_no_pk.stderr | 5 + .../ui/func_query_method_call_on_db_field.rs | 14 + .../func_query_method_call_on_db_field.stderr | 5 + flareon/Cargo.toml | 4 + flareon/src/auth.rs | 3 +- flareon/src/db.rs | 299 +++++++--- flareon/src/db/fields.rs | 162 ++++- flareon/src/db/impl_mysql.rs | 8 + flareon/src/db/impl_postgres.rs | 8 + flareon/src/db/impl_sqlite.rs | 19 +- flareon/src/db/migrations.rs | 52 +- flareon/src/db/query.rs | 68 ++- flareon/src/db/relations.rs | 109 ++++ flareon/src/db/sea_query_db.rs | 5 +- flareon/tests/db.rs | 178 +++++- 31 files changed, 1969 insertions(+), 709 deletions(-) create mode 100644 flareon-codegen/src/symbol_resolver.rs create mode 100644 flareon-macros/tests/ui/attr_model_multiple_pks.rs create mode 100644 flareon-macros/tests/ui/attr_model_multiple_pks.stderr create mode 100644 flareon-macros/tests/ui/attr_model_no_pk.rs create mode 100644 flareon-macros/tests/ui/attr_model_no_pk.stderr create mode 100644 flareon-macros/tests/ui/func_query_method_call_on_db_field.rs create mode 100644 flareon-macros/tests/ui/func_query_method_call_on_db_field.stderr create mode 100644 flareon/src/db/relations.rs diff --git a/Cargo.lock b/Cargo.lock index 3565c940..3037f8f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -808,6 +808,7 @@ dependencies = [ "chrono", "derive_builder", "derive_more", + "env_logger", "fake", "flareon_macros", "form_urlencoded", @@ -870,6 +871,7 @@ version = "0.1.0" dependencies = [ "convert_case", "darling", + "log", "proc-macro2", "quote", "syn", diff --git a/flareon-cli/Cargo.toml b/flareon-cli/Cargo.toml index e25a8d8a..677f554d 100644 --- a/flareon-cli/Cargo.toml +++ b/flareon-cli/Cargo.toml @@ -16,7 +16,7 @@ clap = { workspace = true, features = ["derive", "env"] } clap-verbosity-flag = { workspace = true, features = ["tracing"] } darling.workspace = true flareon.workspace = true -flareon_codegen.workspace = true +flareon_codegen = { workspace = true, features = ["symbol-resolver"] } glob.workspace = true prettyplease.workspace = true proc-macro2 = { workspace = true, features = ["span-locations"] } diff --git a/flareon-cli/src/migration_generator.rs b/flareon-cli/src/migration_generator.rs index d799cb15..01f0bda6 100644 --- a/flareon-cli/src/migration_generator.rs +++ b/flareon-cli/src/migration_generator.rs @@ -10,6 +10,7 @@ use cargo_toml::Manifest; use darling::FromMeta; use flareon::db::migrations::{DynMigration, MigrationEngine}; use flareon_codegen::model::{Field, Model, ModelArgs, ModelOpts, ModelType}; +use flareon_codegen::symbol_resolver::{ModulePath, SymbolResolver, VisibleSymbol}; use proc_macro2::TokenStream; use quote::{format_ident, quote}; use syn::{parse_quote, Attribute, Meta, UseTree}; @@ -73,34 +74,47 @@ impl MigrationGenerator { fn generate_and_write_migrations(&mut self) -> anyhow::Result<()> { let source_files = self.get_source_files()?; - if let Some(migration) = self.generate_migrations(source_files)? { + if let Some(migration) = self.generate_migrations_to_write(source_files)? { self.write_migration(migration)?; } Ok(()) } + pub fn generate_migrations_to_write( + &mut self, + source_files: Vec, + ) -> anyhow::Result> { + if let Some(migration) = self.generate_migrations(source_files)? { + let migration_name = migration.migration_name.clone(); + let content = self.generate_migration_file_content(migration); + Ok(Some(MigrationAsSource::new(migration_name, content))) + } else { + Ok(None) + } + } + pub fn generate_migrations( &mut self, source_files: Vec, - ) -> anyhow::Result> { + ) -> anyhow::Result> { let AppState { models, migrations } = self.process_source_files(source_files)?; let migration_processor = MigrationProcessor::new(migrations)?; let migration_models = migration_processor.latest_models(); - let (modified_models, operations) = self.generate_operations(&models, &migration_models); + let (modified_models, operations) = self.generate_operations(&models, &migration_models); if operations.is_empty() { Ok(None) } else { let migration_name = migration_processor.next_migration_name()?; - let dependencies = migration_processor.dependencies(); - let content = self.generate_migration_file_content( - &migration_name, - &modified_models, + let dependencies = migration_processor.base_dependencies(); + + Ok(Some(GeneratedMigration { + migration_name, + modified_models, dependencies, operations, - ); - Ok(Some(MigrationToWrite::new(migration_name, content))) + })) } } @@ -173,18 +187,18 @@ impl MigrationGenerator { }: SourceFile, app_state: &mut AppState, ) -> anyhow::Result<()> { - let imports = Self::get_imports(&file, &ModulePath::from_fs_path(&path)); - let import_resolver = SymbolResolver::new(imports); + let symbol_resolver = SymbolResolver::from_file(&file, &path); let mut migration_models = Vec::new(); for item in file.items { if let syn::Item::Struct(mut item) = item { for attr in &item.attrs.clone() { if is_model_attr(attr) { - import_resolver.resolve_struct(&mut item); + symbol_resolver.resolve_struct(&mut item); let args = Self::args_from_attr(&path, attr)?; - let model_in_source = ModelInSource::from_item(item, &args)?; + let model_in_source = + ModelInSource::from_item(item, &args, &symbol_resolver)?; match args.model_type { ModelType::Application => app_state.models.push(model_in_source), @@ -214,29 +228,6 @@ impl MigrationGenerator { Ok(()) } - /// Return the list of top-level `use` statements, structs, and constants as - /// a list of [`VisibleSymbol`]s from the file. - fn get_imports(file: &syn::File, module_path: &ModulePath) -> Vec { - let mut imports = Vec::new(); - - for item in &file.items { - match item { - syn::Item::Use(item) => { - imports.append(&mut VisibleSymbol::from_item_use(item, module_path)); - } - syn::Item::Struct(item_struct) => { - imports.push(VisibleSymbol::from_item_struct(item_struct, module_path)); - } - syn::Item::Const(item_const) => { - imports.push(VisibleSymbol::from_item_const(item_const, module_path)); - } - _ => {} - } - } - - imports - } - fn args_from_attr(path: &Path, attr: &Attribute) -> Result { match attr.meta { Meta::Path(_) => { @@ -398,23 +389,20 @@ impl MigrationGenerator { todo!() } - fn generate_migration_file_content( - &self, - migration_name: &str, - modified_models: &[ModelInSource], - dependencies: Vec, - operations: Vec, - ) -> String { - let operations: Vec<_> = operations + fn generate_migration_file_content(&self, migration: GeneratedMigration) -> String { + let operations: Vec<_> = migration + .operations .into_iter() .map(|operation| operation.repr()) .collect(); - let dependencies: Vec<_> = dependencies + let dependencies: Vec<_> = migration + .dependencies .into_iter() .map(|dependency| dependency.repr()) .collect(); let app_name = self.options.app_name.as_ref().unwrap_or(&self.crate_name); + let migration_name = &migration.migration_name; let migration_def = quote! { #[derive(Debug, Copy, Clone)] pub(super) struct Migration; @@ -431,7 +419,8 @@ impl MigrationGenerator { } }; - let models = modified_models + let models = migration + .modified_models .iter() .map(Self::model_to_migration_model) .collect::>(); @@ -442,7 +431,7 @@ impl MigrationGenerator { Self::generate_migration(migration_def, models_def) } - fn write_migration(&self, migration: MigrationToWrite) -> anyhow::Result<()> { + fn write_migration(&self, migration: MigrationAsSource) -> anyhow::Result<()> { let src_path = self .options .output_dir @@ -547,290 +536,6 @@ impl AppState { } } -/// Represents a symbol visible in the current module. This might mean there is -/// a `use` statement for a given type, but also, for instance, the type is -/// defined in the current module. -/// -/// For instance, for `use std::collections::HashMap;` the `VisibleSymbol ` -/// would be: -/// ```ignore -/// # /* -/// VisibleSymbol { -/// alias: "HashMap", -/// full_path: "std::collections::HashMap", -/// kind: VisibleSymbolKind::Use, -/// } -/// # */ -/// ``` -#[derive(Debug, Clone, PartialEq, Eq)] -struct VisibleSymbol { - alias: String, - full_path: String, - kind: VisibleSymbolKind, -} - -#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] -enum VisibleSymbolKind { - Use, - Struct, - Const, -} - -impl VisibleSymbol { - #[must_use] - fn new(alias: &str, full_path: &str, kind: VisibleSymbolKind) -> Self { - Self { - alias: alias.to_string(), - full_path: full_path.to_string(), - kind, - } - } - - fn full_path_parts(&self) -> impl Iterator { - self.full_path.split("::") - } - - fn new_use(alias: &str, full_path: &str) -> Self { - Self::new(alias, full_path, VisibleSymbolKind::Use) - } - - fn from_item_use(item: &syn::ItemUse, module_path: &ModulePath) -> Vec { - Self::from_tree(&item.tree, module_path) - } - - fn from_item_struct(item: &syn::ItemStruct, module_path: &ModulePath) -> Self { - let ident = item.ident.to_string(); - let full_path = Self::module_path(module_path, &ident); - - Self { - alias: ident, - full_path, - kind: VisibleSymbolKind::Struct, - } - } - - fn from_item_const(item: &syn::ItemConst, module_path: &ModulePath) -> Self { - let ident = item.ident.to_string(); - let full_path = Self::module_path(module_path, &ident); - - Self { - alias: ident, - full_path, - kind: VisibleSymbolKind::Const, - } - } - - fn module_path(module_path: &ModulePath, ident: &str) -> String { - format!("{module_path}::{ident}") - } - - fn from_tree(tree: &UseTree, current_module: &ModulePath) -> Vec { - match tree { - UseTree::Path(path) => { - let ident = path.ident.to_string(); - let resolved_path = if ident == "crate" { - current_module.crate_name().to_string() - } else if ident == "self" { - current_module.to_string() - } else if ident == "super" { - current_module.parent().to_string() - } else { - ident - }; - - return Self::from_tree(&path.tree, current_module) - .into_iter() - .map(|import| { - Self::new_use( - &import.alias, - &format!("{}::{}", resolved_path, import.full_path), - ) - }) - .collect(); - } - UseTree::Name(name) => { - let ident = name.ident.to_string(); - return vec![Self::new_use(&ident, &ident)]; - } - UseTree::Rename(rename) => { - return vec![Self::new_use( - &rename.rename.to_string(), - &rename.ident.to_string(), - )]; - } - UseTree::Glob(_) => { - warn!("Glob imports are not supported"); - } - UseTree::Group(group) => { - return group - .items - .iter() - .flat_map(|tree| Self::from_tree(tree, current_module)) - .collect(); - } - } - - vec![] - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct ModulePath { - parts: Vec, -} - -impl ModulePath { - #[must_use] - fn from_fs_path(path: &Path) -> Self { - let mut parts = vec![String::from("crate")]; - - if path == Path::new("lib.rs") || path == Path::new("main.rs") { - return Self { parts }; - } - - parts.append( - &mut path - .components() - .map(|c| { - let component_str = c.as_os_str().to_string_lossy(); - component_str - .strip_suffix(".rs") - .unwrap_or(&component_str) - .to_string() - }) - .collect::>(), - ); - - if parts - .last() - .expect("parts must have at least one component") - == "mod" - { - parts.pop(); - } - - Self { parts } - } - - #[must_use] - fn parent(&self) -> Self { - let mut parts = self.parts.clone(); - parts.pop(); - Self { parts } - } - - #[must_use] - fn crate_name(&self) -> &str { - &self.parts[0] - } -} - -impl Display for ModulePath { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.parts.join("::")) - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -struct SymbolResolver { - /// List of imports in the format `"HashMap" -> VisibleSymbol` - symbols: HashMap, -} - -impl SymbolResolver { - #[must_use] - fn new(symbols: Vec) -> Self { - let mut symbol_map = HashMap::new(); - for symbol in symbols { - symbol_map.insert(symbol.alias.clone(), symbol); - } - - Self { - symbols: symbol_map, - } - } - - fn resolve_struct(&self, item: &mut syn::ItemStruct) { - for field in &mut item.fields { - if let syn::Type::Path(path) = &mut field.ty { - self.resolve(path); - } - } - } - - /// Checks the provided `TypePath` and resolves the full type path, if - /// available. - fn resolve(&self, path: &mut syn::TypePath) { - let first_segment = path.path.segments.first(); - - if let Some(first_segment) = first_segment { - if let Some(symbol) = self.symbols.get(&first_segment.ident.to_string()) { - let mut new_segments: Vec<_> = symbol - .full_path_parts() - .map(|s| syn::PathSegment { - ident: syn::Ident::new(s, first_segment.ident.span()), - arguments: syn::PathArguments::None, - }) - .collect(); - - let first_arguments = first_segment.arguments.clone(); - new_segments - .last_mut() - .expect("new_segments must have at least one element") - .arguments = first_arguments; - - new_segments.extend(path.path.segments.iter().skip(1).cloned()); - path.path.segments = syn::punctuated::Punctuated::from_iter(new_segments); - } - - for segment in &mut path.path.segments { - self.resolve_path_arguments(&mut segment.arguments); - } - } - } - - fn resolve_path_arguments(&self, arguments: &mut syn::PathArguments) { - if let syn::PathArguments::AngleBracketed(args) = arguments { - for arg in &mut args.args { - self.resolve_generic_argument(arg); - } - } - } - - fn resolve_generic_argument(&self, arg: &mut syn::GenericArgument) { - if let syn::GenericArgument::Type(syn::Type::Path(path)) = arg { - if let Some(new_arg) = self.try_resolve_generic_const(path) { - *arg = new_arg; - } else { - self.resolve(path); - } - } - } - - fn try_resolve_generic_const(&self, path: &syn::TypePath) -> Option { - if path.qself.is_none() && path.path.segments.len() == 1 { - let segment = path - .path - .segments - .first() - .expect("segments have exactly one element"); - if segment.arguments.is_none() { - let ident = segment.ident.to_string(); - if let Some(symbol) = self.symbols.get(&ident) { - if symbol.kind == VisibleSymbolKind::Const { - let path = &symbol.full_path; - return Some(syn::GenericArgument::Const( - syn::parse_str(path).expect("full_path should be a valid path"), - )); - } - } - } - } - - None - } -} - /// Helper struct to process already existing migrations. #[derive(Debug, Clone)] struct MigrationProcessor { @@ -886,7 +591,9 @@ impl MigrationProcessor { Ok(format!("m_{migration_number:04}_auto_{date_time}")) } - fn dependencies(&self) -> Vec { + /// Returns the list of dependencies for the next migration, based on the + /// already existing and processed migrations. + fn base_dependencies(&self) -> Vec { if self.migrations.is_empty() { return Vec::new(); } @@ -899,18 +606,22 @@ impl MigrationProcessor { } } -#[derive(Debug, Clone, PartialEq, Eq)] -struct ModelInSource { +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ModelInSource { model_item: syn::ItemStruct, model: Model, } impl ModelInSource { - fn from_item(item: syn::ItemStruct, args: &ModelArgs) -> anyhow::Result { + fn from_item( + item: syn::ItemStruct, + args: &ModelArgs, + symbol_resolver: &SymbolResolver, + ) -> anyhow::Result { let input: syn::DeriveInput = item.clone().into(); let opts = ModelOpts::new_from_derive_input(&input) .map_err(|e| anyhow::anyhow!("cannot parse model: {}", e))?; - let model = opts.as_model(args)?; + let model = opts.as_model(args, Some(symbol_resolver))?; Ok(Self { model_item: item, @@ -919,13 +630,24 @@ impl ModelInSource { } } +/// A migration generated by the CLI and before converting to a Rust +/// source code and writing to a file. #[derive(Debug, Clone)] -pub struct MigrationToWrite { +pub struct GeneratedMigration { + pub migration_name: String, + pub modified_models: Vec, + pub dependencies: Vec, + pub operations: Vec, +} + +/// A migration represented as a generated and ready to write Rust source code. +#[derive(Debug, Clone)] +pub struct MigrationAsSource { pub name: String, pub content: String, } -impl MigrationToWrite { +impl MigrationAsSource { #[must_use] pub fn new(name: String, content: String) -> Self { Self { name, content } @@ -954,7 +676,10 @@ impl Repr for Field { let mut tokens = quote! { ::flareon::db::migrations::Field::new(::flareon::db::Identifier::new(#column_name), <#ty as ::flareon::db::DatabaseField>::TYPE) }; - if self.auto_value { + if self + .auto_value + .expect("auto_value is expected to be present when parsing the entire file") + { tokens = quote! { #tokens.auto() } } if self.primary_key { @@ -998,7 +723,7 @@ impl DynMigration for Migration { /// /// This is used to generate migration files. #[derive(Debug, Clone, PartialEq, Eq, Hash)] -enum DynDependency { +pub enum DynDependency { Migration { app: String, migration: String }, Model { app: String, model_name: String }, } @@ -1024,8 +749,8 @@ impl Repr for DynDependency { /// runtime and is using codegen types. /// /// This is used to generate migration files. -#[derive(Debug, Clone)] -enum DynOperation { +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum DynOperation { CreateModel { table_name: String, fields: Vec, @@ -1101,8 +826,6 @@ impl Error for ParsingError {} #[cfg(test)] mod tests { - use quote::ToTokens; - use super::*; #[test] @@ -1119,7 +842,7 @@ mod tests { let migrations = vec![]; let processor = MigrationProcessor::new(migrations).unwrap(); - let next_migration_name = processor.dependencies(); + let next_migration_name = processor.base_dependencies(); assert_eq!(next_migration_name, vec![]); } @@ -1132,7 +855,7 @@ mod tests { }]; let processor = MigrationProcessor::new(migrations).unwrap(); - let next_migration_name = processor.dependencies(); + let next_migration_name = processor.base_dependencies(); assert_eq!( next_migration_name, vec![DynDependency::Migration { @@ -1141,147 +864,4 @@ mod tests { }] ); } - - #[test] - fn imports() { - let source = r" -use std::collections::HashMap; -use std::error::Error as StdError; -use std::fmt::{Debug, Display, Formatter}; -use std::fs::*; -use rand as r; -use super::MyModel; -use crate::MyOtherModel; -use self::MyThirdModel; - -struct MyFourthModel {} - -const MY_CONSTANT: u8 = 42; - "; - - let file = SourceFile::parse(PathBuf::from("foo/bar.rs").clone(), source).unwrap(); - let imports = - MigrationGenerator::get_imports(&file.content, &ModulePath::from_fs_path(&file.path)); - - let expected = vec![ - VisibleSymbol { - alias: "HashMap".to_string(), - full_path: "std::collections::HashMap".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "StdError".to_string(), - full_path: "std::error::Error".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "Debug".to_string(), - full_path: "std::fmt::Debug".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "Display".to_string(), - full_path: "std::fmt::Display".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "Formatter".to_string(), - full_path: "std::fmt::Formatter".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "r".to_string(), - full_path: "rand".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "MyModel".to_string(), - full_path: "crate::foo::MyModel".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "MyOtherModel".to_string(), - full_path: "crate::MyOtherModel".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "MyThirdModel".to_string(), - full_path: "crate::foo::bar::MyThirdModel".to_string(), - kind: VisibleSymbolKind::Use, - }, - VisibleSymbol { - alias: "MyFourthModel".to_string(), - full_path: "crate::foo::bar::MyFourthModel".to_string(), - kind: VisibleSymbolKind::Struct, - }, - VisibleSymbol { - alias: "MY_CONSTANT".to_string(), - full_path: "crate::foo::bar::MY_CONSTANT".to_string(), - kind: VisibleSymbolKind::Const, - }, - ]; - assert_eq!(imports, expected); - } - - #[test] - fn import_resolver() { - let resolver = SymbolResolver::new(vec![ - VisibleSymbol::new_use("MyType", "crate::models::MyType"), - VisibleSymbol::new_use("HashMap", "std::collections::HashMap"), - ]); - - let path = &mut parse_quote!(MyType); - resolver.resolve(path); - assert_eq!( - quote!(crate::models::MyType).to_string(), - path.into_token_stream().to_string() - ); - - let path = &mut parse_quote!(HashMap); - resolver.resolve(path); - assert_eq!( - quote!(std::collections::HashMap).to_string(), - path.into_token_stream().to_string() - ); - - let path = &mut parse_quote!(Option); - resolver.resolve(path); - assert_eq!( - quote!(Option).to_string(), - path.into_token_stream().to_string() - ); - } - - #[test] - fn import_resolver_resolve_struct() { - let resolver = SymbolResolver::new(vec![ - VisibleSymbol::new_use("MyType", "crate::models::MyType"), - VisibleSymbol::new_use("HashMap", "std::collections::HashMap"), - VisibleSymbol::new_use("LimitedString", "flareon::db::LimitedString"), - VisibleSymbol::new( - "MY_CONSTANT", - "crate::constants::MY_CONSTANT", - VisibleSymbolKind::Const, - ), - ]); - - let mut actual = parse_quote! { - struct Example { - field_1: MyType, - field_2: HashMap, - field_3: Option, - field_4: LimitedString, - } - }; - resolver.resolve_struct(&mut actual); - let expected = quote! { - struct Example { - field_1: crate::models::MyType, - field_2: std::collections::HashMap, - field_3: Option, - field_4: flareon::db::LimitedString<{ crate::constants::MY_CONSTANT }>, - } - }; - assert_eq!(actual.into_token_stream().to_string(), expected.to_string()); - } } diff --git a/flareon-cli/tests/migration_generator.rs b/flareon-cli/tests/migration_generator.rs index fadbc338..4efefffc 100644 --- a/flareon-cli/tests/migration_generator.rs +++ b/flareon-cli/tests/migration_generator.rs @@ -1,24 +1,61 @@ use std::path::PathBuf; use flareon_cli::migration_generator::{ - MigrationGenerator, MigrationGeneratorOptions, MigrationToWrite, SourceFile, + DynOperation, MigrationAsSource, MigrationGenerator, MigrationGeneratorOptions, SourceFile, }; -/// Test that the migration generator can generate a create model migration for -/// a given model which compiles successfully. +/// Test that the migration generator can generate a "create model" migration +/// for a given model that has an expected state. +#[test] +fn create_model_state_test() { + let mut generator = test_generator(); + let src = include_str!("migration_generator/create_model.rs"); + let source_files = vec![SourceFile::parse(PathBuf::from("main.rs"), src).unwrap()]; + + let migration = generator + .generate_migrations(source_files) + .unwrap() + .unwrap(); + + assert_eq!(migration.migration_name, "m_0001_initial"); + assert!(migration.dependencies.is_empty()); + if let DynOperation::CreateModel { table_name, fields } = &migration.operations[0] { + assert_eq!(table_name, "my_model"); + assert_eq!(fields.len(), 3); + + let field = &fields[0]; + assert_eq!(field.column_name, "id"); + assert!(field.primary_key); + assert!(field.auto_value.unwrap()); + assert!(!field.foreign_key.unwrap()); + + let field = &fields[1]; + assert_eq!(field.column_name, "field_1"); + assert!(!field.primary_key); + assert!(!field.auto_value.unwrap()); + assert!(!field.foreign_key.unwrap()); + + let field = &fields[2]; + assert_eq!(field.column_name, "field_2"); + assert!(!field.primary_key); + assert!(!field.auto_value.unwrap()); + assert!(!field.foreign_key.unwrap()); + } +} + +/// Test that the migration generator can generate a "create model" migration +/// for a given model which compiles successfully. #[test] #[cfg_attr(miri, ignore)] // unsupported operation: extern static `pidfd_spawnp` is not supported by Miri fn create_model_compile_test() { - let mut generator = MigrationGenerator::new( - PathBuf::from("Cargo.toml"), - String::from("my_crate"), - MigrationGeneratorOptions::default(), - ); + let mut generator = test_generator(); let src = include_str!("migration_generator/create_model.rs"); let source_files = vec![SourceFile::parse(PathBuf::from("main.rs"), src).unwrap()]; - let migration_opt = generator.generate_migrations(source_files).unwrap(); - let MigrationToWrite { + let migration_opt = generator + .generate_migrations_to_write(source_files) + .unwrap(); + let MigrationAsSource { name: migration_name, content: migration_content, } = migration_opt.unwrap(); @@ -41,3 +78,11 @@ mod migrations {{ let t = trybuild::TestCases::new(); t.pass(&test_path); } + +fn test_generator() -> MigrationGenerator { + MigrationGenerator::new( + PathBuf::from("Cargo.toml"), + String::from("my_crate"), + MigrationGeneratorOptions::default(), + ) +} diff --git a/flareon-cli/tests/migration_generator/create_model.rs b/flareon-cli/tests/migration_generator/create_model.rs index a249d4d1..fd19eab5 100644 --- a/flareon-cli/tests/migration_generator/create_model.rs +++ b/flareon-cli/tests/migration_generator/create_model.rs @@ -1,9 +1,10 @@ -use flareon::db::{model, LimitedString}; +use flareon::db::{model, Auto, LimitedString}; pub const FIELD_LEN: u32 = 64; #[model] struct MyModel { + id: Auto, field_1: String, field_2: LimitedString, } diff --git a/flareon-codegen/Cargo.toml b/flareon-codegen/Cargo.toml index 84f0ca48..97e7437e 100644 --- a/flareon-codegen/Cargo.toml +++ b/flareon-codegen/Cargo.toml @@ -11,9 +11,14 @@ workspace = true [dependencies] convert_case.workspace = true darling.workspace = true +log = { workspace = true, optional = true } proc-macro2.workspace = true quote.workspace = true syn.workspace = true [dev-dependencies] proc-macro2 = { workspace = true, features = ["span-locations"] } + +[features] +default = [] +symbol-resolver = ["dep:log"] diff --git a/flareon-codegen/src/expr.rs b/flareon-codegen/src/expr.rs index f946b6eb..806142f9 100644 --- a/flareon-codegen/src/expr.rs +++ b/flareon-codegen/src/expr.rs @@ -9,7 +9,9 @@ enum ItemToken { Field(FieldParser), Literal(syn::Lit), Ident(syn::Ident), - MethodCall(MethodCallParser), + MemberAccess(MemberAccessParser), + FunctionCall(FunctionCallParser), + Reference(ReferenceParser), Op(OpParser), } @@ -23,8 +25,12 @@ impl Parse for ItemToken { if lookahead.peek(Token![$]) { input.parse().map(ItemToken::Field) + } else if lookahead.peek(Token![&]) { + input.parse().map(ItemToken::Reference) } else if lookahead.peek(Token![.]) { - input.parse().map(ItemToken::MethodCall) + input.parse().map(ItemToken::MemberAccess) + } else if lookahead.peek(syn::token::Paren) { + input.parse().map(ItemToken::FunctionCall) } else if lookahead.peek(syn::Lit) { input.parse().map(ItemToken::Literal) } else if lookahead.peek(syn::Ident) { @@ -41,7 +47,9 @@ impl ItemToken { ItemToken::Field(field) => field.span(), ItemToken::Literal(lit) => lit.span(), ItemToken::Ident(ident) => ident.span(), - ItemToken::MethodCall(method_call) => method_call.span(), + ItemToken::MemberAccess(member_access) => member_access.span(), + ItemToken::FunctionCall(function_call) => function_call.span(), + ItemToken::Reference(reference) => reference.span(), ItemToken::Op(op) => op.span(), } } @@ -49,7 +57,7 @@ impl ItemToken { #[derive(Debug)] struct FieldParser { - _field_token: Token![$], + field_token: Token![$], name: syn::Ident, } @@ -63,34 +71,74 @@ impl FieldParser { impl Parse for FieldParser { fn parse(input: ParseStream) -> syn::Result { Ok(FieldParser { - _field_token: input.parse()?, + field_token: input.parse()?, name: input.parse()?, }) } } #[derive(Debug)] -struct MethodCallParser { - _dot: Token![.], - method_name: syn::Ident, - _paren_token: syn::token::Paren, +struct ReferenceParser { + reference_token: Token![&], + expr: syn::Expr, +} + +impl ReferenceParser { + #[must_use] + fn span(&self) -> proc_macro2::Span { + self.expr.span() + } +} + +impl Parse for ReferenceParser { + fn parse(input: ParseStream) -> syn::Result { + Ok(ReferenceParser { + reference_token: input.parse()?, + expr: input.parse()?, + }) + } +} + +#[derive(Debug)] +struct MemberAccessParser { + dot: Token![.], + member_name: syn::Ident, +} + +impl MemberAccessParser { + #[must_use] + fn span(&self) -> proc_macro2::Span { + self.member_name.span() + } +} + +impl Parse for MemberAccessParser { + fn parse(input: ParseStream) -> syn::Result { + Ok(Self { + dot: input.parse()?, + member_name: input.parse()?, + }) + } +} + +#[derive(Debug)] +struct FunctionCallParser { + paren_token: syn::token::Paren, args: syn::punctuated::Punctuated, } -impl MethodCallParser { +impl FunctionCallParser { #[must_use] fn span(&self) -> proc_macro2::Span { - self.method_name.span() + self.args.span() } } -impl Parse for MethodCallParser { +impl Parse for FunctionCallParser { fn parse(input: ParseStream) -> syn::Result { let args_content; Ok(Self { - _dot: input.parse()?, - method_name: input.parse()?, - _paren_token: syn::parenthesized!(args_content in input), + paren_token: syn::parenthesized!(args_content in input), args: args_content.parse_terminated(syn::Expr::parse, Token![,])?, }) } @@ -202,18 +250,25 @@ type InfixBindingPriority = BindingPriority; /// assert_eq!( /// expr, /// Expr::Eq( -/// Box::new(Expr::FieldRef(parse_quote!(field))), +/// Box::new(Expr::FieldRef { field_name: parse_quote!(field), field_token: parse_quote!($)}), /// Box::new(Expr::Value(parse_quote!(42))) /// ) /// ); /// ``` #[derive(Debug, PartialEq, Eq)] pub enum Expr { - FieldRef(syn::Ident), + FieldRef { + field_name: syn::Ident, + field_token: Token![$], + }, Value(syn::Expr), - MethodCall { - called_on: Box, - method_name: syn::Ident, + MemberAccess { + parent: Box, + member_name: syn::Ident, + member_access_token: Token![.], + }, + FunctionCall { + function: Box, args: Vec, }, And(Box, Box), @@ -247,7 +302,18 @@ impl Expr { let lhs_item = input.parse::()?; match lhs_item { - ItemToken::Field(field) => Expr::FieldRef(field.name), + ItemToken::Field(field) => Expr::FieldRef { + field_name: field.name, + field_token: field.field_token, + }, + ItemToken::Reference(reference) => { + Expr::Value(syn::Expr::Reference(syn::ExprReference { + attrs: Vec::new(), + and_token: reference.reference_token, + mutability: None, + expr: Box::new(reference.expr), + })) + } ItemToken::Ident(ident) => Expr::Value(syn::Expr::Path(syn::ExprPath { attrs: Vec::new(), qself: None, @@ -273,12 +339,19 @@ impl Expr { let op_item = input.fork().parse::()?; match op_item { - ItemToken::MethodCall(call) => { + ItemToken::MemberAccess(member_access) => { + input.parse::()?; + lhs = Expr::MemberAccess { + parent: Box::new(lhs), + member_name: member_access.member_name, + member_access_token: member_access.dot, + }; + } + ItemToken::FunctionCall(call) => { input.parse::()?; let args = call.args.into_iter().collect::>(); - lhs = Expr::MethodCall { - called_on: Box::new(lhs), - method_name: call.method_name, + lhs = Expr::FunctionCall { + function: Box::new(lhs), args, }; } @@ -321,61 +394,88 @@ impl Expr { #[must_use] pub fn as_tokens(&self) -> Option { + self.as_tokens_impl(ExprAsTokensMode::FieldRefAsNone) + } + + #[must_use] + pub fn as_tokens_full(&self) -> TokenStream { + self.as_tokens_impl(ExprAsTokensMode::Full) + .expect("Full mode should never return None") + } + + #[must_use] + fn as_tokens_impl(&self, mode: ExprAsTokensMode) -> Option { match self { - Expr::FieldRef(_) => None, + Expr::FieldRef { + field_name, + field_token, + } => match mode { + ExprAsTokensMode::FieldRefAsNone => None, + ExprAsTokensMode::Full => Some(quote! {#field_token #field_name}), + }, Expr::Value(expr) => Some(quote! {#expr}), - Expr::MethodCall { - called_on, - method_name, - args, + Expr::MemberAccess { + parent, + member_name, + member_access_token, } => { - let called_on_tokens = called_on.as_tokens()?; - Some(quote! {#called_on_tokens.#method_name(#(#args),*)}) + let parent_tokens = parent.as_tokens_impl(mode)?; + Some(quote! {#parent_tokens #member_access_token #member_name}) + } + Expr::FunctionCall { function, args } => { + let function_tokens = function.as_tokens_impl(mode)?; + Some(quote! {#function_tokens(#(#args),*)}) } Expr::And(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens && #rhs_tokens}) } Expr::Or(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens || #rhs_tokens}) } Expr::Eq(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens == #rhs_tokens}) } Expr::Ne(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens != #rhs_tokens}) } Expr::Add(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens + #rhs_tokens}) } Expr::Sub(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens - #rhs_tokens}) } Expr::Mul(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens * #rhs_tokens}) } Expr::Div(lhs, rhs) => { - let lhs_tokens = lhs.as_tokens()?; - let rhs_tokens = rhs.as_tokens()?; + let lhs_tokens = lhs.as_tokens_impl(mode)?; + let rhs_tokens = rhs.as_tokens_impl(mode)?; Some(quote! {#lhs_tokens / #rhs_tokens}) } } } } +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +enum ExprAsTokensMode { + FieldRefAsNone, + Full, +} + impl Parse for Expr { fn parse(input: ParseStream) -> syn::Result { Self::parse_impl(input, 0) @@ -393,7 +493,7 @@ mod tests { #[test] fn field_ref() { let input = quote! { $field }; - let expected = Expr::FieldRef(syn::Ident::new("field", span())); + let expected = field("field"); assert_eq!(expected, unwrap_syn(Expr::parse(input))); } @@ -410,7 +510,7 @@ mod tests { fn field_eq() { let input = quote! { $field == 42 }; let expected = Expr::Eq( - Box::new(Expr::FieldRef(syn::Ident::new("field", span()))), + Box::new(field("field")), Box::new(Expr::Value(parse_quote!(42))), ); @@ -439,11 +539,11 @@ mod tests { let input = quote! { $field == 42 && $field != 42 }; let expected = Expr::And( Box::new(Expr::Eq( - Box::new(Expr::FieldRef(syn::Ident::new("field", span()))), + Box::new(field("field")), Box::new(Expr::Value(parse_quote!(42))), )), Box::new(Expr::Ne( - Box::new(Expr::FieldRef(syn::Ident::new("field", span()))), + Box::new(field("field")), Box::new(Expr::Value(parse_quote!(42))), )), ); @@ -470,18 +570,52 @@ mod tests { assert_eq!(expected, unwrap_syn(Expr::parse(input))); } + #[test] + fn function_call() { + let input = quote! { $a == bar() }; + let expected = Expr::Eq( + Box::new(field("a")), + Box::new(Expr::FunctionCall { + function: Box::new(value("bar")), + args: Vec::new(), + }), + ); + + assert_eq!(expected, unwrap_syn(Expr::parse(input))); + } + + #[test] + fn parse_member_access() { + let input = quote! { $a == foo.bar }; + let expected = Expr::Eq( + Box::new(field("a")), + Box::new(member_access(value("foo"), "bar")), + ); + + assert_eq!(expected, unwrap_syn(Expr::parse(input))); + } + + #[test] + fn parse_reference() { + let input = quote! { &foo }; + let expected = reference("foo"); + + assert_eq!(expected, unwrap_syn(Expr::parse(input))); + } + #[test] fn method_call() { let input = quote! { $a == foo.bar().baz() }; let expected = Expr::Eq( Box::new(field("a")), - Box::new(Expr::MethodCall { - called_on: Box::new(Expr::MethodCall { - called_on: Box::new(value("foo")), - method_name: syn::Ident::new("bar", span()), - args: Vec::new(), - }), - method_name: syn::Ident::new("baz", span()), + Box::new(Expr::FunctionCall { + function: Box::new(member_access( + Expr::FunctionCall { + function: Box::new(member_access(value("foo"), "bar")), + args: Vec::new(), + }, + "baz", + )), args: Vec::new(), }), ); @@ -569,9 +703,35 @@ mod tests { assert_eq!(input.to_string(), expr.as_tokens().unwrap().to_string()); } + #[test] + fn tokens_full() { + let input = quote! { $name.len() }; + let expr = unwrap_syn(Expr::parse(input.clone())); + + assert_eq!(input.to_string(), expr.as_tokens_full().to_string()); + } + #[must_use] fn field(name: &str) -> Expr { - Expr::FieldRef(syn::Ident::new(name, span())) + Expr::FieldRef { + field_name: syn::Ident::new(name, span()), + field_token: Token![$](span()), + } + } + + #[must_use] + fn member_access(parent: Expr, member_name: &str) -> Expr { + Expr::MemberAccess { + parent: Box::new(parent), + member_name: syn::Ident::new(member_name, span()), + member_access_token: Token![.](span()), + } + } + + #[must_use] + fn reference(ident: &str) -> Expr { + let ident = syn::Ident::new(ident, span()); + Expr::Value(parse_quote!(&#ident)) } #[must_use] diff --git a/flareon-codegen/src/lib.rs b/flareon-codegen/src/lib.rs index db4772ea..8668603a 100644 --- a/flareon-codegen/src/lib.rs +++ b/flareon-codegen/src/lib.rs @@ -2,3 +2,19 @@ extern crate self as flareon_codegen; pub mod expr; pub mod model; +#[cfg(feature = "symbol-resolver")] +pub mod symbol_resolver; +#[cfg(not(feature = "symbol-resolver"))] +pub mod symbol_resolver { + /// Dummy SymbolResolver for use in contexts when it's not useful (e.g. + /// macros which do not have access to the entire source tree to look + /// for `use` statements anyway). + /// + /// This is defined as an empty enum so that it's entirely optimized out by + /// the compiler, along with all functions that reference it. + pub enum SymbolResolver {} + + impl SymbolResolver { + pub fn resolve(&self, _: &mut syn::Type) {} + } +} diff --git a/flareon-codegen/src/model.rs b/flareon-codegen/src/model.rs index a9e4e8fa..f9adb399 100644 --- a/flareon-codegen/src/model.rs +++ b/flareon-codegen/src/model.rs @@ -1,6 +1,8 @@ use convert_case::{Case, Casing}; use darling::{FromDeriveInput, FromField, FromMeta}; +use crate::symbol_resolver::SymbolResolver; + #[allow(clippy::module_name_repetitions)] #[derive(Debug, Default, FromMeta)] pub struct ModelArgs { @@ -59,8 +61,16 @@ impl ModelOpts { /// /// Returns an error if the model name does not start with an underscore /// when the model type is [`ModelType::Migration`]. - pub fn as_model(&self, args: &ModelArgs) -> Result { - let fields = self.fields().iter().map(|field| field.as_field()).collect(); + pub fn as_model( + &self, + args: &ModelArgs, + symbol_resolver: Option<&SymbolResolver>, + ) -> Result { + let fields: Vec<_> = self + .fields() + .iter() + .map(|field| field.as_field(symbol_resolver)) + .collect(); let mut original_name = self.ident.to_string(); if args.model_type == ModelType::Migration { @@ -80,14 +90,36 @@ impl ModelOpts { original_name.to_string().to_case(Case::Snake) }; + let primary_key_field = self.get_primary_key_field(&fields)?; + Ok(Model { name: self.ident.clone(), original_name, model_type: args.model_type, table_name, + pk_field: primary_key_field.clone(), fields, }) } + + fn get_primary_key_field<'a>(&self, fields: &'a [Field]) -> Result<&'a Field, syn::Error> { + let pks: Vec<_> = fields.iter().filter(|field| field.primary_key).collect(); + if pks.is_empty() { + return Err(syn::Error::new( + self.ident.span(), + "models must have a primary key field, either named `id` \ + or annotated with the `#[model(primary_key)]` attribute", + )); + } + if pks.len() > 1 { + return Err(syn::Error::new( + pks[1].field_name.span(), + "composite primary keys are not supported; only one primary key field is allowed", + )); + } + + Ok(pks[0]) + } } #[derive(Debug, Clone, FromField)] @@ -95,10 +127,50 @@ impl ModelOpts { pub struct FieldOpts { pub ident: Option, pub ty: syn::Type, + pub primary_key: darling::util::Flag, pub unique: darling::util::Flag, } impl FieldOpts { + #[must_use] + fn find_type(&self, type_to_check: &str, symbol_resolver: &SymbolResolver) -> bool { + let mut ty = self.ty.clone(); + symbol_resolver.resolve(&mut ty); + Self::inner_type_names(&ty) + .iter() + .any(|name| name == type_to_check) + } + + #[must_use] + fn inner_type_names(ty: &syn::Type) -> Vec { + let mut names = Vec::new(); + Self::inner_type_names_impl(ty, &mut names); + names + } + + fn inner_type_names_impl(ty: &syn::Type, names: &mut Vec) { + if let syn::Type::Path(type_path) = ty { + let name = type_path + .path + .segments + .iter() + .map(|s| s.ident.to_string()) + .collect::>() + .join("::"); + names.push(name); + + for arg in &type_path.path.segments { + if let syn::PathArguments::AngleBracketed(arg) = &arg.arguments { + for arg in &arg.args { + if let syn::GenericArgument::Type(ty) = arg { + Self::inner_type_names_impl(ty, names); + } + } + } + } + } + } + /// Convert the field options into a field. /// /// # Panics @@ -106,32 +178,37 @@ impl FieldOpts { /// Panics if the field does not have an identifier (i.e. it is a tuple /// struct). #[must_use] - pub fn as_field(&self) -> Field { + pub fn as_field(&self, symbol_resolver: Option<&SymbolResolver>) -> Field { let name = self.ident.as_ref().unwrap(); let column_name = name.to_string(); - // TODO define a separate type for auto fields - let is_auto = column_name == "id"; - // TODO define #[model(primary_key)] attribute - let is_primary_key = column_name == "id"; + let (auto_value, foreign_key) = match symbol_resolver { + Some(resolver) => ( + Some(self.find_type("flareon::db::Auto", resolver)), + Some(self.find_type("flareon::db::ForeignKey", resolver)), + ), + None => (None, None), + }; + let is_primary_key = column_name == "id" || self.primary_key.is_present(); Field { field_name: name.clone(), column_name, ty: self.ty.clone(), - auto_value: is_auto, + auto_value, primary_key: is_primary_key, - null: false, + foreign_key, unique: self.unique.is_present(), } } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone, PartialEq, Eq, Hash)] pub struct Model { pub name: syn::Ident, pub original_name: String, pub model_type: ModelType, pub table_name: String, + pub pk_field: Field, pub fields: Vec, } @@ -147,9 +224,13 @@ pub struct Field { pub field_name: syn::Ident, pub column_name: String, pub ty: syn::Type, - pub auto_value: bool, + /// Whether the field is an auto field (e.g. `id`); `None` if it could not + /// be determined. + pub auto_value: Option, pub primary_key: bool, - pub null: bool, + /// Whether the field is a foreign key; `None` if it could not be + /// determined. + pub foreign_key: Option, pub unique: bool, } @@ -158,6 +239,8 @@ mod tests { use syn::parse_quote; use super::*; + #[cfg(feature = "symbol-resolver")] + use crate::symbol_resolver::{VisibleSymbol, VisibleSymbolKind}; #[test] fn model_args_default() { @@ -197,7 +280,7 @@ mod tests { }; let opts = ModelOpts::new_from_derive_input(&input).unwrap(); let args = ModelArgs::default(); - let model = opts.as_model(&args).unwrap(); + let model = opts.as_model(&args, None).unwrap(); assert_eq!(model.name.to_string(), "TestModel"); assert_eq!(model.table_name, "test_model"); assert_eq!(model.fields.len(), 2); @@ -215,13 +298,67 @@ mod tests { }; let opts = ModelOpts::new_from_derive_input(&input).unwrap(); let args = ModelArgs::from_meta(&input.attrs.first().unwrap().meta).unwrap(); - let err = opts.as_model(&args).unwrap_err(); + let err = opts.as_model(&args, None).unwrap_err(); assert_eq!( err.to_string(), "migration model names must start with an underscore" ); } + #[test] + fn model_opts_as_model_pk_attr() { + let input: syn::DeriveInput = parse_quote! { + #[model] + struct TestModel { + #[model(primary_key)] + name: i32, + } + }; + let opts = ModelOpts::new_from_derive_input(&input).unwrap(); + let args = ModelArgs::default(); + let model = opts.as_model(&args, None).unwrap(); + assert_eq!(model.fields.len(), 1); + assert!(model.fields[0].primary_key); + } + + #[test] + fn model_opts_as_model_no_pk() { + let input: syn::DeriveInput = parse_quote! { + #[model] + struct TestModel { + name: String, + } + }; + let opts = ModelOpts::new_from_derive_input(&input).unwrap(); + let args = ModelArgs::default(); + let err = opts.as_model(&args, None).unwrap_err(); + assert_eq!( + err.to_string(), + "models must have a primary key field, either named `id` \ + or annotated with the `#[model(primary_key)]` attribute" + ); + } + + #[test] + fn model_opts_as_model_multiple_pks() { + let input: syn::DeriveInput = parse_quote! { + #[model] + struct TestModel { + id: i64, + #[model(primary_key)] + id_2: i64, + name: String, + } + }; + let opts = ModelOpts::new_from_derive_input(&input).unwrap(); + let args = ModelArgs::default(); + let err = opts.as_model(&args, None).unwrap_err(); + assert_eq!( + err.to_string(), + "composite primary keys are not supported; only one primary key field is allowed" + ); + } + #[test] fn field_opts_as_field() { let input: syn::Field = parse_quote! { @@ -229,10 +366,46 @@ mod tests { name: String }; let field_opts = FieldOpts::from_field(&input).unwrap(); - let field = field_opts.as_field(); + let field = field_opts.as_field(None); assert_eq!(field.field_name.to_string(), "name"); assert_eq!(field.column_name, "name"); assert_eq!(field.ty, parse_quote!(String)); assert!(field.unique); + assert_eq!(field.auto_value, None); + assert_eq!(field.foreign_key, None); + } + + #[test] + fn inner_type_names() { + let input: syn::Type = + parse_quote! { ::my_crate::MyContainer<'a, Vec> }; + let names = FieldOpts::inner_type_names(&input); + assert_eq!( + names, + vec!["my_crate::MyContainer", "Vec", "std::string::String"] + ); + } + + #[cfg(feature = "symbol-resolver")] + #[test] + fn contains_type() { + let symbols = vec![VisibleSymbol::new( + "MyContainer", + "my_crate::MyContainer", + VisibleSymbolKind::Use, + )]; + let resolver = SymbolResolver::new(symbols); + + let opts = FieldOpts { + ident: None, + ty: parse_quote! { MyContainer }, + primary_key: Default::default(), + unique: Default::default(), + }; + + assert!(opts.find_type("my_crate::MyContainer", &resolver)); + assert!(opts.find_type("std::string::String", &resolver)); + assert!(!opts.find_type("MyContainer", &resolver)); + assert!(!opts.find_type("String", &resolver)); } } diff --git a/flareon-codegen/src/symbol_resolver.rs b/flareon-codegen/src/symbol_resolver.rs new file mode 100644 index 00000000..94343204 --- /dev/null +++ b/flareon-codegen/src/symbol_resolver.rs @@ -0,0 +1,479 @@ +#![cfg(feature = "symbol-resolver")] + +use std::collections::HashMap; +use std::fmt::Display; +use std::iter::FromIterator; +use std::path::{Path, PathBuf}; + +use log::warn; +use quote::quote; +use syn::{parse_quote, UseTree}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct SymbolResolver { + /// List of imports in the format `"HashMap" -> VisibleSymbol` + symbols: HashMap, +} + +impl SymbolResolver { + #[must_use] + pub fn new(symbols: Vec) -> Self { + let mut symbol_map = HashMap::new(); + for symbol in symbols { + symbol_map.insert(symbol.alias.clone(), symbol); + } + + Self { + symbols: symbol_map, + } + } + + pub fn from_file(file: &syn::File, module_path: &Path) -> Self { + let imports = Self::get_imports(file, &ModulePath::from_fs_path(module_path)); + Self::new(imports) + } + + /// Return the list of top-level `use` statements, structs, and constants as + /// a list of [`VisibleSymbol`]s from the file. + fn get_imports(file: &syn::File, module_path: &ModulePath) -> Vec { + let mut imports = Vec::new(); + + for item in &file.items { + match item { + syn::Item::Use(item) => { + imports.append(&mut VisibleSymbol::from_item_use(item, module_path)); + } + syn::Item::Struct(item_struct) => { + imports.push(VisibleSymbol::from_item_struct(item_struct, module_path)); + } + syn::Item::Const(item_const) => { + imports.push(VisibleSymbol::from_item_const(item_const, module_path)); + } + _ => {} + } + } + + imports + } + + pub fn resolve_struct(&self, item: &mut syn::ItemStruct) { + for field in &mut item.fields { + self.resolve(&mut field.ty); + } + } + + pub fn resolve(&self, ty: &mut syn::Type) { + if let syn::Type::Path(path) = ty { + self.resolve_type_path(path); + } + } + + /// Checks the provided `TypePath` and resolves the full type path, if + /// available. + fn resolve_type_path(&self, path: &mut syn::TypePath) { + let first_segment = path.path.segments.first(); + + if let Some(first_segment) = first_segment { + if let Some(symbol) = self.symbols.get(&first_segment.ident.to_string()) { + let mut new_segments: Vec<_> = symbol + .full_path_parts() + .map(|s| syn::PathSegment { + ident: syn::Ident::new(s, first_segment.ident.span()), + arguments: syn::PathArguments::None, + }) + .collect(); + + let first_arguments = first_segment.arguments.clone(); + new_segments + .last_mut() + .expect("new_segments must have at least one element") + .arguments = first_arguments; + + new_segments.extend(path.path.segments.iter().skip(1).cloned()); + path.path.segments = syn::punctuated::Punctuated::from_iter(new_segments); + } + + for segment in &mut path.path.segments { + self.resolve_path_arguments(&mut segment.arguments); + } + } + } + + fn resolve_path_arguments(&self, arguments: &mut syn::PathArguments) { + if let syn::PathArguments::AngleBracketed(args) = arguments { + for arg in &mut args.args { + self.resolve_generic_argument(arg); + } + } + } + + fn resolve_generic_argument(&self, arg: &mut syn::GenericArgument) { + if let syn::GenericArgument::Type(syn::Type::Path(path)) = arg { + if let Some(new_arg) = self.try_resolve_generic_const(path) { + *arg = new_arg; + } else { + self.resolve_type_path(path); + } + } + } + + fn try_resolve_generic_const(&self, path: &syn::TypePath) -> Option { + if path.qself.is_none() && path.path.segments.len() == 1 { + let segment = path + .path + .segments + .first() + .expect("segments have exactly one element"); + if segment.arguments.is_none() { + let ident = segment.ident.to_string(); + if let Some(symbol) = self.symbols.get(&ident) { + if symbol.kind == VisibleSymbolKind::Const { + let path = &symbol.full_path; + return Some(syn::GenericArgument::Const( + syn::parse_str(path).expect("full_path should be a valid path"), + )); + } + } + } + } + + None + } +} + +/// Represents a symbol visible in the current module. This might mean there is +/// a `use` statement for a given type, but also, for instance, the type is +/// defined in the current module. +/// +/// For instance, for `use std::collections::HashMap;` the `VisibleSymbol ` +/// would be: +/// ``` +/// use flareon_codegen::symbol_resolver::{VisibleSymbol, VisibleSymbolKind}; +/// +/// let _ = VisibleSymbol { +/// alias: String::from("HashMap"), +/// full_path: String::from("std::collections::HashMap"), +/// kind: VisibleSymbolKind::Use, +/// }; +/// ``` +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct VisibleSymbol { + pub alias: String, + pub full_path: String, + pub kind: VisibleSymbolKind, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum VisibleSymbolKind { + Use, + Struct, + Const, +} + +impl VisibleSymbol { + #[must_use] + pub fn new(alias: &str, full_path: &str, kind: VisibleSymbolKind) -> Self { + assert_ne!(alias, "", "alias must not be empty"); + assert!(!alias.contains("::"), "alias must not contain '::'"); + Self { + alias: alias.to_string(), + full_path: full_path.to_string(), + kind, + } + } + + fn full_path_parts(&self) -> impl Iterator { + self.full_path.split("::") + } + + fn new_use(alias: &str, full_path: &str) -> Self { + Self::new(alias, full_path, VisibleSymbolKind::Use) + } + + fn from_item_use(item: &syn::ItemUse, module_path: &ModulePath) -> Vec { + Self::from_tree(&item.tree, module_path) + } + + fn from_item_struct(item: &syn::ItemStruct, module_path: &ModulePath) -> Self { + let ident = item.ident.to_string(); + let full_path = Self::module_path(module_path, &ident); + + Self { + alias: ident, + full_path, + kind: VisibleSymbolKind::Struct, + } + } + + fn from_item_const(item: &syn::ItemConst, module_path: &ModulePath) -> Self { + let ident = item.ident.to_string(); + let full_path = Self::module_path(module_path, &ident); + + Self { + alias: ident, + full_path, + kind: VisibleSymbolKind::Const, + } + } + + fn module_path(module_path: &ModulePath, ident: &str) -> String { + format!("{module_path}::{ident}") + } + + fn from_tree(tree: &UseTree, current_module: &ModulePath) -> Vec { + match tree { + UseTree::Path(path) => { + let ident = path.ident.to_string(); + let resolved_path = if ident == "crate" { + current_module.crate_name().to_string() + } else if ident == "self" { + current_module.to_string() + } else if ident == "super" { + current_module.parent().to_string() + } else { + ident + }; + + return Self::from_tree(&path.tree, current_module) + .into_iter() + .map(|import| { + Self::new_use( + &import.alias, + &format!("{}::{}", resolved_path, import.full_path), + ) + }) + .collect(); + } + UseTree::Name(name) => { + let ident = name.ident.to_string(); + return vec![Self::new_use(&ident, &ident)]; + } + UseTree::Rename(rename) => { + return vec![Self::new_use( + &rename.rename.to_string(), + &rename.ident.to_string(), + )]; + } + UseTree::Glob(_) => { + warn!("Glob imports are not supported"); + } + UseTree::Group(group) => { + return group + .items + .iter() + .flat_map(|tree| Self::from_tree(tree, current_module)) + .collect(); + } + } + + vec![] + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct ModulePath { + parts: Vec, +} + +impl ModulePath { + #[must_use] + pub fn from_fs_path(path: &Path) -> Self { + let mut parts = vec![String::from("crate")]; + + if path == Path::new("lib.rs") || path == Path::new("main.rs") { + return Self { parts }; + } + + parts.append( + &mut path + .components() + .map(|c| { + let component_str = c.as_os_str().to_string_lossy(); + component_str + .strip_suffix(".rs") + .unwrap_or(&component_str) + .to_string() + }) + .collect::>(), + ); + + if parts + .last() + .expect("parts must have at least one component") + == "mod" + { + parts.pop(); + } + + Self { parts } + } + + #[must_use] + fn parent(&self) -> Self { + let mut parts = self.parts.clone(); + parts.pop(); + Self { parts } + } + + #[must_use] + fn crate_name(&self) -> &str { + &self.parts[0] + } +} + +impl Display for ModulePath { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.parts.join("::")) + } +} + +#[cfg(test)] +mod tests { + use flareon_codegen::symbol_resolver::VisibleSymbolKind; + use quote::ToTokens; + + use super::*; + + #[test] + fn imports() { + let source = r" +use std::collections::HashMap; +use std::error::Error as StdError; +use std::fmt::{Debug, Display, Formatter}; +use std::fs::*; +use rand as r; +use super::MyModel; +use crate::MyOtherModel; +use self::MyThirdModel; + +struct MyFourthModel {} + +const MY_CONSTANT: u8 = 42; + "; + + let file = syn::parse_file(source).unwrap(); + let imports = + SymbolResolver::get_imports(&file, &ModulePath::from_fs_path(Path::new("foo/bar.rs"))); + + let expected = vec![ + VisibleSymbol { + alias: "HashMap".to_string(), + full_path: "std::collections::HashMap".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "StdError".to_string(), + full_path: "std::error::Error".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "Debug".to_string(), + full_path: "std::fmt::Debug".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "Display".to_string(), + full_path: "std::fmt::Display".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "Formatter".to_string(), + full_path: "std::fmt::Formatter".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "r".to_string(), + full_path: "rand".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "MyModel".to_string(), + full_path: "crate::foo::MyModel".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "MyOtherModel".to_string(), + full_path: "crate::MyOtherModel".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "MyThirdModel".to_string(), + full_path: "crate::foo::bar::MyThirdModel".to_string(), + kind: VisibleSymbolKind::Use, + }, + VisibleSymbol { + alias: "MyFourthModel".to_string(), + full_path: "crate::foo::bar::MyFourthModel".to_string(), + kind: VisibleSymbolKind::Struct, + }, + VisibleSymbol { + alias: "MY_CONSTANT".to_string(), + full_path: "crate::foo::bar::MY_CONSTANT".to_string(), + kind: VisibleSymbolKind::Const, + }, + ]; + assert_eq!(imports, expected); + } + + #[test] + fn import_resolver() { + let resolver = SymbolResolver::new(vec![ + VisibleSymbol::new_use("MyType", "crate::models::MyType"), + VisibleSymbol::new_use("HashMap", "std::collections::HashMap"), + ]); + + let path = &mut parse_quote!(MyType); + resolver.resolve_type_path(path); + assert_eq!( + quote!(crate::models::MyType).to_string(), + path.into_token_stream().to_string() + ); + + let path = &mut parse_quote!(HashMap); + resolver.resolve_type_path(path); + assert_eq!( + quote!(std::collections::HashMap).to_string(), + path.into_token_stream().to_string() + ); + + let path = &mut parse_quote!(Option); + resolver.resolve_type_path(path); + assert_eq!( + quote!(Option).to_string(), + path.into_token_stream().to_string() + ); + } + + #[test] + fn import_resolver_resolve_struct() { + let resolver = SymbolResolver::new(vec![ + VisibleSymbol::new_use("MyType", "crate::models::MyType"), + VisibleSymbol::new_use("HashMap", "std::collections::HashMap"), + VisibleSymbol::new_use("LimitedString", "flareon::db::LimitedString"), + VisibleSymbol::new( + "MY_CONSTANT", + "crate::constants::MY_CONSTANT", + VisibleSymbolKind::Const, + ), + ]); + + let mut actual = parse_quote! { + struct Example { + field_1: MyType, + field_2: HashMap, + field_3: Option, + field_4: LimitedString, + } + }; + resolver.resolve_struct(&mut actual); + let expected = quote! { + struct Example { + field_1: crate::models::MyType, + field_2: std::collections::HashMap, + field_3: Option, + field_4: flareon::db::LimitedString<{ crate::constants::MY_CONSTANT }>, + } + }; + assert_eq!(actual.into_token_stream().to_string(), expected.to_string()); + } +} diff --git a/flareon-macros/src/model.rs b/flareon-macros/src/model.rs index fc0ebc92..ce328290 100644 --- a/flareon-macros/src/model.rs +++ b/flareon-macros/src/model.rs @@ -27,7 +27,7 @@ pub(super) fn impl_model_for_struct( } }; - let model = match opts.as_model(&args) { + let model = match opts.as_model(&args, None) { Ok(val) => val, Err(err) => { return err.to_compile_error(); @@ -71,9 +71,11 @@ fn remove_helper_field_attributes(fields: &mut syn::Fields) -> &Punctuated, fields_as_from_db: Vec, + fields_as_update_from_db: Vec, fields_as_get_values: Vec, fields_as_field_refs: Vec, } @@ -91,9 +93,11 @@ impl ModelBuilder { let mut model_builder = Self { name: model.name.clone(), table_name: model.table_name, + pk_field: model.pk_field.clone(), fields_struct_name: format_ident!("{}Fields", model.name), fields_as_columns: Vec::with_capacity(field_count), fields_as_from_db: Vec::with_capacity(field_count), + fields_as_update_from_db: Vec::with_capacity(field_count), fields_as_get_values: Vec::with_capacity(field_count), fields_as_field_refs: Vec::with_capacity(field_count), }; @@ -113,18 +117,9 @@ impl ModelBuilder { let column_name = &field.column_name; { - let mut field_as_column = quote!(#orm_ident::Column::new( + let field_as_column = quote!(#orm_ident::Column::new( #orm_ident::Identifier::new(#column_name) )); - if field.auto_value { - field_as_column.append_all(quote!(.auto())); - } - if field.null { - field_as_column.append_all(quote!(.null())); - } - if field.unique { - field_as_column.append_all(quote!(.unique())); - } self.fields_as_columns.push(field_as_column); } @@ -132,8 +127,12 @@ impl ModelBuilder { #name: db_row.get::<#ty>(#index)? )); + self.fields_as_update_from_db.push(quote!( + #index => { self.#name = db_row.get::<#ty>(row_field_id)?; } + )); + self.fields_as_get_values.push(quote!( - #index => &self.#name as &dyn #orm_ident::ToDbValue + #index => &self.#name as &dyn #orm_ident::ToDbFieldValue )); self.fields_as_field_refs.push(quote!( @@ -144,24 +143,40 @@ impl ModelBuilder { #[must_use] fn build_model_impl(&self) -> TokenStream { + let crate_ident = flareon_ident(); let orm_ident = orm_ident(); let name = &self.name; let table_name = &self.table_name; let fields_struct_name = &self.fields_struct_name; let fields_as_columns = &self.fields_as_columns; + let pk_field_name = &self.pk_field.field_name; + let pk_column_name = &self.pk_field.column_name; + let pk_type = &self.pk_field.ty; let fields_as_from_db = &self.fields_as_from_db; + let fields_as_update_from_db = &self.fields_as_update_from_db; let fields_as_get_values = &self.fields_as_get_values; quote! { + #[#crate_ident::__private::async_trait] #[automatically_derived] impl #orm_ident::Model for #name { type Fields = #fields_struct_name; + type PrimaryKey = #pk_type; const COLUMNS: &'static [#orm_ident::Column] = &[ #(#fields_as_columns,)* ]; const TABLE_NAME: #orm_ident::Identifier = #orm_ident::Identifier::new(#table_name); + const PRIMARY_KEY_NAME: #orm_ident::Identifier = #orm_ident::Identifier::new(#pk_column_name); + + fn primary_key(&self) -> &Self::PrimaryKey { + &self.#pk_field_name + } + + fn set_primary_key(&mut self, primary_key: Self::PrimaryKey) { + self.#pk_field_name = primary_key; + } fn from_db(db_row: #orm_ident::Row) -> #orm_ident::Result { Ok(Self { @@ -169,7 +184,18 @@ impl ModelBuilder { }) } - fn get_values(&self, columns: &[usize]) -> Vec<&dyn #orm_ident::ToDbValue> { + fn update_from_db(&mut self, db_row: #orm_ident::Row, columns: &[usize]) -> #orm_ident::Result<()> { + for (row_field_id, column_id) in columns.into_iter().enumerate() { + match *column_id { + #(#fields_as_update_from_db,)* + _ => panic!("Unknown column index: {}", column_id), + } + } + + Ok(()) + } + + fn get_values(&self, columns: &[usize]) -> Vec<&dyn #orm_ident::ToDbFieldValue> { columns .iter() .map(|&column| match column { @@ -178,6 +204,15 @@ impl ModelBuilder { }) .collect() } + + async fn get_by_primary_key( + db: &DB, + pk: Self::PrimaryKey, + ) -> #orm_ident::Result> { + #orm_ident::query!(Self, $#pk_field_name == pk) + .get(db) + .await + } } } } diff --git a/flareon-macros/src/query.rs b/flareon-macros/src/query.rs index 32a3f1b6..5e0b455b 100644 --- a/flareon-macros/src/query.rs +++ b/flareon-macros/src/query.rs @@ -40,23 +40,33 @@ pub(super) fn expr_to_tokens(model_name: &syn::Type, expr: Expr) -> TokenStream let crate_name = flareon_ident(); match expr { - Expr::FieldRef(name) => { - quote!(<#model_name as #crate_name::db::Model>::Fields::#name.as_expr()) + Expr::FieldRef { field_name, .. } => { + quote!(<#model_name as #crate_name::db::Model>::Fields::#field_name.as_expr()) } Expr::Value(value) => { quote!(#crate_name::db::query::Expr::value(#value)) } - Expr::MethodCall { - called_on, - method_name, - args, - } => match *called_on { - Expr::Value(syn_expr) => { - quote!(#crate_name::db::query::Expr::value(#syn_expr.#method_name(#(#args),*))) + Expr::MemberAccess { + parent, + member_name, + .. + } => match parent.as_tokens() { + Some(tokens) => { + quote!(#crate_name::db::query::Expr::value(#tokens.#member_name)) } - _ => syn::Error::new( - method_name.span(), - "only method calls on values are supported", + None => syn::Error::new_spanned( + parent.as_tokens_full(), + "accessing members of values that reference database fields is unsupported", + ) + .to_compile_error(), + }, + Expr::FunctionCall { function, args } => match function.as_tokens() { + Some(tokens) => { + quote!(#crate_name::db::query::Expr::value(#tokens(#(#args),*))) + } + None => syn::Error::new_spanned( + function.as_tokens_full(), + "calling functions that reference database fields is unsupported", ) .to_compile_error(), }, @@ -90,9 +100,9 @@ fn handle_binary_comparison( let bin_fn = format_ident!("{}", bin_fn); let bin_trait = format_ident!("{}", bin_trait); - if let Expr::FieldRef(ref field) = lhs { + if let Expr::FieldRef { ref field_name, .. } = lhs { if let Some(rhs_tokens) = rhs.as_tokens() { - return quote!(#crate_name::db::query::#bin_trait::#bin_fn(<#model_name as #crate_name::db::Model>::Fields::#field, #rhs_tokens)); + return quote!(#crate_name::db::query::#bin_trait::#bin_fn(<#model_name as #crate_name::db::Model>::Fields::#field_name, #rhs_tokens)); } } diff --git a/flareon-macros/tests/compile_tests.rs b/flareon-macros/tests/compile_tests.rs index e51a8092..782b47ad 100644 --- a/flareon-macros/tests/compile_tests.rs +++ b/flareon-macros/tests/compile_tests.rs @@ -16,6 +16,8 @@ fn attr_model() { t.compile_fail("tests/ui/attr_model_tuple.rs"); t.compile_fail("tests/ui/attr_model_enum.rs"); t.compile_fail("tests/ui/attr_model_generic.rs"); + t.compile_fail("tests/ui/attr_model_no_pk.rs"); + t.compile_fail("tests/ui/attr_model_multiple_pks.rs"); } #[rustversion::attr(not(nightly), ignore)] @@ -28,6 +30,7 @@ fn func_query() { t.compile_fail("tests/ui/func_query_starting_op.rs"); t.compile_fail("tests/ui/func_query_double_field.rs"); t.compile_fail("tests/ui/func_query_invalid_field.rs"); + t.compile_fail("tests/ui/func_query_method_call_on_db_field.rs"); } #[rustversion::attr(not(nightly), ignore)] diff --git a/flareon-macros/tests/ui/attr_model_multiple_pks.rs b/flareon-macros/tests/ui/attr_model_multiple_pks.rs new file mode 100644 index 00000000..614ca7f6 --- /dev/null +++ b/flareon-macros/tests/ui/attr_model_multiple_pks.rs @@ -0,0 +1,11 @@ +use flareon::db::model; + +#[model] +struct MyModel { + id: i64, + #[model(primary_key)] + id_2: i64, + name: String, +} + +fn main() {} diff --git a/flareon-macros/tests/ui/attr_model_multiple_pks.stderr b/flareon-macros/tests/ui/attr_model_multiple_pks.stderr new file mode 100644 index 00000000..c21fb84a --- /dev/null +++ b/flareon-macros/tests/ui/attr_model_multiple_pks.stderr @@ -0,0 +1,5 @@ +error: composite primary keys are not supported; only one primary key field is allowed + --> tests/ui/attr_model_multiple_pks.rs:7:5 + | +7 | id_2: i64, + | ^^^^ diff --git a/flareon-macros/tests/ui/attr_model_no_pk.rs b/flareon-macros/tests/ui/attr_model_no_pk.rs new file mode 100644 index 00000000..4c8114d6 --- /dev/null +++ b/flareon-macros/tests/ui/attr_model_no_pk.rs @@ -0,0 +1,8 @@ +use flareon::db::model; + +#[model] +struct MyModel { + name: std::string::String, +} + +fn main() {} diff --git a/flareon-macros/tests/ui/attr_model_no_pk.stderr b/flareon-macros/tests/ui/attr_model_no_pk.stderr new file mode 100644 index 00000000..528251ab --- /dev/null +++ b/flareon-macros/tests/ui/attr_model_no_pk.stderr @@ -0,0 +1,5 @@ +error: models must have a primary key field, either named `id` or annotated with the `#[model(primary_key)]` attribute + --> tests/ui/attr_model_no_pk.rs:4:8 + | +4 | struct MyModel { + | ^^^^^^^ diff --git a/flareon-macros/tests/ui/func_query_method_call_on_db_field.rs b/flareon-macros/tests/ui/func_query_method_call_on_db_field.rs new file mode 100644 index 00000000..51dc7a7a --- /dev/null +++ b/flareon-macros/tests/ui/func_query_method_call_on_db_field.rs @@ -0,0 +1,14 @@ +use flareon::db::{model, query}; + +#[derive(Debug)] +#[model] +struct MyModel { + id: i32, + name: std::string::String, + description: String, + visits: i32, +} + +fn main() { + query!(MyModel, $name.len); +} diff --git a/flareon-macros/tests/ui/func_query_method_call_on_db_field.stderr b/flareon-macros/tests/ui/func_query_method_call_on_db_field.stderr new file mode 100644 index 00000000..c784f04a --- /dev/null +++ b/flareon-macros/tests/ui/func_query_method_call_on_db_field.stderr @@ -0,0 +1,5 @@ +error: accessing members of values that reference database fields is unsupported + --> tests/ui/func_query_method_call_on_db_field.rs:13:21 + | +13 | query!(MyModel, $name.len); + | ^^^^^ diff --git a/flareon/Cargo.toml b/flareon/Cargo.toml index 536f5fd1..22f5475e 100644 --- a/flareon/Cargo.toml +++ b/flareon/Cargo.toml @@ -44,7 +44,11 @@ time.workspace = true tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } tower = { workspace = true, features = ["util"] } tower-sessions = { workspace = true, features = ["memory-store"] } +<<<<<<< HEAD tracing.workspace = true +======= +env_logger = "0.11.5" +>>>>>>> 4db9303 (feat(orm): add foreign key support) [dev-dependencies] async-stream.workspace = true diff --git a/flareon/src/auth.rs b/flareon/src/auth.rs index 28efb146..6fe88d71 100644 --- a/flareon/src/auth.rs +++ b/flareon/src/auth.rs @@ -23,6 +23,7 @@ use subtle::ConstantTimeEq; use thiserror::Error; use crate::config::SecretKey; +use crate::db::DbValue; #[cfg(feature = "db")] use crate::db::{ColumnType, DatabaseField, FromDbValue, SqlxValueRef, ToDbValue}; use crate::request::{Request, RequestExt}; @@ -433,7 +434,7 @@ impl FromDbValue for PasswordHash { #[cfg(feature = "db")] impl ToDbValue for PasswordHash { - fn to_sea_query_value(&self) -> sea_query::Value { + fn to_db_value(&self) -> DbValue { self.0.clone().into() } } diff --git a/flareon/src/db.rs b/flareon/src/db.rs index 700a2f43..3611ae22 100644 --- a/flareon/src/db.rs +++ b/flareon/src/db.rs @@ -12,6 +12,7 @@ pub mod impl_postgres; pub mod impl_sqlite; pub mod migrations; pub mod query; +mod relations; mod sea_query_db; use std::fmt::Write; @@ -23,7 +24,8 @@ pub use flareon_macros::{model, query}; #[cfg(test)] use mockall::automock; use query::Query; -use sea_query::{Iden, SchemaStatementBuilder, SimpleExpr}; +pub use relations::{ForeignKey, ForeignKeyOnDeletePolicy, ForeignKeyOnUpdatePolicy}; +use sea_query::{Iden, IntoColumnRef, ReturningClause, SchemaStatementBuilder, SimpleExpr}; use sea_query_binder::{SqlxBinder, SqlxValues}; use sqlx::{Type, TypeInfo}; use thiserror::Error; @@ -57,6 +59,13 @@ pub enum DatabaseError { /// Error when applying migrations. #[error("Error when applying migrations: {0}")] MigrationError(#[from] migrations::MigrationEngineError), + /// Foreign Key could not be retrieved from the database because the record + /// was not found. + #[error("Error retrieving a Foreign Key from the database: record not found")] + ForeignKeyNotFound, + /// Primary key could not be converted from i64 using [`TryFromI64`] trait. + #[error("Primary key could not be converted from i64")] + PrimaryKeyFromI64Error, } impl DatabaseError { @@ -100,9 +109,15 @@ pub trait Model: Sized + Send + 'static { /// Rust. type Fields; + /// The primary key type of the model. + type PrimaryKey: PrimaryKey; + /// The name of the table in the database. const TABLE_NAME: Identifier; + /// The name of the primary key column in the database. + const PRIMARY_KEY_NAME: Identifier; + /// The columns of the model. const COLUMNS: &'static [Column]; @@ -114,8 +129,17 @@ pub trait Model: Sized + Send + 'static { /// with the model. fn from_db(db_row: Row) -> Result; + fn update_from_db(&mut self, db_row: Row, columns: &[usize]) -> Result<()>; + + /// Returns the primary key of the model. + fn primary_key(&self) -> &Self::PrimaryKey; + + /// Used by the ORM to set the primary key of the model after it has been + /// saved to the database. + fn set_primary_key(&mut self, primary_key: Self::PrimaryKey); + /// Gets the values of the model for the given columns. - fn get_values(&self, columns: &[usize]) -> Vec<&dyn ToDbValue>; + fn get_values(&self, columns: &[usize]) -> Vec<&dyn ToDbFieldValue>; /// Returns a query for all objects of this model. #[must_use] @@ -123,6 +147,11 @@ pub trait Model: Sized + Send + 'static { Query::new() } + async fn get_by_primary_key( + db: &DB, + pk: Self::PrimaryKey, + ) -> Result>; + /// Saves the model to the database. /// /// # Errors @@ -175,45 +204,18 @@ impl Iden for &Identifier { #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct Column { name: Identifier, - auto_value: bool, - unique: bool, - null: bool, } impl Column { /// Creates a new column with the given name. #[must_use] pub const fn new(name: Identifier) -> Self { - Self { - name, - auto_value: false, - unique: false, - null: false, - } - } - - /// Marks the column as auto-increment. - #[must_use] - pub const fn auto(mut self) -> Self { - self.auto_value = true; - self - } - - /// Marks the column unique. - #[must_use] - pub const fn unique(mut self) -> Self { - self.unique = true; - self - } - - /// Marks the column as nullable. - #[must_use] - pub const fn null(mut self) -> Self { - self.null = true; - self + Self { name } } } +pub trait PrimaryKey: DatabaseField + Clone {} + /// A row structure that holds the data of a single row retrieved from the /// database. #[non_exhaustive] @@ -259,7 +261,7 @@ impl Row { } /// A trait denoting that some type can be used as a field in a database. -pub trait DatabaseField: FromDbValue + ToDbValue { +pub trait DatabaseField: FromDbValue + ToDbFieldValue { const NULLABLE: bool = false; /// The type of the column in the database as one of the variants of @@ -318,18 +320,70 @@ pub trait FromDbValue { Self: Sized; } +pub type DbValue = sea_query::Value; + /// A trait for converting a Rust value to a database value. pub trait ToDbValue: Send + Sync { /// Converts the Rust value to a `sea_query` value. /// /// This method is used to convert the Rust value to a value that can be /// used in a query. - fn to_sea_query_value(&self) -> sea_query::Value; + fn to_db_value(&self) -> DbValue; +} + +pub trait ToDbFieldValue { + fn to_db_field_value(&self) -> DbFieldValue; +} + +#[derive(Debug, Clone, PartialEq)] +pub enum DbFieldValue { + /// The value should be automatically generated by the database and not + /// included in the query. + Auto, + /// A value that should be included in the query. + Value(DbValue), +} + +impl DbFieldValue { + #[must_use] + pub fn is_auto(&self) -> bool { + matches!(self, Self::Auto) + } + + #[must_use] + pub fn is_value(&self) -> bool { + matches!(self, Self::Value(_)) + } + + #[must_use] + pub fn unwrap_value(self) -> sea_query::Value { + self.expect_value("called DbValue::unwrap_value() on a wrong DbValue variant") + } + + #[must_use] + pub fn expect_value(self, message: &str) -> sea_query::Value { + match self { + Self::Value(value) => value, + _ => panic!("{message}"), + } + } +} + +impl ToDbFieldValue for T { + fn to_db_field_value(&self) -> DbFieldValue { + DbFieldValue::Value(self.to_db_value()) + } +} + +impl> From for DbFieldValue { + fn from(value: T) -> Self { + Self::Value(value.into()) + } } impl ToDbValue for &T { - fn to_sea_query_value(&self) -> sea_query::Value { - (*self).to_sea_query_value() + fn to_db_value(&self) -> DbValue { + (*self).to_db_value() } } @@ -483,40 +537,75 @@ impl Database { /// the database, for instance because the migrations haven't been /// applied, or there was a problem with the database connection. pub async fn insert(&self, data: &mut T) -> Result<()> { - let non_auto_column_identifiers = T::COLUMNS + let column_identifiers = T::COLUMNS .iter() - .filter_map(|column| { - if column.auto_value { - None - } else { - Some(Identifier::from(column.name.as_str())) - } - }) - .collect::>(); - let value_indices = T::COLUMNS + .map(|column| Identifier::from(column.name.as_str())); + let value_indices: Vec<_> = T::COLUMNS .iter() .enumerate() - .filter_map(|(i, column)| if column.auto_value { None } else { Some(i) }) - .collect::>(); - let values = data.get_values(&value_indices); + .map(|(i, _column)| i) + .collect(); + let values = data + .get_values(&value_indices) + .into_iter() + .map(ToDbFieldValue::to_db_field_value); + + let mut auto_col_ids = Vec::new(); + let mut auto_col_identifiers = Vec::new(); + let mut value_identifiers = Vec::new(); + let mut filtered_values = Vec::new(); + std::iter::zip(std::iter::zip(value_indices, column_identifiers), values).for_each( + |((index, identifier), value)| match value { + DbFieldValue::Auto => { + auto_col_ids.push(index); + auto_col_identifiers.push(identifier.into_column_ref()); + } + DbFieldValue::Value(value) => { + value_identifiers.push(identifier); + filtered_values.push(value); + } + }, + ); - let insert_statement = sea_query::Query::insert() + let mut insert_statement = sea_query::Query::insert() .into_table(T::TABLE_NAME) - .columns(non_auto_column_identifiers) + .columns(value_identifiers) .values( - values + filtered_values .into_iter() - .map(|value| SimpleExpr::Value(value.to_sea_query_value())) + .map(|value| SimpleExpr::Value(value)) .collect::>(), )? + .or_default_values() .to_owned(); - let statement_result = self.execute_statement(&insert_statement).await?; + if !auto_col_ids.is_empty() { + let row = if self.supports_returning() { + insert_statement.returning(ReturningClause::Columns(auto_col_identifiers)); + + self.fetch_option(&insert_statement) + .await? + .expect("query should return the primary key") + } else { + let result = self.execute_statement(&insert_statement).await?; + let row_id = result + .last_inserted_row_id + .expect("expected last inserted row ID if RETURNING clause is not supported"); + let query = sea_query::Query::select() + .from(T::TABLE_NAME) + .columns(auto_col_identifiers) + .and_where(sea_query::Expr::col(T::PRIMARY_KEY_NAME).eq(row_id)) + .to_owned(); + self.fetch_option(&query).await?.expect( + "expected a row returned from a SELECT if RETURNING clause is not supported", + ) + }; + data.update_from_db(row, &auto_col_ids)?; + } else { + self.execute_statement(&insert_statement).await?; + } - debug!( - "Inserted row; rows affected: {}", - statement_result.rows_affected() - ); + debug!("Inserted row"); Ok(()) } @@ -625,7 +714,7 @@ impl Database { ) -> Result { let values = values .iter() - .map(ToDbValue::to_sea_query_value) + .map(ToDbValue::to_db_value) .collect::>(); let values = SqlxValues(sea_query::Values(values)); @@ -659,6 +748,14 @@ impl Database { Ok(result) } + fn supports_returning(&self) -> bool { + match self.inner { + DatabaseImpl::Sqlite(_) => true, + DatabaseImpl::Postgres(_) => true, + DatabaseImpl::MySql(_) => false, + } + } + async fn fetch_all(&self, statement: &T) -> Result> where T: SqlxBinder, @@ -777,14 +874,30 @@ impl DatabaseBackend for Database { #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] pub struct StatementResult { rows_affected: RowsNum, + last_inserted_row_id: Option, } impl StatementResult { /// Creates a new statement result with the given number of rows affected. - #[cfg(test)] #[must_use] pub(crate) fn new(rows_affected: RowsNum) -> Self { - Self { rows_affected } + Self { + rows_affected, + last_inserted_row_id: None, + } + } + + /// Creates a new statement result with the given number of rows affected + /// and last inserted row ID. + #[must_use] + pub(crate) fn new_with_last_inserted_row_id( + rows_affected: RowsNum, + last_inserted_row_id: u64, + ) -> Self { + Self { + rows_affected, + last_inserted_row_id: Some(last_inserted_row_id), + } } /// Returns the number of rows affected by the query. @@ -792,12 +905,68 @@ impl StatementResult { pub fn rows_affected(&self) -> RowsNum { self.rows_affected } + + /// Returns the ID of the last inserted row. + #[must_use] + pub fn last_inserted_row_id(&self) -> Option { + self.last_inserted_row_id + } } /// A structure that holds the number of rows. #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deref, Display)] pub struct RowsNum(pub u64); +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum Auto { + Fixed(T), + Auto, +} + +impl Auto { + #[must_use] + pub const fn auto() -> Self { + Self::Auto + } + + #[must_use] + pub const fn fixed(value: T) -> Self { + Self::Fixed(value) + } +} + +impl Default for Auto { + fn default() -> Self { + Self::Auto + } +} + +impl From for Auto { + fn from(value: T) -> Self { + Self::fixed(value) + } +} + +trait TryFromI64 { + fn try_from_i64(value: i64) -> Result + where + Self: Sized; +} + +impl TryFromI64 for i64 { + fn try_from_i64(value: i64) -> Result { + Ok(value) + } +} + +impl TryFromI64 for i32 { + fn try_from_i64(value: i64) -> Result { + value + .try_into() + .map_err(|_| DatabaseError::PrimaryKeyFromI64Error) + } +} + /// A wrapper over a string that has a limited length. /// /// This type is used to represent a string that has a limited length in the @@ -935,14 +1104,6 @@ mod tests { fn column() { let column = Column::new(Identifier::new("test")); assert_eq!(column.name.as_str(), "test"); - assert!(!column.auto_value); - assert!(!column.null); - - let column_auto = column.auto(); - assert!(column_auto.auto_value); - - let column_null = column.null(); - assert!(column_null.null); } #[test] diff --git a/flareon/src/db/fields.rs b/flareon/src/db/fields.rs index 47abc69f..be504093 100644 --- a/flareon/src/db/fields.rs +++ b/flareon/src/db/fields.rs @@ -1,5 +1,7 @@ +//! `DatabaseField` implementations for common types. + use flareon::db::DatabaseField; -use sea_query::Value; +use log::debug; #[cfg(feature = "mysql")] use crate::db::impl_mysql::MySqlValueRef; @@ -8,7 +10,8 @@ use crate::db::impl_postgres::PostgresValueRef; #[cfg(feature = "sqlite")] use crate::db::impl_sqlite::SqliteValueRef; use crate::db::{ - ColumnType, DatabaseError, FromDbValue, LimitedString, Result, SqlxValueRef, ToDbValue, + Auto, ColumnType, DatabaseError, DbFieldValue, DbValue, ForeignKey, FromDbValue, LimitedString, + Model, PrimaryKey, Result, SqlxValueRef, ToDbFieldValue, ToDbValue, }; macro_rules! impl_from_sqlite_default { @@ -41,13 +44,13 @@ macro_rules! impl_from_mysql_default { macro_rules! impl_to_db_value_default { ($ty:ty) => { impl ToDbValue for $ty { - fn to_sea_query_value(&self) -> Value { + fn to_db_value(&self) -> DbValue { self.clone().into() } } impl ToDbValue for Option<$ty> { - fn to_sea_query_value(&self) -> Value { + fn to_db_value(&self) -> DbValue { self.clone().into() } } @@ -136,7 +139,7 @@ impl_db_field!(String, Text); impl_db_field!(Vec, Blob); impl ToDbValue for &str { - fn to_sea_query_value(&self) -> Value { + fn to_db_value(&self) -> DbValue { (*self).to_string().into() } } @@ -171,14 +174,14 @@ impl FromDbValue for Option> { impl_to_db_value_default!(chrono::DateTime); impl ToDbValue for Option<&str> { - fn to_sea_query_value(&self) -> Value { + fn to_db_value(&self) -> DbValue { self.map(ToString::to_string).into() } } impl DatabaseField for Option where - Option: ToDbValue + FromDbValue, + Option: ToDbFieldValue + FromDbValue, { const NULLABLE: bool = true; const TYPE: ColumnType = T::TYPE; @@ -209,13 +212,154 @@ impl FromDbValue for LimitedString { } impl ToDbValue for LimitedString { - fn to_sea_query_value(&self) -> Value { + fn to_db_value(&self) -> DbValue { self.0.clone().into() } } impl ToDbValue for Option> { - fn to_sea_query_value(&self) -> Value { + fn to_db_value(&self) -> DbValue { self.clone().map(|s| s.0).into() } } + +impl DatabaseField for ForeignKey { + const NULLABLE: bool = T::PrimaryKey::NULLABLE; + const TYPE: ColumnType = T::PrimaryKey::TYPE; +} + +impl FromDbValue for ForeignKey { + #[cfg(feature = "sqlite")] + fn from_sqlite(value: SqliteValueRef) -> Result { + T::PrimaryKey::from_sqlite(value).map(ForeignKey::PrimaryKey) + } + + #[cfg(feature = "postgres")] + fn from_postgres(value: PostgresValueRef) -> Result { + T::PrimaryKey::from_postgres(value).map(ForeignKey::PrimaryKey) + } + + #[cfg(feature = "mysql")] + fn from_mysql(value: MySqlValueRef) -> Result { + T::PrimaryKey::from_mysql(value).map(ForeignKey::PrimaryKey) + } +} + +impl ToDbFieldValue for ForeignKey { + fn to_db_field_value(&self) -> DbFieldValue { + self.primary_key().to_db_field_value() + } +} + +impl FromDbValue for Option> +where + Option: FromDbValue, +{ + #[cfg(feature = "sqlite")] + fn from_sqlite(value: SqliteValueRef) -> Result { + Ok(>::from_sqlite(value)?.map(ForeignKey::PrimaryKey)) + } + + #[cfg(feature = "postgres")] + fn from_postgres(value: PostgresValueRef) -> Result { + Ok(>::from_postgres(value)?.map(ForeignKey::PrimaryKey)) + } + + #[cfg(feature = "mysql")] + fn from_mysql(value: MySqlValueRef) -> Result { + Ok(>::from_mysql(value)?.map(ForeignKey::PrimaryKey)) + } +} + +impl ToDbFieldValue for Option> +where + Option: ToDbFieldValue, +{ + fn to_db_field_value(&self) -> DbFieldValue { + match self { + Some(foreign_key) => foreign_key.to_db_field_value(), + None => >::None.to_db_field_value(), + } + } +} + +impl DatabaseField for Auto { + const NULLABLE: bool = T::NULLABLE; + const TYPE: ColumnType = T::TYPE; +} + +impl FromDbValue for Auto { + fn from_sqlite(value: SqliteValueRef) -> Result + where + Self: Sized, + { + Ok(Self::fixed(T::from_sqlite(value)?)) + } + + fn from_postgres(value: PostgresValueRef) -> Result + where + Self: Sized, + { + Ok(Self::fixed(T::from_postgres(value)?)) + } + + fn from_mysql(value: MySqlValueRef) -> Result + where + Self: Sized, + { + Ok(Self::fixed(T::from_mysql(value)?)) + } +} + +impl ToDbFieldValue for Auto { + fn to_db_field_value(&self) -> DbFieldValue { + match self { + Self::Fixed(value) => value.to_db_field_value(), + Self::Auto => DbFieldValue::Auto, + } + } +} + +impl FromDbValue for Option> +where + Option: FromDbValue, +{ + fn from_sqlite(value: SqliteValueRef) -> Result + where + Self: Sized, + { + >::from_sqlite(value).map(|value| value.map(Auto::fixed)) + } + + fn from_postgres(value: PostgresValueRef) -> Result + where + Self: Sized, + { + >::from_postgres(value).map(|value| value.map(Auto::fixed)) + } + + fn from_mysql(value: MySqlValueRef) -> Result + where + Self: Sized, + { + >::from_mysql(value).map(|value| value.map(Auto::fixed)) + } +} + +impl ToDbFieldValue for Option> +where + Option: ToDbFieldValue, +{ + fn to_db_field_value(&self) -> DbFieldValue { + match self { + Some(auto) => auto.to_db_field_value(), + None => >::None.to_db_field_value(), + } + } +} + +impl PrimaryKey for Auto {} + +impl PrimaryKey for i32 {} + +impl PrimaryKey for i64 {} diff --git a/flareon/src/db/impl_mysql.rs b/flareon/src/db/impl_mysql.rs index 23141049..d63aef9d 100644 --- a/flareon/src/db/impl_mysql.rs +++ b/flareon/src/db/impl_mysql.rs @@ -4,10 +4,18 @@ use crate::db::ColumnType; impl_sea_query_db_backend!(DatabaseMySql: sqlx::mysql::MySql, sqlx::mysql::MySqlPool, MySqlRow, MySqlValueRef, sea_query::MysqlQueryBuilder); impl DatabaseMySql { + async fn init(&self) -> crate::db::Result<()> { + Ok(()) + } + fn prepare_values(_values: &mut sea_query_binder::SqlxValues) { // No changes are needed for MySQL } + fn last_inserted_row_id_for(result: &sqlx::mysql::MySqlQueryResult) -> Option { + Some(result.last_insert_id()) + } + pub(super) fn sea_query_column_type_for( &self, column_type: ColumnType, diff --git a/flareon/src/db/impl_postgres.rs b/flareon/src/db/impl_postgres.rs index 5ade4646..eca4550e 100644 --- a/flareon/src/db/impl_postgres.rs +++ b/flareon/src/db/impl_postgres.rs @@ -3,6 +3,10 @@ use crate::db::sea_query_db::impl_sea_query_db_backend; impl_sea_query_db_backend!(DatabasePostgres: sqlx::postgres::Postgres, sqlx::postgres::PgPool, PostgresRow, PostgresValueRef, sea_query::PostgresQueryBuilder); impl DatabasePostgres { + async fn init(&self) -> crate::db::Result<()> { + Ok(()) + } + fn prepare_values(values: &mut sea_query_binder::SqlxValues) { for value in &mut values.0 .0 { Self::tinyint_to_smallint(value); @@ -34,6 +38,10 @@ impl DatabasePostgres { } } + fn last_inserted_row_id_for(result: &sqlx::postgres::PgQueryResult) -> Option { + None + } + pub(super) fn sea_query_column_type_for( &self, column_type: crate::db::ColumnType, diff --git a/flareon/src/db/impl_sqlite.rs b/flareon/src/db/impl_sqlite.rs index 5f228b12..705f517f 100644 --- a/flareon/src/db/impl_sqlite.rs +++ b/flareon/src/db/impl_sqlite.rs @@ -1,12 +1,29 @@ +use sea_query_binder::SqlxValues; +use sqlx::Executor; + use crate::db::sea_query_db::impl_sea_query_db_backend; impl_sea_query_db_backend!(DatabaseSqlite: sqlx::sqlite::Sqlite, sqlx::sqlite::SqlitePool, SqliteRow, SqliteValueRef, sea_query::SqliteQueryBuilder); impl DatabaseSqlite { - fn prepare_values(_values: &mut sea_query_binder::SqlxValues) { + async fn init(&self) -> crate::db::Result<()> { + self.raw("PRAGMA foreign_keys = ON").await?; + Ok(()) + } + + async fn raw(&self, sql: &str) -> crate::db::Result { + self.raw_with(sql, SqlxValues(sea_query::Values(Vec::new()))) + .await + } + + fn prepare_values(_values: &mut SqlxValues) { // No changes are needed for SQLite } + fn last_inserted_row_id_for(result: &sqlx::sqlite::SqliteQueryResult) -> Option { + Some(result.last_insert_rowid() as u64) + } + pub(super) fn sea_query_column_type_for( &self, column_type: crate::db::ColumnType, diff --git a/flareon/src/db/migrations.rs b/flareon/src/db/migrations.rs index 16d851f9..72560348 100644 --- a/flareon/src/db/migrations.rs +++ b/flareon/src/db/migrations.rs @@ -3,13 +3,14 @@ mod sorter; use std::fmt; use std::fmt::{Debug, Formatter}; -use flareon_macros::{model, query}; +use flareon::db::relations::ForeignKeyOnUpdatePolicy; use sea_query::{ColumnDef, StringLen}; use thiserror::Error; use tracing::info; use crate::db::migrations::sorter::{MigrationSorter, MigrationSorterError}; -use crate::db::{ColumnType, Database, DatabaseField, Identifier, Result}; +use crate::db::relations::ForeignKeyOnDeletePolicy; +use crate::db::{model, query, ColumnType, Database, DatabaseField, Identifier, Model, Result}; #[derive(Debug, Clone, Error)] #[non_exhaustive] @@ -244,6 +245,17 @@ impl Operation { let mut query = sea_query::Table::create().table(*table_name).to_owned(); for field in *fields { query.col(field.as_column_def(database)); + if let Some(foreign_key) = field.foreign_key { + query.foreign_key( + sea_query::ForeignKeyCreateStatement::new() + .from_tbl(*table_name) + .from_col(field.name) + .to_tbl(foreign_key.model) + .to_col(foreign_key.field) + .on_delete(foreign_key.on_delete.into()) + .on_update(foreign_key.on_update.into()), + ); + } } if *if_not_exists { query.if_not_exists(); @@ -345,6 +357,7 @@ pub struct Field { pub null: bool, /// Whether the column has a unique constraint pub unique: bool, + foreign_key: Option, } impl Field { @@ -357,9 +370,36 @@ impl Field { auto_value: false, null: false, unique: false, + foreign_key: None, } } + #[must_use] + pub const fn foreign_key( + mut self, + to_model: Identifier, + to_field: Identifier, + on_delete: ForeignKeyOnDeletePolicy, + on_update: ForeignKeyOnUpdatePolicy, + ) -> Self { + assert!( + self.null || !matches!(on_delete, ForeignKeyOnDeletePolicy::SetNone), + "`ForeignKey` must be inside `Option` if `on_delete` is set to `SetNone`" + ); + assert!( + self.null || !matches!(on_update, ForeignKeyOnUpdatePolicy::SetNone), + "`ForeignKey` must be inside `Option` if `on_update` is set to `SetNone`" + ); + + self.foreign_key = Some(ForeignKeyReference { + model: to_model, + field: to_field, + on_delete, + on_update, + }); + self + } + #[must_use] pub const fn primary_key(mut self) -> Self { self.primary_key = true; @@ -411,6 +451,14 @@ impl Field { } } +#[derive(Debug, Copy, Clone)] +struct ForeignKeyReference { + model: Identifier, + field: Identifier, + on_delete: ForeignKeyOnDeletePolicy, + on_update: ForeignKeyOnUpdatePolicy, +} + #[cfg_attr(test, mockall::automock)] pub(super) trait ColumnTypeMapper { fn sea_query_column_type_for(&self, column_type: ColumnType) -> sea_query::ColumnType; diff --git a/flareon/src/db/query.rs b/flareon/src/db/query.rs index 72ec7b1a..2f097a7f 100644 --- a/flareon/src/db/query.rs +++ b/flareon/src/db/query.rs @@ -4,7 +4,10 @@ use derive_more::Debug; use sea_query::IntoColumnRef; use crate::db; -use crate::db::{DatabaseBackend, FromDbValue, Identifier, Model, StatementResult, ToDbValue}; +use crate::db::{ + Auto, DatabaseBackend, DbFieldValue, DbValue, ForeignKey, FromDbValue, Identifier, Model, + StatementResult, ToDbFieldValue, +}; /// A query that can be executed on a database. Can be used to filter, update, /// or delete rows. @@ -131,7 +134,7 @@ impl Query { #[derive(Debug)] pub enum Expr { Field(Identifier), - Value(#[debug("{}", _0.to_sea_query_value())] Box), + Value(DbValue), And(Box, Box), Or(Box, Box), Eq(Box, Box), @@ -169,8 +172,11 @@ impl Expr { /// let expr = Expr::value(30); /// ``` #[must_use] - pub fn value(value: T) -> Self { - Self::Value(Box::new(value)) + pub fn value(value: T) -> Self { + match value.to_db_field_value() { + DbFieldValue::Value(value) => Self::Value(value), + _ => panic!("Cannot create query with a non-value field"), + } } /// Create a new `AND` expression. @@ -299,7 +305,7 @@ impl Expr { pub fn as_sea_query_expr(&self) -> sea_query::SimpleExpr { match self { Self::Field(identifier) => (*identifier).into_column_ref().into(), - Self::Value(value) => value.to_sea_query_value().into(), + Self::Value(value) => (*value).clone().into(), Self::And(lhs, rhs) => lhs.as_sea_query_expr().and(rhs.as_sea_query_expr()), Self::Or(lhs, rhs) => lhs.as_sea_query_expr().or(rhs.as_sea_query_expr()), Self::Eq(lhs, rhs) => lhs.as_sea_query_expr().eq(rhs.as_sea_query_expr()), @@ -323,7 +329,7 @@ pub struct FieldRef { phantom_data: PhantomData, } -impl FieldRef { +impl FieldRef { /// Create a new field reference. #[must_use] pub const fn new(identifier: Identifier) -> Self { @@ -344,18 +350,18 @@ impl FieldRef { /// A trait for types that can be compared in database expressions. pub trait ExprEq { - fn eq>(self, other: V) -> Expr; + fn eq>(self, other: V) -> Expr; - fn ne>(self, other: V) -> Expr; + fn ne>(self, other: V) -> Expr; } -impl ExprEq for FieldRef { - fn eq>(self, other: V) -> Expr { - Expr::eq(self.as_expr(), Expr::value(other.into())) +impl ExprEq for FieldRef { + fn eq>(self, other: V) -> Expr { + Expr::eq(self.as_expr(), Expr::value(other.into_field())) } - fn ne>(self, other: V) -> Expr { - Expr::ne(self.as_expr(), Expr::value(other.into())) + fn ne>(self, other: V) -> Expr { + Expr::ne(self.as_expr(), Expr::value(other.into_field())) } } @@ -409,6 +415,40 @@ impl_num_expr!(u64); impl_num_expr!(f32); impl_num_expr!(f64); +trait IntoField { + fn into_field(self) -> T; +} + +impl IntoField for T { + fn into_field(self) -> T { + self + } +} + +impl IntoField> for T { + fn into_field(self) -> Auto { + Auto::fixed(self) + } +} + +impl IntoField for &str { + fn into_field(self) -> String { + self.to_string() + } +} + +impl IntoField> for T { + fn into_field(self) -> ForeignKey { + ForeignKey::from(self) + } +} + +impl IntoField> for &T { + fn into_field(self) -> ForeignKey { + ForeignKey::from(self) + } +} + #[cfg(test)] mod tests { use flareon_macros::model; @@ -505,7 +545,7 @@ mod tests { fn test_expr_value() { let expr = Expr::value(30); if let Expr::Value(value) = expr { - assert_eq!(value.to_sea_query_value().to_string(), "30"); + assert_eq!(value.to_string(), "30"); } else { panic!("Expected Expr::Value"); } diff --git a/flareon/src/db/relations.rs b/flareon/src/db/relations.rs new file mode 100644 index 00000000..2dd9a678 --- /dev/null +++ b/flareon/src/db/relations.rs @@ -0,0 +1,109 @@ +use flareon::db::DatabaseError; + +use crate::db::{DatabaseBackend, Model, Result}; + +#[derive(Debug, Clone)] +pub enum ForeignKey { + PrimaryKey(T::PrimaryKey), + Model(Box), +} + +impl ForeignKey { + pub fn primary_key(&self) -> &T::PrimaryKey { + match self { + Self::PrimaryKey(pk) => pk, + Self::Model(model) => model.primary_key(), + } + } + + pub fn model(&self) -> Option<&T> { + match self { + Self::Model(model) => Some(model), + _ => None, + } + } + + pub fn unwrap(self) -> T { + match self { + Self::Model(model) => *model, + _ => panic!("object has not been retrieved from the database"), + } + } + + /// Retrieve the model from the database, if needed, and return it. + pub async fn get(&mut self, db: &DB) -> Result<&T> { + match self { + Self::Model(model) => Ok(model), + Self::PrimaryKey(pk) => { + let model = T::get_by_primary_key(db, pk.clone()) + .await? + .ok_or(DatabaseError::ForeignKeyNotFound)?; + *self = Self::Model(Box::new(model)); + Ok(self.model().expect("model was just set")) + } + } + } +} + +impl PartialEq for ForeignKey +where + T::PrimaryKey: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.primary_key() == other.primary_key() + } +} + +impl Eq for ForeignKey where T::PrimaryKey: Eq {} + +impl From for ForeignKey { + fn from(model: T) -> Self { + Self::Model(Box::new(model)) + } +} + +impl From<&T> for ForeignKey { + fn from(model: &T) -> Self { + Self::PrimaryKey(model.primary_key().clone()) + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)] +pub enum ForeignKeyOnDeletePolicy { + NoAction, + #[default] + Restrict, + Cascade, + SetNone, +} + +impl From for sea_query::ForeignKeyAction { + fn from(value: ForeignKeyOnDeletePolicy) -> Self { + match value { + ForeignKeyOnDeletePolicy::NoAction => Self::NoAction, + ForeignKeyOnDeletePolicy::Restrict => Self::Restrict, + ForeignKeyOnDeletePolicy::Cascade => Self::Cascade, + ForeignKeyOnDeletePolicy::SetNone => Self::SetNull, + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)] +pub enum ForeignKeyOnUpdatePolicy { + NoAction, + Restrict, + #[default] + Cascade, + SetNone, +} + +impl From for sea_query::ForeignKeyAction { + fn from(value: ForeignKeyOnUpdatePolicy) -> Self { + match value { + ForeignKeyOnUpdatePolicy::NoAction => Self::NoAction, + ForeignKeyOnUpdatePolicy::Restrict => Self::Restrict, + ForeignKeyOnUpdatePolicy::Cascade => Self::Cascade, + ForeignKeyOnUpdatePolicy::SetNone => Self::SetNull, + } + } +} diff --git a/flareon/src/db/sea_query_db.rs b/flareon/src/db/sea_query_db.rs index 2f24939a..f5897753 100644 --- a/flareon/src/db/sea_query_db.rs +++ b/flareon/src/db/sea_query_db.rs @@ -15,7 +15,9 @@ macro_rules! impl_sea_query_db_backend { pub(super) async fn new(url: &str) -> crate::db::Result { let db_connection = <$pool_ty>::connect(url).await?; - Ok(Self { db_connection }) + let db = Self { db_connection }; + db.init(); + Ok(db) } pub(super) async fn close(&self) -> crate::db::Result<()> { @@ -88,6 +90,7 @@ macro_rules! impl_sea_query_db_backend { let result = sqlx_statement.execute(&self.db_connection).await?; let result = crate::db::StatementResult { rows_affected: crate::db::RowsNum(result.rows_affected()), + last_inserted_row_id: Self::last_inserted_row_id_for(&result), }; tracing::debug!("Rows affected: {}", result.rows_affected.0); diff --git a/flareon/tests/db.rs b/flareon/tests/db.rs index d365a2a7..9489d801 100644 --- a/flareon/tests/db.rs +++ b/flareon/tests/db.rs @@ -6,7 +6,10 @@ use fake::rand::SeedableRng; use fake::{Dummy, Fake, Faker}; use flareon::db::migrations::{Field, Operation}; use flareon::db::query::ExprEq; -use flareon::db::{model, query, Database, DatabaseField, Identifier, LimitedString, Model}; +use flareon::db::{ + model, query, Auto, Database, DatabaseError, DatabaseField, ForeignKey, + ForeignKeyOnDeletePolicy, ForeignKeyOnUpdatePolicy, Identifier, LimitedString, Model, +}; use flareon::test::TestDatabase; #[flareon_macros::dbtest] @@ -16,7 +19,7 @@ async fn model_crud(test_db: &mut TestDatabase) { assert_eq!(TestModel::objects().all(&**test_db).await.unwrap(), vec![]); let mut model = TestModel { - id: 0, + id: Auto::fixed(1), name: "test".to_owned(), }; model.save(&**test_db).await.unwrap(); @@ -40,7 +43,7 @@ async fn model_macro_filtering(test_db: &mut TestDatabase) { assert_eq!(TestModel::objects().all(&**test_db).await.unwrap(), vec![]); let mut model = TestModel { - id: 0, + id: Auto::auto(), name: "test".to_owned(), }; model.save(&**test_db).await.unwrap(); @@ -61,7 +64,7 @@ async fn model_macro_filtering(test_db: &mut TestDatabase) { #[derive(Debug, PartialEq)] #[model] struct TestModel { - id: i32, + id: Auto, name: String, } @@ -72,7 +75,7 @@ async fn migrate_test_model(db: &Database) { const CREATE_TEST_MODEL: Operation = Operation::create_model() .table_name(Identifier::new("test_model")) .fields(&[ - Field::new(Identifier::new("id"), ::TYPE) + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) .primary_key() .auto(), Field::new(Identifier::new("name"), ::TYPE), @@ -99,8 +102,8 @@ macro_rules! all_fields_migration_field { #[derive(Debug, PartialEq, Dummy)] #[model] struct AllFieldsModel { - #[dummy(expr = "0i32")] - id: i32, + #[dummy(expr = "Auto::auto()")] + id: Auto, field_bool: bool, field_i8: i8, field_i16: i16, @@ -134,7 +137,7 @@ async fn migrate_all_fields_model(db: &Database) { const CREATE_ALL_FIELDS_MODEL: Operation = Operation::create_model() .table_name(Identifier::new("all_fields_model")) .fields(&[ - Field::new(Identifier::new("id"), ::TYPE) + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) .primary_key() .auto(), all_fields_migration_field!(bool), @@ -174,7 +177,6 @@ async fn all_fields_model(db: &mut TestDatabase) { } let mut models_from_db: Vec<_> = AllFieldsModel::objects().all(&**db).await.unwrap(); - models_from_db.iter_mut().for_each(|model| model.id = 0); normalize_datetimes(&mut models); normalize_datetimes(&mut models_from_db); @@ -197,3 +199,161 @@ fn normalize_datetimes(data: &mut Vec) { ); } } + +#[flareon_macros::dbtest] +async fn foreign_keys(db: &mut TestDatabase) { + #[derive(Debug, Clone, PartialEq)] + #[model] + struct Artist { + id: Auto, + name: String, + } + + #[derive(Debug, Clone, PartialEq)] + #[model] + struct Track { + id: Auto, + artist: ForeignKey, + name: String, + } + + const CREATE_ARTIST: Operation = Operation::create_model() + .table_name(Identifier::new("artist")) + .fields(&[ + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) + .primary_key() + .auto(), + Field::new(Identifier::new("name"), ::TYPE), + ]) + .build(); + const CREATE_TRACK: Operation = Operation::create_model() + .table_name(Identifier::new("track")) + .fields(&[ + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) + .primary_key() + .auto(), + Field::new( + Identifier::new("artist"), + as DatabaseField>::TYPE, + ) + .foreign_key( + ::TABLE_NAME, + ::PRIMARY_KEY_NAME, + ForeignKeyOnDeletePolicy::Restrict, + ForeignKeyOnUpdatePolicy::Restrict, + ), + Field::new(Identifier::new("name"), ::TYPE), + ]) + .build(); + + CREATE_ARTIST.forwards(db).await.unwrap(); + CREATE_TRACK.forwards(db).await.unwrap(); + + let mut artist = Artist { + id: Auto::auto(), + name: "artist".to_owned(), + }; + artist.save(&**db).await.unwrap(); + + let mut track = Track { + id: Auto::auto(), + artist: ForeignKey::from(&artist), + name: "track".to_owned(), + }; + track.save(&**db).await.unwrap(); + + let mut track = Track::objects().all(&**db).await.unwrap()[0].clone(); + let artist_from_db = track.artist.get(&**db).await.unwrap(); + assert_eq!(artist_from_db, &artist); + + let error = query!(Artist, $id == artist.id) + .delete(&**db) + .await + .unwrap_err(); + // expected foreign key violation + assert!(matches!(error, DatabaseError::DatabaseEngineError(_))); + + query!(Track, $artist == &artist) + .delete(&**db) + .await + .unwrap(); + query!(Artist, $id == artist.id) + .delete(&**db) + .await + .unwrap(); + // no error should be thrown +} + +#[flareon_macros::dbtest] +async fn foreign_keys_option(db: &mut TestDatabase) { + #[derive(Debug, Clone, PartialEq)] + #[model] + struct Parent { + id: Auto, + } + + #[derive(Debug, Clone, PartialEq)] + #[model] + struct Child { + id: Auto, + parent: Option>, + } + + const CREATE_PARENT: Operation = Operation::create_model() + .table_name(Identifier::new("parent")) + .fields(&[ + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) + .primary_key() + .auto(), + ]) + .build(); + const CREATE_CHILD: Operation = Operation::create_model() + .table_name(Identifier::new("child")) + .fields(&[ + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) + .primary_key() + .auto(), + Field::new( + Identifier::new("parent"), + > as DatabaseField>::TYPE, + ) + .foreign_key( + ::TABLE_NAME, + ::PRIMARY_KEY_NAME, + ForeignKeyOnDeletePolicy::Restrict, + ForeignKeyOnUpdatePolicy::Restrict, + ) + .set_null(> as DatabaseField>::NULLABLE), + ]) + .build(); + + CREATE_PARENT.forwards(db).await.unwrap(); + CREATE_CHILD.forwards(db).await.unwrap(); + + // no parent + let mut child = Child { + id: Auto::auto(), + parent: None, + }; + child.save(&**db).await.unwrap(); + + let mut child = Child::objects().all(&**db).await.unwrap()[0].clone(); + assert_eq!(child.parent, None); + + query!(Child, $id == child.id).delete(&**db).await.unwrap(); + + // with parent + let mut parent = Parent { id: Auto::auto() }; + parent.save(&**db).await.unwrap(); + + let mut child = Child { + id: Auto::auto(), + parent: Some(ForeignKey::from(&parent)), + }; + child.save(&**db).await.unwrap(); + + let mut child = Child::objects().all(&**db).await.unwrap()[0].clone(); + let mut parent_fk = child.parent.unwrap(); + let parent_from_db = parent_fk.get(&**db).await.unwrap(); + assert_eq!(parent_from_db, &parent); +} From e57f089b755926d2d80f7be6685ae688a6303642 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Ma=C4=87kowski?= Date: Tue, 24 Dec 2024 14:40:23 +0100 Subject: [PATCH 2/9] feat: address review comments, more work on migration generation --- flareon-cli/tests/migration_generator.rs | 8 +- flareon-codegen/src/model.rs | 129 ++++++++++++++++++++--- 2 files changed, 118 insertions(+), 19 deletions(-) diff --git a/flareon-cli/tests/migration_generator.rs b/flareon-cli/tests/migration_generator.rs index 4efefffc..41ce6e5a 100644 --- a/flareon-cli/tests/migration_generator.rs +++ b/flareon-cli/tests/migration_generator.rs @@ -27,19 +27,21 @@ fn create_model_state_test() { assert_eq!(field.column_name, "id"); assert!(field.primary_key); assert!(field.auto_value.unwrap()); - assert!(!field.foreign_key.unwrap()); + assert!(field.foreign_key.clone().unwrap().is_none()); let field = &fields[1]; assert_eq!(field.column_name, "field_1"); assert!(!field.primary_key); assert!(!field.auto_value.unwrap()); - assert!(!field.foreign_key.unwrap()); + assert!(field.foreign_key.clone().unwrap().is_none()); let field = &fields[2]; assert_eq!(field.column_name, "field_2"); assert!(!field.primary_key); assert!(!field.auto_value.unwrap()); - assert!(!field.foreign_key.unwrap()); + assert!(field.foreign_key.clone().unwrap().is_none()); + } else { + panic!("Expected a create model operation"); } } diff --git a/flareon-codegen/src/model.rs b/flareon-codegen/src/model.rs index f9adb399..db0e7f70 100644 --- a/flareon-codegen/src/model.rs +++ b/flareon-codegen/src/model.rs @@ -133,7 +133,7 @@ pub struct FieldOpts { impl FieldOpts { #[must_use] - fn find_type(&self, type_to_check: &str, symbol_resolver: &SymbolResolver) -> bool { + fn has_type(&self, type_to_check: &str, symbol_resolver: &SymbolResolver) -> bool { let mut ty = self.ty.clone(); symbol_resolver.resolve(&mut ty); Self::inner_type_names(&ty) @@ -171,6 +171,54 @@ impl FieldOpts { } } + fn find_type(&self, type_to_find: &str, symbol_resolver: &SymbolResolver) -> Option { + let mut ty = self.ty.clone(); + symbol_resolver.resolve(&mut ty); + Self::find_type_resolved(&ty, type_to_find) + } + + fn find_type_resolved(ty: &syn::Type, type_to_find: &str) -> Option { + if let syn::Type::Path(type_path) = ty { + let name = type_path + .path + .segments + .iter() + .map(|s| s.ident.to_string()) + .collect::>() + .join("::"); + + if name == type_to_find { + return Some(ty.clone()); + } + + for arg in &type_path.path.segments { + if let syn::PathArguments::AngleBracketed(arg) = &arg.arguments { + if let Some(ty) = Self::find_type_in_generics(arg, type_to_find) { + return Some(ty); + } + } + } + } + + None + } + + fn find_type_in_generics( + arg: &syn::AngleBracketedGenericArguments, + type_to_find: &str, + ) -> Option { + arg.args + .iter() + .filter_map(|arg| { + if let syn::GenericArgument::Type(ty) = arg { + Self::find_type_resolved(ty, type_to_find) + } else { + None + } + }) + .next() + } + /// Convert the field options into a field. /// /// # Panics @@ -183,10 +231,13 @@ impl FieldOpts { let column_name = name.to_string(); let (auto_value, foreign_key) = match symbol_resolver { Some(resolver) => ( - Some(self.find_type("flareon::db::Auto", resolver)), - Some(self.find_type("flareon::db::ForeignKey", resolver)), + MaybeUnknown::Determined(self.find_type("flareon::db::Auto", resolver).is_some()), + MaybeUnknown::Determined( + self.find_type("flareon::db::ForeignKey", resolver) + .map(ForeignKeySpec::from), + ), ), - None => (None, None), + None => (MaybeUnknown::Unknown, MaybeUnknown::Unknown), }; let is_primary_key = column_name == "id" || self.primary_key.is_present(); @@ -224,19 +275,65 @@ pub struct Field { pub field_name: syn::Ident, pub column_name: String, pub ty: syn::Type, - /// Whether the field is an auto field (e.g. `id`); `None` if it could not - /// be determined. - pub auto_value: Option, + /// Whether the field is an auto field (e.g. `id`); + /// [`MaybeUnknown::Unknown`] if this `Field` instance was not resolved with + /// a [`SymbolResolver`]. + pub auto_value: MaybeUnknown, pub primary_key: bool, - /// Whether the field is a foreign key; `None` if it could not be - /// determined. - pub foreign_key: Option, + /// [`Some`] wrapped in [`MaybeUnknown::Determined`] if this field is a + /// foreign key; [`None`] wrapped in [`MaybeUnknown::Determined`] if this + /// field is determined not to be a foreign key; [`MaybeUnknown::Unknown`] + /// if this `Field` instance was not resolved with a [`SymbolResolver`]. + pub foreign_key: MaybeUnknown>, pub unique: bool, } +/// Wraps a type whose value may or may not be possible to be determined using +/// the information available. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum MaybeUnknown { + /// Indicates that this instance is determined to be a certain value + /// (possibly [`None`] if wrapping an [`Option`]). + Determined(T), + /// Indicates that the value is unknown. + Unknown, +} + +impl MaybeUnknown { + pub fn unwrap(self) -> T { + match self { + MaybeUnknown::Determined(value) => value, + MaybeUnknown::Unknown => { + panic!("called `MaybeUnknown::unwrap()` on an `Unknown` value") + } + } + } + + pub fn expect(self, msg: &str) -> T { + match self { + MaybeUnknown::Determined(value) => value, + MaybeUnknown::Unknown => { + panic!("{}", msg) + } + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ForeignKeySpec { + pub ty: syn::Type, +} + +impl From for ForeignKeySpec { + fn from(value: syn::Type) -> Self { + todo!() + } +} + #[cfg(test)] mod tests { use syn::parse_quote; + use syn::TraitBoundModifier::Maybe; use super::*; #[cfg(feature = "symbol-resolver")] @@ -371,8 +468,8 @@ mod tests { assert_eq!(field.column_name, "name"); assert_eq!(field.ty, parse_quote!(String)); assert!(field.unique); - assert_eq!(field.auto_value, None); - assert_eq!(field.foreign_key, None); + assert_eq!(field.auto_value, MaybeUnknown::Unknown); + assert_eq!(field.foreign_key, MaybeUnknown::Unknown); } #[test] @@ -403,9 +500,9 @@ mod tests { unique: Default::default(), }; - assert!(opts.find_type("my_crate::MyContainer", &resolver)); - assert!(opts.find_type("std::string::String", &resolver)); - assert!(!opts.find_type("MyContainer", &resolver)); - assert!(!opts.find_type("String", &resolver)); + assert!(opts.has_type("my_crate::MyContainer", &resolver)); + assert!(opts.has_type("std::string::String", &resolver)); + assert!(!opts.has_type("MyContainer", &resolver)); + assert!(!opts.has_type("String", &resolver)); } } From 736d1d3f0ed930b6a98768e5c43d732910cd588c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Ma=C4=87kowski?= Date: Mon, 6 Jan 2025 13:03:09 +0100 Subject: [PATCH 3/9] feat: store foreign key type in Field --- flareon-cli/src/migration_generator.rs | 9 ++ flareon-codegen/src/lib.rs | 2 + flareon-codegen/src/maybe_unknown.rs | 56 +++++++++ flareon-codegen/src/model.rs | 157 ++++++++++--------------- 4 files changed, 131 insertions(+), 93 deletions(-) create mode 100644 flareon-codegen/src/maybe_unknown.rs diff --git a/flareon-cli/src/migration_generator.rs b/flareon-cli/src/migration_generator.rs index 01f0bda6..a8c13f7c 100644 --- a/flareon-cli/src/migration_generator.rs +++ b/flareon-cli/src/migration_generator.rs @@ -761,6 +761,15 @@ pub enum DynOperation { }, } +impl DynOperation { + fn foreign_keys_added(&self) -> Vec<&syn::Type> { + match self { + DynOperation::CreateModel { fields, .. } => {} + DynOperation::AddField { field, .. } => {} + } + } +} + impl Repr for DynOperation { fn repr(&self) -> TokenStream { match self { diff --git a/flareon-codegen/src/lib.rs b/flareon-codegen/src/lib.rs index 8668603a..a55bb032 100644 --- a/flareon-codegen/src/lib.rs +++ b/flareon-codegen/src/lib.rs @@ -1,9 +1,11 @@ extern crate self as flareon_codegen; pub mod expr; +mod maybe_unknown; pub mod model; #[cfg(feature = "symbol-resolver")] pub mod symbol_resolver; + #[cfg(not(feature = "symbol-resolver"))] pub mod symbol_resolver { /// Dummy SymbolResolver for use in contexts when it's not useful (e.g. diff --git a/flareon-codegen/src/maybe_unknown.rs b/flareon-codegen/src/maybe_unknown.rs new file mode 100644 index 00000000..3b200b3e --- /dev/null +++ b/flareon-codegen/src/maybe_unknown.rs @@ -0,0 +1,56 @@ +/// Wraps a type whose value may or may not be possible to be determined using +/// the information available. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum MaybeUnknown { + /// Indicates that this instance is determined to be a certain value + /// (possibly [`None`] if wrapping an [`Option`]). + Determined(T), + /// Indicates that the value is unknown. + Unknown, +} + +impl MaybeUnknown { + pub fn unwrap(self) -> T { + self.expect("called `MaybeUnknown::unwrap()` on an `Unknown` value") + } + + pub fn expect(self, msg: &str) -> T { + match self { + MaybeUnknown::Determined(value) => value, + MaybeUnknown::Unknown => { + panic!("{}", msg) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn maybe_unknown_determined() { + let value = MaybeUnknown::Determined(42); + assert_eq!(value.unwrap(), 42); + } + + #[test] + #[should_panic(expected = "called `MaybeUnknown::unwrap()` on an `Unknown` value")] + fn maybe_unknown_unknown_unwrap() { + let value: MaybeUnknown = MaybeUnknown::Unknown; + assert_eq!(value.unwrap(), 42); + } + + #[test] + fn maybe_unknown_expect() { + let value = MaybeUnknown::Determined(42); + assert_eq!(value.expect("value should be determined"), 42); + } + + #[test] + #[should_panic(expected = "value should be determined")] + fn maybe_unknown_unknown_expect() { + let value: MaybeUnknown = MaybeUnknown::Unknown; + value.expect("value should be determined"); + } +} diff --git a/flareon-codegen/src/model.rs b/flareon-codegen/src/model.rs index db0e7f70..0215a0e2 100644 --- a/flareon-codegen/src/model.rs +++ b/flareon-codegen/src/model.rs @@ -1,6 +1,8 @@ use convert_case::{Case, Casing}; use darling::{FromDeriveInput, FromField, FromMeta}; +use syn::spanned::Spanned; +use crate::maybe_unknown::MaybeUnknown; use crate::symbol_resolver::SymbolResolver; #[allow(clippy::module_name_repetitions)] @@ -66,11 +68,11 @@ impl ModelOpts { args: &ModelArgs, symbol_resolver: Option<&SymbolResolver>, ) -> Result { - let fields: Vec<_> = self + let fields = self .fields() .iter() .map(|field| field.as_field(symbol_resolver)) - .collect(); + .collect::, _>>()?; let mut original_name = self.ident.to_string(); if args.model_type == ModelType::Migration { @@ -132,45 +134,6 @@ pub struct FieldOpts { } impl FieldOpts { - #[must_use] - fn has_type(&self, type_to_check: &str, symbol_resolver: &SymbolResolver) -> bool { - let mut ty = self.ty.clone(); - symbol_resolver.resolve(&mut ty); - Self::inner_type_names(&ty) - .iter() - .any(|name| name == type_to_check) - } - - #[must_use] - fn inner_type_names(ty: &syn::Type) -> Vec { - let mut names = Vec::new(); - Self::inner_type_names_impl(ty, &mut names); - names - } - - fn inner_type_names_impl(ty: &syn::Type, names: &mut Vec) { - if let syn::Type::Path(type_path) = ty { - let name = type_path - .path - .segments - .iter() - .map(|s| s.ident.to_string()) - .collect::>() - .join("::"); - names.push(name); - - for arg in &type_path.path.segments { - if let syn::PathArguments::AngleBracketed(arg) = &arg.arguments { - for arg in &arg.args { - if let syn::GenericArgument::Type(ty) = arg { - Self::inner_type_names_impl(ty, names); - } - } - } - } - } - } - fn find_type(&self, type_to_find: &str, symbol_resolver: &SymbolResolver) -> Option { let mut ty = self.ty.clone(); symbol_resolver.resolve(&mut ty); @@ -225,23 +188,24 @@ impl FieldOpts { /// /// Panics if the field does not have an identifier (i.e. it is a tuple /// struct). - #[must_use] - pub fn as_field(&self, symbol_resolver: Option<&SymbolResolver>) -> Field { + pub fn as_field(&self, symbol_resolver: Option<&SymbolResolver>) -> Result { let name = self.ident.as_ref().unwrap(); let column_name = name.to_string(); + let (auto_value, foreign_key) = match symbol_resolver { Some(resolver) => ( MaybeUnknown::Determined(self.find_type("flareon::db::Auto", resolver).is_some()), MaybeUnknown::Determined( self.find_type("flareon::db::ForeignKey", resolver) - .map(ForeignKeySpec::from), + .map(ForeignKeySpec::try_from) + .transpose()?, ), ), None => (MaybeUnknown::Unknown, MaybeUnknown::Unknown), }; let is_primary_key = column_name == "id" || self.primary_key.is_present(); - Field { + Ok(Field { field_name: name.clone(), column_name, ty: self.ty.clone(), @@ -249,7 +213,7 @@ impl FieldOpts { primary_key: is_primary_key, foreign_key, unique: self.unique.is_present(), - } + }) } } @@ -288,52 +252,60 @@ pub struct Field { pub unique: bool, } -/// Wraps a type whose value may or may not be possible to be determined using -/// the information available. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum MaybeUnknown { - /// Indicates that this instance is determined to be a certain value - /// (possibly [`None`] if wrapping an [`Option`]). - Determined(T), - /// Indicates that the value is unknown. - Unknown, +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct ForeignKeySpec { + pub to_model: syn::Type, } -impl MaybeUnknown { - pub fn unwrap(self) -> T { - match self { - MaybeUnknown::Determined(value) => value, - MaybeUnknown::Unknown => { - panic!("called `MaybeUnknown::unwrap()` on an `Unknown` value") - } - } - } +impl TryFrom for ForeignKeySpec { + type Error = syn::Error; - pub fn expect(self, msg: &str) -> T { - match self { - MaybeUnknown::Determined(value) => value, - MaybeUnknown::Unknown => { - panic!("{}", msg) - } - } - } -} + fn try_from(ty: syn::Type) -> Result { + let type_path = if let syn::Type::Path(type_path) = &ty { + type_path + } else { + panic!("Expected a path type for a foreign key"); + }; -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ForeignKeySpec { - pub ty: syn::Type, -} + let args = if let syn::PathArguments::AngleBracketed(args) = &type_path + .path + .segments + .last() + .expect("type path must have at least one segment") + .arguments + { + args + } else { + return Err(syn::Error::new( + ty.span(), + "expected ForeignKey to have angle-bracketed generic arguments", + )); + }; + + if args.args.len() != 1 { + return Err(syn::Error::new( + ty.span(), + "expected ForeignKey to have only one generic parameter", + )); + } -impl From for ForeignKeySpec { - fn from(value: syn::Type) -> Self { - todo!() + let inner = &args.args[0]; + if let syn::GenericArgument::Type(ty) = inner { + Ok(Self { + to_model: ty.clone(), + }) + } else { + Err(syn::Error::new( + ty.span(), + "expected ForeignKey to have a type generic argument", + )) + } } } #[cfg(test)] mod tests { use syn::parse_quote; - use syn::TraitBoundModifier::Maybe; use super::*; #[cfg(feature = "symbol-resolver")] @@ -463,7 +435,7 @@ mod tests { name: String }; let field_opts = FieldOpts::from_field(&input).unwrap(); - let field = field_opts.as_field(None); + let field = field_opts.as_field(None).unwrap(); assert_eq!(field.field_name.to_string(), "name"); assert_eq!(field.column_name, "name"); assert_eq!(field.ty, parse_quote!(String)); @@ -473,19 +445,18 @@ mod tests { } #[test] - fn inner_type_names() { + fn find_type_resolved() { let input: syn::Type = parse_quote! { ::my_crate::MyContainer<'a, Vec> }; - let names = FieldOpts::inner_type_names(&input); - assert_eq!( - names, - vec!["my_crate::MyContainer", "Vec", "std::string::String"] - ); + assert!(FieldOpts::find_type_resolved(&input, "my_crate::MyContainer").is_some()); + assert!(FieldOpts::find_type_resolved(&input, "Vec").is_some()); + assert!(FieldOpts::find_type_resolved(&input, "std::string::String").is_some()); + assert!(!FieldOpts::find_type_resolved(&input, "OtherType").is_some()); } #[cfg(feature = "symbol-resolver")] #[test] - fn contains_type() { + fn find_type() { let symbols = vec![VisibleSymbol::new( "MyContainer", "my_crate::MyContainer", @@ -500,9 +471,9 @@ mod tests { unique: Default::default(), }; - assert!(opts.has_type("my_crate::MyContainer", &resolver)); - assert!(opts.has_type("std::string::String", &resolver)); - assert!(!opts.has_type("MyContainer", &resolver)); - assert!(!opts.has_type("String", &resolver)); + assert!(opts.find_type("my_crate::MyContainer", &resolver).is_some()); + assert!(opts.find_type("std::string::String", &resolver).is_some()); + assert!(!opts.find_type("MyContainer", &resolver).is_none()); + assert!(!opts.find_type("String", &resolver).is_none()); } } From 94f862a01aa65fcdab4a2d07f12d1b0bc12f1235 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Ma=C4=87kowski?= Date: Tue, 7 Jan 2025 19:47:59 +0100 Subject: [PATCH 4/9] feat: more work on ForeignKeys, add migration generation --- Cargo.lock | 261 ++++----- Cargo.toml | 1 + flareon-cli/Cargo.toml | 1 + flareon-cli/src/main.rs | 16 +- flareon-cli/src/migration_generator.rs | 526 +++++++++++++++++- flareon-cli/tests/migration_generator.rs | 182 +++++- .../tests/migration_generator/create_model.rs | 10 +- .../tests/migration_generator/foreign_key.rs | 14 + .../migration_generator/foreign_key_cycle.rs | 15 + .../foreign_key_two_migrations/step_1.rs | 8 + .../foreign_key_two_migrations/step_2.rs | 14 + flareon-codegen/Cargo.toml | 4 +- flareon-codegen/src/expr.rs | 3 +- flareon-codegen/src/lib.rs | 4 +- flareon-codegen/src/maybe_unknown.rs | 14 +- flareon-codegen/src/model.rs | 30 +- flareon-codegen/src/symbol_resolver.rs | 13 +- flareon/Cargo.toml | 4 - .../src/auth/db/migrations/m_0001_initial.rs | 2 +- flareon/src/db.rs | 99 ++-- flareon/src/db/fields.rs | 1 - flareon/src/db/impl_mysql.rs | 1 + flareon/src/db/impl_postgres.rs | 3 +- flareon/src/db/impl_sqlite.rs | 1 - flareon/src/db/migrations.rs | 10 +- flareon/src/db/migrations/sorter.rs | 151 +---- flareon/src/db/query.rs | 4 +- flareon/src/db/relations.rs | 28 +- flareon/src/db/sea_query_db.rs | 2 +- flareon/src/lib.rs | 1 + flareon/src/private.rs | 6 +- flareon/src/utils.rs | 1 + flareon/src/utils/graph.rs | 134 +++++ flareon/tests/db.rs | 89 ++- 34 files changed, 1228 insertions(+), 425 deletions(-) create mode 100644 flareon-cli/tests/migration_generator/foreign_key.rs create mode 100644 flareon-cli/tests/migration_generator/foreign_key_cycle.rs create mode 100644 flareon-cli/tests/migration_generator/foreign_key_two_migrations/step_1.rs create mode 100644 flareon-cli/tests/migration_generator/foreign_key_two_migrations/step_2.rs create mode 100644 flareon/src/utils.rs create mode 100644 flareon/src/utils/graph.rs diff --git a/Cargo.lock b/Cargo.lock index 3037f8f9..b7dc8338 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -17,18 +17,6 @@ version = "2.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627" -[[package]] -name = "ahash" -version = "0.8.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" -dependencies = [ - "cfg-if", - "once_cell", - "version_check", - "zerocopy", -] - [[package]] name = "aho-corasick" version = "1.1.3" @@ -150,9 +138,9 @@ dependencies = [ [[package]] name = "async-trait" -version = "0.1.83" +version = "0.1.85" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721cae7de5c34fbb2acd27e21e6d2cf7b886dce0c27388d46c4e6c47ea4318dd" +checksum = "3f934833b4b7233644e5848f235df3f57ed8c80f1528a26c3dfa13d2147fa056" dependencies = [ "proc-macro2", "quote", @@ -326,9 +314,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.2.3" +version = "1.2.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "27f657647bcff5394bf56c7317665bbf790a137a50eaaa5c6bfbb9e27a518f2d" +checksum = "a012a0df96dd6d06ba9a1b29d6402d1a5d77c6befd2566afdc26e10603dc93d7" dependencies = [ "shlex", ] @@ -353,9 +341,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.23" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" +checksum = "9560b07a799281c7e0958b9296854d6fafd4c5f31444a7e5bb1ad6dde5ccf1bd" dependencies = [ "clap_builder", "clap_derive", @@ -373,9 +361,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.23" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" +checksum = "874e0dd3eb68bf99058751ac9712f622e61e6f393a94f7128fa26e3f02f5c7cd" dependencies = [ "anstream", "anstyle", @@ -385,9 +373,9 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.5.18" +version = "4.5.24" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +checksum = "54b755194d6389280185988721fffba69495eed5ee9feeee9a599b53db80318c" dependencies = [ "heck", "proc-macro2", @@ -480,18 +468,18 @@ checksum = "19d374276b40fb8bbdee95aef7c7fa6b5316ec764510eb64b8dd0e2ed0d7e7f5" [[package]] name = "crossbeam-queue" -version = "0.3.11" +version = "0.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +checksum = "0f58bbc28f91df819d0aa2a2c00cd19754769c2fad90579b3592b1c9ba7a3115" dependencies = [ "crossbeam-utils", ] [[package]] name = "crossbeam-utils" -version = "0.8.20" +version = "0.8.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crypto-common" @@ -731,9 +719,9 @@ dependencies = [ [[package]] name = "event-listener" -version = "5.3.1" +version = "5.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6032be9bd27023a771701cc49f9f053c751055f71efb2e0ae5c15809093675ba" +checksum = "3492acde4c3fc54c845eaab3eed8bd00c7a7d881f78bfc801e43a93dec1331ae" dependencies = [ "concurrent-queue", "parking", @@ -792,9 +780,15 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.2.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "486f806e73c5707928240ddc295403b1b93c96a02038563881c4a2fd84b81ac4" +checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" + +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" [[package]] name = "flareon" @@ -808,7 +802,6 @@ dependencies = [ "chrono", "derive_builder", "derive_more", - "env_logger", "fake", "flareon_macros", "form_urlencoded", @@ -855,6 +848,7 @@ dependencies = [ "flareon", "flareon_codegen", "glob", + "petgraph", "prettyplease", "proc-macro2", "quote", @@ -871,10 +865,10 @@ version = "0.1.0" dependencies = [ "convert_case", "darling", - "log", "proc-macro2", "quote", "syn", + "tracing", ] [[package]] @@ -909,6 +903,12 @@ version = "1.0.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1" +[[package]] +name = "foldhash" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a0d2fde1f7b3d48b8395d5f2de76c18a528bd6a9cdde438df747bfcba3e05d6f" + [[package]] name = "form_urlencoded" version = "1.2.1" @@ -1059,27 +1059,22 @@ checksum = "a8d1add55171497b4705a648c6b583acafb01d58050a51727785f0b2c8e0a2b2" [[package]] name = "hashbrown" -version = "0.14.5" +version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" dependencies = [ - "ahash", "allocator-api2", + "equivalent", + "foldhash", ] -[[package]] -name = "hashbrown" -version = "0.15.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" - [[package]] name = "hashlink" -version = "0.9.1" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +checksum = "7382cf6263419f2d8df38c55d7da83da5c18aef87fc7a7fc1fb1e344edfe14c1" dependencies = [ - "hashbrown 0.14.5", + "hashbrown", ] [[package]] @@ -1123,11 +1118,11 @@ dependencies = [ [[package]] name = "home" -version = "0.5.9" +version = "0.5.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +checksum = "589533453244b0995c858700322199b2becb13b627df2851f64a2775d024abcf" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -1187,18 +1182,18 @@ dependencies = [ [[package]] name = "hybrid-array" -version = "0.2.1" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45a9a965bb102c1c891fb017c09a05c965186b1265a207640f323ddd009f9deb" +checksum = "f2d35805454dc9f8662a98d6d61886ffe26bd465f5960e0e55345c70d5c0d2a9" dependencies = [ "typenum", ] [[package]] name = "hyper" -version = "1.5.1" +version = "1.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97818827ef4f364230e16705d4706e2897df2bb60617d6ca15d598025a3c481f" +checksum = "256fb8d4bd6413123cc9d91832d78325c48ff41677595be797d90f42969beae0" dependencies = [ "bytes", "futures-channel", @@ -1404,7 +1399,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "62f822373a4fe84d4bb149bf54e584a7f4abec90e072ed49cda0edea5b95471f" dependencies = [ "equivalent", - "hashbrown 0.15.2", + "hashbrown", ] [[package]] @@ -1451,9 +1446,9 @@ dependencies = [ [[package]] name = "libc" -version = "0.2.167" +version = "0.2.169" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09d6582e104315a817dff97f75133544b2e094ee22447d2acf4a74e189ba06fc" +checksum = "b5aba8db14291edd000dfcc4d620c7ebfb122c613afb886ca8803fa4e128a20a" [[package]] name = "libm" @@ -1474,9 +1469,9 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.4.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" [[package]] name = "litemap" @@ -1556,9 +1551,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.8.0" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1" +checksum = "4ffbe83022cedc1d264172192511ae958937694cd57ce297164951b8b3568394" dependencies = [ "adler2", ] @@ -1675,9 +1670,9 @@ dependencies = [ [[package]] name = "object" -version = "0.36.5" +version = "0.36.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aedf0a2d09c573ed1d8d85b30c119153926a2b36dce0ab28322c09a117a4683e" +checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87" dependencies = [ "memchr", ] @@ -1746,12 +1741,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "paste" -version = "1.0.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" - [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -1767,11 +1756,21 @@ version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" +[[package]] +name = "petgraph" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3672b37090dbd86368a4145bc067582552b29c27377cad4e0a306c97f9bd7772" +dependencies = [ + "fixedbitset", + "indexmap", +] + [[package]] name = "pin-project-lite" -version = "0.2.15" +version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "915a1e146535de9163f3987b8944ed8cf49a18bb0056bcebcdcece385cece4ff" +checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" [[package]] name = "pin-utils" @@ -1823,9 +1822,9 @@ dependencies = [ [[package]] name = "predicates" -version = "3.1.2" +version = "3.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e9086cc7640c29a356d1a29fd134380bee9d8f79a17410aa76e7ad295f42c97" +checksum = "a5d19ee57562043d37e82899fade9a22ebab7be9cef5026b07fda9cdd4293573" dependencies = [ "anstyle", "predicates-core", @@ -1833,15 +1832,15 @@ dependencies = [ [[package]] name = "predicates-core" -version = "1.0.8" +version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae8177bee8e75d6846599c6b9ff679ed51e882816914eec639944d7c9aa11931" +checksum = "727e462b119fe9c93fd0eb1429a5f7647394014cf3c04ab2c0350eeb09095ffa" [[package]] name = "predicates-tree" -version = "1.0.11" +version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41b740d195ed3166cd147c8047ec98db0e22ec019eb8eeb76d343b795304fb13" +checksum = "72dd2d6d381dfb73a193c7fca536518d7caee39fc8503f74e7dc0be0531b425c" dependencies = [ "predicates-core", "termtree", @@ -1849,9 +1848,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.25" +version = "0.2.27" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64d1ec885c64d0457d564db4ec299b2dae3f9c02808b8ad9c3a089c591b18033" +checksum = "483f8c21f64f3ea09fe0f30f5d48c3e8eefe5dac9129f0075f76593b4c1da705" dependencies = [ "proc-macro2", "syn", @@ -1916,9 +1915,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.7" +version = "0.5.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b6dfecf2c74bce2466cabf93f6664d6998a69eb21e39f4207930065b27b771f" +checksum = "03a862b389f93e68874fbf580b9de08dd02facb9a788ebadaf4a3fd33cf58834" dependencies = [ "bitflags", ] @@ -2042,15 +2041,15 @@ checksum = "c7fb8039b3032c191086b10f11f319a6e99e1e82889c5cc6046f515c9db1d497" [[package]] name = "rustix" -version = "0.38.41" +version = "0.38.43" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7f649912bc1495e167a6edee79151c84b1bad49748cb4f1f1167f459f6224f6" +checksum = "a78891ee6bf2340288408954ac787aa063d8e8817e9f53abb37c695c6d834ef6" dependencies = [ "bitflags", "errno", "libc", "linux-raw-sys", - "windows-sys 0.52.0", + "windows-sys 0.59.0", ] [[package]] @@ -2114,9 +2113,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.134" +version = "1.0.135" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d00f4175c42ee48b15416f6193a959ba3a0d67fc699a0db9ad12df9f83991c7d" +checksum = "2b0d7ba2887406110130a978386c4e1befb98c674b4fba677954e4db976630d9" dependencies = [ "itoa", "memchr", @@ -2250,21 +2249,11 @@ dependencies = [ "der", ] -[[package]] -name = "sqlformat" -version = "0.2.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bba3a93db0cc4f7bdece8bb09e77e2e785c20bfebf79eb8340ed80708048790" -dependencies = [ - "nom", - "unicode_categories", -] - [[package]] name = "sqlx" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "93334716a037193fac19df402f8571269c84a00852f6a7066b5d2616dcd64d3e" +checksum = "4410e73b3c0d8442c5f99b425d7a435b5ee0ae4167b3196771dd3f7a01be745f" dependencies = [ "sqlx-core", "sqlx-macros", @@ -2275,38 +2264,32 @@ dependencies = [ [[package]] name = "sqlx-core" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4d8060b456358185f7d50c55d9b5066ad956956fddec42ee2e8567134a8936e" +checksum = "6a007b6936676aa9ab40207cde35daab0a04b823be8ae004368c0793b96a61e0" dependencies = [ - "atoi", - "byteorder", "bytes", "chrono", "crc", "crossbeam-queue", "either", "event-listener", - "futures-channel", "futures-core", "futures-intrusive", "futures-io", "futures-util", - "hashbrown 0.14.5", + "hashbrown", "hashlink", - "hex", "indexmap", "log", "memchr", "once_cell", - "paste", "percent-encoding", "serde", "serde_json", "sha2 0.10.8", "smallvec", - "sqlformat", - "thiserror 1.0.69", + "thiserror 2.0.9", "tokio", "tokio-stream", "tracing", @@ -2315,9 +2298,9 @@ dependencies = [ [[package]] name = "sqlx-macros" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cac0692bcc9de3b073e8d747391827297e075c7710ff6276d9f7a1f3d58c6657" +checksum = "3112e2ad78643fef903618d78cf0aec1cb3134b019730edb039b69eaf531f310" dependencies = [ "proc-macro2", "quote", @@ -2328,9 +2311,9 @@ dependencies = [ [[package]] name = "sqlx-macros-core" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1804e8a7c7865599c9c79be146dc8a9fd8cc86935fa641d3ea58e5f0688abaa5" +checksum = "4e9f90acc5ab146a99bf5061a7eb4976b573f560bc898ef3bf8435448dd5e7ad" dependencies = [ "dotenvy", "either", @@ -2354,9 +2337,9 @@ dependencies = [ [[package]] name = "sqlx-mysql" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "64bb4714269afa44aef2755150a0fc19d756fb580a67db8885608cf02f47d06a" +checksum = "4560278f0e00ce64938540546f59f590d60beee33fffbd3b9cd47851e5fff233" dependencies = [ "atoi", "base64", @@ -2390,16 +2373,16 @@ dependencies = [ "smallvec", "sqlx-core", "stringprep", - "thiserror 1.0.69", + "thiserror 2.0.9", "tracing", "whoami", ] [[package]] name = "sqlx-postgres" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6fa91a732d854c5d7726349bb4bb879bb9478993ceb764247660aee25f67c2f8" +checksum = "c5b98a57f363ed6764d5b3a12bfedf62f07aa16e1856a7ddc2a0bb190a959613" dependencies = [ "atoi", "base64", @@ -2411,7 +2394,6 @@ dependencies = [ "etcetera", "futures-channel", "futures-core", - "futures-io", "futures-util", "hex", "hkdf", @@ -2429,16 +2411,16 @@ dependencies = [ "smallvec", "sqlx-core", "stringprep", - "thiserror 1.0.69", + "thiserror 2.0.9", "tracing", "whoami", ] [[package]] name = "sqlx-sqlite" -version = "0.8.2" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5b2cf34a45953bfd3daaf3db0f7a7878ab9b7a6b91b422d24a7a9e4c857b680" +checksum = "f85ca71d3a5b24e64e1d08dd8fe36c6c95c339a896cc33068148906784620540" dependencies = [ "atoi", "chrono", @@ -2489,9 +2471,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.93" +version = "2.0.95" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c786062daee0d6db1132800e623df74274a0a87322d8e183338e01b3d98d058" +checksum = "46f71c0377baf4ef1cc3e3402ded576dccc315800fbc62dfc7fe04b009773b4a" dependencies = [ "proc-macro2", "quote", @@ -2523,12 +2505,13 @@ checksum = "42a4d50cdb458045afc8131fd91b64904da29548bcb63c7236e0844936c13078" [[package]] name = "tempfile" -version = "3.14.0" +version = "3.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28cce251fcbc87fac86a866eeb0d6c2d536fc16d06f184bb61aeae11aa4cee0c" +checksum = "9a8a559c81686f576e8cd0290cd2a24a2a9ad80c98b3478856500fcbd7acd704" dependencies = [ "cfg-if", "fastrand", + "getrandom", "once_cell", "rustix", "windows-sys 0.59.0", @@ -2545,9 +2528,9 @@ dependencies = [ [[package]] name = "termtree" -version = "0.4.1" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" +checksum = "8f50febec83f5ee1df3015341d8bd429f2d1cc62bcba7ea2076759d315084683" [[package]] name = "thiserror" @@ -2642,9 +2625,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.8.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" +checksum = "022db8904dfa342efe721985167e9fcd16c29b226db4397ed752a761cfce81e8" dependencies = [ "tinyvec_macros", ] @@ -2657,9 +2640,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.42.0" +version = "1.43.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5cec9b21b0450273377fc97bd4c33a8acffc8c996c987a7c5b319a0083707551" +checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e" dependencies = [ "backtrace", "bytes", @@ -2673,9 +2656,9 @@ dependencies = [ [[package]] name = "tokio-macros" -version = "2.4.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" +checksum = "6e06d43f1345a3bcd39f6a56dbb7dcab2ba47e68e8ac134855e7e2bdbaf8cab8" dependencies = [ "proc-macro2", "quote", @@ -2907,15 +2890,15 @@ checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "unicase" -version = "2.8.0" +version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e51b68083f157f853b6379db119d1c1be0e6e4dec98101079dec41f6f5cf6df" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" [[package]] name = "unicode-bidi" -version = "0.3.17" +version = "0.3.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ab17db44d7388991a428b2ee655ce0c212e862eff1768a455c58f9aad6e7893" +checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" [[package]] name = "unicode-ident" @@ -2950,12 +2933,6 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" -[[package]] -name = "unicode_categories" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "39ec24b3121d976906ece63c9daad25b85969647682eee313cb5779fdd69e14e" - [[package]] name = "url" version = "2.5.4" @@ -3269,9 +3246,9 @@ checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" -version = "0.6.20" +version = "0.6.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" +checksum = "39281189af81c07ec09db316b302a3e67bf9bd7cbf6c820b50e35fee9c2fa980" dependencies = [ "memchr", ] diff --git a/Cargo.toml b/Cargo.toml index cb567feb..ba1f7950 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -54,6 +54,7 @@ indexmap = "2" mime_guess = { version = "2", default-features = false } mockall = "0.13" password-auth = { version = "1.1.0-pre.1", default-features = false } +petgraph = { version = "0.7", default-features = false } pin-project-lite = "0.2" prettyplease = "0.2" proc-macro-crate = "3" diff --git a/flareon-cli/Cargo.toml b/flareon-cli/Cargo.toml index 677f554d..6b64b5ed 100644 --- a/flareon-cli/Cargo.toml +++ b/flareon-cli/Cargo.toml @@ -18,6 +18,7 @@ darling.workspace = true flareon.workspace = true flareon_codegen = { workspace = true, features = ["symbol-resolver"] } glob.workspace = true +petgraph.workspace = true prettyplease.workspace = true proc-macro2 = { workspace = true, features = ["span-locations"] } quote.workspace = true diff --git a/flareon-cli/src/main.rs b/flareon-cli/src/main.rs index 3e1f8e66..02c6c038 100644 --- a/flareon-cli/src/main.rs +++ b/flareon-cli/src/main.rs @@ -1,3 +1,5 @@ +extern crate core; + mod migration_generator; mod utils; @@ -25,9 +27,6 @@ enum Commands { /// Path to the crate directory to generate migrations for (default: /// current directory) path: Option, - /// Name of the app to use in the migration (default: crate name) - #[arg(long)] - app_name: Option, /// Directory to write the migrations to (default: migrations/ directory /// in the crate's src/ directory) #[arg(long)] @@ -47,16 +46,9 @@ fn main() -> anyhow::Result<()> { .init(); match cli.command { - Commands::MakeMigrations { - path, - app_name, - output_dir, - } => { + Commands::MakeMigrations { path, output_dir } => { let path = path.unwrap_or_else(|| PathBuf::from(".")); - let options = MigrationGeneratorOptions { - app_name, - output_dir, - }; + let options = MigrationGeneratorOptions { output_dir }; make_migrations(&path, options).with_context(|| "unable to create migrations")?; } } diff --git a/flareon-cli/src/migration_generator.rs b/flareon-cli/src/migration_generator.rs index a8c13f7c..bb8cb471 100644 --- a/flareon-cli/src/migration_generator.rs +++ b/flareon-cli/src/migration_generator.rs @@ -10,11 +10,13 @@ use cargo_toml::Manifest; use darling::FromMeta; use flareon::db::migrations::{DynMigration, MigrationEngine}; use flareon_codegen::model::{Field, Model, ModelArgs, ModelOpts, ModelType}; -use flareon_codegen::symbol_resolver::{ModulePath, SymbolResolver, VisibleSymbol}; +use flareon_codegen::symbol_resolver::SymbolResolver; +use petgraph::graph::{DiGraph, NodeIndex}; +use petgraph::visit::EdgeRef; use proc_macro2::TokenStream; -use quote::{format_ident, quote}; -use syn::{parse_quote, Attribute, Meta, UseTree}; -use tracing::{debug, info, warn}; +use quote::{format_ident, quote, ToTokens}; +use syn::{parse_quote, Attribute, Meta}; +use tracing::{debug, info, trace}; use crate::utils::find_cargo_toml; @@ -46,7 +48,6 @@ pub fn make_migrations(path: &Path, options: MigrationGeneratorOptions) -> anyho #[derive(Debug, Clone, Default)] pub struct MigrationGeneratorOptions { - pub app_name: Option, pub output_dir: Option, } @@ -75,7 +76,7 @@ impl MigrationGenerator { let source_files = self.get_source_files()?; if let Some(migration) = self.generate_migrations_to_write(source_files)? { - self.write_migration(migration)?; + self.write_migration(&migration)?; } Ok(()) @@ -109,12 +110,19 @@ impl MigrationGenerator { let migration_name = migration_processor.next_migration_name()?; let dependencies = migration_processor.base_dependencies(); - Ok(Some(GeneratedMigration { + let mut migration = GeneratedMigration { migration_name, modified_models, dependencies, operations, - })) + }; + migration.remove_cycles(); + migration.toposort_operations(); + migration + .dependencies + .extend(migration.get_foreign_key_dependencies(&self.crate_name)); + + Ok(Some(migration)) } } @@ -187,6 +195,8 @@ impl MigrationGenerator { }: SourceFile, app_state: &mut AppState, ) -> anyhow::Result<()> { + trace!("Processing file: {:?}", &path); + let symbol_resolver = SymbolResolver::from_file(&file, &path); let mut migration_models = Vec::new(); @@ -201,8 +211,20 @@ impl MigrationGenerator { ModelInSource::from_item(item, &args, &symbol_resolver)?; match args.model_type { - ModelType::Application => app_state.models.push(model_in_source), - ModelType::Migration => migration_models.push(model_in_source), + ModelType::Application => { + trace!( + "Found an Application model: {}", + model_in_source.model.name.to_string() + ); + app_state.models.push(model_in_source); + } + ModelType::Migration => { + trace!( + "Found a Migration model: {}", + model_in_source.model.name.to_string() + ); + migration_models.push(model_in_source); + } ModelType::Internal => {} } @@ -297,6 +319,7 @@ impl MigrationGenerator { fn make_create_model_operation(app_model: &ModelInSource) -> DynOperation { DynOperation::CreateModel { table_name: app_model.model.table_name.clone(), + model_ty: app_model.model.resolved_ty.clone().expect("resolved_ty is expected to be present when parsing the entire file with symbol resolver"), fields: app_model.model.fields.clone(), } } @@ -320,6 +343,7 @@ impl MigrationGenerator { } let mut all_field_names: Vec<_> = all_field_names.into_iter().collect(); + // sort to ensure deterministic order all_field_names.sort(); let mut operations = Vec::new(); @@ -357,6 +381,7 @@ impl MigrationGenerator { fn make_add_field_operation(app_model: &ModelInSource, field: &Field) -> DynOperation { DynOperation::AddField { table_name: app_model.model.table_name.clone(), + model_ty: app_model.model.resolved_ty.clone().expect("resolved_ty is expected to be present when parsing the entire file with symbol resolver"), field: field.clone(), } } @@ -401,7 +426,7 @@ impl MigrationGenerator { .map(|dependency| dependency.repr()) .collect(); - let app_name = self.options.app_name.as_ref().unwrap_or(&self.crate_name); + let app_name = &self.crate_name; let migration_name = &migration.migration_name; let migration_def = quote! { #[derive(Debug, Copy, Clone)] @@ -431,7 +456,7 @@ impl MigrationGenerator { Self::generate_migration(migration_def, models_def) } - fn write_migration(&self, migration: MigrationAsSource) -> anyhow::Result<()> { + fn write_migration(&self, migration: &MigrationAsSource) -> anyhow::Result<()> { let src_path = self .options .output_dir @@ -640,6 +665,160 @@ pub struct GeneratedMigration { pub operations: Vec, } +impl GeneratedMigration { + fn get_foreign_key_dependencies(&self, crate_name: &str) -> Vec { + let create_ops = self.get_create_ops_map(); + let ops_adding_foreign_keys = self.get_ops_adding_foreign_keys(); + + let mut dependencies = Vec::new(); + for (_index, dependency_ty) in &ops_adding_foreign_keys { + if !create_ops.contains_key(dependency_ty) { + dependencies.push(DynDependency::Model { + model_type: dependency_ty.clone(), + }); + } + } + + dependencies + } + + fn remove_cycles(&mut self) { + let graph = self.construct_dependency_graph(); + + let cycle_edges = petgraph::algo::feedback_arc_set::greedy_feedback_arc_set(&graph); + for edge_id in cycle_edges { + let (from, to) = graph.edge_endpoints(edge_id.id()).unwrap(); + + let to_op = self.operations[to.index()].clone(); + let from_op = &mut self.operations[from.index()]; + debug!( + "Removing cycle by removing operation {:?} that depends on {:?}", + from_op, to_op + ); + + let to_add = Self::remove_dependency(from_op, &to_op); + self.operations.extend(to_add); + } + } + + #[must_use] + fn remove_dependency(from: &mut DynOperation, to: &DynOperation) -> Vec { + match from { + DynOperation::CreateModel { + table_name, + model_ty, + fields, + } => { + let to_type = match to { + DynOperation::CreateModel { model_ty, .. } => model_ty, + DynOperation::AddField { .. } => { + unreachable!("AddField operation shouldn't be a dependency of CreateModel because it doesn't create a new model") + } + }; + trace!( + "Removing foreign keys from {} to {}", + model_ty.to_token_stream().to_string(), + to_type.into_token_stream().to_string() + ); + + let mut result = Vec::new(); + let (fields_to_remove, fields_to_retain): (Vec<_>, Vec<_>) = std::mem::take(fields) + .into_iter() + .partition(|field| is_field_foreign_key_to(field, to_type)); + *fields = fields_to_retain; + + for field in fields_to_remove { + result.push(DynOperation::AddField { + table_name: table_name.clone(), + model_ty: model_ty.clone(), + field, + }); + } + + result + } + DynOperation::AddField { .. } => { + // AddField only links two already existing models together, so + // removing it shouldn't ever affect whether a graph is cyclic + unreachable!("AddField operation should never create cycles") + } + } + } + + fn toposort_operations(&mut self) { + let graph = self.construct_dependency_graph(); + + let sorted = petgraph::algo::toposort(&graph, None) + .expect("cycles shouldn't exist after removing them"); + let mut sorted = sorted + .into_iter() + .map(petgraph::prelude::NodeIndex::index) + .collect::>(); + flareon::__private::apply_permutation(&mut self.operations, &mut sorted); + } + + #[must_use] + fn construct_dependency_graph(&mut self) -> DiGraph { + let create_ops = self.get_create_ops_map(); + let ops_adding_foreign_keys = self.get_ops_adding_foreign_keys(); + + let mut graph = DiGraph::with_capacity(self.operations.len(), 0); + + for i in 0..self.operations.len() { + graph.add_node(i); + } + for (i, dependency_ty) in &ops_adding_foreign_keys { + if let Some(&dependency) = create_ops.get(dependency_ty) { + graph.add_edge(NodeIndex::new(dependency), NodeIndex::new(*i), ()); + } + } + + graph + } + + /// Return a map of (resolved) model types to the index of the + /// operation that creates given model. + #[must_use] + fn get_create_ops_map(&self) -> HashMap { + self.operations + .iter() + .enumerate() + .filter_map(|(i, op)| match op { + DynOperation::CreateModel { model_ty, .. } => Some((model_ty.clone(), i)), + _ => None, + }) + .collect() + } + + /// Return a list of operations that add foreign keys as tuples of + /// operation index and the type of the model that foreign key points to. + #[must_use] + fn get_ops_adding_foreign_keys(&self) -> Vec<(usize, syn::Type)> { + self.operations + .iter() + .enumerate() + .flat_map(|(i, op)| match op { + DynOperation::CreateModel { fields, .. } => fields + .iter() + .filter_map(foreign_key_for_field) + .map(|to_model| (i, to_model)) + .collect::>(), + DynOperation::AddField { + field, model_ty, .. + } => { + let mut ops = vec![(i, model_ty.clone())]; + + if let Some(to_type) = foreign_key_for_field(field) { + ops.push((i, to_type)); + } + + ops + } + }) + .collect() + } +} + /// A migration represented as a generated and ready to write Rust source code. #[derive(Debug, Clone)] pub struct MigrationAsSource { @@ -678,13 +857,25 @@ impl Repr for Field { }; if self .auto_value - .expect("auto_value is expected to be present when parsing the entire file") + .expect("auto_value is expected to be present when parsing the entire file with symbol resolver") { tokens = quote! { #tokens.auto() } } if self.primary_key { tokens = quote! { #tokens.primary_key() } } + if let Some(fk_spec) = self.foreign_key.clone().expect("foreign_key is expected to be present when parsing the entire file with symbol resolver") { + let to_model = &fk_spec.to_model; + + tokens = quote! { + #tokens.foreign_key( + <#to_model as ::flareon::db::Model>::TABLE_NAME, + <#to_model as ::flareon::db::Model>::PRIMARY_KEY_NAME, + ::flareon::db::ForeignKeyOnDeletePolicy::Restrict, + ::flareon::db::ForeignKeyOnUpdatePolicy::Restrict, + ) + } + } tokens = quote! { #tokens.set_null(<#ty as ::flareon::db::DatabaseField>::NULLABLE) }; if self.unique { tokens = quote! { #tokens.unique() } @@ -725,7 +916,7 @@ impl DynMigration for Migration { #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum DynDependency { Migration { app: String, migration: String }, - Model { app: String, model_name: String }, + Model { model_type: syn::Type }, } impl Repr for DynDependency { @@ -736,9 +927,12 @@ impl Repr for DynDependency { ::flareon::db::migrations::MigrationDependency::migration(#app, #migration) } } - Self::Model { app, model_name } => { + Self::Model { model_type } => { quote! { - ::flareon::db::migrations::MigrationDependency::model(#app, #model_name) + ::flareon::db::migrations::MigrationDependency::model( + <#model_type as ::flareon::db::Model>::APP_NAME, + <#model_type as ::flareon::db::Model>::TABLE_NAME + ) } } } @@ -753,27 +947,50 @@ impl Repr for DynDependency { pub enum DynOperation { CreateModel { table_name: String, + model_ty: syn::Type, fields: Vec, }, AddField { table_name: String, + model_ty: syn::Type, field: Field, }, } -impl DynOperation { - fn foreign_keys_added(&self) -> Vec<&syn::Type> { - match self { - DynOperation::CreateModel { fields, .. } => {} - DynOperation::AddField { field, .. } => {} - } +/// Returns whether given [`Field`] is a foreign key to given type. +fn is_field_foreign_key_to(field: &Field, ty: &syn::Type) -> bool { + foreign_key_for_field(field).is_some_and(|to_model| &to_model == ty) +} + +/// Returns the type of the model that the given field is a foreign key to. +/// Returns [`None`] if the field is not a foreign key. +fn foreign_key_for_field(field: &Field) -> Option { + match field.foreign_key.clone().expect( + "foreign_key is expected to be present when parsing the entire file with symbol resolver", + ) { + None => None, + Some(foreign_key_spec) => Some(foreign_key_spec.to_model), } } impl Repr for DynOperation { fn repr(&self) -> TokenStream { match self { - Self::CreateModel { table_name, fields } => { + Self::CreateModel { + table_name, + model_ty, + fields, + .. + } => { + let model_name = match model_ty { + syn::Type::Path(syn::TypePath { path, .. }) => path + .segments + .last() + .expect("TypePath must have at least one segment") + .ident + .to_string(), + _ => unreachable!("model_ty is expected to be a TypePath"), + }; let fields = fields.iter().map(Repr::repr).collect::>(); quote! { ::flareon::db::migrations::Operation::create_model() @@ -784,7 +1001,9 @@ impl Repr for DynOperation { .build() } } - Self::AddField { table_name, field } => { + Self::AddField { + table_name, field, .. + } => { let field = field.repr(); quote! { ::flareon::db::migrations::Operation::add_field() @@ -835,6 +1054,9 @@ impl Error for ParsingError {} #[cfg(test)] mod tests { + use flareon_codegen::maybe_unknown::MaybeUnknown; + use flareon_codegen::model::ForeignKeySpec; + use super::*; #[test] @@ -873,4 +1095,260 @@ mod tests { }] ); } + + #[test] + fn toposort_operations() { + let mut migration = GeneratedMigration { + migration_name: "test_migration".to_string(), + modified_models: vec![], + dependencies: vec![], + operations: vec![ + DynOperation::AddField { + table_name: "table2".to_string(), + model_ty: parse_quote!(Table2), + field: Field { + field_name: format_ident!("field1"), + column_name: "field1".to_string(), + ty: parse_quote!(i32), + auto_value: MaybeUnknown::Known(false), + primary_key: false, + unique: false, + foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { + to_model: parse_quote!(Table1), + })), + }, + }, + DynOperation::CreateModel { + table_name: "table1".to_string(), + model_ty: parse_quote!(Table1), + fields: vec![], + }, + ], + }; + + migration.toposort_operations(); + + assert_eq!(migration.operations.len(), 2); + if let DynOperation::CreateModel { table_name, .. } = &migration.operations[0] { + assert_eq!(table_name, "table1"); + } else { + panic!("Expected CreateModel operation"); + } + if let DynOperation::AddField { table_name, .. } = &migration.operations[1] { + assert_eq!(table_name, "table2"); + } else { + panic!("Expected AddField operation"); + } + } + + #[test] + fn remove_cycles() { + let mut migration = GeneratedMigration { + migration_name: "test_migration".to_string(), + modified_models: vec![], + dependencies: vec![], + operations: vec![ + DynOperation::CreateModel { + table_name: "table1".to_string(), + model_ty: parse_quote!(Table1), + fields: vec![Field { + field_name: format_ident!("field1"), + column_name: "field1".to_string(), + ty: parse_quote!(ForeignKey), + auto_value: MaybeUnknown::Known(false), + primary_key: false, + unique: false, + foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { + to_model: parse_quote!(Table2), + })), + }], + }, + DynOperation::CreateModel { + table_name: "table2".to_string(), + model_ty: parse_quote!(Table2), + fields: vec![Field { + field_name: format_ident!("field1"), + column_name: "field1".to_string(), + ty: parse_quote!(ForeignKey), + auto_value: MaybeUnknown::Known(false), + primary_key: false, + unique: false, + foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { + to_model: parse_quote!(Table1), + })), + }], + }, + ], + }; + + migration.remove_cycles(); + + assert_eq!(migration.operations.len(), 3); + if let DynOperation::CreateModel { + table_name, fields, .. + } = &migration.operations[0] + { + assert_eq!(table_name, "table1"); + assert!(!fields.is_empty()); + } else { + panic!("Expected CreateModel operation"); + } + if let DynOperation::CreateModel { + table_name, fields, .. + } = &migration.operations[1] + { + assert_eq!(table_name, "table2"); + assert!(fields.is_empty()); + } else { + panic!("Expected CreateModel operation"); + } + if let DynOperation::AddField { table_name, .. } = &migration.operations[2] { + assert_eq!(table_name, "table2"); + } else { + panic!("Expected AddField operation"); + } + } + + #[test] + fn remove_dependency() { + let mut create_model_op = DynOperation::CreateModel { + table_name: "table1".to_string(), + model_ty: parse_quote!(Table1), + fields: vec![Field { + field_name: format_ident!("field1"), + column_name: "field1".to_string(), + ty: parse_quote!(ForeignKey), + auto_value: MaybeUnknown::Known(false), + primary_key: false, + unique: false, + foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { + to_model: parse_quote!(Table2), + })), + }], + }; + + let add_field_op = DynOperation::CreateModel { + table_name: "table2".to_string(), + model_ty: parse_quote!(Table2), + fields: vec![], + }; + + let additional_ops = + GeneratedMigration::remove_dependency(&mut create_model_op, &add_field_op); + + match create_model_op { + DynOperation::CreateModel { fields, .. } => { + assert_eq!(fields.len(), 0); + } + _ => { + panic!("Expected from operation not to change type"); + } + } + assert_eq!(additional_ops.len(), 1); + if let DynOperation::AddField { table_name, .. } = &additional_ops[0] { + assert_eq!(table_name, "table1"); + } else { + panic!("Expected AddField operation"); + } + } + + #[test] + fn get_foreign_key_dependencies_no_foreign_keys() { + let migration = GeneratedMigration { + migration_name: "test_migration".to_string(), + modified_models: vec![], + dependencies: vec![], + operations: vec![DynOperation::CreateModel { + table_name: "table1".to_string(), + model_ty: parse_quote!(Table1), + fields: vec![], + }], + }; + + let external_dependencies = migration.get_foreign_key_dependencies("my_crate"); + assert!(external_dependencies.is_empty()); + } + + #[test] + fn get_foreign_key_dependencies_with_foreign_keys() { + let migration = GeneratedMigration { + migration_name: "test_migration".to_string(), + modified_models: vec![], + dependencies: vec![], + operations: vec![DynOperation::CreateModel { + table_name: "table1".to_string(), + model_ty: parse_quote!(Table1), + fields: vec![Field { + field_name: format_ident!("field1"), + column_name: "field1".to_string(), + ty: parse_quote!(ForeignKey), + auto_value: MaybeUnknown::Known(false), + primary_key: false, + unique: false, + foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { + to_model: parse_quote!(crate::Table2), + })), + }], + }], + }; + + let external_dependencies = migration.get_foreign_key_dependencies("my_crate"); + assert_eq!(external_dependencies.len(), 1); + assert_eq!( + external_dependencies[0], + DynDependency::Model { + model_type: parse_quote!(crate::Table2), + } + ); + } + + #[test] + fn get_foreign_key_dependencies_with_multiple_foreign_keys() { + let migration = GeneratedMigration { + migration_name: "test_migration".to_string(), + modified_models: vec![], + dependencies: vec![], + operations: vec![ + DynOperation::CreateModel { + table_name: "table1".to_string(), + model_ty: parse_quote!(Table1), + fields: vec![Field { + field_name: format_ident!("field1"), + column_name: "field1".to_string(), + ty: parse_quote!(ForeignKey), + auto_value: MaybeUnknown::Known(false), + primary_key: false, + unique: false, + foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { + to_model: parse_quote!(my_crate::Table2), + })), + }], + }, + DynOperation::CreateModel { + table_name: "table3".to_string(), + model_ty: parse_quote!(Table3), + fields: vec![Field { + field_name: format_ident!("field2"), + column_name: "field2".to_string(), + ty: parse_quote!(ForeignKey), + auto_value: MaybeUnknown::Known(false), + primary_key: false, + unique: false, + foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { + to_model: parse_quote!(crate::Table4), + })), + }], + }, + ], + }; + + let external_dependencies = migration.get_foreign_key_dependencies("my_crate"); + assert_eq!(external_dependencies.len(), 2); + assert!(external_dependencies.contains(&DynDependency::Model { + model_type: parse_quote!(my_crate::Table2), + })); + assert!(external_dependencies.contains(&DynDependency::Model { + model_type: parse_quote!(crate::Table4), + })); + } } diff --git a/flareon-cli/tests/migration_generator.rs b/flareon-cli/tests/migration_generator.rs index 41ce6e5a..43509e76 100644 --- a/flareon-cli/tests/migration_generator.rs +++ b/flareon-cli/tests/migration_generator.rs @@ -1,8 +1,10 @@ use std::path::PathBuf; use flareon_cli::migration_generator::{ - DynOperation, MigrationAsSource, MigrationGenerator, MigrationGeneratorOptions, SourceFile, + DynDependency, DynOperation, MigrationAsSource, MigrationGenerator, MigrationGeneratorOptions, + SourceFile, }; +use syn::parse_quote; /// Test that the migration generator can generate a "create model" migration /// for a given model that has an expected state. @@ -19,30 +21,138 @@ fn create_model_state_test() { assert_eq!(migration.migration_name, "m_0001_initial"); assert!(migration.dependencies.is_empty()); - if let DynOperation::CreateModel { table_name, fields } = &migration.operations[0] { - assert_eq!(table_name, "my_model"); - assert_eq!(fields.len(), 3); - - let field = &fields[0]; - assert_eq!(field.column_name, "id"); - assert!(field.primary_key); - assert!(field.auto_value.unwrap()); - assert!(field.foreign_key.clone().unwrap().is_none()); - - let field = &fields[1]; - assert_eq!(field.column_name, "field_1"); - assert!(!field.primary_key); - assert!(!field.auto_value.unwrap()); - assert!(field.foreign_key.clone().unwrap().is_none()); - - let field = &fields[2]; - assert_eq!(field.column_name, "field_2"); - assert!(!field.primary_key); - assert!(!field.auto_value.unwrap()); - assert!(field.foreign_key.clone().unwrap().is_none()); - } else { - panic!("Expected a create model operation"); - } + + let (table_name, fields) = unwrap_create_model(&migration.operations[0]); + assert_eq!(table_name, "parent"); + assert_eq!(fields.len(), 1); + + let (table_name, fields) = unwrap_create_model(&migration.operations[1]); + assert_eq!(table_name, "my_model"); + assert_eq!(fields.len(), 4); + + let field = &fields[0]; + assert_eq!(field.column_name, "id"); + assert!(field.primary_key); + assert!(field.auto_value.unwrap()); + assert!(field.foreign_key.clone().unwrap().is_none()); + + let field = &fields[1]; + assert_eq!(field.column_name, "field_1"); + assert!(!field.primary_key); + assert!(!field.auto_value.unwrap()); + assert!(field.foreign_key.clone().unwrap().is_none()); + + let field = &fields[2]; + assert_eq!(field.column_name, "field_2"); + assert!(!field.primary_key); + assert!(!field.auto_value.unwrap()); + assert!(field.foreign_key.clone().unwrap().is_none()); + + let field = &fields[3]; + assert_eq!(field.column_name, "parent"); + assert!(!field.primary_key); + assert!(!field.auto_value.unwrap()); + assert!(field.foreign_key.clone().unwrap().is_some()); +} + +#[test] +fn create_models_foreign_key() { + let mut generator = test_generator(); + let src = include_str!("migration_generator/foreign_key.rs"); + let source_files = vec![SourceFile::parse(PathBuf::from("main.rs"), src).unwrap()]; + + let migration = generator + .generate_migrations(source_files) + .unwrap() + .unwrap(); + + assert_eq!(migration.dependencies.len(), 0); + assert_eq!(migration.operations.len(), 2); + + // Parent must be created before Child + let (table_name, fields) = unwrap_create_model(&migration.operations[0]); + assert_eq!(table_name, "parent"); + assert_eq!(fields.len(), 1); + + let (table_name, fields) = unwrap_create_model(&migration.operations[1]); + assert_eq!(table_name, "child"); + assert_eq!(fields.len(), 2); + + let field = &fields[0]; + assert_eq!(field.column_name, "id"); + assert!(field.primary_key); + assert!(field.auto_value.unwrap()); + assert!(field.foreign_key.clone().unwrap().is_none()); + + let field = &fields[1]; + assert_eq!(field.column_name, "parent"); + assert!(!field.primary_key); + assert!(!field.auto_value.unwrap()); + assert!(field.foreign_key.clone().unwrap().is_some()); +} + +#[test] +fn create_models_foreign_key_cycle() { + let mut generator = test_generator(); + let src = include_str!("migration_generator/foreign_key_cycle.rs"); + let source_files = vec![SourceFile::parse(PathBuf::from("main.rs"), src).unwrap()]; + + let migration = generator + .generate_migrations(source_files) + .unwrap() + .unwrap(); + + assert_eq!(migration.dependencies.len(), 0); + assert_eq!(migration.operations.len(), 3); + + // Parent must be created before Child + let (table_name, fields) = unwrap_create_model(&migration.operations[0]); + assert_eq!(table_name, "parent"); + assert_eq!(fields.len(), 1); + + let (table_name, fields) = unwrap_create_model(&migration.operations[1]); + assert_eq!(table_name, "child"); + assert_eq!(fields.len(), 2); + + let (table_name, field) = unwrap_add_field(&migration.operations[2]); + assert_eq!(table_name, "parent"); + assert_eq!(field.field_name, "child"); +} + +#[test] +fn create_models_foreign_two_migrations() { + let mut generator = test_generator(); + + let src = include_str!("migration_generator/foreign_key_two_migrations/step_1.rs"); + let source_files = vec![SourceFile::parse(PathBuf::from("main.rs"), src).unwrap()]; + let migration_file = generator + .generate_migrations_to_write(source_files) + .unwrap() + .unwrap(); + + let src = include_str!("migration_generator/foreign_key_two_migrations/step_2.rs"); + let source_files = vec![ + SourceFile::parse(PathBuf::from("main.rs"), src).unwrap(), + SourceFile::parse(PathBuf::from(&migration_file.name), &migration_file.content).unwrap(), + ]; + let migration = generator + .generate_migrations(source_files) + .unwrap() + .unwrap(); + + assert_eq!(migration.dependencies.len(), 2); + assert!(migration.dependencies.contains(&DynDependency::Migration { + app: "my_crate".to_string(), + migration: "m_0001_initial".to_string() + })); + assert!(migration.dependencies.contains(&DynDependency::Model { + model_type: parse_quote!(crate::Parent), + })); + + assert_eq!(migration.operations.len(), 1); + + let (table_name, _fields) = unwrap_create_model(&migration.operations[0]); + assert_eq!(table_name, "child"); } /// Test that the migration generator can generate a "create model" migration @@ -88,3 +198,25 @@ fn test_generator() -> MigrationGenerator { MigrationGeneratorOptions::default(), ) } + +fn unwrap_create_model(op: &DynOperation) -> (&str, Vec) { + if let DynOperation::CreateModel { + table_name, fields, .. + } = op + { + (table_name, fields.clone()) + } else { + panic!("expected create model operation"); + } +} + +fn unwrap_add_field(op: &DynOperation) -> (&str, flareon_codegen::model::Field) { + if let DynOperation::AddField { + table_name, field, .. + } = op + { + (table_name, field.clone()) + } else { + panic!("expected create model operation"); + } +} diff --git a/flareon-cli/tests/migration_generator/create_model.rs b/flareon-cli/tests/migration_generator/create_model.rs index fd19eab5..707503e5 100644 --- a/flareon-cli/tests/migration_generator/create_model.rs +++ b/flareon-cli/tests/migration_generator/create_model.rs @@ -1,12 +1,20 @@ -use flareon::db::{model, Auto, LimitedString}; +use flareon::db::{model, Auto, ForeignKey, LimitedString}; pub const FIELD_LEN: u32 = 64; +#[derive(Debug)] +#[model] +struct Parent { + id: Auto, +} + +#[derive(Debug)] #[model] struct MyModel { id: Auto, field_1: String, field_2: LimitedString, + parent: ForeignKey, } fn main() {} diff --git a/flareon-cli/tests/migration_generator/foreign_key.rs b/flareon-cli/tests/migration_generator/foreign_key.rs new file mode 100644 index 00000000..c6c7def0 --- /dev/null +++ b/flareon-cli/tests/migration_generator/foreign_key.rs @@ -0,0 +1,14 @@ +use flareon::db::{model, Auto, ForeignKey}; + +#[model] +struct Child { + id: Auto, + parent: ForeignKey, +} + +#[model] +struct Parent { + id: Auto, +} + +fn main() {} diff --git a/flareon-cli/tests/migration_generator/foreign_key_cycle.rs b/flareon-cli/tests/migration_generator/foreign_key_cycle.rs new file mode 100644 index 00000000..e197ef34 --- /dev/null +++ b/flareon-cli/tests/migration_generator/foreign_key_cycle.rs @@ -0,0 +1,15 @@ +use flareon::db::{model, Auto, ForeignKey}; + +#[model] +struct Child { + id: Auto, + parent: ForeignKey, +} + +#[model] +struct Parent { + id: Auto, + child: ForeignKey, +} + +fn main() {} diff --git a/flareon-cli/tests/migration_generator/foreign_key_two_migrations/step_1.rs b/flareon-cli/tests/migration_generator/foreign_key_two_migrations/step_1.rs new file mode 100644 index 00000000..ba5fa83e --- /dev/null +++ b/flareon-cli/tests/migration_generator/foreign_key_two_migrations/step_1.rs @@ -0,0 +1,8 @@ +use flareon::db::{model, Auto, ForeignKey}; + +#[model] +struct Parent { + id: Auto, +} + +fn main() {} diff --git a/flareon-cli/tests/migration_generator/foreign_key_two_migrations/step_2.rs b/flareon-cli/tests/migration_generator/foreign_key_two_migrations/step_2.rs new file mode 100644 index 00000000..c6c7def0 --- /dev/null +++ b/flareon-cli/tests/migration_generator/foreign_key_two_migrations/step_2.rs @@ -0,0 +1,14 @@ +use flareon::db::{model, Auto, ForeignKey}; + +#[model] +struct Child { + id: Auto, + parent: ForeignKey, +} + +#[model] +struct Parent { + id: Auto, +} + +fn main() {} diff --git a/flareon-codegen/Cargo.toml b/flareon-codegen/Cargo.toml index 97e7437e..88335930 100644 --- a/flareon-codegen/Cargo.toml +++ b/flareon-codegen/Cargo.toml @@ -11,14 +11,14 @@ workspace = true [dependencies] convert_case.workspace = true darling.workspace = true -log = { workspace = true, optional = true } proc-macro2.workspace = true quote.workspace = true syn.workspace = true +tracing = { workspace = true, optional = true } [dev-dependencies] proc-macro2 = { workspace = true, features = ["span-locations"] } [features] default = [] -symbol-resolver = ["dep:log"] +symbol-resolver = ["dep:tracing"] diff --git a/flareon-codegen/src/expr.rs b/flareon-codegen/src/expr.rs index 806142f9..53187f38 100644 --- a/flareon-codegen/src/expr.rs +++ b/flareon-codegen/src/expr.rs @@ -123,7 +123,6 @@ impl Parse for MemberAccessParser { #[derive(Debug)] struct FunctionCallParser { - paren_token: syn::token::Paren, args: syn::punctuated::Punctuated, } @@ -137,8 +136,8 @@ impl FunctionCallParser { impl Parse for FunctionCallParser { fn parse(input: ParseStream) -> syn::Result { let args_content; + syn::parenthesized!(args_content in input); Ok(Self { - paren_token: syn::parenthesized!(args_content in input), args: args_content.parse_terminated(syn::Expr::parse, Token![,])?, }) } diff --git a/flareon-codegen/src/lib.rs b/flareon-codegen/src/lib.rs index a55bb032..4e9c1cc0 100644 --- a/flareon-codegen/src/lib.rs +++ b/flareon-codegen/src/lib.rs @@ -1,14 +1,14 @@ extern crate self as flareon_codegen; pub mod expr; -mod maybe_unknown; +pub mod maybe_unknown; pub mod model; #[cfg(feature = "symbol-resolver")] pub mod symbol_resolver; #[cfg(not(feature = "symbol-resolver"))] pub mod symbol_resolver { - /// Dummy SymbolResolver for use in contexts when it's not useful (e.g. + /// Dummy `SymbolResolver` for use in contexts when it's not useful (e.g. /// macros which do not have access to the entire source tree to look /// for `use` statements anyway). /// diff --git a/flareon-codegen/src/maybe_unknown.rs b/flareon-codegen/src/maybe_unknown.rs index 3b200b3e..23784356 100644 --- a/flareon-codegen/src/maybe_unknown.rs +++ b/flareon-codegen/src/maybe_unknown.rs @@ -4,7 +4,7 @@ pub enum MaybeUnknown { /// Indicates that this instance is determined to be a certain value /// (possibly [`None`] if wrapping an [`Option`]). - Determined(T), + Known(T), /// Indicates that the value is unknown. Unknown, } @@ -16,7 +16,7 @@ impl MaybeUnknown { pub fn expect(self, msg: &str) -> T { match self { - MaybeUnknown::Determined(value) => value, + MaybeUnknown::Known(value) => value, MaybeUnknown::Unknown => { panic!("{}", msg) } @@ -30,10 +30,16 @@ mod tests { #[test] fn maybe_unknown_determined() { - let value = MaybeUnknown::Determined(42); + let value = MaybeUnknown::Known(42); assert_eq!(value.unwrap(), 42); } + #[test] + fn maybe_unknown_known_none() { + let value = MaybeUnknown::Known(None::<()>); + assert!(value.unwrap().is_none()); + } + #[test] #[should_panic(expected = "called `MaybeUnknown::unwrap()` on an `Unknown` value")] fn maybe_unknown_unknown_unwrap() { @@ -43,7 +49,7 @@ mod tests { #[test] fn maybe_unknown_expect() { - let value = MaybeUnknown::Determined(42); + let value = MaybeUnknown::Known(42); assert_eq!(value.expect("value should be determined"), 42); } diff --git a/flareon-codegen/src/model.rs b/flareon-codegen/src/model.rs index 0215a0e2..d820f2bc 100644 --- a/flareon-codegen/src/model.rs +++ b/flareon-codegen/src/model.rs @@ -94,9 +94,22 @@ impl ModelOpts { let primary_key_field = self.get_primary_key_field(&fields)?; + let ty = match symbol_resolver { + Some(symbol_resolver) => { + let mut ty = syn::Type::Path(syn::TypePath { + qself: None, + path: syn::Path::from(self.ident.clone()), + }); + symbol_resolver.resolve(&mut ty); + Some(ty) + } + None => None, + }; + Ok(Model { name: self.ident.clone(), original_name, + resolved_ty: ty, model_type: args.model_type, table_name, pk_field: primary_key_field.clone(), @@ -194,8 +207,8 @@ impl FieldOpts { let (auto_value, foreign_key) = match symbol_resolver { Some(resolver) => ( - MaybeUnknown::Determined(self.find_type("flareon::db::Auto", resolver).is_some()), - MaybeUnknown::Determined( + MaybeUnknown::Known(self.find_type("flareon::db::Auto", resolver).is_some()), + MaybeUnknown::Known( self.find_type("flareon::db::ForeignKey", resolver) .map(ForeignKeySpec::try_from) .transpose()?, @@ -221,6 +234,9 @@ impl FieldOpts { pub struct Model { pub name: syn::Ident, pub original_name: String, + /// The type of the model, or [`None`] if the symbol resolver was not + /// enabled. + pub resolved_ty: Option, pub model_type: ModelType, pub table_name: String, pub pk_field: Field, @@ -244,8 +260,8 @@ pub struct Field { /// a [`SymbolResolver`]. pub auto_value: MaybeUnknown, pub primary_key: bool, - /// [`Some`] wrapped in [`MaybeUnknown::Determined`] if this field is a - /// foreign key; [`None`] wrapped in [`MaybeUnknown::Determined`] if this + /// [`Some`] wrapped in [`MaybeUnknown::Known`] if this field is a + /// foreign key; [`None`] wrapped in [`MaybeUnknown::Known`] if this /// field is determined not to be a foreign key; [`MaybeUnknown::Unknown`] /// if this `Field` instance was not resolved with a [`SymbolResolver`]. pub foreign_key: MaybeUnknown>, @@ -451,7 +467,7 @@ mod tests { assert!(FieldOpts::find_type_resolved(&input, "my_crate::MyContainer").is_some()); assert!(FieldOpts::find_type_resolved(&input, "Vec").is_some()); assert!(FieldOpts::find_type_resolved(&input, "std::string::String").is_some()); - assert!(!FieldOpts::find_type_resolved(&input, "OtherType").is_some()); + assert!(FieldOpts::find_type_resolved(&input, "OtherType").is_none()); } #[cfg(feature = "symbol-resolver")] @@ -473,7 +489,7 @@ mod tests { assert!(opts.find_type("my_crate::MyContainer", &resolver).is_some()); assert!(opts.find_type("std::string::String", &resolver).is_some()); - assert!(!opts.find_type("MyContainer", &resolver).is_none()); - assert!(!opts.find_type("String", &resolver).is_none()); + assert!(opts.find_type("MyContainer", &resolver).is_none()); + assert!(opts.find_type("String", &resolver).is_none()); } } diff --git a/flareon-codegen/src/symbol_resolver.rs b/flareon-codegen/src/symbol_resolver.rs index 94343204..1eaf5db7 100644 --- a/flareon-codegen/src/symbol_resolver.rs +++ b/flareon-codegen/src/symbol_resolver.rs @@ -1,13 +1,10 @@ -#![cfg(feature = "symbol-resolver")] - use std::collections::HashMap; use std::fmt::Display; use std::iter::FromIterator; -use std::path::{Path, PathBuf}; +use std::path::Path; -use log::warn; -use quote::quote; -use syn::{parse_quote, UseTree}; +use syn::UseTree; +use tracing::warn; #[derive(Debug, Clone, PartialEq, Eq)] pub struct SymbolResolver { @@ -28,6 +25,7 @@ impl SymbolResolver { } } + #[must_use] pub fn from_file(file: &syn::File, module_path: &Path) -> Self { let imports = Self::get_imports(file, &ModulePath::from_fs_path(module_path)); Self::new(imports) @@ -330,7 +328,8 @@ impl Display for ModulePath { #[cfg(test)] mod tests { use flareon_codegen::symbol_resolver::VisibleSymbolKind; - use quote::ToTokens; + use quote::{quote, ToTokens}; + use syn::parse_quote; use super::*; diff --git a/flareon/Cargo.toml b/flareon/Cargo.toml index 22f5475e..536f5fd1 100644 --- a/flareon/Cargo.toml +++ b/flareon/Cargo.toml @@ -44,11 +44,7 @@ time.workspace = true tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } tower = { workspace = true, features = ["util"] } tower-sessions = { workspace = true, features = ["memory-store"] } -<<<<<<< HEAD tracing.workspace = true -======= -env_logger = "0.11.5" ->>>>>>> 4db9303 (feat(orm): add foreign key support) [dev-dependencies] async-stream.workspace = true diff --git a/flareon/src/auth/db/migrations/m_0001_initial.rs b/flareon/src/auth/db/migrations/m_0001_initial.rs index 821822de..8809c085 100644 --- a/flareon/src/auth/db/migrations/m_0001_initial.rs +++ b/flareon/src/auth/db/migrations/m_0001_initial.rs @@ -3,7 +3,7 @@ #[derive(Debug, Copy, Clone)] pub(super) struct Migration; impl ::flareon::db::migrations::Migration for Migration { - const APP_NAME: &'static str = "flareon_auth"; + const APP_NAME: &'static str = "flareon"; const MIGRATION_NAME: &'static str = "m_0001_initial"; const DEPENDENCIES: &'static [::flareon::db::migrations::MigrationDependency] = &[]; const OPERATIONS: &'static [::flareon::db::migrations::Operation] = &[ diff --git a/flareon/src/db.rs b/flareon/src/db.rs index 3611ae22..3fee822a 100644 --- a/flareon/src/db.rs +++ b/flareon/src/db.rs @@ -63,9 +63,6 @@ pub enum DatabaseError { /// was not found. #[error("Error retrieving a Foreign Key from the database: record not found")] ForeignKeyNotFound, - /// Primary key could not be converted from i64 using [`TryFromI64`] trait. - #[error("Primary key could not be converted from i64")] - PrimaryKeyFromI64Error, } impl DatabaseError { @@ -331,6 +328,8 @@ pub trait ToDbValue: Send + Sync { fn to_db_value(&self) -> DbValue; } +/// A generalization of [`ToDbValue`] that can also return a marker that means a +/// value should be automatically generated by the database. pub trait ToDbFieldValue { fn to_db_field_value(&self) -> DbFieldValue; } @@ -364,7 +363,7 @@ impl DbFieldValue { pub fn expect_value(self, message: &str) -> sea_query::Value { match self { Self::Value(value) => value, - _ => panic!("{message}"), + Self::Auto => panic!("{message}"), } } } @@ -573,13 +572,15 @@ impl Database { .values( filtered_values .into_iter() - .map(|value| SimpleExpr::Value(value)) + .map(SimpleExpr::Value) .collect::>(), )? .or_default_values() .to_owned(); - if !auto_col_ids.is_empty() { + if auto_col_ids.is_empty() { + self.execute_statement(&insert_statement).await?; + } else { let row = if self.supports_returning() { insert_statement.returning(ReturningClause::Columns(auto_col_identifiers)); @@ -601,8 +602,6 @@ impl Database { ) }; data.update_from_db(row, &auto_col_ids)?; - } else { - self.execute_statement(&insert_statement).await?; } debug!("Inserted row"); @@ -750,8 +749,7 @@ impl Database { fn supports_returning(&self) -> bool { match self.inner { - DatabaseImpl::Sqlite(_) => true, - DatabaseImpl::Postgres(_) => true, + DatabaseImpl::Sqlite(_) | DatabaseImpl::Postgres(_) => true, DatabaseImpl::MySql(_) => false, } } @@ -880,26 +878,14 @@ pub struct StatementResult { impl StatementResult { /// Creates a new statement result with the given number of rows affected. #[must_use] - pub(crate) fn new(rows_affected: RowsNum) -> Self { + #[cfg(test)] + fn new(rows_affected: RowsNum) -> Self { Self { rows_affected, last_inserted_row_id: None, } } - /// Creates a new statement result with the given number of rows affected - /// and last inserted row ID. - #[must_use] - pub(crate) fn new_with_last_inserted_row_id( - rows_affected: RowsNum, - last_inserted_row_id: u64, - ) -> Self { - Self { - rows_affected, - last_inserted_row_id: Some(last_inserted_row_id), - } - } - /// Returns the number of rows affected by the query. #[must_use] pub fn rows_affected(&self) -> RowsNum { @@ -917,18 +903,63 @@ impl StatementResult { #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Deref, Display)] pub struct RowsNum(pub u64); +/// A wrapper over a value that can be either a fixed value or be automatically +/// generated by the database. +/// +/// This is primarily used for auto-incrementing primary keys. +/// +/// # Examples +/// +/// ``` +/// use flareon::db::{model, Auto, Model}; +/// # use flareon::db::migrations::{Field, Operation}; +/// # use flareon::db::{Database, Identifier}; +/// # use flareon::Result; +/// +/// #[model] +/// struct MyModel { +/// id: Auto, +/// } +/// +/// # #[tokio::main] +/// # async fn main() -> Result<()> { +/// +/// # const OPERATION: Operation = Operation::create_model() +/// # .table_name(Identifier::new("todoapp__my_model")) +/// # .fields(&[ +/// # Field::new(Identifier::new("id"), ::TYPE) +/// # .primary_key() +/// # .auto(), +/// # ]) +/// # .build(); +/// +/// let database = Database::new("sqlite::memory:").await?; +/// # OPERATION.forwards(&database).await?; +/// +/// let mut my_model = MyModel { id: Auto::auto() }; +/// my_model.save(&database).await?; +/// assert!(matches!(my_model.id, Auto::Fixed(_))); +/// +/// # Ok(()) +/// # } +/// ``` #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum Auto { + /// A fixed value. Fixed(T), + /// A value that will be automatically generated by the database. Auto, } impl Auto { + /// Creates a new `Auto` instance that is automatically generated by the + /// database. #[must_use] pub const fn auto() -> Self { Self::Auto } + /// Creates a new `Auto` instance with a fixed value. #[must_use] pub const fn fixed(value: T) -> Self { Self::Fixed(value) @@ -947,26 +978,6 @@ impl From for Auto { } } -trait TryFromI64 { - fn try_from_i64(value: i64) -> Result - where - Self: Sized; -} - -impl TryFromI64 for i64 { - fn try_from_i64(value: i64) -> Result { - Ok(value) - } -} - -impl TryFromI64 for i32 { - fn try_from_i64(value: i64) -> Result { - value - .try_into() - .map_err(|_| DatabaseError::PrimaryKeyFromI64Error) - } -} - /// A wrapper over a string that has a limited length. /// /// This type is used to represent a string that has a limited length in the diff --git a/flareon/src/db/fields.rs b/flareon/src/db/fields.rs index be504093..667180a9 100644 --- a/flareon/src/db/fields.rs +++ b/flareon/src/db/fields.rs @@ -1,7 +1,6 @@ //! `DatabaseField` implementations for common types. use flareon::db::DatabaseField; -use log::debug; #[cfg(feature = "mysql")] use crate::db::impl_mysql::MySqlValueRef; diff --git a/flareon/src/db/impl_mysql.rs b/flareon/src/db/impl_mysql.rs index d63aef9d..43aa6fbd 100644 --- a/flareon/src/db/impl_mysql.rs +++ b/flareon/src/db/impl_mysql.rs @@ -4,6 +4,7 @@ use crate::db::ColumnType; impl_sea_query_db_backend!(DatabaseMySql: sqlx::mysql::MySql, sqlx::mysql::MySqlPool, MySqlRow, MySqlValueRef, sea_query::MysqlQueryBuilder); impl DatabaseMySql { + #[allow(clippy::unused_async)] async fn init(&self) -> crate::db::Result<()> { Ok(()) } diff --git a/flareon/src/db/impl_postgres.rs b/flareon/src/db/impl_postgres.rs index eca4550e..a4ffa8bf 100644 --- a/flareon/src/db/impl_postgres.rs +++ b/flareon/src/db/impl_postgres.rs @@ -3,6 +3,7 @@ use crate::db::sea_query_db::impl_sea_query_db_backend; impl_sea_query_db_backend!(DatabasePostgres: sqlx::postgres::Postgres, sqlx::postgres::PgPool, PostgresRow, PostgresValueRef, sea_query::PostgresQueryBuilder); impl DatabasePostgres { + #[allow(clippy::unused_async)] async fn init(&self) -> crate::db::Result<()> { Ok(()) } @@ -38,7 +39,7 @@ impl DatabasePostgres { } } - fn last_inserted_row_id_for(result: &sqlx::postgres::PgQueryResult) -> Option { + fn last_inserted_row_id_for(_result: &sqlx::postgres::PgQueryResult) -> Option { None } diff --git a/flareon/src/db/impl_sqlite.rs b/flareon/src/db/impl_sqlite.rs index 705f517f..b1392058 100644 --- a/flareon/src/db/impl_sqlite.rs +++ b/flareon/src/db/impl_sqlite.rs @@ -1,5 +1,4 @@ use sea_query_binder::SqlxValues; -use sqlx::Executor; use crate::db::sea_query_db::impl_sea_query_db_backend; diff --git a/flareon/src/db/migrations.rs b/flareon/src/db/migrations.rs index 72560348..0c3422ce 100644 --- a/flareon/src/db/migrations.rs +++ b/flareon/src/db/migrations.rs @@ -10,7 +10,7 @@ use tracing::info; use crate::db::migrations::sorter::{MigrationSorter, MigrationSorterError}; use crate::db::relations::ForeignKeyOnDeletePolicy; -use crate::db::{model, query, ColumnType, Database, DatabaseField, Identifier, Model, Result}; +use crate::db::{model, query, ColumnType, Database, DatabaseField, Identifier, Result}; #[derive(Debug, Clone, Error)] #[non_exhaustive] @@ -712,7 +712,7 @@ enum MigrationDependencyInner { }, Model { app: &'static str, - model_name: &'static str, + table_name: &'static str, }, } @@ -734,11 +734,11 @@ impl MigrationDependency { /// Creates a dependency on a model. /// /// This ensures that the migration engine will apply the migration that - /// creates the model with the given app and model name before the current + /// creates the model with the given app and table name before the current /// migration. #[must_use] - pub const fn model(app: &'static str, model_name: &'static str) -> Self { - Self::new(MigrationDependencyInner::Model { app, model_name }) + pub const fn model(app: &'static str, table_name: &'static str) -> Self { + Self::new(MigrationDependencyInner::Model { app, table_name }) } } diff --git a/flareon/src/db/migrations/sorter.rs b/flareon/src/db/migrations/sorter.rs index dc32bfc3..b3f7ace6 100644 --- a/flareon/src/db/migrations/sorter.rs +++ b/flareon/src/db/migrations/sorter.rs @@ -4,12 +4,13 @@ use flareon::db::migrations::MigrationDependency; use thiserror::Error; use crate::db::migrations::{DynMigration, MigrationDependencyInner, OperationInner}; +use crate::utils::graph::{apply_permutation, Graph}; #[derive(Debug, Clone, PartialEq, Eq, Error)] #[non_exhaustive] pub enum MigrationSorterError { #[error("Cycle detected in migrations")] - CycleDetected, + CycleDetected(#[from] flareon::utils::graph::CycleDetected), #[error("Dependency not found: {}", format_migration_dependency(.0))] InvalidDependency(MigrationDependency), #[error("Migration defined twice: {app_name}::{migration_name}")] @@ -17,10 +18,10 @@ pub enum MigrationSorterError { app_name: String, migration_name: String, }, - #[error("Migration creating model defined twice: {app_name}::{model_name}")] + #[error("Migration creating model defined twice: {app_name}::{table_name}")] DuplicateModel { app_name: String, - model_name: String, + table_name: String, }, } @@ -31,8 +32,8 @@ fn format_migration_dependency(dependency: &MigrationDependency) -> String { MigrationDependencyInner::Migration { app, migration } => { format!("migration {app}::{migration}") } - MigrationDependencyInner::Model { app, model_name } => { - format!("model {app}::{model_name}") + MigrationDependencyInner::Model { app, table_name } => { + format!("model {app}::{table_name}") } } } @@ -52,7 +53,7 @@ impl<'a, T: DynMigration> MigrationSorter<'a, T> { pub(super) fn sort(&mut self) -> Result<()> { // Sort by names to ensure that the order is deterministic self.migrations - .sort_by(|a, b| (b.app_name(), b.name()).cmp(&(a.app_name(), a.name()))); + .sort_by(|a, b| (a.app_name(), a.name()).cmp(&(b.app_name(), b.name()))); self.toposort()?; Ok(()) @@ -96,12 +97,12 @@ impl<'a, T: DynMigration> MigrationSorter<'a, T> { if let OperationInner::CreateModel { table_name, .. } = operation.inner { let app_and_model = MigrationLookup::ByAppAndModel { app: migration.app_name(), - model: table_name.0, + table_name: table_name.0, }; if map.insert(app_and_model, index).is_some() { return Err(MigrationSorterError::DuplicateModel { app_name: migration.app_name().to_owned(), - model_name: table_name.0.to_owned(), + table_name: table_name.0.to_owned(), }); } } @@ -112,28 +113,10 @@ impl<'a, T: DynMigration> MigrationSorter<'a, T> { } } -fn apply_permutation(migrations: &mut [T], order: &mut [usize]) { - for i in 0..order.len() { - let mut current = i; - let mut next = order[current]; - - while next != i { - // process the cycle - migrations.swap(current, next); - order[current] = current; - - current = next; - next = order[current]; - } - - order[current] = current; - } -} - #[derive(Debug, Clone, Eq, PartialEq, Hash)] enum MigrationLookup<'a> { ByAppAndName { app: &'a str, name: &'a str }, - ByAppAndModel { app: &'a str, model: &'a str }, + ByAppAndModel { app: &'a str, table_name: &'a str }, } impl From<&MigrationDependency> for MigrationLookup<'_> { @@ -145,84 +128,13 @@ impl From<&MigrationDependency> for MigrationLookup<'_> { name: migration, } } - MigrationDependencyInner::Model { app, model_name } => MigrationLookup::ByAppAndModel { - app, - model: model_name, - }, - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -struct Graph { - vertex_edges: Vec>, -} - -impl Graph { - #[must_use] - fn new(vertex_num: usize) -> Self { - Self { - vertex_edges: vec![Vec::new(); vertex_num], - } - } - - fn add_edge(&mut self, from: usize, to: usize) { - self.vertex_edges[from].push(to); - } - - #[must_use] - fn vertex_num(&self) -> usize { - self.vertex_edges.len() - } - - fn toposort(&mut self) -> Result> { - let mut visited = vec![VisitedStatus::NotVisited; self.vertex_num()]; - let mut sorted_indices_stack = Vec::with_capacity(self.vertex_num()); - - for index in 0..self.vertex_num() { - self.visit(index, &mut visited, &mut sorted_indices_stack)?; - } - - assert_eq!(sorted_indices_stack.len(), self.vertex_num()); - - sorted_indices_stack.reverse(); - Ok(sorted_indices_stack) - } - - fn visit( - &self, - index: usize, - visited: &mut Vec, - sorted_indices_stack: &mut Vec, - ) -> Result<()> { - match visited[index] { - VisitedStatus::Visited => return Ok(()), - VisitedStatus::Visiting => { - return Err(MigrationSorterError::CycleDetected); + MigrationDependencyInner::Model { app, table_name } => { + MigrationLookup::ByAppAndModel { app, table_name } } - VisitedStatus::NotVisited => {} } - - visited[index] = VisitedStatus::Visiting; - - for &neighbor in &self.vertex_edges[index] { - self.visit(neighbor, visited, sorted_indices_stack)?; - } - - visited[index] = VisitedStatus::Visited; - sorted_indices_stack.push(index); - - Ok(()) } } -#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] -enum VisitedStatus { - NotVisited, - Visiting, - Visited, -} - #[cfg(test)] mod tests { use super::*; @@ -230,24 +142,6 @@ mod tests { use crate::db::Identifier; use crate::test::TestMigration; - #[test] - fn graph_toposort() { - let mut graph = Graph::new(8); - graph.add_edge(0, 3); - graph.add_edge(1, 3); - graph.add_edge(1, 4); - graph.add_edge(2, 4); - graph.add_edge(2, 7); - graph.add_edge(3, 5); - graph.add_edge(3, 6); - graph.add_edge(3, 7); - graph.add_edge(4, 6); - - let sorted_indices = graph.toposort().unwrap(); - - assert_eq!(sorted_indices, vec![2, 1, 4, 0, 3, 7, 6, 5]); - } - #[test] fn create_lookup_table() { let migrations = vec![ @@ -284,11 +178,11 @@ mod tests { })); assert!(lookup.contains_key(&MigrationLookup::ByAppAndModel { app: "app1", - model: "model1" + table_name: "model1" })); assert!(lookup.contains_key(&MigrationLookup::ByAppAndModel { app: "app1", - model: "model2" + table_name: "model2" })); } @@ -341,12 +235,12 @@ mod tests { ("app2", "migration_before") ); assert_eq!( - (migrations[1].app_name(), migrations[1].name()), - ("app1", "migration_before") + (migrations[1].app_name(), migrations[2].name()), + ("app2", "migration_before") ); assert_eq!( - (migrations[2].app_name(), migrations[2].name()), - ("app2", "migration_after") + (migrations[2].app_name(), migrations[1].name()), + ("app1", "migration_after") ); assert_eq!( (migrations[3].app_name(), migrations[3].name()), @@ -371,6 +265,7 @@ mod tests { const MIGRATION_NUM: usize = 100; let mut migrations = Vec::new(); + #[allow(clippy::needless_range_loop)] for i in 0..MIGRATION_NUM { let deps = (0..i) .map(|i| MigrationDependency::migration("app1", MIGRATION_NAMES[i])) @@ -411,10 +306,10 @@ mod tests { ]; let mut sorter = MigrationSorter::new(&mut migrations); - assert_eq!( + assert!(matches!( sorter.toposort().unwrap_err(), - MigrationSorterError::CycleDetected - ); + MigrationSorterError::CycleDetected(_) + )); } #[test] @@ -462,7 +357,7 @@ mod tests { sorter.toposort().unwrap_err(), MigrationSorterError::DuplicateModel { app_name: "app1".to_owned(), - model_name: "model1".to_owned() + table_name: "model1".to_owned() } ); } diff --git a/flareon/src/db/query.rs b/flareon/src/db/query.rs index 2f097a7f..0bd58494 100644 --- a/flareon/src/db/query.rs +++ b/flareon/src/db/query.rs @@ -175,7 +175,7 @@ impl Expr { pub fn value(value: T) -> Self { match value.to_db_field_value() { DbFieldValue::Value(value) => Self::Value(value), - _ => panic!("Cannot create query with a non-value field"), + DbFieldValue::Auto => panic!("Cannot create query with a non-value field"), } } @@ -415,7 +415,7 @@ impl_num_expr!(u64); impl_num_expr!(f32); impl_num_expr!(f64); -trait IntoField { +pub trait IntoField { fn into_field(self) -> T; } diff --git a/flareon/src/db/relations.rs b/flareon/src/db/relations.rs index 2dd9a678..e819a600 100644 --- a/flareon/src/db/relations.rs +++ b/flareon/src/db/relations.rs @@ -2,9 +2,17 @@ use flareon::db::DatabaseError; use crate::db::{DatabaseBackend, Model, Result}; +/// A foreign key to another model. +/// +/// Internally, this is represented either as a primary key (in case the +/// model has not been retrieved from the database) or as the model itself. #[derive(Debug, Clone)] pub enum ForeignKey { + /// The primary key of the referenced model; used when the model has not + /// been retrieved from the database yet or when it's unnecessary to + /// store the entire model instance. PrimaryKey(T::PrimaryKey), + /// The referenced model. Model(Box), } @@ -19,14 +27,14 @@ impl ForeignKey { pub fn model(&self) -> Option<&T> { match self { Self::Model(model) => Some(model), - _ => None, + Self::PrimaryKey(_) => None, } } pub fn unwrap(self) -> T { match self { Self::Model(model) => *model, - _ => panic!("object has not been retrieved from the database"), + Self::PrimaryKey(_) => panic!("object has not been retrieved from the database"), } } @@ -68,6 +76,14 @@ impl From<&T> for ForeignKey { } } +/// A foreign key on delete constraint. +/// +/// This is used to define the behavior of a foreign key when the referenced row +/// is deleted. +/// +/// # See also +/// +/// - [`ForeignKeyOnUpdatePolicy`] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)] pub enum ForeignKeyOnDeletePolicy { NoAction, @@ -88,6 +104,14 @@ impl From for sea_query::ForeignKeyAction { } } +/// A foreign key on delete constraint. +/// +/// This is used to define the behavior of a foreign key when the referenced row +/// is updated. +/// +/// # See also +/// +/// - [`ForeignKeyOnDeletePolicy`] #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Default)] pub enum ForeignKeyOnUpdatePolicy { NoAction, diff --git a/flareon/src/db/sea_query_db.rs b/flareon/src/db/sea_query_db.rs index f5897753..227d6785 100644 --- a/flareon/src/db/sea_query_db.rs +++ b/flareon/src/db/sea_query_db.rs @@ -16,7 +16,7 @@ macro_rules! impl_sea_query_db_backend { let db_connection = <$pool_ty>::connect(url).await?; let db = Self { db_connection }; - db.init(); + db.init().await?; Ok(db) } diff --git a/flareon/src/lib.rs b/flareon/src/lib.rs index ffacf5bb..7060813e 100644 --- a/flareon/src/lib.rs +++ b/flareon/src/lib.rs @@ -64,6 +64,7 @@ pub mod response; pub mod router; pub mod static_files; pub mod test; +pub(crate) mod utils; use std::fmt::Formatter; use std::future::{poll_fn, Future}; diff --git a/flareon/src/private.rs b/flareon/src/private.rs index 3ee22cc0..e02a578a 100644 --- a/flareon/src/private.rs +++ b/flareon/src/private.rs @@ -1,4 +1,5 @@ -//! Re-exports of some of the Flareon dependencies that are used in the macros. +//! Re-exports of some of the Flareon dependencies that are used in the macros +//! and the CLI. //! //! This is to avoid the need to add them as dependencies to the crate that uses //! the macros. @@ -8,3 +9,6 @@ pub use async_trait::async_trait; pub use bytes::Bytes; pub use tokio; + +// used in the CLI +pub use crate::utils::graph::apply_permutation; diff --git a/flareon/src/utils.rs b/flareon/src/utils.rs new file mode 100644 index 00000000..07ebd317 --- /dev/null +++ b/flareon/src/utils.rs @@ -0,0 +1 @@ +pub(crate) mod graph; diff --git a/flareon/src/utils/graph.rs b/flareon/src/utils/graph.rs new file mode 100644 index 00000000..e0cfdec6 --- /dev/null +++ b/flareon/src/utils/graph.rs @@ -0,0 +1,134 @@ +use thiserror::Error; + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Error)] +#[error("Cycle detected in the graph")] +pub struct CycleDetected; + +pub fn apply_permutation(items: &mut [T], order: &mut [usize]) { + for i in 0..order.len() { + let mut current = i; + let mut next = order[current]; + + while next != i { + // process the cycle + items.swap(current, next); + order[current] = current; + + current = next; + next = order[current]; + } + + order[current] = current; + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub(crate) struct Graph { + vertex_edges: Vec>, +} + +impl Graph { + #[must_use] + pub(crate) fn new(vertex_num: usize) -> Self { + Self { + vertex_edges: vec![Vec::new(); vertex_num], + } + } + + pub(crate) fn add_edge(&mut self, from: usize, to: usize) { + self.vertex_edges[from].push(to); + } + + #[must_use] + pub(crate) fn vertex_num(&self) -> usize { + self.vertex_edges.len() + } + + pub(crate) fn toposort(&mut self) -> Result, CycleDetected> { + let mut visited = vec![VisitedStatus::NotVisited; self.vertex_num()]; + let mut sorted_indices_stack = Vec::with_capacity(self.vertex_num()); + + for index in (0..self.vertex_num()).rev() { + self.toposort_visit(index, &mut visited, &mut sorted_indices_stack)?; + } + + assert_eq!(sorted_indices_stack.len(), self.vertex_num()); + + sorted_indices_stack.reverse(); + Ok(sorted_indices_stack) + } + + fn toposort_visit( + &self, + index: usize, + visited: &mut Vec, + sorted_indices_stack: &mut Vec, + ) -> Result<(), CycleDetected> { + match visited[index] { + VisitedStatus::Visited => return Ok(()), + VisitedStatus::Visiting => { + return Err(CycleDetected); + } + VisitedStatus::NotVisited => {} + } + + visited[index] = VisitedStatus::Visiting; + + for &neighbor in &self.vertex_edges[index] { + self.toposort_visit(neighbor, visited, sorted_indices_stack)?; + } + + visited[index] = VisitedStatus::Visited; + sorted_indices_stack.push(index); + + Ok(()) + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)] +enum VisitedStatus { + NotVisited, + Visiting, + Visited, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn graph_toposort_stable() { + let mut graph = Graph::new(8); + let sorted_indices = graph.toposort().unwrap(); + assert_eq!(sorted_indices, vec![0, 1, 2, 3, 4, 5, 6, 7]); + } + + #[test] + fn graph_toposort() { + let mut graph = Graph::new(8); + graph.add_edge(5, 3); + graph.add_edge(1, 3); + graph.add_edge(1, 2); + graph.add_edge(4, 2); + graph.add_edge(4, 6); + graph.add_edge(3, 0); + graph.add_edge(3, 7); + graph.add_edge(3, 6); + graph.add_edge(2, 7); + + let sorted_indices = graph.toposort().unwrap(); + + assert_eq!(sorted_indices, vec![1, 4, 2, 5, 3, 0, 6, 7]); + } + + #[test] + fn graph_toposort_with_cycle() { + let mut graph = Graph::new(4); + graph.add_edge(0, 1); + graph.add_edge(1, 2); + graph.add_edge(2, 3); + graph.add_edge(3, 0); + + assert!(matches!(graph.toposort(), Err(CycleDetected))); + } +} diff --git a/flareon/tests/db.rs b/flareon/tests/db.rs index 9489d801..334354bc 100644 --- a/flareon/tests/db.rs +++ b/flareon/tests/db.rs @@ -317,13 +317,13 @@ async fn foreign_keys_option(db: &mut TestDatabase) { Identifier::new("parent"), > as DatabaseField>::TYPE, ) + .set_null(> as DatabaseField>::NULLABLE) .foreign_key( ::TABLE_NAME, ::PRIMARY_KEY_NAME, - ForeignKeyOnDeletePolicy::Restrict, - ForeignKeyOnUpdatePolicy::Restrict, - ) - .set_null(> as DatabaseField>::NULLABLE), + ForeignKeyOnDeletePolicy::SetNone, + ForeignKeyOnUpdatePolicy::SetNone, + ), ]) .build(); @@ -337,7 +337,7 @@ async fn foreign_keys_option(db: &mut TestDatabase) { }; child.save(&**db).await.unwrap(); - let mut child = Child::objects().all(&**db).await.unwrap()[0].clone(); + let child = Child::objects().all(&**db).await.unwrap()[0].clone(); assert_eq!(child.parent, None); query!(Child, $id == child.id).delete(&**db).await.unwrap(); @@ -352,8 +352,85 @@ async fn foreign_keys_option(db: &mut TestDatabase) { }; child.save(&**db).await.unwrap(); - let mut child = Child::objects().all(&**db).await.unwrap()[0].clone(); + let child = Child::objects().all(&**db).await.unwrap()[0].clone(); + let mut parent_fk = child.parent.unwrap(); + let parent_from_db = parent_fk.get(&**db).await.unwrap(); + assert_eq!(parent_from_db, &parent); + + // Check none policy + query!(Parent, $id == parent.id) + .delete(&**db) + .await + .unwrap(); + let child = Child::objects().all(&**db).await.unwrap()[0].clone(); + assert_eq!(child.parent, None); +} + +#[flareon_macros::dbtest] +async fn foreign_keys_cascade(db: &mut TestDatabase) { + #[derive(Debug, Clone, PartialEq)] + #[model] + struct Parent { + id: Auto, + } + + #[derive(Debug, Clone, PartialEq)] + #[model] + struct Child { + id: Auto, + parent: Option>, + } + + const CREATE_PARENT: Operation = Operation::create_model() + .table_name(Identifier::new("parent")) + .fields(&[ + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) + .primary_key() + .auto(), + ]) + .build(); + const CREATE_CHILD: Operation = Operation::create_model() + .table_name(Identifier::new("child")) + .fields(&[ + Field::new(Identifier::new("id"), as DatabaseField>::TYPE) + .primary_key() + .auto(), + Field::new( + Identifier::new("parent"), + > as DatabaseField>::TYPE, + ) + .set_null(> as DatabaseField>::NULLABLE) + .foreign_key( + ::TABLE_NAME, + ::PRIMARY_KEY_NAME, + ForeignKeyOnDeletePolicy::Cascade, + ForeignKeyOnUpdatePolicy::Cascade, + ), + ]) + .build(); + + CREATE_PARENT.forwards(db).await.unwrap(); + CREATE_CHILD.forwards(db).await.unwrap(); + + // with parent + let mut parent = Parent { id: Auto::auto() }; + parent.save(&**db).await.unwrap(); + + let mut child = Child { + id: Auto::auto(), + parent: Some(ForeignKey::from(&parent)), + }; + child.save(&**db).await.unwrap(); + + let child = Child::objects().all(&**db).await.unwrap()[0].clone(); let mut parent_fk = child.parent.unwrap(); let parent_from_db = parent_fk.get(&**db).await.unwrap(); assert_eq!(parent_from_db, &parent); + + // Check cascade policy + query!(Parent, $id == parent.id) + .delete(&**db) + .await + .unwrap(); + assert!(Child::objects().all(&**db).await.unwrap().is_empty()); } From 0d818726aa26f93f41ab9e8e1181ff1099d1c947 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Ma=C4=87kowski?= Date: Tue, 7 Jan 2025 23:17:02 +0100 Subject: [PATCH 5/9] ci: run CI on Rust Beta until new trait solver is there --- .github/workflows/rust.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 7bd94ba3..5e0123a0 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -23,7 +23,7 @@ jobs: if: github.event_name == 'push' || github.event_name == 'schedule' || github.event.pull_request.head.repo.full_name != github.repository strategy: matrix: - rust: [stable, nightly] + rust: [beta, nightly] os: [ubuntu-latest, macos-latest, windows-latest] name: Build & test @@ -70,7 +70,7 @@ jobs: - name: Install Rust toolchain uses: dtolnay/rust-toolchain@master with: - toolchain: stable + toolchain: beta components: clippy - name: Cache Cargo registry @@ -194,7 +194,7 @@ jobs: - name: Install Rust toolchain uses: dtolnay/rust-toolchain@master with: - toolchain: stable + toolchain: beta - name: Cache Cargo registry uses: Swatinem/rust-cache@v2 From 4081a59e2ecff8182be653bc388f11e7ba4cfb2e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Ma=C4=87kowski?= Date: Tue, 7 Jan 2025 23:25:12 +0100 Subject: [PATCH 6/9] chore: more fixes, more docs --- flareon-cli/src/migration_generator.rs | 26 +++++++------------------- flareon/src/auth.rs | 3 +-- flareon/src/db.rs | 16 +++++++++++++--- flareon/src/db/fields.rs | 6 ++++++ flareon/src/db/relations.rs | 2 +- 5 files changed, 28 insertions(+), 25 deletions(-) diff --git a/flareon-cli/src/migration_generator.rs b/flareon-cli/src/migration_generator.rs index bb8cb471..f4015b34 100644 --- a/flareon-cli/src/migration_generator.rs +++ b/flareon-cli/src/migration_generator.rs @@ -120,7 +120,7 @@ impl MigrationGenerator { migration.toposort_operations(); migration .dependencies - .extend(migration.get_foreign_key_dependencies(&self.crate_name)); + .extend(migration.get_foreign_key_dependencies()); Ok(Some(migration)) } @@ -527,7 +527,7 @@ pub struct SourceFile { impl SourceFile { #[must_use] - pub fn new(path: PathBuf, content: syn::File) -> Self { + fn new(path: PathBuf, content: syn::File) -> Self { assert!( path.is_relative(), "path must be relative to the src directory" @@ -666,7 +666,7 @@ pub struct GeneratedMigration { } impl GeneratedMigration { - fn get_foreign_key_dependencies(&self, crate_name: &str) -> Vec { + fn get_foreign_key_dependencies(&self) -> Vec { let create_ops = self.get_create_ops_map(); let ops_adding_foreign_keys = self.get_ops_adding_foreign_keys(); @@ -977,20 +977,8 @@ impl Repr for DynOperation { fn repr(&self) -> TokenStream { match self { Self::CreateModel { - table_name, - model_ty, - fields, - .. + table_name, fields, .. } => { - let model_name = match model_ty { - syn::Type::Path(syn::TypePath { path, .. }) => path - .segments - .last() - .expect("TypePath must have at least one segment") - .ident - .to_string(), - _ => unreachable!("model_ty is expected to be a TypePath"), - }; let fields = fields.iter().map(Repr::repr).collect::>(); quote! { ::flareon::db::migrations::Operation::create_model() @@ -1265,7 +1253,7 @@ mod tests { }], }; - let external_dependencies = migration.get_foreign_key_dependencies("my_crate"); + let external_dependencies = migration.get_foreign_key_dependencies(); assert!(external_dependencies.is_empty()); } @@ -1292,7 +1280,7 @@ mod tests { }], }; - let external_dependencies = migration.get_foreign_key_dependencies("my_crate"); + let external_dependencies = migration.get_foreign_key_dependencies(); assert_eq!(external_dependencies.len(), 1); assert_eq!( external_dependencies[0], @@ -1342,7 +1330,7 @@ mod tests { ], }; - let external_dependencies = migration.get_foreign_key_dependencies("my_crate"); + let external_dependencies = migration.get_foreign_key_dependencies(); assert_eq!(external_dependencies.len(), 2); assert!(external_dependencies.contains(&DynDependency::Model { model_type: parse_quote!(my_crate::Table2), diff --git a/flareon/src/auth.rs b/flareon/src/auth.rs index 6fe88d71..f67e6d1b 100644 --- a/flareon/src/auth.rs +++ b/flareon/src/auth.rs @@ -23,9 +23,8 @@ use subtle::ConstantTimeEq; use thiserror::Error; use crate::config::SecretKey; -use crate::db::DbValue; #[cfg(feature = "db")] -use crate::db::{ColumnType, DatabaseField, FromDbValue, SqlxValueRef, ToDbValue}; +use crate::db::{ColumnType, DatabaseField, DbValue, FromDbValue, SqlxValueRef, ToDbValue}; use crate::request::{Request, RequestExt}; #[derive(Debug, Error)] diff --git a/flareon/src/db.rs b/flareon/src/db.rs index 3fee822a..161085e8 100644 --- a/flareon/src/db.rs +++ b/flareon/src/db.rs @@ -73,6 +73,7 @@ impl DatabaseError { } } +/// An alias for [`Result`] that uses [`DatabaseError`] as the error type. pub type Result = std::result::Result; /// A model trait for database models. @@ -211,6 +212,8 @@ impl Column { } } +/// A marker trait that denotes that a type can be used as a primary key in a +/// database. pub trait PrimaryKey: DatabaseField + Clone {} /// A row structure that holds the data of a single row retrieved from the @@ -749,7 +752,11 @@ impl Database { fn supports_returning(&self) -> bool { match self.inner { - DatabaseImpl::Sqlite(_) | DatabaseImpl::Postgres(_) => true, + #[cfg(feature = "sqlite")] + DatabaseImpl::Sqlite(_) => true, + #[cfg(feature = "postgres")] + DatabaseImpl::Postgres(_) => true, + #[cfg(feature = "mysql")] DatabaseImpl::MySql(_) => false, } } @@ -913,7 +920,7 @@ pub struct RowsNum(pub u64); /// ``` /// use flareon::db::{model, Auto, Model}; /// # use flareon::db::migrations::{Field, Operation}; -/// # use flareon::db::{Database, Identifier}; +/// # use flareon::db::{Database, Identifier, DatabaseField}; /// # use flareon::Result; /// /// #[model] @@ -925,7 +932,7 @@ pub struct RowsNum(pub u64); /// # async fn main() -> Result<()> { /// /// # const OPERATION: Operation = Operation::create_model() -/// # .table_name(Identifier::new("todoapp__my_model")) +/// # .table_name(Identifier::new("my_model")) /// # .fields(&[ /// # Field::new(Identifier::new("id"), ::TYPE) /// # .primary_key() @@ -955,6 +962,7 @@ impl Auto { /// Creates a new `Auto` instance that is automatically generated by the /// database. #[must_use] + #[allow(clippy::self_named_constructors)] pub const fn auto() -> Self { Self::Auto } @@ -1025,6 +1033,8 @@ impl PartialEq> for String { } } +/// An error returned by [`LimitedString::new`] when the string is longer than +/// the specified limit. #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Error)] #[error("string is too long ({length} > {LIMIT})")] pub struct NewLimitedStringError { diff --git a/flareon/src/db/fields.rs b/flareon/src/db/fields.rs index 667180a9..cc845f5f 100644 --- a/flareon/src/db/fields.rs +++ b/flareon/src/db/fields.rs @@ -288,6 +288,7 @@ impl DatabaseField for Auto { } impl FromDbValue for Auto { + #[cfg(feature = "sqlite")] fn from_sqlite(value: SqliteValueRef) -> Result where Self: Sized, @@ -295,6 +296,7 @@ impl FromDbValue for Auto { Ok(Self::fixed(T::from_sqlite(value)?)) } + #[cfg(feature = "postgres")] fn from_postgres(value: PostgresValueRef) -> Result where Self: Sized, @@ -302,6 +304,7 @@ impl FromDbValue for Auto { Ok(Self::fixed(T::from_postgres(value)?)) } + #[cfg(feature = "mysql")] fn from_mysql(value: MySqlValueRef) -> Result where Self: Sized, @@ -323,6 +326,7 @@ impl FromDbValue for Option> where Option: FromDbValue, { + #[cfg(feature = "sqlite")] fn from_sqlite(value: SqliteValueRef) -> Result where Self: Sized, @@ -330,6 +334,7 @@ where >::from_sqlite(value).map(|value| value.map(Auto::fixed)) } + #[cfg(feature = "postgres")] fn from_postgres(value: PostgresValueRef) -> Result where Self: Sized, @@ -337,6 +342,7 @@ where >::from_postgres(value).map(|value| value.map(Auto::fixed)) } + #[cfg(feature = "mysql")] fn from_mysql(value: MySqlValueRef) -> Result where Self: Sized, diff --git a/flareon/src/db/relations.rs b/flareon/src/db/relations.rs index e819a600..95e9e3c1 100644 --- a/flareon/src/db/relations.rs +++ b/flareon/src/db/relations.rs @@ -104,7 +104,7 @@ impl From for sea_query::ForeignKeyAction { } } -/// A foreign key on delete constraint. +/// A foreign key on update constraint. /// /// This is used to define the behavior of a foreign key when the referenced row /// is updated. From 692c1a174275998ee60cc374234ddddc9af943c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Ma=C4=87kowski?= Date: Wed, 8 Jan 2025 09:49:28 +0100 Subject: [PATCH 7/9] chore: more fixes, tests, docs --- flareon-cli/src/main.rs | 16 ++- flareon-cli/src/migration_generator.rs | 135 +++++++++++++++++++++-- flareon-cli/tests/migration_generator.rs | 2 +- flareon/src/db.rs | 40 +++++++ flareon/src/db/migrations.rs | 24 +++- flareon/src/db/relations.rs | 79 +++++++++++++ 6 files changed, 277 insertions(+), 19 deletions(-) diff --git a/flareon-cli/src/main.rs b/flareon-cli/src/main.rs index 02c6c038..3e1f8e66 100644 --- a/flareon-cli/src/main.rs +++ b/flareon-cli/src/main.rs @@ -1,5 +1,3 @@ -extern crate core; - mod migration_generator; mod utils; @@ -27,6 +25,9 @@ enum Commands { /// Path to the crate directory to generate migrations for (default: /// current directory) path: Option, + /// Name of the app to use in the migration (default: crate name) + #[arg(long)] + app_name: Option, /// Directory to write the migrations to (default: migrations/ directory /// in the crate's src/ directory) #[arg(long)] @@ -46,9 +47,16 @@ fn main() -> anyhow::Result<()> { .init(); match cli.command { - Commands::MakeMigrations { path, output_dir } => { + Commands::MakeMigrations { + path, + app_name, + output_dir, + } => { let path = path.unwrap_or_else(|| PathBuf::from(".")); - let options = MigrationGeneratorOptions { output_dir }; + let options = MigrationGeneratorOptions { + app_name, + output_dir, + }; make_migrations(&path, options).with_context(|| "unable to create migrations")?; } } diff --git a/flareon-cli/src/migration_generator.rs b/flareon-cli/src/migration_generator.rs index f4015b34..00b68205 100644 --- a/flareon-cli/src/migration_generator.rs +++ b/flareon-cli/src/migration_generator.rs @@ -48,6 +48,7 @@ pub fn make_migrations(path: &Path, options: MigrationGeneratorOptions) -> anyho #[derive(Debug, Clone, Default)] pub struct MigrationGeneratorOptions { + pub app_name: Option, pub output_dir: Option, } @@ -82,6 +83,7 @@ impl MigrationGenerator { Ok(()) } + /// Generate migrations as a ready-to-write source code. pub fn generate_migrations_to_write( &mut self, source_files: Vec, @@ -95,6 +97,8 @@ impl MigrationGenerator { } } + /// Generate migrations and return internal structures that can be used to + /// generate source code. pub fn generate_migrations( &mut self, source_files: Vec, @@ -319,7 +323,10 @@ impl MigrationGenerator { fn make_create_model_operation(app_model: &ModelInSource) -> DynOperation { DynOperation::CreateModel { table_name: app_model.model.table_name.clone(), - model_ty: app_model.model.resolved_ty.clone().expect("resolved_ty is expected to be present when parsing the entire file with symbol resolver"), + model_ty: app_model.model.resolved_ty.clone().expect( + "resolved_ty is expected to be present when \ + parsing the entire file with symbol resolver", + ), fields: app_model.model.fields.clone(), } } @@ -381,7 +388,10 @@ impl MigrationGenerator { fn make_add_field_operation(app_model: &ModelInSource, field: &Field) -> DynOperation { DynOperation::AddField { table_name: app_model.model.table_name.clone(), - model_ty: app_model.model.resolved_ty.clone().expect("resolved_ty is expected to be present when parsing the entire file with symbol resolver"), + model_ty: app_model.model.resolved_ty.clone().expect( + "resolved_ty is expected to be present \ + when parsing the entire file with symbol resolver", + ), field: field.clone(), } } @@ -426,7 +436,7 @@ impl MigrationGenerator { .map(|dependency| dependency.repr()) .collect(); - let app_name = &self.crate_name; + let app_name = self.options.app_name.as_ref().unwrap_or(&self.crate_name); let migration_name = &migration.migration_name; let migration_def = quote! { #[derive(Debug, Copy, Clone)] @@ -666,6 +676,8 @@ pub struct GeneratedMigration { } impl GeneratedMigration { + /// Get the list of [`DynDependency`] for all foreign keys that point + /// to models that are **not** created in this migration. fn get_foreign_key_dependencies(&self) -> Vec { let create_ops = self.get_create_ops_map(); let ops_adding_foreign_keys = self.get_ops_adding_foreign_keys(); @@ -682,6 +694,16 @@ impl GeneratedMigration { dependencies } + /// Removes dependency cycles by removing operations that create cycles. + /// + /// This method tries to minimize the number of operations added by + /// calculating the minimum feedback arc set of the dependency graph. + /// + /// This method modifies the `operations` field in place. + /// + /// # See also + /// + /// * [`Self::remove_dependency`] fn remove_cycles(&mut self) { let graph = self.construct_dependency_graph(); @@ -701,6 +723,11 @@ impl GeneratedMigration { } } + /// Remove a dependency between two operations. + /// + /// This is done by removing foreign keys from the `from` operation that + /// point to the model created by `to` operation, and creating a new + /// `AddField` operation for each removed foreign key. #[must_use] fn remove_dependency(from: &mut DynOperation, to: &DynOperation) -> Vec { match from { @@ -712,7 +739,10 @@ impl GeneratedMigration { let to_type = match to { DynOperation::CreateModel { model_ty, .. } => model_ty, DynOperation::AddField { .. } => { - unreachable!("AddField operation shouldn't be a dependency of CreateModel because it doesn't create a new model") + unreachable!( + "AddField operation shouldn't be a dependency of CreateModel \ + because it doesn't create a new model" + ) } }; trace!( @@ -745,6 +775,18 @@ impl GeneratedMigration { } } + /// Topologically sort operations in this migration. + /// + /// This is to ensure that operations will be applied in the correct order. + /// If there are no dependencies between operations, the order of operations + /// will not be modified. + /// + /// This method modifies the `operations` field in place. + /// + /// # Panics + /// + /// This method should be called after removing cycles; otherwise it will + /// panic. fn toposort_operations(&mut self) { let graph = self.construct_dependency_graph(); @@ -752,11 +794,17 @@ impl GeneratedMigration { .expect("cycles shouldn't exist after removing them"); let mut sorted = sorted .into_iter() - .map(petgraph::prelude::NodeIndex::index) + .map(petgraph::graph::NodeIndex::index) .collect::>(); flareon::__private::apply_permutation(&mut self.operations, &mut sorted); } + /// Construct a graph that represents reverse dependencies between + /// operations in this migration. + /// + /// The graph is directed and has an edge from operation A to operation B + /// if operation B creates a foreign key that points to a model created by + /// operation A. #[must_use] fn construct_dependency_graph(&mut self) -> DiGraph { let create_ops = self.get_create_ops_map(); @@ -769,7 +817,11 @@ impl GeneratedMigration { } for (i, dependency_ty) in &ops_adding_foreign_keys { if let Some(&dependency) = create_ops.get(dependency_ty) { - graph.add_edge(NodeIndex::new(dependency), NodeIndex::new(*i), ()); + graph.add_edge( + petgraph::graph::NodeIndex::new(dependency), + petgraph::graph::NodeIndex::new(*i), + (), + ); } } @@ -855,16 +907,19 @@ impl Repr for Field { let mut tokens = quote! { ::flareon::db::migrations::Field::new(::flareon::db::Identifier::new(#column_name), <#ty as ::flareon::db::DatabaseField>::TYPE) }; - if self - .auto_value - .expect("auto_value is expected to be present when parsing the entire file with symbol resolver") - { + if self.auto_value.expect( + "auto_value is expected to be present \ + when parsing the entire file with symbol resolver", + ) { tokens = quote! { #tokens.auto() } } if self.primary_key { tokens = quote! { #tokens.primary_key() } } - if let Some(fk_spec) = self.foreign_key.clone().expect("foreign_key is expected to be present when parsing the entire file with symbol resolver") { + if let Some(fk_spec) = self.foreign_key.clone().expect( + "foreign_key is expected to be present \ + when parsing the entire file with symbol resolver", + ) { let to_model = &fk_spec.to_model; tokens = quote! { @@ -966,7 +1021,8 @@ fn is_field_foreign_key_to(field: &Field, ty: &syn::Type) -> bool { /// Returns [`None`] if the field is not a foreign key. fn foreign_key_for_field(field: &Field) -> Option { match field.foreign_key.clone().expect( - "foreign_key is expected to be present when parsing the entire file with symbol resolver", + "foreign_key is expected to be present \ + when parsing the entire file with symbol resolver", ) { None => None, Some(foreign_key_spec) => Some(foreign_key_spec.to_model), @@ -1339,4 +1395,59 @@ mod tests { model_type: parse_quote!(crate::Table4), })); } + + #[test] + fn make_add_field_operation() { + let app_model = ModelInSource { + model_item: parse_quote! { + struct TestModel { + id: i32, + field1: i32, + } + }, + model: Model { + name: format_ident!("TestModel"), + original_name: "TestModel".to_string(), + resolved_ty: Some(parse_quote!(TestModel)), + model_type: Default::default(), + table_name: "test_model".to_string(), + pk_field: Field { + field_name: format_ident!("id"), + column_name: "id".to_string(), + ty: parse_quote!(i32), + auto_value: MaybeUnknown::Known(true), + primary_key: true, + unique: false, + foreign_key: MaybeUnknown::Known(None), + }, + fields: vec![], + }, + }; + + let field = Field { + field_name: format_ident!("new_field"), + column_name: "new_field".to_string(), + ty: parse_quote!(i32), + auto_value: MaybeUnknown::Known(false), + primary_key: false, + unique: false, + foreign_key: MaybeUnknown::Known(None), + }; + + let operation = MigrationGenerator::make_add_field_operation(&app_model, &field); + + match operation { + DynOperation::AddField { + table_name, + model_ty, + field: op_field, + } => { + assert_eq!(table_name, "test_model"); + assert_eq!(model_ty, parse_quote!(TestModel)); + assert_eq!(op_field.column_name, "new_field"); + assert_eq!(op_field.ty, parse_quote!(i32)); + } + _ => panic!("Expected AddField operation"), + } + } } diff --git a/flareon-cli/tests/migration_generator.rs b/flareon-cli/tests/migration_generator.rs index 43509e76..e5ba5a78 100644 --- a/flareon-cli/tests/migration_generator.rs +++ b/flareon-cli/tests/migration_generator.rs @@ -120,7 +120,7 @@ fn create_models_foreign_key_cycle() { } #[test] -fn create_models_foreign_two_migrations() { +fn create_models_foreign_key_two_migrations() { let mut generator = test_generator(); let src = include_str!("migration_generator/foreign_key_two_migrations/step_1.rs"); diff --git a/flareon/src/db.rs b/flareon/src/db.rs index 161085e8..35ddca49 100644 --- a/flareon/src/db.rs +++ b/flareon/src/db.rs @@ -1163,4 +1163,44 @@ mod tests { LimitedString::<5>::new("test").unwrap(), ); } + + #[test] + fn db_field_value_is_auto() { + let auto_value = DbFieldValue::Auto; + assert!(auto_value.is_auto()); + assert!(!auto_value.is_value()); + } + + #[test] + fn db_field_value_is_value() { + let value = DbFieldValue::Value(42.into()); + assert!(value.is_value()); + assert!(!value.is_auto()); + } + + #[test] + fn db_field_value_unwrap() { + let value = DbFieldValue::Value(42.into()); + assert_eq!(value.unwrap_value(), 42.into()); + } + + #[test] + #[should_panic(expected = "called DbValue::unwrap_value() on a wrong DbValue variant")] + fn db_field_value_unwrap_panic() { + let auto_value = DbFieldValue::Auto; + let _ = auto_value.unwrap_value(); + } + + #[test] + fn db_field_value_expect() { + let value = DbFieldValue::Value(42.into()); + assert_eq!(value.expect_value("expected a value"), 42.into()); + } + + #[test] + #[should_panic(expected = "expected a value")] + fn db_field_value_expect_panic() { + let auto_value = DbFieldValue::Auto; + let _ = auto_value.expect_value("expected a value"); + } } diff --git a/flareon/src/db/migrations.rs b/flareon/src/db/migrations.rs index 0c3422ce..f307dfd9 100644 --- a/flareon/src/db/migrations.rs +++ b/flareon/src/db/migrations.rs @@ -451,7 +451,7 @@ impl Field { } } -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] struct ForeignKeyReference { model: Identifier, field: Identifier, @@ -848,7 +848,7 @@ mod tests { } #[test] - fn test_field_new() { + fn field_new() { let field = Field::new(Identifier::new("id"), ColumnType::Integer) .primary_key() .auto() @@ -861,6 +861,26 @@ mod tests { assert!(field.null); } + #[test] + fn field_foreign_key() { + let field = Field::new(Identifier::new("parent"), ColumnType::Integer).foreign_key( + Identifier::new("testapp__parent"), + Identifier::new("id"), + ForeignKeyOnDeletePolicy::Restrict, + ForeignKeyOnUpdatePolicy::Restrict, + ); + + assert_eq!( + field.foreign_key, + Some(ForeignKeyReference { + model: Identifier::new("testapp__parent"), + field: Identifier::new("id"), + on_delete: ForeignKeyOnDeletePolicy::Restrict, + on_update: ForeignKeyOnUpdatePolicy::Restrict, + }) + ); + } + #[test] fn test_migration_wrapper() { let migration = MigrationWrapper::new(TestMigration); diff --git a/flareon/src/db/relations.rs b/flareon/src/db/relations.rs index 95e9e3c1..b634079f 100644 --- a/flareon/src/db/relations.rs +++ b/flareon/src/db/relations.rs @@ -17,6 +17,7 @@ pub enum ForeignKey { } impl ForeignKey { + /// Returns the primary key of the referenced model. pub fn primary_key(&self) -> &T::PrimaryKey { match self { Self::PrimaryKey(pk) => pk, @@ -24,6 +25,8 @@ impl ForeignKey { } } + /// Returns the model, if it has been stored in this [`ForeignKey`] + /// instance, or [`None`] otherwise. pub fn model(&self) -> Option<&T> { match self { Self::Model(model) => Some(model), @@ -31,6 +34,11 @@ impl ForeignKey { } } + /// Unwrap the foreign key, returning the model. + /// + /// # Panics + /// + /// Panics if the model has not been stored in this [`ForeignKey`] instance. pub fn unwrap(self) -> T { match self { Self::Model(model) => *model, @@ -39,6 +47,11 @@ impl ForeignKey { } /// Retrieve the model from the database, if needed, and return it. + /// + /// If the model has already been retrieved, this method will return it. + /// + /// This method will replace the primary key with the model instance if + /// the primary key is stored in this [`ForeignKey`] instance. pub async fn get(&mut self, db: &DB) -> Result<&T> { match self { Self::Model(model) => Ok(model), @@ -131,3 +144,69 @@ impl From for sea_query::ForeignKeyAction { } } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::db::{model, Auto}; + + #[derive(Debug, Clone, PartialEq)] + #[model] + struct TestModel { + id: Auto, + } + + #[test] + fn test_primary_key() { + let fk = ForeignKey::::PrimaryKey(Auto::fixed(1)); + + assert_eq!(fk.primary_key(), &Auto::fixed(1)); + } + + #[test] + fn test_model() { + let model = TestModel { id: Auto::fixed(1) }; + let fk = ForeignKey::Model(Box::new(model.clone())); + + assert_eq!(fk.model().unwrap(), &model); + assert_eq!(fk.primary_key(), &Auto::fixed(1)); + } + + #[test] + fn test_unwrap_model() { + let model = TestModel { id: Auto::fixed(1) }; + let fk = ForeignKey::Model(Box::new(model.clone())); + + assert_eq!(fk.unwrap(), model); + } + + #[should_panic(expected = "object has not been retrieved from the database")] + fn test_unwrap_primary_key() { + let fk = ForeignKey::::PrimaryKey(Auto::fixed(1)); + fk.unwrap(); + } + + #[test] + fn test_partial_eq() { + let fk1 = ForeignKey::::PrimaryKey(Auto::fixed(1)); + let fk2 = ForeignKey::::PrimaryKey(Auto::fixed(1)); + + assert_eq!(fk1, fk2); + } + + #[test] + fn test_from_model() { + let model = TestModel { id: Auto::fixed(1) }; + let fk: ForeignKey = ForeignKey::from(model.clone()); + + assert_eq!(fk.model().unwrap(), &model); + } + + #[test] + fn test_from_model_ref() { + let model = TestModel { id: Auto::fixed(1) }; + let fk: ForeignKey = ForeignKey::from(&model); + + assert_eq!(fk.primary_key(), &Auto::fixed(1)); + } +} From ba575c92abd05f25e92396843958978dbf2b749b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Ma=C4=87kowski?= Date: Thu, 9 Jan 2025 17:30:41 +0100 Subject: [PATCH 8/9] chore: go back to stable since 1.84 is here --- .github/workflows/rust.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 5e0123a0..7bd94ba3 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -23,7 +23,7 @@ jobs: if: github.event_name == 'push' || github.event_name == 'schedule' || github.event.pull_request.head.repo.full_name != github.repository strategy: matrix: - rust: [beta, nightly] + rust: [stable, nightly] os: [ubuntu-latest, macos-latest, windows-latest] name: Build & test @@ -70,7 +70,7 @@ jobs: - name: Install Rust toolchain uses: dtolnay/rust-toolchain@master with: - toolchain: beta + toolchain: stable components: clippy - name: Cache Cargo registry @@ -194,7 +194,7 @@ jobs: - name: Install Rust toolchain uses: dtolnay/rust-toolchain@master with: - toolchain: beta + toolchain: stable - name: Cache Cargo registry uses: Swatinem/rust-cache@v2 From 93116c2a7f4e977eb2be35f6ac24a0a659096b57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Ma=C4=87kowski?= Date: Fri, 10 Jan 2025 15:51:39 +0100 Subject: [PATCH 9/9] chore: address review comments --- flareon-cli/src/migration_generator.rs | 384 ++++++++---------- flareon-cli/tests/migration_generator.rs | 24 +- flareon-codegen/src/expr.rs | 25 ++ flareon-codegen/src/lib.rs | 16 - flareon-codegen/src/maybe_unknown.rs | 62 --- flareon-codegen/src/model.rs | 107 +++-- flareon-macros/src/model.rs | 2 +- .../src/auth/db/migrations/m_0001_initial.rs | 2 +- flareon/src/db.rs | 4 +- flareon/src/db/relations.rs | 3 +- flareon/tests/db.rs | 4 +- 11 files changed, 284 insertions(+), 349 deletions(-) delete mode 100644 flareon-codegen/src/maybe_unknown.rs diff --git a/flareon-cli/src/migration_generator.rs b/flareon-cli/src/migration_generator.rs index 00b68205..45b43979 100644 --- a/flareon-cli/src/migration_generator.rs +++ b/flareon-cli/src/migration_generator.rs @@ -11,7 +11,7 @@ use darling::FromMeta; use flareon::db::migrations::{DynMigration, MigrationEngine}; use flareon_codegen::model::{Field, Model, ModelArgs, ModelOpts, ModelType}; use flareon_codegen::symbol_resolver::SymbolResolver; -use petgraph::graph::{DiGraph, NodeIndex}; +use petgraph::graph::DiGraph; use petgraph::visit::EdgeRef; use proc_macro2::TokenStream; use quote::{format_ident, quote, ToTokens}; @@ -114,18 +114,8 @@ impl MigrationGenerator { let migration_name = migration_processor.next_migration_name()?; let dependencies = migration_processor.base_dependencies(); - let mut migration = GeneratedMigration { - migration_name, - modified_models, - dependencies, - operations, - }; - migration.remove_cycles(); - migration.toposort_operations(); - migration - .dependencies - .extend(migration.get_foreign_key_dependencies()); - + let migration = + GeneratedMigration::new(migration_name, modified_models, dependencies, operations); Ok(Some(migration)) } } @@ -323,10 +313,7 @@ impl MigrationGenerator { fn make_create_model_operation(app_model: &ModelInSource) -> DynOperation { DynOperation::CreateModel { table_name: app_model.model.table_name.clone(), - model_ty: app_model.model.resolved_ty.clone().expect( - "resolved_ty is expected to be present when \ - parsing the entire file with symbol resolver", - ), + model_ty: app_model.model.resolved_ty.clone(), fields: app_model.model.fields.clone(), } } @@ -388,10 +375,7 @@ impl MigrationGenerator { fn make_add_field_operation(app_model: &ModelInSource, field: &Field) -> DynOperation { DynOperation::AddField { table_name: app_model.model.table_name.clone(), - model_ty: app_model.model.resolved_ty.clone().expect( - "resolved_ty is expected to be present \ - when parsing the entire file with symbol resolver", - ), + model_ty: app_model.model.resolved_ty.clone(), field: field.clone(), } } @@ -656,7 +640,7 @@ impl ModelInSource { let input: syn::DeriveInput = item.clone().into(); let opts = ModelOpts::new_from_derive_input(&input) .map_err(|e| anyhow::anyhow!("cannot parse model: {}", e))?; - let model = opts.as_model(args, Some(symbol_resolver))?; + let model = opts.as_model(args, symbol_resolver)?; Ok(Self { model_item: item, @@ -676,11 +660,30 @@ pub struct GeneratedMigration { } impl GeneratedMigration { + #[must_use] + fn new( + migration_name: String, + modified_models: Vec, + mut dependencies: Vec, + mut operations: Vec, + ) -> Self { + Self::remove_cycles(&mut operations); + Self::toposort_operations(&mut operations); + dependencies.extend(Self::get_foreign_key_dependencies(&operations)); + + Self { + migration_name, + modified_models, + dependencies, + operations, + } + } + /// Get the list of [`DynDependency`] for all foreign keys that point /// to models that are **not** created in this migration. - fn get_foreign_key_dependencies(&self) -> Vec { - let create_ops = self.get_create_ops_map(); - let ops_adding_foreign_keys = self.get_ops_adding_foreign_keys(); + fn get_foreign_key_dependencies(operations: &[DynOperation]) -> Vec { + let create_ops = Self::get_create_ops_map(operations); + let ops_adding_foreign_keys = Self::get_ops_adding_foreign_keys(operations); let mut dependencies = Vec::new(); for (_index, dependency_ty) in &ops_adding_foreign_keys { @@ -699,27 +702,29 @@ impl GeneratedMigration { /// This method tries to minimize the number of operations added by /// calculating the minimum feedback arc set of the dependency graph. /// - /// This method modifies the `operations` field in place. + /// This method modifies the `operations` parameter in place. /// /// # See also /// /// * [`Self::remove_dependency`] - fn remove_cycles(&mut self) { - let graph = self.construct_dependency_graph(); + fn remove_cycles(operations: &mut Vec) { + let graph = Self::construct_dependency_graph(operations); let cycle_edges = petgraph::algo::feedback_arc_set::greedy_feedback_arc_set(&graph); for edge_id in cycle_edges { - let (from, to) = graph.edge_endpoints(edge_id.id()).unwrap(); + let (from, to) = graph + .edge_endpoints(edge_id.id()) + .expect("greedy_feedback_arc_set should always return valid edge refs"); - let to_op = self.operations[to.index()].clone(); - let from_op = &mut self.operations[from.index()]; + let to_op = operations[to.index()].clone(); + let from_op = &mut operations[from.index()]; debug!( "Removing cycle by removing operation {:?} that depends on {:?}", from_op, to_op ); let to_add = Self::remove_dependency(from_op, &to_op); - self.operations.extend(to_add); + operations.extend(to_add); } } @@ -787,8 +792,8 @@ impl GeneratedMigration { /// /// This method should be called after removing cycles; otherwise it will /// panic. - fn toposort_operations(&mut self) { - let graph = self.construct_dependency_graph(); + fn toposort_operations(operations: &mut [DynOperation]) { + let graph = Self::construct_dependency_graph(operations); let sorted = petgraph::algo::toposort(&graph, None) .expect("cycles shouldn't exist after removing them"); @@ -796,23 +801,23 @@ impl GeneratedMigration { .into_iter() .map(petgraph::graph::NodeIndex::index) .collect::>(); - flareon::__private::apply_permutation(&mut self.operations, &mut sorted); + flareon::__private::apply_permutation(operations, &mut sorted); } /// Construct a graph that represents reverse dependencies between - /// operations in this migration. + /// given operations. /// /// The graph is directed and has an edge from operation A to operation B /// if operation B creates a foreign key that points to a model created by /// operation A. #[must_use] - fn construct_dependency_graph(&mut self) -> DiGraph { - let create_ops = self.get_create_ops_map(); - let ops_adding_foreign_keys = self.get_ops_adding_foreign_keys(); + fn construct_dependency_graph(operations: &[DynOperation]) -> DiGraph { + let create_ops = Self::get_create_ops_map(operations); + let ops_adding_foreign_keys = Self::get_ops_adding_foreign_keys(operations); - let mut graph = DiGraph::with_capacity(self.operations.len(), 0); + let mut graph = DiGraph::with_capacity(operations.len(), 0); - for i in 0..self.operations.len() { + for i in 0..operations.len() { graph.add_node(i); } for (i, dependency_ty) in &ops_adding_foreign_keys { @@ -831,8 +836,8 @@ impl GeneratedMigration { /// Return a map of (resolved) model types to the index of the /// operation that creates given model. #[must_use] - fn get_create_ops_map(&self) -> HashMap { - self.operations + fn get_create_ops_map(operations: &[DynOperation]) -> HashMap { + operations .iter() .enumerate() .filter_map(|(i, op)| match op { @@ -845,8 +850,8 @@ impl GeneratedMigration { /// Return a list of operations that add foreign keys as tuples of /// operation index and the type of the model that foreign key points to. #[must_use] - fn get_ops_adding_foreign_keys(&self) -> Vec<(usize, syn::Type)> { - self.operations + fn get_ops_adding_foreign_keys(operations: &[DynOperation]) -> Vec<(usize, syn::Type)> { + operations .iter() .enumerate() .flat_map(|(i, op)| match op { @@ -907,19 +912,13 @@ impl Repr for Field { let mut tokens = quote! { ::flareon::db::migrations::Field::new(::flareon::db::Identifier::new(#column_name), <#ty as ::flareon::db::DatabaseField>::TYPE) }; - if self.auto_value.expect( - "auto_value is expected to be present \ - when parsing the entire file with symbol resolver", - ) { + if self.auto_value { tokens = quote! { #tokens.auto() } } if self.primary_key { tokens = quote! { #tokens.primary_key() } } - if let Some(fk_spec) = self.foreign_key.clone().expect( - "foreign_key is expected to be present \ - when parsing the entire file with symbol resolver", - ) { + if let Some(fk_spec) = self.foreign_key.clone() { let to_model = &fk_spec.to_model; tokens = quote! { @@ -1020,10 +1019,7 @@ fn is_field_foreign_key_to(field: &Field, ty: &syn::Type) -> bool { /// Returns the type of the model that the given field is a foreign key to. /// Returns [`None`] if the field is not a foreign key. fn foreign_key_for_field(field: &Field) -> Option { - match field.foreign_key.clone().expect( - "foreign_key is expected to be present \ - when parsing the entire file with symbol resolver", - ) { + match field.foreign_key.clone() { None => None, Some(foreign_key_spec) => Some(foreign_key_spec.to_model), } @@ -1098,7 +1094,6 @@ impl Error for ParsingError {} #[cfg(test)] mod tests { - use flareon_codegen::maybe_unknown::MaybeUnknown; use flareon_codegen::model::ForeignKeySpec; use super::*; @@ -1142,43 +1137,38 @@ mod tests { #[test] fn toposort_operations() { - let mut migration = GeneratedMigration { - migration_name: "test_migration".to_string(), - modified_models: vec![], - dependencies: vec![], - operations: vec![ - DynOperation::AddField { - table_name: "table2".to_string(), - model_ty: parse_quote!(Table2), - field: Field { - field_name: format_ident!("field1"), - column_name: "field1".to_string(), - ty: parse_quote!(i32), - auto_value: MaybeUnknown::Known(false), - primary_key: false, - unique: false, - foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { - to_model: parse_quote!(Table1), - })), - }, - }, - DynOperation::CreateModel { - table_name: "table1".to_string(), - model_ty: parse_quote!(Table1), - fields: vec![], + let mut operations = vec![ + DynOperation::AddField { + table_name: "table2".to_string(), + model_ty: parse_quote!(Table2), + field: Field { + field_name: format_ident!("field1"), + column_name: "field1".to_string(), + ty: parse_quote!(i32), + auto_value: false, + primary_key: false, + unique: false, + foreign_key: Some(ForeignKeySpec { + to_model: parse_quote!(Table1), + }), }, - ], - }; + }, + DynOperation::CreateModel { + table_name: "table1".to_string(), + model_ty: parse_quote!(Table1), + fields: vec![], + }, + ]; - migration.toposort_operations(); + GeneratedMigration::toposort_operations(&mut operations); - assert_eq!(migration.operations.len(), 2); - if let DynOperation::CreateModel { table_name, .. } = &migration.operations[0] { + assert_eq!(operations.len(), 2); + if let DynOperation::CreateModel { table_name, .. } = &operations[0] { assert_eq!(table_name, "table1"); } else { panic!("Expected CreateModel operation"); } - if let DynOperation::AddField { table_name, .. } = &migration.operations[1] { + if let DynOperation::AddField { table_name, .. } = &operations[1] { assert_eq!(table_name, "table2"); } else { panic!("Expected AddField operation"); @@ -1187,50 +1177,45 @@ mod tests { #[test] fn remove_cycles() { - let mut migration = GeneratedMigration { - migration_name: "test_migration".to_string(), - modified_models: vec![], - dependencies: vec![], - operations: vec![ - DynOperation::CreateModel { - table_name: "table1".to_string(), - model_ty: parse_quote!(Table1), - fields: vec![Field { - field_name: format_ident!("field1"), - column_name: "field1".to_string(), - ty: parse_quote!(ForeignKey), - auto_value: MaybeUnknown::Known(false), - primary_key: false, - unique: false, - foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { - to_model: parse_quote!(Table2), - })), - }], - }, - DynOperation::CreateModel { - table_name: "table2".to_string(), - model_ty: parse_quote!(Table2), - fields: vec![Field { - field_name: format_ident!("field1"), - column_name: "field1".to_string(), - ty: parse_quote!(ForeignKey), - auto_value: MaybeUnknown::Known(false), - primary_key: false, - unique: false, - foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { - to_model: parse_quote!(Table1), - })), - }], - }, - ], - }; + let mut operations = vec![ + DynOperation::CreateModel { + table_name: "table1".to_string(), + model_ty: parse_quote!(Table1), + fields: vec![Field { + field_name: format_ident!("field1"), + column_name: "field1".to_string(), + ty: parse_quote!(ForeignKey), + auto_value: false, + primary_key: false, + unique: false, + foreign_key: Some(ForeignKeySpec { + to_model: parse_quote!(Table2), + }), + }], + }, + DynOperation::CreateModel { + table_name: "table2".to_string(), + model_ty: parse_quote!(Table2), + fields: vec![Field { + field_name: format_ident!("field1"), + column_name: "field1".to_string(), + ty: parse_quote!(ForeignKey), + auto_value: false, + primary_key: false, + unique: false, + foreign_key: Some(ForeignKeySpec { + to_model: parse_quote!(Table1), + }), + }], + }, + ]; - migration.remove_cycles(); + GeneratedMigration::remove_cycles(&mut operations); - assert_eq!(migration.operations.len(), 3); + assert_eq!(operations.len(), 3); if let DynOperation::CreateModel { table_name, fields, .. - } = &migration.operations[0] + } = &operations[0] { assert_eq!(table_name, "table1"); assert!(!fields.is_empty()); @@ -1239,14 +1224,14 @@ mod tests { } if let DynOperation::CreateModel { table_name, fields, .. - } = &migration.operations[1] + } = &operations[1] { assert_eq!(table_name, "table2"); assert!(fields.is_empty()); } else { panic!("Expected CreateModel operation"); } - if let DynOperation::AddField { table_name, .. } = &migration.operations[2] { + if let DynOperation::AddField { table_name, .. } = &operations[2] { assert_eq!(table_name, "table2"); } else { panic!("Expected AddField operation"); @@ -1262,12 +1247,12 @@ mod tests { field_name: format_ident!("field1"), column_name: "field1".to_string(), ty: parse_quote!(ForeignKey), - auto_value: MaybeUnknown::Known(false), + auto_value: false, primary_key: false, unique: false, - foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { + foreign_key: Some(ForeignKeySpec { to_model: parse_quote!(Table2), - })), + }), }], }; @@ -1298,45 +1283,35 @@ mod tests { #[test] fn get_foreign_key_dependencies_no_foreign_keys() { - let migration = GeneratedMigration { - migration_name: "test_migration".to_string(), - modified_models: vec![], - dependencies: vec![], - operations: vec![DynOperation::CreateModel { - table_name: "table1".to_string(), - model_ty: parse_quote!(Table1), - fields: vec![], - }], - }; + let operations = vec![DynOperation::CreateModel { + table_name: "table1".to_string(), + model_ty: parse_quote!(Table1), + fields: vec![], + }]; - let external_dependencies = migration.get_foreign_key_dependencies(); + let external_dependencies = GeneratedMigration::get_foreign_key_dependencies(&operations); assert!(external_dependencies.is_empty()); } #[test] fn get_foreign_key_dependencies_with_foreign_keys() { - let migration = GeneratedMigration { - migration_name: "test_migration".to_string(), - modified_models: vec![], - dependencies: vec![], - operations: vec![DynOperation::CreateModel { - table_name: "table1".to_string(), - model_ty: parse_quote!(Table1), - fields: vec![Field { - field_name: format_ident!("field1"), - column_name: "field1".to_string(), - ty: parse_quote!(ForeignKey), - auto_value: MaybeUnknown::Known(false), - primary_key: false, - unique: false, - foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { - to_model: parse_quote!(crate::Table2), - })), - }], + let operations = vec![DynOperation::CreateModel { + table_name: "table1".to_string(), + model_ty: parse_quote!(Table1), + fields: vec![Field { + field_name: format_ident!("field1"), + column_name: "field1".to_string(), + ty: parse_quote!(ForeignKey), + auto_value: false, + primary_key: false, + unique: false, + foreign_key: Some(ForeignKeySpec { + to_model: parse_quote!(crate::Table2), + }), }], - }; + }]; - let external_dependencies = migration.get_foreign_key_dependencies(); + let external_dependencies = GeneratedMigration::get_foreign_key_dependencies(&operations); assert_eq!(external_dependencies.len(), 1); assert_eq!( external_dependencies[0], @@ -1348,45 +1323,40 @@ mod tests { #[test] fn get_foreign_key_dependencies_with_multiple_foreign_keys() { - let migration = GeneratedMigration { - migration_name: "test_migration".to_string(), - modified_models: vec![], - dependencies: vec![], - operations: vec![ - DynOperation::CreateModel { - table_name: "table1".to_string(), - model_ty: parse_quote!(Table1), - fields: vec![Field { - field_name: format_ident!("field1"), - column_name: "field1".to_string(), - ty: parse_quote!(ForeignKey), - auto_value: MaybeUnknown::Known(false), - primary_key: false, - unique: false, - foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { - to_model: parse_quote!(my_crate::Table2), - })), - }], - }, - DynOperation::CreateModel { - table_name: "table3".to_string(), - model_ty: parse_quote!(Table3), - fields: vec![Field { - field_name: format_ident!("field2"), - column_name: "field2".to_string(), - ty: parse_quote!(ForeignKey), - auto_value: MaybeUnknown::Known(false), - primary_key: false, - unique: false, - foreign_key: MaybeUnknown::Known(Some(ForeignKeySpec { - to_model: parse_quote!(crate::Table4), - })), - }], - }, - ], - }; + let operations = vec![ + DynOperation::CreateModel { + table_name: "table1".to_string(), + model_ty: parse_quote!(Table1), + fields: vec![Field { + field_name: format_ident!("field1"), + column_name: "field1".to_string(), + ty: parse_quote!(ForeignKey), + auto_value: false, + primary_key: false, + unique: false, + foreign_key: Some(ForeignKeySpec { + to_model: parse_quote!(my_crate::Table2), + }), + }], + }, + DynOperation::CreateModel { + table_name: "table3".to_string(), + model_ty: parse_quote!(Table3), + fields: vec![Field { + field_name: format_ident!("field2"), + column_name: "field2".to_string(), + ty: parse_quote!(ForeignKey), + auto_value: false, + primary_key: false, + unique: false, + foreign_key: Some(ForeignKeySpec { + to_model: parse_quote!(crate::Table4), + }), + }], + }, + ]; - let external_dependencies = migration.get_foreign_key_dependencies(); + let external_dependencies = GeneratedMigration::get_foreign_key_dependencies(&operations); assert_eq!(external_dependencies.len(), 2); assert!(external_dependencies.contains(&DynDependency::Model { model_type: parse_quote!(my_crate::Table2), @@ -1408,17 +1378,17 @@ mod tests { model: Model { name: format_ident!("TestModel"), original_name: "TestModel".to_string(), - resolved_ty: Some(parse_quote!(TestModel)), + resolved_ty: parse_quote!(TestModel), model_type: Default::default(), table_name: "test_model".to_string(), pk_field: Field { field_name: format_ident!("id"), column_name: "id".to_string(), ty: parse_quote!(i32), - auto_value: MaybeUnknown::Known(true), + auto_value: true, primary_key: true, unique: false, - foreign_key: MaybeUnknown::Known(None), + foreign_key: None, }, fields: vec![], }, @@ -1428,10 +1398,10 @@ mod tests { field_name: format_ident!("new_field"), column_name: "new_field".to_string(), ty: parse_quote!(i32), - auto_value: MaybeUnknown::Known(false), + auto_value: false, primary_key: false, unique: false, - foreign_key: MaybeUnknown::Known(None), + foreign_key: None, }; let operation = MigrationGenerator::make_add_field_operation(&app_model, &field); diff --git a/flareon-cli/tests/migration_generator.rs b/flareon-cli/tests/migration_generator.rs index e5ba5a78..85be342a 100644 --- a/flareon-cli/tests/migration_generator.rs +++ b/flareon-cli/tests/migration_generator.rs @@ -33,26 +33,26 @@ fn create_model_state_test() { let field = &fields[0]; assert_eq!(field.column_name, "id"); assert!(field.primary_key); - assert!(field.auto_value.unwrap()); - assert!(field.foreign_key.clone().unwrap().is_none()); + assert!(field.auto_value); + assert!(field.foreign_key.clone().is_none()); let field = &fields[1]; assert_eq!(field.column_name, "field_1"); assert!(!field.primary_key); - assert!(!field.auto_value.unwrap()); - assert!(field.foreign_key.clone().unwrap().is_none()); + assert!(!field.auto_value); + assert!(field.foreign_key.clone().is_none()); let field = &fields[2]; assert_eq!(field.column_name, "field_2"); assert!(!field.primary_key); - assert!(!field.auto_value.unwrap()); - assert!(field.foreign_key.clone().unwrap().is_none()); + assert!(!field.auto_value); + assert!(field.foreign_key.clone().is_none()); let field = &fields[3]; assert_eq!(field.column_name, "parent"); assert!(!field.primary_key); - assert!(!field.auto_value.unwrap()); - assert!(field.foreign_key.clone().unwrap().is_some()); + assert!(!field.auto_value); + assert!(field.foreign_key.clone().is_some()); } #[test] @@ -81,14 +81,14 @@ fn create_models_foreign_key() { let field = &fields[0]; assert_eq!(field.column_name, "id"); assert!(field.primary_key); - assert!(field.auto_value.unwrap()); - assert!(field.foreign_key.clone().unwrap().is_none()); + assert!(field.auto_value); + assert!(field.foreign_key.clone().is_none()); let field = &fields[1]; assert_eq!(field.column_name, "parent"); assert!(!field.primary_key); - assert!(!field.auto_value.unwrap()); - assert!(field.foreign_key.clone().unwrap().is_some()); + assert!(!field.auto_value); + assert!(field.foreign_key.clone().is_some()); } #[test] diff --git a/flareon-codegen/src/expr.rs b/flareon-codegen/src/expr.rs index 53187f38..b8f2846d 100644 --- a/flareon-codegen/src/expr.rs +++ b/flareon-codegen/src/expr.rs @@ -583,6 +583,20 @@ mod tests { assert_eq!(expected, unwrap_syn(Expr::parse(input))); } + #[test] + fn function_call_with_args() { + let input = quote! { $a == bar(42, "baz") }; + let expected = Expr::Eq( + Box::new(field("a")), + Box::new(Expr::FunctionCall { + function: Box::new(value("bar")), + args: vec![parse_quote!(42), parse_quote!("baz")], + }), + ); + + assert_eq!(expected, unwrap_syn(Expr::parse(input))); + } + #[test] fn parse_member_access() { let input = quote! { $a == foo.bar }; @@ -594,6 +608,17 @@ mod tests { assert_eq!(expected, unwrap_syn(Expr::parse(input))); } + #[test] + fn parse_member_access_multiple() { + let input = quote! { $a == foo.bar.baz }; + let expected = Expr::Eq( + Box::new(field("a")), + Box::new(member_access(member_access(value("foo"), "bar"), "baz")), + ); + + assert_eq!(expected, unwrap_syn(Expr::parse(input))); + } + #[test] fn parse_reference() { let input = quote! { &foo }; diff --git a/flareon-codegen/src/lib.rs b/flareon-codegen/src/lib.rs index 4e9c1cc0..6b9bcdd3 100644 --- a/flareon-codegen/src/lib.rs +++ b/flareon-codegen/src/lib.rs @@ -1,22 +1,6 @@ extern crate self as flareon_codegen; pub mod expr; -pub mod maybe_unknown; pub mod model; #[cfg(feature = "symbol-resolver")] pub mod symbol_resolver; - -#[cfg(not(feature = "symbol-resolver"))] -pub mod symbol_resolver { - /// Dummy `SymbolResolver` for use in contexts when it's not useful (e.g. - /// macros which do not have access to the entire source tree to look - /// for `use` statements anyway). - /// - /// This is defined as an empty enum so that it's entirely optimized out by - /// the compiler, along with all functions that reference it. - pub enum SymbolResolver {} - - impl SymbolResolver { - pub fn resolve(&self, _: &mut syn::Type) {} - } -} diff --git a/flareon-codegen/src/maybe_unknown.rs b/flareon-codegen/src/maybe_unknown.rs deleted file mode 100644 index 23784356..00000000 --- a/flareon-codegen/src/maybe_unknown.rs +++ /dev/null @@ -1,62 +0,0 @@ -/// Wraps a type whose value may or may not be possible to be determined using -/// the information available. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] -pub enum MaybeUnknown { - /// Indicates that this instance is determined to be a certain value - /// (possibly [`None`] if wrapping an [`Option`]). - Known(T), - /// Indicates that the value is unknown. - Unknown, -} - -impl MaybeUnknown { - pub fn unwrap(self) -> T { - self.expect("called `MaybeUnknown::unwrap()` on an `Unknown` value") - } - - pub fn expect(self, msg: &str) -> T { - match self { - MaybeUnknown::Known(value) => value, - MaybeUnknown::Unknown => { - panic!("{}", msg) - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn maybe_unknown_determined() { - let value = MaybeUnknown::Known(42); - assert_eq!(value.unwrap(), 42); - } - - #[test] - fn maybe_unknown_known_none() { - let value = MaybeUnknown::Known(None::<()>); - assert!(value.unwrap().is_none()); - } - - #[test] - #[should_panic(expected = "called `MaybeUnknown::unwrap()` on an `Unknown` value")] - fn maybe_unknown_unknown_unwrap() { - let value: MaybeUnknown = MaybeUnknown::Unknown; - assert_eq!(value.unwrap(), 42); - } - - #[test] - fn maybe_unknown_expect() { - let value = MaybeUnknown::Known(42); - assert_eq!(value.expect("value should be determined"), 42); - } - - #[test] - #[should_panic(expected = "value should be determined")] - fn maybe_unknown_unknown_expect() { - let value: MaybeUnknown = MaybeUnknown::Unknown; - value.expect("value should be determined"); - } -} diff --git a/flareon-codegen/src/model.rs b/flareon-codegen/src/model.rs index d820f2bc..66800567 100644 --- a/flareon-codegen/src/model.rs +++ b/flareon-codegen/src/model.rs @@ -2,7 +2,7 @@ use convert_case::{Case, Casing}; use darling::{FromDeriveInput, FromField, FromMeta}; use syn::spanned::Spanned; -use crate::maybe_unknown::MaybeUnknown; +#[cfg(feature = "symbol-resolver")] use crate::symbol_resolver::SymbolResolver; #[allow(clippy::module_name_repetitions)] @@ -66,12 +66,17 @@ impl ModelOpts { pub fn as_model( &self, args: &ModelArgs, - symbol_resolver: Option<&SymbolResolver>, + #[cfg(feature = "symbol-resolver")] symbol_resolver: &SymbolResolver, ) -> Result { + #[cfg(feature = "symbol-resolver")] + let as_field = |field: &&FieldOpts| field.as_field(symbol_resolver); + #[cfg(not(feature = "symbol-resolver"))] + let as_field = |field: &&FieldOpts| field.as_field(); + let fields = self .fields() .iter() - .map(|field| field.as_field(symbol_resolver)) + .map(as_field) .collect::, _>>()?; let mut original_name = self.ident.to_string(); @@ -94,21 +99,20 @@ impl ModelOpts { let primary_key_field = self.get_primary_key_field(&fields)?; - let ty = match symbol_resolver { - Some(symbol_resolver) => { - let mut ty = syn::Type::Path(syn::TypePath { - qself: None, - path: syn::Path::from(self.ident.clone()), - }); - symbol_resolver.resolve(&mut ty); - Some(ty) - } - None => None, + #[cfg(feature = "symbol-resolver")] + let ty = { + let mut ty = syn::Type::Path(syn::TypePath { + qself: None, + path: syn::Path::from(self.ident.clone()), + }); + symbol_resolver.resolve(&mut ty); + ty }; Ok(Model { name: self.ident.clone(), original_name, + #[cfg(feature = "symbol-resolver")] resolved_ty: ty, model_type: args.model_type, table_name, @@ -147,12 +151,14 @@ pub struct FieldOpts { } impl FieldOpts { + #[cfg(feature = "symbol-resolver")] fn find_type(&self, type_to_find: &str, symbol_resolver: &SymbolResolver) -> Option { let mut ty = self.ty.clone(); symbol_resolver.resolve(&mut ty); Self::find_type_resolved(&ty, type_to_find) } + #[cfg(feature = "symbol-resolver")] fn find_type_resolved(ty: &syn::Type, type_to_find: &str) -> Option { if let syn::Type::Path(type_path) = ty { let name = type_path @@ -179,6 +185,7 @@ impl FieldOpts { None } + #[cfg(feature = "symbol-resolver")] fn find_type_in_generics( arg: &syn::AngleBracketedGenericArguments, type_to_find: &str, @@ -201,29 +208,31 @@ impl FieldOpts { /// /// Panics if the field does not have an identifier (i.e. it is a tuple /// struct). - pub fn as_field(&self, symbol_resolver: Option<&SymbolResolver>) -> Result { + pub fn as_field( + &self, + #[cfg(feature = "symbol-resolver")] symbol_resolver: &SymbolResolver, + ) -> Result { let name = self.ident.as_ref().unwrap(); let column_name = name.to_string(); - let (auto_value, foreign_key) = match symbol_resolver { - Some(resolver) => ( - MaybeUnknown::Known(self.find_type("flareon::db::Auto", resolver).is_some()), - MaybeUnknown::Known( - self.find_type("flareon::db::ForeignKey", resolver) - .map(ForeignKeySpec::try_from) - .transpose()?, - ), - ), - None => (MaybeUnknown::Unknown, MaybeUnknown::Unknown), - }; + #[cfg(feature = "symbol-resolver")] + let (auto_value, foreign_key) = ( + self.find_type("flareon::db::Auto", symbol_resolver) + .is_some(), + self.find_type("flareon::db::ForeignKey", symbol_resolver) + .map(ForeignKeySpec::try_from) + .transpose()?, + ); let is_primary_key = column_name == "id" || self.primary_key.is_present(); Ok(Field { field_name: name.clone(), column_name, ty: self.ty.clone(), + #[cfg(feature = "symbol-resolver")] auto_value, primary_key: is_primary_key, + #[cfg(feature = "symbol-resolver")] foreign_key, unique: self.unique.is_present(), }) @@ -234,9 +243,9 @@ impl FieldOpts { pub struct Model { pub name: syn::Ident, pub original_name: String, - /// The type of the model, or [`None`] if the symbol resolver was not - /// enabled. - pub resolved_ty: Option, + /// The type of the model resolved by symbol resolver. + #[cfg(feature = "symbol-resolver")] + pub resolved_ty: syn::Type, pub model_type: ModelType, pub table_name: String, pub pk_field: Field, @@ -255,16 +264,14 @@ pub struct Field { pub field_name: syn::Ident, pub column_name: String, pub ty: syn::Type, - /// Whether the field is an auto field (e.g. `id`); - /// [`MaybeUnknown::Unknown`] if this `Field` instance was not resolved with - /// a [`SymbolResolver`]. - pub auto_value: MaybeUnknown, + /// Whether the field is an auto field (e.g. `id`). + #[cfg(feature = "symbol-resolver")] + pub auto_value: bool, pub primary_key: bool, - /// [`Some`] wrapped in [`MaybeUnknown::Known`] if this field is a - /// foreign key; [`None`] wrapped in [`MaybeUnknown::Known`] if this - /// field is determined not to be a foreign key; [`MaybeUnknown::Unknown`] - /// if this `Field` instance was not resolved with a [`SymbolResolver`]. - pub foreign_key: MaybeUnknown>, + /// [`Some`] if this field is a foreign key; [`None`] if this field is + /// determined not to be a foreign key. + #[cfg(feature = "symbol-resolver")] + pub foreign_key: Option, pub unique: bool, } @@ -355,6 +362,7 @@ mod tests { assert_eq!(fields[1].ident.as_ref().unwrap().to_string(), "name"); } + #[cfg(feature = "symbol-resolver")] #[test] fn model_opts_as_model() { let input: syn::DeriveInput = parse_quote! { @@ -365,13 +373,14 @@ mod tests { }; let opts = ModelOpts::new_from_derive_input(&input).unwrap(); let args = ModelArgs::default(); - let model = opts.as_model(&args, None).unwrap(); + let model = opts.as_model(&args, &SymbolResolver::new(vec![])).unwrap(); assert_eq!(model.name.to_string(), "TestModel"); assert_eq!(model.table_name, "test_model"); assert_eq!(model.fields.len(), 2); assert_eq!(model.field_count(), 2); } + #[cfg(feature = "symbol-resolver")] #[test] fn model_opts_as_model_migration() { let input: syn::DeriveInput = parse_quote! { @@ -383,13 +392,16 @@ mod tests { }; let opts = ModelOpts::new_from_derive_input(&input).unwrap(); let args = ModelArgs::from_meta(&input.attrs.first().unwrap().meta).unwrap(); - let err = opts.as_model(&args, None).unwrap_err(); + let err = opts + .as_model(&args, &SymbolResolver::new(vec![])) + .unwrap_err(); assert_eq!( err.to_string(), "migration model names must start with an underscore" ); } + #[cfg(feature = "symbol-resolver")] #[test] fn model_opts_as_model_pk_attr() { let input: syn::DeriveInput = parse_quote! { @@ -401,11 +413,12 @@ mod tests { }; let opts = ModelOpts::new_from_derive_input(&input).unwrap(); let args = ModelArgs::default(); - let model = opts.as_model(&args, None).unwrap(); + let model = opts.as_model(&args, &SymbolResolver::new(vec![])).unwrap(); assert_eq!(model.fields.len(), 1); assert!(model.fields[0].primary_key); } + #[cfg(feature = "symbol-resolver")] #[test] fn model_opts_as_model_no_pk() { let input: syn::DeriveInput = parse_quote! { @@ -416,7 +429,9 @@ mod tests { }; let opts = ModelOpts::new_from_derive_input(&input).unwrap(); let args = ModelArgs::default(); - let err = opts.as_model(&args, None).unwrap_err(); + let err = opts + .as_model(&args, &SymbolResolver::new(vec![])) + .unwrap_err(); assert_eq!( err.to_string(), "models must have a primary key field, either named `id` \ @@ -424,6 +439,7 @@ mod tests { ); } + #[cfg(feature = "symbol-resolver")] #[test] fn model_opts_as_model_multiple_pks() { let input: syn::DeriveInput = parse_quote! { @@ -437,13 +453,16 @@ mod tests { }; let opts = ModelOpts::new_from_derive_input(&input).unwrap(); let args = ModelArgs::default(); - let err = opts.as_model(&args, None).unwrap_err(); + let err = opts + .as_model(&args, &SymbolResolver::new(vec![])) + .unwrap_err(); assert_eq!( err.to_string(), "composite primary keys are not supported; only one primary key field is allowed" ); } + #[cfg(feature = "symbol-resolver")] #[test] fn field_opts_as_field() { let input: syn::Field = parse_quote! { @@ -451,13 +470,11 @@ mod tests { name: String }; let field_opts = FieldOpts::from_field(&input).unwrap(); - let field = field_opts.as_field(None).unwrap(); + let field = field_opts.as_field(&SymbolResolver::new(vec![])).unwrap(); assert_eq!(field.field_name.to_string(), "name"); assert_eq!(field.column_name, "name"); assert_eq!(field.ty, parse_quote!(String)); assert!(field.unique); - assert_eq!(field.auto_value, MaybeUnknown::Unknown); - assert_eq!(field.foreign_key, MaybeUnknown::Unknown); } #[test] diff --git a/flareon-macros/src/model.rs b/flareon-macros/src/model.rs index ce328290..288868da 100644 --- a/flareon-macros/src/model.rs +++ b/flareon-macros/src/model.rs @@ -27,7 +27,7 @@ pub(super) fn impl_model_for_struct( } }; - let model = match opts.as_model(&args, None) { + let model = match opts.as_model(&args) { Ok(val) => val, Err(err) => { return err.to_compile_error(); diff --git a/flareon/src/auth/db/migrations/m_0001_initial.rs b/flareon/src/auth/db/migrations/m_0001_initial.rs index 8809c085..821822de 100644 --- a/flareon/src/auth/db/migrations/m_0001_initial.rs +++ b/flareon/src/auth/db/migrations/m_0001_initial.rs @@ -3,7 +3,7 @@ #[derive(Debug, Copy, Clone)] pub(super) struct Migration; impl ::flareon::db::migrations::Migration for Migration { - const APP_NAME: &'static str = "flareon"; + const APP_NAME: &'static str = "flareon_auth"; const MIGRATION_NAME: &'static str = "m_0001_initial"; const DEPENDENCIES: &'static [::flareon::db::migrations::MigrationDependency] = &[]; const OPERATIONS: &'static [::flareon::db::migrations::Operation] = &[ diff --git a/flareon/src/db.rs b/flareon/src/db.rs index 35ddca49..bb52b718 100644 --- a/flareon/src/db.rs +++ b/flareon/src/db.rs @@ -358,12 +358,12 @@ impl DbFieldValue { } #[must_use] - pub fn unwrap_value(self) -> sea_query::Value { + pub fn unwrap_value(self) -> DbValue { self.expect_value("called DbValue::unwrap_value() on a wrong DbValue variant") } #[must_use] - pub fn expect_value(self, message: &str) -> sea_query::Value { + pub fn expect_value(self, message: &str) -> DbValue { match self { Self::Value(value) => value, Self::Auto => panic!("{message}"), diff --git a/flareon/src/db/relations.rs b/flareon/src/db/relations.rs index b634079f..e166b390 100644 --- a/flareon/src/db/relations.rs +++ b/flareon/src/db/relations.rs @@ -181,7 +181,8 @@ mod tests { } #[should_panic(expected = "object has not been retrieved from the database")] - fn test_unwrap_primary_key() { + #[test] + fn unwrap_primary_key() { let fk = ForeignKey::::PrimaryKey(Auto::fixed(1)); fk.unwrap(); } diff --git a/flareon/tests/db.rs b/flareon/tests/db.rs index 334354bc..7ce9785d 100644 --- a/flareon/tests/db.rs +++ b/flareon/tests/db.rs @@ -330,7 +330,7 @@ async fn foreign_keys_option(db: &mut TestDatabase) { CREATE_PARENT.forwards(db).await.unwrap(); CREATE_CHILD.forwards(db).await.unwrap(); - // no parent + // Test child with `None` parent let mut child = Child { id: Auto::auto(), parent: None, @@ -342,7 +342,7 @@ async fn foreign_keys_option(db: &mut TestDatabase) { query!(Child, $id == child.id).delete(&**db).await.unwrap(); - // with parent + // Test child with `Some` parent let mut parent = Parent { id: Auto::auto() }; parent.save(&**db).await.unwrap();