From 7c9f0c755554134fe3e7488268b0cda910eabff8 Mon Sep 17 00:00:00 2001
From: Evans Mungai <evans@replicated.com>
Date: Thu, 21 Sep 2023 12:30:53 +0100
Subject: [PATCH] chore: fix how OCI images are pulled

---
 pkg/oci/pull.go      | 78 +++++++++++++++++++++++++++++++++-----------
 pkg/oci/pull_test.go | 56 +++++++++++++++++++++++++++++++
 2 files changed, 115 insertions(+), 19 deletions(-)
 create mode 100644 pkg/oci/pull_test.go

diff --git a/pkg/oci/pull.go b/pkg/oci/pull.go
index 6618262a7..c9f851462 100644
--- a/pkg/oci/pull.go
+++ b/pkg/oci/pull.go
@@ -4,6 +4,7 @@ import (
 	"context"
 	"fmt"
 	"net/http"
+	"net/url"
 	"path/filepath"
 	"strings"
 
@@ -11,6 +12,7 @@ import (
 	"github.com/pkg/errors"
 	"github.com/replicatedhq/troubleshoot/internal/util"
 	"github.com/replicatedhq/troubleshoot/pkg/version"
+	"k8s.io/klog/v2"
 	"oras.land/oras-go/pkg/auth"
 	dockerauth "oras.land/oras-go/pkg/auth/docker"
 	"oras.land/oras-go/pkg/content"
@@ -27,14 +29,39 @@ var (
 )
 
 func PullPreflightFromOCI(uri string) ([]byte, error) {
-	return pullFromOCI(uri, "replicated.preflight.spec", "replicated-preflight")
+	return pullFromOCI(context.Background(), uri, "replicated.preflight.spec", "replicated-preflight")
 }
 
 func PullSupportBundleFromOCI(uri string) ([]byte, error) {
-	return pullFromOCI(uri, "replicated.supportbundle.spec", "replicated-supportbundle")
+	return pullFromOCI(context.Background(), uri, "replicated.supportbundle.spec", "replicated-supportbundle")
 }
 
-func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error) {
+func PullSpecsFromOCI(ctx context.Context, uri string) ([]string, error) {
+	rawSpecs := []string{}
+
+	// First try to pull the preflight spec
+	rawPreflight, err := pullFromOCI(ctx, uri, "replicated.preflight.spec", "replicated-preflight")
+	if err != nil {
+		// Ignore "not found" error and continue fetching the support bundle spec
+		if !errors.Is(err, ErrNoRelease) {
+			return nil, err
+		}
+	} else {
+		rawSpecs = append(rawSpecs, string(rawPreflight))
+	}
+
+	// Then try to pull the support bundle spec
+	rawSupportBundle, err := pullFromOCI(ctx, uri, "replicated.supportbundle.spec", "replicated-supportbundle")
+	// If we had found a preflight spec, do not return an error
+	if err != nil && len(rawSpecs) == 0 {
+		return nil, err
+	}
+	rawSpecs = append(rawSpecs, string(rawSupportBundle))
+
+	return rawSpecs, nil
+}
+
+func pullFromOCI(ctx context.Context, uri string, mediaType string, imageName string) ([]byte, error) {
 	// helm credentials
 	helmCredentialsFile := filepath.Join(util.HomeDir(), HelmCredentialsFileBasename)
 	dockerauthClient, err := dockerauth.NewClientWithDockerFallback(helmCredentialsFile)
@@ -52,6 +79,7 @@ func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error)
 		return nil, errors.Wrap(err, "failed to create resolver")
 	}
 
+	// TODO: How do we handle "not found" cases?
 	memoryStore := content.NewMemory()
 	allowedMediaTypes := []string{
 		mediaType,
@@ -60,24 +88,13 @@ func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error)
 	var descriptors, layers []ocispec.Descriptor
 	registryStore := content.Registry{Resolver: resolver}
 
-	// remove the oci://
-	uri = strings.TrimPrefix(uri, "oci://")
-
-	uriParts := strings.Split(uri, ":")
-	uri = fmt.Sprintf("%s/%s", uriParts[0], imageName)
-
-	if len(uriParts) > 1 {
-		uri = fmt.Sprintf("%s:%s", uri, uriParts[1])
-	} else {
-		uri = fmt.Sprintf("%s:latest", uri)
-	}
-
-	parsedRef, err := registry.ParseReference(uri)
+	parsedRef, err := toRegistryRef(uri)
 	if err != nil {
-		return nil, errors.Wrap(err, "failed to parse reference")
+		return nil, err
 	}
+	klog.V(1).Infof("Pulling OCI image from %q", parsedRef.String())
 
-	manifest, err := oras.Copy(context.TODO(), registryStore, parsedRef.String(), memoryStore, "",
+	manifest, err := oras.Copy(ctx, registryStore, parsedRef.String(), memoryStore, "",
 		oras.WithPullEmptyNameAllowed(),
 		oras.WithAllowedMediaTypes(allowedMediaTypes),
 		oras.WithLayerDescriptors(func(l []ocispec.Descriptor) {
@@ -94,7 +111,7 @@ func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error)
 	descriptors = append(descriptors, manifest)
 	descriptors = append(descriptors, layers...)
 
-	// expect 1 descriptor
+	// expect 2 descriptors
 	if len(descriptors) != 2 {
 		return nil, fmt.Errorf("expected 2 descriptor, got %d", len(descriptors))
 	}
@@ -120,3 +137,26 @@ func pullFromOCI(uri string, mediaType string, imageName string) ([]byte, error)
 
 	return matchingSpec, nil
 }
+
+func toRegistryRef(raw string) (registry.Reference, error) {
+	u, err := url.Parse(raw)
+	if err != nil {
+		return registry.Reference{}, err
+	}
+
+	// Always check the scheme. If more schemes need to be supported
+	// we need to compare u.Scheme against a list of supported schemes.
+	// url.Parse(raw) will not return an error is a scheme is not present.
+	if u.Scheme != "oci" {
+		return registry.Reference{}, fmt.Errorf("%q is an invalid OCI registry scheme", u.Scheme)
+	}
+
+	parts := strings.Split(u.EscapedPath(), ":")
+	tag := "latest"
+	if len(parts) > 1 {
+		tag = parts[1]
+	}
+	// remove the oci://
+	uri := fmt.Sprintf("%s%s:%s", u.Host, parts[0], tag)
+	return registry.ParseReference(uri)
+}
diff --git a/pkg/oci/pull_test.go b/pkg/oci/pull_test.go
new file mode 100644
index 000000000..e8d98c578
--- /dev/null
+++ b/pkg/oci/pull_test.go
@@ -0,0 +1,56 @@
+package oci
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+	"github.com/stretchr/testify/require"
+)
+
+func Test_toRegistryRef(t *testing.T) {
+	tests := []struct {
+		name    string
+		uri     string
+		want    string
+		wantErr bool
+	}{
+		{
+			name: "valid uri",
+			uri:  "oci://localhost/replicated-preflight",
+			want: "localhost/replicated-preflight:latest",
+		},
+		{
+			name: "valid uri with port",
+			uri:  "oci://localhost:5000/replicated-preflight",
+			want: "localhost:5000/replicated-preflight:latest",
+		},
+		{
+			name: "valid uri with tag",
+			uri:  "oci://localhost:5000/replicated-preflight:v4",
+			want: "localhost:5000/replicated-preflight:v4",
+		},
+		{
+			name:    "invalid uri - missing scheme",
+			uri:     "localhost:5000/replicated-preflight:v4",
+			wantErr: true,
+		},
+		{
+			name:    "invalid uri - wrong scheme",
+			uri:     "https://localhost:5000/replicated-preflight:v4",
+			wantErr: true,
+		},
+		{
+			name:    "empty uri",
+			wantErr: true,
+		},
+	}
+	for _, tt := range tests {
+		t.Run(tt.name, func(t *testing.T) {
+			got, err := toRegistryRef(tt.uri)
+			require.Equalf(t, (err != nil), tt.wantErr, "toRegistryRef() error = %v, wantErr %v", err, tt.wantErr)
+
+			gotStr := got.String()
+			assert.Equalf(t, tt.want, gotStr, "toRegistryRef() = %v, want %v", gotStr, tt.want)
+		})
+	}
+}