Skip to content

Commit

Permalink
Split request options from client into a new request struct (#230)
Browse files Browse the repository at this point in the history
and document checksumming a bit better.
  • Loading branch information
azr authored Feb 4, 2020
1 parent 0d2e82b commit 8c9292b
Show file tree
Hide file tree
Showing 33 changed files with 854 additions and 662 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
vendor
44 changes: 21 additions & 23 deletions checksum.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package getter
import (
"bufio"
"bytes"
"context"
"crypto/md5"
"crypto/sha1"
"crypto/sha256"
Expand Down Expand Up @@ -93,7 +94,7 @@ func (c *FileChecksum) checksum(source string) error {
// <checksum> *file2
//
// see parseChecksumLine for more detail on checksum file parsing
func (c *Client) extractChecksum(u *url.URL) (*FileChecksum, error) {
func (c *Client) extractChecksum(ctx context.Context, u *url.URL) (*FileChecksum, error) {
q := u.Query()
v := q.Get("checksum")

Expand All @@ -115,7 +116,7 @@ func (c *Client) extractChecksum(u *url.URL) (*FileChecksum, error) {

switch checksumType {
case "file":
return c.ChecksumFromFile(checksumValue, u)
return c.ChecksumFromFile(ctx, checksumValue, u.Path)
default:
return newChecksumFromType(checksumType, checksumValue, filepath.Base(u.EscapedPath()))
}
Expand Down Expand Up @@ -183,15 +184,16 @@ func newChecksumFromValue(checksumValue, filename string) (*FileChecksum, error)
return c, nil
}

// ChecksumFromFile will return all the FileChecksums found in file
// ChecksumFromFile will return the first file checksum found in the
// `checksumURL` file that corresponds to the `checksummedURL` path.
//
// ChecksumFromFile will try to guess the hashing algorithm based on content
// of checksum file
// ChecksumFromFile will infer the hashing algorithm based on the checksumURL
// file content.
//
// ChecksumFromFile will only return checksums for files that match file
// behind src
func (c *Client) ChecksumFromFile(checksumFile string, src *url.URL) (*FileChecksum, error) {
checksumFileURL, err := urlhelper.Parse(checksumFile)
// ChecksumFromFile will only return checksums for files that match
// checksummedURL, which is the object being checksummed.
func (c *Client) ChecksumFromFile(ctx context.Context, checksumURL, checksummedURL string) (*FileChecksum, error) {
checksumFileURL, err := urlhelper.Parse(checksumURL)
if err != nil {
return nil, err
}
Expand All @@ -202,24 +204,20 @@ func (c *Client) ChecksumFromFile(checksumFile string, src *url.URL) (*FileCheck
}
defer os.Remove(tempfile)

c2 := &Client{
Ctx: c.Ctx,
Getters: c.Getters,
Decompressors: c.Decompressors,
Detectors: c.Detectors,
Pwd: c.Pwd,
Dir: false,
Src: checksumFile,
Dst: tempfile,
ProgressListener: c.ProgressListener,
req := &Request{
// Pwd: c.Pwd, TODO(adrien): pass pwd ?
Dir: false,
Src: checksumURL,
Dst: tempfile,
// ProgressListener: c.ProgressListener, TODO(adrien): pass progress bar ?
}
if err = c2.Get(); err != nil {
if err = c.Get(ctx, req); err != nil {
return nil, fmt.Errorf(
"Error downloading checksum file: %s", err)
}

filename := filepath.Base(src.Path)
absPath, err := filepath.Abs(src.Path)
filename := filepath.Base(checksummedURL)
absPath, err := filepath.Abs(checksummedURL)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -277,7 +275,7 @@ func (c *Client) ChecksumFromFile(checksumFile string, src *url.URL) (*FileCheck
}
}
}
return nil, fmt.Errorf("no checksum found in: %s", checksumFile)
return nil, fmt.Errorf("no checksum found in: %s", checksumURL)
}

// parseChecksumLine takes a line from a checksum file and returns
Expand Down
121 changes: 44 additions & 77 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,25 +19,6 @@ import (
// Using a client directly allows more fine-grained control over how downloading
// is done, as well as customizing the protocols supported.
type Client struct {
// Ctx for cancellation
Ctx context.Context

// Src is the source URL to get.
//
// Dst is the path to save the downloaded thing as. If Dir is set to
// true, then this should be a directory. If the directory doesn't exist,
// it will be created for you.
//
// Pwd is the working directory for detection. If this isn't set, some
// detection may fail. Client will not default pwd to the current
// working directory for security reasons.
Src string
Dst string
Pwd string

// Mode is the method of download the client will use. See ClientMode
// for documentation.
Mode ClientMode

// Detectors is the list of detectors that are tried on the source.
// If this is nil, then the default Detectors will be used.
Expand All @@ -50,68 +31,54 @@ type Client struct {
// Getters is the map of protocols supported by this client. If this
// is nil, then the default Getters variable will be used.
Getters map[string]Getter

// Dir, if true, tells the Client it is downloading a directory (versus
// a single file). This distinction is necessary since filenames and
// directory names follow the same format so disambiguating is impossible
// without knowing ahead of time.
//
// WARNING: deprecated. If Mode is set, that will take precedence.
Dir bool

// ProgressListener allows to track file downloads.
// By default a no op progress listener is used.
ProgressListener ProgressTracker

Options []ClientOption
}

// Get downloads the configured source to the destination.
func (c *Client) Get() error {
if err := c.Configure(c.Options...); err != nil {
func (c *Client) Get(ctx context.Context, req *Request) error {
if err := c.configure(); err != nil {
return err
}

// Store this locally since there are cases we swap this
mode := c.Mode
if mode == ClientModeInvalid {
if c.Dir {
mode = ClientModeDir
if req.Mode == ClientModeInvalid {
if req.Dir {
req.Mode = ClientModeDir
} else {
mode = ClientModeFile
req.Mode = ClientModeFile
}
}

src, err := Detect(c.Src, c.Pwd, c.Detectors)
var err error
req.Src, err = Detect(req.Src, req.Pwd, c.Detectors)
if err != nil {
return err
}

var force string
// Determine if we have a forced protocol, i.e. "git::http://..."
force, src := getForcedGetter(src)
force, req.Src = getForcedGetter(req.Src)

// If there is a subdir component, then we download the root separately
// and then copy over the proper subdir.
var realDst string
dst := c.Dst
src, subDir := SourceDirSubdir(src)
var realDst, subDir string
req.Src, subDir = SourceDirSubdir(req.Src)
if subDir != "" {
td, tdcloser, err := safetemp.Dir("", "getter")
if err != nil {
return err
}
defer tdcloser.Close()

realDst = dst
dst = td
realDst = req.Dst
req.Dst = td
}

u, err := urlhelper.Parse(src)
req.u, err = urlhelper.Parse(req.Src)
if err != nil {
return err
}
if force == "" {
force = u.Scheme
force = req.u.Scheme
}

g, ok := c.Getters[force]
Expand All @@ -121,15 +88,15 @@ func (c *Client) Get() error {
}

// We have magic query parameters that we use to signal different features
q := u.Query()
q := req.u.Query()

// Determine if we have an archive type
archiveV := q.Get("archive")
if archiveV != "" {
// Delete the paramter since it is a magic parameter we don't
// want to pass on to the Getter
q.Del("archive")
u.RawQuery = q.Encode()
req.u.RawQuery = q.Encode()

// If we can parse the value as a bool and it is false, then
// set the archive to "-" which should never map to a decompressor
Expand All @@ -141,7 +108,7 @@ func (c *Client) Get() error {
// We don't appear to... but is it part of the filename?
matchingLen := 0
for k := range c.Decompressors {
if strings.HasSuffix(u.Path, "."+k) && len(k) > matchingLen {
if strings.HasSuffix(req.u.Path, "."+k) && len(k) > matchingLen {
archiveV = k
matchingLen = len(k)
}
Expand All @@ -166,65 +133,65 @@ func (c *Client) Get() error {

// Swap the download directory to be our temporary path and
// store the old values.
decompressDst = dst
decompressDir = mode != ClientModeFile
dst = filepath.Join(td, "archive")
mode = ClientModeFile
decompressDst = req.Dst
decompressDir = req.Mode != ClientModeFile
req.Dst = filepath.Join(td, "archive")
req.Mode = ClientModeFile
}

// Determine checksum if we have one
checksum, err := c.extractChecksum(u)
checksum, err := c.extractChecksum(ctx, req.u)
if err != nil {
return fmt.Errorf("invalid checksum: %s", err)
}

// Delete the query parameter if we have it.
q.Del("checksum")
u.RawQuery = q.Encode()
req.u.RawQuery = q.Encode()

if mode == ClientModeAny {
if req.Mode == ClientModeAny {
// Ask the getter which client mode to use
mode, err = g.ClientMode(u)
req.Mode, err = g.ClientMode(ctx, req.u)
if err != nil {
return err
}

// Destination is the base name of the URL path in "any" mode when
// a file source is detected.
if mode == ClientModeFile {
filename := filepath.Base(u.Path)
if req.Mode == ClientModeFile {
filename := filepath.Base(req.u.Path)

// Determine if we have a custom file name
if v := q.Get("filename"); v != "" {
// Delete the query parameter if we have it.
q.Del("filename")
u.RawQuery = q.Encode()
req.u.RawQuery = q.Encode()

filename = v
}

dst = filepath.Join(dst, filename)
req.Dst = filepath.Join(req.Dst, filename)
}
}

// If we're not downloading a directory, then just download the file
// and return.
if mode == ClientModeFile {
if req.Mode == ClientModeFile {
getFile := true
if checksum != nil {
if err := checksum.checksum(dst); err == nil {
if err := checksum.checksum(req.Dst); err == nil {
// don't get the file if the checksum of dst is correct
getFile = false
}
}
if getFile {
err := g.GetFile(dst, u)
err := g.GetFile(ctx, req)
if err != nil {
return err
}

if checksum != nil {
if err := checksum.checksum(dst); err != nil {
if err := checksum.checksum(req.Dst); err != nil {
return err
}
}
Expand All @@ -233,24 +200,24 @@ func (c *Client) Get() error {
if decompressor != nil {
// We have a decompressor, so decompress the current destination
// into the final destination with the proper mode.
err := decompressor.Decompress(decompressDst, dst, decompressDir)
err := decompressor.Decompress(decompressDst, req.Dst, decompressDir)
if err != nil {
return err
}

// Swap the information back
dst = decompressDst
req.Dst = decompressDst
if decompressDir {
mode = ClientModeAny
req.Mode = ClientModeAny
} else {
mode = ClientModeFile
req.Mode = ClientModeFile
}
}

// We check the dir value again because it can be switched back
// if we were unarchiving. If we're still only Get-ing a file, then
// we're done.
if mode == ClientModeFile {
if req.Mode == ClientModeFile {
return nil
}
}
Expand All @@ -269,9 +236,9 @@ func (c *Client) Get() error {

// We're downloading a directory, which might require a bit more work
// if we're specifying a subdir.
err := g.Get(dst, u)
err := g.Get(ctx, req)
if err != nil {
err = fmt.Errorf("error downloading '%s': %s", src, err)
err = fmt.Errorf("error downloading '%s': %s", req.Src, err)
return err
}
}
Expand All @@ -286,12 +253,12 @@ func (c *Client) Get() error {
}

// Process any globs
subDir, err := SubdirGlob(dst, subDir)
subDir, err := SubdirGlob(req.Dst, subDir)
if err != nil {
return err
}

return copyDir(c.Ctx, realDst, subDir, false)
return copyDir(ctx, realDst, subDir, false)
}

return nil
Expand Down
Loading

0 comments on commit 8c9292b

Please sign in to comment.