diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..22d0d82f8 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +vendor diff --git a/checksum.go b/checksum.go index eeccfea9d..6a2ed5d16 100644 --- a/checksum.go +++ b/checksum.go @@ -3,6 +3,7 @@ package getter import ( "bufio" "bytes" + "context" "crypto/md5" "crypto/sha1" "crypto/sha256" @@ -93,7 +94,7 @@ func (c *FileChecksum) checksum(source string) error { // *file2 // // see parseChecksumLine for more detail on checksum file parsing -func (c *Client) extractChecksum(u *url.URL) (*FileChecksum, error) { +func (c *Client) extractChecksum(ctx context.Context, u *url.URL) (*FileChecksum, error) { q := u.Query() v := q.Get("checksum") @@ -115,7 +116,7 @@ func (c *Client) extractChecksum(u *url.URL) (*FileChecksum, error) { switch checksumType { case "file": - return c.ChecksumFromFile(checksumValue, u) + return c.ChecksumFromFile(ctx, checksumValue, u.Path) default: return newChecksumFromType(checksumType, checksumValue, filepath.Base(u.EscapedPath())) } @@ -183,15 +184,16 @@ func newChecksumFromValue(checksumValue, filename string) (*FileChecksum, error) return c, nil } -// ChecksumFromFile will return all the FileChecksums found in file +// ChecksumFromFile will return the first file checksum found in the +// `checksumURL` file that corresponds to the `checksummedURL` path. // -// ChecksumFromFile will try to guess the hashing algorithm based on content -// of checksum file +// ChecksumFromFile will infer the hashing algorithm based on the checksumURL +// file content. // -// ChecksumFromFile will only return checksums for files that match file -// behind src -func (c *Client) ChecksumFromFile(checksumFile string, src *url.URL) (*FileChecksum, error) { - checksumFileURL, err := urlhelper.Parse(checksumFile) +// ChecksumFromFile will only return checksums for files that match +// checksummedURL, which is the object being checksummed. +func (c *Client) ChecksumFromFile(ctx context.Context, checksumURL, checksummedURL string) (*FileChecksum, error) { + checksumFileURL, err := urlhelper.Parse(checksumURL) if err != nil { return nil, err } @@ -202,24 +204,20 @@ func (c *Client) ChecksumFromFile(checksumFile string, src *url.URL) (*FileCheck } defer os.Remove(tempfile) - c2 := &Client{ - Ctx: c.Ctx, - Getters: c.Getters, - Decompressors: c.Decompressors, - Detectors: c.Detectors, - Pwd: c.Pwd, - Dir: false, - Src: checksumFile, - Dst: tempfile, - ProgressListener: c.ProgressListener, + req := &Request{ + // Pwd: c.Pwd, TODO(adrien): pass pwd ? + Dir: false, + Src: checksumURL, + Dst: tempfile, + // ProgressListener: c.ProgressListener, TODO(adrien): pass progress bar ? } - if err = c2.Get(); err != nil { + if err = c.Get(ctx, req); err != nil { return nil, fmt.Errorf( "Error downloading checksum file: %s", err) } - filename := filepath.Base(src.Path) - absPath, err := filepath.Abs(src.Path) + filename := filepath.Base(checksummedURL) + absPath, err := filepath.Abs(checksummedURL) if err != nil { return nil, err } @@ -277,7 +275,7 @@ func (c *Client) ChecksumFromFile(checksumFile string, src *url.URL) (*FileCheck } } } - return nil, fmt.Errorf("no checksum found in: %s", checksumFile) + return nil, fmt.Errorf("no checksum found in: %s", checksumURL) } // parseChecksumLine takes a line from a checksum file and returns diff --git a/client.go b/client.go index 38fb43b8f..82216f2df 100644 --- a/client.go +++ b/client.go @@ -19,25 +19,6 @@ import ( // Using a client directly allows more fine-grained control over how downloading // is done, as well as customizing the protocols supported. type Client struct { - // Ctx for cancellation - Ctx context.Context - - // Src is the source URL to get. - // - // Dst is the path to save the downloaded thing as. If Dir is set to - // true, then this should be a directory. If the directory doesn't exist, - // it will be created for you. - // - // Pwd is the working directory for detection. If this isn't set, some - // detection may fail. Client will not default pwd to the current - // working directory for security reasons. - Src string - Dst string - Pwd string - - // Mode is the method of download the client will use. See ClientMode - // for documentation. - Mode ClientMode // Detectors is the list of detectors that are tried on the source. // If this is nil, then the default Detectors will be used. @@ -50,51 +31,37 @@ type Client struct { // Getters is the map of protocols supported by this client. If this // is nil, then the default Getters variable will be used. Getters map[string]Getter - - // Dir, if true, tells the Client it is downloading a directory (versus - // a single file). This distinction is necessary since filenames and - // directory names follow the same format so disambiguating is impossible - // without knowing ahead of time. - // - // WARNING: deprecated. If Mode is set, that will take precedence. - Dir bool - - // ProgressListener allows to track file downloads. - // By default a no op progress listener is used. - ProgressListener ProgressTracker - - Options []ClientOption } // Get downloads the configured source to the destination. -func (c *Client) Get() error { - if err := c.Configure(c.Options...); err != nil { +func (c *Client) Get(ctx context.Context, req *Request) error { + if err := c.configure(); err != nil { return err } // Store this locally since there are cases we swap this - mode := c.Mode - if mode == ClientModeInvalid { - if c.Dir { - mode = ClientModeDir + if req.Mode == ClientModeInvalid { + if req.Dir { + req.Mode = ClientModeDir } else { - mode = ClientModeFile + req.Mode = ClientModeFile } } - src, err := Detect(c.Src, c.Pwd, c.Detectors) + var err error + req.Src, err = Detect(req.Src, req.Pwd, c.Detectors) if err != nil { return err } + var force string // Determine if we have a forced protocol, i.e. "git::http://..." - force, src := getForcedGetter(src) + force, req.Src = getForcedGetter(req.Src) // If there is a subdir component, then we download the root separately // and then copy over the proper subdir. - var realDst string - dst := c.Dst - src, subDir := SourceDirSubdir(src) + var realDst, subDir string + req.Src, subDir = SourceDirSubdir(req.Src) if subDir != "" { td, tdcloser, err := safetemp.Dir("", "getter") if err != nil { @@ -102,16 +69,16 @@ func (c *Client) Get() error { } defer tdcloser.Close() - realDst = dst - dst = td + realDst = req.Dst + req.Dst = td } - u, err := urlhelper.Parse(src) + req.u, err = urlhelper.Parse(req.Src) if err != nil { return err } if force == "" { - force = u.Scheme + force = req.u.Scheme } g, ok := c.Getters[force] @@ -121,7 +88,7 @@ func (c *Client) Get() error { } // We have magic query parameters that we use to signal different features - q := u.Query() + q := req.u.Query() // Determine if we have an archive type archiveV := q.Get("archive") @@ -129,7 +96,7 @@ func (c *Client) Get() error { // Delete the paramter since it is a magic parameter we don't // want to pass on to the Getter q.Del("archive") - u.RawQuery = q.Encode() + req.u.RawQuery = q.Encode() // If we can parse the value as a bool and it is false, then // set the archive to "-" which should never map to a decompressor @@ -141,7 +108,7 @@ func (c *Client) Get() error { // We don't appear to... but is it part of the filename? matchingLen := 0 for k := range c.Decompressors { - if strings.HasSuffix(u.Path, "."+k) && len(k) > matchingLen { + if strings.HasSuffix(req.u.Path, "."+k) && len(k) > matchingLen { archiveV = k matchingLen = len(k) } @@ -166,65 +133,65 @@ func (c *Client) Get() error { // Swap the download directory to be our temporary path and // store the old values. - decompressDst = dst - decompressDir = mode != ClientModeFile - dst = filepath.Join(td, "archive") - mode = ClientModeFile + decompressDst = req.Dst + decompressDir = req.Mode != ClientModeFile + req.Dst = filepath.Join(td, "archive") + req.Mode = ClientModeFile } // Determine checksum if we have one - checksum, err := c.extractChecksum(u) + checksum, err := c.extractChecksum(ctx, req.u) if err != nil { return fmt.Errorf("invalid checksum: %s", err) } // Delete the query parameter if we have it. q.Del("checksum") - u.RawQuery = q.Encode() + req.u.RawQuery = q.Encode() - if mode == ClientModeAny { + if req.Mode == ClientModeAny { // Ask the getter which client mode to use - mode, err = g.ClientMode(u) + req.Mode, err = g.ClientMode(ctx, req.u) if err != nil { return err } // Destination is the base name of the URL path in "any" mode when // a file source is detected. - if mode == ClientModeFile { - filename := filepath.Base(u.Path) + if req.Mode == ClientModeFile { + filename := filepath.Base(req.u.Path) // Determine if we have a custom file name if v := q.Get("filename"); v != "" { // Delete the query parameter if we have it. q.Del("filename") - u.RawQuery = q.Encode() + req.u.RawQuery = q.Encode() filename = v } - dst = filepath.Join(dst, filename) + req.Dst = filepath.Join(req.Dst, filename) } } // If we're not downloading a directory, then just download the file // and return. - if mode == ClientModeFile { + if req.Mode == ClientModeFile { getFile := true if checksum != nil { - if err := checksum.checksum(dst); err == nil { + if err := checksum.checksum(req.Dst); err == nil { // don't get the file if the checksum of dst is correct getFile = false } } if getFile { - err := g.GetFile(dst, u) + err := g.GetFile(ctx, req) if err != nil { return err } if checksum != nil { - if err := checksum.checksum(dst); err != nil { + if err := checksum.checksum(req.Dst); err != nil { return err } } @@ -233,24 +200,24 @@ func (c *Client) Get() error { if decompressor != nil { // We have a decompressor, so decompress the current destination // into the final destination with the proper mode. - err := decompressor.Decompress(decompressDst, dst, decompressDir) + err := decompressor.Decompress(decompressDst, req.Dst, decompressDir) if err != nil { return err } // Swap the information back - dst = decompressDst + req.Dst = decompressDst if decompressDir { - mode = ClientModeAny + req.Mode = ClientModeAny } else { - mode = ClientModeFile + req.Mode = ClientModeFile } } // We check the dir value again because it can be switched back // if we were unarchiving. If we're still only Get-ing a file, then // we're done. - if mode == ClientModeFile { + if req.Mode == ClientModeFile { return nil } } @@ -269,9 +236,9 @@ func (c *Client) Get() error { // We're downloading a directory, which might require a bit more work // if we're specifying a subdir. - err := g.Get(dst, u) + err := g.Get(ctx, req) if err != nil { - err = fmt.Errorf("error downloading '%s': %s", src, err) + err = fmt.Errorf("error downloading '%s': %s", req.Src, err) return err } } @@ -286,12 +253,12 @@ func (c *Client) Get() error { } // Process any globs - subDir, err := SubdirGlob(dst, subDir) + subDir, err := SubdirGlob(req.Dst, subDir) if err != nil { return err } - return copyDir(c.Ctx, realDst, subDir, false) + return copyDir(ctx, realDst, subDir, false) } return nil diff --git a/client_option.go b/client_option.go index c1ee413b0..567f3c8ac 100644 --- a/client_option.go +++ b/client_option.go @@ -1,22 +1,7 @@ package getter -import "context" - -// A ClientOption allows to configure a client -type ClientOption func(*Client) error - -// Configure configures a client with options. -func (c *Client) Configure(opts ...ClientOption) error { - if c.Ctx == nil { - c.Ctx = context.Background() - } - c.Options = opts - for _, opt := range opts { - err := opt(c) - if err != nil { - return err - } - } +// configure configures a client with options. +func (c *Client) configure() error { // Default decompressor values if c.Decompressors == nil { c.Decompressors = Decompressors @@ -35,12 +20,3 @@ func (c *Client) Configure(opts ...ClientOption) error { } return nil } - -// WithContext allows to pass a context to operation -// in order to be able to cancel a download in progress. -func WithContext(ctx context.Context) func(*Client) error { - return func(c *Client) error { - c.Ctx = ctx - return nil - } -} diff --git a/client_option_progress.go b/client_option_progress.go index 9b185f71d..1ec9aa1e9 100644 --- a/client_option_progress.go +++ b/client_option_progress.go @@ -4,18 +4,6 @@ import ( "io" ) -// WithProgress allows for a user to track -// the progress of a download. -// For example by displaying a progress bar with -// current download. -// Not all getters have progress support yet. -func WithProgress(pl ProgressTracker) func(*Client) error { - return func(c *Client) error { - c.ProgressListener = pl - return nil - } -} - // ProgressTracker allows to track the progress of downloads. type ProgressTracker interface { // TrackProgress should be called when diff --git a/client_option_progress_test.go b/client_option_progress_test.go index a578fed67..ba9556354 100644 --- a/client_option_progress_test.go +++ b/client_option_progress_test.go @@ -1,6 +1,7 @@ package getter import ( + "context" "io" "net/http" "net/http/httptest" @@ -35,11 +36,12 @@ func TestGet_progress(t *testing.T) { rw.Header().Add("X-Terraform-Get", "something") })) defer s.Close() + ctx := context.Background() { // dl without tracking dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) - if err := GetFile(dst, s.URL+"/file?thig=this&that"); err != nil { + if err := GetFile(ctx, dst, s.URL+"/file?thig=this&that"); err != nil { t.Fatalf("download failed: %v", err) } } @@ -48,10 +50,22 @@ func TestGet_progress(t *testing.T) { p := &MockProgressTracking{} dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) - if err := GetFile(dst, s.URL+"/file?thig=this&that", WithProgress(p)); err != nil { + req := &Request{ + Dst: dst, + Src: s.URL + "/file?thig=this&that", + ProgressListener: p, + Dir: false, + } + if err := DefaultClient.Get(ctx, req); err != nil { t.Fatalf("download failed: %v", err) } - if err := GetFile(dst, s.URL+"/otherfile?thig=this&that", WithProgress(p)); err != nil { + req = &Request{ + Dst: dst, + Src: s.URL + "/otherfile?thig=this&that", + ProgressListener: p, + Dir: false, + } + if err := DefaultClient.Get(ctx, req); err != nil { t.Fatalf("download failed: %v", err) } diff --git a/cmd/go-getter/main.go b/cmd/go-getter/main.go index 3cd028641..61b5c848c 100644 --- a/cmd/go-getter/main.go +++ b/cmd/go-getter/main.go @@ -41,20 +41,16 @@ func main() { log.Fatalf("Error getting wd: %s", err) } - opts := []getter.ClientOption{} - if *progress { - opts = append(opts, getter.WithProgress(defaultProgressBar)) - } - ctx, cancel := context.WithCancel(context.Background()) // Build the client - client := &getter.Client{ - Ctx: ctx, - Src: args[0], - Dst: args[1], - Pwd: pwd, - Mode: mode, - Options: opts, + req := &getter.Request{ + Src: args[0], + Dst: args[1], + Pwd: pwd, + Mode: mode, + } + if *progress { + req.ProgressListener = defaultProgressBar } wg := sync.WaitGroup{} @@ -63,7 +59,7 @@ func main() { go func() { defer wg.Done() defer cancel() - if err := client.Get(); err != nil { + if err := getter.DefaultClient.Get(ctx, req); err != nil { errChan <- err } }() diff --git a/decompress_zip.go b/decompress_zip.go index 0830f7914..650a852f3 100644 --- a/decompress_zip.go +++ b/decompress_zip.go @@ -8,8 +8,8 @@ import ( "path/filepath" ) -// ZipDecompressor is an implementation of Decompressor that can -// decompress zip files. +// ZipDecompressor is an implementation of Decompressor that can decompress zip +// files. type ZipDecompressor struct{} func (d *ZipDecompressor) Decompress(dst, src string, dir bool) error { diff --git a/folder_storage.go b/folder_storage.go index 647ccf459..4481ee27a 100644 --- a/folder_storage.go +++ b/folder_storage.go @@ -1,6 +1,7 @@ package getter import ( + "context" "crypto/md5" "encoding/hex" "fmt" @@ -39,7 +40,7 @@ func (s *FolderStorage) Dir(key string) (d string, e bool, err error) { } // Get implements Storage.Get -func (s *FolderStorage) Get(key string, source string, update bool) error { +func (s *FolderStorage) Get(ctx context.Context, key string, source string, update bool) error { dir := s.dir(key) if !update { if _, err := os.Stat(dir); err == nil { @@ -54,7 +55,7 @@ func (s *FolderStorage) Get(key string, source string, update bool) error { } // Get the source. This always forces an update. - return Get(dir, source) + return Get(ctx, dir, source) } // dir returns the directory name internally that we'll use to map to diff --git a/folder_storage_test.go b/folder_storage_test.go index feb8d3425..64209aef9 100644 --- a/folder_storage_test.go +++ b/folder_storage_test.go @@ -1,6 +1,7 @@ package getter import ( + "context" "os" "path/filepath" "testing" @@ -12,6 +13,7 @@ func TestFolderStorage_impl(t *testing.T) { func TestFolderStorage(t *testing.T) { s := &FolderStorage{StorageDir: tempDir(t)} + ctx := context.Background() module := testModule("basic") @@ -27,7 +29,7 @@ func TestFolderStorage(t *testing.T) { key := "foo" // We can get it - err = s.Get(key, module, false) + err = s.Get(ctx, key, module, false) if err != nil { t.Fatalf("err: %s", err) } diff --git a/get.go b/get.go index c233763c6..2738bc4ce 100644 --- a/get.go +++ b/get.go @@ -13,6 +13,7 @@ package getter import ( "bytes" + "context" "fmt" "net/url" "os/exec" @@ -31,16 +32,16 @@ type Getter interface { // The directory may already exist (if we're updating). If it is in a // format that isn't understood, an error should be returned. Get shouldn't // simply nuke the directory. - Get(string, *url.URL) error + Get(context.Context, *Request) error // GetFile downloads the give URL into the given path. The URL must // reference a single file. If possible, the Getter should check if // the remote end contains the same file and no-op this operation. - GetFile(string, *url.URL) error + GetFile(context.Context, *Request) error // ClientMode returns the mode based on the given URL. This is used to // allow clients to let the getters decide which mode to use. - ClientMode(*url.URL) (ClientMode, error) + ClientMode(context.Context, *url.URL) (ClientMode, error) // SetClient allows a getter to know it's client // in order to access client's Get functions or @@ -59,6 +60,12 @@ var forcedRegexp = regexp.MustCompile(`^([A-Za-z0-9]+)::(.+)$`) // httpClient is the default client to be used by HttpGetters. var httpClient = cleanhttp.DefaultClient() +var DefaultClient = &Client{ + Getters: Getters, + Detectors: Detectors, + Decompressors: Decompressors, +} + func init() { httpGetter := &HttpGetter{ Netrc: true, @@ -80,13 +87,13 @@ func init() { // // src is a URL, whereas dst is always just a file path to a folder. This // folder doesn't need to exist. It will be created if it doesn't exist. -func Get(dst, src string, opts ...ClientOption) error { - return (&Client{ - Src: src, - Dst: dst, - Dir: true, - Options: opts, - }).Get() +func Get(ctx context.Context, dst, src string) error { + req := &Request{ + Src: src, + Dst: dst, + Dir: true, + } + return DefaultClient.Get(ctx, req) } // GetAny downloads a URL into the given destination. Unlike Get or @@ -95,24 +102,24 @@ func Get(dst, src string, opts ...ClientOption) error { // dst must be a directory. If src is a file, it will be downloaded // into dst with the basename of the URL. If src is a directory or // archive, it will be unpacked directly into dst. -func GetAny(dst, src string, opts ...ClientOption) error { - return (&Client{ - Src: src, - Dst: dst, - Mode: ClientModeAny, - Options: opts, - }).Get() +func GetAny(ctx context.Context, dst, src string) error { + req := &Request{ + Src: src, + Dst: dst, + Mode: ClientModeAny, + } + return DefaultClient.Get(ctx, req) } // GetFile downloads the file specified by src into the path specified by // dst. -func GetFile(dst, src string, opts ...ClientOption) error { - return (&Client{ - Src: src, - Dst: dst, - Dir: false, - Options: opts, - }).Get() +func GetFile(ctx context.Context, dst, src string) error { + req := &Request{ + Src: src, + Dst: dst, + Dir: false, + } + return DefaultClient.Get(ctx, req) } // getRunCommand is a helper that will run a command and capture the output diff --git a/get_base.go b/get_base.go index 09e9b6313..4a9ca733f 100644 --- a/get_base.go +++ b/get_base.go @@ -1,7 +1,5 @@ package getter -import "context" - // getter is our base getter; it regroups // fields all getters have in common. type getter struct { @@ -9,12 +7,3 @@ type getter struct { } func (g *getter) SetClient(c *Client) { g.client = c } - -// Context tries to returns the Contex from the getter's -// client. otherwise context.Background() is returned. -func (g *getter) Context() context.Context { - if g == nil || g.client == nil { - return context.Background() - } - return g.client.Ctx -} diff --git a/get_file.go b/get_file.go index 78660839a..eccd41210 100644 --- a/get_file.go +++ b/get_file.go @@ -1,22 +1,20 @@ package getter import ( + "context" + "fmt" "net/url" "os" + "path/filepath" ) // FileGetter is a Getter implementation that will download a module from // a file scheme. type FileGetter struct { getter - - // Copy, if set to true, will copy data instead of using a symlink. If - // false, attempts to symlink to speed up the operation and to lower the - // disk space usage. If the symlink fails, may attempt to copy on windows. - Copy bool } -func (g *FileGetter) ClientMode(u *url.URL) (ClientMode, error) { +func (g *FileGetter) ClientMode(ctx context.Context, u *url.URL) (ClientMode, error) { path := u.Path if u.RawPath != "" { path = u.RawPath @@ -34,3 +32,110 @@ func (g *FileGetter) ClientMode(u *url.URL) (ClientMode, error) { return ClientModeFile, nil } + +func (g *FileGetter) Get(ctx context.Context, req *Request) error { + path := req.u.Path + if req.u.RawPath != "" { + path = req.u.RawPath + } + + // The source path must exist and be a directory to be usable. + if fi, err := os.Stat(path); err != nil { + return fmt.Errorf("source path error: %s", err) + } else if !fi.IsDir() { + return fmt.Errorf("source path must be a directory") + } + + fi, err := os.Lstat(req.Dst) + if err != nil && !os.IsNotExist(err) { + return err + } + + // If the destination already exists, it must be a symlink + if err == nil { + mode := fi.Mode() + if mode&os.ModeSymlink == 0 { + return fmt.Errorf("destination exists and is not a symlink") + } + + // Remove the destination + if err := os.Remove(req.Dst); err != nil { + return err + } + } + + // Create all the parent directories + if err := os.MkdirAll(filepath.Dir(req.Dst), 0755); err != nil { + return err + } + + return SymlinkAny(path, req.Dst) +} + +func (g *FileGetter) GetFile(ctx context.Context, req *Request) error { + path := req.u.Path + if req.u.RawPath != "" { + path = req.u.RawPath + } + + // The source path must exist and be a file to be usable. + if fi, err := os.Stat(path); err != nil { + return fmt.Errorf("source path error: %s", err) + } else if fi.IsDir() { + return fmt.Errorf("source path must be a file") + } + + _, err := os.Lstat(req.Dst) + if err != nil && !os.IsNotExist(err) { + return err + } + + // If the destination already exists, it must be a symlink + if err == nil { + // Remove the destination + if err := os.Remove(req.Dst); err != nil { + return err + } + } + + // Create all the parent directories + if err := os.MkdirAll(filepath.Dir(req.Dst), 0755); err != nil { + return err + } + + // If we're not copying, just symlink and we're done + if !req.Copy { + if err = os.Symlink(path, req.Dst); err == nil { + return err + } + lerr, ok := err.(*os.LinkError) + if !ok { + return err + } + switch lerr.Err { + case ErrUnauthorized: + // On windows this means we don't have + // symlink privilege, let's + // fallback to a copy to avoid an error. + break + default: + return err + } + } + + // Copy + srcF, err := os.Open(path) + if err != nil { + return err + } + defer srcF.Close() + + dstF, err := os.Create(req.Dst) + if err != nil { + return err + } + defer dstF.Close() + + _, err = Copy(ctx, dstF, srcF) + return err +} diff --git a/get_file_symlink.go b/get_file_symlink.go new file mode 100644 index 000000000..12d296aca --- /dev/null +++ b/get_file_symlink.go @@ -0,0 +1,10 @@ +// +build !windows + +package getter + +import ( + "os" +) + +var ErrUnauthorized = os.ErrPermission +var SymlinkAny = os.Symlink diff --git a/get_file_symlink_windows.go b/get_file_symlink_windows.go new file mode 100644 index 000000000..5b8032a2a --- /dev/null +++ b/get_file_symlink_windows.go @@ -0,0 +1,28 @@ +package getter + +import ( + "fmt" + "os/exec" + "strings" + "syscall" +) + +func SymlinkAny(oldname, newname string) error { + sourcePath := toBackslash(oldname) + + // Use mklink to create a junction point + output, err := exec.Command("cmd", "/c", "mklink", "/J", newname, sourcePath).CombinedOutput() + if err != nil { + return fmt.Errorf("failed to run mklink %v %v: %v %q", newname, sourcePath, err, output) + } + return nil +} + +var ErrUnauthorized = syscall.ERROR_PRIVILEGE_NOT_HELD + +// toBackslash returns the result of replacing each slash character +// in path with a backslash ('\') character. Multiple separators are +// replaced by multiple backslashes. +func toBackslash(path string) string { + return strings.Replace(path, "/", "\\", -1) +} diff --git a/get_file_test.go b/get_file_test.go index 94ab3c1c1..02fa03419 100644 --- a/get_file_test.go +++ b/get_file_test.go @@ -1,6 +1,7 @@ package getter import ( + "context" "os" "path/filepath" "testing" @@ -13,9 +14,15 @@ func TestFileGetter_impl(t *testing.T) { func TestFileGetter(t *testing.T) { g := new(FileGetter) dst := tempDir(t) + ctx := context.Background() + + req := &Request{ + Dst: dst, + u: testModuleURL("basic"), + } // With a dir that doesn't exist - if err := g.Get(dst, testModuleURL("basic")); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -38,11 +45,17 @@ func TestFileGetter(t *testing.T) { func TestFileGetter_sourceFile(t *testing.T) { g := new(FileGetter) dst := tempDir(t) + ctx := context.Background() // With a source URL that is a path to a file u := testModuleURL("basic") u.Path += "/main.tf" - if err := g.Get(dst, u); err == nil { + + req := &Request{ + Dst: dst, + u: u, + } + if err := g.Get(ctx, req); err == nil { t.Fatal("should error") } } @@ -50,11 +63,17 @@ func TestFileGetter_sourceFile(t *testing.T) { func TestFileGetter_sourceNoExist(t *testing.T) { g := new(FileGetter) dst := tempDir(t) + ctx := context.Background() // With a source URL that doesn't exist u := testModuleURL("basic") u.Path += "/main" - if err := g.Get(dst, u); err == nil { + + req := &Request{ + Dst: dst, + u: u, + } + if err := g.Get(ctx, req); err == nil { t.Fatal("should error") } } @@ -62,13 +81,18 @@ func TestFileGetter_sourceNoExist(t *testing.T) { func TestFileGetter_dir(t *testing.T) { g := new(FileGetter) dst := tempDir(t) + ctx := context.Background() if err := os.MkdirAll(dst, 0755); err != nil { t.Fatalf("err: %s", err) } + req := &Request{ + Dst: dst, + u: testModuleURL("basic"), + } // With a dir that exists that isn't a symlink - if err := g.Get(dst, testModuleURL("basic")); err == nil { + if err := g.Get(ctx, req); err == nil { t.Fatal("should error") } } @@ -76,6 +100,8 @@ func TestFileGetter_dir(t *testing.T) { func TestFileGetter_dirSymlink(t *testing.T) { g := new(FileGetter) dst := tempDir(t) + ctx := context.Background() + dst2 := tempDir(t) // Make parents @@ -91,8 +117,13 @@ func TestFileGetter_dirSymlink(t *testing.T) { t.Fatalf("err: %s", err) } + req := &Request{ + Dst: dst, + u: testModuleURL("basic"), + } + // With a dir that exists that isn't a symlink - if err := g.Get(dst, testModuleURL("basic")); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -107,9 +138,15 @@ func TestFileGetter_GetFile(t *testing.T) { g := new(FileGetter) dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) + ctx := context.Background() + + req := &Request{ + Dst: dst, + u: testModuleURL("basic-file/foo.txt"), + } // With a dir that doesn't exist - if err := g.GetFile(dst, testModuleURL("basic-file/foo.txt")); err != nil { + if err := g.GetFile(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -128,13 +165,19 @@ func TestFileGetter_GetFile(t *testing.T) { func TestFileGetter_GetFile_Copy(t *testing.T) { g := new(FileGetter) - g.Copy = true dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) + ctx := context.Background() + + req := &Request{ + Dst: dst, + u: testModuleURL("basic-file/foo.txt"), + Copy: true, + } // With a dir that doesn't exist - if err := g.GetFile(dst, testModuleURL("basic-file/foo.txt")); err != nil { + if err := g.GetFile(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -155,9 +198,15 @@ func TestFileGetter_GetFile_Copy(t *testing.T) { func TestFileGetter_percent2F(t *testing.T) { g := new(FileGetter) dst := tempDir(t) + ctx := context.Background() + + req := &Request{ + Dst: dst, + u: testModuleURL("basic%2Ftest"), + } // With a dir that doesn't exist - if err := g.Get(dst, testModuleURL("basic%2Ftest")); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -170,18 +219,20 @@ func TestFileGetter_percent2F(t *testing.T) { func TestFileGetter_ClientMode_notexist(t *testing.T) { g := new(FileGetter) + ctx := context.Background() u := testURL("nonexistent") - if _, err := g.ClientMode(u); err == nil { + if _, err := g.ClientMode(ctx, u); err == nil { t.Fatal("expect source file error") } } func TestFileGetter_ClientMode_file(t *testing.T) { g := new(FileGetter) + ctx := context.Background() // Check the client mode when pointed at a file. - mode, err := g.ClientMode(testModuleURL("basic-file/foo.txt")) + mode, err := g.ClientMode(ctx, testModuleURL("basic-file/foo.txt")) if err != nil { t.Fatalf("err: %s", err) } @@ -192,9 +243,10 @@ func TestFileGetter_ClientMode_file(t *testing.T) { func TestFileGetter_ClientMode_dir(t *testing.T) { g := new(FileGetter) + ctx := context.Background() // Check the client mode when pointed at a directory. - mode, err := g.ClientMode(testModuleURL("basic")) + mode, err := g.ClientMode(ctx, testModuleURL("basic")) if err != nil { t.Fatalf("err: %s", err) } diff --git a/get_file_unix.go b/get_file_unix.go deleted file mode 100644 index c3b28ae51..000000000 --- a/get_file_unix.go +++ /dev/null @@ -1,103 +0,0 @@ -// +build !windows - -package getter - -import ( - "fmt" - "net/url" - "os" - "path/filepath" -) - -func (g *FileGetter) Get(dst string, u *url.URL) error { - path := u.Path - if u.RawPath != "" { - path = u.RawPath - } - - // The source path must exist and be a directory to be usable. - if fi, err := os.Stat(path); err != nil { - return fmt.Errorf("source path error: %s", err) - } else if !fi.IsDir() { - return fmt.Errorf("source path must be a directory") - } - - fi, err := os.Lstat(dst) - if err != nil && !os.IsNotExist(err) { - return err - } - - // If the destination already exists, it must be a symlink - if err == nil { - mode := fi.Mode() - if mode&os.ModeSymlink == 0 { - return fmt.Errorf("destination exists and is not a symlink") - } - - // Remove the destination - if err := os.Remove(dst); err != nil { - return err - } - } - - // Create all the parent directories - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { - return err - } - - return os.Symlink(path, dst) -} - -func (g *FileGetter) GetFile(dst string, u *url.URL) error { - ctx := g.Context() - path := u.Path - if u.RawPath != "" { - path = u.RawPath - } - - // The source path must exist and be a file to be usable. - if fi, err := os.Stat(path); err != nil { - return fmt.Errorf("source path error: %s", err) - } else if fi.IsDir() { - return fmt.Errorf("source path must be a file") - } - - _, err := os.Lstat(dst) - if err != nil && !os.IsNotExist(err) { - return err - } - - // If the destination already exists, it must be a symlink - if err == nil { - // Remove the destination - if err := os.Remove(dst); err != nil { - return err - } - } - - // Create all the parent directories - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { - return err - } - - // If we're not copying, just symlink and we're done - if !g.Copy { - return os.Symlink(path, dst) - } - - // Copy - srcF, err := os.Open(path) - if err != nil { - return err - } - defer srcF.Close() - - dstF, err := os.Create(dst) - if err != nil { - return err - } - defer dstF.Close() - - _, err = Copy(ctx, dstF, srcF) - return err -} diff --git a/get_file_windows.go b/get_file_windows.go deleted file mode 100644 index 24f1acb17..000000000 --- a/get_file_windows.go +++ /dev/null @@ -1,136 +0,0 @@ -// +build windows - -package getter - -import ( - "fmt" - "net/url" - "os" - "os/exec" - "path/filepath" - "strings" - "syscall" -) - -func (g *FileGetter) Get(dst string, u *url.URL) error { - ctx := g.Context() - path := u.Path - if u.RawPath != "" { - path = u.RawPath - } - - // The source path must exist and be a directory to be usable. - if fi, err := os.Stat(path); err != nil { - return fmt.Errorf("source path error: %s", err) - } else if !fi.IsDir() { - return fmt.Errorf("source path must be a directory") - } - - fi, err := os.Lstat(dst) - if err != nil && !os.IsNotExist(err) { - return err - } - - // If the destination already exists, it must be a symlink - if err == nil { - mode := fi.Mode() - if mode&os.ModeSymlink == 0 { - return fmt.Errorf("destination exists and is not a symlink") - } - - // Remove the destination - if err := os.Remove(dst); err != nil { - return err - } - } - - // Create all the parent directories - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { - return err - } - - sourcePath := toBackslash(path) - - // Use mklink to create a junction point - output, err := exec.CommandContext(ctx, "cmd", "/c", "mklink", "/J", dst, sourcePath).CombinedOutput() - if err != nil { - return fmt.Errorf("failed to run mklink %v %v: %v %q", dst, sourcePath, err, output) - } - - return nil -} - -func (g *FileGetter) GetFile(dst string, u *url.URL) error { - ctx := g.Context() - path := u.Path - if u.RawPath != "" { - path = u.RawPath - } - - // The source path must exist and be a directory to be usable. - if fi, err := os.Stat(path); err != nil { - return fmt.Errorf("source path error: %s", err) - } else if fi.IsDir() { - return fmt.Errorf("source path must be a file") - } - - _, err := os.Lstat(dst) - if err != nil && !os.IsNotExist(err) { - return err - } - - // If the destination already exists, it must be a symlink - if err == nil { - // Remove the destination - if err := os.Remove(dst); err != nil { - return err - } - } - - // Create all the parent directories - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { - return err - } - - // If we're not copying, just symlink and we're done - if !g.Copy { - if err = os.Symlink(path, dst); err == nil { - return err - } - lerr, ok := err.(*os.LinkError) - if !ok { - return err - } - switch lerr.Err { - case syscall.ERROR_PRIVILEGE_NOT_HELD: - // no symlink privilege, let's - // fallback to a copy to avoid an error. - break - default: - return err - } - } - - // Copy - srcF, err := os.Open(path) - if err != nil { - return err - } - defer srcF.Close() - - dstF, err := os.Create(dst) - if err != nil { - return err - } - defer dstF.Close() - - _, err = Copy(ctx, dstF, srcF) - return err -} - -// toBackslash returns the result of replacing each slash character -// in path with a backslash ('\') character. Multiple separators are -// replaced by multiple backslashes. -func toBackslash(path string) string { - return strings.Replace(path, "/", "\\", -1) -} diff --git a/get_gcs.go b/get_gcs.go index 6faa70f4f..f81913890 100644 --- a/get_gcs.go +++ b/get_gcs.go @@ -18,8 +18,7 @@ type GCSGetter struct { getter } -func (g *GCSGetter) ClientMode(u *url.URL) (ClientMode, error) { - ctx := g.Context() +func (g *GCSGetter) ClientMode(ctx context.Context, u *url.URL) (ClientMode, error) { // Parse URL bucket, object, err := g.parseURL(u) @@ -55,29 +54,27 @@ func (g *GCSGetter) ClientMode(u *url.URL) (ClientMode, error) { return ClientModeFile, nil } -func (g *GCSGetter) Get(dst string, u *url.URL) error { - ctx := g.Context() - +func (g *GCSGetter) Get(ctx context.Context, req *Request) error { // Parse URL - bucket, object, err := g.parseURL(u) + bucket, object, err := g.parseURL(req.u) if err != nil { return err } // Remove destination if it already exists - _, err = os.Stat(dst) + _, err = os.Stat(req.Dst) if err != nil && !os.IsNotExist(err) { return err } if err == nil { // Remove the destination - if err := os.RemoveAll(dst); err != nil { + if err := os.RemoveAll(req.Dst); err != nil { return err } } // Create all the parent directories - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(req.Dst), 0755); err != nil { return err } @@ -103,7 +100,7 @@ func (g *GCSGetter) Get(dst string, u *url.URL) error { if err != nil { return err } - objDst = filepath.Join(dst, objDst) + objDst = filepath.Join(req.Dst, objDst) // Download the matching object. err = g.getObject(ctx, client, objDst, bucket, obj.Name) if err != nil { @@ -114,11 +111,9 @@ func (g *GCSGetter) Get(dst string, u *url.URL) error { return nil } -func (g *GCSGetter) GetFile(dst string, u *url.URL) error { - ctx := g.Context() - +func (g *GCSGetter) GetFile(ctx context.Context, req *Request) error { // Parse URL - bucket, object, err := g.parseURL(u) + bucket, object, err := g.parseURL(req.u) if err != nil { return err } @@ -127,7 +122,7 @@ func (g *GCSGetter) GetFile(dst string, u *url.URL) error { if err != nil { return err } - return g.getObject(ctx, client, dst, bucket, object) + return g.getObject(ctx, client, req.Dst, bucket, object) } func (g *GCSGetter) getObject(ctx context.Context, client *storage.Client, dst, bucket, object string) error { diff --git a/get_gcs_test.go b/get_gcs_test.go index b9d972a66..54b16b09d 100644 --- a/get_gcs_test.go +++ b/get_gcs_test.go @@ -1,6 +1,7 @@ package getter import ( + "context" "net/url" "os" "path/filepath" @@ -32,10 +33,15 @@ func TestGCSGetter(t *testing.T) { g := new(GCSGetter) dst := tempDir(t) + ctx := context.Background() + + req := &Request{ + Dst: dst, + u: testURL("https://www.googleapis.com/storage/v1/go-getter-test/go-getter/folder"), + } // With a dir that doesn't exist - err := g.Get( - dst, testURL("https://www.googleapis.com/storage/v1/go-getter-test/go-getter/folder")) + err := g.Get(ctx, req) if err != nil { t.Fatalf("err: %s", err) } @@ -52,10 +58,15 @@ func TestGCSGetter_subdir(t *testing.T) { g := new(GCSGetter) dst := tempDir(t) + ctx := context.Background() + + req := &Request{ + Dst: dst, + u: testURL("https://www.googleapis.com/storage/v1/go-getter-test/go-getter/folder/subfolder"), + } // With a dir that doesn't exist - err := g.Get( - dst, testURL("https://www.googleapis.com/storage/v1/go-getter-test/go-getter/folder/subfolder")) + err := g.Get(ctx, req) if err != nil { t.Fatalf("err: %s", err) } @@ -73,10 +84,15 @@ func TestGCSGetter_GetFile(t *testing.T) { g := new(GCSGetter) dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) + ctx := context.Background() + + req := &Request{ + Dst: dst, + u: testURL("https://www.googleapis.com/storage/v1/go-getter-test/go-getter/folder/main.tf"), + } // Download - err := g.GetFile( - dst, testURL("https://www.googleapis.com/storage/v1/go-getter-test/go-getter/folder/main.tf")) + err := g.GetFile(ctx, req) if err != nil { t.Fatalf("err: %s", err) } @@ -92,10 +108,15 @@ func TestGCSGetter_GetFile_notfound(t *testing.T) { g := new(GCSGetter) dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) + ctx := context.Background() + + req := &Request{ + Dst: dst, + u: testURL("https://www.googleapis.com/storage/v1/go-getter-test/go-getter/folder/404.tf"), + } // Download - err := g.GetFile( - dst, testURL("https://www.googleapis.com/storage/v1/go-getter-test/go-getter/folder/404.tf")) + err := g.GetFile(ctx, req) if err == nil { t.Fatalf("expected error, got none") } @@ -105,9 +126,10 @@ func TestGCSGetter_ClientMode_dir(t *testing.T) { defer initGCPCredentials(t)() g := new(GCSGetter) + ctx := context.Background() // Check client mode on a key prefix with only a single key. - mode, err := g.ClientMode( + mode, err := g.ClientMode(ctx, testURL("https://www.googleapis.com/storage/v1/go-getter-test/go-getter/folder/subfolder")) if err != nil { t.Fatalf("err: %s", err) @@ -121,9 +143,10 @@ func TestGCSGetter_ClientMode_file(t *testing.T) { defer initGCPCredentials(t)() g := new(GCSGetter) + ctx := context.Background() // Check client mode on a key prefix which contains sub-keys. - mode, err := g.ClientMode( + mode, err := g.ClientMode(ctx, testURL("https://www.googleapis.com/storage/v1/go-getter-test/go-getter/folder/subfolder/sub.tf")) if err != nil { t.Fatalf("err: %s", err) @@ -137,10 +160,11 @@ func TestGCSGetter_ClientMode_notfound(t *testing.T) { defer initGCPCredentials(t)() g := new(GCSGetter) + ctx := context.Background() // Check the client mode when a non-existent key is looked up. This does not // return an error, but rather should just return the file mode. - mode, err := g.ClientMode( + mode, err := g.ClientMode(ctx, testURL("https://www.googleapis.com/storage/v1/go-getter-test/go-getter/foobar")) if err != nil { t.Fatalf("err: %s", err) diff --git a/get_git.go b/get_git.go index 1b9f4be81..e236067bc 100644 --- a/get_git.go +++ b/get_git.go @@ -28,12 +28,11 @@ type GitGetter struct { var defaultBranchRegexp = regexp.MustCompile(`\s->\sorigin/(.*)`) -func (g *GitGetter) ClientMode(_ *url.URL) (ClientMode, error) { +func (g *GitGetter) ClientMode(_ context.Context, u *url.URL) (ClientMode, error) { return ClientModeDir, nil } -func (g *GitGetter) Get(dst string, u *url.URL) error { - ctx := g.Context() +func (g *GitGetter) Get(ctx context.Context, req *Request) error { if _, err := exec.LookPath("git"); err != nil { return fmt.Errorf("git must be available and on the PATH") } @@ -44,7 +43,7 @@ func (g *GitGetter) Get(dst string, u *url.URL) error { // // This is not necessary in versions of Go which have patched // CVE-2019-14809 (e.g. Go 1.12.8+) - if portStr := u.Port(); portStr != "" { + if portStr := req.u.Port(); portStr != "" { if _, err := strconv.ParseUint(portStr, 10, 16); err != nil { return fmt.Errorf("invalid port number %q; if using the \"scp-like\" git address scheme where a colon introduces the path instead, remove the ssh:// portion and use just the git:: prefix", portStr) } @@ -53,7 +52,7 @@ func (g *GitGetter) Get(dst string, u *url.URL) error { // Extract some query parameters we use var ref, sshKey string var depth int - q := u.Query() + q := req.u.Query() if len(q) > 0 { ref = q.Get("ref") q.Del("ref") @@ -67,9 +66,9 @@ func (g *GitGetter) Get(dst string, u *url.URL) error { q.Del("depth") // Copy the URL - var newU url.URL = *u - u = &newU - u.RawQuery = q.Encode() + var newU url.URL = *req.u + req.u = &newU + req.u.RawQuery = q.Encode() } var sshKeyFile string @@ -107,14 +106,14 @@ func (g *GitGetter) Get(dst string, u *url.URL) error { } // Clone or update the repository - _, err := os.Stat(dst) + _, err := os.Stat(req.Dst) if err != nil && !os.IsNotExist(err) { return err } if err == nil { - err = g.update(ctx, dst, sshKeyFile, ref, depth) + err = g.update(ctx, req.Dst, sshKeyFile, ref, depth) } else { - err = g.clone(ctx, dst, sshKeyFile, u, depth) + err = g.clone(ctx, sshKeyFile, depth, req) } if err != nil { return err @@ -122,18 +121,18 @@ func (g *GitGetter) Get(dst string, u *url.URL) error { // Next: check out the proper tag/branch if it is specified, and checkout if ref != "" { - if err := g.checkout(dst, ref); err != nil { + if err := g.checkout(req.Dst, ref); err != nil { return err } } // Lastly, download any/all submodules. - return g.fetchSubmodules(ctx, dst, sshKeyFile, depth) + return g.fetchSubmodules(ctx, req.Dst, sshKeyFile, depth) } // GetFile for Git doesn't support updating at this time. It will download // the file every time. -func (g *GitGetter) GetFile(dst string, u *url.URL) error { +func (g *GitGetter) GetFile(ctx context.Context, req *Request) error { td, tdcloser, err := safetemp.Dir("", "getter") if err != nil { return err @@ -142,22 +141,26 @@ func (g *GitGetter) GetFile(dst string, u *url.URL) error { // Get the filename, and strip the filename from the URL so we can // just get the repository directly. - filename := filepath.Base(u.Path) - u.Path = filepath.Dir(u.Path) + filename := filepath.Base(req.u.Path) + req.u.Path = filepath.Dir(req.u.Path) + dst := req.Dst + req.Dst = td // Get the full repository - if err := g.Get(td, u); err != nil { + if err := g.Get(ctx, req); err != nil { return err } // Copy the single file - u, err = urlhelper.Parse(fmtFileURL(filepath.Join(td, filename))) + req.u, err = urlhelper.Parse(fmtFileURL(filepath.Join(td, filename))) if err != nil { return err } - fg := &FileGetter{Copy: true} - return fg.GetFile(dst, u) + fg := &FileGetter{} + req.Copy = true + req.Dst = dst + return fg.GetFile(ctx, req) } func (g *GitGetter) checkout(dst string, ref string) error { @@ -166,14 +169,14 @@ func (g *GitGetter) checkout(dst string, ref string) error { return getRunCommand(cmd) } -func (g *GitGetter) clone(ctx context.Context, dst, sshKeyFile string, u *url.URL, depth int) error { +func (g *GitGetter) clone(ctx context.Context, sshKeyFile string, depth int, req *Request) error { args := []string{"clone"} if depth > 0 { args = append(args, "--depth", strconv.Itoa(depth)) } - args = append(args, u.String(), dst) + args = append(args, req.u.String(), req.Dst) cmd := exec.CommandContext(ctx, "git", args...) setupGitEnv(cmd, sshKeyFile) return getRunCommand(cmd) diff --git a/get_git_test.go b/get_git_test.go index b1a7058de..02f9e49f7 100644 --- a/get_git_test.go +++ b/get_git_test.go @@ -2,6 +2,7 @@ package getter import ( "bytes" + "context" "encoding/base64" "io/ioutil" "net/url" @@ -31,6 +32,7 @@ func TestGitGetter(t *testing.T) { if !testHasGit { t.Skip("git not found, skipping") } + ctx := context.Background() g := new(GitGetter) dst := tempDir(t) @@ -38,8 +40,13 @@ func TestGitGetter(t *testing.T) { repo := testGitRepo(t, "basic") repo.commitFile("foo.txt", "hello") + req := &Request{ + Dst: dst, + u: repo.url, + } + // With a dir that doesn't exist - if err := g.Get(dst, repo.url); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -54,6 +61,7 @@ func TestGitGetter_branch(t *testing.T) { if !testHasGit { t.Skip("git not found, skipping") } + ctx := context.Background() g := new(GitGetter) dst := tempDir(t) @@ -66,7 +74,10 @@ func TestGitGetter_branch(t *testing.T) { q.Add("ref", "test-branch") repo.url.RawQuery = q.Encode() - if err := g.Get(dst, repo.url); err != nil { + if err := g.Get(ctx, &Request{ + Dst: dst, + u: repo.url, + }); err != nil { t.Fatalf("err: %s", err) } @@ -77,7 +88,10 @@ func TestGitGetter_branch(t *testing.T) { } // Get again should work - if err := g.Get(dst, repo.url); err != nil { + if err := g.Get(ctx, &Request{ + Dst: dst, + u: repo.url, + }); err != nil { t.Fatalf("err: %s", err) } @@ -94,6 +108,7 @@ func TestGitGetter_remoteWithoutMaster(t *testing.T) { t.Skip() } + ctx := context.Background() g := new(GitGetter) dst := tempDir(t) @@ -103,8 +118,13 @@ func TestGitGetter_remoteWithoutMaster(t *testing.T) { q := repo.url.Query() repo.url.RawQuery = q.Encode() + req := &Request{ + Src: repo.url.String(), + u: repo.url, + Dst: dst, + } - if err := g.Get(dst, repo.url); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -115,7 +135,7 @@ func TestGitGetter_remoteWithoutMaster(t *testing.T) { } // Get again should work - if err := g.Get(dst, repo.url); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -131,6 +151,7 @@ func TestGitGetter_shallowClone(t *testing.T) { t.Log("git not found, skipping") t.Skip() } + ctx := context.Background() g := new(GitGetter) dst := tempDir(t) @@ -144,7 +165,10 @@ func TestGitGetter_shallowClone(t *testing.T) { q.Add("depth", "1") repo.url.RawQuery = q.Encode() - if err := g.Get(dst, repo.url); err != nil { + if err := g.Get(ctx, &Request{ + Dst: dst, + u: repo.url, + }); err != nil { t.Fatalf("err: %s", err) } @@ -166,6 +190,7 @@ func TestGitGetter_branchUpdate(t *testing.T) { if !testHasGit { t.Skip("git not found, skipping") } + ctx := context.Background() g := new(GitGetter) dst := tempDir(t) @@ -179,7 +204,11 @@ func TestGitGetter_branchUpdate(t *testing.T) { q := repo.url.Query() q.Add("ref", "test-branch") repo.url.RawQuery = q.Encode() - if err := g.Get(dst, repo.url); err != nil { + + if err := g.Get(ctx, &Request{ + Dst: dst, + u: repo.url, + }); err != nil { t.Fatalf("err: %s", err) } @@ -193,7 +222,10 @@ func TestGitGetter_branchUpdate(t *testing.T) { repo.commitFile("branch-update.txt", "branch-update") // Get again should work - if err := g.Get(dst, repo.url); err != nil { + if err := g.Get(ctx, &Request{ + Dst: dst, + u: repo.url, + }); err != nil { t.Fatalf("err: %s", err) } @@ -208,6 +240,7 @@ func TestGitGetter_tag(t *testing.T) { if !testHasGit { t.Skip("git not found, skipping") } + ctx := context.Background() g := new(GitGetter) dst := tempDir(t) @@ -220,7 +253,10 @@ func TestGitGetter_tag(t *testing.T) { q.Add("ref", "v1.0") repo.url.RawQuery = q.Encode() - if err := g.Get(dst, repo.url); err != nil { + if err := g.Get(ctx, &Request{ + Dst: dst, + u: repo.url, + }); err != nil { t.Fatalf("err: %s", err) } @@ -231,7 +267,10 @@ func TestGitGetter_tag(t *testing.T) { } // Get again should work - if err := g.Get(dst, repo.url); err != nil { + if err := g.Get(ctx, &Request{ + Dst: dst, + u: repo.url, + }); err != nil { t.Fatalf("err: %s", err) } @@ -246,6 +285,7 @@ func TestGitGetter_GetFile(t *testing.T) { if !testHasGit { t.Skip("git not found, skipping") } + ctx := context.Background() g := new(GitGetter) dst := tempTestFile(t) @@ -256,7 +296,12 @@ func TestGitGetter_GetFile(t *testing.T) { // Download the file repo.url.Path = filepath.Join(repo.url.Path, "file.txt") - if err := g.GetFile(dst, repo.url); err != nil { + req := &Request{ + Dst: dst, + u: repo.url, + } + + if err := g.GetFile(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -310,6 +355,7 @@ func TestGitGetter_sshKey(t *testing.T) { if !testHasGit { t.Skip("git not found, skipping") } + ctx := context.Background() g := new(GitGetter) dst := tempDir(t) @@ -326,7 +372,12 @@ func TestGitGetter_sshKey(t *testing.T) { t.Fatal(err) } - if err := g.Get(dst, u); err != nil { + req := &Request{ + Dst: dst, + u: u, + } + + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -341,6 +392,7 @@ func TestGitGetter_sshSCPStyle(t *testing.T) { t.Skip("git not found, skipping") } + ctx := context.Background() g := new(GitGetter) dst := tempDir(t) @@ -352,13 +404,14 @@ func TestGitGetter_sshSCPStyle(t *testing.T) { // This test exercises the combination of the git detector and the // git getter, to make sure that together they make scp-style URLs work. - client := &Client{ + req := &Request{ Src: "git@github.com:hashicorp/test-private-repo?sshkey=" + encodedKey, Dst: dst, Pwd: ".", Mode: ClientModeDir, - + } + client := &Client{ Detectors: []Detector{ new(GitDetector), }, @@ -367,7 +420,7 @@ func TestGitGetter_sshSCPStyle(t *testing.T) { }, } - if err := client.Get(); err != nil { + if err := client.Get(ctx, req); err != nil { t.Fatalf("client.Get failed: %s", err) } @@ -382,6 +435,7 @@ func TestGitGetter_sshExplicitPort(t *testing.T) { t.Skip("git not found, skipping") } + ctx := context.Background() g := new(GitGetter) dst := tempDir(t) @@ -393,12 +447,14 @@ func TestGitGetter_sshExplicitPort(t *testing.T) { // This test exercises the combination of the git detector and the // git getter, to make sure that together they make scp-style URLs work. - client := &Client{ + req := &Request{ Src: "git::ssh://git@github.com:22/hashicorp/test-private-repo?sshkey=" + encodedKey, Dst: dst, Pwd: ".", Mode: ClientModeDir, + } + client := &Client{ Detectors: []Detector{ new(GitDetector), @@ -408,7 +464,7 @@ func TestGitGetter_sshExplicitPort(t *testing.T) { }, } - if err := client.Get(); err != nil { + if err := client.Get(ctx, req); err != nil { t.Fatalf("client.Get failed: %s", err) } @@ -423,6 +479,7 @@ func TestGitGetter_sshSCPStyleInvalidScheme(t *testing.T) { t.Skip("git not found, skipping") } + ctx := context.Background() g := new(GitGetter) dst := tempDir(t) @@ -434,13 +491,15 @@ func TestGitGetter_sshSCPStyleInvalidScheme(t *testing.T) { // This test exercises the combination of the git detector and the // git getter, to make sure that together they make scp-style URLs work. - client := &Client{ + req := &Request{ Src: "git::ssh://git@github.com:hashicorp/test-private-repo?sshkey=" + encodedKey, Dst: dst, Pwd: ".", Mode: ClientModeDir, + } + client := &Client{ Detectors: []Detector{ new(GitDetector), }, @@ -449,7 +508,7 @@ func TestGitGetter_sshSCPStyleInvalidScheme(t *testing.T) { }, } - err := client.Get() + err := client.Get(ctx, req) if err == nil { t.Fatalf("get succeeded; want error") } @@ -465,6 +524,7 @@ func TestGitGetter_submodule(t *testing.T) { if !testHasGit { t.Skip("git not found, skipping") } + ctx := context.Background() g := new(GitGetter) dst := tempDir(t) @@ -495,8 +555,13 @@ func TestGitGetter_submodule(t *testing.T) { p.git("submodule", "add", "-f", relpath(p.dir, c.dir)) p.git("commit", "-m", "Add child submodule") + req := &Request{ + Dst: dst, + u: p.url, + } + // Clone the root repository - if err := g.Get(dst, p.url); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } diff --git a/get_hg.go b/get_hg.go index 290649c91..365c3caae 100644 --- a/get_hg.go +++ b/get_hg.go @@ -19,17 +19,16 @@ type HgGetter struct { getter } -func (g *HgGetter) ClientMode(_ *url.URL) (ClientMode, error) { +func (g *HgGetter) ClientMode(ctx context.Context, _ *url.URL) (ClientMode, error) { return ClientModeDir, nil } -func (g *HgGetter) Get(dst string, u *url.URL) error { - ctx := g.Context() +func (g *HgGetter) Get(ctx context.Context, req *Request) error { if _, err := exec.LookPath("hg"); err != nil { return fmt.Errorf("hg must be available and on the PATH") } - newURL, err := urlhelper.Parse(u.String()) + newURL, err := urlhelper.Parse(req.u.String()) if err != nil { return err } @@ -48,26 +47,26 @@ func (g *HgGetter) Get(dst string, u *url.URL) error { newURL.RawQuery = q.Encode() } - _, err = os.Stat(dst) + _, err = os.Stat(req.Dst) if err != nil && !os.IsNotExist(err) { return err } if err != nil { - if err := g.clone(dst, newURL); err != nil { + if err := g.clone(req.Dst, newURL); err != nil { return err } } - if err := g.pull(dst, newURL); err != nil { + if err := g.pull(req.Dst, newURL); err != nil { return err } - return g.update(ctx, dst, newURL, rev) + return g.update(ctx, req.Dst, newURL, rev) } // GetFile for Hg doesn't support updating at this time. It will download // the file every time. -func (g *HgGetter) GetFile(dst string, u *url.URL) error { +func (g *HgGetter) GetFile(ctx context.Context, req *Request) error { // Create a temporary directory to store the full source. This has to be // a non-existent directory. td, tdcloser, err := safetemp.Dir("", "getter") @@ -78,27 +77,31 @@ func (g *HgGetter) GetFile(dst string, u *url.URL) error { // Get the filename, and strip the filename from the URL so we can // just get the repository directly. - filename := filepath.Base(u.Path) - u.Path = filepath.ToSlash(filepath.Dir(u.Path)) + filename := filepath.Base(req.u.Path) + req.u.Path = filepath.Dir(req.u.Path) + dst := req.Dst + req.Dst = td // If we're on Windows, we need to set the host to "localhost" for hg if runtime.GOOS == "windows" { - u.Host = "localhost" + req.u.Host = "localhost" } // Get the full repository - if err := g.Get(td, u); err != nil { + if err := g.Get(ctx, req); err != nil { return err } // Copy the single file - u, err = urlhelper.Parse(fmtFileURL(filepath.Join(td, filename))) + req.u, err = urlhelper.Parse(fmtFileURL(filepath.Join(td, filename))) if err != nil { return err } - fg := &FileGetter{Copy: true, getter: g.getter} - return fg.GetFile(dst, u) + fg := &FileGetter{} + req.Copy = true + req.Dst = dst + return fg.GetFile(ctx, req) } func (g *HgGetter) clone(dst string, u *url.URL) error { diff --git a/get_hg_test.go b/get_hg_test.go index ee1657945..442693dfa 100644 --- a/get_hg_test.go +++ b/get_hg_test.go @@ -1,6 +1,7 @@ package getter import ( + "context" "os" "os/exec" "path/filepath" @@ -24,12 +25,18 @@ func TestHgGetter(t *testing.T) { t.Log("hg not found, skipping") t.Skip() } + ctx := context.Background() g := new(HgGetter) dst := tempDir(t) + req := &Request{ + Dst: dst, + u: testModuleURL("basic-hg"), + } + // With a dir that doesn't exist - if err := g.Get(dst, testModuleURL("basic-hg")); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -45,6 +52,7 @@ func TestHgGetter_branch(t *testing.T) { t.Log("hg not found, skipping") t.Skip() } + ctx := context.Background() g := new(HgGetter) dst := tempDir(t) @@ -54,7 +62,12 @@ func TestHgGetter_branch(t *testing.T) { q.Add("rev", "test-branch") url.RawQuery = q.Encode() - if err := g.Get(dst, url); err != nil { + req := &Request{ + Dst: dst, + u: url, + } + + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -65,7 +78,7 @@ func TestHgGetter_branch(t *testing.T) { } // Get again should work - if err := g.Get(dst, url); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -81,13 +94,19 @@ func TestHgGetter_GetFile(t *testing.T) { t.Log("hg not found, skipping") t.Skip() } + ctx := context.Background() g := new(HgGetter) dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) + req := &Request{ + Dst: dst, + u: testModuleURL("basic-hg/foo.txt"), + } + // Download - if err := g.GetFile(dst, testModuleURL("basic-hg/foo.txt")); err != nil { + if err := g.GetFile(ctx, req); err != nil { t.Fatalf("err: %s", err) } diff --git a/get_http.go b/get_http.go index 618a411f9..cfcc6b669 100644 --- a/get_http.go +++ b/get_http.go @@ -52,22 +52,21 @@ type HttpGetter struct { Header http.Header } -func (g *HttpGetter) ClientMode(u *url.URL) (ClientMode, error) { +func (g *HttpGetter) ClientMode(ctx context.Context, u *url.URL) (ClientMode, error) { if strings.HasSuffix(u.Path, "/") { return ClientModeDir, nil } return ClientModeFile, nil } -func (g *HttpGetter) Get(dst string, u *url.URL) error { - ctx := g.Context() +func (g *HttpGetter) Get(ctx context.Context, req *Request) error { // Copy the URL so we can modify it - var newU url.URL = *u - u = &newU + var newU url.URL = *req.u + req.u = &newU if g.Netrc { // Add auth from netrc if we can - if err := addAuthFromNetrc(u); err != nil { + if err := addAuthFromNetrc(req.u); err != nil { return err } } @@ -77,21 +76,20 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { } // Add terraform-get to the parameter. - q := u.Query() + q := req.u.Query() q.Add("terraform-get", "1") - u.RawQuery = q.Encode() + req.u.RawQuery = q.Encode() // Get the URL - req, err := http.NewRequest("GET", u.String(), nil) + httpReq, err := http.NewRequest("GET", req.u.String(), nil) if err != nil { return err } if g.Header != nil { - req.Header = g.Header.Clone() + httpReq.Header = g.Header.Clone() } - - resp, err := g.Client.Do(req) + resp, err := g.Client.Do(httpReq) if err != nil { return err } @@ -118,16 +116,16 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { // If there is a subdir component, then we download the root separately // into a temporary directory, then copy over the proper subdir. source, subDir := SourceDirSubdir(source) + req = &Request{ + Dir: true, + Src: source, + Dst: req.Dst, + } if subDir == "" { - var opts []ClientOption - if g.client != nil { - opts = g.client.Options - } - return Get(dst, source, opts...) + return DefaultClient.Get(ctx, req) } - // We have a subdir, time to jump some hoops - return g.getSubdir(ctx, dst, source, subDir) + return g.getSubdir(ctx, req.Dst, source, subDir) } // GetFile fetches the file from src and stores it at dst. @@ -136,20 +134,19 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { // older version of the destination file does not exist, else it will be either // falsely identified as being replaced, or corrupted with extra bytes // appended. -func (g *HttpGetter) GetFile(dst string, src *url.URL) error { - ctx := g.Context() +func (g *HttpGetter) GetFile(ctx context.Context, req *Request) error { if g.Netrc { // Add auth from netrc if we can - if err := addAuthFromNetrc(src); err != nil { + if err := addAuthFromNetrc(req.u); err != nil { return err } } // Create all the parent directories if needed - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(req.Dst), 0755); err != nil { return err } - f, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE, os.FileMode(0666)) + f, err := os.OpenFile(req.Dst, os.O_RDWR|os.O_CREATE, os.FileMode(0666)) if err != nil { return err } @@ -164,14 +161,14 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { // We first make a HEAD request so we can check // if the server supports range queries. If the server/URL doesn't // support HEAD requests, we just fall back to GET. - req, err := http.NewRequest("HEAD", src.String(), nil) + httpReq, err := http.NewRequest("HEAD", req.u.String(), nil) if err != nil { return err } if g.Header != nil { - req.Header = g.Header.Clone() + httpReq.Header = g.Header.Clone() } - headResp, err := g.Client.Do(req) + headResp, err := g.Client.Do(httpReq) if err == nil { headResp.Body.Close() if headResp.StatusCode == 200 { @@ -181,7 +178,7 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { if fi, err := f.Stat(); err == nil { if _, err = f.Seek(0, io.SeekEnd); err == nil { currentFileSize = fi.Size() - req.Header.Set("Range", fmt.Sprintf("bytes=%d-", currentFileSize)) + httpReq.Header.Set("Range", fmt.Sprintf("bytes=%d-", currentFileSize)) if currentFileSize >= headResp.ContentLength { // file already present return nil @@ -191,9 +188,9 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { } } } - req.Method = "GET" + httpReq.Method = "GET" - resp, err := g.Client.Do(req) + resp, err := g.Client.Do(httpReq) if err != nil { return err } @@ -207,10 +204,10 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error { body := resp.Body - if g.client != nil && g.client.ProgressListener != nil { + if req.ProgressListener != nil { // track download - fn := filepath.Base(src.EscapedPath()) - body = g.client.ProgressListener.TrackProgress(fn, currentFileSize, currentFileSize+resp.ContentLength, resp.Body) + fn := filepath.Base(req.u.EscapedPath()) + body = req.ProgressListener.TrackProgress(fn, currentFileSize, currentFileSize+resp.ContentLength, resp.Body) } defer body.Close() @@ -232,12 +229,8 @@ func (g *HttpGetter) getSubdir(ctx context.Context, dst, source, subDir string) } defer tdcloser.Close() - var opts []ClientOption - if g.client != nil { - opts = g.client.Options - } // Download that into the given directory - if err := Get(td, source, opts...); err != nil { + if err := Get(ctx, td, source); err != nil { return err } diff --git a/get_http_test.go b/get_http_test.go index 6d7fb90ac..d1fd3c012 100644 --- a/get_http_test.go +++ b/get_http_test.go @@ -1,6 +1,7 @@ package getter import ( + "context" "crypto/sha256" "encoding/hex" "errors" @@ -23,6 +24,7 @@ func TestHttpGetter_impl(t *testing.T) { func TestHttpGetter_header(t *testing.T) { ln := testHttpServer(t) defer ln.Close() + ctx := context.Background() g := new(HttpGetter) dst := tempDir(t) @@ -33,8 +35,13 @@ func TestHttpGetter_header(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/header" + req := &Request{ + Dst: dst, + u: &u, + } + // Get it! - if err := g.Get(dst, &u); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -48,6 +55,7 @@ func TestHttpGetter_header(t *testing.T) { func TestHttpGetter_requestHeader(t *testing.T) { ln := testHttpServer(t) defer ln.Close() + ctx := context.Background() g := new(HttpGetter) g.Header = make(http.Header) @@ -61,8 +69,13 @@ func TestHttpGetter_requestHeader(t *testing.T) { u.Path = "/expect-header" u.RawQuery = "expected=X-Foobar" + req := &Request{ + Dst: dst, + u: &u, + } + // Get it! - if err := g.GetFile(dst, &u); err != nil { + if err := g.GetFile(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -76,6 +89,7 @@ func TestHttpGetter_requestHeader(t *testing.T) { func TestHttpGetter_meta(t *testing.T) { ln := testHttpServer(t) defer ln.Close() + ctx := context.Background() g := new(HttpGetter) dst := tempDir(t) @@ -86,8 +100,13 @@ func TestHttpGetter_meta(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/meta" + req := &Request{ + Dst: dst, + u: &u, + } + // Get it! - if err := g.Get(dst, &u); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -101,6 +120,7 @@ func TestHttpGetter_meta(t *testing.T) { func TestHttpGetter_metaSubdir(t *testing.T) { ln := testHttpServer(t) defer ln.Close() + ctx := context.Background() g := new(HttpGetter) dst := tempDir(t) @@ -111,8 +131,13 @@ func TestHttpGetter_metaSubdir(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/meta-subdir" + req := &Request{ + Dst: dst, + u: &u, + } + // Get it! - if err := g.Get(dst, &u); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -126,6 +151,7 @@ func TestHttpGetter_metaSubdir(t *testing.T) { func TestHttpGetter_metaSubdirGlob(t *testing.T) { ln := testHttpServer(t) defer ln.Close() + ctx := context.Background() g := new(HttpGetter) dst := tempDir(t) @@ -136,8 +162,13 @@ func TestHttpGetter_metaSubdirGlob(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/meta-subdir-glob" + req := &Request{ + Dst: dst, + u: &u, + } + // Get it! - if err := g.Get(dst, &u); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -151,6 +182,7 @@ func TestHttpGetter_metaSubdirGlob(t *testing.T) { func TestHttpGetter_none(t *testing.T) { ln := testHttpServer(t) defer ln.Close() + ctx := context.Background() g := new(HttpGetter) dst := tempDir(t) @@ -161,8 +193,13 @@ func TestHttpGetter_none(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/none" + req := &Request{ + Dst: dst, + u: &u, + } + // Get it! - if err := g.Get(dst, &u); err == nil { + if err := g.Get(ctx, req); err == nil { t.Fatal("should error") } } @@ -201,9 +238,10 @@ func TestHttpGetter_resume(t *testing.T) { RawQuery: "checksum=" + checksum, } t.Logf("url: %s", u.String()) + ctx := context.Background() // Finish getting it! - if err := GetFile(dst, u.String()); err != nil { + if err := GetFile(ctx, dst, u.String()); err != nil { t.Fatalf("finishing download should not error: %v", err) } @@ -217,7 +255,7 @@ func TestHttpGetter_resume(t *testing.T) { } // Get it again - if err := GetFile(dst, u.String()); err != nil { + if err := GetFile(ctx, dst, u.String()); err != nil { t.Fatalf("should not error: %v", err) } } @@ -257,9 +295,10 @@ func TestHttpGetter_resumeNoRange(t *testing.T) { RawQuery: "checksum=" + checksum, } t.Logf("url: %s", u.String()) + ctx := context.Background() // Finish getting it! - if err := GetFile(dst, u.String()); err != nil { + if err := GetFile(ctx, dst, u.String()); err != nil { t.Fatalf("finishing download should not error: %v", err) } @@ -276,6 +315,7 @@ func TestHttpGetter_resumeNoRange(t *testing.T) { func TestHttpGetter_file(t *testing.T) { ln := testHttpServer(t) defer ln.Close() + ctx := context.Background() g := new(HttpGetter) dst := tempTestFile(t) @@ -286,8 +326,13 @@ func TestHttpGetter_file(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/file" + req := &Request{ + Dst: dst, + u: &u, + } + // Get it! - if err := g.GetFile(dst, &u); err != nil { + if err := g.GetFile(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -301,6 +346,7 @@ func TestHttpGetter_file(t *testing.T) { func TestHttpGetter_auth(t *testing.T) { ln := testHttpServer(t) defer ln.Close() + ctx := context.Background() g := new(HttpGetter) dst := tempDir(t) @@ -312,8 +358,13 @@ func TestHttpGetter_auth(t *testing.T) { u.Path = "/meta-auth" u.User = url.UserPassword("foo", "bar") + req := &Request{ + Dst: dst, + u: &u, + } + // Get it! - if err := g.Get(dst, &u); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -327,6 +378,7 @@ func TestHttpGetter_auth(t *testing.T) { func TestHttpGetter_authNetrc(t *testing.T) { ln := testHttpServer(t) defer ln.Close() + ctx := context.Background() g := new(HttpGetter) dst := tempDir(t) @@ -342,8 +394,13 @@ func TestHttpGetter_authNetrc(t *testing.T) { defer closer() defer tempEnv(t, "NETRC", path)() + req := &Request{ + Dst: dst, + u: &u, + } + // Get it! - if err := g.Get(dst, &u); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -371,6 +428,7 @@ func TestHttpGetter_cleanhttp(t *testing.T) { defer func() { http.DefaultClient.Transport = http.DefaultTransport }() + ctx := context.Background() g := new(HttpGetter) dst := tempDir(t) @@ -381,8 +439,13 @@ func TestHttpGetter_cleanhttp(t *testing.T) { u.Host = ln.Addr().String() u.Path = "/header" + req := &Request{ + Dst: dst, + u: &u, + } + // Get it! - if err := g.Get(dst, &u); err != nil { + if err := g.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } } diff --git a/get_mock.go b/get_mock.go index e2a98ea28..786677e35 100644 --- a/get_mock.go +++ b/get_mock.go @@ -1,6 +1,7 @@ package getter import ( + "context" "net/url" ) @@ -23,30 +24,30 @@ type MockGetter struct { GetFileErr error } -func (g *MockGetter) Get(dst string, u *url.URL) error { +func (g *MockGetter) Get(ctx context.Context, req *Request) error { g.GetCalled = true - g.GetDst = dst - g.GetURL = u + g.GetDst = req.Dst + g.GetURL = req.u if g.Proxy != nil { - return g.Proxy.Get(dst, u) + return g.Proxy.Get(ctx, req) } return g.GetErr } -func (g *MockGetter) GetFile(dst string, u *url.URL) error { +func (g *MockGetter) GetFile(ctx context.Context, req *Request) error { g.GetFileCalled = true - g.GetFileDst = dst - g.GetFileURL = u + g.GetFileDst = req.Dst + g.GetFileURL = req.u if g.Proxy != nil { - return g.Proxy.GetFile(dst, u) + return g.Proxy.GetFile(ctx, req) } return g.GetFileErr } -func (g *MockGetter) ClientMode(u *url.URL) (ClientMode, error) { +func (g *MockGetter) ClientMode(ctx context.Context, u *url.URL) (ClientMode, error) { if l := len(u.Path); l > 0 && u.Path[l-1:] == "/" { return ClientModeDir, nil } diff --git a/get_s3.go b/get_s3.go index 93eeb0b81..d45962a57 100644 --- a/get_s3.go +++ b/get_s3.go @@ -22,7 +22,7 @@ type S3Getter struct { getter } -func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { +func (g *S3Getter) ClientMode(ctx context.Context, u *url.URL) (ClientMode, error) { // Parse URL region, bucket, path, _, creds, err := g.parseUrl(u) if err != nil { @@ -61,34 +61,33 @@ func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { return ClientModeFile, nil } -func (g *S3Getter) Get(dst string, u *url.URL) error { - ctx := g.Context() +func (g *S3Getter) Get(ctx context.Context, req *Request) error { // Parse URL - region, bucket, path, _, creds, err := g.parseUrl(u) + region, bucket, path, _, creds, err := g.parseUrl(req.u) if err != nil { return err } // Remove destination if it already exists - _, err = os.Stat(dst) + _, err = os.Stat(req.Dst) if err != nil && !os.IsNotExist(err) { return err } if err == nil { // Remove the destination - if err := os.RemoveAll(dst); err != nil { + if err := os.RemoveAll(req.Dst); err != nil { return err } } // Create all the parent directories - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + if err := os.MkdirAll(filepath.Dir(req.Dst), 0755); err != nil { return err } - config := g.getAWSConfig(region, u, creds) + config := g.getAWSConfig(region, req.u, creds) sess := session.New(config) client := s3.New(sess) @@ -96,15 +95,15 @@ func (g *S3Getter) Get(dst string, u *url.URL) error { lastMarker := "" hasMore := true for hasMore { - req := &s3.ListObjectsInput{ + s3Req := &s3.ListObjectsInput{ Bucket: aws.String(bucket), Prefix: aws.String(path), } if lastMarker != "" { - req.Marker = aws.String(lastMarker) + s3Req.Marker = aws.String(lastMarker) } - resp, err := client.ListObjects(req) + resp, err := client.ListObjects(s3Req) if err != nil { return err } @@ -126,7 +125,7 @@ func (g *S3Getter) Get(dst string, u *url.URL) error { if err != nil { return err } - objDst = filepath.Join(dst, objDst) + objDst = filepath.Join(req.Dst, objDst) if err := g.getObject(ctx, client, objDst, bucket, objPath, ""); err != nil { return err @@ -137,17 +136,16 @@ func (g *S3Getter) Get(dst string, u *url.URL) error { return nil } -func (g *S3Getter) GetFile(dst string, u *url.URL) error { - ctx := g.Context() - region, bucket, path, version, creds, err := g.parseUrl(u) +func (g *S3Getter) GetFile(ctx context.Context, req *Request) error { + region, bucket, path, version, creds, err := g.parseUrl(req.u) if err != nil { return err } - config := g.getAWSConfig(region, u, creds) + config := g.getAWSConfig(region, req.u, creds) sess := session.New(config) client := s3.New(sess) - return g.getObject(ctx, client, dst, bucket, path, version) + return g.getObject(ctx, client, req.Dst, bucket, path, version) } func (g *S3Getter) getObject(ctx context.Context, client *s3.S3, dst, bucket, key, version string) error { diff --git a/get_s3_test.go b/get_s3_test.go index e2233a28b..303b1ce04 100644 --- a/get_s3_test.go +++ b/get_s3_test.go @@ -1,6 +1,7 @@ package getter import ( + "context" "net/url" "os" "path/filepath" @@ -26,12 +27,16 @@ func TestS3Getter_impl(t *testing.T) { } func TestS3Getter(t *testing.T) { + ctx := context.Background() + g := new(S3Getter) dst := tempDir(t) - + req := &Request{ + Dst: dst, + u: testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder"), + } // With a dir that doesn't exist - err := g.Get( - dst, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder")) + err := g.Get(ctx, req) if err != nil { t.Fatalf("err: %s", err) } @@ -44,12 +49,16 @@ func TestS3Getter(t *testing.T) { } func TestS3Getter_subdir(t *testing.T) { + ctx := context.Background() + g := new(S3Getter) dst := tempDir(t) - + req := &Request{ + Dst: dst, + u: testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/subfolder"), + } // With a dir that doesn't exist - err := g.Get( - dst, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/subfolder")) + err := g.Get(ctx, req) if err != nil { t.Fatalf("err: %s", err) } @@ -62,13 +71,18 @@ func TestS3Getter_subdir(t *testing.T) { } func TestS3Getter_GetFile(t *testing.T) { + ctx := context.Background() + g := new(S3Getter) dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) + req := &Request{ + Dst: dst, + u: testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/main.tf"), + } // Download - err := g.GetFile( - dst, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/main.tf")) + err := g.GetFile(ctx, req) if err != nil { t.Fatalf("err: %s", err) } @@ -81,14 +95,19 @@ func TestS3Getter_GetFile(t *testing.T) { } func TestS3Getter_GetFile_badParams(t *testing.T) { + ctx := context.Background() + g := new(S3Getter) dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) + req := &Request{ + Dst: dst, + u: testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/main.tf?aws_access_key_id=foo&aws_access_key_secret=bar&aws_access_token=baz"), + } + // Download - err := g.GetFile( - dst, - testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/main.tf?aws_access_key_id=foo&aws_access_key_secret=bar&aws_access_token=baz")) + err := g.GetFile(ctx, req) if err == nil { t.Fatalf("expected error, got none") } @@ -99,23 +118,31 @@ func TestS3Getter_GetFile_badParams(t *testing.T) { } func TestS3Getter_GetFile_notfound(t *testing.T) { + ctx := context.Background() + g := new(S3Getter) dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) + req := &Request{ + Dst: dst, + u: testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/404.tf"), + } + // Download - err := g.GetFile( - dst, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/404.tf")) + err := g.GetFile(ctx, req) if err == nil { t.Fatalf("expected error, got none") } } func TestS3Getter_ClientMode_dir(t *testing.T) { + ctx := context.Background() + g := new(S3Getter) // Check client mode on a key prefix with only a single key. - mode, err := g.ClientMode( + mode, err := g.ClientMode(ctx, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder")) if err != nil { t.Fatalf("err: %s", err) @@ -126,10 +153,12 @@ func TestS3Getter_ClientMode_dir(t *testing.T) { } func TestS3Getter_ClientMode_file(t *testing.T) { + ctx := context.Background() + g := new(S3Getter) // Check client mode on a key prefix which contains sub-keys. - mode, err := g.ClientMode( + mode, err := g.ClientMode(ctx, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/folder/main.tf")) if err != nil { t.Fatalf("err: %s", err) @@ -140,6 +169,8 @@ func TestS3Getter_ClientMode_file(t *testing.T) { } func TestS3Getter_ClientMode_notfound(t *testing.T) { + ctx := context.Background() + g := new(S3Getter) // Check the client mode when a non-existent key is looked up. This does not @@ -147,7 +178,7 @@ func TestS3Getter_ClientMode_notfound(t *testing.T) { // can return an appropriate error later on. This also checks that the // prefix is handled properly (e.g., "/fold" and "/folder" don't put the // client mode into "dir". - mode, err := g.ClientMode( + mode, err := g.ClientMode(ctx, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/fold")) if err != nil { t.Fatalf("err: %s", err) @@ -158,11 +189,13 @@ func TestS3Getter_ClientMode_notfound(t *testing.T) { } func TestS3Getter_ClientMode_collision(t *testing.T) { + ctx := context.Background() + g := new(S3Getter) // Check that the client mode is "file" if there is both an object and a // folder with a common prefix (i.e., a "collision" in the namespace). - mode, err := g.ClientMode( + mode, err := g.ClientMode(ctx, testURL("https://s3.amazonaws.com/hc-oss-test/go-getter/collision/foo")) if err != nil { t.Fatalf("err: %s", err) @@ -217,6 +250,7 @@ func TestS3Getter_Url(t *testing.T) { for i, pt := range s3tests { t.Run(pt.name, func(t *testing.T) { + g := new(S3Getter) forced, src := getForcedGetter(pt.url) u, err := url.Parse(src) diff --git a/get_test.go b/get_test.go index d70aed881..58d610a05 100644 --- a/get_test.go +++ b/get_test.go @@ -1,6 +1,7 @@ package getter import ( + "context" "os" "path/filepath" "strings" @@ -8,20 +9,24 @@ import ( ) func TestGet_badSchema(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := testModule("basic") u = strings.Replace(u, "file", "nope", -1) - if err := Get(dst, u); err == nil { + if err := Get(ctx, dst, u); err == nil { t.Fatal("should error") } } func TestGet_file(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := testModule("basic") - if err := Get(dst, u); err != nil { + if err := Get(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -33,10 +38,12 @@ func TestGet_file(t *testing.T) { // https://github.com/hashicorp/terraform/issues/11438 func TestGet_fileDecompressorExt(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := testModule("basic-tgz") - if err := Get(dst, u); err != nil { + if err := Get(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -48,10 +55,12 @@ func TestGet_fileDecompressorExt(t *testing.T) { // https://github.com/hashicorp/terraform/issues/8418 func TestGet_filePercent2F(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := testModule("basic%2Ftest") - if err := Get(dst, u); err != nil { + if err := Get(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -62,6 +71,8 @@ func TestGet_filePercent2F(t *testing.T) { } func TestGet_fileDetect(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := filepath.Join(".", "testdata", "basic") pwd, err := os.Getwd() @@ -69,18 +80,19 @@ func TestGet_fileDetect(t *testing.T) { t.Fatalf("err: %s", err) } - client := &Client{ + req := &Request{ Src: u, Dst: dst, Pwd: pwd, Dir: true, } + client := &Client{} - if err := client.Configure(); err != nil { + if err := client.configure(); err != nil { t.Fatalf("configure: %s", err) } - if err := client.Get(); err != nil { + if err := client.Get(ctx, req); err != nil { t.Fatalf("get: %s", err) } @@ -91,11 +103,13 @@ func TestGet_fileDetect(t *testing.T) { } func TestGet_fileForced(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := testModule("basic") u = "file::" + u - if err := Get(dst, u); err != nil { + if err := Get(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -106,10 +120,12 @@ func TestGet_fileForced(t *testing.T) { } func TestGet_fileSubdir(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := testModule("basic//subdir") - if err := Get(dst, u); err != nil { + if err := Get(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -120,11 +136,13 @@ func TestGet_fileSubdir(t *testing.T) { } func TestGet_archive(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := filepath.Join("./testdata", "archive.tar.gz") u, _ = filepath.Abs(u) - if err := Get(dst, u); err != nil { + if err := Get(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -135,11 +153,13 @@ func TestGet_archive(t *testing.T) { } func TestGetAny_archive(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := filepath.Join("./testdata", "archive.tar.gz") u, _ = filepath.Abs(u) - if err := GetAny(dst, u); err != nil { + if err := GetAny(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -150,9 +170,11 @@ func TestGetAny_archive(t *testing.T) { } func TestGet_archiveRooted(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := testModule("archive-rooted/archive.tar.gz") - if err := Get(dst, u); err != nil { + if err := Get(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -163,10 +185,12 @@ func TestGet_archiveRooted(t *testing.T) { } func TestGet_archiveSubdirWild(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := testModule("archive-rooted/archive.tar.gz") u += "//*" - if err := Get(dst, u); err != nil { + if err := Get(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -177,10 +201,12 @@ func TestGet_archiveSubdirWild(t *testing.T) { } func TestGet_archiveSubdirWildMultiMatch(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := testModule("archive-rooted-multi/archive.tar.gz") u += "//*" - if err := Get(dst, u); err == nil { + if err := Get(ctx, dst, u); err == nil { t.Fatal("should error") } else if !strings.Contains(err.Error(), "multiple") { t.Fatalf("err: %s", err) @@ -188,10 +214,12 @@ func TestGet_archiveSubdirWildMultiMatch(t *testing.T) { } func TestGetAny_file(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := testModule("basic-file/foo.txt") - if err := GetAny(dst, u); err != nil { + if err := GetAny(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -202,11 +230,13 @@ func TestGetAny_file(t *testing.T) { } func TestGetAny_dir(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := filepath.Join("./testdata", "basic") u, _ = filepath.Abs(u) - if err := GetAny(dst, u); err != nil { + if err := GetAny(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -224,11 +254,13 @@ func TestGetAny_dir(t *testing.T) { } func TestGetFile(t *testing.T) { + ctx := context.Background() + dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) u := testModule("basic-file/foo.txt") - if err := GetFile(dst, u); err != nil { + if err := GetFile(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -237,11 +269,13 @@ func TestGetFile(t *testing.T) { } func TestGetFile_archive(t *testing.T) { + ctx := context.Background() + dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) u := testModule("basic-file-archive/archive.tar.gz") - if err := GetFile(dst, u); err != nil { + if err := GetFile(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -250,12 +284,14 @@ func TestGetFile_archive(t *testing.T) { } func TestGetFile_archiveChecksum(t *testing.T) { + ctx := context.Background() + dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) u := testModule( "basic-file-archive/archive.tar.gz?checksum=md5:fbd90037dacc4b1ab40811d610dde2f0") - if err := GetFile(dst, u); err != nil { + if err := GetFile(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -264,12 +300,14 @@ func TestGetFile_archiveChecksum(t *testing.T) { } func TestGetFile_archiveNoUnarchive(t *testing.T) { + ctx := context.Background() + dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) u := testModule("basic-file-archive/archive.tar.gz") u += "?archive=false" - if err := GetFile(dst, u); err != nil { + if err := GetFile(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -282,6 +320,8 @@ func TestGetFile_archiveNoUnarchive(t *testing.T) { } func TestGetFile_checksum(t *testing.T) { + ctx := context.Background() + cases := []struct { Append string Err bool @@ -354,7 +394,7 @@ func TestGetFile_checksum(t *testing.T) { func() { dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) - if err := GetFile(dst, u); (err != nil) != tc.Err { + if err := GetFile(ctx, dst, u); (err != nil) != tc.Err { t.Fatalf("append: %s\n\nerr: %s", tc.Append, err) } @@ -365,6 +405,7 @@ func TestGetFile_checksum(t *testing.T) { } func TestGetFile_checksum_from_file(t *testing.T) { + checksums := testModule("checksum-file") httpChecksums := httpTestModule("checksum-file") defer httpChecksums.Close() @@ -432,9 +473,11 @@ func TestGetFile_checksum_from_file(t *testing.T) { for _, tc := range cases { u := checksums + "/content.txt" + tc.Append t.Run(tc.Append, func(t *testing.T) { + ctx := context.Background() + dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) - if err := GetFile(dst, u); (err != nil) != tc.WantErr { + if err := GetFile(ctx, dst, u); (err != nil) != tc.WantErr { t.Fatalf("append: %s\n\nerr: %s", tc.Append, err) } @@ -447,21 +490,25 @@ func TestGetFile_checksum_from_file(t *testing.T) { } func TestGetFile_checksumURL(t *testing.T) { + ctx := context.Background() + dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) u := testModule("basic-file/foo.txt") + "?checksum=md5:09f7e02f1290be211da707a266f153b3" getter := &MockGetter{Proxy: new(FileGetter)} - client := &Client{ + req := &Request{ Src: u, Dst: dst, Dir: false, + } + client := &Client{ Getters: map[string]Getter{ "file": getter, }, } - if err := client.Get(); err != nil { + if err := client.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -471,12 +518,14 @@ func TestGetFile_checksumURL(t *testing.T) { } func TestGetFile_filename(t *testing.T) { + ctx := context.Background() + dst := tempDir(t) u := testModule("basic-file/foo.txt") u += "?filename=bar.txt" - if err := GetAny(dst, u); err != nil { + if err := GetAny(ctx, dst, u); err != nil { t.Fatalf("err: %s", err) } @@ -487,22 +536,26 @@ func TestGetFile_filename(t *testing.T) { } func TestGetFile_checksumSkip(t *testing.T) { + ctx := context.Background() + dst := tempTestFile(t) defer os.RemoveAll(filepath.Dir(dst)) u := testModule("basic-file/foo.txt") + "?checksum=md5:09f7e02f1290be211da707a266f153b3" getter := &MockGetter{Proxy: new(FileGetter)} - client := &Client{ + req := &Request{ Src: u, Dst: dst, Dir: false, + } + client := &Client{ Getters: map[string]Getter{ "file": getter, }, } // get the file - if err := client.Get(); err != nil { + if err := client.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } @@ -514,7 +567,7 @@ func TestGetFile_checksumSkip(t *testing.T) { getter.Proxy = nil getter.GetFileCalled = false - if err := client.Get(); err != nil { + if err := client.Get(ctx, req); err != nil { t.Fatalf("err: %s", err) } diff --git a/go.mod b/go.mod index a869e8f80..30606f9b8 100644 --- a/go.mod +++ b/go.mod @@ -21,3 +21,5 @@ require ( google.golang.org/api v0.9.0 gopkg.in/cheggaaa/pb.v1 v1.0.27 // indirect ) + +go 1.13 diff --git a/request.go b/request.go new file mode 100644 index 000000000..a93d2ba33 --- /dev/null +++ b/request.go @@ -0,0 +1,42 @@ +package getter + +import "net/url" + +type Request struct { + // Src is the source URL to get. + // + // Dst is the path to save the downloaded thing as. If Dir is set to + // true, then this should be a directory. If the directory doesn't exist, + // it will be created for you. + // + // Pwd is the working directory for detection. If this isn't set, some + // detection may fail. Client will not default pwd to the current + // working directory for security reasons. + Src string + Dst string + Pwd string + + // Mode is the method of download the client will use. See ClientMode + // for documentation. + Mode ClientMode + + // Copy, in local file mode if set to true, will copy data instead of using + // a symlink. If false, attempts to symlink to speed up the operation and + // to lower the disk space usage. If the symlink fails, may attempt to copy + // on windows. + Copy bool + + // Dir, if true, tells the Client it is downloading a directory (versus + // a single file). This distinction is necessary since filenames and + // directory names follow the same format so disambiguating is impossible + // without knowing ahead of time. + // + // WARNING: deprecated. If Mode is set, that will take precedence. + Dir bool + + // ProgressListener allows to track file downloads. + // By default a no op progress listener is used. + ProgressListener ProgressTracker + + u *url.URL +} diff --git a/storage.go b/storage.go index 2bc6b9ec3..3ccc1c3af 100644 --- a/storage.go +++ b/storage.go @@ -1,5 +1,7 @@ package getter +import "context" + // Storage is an interface that knows how to lookup downloaded directories // as well as download and update directories from their sources into the // proper location. @@ -9,5 +11,5 @@ type Storage interface { Dir(string) (string, bool, error) // Get will download and optionally update the given directory. - Get(string, string, bool) error + Get(context.Context, string, string, bool) error }