Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ private CreateTableStatement copy(final CreateTableStatement sqlStatement,
.table(boundTable)
.selectStatement(boundSelectStatement)
.ifNotExists(sqlStatement.isIfNotExists())
.temporary(sqlStatement.isTemporary())
.likeTable(sqlStatement.getLikeTable().orElse(null))
.createTableOption(sqlStatement.getCreateTableOption().orElse(null))
.columnDefinitions(boundColumnDefinitions)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ public DropTableStatement bind(final DropTableStatement sqlStatement, final SQLS
}

private DropTableStatement copy(final DropTableStatement sqlStatement, final Collection<SimpleTableSegment> boundTables) {
DropTableStatement result = new DropTableStatement(sqlStatement.getDatabaseType(), boundTables, sqlStatement.isIfExists(), sqlStatement.isContainsCascade());
DropTableStatement result = new DropTableStatement(sqlStatement.getDatabaseType(), boundTables, sqlStatement.isIfExists(), sqlStatement.isTemporary(), sqlStatement.isContainsCascade());
SQLStatementCopyUtils.copyAttributes(sqlStatement, result);
return result;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.shardingsphere.infra.connection.kernel;

import org.apache.shardingsphere.database.connector.core.type.DatabaseType;
import org.apache.shardingsphere.infra.annotation.HighFrequencyInvocation;
import org.apache.shardingsphere.infra.checker.SupportedSQLCheckEngine;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
import org.apache.shardingsphere.infra.exception.generic.UnsupportedSQLOperationException;
import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContextBuilder;
import org.apache.shardingsphere.infra.executor.sql.log.SQLLogger;
Expand All @@ -31,6 +33,9 @@
import org.apache.shardingsphere.infra.route.context.RouteContext;
import org.apache.shardingsphere.infra.route.engine.SQLRouteEngine;
import org.apache.shardingsphere.infra.session.query.QueryContext;
import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.statement.core.statement.type.ddl.table.CreateTableStatement;
import org.apache.shardingsphere.sql.parser.statement.core.statement.type.ddl.table.DropTableStatement;

/**
* Kernel processor.
Expand All @@ -56,13 +61,32 @@ public ExecutionContext generateExecutionContext(final QueryContext queryContext
}

private void check(final QueryContext queryContext) {
checkUnsupportedTemporaryTableDDL(queryContext);
if (queryContext.getHintValueContext().isSkipMetadataValidate()) {
return;
}
ShardingSphereDatabase database = queryContext.getUsedDatabase();
new SupportedSQLCheckEngine().checkSQL(database.getRuleMetaData().getRules(), queryContext.getSqlStatementContext(), database);
}

private void checkUnsupportedTemporaryTableDDL(final QueryContext queryContext) {
ShardingSphereDatabase database = queryContext.getUsedDatabase();
if (!isMySQL(database.getProtocolType())) {
return;
}
SQLStatement sqlStatement = queryContext.getSqlStatementContext().getSqlStatement();
if (sqlStatement instanceof CreateTableStatement && ((CreateTableStatement) sqlStatement).isTemporary()) {
throw new UnsupportedSQLOperationException("CREATE TEMPORARY TABLE");
}
if (sqlStatement instanceof DropTableStatement && ((DropTableStatement) sqlStatement).isTemporary()) {
throw new UnsupportedSQLOperationException("DROP TEMPORARY TABLE");
}
}

private boolean isMySQL(final DatabaseType databaseType) {
return "MySQL".equalsIgnoreCase(databaseType.getType());
}

private RouteContext route(final QueryContext queryContext, final RuleMetaData globalRuleMetaData, final ConfigurationProperties props) {
ShardingSphereDatabase database = queryContext.getUsedDatabase();
return new SQLRouteEngine(database.getRuleMetaData().getRules(), props).route(queryContext, globalRuleMetaData, database);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,12 @@
package org.apache.shardingsphere.infra.connection.kernel;

import org.apache.shardingsphere.database.connector.core.type.DatabaseType;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.type.CommonSQLStatementContext;
import org.apache.shardingsphere.infra.binder.context.statement.SQLStatementContext;
import org.apache.shardingsphere.infra.checker.SupportedSQLCheckEngine;
import org.apache.shardingsphere.infra.config.props.ConfigurationProperties;
import org.apache.shardingsphere.infra.config.props.ConfigurationPropertyKey;
import org.apache.shardingsphere.infra.exception.generic.UnsupportedSQLOperationException;
import org.apache.shardingsphere.infra.executor.sql.context.ExecutionContext;
import org.apache.shardingsphere.infra.executor.sql.log.SQLLogger;
import org.apache.shardingsphere.infra.hint.HintValueContext;
Expand All @@ -39,7 +40,13 @@
import org.apache.shardingsphere.infra.spi.type.typed.TypedSPILoader;
import org.apache.shardingsphere.infra.util.props.PropertiesBuilder;
import org.apache.shardingsphere.infra.util.props.PropertiesBuilder.Property;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.statement.core.segment.generic.table.TableNameSegment;
import org.apache.shardingsphere.sql.parser.statement.core.statement.SQLStatement;
import org.apache.shardingsphere.sql.parser.statement.core.statement.type.ddl.table.CreateTableStatement;
import org.apache.shardingsphere.sql.parser.statement.core.statement.type.ddl.table.DropTableStatement;
import org.apache.shardingsphere.sql.parser.statement.core.statement.type.dml.SelectStatement;
import org.apache.shardingsphere.sql.parser.statement.core.value.identifier.IdentifierValue;
import org.apache.shardingsphere.sqltranslator.context.SQLTranslatorContext;
import org.apache.shardingsphere.sqltranslator.rule.SQLTranslatorRule;
import org.junit.jupiter.api.Test;
Expand All @@ -55,6 +62,7 @@
import static org.hamcrest.Matchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
Expand Down Expand Up @@ -96,6 +104,33 @@ void assertGenerateExecutionContextWithoutMetadataValidationAndSQLLogging() {
}
}

@Test
void assertGenerateExecutionContextWithTemporaryCreateTable() {
QueryContext queryContext = createQueryContext(createTemporaryCreateTableStatement(), false, createMySQLDatabaseType());
assertThrows(UnsupportedSQLOperationException.class,
() -> new KernelProcessor().generateExecutionContext(queryContext, new RuleMetaData(Arrays.asList(mockSQLTranslatorRule(), mockAggregatedDataSourceRule())), createProps(false)));
}

@Test
void assertGenerateExecutionContextWithTemporaryDropTable() {
QueryContext queryContext = createQueryContext(createTemporaryDropTableStatement(), false, createMySQLDatabaseType());
assertThrows(UnsupportedSQLOperationException.class,
() -> new KernelProcessor().generateExecutionContext(queryContext, new RuleMetaData(Arrays.asList(mockSQLTranslatorRule(), mockAggregatedDataSourceRule())), createProps(false)));
}

@Test
void assertGenerateExecutionContextWithTemporaryCreateTableWhenSkipMetadataValidate() {
QueryContext queryContext = createQueryContext(createTemporaryCreateTableStatement(), true, createMySQLDatabaseType());
try (
MockedConstruction<SupportedSQLCheckEngine> mockedCheckEngines = mockConstruction(SupportedSQLCheckEngine.class);
MockedStatic<SQLLogger> mockedSQLLogger = mockStatic(SQLLogger.class)) {
assertThrows(UnsupportedSQLOperationException.class,
() -> new KernelProcessor().generateExecutionContext(queryContext, new RuleMetaData(Arrays.asList(mockSQLTranslatorRule(), mockAggregatedDataSourceRule())), createProps(false)));
assertTrue(mockedCheckEngines.constructed().isEmpty());
mockedSQLLogger.verifyNoInteractions();
}
}

private SQLTranslatorRule mockSQLTranslatorRule() {
SQLTranslatorRule result = mock(SQLTranslatorRule.class);
when(result.getAttributes()).thenReturn(new RuleAttributes());
Expand All @@ -112,6 +147,10 @@ private ShardingSphereRule mockAggregatedDataSourceRule() {
}

private QueryContext createQueryContext(final boolean skipMetadataValidate) {
return createQueryContext(SelectStatement.builder().databaseType(databaseType).build(), skipMetadataValidate, databaseType);
}

private QueryContext createQueryContext(final SQLStatement sqlStatement, final boolean skipMetadataValidate, final DatabaseType databaseType) {
HintValueContext hintValueContext = new HintValueContext();
hintValueContext.setSkipMetadataValidate(skipMetadataValidate);
ConnectionContext connectionContext = mock(ConnectionContext.class);
Expand All @@ -125,11 +164,30 @@ private QueryContext createQueryContext(final boolean skipMetadataValidate) {
new ConfigurationProperties(new Properties()));
when(metaData.getDatabase("foo_db")).thenReturn(database);
when(metaData.getProps()).thenReturn(new ConfigurationProperties(new Properties()));
SQLStatementContext sqlStatementContext = new CommonSQLStatementContext(SelectStatement.builder().databaseType(databaseType).build());
SQLStatementContext sqlStatementContext = new CommonSQLStatementContext(sqlStatement);
return new QueryContext(sqlStatementContext, "SELECT * FROM tbl", Collections.emptyList(), hintValueContext, connectionContext, metaData);
}

private ConfigurationProperties createProps(final boolean sqlShow) {
return new ConfigurationProperties(PropertiesBuilder.build(new Property(ConfigurationPropertyKey.SQL_SHOW.getKey(), Boolean.toString(sqlShow))));
}

private CreateTableStatement createTemporaryCreateTableStatement() {
return CreateTableStatement.builder()
.databaseType(createMySQLDatabaseType())
.table(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order"))))
.temporary(true)
.build();
}

private DropTableStatement createTemporaryDropTableStatement() {
return new DropTableStatement(createMySQLDatabaseType(),
Collections.singletonList(new SimpleTableSegment(new TableNameSegment(0, 0, new IdentifierValue("t_order")))), false, true, false);
}

private DatabaseType createMySQLDatabaseType() {
DatabaseType result = mock(DatabaseType.class);
when(result.getType()).thenReturn("MySQL");
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ public ASTNode visitCreateTable(final CreateTableContext ctx) {
.databaseType(getDatabaseType())
.table((SimpleTableSegment) visit(ctx.tableName()))
.ifNotExists(null != ctx.ifNotExists())
.temporary(null != ctx.TEMPORARY())
.likeTable(null == ctx.createLikeClause() ? null : (SimpleTableSegment) visit(ctx.createLikeClause()))
.createTableOption(null == ctx.createTableOptions() ? null : (CreateTableOptionSegment) visit(ctx.createTableOptions()))
.columnDefinitions(columnDefinitions)
Expand Down Expand Up @@ -690,7 +691,7 @@ public ASTNode visitPlace(final PlaceContext ctx) {
@SuppressWarnings("unchecked")
@Override
public ASTNode visitDropTable(final DropTableContext ctx) {
return new DropTableStatement(getDatabaseType(), ((CollectionValue<SimpleTableSegment>) visit(ctx.tableList())).getValue(), null != ctx.ifExists(), false);
return new DropTableStatement(getDatabaseType(), ((CollectionValue<SimpleTableSegment>) visit(ctx.tableList())).getValue(), null != ctx.ifExists(), null != ctx.TEMPORARY(), false);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.shardingsphere.sql.parser.engine.mysql.visitor.statement;

import org.apache.shardingsphere.sql.parser.engine.api.CacheOption;
import org.apache.shardingsphere.sql.parser.engine.api.SQLParserEngine;
import org.apache.shardingsphere.sql.parser.engine.api.SQLStatementVisitorEngine;
import org.apache.shardingsphere.sql.parser.engine.core.ParseASTNode;
import org.apache.shardingsphere.sql.parser.statement.core.statement.type.ddl.table.CreateTableStatement;
import org.apache.shardingsphere.sql.parser.statement.core.statement.type.ddl.table.DropTableStatement;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

class MySQLStatementVisitorTest {

private static final CacheOption CACHE_OPTION = new CacheOption(128, 1024L);

@Test
void assertVisitCreateTemporaryTable() {
CreateTableStatement statement = (CreateTableStatement) parse("CREATE TEMPORARY TABLE t_order (order_id INT)");
assertTrue(statement.isTemporary());
}

@Test
void assertVisitCreateTableWithoutTemporary() {
CreateTableStatement statement = (CreateTableStatement) parse("CREATE TABLE t_order (order_id INT)");
assertFalse(statement.isTemporary());
}

@Test
void assertVisitDropTemporaryTable() {
DropTableStatement statement = (DropTableStatement) parse("DROP TEMPORARY TABLE t_order");
assertTrue(statement.isTemporary());
}

@Test
void assertVisitDropTableWithoutTemporary() {
DropTableStatement statement = (DropTableStatement) parse("DROP TABLE t_order");
assertFalse(statement.isTemporary());
}

private Object parse(final String sql) {
ParseASTNode parseASTNode = new SQLParserEngine("MySQL", CACHE_OPTION).parse(sql, false);
return new SQLStatementVisitorEngine("MySQL").visit(parseASTNode);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ public final class CreateTableStatement extends DDLStatement {

private final boolean ifNotExists;

private final boolean temporary;

private final SimpleTableSegment likeTable;

private final CreateTableOptionSegment createTableOption;
Expand All @@ -70,13 +72,14 @@ public final class CreateTableStatement extends DDLStatement {

@Builder
private CreateTableStatement(final DatabaseType databaseType, final SimpleTableSegment table, final SelectStatement selectStatement,
final boolean ifNotExists, final SimpleTableSegment likeTable, final CreateTableOptionSegment createTableOption,
final boolean ifNotExists, final boolean temporary, final SimpleTableSegment likeTable, final CreateTableOptionSegment createTableOption,
final Collection<ColumnDefinitionSegment> columnDefinitions, final Collection<ConstraintDefinitionSegment> constraintDefinitions,
final List<ColumnSegment> columns, final Collection<RollupSegment> rollups) {
super(databaseType);
this.table = table;
this.selectStatement = selectStatement;
this.ifNotExists = ifNotExists;
this.temporary = temporary;
this.likeTable = likeTable;
this.createTableOption = createTableOption;
this.columnDefinitions = null == columnDefinitions ? Collections.emptyList() : columnDefinitions;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,21 @@ public final class DropTableStatement extends DDLStatement {

private final boolean ifExists;

private final boolean temporary;

private final boolean containsCascade;

private final SQLStatementAttributes attributes;

public DropTableStatement(final DatabaseType databaseType, final Collection<SimpleTableSegment> tables, final boolean ifExists, final boolean containsCascade) {
this(databaseType, tables, ifExists, false, containsCascade);
}

public DropTableStatement(final DatabaseType databaseType, final Collection<SimpleTableSegment> tables, final boolean ifExists, final boolean temporary, final boolean containsCascade) {
super(databaseType);
this.tables = tables;
this.ifExists = ifExists;
this.temporary = temporary;
this.containsCascade = containsCascade;
attributes = new SQLStatementAttributes(new TableSQLStatementAttribute(tables));
}
Expand Down
Loading