Skip to content

Sync v2 #2180

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Sync v2 #2180

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion spelling_dict.txt
Original file line number Diff line number Diff line change
Expand Up @@ -456,4 +456,5 @@ multipoint
multilinestring
multipolygon
geometrycollection
charlength
charlength
xmls
4 changes: 2 additions & 2 deletions sqle/api/controller/v1/sql_audit_record.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,8 @@ func addSQLsFromFileToTasks(sqls getSQLFromFileResp, task *model.Task, plugin dr
return nil
}

func buildOnlineTaskForAudit(c echo.Context, s *model.Storage, userId uint64, instanceName, instanceSchema, projectUid string, sqls getSQLFromFileResp) (*model.Task, error) {
instance, exist, err := dms.GetInstanceInProjectByName(c.Request().Context(), projectUid, instanceName)
func buildOnlineTaskForAudit(c echo.Context, s *model.Storage, userId uint, instanceName, instanceSchema, projectName string, sqls getSQLFromFileResp) (*model.Task, error) {
instance, exist, err := s.GetInstanceByNameAndProjectName(instanceName, projectName)
if err != nil {
return nil, err
}
Expand Down
32 changes: 5 additions & 27 deletions sqle/driver/mysql/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,19 +493,19 @@ func getCreateTableAndOnCondition(input *RuleHandlerInput) (map[string]*ast.Crea
if stmt.From == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.From.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.From.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.From.TableRefs)
case *ast.UpdateStmt:
if stmt.TableRefs == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.TableRefs.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.TableRefs.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.TableRefs.TableRefs)
case *ast.DeleteStmt:
if stmt.TableRefs == nil {
return nil, nil
}
tableNameCreateTableStmtMap = getTableNameCreateTableStmtMap(input.Ctx, stmt.TableRefs.TableRefs)
tableNameCreateTableStmtMap = input.Ctx.GetTableNameCreateTableStmtMap(stmt.TableRefs.TableRefs)
onConditions = util.GetTableFromOnCondition(stmt.TableRefs.TableRefs)
default:
return nil, nil
Expand Down Expand Up @@ -696,28 +696,6 @@ func getTableNameCreateTableStmtMapForJoinType(sessionContext *session.Context,
return tableNameCreateTableStmtMap
}

func getTableNameCreateTableStmtMap(sessionContext *session.Context, joinStmt *ast.Join) map[string] /*table name or alias table name*/ *ast.CreateTableStmt {
tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt)
tableSources := util.GetTableSources(joinStmt)
for _, tableSource := range tableSources {
if tableNameStmt, ok := tableSource.Source.(*ast.TableName); ok {
tableName := tableNameStmt.Name.L
if tableSource.AsName.L != "" {
// 如果使用别名,则需要用别名引用
tableName = tableSource.AsName.L
}

createTableStmt, exist, err := sessionContext.GetCreateTableStmt(tableNameStmt)
if err != nil || !exist {
continue
}
// TODO: 跨库的 JOIN 无法区分
tableNameCreateTableStmtMap[tableName] = createTableStmt
}
}
return tableNameCreateTableStmtMap
}

func getOnConditionLeftAndRightType(onCondition *ast.OnCondition, createTableStmtMap map[string]*ast.CreateTableStmt) (byte, byte) {
var leftType, rightType byte
// onCondition在中的ColumnNameExpr.Refer为nil无法索引到原表名和表别名
Expand Down Expand Up @@ -3259,7 +3237,7 @@ func checkWhereConditionUseIndex(ctx *session.Context, whereVisitor *util.WhereW
continue
}

tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(ctx, whereExpr.TableRef)
tableNameCreateTableStmtMap := ctx.GetTableNameCreateTableStmtMap(whereExpr.TableRef)
util.ScanWhereStmt(func(expr ast.ExprNode) (skip bool) {
switch x := expr.(type) {
case *ast.ColumnNameExpr:
Expand Down Expand Up @@ -5465,7 +5443,7 @@ func judgeJoinFieldUseIndex(input *RuleHandlerInput) (bool, error) {
// 如果SQL没有JOIN多表,则不需要审核
return true, fmt.Errorf("sql have not join node")
}
tableNameCreateTableStmtMap := getTableNameCreateTableStmtMap(input.Ctx, joinNode)
tableNameCreateTableStmtMap := input.Ctx.GetTableNameCreateTableStmtMap(joinNode)
tableIndexes := make(map[string][]*ast.Constraint, len(tableNameCreateTableStmtMap))
for tableName, createTableStmt := range tableNameCreateTableStmtMap {
tableIndexes[tableName] = createTableStmt.Constraints
Expand Down
22 changes: 22 additions & 0 deletions sqle/driver/mysql/session/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -1121,3 +1121,25 @@ func (c *Context) GetExecutor() *executor.Executor {
func (c *Context) GetTableIndexesInfo(schema, tableName string) ([]*executor.TableIndexesInfo, error) {
return c.e.GetTableIndexesInfo(utils.SupplementalQuotationMarks(schema), utils.SupplementalQuotationMarks(tableName))
}

func (c *Context) GetTableNameCreateTableStmtMap(joinStmt *ast.Join) map[string] /*table name or alias table name*/ *ast.CreateTableStmt {
tableNameCreateTableStmtMap := make(map[string]*ast.CreateTableStmt)
tableSources := util.GetTableSources(joinStmt)
for _, tableSource := range tableSources {
if tableNameStmt, ok := tableSource.Source.(*ast.TableName); ok {
tableName := tableNameStmt.Name.L
if tableSource.AsName.L != "" {
// 如果使用别名,则需要用别名引用
tableName = tableSource.AsName.L
}

createTableStmt, exist, err := c.GetCreateTableStmt(tableNameStmt)
if err != nil || !exist {
continue
}
// TODO: 跨库的 JOIN 无法区分
tableNameCreateTableStmtMap[tableName] = createTableStmt
}
}
return tableNameCreateTableStmtMap
}