-
Notifications
You must be signed in to change notification settings - Fork 246
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
move cheggaaa progress bar usage to it's own pkg + mutex protect prog…
…ress bar + add some tests + simplify
- Loading branch information
Showing
7 changed files
with
144 additions
and
82 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
package getter | ||
|
||
import ( | ||
"io" | ||
"net/http" | ||
"net/http/httptest" | ||
"sync" | ||
"testing" | ||
) | ||
|
||
type MockProgressTracking struct { | ||
sync.Mutex | ||
downloaded map[string]int | ||
} | ||
|
||
func (p *MockProgressTracking) TrackProgress(src string, | ||
currentSize, totalSize int64, stream io.ReadCloser) (body io.ReadCloser) { | ||
p.Lock() | ||
defer p.Unlock() | ||
|
||
if p.downloaded == nil { | ||
p.downloaded = map[string]int{} | ||
} | ||
|
||
v, _ := p.downloaded[src] | ||
p.downloaded[src] = v + 1 | ||
return stream | ||
} | ||
|
||
func TestGet_progress(t *testing.T) { | ||
s := httptest.NewServer(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { | ||
// all good | ||
rw.Header().Add("X-Terraform-Get", "something") | ||
})) | ||
defer s.Close() | ||
|
||
{ // dl without tracking | ||
dst := tempFile(t) | ||
if err := GetFile(dst, s.URL+"/file?thig=this&that"); err != nil { | ||
t.Fatalf("download failed: %v", err) | ||
} | ||
} | ||
|
||
{ // tracking | ||
p := &MockProgressTracking{} | ||
dst := tempFile(t) | ||
if err := GetFile(dst, s.URL+"/file?thig=this&that", WithProgress(p)); err != nil { | ||
t.Fatalf("download failed: %v", err) | ||
} | ||
if err := GetFile(dst, s.URL+"/otherfile?thig=this&that", WithProgress(p)); err != nil { | ||
t.Fatalf("download failed: %v", err) | ||
} | ||
|
||
if p.downloaded["file"] != 1 { | ||
t.Error("Expected a file download") | ||
} | ||
if p.downloaded["otherfile"] != 1 { | ||
t.Error("Expected a otherfile download") | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
package cheggaaa | ||
|
||
import ( | ||
"io" | ||
"path/filepath" | ||
"sync" | ||
|
||
"github.com/cheggaaa/pb" | ||
getter "github.com/hashicorp/go-getter" | ||
) | ||
|
||
// DefaultProgressBar is the default instance of a cheggaaa | ||
// progress bar. It is recommended to use DefaultProgressBar | ||
// in all places. | ||
var DefaultProgressBar getter.ProgressTracker = &ProgressBar{} | ||
|
||
// ProgressBar wraps a github.com/cheggaaa/pb.Pool | ||
// in order to display download progress for one or multiple | ||
// downloads. | ||
// | ||
// If two different instance of ProgressBar try to | ||
// display a progress only one will be displayed. | ||
// It is therefore recommended to use DefaultProgressBar | ||
type ProgressBar struct { | ||
// lock everything below | ||
lock sync.Mutex | ||
|
||
pool *pb.Pool | ||
|
||
pbs int | ||
} | ||
|
||
func defaultCheggaaaProgressBarConfigFN(bar *pb.ProgressBar, prefix string) { | ||
bar.SetUnits(pb.U_BYTES) | ||
bar.Prefix(prefix) | ||
} | ||
|
||
// TrackProgress instantiates a new progress bar that will | ||
// display the progress of stream until closed. | ||
// total can be 0. | ||
func (cpb *ProgressBar) TrackProgress(src string, currentSize, totalSize int64, stream io.ReadCloser) io.ReadCloser { | ||
cpb.lock.Lock() | ||
defer cpb.lock.Unlock() | ||
|
||
newPb := pb.New64(totalSize) | ||
newPb.Set64(currentSize) | ||
defaultCheggaaaProgressBarConfigFN(newPb, filepath.Base(src)) | ||
if cpb.pool == nil { | ||
cpb.pool = pb.NewPool() | ||
cpb.pool.Start() | ||
} | ||
cpb.pool.Add(newPb) | ||
reader := newPb.NewProxyReader(stream) | ||
|
||
cpb.pbs++ | ||
return &readCloser{ | ||
Reader: reader, | ||
close: func() error { | ||
cpb.lock.Lock() | ||
defer cpb.lock.Unlock() | ||
|
||
newPb.Finish() | ||
cpb.pbs-- | ||
if cpb.pbs <= 0 { | ||
cpb.pool.Stop() | ||
cpb.pool = nil | ||
} | ||
return nil | ||
}, | ||
} | ||
} | ||
|
||
type readCloser struct { | ||
io.Reader | ||
close func() error | ||
} | ||
|
||
func (c *readCloser) Close() error { return c.close() } |