Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add references from Client and Statement to each other #1048

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 31 additions & 2 deletions ext/mysql2/client.c
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ extern VALUE mMysql2, cMysql2Error, cMysql2TimeoutError;
static VALUE sym_id, sym_version, sym_header_version, sym_async, sym_symbolize_keys, sym_as, sym_array, sym_stream;
static VALUE sym_no_good_index_used, sym_no_index_used, sym_query_was_slow;
static ID intern_brackets, intern_merge, intern_merge_bang, intern_new_with_args,
intern_current_query_options, intern_read_timeout;
intern_current_query_options, intern_read_timeout, intern_values;

#define REQUIRE_INITIALIZED(wrapper) \
if (!wrapper->initialized) { \
Expand Down Expand Up @@ -166,6 +166,7 @@ static void rb_mysql_client_mark(void * wrapper) {
if (w) {
rb_gc_mark(w->encoding);
rb_gc_mark(w->active_thread);
rb_gc_mark(w->prepared_statements);
}
}

Expand Down Expand Up @@ -262,6 +263,14 @@ static VALUE invalidate_fd(int clientfd)
}
#endif /* _WIN32 */

static int decr_mysql2_stmt_hash(VALUE key, VALUE val, VALUE arg)
{
mysql_client_wrapper *wrapper = (mysql_client_wrapper *)arg;
VALUE stmt = rb_ivar_get(wrapper->prepared_statements, key);
// rb_funcall(stmt, rb_intern("close"), 0);
return 0;
}

static void *nogvl_close(void *ptr) {
mysql_client_wrapper *wrapper = ptr;

Expand Down Expand Up @@ -303,6 +312,8 @@ void decr_mysql2_client(mysql_client_wrapper *wrapper)
}
#endif

// rb_hash_foreach(wrapper->prepared_statements, decr_mysql2_stmt_hash, (VALUE)wrapper);

