diff --git a/.travis.yml b/.travis.yml index 87cc8bc..89228d5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,6 +1,6 @@ language: crystal script: - - crystal spec + - crystal spec -D $DB_TYPE - crystal tool format --check services: - mysql @@ -18,11 +18,9 @@ before_script: - psql -c 'create database crecto_test;' -U postgres - psql $PG_URL < spec/migrations/pg_migrations.sql - sqlite3 ./crecto_test.db < spec/migrations/sqlite3_migrations.sql - - if [ ! -z "$PG_URL" ]; then cp ./spec/travis_pg_repo.cr ./spec/repo.cr; fi - - if [ ! -z "$MYSQL_URL" ]; then cp ./spec/travis_mysql_repo.cr ./spec/repo.cr; fi - - if [ ! -z "$SQLITE3_PATH" ]; then cp ./spec/travis_sqlite_repo.cr ./spec/repo.cr; fi + - cp ./spec/travis_${DB_TYPE}_repo.cr ./spec/repo.cr env: matrix: - - PG_URL=postgres://postgres@localhost:5432/crecto_test - - MYSQL_URL=mysql://root@localhost/crecto_test - - SQLITE3_PATH=sqlite3://./crecto_test.db + - PG_URL=postgres://postgres@localhost:5432/crecto_test DB_TYPE=pg + - MYSQL_URL=mysql://root@localhost/crecto_test DB_TYPE=mysql + - SQLITE3_PATH=sqlite3://./crecto_test.db DB_TYPE=sqlite diff --git a/spec/transactions_spec.cr b/spec/transactions_spec.cr index 95d7bbf..a7167de 100644 --- a/spec/transactions_spec.cr +++ b/spec/transactions_spec.cr @@ -183,8 +183,8 @@ describe Crecto do user = User.new expect_raises Crecto::InvalidChangeset do - Repo.transaction! do - Repo.insert!(user) + Repo.transaction! do |tx| + tx.insert!(user) end end end @@ -193,8 +193,8 @@ describe Crecto do user = User.new user.name = "this should insert in the transaction" - Repo.transaction! do - Repo.insert(user) + Repo.transaction! do |tx| + tx.insert(user) end users = Repo.all(User, Query.where(name: "this should insert in the transaction")) @@ -207,8 +207,8 @@ describe Crecto do user = quick_create_user("this should delete") - Repo.transaction! do - Repo.delete!(user) + Repo.transaction! do |tx| + tx.delete!(user) end users = Repo.all(User, Query.where(id: user.id)) @@ -222,8 +222,8 @@ describe Crecto do Repo.delete_all(Post) - Repo.transaction! do - Repo.delete_all(User) + Repo.transaction! do |tx| + tx.delete_all(User) end users = Repo.all(User) @@ -235,8 +235,8 @@ describe Crecto do user.name = "this should have changed 89ffsf" - Repo.transaction! do - Repo.update(user) + Repo.transaction! do |tx| + tx.update(user) end user = Repo.get!(User, user.id) @@ -248,8 +248,8 @@ describe Crecto do quick_create_user_with_things("testing_update_all", 123) quick_create_user_with_things("testing_update_all", 123) - Repo.transaction! do - Repo.update_all(User, Query.where(name: "testing_update_all"), {things: 9494}) + Repo.transaction! do |tx| + tx.update_all(User, Query.where(name: "testing_update_all"), {things: 9494}) end Repo.all(User, Query.where(things: 123)).size.should eq 0 @@ -268,12 +268,12 @@ describe Crecto do insert_user = User.new insert_user.name = "all_transactions_insert_user" - Repo.transaction! do - Repo.insert!(insert_user) - Repo.delete!(delete_user) - Repo.delete_all(Post) - Repo.update!(update_user) - Repo.update_all(User, Query.where(name: "perform_all"), {name: "perform_all_io2oj999"}) + Repo.transaction! do |tx| + tx.insert!(insert_user) + tx.delete!(delete_user) + tx.delete_all(Post) + tx.update!(update_user) + tx.update_all(User, Query.where(name: "perform_all"), {name: "perform_all_io2oj999"}) end # check insert happened @@ -336,6 +336,60 @@ describe Crecto do Repo.all(User, Query.where(name: "perform_all")).size.should eq 2 Repo.all(User, Query.where(name: "perform_all_io2oj999")).size.should eq 0 end + + # This only works for postgres for now + {% begin %} + {{ flag?(:pg) ? :it.id : :pending.id }} "allows reading records inserted inside the transaction" do + insert_user = User.new + insert_user.name = "insert_user" + + Repo.transaction! do |tx| + id = tx.insert!(insert_user).instance.id + tx.get(User, id).should_not eq(nil) + tx.get!(User, id).should_not eq(nil) + tx.get(User, id, Query.new).should_not eq(nil) + tx.get!(User, id, Query.new).should_not eq(nil) + tx.get_by(User, id: id).should_not eq(nil) + tx.get_by!(User, id: id).should_not eq(nil) + tx.get_by(User, id: id).should_not eq(nil) + tx.get_by!(User, id: id).should_not eq(nil) + tx.get_by(User, Query.where(id: id)).should_not eq(nil) + tx.get_by!(User, Query.where(id: id)).should_not eq(nil) + tx.all(User, Query.where(id: id)).first.should_not eq(nil) + tx.all(User, Query.where(id: id), preload: [] of Symbol).first.should_not eq(nil) + end + end + {% end %} + + # Sqlite doesn't support nesting transactions + {% unless flag?(:sqlite) %} + it "allows nesting transactions" do + Repo.delete_all(Post) + Repo.delete_all(User) + + insert_user = User.new + insert_user.name = "nested_transactions_insert_user" + invalid_user = User.new + delete_user = quick_create_user("nested_transactions_delete_user") + + Repo.transaction! do |tx| + tx.insert!(insert_user) + + expect_raises Crecto::InvalidChangeset do + Repo.transaction! do |inner_tx| + inner_tx.delete!(delete_user) + inner_tx.insert!(invalid_user) + end + end + end + + # check insert happened + Repo.all(User, Query.where(name: "nested_transactions_insert_user")).size.should eq 1 + + # check delete didn't happen + Repo.all(User, Query.where(name: "nested_transactions_delete_user")).size.should eq 1 + end + {% end %} end end end diff --git a/src/crecto/live_transaction.cr b/src/crecto/live_transaction.cr index f40aff3..217bb7c 100644 --- a/src/crecto/live_transaction.cr +++ b/src/crecto/live_transaction.cr @@ -1,13 +1,89 @@ -require "./multi" +require "./repo/query" module Crecto class LiveTransaction(T) + alias Query = Repo::Query + def initialize(@tx : DB::Transaction, @repo : T) end + def raw_exec(args : Array) + @repo.raw_exec(args, tx: @tx) + end + + def raw_exec(*args) + @repo.raw_exec(*args, tx: @tx) + end + + def raw_query(query, *args) + @repo.raw_query(query, *args, tx: @tx) do |rs| + yield rs + end + end + + def raw_query(query, args : Array) + @repo.raw_query(query, args, tx: @tx) + end + + def raw_query(query, *args) + @repo.raw_query(query, *args) + end + + def raw_scalar(*args) + @repo.raw_scalar(*args, tx: @tx) + end + + def all(queryable, query : Query, *, preload = [] of Symbol) + @repo.all(queryable, query, tx: @tx, preload: preload) + end + + def all(queryable, query = Query.new) + @repo.all(queryable, query, tx: @tx) + end + + def get(queryable, id) + @repo.get(queryable, id, tx: @tx) + end + + def get!(queryable, id) + @repo.get!(queryable, id, tx: @tx) + end + + def get(queryable, id, query : Query) + @repo.get(queryable, id, query, tx: @tx) + end + + def get!(queryable, id, query : Query) + @repo.get!(queryable, id, query, tx: @tx) + end + + def get_by(queryable, **opts) + @repo.get_by(queryable, @tx, **opts) + end + + def get_by(queryable, query) + @repo.get_by(queryable, query, tx: @tx) + end + + def get_by!(queryable, **opts) + @repo.get_by!(queryable, @tx, **opts) + end + + def get_by!(queryable, query) + @repo.get_by!(queryable, query, tx: @tx) + end + + def get_association(queryable_instance, association_name : Symbol) + @repo.get_association(queryable_instance, association_name, tx: @tx) + end + + def get_association!(queryable_instance, association_name : Symbol) + @repo.get_association!(queryable_instance, association_name, tx: @tx) + end + {% for type in %w[insert insert! delete delete! update update!] %} def {{type.id}}(queryable : Crecto::Model) - @repo.{{type.id}}(queryable, @tx) + @repo.{{type.id}}(queryable, tx: @tx) end def {{type.id}}(changeset : Crecto::Changeset::Changeset) @@ -16,15 +92,29 @@ module Crecto {% end %} def delete_all(queryable, query = Crecto::Repo::Query.new) - @repo.delete_all(queryable, query, @tx) + @repo.delete_all(queryable, query, tx: @tx) end def update_all(queryable, query, update_hash : Multi::UpdateHash) - @repo.update_all(queryable, query, update_hash, @tx) + @repo.update_all(queryable, query, update_hash, tx: @tx) end def update_all(queryable, query, update_tuple : NamedTuple) update_all(queryable, query, update_tuple.to_h) end + + def aggregate(queryable, aggregate_function : Symbol, field : Symbol) + @repo.aggregate(queryable, aggregate_function, field, tx: @tx) + end + + def aggregate(queryable, aggregate_function : Symbol, field : Symbol, query : Crecto::Repo::Query) + @repo.aggregate(queryable, aggregate_function, field, query, tx: @tx) + end + + def transaction! + @repo.transaction!(@tx) do |tx| + yield tx + end + end end end diff --git a/src/crecto/repo.cr b/src/crecto/repo.cr index 8b6639c..2282405 100644 --- a/src/crecto/repo.cr +++ b/src/crecto/repo.cr @@ -28,81 +28,69 @@ module Crecto end # Run a raw `exec` query directly on the adapter connection - def raw_exec(args : Array) - config.get_connection.exec(args) + def raw_exec(args : Array, tx : DB::Transaction? = nil) + (tx || config.get_connection).exec(args) end # Run a raw `exec` query directly on the adapter connection - def raw_exec(*args) - config.get_connection.exec(*args) + def raw_exec(*args, tx : DB::Transaction? = nil) + (tx || config.get_connection).exec(*args) end # Run a raw `query` query directly on the adapter connection - def raw_query(query, *args) - config.get_connection.query(query, *args) do |rs| + def raw_query(query, *args, tx : DB::Transaction? = nil) + (tx || config.get_connection).query(query, *args) do |rs| yield rs end end # Run a raw `query` query directly on the adapter connection - def raw_query(query, args : Array) - config.get_connection.query(args) + def raw_query(query, args : Array, tx : DB::Transaction? = nil) + (tx || config.get_connection).query(args) end # Run a raw `query` query directly on the adapter connection - def raw_query(query, *args) - config.get_connection.query(*args) + def raw_query(query, *args, tx : DB::Transaction? = nil) + (tx || config.get_connection).query(*args) end # Run a raw `scalar` query directly on the adapter connection - def raw_scalar(*args) - config.get_connection.scalar(*args) + def raw_scalar(*args, tx : DB::Transaction? = nil) + (tx || config.get_connection).scalar(*args) end - # Return a list of *queryable* instances using *query* + # Return a list of *queryable* instances (optionally) using *query* # # ``` # query = Query.where(name: "fred") # users = Repo.all(User, query) # ``` - def all(queryable, query : Query? = Query.new, **opts) - q = config.adapter.run(config.get_connection, :all, queryable, query).as(DB::ResultSet) - - results = queryable.from_rs(q.as(DB::ResultSet)) - - opt_preloads = opts.fetch(:preload, [] of Symbol) - preloads = query.preloads + opt_preloads.map { |a| {symbol: a, query: nil} } - if preloads.any? - add_preloads(results, queryable, preloads) - end - - results - end - - # Returns a list of *queryable* instances. Accepts an optional `query` # # ``` - # users = Crecto::Repo.all(User) + # users = Repo.all(User) # ``` - def all(queryable, query = Query.new) - q = config.adapter.run(config.get_connection, :all, queryable, query).as(DB::ResultSet) - results = queryable.from_rs(q) + def all(queryable, query : Query? = Query.new, *, tx : DB::Transaction? = nil, preload = [] of Symbol) + q = config.adapter.run(tx || config.get_connection, :all, queryable, query).as(DB::ResultSet) + + results = queryable.from_rs(q.as(DB::ResultSet)) - preloads = query.preloads - if preloads.any? - add_preloads(results, queryable, preloads) + query_preloads = query.try(&.preloads) || [] of NamedTuple(symbol: Symbol, query: Query?) + combined_preloads = query_preloads + preload.map { |a| {symbol: a, query: nil} } + + if combined_preloads.any? + add_preloads(results, queryable, combined_preloads, tx) end - results - end + results + end # Return a single nilable insance of *queryable* by primary key with *id*. # # ``` # user = Repo.get(User, 1) # ``` - def get(queryable, id) - q = config.adapter.run(config.get_connection, :get, queryable, id).as(DB::ResultSet) + def get(queryable, id, *, tx : DB::Transaction? = nil) + q = config.adapter.run(tx || config.get_connection, :get, queryable, id).as(DB::ResultSet) results = queryable.from_rs(q) results.first if results.any? end @@ -113,8 +101,8 @@ module Crecto # ``` # user = Repo.get(User, 1) # ``` - def get!(queryable, id) - if result = get(queryable, id) + def get!(queryable, id, *, tx : DB::Transaction? = nil) + if result = get(queryable, id, tx: tx) result else raise NoResults.new("No Results") @@ -128,13 +116,13 @@ module Crecto # query = Query.preload(:posts) # user = Repo.get(User, 1, query) # ``` - def get(queryable, id, query : Query) - q = config.adapter.run(config.get_connection, :get, queryable, id).as(DB::ResultSet) + def get(queryable, id, query : Query, *, tx : DB::Transaction? = nil) + q = config.adapter.run(tx || config.get_connection, :get, queryable, id).as(DB::ResultSet) results = queryable.from_rs(q) if results.any? if query.preloads.any? - add_preloads(results, queryable, query.preloads) + add_preloads(results, queryable, query.preloads, tx) end results.first @@ -149,8 +137,8 @@ module Crecto # query = Query.preload(:posts) # user = Repo.get(User, 1, query) # ``` - def get!(queryable, id, query : Query) - if result = get(queryable, id, query) + def get!(queryable, id, query : Query, *, tx : DB::Transaction? = nil) + if result = get(queryable, id, query, tx: tx) result else raise NoResults.new("No Results") @@ -162,8 +150,8 @@ module Crecto # ``` # user = Repo.get_by(User, name: "fred", age: 21) # ``` - def get_by(queryable, **opts) - get_by(queryable, Query.where(**opts)) + def get_by(queryable, tx : DB::Transaction? = nil, **opts) + get_by(queryable, Query.where(**opts), tx: tx) end # Return a single nilable instance of *queryable* using the *query* param @@ -172,8 +160,8 @@ module Crecto # ``` # user = Repo.get_by(User, Query.where(name: "fred", age: 21)) # ``` - def get_by(queryable, query) - results = all(queryable, query.limit(1)) + def get_by(queryable, query, *, tx : DB::Transaction? = nil) + results = all(queryable, query.limit(1), tx: tx) results.first if results.any? end @@ -183,8 +171,8 @@ module Crecto # ``` # user = Repo.get_by(User, name: "fred", age: 21) # ``` - def get_by!(queryable, **opts) - get_by!(queryable, Query.where(**opts)) + def get_by!(queryable, tx : DB::Transaction? = nil, **opts) + get_by!(queryable, Query.where(**opts), tx: tx) end # Return a single instance of *queryable* using the *query* param @@ -193,8 +181,8 @@ module Crecto # ``` # user = Repo.get_by(User, Query.where(name: "fred", age: 21)) # ``` - def get_by!(queryable, query) - if result = get_by(queryable, query) + def get_by!(queryable, query, *, tx : DB::Transaction? = nil) + if result = get_by(queryable, query, tx: tx) result else raise NoResults.new("No Results") @@ -207,14 +195,14 @@ module Crecto # user = Crecto::Repo.get(User, 1) # post = Repo.get_association(user, :post) # ``` - def get_association(queryable_instance, association_name : Symbol, query : Query = Query.new) + def get_association(queryable_instance, association_name : Symbol, query : Query = Query.new, *, tx : DB::Transaction? = nil) case queryable_instance.class.association_type_for_association(association_name) when :has_many - get_has_many_association(queryable_instance, association_name, query) + get_has_many_association(queryable_instance, association_name, query, tx) when :has_one - get_has_one_association(queryable_instance, association_name, query) + get_has_one_association(queryable_instance, association_name, query, tx) when :belongs_to - get_belongs_to_association(queryable_instance, association_name, query) + get_belongs_to_association(queryable_instance, association_name, query, tx) else raise Exception.new("invalid operation passed to get_association") end @@ -228,8 +216,8 @@ module Crecto # user = Crecto::Repo.get(User, 1) # post = Repo.get_association!(user, :post) # ``` - def get_association!(queryable_instance, association_name : Symbol, query : Query = Query.new) - if result = get_association(queryable_instance, association_name, query) + def get_association!(queryable_instance, association_name : Symbol, query : Query = Query.new, *, tx : DB::Transaction? = nil) + if result = get_association(queryable_instance, association_name, query, tx: tx) result else raise NoResults.new("No Results") @@ -242,7 +230,7 @@ module Crecto # user = User.new # Repo.insert(user) # ``` - def insert(queryable_instance, tx : DB::Transaction?) + def insert(queryable_instance, *, tx : DB::Transaction? = nil) changeset = queryable_instance.class.changeset(queryable_instance) return changeset unless changeset.valid? @@ -271,10 +259,6 @@ module Crecto changeset end - def insert(queryable_instance) - insert(queryable_instance, nil) - end - # Insert a changeset instance into the data store. # # ``` @@ -282,8 +266,8 @@ module Crecto # changeset = User.changeset(user) # Repo.insert(changeset) # ``` - def insert(changeset : Crecto::Changeset::Changeset) - insert(changeset.instance) + def insert(changeset : Crecto::Changeset::Changeset, *, tx : DB::Transaction? = nil) + insert(changeset.instance, tx: tx) end # Insert a schema instance into the data store or raise if the resulting @@ -293,8 +277,8 @@ module Crecto # user = User.new # Repo.insert!(user) # ``` - def insert!(queryable_instance, tx : DB::Transaction? = nil) - insert(queryable_instance, tx).tap do |changeset| + def insert!(queryable_instance, *, tx : DB::Transaction? = nil) + insert(queryable_instance, tx: tx).tap do |changeset| raise InvalidChangeset.new(changeset) unless changeset.valid? end end @@ -307,8 +291,8 @@ module Crecto # changeset = User.changeset(user) # Repo.insert!(changeset) # ``` - def insert!(changeset : Crecto::Changeset::Changeset) - insert!(changeset.instance) + def insert!(changeset : Crecto::Changeset::Changeset, *, tx : DB::Transaction? = nil) + insert!(changeset.instance, tx: tx) end # Update a shema instance in the data store. @@ -316,7 +300,7 @@ module Crecto # ``` # Repo.update(user) # ``` - def update(queryable_instance, tx : DB::Transaction?) + def update(queryable_instance, *, tx : DB::Transaction? = nil) changeset = queryable_instance.class.changeset(queryable_instance) return changeset unless changeset.valid? @@ -339,17 +323,13 @@ module Crecto changeset end - def update(queryable_instance) - update(queryable_instance, nil) - end - # Update a changeset instance in the data store. # # ``` # Repo.update(changeset) # ``` - def update(changeset : Crecto::Changeset::Changeset) - update(changeset.instance) + def update(changeset : Crecto::Changeset::Changeset, *, tx : DB::Transaction? = nil) + update(changeset.instance, tx: tx) end # Update a schema instance in the data store or raise if the resulting @@ -358,8 +338,8 @@ module Crecto # ``` # Repo.update!(user) # ``` - def update!(queryable_instance, tx : DB::Transaction? = nil) - update(queryable_instance, tx).tap do |changeset| + def update!(queryable_instance, *, tx : DB::Transaction? = nil) + update(queryable_instance, tx: tx).tap do |changeset| raise InvalidChangeset.new(changeset) unless changeset.valid? end end @@ -370,8 +350,8 @@ module Crecto # ``` # Repo.update(changeset) # ``` - def update!(changeset : Crecto::Changeset::Changeset) - update!(changeset.instance) + def update!(changeset : Crecto::Changeset::Changeset, *, tx : DB::Transaction? = nil) + update!(changeset.instance, tx: tx) end # Update multipile records with a single query @@ -380,20 +360,12 @@ module Crecto # query = Crecto::Repo::Query.where(name: "Ted", count: 0) # Repo.update_all(User, query, {count: 1, date: Time.local}) # ``` - def update_all(queryable, query, update_hash : Hash, tx : DB::Transaction?) + def update_all(queryable, query, update_hash : Hash, *, tx : DB::Transaction? = nil) config.adapter.run(tx || config.get_connection, :update_all, queryable, query, update_hash) end - def update_all(queryable, query, update_hash : Hash) - update_all(queryable, query, update_hash, nil) - end - - def update_all(queryable, query, update_hash : NamedTuple, tx : DB::Transaction?) - update_all(queryable, query, update_hash.to_h, tx) - end - - def update_all(queryable, query, update_hash : NamedTuple) - update_all(queryable, query, update_hash, nil) + def update_all(queryable, query, update_hash : NamedTuple, *, tx : DB::Transaction? = nil) + update_all(queryable, query, update_hash.to_h, tx: tx) end # Delete a shema instance from the data store. @@ -401,7 +373,7 @@ module Crecto # ``` # Repo.delete(user) # ``` - def delete(queryable_instance, tx : DB::Transaction?) + def delete(queryable_instance, *, tx : DB::Transaction? = nil) changeset = queryable_instance.class.changeset(queryable_instance) return changeset unless changeset.valid? @@ -418,17 +390,13 @@ module Crecto changeset end - def delete(queryable_instance) - delete(queryable_instance, nil) - end - # Delete a changeset instance from the data store. # # ``` # Repo.delete(changeset) # ``` - def delete(changeset : Crecto::Changeset::Changeset) - delete(changeset.instance) + def delete(changeset : Crecto::Changeset::Changeset, *, tx : DB::Transaction? = nil) + delete(changeset.instance, tx: tx) end # Delete a schema instance from the data store or raise if the resulting @@ -437,8 +405,8 @@ module Crecto # ``` # Repo.delete!(user) # ``` - def delete!(queryable_instance, tx : DB::Transaction? = nil) - delete(queryable_instance, tx).tap do |changeset| + def delete!(queryable_instance, *, tx : DB::Transaction? = nil) + delete(queryable_instance, tx: tx).tap do |changeset| raise InvalidChangeset.new(changeset) unless changeset.valid? end end @@ -459,19 +427,14 @@ module Crecto # query = Crecto::Repo::Query.where(name: "Fred") # Repo.delete_all(User, query) # ``` - def delete_all(queryable, query : Query?, tx : DB::Transaction?) - query = Query.new if query.nil? - check_dependents(queryable, query, tx) + def delete_all(queryable, query : Query = Query.new, *, tx : DB::Transaction? = nil) + check_dependents(queryable, query, tx: tx) result = config.adapter.run(tx || config.get_connection, :delete_all, queryable, query) if tx.nil? && config.adapter == Crecto::Adapters::Postgres result.as(DB::ResultSet).close if result.is_a?(DB::ResultSet) end end - def delete_all(queryable, query = Query.new) - delete_all(queryable, query, nil) - end - # Run aribtrary sql queries. `query` will cast the output as that # object. In this example, `query` will try to cast the # output as `User`. If query results happen to error nil is @@ -480,8 +443,8 @@ module Crecto # ``` # Repo.query(User, "select * from users where id > ?", [30]) # ``` - def query(queryable, sql : String, params = [] of DbValue) : Array - q = config.adapter.run(config.get_connection, :sql, sql, params).as(DB::ResultSet) + def query(queryable, sql : String, params = [] of DbValue, *, tx : DB::Transaction? = nil) : Array + q = config.adapter.run(tx || config.get_connection, :sql, sql, params).as(DB::ResultSet) results = queryable.from_rs(q) results end @@ -496,8 +459,8 @@ module Crecto # ``` # query = Crecto::Repo.query("select * from users where id = ?", [30]) # ``` - def query(sql : String, params = [] of DbValue) : DB::ResultSet - config.adapter.run(config.get_connection, :sql, sql, params).as(DB::ResultSet) + def query(sql : String, params = [] of DbValue, *, tx : DB::Transaction? = nil) : DB::ResultSet + config.adapter.run(tx || config.get_connection, :sql, sql, params).as(DB::ResultSet) end def transaction(multi : Crecto::Multi) @@ -529,8 +492,8 @@ module Crecto # tx.insert!(post) # end # ``` - def transaction! - config.get_connection.transaction do |tx| + def transaction!(outer_tx : DB::Transaction? = nil) + (outer_tx || config.get_connection).transaction do |tx| begin yield LiveTransaction.new(tx, self) rescue error : Exception @@ -541,38 +504,38 @@ module Crecto end {% for operation in %w[insert update delete] %} - private def run_operation(operation : Multi::{{operation.camelcase.id}}, tx) - cs = {{operation.id}}(operation.instance, tx) + private def run_operation(operation : Multi::{{operation.camelcase.id}}, tx : DB::Transaction?) + cs = {{operation.id}}(operation.instance, tx: tx) raise cs.errors.first[:message] if !cs.valid? rescue ex : Exception raise OperationError.new(ex, operation.instance.class, {{operation}}) end {% end %} - private def run_operation(operation : Multi::UpdateAll, tx) - update_all(operation.queryable, operation.query, operation.update_hash, tx) + private def run_operation(operation : Multi::UpdateAll, tx : DB::Transaction?) + update_all(operation.queryable, operation.query, operation.update_hash, tx: tx) rescue ex : Exception raise OperationError.new(ex, operation.queryable, "update_all") end - private def run_operation(operation : Multi::DeleteAll, tx) - delete_all(operation.queryable, operation.query, tx) + private def run_operation(operation : Multi::DeleteAll, tx : DB::Transaction?) + delete_all(operation.queryable, operation.query, tx: tx) rescue ex : Exception raise OperationError.new(ex, operation.queryable, "delete_all") end # Calculate the given aggregate `aggregate_function` over the given `field` # Aggregate `aggregate_function` must be one of (:avg, :count, :max, :min:, :sum) - def aggregate(queryable, aggregate_function : Symbol, field : Symbol) + def aggregate(queryable, aggregate_function : Symbol, field : Symbol, *, tx : DB::Transaction? = nil) raise InvalidOption.new("Aggregate must be one of :avg, :count, :max, :min:, :sum") unless [:avg, :count, :max, :min, :sum].includes?(aggregate_function) - config.adapter.aggregate(config.get_connection, queryable, aggregate_function, field) + config.adapter.aggregate(tx || config.get_connection, queryable, aggregate_function, field) end - def aggregate(queryable, aggregate_function : Symbol, field : Symbol, query : Crecto::Repo::Query) + def aggregate(queryable, aggregate_function : Symbol, field : Symbol, query : Crecto::Repo::Query, *, tx : DB::Transaction? = nil) raise InvalidOption.new("Aggregate must be one of :avg, :count, :max, :min:, :sum") unless [:avg, :count, :max, :min, :sum].includes?(aggregate_function) - config.adapter.aggregate(config.get_connection, queryable, aggregate_function, field, query) + config.adapter.aggregate(tx || config.get_connection, queryable, aggregate_function, field, query) end private def check_dependents(changeset, tx : DB::Transaction?) : Nil @@ -604,7 +567,7 @@ module Crecto end end - private def delete_dependents(queryable, destroy_assoc, ids, tx) + private def delete_dependents(queryable, destroy_assoc, ids, tx : DB::Transaction?) through_key = queryable.through_key_for_association(destroy_assoc) if through_key.nil? foreign_key = queryable.foreign_key_for_association(destroy_assoc) @@ -612,7 +575,7 @@ module Crecto association_klass = queryable.klass_for_association(destroy_assoc) return if association_klass.nil? q = Crecto::Repo::Query.where(foreign_key, ids) - delete_all(association_klass, q, tx) + delete_all(association_klass, q, tx: tx) else outer_klass = queryable.klass_for_association(destroy_assoc) # Project join_klass = queryable.klass_for_association(through_key) # UserProject @@ -623,54 +586,54 @@ module Crecto return if join_key.nil? query = Query.select([outer_key.to_s]) query = query.where(join_key, ids) - join_associations = all(join_klass, query) + join_associations = all(join_klass, query, tx: tx) outer_klass_ids = join_associations.map { |ja| outer_klass.foreign_key_value_for_association(through_key, ja) } return if join_associations.empty? - delete_all(join_klass, Query.where(join_key, ids), tx) + delete_all(join_klass, Query.where(join_key, ids), tx: tx) outer_klass_pk_field = outer_klass.primary_key_field_symbol - delete_all(outer_klass, Query.where(outer_klass_pk_field, outer_klass_ids), tx) + delete_all(outer_klass, Query.where(outer_klass_pk_field, outer_klass_ids), tx: tx) end end - private def nullify_dependents(queryable, nullify_assoc, ids, tx) + private def nullify_dependents(queryable, nullify_assoc, ids, tx : DB::Transaction?) through_key = queryable.through_key_for_association(nullify_assoc) if through_key.nil? foreign_key = queryable.foreign_key_for_association(nullify_assoc) association_klass = queryable.klass_for_association(nullify_assoc) return if foreign_key.nil? || association_klass.nil? q = Crecto::Repo::Query.where(foreign_key, ids) - update_all(association_klass, q, {foreign_key => nil}, tx) + update_all(association_klass, q, {foreign_key => nil}, tx: tx) end end - private def add_preloads(results, queryable, preloads) + private def add_preloads(results, queryable, preloads, tx : DB::Transaction?) preloads.each do |preload| case queryable.association_type_for_association(preload[:symbol]) when :has_many - has_many_preload(results, queryable, preload) + has_many_preload(results, queryable, preload, tx) when :has_one - has_one_preload(results, queryable, preload) + has_one_preload(results, queryable, preload, tx) when :belongs_to - belongs_to_preload(results, queryable, preload) + belongs_to_preload(results, queryable, preload, tx) else raise Exception.new("invalid operation passed to add_preloads") end end end - private def has_one_preload(results, queryable, preload) - join_direct(results, queryable, preload, singular: true) + private def has_one_preload(results, queryable, preload, tx : DB::Transaction?) + join_direct(results, queryable, preload, tx, singular: true) end - private def has_many_preload(results, queryable, preload) + private def has_many_preload(results, queryable, preload, tx : DB::Transaction?) if queryable.through_key_for_association(preload[:symbol]) - join_through(results, queryable, preload) + join_through(results, queryable, preload, tx) else - join_direct(results, queryable, preload) + join_direct(results, queryable, preload, tx) end end - private def join_direct(results, queryable, preload, singular = false) + private def join_direct(results, queryable, preload, tx : DB::Transaction?, singular = false) ids = results.map(&.pkey_value.as(PkeyValue)) foreign_key = queryable.foreign_key_for_association(preload[:symbol]) return if foreign_key.nil? @@ -680,7 +643,7 @@ module Crecto end association_klass = queryable.klass_for_association(preload[:symbol]) return if association_klass.nil? - relation_items = all(association_klass, query) + relation_items = all(association_klass, query, tx: tx) relation_items = relation_items.group_by { |t| queryable.foreign_key_value_for_association(preload[:symbol], t) } results.each do |result| @@ -691,7 +654,7 @@ module Crecto end end - private def join_through(results, queryable, preload) + private def join_through(results, queryable, preload, tx : DB::Transaction?) ids = results.map(&.pkey_value.as(PkeyValue)) foreign_key = queryable.foreign_key_for_association(preload[:symbol]) return if foreign_key.nil? @@ -699,7 +662,7 @@ module Crecto # UserProjects association_klass = queryable.klass_for_association(queryable.through_key_for_association(preload[:symbol]).as(Symbol)) return if association_klass.nil? - join_table_items = all(association_klass, join_query) + join_table_items = all(association_klass, join_query, tx: tx) # array of Project id's if join_table_items.empty? @@ -718,7 +681,7 @@ module Crecto association_query = association_query.combine(preload_query) end # Projects - relation_items = all(association_klass, association_query) + relation_items = all(association_klass, association_query, tx: tx) # UserProject grouped by user_id join_table_items = join_table_items.group_by { |t| queryable.foreign_key_value_for_association(queryable.through_key_for_association(preload[:symbol]).as(Symbol), t) } @@ -733,7 +696,7 @@ module Crecto end end - private def belongs_to_preload(results, queryable, preload) + private def belongs_to_preload(results, queryable, preload, tx : DB::Transaction?) ids = results.map { |r| queryable.foreign_key_value_for_association(preload[:symbol], r).as(PkeyValue) } ids.compact! return if ids.empty? @@ -748,7 +711,7 @@ module Crecto end association_klass = queryable.klass_for_association(preload[:symbol]) return if association_klass.nil? - relation_items = all(association_klass, query) + relation_items = all(association_klass, query, tx: tx) unless relation_items.nil? relation_items = relation_items.group_by { |t| t.pkey_value.as(PkeyValue) } @@ -763,28 +726,28 @@ module Crecto end end - private def get_has_many_association(instance, association : Symbol, query : Query) + private def get_has_many_association(instance, association : Symbol, query : Query, tx : DB::Transaction?) queryable = instance.class foreign_key = queryable.foreign_key_for_association(association) return if foreign_key.nil? query = query.where(foreign_key, instance.pkey_value) association_klass = queryable.klass_for_association(association) return if association_klass.nil? - all(association_klass, query) + all(association_klass, query, tx: tx) end - private def get_has_one_association(instance, association : Symbol, query : Query) - many = get_has_many_association(instance, association, query) + private def get_has_one_association(instance, association : Symbol, query : Query, tx : DB::Transaction?) + many = get_has_many_association(instance, association, query, tx) return if many.nil? many.first? end - private def get_belongs_to_association(instance, association : Symbol, query : Query) + private def get_belongs_to_association(instance, association : Symbol, query : Query, tx : DB::Transaction?) queryable = instance.class klass_for_association = queryable.klass_for_association(association) return if klass_for_association.nil? key_for_association = queryable.foreign_key_value_for_association(association, instance) - get(klass_for_association, key_for_association, query) + get(klass_for_association, key_for_association, query, tx: tx) end end end