Skip to content

Commit c7d0ebb

Browse files
authored
Merge pull request #9 from trocco-io/feature/bug-fix-for-nonascii-column-name
Fix SQL syntax error in COPY INTO when using non-ascii characters column name
2 parents 54844dc + 2e976a9 commit c7d0ebb

File tree

8 files changed

+439
-28
lines changed

8 files changed

+439
-28
lines changed

example/test.yml.example

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1-
# The catalog_name and schema_name must be created in advance.
1+
# The catalog_name.schema_name, catalog_name.non_ascii_schema_name, non_ascii_catalog_name.non_ascii_schema_name and non_ascii_catalog_name.schema_name must be created in advance.
22

33
server_hostname:
44
http_path:
55
personal_access_token:
66
catalog_name:
77
schema_name:
8+
non_ascii_schema_name:
9+
non_ascii_catalog_name:
810
table_prefix:
911
staging_volume_name_prefix:

src/main/java/org/embulk/output/databricks/DatabricksOutputConnection.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ protected String buildCopySQL(TableIdentifier table, String filePath, JdbcSchema
107107
if (i != 0) {
108108
sb.append(" , ");
109109
}
110-
sb.append(String.format("_c%d::%s %s", i, getCreateTableTypeName(column), column.getName()));
110+
String quotedColumnName = quoteIdentifierString(column.getName());
111+
sb.append(String.format("_c%d::%s %s", i, getCreateTableTypeName(column), quotedColumnName));
111112
}
112113
sb.append(" FROM ");
113114
sb.append(quoteIdentifierString(filePath, "\""));
@@ -120,6 +121,15 @@ protected String buildCopySQL(TableIdentifier table, String filePath, JdbcSchema
120121
return sb.toString();
121122
}
122123