nogvl_close(wrapper);
xfree(wrapper->client);
xfree(wrapper);
Expand All @@ -315,6 +326,7 @@ static VALUE allocate(VALUE klass) {
obj = Data_Make_Struct(klass, mysql_client_wrapper, rb_mysql_client_mark, rb_mysql_client_free, wrapper);
wrapper->encoding = Qnil;
wrapper->active_thread = Qnil;
wrapper->prepared_statements = rb_hash_new();
wrapper->automatic_close = 1;
wrapper->server_version = 0;
wrapper->reconnect_enabled = 0;
Expand Down Expand Up @@ -1371,10 +1383,25 @@ static VALUE initialize_ext(VALUE self) {
* Create a new prepared statement.
*/
static VALUE rb_mysql_client_prepare_statement(VALUE self, VALUE sql) {
VALUE stmt;
GET_CLIENT(self);
REQUIRE_CONNECTED(wrapper);

return rb_mysql_stmt_new(self, sql);
stmt = rb_mysql_stmt_new(self, sql);

return stmt;
}

/* call-seq:
* client.prepared_statements
*
* Returns an array of prepared statement objects.
*/
static VALUE rb_mysql_client_prepared_statements_read(VALUE self) {
unsigned long retVal;
GET_CLIENT(self);

return rb_funcall(wrapper->prepared_statements, intern_values, 0);
}

void init_mysql2_client() {
Expand Down Expand Up @@ -1423,6 +1450,7 @@ void init_mysql2_client() {
rb_define_method(cMysql2Client, "last_id", rb_mysql_client_last_id, 0);
rb_define_method(cMysql2Client, "affected_rows", rb_mysql_client_affected_rows, 0);
rb_define_method(cMysql2Client, "prepare", rb_mysql_client_prepare_statement, 1);
rb_define_method(cMysql2Client, "prepared_statements", rb_mysql_client_prepared_statements_read, 0);
rb_define_method(cMysql2Client, "thread_id", rb_mysql_client_thread_id, 0);
rb_define_method(cMysql2Client, "ping", rb_mysql_client_ping, 0);
rb_define_method(cMysql2Client, "select_db", rb_mysql_client_select_db, 1);
Expand Down Expand Up @@ -1474,6 +1502,7 @@ void init_mysql2_client() {
intern_new_with_args = rb_intern("new_with_args");
intern_current_query_options = rb_intern("@current_query_options");
intern_read_timeout = rb_intern("@read_timeout");
intern_values = rb_intern("values");

#ifdef CLIENT_LONG_PASSWORD
rb_const_set(cMysql2Client, rb_intern("LONG_PASSWORD"),
Expand Down
1 change: 1 addition & 0 deletions ext/mysql2/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
typedef struct {
VALUE encoding;
VALUE active_thread; /* rb_thread_current() or Qnil */
VALUE prepared_statements;
long server_version;
int reconnect_enabled;
unsigned int connect_timeout;
Expand Down
32 changes: 31 additions & 1 deletion ext/mysql2/statement.c
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,15 @@ void decr_mysql2_stmt(mysql_stmt_wrapper *stmt_wrapper) {
stmt_wrapper->refcount--;

if (stmt_wrapper->refcount == 0) {
// If the GC get to client first it will be nil, and this cleanup won't matter
if (stmt_wrapper->client_wrapper && stmt_wrapper->client_wrapper->refcount > 0) {
// Remove the reference to this statement handle from the Client object.
rb_hash_delete(stmt_wrapper->client_wrapper->prepared_statements,
ULL2NUM((unsigned long long)stmt_wrapper));
}

nogvl_stmt_close(stmt_wrapper);
decr_mysql2_client(stmt_wrapper->client_wrapper);
xfree(stmt_wrapper);
}
}
Expand Down Expand Up @@ -98,10 +106,18 @@ VALUE rb_mysql_stmt_new(VALUE rb_client, VALUE sql) {

rb_stmt = Data_Make_Struct(cMysql2Statement, mysql_stmt_wrapper, rb_mysql_stmt_mark, rb_mysql_stmt_free, stmt_wrapper);
{
stmt_wrapper->client = rb_client;
stmt_wrapper->refcount = 1;
stmt_wrapper->closed = 0;
stmt_wrapper->stmt = NULL;

/* Keep a handle to the Client to ensure it doesn't get garbage collected first */
stmt_wrapper->client = rb_client;
if (rb_client != Qnil) {
stmt_wrapper->client_wrapper = DATA_PTR(rb_client);
stmt_wrapper->client_wrapper->refcount++;
} else {
stmt_wrapper->client_wrapper = NULL;
}
}

// instantiate stmt
Expand Down Expand Up @@ -136,6 +152,18 @@ VALUE rb_mysql_stmt_new(VALUE rb_client, VALUE sql) {
}
}

// Stash a reference to this statement handle into the Client to prevent
// premature garbage collection.
//
// A statement can either be free explicitly or when the client object is
// torn down. Freeing a statement handle at any other time causes protocol
// traffic that might happen while the connection state is set for another
// operation.
{
GET_CLIENT(rb_client);
rb_hash_aset(wrapper->prepared_statements, ULL2NUM((unsigned long long)stmt_wrapper), rb_stmt);
}

return rb_stmt;
}

Expand Down Expand Up @@ -565,7 +593,9 @@ static VALUE rb_mysql_stmt_affected_rows(VALUE self) {
*/
static VALUE rb_mysql_stmt_close(VALUE self) {
GET_STATEMENT(self);
GET_CLIENT(stmt_wrapper->client);
stmt_wrapper->closed = 1;
rb_hash_delete(wrapper->prepared_statements, ULL2NUM((unsigned long long)stmt_wrapper));
rb_thread_call_without_gvl(nogvl_stmt_close, stmt_wrapper, RUBY_UBF_IO, 0);
return Qnil;
}
Expand Down
5 changes: 3 additions & 2 deletions ext/mysql2/statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
#define MYSQL2_STATEMENT_H

typedef struct {
int closed;
int refcount;
VALUE client;
mysql_client_wrapper *client_wrapper;
MYSQL_STMT *stmt;
int refcount;
int closed;
} mysql_stmt_wrapper;

void init_mysql2_statement(void);
Expand Down