diff --git a/cmd/go-tpc/ch_benchmark.go b/cmd/go-tpc/ch_benchmark.go index 6ed5661..21f221a 100644 --- a/cmd/go-tpc/ch_benchmark.go +++ b/cmd/go-tpc/ch_benchmark.go @@ -20,8 +20,8 @@ import ( var chConfig ch.Config var ( apConnParams string - apHost string - apPort int + apHosts []string + apPorts []int ) func registerCHBenchmark(root *cobra.Command) { @@ -70,22 +70,20 @@ func registerCHBenchmark(root *cobra.Command) { var cmdRun = &cobra.Command{ Use: "run", Short: "Run workload", + PreRun: func(cmd *cobra.Command, args []string) { + if len(apConnParams) == 0 { + apConnParams = connParams + } + if len(apHosts) == 0 { + apHosts = hosts + } + if len(apPorts) == 0 { + apPorts = ports + } + }, Run: func(cmd *cobra.Command, _ []string) { - executeCH("run", func() string { - origConnParams, origHost, origPort := connParams, host, port - defer func() { - connParams, host, port = origConnParams, origHost, origPort - }() - if len(apConnParams) > 0 { - connParams = apConnParams - } - if len(apHost) > 0 { - host = apHost - } - if apPort > 0 { - port = apPort - } - return buildDSN(false) + executeCH("run", func() (*sql.DB, error) { + return newDB(makeTargets(apHosts, apPorts), driver, user, password, dbName, apConnParams) }) }, } @@ -106,13 +104,13 @@ func registerCHBenchmark(root *cobra.Command) { cmdRun.PersistentFlags().IntSliceVar(&tpccConfig.Weight, "weight", []int{45, 43, 4, 4, 4}, "Weight for NewOrder, Payment, OrderStatus, Delivery, StockLevel") cmdRun.Flags().StringVar(&apConnParams, "ap-conn-params", "", "Connection parameters for analytical processing") - cmdRun.Flags().StringVar(&apHost, "ap-host", "", "Database host for analytical processing") - cmdRun.Flags().IntVar(&apPort, "ap-port", 0, "Database port for analytical processing") + cmdRun.Flags().StringSliceVar(&apHosts, "ap-host", nil, "Database host for analytical processing") + cmdRun.Flags().IntSliceVar(&apPorts, "ap-port", nil, "Database port for analytical processing") cmd.AddCommand(cmdRun, cmdPrepare) root.AddCommand(cmd) } -func executeCH(action string, buildDSNForAP func() string) { +func executeCH(action string, openAP func() (*sql.DB, error)) { runtime.GOMAXPROCS(maxProcs) openDB() @@ -127,11 +125,7 @@ func executeCH(action string, buildDSNForAP func() string) { chConfig.Driver = driver chConfig.DBName = dbName chConfig.QueryNames = strings.Split(chConfig.RawQueries, ",") - if len(apHost) > 0 { - chConfig.PlanReplayerConfig.Host = apHost - } else { - chConfig.PlanReplayerConfig.Host = host - } + chConfig.PlanReplayerConfig.Host = apHosts[0] chConfig.PlanReplayerConfig.StatusPort = statusPort var ( @@ -143,10 +137,10 @@ func executeCH(action string, buildDSNForAP func() string) { fmt.Printf("Failed to init tp work loader: %v\n", err) os.Exit(1) } - if buildDSNForAP == nil { + if openAP == nil { ap = ch.NewWorkloader(globalDB, &chConfig) } else { - db, err := sql.Open(driver, buildDSNForAP()) + db, err := openAP() if err != nil { fmt.Printf("Failed to open db for analytical processing: %v\n", err) os.Exit(1) diff --git a/cmd/go-tpc/main.go b/cmd/go-tpc/main.go index 57b359a..92763d5 100644 --- a/cmd/go-tpc/main.go +++ b/cmd/go-tpc/main.go @@ -2,11 +2,16 @@ package main import ( "context" + "crypto/sha1" "database/sql" + sqldrv "database/sql/driver" + "encoding/hex" "fmt" "os" "os/signal" + "strconv" "strings" + "sync/atomic" "syscall" "time" @@ -14,15 +19,15 @@ import ( "github.com/spf13/cobra" // mysql package - _ "github.com/go-sql-driver/mysql" + "github.com/go-sql-driver/mysql" // pg - _ "github.com/lib/pq" + "github.com/lib/pq" ) var ( dbName string - host string - port int + hosts []string + ports []int statusPort int user string password string @@ -41,6 +46,7 @@ var ( maxProcs int connParams string outputStyle string + targets []string globalDB *sql.DB globalCtx context.Context @@ -52,51 +58,82 @@ const ( pgDriver = "postgres" ) -func closeDB() { - if globalDB != nil { - globalDB.Close() - } - globalDB = nil +type MuxDriver struct { + cursor uint64 + instances []string + internal sqldrv.Driver } -func buildDSN(tmp bool) string { - switch driver { - case mysqlDriver: - if tmp { - return fmt.Sprintf("%s:%s@tcp(%s:%d)/", user, password, host, port) - } - // allow multiple statements in one query to allow q15 on the TPC-H - dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?multiStatements=true", user, password, host, port, dbName) - if len(connParams) > 0 { - dsn = dsn + "&" + connParams +func (drv *MuxDriver) Open(name string) (sqldrv.Conn, error) { + k := atomic.AddUint64(&drv.cursor, 1) + return drv.internal.Open(drv.instances[int(k)%len(drv.instances)]) +} + +func makeTargets(hosts []string, ports []int) []string { + targets := make([]string, 0, len(hosts)*len(ports)) + for _, host := range hosts { + for _, port := range ports { + targets = append(targets, host+":"+strconv.Itoa(port)) } - return dsn - case pgDriver: - if tmp { - return fmt.Sprintf("postgres://%s:%s@%s:%d/?%s", user, password, host, port, connParams) + } + return targets +} + +func newDB(targets []string, driver string, user string, password string, dbName string, connParams string) (*sql.DB, error) { + if len(targets) == 0 { + panic(fmt.Errorf("empty targets")) + } + var ( + drv sqldrv.Driver + hash = sha1.New() + names = make([]string, len(targets)) + ) + hash.Write([]byte(driver)) + hash.Write([]byte(user)) + hash.Write([]byte(password)) + hash.Write([]byte(dbName)) + hash.Write([]byte(connParams)) + for i, addr := range targets { + hash.Write([]byte(addr)) + switch driver { + case mysqlDriver: + // allow multiple statements in one query to allow q15 on the TPC-H + dsn := fmt.Sprintf("%s:%s@tcp(%s)/%s?multiStatements=true", user, password, addr, dbName) + if len(connParams) > 0 { + dsn = dsn + "&" + connParams + } + names[i] = dsn + drv = &mysql.MySQLDriver{} + case pgDriver: + dsn := fmt.Sprintf("postgres://%s:%s@%s/%s", user, password, addr, dbName) + if len(connParams) > 0 { + dsn = dsn + "?" + connParams + } + names[i] = dsn + drv = &pq.Driver{} + default: + panic(fmt.Errorf("unknown driver: %q", driver)) } - dsn := fmt.Sprintf("postgres://%s:%s@%s:%d/%s", user, password, host, port, dbName) - if len(connParams) > 0 { - dsn = dsn + "?" + connParams + } + + if len(names) == 1 { + return sql.Open(driver, names[0]) + } + drvName := driver + "+" + hex.EncodeToString(hash.Sum(nil)) + for _, n := range sql.Drivers() { + if n == drvName { + return sql.Open(drvName, "") } - return dsn - default: - panic(fmt.Errorf("unknown driver: %q", driver)) } + sql.Register(drvName, &MuxDriver{instances: names, internal: drv}) + return sql.Open(drvName, "") } -func isDBNotExist(err error) bool { - if err == nil { - return false - } - switch driver { - case mysqlDriver: - return strings.Contains(err.Error(), "Unknown database") - case pgDriver: - msg := err.Error() - return strings.HasPrefix(msg, "pq: database") && strings.HasSuffix(msg, "does not exist") +func closeDB() { + if globalDB != nil { + globalDB.Close() } - return false + globalDB = nil } func openDB() { @@ -104,13 +141,13 @@ func openDB() { tmpDB *sql.DB err error ) - globalDB, err = sql.Open(driver, buildDSN(false)) + globalDB, err = newDB(targets, driver, user, password, dbName, connParams) if err != nil { panic(err) } if err := globalDB.Ping(); err != nil { if isDBNotExist(err) { - tmpDB, _ = sql.Open(driver, buildDSN(true)) + tmpDB, _ = newDB(targets, driver, user, password, "", connParams) defer tmpDB.Close() if _, err := tmpDB.Exec(createDBDDL + dbName); err != nil { panic(fmt.Errorf("failed to create database, err %v\n", err)) @@ -123,19 +160,38 @@ func openDB() { } } +func isDBNotExist(err error) bool { + if err == nil { + return false + } + switch driver { + case mysqlDriver: + return strings.Contains(err.Error(), "Unknown database") + case pgDriver: + msg := err.Error() + return strings.HasPrefix(msg, "pq: database") && strings.HasSuffix(msg, "does not exist") + } + return false +} + func main() { var rootCmd = &cobra.Command{ Use: "go-tpc", Short: "Benchmark database with different workloads", + PersistentPreRun: func(cmd *cobra.Command, args []string) { + if len(targets) == 0 { + targets = makeTargets(hosts, ports) + } + }, } rootCmd.PersistentFlags().IntVar(&maxProcs, "max-procs", 0, "runtime.GOMAXPROCS") rootCmd.PersistentFlags().StringVar(&pprofAddr, "pprof", "", "Address of pprof endpoint") rootCmd.PersistentFlags().StringVar(&metricsAddr, "metrics-addr", "", "Address of metrics endpoint") rootCmd.PersistentFlags().StringVarP(&dbName, "db", "D", "test", "Database name") - rootCmd.PersistentFlags().StringVarP(&host, "host", "H", "127.0.0.1", "Database host") + rootCmd.PersistentFlags().StringSliceVarP(&hosts, "host", "H", []string{"127.0.0.1"}, "Database host") rootCmd.PersistentFlags().StringVarP(&user, "user", "U", "root", "Database user") rootCmd.PersistentFlags().StringVarP(&password, "password", "p", "", "Database password") - rootCmd.PersistentFlags().IntVarP(&port, "port", "P", 4000, "Database port") + rootCmd.PersistentFlags().IntSliceVarP(&ports, "port", "P", []int{4000}, "Database port") rootCmd.PersistentFlags().IntVarP(&statusPort, "statusPort", "S", 10080, "Database status port") rootCmd.PersistentFlags().IntVarP(&threads, "threads", "T", 1, "Thread concurrency") rootCmd.PersistentFlags().IntVarP(&acThreads, "acThreads", "t", 1, "OLAP client concurrency, only for CH-benCHmark") @@ -151,6 +207,8 @@ func main() { 5: Snapshot, 6: Serializable, 7: Linerizable`) rootCmd.PersistentFlags().StringVar(&connParams, "conn-params", "", "session variables, e.g. for TiDB --conn-params tidb_isolation_read_engines='tiflash', For PostgreSQL: --conn-params sslmode=disable") rootCmd.PersistentFlags().StringVar(&outputStyle, "output", util.OutputStylePlain, "output style, valid values can be { plain | table | json }") + rootCmd.PersistentFlags().StringSliceVar(&targets, "targets", nil, "Target database addresses") + rootCmd.PersistentFlags().MarkHidden("targets") cobra.EnablePrefixMatching = true diff --git a/cmd/go-tpc/rawsql.go b/cmd/go-tpc/rawsql.go index 468d491..418fb93 100644 --- a/cmd/go-tpc/rawsql.go +++ b/cmd/go-tpc/rawsql.go @@ -83,7 +83,7 @@ func execRawsql(action string) { rawsqlConfig.QueryNames = strings.Split(queryFiles, ",") rawsqlConfig.Queries = make(map[string]string, len(rawsqlConfig.QueryNames)) rawsqlConfig.RefreshWait = refreshConnWait - rawsqlConfig.PlanReplayerConfig.Host = host + rawsqlConfig.PlanReplayerConfig.Host = hosts[0] rawsqlConfig.PlanReplayerConfig.StatusPort = statusPort for i, filename := range rawsqlConfig.QueryNames { diff --git a/cmd/go-tpc/tpch.go b/cmd/go-tpc/tpch.go index f8ac9e7..d0f8501 100644 --- a/cmd/go-tpc/tpch.go +++ b/cmd/go-tpc/tpch.go @@ -21,7 +21,7 @@ func executeTpch(action string) { os.Exit(1) } - tpchConfig.PlanReplayerConfig.Host = host + tpchConfig.PlanReplayerConfig.Host = hosts[0] tpchConfig.PlanReplayerConfig.StatusPort = statusPort tpchConfig.OutputStyle = outputStyle