124+
@Override
125+
protected String quoteIdentifierString(String str, String quoteString) {
126+
// https://docs.databricks.com/en/sql/language-manual/sql-ref-identifiers.html
127+
if (quoteString.equals("`")) {
128+
return quoteString + str.replaceAll(quoteString, quoteString + quoteString) + quoteString;
129+
}
130+
return super.quoteIdentifierString(str, quoteString);
131+
}
132+
123133
// This is almost a copy of JdbcOutputConnection except for aggregating fromTables to first from
124134
// table,
125135
// because Databricks MERGE INTO source can only specify a single table.
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
package org.embulk.output.databricks;
2+
3+
import static java.lang.String.format;
4+
import static org.embulk.output.databricks.util.ConnectionUtil.*;
5+
import static org.junit.Assert.assertEquals;
6+
7+
import java.sql.Connection;
8+
import java.sql.DatabaseMetaData;
9+
import java.sql.ResultSet;
10+
import java.sql.SQLException;
11+
import org.embulk.output.databricks.util.ConfigUtil;
12+
import org.embulk.output.jdbc.JdbcUtils;
13+
import org.junit.After;
14+
import org.junit.Before;
15+
import org.junit.Test;
16+
17+
// The purpose of this class is to understand the behavior of DatabaseMetadata,
18+
// so if this test fails due to a library update, please change the test result.
19+
public class TestDatabaseMetadata {
20+
private DatabaseMetaData dbm;
21+
private Connection conn;
22+
23+
ConfigUtil.TestTask t = ConfigUtil.createTestTask();
24+
String catalog = t.getCatalogName();
25+
String schema = t.getSchemaName();
26+
String table = t.getTablePrefix() + "_test";
27+
String nonAsciiCatalog = t.getNonAsciiCatalogName();
28+
String nonAsciiSchema = t.getNonAsciiSchemaName();
29+
String nonAsciiTable = t.getTablePrefix() + "_テスト";
30+
31+
@Before
32+
public void setup() throws SQLException, ClassNotFoundException {
33+
conn = connectByTestTask();
34+
dbm = conn.getMetaData();
35+
run(conn, "USE CATALOG " + catalog);
36+
run(conn, "USE SCHEMA " + schema);
37+
createTables();
38+
}
39+
40+
@After
41+
public void cleanup() {
42+
try {
43+
conn.close();
44+
} catch (SQLException ignored) {
45+
46+
}
47+
dropAllTemporaryTables();
48+
}
49+
50+
@Test
51+
public void testGetPrimaryKeys() throws SQLException {
52+
assertEquals(1, countPrimaryKeys(catalog, schema, table, "a0"));
53+
assertEquals(1, countPrimaryKeys(null, schema, table, "a0"));
54+
assertEquals(1, countPrimaryKeys(nonAsciiCatalog, nonAsciiSchema, nonAsciiTable, "h0"));
55+
assertEquals(1, countPrimaryKeys(null, nonAsciiSchema, nonAsciiTable, "d0"));
56+
}
57+
58+
@Test
59+
public void testGetTables() throws SQLException {
60+
assertEquals(1, countTablesResult(catalog, schema, table));
61+
assertEquals(2, countTablesResult(null, schema, table));
62+
assertEquals(1, countTablesResult(nonAsciiCatalog, nonAsciiSchema, nonAsciiTable));
63+
assertEquals(0, countTablesResult(null, nonAsciiSchema, nonAsciiTable)); // expected 2
64+
}
65+
66+
@Test
67+
public void testGetColumns() throws SQLException {
68+
assertEquals(2, countColumnsResult(catalog, schema, table));
69+
assertEquals(4, countColumnsResult(null, schema, table));
70+
assertEquals(2, countColumnsResult(nonAsciiCatalog, nonAsciiSchema, nonAsciiTable));
71+
assertEquals(0, countColumnsResult(null, nonAsciiSchema, nonAsciiTable)); // expected 2
72+
}
73+
74+
private void createTables() {
75+
String queryFormat =
76+
"CREATE TABLE IF NOT EXISTS `%s`.`%s`.`%s` (%s String PRIMARY KEY, %s INTEGER)";
77+
run(conn, format(queryFormat, catalog, schema, table, "a0", "a1"));
78+
run(conn, format(queryFormat, catalog, schema, nonAsciiTable, "b0", "b1"));
79+
run(conn, format(queryFormat, catalog, nonAsciiSchema, table, "c0", "c1"));
80+
run(conn, format(queryFormat, catalog, nonAsciiSchema, nonAsciiTable, "d0", "d1"));
81+
run(conn, format(queryFormat, nonAsciiCatalog, schema, table, "e0", "e1"));
82+
run(conn, format(queryFormat, nonAsciiCatalog, schema, nonAsciiTable, "f0", "f1"));
83+
run(conn, format(queryFormat, nonAsciiCatalog, nonAsciiSchema, table, "g0", "g1"));
84+
run(conn, format(queryFormat, nonAsciiCatalog, nonAsciiSchema, nonAsciiTable, "h0", "h1"));
85+
}
86+
87+
private int countPrimaryKeys(
88+
String catalogName, String schemaName, String tableName, String primaryKey)
89+
throws SQLException {
90+
try (ResultSet rs = dbm.getPrimaryKeys(catalogName, schemaName, tableName)) {
91+
int count = 0;
92+
while (rs.next()) {
93+
String columnName = rs.getString("COLUMN_NAME");
94+
assertEquals(primaryKey, columnName);
95+
count += 1;
96+
}
97+
return count;
98+
}
99+
}
100+
101+
private int countTablesResult(String catalogName, String schemaName, String tableName)
102+
throws SQLException {
103+
String e = dbm.getSearchStringEscape();
104+
String c = JdbcUtils.escapeSearchString(catalogName, e);
105+
String s = JdbcUtils.escapeSearchString(schemaName, e);
106+
String t = JdbcUtils.escapeSearchString(tableName, e);
107+
try (ResultSet rs = dbm.getTables(c, s, t, null)) {
108+
return countResultSet(rs);
109+
}
110+
}
111+
112+
private int countColumnsResult(String catalogName, String schemaName, String tableName)
113+
throws SQLException {
114+
String e = dbm.getSearchStringEscape();
115+
String c = JdbcUtils.escapeSearchString(catalogName, e);
116+
String s = JdbcUtils.escapeSearchString(schemaName, e);
117+
String t = JdbcUtils.escapeSearchString(tableName, e);
118+
try (ResultSet rs = dbm.getColumns(c, s, t, null)) {
119+
return countResultSet(rs);
120+
}
121+
}
122+
123+
private int countResultSet(ResultSet rs) throws SQLException {
124+
int count = 0;
125+
while (rs.next()) {
126+
count += 1;
127+
}
128+
return count;
129+
}
130+
}

