diff --git a/connection.go b/connection.go index 4b16921..674b7d6 100644 --- a/connection.go +++ b/connection.go @@ -129,7 +129,7 @@ func (conn *odpsConn) wait(query string, args []driver.Value) (string, error) { query = fmt.Sprintf(query, args) } - ins, err := conn.createInstance(newSQLJob(query)) + ins, err := conn.createInstance(newSQLJob(query, conn.QueryHints)) if err != nil { return "", err } diff --git a/driver_test.go b/driver_test.go index f728290..b9d963d 100644 --- a/driver_test.go +++ b/driver_test.go @@ -9,10 +9,11 @@ import ( ) var cfg4test = &Config{ - AccessID: os.Getenv("ODPS_ACCESS_ID"), - AccessKey: os.Getenv("ODPS_ACCESS_KEY"), - Project: os.Getenv("ODPS_PROJECT"), - Endpoint: os.Getenv("ODPS_ENDPOINT"), + AccessID: os.Getenv("ODPS_ACCESS_ID"), + AccessKey: os.Getenv("ODPS_ACCESS_KEY"), + Project: os.Getenv("ODPS_PROJECT"), + Endpoint: os.Getenv("ODPS_ENDPOINT"), + QueryHints: map[string]string{"odps.sql.mapper.split_size": "16"}, } func TestSQLOpen(t *testing.T) { @@ -21,3 +22,11 @@ func TestSQLOpen(t *testing.T) { defer db.Close() a.NoError(err) } + +func TestQuerySettings(t *testing.T) { + a := assert.New(t) + db, err := sql.Open("maxcompute", cfg4test.FormatDSN()) + a.NoError(err) + _, err = db.Query("SELECT * FROM gomaxcompute_test LIMIT;") + a.NoError(err) +} diff --git a/dsn.go b/dsn.go index 399d98d..13db749 100644 --- a/dsn.go +++ b/dsn.go @@ -2,6 +2,7 @@ package gomaxcompute import ( "fmt" + "net/url" "regexp" "strings" ) @@ -12,11 +13,14 @@ var ( reQuery = regexp.MustCompile(`^([a-zA-Z0-9_-]+)=([a-zA-Z0-9_-]*)$`) ) +const HINT_PREFIX = "hint_" + type Config struct { - AccessID string - AccessKey string - Project string - Endpoint string + AccessID string + AccessKey string + Project string + Endpoint string + QueryHints map[string]string } func ParseDSN(dsn string) (*Config, error) { @@ -26,32 +30,41 @@ func ParseDSN(dsn string) (*Config, error) { } id, key, endpointURL := sub[1], sub[2], sub[3] - query := make(map[string]string) - for _, s := range strings.Split(sub[4], "&") { - pair := reQuery.FindStringSubmatch(s) - if len(pair) != 3 { - return nil, fmt.Errorf("dsn %s doesn't match access_id:access_key@url?curr_project=project&scheme=http|https", dsn) - } - if pair[1] != "scheme" && pair[1] != currentProject { - return nil, fmt.Errorf("dsn %s 's query is neither scheme or %s", dsn, currentProject) - } - query[pair[1]] = pair[2] + var schemeArgs []string + var currProjArgs []string + var ok bool + queryHints := make(map[string]string) + + querys, err := url.ParseQuery(sub[4]) + if err != nil { + return nil, err } - if _, ok := query[currentProject]; !ok { - return nil, fmt.Errorf("dsn %s doesn't have curr_project", dsn) + + if schemeArgs, ok = querys["scheme"]; !ok || len(schemeArgs) != 1 { + return nil, fmt.Errorf("dsn %s should have one scheme argument", dsn) } - if _, ok := query["scheme"]; !ok { - return nil, fmt.Errorf("dsn %s doesn't have scheme", dsn) + if currProjArgs, ok = querys[currentProject]; !ok || len(currProjArgs) != 1 { + return nil, fmt.Errorf("dsn %s should have one current_project argument", dsn) } - if query["scheme"] != "http" && query["scheme"] != "https" { + + for k, v := range querys { + // The query args such as hints_odps.sql.mapper.split_size=16 + // would be converted to the maxcompute query hints: {"odps.sql.mapper.split_size": "16"} + if strings.HasPrefix(k, HINT_PREFIX) { + queryHints[k[5:]] = v[0] + } + } + + if schemeArgs[0] != "http" && schemeArgs[0] != "https" { return nil, fmt.Errorf("dsn %s 's scheme is neither http nor https", dsn) } config := &Config{ - AccessID: id, - AccessKey: key, - Project: query[currentProject], - Endpoint: query["scheme"] + "://" + endpointURL} + AccessID: id, + AccessKey: key, + Project: currProjArgs[0], + Endpoint: schemeArgs[0] + "://" + endpointURL, + QueryHints: queryHints} return config, nil } @@ -62,6 +75,12 @@ func (cfg *Config) FormatDSN() string { return "" } scheme, endpointURL := pair[0], pair[1] - return fmt.Sprintf("%s:%s@%s?curr_project=%s&scheme=%s", + dsnFormt := fmt.Sprintf("%s:%s@%s?curr_project=%s&scheme=%s", cfg.AccessID, cfg.AccessKey, endpointURL, cfg.Project, scheme) + if len(cfg.QueryHints) != 0 { + for k, v := range cfg.QueryHints { + dsnFormt = fmt.Sprintf("%s&hint_%s=%v", dsnFormt, k, v) + } + } + return dsnFormt } diff --git a/dsn_test.go b/dsn_test.go index 788cc6b..e6f5553 100644 --- a/dsn_test.go +++ b/dsn_test.go @@ -9,13 +9,14 @@ import ( func TestConfig_ParseDSN(t *testing.T) { a := assert.New(t) - correct := "access_id:access_key@service.com/api?curr_project=test_ci&scheme=http" + correct := "access_id:access_key@service.com/api?curr_project=test_ci&scheme=http&hint_odps.sql.mapper.split_size=16" config, err := ParseDSN(correct) a.NoError(err) a.Equal("access_id", config.AccessID) a.Equal("access_key", config.AccessKey) a.Equal("test_ci", config.Project) a.Equal("http://service.com/api", config.Endpoint) + a.Equal("16", config.QueryHints["odps.sql.mapper.split_size"]) badDSN := []string{ "", // empty @@ -46,11 +47,13 @@ func TestConfig_ParseDSN(t *testing.T) { func TestConfig_FormatDSN(t *testing.T) { a := assert.New(t) config := Config{ - AccessID: "access_id", - AccessKey: "access_key", - Project: "test_ci", - Endpoint: "http://service.com/api"} - a.Equal("access_id:access_key@service.com/api?curr_project=test_ci&scheme=http", config.FormatDSN()) + AccessID: "access_id", + AccessKey: "access_key", + Project: "test_ci", + Endpoint: "http://service.com/api", + QueryHints: map[string]string{"odps.sql.mapper.split_size": "16"}} + a.Equal("access_id:access_key@service.com/api?curr_project="+ + "test_ci&scheme=http&hint_odps.sql.mapper.split_size=16", config.FormatDSN()) } func TestConfig_ParseAndFormatRoundTrip(t *testing.T) { diff --git a/job.go b/job.go index 8fbb63b..e56d42d 100644 --- a/job.go +++ b/job.go @@ -25,8 +25,8 @@ func newJob(tasks ...odpsTask) *odpsJob { } } -func newSQLJob(sql string) *odpsJob { - return newJob(newAnonymousSQLTask(sql, nil)) +func newSQLJob(sql string, hints map[string]string) *odpsJob { + return newJob(newAnonymousSQLTask(sql, hints)) } func (j *odpsJob) MarshalXML(e *xml.Encoder, start xml.StartElement) (err error) { diff --git a/task.go b/task.go index 139e63d..15c8bcb 100644 --- a/task.go +++ b/task.go @@ -73,6 +73,10 @@ func newSQLTask(name, query string, config map[string]string) odpsTask { "uuid": uuid.NewV4().String(), "settings": `{"odps.sql.udf.strict.mode": "true"}`, } + } else { + if _, ok := config["uuid"]; !ok { + config["uuid"] = uuid.NewV4().String() + } } // maxcompute sql ends with a ';' query = strings.TrimSpace(query)