Skip to content

Commit

Permalink
move cheggaaa progress bar usage to it's own pkg + mutex protect prog…
Browse files Browse the repository at this point in the history
…ress bar + add some tests + simplify
  • Loading branch information
azr committed Dec 14, 2018
1 parent 937d58d commit 6d21948
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 82 deletions.
1 change: 0 additions & 1 deletion client_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ type ClientOption func(*Client) error
// Configure configures a client with options.
func (c *Client) Configure(opts ...ClientOption) error {
c.Options = opts
c.ProgressListener = noopProgressListener
for _, opt := range opts {
err := opt(c)
if err != nil {
Expand Down
11 changes: 0 additions & 11 deletions client_option_progress.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,3 @@ type ProgressTracker interface {
// When the download is finished, body shall be closed.
TrackProgress(src string, currentSize, totalSize int64, stream io.ReadCloser) (body io.ReadCloser)
}

// NoopProgressListener is a progress listener
// that has no effect.
type NoopProgressListener struct{}

var noopProgressListener ProgressTracker = &NoopProgressListener{}

// TrackProgress is a no op
func (*NoopProgressListener) TrackProgress(_ string, _, _ int64, stream io.ReadCloser) io.ReadCloser {
return stream
}
67 changes: 0 additions & 67 deletions client_option_progress_cheggaaa.go

This file was deleted.

61 changes: 61 additions & 0 deletions client_option_progress_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package getter

import (
"io"
"net/http"
"net/http/httptest"
"sync"
"testing"
)

type MockProgressTracking struct {
sync.Mutex
downloaded map[string]int
}

func (p *MockProgressTracking) TrackProgress(src string,
currentSize, totalSize int64, stream io.ReadCloser) (body io.ReadCloser) {
p.Lock()
defer p.Unlock()

if p.downloaded == nil {
p.downloaded = map[string]int{}
}

v, _ := p.downloaded[src]
p.downloaded[src] = v + 1
return stream
}

func TestGet_progress(t *testing.T) {
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
// all good
rw.Header().Add("X-Terraform-Get", "something")
}))
defer s.Close()

{ // dl without tracking
dst := tempFile(t)
if err := GetFile(dst, s.URL+"/file?thig=this&that"); err != nil {
t.Fatalf("download failed: %v", err)
}
}

{ // tracking
p := &MockProgressTracking{}
dst := tempFile(t)
if err := GetFile(dst, s.URL+"/file?thig=this&that", WithProgress(p)); err != nil {
t.Fatalf("download failed: %v", err)
}
if err := GetFile(dst, s.URL+"/otherfile?thig=this&that", WithProgress(p)); err != nil {
t.Fatalf("download failed: %v", err)
}

if p.downloaded["file"] != 1 {
t.Error("Expected a file download")
}
if p.downloaded["otherfile"] != 1 {
t.Error("Expected a otherfile download")
}
}
}
3 changes: 2 additions & 1 deletion cmd/go-getter/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"

"github.com/hashicorp/go-getter"
"github.com/hashicorp/go-getter/progresstracking/cheggaaa"
)

func main() {
Expand Down Expand Up @@ -48,7 +49,7 @@ func main() {
}
var opts []getter.ClientOption
if *progress {
opts = append(opts, getter.WithCheggaaaProgressBarV1())
opts = append(opts, getter.WithProgress(cheggaaa.DefaultProgressBar))
}

if err := client.Configure(opts...); err != nil {
Expand Down
5 changes: 3 additions & 2 deletions get_http.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,10 @@ func (g *HttpGetter) GetFile(dst string, src *url.URL) error {

body := resp.Body

if g.client != nil {
if g.client != nil && g.client.ProgressListener != nil {
// track download
body = g.client.ProgressListener.TrackProgress(src.String(), currentFileSize, currentFileSize+resp.ContentLength, resp.Body)
fn := filepath.Base(src.EscapedPath())
body = g.client.ProgressListener.TrackProgress(fn, currentFileSize, currentFileSize+resp.ContentLength, resp.Body)
}
defer body.Close()

Expand Down
78 changes: 78 additions & 0 deletions progresstracking/cheggaaa/progress_tracking.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package cheggaaa

import (
"io"
"path/filepath"
"sync"

"github.com/cheggaaa/pb"
getter "github.com/hashicorp/go-getter"
)

// DefaultProgressBar is the default instance of a cheggaaa
// progress bar. It is recommended to use DefaultProgressBar
// in all places.
var DefaultProgressBar getter.ProgressTracker = &ProgressBar{}

// ProgressBar wraps a github.com/cheggaaa/pb.Pool
// in order to display download progress for one or multiple
// downloads.
//
// If two different instance of ProgressBar try to
// display a progress only one will be displayed.
// It is therefore recommended to use DefaultProgressBar
type ProgressBar struct {
// lock everything below
lock sync.Mutex

pool *pb.Pool

pbs int
}

func defaultCheggaaaProgressBarConfigFN(bar *pb.ProgressBar, prefix string) {
bar.SetUnits(pb.U_BYTES)
bar.Prefix(prefix)
}

// TrackProgress instantiates a new progress bar that will
// display the progress of stream until closed.
// total can be 0.
func (cpb *ProgressBar) TrackProgress(src string, currentSize, totalSize int64, stream io.ReadCloser) io.ReadCloser {
cpb.lock.Lock()
defer cpb.lock.Unlock()

newPb := pb.New64(totalSize)
newPb.Set64(currentSize)
defaultCheggaaaProgressBarConfigFN(newPb, filepath.Base(src))
if cpb.pool == nil {
cpb.pool = pb.NewPool()
cpb.pool.Start()
}
cpb.pool.Add(newPb)
reader := newPb.NewProxyReader(stream)

cpb.pbs++
return &readCloser{
Reader: reader,
close: func() error {
cpb.lock.Lock()
defer cpb.lock.Unlock()

newPb.Finish()
cpb.pbs--
if cpb.pbs <= 0 {
cpb.pool.Stop()
cpb.pool = nil
}
return nil
},
}
}

type readCloser struct {
io.Reader
close func() error
}

func (c *readCloser) Close() error { return c.close() }

0 comments on commit 6d21948

Please sign in to comment.