diff --git a/ch/ddl.go b/ch/ddl.go index 245136e..67c6792 100644 --- a/ch/ddl.go +++ b/ch/ddl.go @@ -3,6 +3,7 @@ package ch import ( "context" "fmt" + "github.com/pingcap/go-tpc/pkg/util" ) var allTables []string @@ -15,6 +16,9 @@ func init() { func (w *Workloader) createTableDDL(ctx context.Context, query string, tableName string, action string) error { s := w.getState(ctx) fmt.Printf("%s %s\n", action, tableName) + if ctx.Value("risingwave") != nil && ctx.Value("risingwave").(bool) { + query = util.ConvertToRisingWaveDDL(query) + } if _, err := s.Conn.ExecContext(ctx, query); err != nil { return err } diff --git a/cmd/go-tpc/ch_benchmark.go b/cmd/go-tpc/ch_benchmark.go index be92693..37683d6 100644 --- a/cmd/go-tpc/ch_benchmark.go +++ b/cmd/go-tpc/ch_benchmark.go @@ -66,14 +66,6 @@ func registerCHBenchmark(root *cobra.Command) { "tidb_index_serial_scan_concurrency", 1, "tidb_index_serial_scan_concurrency param for analyze jobs") - cmdPrepare.PersistentFlags().BoolVar(&chConfig.OnlyDdl, - "only-ddl", - false, - "ch prepare only ddl (default false)") - cmdPrepare.PersistentFlags().BoolVar(&chConfig.SkipDdl, - "skip-ddl", - false, - "ch prepare skip ddl (default false)") var cmdRun = &cobra.Command{ Use: "run", @@ -134,9 +126,13 @@ func executeCH(action string, openAP func() (*sql.DB, error)) { tpccConfig.DBName = dbName tpccConfig.Threads = threads tpccConfig.Isolation = isolationLevel + tpccConfig.SkipDdl = skipDdl + tpccConfig.OnlyDdl = onlyDdl chConfig.OutputStyle = outputStyle chConfig.Driver = driver chConfig.DBName = dbName + chConfig.OnlyDdl = onlyDdl + chConfig.SkipDdl = skipDdl chConfig.QueryNames = strings.Split(chConfig.RawQueries, ",") if action == "run" { chConfig.PlanReplayerConfig.Host = apHosts[0] diff --git a/cmd/go-tpc/main.go b/cmd/go-tpc/main.go index 92763d5..7fab1bf 100644 --- a/cmd/go-tpc/main.go +++ b/cmd/go-tpc/main.go @@ -47,9 +47,11 @@ var ( connParams string outputStyle string targets []string - - globalDB *sql.DB - globalCtx context.Context + skipDdl bool + onlyDdl bool + risingwave bool + globalDB *sql.DB + globalCtx context.Context ) const ( @@ -209,6 +211,9 @@ func main() { rootCmd.PersistentFlags().StringVar(&outputStyle, "output", util.OutputStylePlain, "output style, valid values can be { plain | table | json }") rootCmd.PersistentFlags().StringSliceVar(&targets, "targets", nil, "Target database addresses") rootCmd.PersistentFlags().MarkHidden("targets") + rootCmd.PersistentFlags().BoolVar(&skipDdl, "skip-ddl", false, "Skip DDL operations") + rootCmd.PersistentFlags().BoolVar(&onlyDdl, "only-ddl", false, "Only DDL operations") + rootCmd.PersistentFlags().BoolVar(&risingwave, "risingwave", false, "Convert DDL to support RisingWave") cobra.EnablePrefixMatching = true diff --git a/cmd/go-tpc/misc.go b/cmd/go-tpc/misc.go index 3cf04a9..2dac723 100644 --- a/cmd/go-tpc/misc.go +++ b/cmd/go-tpc/misc.go @@ -41,6 +41,7 @@ func execute(timeoutCtx context.Context, w workload.Workloader, action string, t count := totalCount / threads ctx := w.InitThread(context.Background(), index) + ctx = context.WithValue(ctx, "risingwave", risingwave) defer w.CleanupThread(ctx, index) switch action { diff --git a/cmd/go-tpc/tpcc.go b/cmd/go-tpc/tpcc.go index 8892137..ffe02ae 100644 --- a/cmd/go-tpc/tpcc.go +++ b/cmd/go-tpc/tpcc.go @@ -49,6 +49,8 @@ func executeTpcc(action string) { tpccConfig.DBName = dbName tpccConfig.Threads = threads tpccConfig.Isolation = isolationLevel + tpccConfig.OnlyDdl = onlyDdl + tpccConfig.SkipDdl = skipDdl var ( w workload.Workloader err error @@ -102,8 +104,6 @@ func registerTpcc(root *cobra.Command) { "generating file, separated by ','. Valid only if output is set. If this flag is not set, generate all tables by default") cmdPrepare.PersistentFlags().IntVar(&tpccConfig.PrepareRetryCount, "retry-count", 50, "Retry count when errors occur") cmdPrepare.PersistentFlags().DurationVar(&tpccConfig.PrepareRetryInterval, "retry-interval", 10*time.Second, "The interval for each retry") - cmdPrepare.PersistentFlags().BoolVar(&tpccConfig.OnlyDdl, "only-ddl", false, "TPCC prepare ddl only (default false)") - cmdPrepare.PersistentFlags().BoolVar(&tpccConfig.SkipDdl, "skip-ddl", false, "TPCC prepare skip ddl (default false)") var cmdRun = &cobra.Command{ Use: "run", diff --git a/cmd/go-tpc/tpch.go b/cmd/go-tpc/tpch.go index 15549ed..b93ca11 100644 --- a/cmd/go-tpc/tpch.go +++ b/cmd/go-tpc/tpch.go @@ -30,6 +30,8 @@ func executeTpch(action string) { tpchConfig.Driver = driver tpchConfig.DBName = dbName tpchConfig.PrepareThreads = threads + tpchConfig.OnlyDdl = onlyDdl + tpchConfig.SkipDdl = skipDdl tpchConfig.QueryNames = strings.Split(tpchConfig.RawQueries, ",") if action == "prepare" && tpchConfig.OutputType == "kafka" { if dropData { @@ -132,16 +134,6 @@ func registerTpch(root *cobra.Command) { 20, "kafka flush timeout seconds", ) - cmdPrepare.PersistentFlags().BoolVar(&tpchConfig.SkipDdl, - "skip-ddl", - false, - "tpch prepare skip ddl (default false)", - ) - cmdPrepare.PersistentFlags().BoolVar(&tpchConfig.OnlyDdl, - "only-ddl", - false, - "tpch prepare only ddl (default false)", - ) var cmdRun = &cobra.Command{ Use: "run", diff --git a/pkg/util/risingwave.go b/pkg/util/risingwave.go new file mode 100644 index 0000000..0d949d5 --- /dev/null +++ b/pkg/util/risingwave.go @@ -0,0 +1,14 @@ +package util + +import ( + "regexp" +) + +func ConvertToRisingWaveDDL(query string) string { + query = regexp.MustCompile("(?i)varchar\\(\\d+\\)").ReplaceAllString(query, "VARCHAR") + query = regexp.MustCompile("(?i)numeric\\(.*?\\)").ReplaceAllString(query, "NUMERIC") + query = regexp.MustCompile("(?i)decimal\\(.*?\\)").ReplaceAllString(query, "DECIMAL") + query = regexp.MustCompile("(?i)char\\(\\d+\\)").ReplaceAllString(query, "VARCHAR") + query = regexp.MustCompile("(?i) not null").ReplaceAllString(query, "") + return query +} diff --git a/tpcc/ddl.go b/tpcc/ddl.go index ab50986..6ccd431 100644 --- a/tpcc/ddl.go +++ b/tpcc/ddl.go @@ -3,6 +3,7 @@ package tpcc import ( "context" "fmt" + "github.com/pingcap/go-tpc/pkg/util" ) const ( @@ -31,6 +32,9 @@ func newDDLManager(parts int, useFK bool, warehouses, partitionType int) *ddlMan func (w *ddlManager) createTableDDL(ctx context.Context, query string, tableName string) error { s := getTPCCState(ctx) fmt.Printf("creating table %s\n", tableName) + if ctx.Value("risingwave") != nil && ctx.Value("risingwave").(bool) { + query = util.ConvertToRisingWaveDDL(query) + } if _, err := s.Conn.ExecContext(ctx, query); err != nil { return err } diff --git a/tpch/ddl.go b/tpch/ddl.go index e14bfcb..0e222b6 100644 --- a/tpch/ddl.go +++ b/tpch/ddl.go @@ -3,6 +3,7 @@ package tpch import ( "context" "fmt" + "github.com/pingcap/go-tpc/pkg/util" ) var AllTables []string @@ -14,6 +15,9 @@ func init() { func (w *Workloader) createTableDDL(ctx context.Context, query string, tableName string, action string) error { s := w.getState(ctx) fmt.Printf("%s %s\n", action, tableName) + if ctx.Value("risingwave") != nil && ctx.Value("risingwave").(bool) { + query = util.ConvertToRisingWaveDDL(query) + } if _, err := s.Conn.ExecContext(ctx, query); err != nil { return err }