diff --git a/ojp-jdbc-driver/src/main/java/org/openjproxy/jdbc/xa/OjpXALogicalConnection.java b/ojp-jdbc-driver/src/main/java/org/openjproxy/jdbc/xa/OjpXALogicalConnection.java index fce53d984..f624cdd37 100644 --- a/ojp-jdbc-driver/src/main/java/org/openjproxy/jdbc/xa/OjpXALogicalConnection.java +++ b/ojp-jdbc-driver/src/main/java/org/openjproxy/jdbc/xa/OjpXALogicalConnection.java @@ -19,13 +19,11 @@ class OjpXALogicalConnection extends Connection { private final OjpXAConnection xaConnection; - private boolean closed = false; OjpXALogicalConnection(OjpXAConnection xaConnection, SessionInfo sessionInfo, String url, String boundServerAddress) throws SQLException { // Pass the statementService and dbName to the parent Connection class super(sessionInfo, xaConnection.getStatementService(), DatabaseUtils.resolveDbName(url)); this.xaConnection = xaConnection; - // Register with ConnectionTracker if using multinode - this ensures XAConnectionRedistributor // can find and invalidate this connection when the bound server fails if (xaConnection.getStatementService() instanceof MultinodeStatementService) { @@ -42,18 +40,16 @@ class OjpXALogicalConnection extends Connection { } } } - log.debug("Created logical connection using XA session: {}", sessionInfo.getSessionUUID()); } - /** * Find the ServerEndpoint matching the bound server address. */ private ServerEndpoint findServerEndpoint(MultinodeConnectionManager connectionManager, String serverAddress) { try { log.debug("Finding server endpoint for address: {}", serverAddress); - ServerEndpoint serverEndpoint = connectionManager.getServerEndpoints().stream().filter(se -> - se.getAddress().equalsIgnoreCase(serverAddress) + ServerEndpoint serverEndpoint = connectionManager.getServerEndpoints().stream().filter(se -> + se.getAddress().equalsIgnoreCase(serverAddress) ).findFirst().orElse(null); log.debug("Server endpoint for address {} found {}", serverAddress, serverEndpoint != null ? "successfully" : "not found"); return serverEndpoint; @@ -66,16 +62,7 @@ private ServerEndpoint findServerEndpoint(MultinodeConnectionManager connectionM @Override public void close() throws SQLException { log.debug("Logical connection close called"); - if (!closed) { - closed = true; - // Don't close the underlying XA connection - just mark this logical connection as closed - // The actual XA connection will be closed when XAConnection.close() is called - } - } - - @Override - public boolean isClosed() throws SQLException { - return closed; + super.close(); } @Override @@ -103,4 +90,14 @@ public boolean getAutoCommit() throws SQLException { // XA connections are always non-auto-commit return false; } + + @Override + public org.openjproxy.jdbc.Savepoint setSavepoint() throws SQLException { + throw new java.sql.SQLFeatureNotSupportedException("Savepoints are not supported in XA transactions."); + } + + @Override + public org.openjproxy.jdbc.Savepoint setSavepoint(String name) throws SQLException { + throw new java.sql.SQLFeatureNotSupportedException("Savepoints are not supported in XA transactions."); + } } diff --git a/ojp-jdbc-driver/src/test/java/openjproxy/jdbc/MariaDBXAIntegrationTest.java b/ojp-jdbc-driver/src/test/java/openjproxy/jdbc/MariaDBXAIntegrationTest.java new file mode 100644 index 000000000..5ba13327b --- /dev/null +++ b/ojp-jdbc-driver/src/test/java/openjproxy/jdbc/MariaDBXAIntegrationTest.java @@ -0,0 +1,200 @@ +package openjproxy.jdbc; + +import lombok.extern.slf4j.Slf4j; +import openjproxy.jdbc.testutil.TestDBUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvFileSource; +import org.openjproxy.jdbc.xa.OjpXADataSource; + +import javax.sql.XAConnection; +import javax.transaction.xa.XAResource; +import javax.transaction.xa.Xid; +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; +import java.sql.SQLException; +import java.sql.Statement; + +import static org.junit.jupiter.api.Assertions.*; +import static org.junit.jupiter.api.Assumptions.assumeFalse; + +/** + * Integration tests for XA transaction support with MariaDB. + * These tests require: + * 1. A running OJP server (localhost:1059) + * 2. A MariaDB database with XA support + */ +@Slf4j +class MariaDBXAIntegrationTest { + + private static boolean isTestEnabled; + private XAConnection xaConnection; + private Connection connection; + + @BeforeAll + static void checkTestConfiguration() { + isTestEnabled = Boolean.parseBoolean(System.getProperty("enableMariaDBTests", "false")); + } + + private void setUp(String url, String user, String password) throws SQLException { + assumeFalse(!isTestEnabled, "MariaDB XA tests are disabled. Enable with -DenableMariaDBTests=true"); + + // Create XA DataSource + OjpXADataSource xaDataSource = new OjpXADataSource(); + xaDataSource.setUrl(url); + xaDataSource.setUser(user); + xaDataSource.setPassword(password); + + // Get XA Connection + xaConnection = xaDataSource.getXAConnection(user, password); + connection = xaConnection.getConnection(); + } + + @AfterEach + void tearDown() { + TestDBUtils.closeQuietly(connection); + if (xaConnection != null) { + try { + xaConnection.close(); + } catch (Exception e) { + log.warn("Error closing XA connection: {}", e.getMessage()); + } + } + } + + @ParameterizedTest + @CsvFileSource(resources = "/mariadb_xa_connection.csv") + void testXAConnectionBasics(String driverClass, String url, String user, String password) throws Exception { + setUp(url, user, password); + + assertNotNull(xaConnection, "XA connection should be created"); + assertNotNull(connection, "Logical connection should be created"); + assertFalse(connection.isClosed(), "Connection should not be closed"); + + // Get XA Resource + XAResource xaResource = xaConnection.getXAResource(); + assertNotNull(xaResource, "XA resource should not be null"); + + // Verify connection is not auto-commit + assertFalse(connection.getAutoCommit(), "XA connection should not be auto-commit"); + } + + @ParameterizedTest + @CsvFileSource(resources = "/mariadb_xa_connection.csv") + void testXATransactionWithCRUD(String driverClass, String url, String user, String password) throws Exception { + setUp(url, user, password); + + XAResource xaResource = xaConnection.getXAResource(); + String tableName = "mariadb_xa_crud_test"; + + // Clean up and create table + try (Connection setupConn = TestDBUtils.createConnection(url, user, password, false).getConnection()) { + try (Statement stmt = setupConn.createStatement()) { + stmt.execute("DROP TABLE IF EXISTS " + tableName); + stmt.execute("CREATE TABLE " + tableName + " (id INT PRIMARY KEY, val VARCHAR(50))"); + } + } + + // 1. Start XA Transaction + Xid xid = createXid(1); + xaResource.start(xid, XAResource.TMNOFLAGS); + + // 2. Execute DML + try (PreparedStatement ps = connection.prepareStatement("INSERT INTO " + tableName + " VALUES (?, ?)")) { + ps.setInt(1, 1); + ps.setString(2, "XA_VAL_1"); + ps.executeUpdate(); + } + + // 3. End and Commit + xaResource.end(xid, XAResource.TMSUCCESS); + int prepareResult = xaResource.prepare(xid); + assertEquals(XAResource.XA_OK, prepareResult); + xaResource.commit(xid, false); + + // 4. Verify data + try (Connection verifyConn = TestDBUtils.createConnection(url, user, password, false).getConnection()) { + try (PreparedStatement ps = verifyConn.prepareStatement("SELECT val FROM " + tableName + " WHERE id = 1")) { + ResultSet rs = ps.executeQuery(); + assertTrue(rs.next()); + assertEquals("XA_VAL_1", rs.getString(1)); + } + } + } + + @ParameterizedTest + @CsvFileSource(resources = "/mariadb_xa_connection.csv") + void testXARollback(String driverClass, String url, String user, String password) throws Exception { + setUp(url, user, password); + + XAResource xaResource = xaConnection.getXAResource(); + String tableName = "mariadb_xa_rollback_test"; + + // Clean up and create table + try (Connection setupConn = TestDBUtils.createConnection(url, user, password, false).getConnection()) { + try (Statement stmt = setupConn.createStatement()) { + stmt.execute("DROP TABLE IF EXISTS " + tableName); + stmt.execute("CREATE TABLE " + tableName + " (id INT PRIMARY KEY, val VARCHAR(50))"); + } + } + + // 1. Start XA Transaction + Xid xid = createXid(2); + xaResource.start(xid, XAResource.TMNOFLAGS); + + // 2. Execute DML + try (PreparedStatement ps = connection.prepareStatement("INSERT INTO " + tableName + " VALUES (?, ?)")) { + ps.setInt(1, 1); + ps.setString(2, "SHOULD_ROLLBACK"); + ps.executeUpdate(); + } + + // 3. End and Rollback + xaResource.end(xid, XAResource.TMSUCCESS); + xaResource.rollback(xid); + + // 4. Verify absence of data + try (Connection verifyConn = TestDBUtils.createConnection(url, user, password, false).getConnection()) { + try (PreparedStatement ps = verifyConn.prepareStatement("SELECT COUNT(*) FROM " + tableName)) { + ResultSet rs = ps.executeQuery(); + assertTrue(rs.next()); + assertEquals(0, rs.getInt(1)); + } + } + } + + private Xid createXid(int id) { + byte[] gtrid = new byte[] { (byte) (0x10 + id), 0x01, 0x02 }; + byte[] bqual = new byte[] { (byte) (0x20 + id), 0x03, 0x04 }; + return new TestXid(0x1234, gtrid, bqual); + } + + private static class TestXid implements Xid { + private final int formatId; + private final byte[] globalTransactionId; + private final byte[] branchQualifier; + + TestXid(int formatId, byte[] globalTransactionId, byte[] branchQualifier) { + this.formatId = formatId; + this.globalTransactionId = globalTransactionId; + this.branchQualifier = branchQualifier; + } + + @Override + public int getFormatId() { + return formatId; + } + + @Override + public byte[] getGlobalTransactionId() { + return globalTransactionId; + } + + @Override + public byte[] getBranchQualifier() { + return branchQualifier; + } + } +} diff --git a/ojp-jdbc-driver/src/test/java/openjproxy/jdbc/MySQLMariaDBConnectionExtensiveTests.java b/ojp-jdbc-driver/src/test/java/openjproxy/jdbc/MySQLMariaDBConnectionExtensiveTests.java index 9fde178ef..d8796c9ec 100644 --- a/ojp-jdbc-driver/src/test/java/openjproxy/jdbc/MySQLMariaDBConnectionExtensiveTests.java +++ b/ojp-jdbc-driver/src/test/java/openjproxy/jdbc/MySQLMariaDBConnectionExtensiveTests.java @@ -2,6 +2,7 @@ import lombok.SneakyThrows; import openjproxy.jdbc.testutil.TestDBUtils; +import openjproxy.jdbc.testutil.TestDBUtils.ConnectionResult; import org.junit.Assert; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeAll; @@ -11,7 +12,6 @@ import java.sql.CallableStatement; import java.sql.Connection; import java.sql.DatabaseMetaData; -import java.sql.DriverManager; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLClientInfoException; @@ -25,11 +25,12 @@ import static org.junit.jupiter.api.Assertions.*; import static org.junit.jupiter.api.Assumptions.assumeFalse; - class MySQLMariaDBConnectionExtensiveTests { +class MySQLMariaDBConnectionExtensiveTests { private static boolean isMySQLTestEnabled; private static boolean isMariaDBTestEnabled; private Connection connection; + private ConnectionResult connectionResult; @BeforeAll static void checkTestConfiguration() { @@ -38,21 +39,25 @@ static void checkTestConfiguration() { } @SneakyThrows - public void setUp(String driverClass, String url, String user, String password) throws SQLException { - assumeFalse(!isMySQLTestEnabled, "MySQL tests are not enabled"); - assumeFalse(!isMariaDBTestEnabled, "MariaDB tests are not enabled"); - connection = DriverManager.getConnection(url, user, password); + public void setUp(String driverClass, String url, String user, String password, boolean isXA) throws SQLException { + assumeFalse(!isMySQLTestEnabled && url.toLowerCase().contains("mysql"), "MySQL tests are not enabled"); + assumeFalse(!isMariaDBTestEnabled && url.toLowerCase().contains("mariadb"), "MariaDB tests are not enabled"); + connectionResult = TestDBUtils.createConnection(url, user, password, isXA); + connection = connectionResult.getConnection(); } @AfterEach void tearDown() { - TestDBUtils.closeQuietly(connection); + if (connectionResult != null) { + connectionResult.close(); + } } @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testCreateStatement(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testCreateStatement(String driverClass, String url, String user, String password, boolean isXA) + throws SQLException { + setUp(driverClass, url, user, password, isXA); Statement statement = connection.createStatement(); assertNotNull(statement); @@ -61,8 +66,9 @@ void testCreateStatement(String driverClass, String url, String user, String pas @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testPrepareStatement(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testPrepareStatement(String driverClass, String url, String user, String password, boolean isXA) + throws SQLException { + setUp(driverClass, url, user, password, isXA); PreparedStatement preparedStatement = connection.prepareStatement("SELECT 1"); assertNotNull(preparedStatement); @@ -71,8 +77,9 @@ void testPrepareStatement(String driverClass, String url, String user, String pa @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testPrepareCall(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testPrepareCall(String driverClass, String url, String user, String password, boolean isXA) + throws SQLException { + setUp(driverClass, url, user, password, isXA); // MySQL supports callable statements, though syntax may differ try { @@ -87,8 +94,8 @@ void testPrepareCall(String driverClass, String url, String user, String passwor @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testNativeSQL(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testNativeSQL(String driverClass, String url, String user, String password, boolean isXA) throws SQLException { + setUp(driverClass, url, user, password, isXA); String nativeSQL = connection.nativeSQL("SELECT {fn NOW()}"); assertNotNull(nativeSQL); @@ -98,8 +105,9 @@ void testNativeSQL(String driverClass, String url, String user, String password) @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testAutoCommit(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testAutoCommit(String driverClass, String url, String user, String password, boolean isXA) + throws SQLException { + setUp(driverClass, url, user, password, isXA); // Test getting and setting auto-commit boolean originalAutoCommit = connection.getAutoCommit(); @@ -108,7 +116,12 @@ void testAutoCommit(String driverClass, String url, String user, String password assertFalse(connection.getAutoCommit()); connection.setAutoCommit(true); - assertTrue(connection.getAutoCommit()); + if (!isXA) { + assertTrue(connection.getAutoCommit()); + } else { + // XA connections ignore setAutoCommit(true) and always return false + assertFalse(connection.getAutoCommit()); + } // Restore original state connection.setAutoCommit(originalAutoCommit); @@ -116,23 +129,29 @@ void testAutoCommit(String driverClass, String url, String user, String password @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testCommitAndRollback(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testCommitAndRollback(String driverClass, String url, String user, String password, boolean isXA) + throws SQLException { + setUp(driverClass, url, user, password, isXA); // Test commit and rollback operations connection.setAutoCommit(false); - // These should not throw exceptions - connection.commit(); - connection.rollback(); + // In XA mode, connection.commit() and rollback() are not allowed + if (isXA) { + assertThrows(SQLException.class, () -> connection.commit()); + assertThrows(SQLException.class, () -> connection.rollback()); + } else { + connection.commit(); + connection.rollback(); + } connection.setAutoCommit(true); } @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testIsClosed(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testIsClosed(String driverClass, String url, String user, String password, boolean isXA) throws SQLException { + setUp(driverClass, url, user, password, isXA); assertFalse(connection.isClosed()); @@ -142,8 +161,9 @@ void testIsClosed(String driverClass, String url, String user, String password) @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testGetMetaData(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testGetMetaData(String driverClass, String url, String user, String password, boolean isXA) + throws SQLException { + setUp(driverClass, url, user, password, isXA); DatabaseMetaData metaData = connection.getMetaData(); assertNotNull(metaData); @@ -158,8 +178,8 @@ void testGetMetaData(String driverClass, String url, String user, String passwor @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testReadOnly(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testReadOnly(String driverClass, String url, String user, String password, boolean isXA) throws SQLException { + setUp(driverClass, url, user, password, isXA); // Test read-only mode boolean originalReadOnly = connection.isReadOnly(); @@ -179,8 +199,8 @@ void testReadOnly(String driverClass, String url, String user, String password) @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testCatalog(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testCatalog(String driverClass, String url, String user, String password, boolean isXA) throws SQLException { + setUp(driverClass, url, user, password, isXA); String catalog = connection.getCatalog(); // Catalog might be null or the database name @@ -194,11 +214,13 @@ void testCatalog(String driverClass, String url, String user, String password) t @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testTransactionIsolation(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testTransactionIsolation(String driverClass, String url, String user, String password, boolean isXA) + throws SQLException { + setUp(driverClass, url, user, password, isXA); int isolationLevel = connection.getTransactionIsolation(); - assertTrue(isolationLevel >= Connection.TRANSACTION_NONE && isolationLevel <= Connection.TRANSACTION_SERIALIZABLE); + assertTrue( + isolationLevel >= Connection.TRANSACTION_NONE && isolationLevel <= Connection.TRANSACTION_SERIALIZABLE); // Test setting transaction isolation level connection.setTransactionIsolation(Connection.TRANSACTION_READ_COMMITTED); @@ -210,8 +232,8 @@ void testTransactionIsolation(String driverClass, String url, String user, Strin @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testWarnings(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testWarnings(String driverClass, String url, String user, String password, boolean isXA) throws SQLException { + setUp(driverClass, url, user, password, isXA); // Test warning operations SQLWarning warnings = connection.getWarnings(); @@ -223,8 +245,9 @@ void testWarnings(String driverClass, String url, String user, String password) @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testCreateStatementWithParameters(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testCreateStatementWithParameters(String driverClass, String url, String user, String password, boolean isXA) + throws SQLException { + setUp(driverClass, url, user, password, isXA); Statement statement = connection.createStatement(ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); assertNotNull(statement); @@ -237,10 +260,12 @@ void testCreateStatementWithParameters(String driverClass, String url, String us @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testPrepareStatementWithParameters(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testPrepareStatementWithParameters(String driverClass, String url, String user, String password, boolean isXA) + throws SQLException { + setUp(driverClass, url, user, password, isXA); - PreparedStatement ps = connection.prepareStatement("SELECT 1", ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY); + PreparedStatement ps = connection.prepareStatement("SELECT 1", ResultSet.TYPE_FORWARD_ONLY, + ResultSet.CONCUR_READ_ONLY); assertNotNull(ps); ps.close(); @@ -251,11 +276,13 @@ void testPrepareStatementWithParameters(String driverClass, String url, String u @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testHoldability(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testHoldability(String driverClass, String url, String user, String password, boolean isXA) + throws SQLException { + setUp(driverClass, url, user, password, isXA); int holdability = connection.getHoldability(); - assertTrue(holdability == ResultSet.HOLD_CURSORS_OVER_COMMIT || holdability == ResultSet.CLOSE_CURSORS_AT_COMMIT); + assertTrue( + holdability == ResultSet.HOLD_CURSORS_OVER_COMMIT || holdability == ResultSet.CLOSE_CURSORS_AT_COMMIT); // Test setting holdability connection.setHoldability(ResultSet.HOLD_CURSORS_OVER_COMMIT); @@ -264,11 +291,18 @@ void testHoldability(String driverClass, String url, String user, String passwor @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testSavepoints(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testSavepoints(String driverClass, String url, String user, String password, boolean isXA) + throws SQLException { + setUp(driverClass, url, user, password, isXA); connection.setAutoCommit(false); + if (isXA) { + assertThrows(SQLFeatureNotSupportedException.class, () -> connection.setSavepoint()); + assertThrows(SQLFeatureNotSupportedException.class, () -> connection.setSavepoint("test_savepoint")); + return; + } + // Test unnamed savepoint Savepoint savepoint1 = connection.setSavepoint(); assertNotNull(savepoint1); @@ -290,8 +324,9 @@ void testSavepoints(String driverClass, String url, String user, String password @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testClientInfo(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testClientInfo(String driverClass, String url, String user, String password, boolean isXA) + throws SQLException { + setUp(driverClass, url, user, password, isXA); Properties clientInfo = connection.getClientInfo(); assertNotNull(clientInfo); @@ -307,8 +342,8 @@ void testClientInfo(String driverClass, String url, String user, String password @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testValid(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testValid(String driverClass, String url, String user, String password, boolean isXA) throws SQLException { + setUp(driverClass, url, user, password, isXA); boolean isValid = connection.isValid(5); assertTrue(isValid); @@ -321,16 +356,17 @@ void testValid(String driverClass, String url, String user, String password) thr @ParameterizedTest @CsvFileSource(resources = "/mysql_mariadb_connection.csv") - void testUnsupportedOperations(String driverClass, String url, String user, String password) throws SQLException { - setUp(driverClass, url, user, password); + void testUnsupportedOperations(String driverClass, String url, String user, String password, boolean isXA) + throws SQLException { + setUp(driverClass, url, user, password, isXA); // Test operations that might not be supported Assert.assertThrows(SQLException.class, () -> { - connection.createArrayOf("VARCHAR", new String[]{"test"}); + connection.createArrayOf("VARCHAR", new String[] { "test" }); }); Assert.assertThrows(SQLFeatureNotSupportedException.class, () -> { - connection.createStruct("test_type", new Object[]{}); + connection.createStruct("test_type", new Object[] {}); }); } } \ No newline at end of file diff --git a/ojp-jdbc-driver/src/test/resources/mariadb_xa_connection.csv b/ojp-jdbc-driver/src/test/resources/mariadb_xa_connection.csv new file mode 100644 index 000000000..097ff5cb1 --- /dev/null +++ b/ojp-jdbc-driver/src/test/resources/mariadb_xa_connection.csv @@ -0,0 +1 @@ +org.openjproxy.jdbc.Driver,jdbc:ojp[localhost:1059]_mariadb://localhost:3307/defaultdb,testuser,testpassword diff --git a/ojp-jdbc-driver/src/test/resources/mysql_mariadb_connection.csv b/ojp-jdbc-driver/src/test/resources/mysql_mariadb_connection.csv index 8bda7086d..565e250c7 100644 --- a/ojp-jdbc-driver/src/test/resources/mysql_mariadb_connection.csv +++ b/ojp-jdbc-driver/src/test/resources/mysql_mariadb_connection.csv @@ -1,2 +1,4 @@ org.openjproxy.jdbc.Driver,jdbc:ojp[localhost:1059]_mysql://localhost:3306/defaultdb,testuser,testpassword,false -org.openjproxy.jdbc.Driver,jdbc:ojp[localhost:1059]_mariadb://localhost:3307/defaultdb,testuser,testpassword,false \ No newline at end of file +org.openjproxy.jdbc.Driver,jdbc:ojp[localhost:1059]_mariadb://localhost:3307/defaultdb,testuser,testpassword,false +org.openjproxy.jdbc.Driver,jdbc:ojp[localhost:1059]_mysql://localhost:3306/defaultdb,testuser,testpassword,true +org.openjproxy.jdbc.Driver,jdbc:ojp[localhost:1059]_mariadb://localhost:3307/defaultdb,testuser,testpassword,true \ No newline at end of file diff --git a/ojp-server/src/main/java/org/openjproxy/grpc/server/action/connection/HandleXAConnectionWithPoolingAction.java b/ojp-server/src/main/java/org/openjproxy/grpc/server/action/connection/HandleXAConnectionWithPoolingAction.java index fa64411d5..2c2a60b7c 100644 --- a/ojp-server/src/main/java/org/openjproxy/grpc/server/action/connection/HandleXAConnectionWithPoolingAction.java +++ b/ojp-server/src/main/java/org/openjproxy/grpc/server/action/connection/HandleXAConnectionWithPoolingAction.java @@ -36,34 +36,34 @@ */ @Slf4j public class HandleXAConnectionWithPoolingAction { - + private static final HandleXAConnectionWithPoolingAction INSTANCE = new HandleXAConnectionWithPoolingAction(); - + // Lock objects for synchronizing registry creation per connection hash // Using ReentrantLock for virtual thread compatibility // Note: In practice, connection hashes are bounded by the finite set of database credentials // used by the application, so this map won't grow indefinitely private final Map registryLocks = new ConcurrentHashMap<>(); - + private HandleXAConnectionWithPoolingAction() { // Private constructor prevents external instantiation } - + public static HandleXAConnectionWithPoolingAction getInstance() { return INSTANCE; } - + public void execute(ActionContext context, ConnectionDetails connectionDetails, String connHash, - int actualMaxXaTransactions, long xaStartTimeoutMillis, - StreamObserver responseObserver) { + int actualMaxXaTransactions, long xaStartTimeoutMillis, + StreamObserver responseObserver) { log.info("Using XA Pool Provider SPI for connHash: {}", connHash); - + // Use ReentrantLock for virtual thread compatibility // Lock ONLY during registry creation/check, not during session borrowing // This prevents race conditions in registry creation while avoiding deadlocks during pool borrowing Lock lock = getRegistryLock(connHash); XATransactionRegistry registry; - + lock.lock(); try { registry = getOrCreateRegistry(context, connectionDetails, connHash, actualMaxXaTransactions, xaStartTimeoutMillis, responseObserver); @@ -74,12 +74,10 @@ public void execute(ActionContext context, ConnectionDetails connectionDetails, } finally { lock.unlock(); } - // Now borrow from pool WITHOUT holding the lock // This allows multiple threads to borrow concurrently from the same pool borrowSessionAndRespond(context, connectionDetails, connHash, actualMaxXaTransactions, xaStartTimeoutMillis, registry, responseObserver); } - /** * Get a lock for synchronizing registry operations for a specific connection hash. * This prevents race conditions where multiple threads try to create registries simultaneously. @@ -88,28 +86,26 @@ public void execute(ActionContext context, ConnectionDetails connectionDetails, private Lock getRegistryLock(String connHash) { return registryLocks.computeIfAbsent(connHash, k -> new ReentrantLock()); } - /** * Get or create the XA registry for the given connection hash. * This method is called while holding the registry lock to prevent race conditions. * Returns null if unpooled mode or if an error occurred (error already sent to client). */ private XATransactionRegistry getOrCreateRegistry(ActionContext context, ConnectionDetails connectionDetails, String connHash, - int actualMaxXaTransactions, long xaStartTimeoutMillis, - StreamObserver responseObserver) { - + int actualMaxXaTransactions, long xaStartTimeoutMillis, + StreamObserver responseObserver) { + // Get current serverEndpoints configuration List currentServerEndpoints = connectionDetails.getServerEndpointsList(); - String currentEndpointsHash = (currentServerEndpoints == null || currentServerEndpoints.isEmpty()) - ? "NONE" + String currentEndpointsHash = (currentServerEndpoints == null || currentServerEndpoints.isEmpty()) + ? "NONE" : String.join(",", currentServerEndpoints); - + // Check if we already have an XA registry for this connection hash // NOTE: This method is called while holding the registry lock to prevent race conditions XATransactionRegistry registry = context.getXaRegistries().get(connHash); log.info("XA registry cache lookup for {}: exists={}, current serverEndpoints hash: {}", - connHash, registry != null, currentEndpointsHash); - + connHash, registry != null, currentEndpointsHash); // Calculate what the pool sizes SHOULD be based on current configuration int expectedMaxPoolSize; int expectedMinIdle; @@ -117,7 +113,7 @@ private XATransactionRegistry getOrCreateRegistry(ActionContext context, Connect try { Properties clientProperties = ConnectionPoolConfigurer.extractClientProperties(connectionDetails); DataSourceConfigurationManager.XADataSourceConfiguration xaConfig = - DataSourceConfigurationManager.getXAConfiguration(clientProperties); + DataSourceConfigurationManager.getXAConfiguration(clientProperties); expectedMaxPoolSize = xaConfig.getMaximumPoolSize(); expectedMinIdle = xaConfig.getMinimumIdle(); poolEnabled = xaConfig.isPoolEnabled(); @@ -139,17 +135,17 @@ private XATransactionRegistry getOrCreateRegistry(ActionContext context, Connect expectedMinIdle = -1; poolEnabled = true; // Default to pooled mode if config fails } - + // Check if registry exists and needs recreation due to configuration mismatch boolean needsRecreation = false; if (registry != null) { String registryEndpointsHash = registry.getServerEndpointsHash(); int registryMaxPool = registry.getMaxPoolSize(); int registryMinIdle = registry.getMinIdle(); - + // Check if serverEndpoints changed if (registryEndpointsHash == null || !registryEndpointsHash.equals(currentEndpointsHash)) { - log.warn("XA registry for {} has serverEndpoints mismatch: registry='{}' vs current='{}'. Will recreate.", + log.warn("XA registry for {} has serverEndpoints mismatch: registry='{}' vs current='{}'. Will recreate.", connHash, registryEndpointsHash, currentEndpointsHash); needsRecreation = true; } @@ -164,7 +160,7 @@ else if (expectedMinIdle > 0 && registryMinIdle != expectedMinIdle) { connHash, registryMinIdle, expectedMinIdle); needsRecreation = true; } - + if (needsRecreation) { // Close and remove old registry try { @@ -179,62 +175,60 @@ else if (expectedMinIdle > 0 && registryMinIdle != expectedMinIdle) { context.getClusterHealthTracker().removeTracking(connHash); } } - + if (registry == null) { log.info("Creating NEW XA registry for connHash: {} with serverEndpoints: {}", connHash, currentEndpointsHash); - + // Check if XA pooling is enabled if (!poolEnabled) { log.info("XA unpooled mode enabled for connHash: {}", connHash); - + // Handle unpooled XA connection (releases lock before handling) // Return null to signal unpooled mode was handled HandleUnpooledXAConnectionAction.getInstance().execute(context, connectionDetails, connHash, responseObserver); return null; } - + try { // Parse URL to remove OJP-specific prefix (same as non-XA path) String parsedUrl = UrlParser.parseUrl(connectionDetails.getUrl()); - + // Get XA datasource configuration from client properties (uses XA-specific properties) Properties clientProperties = ConnectionPoolConfigurer.extractClientProperties(connectionDetails); DataSourceConfigurationManager.XADataSourceConfiguration xaConfig = - DataSourceConfigurationManager.getXAConfiguration(clientProperties); - + DataSourceConfigurationManager.getXAConfiguration(clientProperties); + // Get default pool sizes from XA configuration int maxPoolSize = xaConfig.getMaximumPoolSize(); int minIdle = xaConfig.getMinimumIdle(); - - log.info("XA pool BEFORE multinode coordination for {}: requested max={}, min={}", + + log.info("XA pool BEFORE multinode coordination for {}: requested max={}, min={}", connHash, maxPoolSize, minIdle); - + // Apply multinode pool coordination if server endpoints provided List serverEndpoints = connectionDetails.getServerEndpointsList(); - log.info("XA serverEndpoints list: null={}, size={}, endpoints={}", - serverEndpoints == null, + log.info("XA serverEndpoints list: null={}, size={}, endpoints={}", + serverEndpoints == null, serverEndpoints == null ? 0 : serverEndpoints.size(), serverEndpoints); - + if (serverEndpoints != null && !serverEndpoints.isEmpty()) { // Multinode: divide pool sizes among servers MultinodePoolCoordinator.PoolAllocation allocation = - ConnectionPoolConfigurer.getPoolCoordinator().calculatePoolSizes( + ConnectionPoolConfigurer.getPoolCoordinator().calculatePoolSizes( connHash, maxPoolSize, minIdle, serverEndpoints); - + maxPoolSize = allocation.getCurrentMaxPoolSize(); minIdle = allocation.getCurrentMinIdle(); - - log.info("XA multinode pool coordination for {}: {} servers, divided sizes: max={}, min={}", + + log.info("XA multinode pool coordination for {}: {} servers, divided sizes: max={}, min={}", connHash, serverEndpoints.size(), maxPoolSize, minIdle); } else { log.info("XA multinode coordination SKIPPED for {}: serverEndpoints null or empty", connHash); } - - log.info("XA pool AFTER multinode coordination for {}: final max={}, min={}", - connHash, maxPoolSize, minIdle); - + log.info("XA pool AFTER multinode coordination for {}: final max={}, min={}", + connHash, maxPoolSize, minIdle); // Build configuration map for XA Pool Provider Map xaPoolConfig = new HashMap<>(); xaPoolConfig.put("xa.datasource.className", getXADataSourceClassName(parsedUrl)); @@ -251,30 +245,30 @@ else if (expectedMinIdle > 0 && registryMinIdle != expectedMinIdle) { xaPoolConfig.put("xa.timeBetweenEvictionRunsMs", String.valueOf(xaConfig.getTimeBetweenEvictionRuns())); xaPoolConfig.put("xa.numTestsPerEvictionRun", String.valueOf(xaConfig.getNumTestsPerEvictionRun())); xaPoolConfig.put("xa.softMinEvictableIdleTimeMs", String.valueOf(xaConfig.getSoftMinEvictableIdleTime())); - + // Create pooled XA DataSource via provider log.info("[XA-POOL-CREATE] Creating XA pool for connHash={}, serverEndpointsHash={}, config=(max={}, min={})", connHash, currentEndpointsHash, maxPoolSize, minIdle); Object pooledXADataSource = context.getXaPoolProvider().createXADataSource(xaPoolConfig); - + // Create XA Transaction Registry with serverEndpoints hash and pool sizes for validation registry = new XATransactionRegistry(context.getXaPoolProvider(), pooledXADataSource, currentEndpointsHash, maxPoolSize, minIdle); context.getXaRegistries().put(connHash, registry); - + // Initialize pool with minIdle connections immediately after creation // Without this, the pool starts empty and only creates connections on demand log.info("[XA-POOL-INIT] Initializing XA pool with minIdle={} connections for connHash={}", minIdle, connHash); registry.resizeBackendPool(maxPoolSize, minIdle); - + // Create slow query segregation manager for XA CreateSlowQuerySegregationManagerAction.getInstance().execute(context, connHash, actualMaxXaTransactions, true, xaStartTimeoutMillis); - - log.info("[XA-POOL-CREATE] Successfully created XA pool for connHash={} - maxPoolSize={}, minIdle={}, multinode={}, poolObject={}", - connHash, maxPoolSize, minIdle, serverEndpoints != null && !serverEndpoints.isEmpty(), + + log.info("[XA-POOL-CREATE] Successfully created XA pool for connHash={} - maxPoolSize={}, minIdle={}, multinode={}, poolObject={}", + connHash, maxPoolSize, minIdle, serverEndpoints != null && !serverEndpoints.isEmpty(), pooledXADataSource.getClass().getSimpleName()); - + } catch (Exception e) { - log.error("[XA-POOL-CREATE] FAILED to create XA Pool Provider registry for connHash={}, serverEndpointsHash={}: {}", + log.error("[XA-POOL-CREATE] FAILED to create XA Pool Provider registry for connHash={}, serverEndpointsHash={}: {}", connHash, currentEndpointsHash, e.getMessage(), e); SQLException sqlException = new SQLException("Failed to create XA pool: " + e.getMessage(), e); sendSQLExceptionMetadata(sqlException, responseObserver); @@ -284,20 +278,20 @@ else if (expectedMinIdle > 0 && registryMinIdle != expectedMinIdle) { log.info("[XA-POOL-REUSE] Reusing EXISTING XA registry for connHash={} (pool already created, cached sizes: max={}, min={})", connHash, registry.getMaxPoolSize(), registry.getMinIdle()); } - + return registry; } - + /** * Borrow a session from the pool and send response to client. * This method is called WITHOUT holding the registry lock to allow concurrent borrowing. */ private void borrowSessionAndRespond(ActionContext context, ConnectionDetails connectionDetails, String connHash, - int actualMaxXaTransactions, long xaStartTimeoutMillis, XATransactionRegistry registry, - StreamObserver responseObserver) { - + int actualMaxXaTransactions, long xaStartTimeoutMillis, XATransactionRegistry registry, + StreamObserver responseObserver) { + context.getSessionManager().registerClientUUID(connHash, connectionDetails.getClientUUID()); - + // CRITICAL FIX: Call processClusterHealth() BEFORE borrowing session // This ensures pool rebalancing happens even when server 1 fails before any XA operations execute // Without this, pool exhaustion prevents cluster health propagation and pool never expands @@ -305,7 +299,7 @@ private void borrowSessionAndRespond(ActionContext context, ConnectionDetails co // Use the ACTUAL cluster health from the client (not synthetic) // The client sends the current health status of all servers String actualClusterHealth = connectionDetails.getClusterHealth(); - + // Create a temporary SessionInfo with cluster health for processing // We don't have the actual sessionInfo yet since we haven't borrowed from the pool SessionInfo tempSessionInfo = SessionInfo.newBuilder() @@ -313,52 +307,52 @@ private void borrowSessionAndRespond(ActionContext context, ConnectionDetails co .setConnHash(connHash) .setClusterHealth(actualClusterHealth) .build(); - - log.info("[XA-CONNECT-REBALANCE] Calling processClusterHealth BEFORE borrow for connHash={}, clusterHealth={}", + + log.info("[XA-CONNECT-REBALANCE] Calling processClusterHealth BEFORE borrow for connHash={}, clusterHealth={}", connHash, actualClusterHealth); - + // Process cluster health to trigger pool rebalancing if needed ProcessClusterHealthAction.getInstance().execute(context, tempSessionInfo); } else { - log.warn("[XA-CONNECT-REBALANCE] No cluster health provided in ConnectionDetails for connHash={}, pool rebalancing may be delayed", + log.warn("[XA-CONNECT-REBALANCE] No cluster health provided in ConnectionDetails for connHash={}, pool rebalancing may be delayed", connHash); } - + // Borrow a XABackendSession from the pool for immediate use // Note: Unlike the original "deferred" approach, we allocate eagerly because // XA applications expect getConnection() to work immediately, before xaStart() XABackendSession backendSession = null; try { - backendSession = + backendSession = (XABackendSession) context.getXaPoolProvider().borrowSession(registry.getPooledXADataSource()); - + XAConnection xaConnection = backendSession.getXAConnection(); Connection connection = backendSession.getConnection(); - + // Create XA session with the pooled XAConnection SessionInfo sessionInfo = context.getSessionManager().createXASession( connectionDetails.getClientUUID(), connection, xaConnection); - + // Store the XABackendSession reference in the session for later lifecycle management Session session = context.getSessionManager().getSession(sessionInfo); if (session != null) { session.setBackendSession(backendSession); } - - log.info("Created XA session (pooled, eager allocation) with client UUID: {} for connHash: {}", + + log.info("Created XA session (pooled, eager allocation) with client UUID: {} for connHash: {}", connectionDetails.getClientUUID(), connHash); - + // Note: processClusterHealth() already called BEFORE borrowing session (see above) // This ensures pool is resized before we try to borrow, preventing exhaustion - + responseObserver.onNext(sessionInfo); context.getDbNameMap().put(connHash, DatabaseUtils.resolveDbName(connectionDetails.getUrl())); responseObserver.onCompleted(); - + } catch (Exception e) { - log.error("Failed to borrow XABackendSession from pool for connection hash {}: {}", + log.error("Failed to borrow XABackendSession from pool for connection hash {}: {}", connHash, e.getMessage(), e); - + // CRITICAL FIX: Return the borrowed session back to pool on failure to prevent session leaks // This was causing PostgreSQL "too many clients" errors as leaked sessions bypassed pool limits if (backendSession != null) { @@ -376,13 +370,13 @@ private void borrowSessionAndRespond(ActionContext context, ConnectionDetails co } } } - + SQLException sqlException = new SQLException("Failed to allocate XA session from pool: " + e.getMessage(), e); sendSQLExceptionMetadata(sqlException, responseObserver); return; } } - + /** * Determine XADataSource class name based on database URL. */ @@ -396,7 +390,9 @@ private String getXADataSourceClassName(String url) { return "com.microsoft.sqlserver.jdbc.SQLServerXADataSource"; } else if (lowerUrl.contains(":db2:")) { return "com.ibm.db2.jcc.DB2XADataSource"; - } else if (lowerUrl.contains(":mysql:") || lowerUrl.contains(":mariadb:")) { + } else if (lowerUrl.contains(":mariadb:")) { + return "org.mariadb.jdbc.MariaDbDataSource"; + } else if (lowerUrl.contains(":mysql:")) { return "com.mysql.cj.jdbc.MysqlXADataSource"; } else { throw new IllegalArgumentException("Unsupported database for XA: " + url); diff --git a/ojp-server/src/main/java/org/openjproxy/grpc/server/utils/DriverUtils.java b/ojp-server/src/main/java/org/openjproxy/grpc/server/utils/DriverUtils.java index a1750b490..8c83febd2 100644 --- a/ojp-server/src/main/java/org/openjproxy/grpc/server/utils/DriverUtils.java +++ b/ojp-server/src/main/java/org/openjproxy/grpc/server/utils/DriverUtils.java @@ -20,48 +20,48 @@ @Slf4j @UtilityClass public class DriverUtils { - + /** * Register all JDBC drivers supported and report their availability status. - * This checks if the driver can be loaded via Class.forName() OR if it's registered with DriverManager. + * This checks if the driver can be loaded via Class.forName() OR if it's registered with DriverManager. * @param driversPath Optional path to external libraries directory for user guidance in error messages */ public void registerDrivers(String driversPath) { - String driverPathMessage = (driversPath != null && !driversPath.trim().isEmpty()) - ? driversPath - : "./ojp-libs"; - + String driverPathMessage = (driversPath != null && !driversPath.trim().isEmpty()) + ? driversPath + : "./ojp-libs"; + //Check open source drivers - checkDriver(H2_DRIVER_CLASS, "H2", - "https://mvnrepository.com/artifact/com.h2database/h2", "h2-*.jar", driverPathMessage); - checkDriver(POSTGRES_DRIVER_CLASS, "PostgreSQL", - "https://mvnrepository.com/artifact/org.postgresql/postgresql", "postgresql-*.jar", driverPathMessage); - checkDriver(MYSQL_DRIVER_CLASS, "MySQL", - "https://mvnrepository.com/artifact/com.mysql/mysql-connector-j", "mysql-connector-j-*.jar", driverPathMessage); - checkDriver(MARIADB_DRIVER_CLASS, "MariaDB", - "https://mvnrepository.com/artifact/org.mariadb.jdbc/mariadb-java-client", "mariadb-java-client-*.jar", driverPathMessage); - + checkDriver(H2_DRIVER_CLASS, "H2", + "https://mvnrepository.com/artifact/com.h2database/h2", "h2-*.jar", driverPathMessage); + checkDriver(POSTGRES_DRIVER_CLASS, "PostgreSQL", + "https://mvnrepository.com/artifact/org.postgresql/postgresql", "postgresql-*.jar", driverPathMessage); + checkDriver(MYSQL_DRIVER_CLASS, "MySQL", + "https://mvnrepository.com/artifact/com.mysql/mysql-connector-j", "mysql-connector-j-*.jar", driverPathMessage); + checkDriver(MARIADB_DRIVER_CLASS, "MariaDB", + "https://mvnrepository.com/artifact/org.mariadb.jdbc/mariadb-java-client", "mariadb-java-client-*.jar", driverPathMessage); + //Check proprietary drivers (if present) - checkDriver(ORACLE_DRIVER_CLASS, "Oracle", - "https://www.oracle.com/database/technologies/jdbc-downloads.html", "ojdbc*.jar", driverPathMessage); - checkDriver(SQLSERVER_DRIVER_CLASS, "SQL Server", - "https://learn.microsoft.com/en-us/sql/connect/jdbc/download-microsoft-jdbc-driver-for-sql-server", "mssql-jdbc-*.jar", driverPathMessage); - checkDriver(DB2_DRIVER_CLASS, "DB2", - "IBM website", "db2jcc*.jar", driverPathMessage); + checkDriver(ORACLE_DRIVER_CLASS, "Oracle", + "https://www.oracle.com/database/technologies/jdbc-downloads.html", "ojdbc*.jar", driverPathMessage); + checkDriver(SQLSERVER_DRIVER_CLASS, "SQL Server", + "https://learn.microsoft.com/en-us/sql/connect/jdbc/download-microsoft-jdbc-driver-for-sql-server", "mssql-jdbc-*.jar", driverPathMessage); + checkDriver(DB2_DRIVER_CLASS, "DB2", + "IBM website", "db2jcc*.jar", driverPathMessage); } - + /** * Check if a driver is available either via Class.forName() or DriverManager. * This method checks both the main classpath and drivers registered with DriverManager * (which includes drivers loaded via URLClassLoader and wrapped in DriverShim). */ - private void checkDriver(String driverClass, String driverName, - String downloadUrl, String jarName, String driverPath) { + private void checkDriver(String driverClass, String driverName, + String downloadUrl, String jarName, String driverPath) { boolean found = false; - + // First try Class.forName() - works for drivers in the main classpath try { - Class.forName(driverClass); + Class.forName(driverClass, true, Thread.currentThread().getContextClassLoader()); found = true; } catch (ClassNotFoundException e) { // Driver not in main classpath, check if it's registered with DriverManager @@ -83,7 +83,7 @@ private void checkDriver(String driverClass, String driverName, } } } - + if (found) { log.info("{} JDBC driver loaded successfully", driverName); } else { @@ -93,9 +93,9 @@ private void checkDriver(String driverClass, String driverName, log.info(" 3. Restart OJP Server"); } } - + /** - * Register all JDBC drivers supported without path information. + * Register all JDBC drivers supported without path information. * @deprecated Use {@link #registerDrivers(String)} instead to provide better error messages */ @Deprecated diff --git a/ojp-server/src/main/java/org/openjproxy/grpc/server/xa/XADataSourceFactory.java b/ojp-server/src/main/java/org/openjproxy/grpc/server/xa/XADataSourceFactory.java index 4a6e46e9c..fb55ed515 100644 --- a/ojp-server/src/main/java/org/openjproxy/grpc/server/xa/XADataSourceFactory.java +++ b/ojp-server/src/main/java/org/openjproxy/grpc/server/xa/XADataSourceFactory.java @@ -14,6 +14,8 @@ @Slf4j public class XADataSourceFactory { + public static final String POSTGRESQL_XA_DATASOURCE = "org.postgresql.xa.PGXADataSource"; + /** * Creates an XADataSource for the specified database type based on the URL. * @@ -24,10 +26,11 @@ public class XADataSourceFactory { */ public static XADataSource createXADataSource(String url, ConnectionDetails connectionDetails) throws SQLException { String lowerUrl = url.toLowerCase(); - try { if (lowerUrl.contains("postgresql")) { return createPostgreSQLXADataSource(url, connectionDetails); + } else if (lowerUrl.contains("mariadb")) { + return createMariaDBXADataSource(url, connectionDetails); } else if (lowerUrl.contains("mysql")) { return createMySQLXADataSource(url, connectionDetails); } else if (lowerUrl.contains("oracle")) { @@ -55,14 +58,13 @@ public static XADataSource createXADataSource(String url, ConnectionDetails conn */ private static XADataSource createPostgreSQLXADataSource(String url, ConnectionDetails connectionDetails) throws SQLException { try { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); // Check if PostgreSQL driver is available - Class.forName("org.postgresql.xa.PGXADataSource"); - + Class.forName(POSTGRESQL_XA_DATASOURCE, true, classLoader); // Use reflection to create and configure PGXADataSource - XADataSource xaDS = (XADataSource) Class.forName("org.postgresql.xa.PGXADataSource") + XADataSource xaDS = (XADataSource) Class.forName(POSTGRESQL_XA_DATASOURCE, true, classLoader) .getDeclaredConstructor() .newInstance(); - // Parse connection URL to extract host, port, database // Format: jdbc:postgresql://host:port/database or ojp[...]:host:port/database String cleanUrl = url; @@ -71,7 +73,6 @@ private static XADataSource createPostgreSQLXADataSource(String url, ConnectionD } else if (cleanUrl.toLowerCase().startsWith("jdbc:postgresql:")) { cleanUrl = cleanUrl.substring("jdbc:".length()); } - // Parse postgresql://host:port/database if (cleanUrl.startsWith("postgresql://")) { cleanUrl = cleanUrl.substring("postgresql://".length()); @@ -79,27 +80,22 @@ private static XADataSource createPostgreSQLXADataSource(String url, ConnectionD if (parts.length >= 2) { String hostPort = parts[0]; String database = parts[1].split("\\?")[0]; // Remove query params - String[] hostPortParts = hostPort.split(":"); String host = hostPortParts[0]; int port = hostPortParts.length > 1 ? Integer.parseInt(hostPortParts[1]) : 5432; - // Set properties using reflection xaDS.getClass().getMethod("setServerNames", String[].class).invoke(xaDS, (Object) new String[]{host}); xaDS.getClass().getMethod("setPortNumbers", int[].class).invoke(xaDS, (Object) new int[]{port}); xaDS.getClass().getMethod("setDatabaseName", String.class).invoke(xaDS, database); } } - xaDS.getClass().getMethod("setUser", String.class).invoke(xaDS, connectionDetails.getUser()); xaDS.getClass().getMethod("setPassword", String.class).invoke(xaDS, connectionDetails.getPassword()); - // Get server names for logging String[] serverNames = (String[]) xaDS.getClass().getMethod("getServerNames").invoke(xaDS); String host = (serverNames != null && serverNames.length > 0) ? serverNames[0] : "unknown"; log.info("Created PostgreSQL XADataSource for host: {}", host); return xaDS; - } catch (ClassNotFoundException e) { throw new SQLException("PostgreSQL JDBC driver not found. Add postgresql JDBC driver to classpath.", e); } catch (Exception e) { @@ -112,21 +108,18 @@ private static XADataSource createPostgreSQLXADataSource(String url, ConnectionD */ private static XADataSource createMySQLXADataSource(String url, ConnectionDetails connectionDetails) throws SQLException { try { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); // Check if MySQL driver is available - Class.forName("com.mysql.cj.jdbc.MysqlXADataSource"); - + Class.forName("com.mysql.cj.jdbc.MysqlXADataSource", true, classLoader); // Use reflection to create and configure MysqlXADataSource - XADataSource xaDS = (XADataSource) Class.forName("com.mysql.cj.jdbc.MysqlXADataSource") + XADataSource xaDS = (XADataSource) Class.forName("com.mysql.cj.jdbc.MysqlXADataSource", true, classLoader) .getDeclaredConstructor() .newInstance(); - xaDS.getClass().getMethod("setUrl", String.class).invoke(xaDS, url); xaDS.getClass().getMethod("setUser", String.class).invoke(xaDS, connectionDetails.getUser()); xaDS.getClass().getMethod("setPassword", String.class).invoke(xaDS, connectionDetails.getPassword()); - log.info("Created MySQL XADataSource for URL: {}", url); return xaDS; - } catch (ClassNotFoundException e) { throw new SQLException("MySQL JDBC driver not found. Add mysql-connector-j to classpath.", e); } catch (Exception e) { @@ -134,6 +127,35 @@ private static XADataSource createMySQLXADataSource(String url, ConnectionDetail } } + /** + * Creates a MariaDB XADataSource. + */ + private static XADataSource createMariaDBXADataSource(String url, ConnectionDetails connectionDetails) + throws SQLException { + try { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); + // Check if MariaDB driver is available + Class.forName("org.mariadb.jdbc.MariaDbDataSource", true, classLoader); + + // Use reflection to create and configure MariaDbDataSource + XADataSource xaDS = (XADataSource) Class.forName("org.mariadb.jdbc.MariaDbDataSource", true, classLoader) + .getDeclaredConstructor() + .newInstance(); + + xaDS.getClass().getMethod("setUrl", String.class).invoke(xaDS, url); + xaDS.getClass().getMethod("setUser", String.class).invoke(xaDS, connectionDetails.getUser()); + xaDS.getClass().getMethod("setPassword", String.class).invoke(xaDS, connectionDetails.getPassword()); + + log.info("Created MariaDB XADataSource for URL: {}", url); + return xaDS; + + } catch (ClassNotFoundException e) { + throw new SQLException("MariaDB JDBC driver not found. Add mariadb-java-client to classpath.", e); + } catch (Exception e) { + throw new SQLException("Failed to create MariaDB XADataSource: " + e.getMessage(), e); + } + } + /** * Creates an Oracle XADataSource. * @@ -149,33 +171,28 @@ private static XADataSource createMySQLXADataSource(String url, ConnectionDetail */ private static XADataSource createOracleXADataSource(String url, ConnectionDetails connectionDetails) throws SQLException { try { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); // Check if Oracle driver is available - Class.forName("oracle.jdbc.xa.client.OracleXADataSource"); - + Class.forName("oracle.jdbc.xa.client.OracleXADataSource", true, classLoader); // Use reflection to create and configure OracleXADataSource - XADataSource xaDS = (XADataSource) Class.forName("oracle.jdbc.xa.client.OracleXADataSource") + XADataSource xaDS = (XADataSource) Class.forName("oracle.jdbc.xa.client.OracleXADataSource", true, classLoader) .getDeclaredConstructor() .newInstance(); - // Clean the URL - remove OJP wrapper if present String cleanUrl = url; if (cleanUrl.toLowerCase().contains("_oracle:")) { cleanUrl = "jdbc:oracle:" + cleanUrl.substring(cleanUrl.toLowerCase().indexOf("_oracle:") + 8); } - // Parse Oracle connection URL to extract components // Format: jdbc:oracle:thin:@host:port/service or jdbc:oracle:thin:@host:port:sid if (cleanUrl.toLowerCase().startsWith("jdbc:oracle:thin:@")) { String connectionPart = cleanUrl.substring("jdbc:oracle:thin:@".length()); - // Parse host:port/service or host:port:sid String host = "localhost"; int port = 1521; String serviceName = null; - // Set driver type first - required for Oracle to construct proper URL internally xaDS.getClass().getMethod("setDriverType", String.class).invoke(xaDS, "thin"); - if (connectionPart.contains("/")) { // Service name format: host:port/service String[] parts = connectionPart.split("/"); @@ -185,12 +202,10 @@ private static XADataSource createOracleXADataSource(String url, ConnectionDetai port = Integer.parseInt(hostPort[1]); } serviceName = parts[1]; - // Set properties using reflection xaDS.getClass().getMethod("setServerName", String.class).invoke(xaDS, host); xaDS.getClass().getMethod("setPortNumber", int.class).invoke(xaDS, port); xaDS.getClass().getMethod("setServiceName", String.class).invoke(xaDS, serviceName); - } else if (connectionPart.contains(":")) { // SID format: host:port:sid String[] parts = connectionPart.split(":"); @@ -217,10 +232,8 @@ private static XADataSource createOracleXADataSource(String url, ConnectionDetai xaDS.getClass().getMethod("setServerName", String.class).invoke(xaDS, "localhost"); xaDS.getClass().getMethod("setPortNumber", int.class).invoke(xaDS, 1521); } - xaDS.getClass().getMethod("setUser", String.class).invoke(xaDS, connectionDetails.getUser()); xaDS.getClass().getMethod("setPassword", String.class).invoke(xaDS, connectionDetails.getPassword()); - // Oracle XA requires specific properties to work correctly // Set connection properties that enable XA support try { @@ -230,15 +243,12 @@ private static XADataSource createOracleXADataSource(String url, ConnectionDetai props.setProperty("password", connectionDetails.getPassword()); // Oracle XA specific properties props.setProperty("v$session.program", "OJP-XA"); - xaDS.getClass().getMethod("setConnectionProperties", java.util.Properties.class).invoke(xaDS, props); } catch (Exception e) { log.warn("Could not set connection properties on Oracle XADataSource: {}", e.getMessage()); } - log.info("Created Oracle XADataSource for URL: {}", url); return xaDS; - } catch (ClassNotFoundException e) { throw new SQLException("Oracle JDBC driver not found. Add ojdbc (ojdbc8 or ojdbc11) to classpath.", e); } catch (Exception e) { @@ -251,22 +261,19 @@ private static XADataSource createOracleXADataSource(String url, ConnectionDetai */ private static XADataSource createSQLServerXADataSource(String url, ConnectionDetails connectionDetails) throws SQLException { try { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); // Check if SQL Server driver is available - Class.forName("com.microsoft.sqlserver.jdbc.SQLServerXADataSource"); - + Class.forName("com.microsoft.sqlserver.jdbc.SQLServerXADataSource", true, classLoader); // Use reflection to create and configure SQLServerXADataSource - XADataSource xaDS = (XADataSource) Class.forName("com.microsoft.sqlserver.jdbc.SQLServerXADataSource") + XADataSource xaDS = (XADataSource) Class.forName("com.microsoft.sqlserver.jdbc.SQLServerXADataSource", true, classLoader) .getDeclaredConstructor() .newInstance(); - // Set URL using reflection xaDS.getClass().getMethod("setURL", String.class).invoke(xaDS, url); xaDS.getClass().getMethod("setUser", String.class).invoke(xaDS, connectionDetails.getUser()); xaDS.getClass().getMethod("setPassword", String.class).invoke(xaDS, connectionDetails.getPassword()); - log.info("Created SQL Server XADataSource for URL: {}", url); return xaDS; - } catch (ClassNotFoundException e) { throw new SQLException("SQL Server JDBC driver not found. Add mssql-jdbc to classpath.", e); } catch (Exception e) { @@ -279,14 +286,13 @@ private static XADataSource createSQLServerXADataSource(String url, ConnectionDe */ private static XADataSource createDB2XADataSource(String url, ConnectionDetails connectionDetails) throws SQLException { try { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); // Check if DB2 driver is available - Class.forName("com.ibm.db2.jcc.DB2XADataSource"); - + Class.forName("com.ibm.db2.jcc.DB2XADataSource", true, classLoader); // Use reflection to create and configure DB2XADataSource - XADataSource xaDS = (XADataSource) Class.forName("com.ibm.db2.jcc.DB2XADataSource") + XADataSource xaDS = (XADataSource) Class.forName("com.ibm.db2.jcc.DB2XADataSource", true, classLoader) .getDeclaredConstructor() .newInstance(); - // Parse DB2 URL: jdbc:db2://host:port/database String cleanUrl = url; if (cleanUrl.toLowerCase().contains("_db2:")) { @@ -294,7 +300,6 @@ private static XADataSource createDB2XADataSource(String url, ConnectionDetails } else if (cleanUrl.toLowerCase().startsWith("jdbc:db2:")) { cleanUrl = cleanUrl.substring("jdbc:".length()); } - // Parse db2://host:port/database if (cleanUrl.startsWith("db2://")) { cleanUrl = cleanUrl.substring("db2://".length()); @@ -302,11 +307,9 @@ private static XADataSource createDB2XADataSource(String url, ConnectionDetails if (parts.length >= 2) { String hostPort = parts[0]; String database = parts[1].split("\\?")[0]; // Remove query params - String[] hostPortParts = hostPort.split(":"); String host = hostPortParts[0]; int port = hostPortParts.length > 1 ? Integer.parseInt(hostPortParts[1]) : 50000; - // Set properties using reflection xaDS.getClass().getMethod("setServerName", String.class).invoke(xaDS, host); xaDS.getClass().getMethod("setPortNumber", int.class).invoke(xaDS, port); @@ -314,13 +317,10 @@ private static XADataSource createDB2XADataSource(String url, ConnectionDetails xaDS.getClass().getMethod("setDriverType", int.class).invoke(xaDS, 4); // Type 4 driver } } - xaDS.getClass().getMethod("setUser", String.class).invoke(xaDS, connectionDetails.getUser()); xaDS.getClass().getMethod("setPassword", String.class).invoke(xaDS, connectionDetails.getPassword()); - log.info("Created DB2 XADataSource for URL: {}", url); return xaDS; - } catch (ClassNotFoundException e) { throw new SQLException("DB2 JDBC driver not found. Add db2jcc or db2jcc4 to classpath.", e); } catch (Exception e) { @@ -334,14 +334,13 @@ private static XADataSource createDB2XADataSource(String url, ConnectionDetails */ private static XADataSource createCockroachDBXADataSource(String url, ConnectionDetails connectionDetails) throws SQLException { try { + ClassLoader classLoader = Thread.currentThread().getContextClassLoader(); // Check if PostgreSQL driver is available (CockroachDB uses PostgreSQL protocol) - Class.forName("org.postgresql.xa.PGXADataSource"); - + Class.forName(POSTGRESQL_XA_DATASOURCE, true, classLoader); // Use reflection to create and configure PGXADataSource - XADataSource xaDS = (XADataSource) Class.forName("org.postgresql.xa.PGXADataSource") + XADataSource xaDS = (XADataSource) Class.forName(POSTGRESQL_XA_DATASOURCE, true, classLoader) .getDeclaredConstructor() .newInstance(); - // Parse connection URL to extract host, port, database // CockroachDB URL format: jdbc:postgresql://host:port/database String cleanUrl = url; @@ -359,7 +358,6 @@ private static XADataSource createCockroachDBXADataSource(String url, Connection } else if (cleanUrl.toLowerCase().startsWith("jdbc:cockroachdb:")) { cleanUrl = cleanUrl.substring("jdbc:".length()).replace("cockroachdb:", "postgresql:"); } - // Parse postgresql://host:port/database if (cleanUrl.startsWith("postgresql://")) { cleanUrl = cleanUrl.substring("postgresql://".length()); @@ -367,27 +365,22 @@ private static XADataSource createCockroachDBXADataSource(String url, Connection if (parts.length >= 2) { String hostPort = parts[0]; String database = parts[1].split("\\?")[0]; // Remove query params - String[] hostPortParts = hostPort.split(":"); String host = hostPortParts[0]; int port = hostPortParts.length > 1 ? Integer.parseInt(hostPortParts[1]) : 26257; // CockroachDB default port - // Set properties using reflection xaDS.getClass().getMethod("setServerNames", String[].class).invoke(xaDS, (Object) new String[]{host}); xaDS.getClass().getMethod("setPortNumbers", int[].class).invoke(xaDS, (Object) new int[]{port}); xaDS.getClass().getMethod("setDatabaseName", String.class).invoke(xaDS, database); } } - xaDS.getClass().getMethod("setUser", String.class).invoke(xaDS, connectionDetails.getUser()); xaDS.getClass().getMethod("setPassword", String.class).invoke(xaDS, connectionDetails.getPassword()); - // Get server names for logging String[] serverNames = (String[]) xaDS.getClass().getMethod("getServerNames").invoke(xaDS); String host = (serverNames != null && serverNames.length > 0) ? serverNames[0] : "unknown"; log.info("Created CockroachDB XADataSource (using PostgreSQL driver) for host: {}", host); return xaDS; - } catch (ClassNotFoundException e) { throw new SQLException("PostgreSQL JDBC driver not found (required for CockroachDB). Add postgresql JDBC driver to classpath.", e); } catch (Exception e) {