Skip to content

Commit 56fc358

Browse files
Merge pull request #1347 from cuixq:client
PiperOrigin-RevId: 808791051
2 parents f91afc5 + 1f30003 commit 56fc358

File tree

12 files changed

+96
-41
lines changed

12 files changed

+96
-41
lines changed

clients/datasource/maven_registry.go

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,11 +34,15 @@ import (
3434
"deps.dev/util/semver"
3535
"github.com/google/osv-scalibr/log"
3636
"golang.org/x/net/html/charset"
37+
"golang.org/x/oauth2/google"
3738
)
3839

3940
// mavenCentral holds the URL of Maven Central Repository.
4041
const mavenCentral = "https://repo.maven.apache.org/maven2"
4142

43+
// artifactRegistryScheme defines the scheme for Google Artifact Registry.
44+
const artifactRegistryScheme = "artifactregistry"
45+
4246
var errAPIFailed = errors.New("API query failed")
4347

4448
// MavenRegistryAPIClient defines a client to fetch metadata from a Maven registry.
@@ -47,6 +51,7 @@ type MavenRegistryAPIClient struct {
4751
registries []MavenRegistry // Additional registries specified to fetch projects
4852
registryAuths map[string]*HTTPAuthentication // Authentication for the registries keyed by registry ID. From settings.xml
4953
localRegistry string // The local directory that holds Maven manifests
54+
googleClient *http.Client // A client for authenticating with Google services, used for Artifact Registry.
5055

5156
// Cache fields
5257
mu *sync.Mutex
@@ -71,7 +76,7 @@ type MavenRegistry struct {
7176
}
7277

7378
// NewMavenRegistryAPIClient returns a new MavenRegistryAPIClient.
74-
func NewMavenRegistryAPIClient(registry MavenRegistry, localRegistry string) (*MavenRegistryAPIClient, error) {
79+
func NewMavenRegistryAPIClient(ctx context.Context, registry MavenRegistry, localRegistry string) (*MavenRegistryAPIClient, error) {
7580
if registry.URL == "" {
7681
registry.URL = mavenCentral
7782
registry.ID = "central"
@@ -90,14 +95,18 @@ func NewMavenRegistryAPIClient(registry MavenRegistry, localRegistry string) (*M
9095
globalSettings := ParseMavenSettings(globalMavenSettingsFile())
9196
userSettings := ParseMavenSettings(userMavenSettingsFile())
9297

93-
return &MavenRegistryAPIClient{
98+
client := &MavenRegistryAPIClient{
9499
// We assume only downloading releases is allowed on the default registry.
95100
defaultRegistry: registry,
96101
localRegistry: localRegistry,
97102
mu: &sync.Mutex{},
98103
responses: NewRequestCache[string, response](),
99104
registryAuths: MakeMavenAuth(globalSettings, userSettings),
100-
}, nil
105+
}
106+
if registry.Parsed.Scheme == artifactRegistryScheme {
107+
client.createGoogleClient(ctx)
108+
}
109+
return client, nil
101110
}
102111

103112
// SetLocalRegistry sets the local directory that stores the downloaded Maven manifests.
@@ -114,13 +123,14 @@ func (m *MavenRegistryAPIClient) WithoutRegistries() *MavenRegistryAPIClient {
114123
cacheTimestamp: m.cacheTimestamp,
115124
responses: m.responses,
116125
registryAuths: m.registryAuths,
126+
googleClient: m.googleClient,
117127
}
118128
}
119129

120130
// AddRegistry adds the given registry to the list of registries if it has not been added.
121-
func (m *MavenRegistryAPIClient) AddRegistry(registry MavenRegistry) error {
131+
func (m *MavenRegistryAPIClient) AddRegistry(ctx context.Context, registry MavenRegistry) error {
122132
if registry.ID == m.defaultRegistry.ID {
123-
return m.updateDefaultRegistry(registry)
133+
return m.updateDefaultRegistry(ctx, registry)
124134
}
125135

126136
for _, reg := range m.registries {
@@ -136,20 +146,42 @@ func (m *MavenRegistryAPIClient) AddRegistry(registry MavenRegistry) error {
136146

137147
registry.Parsed = u
138148
m.registries = append(m.registries, registry)
149+
if registry.Parsed.Scheme == artifactRegistryScheme {
150+
m.createGoogleClient(ctx)
151+
}
139152

140153
return nil
141154
}
142155

143-
func (m *MavenRegistryAPIClient) updateDefaultRegistry(registry MavenRegistry) error {
156+
func (m *MavenRegistryAPIClient) updateDefaultRegistry(ctx context.Context, registry MavenRegistry) error {
144157
u, err := url.Parse(registry.URL)
145158
if err != nil {
146159
return err
147160
}
148161
registry.Parsed = u
149162
m.defaultRegistry = registry
163+
if registry.Parsed.Scheme == artifactRegistryScheme {
164+
m.createGoogleClient(ctx)
165+
}
150166
return nil
151167
}
152168

169+
// createGoogleClient creates a client for authenticating with Google services.
170+
func (m *MavenRegistryAPIClient) createGoogleClient(ctx context.Context) {
171+
if m.googleClient != nil {
172+
return
173+
}
174+
// This is the scope that artifact-registry-go-tools uses.
175+
// https://github.com/GoogleCloudPlatform/artifact-registry-go-tools/blob/main/pkg/auth/auth.go
176+
client, err := google.DefaultClient(ctx, "https://www.googleapis.com/auth/cloud-platform")
177+
if err != nil {
178+
// We don't return an error here so that we can fall back to a regular http client.
179+
log.Warnf("failed to create Google default client, Artifact Registry access will be unavailable: %v", err)
180+
return
181+
}
182+
m.googleClient = client
183+
}
184+
153185
// GetRegistries returns the registries added to this client.
154186
func (m *MavenRegistryAPIClient) GetRegistries() (registries []MavenRegistry) {
155187
return m.registries
@@ -269,17 +301,28 @@ func (m *MavenRegistryAPIClient) get(ctx context.Context, auth *HTTPAuthenticati
269301
}
270302
}
271303

272-
u := registry.Parsed.JoinPath(paths...).String()
304+
httpClient := http.DefaultClient
305+
requestURL := *registry.Parsed
306+
isArtifactRegistry := requestURL.Scheme == artifactRegistryScheme
307+
if isArtifactRegistry {
308+
requestURL.Scheme = "https"
309+
// For Artifact Registry, use google.DefaultClient for ADC if available.
310+
if m.googleClient != nil {
311+
httpClient = m.googleClient
312+
}
313+
}
314+
315+
u := requestURL.JoinPath(paths...).String()
273316
resp, err := m.responses.Get(u, func() (response, error) {
274317
log.Infof("Fetching response from: %s", u)
275-
resp, err := auth.Get(ctx, http.DefaultClient, u)
318+
resp, err := auth.Get(ctx, httpClient, u)
276319
if err != nil {
277320
return response{}, fmt.Errorf("%w: Maven registry query failed: %w", errAPIFailed, err)
278321
}
279322
defer resp.Body.Close()
280323

281-
if !slices.Contains([]int{http.StatusOK, http.StatusNotFound, http.StatusUnauthorized}, resp.StatusCode) {
282-
// Only cache responses with Status OK, NotFound, or Unauthorized
324+
if !slices.Contains([]int{http.StatusOK, http.StatusNotFound, http.StatusUnauthorized, http.StatusForbidden}, resp.StatusCode) {
325+
// Only cache responses with Status OK, NotFound, Unauthorized, or Forbidden
283326
return response{}, fmt.Errorf("%w: Maven registry query status: %d", errAPIFailed, resp.StatusCode)
284327
}
285328

@@ -301,6 +344,10 @@ func (m *MavenRegistryAPIClient) get(ctx context.Context, auth *HTTPAuthenticati
301344
return err
302345
}
303346

347+
if resp.StatusCode == http.StatusForbidden && isArtifactRegistry {
348+
return fmt.Errorf("%w: Maven registry query status: %d (Forbidden). Please check your Application Default Credentials (ADC) have permission to read from %s", errAPIFailed, resp.StatusCode, registry.URL)
349+
}
350+
304351
if resp.StatusCode != http.StatusOK {
305352
return fmt.Errorf("%w: Maven registry query status: %d", errAPIFailed, resp.StatusCode)
306353
}

clients/datasource/maven_registry_auth_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func TestWithoutRegistriesMaintainsAuthData(t *testing.T) {
2828
srv := clienttest.NewMockHTTPServer(t)
2929

3030
// Create original client with multiple registries
31-
client, _ := NewMavenRegistryAPIClient(MavenRegistry{URL: srv.URL, ReleasesEnabled: true}, "")
31+
client, _ := NewMavenRegistryAPIClient(t.Context(), MavenRegistry{URL: srv.URL, ReleasesEnabled: true}, "")
3232
testRegistry1 := MavenRegistry{
3333
URL: "https://test1.maven.org/maven2/",
3434
ID: "test1",
@@ -39,10 +39,10 @@ func TestWithoutRegistriesMaintainsAuthData(t *testing.T) {
3939
ID: "test2",
4040
SnapshotsEnabled: true,
4141
}
42-
if err := client.AddRegistry(testRegistry1); err != nil {
42+
if err := client.AddRegistry(t.Context(), testRegistry1); err != nil {
4343
t.Fatalf("failed to add registry %s: %v", testRegistry1.URL, err)
4444
}
45-
if err := client.AddRegistry(testRegistry2); err != nil {
45+
if err := client.AddRegistry(t.Context(), testRegistry2); err != nil {
4646
t.Fatalf("failed to add registry %s: %v", testRegistry2.URL, err)
4747
}
4848

clients/datasource/maven_registry_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import (
2828

2929
func TestGetProject(t *testing.T) {
3030
srv := clienttest.NewMockHTTPServer(t)
31-
client, _ := datasource.NewMavenRegistryAPIClient(datasource.MavenRegistry{URL: srv.URL, ReleasesEnabled: true}, "")
31+
client, _ := datasource.NewMavenRegistryAPIClient(t.Context(), datasource.MavenRegistry{URL: srv.URL, ReleasesEnabled: true}, "")
3232
srv.SetResponse(t, "org/example/x.y.z/1.0.0/x.y.z-1.0.0.pom", []byte(`
3333
<project>
3434
<groupId>org.example</groupId>
@@ -55,7 +55,7 @@ func TestGetProject(t *testing.T) {
5555

5656
func TestGetProjectSnapshot(t *testing.T) {
5757
srv := clienttest.NewMockHTTPServer(t)
58-
client, _ := datasource.NewMavenRegistryAPIClient(datasource.MavenRegistry{URL: srv.URL, SnapshotsEnabled: true}, "")
58+
client, _ := datasource.NewMavenRegistryAPIClient(t.Context(), datasource.MavenRegistry{URL: srv.URL, SnapshotsEnabled: true}, "")
5959
srv.SetResponse(t, "org/example/x.y.z/3.3.1-SNAPSHOT/maven-metadata.xml", []byte(`
6060
<metadata>
6161
<groupId>org.example</groupId>
@@ -107,7 +107,7 @@ func TestGetProjectSnapshot(t *testing.T) {
107107

108108
func TestMultipleRegistry(t *testing.T) {
109109
dft := clienttest.NewMockHTTPServer(t)
110-
client, _ := datasource.NewMavenRegistryAPIClient(datasource.MavenRegistry{URL: dft.URL, ReleasesEnabled: true}, "")
110+
client, _ := datasource.NewMavenRegistryAPIClient(t.Context(), datasource.MavenRegistry{URL: dft.URL, ReleasesEnabled: true}, "")
111111
dft.SetResponse(t, "org/example/x.y.z/maven-metadata.xml", []byte(`
112112
<metadata>
113113
<groupId>org.example</groupId>
@@ -138,7 +138,7 @@ func TestMultipleRegistry(t *testing.T) {
138138
`))
139139

140140
srv := clienttest.NewMockHTTPServer(t)
141-
if err := client.AddRegistry(datasource.MavenRegistry{URL: srv.URL, ReleasesEnabled: true}); err != nil {
141+
if err := client.AddRegistry(t.Context(), datasource.MavenRegistry{URL: srv.URL, ReleasesEnabled: true}); err != nil {
142142
t.Fatalf("failed to add registry %s: %v", srv.URL, err)
143143
}
144144
srv.SetResponse(t, "org/example/x.y.z/maven-metadata.xml", []byte(`
@@ -197,7 +197,7 @@ func TestMultipleRegistry(t *testing.T) {
197197

198198
func TestUpdateDefaultRegistry(t *testing.T) {
199199
dft := clienttest.NewMockHTTPServer(t)
200-
client, _ := datasource.NewMavenRegistryAPIClient(datasource.MavenRegistry{URL: dft.URL, ReleasesEnabled: true}, "")
200+
client, _ := datasource.NewMavenRegistryAPIClient(t.Context(), datasource.MavenRegistry{URL: dft.URL, ReleasesEnabled: true}, "")
201201
dft.SetResponse(t, "org/example/x.y.z/maven-metadata.xml", []byte(`
202202
<metadata>
203203
<groupId>org.example</groupId>
@@ -222,7 +222,7 @@ func TestUpdateDefaultRegistry(t *testing.T) {
222222
}
223223

224224
srv := clienttest.NewMockHTTPServer(t)
225-
if err := client.AddRegistry(datasource.MavenRegistry{URL: srv.URL, ID: "default", ReleasesEnabled: true}); err != nil {
225+
if err := client.AddRegistry(t.Context(), datasource.MavenRegistry{URL: srv.URL, ID: "default", ReleasesEnabled: true}); err != nil {
226226
t.Fatalf("failed to add registry %s: %v", srv.URL, err)
227227
}
228228
srv.SetResponse(t, "org/example/x.y.z/maven-metadata.xml", []byte(`
@@ -252,7 +252,7 @@ func TestUpdateDefaultRegistry(t *testing.T) {
252252
func TestMavenLocalRegistry(t *testing.T) {
253253
tempDir := t.TempDir()
254254
srv := clienttest.NewMockHTTPServer(t)
255-
client, _ := datasource.NewMavenRegistryAPIClient(datasource.MavenRegistry{URL: srv.URL, ReleasesEnabled: true}, tempDir)
255+
client, _ := datasource.NewMavenRegistryAPIClient(t.Context(), datasource.MavenRegistry{URL: srv.URL, ReleasesEnabled: true}, tempDir)
256256
path := "org/example/x.y.z/1.0.0/x.y.z-1.0.0.pom"
257257
resp := []byte(`
258258
<project>

clients/resolution/client.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,16 @@
1616
package resolution
1717

1818
import (
19+
"context"
20+
1921
"deps.dev/util/resolve"
2022
)
2123

2224
// ClientWithRegistries is a resolve.Client that allows package registries to be added.
2325
type ClientWithRegistries interface {
2426
resolve.Client
2527
// AddRegistries adds the specified registries to fetch data.
26-
AddRegistries(registries []Registry) error
28+
AddRegistries(ctx context.Context, registries []Registry) error
2729
}
2830

2931
// Registry is the interface of a registry to fetch data.

clients/resolution/combined_native_client.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ func NewCombinedNativeClient(opts CombinedNativeClientOptions) (*CombinedNativeC
5050

5151
// Version returns metadata of a version specified by the VersionKey.
5252
func (c *CombinedNativeClient) Version(ctx context.Context, vk resolve.VersionKey) (resolve.Version, error) {
53-
client, err := c.clientForSystem(vk.System)
53+
client, err := c.clientForSystem(ctx, vk.System)
5454
if err != nil {
5555
return resolve.Version{}, err
5656
}
@@ -59,7 +59,7 @@ func (c *CombinedNativeClient) Version(ctx context.Context, vk resolve.VersionKe
5959

6060
// Versions returns all the available versions of the package specified by the given PackageKey.
6161
func (c *CombinedNativeClient) Versions(ctx context.Context, pk resolve.PackageKey) ([]resolve.Version, error) {
62-
client, err := c.clientForSystem(pk.System)
62+
client, err := c.clientForSystem(ctx, pk.System)
6363
if err != nil {
6464
return nil, err
6565
}
@@ -68,7 +68,7 @@ func (c *CombinedNativeClient) Versions(ctx context.Context, pk resolve.PackageK
6868

6969
// Requirements returns requirements of a version specified by the VersionKey.
7070
func (c *CombinedNativeClient) Requirements(ctx context.Context, vk resolve.VersionKey) ([]resolve.RequirementVersion, error) {
71-
client, err := c.clientForSystem(vk.System)
71+
client, err := c.clientForSystem(ctx, vk.System)
7272
if err != nil {
7373
return nil, err
7474
}
@@ -77,20 +77,20 @@ func (c *CombinedNativeClient) Requirements(ctx context.Context, vk resolve.Vers
7777

7878
// MatchingVersions returns versions matching the requirement specified by the VersionKey.
7979
func (c *CombinedNativeClient) MatchingVersions(ctx context.Context, vk resolve.VersionKey) ([]resolve.Version, error) {
80-
client, err := c.clientForSystem(vk.System)
80+
client, err := c.clientForSystem(ctx, vk.System)
8181
if err != nil {
8282
return nil, err
8383
}
8484
return client.MatchingVersions(ctx, vk)
8585
}
8686

8787
// AddRegistries adds registries to the MavenRegistryClient.
88-
func (c *CombinedNativeClient) AddRegistries(registries []Registry) error {
88+
func (c *CombinedNativeClient) AddRegistries(ctx context.Context, registries []Registry) error {
8989
// TODO(#541): Currently only MavenRegistryClient supports adding registries.
9090
// We might need to add support for PyPIRegistryClient.
9191
// But this AddRegistries method should take a system as input,
9292
// so that we can add registries to the corresponding client.
93-
client, err := c.clientForSystem(resolve.Maven)
93+
client, err := c.clientForSystem(ctx, resolve.Maven)
9494
if err != nil {
9595
return err
9696
}
@@ -99,10 +99,10 @@ func (c *CombinedNativeClient) AddRegistries(registries []Registry) error {
9999
// Currently should not happen.
100100
return nil
101101
}
102-
return regCl.AddRegistries(registries)
102+
return regCl.AddRegistries(ctx, registries)
103103
}
104104

105-
func (c *CombinedNativeClient) clientForSystem(sys resolve.System) (resolve.Client, error) {
105+
func (c *CombinedNativeClient) clientForSystem(ctx context.Context, sys resolve.System) (resolve.Client, error) {
106106
c.mu.Lock()
107107
defer c.mu.Unlock()
108108

@@ -114,7 +114,7 @@ func (c *CombinedNativeClient) clientForSystem(sys resolve.System) (resolve.Clie
114114
if localRegistry != "" {
115115
localRegistry = filepath.Join(c.opts.LocalRegistry, "maven")
116116
}
117-
c.mavenRegistryClient, err = NewMavenRegistryClient(c.opts.MavenRegistry, localRegistry)
117+
c.mavenRegistryClient, err = NewMavenRegistryClient(ctx, c.opts.MavenRegistry, localRegistry)
118118
if err != nil {
119119
return nil, err
120120
}

clients/resolution/maven_registry_client.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ type MavenRegistryClient struct {
3333
}
3434

3535
// NewMavenRegistryClient makes a new MavenRegistryClient.
36-
func NewMavenRegistryClient(remote, local string) (*MavenRegistryClient, error) {
37-
client, err := datasource.NewMavenRegistryAPIClient(datasource.MavenRegistry{URL: remote, ReleasesEnabled: true}, local)
36+
func NewMavenRegistryClient(ctx context.Context, remote, local string) (*MavenRegistryClient, error) {
37+
client, err := datasource.NewMavenRegistryAPIClient(ctx, datasource.MavenRegistry{URL: remote, ReleasesEnabled: true}, local)
3838
if err != nil {
3939
return nil, err
4040
}
@@ -174,13 +174,13 @@ func (c *MavenRegistryClient) MatchingVersions(ctx context.Context, vk resolve.V
174174
}
175175

176176
// AddRegistries adds registries to the MavenRegistryClient.
177-
func (c *MavenRegistryClient) AddRegistries(registries []Registry) error {
177+
func (c *MavenRegistryClient) AddRegistries(ctx context.Context, registries []Registry) error {
178178
for _, reg := range registries {
179179
specific, ok := reg.(datasource.MavenRegistry)
180180
if !ok {
181181
return errors.New("invalid Maven registry information")
182182
}
183-
if err := c.api.AddRegistry(specific); err != nil {
183+
if err := c.api.AddRegistry(ctx, specific); err != nil {
184184
return err
185185
}
186186
}

extractor/filesystem/language/java/pomxmlnet/pomxmlnet.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ type Config struct {
5858
// NewConfig returns the configuration given the URL of the Maven registry to fetch metadata.
5959
func NewConfig(remote, local string) Config {
6060
// No need to check errors since we are using the default Maven Central URL.
61-
mavenClient, _ := datasource.NewMavenRegistryAPIClient(datasource.MavenRegistry{
61+
mavenClient, _ := datasource.NewMavenRegistryAPIClient(context.Background(), datasource.MavenRegistry{
6262
URL: remote,
6363
ReleasesEnabled: true,
6464
}, local)
@@ -117,7 +117,7 @@ func (e Extractor) Extract(ctx context.Context, input *filesystem.ScanInput) (in
117117
// Clear the registries that may be from other extraction.
118118
e.MavenClient = e.MavenClient.WithoutRegistries()
119119
for _, repo := range project.Repositories {
120-
if err := e.MavenClient.AddRegistry(datasource.MavenRegistry{
120+
if err := e.MavenClient.AddRegistry(ctx, datasource.MavenRegistry{
121121
URL: string(repo.URL),
122122
ID: string(repo.ID),
123123
ReleasesEnabled: repo.Releases.Enabled.Boolean(),
@@ -150,7 +150,7 @@ func (e Extractor) Extract(ctx context.Context, input *filesystem.ScanInput) (in
150150
clientRegs[i] = reg
151151
}
152152
if cl, ok := e.depClient.(resolution.ClientWithRegistries); ok {
153-
if err := cl.AddRegistries(clientRegs); err != nil {
153+
if err := cl.AddRegistries(ctx, clientRegs); err != nil {
154154
return inventory.Inventory{}, err
155155
}
156156
}

extractor/filesystem/language/java/pomxmlnet/pomxmlnet_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ func TestExtractor_Extract_WithMockServer(t *testing.T) {
478478
</project>
479479
`))
480480

481-
apiClient, err := datasource.NewMavenRegistryAPIClient(datasource.MavenRegistry{URL: srv.URL, ReleasesEnabled: true}, "")
481+
apiClient, err := datasource.NewMavenRegistryAPIClient(t.Context(), datasource.MavenRegistry{URL: srv.URL, ReleasesEnabled: true}, "")
482482
if err != nil {
483483
t.Fatalf("%v", err)
484484
}

0 commit comments

Comments
 (0)