From 572fb750f5f318823d612ceb21116d19427482df Mon Sep 17 00:00:00 2001 From: Adrien Delorme Date: Thu, 15 Nov 2018 16:06:52 +0100 Subject: [PATCH] allow to track download progress by passing options Also set default Getters, Decompressors & Detectors in Configure func with the idea to add a validate func that will validate if src url will work in the future --- client.go | 30 ++++++----------- client_option.go | 18 +++++++++++ client_option_progress.go | 46 ++++++++++++++++++++++++++ cmd/go-getter/main.go | 5 ++- get.go | 26 ++++++++------- get_base.go | 9 ++++++ get_file.go | 2 ++ get_git.go | 4 ++- get_hg.go | 4 ++- get_http.go | 68 ++++++++++++++++++++++++++++----------- get_mock.go | 2 ++ get_s3.go | 4 ++- 12 files changed, 164 insertions(+), 54 deletions(-) create mode 100644 client_option_progress.go create mode 100644 get_base.go diff --git a/client.go b/client.go index 47c4ba337..9c8f63cbb 100644 --- a/client.go +++ b/client.go @@ -62,6 +62,12 @@ type Client struct { // // WARNING: deprecated. If Mode is set, that will take precedence. Dir bool + + // ProgressListener allows to track file downloads. + // By default a no op progress listener is used. + ProgressListener ProgressTracker + + Options []ClientOption } // Get downloads the configured source to the destination. @@ -76,18 +82,7 @@ func (c *Client) Get() error { } } - // Default decompressor value - decompressors := c.Decompressors - if decompressors == nil { - decompressors = Decompressors - } - - // Detect the URL. This is safe if it is already detected. - detectors := c.Detectors - if detectors == nil { - detectors = Detectors - } - src, err := Detect(c.Src, c.Pwd, detectors) + src, err := Detect(c.Src, c.Pwd, c.Detectors) if err != nil { return err } @@ -119,12 +114,7 @@ func (c *Client) Get() error { force = u.Scheme } - getters := c.Getters - if getters == nil { - getters = Getters - } - - g, ok := getters[force] + g, ok := c.Getters[force] if !ok { return fmt.Errorf( "download not supported for scheme '%s'", force) @@ -150,7 +140,7 @@ func (c *Client) Get() error { if archiveV == "" { // We don't appear to... but is it part of the filename? matchingLen := 0 - for k, _ := range decompressors { + for k := range c.Decompressors { if strings.HasSuffix(u.Path, "."+k) && len(k) > matchingLen { archiveV = k matchingLen = len(k) @@ -163,7 +153,7 @@ func (c *Client) Get() error { // real path. var decompressDst string var decompressDir bool - decompressor := decompressors[archiveV] + decompressor := c.Decompressors[archiveV] if decompressor != nil { // Create a temporary directory to store our archive. We delete // this at the end of everything. diff --git a/client_option.go b/client_option.go index 4bc69a759..81e8a5d03 100644 --- a/client_option.go +++ b/client_option.go @@ -5,11 +5,29 @@ type ClientOption func(*Client) error // Configure configures a client with options. func (c *Client) Configure(opts ...ClientOption) error { + c.Options = opts + c.ProgressListener = noopProgressListener for _, opt := range opts { err := opt(c) if err != nil { return err } } + // Default decompressor values + if c.Decompressors == nil { + c.Decompressors = Decompressors + } + // Default detector values + if c.Detectors == nil { + c.Detectors = Detectors + } + // Default getter values + if c.Getters == nil { + c.Getters = Getters + } + + for _, getter := range c.Getters { + getter.SetClient(c) + } return nil } diff --git a/client_option_progress.go b/client_option_progress.go new file mode 100644 index 000000000..39a46e5c9 --- /dev/null +++ b/client_option_progress.go @@ -0,0 +1,46 @@ +package getter + +import ( + "io" +) + +// WithProgress allows for a user to track +// the progress of a download. +// For example by displaying a progress bar with +// current download. +// Not all getters have progress support yet. +func WithProgress(pl ProgressTracker) func(*Client) error { + return func(c *Client) error { + c.ProgressListener = pl + return nil + } +} + +// ProgressTracker allows to track the progress of downloads. +type ProgressTracker interface { + // TrackProgress should be called when + // a new object is being downloaded. + // src is the location the file is + // downloaded from. + // size is the total size in bytes, + // size can be zero if the file size + // is not known. + // stream is the file being downloaded, every + // written byte will add up to processed size. + // + // TrackProgress returns a ReadCloser that wraps the + // download in progress ( stream ). + // When the download is finished, body shall be closed. + TrackProgress(src string, size int64, stream io.ReadCloser) (body io.ReadCloser) +} + +// NoopProgressListener is a progress listener +// that has no effect. +type NoopProgressListener struct{} + +var noopProgressListener ProgressTracker = &NoopProgressListener{} + +// TrackProgress is a no op +func (*NoopProgressListener) TrackProgress(_ string, _ int64, stream io.ReadCloser) io.ReadCloser { + return stream +} diff --git a/cmd/go-getter/main.go b/cmd/go-getter/main.go index 3b26a1b96..6d0b09dd3 100644 --- a/cmd/go-getter/main.go +++ b/cmd/go-getter/main.go @@ -46,9 +46,12 @@ func main() { Mode: mode, } + if err := client.Configure(); err != nil { + log.Fatalf("Configure: %s", err) + } + if err := client.Get(); err != nil { log.Fatalf("Error downloading: %s", err) - os.Exit(1) } log.Println("Success!") diff --git a/get.go b/get.go index b59c7ff50..838e92d4e 100644 --- a/get.go +++ b/get.go @@ -41,6 +41,11 @@ type Getter interface { // ClientMode returns the mode based on the given URL. This is used to // allow clients to let the getters decide which mode to use. ClientMode(*url.URL) (ClientMode, error) + + // SetClient allows a getter to know it's client + // in order to access client's Get functions or + // progress tracking. + SetClient(*Client) } // Getters is the mapping of scheme to the Getter implementation that will @@ -76,10 +81,9 @@ func init() { // folder doesn't need to exist. It will be created if it doesn't exist. func Get(dst, src string, opts ...ClientOption) error { c := &Client{ - Src: src, - Dst: dst, - Dir: true, - Getters: Getters, + Src: src, + Dst: dst, + Dir: true, } if err := c.Configure(opts...); err != nil { return err @@ -95,10 +99,9 @@ func Get(dst, src string, opts ...ClientOption) error { // archive, it will be unpacked directly into dst. func GetAny(dst, src string, opts ...ClientOption) error { c := &Client{ - Src: src, - Dst: dst, - Mode: ClientModeAny, - Getters: Getters, + Src: src, + Dst: dst, + Mode: ClientModeAny, } if err := c.Configure(opts...); err != nil { return err @@ -110,10 +113,9 @@ func GetAny(dst, src string, opts ...ClientOption) error { // dst. func GetFile(dst, src string, opts ...ClientOption) error { c := &Client{ - Src: src, - Dst: dst, - Dir: false, - Getters: Getters, + Src: src, + Dst: dst, + Dir: false, } if err := c.Configure(opts...); err != nil { return err diff --git a/get_base.go b/get_base.go new file mode 100644 index 000000000..4a9ca733f --- /dev/null +++ b/get_base.go @@ -0,0 +1,9 @@ +package getter + +// getter is our base getter; it regroups +// fields all getters have in common. +type getter struct { + client *Client +} + +func (g *getter) SetClient(c *Client) { g.client = c } diff --git a/get_file.go b/get_file.go index e5d2d61d7..786219f03 100644 --- a/get_file.go +++ b/get_file.go @@ -8,6 +8,8 @@ import ( // FileGetter is a Getter implementation that will download a module from // a file scheme. type FileGetter struct { + getter + // Copy, if set to true, will copy data instead of using a symlink Copy bool } diff --git a/get_git.go b/get_git.go index 988af1dd9..435578782 100644 --- a/get_git.go +++ b/get_git.go @@ -17,7 +17,9 @@ import ( // GitGetter is a Getter implementation that will download a module from // a git repository. -type GitGetter struct{} +type GitGetter struct { + getter +} func (g *GitGetter) ClientMode(_ *url.URL) (ClientMode, error) { return ClientModeDir, nil diff --git a/get_hg.go b/get_hg.go index f38692270..24049c97d 100644 --- a/get_hg.go +++ b/get_hg.go @@ -14,7 +14,9 @@ import ( // HgGetter is a Getter implementation that will download a module from // a Mercurial repository. -type HgGetter struct{} +type HgGetter struct { + getter +} func (g *HgGetter) ClientMode(_ *url.URL) (ClientMode, error) { return ClientModeDir, nil diff --git a/get_http.go b/get_http.go index f66f4e59f..21a34f397 100644 --- a/get_http.go +++ b/get_http.go @@ -18,7 +18,7 @@ import ( // // For file downloads, HTTP is used directly. // -// The protocol for downloading a directory from an HTTP endpoing is as follows: +// The protocol for downloading a directory from an HTTP endpoint is as follows: // // An HTTP GET request is made to the URL with the additional GET parameter // "terraform-get=1". This lets you handle that scenario specially if you @@ -34,6 +34,8 @@ import ( // formed URL. The shorthand syntax of "github.com/foo/bar" or relative // paths are not allowed. type HttpGetter struct { + getter + // Netrc, if true, will lookup and use auth information found // in the user's netrc file if available. Netrc bool @@ -112,52 +114,82 @@ func (g *HttpGetter) Get(dst string, u *url.URL) error { // into a temporary directory, then copy over the proper subdir. source, subDir := SourceDirSubdir(source) if subDir == "" { - return Get(dst, source) + return Get(dst, source, g.client.Options...) } // We have a subdir, time to jump some hoops return g.getSubdir(dst, source, subDir) } -func (g *HttpGetter) GetFile(dst string, u *url.URL) error { +func (g *HttpGetter) GetFile(dst string, src *url.URL) error { if g.Netrc { // Add auth from netrc if we can - if err := addAuthFromNetrc(u); err != nil { + if err := addAuthFromNetrc(src); err != nil { return err } } + // Create all the parent directories if needed + if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { + return err + } + + f, err := os.OpenFile(dst, os.O_RDWR|os.O_CREATE, os.FileMode(0666)) + if err != nil { + return err + } + if g.Client == nil { g.Client = httpClient } - req, err := http.NewRequest("GET", u.String(), nil) + var current int64 + + // We first make a HEAD request so we can check + // if the server supports range queries. If the server/URL doesn't + // support HEAD requests, we just fall back to GET. + req, err := http.NewRequest("HEAD", src.String(), nil) if err != nil { return err } + if g.Header != nil { + req.Header = g.Header + } + headResp, err := g.Client.Do(req) + if err == nil && headResp != nil { + if headResp.StatusCode == 200 { + // If the HEAD request succeeded, then attempt to set the range + // query if we can. + if headResp.Header.Get("Accept-Ranges") == "bytes" { + if fi, err := f.Stat(); err == nil { + if _, err = f.Seek(0, os.SEEK_END); err == nil { + req.Header.Set("Range", fmt.Sprintf("bytes=%d-", fi.Size())) + current = fi.Size() + } + } + } + } + headResp.Body.Close() + } + req.Method = "GET" req.Header = g.Header resp, err := g.Client.Do(req) if err != nil { return err } - - defer resp.Body.Close() if resp.StatusCode != 200 { + resp.Body.Close() return fmt.Errorf("bad response code: %d", resp.StatusCode) } - // Create all the parent directories - if err := os.MkdirAll(filepath.Dir(dst), 0755); err != nil { - return err - } - - f, err := os.Create(dst) - if err != nil { - return err - } + println(g.client) + println(src.String()) + // track download + body := g.client.ProgressListener.TrackProgress(src.String(), current+resp.ContentLength, resp.Body) + defer body.Close() - n, err := io.Copy(f, resp.Body) + n, err := io.Copy(f, body) if err == nil && n < resp.ContentLength { err = io.ErrShortWrite } @@ -179,7 +211,7 @@ func (g *HttpGetter) getSubdir(dst, source, subDir string) error { defer tdcloser.Close() // Download that into the given directory - if err := Get(td, source); err != nil { + if err := Get(td, source, g.client.Options...); err != nil { return err } diff --git a/get_mock.go b/get_mock.go index 882e694dc..e2a98ea28 100644 --- a/get_mock.go +++ b/get_mock.go @@ -6,6 +6,8 @@ import ( // MockGetter is an implementation of Getter that can be used for tests. type MockGetter struct { + getter + // Proxy, if set, will be called after recording the calls below. // If it isn't set, then the *Err values will be returned. Proxy Getter diff --git a/get_s3.go b/get_s3.go index ebb321741..ddea48c5d 100644 --- a/get_s3.go +++ b/get_s3.go @@ -18,7 +18,9 @@ import ( // S3Getter is a Getter implementation that will download a module from // a S3 bucket. -type S3Getter struct{} +type S3Getter struct { + getter +} func (g *S3Getter) ClientMode(u *url.URL) (ClientMode, error) { // Parse URL