Skip to content
54 changes: 38 additions & 16 deletions mssql_python/pybind/connection/connection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,13 @@ Connection::Connection(const std::wstring& conn_str, bool use_pool)
}

Connection::~Connection() {
disconnect(); // fallback if user forgets to disconnect
try {
disconnect(); // fallback if user forgets to disconnect
} catch (...) {
Comment thread
subrata-ms marked this conversation as resolved.
Outdated
// Never throw from a destructor — doing so during stack unwinding
// causes std::terminate(). Log and swallow.
LOG_ERROR("Exception suppressed in ~Connection destructor");
}
Comment thread
subrata-ms marked this conversation as resolved.
Outdated
}

// Allocates connection handle
Expand Down Expand Up @@ -99,23 +105,22 @@ void Connection::disconnect() {
// When we free the DBC handle below, the ODBC driver will automatically free
// all child STMT handles. We need to tell the SqlHandle objects about this
// so they don't try to free the handles again during their destruction.

// THREAD-SAFETY: Lock mutex to safely access _childStatementHandles
// This protects against concurrent allocStatementHandle() calls or GC finalizers
{
std::lock_guard<std::mutex> lock(_childHandlesMutex);

// First compact: remove expired weak_ptrs (they're already destroyed)
size_t originalSize = _childStatementHandles.size();
_childStatementHandles.erase(
std::remove_if(_childStatementHandles.begin(), _childStatementHandles.end(),
[](const std::weak_ptr<SqlHandle>& wp) { return wp.expired(); }),
_childStatementHandles.end());

LOG("Compacted child handles: %zu -> %zu (removed %zu expired)",
originalSize, _childStatementHandles.size(),
originalSize - _childStatementHandles.size());


LOG("Compacted child handles: %zu -> %zu (removed %zu expired)", originalSize,
_childStatementHandles.size(), originalSize - _childStatementHandles.size());

LOG("Marking %zu child statement handles as implicitly freed",
_childStatementHandles.size());
for (auto& weakHandle : _childStatementHandles) {
Expand All @@ -124,8 +129,10 @@ void Connection::disconnect() {
// This is guaranteed by allocStatementHandle() which only creates STMT handles
// If this assertion fails, it indicates a serious bug in handle tracking
if (handle->type() != SQL_HANDLE_STMT) {
LOG_ERROR("CRITICAL: Non-STMT handle (type=%d) found in _childStatementHandles. "
"This will cause a handle leak!", handle->type());
LOG_ERROR(
"CRITICAL: Non-STMT handle (type=%d) found in _childStatementHandles. "
"This will cause a handle leak!",
handle->type());
continue; // Skip marking to prevent leak
}
handle->markImplicitlyFreed();
Expand All @@ -136,8 +143,24 @@ void Connection::disconnect() {
} // Release lock before potentially slow SQLDisconnect call

SQLRETURN ret = SQLDisconnect_ptr(_dbcHandle->get());
checkError(ret);
// triggers SQLFreeHandle via destructor, if last owner
if (!SQL_SUCCEEDED(ret)) {
// Log the error but do NOT throw — disconnect must be safe to call
// from destructors, reset() failure paths, and pool cleanup.
// Throwing here during stack unwinding causes std::terminate().
LOG_ERROR("SQLDisconnect failed (ret=%d), forcing handle cleanup", ret);
Comment thread
subrata-ms marked this conversation as resolved.
Outdated

// Best-effort: retrieve and log ODBC diagnostics for debuggability.
// This must not throw, to keep disconnect noexcept-safe.
try {
ErrorInfo err = SQLCheckError_Wrap(SQL_HANDLE_DBC, _dbcHandle, ret);
std::string diagMsg = WideToUTF8(err.ddbcErrorMsg);
LOG_ERROR("SQLDisconnect diagnostics: %s", diagMsg.c_str());
} catch (...) {
// Swallow all exceptions: cleanup paths must not throw.
LOG_ERROR("SQLDisconnect: failed to retrieve ODBC diagnostics");
}
}
// Always free the handle regardless of SQLDisconnect result
_dbcHandle.reset();
} else {
LOG("No connection handle to disconnect");
Expand Down Expand Up @@ -221,7 +244,7 @@ SqlHandlePtr Connection::allocStatementHandle() {
// or GC finalizers running from different threads
{
std::lock_guard<std::mutex> lock(_childHandlesMutex);

// Track this child handle so we can mark it as implicitly freed when connection closes
// Use weak_ptr to avoid circular references and allow normal cleanup
_childStatementHandles.push_back(stmtHandle);
Expand All @@ -237,9 +260,8 @@ SqlHandlePtr Connection::allocStatementHandle() {
[](const std::weak_ptr<SqlHandle>& wp) { return wp.expired(); }),
_childStatementHandles.end());
_allocationsSinceCompaction = 0;
LOG("Periodic compaction: %zu -> %zu handles (removed %zu expired)",
originalSize, _childStatementHandles.size(),
originalSize - _childStatementHandles.size());
LOG("Periodic compaction: %zu -> %zu handles (removed %zu expired)", originalSize,
_childStatementHandles.size(), originalSize - _childStatementHandles.size());
}
} // Release lock

Expand Down
30 changes: 17 additions & 13 deletions mssql_python/pybind/connection/connection_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,24 @@ std::shared_ptr<Connection> ConnectionPool::acquire(const std::wstring& connStr,
auto now = std::chrono::steady_clock::now();
size_t before = _pool.size();

LOG("ConnectionPool::acquire: pool_size=%zu, max_size=%zu, idle_timeout=%d", before,
_max_size, _idle_timeout_secs);

// Phase 1: Remove stale connections, collect for later disconnect
_pool.erase(std::remove_if(_pool.begin(), _pool.end(),
[&](const std::shared_ptr<Connection>& conn) {
auto idle_time =
std::chrono::duration_cast<std::chrono::seconds>(
now - conn->lastUsed())
.count();
if (idle_time > _idle_timeout_secs) {
to_disconnect.push_back(conn);
return true;
}
return false;
}),
_pool.end());
_pool.erase(
std::remove_if(
_pool.begin(), _pool.end(),
[&](const std::shared_ptr<Connection>& conn) {
auto idle_time =
std::chrono::duration_cast<std::chrono::seconds>(now - conn->lastUsed())
.count();
if (idle_time > _idle_timeout_secs) {
to_disconnect.push_back(conn);
return true;
}
return false;
}),
_pool.end());

size_t pruned = before - _pool.size();
_current_size = (_current_size >= pruned) ? (_current_size - pruned) : 0;
Expand Down
4 changes: 1 addition & 3 deletions mssql_python/pybind/ddbc_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4514,10 +4514,8 @@ SQLLEN SQLRowCount_wrap(SqlHandlePtr StatementHandle) {
return rowCount;
}

static std::once_flag pooling_init_flag;
void enable_pooling(int maxSize, int idleTimeout) {
std::call_once(pooling_init_flag,
Comment thread
subrata-ms marked this conversation as resolved.
[&]() { ConnectionPoolManager::getInstance().configure(maxSize, idleTimeout); });
ConnectionPoolManager::getInstance().configure(maxSize, idleTimeout);
}

// Thread-safe decimal separator setting
Expand Down
132 changes: 67 additions & 65 deletions tests/test_009_pooling.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,20 +104,16 @@ def test_connection_pooling_isolation_level_reset(conn_str):
# Set isolation level to SERIALIZABLE (non-default)
conn1.set_attr(mssql_python.SQL_ATTR_TXN_ISOLATION, mssql_python.SQL_TXN_SERIALIZABLE)

# Verify the isolation level was set
# Verify the isolation level was set (use DBCC USEROPTIONS to avoid
# requiring VIEW SERVER PERFORMANCE STATE permission for sys.dm_exec_sessions)
cursor1 = conn1.cursor()
cursor1.execute(
"SELECT CASE transaction_isolation_level "
"WHEN 0 THEN 'Unspecified' "
"WHEN 1 THEN 'ReadUncommitted' "
"WHEN 2 THEN 'ReadCommitted' "
"WHEN 3 THEN 'RepeatableRead' "
"WHEN 4 THEN 'Serializable' "
"WHEN 5 THEN 'Snapshot' END AS isolation_level "
"FROM sys.dm_exec_sessions WHERE session_id = @@SPID"
)
isolation_level_1 = cursor1.fetchone()[0]
assert isolation_level_1 == "Serializable", f"Expected Serializable, got {isolation_level_1}"
cursor1.execute("DBCC USEROPTIONS WITH NO_INFOMSGS")
isolation_level_1 = None
for row in cursor1.fetchall():
if row[0] == "isolation level":
isolation_level_1 = row[1]
break
assert isolation_level_1 == "serializable", f"Expected serializable, got {isolation_level_1}"

# Get SPID for verification of connection reuse
cursor1.execute("SELECT @@SPID")
Expand All @@ -138,24 +134,20 @@ def test_connection_pooling_isolation_level_reset(conn_str):
# Verify connection was reused
assert spid1 == spid2, "Connection was not reused from pool"

# Check if isolation level is reset to default
cursor2.execute(
"SELECT CASE transaction_isolation_level "
"WHEN 0 THEN 'Unspecified' "
"WHEN 1 THEN 'ReadUncommitted' "
"WHEN 2 THEN 'ReadCommitted' "
"WHEN 3 THEN 'RepeatableRead' "
"WHEN 4 THEN 'Serializable' "
"WHEN 5 THEN 'Snapshot' END AS isolation_level "
"FROM sys.dm_exec_sessions WHERE session_id = @@SPID"
)
isolation_level_2 = cursor2.fetchone()[0]
# Check if isolation level is reset to default (use DBCC USEROPTIONS to avoid
# requiring VIEW SERVER PERFORMANCE STATE permission for sys.dm_exec_sessions)
cursor2.execute("DBCC USEROPTIONS WITH NO_INFOMSGS")
isolation_level_2 = None
for row in cursor2.fetchall():
if row[0] == "isolation level":
isolation_level_2 = row[1]
break

# Verify isolation level is reset to default (READ COMMITTED)
# This is the CORRECT behavior for connection pooling - we should reset
# session state to prevent settings from one usage affecting the next
assert isolation_level_2 == "ReadCommitted", (
f"Isolation level was not reset! Expected 'ReadCommitted', got '{isolation_level_2}'. "
assert isolation_level_2 == "read committed", (
f"Isolation level was not reset! Expected 'read committed', got '{isolation_level_2}'. "
f"This indicates session state leaked from the previous connection usage."
)

Expand Down Expand Up @@ -278,82 +270,92 @@ def try_overflow():
c.close()


@pytest.mark.skip("Flaky test - idle timeout behavior needs investigation")
def test_pool_idle_timeout_removes_connections(conn_str):
"""Test that idle_timeout removes connections from the pool after the timeout."""
pooling(max_size=2, idle_timeout=1)
conn1 = connect(conn_str)
spid_list = []
cursor1 = conn1.cursor()
# Use @@SPID to identify the connection without requiring
# VIEW SERVER PERFORMANCE STATE permission for sys.dm_exec_connections.
cursor1.execute("SELECT @@SPID")
spid1 = cursor1.fetchone()[0]
spid_list.append(spid1)
conn1.close()

# Wait for longer than idle_timeout
time.sleep(3)
# Wait well beyond the idle_timeout to account for slow CI and integer-second granularity
time.sleep(5)

# Get a new connection, which should not reuse the previous SPID
# Get a new connection — the idle one should have been evicted during acquire()
conn2 = connect(conn_str)
cursor2 = conn2.cursor()
cursor2.execute("SELECT @@SPID")
spid2 = cursor2.fetchone()[0]
spid_list.append(spid2)
conn2.close()

assert spid1 != spid2, "Idle timeout did not remove connection from pool"
assert spid1 != spid2, "Idle timeout did not remove connection from pool — same SPID reused"


# =============================================================================
# Error Handling and Recovery Tests
# =============================================================================


@pytest.mark.skip(
"Test causes fatal crash - forcibly closing underlying connection leads to undefined behavior"
)
def test_pool_removes_invalid_connections(conn_str):
"""Test that the pool removes connections that become invalid (simulate by closing underlying connection)."""
"""Test that the pool removes connections that become invalid and recovers gracefully.

This test simulates a connection being returned to the pool in a dirty state
(with an open transaction) by calling _conn.close() directly, bypassing the
normal Python close() which does a rollback. The pool's acquire() should detect
the bad connection during reset(), discard it, and create a fresh one.
"""
pooling(max_size=1, idle_timeout=30)
conn = connect(conn_str)
cursor = conn.cursor()
cursor.execute("SELECT 1")
# Simulate invalidation by forcibly closing the connection at the driver level
try:
# Try to access a private attribute or method to forcibly close the underlying connection
# This is implementation-specific; if not possible, skip
if hasattr(conn, "_conn") and hasattr(conn._conn, "close"):
conn._conn.close()
else:
pytest.skip("Cannot forcibly close underlying connection for this driver")
except Exception:
pass
# Safely close the connection, ignoring errors due to forced invalidation
cursor.fetchone()

# Record the SPID of the original connection (avoids requiring
# VIEW SERVER PERFORMANCE STATE permission for sys.dm_exec_connections)
cursor.execute("SELECT @@SPID")
original_spid = cursor.fetchone()[0]

# Force-return the connection to the pool WITHOUT rollback.
# This leaves the pooled connection in a dirty state (open implicit transaction)
# which will cause reset() to fail on next acquire().
conn._conn.close()

# Python close() will fail since the underlying handle is already gone
try:
conn.close()
except RuntimeError as e:
if "not initialized" not in str(e):
raise
# Now, get a new connection from the pool and ensure it works
except RuntimeError:
pass

# Now get a new connection the pool should discard the dirty one and create fresh
new_conn = connect(conn_str)
new_cursor = new_conn.cursor()
try:
new_cursor.execute("SELECT 1")
result = new_cursor.fetchone()
assert result is not None and result[0] == 1, "Pool did not remove invalid connection"
finally:
new_conn.close()
new_cursor.execute("SELECT 1")
result = new_cursor.fetchone()
assert result is not None and result[0] == 1, "Pool did not recover from invalid connection"

# Verify it's a different physical connection
new_cursor.execute("SELECT @@SPID")
new_spid = new_cursor.fetchone()[0]
assert (
original_spid != new_spid
), "Expected a new physical connection after pool discarded the dirty one"
Comment thread
subrata-ms marked this conversation as resolved.

new_conn.close()


def test_pool_recovery_after_failed_connection(conn_str):
"""Test that the pool recovers after a failed connection attempt."""
pooling(max_size=1, idle_timeout=30)
# First, try to connect with a bad password (should fail)
if "Pwd=" in conn_str:
bad_conn_str = conn_str.replace("Pwd=", "Pwd=wrongpassword")
elif "Password=" in conn_str:
bad_conn_str = conn_str.replace("Password=", "Password=wrongpassword")
else:
import re

# Replace the value of the first Pwd/Password key-value pair with "wrongpassword"
pattern = re.compile(r"(?i)(Pwd|Password\s*=\s*)([^;]*)")
Comment thread
subrata-ms marked this conversation as resolved.
Outdated
bad_conn_str, num_subs = pattern.subn(lambda m: m.group(1) + "wrongpassword", conn_str, count=1)
if num_subs == 0:
pytest.skip("No password found in connection string to modify")
with pytest.raises(Exception):
connect(bad_conn_str)
Expand Down
Loading