diff --git a/README.md b/README.md index b603ba72..4698c0fe 100644 --- a/README.md +++ b/README.md @@ -113,7 +113,29 @@ If you are looking for the **plain necessities**, you should use the [ejabberd][ #### Transaction -This driver currently does not support transactions. +This driver supports transaction in this way: + +Usage: + +```erlang +Transaction = fun(Q) -> + R1 = Q(<<"SELECT * from some_table">>, []), + R2 = Q(<<"SELECT * from some_other_table">>, []), + R3 = Q(<"INSERT INTO ...">>, []), + .... + + {ok, SomeResult} +end, + +{ok, Result} = emysql:transaction(Pool, Transaction, Timeout). + +The transaction fun receive one argument, wich is also a function. +That is used to perform queries. Only textual queries allowed +(no prepared statments with arguments). + +The result is the return value of the transaction function. +If the function crashes, transaction is aborted. That is the only +way to abort the transaction: throwing an error. For **mnesia-style transactions**, one of the multiple '[erlang-mysql-driver][22]s' may suite you best. There are [quite many][16] branches of it out there, and they are based on the same project as the ejabberd driver. To learn more about out the differences between the drivers, see the [mysql driver history][History]. diff --git a/src/emysql.erl b/src/emysql.erl index a373de9a..a0112710 100644 --- a/src/emysql.erl +++ b/src/emysql.erl @@ -110,7 +110,7 @@ %% Used to interact with the database. -export([ prepare/2, - execute/2, execute/3, execute/4, execute/5, + execute/2, execute/3, execute/4, execute/5, transaction/3, default_timeout/0 ]). @@ -533,6 +533,11 @@ execute(PoolId, Query, Timeout) when (is_list(Query) orelse is_binary(Query)) an execute(PoolId, StmtName, Timeout) when is_atom(StmtName), (is_integer(Timeout) orelse Timeout == infinity) -> execute(PoolId, StmtName, [], Timeout). +transaction(PoolId, Fun, Timeout) when is_function(Fun) -> + Connection = emysql_conn_mgr:wait_for_connection(PoolId), + monitor_work(Connection, Timeout, {transaction, Connection, Fun}). + + %% @spec execute(PoolId, Query|StmtName, Args, Timeout) -> Result | [Result] %% PoolId = atom() %% Query = binary() | string() @@ -759,7 +764,14 @@ monitor_work(Connection0, Timeout, Args) when is_record(Connection0, emysql_conn {Pid, Mref} = spawn_monitor( fun() -> put(query_arguments, Args), - Parent ! {self(), apply(fun emysql_conn:execute/3, Args)} + %Parent ! {self(), apply(fun emysql_conn:execute/3, Args)} + case Args of + {transaction, _MaybeOtherConnection, Fun} -> + Result = execute_transaction(Connection, Fun), + Parent ! {self(), Result}; + _ -> + Parent ! {self(), apply(fun emysql_conn:execute/3, Args)} + end end), receive {'DOWN', Mref, process, Pid, tcp_connection_closed} -> @@ -795,3 +807,43 @@ monitor_work(Connection0, Timeout, Args) when is_record(Connection0, emysql_conn emysql_conn:reset_connection(emysql_conn_mgr:pools(), Connection, pass), exit(mysql_timeout) end. + +%% @spec execute_transaction(Connection, Fun) -> Result +%% Connection = pid() +%% Fun = fun() + +%% Result = ok_packet() | result_packet() | error_packet() +%% +%% @doc Execute a transaction function +%% +%% @private +%% @end doc: jfjalburquerque july 15 +%% +execute_transaction(Connection, Fun) -> + Result = try + case emysql_conn:execute(Connection, <<"START TRANSACTION;">>, []) of + #ok_packet{} -> + %% execute transaction function + Fun(fun(Query, Params) -> emysql_conn:execute(Connection, Query, Params) end); + MysqlError -> MysqlError + end + catch + _:Error -> + %% in case of error in the transaction function + %% we send #error_packet to do rollback + #error_packet{msg = Error} + end, + case Result of + #error_packet{} -> + %% in case of #error_packet we execute rollback + case emysql_conn:execute(Connection, <<"ROLLBACK;">>, []) of + #ok_packet{} -> Result#error_packet{code = 1402, status = <<"XA100">>}; + RollbackError -> RollbackError + end; + _ -> + %% in case of #ok_packet we execute commit + case emysql_conn:execute(Connection, <<"COMMIT;">>, []) of + #ok_packet{} -> Result; + CommitError -> CommitError + end + end. \ No newline at end of file diff --git a/src/emysql_conn.erl b/src/emysql_conn.erl index 01a5c74b..564b86f8 100644 --- a/src/emysql_conn.erl +++ b/src/emysql_conn.erl @@ -427,7 +427,7 @@ encode(Val, binary) when is_integer(Val) -> list_to_binary(integer_to_list(Val)); encode(Val, list) when is_float(Val) -> [Res] = io_lib:format("~w", [Val]), - Res; + list_to_binary(Res); encode(Val, binary) when is_float(Val) -> iolist_to_binary(io_lib:format("~w", [Val])); encode({datetime, Val}, ReturnType) ->