src/test/java/org/embulk/output/databricks/TestDatabricksOutputConnection.java

Lines changed: 49 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
package org.embulk.output.databricks;
22

3+
import static org.embulk.output.databricks.util.ConnectionUtil.*;
4+
import static org.junit.Assert.assertTrue;
5+
36
import java.sql.*;
47
import java.util.*;
58
import java.util.concurrent.Executor;
9+
import org.embulk.output.databricks.util.ConfigUtil;
10+
import org.embulk.output.databricks.util.ConnectionUtil;
611
import org.embulk.output.jdbc.JdbcColumn;
712
import org.embulk.output.jdbc.JdbcSchema;
813
import org.embulk.output.jdbc.MergeConfig;
@@ -11,21 +16,46 @@
1116
import org.junit.Test;
1217

1318
public class TestDatabricksOutputConnection {
19+
@Test
20+
public void testTableExists() throws SQLException, ClassNotFoundException {
21+
ConfigUtil.TestTask t = ConfigUtil.createTestTask();
22+
String asciiTableName = t.getTablePrefix() + "_test";
23+
String nonAsciiTableName = t.getTablePrefix() + "_テスト";
24+
testTableExists(t.getCatalogName(), t.getSchemaName(), asciiTableName);
25+
testTableExists(t.getNonAsciiCatalogName(), t.getSchemaName(), asciiTableName);
26+
testTableExists(t.getCatalogName(), t.getNonAsciiSchemaName(), asciiTableName);
27+
testTableExists(t.getCatalogName(), t.getSchemaName(), nonAsciiTableName);
28+
testTableExists(t.getNonAsciiCatalogName(), t.getNonAsciiSchemaName(), nonAsciiTableName);
29+
}
30+
31+
private void testTableExists(String catalogName, String schemaName, String tableName)
32+
throws SQLException, ClassNotFoundException {
33+
String fullTableName = String.format("`%s`.`%s`.`%s`", catalogName, schemaName, tableName);
34+
try (Connection conn = ConnectionUtil.connectByTestTask()) {
35+
run(conn, "CREATE TABLE IF NOT EXISTS " + fullTableName);
36+
try (DatabricksOutputConnection outputConn =
37+
buildOutputConnection(conn, catalogName, schemaName)) {
38+
assertTrue(outputConn.tableExists(new TableIdentifier(null, null, tableName)));
39+
}
40+
} finally {
41+
run("DROP TABLE IF EXISTS " + fullTableName);
42+
}
43+
}
1444

1545
@Test
16-
public void TestBuildCopySQL() throws SQLException {
17-
try (DatabricksOutputConnection conn = buildOutputConnection()) {
46+
public void testBuildCopySQL() throws SQLException {
47+
try (DatabricksOutputConnection conn = buildDummyOutputConnection()) {
1848
TableIdentifier tableIdentifier = new TableIdentifier("database", "schemaName", "tableName");
1949
String actual = conn.buildCopySQL(tableIdentifier, "filePath", buildJdbcSchema());
2050
String expected =
21-
"COPY INTO `database`.`schemaName`.`tableName` FROM ( SELECT _c0::string col0 , _c1::bigint col1 FROM \"filePath\" ) FILEFORMAT = CSV FORMAT_OPTIONS ( 'nullValue' = '\\\\N' , 'delimiter' = '\\t' )";
51+
"COPY INTO `database`.`schemaName`.`tableName` FROM ( SELECT _c0::string `あ` , _c1::bigint ```` FROM \"filePath\" ) FILEFORMAT = CSV FORMAT_OPTIONS ( 'nullValue' = '\\\\N' , 'delimiter' = '\\t' )";
2252
Assert.assertEquals(expected, actual);
2353
}
2454
}
2555

2656
@Test
27-
public void TestBuildAggregateSQL() throws SQLException {
28-
try (DatabricksOutputConnection conn = buildOutputConnection()) {
57+
public void testBuildAggregateSQL() throws SQLException {
58+
try (DatabricksOutputConnection conn = buildDummyOutputConnection()) {
2959
List<TableIdentifier> fromTableIdentifiers = new ArrayList<>();
3060
fromTableIdentifiers.add(new TableIdentifier("database", "schemaName", "tableName0"));
3161
fromTableIdentifiers.add(new TableIdentifier("database", "schemaName", "tableName1"));
@@ -39,28 +69,28 @@ public void TestBuildAggregateSQL() throws SQLException {
3969
}
4070

4171
@Test
42-
public void TestMergeConfigSQLWithMergeRules() throws SQLException {
72+
public void testMergeConfigSQLWithMergeRules() throws SQLException {
4373
List<String> mergeKeys = buildMergeKeys("col0", "col1");
4474
Optional<List<String>> mergeRules =
4575
buildMergeRules("col0 = CONCAT(T.col0, 'test')", "col1 = T.col1 + S.col1");
4676
String actual = mergeConfigSQL(new MergeConfig(mergeKeys, mergeRules));
4777
String expected =
48-
"MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET col0 = CONCAT(T.col0, 'test'), col1 = T.col1 + S.col1 WHEN NOT MATCHED THEN INSERT (`col0`, `col1`) VALUES (S.`col0`, S.`col1`);";
78+
"MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET col0 = CONCAT(T.col0, 'test'), col1 = T.col1 + S.col1 WHEN NOT MATCHED THEN INSERT (``, ````) VALUES (S.``, S.````);";
4979
Assert.assertEquals(expected, actual);
5080
}
5181

5282
@Test
53-
public void TestMergeConfigSQLWithNoMergeRules() throws SQLException {
83+
public void testMergeConfigSQLWithNoMergeRules() throws SQLException {
5484
List<String> mergeKeys = buildMergeKeys("col0", "col1");
5585
Optional<List<String>> mergeRules = Optional.empty();
5686
String actual = mergeConfigSQL(new MergeConfig(mergeKeys, mergeRules));
5787
String expected =
58-
"MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET `col0` = S.`col0`, `col1` = S.`col1` WHEN NOT MATCHED THEN INSERT (`col0`, `col1`) VALUES (S.`col0`, S.`col1`);";
88+
"MERGE INTO `database`.`schemaName`.`tableName100` T USING `database`.`schemaName`.`tableName9` S ON (T.`col0` = S.`col0` AND T.`col1` = S.`col1`) WHEN MATCHED THEN UPDATE SET `` = S.``, ```` = S.```` WHEN NOT MATCHED THEN INSERT (``, ````) VALUES (S.``, S.````);";
5989
Assert.assertEquals(expected, actual);
6090
}
6191

6292
private String mergeConfigSQL(MergeConfig mergeConfig) throws SQLException {
63-
try (DatabricksOutputConnection conn = buildOutputConnection()) {
93+
try (DatabricksOutputConnection conn = buildDummyOutputConnection()) {
6494
TableIdentifier aggregateToTable =
6595
new TableIdentifier("database", "schemaName", "tableName9");
6696
TableIdentifier toTable = new TableIdentifier("database", "schemaName", "tableName100");
@@ -76,15 +106,21 @@ private Optional<List<String>> buildMergeRules(String... keys) {
76106
return keys.length > 0 ? Optional.of(Arrays.asList(keys)) : Optional.empty();
77107
}
78108

79-
private DatabricksOutputConnection buildOutputConnection() throws SQLException {
109+
private DatabricksOutputConnection buildOutputConnection(
110+
Connection conn, String catalogName, String schemaName)
111+
throws SQLException, ClassNotFoundException {
112+
return new DatabricksOutputConnection(conn, catalogName, schemaName);
113+
}
114+
115+
private DatabricksOutputConnection buildDummyOutputConnection() throws SQLException {
80116
return new DatabricksOutputConnection(
81117
buildDummyConnection(), "defaultCatalogName", "defaultSchemaName");
82118
}
83119

84120
private JdbcSchema buildJdbcSchema() {
85121
List<JdbcColumn> jdbcColumns = new ArrayList<>();
86-
jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("col0", Types.VARCHAR, "string", true, false));
87-
jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("col1", Types.BIGINT, "bigint", true, false));
122+
jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("", Types.VARCHAR, "string", true, false));
123+
jdbcColumns.add(JdbcColumn.newTypeDeclaredColumn("`", Types.BIGINT, "bigint", true, false));
88124
return new JdbcSchema(jdbcColumns);
89125
}
90126

0 commit comments

Comments
 (0)