Skip to content

Commit

Permalink
avoid breaking the Getter interface
Browse files Browse the repository at this point in the history
* this 'unbreaks' the Getter interface and uses the context from the client that was set with Getter.SetClient
  • Loading branch information
azr committed Jan 9, 2019
1 parent a8e824b commit 90f6e2c
Show file tree
Hide file tree
Showing 15 changed files with 88 additions and 76 deletions.
4 changes: 2 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ func (c *Client) Get() error {
}
}
if getFile {
err := g.GetFile(c.ctx, dst, u)
err := g.GetFile(dst, u)
if err != nil {
return err
}
Expand Down Expand Up @@ -269,7 +269,7 @@ func (c *Client) Get() error {

// We're downloading a directory, which might require a bit more work
// if we're specifying a subdir.
err := g.Get(c.ctx, dst, u)
err := g.Get(dst, u)
if err != nil {
err = fmt.Errorf("error downloading '%s': %s", src, err)
return err
Expand Down
5 changes: 2 additions & 3 deletions get.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ package getter

import (
"bytes"
"context"
"fmt"
"net/url"
"os/exec"
Expand All @@ -32,12 +31,12 @@ type Getter interface {
// The directory may already exist (if we're updating). If it is in a
// format that isn't understood, an error should be returned. Get shouldn't
// simply nuke the directory.
Get(context.Context, string, *url.URL) error
Get(string, *url.URL) error

// GetFile downloads the give URL into the given path. The URL must
// reference a single file. If possible, the Getter should check if
// the remote end contains the same file and no-op this operation.
GetFile(context.Context, string, *url.URL) error
GetFile(string, *url.URL) error

// ClientMode returns the mode based on the given URL. This is used to
// allow clients to let the getters decide which mode to use.
Expand Down
11 changes: 11 additions & 0 deletions get_base.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
package getter

import "context"

// getter is our base getter; it regroups
// fields all getters have in common.
type getter struct {
client *Client
}

func (g *getter) SetClient(c *Client) { g.client = c }

// Context tries to returns the Contex from the getter's
// client. otherwise context.Background() is returned.
func (g *getter) Context() context.Context {
if g == nil || g.client == nil {
return context.Background()
}
return g.client.ctx
}
17 changes: 8 additions & 9 deletions get_file_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package getter

import (
"context"
"os"
"path/filepath"
"testing"
Expand All @@ -16,7 +15,7 @@ func TestFileGetter(t *testing.T) {
dst := tempDir(t)

// With a dir that doesn't exist
if err := g.Get(context.Background(), dst, testModuleURL("basic")); err != nil {
if err := g.Get(dst, testModuleURL("basic")); err != nil {
t.Fatalf("err: %s", err)
}

Expand All @@ -43,7 +42,7 @@ func TestFileGetter_sourceFile(t *testing.T) {
// With a source URL that is a path to a file
u := testModuleURL("basic")
u.Path += "/main.tf"
if err := g.Get(context.Background(), dst, u); err == nil {
if err := g.Get(dst, u); err == nil {
t.Fatal("should error")
}
}
Expand All @@ -55,7 +54,7 @@ func TestFileGetter_sourceNoExist(t *testing.T) {
// With a source URL that doesn't exist
u := testModuleURL("basic")
u.Path += "/main"
if err := g.Get(context.Background(), dst, u); err == nil {
if err := g.Get(dst, u); err == nil {
t.Fatal("should error")
}
}
Expand All @@ -69,7 +68,7 @@ func TestFileGetter_dir(t *testing.T) {
}

// With a dir that exists that isn't a symlink
if err := g.Get(context.Background(), dst, testModuleURL("basic")); err == nil {
if err := g.Get(dst, testModuleURL("basic")); err == nil {
t.Fatal("should error")
}
}
Expand All @@ -93,7 +92,7 @@ func TestFileGetter_dirSymlink(t *testing.T) {
}

// With a dir that exists that isn't a symlink
if err := g.Get(context.Background(), dst, testModuleURL("basic")); err != nil {
if err := g.Get(dst, testModuleURL("basic")); err != nil {
t.Fatalf("err: %s", err)
}

Expand All @@ -110,7 +109,7 @@ func TestFileGetter_GetFile(t *testing.T) {
defer os.RemoveAll(filepath.Dir(dst))

// With a dir that doesn't exist
if err := g.GetFile(context.Background(), dst, testModuleURL("basic-file/foo.txt")); err != nil {
if err := g.GetFile(dst, testModuleURL("basic-file/foo.txt")); err != nil {
t.Fatalf("err: %s", err)
}

Expand All @@ -135,7 +134,7 @@ func TestFileGetter_GetFile_Copy(t *testing.T) {
defer os.RemoveAll(filepath.Dir(dst))

// With a dir that doesn't exist
if err := g.GetFile(context.Background(), dst, testModuleURL("basic-file/foo.txt")); err != nil {
if err := g.GetFile(dst, testModuleURL("basic-file/foo.txt")); err != nil {
t.Fatalf("err: %s", err)
}

Expand All @@ -158,7 +157,7 @@ func TestFileGetter_percent2F(t *testing.T) {
dst := tempDir(t)

// With a dir that doesn't exist
if err := g.Get(context.Background(), dst, testModuleURL("basic%2Ftest")); err != nil {
if err := g.Get(dst, testModuleURL("basic%2Ftest")); err != nil {
t.Fatalf("err: %s", err)
}

Expand Down
6 changes: 3 additions & 3 deletions get_file_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
package getter

import (
"context"
"fmt"
"net/url"
"os"
"path/filepath"
)

func (g *FileGetter) Get(_ context.Context, dst string, u *url.URL) error {
func (g *FileGetter) Get(dst string, u *url.URL) error {
path := u.Path
if u.RawPath != "" {
path = u.RawPath
Expand Down Expand Up @@ -49,7 +48,8 @@ func (g *FileGetter) Get(_ context.Context, dst string, u *url.URL) error {
return os.Symlink(path, dst)
}

func (g *FileGetter) GetFile(ctx context.Context, dst string, u *url.URL) error {
func (g *FileGetter) GetFile(dst string, u *url.URL) error {
ctx := g.Context()
path := u.Path
if u.RawPath != "" {
path = u.RawPath
Expand Down
7 changes: 4 additions & 3 deletions get_file_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
package getter

import (
"context"
"fmt"
"net/url"
"os"
Expand All @@ -13,6 +12,7 @@ import (
)

func (g *FileGetter) Get(dst string, u *url.URL) error {
ctx := g.Context()
path := u.Path
if u.RawPath != "" {
path = u.RawPath
Expand Down Expand Up @@ -51,15 +51,16 @@ func (g *FileGetter) Get(dst string, u *url.URL) error {
sourcePath := toBackslash(path)

// Use mklink to create a junction point
output, err := exec.Command("cmd", "/c", "mklink", "/J", dst, sourcePath).CombinedOutput()
output, err := exec.CommandContext(ctx, "cmd", "/c", "mklink", "/J", dst, sourcePath).CombinedOutput()
if err != nil {
return fmt.Errorf("failed to run mklink %v %v: %v %q", dst, sourcePath, err, output)
}

return nil
}

func (g *FileGetter) GetFile(ctx context.Context, dst string, u *url.URL) error {
func (g *FileGetter) GetFile(dst string, u *url.URL) error {
ctx := g.Context()
path := u.Path
if u.RawPath != "" {
path = u.RawPath
Expand Down
13 changes: 7 additions & 6 deletions get_git.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ import (
"strings"

urlhelper "github.com/hashicorp/go-getter/helper/url"
"github.com/hashicorp/go-safetemp"
"github.com/hashicorp/go-version"
safetemp "github.com/hashicorp/go-safetemp"
version "github.com/hashicorp/go-version"
)

// GitGetter is a Getter implementation that will download a module from
Expand All @@ -26,7 +26,8 @@ func (g *GitGetter) ClientMode(_ *url.URL) (ClientMode, error) {
return ClientModeDir, nil
}

func (g *GitGetter) Get(ctx context.Context, dst string, u *url.URL) error {
func (g *GitGetter) Get(dst string, u *url.URL) error {
ctx := g.Context()
if _, err := exec.LookPath("git"); err != nil {
return fmt.Errorf("git must be available and on the PATH")
}
Expand Down Expand Up @@ -128,7 +129,7 @@ func (g *GitGetter) Get(ctx context.Context, dst string, u *url.URL) error {

// GetFile for Git doesn't support updating at this time. It will download
// the file every time.
func (g *GitGetter) GetFile(ctx context.Context, dst string, u *url.URL) error {
func (g *GitGetter) GetFile(dst string, u *url.URL) error {
td, tdcloser, err := safetemp.Dir("", "getter")
if err != nil {
return err
Expand All @@ -141,7 +142,7 @@ func (g *GitGetter) GetFile(ctx context.Context, dst string, u *url.URL) error {
u.Path = filepath.Dir(u.Path)

// Get the full repository
if err := g.Get(ctx, td, u); err != nil {
if err := g.Get(td, u); err != nil {
return err
}

Expand All @@ -152,7 +153,7 @@ func (g *GitGetter) GetFile(ctx context.Context, dst string, u *url.URL) error {
}

fg := &FileGetter{Copy: true}
return fg.GetFile(ctx, dst, u)
return fg.GetFile(dst, u)
}

func (g *GitGetter) checkout(dst string, ref string) error {
Expand Down
21 changes: 10 additions & 11 deletions get_git_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package getter

import (
"context"
"encoding/base64"
"io/ioutil"
"net/url"
Expand Down Expand Up @@ -38,7 +37,7 @@ func TestGitGetter(t *testing.T) {
repo.commitFile("foo.txt", "hello")

// With a dir that doesn't exist
if err := g.Get(context.Background(), dst, repo.url); err != nil {
if err := g.Get(dst, repo.url); err != nil {
t.Fatalf("err: %s", err)
}

Expand Down Expand Up @@ -66,7 +65,7 @@ func TestGitGetter_branch(t *testing.T) {
q.Add("ref", "test-branch")
repo.url.RawQuery = q.Encode()

if err := g.Get(context.Background(), dst, repo.url); err != nil {
if err := g.Get(dst, repo.url); err != nil {
t.Fatalf("err: %s", err)
}

Expand All @@ -77,7 +76,7 @@ func TestGitGetter_branch(t *testing.T) {
}

// Get again should work
if err := g.Get(context.Background(), dst, repo.url); err != nil {
if err := g.Get(dst, repo.url); err != nil {
t.Fatalf("err: %s", err)
}

Expand Down Expand Up @@ -106,7 +105,7 @@ func TestGitGetter_branchUpdate(t *testing.T) {
q := repo.url.Query()
q.Add("ref", "test-branch")
repo.url.RawQuery = q.Encode()
if err := g.Get(context.Background(), dst, repo.url); err != nil {
if err := g.Get(dst, repo.url); err != nil {
t.Fatalf("err: %s", err)
}

Expand All @@ -120,7 +119,7 @@ func TestGitGetter_branchUpdate(t *testing.T) {
repo.commitFile("branch-update.txt", "branch-update")

// Get again should work
if err := g.Get(context.Background(), dst, repo.url); err != nil {
if err := g.Get(dst, repo.url); err != nil {
t.Fatalf("err: %s", err)
}

Expand Down Expand Up @@ -148,7 +147,7 @@ func TestGitGetter_tag(t *testing.T) {
q.Add("ref", "v1.0")
repo.url.RawQuery = q.Encode()

if err := g.Get(context.Background(), dst, repo.url); err != nil {
if err := g.Get(dst, repo.url); err != nil {
t.Fatalf("err: %s", err)
}

Expand All @@ -159,7 +158,7 @@ func TestGitGetter_tag(t *testing.T) {
}

// Get again should work
if err := g.Get(context.Background(), dst, repo.url); err != nil {
if err := g.Get(dst, repo.url); err != nil {
t.Fatalf("err: %s", err)
}

Expand All @@ -185,7 +184,7 @@ func TestGitGetter_GetFile(t *testing.T) {

// Download the file
repo.url.Path = filepath.Join(repo.url.Path, "file.txt")
if err := g.GetFile(context.Background(), dst, repo.url); err != nil {
if err := g.GetFile(dst, repo.url); err != nil {
t.Fatalf("err: %s", err)
}

Expand Down Expand Up @@ -246,7 +245,7 @@ func TestGitGetter_sshKey(t *testing.T) {
t.Fatal(err)
}

if err := g.Get(context.Background(), dst, u); err != nil {
if err := g.Get(dst, u); err != nil {
t.Fatalf("err: %s", err)
}

Expand Down Expand Up @@ -282,7 +281,7 @@ func TestGitGetter_submodule(t *testing.T) {
p.git("commit", "-m", "Add child submodule")

// Clone the root repository
if err := g.Get(context.Background(), dst, p.url); err != nil {
if err := g.Get(dst, p.url); err != nil {
t.Fatalf("err: %s", err)
}

Expand Down
13 changes: 7 additions & 6 deletions get_hg.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
"runtime"

urlhelper "github.com/hashicorp/go-getter/helper/url"
"github.com/hashicorp/go-safetemp"
safetemp "github.com/hashicorp/go-safetemp"
)

// HgGetter is a Getter implementation that will download a module from
Expand All @@ -23,7 +23,8 @@ func (g *HgGetter) ClientMode(_ *url.URL) (ClientMode, error) {
return ClientModeDir, nil
}

func (g *HgGetter) Get(ctx context.Context, dst string, u *url.URL) error {
func (g *HgGetter) Get(dst string, u *url.URL) error {
ctx := g.Context()
if _, err := exec.LookPath("hg"); err != nil {
return fmt.Errorf("hg must be available and on the PATH")
}
Expand Down Expand Up @@ -66,7 +67,7 @@ func (g *HgGetter) Get(ctx context.Context, dst string, u *url.URL) error {

// GetFile for Hg doesn't support updating at this time. It will download
// the file every time.
func (g *HgGetter) GetFile(ctx context.Context, dst string, u *url.URL) error {
func (g *HgGetter) GetFile(dst string, u *url.URL) error {
// Create a temporary directory to store the full source. This has to be
// a non-existent directory.
td, tdcloser, err := safetemp.Dir("", "getter")
Expand All @@ -86,7 +87,7 @@ func (g *HgGetter) GetFile(ctx context.Context, dst string, u *url.URL) error {
}

// Get the full repository
if err := g.Get(ctx, td, u); err != nil {
if err := g.Get(td, u); err != nil {
return err
}

Expand All @@ -96,8 +97,8 @@ func (g *HgGetter) GetFile(ctx context.Context, dst string, u *url.URL) error {
return err
}

fg := &FileGetter{Copy: true}
return fg.GetFile(ctx, dst, u)
fg := &FileGetter{Copy: true, getter: g.getter}
return fg.GetFile(dst, u)
}

func (g *HgGetter) clone(dst string, u *url.URL) error {
Expand Down
Loading

0 comments on commit 90f6e2c

Please sign in to comment.