Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyStreamerAt #7

Merged
merged 6 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 95 additions & 1 deletion adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,32 @@ type KeyReaderAt interface {
ReadAt(key string, p []byte, off int64) (int, int64, error)
}

// KeyStreamerAt is the second interface a handler can implement.
// The same conventions as KeyReaderAt apply with respect to the object size and the
// error handling.
// A reader that implements KeyStreamerAt and does not implement KeyReaderAt can
// be wrapped using KeyReaderAtWrapper.
type KeyStreamerAt interface {
StreamAt(key string, off int64, n int64) (io.ReadCloser, int64, error)
}

type keyReaderAtWrapper struct {
KeyStreamerAt
}

func (w keyReaderAtWrapper) ReadAt(key string, p []byte, off int64) (int, int64, error) {
r, size, err := w.KeyStreamerAt.StreamAt(key, off, int64(len(p)))
if err != nil {
return 0, size, err
}
defer r.Close()
n, err := io.ReadFull(r, p)
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
return n, size, err
}

// BlockCacher is the interface that wraps block caching functionality
//
// Add inserts data to the cache for the given key and blockID.
Expand Down Expand Up @@ -109,11 +135,44 @@ type Adapter struct {
numCachedBlocks int
cache BlockCacher
reader KeyReaderAt
canStream bool
splitRanges bool
sizeCache *lru.Cache
retries int
}

func (a *Adapter) srcStreamAt(key string, off int64, n int64) (io.ReadCloser, error) {
s := a.reader.(KeyStreamerAt)
try := 1
delay := 100 * time.Millisecond
var r io.ReadCloser
var tot int64
var err error
for {
r, tot, err = s.StreamAt(key, off, n)
if err != nil && try <= a.retries && errs.Temporary(err) {
try++
time.Sleep(delay)
delay *= 2
continue
}
break
}
if off == 0 {
if err != nil {
if errors.Is(err, syscall.ENOENT) {
a.sizeCache.Add(key, int64(-1))
}
if errors.Is(err, io.EOF) {
a.sizeCache.Add(key, tot)
}
} else {
a.sizeCache.Add(key, tot)
}
}
return r, err
}

