diff --git a/detect.go b/detect.go index c3695510b..1485aaa97 100644 --- a/detect.go +++ b/detect.go @@ -23,6 +23,7 @@ var Detectors []Detector func init() { Detectors = []Detector{ new(GitHubDetector), + new(GitDetector), new(BitBucketDetector), new(S3Detector), new(FileDetector), diff --git a/detect_git.go b/detect_git.go new file mode 100644 index 000000000..eeb8a04c5 --- /dev/null +++ b/detect_git.go @@ -0,0 +1,26 @@ +package getter + +// GitDetector implements Detector to detect Git SSH URLs such as +// git@host.com:dir1/dir2 and converts them to proper URLs. +type GitDetector struct{} + +func (d *GitDetector) Detect(src, _ string) (string, bool, error) { + if len(src) == 0 { + return "", false, nil + } + + u, err := detectSSH(src) + if err != nil { + return "", true, err + } + if u == nil { + return "", false, nil + } + + // We require the username to be "git" to assume that this is a Git URL + if u.User.Username() != "git" { + return "", false, nil + } + + return "git::" + u.String(), true, nil +} diff --git a/detect_git_test.go b/detect_git_test.go new file mode 100644 index 000000000..c899be0b8 --- /dev/null +++ b/detect_git_test.go @@ -0,0 +1,38 @@ +package getter + +import ( + "testing" +) + +func TestGitDetector(t *testing.T) { + cases := []struct { + Input string + Output string + }{ + {"git@github.com:hashicorp/foo.git", "git::ssh://git@github.com/hashicorp/foo.git"}, + { + "git@github.com:hashicorp/foo.git//bar", + "git::ssh://git@github.com/hashicorp/foo.git//bar", + }, + { + "git@github.com:hashicorp/foo.git?foo=bar", + "git::ssh://git@github.com/hashicorp/foo.git?foo=bar", + }, + } + + pwd := "/pwd" + f := new(GitDetector) + for i, tc := range cases { + output, ok, err := f.Detect(tc.Input, pwd) + if err != nil { + t.Fatalf("err: %s", err) + } + if !ok { + t.Fatal("not ok") + } + + if output != tc.Output { + t.Fatalf("%d: bad: %#v", i, output) + } + } +} diff --git a/detect_github.go b/detect_github.go index c084ad9ac..4bf4daf23 100644 --- a/detect_github.go +++ b/detect_github.go @@ -17,8 +17,6 @@ func (d *GitHubDetector) Detect(src, _ string) (string, bool, error) { if strings.HasPrefix(src, "github.com/") { return d.detectHTTP(src) - } else if strings.HasPrefix(src, "git@github.com:") { - return d.detectSSH(src) } return "", false, nil @@ -47,27 +45,3 @@ func (d *GitHubDetector) detectHTTP(src string) (string, bool, error) { return "git::" + url.String(), true, nil } - -func (d *GitHubDetector) detectSSH(src string) (string, bool, error) { - idx := strings.Index(src, ":") - qidx := strings.Index(src, "?") - if qidx == -1 { - qidx = len(src) - } - - var u url.URL - u.Scheme = "ssh" - u.User = url.User("git") - u.Host = "github.com" - u.Path = src[idx+1 : qidx] - if qidx < len(src) { - q, err := url.ParseQuery(src[qidx+1:]) - if err != nil { - return "", true, fmt.Errorf("error parsing GitHub SSH URL: %s", err) - } - - u.RawQuery = q.Encode() - } - - return "git::" + u.String(), true, nil -} diff --git a/detect_github_test.go b/detect_github_test.go index 43ed9fcc6..70f1c8329 100644 --- a/detect_github_test.go +++ b/detect_github_test.go @@ -24,17 +24,6 @@ func TestGitHubDetector(t *testing.T) { "github.com/hashicorp/foo.git?foo=bar", "git::https://github.com/hashicorp/foo.git?foo=bar", }, - - // SSH - {"git@github.com:hashicorp/foo.git", "git::ssh://git@github.com/hashicorp/foo.git"}, - { - "git@github.com:hashicorp/foo.git//bar", - "git::ssh://git@github.com/hashicorp/foo.git//bar", - }, - { - "git@github.com:hashicorp/foo.git?foo=bar", - "git::ssh://git@github.com/hashicorp/foo.git?foo=bar", - }, } pwd := "/pwd" diff --git a/detect_ssh.go b/detect_ssh.go new file mode 100644 index 000000000..c0dbe9d47 --- /dev/null +++ b/detect_ssh.go @@ -0,0 +1,49 @@ +package getter + +import ( + "fmt" + "net/url" + "regexp" + "strings" +) + +// Note that we do not have an SSH-getter currently so this file serves +// only to hold the detectSSH helper that is used by other detectors. + +// sshPattern matches SCP-like SSH patterns (user@host:path) +var sshPattern = regexp.MustCompile("^(?:([^@]+)@)?([^:]+):/?(.+)$") + +// detectSSH determines if the src string matches an SSH-like URL and +// converts it into a net.URL compatible string. This returns nil if the +// string doesn't match the SSH pattern. +// +// This function is tested indirectly via detect_git_test.go +func detectSSH(src string) (*url.URL, error) { + matched := sshPattern.FindStringSubmatch(src) + if matched == nil { + return nil, nil + } + + user := matched[1] + host := matched[2] + path := matched[3] + qidx := strings.Index(path, "?") + if qidx == -1 { + qidx = len(path) + } + + var u url.URL + u.Scheme = "ssh" + u.User = url.User(user) + u.Host = host + u.Path = path[0:qidx] + if qidx < len(path) { + q, err := url.ParseQuery(path[qidx+1:]) + if err != nil { + return nil, fmt.Errorf("error parsing GitHub SSH URL: %s", err) + } + u.RawQuery = q.Encode() + } + + return &u, nil +} diff --git a/detect_test.go b/detect_test.go index 2a1c20205..9bef662a7 100644 --- a/detect_test.go +++ b/detect_test.go @@ -1,6 +1,7 @@ package getter import ( + "fmt" "testing" ) @@ -37,21 +38,55 @@ func TestDetect(t *testing.T) { "git::https://github.com/hashicorp/consul.git", false, }, + { + "git::https://person@someothergit.com/foo/bar", + "", + "git::https://person@someothergit.com/foo/bar", + false, + }, + { + "git::https://person@someothergit.com/foo/bar", + "/bar", + "git::https://person@someothergit.com/foo/bar", + false, + }, { "./foo/archive//*", "/bar", "file:///bar/foo/archive//*", false, }, + + // https://github.com/hashicorp/go-getter/pull/124 + { + "git::ssh://git@my.custom.git/dir1/dir2", + "", + "git::ssh://git@my.custom.git/dir1/dir2", + false, + }, + { + "git::git@my.custom.git:dir1/dir2", + "/foo", + "git::ssh://git@my.custom.git/dir1/dir2", + false, + }, + { + "git::git@my.custom.git:dir1/dir2", + "", + "git::ssh://git@my.custom.git/dir1/dir2", + false, + }, } for i, tc := range cases { - output, err := Detect(tc.Input, tc.Pwd, Detectors) - if err != nil != tc.Err { - t.Fatalf("%d: bad err: %s", i, err) - } - if output != tc.Output { - t.Fatalf("%d: bad output: %s\nexpected: %s", i, output, tc.Output) - } + t.Run(fmt.Sprintf("%d %s", i, tc.Input), func(t *testing.T) { + output, err := Detect(tc.Input, tc.Pwd, Detectors) + if err != nil != tc.Err { + t.Fatalf("%d: bad err: %s", i, err) + } + if output != tc.Output { + t.Fatalf("%d: bad output: %s\nexpected: %s", i, output, tc.Output) + } + }) } }