diff --git a/ext/mysql2/client.c b/ext/mysql2/client.c index 25a35029..c825f5e2 100644 --- a/ext/mysql2/client.c +++ b/ext/mysql2/client.c @@ -19,7 +19,7 @@ extern VALUE mMysql2, cMysql2Error, cMysql2TimeoutError; static VALUE sym_id, sym_version, sym_header_version, sym_async, sym_symbolize_keys, sym_as, sym_array, sym_stream; static VALUE sym_no_good_index_used, sym_no_index_used, sym_query_was_slow; static ID intern_brackets, intern_merge, intern_merge_bang, intern_new_with_args, - intern_current_query_options, intern_read_timeout; + intern_current_query_options, intern_read_timeout, intern_values; #define REQUIRE_INITIALIZED(wrapper) \ if (!wrapper->initialized) { \ @@ -221,6 +221,7 @@ static void rb_mysql_client_mark(void * wrapper) { if (w) { rb_gc_mark_movable(w->encoding); rb_gc_mark_movable(w->active_fiber); + rb_gc_mark_movable(w->prepared_statements); } } @@ -353,6 +354,14 @@ static VALUE invalidate_fd(int clientfd) } #endif /* _WIN32 */ +static int decr_mysql2_stmt_hash(VALUE key, VALUE val, VALUE arg) +{ + mysql_client_wrapper *wrapper = (mysql_client_wrapper *)arg; + VALUE stmt = rb_ivar_get(wrapper->prepared_statements, key); + // rb_funcall(stmt, rb_intern("close"), 0); + return 0; +} + static void *nogvl_close(void *ptr) { mysql_client_wrapper *wrapper = ptr; @@ -388,6 +397,8 @@ void decr_mysql2_client(mysql_client_wrapper *wrapper) } #endif + // rb_hash_foreach(wrapper->prepared_statements, decr_mysql2_stmt_hash, (VALUE)wrapper); + nogvl_close(wrapper); xfree(wrapper->client); xfree(wrapper); @@ -404,6 +415,7 @@ static VALUE allocate(VALUE klass) { #endif wrapper->encoding = Qnil; wrapper->active_fiber = Qnil; + wrapper->prepared_statements = rb_hash_new(); wrapper->automatic_close = 1; wrapper->server_version = 0; wrapper->reconnect_enabled = 0; @@ -1535,10 +1547,25 @@ static VALUE initialize_ext(VALUE self) { * Create a new prepared statement. */ static VALUE rb_mysql_client_prepare_statement(VALUE self, VALUE sql) { + VALUE stmt; GET_CLIENT(self); REQUIRE_CONNECTED(wrapper); - return rb_mysql_stmt_new(self, sql); + stmt = rb_mysql_stmt_new(self, sql); + + return stmt; +} + +/* call-seq: + * client.prepared_statements + * + * Returns an array of prepared statement objects. + */ +static VALUE rb_mysql_client_prepared_statements_read(VALUE self) { + unsigned long retVal; + GET_CLIENT(self); + + return rb_funcall(wrapper->prepared_statements, intern_values, 0); } void init_mysql2_client() { @@ -1588,6 +1615,7 @@ void init_mysql2_client() { rb_define_method(cMysql2Client, "last_id", rb_mysql_client_last_id, 0); rb_define_method(cMysql2Client, "affected_rows", rb_mysql_client_affected_rows, 0); rb_define_method(cMysql2Client, "prepare", rb_mysql_client_prepare_statement, 1); + rb_define_method(cMysql2Client, "prepared_statements", rb_mysql_client_prepared_statements_read, 0); rb_define_method(cMysql2Client, "thread_id", rb_mysql_client_thread_id, 0); rb_define_method(cMysql2Client, "ping", rb_mysql_client_ping, 0); rb_define_method(cMysql2Client, "select_db", rb_mysql_client_select_db, 1); @@ -1641,6 +1669,7 @@ void init_mysql2_client() { intern_new_with_args = rb_intern("new_with_args"); intern_current_query_options = rb_intern("@current_query_options"); intern_read_timeout = rb_intern("@read_timeout"); + intern_values = rb_intern("values"); #ifdef CLIENT_LONG_PASSWORD rb_const_set(cMysql2Client, rb_intern("LONG_PASSWORD"), diff --git a/ext/mysql2/client.h b/ext/mysql2/client.h index 6a8227bd..67bd35a5 100644 --- a/ext/mysql2/client.h +++ b/ext/mysql2/client.h @@ -4,6 +4,7 @@ typedef struct { VALUE encoding; VALUE active_fiber; /* rb_fiber_current() or Qnil */ + VALUE prepared_statements; long server_version; int reconnect_enabled; unsigned int connect_timeout; diff --git a/ext/mysql2/statement.c b/ext/mysql2/statement.c index fa3b660c..c71224ff 100644 --- a/ext/mysql2/statement.c +++ b/ext/mysql2/statement.c @@ -75,7 +75,15 @@ void decr_mysql2_stmt(mysql_stmt_wrapper *stmt_wrapper) { stmt_wrapper->refcount--; if (stmt_wrapper->refcount == 0) { + // If the GC get to client first it will be nil, and this cleanup won't matter + if (stmt_wrapper->client_wrapper && stmt_wrapper->client_wrapper->refcount > 0) { + // Remove the reference to this statement handle from the Client object. + rb_hash_delete(stmt_wrapper->client_wrapper->prepared_statements, + ULL2NUM((unsigned long long)stmt_wrapper)); + } + nogvl_stmt_close(stmt_wrapper); + decr_mysql2_client(stmt_wrapper->client_wrapper); xfree(stmt_wrapper); } } @@ -140,10 +148,18 @@ VALUE rb_mysql_stmt_new(VALUE rb_client, VALUE sql) { rb_stmt = Data_Make_Struct(cMysql2Statement, mysql_stmt_wrapper, rb_mysql_stmt_mark, rb_mysql_stmt_free, stmt_wrapper); #endif { - stmt_wrapper->client = rb_client; stmt_wrapper->refcount = 1; stmt_wrapper->closed = 0; stmt_wrapper->stmt = NULL; + + /* Keep a handle to the Client to ensure it doesn't get garbage collected first */ + stmt_wrapper->client = rb_client; + if (rb_client != Qnil) { + stmt_wrapper->client_wrapper = DATA_PTR(rb_client); + stmt_wrapper->client_wrapper->refcount++; + } else { + stmt_wrapper->client_wrapper = NULL; + } } // instantiate stmt @@ -178,6 +194,18 @@ VALUE rb_mysql_stmt_new(VALUE rb_client, VALUE sql) { } } + // Stash a reference to this statement handle into the Client to prevent + // premature garbage collection. + // + // A statement can either be free explicitly or when the client object is + // torn down. Freeing a statement handle at any other time causes protocol + // traffic that might happen while the connection state is set for another + // operation. + { + GET_CLIENT(rb_client); + rb_hash_aset(wrapper->prepared_statements, ULL2NUM((unsigned long long)stmt_wrapper), rb_stmt); + } + return rb_stmt; } @@ -609,7 +637,9 @@ static VALUE rb_mysql_stmt_close(VALUE self) { RAW_GET_STATEMENT(self); if (!stmt_wrapper->closed) { + GET_CLIENT(stmt_wrapper->client); stmt_wrapper->closed = 1; + rb_hash_delete(wrapper->prepared_statements, ULL2NUM((unsigned long long)stmt_wrapper)); rb_thread_call_without_gvl(nogvl_stmt_close, stmt_wrapper, RUBY_UBF_IO, 0); } diff --git a/ext/mysql2/statement.h b/ext/mysql2/statement.h index e4851067..78a2ef4d 100644 --- a/ext/mysql2/statement.h +++ b/ext/mysql2/statement.h @@ -2,10 +2,11 @@ #define MYSQL2_STATEMENT_H typedef struct { + int closed; + int refcount; VALUE client; + mysql_client_wrapper *client_wrapper; MYSQL_STMT *stmt; - int refcount; - int closed; } mysql_stmt_wrapper; void init_mysql2_statement(void);