func (a *Adapter) srcReadAt(key string, buf []byte, off int64) (int, error) {
try := 1
delay := 100 * time.Millisecond
Expand Down Expand Up @@ -335,11 +394,13 @@ const (
)

func NewAdapter(reader KeyReaderAt, opts ...AdapterOption) (*Adapter, error) {
_, okStream := reader.(KeyStreamerAt)
bc := &Adapter{
blockSize: DefaultBlockSize,
numCachedBlocks: DefaultNumCachedBlocks,
reader: reader,
splitRanges: false,
canStream: okStream,
retries: 5,
}
for _, o := range opts {
Expand Down Expand Up @@ -384,7 +445,40 @@ func (a *Adapter) getRange(key string, rng blockRange) ([][]byte, error) {
nToFetch++
}
}
if nToFetch == len(blocks) {
if nToFetch == len(blocks) && a.canStream {
r, err := a.srcStreamAt(key, rng.start*a.blockSize, (rng.end-rng.start+1)*a.blockSize)
if err != nil {
for i := rng.start; i <= rng.end; i++ {
blockID := a.blockKey(key, i)
a.blmu.Unlock(blockID)
}
return nil, err
}
defer r.Close()
for bid := int64(0); bid <= rng.end-rng.start; bid++ {
blockID := a.blockKey(key, bid+rng.start)
buf := make([]byte, a.blockSize)
n, err := io.ReadFull(r, buf)
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
if err == nil || err == io.EOF {
blocks[bid] = buf[:n]
a.cache.Add(key, uint(rng.start+bid), blocks[bid])
}
if err != nil {
for i := rng.start + bid; i <= rng.end; i++ {
a.blmu.Unlock(a.blockKey(key, i))
}
if err == io.EOF {
break
}
return nil, err
}
a.blmu.Unlock(blockID)
}
return blocks, nil
} else if nToFetch == len(blocks) && !a.canStream {
buf := make([]byte, (rng.end-rng.start+1)*a.blockSize)
n, err := a.srcReadAt(key, buf, rng.start*a.blockSize)
if err != nil && !errors.Is(err, io.EOF) {
Expand Down
24 changes: 9 additions & 15 deletions gcs.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func GCSBillingProject(projectID string) GCSOption {

// GCSHandle creates a KeyReaderAt suitable for constructing an Adapter
// that accesses objects on Google Cloud Storage
func GCSHandle(ctx context.Context, opts ...GCSOption) (*GCSHandler, error) {
func GCSHandle(ctx context.Context, opts ...GCSOption) (KeyReaderAt, error) {
handler := &GCSHandler{
ctx: ctx,
}
Expand All @@ -66,34 +66,28 @@ func GCSHandle(ctx context.Context, opts ...GCSOption) (*GCSHandler, error) {
}
handler.client = cl
}
return handler, nil
return keyReaderAtWrapper{handler}, nil
}

func (gcs *GCSHandler) ReadAt(key string, p []byte, off int64) (int, int64, error) {
func (gcs *GCSHandler) StreamAt(key string, off int64, n int64) (io.ReadCloser, int64, error) {
bucket, object := osuriparse("gs", key)
if len(bucket) == 0 || len(object) == 0 {
return 0, 0, fmt.Errorf("invalid key")
return nil, 0, fmt.Errorf("invalid key")
}
gbucket := gcs.client.Bucket(bucket)
if gcs.billingProjectID != "" {
gbucket = gbucket.UserProject(gcs.billingProjectID)
}
r, err := gbucket.Object(object).NewRangeReader(gcs.ctx, off, int64(len(p)))
//fmt.Printf("read %s [%d-%d]\n", key, off, off+int64(len(p)))
r, err := gbucket.Object(object).NewRangeReader(gcs.ctx, off, n)
if err != nil {
var gerr *googleapi.Error
if off > 0 && errors.As(err, &gerr) && gerr.Code == 416 {
return 0, 0, io.EOF
return nil, 0, io.EOF
}
if errors.Is(err, storage.ErrObjectNotExist) || errors.Is(err, storage.ErrBucketNotExist) {
return 0, -1, syscall.ENOENT
return nil, -1, syscall.ENOENT
}
return 0, 0, fmt.Errorf("new reader for gs://%s/%s: %w", bucket, object, err)
return nil, 0, fmt.Errorf("new reader for gs://%s/%s: %w", bucket, object, err)
}
defer r.Close()
n, err := io.ReadFull(r, p)
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
return n, r.Attrs.Size, err
return r, r.Attrs.Size, err
}
37 changes: 13 additions & 24 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ func HTTPHeader(key, value string) HTTPOption {

// HTTPHandle creates a KeyReaderAt suitable for constructing an Adapter
// that accesses objects using the http protocol
func HTTPHandle(ctx context.Context, opts ...HTTPOption) (*HTTPHandler, error) {
func HTTPHandle(ctx context.Context, opts ...HTTPOption) (KeyReaderAt, error) {
handler := &HTTPHandler{
ctx: ctx,
}
Expand All @@ -68,63 +68,52 @@ func HTTPHandle(ctx context.Context, opts ...HTTPOption) (*HTTPHandler, error) {
if handler.client == nil {
handler.client = &http.Client{}
}
return handler, nil
return keyReaderAtWrapper{handler}, nil
}

func handleResponse(r *http.Response) (int, int64, error) {
func handleResponse(r *http.Response) (io.ReadCloser, int64, error) {
if r.StatusCode == 404 {
return 0, -1, syscall.ENOENT
return nil, -1, syscall.ENOENT
}
if r.StatusCode == 416 {
return 0, 0, io.EOF
return nil, 0, io.EOF
}
return 0, 0, fmt.Errorf("new reader for %s: status code %d", r.Request.URL.String(), r.StatusCode)
return nil, 0, fmt.Errorf("new reader for %s: status code %d", r.Request.URL.String(), r.StatusCode)
}

func (h *HTTPHandler) ReadAt(url string, p []byte, off int64) (int, int64, error) {
func (h *HTTPHandler) StreamAt(key string, off int64, n int64) (io.ReadCloser, int64, error) {
// HEAD request to get object size as it is not returned in range requests
var size int64
if off == 0 {
req, _ := http.NewRequest("HEAD", url, nil)
req, _ := http.NewRequest("HEAD", key, nil)
req = req.WithContext(h.ctx)
for _, mw := range h.requestMiddlewares {
mw(req)
}

r, err := h.client.Do(req)
if err != nil {
return 0, 0, fmt.Errorf("new reader for %s: %w", url, err)
return nil, 0, fmt.Errorf("new reader for %s: %w", key, err)
}
defer r.Body.Close()

if r.StatusCode != 200 {
return handleResponse(r)
}

size = r.ContentLength
}

// GET request to fetch range
req, _ := http.NewRequest("GET", url, nil)
req, _ := http.NewRequest("GET", key, nil)
req = req.WithContext(h.ctx)
for _, mw := range h.requestMiddlewares {
mw(req)
}
req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", off, off+int64(len(p))-1))

req.Header.Add("Range", fmt.Sprintf("bytes=%d-%d", off, off+n-1))
r, err := h.client.Do(req)
if err != nil {
return 0, 0, fmt.Errorf("new reader for %s: %w", url, err)
return nil, 0, fmt.Errorf("new reader for %s: %w", key, err)
}
defer r.Body.Close()

if r.StatusCode != 200 && r.StatusCode != 206 {
return handleResponse(r)
}

n, err := io.ReadFull(r.Body, p)
if err == io.ErrUnexpectedEOF {
err = io.EOF
}
return n, size, err
return r.Body, size, err
}
32 changes: 12 additions & 20 deletions s3.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func S3RequestPayer() S3Option {

// S3Handle creates a KeyReaderAt suitable for constructing an Adapter
// that accesses objects on Amazon S3
func S3Handle(ctx context.Context, opts ...S3Option) (*S3Handler, error) {
func S3Handle(ctx context.Context, opts ...S3Option) (KeyReaderAt, error) {
handler := &S3Handler{
ctx: ctx,
}
Expand All @@ -67,24 +67,24 @@ func S3Handle(ctx context.Context, opts ...S3Option) (*S3Handler, error) {
}
handler.client = s3.NewFromConfig(cfg)
}
return handler, nil
return keyReaderAtWrapper{handler}, nil
}

func handleS3ApiError(err error) (int, int64, error) {
func handleS3ApiError(err error) (io.ReadCloser, int64, error) {
var ae smithy.APIError
if errors.As(err, &ae) && ae.ErrorCode() == "InvalidRange" {
return 0, 0, io.EOF
return nil, 0, io.EOF
}
if errors.As(err, &ae) && (ae.ErrorCode() == "NoSuchBucket" || ae.ErrorCode() == "NoSuchKey" || ae.ErrorCode() == "NotFound") {
return 0, -1, syscall.ENOENT
return nil, -1, syscall.ENOENT
}
return 0, 0, err
return nil, 0, err
}

func (h *S3Handler) ReadAt(key string, p []byte, off int64) (int, int64, error) {
func (h *S3Handler) StreamAt(key string, off int64, n int64) (io.ReadCloser, int64, error) {
bucket, object := osuriparse("s3", key)
if len(bucket) == 0 || len(object) == 0 {
return 0, 0, fmt.Errorf("invalid key")
return nil, 0, fmt.Errorf("invalid key")
}

// HEAD request to get object size as it is not returned in range requests
Expand All @@ -96,7 +96,7 @@ func (h *S3Handler) ReadAt(key string, p []byte, off int64) (int, int64, error)
RequestPayer: types.RequestPayer(h.requestPayer),
})
if err != nil {
return handleS3ApiError(fmt.Errorf("new reader for %s: %w", key, err))
return handleS3ApiError(fmt.Errorf("new reader for s3://%s/%s: %w", bucket, object, err))
}
size = r.ContentLength
}
Expand All @@ -106,18 +106,10 @@ func (h *S3Handler) ReadAt(key string, p []byte, off int64) (int, int64, error)
Bucket: &bucket,
Key: &object,
RequestPayer: types.RequestPayer(h.requestPayer),
Range: aws.String(fmt.Sprintf("bytes=%d-%d", off, off+int64(len(p))-1)),
Range: aws.String(fmt.Sprintf("bytes=%d-%d", off, off+n-1)),
})
if err != nil {
return handleS3ApiError(fmt.Errorf("new reader for %s: %w", key, err))
return handleS3ApiError(fmt.Errorf("new reader for s3://%s/%s: %w", bucket, object, err))
}
defer r.Body.Close()

n, err := io.ReadFull(r.Body, p)
if err == io.ErrUnexpectedEOF {
err = io.EOF
}

//fmt.Printf("read %s [%d-%d]\n", key, off, off+int64(len(p)))
return n, size, err
return r.Body, size, err
}