diff --git a/config/config.go b/config/config.go index f6fe04cfcca..2b42b8392ae 100644 --- a/config/config.go +++ b/config/config.go @@ -560,6 +560,7 @@ type Cloud struct { Secret string LocationSecret string // Deprecated: Use LocationSecrets LocationSecrets []LocationSecret + APIKey APIKey LocationID string PrimaryOrgID string MachineID string @@ -588,6 +589,7 @@ type cloudData struct { LocationSecret string `json:"location_secret"` LocationSecrets []LocationSecret `json:"location_secrets"` + APIKey APIKey `json:"api_key"` LocationID string `json:"location_id"` PrimaryOrgID string `json:"primary_org_id"` MachineID string `json:"machine_id"` @@ -603,6 +605,33 @@ type cloudData struct { TLSPrivateKey string `json:"tls_private_key"` } +// APIKey is the cloud app authentication credential +type APIKey struct { + ID string `json:"id"` + Key string `json:"key"` +} + +// IsFullySet returns true if an APIKey has both the ID and Key fields set. +func (a APIKey) IsFullySet() bool { + return a.ID != "" && a.Key != "" +} + +// IsPartiallySet returns true if only one of the ID or Key fields are set. +func (a APIKey) IsPartiallySet() bool { + return (a.ID == "" && a.Key != "") || (a.ID != "" && a.Key == "") +} + +// GetCloudCredsDialOpt returns a dial option with the cloud credentials for this cloud config. +// API keys are always preferred over robot secrets. If neither are set, nil is returned. +func (config *Cloud) GetCloudCredsDialOpt() rpc.DialOption { + if config.APIKey.IsFullySet() { + return rpc.WithEntityCredentials(config.APIKey.ID, rpc.Credentials{rutils.CredentialsTypeAPIKey, config.APIKey.Key}) + } else if config.Secret != "" { + return rpc.WithEntityCredentials(config.ID, rpc.Credentials{rutils.CredentialsTypeRobotSecret, config.Secret}) + } + return nil +} + // UnmarshalJSON unmarshals JSON data into this config. func (config *Cloud) UnmarshalJSON(data []byte) error { var temp cloudData @@ -614,6 +643,7 @@ func (config *Cloud) UnmarshalJSON(data []byte) error { Secret: temp.Secret, LocationSecret: temp.LocationSecret, LocationSecrets: temp.LocationSecrets, + APIKey: temp.APIKey, LocationID: temp.LocationID, PrimaryOrgID: temp.PrimaryOrgID, MachineID: temp.MachineID, @@ -643,6 +673,7 @@ func (config Cloud) MarshalJSON() ([]byte, error) { Secret: config.Secret, LocationSecret: config.LocationSecret, LocationSecrets: config.LocationSecrets, + APIKey: config.APIKey, LocationID: config.LocationID, PrimaryOrgID: config.PrimaryOrgID, MachineID: config.MachineID, @@ -673,8 +704,10 @@ func (config *Cloud) Validate(path string, fromCloud bool) error { if config.LocalFQDN == "" { return resource.NewConfigValidationFieldRequiredError(path, "local_fqdn") } - } else if config.Secret == "" { - return resource.NewConfigValidationFieldRequiredError(path, "secret") + } else if config.APIKey.IsPartiallySet() { + return resource.NewConfigValidationFieldRequiredError(path, "api_key") + } else if config.Secret == "" && !config.APIKey.IsFullySet() { + return resource.NewConfigValidationFieldRequiredError(path, "api_key") } if config.RefreshInterval == 0 { config.RefreshInterval = 10 * time.Second @@ -1060,6 +1093,7 @@ func CreateTLSWithCert(cfg *Config) (*tls.Config, error) { func ProcessConfig(in *Config) (*Config, error) { out := *in var selfCreds *rpc.Credentials + var selfAuthEntity string if in.Cloud != nil { // We expect a cloud config from app to always contain a non-empty `TLSCertificate` field. // We do this empty string check just to cope with unexpected input, such as cached configs @@ -1071,7 +1105,13 @@ func ProcessConfig(in *Config) (*Config, error) { } out.Network.TLSConfig = tlsConfig } - selfCreds = &rpc.Credentials{rutils.CredentialsTypeRobotSecret, in.Cloud.Secret} + if in.Cloud.APIKey.IsFullySet() { + selfCreds = &rpc.Credentials{rutils.CredentialsTypeAPIKey, in.Cloud.APIKey.Key} + selfAuthEntity = in.Cloud.APIKey.ID + } else { + selfCreds = &rpc.Credentials{rutils.CredentialsTypeRobotSecret, in.Cloud.Secret} + selfAuthEntity = in.Cloud.ID + } } out.Remotes = make([]Remote, len(in.Remotes)) @@ -1086,7 +1126,7 @@ func ProcessConfig(in *Config) (*Config, error) { } remoteCopy.Auth.Managed = true remoteCopy.Auth.SignalingServerAddress = in.Cloud.SignalingAddress - remoteCopy.Auth.SignalingAuthEntity = in.Cloud.ID + remoteCopy.Auth.SignalingAuthEntity = selfAuthEntity remoteCopy.Auth.SignalingCreds = selfCreds } out.Remotes[idx] = remoteCopy diff --git a/config/config_test.go b/config/config_test.go index a6b7183a8cf..6a18b4e7e34 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -201,14 +201,24 @@ func TestConfigEnsure(t *testing.T) { invalidCloud.Cloud.ID = "some_id" err = invalidCloud.Ensure(false, logger) test.That(t, err, test.ShouldNotBeNil) - test.That(t, resource.GetFieldFromFieldRequiredError(err), test.ShouldEqual, "secret") + test.That(t, resource.GetFieldFromFieldRequiredError(err), test.ShouldEqual, "api_key") err = invalidCloud.Ensure(true, logger) test.That(t, err, test.ShouldNotBeNil) test.That(t, resource.GetFieldFromFieldRequiredError(err), test.ShouldEqual, "fqdn") invalidCloud.Cloud.Secret = "my_secret" test.That(t, invalidCloud.Ensure(false, logger), test.ShouldBeNil) test.That(t, invalidCloud.Ensure(true, logger), test.ShouldNotBeNil) + invalidCloud.Cloud.APIKey = config.APIKey{ID: "", Key: "key_value"} + err = invalidCloud.Ensure(false, logger) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, resource.GetFieldFromFieldRequiredError(err), test.ShouldEqual, "api_key") + err = invalidCloud.Ensure(true, logger) + test.That(t, err, test.ShouldNotBeNil) invalidCloud.Cloud.Secret = "" + invalidCloud.Cloud.APIKey = config.APIKey{ID: "key_id", Key: "key_value"} + test.That(t, invalidCloud.Ensure(false, logger), test.ShouldBeNil) + test.That(t, invalidCloud.Ensure(true, logger), test.ShouldNotBeNil) + invalidCloud.Cloud.APIKey = config.APIKey{} invalidCloud.Cloud.FQDN = "wooself" err = invalidCloud.Ensure(true, logger) test.That(t, err, test.ShouldNotBeNil) @@ -475,14 +485,24 @@ func TestConfigEnsurePartialStart(t *testing.T) { invalidCloud.Cloud.ID = "some_id" err = invalidCloud.Ensure(false, logger) test.That(t, err, test.ShouldNotBeNil) - test.That(t, resource.GetFieldFromFieldRequiredError(err), test.ShouldEqual, "secret") + test.That(t, resource.GetFieldFromFieldRequiredError(err), test.ShouldEqual, "api_key") err = invalidCloud.Ensure(true, logger) test.That(t, err, test.ShouldNotBeNil) test.That(t, resource.GetFieldFromFieldRequiredError(err), test.ShouldEqual, "fqdn") invalidCloud.Cloud.Secret = "my_secret" test.That(t, invalidCloud.Ensure(false, logger), test.ShouldBeNil) test.That(t, invalidCloud.Ensure(true, logger), test.ShouldNotBeNil) + invalidCloud.Cloud.APIKey = config.APIKey{ID: "", Key: "key_value"} + err = invalidCloud.Ensure(false, logger) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, resource.GetFieldFromFieldRequiredError(err), test.ShouldEqual, "api_key") + err = invalidCloud.Ensure(true, logger) + test.That(t, err, test.ShouldNotBeNil) invalidCloud.Cloud.Secret = "" + invalidCloud.Cloud.APIKey = config.APIKey{ID: "key_id", Key: "key_value"} + test.That(t, invalidCloud.Ensure(false, logger), test.ShouldBeNil) + test.That(t, invalidCloud.Ensure(true, logger), test.ShouldNotBeNil) + invalidCloud.Cloud.APIKey = config.APIKey{} invalidCloud.Cloud.FQDN = "wooself" err = invalidCloud.Ensure(true, logger) test.That(t, err, test.ShouldNotBeNil) diff --git a/config/diff.go b/config/diff.go index a289def5fa7..fb6e83b3f83 100644 --- a/config/diff.go +++ b/config/diff.go @@ -145,6 +145,9 @@ func prettyDiff(left, right Config) (string, error) { conf.Cloud.LocationSecrets[i].Secret = mask } } + if conf.Cloud.APIKey.Key != "" { + conf.Cloud.APIKey.Key = mask + } // Not really a secret but annoying to diff if conf.Cloud.TLSCertificate != "" { conf.Cloud.TLSCertificate = mask diff --git a/config/diff_test.go b/config/diff_test.go index 83f92c76a94..a7c78e875ba 100644 --- a/config/diff_test.go +++ b/config/diff_test.go @@ -604,6 +604,10 @@ func TestDiffSanitize(t *testing.T) { {ID: "id1", Secret: "sec1"}, {ID: "id2", Secret: "sec2"}, }, + APIKey: config.APIKey{ + ID: "api_key_id", + Key: "sec3", + }, TLSCertificate: "foo", TLSPrivateKey: "bar", } @@ -675,6 +679,7 @@ func TestDiffSanitize(t *testing.T) { test.That(t, diffStr, test.ShouldNotContainSubstring, cloud1.LocationSecret) test.That(t, diffStr, test.ShouldNotContainSubstring, cloud1.LocationSecrets[0].Secret) test.That(t, diffStr, test.ShouldNotContainSubstring, cloud1.LocationSecrets[1].Secret) + test.That(t, diffStr, test.ShouldNotContainSubstring, cloud1.APIKey.Key) test.That(t, diffStr, test.ShouldNotContainSubstring, cloud1.TLSCertificate) test.That(t, diffStr, test.ShouldNotContainSubstring, cloud1.TLSPrivateKey) for _, hdlr := range auth1.Handlers { diff --git a/config/reader.go b/config/reader.go index 948e7db14de..241a477ed5a 100644 --- a/config/reader.go +++ b/config/reader.go @@ -730,14 +730,12 @@ func CreateNewGRPCClient(ctx context.Context, cloudCfg *Cloud, logger logging.Lo } dialOpts := make([]rpc.DialOption, 0, 2) - // Only add credentials when secret is set. - if cloudCfg.Secret != "" { - dialOpts = append(dialOpts, rpc.WithEntityCredentials(cloudCfg.ID, - rpc.Credentials{ - Type: rutils.CredentialsTypeRobotSecret, - Payload: cloudCfg.Secret, - }, - )) + + cloudCreds := cloudCfg.GetCloudCredsDialOpt() + + // Only add credentials when they are set. + if cloudCreds != nil { + dialOpts = append(dialOpts, cloudCreds) } if u.Scheme == "http" { diff --git a/config/reader_test.go b/config/reader_test.go index 5e9e654c5ae..720558cf26a 100644 --- a/config/reader_test.go +++ b/config/reader_test.go @@ -68,7 +68,7 @@ func TestFromReader(t *testing.T) { fakeServer.StoreDeviceConfig(robotPartID, protoConfig, certProto) appAddress := fmt.Sprintf("http://%s", fakeServer.Addr().String()) - appConn, err := grpc.NewAppConn(ctx, appAddress, secret, robotPartID, logger) + appConn, err := grpc.NewAppConn(ctx, appAddress, robotPartID, cloudResponse.GetCloudCredsDialOpt(), logger) test.That(t, err, test.ShouldBeNil) defer appConn.Close() cfgText := fmt.Sprintf(`{"cloud":{"id":%q,"app_address":%q,"secret":%q}}`, robotPartID, appAddress, secret) @@ -120,7 +120,7 @@ func TestFromReader(t *testing.T) { fakeServer.StoreDeviceConfig(robotPartID, nil, nil) appAddress := fmt.Sprintf("http://%s", fakeServer.Addr().String()) - appConn, err := grpc.NewAppConn(ctx, appAddress, secret, robotPartID, logger) + appConn, err := grpc.NewAppConn(ctx, appAddress, robotPartID, cachedCloud.GetCloudCredsDialOpt(), logger) test.That(t, err, test.ShouldBeNil) defer appConn.Close() cfgText := fmt.Sprintf(`{"cloud":{"id":%q,"app_address":%q,"secret":%q}}`, robotPartID, appAddress, secret) @@ -162,7 +162,7 @@ func TestFromReader(t *testing.T) { fakeServer.StoreDeviceConfig(robotPartID, protoConfig, certProto) appAddress := fmt.Sprintf("http://%s", fakeServer.Addr().String()) - appConn, err := grpc.NewAppConn(ctx, appAddress, secret, robotPartID, logger) + appConn, err := grpc.NewAppConn(ctx, appAddress, robotPartID, cloudResponse.GetCloudCredsDialOpt(), logger) test.That(t, err, test.ShouldBeNil) defer appConn.Close() cfgText := fmt.Sprintf(`{"cloud":{"id":%q,"app_address":%q,"secret":%q}}`, robotPartID, appAddress, secret) @@ -207,7 +207,7 @@ func TestStoreToCache(t *testing.T) { } cfg.Cloud = cloud - appConn, err := grpc.NewAppConn(ctx, cloud.AppAddress, cloud.Secret, cloud.ID, logger) + appConn, err := grpc.NewAppConn(ctx, cloud.AppAddress, cloud.ID, cfg.Cloud.GetCloudCredsDialOpt(), logger) test.That(t, err, test.ShouldBeNil) defer appConn.Close() diff --git a/config/watcher_test.go b/config/watcher_test.go index 0872bfc3e6c..b6ec09584a8 100644 --- a/config/watcher_test.go +++ b/config/watcher_test.go @@ -275,8 +275,8 @@ func TestNewWatcherCloud(t *testing.T) { storeConfigInServer(confToReturn) - appConn, err := grpc.NewAppConn(context.Background(), confToReturn.Cloud.AppAddress, confToReturn.Cloud.Secret, confToReturn.Cloud.ID, - logger) + appConn, err := grpc.NewAppConn( + context.Background(), confToReturn.Cloud.AppAddress, confToReturn.Cloud.ID, confToReturn.Cloud.GetCloudCredsDialOpt(), logger) test.That(t, err, test.ShouldBeNil) defer appConn.Close() watcher, err := config.NewWatcher(context.Background(), &config.Config{Cloud: newCloudConf()}, logger, appConn) diff --git a/examples/customresources/demos/remoteserver/server.go b/examples/customresources/demos/remoteserver/server.go index af9bab48ff1..569d7dae6af 100644 --- a/examples/customresources/demos/remoteserver/server.go +++ b/examples/customresources/demos/remoteserver/server.go @@ -44,7 +44,8 @@ func mainWithArgs(ctx context.Context, args []string, logger logging.Logger) (er var appConn rpc.ClientConn if cfg.Cloud != nil && cfg.Cloud.AppAddress != "" { - appConn, err = grpc.NewAppConn(ctx, cfg.Cloud.AppAddress, cfg.Cloud.Secret, cfg.Cloud.ID, logger) + cloudCreds := cfg.Cloud.GetCloudCredsDialOpt() + appConn, err = grpc.NewAppConn(ctx, cfg.Cloud.AppAddress, cfg.Cloud.ID, cloudCreds, logger) if err != nil { return nil } diff --git a/grpc/app_conn.go b/grpc/app_conn.go index dcbab571c26..45f07f7b31c 100644 --- a/grpc/app_conn.go +++ b/grpc/app_conn.go @@ -34,7 +34,7 @@ type AppConn struct { // establishing a connection to App will continue to occur, however, in a background Goroutine. These attempts will continue until a // connection is made. If `cloud` is nil, an `AppConn` with a nil underlying connection will return, and the background dialer will not // start. -func NewAppConn(ctx context.Context, appAddress, secret, id string, logger logging.Logger) (rpc.ClientConn, error) { +func NewAppConn(ctx context.Context, appAddress, partID string, cloudCreds rpc.DialOption, logger logging.Logger) (rpc.ClientConn, error) { appConn := &AppConn{ReconfigurableClientConn: &ReconfigurableClientConn{Logger: logger.Sublogger("app_conn")}} grpcURL, err := url.Parse(appAddress) @@ -42,13 +42,17 @@ func NewAppConn(ctx context.Context, appAddress, secret, id string, logger loggi return nil, err } - dialOpts := dialOpts(secret, id) + dialOpts := make([]rpc.DialOption, 0, 2) + + if cloudCreds != nil { + dialOpts = append(dialOpts, cloudCreds) + } if grpcURL.Scheme == "http" { dialOpts = append(dialOpts, rpc.WithInsecure()) } - ctxWithTimeout, ctxWithTimeoutCancel := contextutils.GetTimeoutCtx(ctx, true, id, logger) + ctxWithTimeout, ctxWithTimeoutCancel := contextutils.GetTimeoutCtx(ctx, true, partID, logger) defer ctxWithTimeoutCancel() // there will always be a deadline if deadline, ok := ctxWithTimeout.Deadline(); ok { @@ -131,17 +135,3 @@ func (ac *AppConn) Close() error { return ac.ReconfigurableClientConn.Close() } - -func dialOpts(secret, id string) []rpc.DialOption { - dialOpts := make([]rpc.DialOption, 0, 2) - // Only add credentials when secret is set. - if secret != "" { - dialOpts = append(dialOpts, rpc.WithEntityCredentials(id, - rpc.Credentials{ - Type: "robot-secret", - Payload: secret, - }, - )) - } - return dialOpts -} diff --git a/internal/cloud/service_test.go b/internal/cloud/service_test.go index 9085f3e7175..639a900c8ab 100644 --- a/internal/cloud/service_test.go +++ b/internal/cloud/service_test.go @@ -53,7 +53,7 @@ func TestCloudManaged(t *testing.T) { AppAddress: fmt.Sprintf("http://%s", addr), } - appConn, err := grpc.NewAppConn(context.Background(), conf.AppAddress, "", "", logger) + appConn, err := grpc.NewAppConn(context.Background(), conf.AppAddress, conf.ID, nil, logger) test.That(t, err, test.ShouldBeNil) svc := cloud.NewCloudConnectionService(conf, appConn, logger) @@ -123,7 +123,11 @@ func TestCloudManagedWithAuth(t *testing.T) { logger, rpc.WithAuthHandler( utils.CredentialsTypeRobotSecret, - rpc.MakeSimpleMultiAuthHandler([]string{"foo"}, []string{"bar"}), + rpc.MakeSimpleMultiAuthHandler([]string{"secret_foo"}, []string{"secret_bar"}), + ), + rpc.WithAuthHandler( + utils.CredentialsTypeAPIKey, + rpc.MakeSimpleMultiAuthHandler([]string{"api_foo"}, []string{"api_bar"}), ), ) test.That(t, err, test.ShouldBeNil) @@ -142,100 +146,142 @@ func TestCloudManagedWithAuth(t *testing.T) { addr := server.InternalAddr().String() - conf := &config.Cloud{ - AppAddress: fmt.Sprintf("http://%s", addr), + testCases := []struct { + name string + config *config.Cloud + shouldAuth bool + }{ + { + name: "no credentials - should fail auth", + config: &config.Cloud{ + AppAddress: fmt.Sprintf("http://%s", addr), + }, + shouldAuth: false, + }, + { + name: "robot secret credentials - should succeed", + config: &config.Cloud{ + AppAddress: fmt.Sprintf("http://%s", addr), + ID: "secret_foo", + Secret: "secret_bar", + }, + shouldAuth: true, + }, + { + name: "API key credentials - should succeed", + config: &config.Cloud{ + AppAddress: fmt.Sprintf("http://%s", addr), + APIKey: config.APIKey{ + ID: "api_foo", + Key: "api_bar", + }, + }, + shouldAuth: true, + }, + { + name: "both credentials - API key should be prioritized", + config: &config.Cloud{ + AppAddress: fmt.Sprintf("http://%s", addr), + ID: "secret_foo", + Secret: "secret_bar", + APIKey: config.APIKey{ + ID: "api_foo", + Key: "api_bar", + }, + }, + shouldAuth: true, + }, + { + name: "valid robot secret with invalid API key - should fail", + config: &config.Cloud{ + AppAddress: fmt.Sprintf("http://%s", addr), + ID: "secret_foo", + Secret: "secret_bar", + APIKey: config.APIKey{ + ID: "invalid_foo", + Key: "invalid_bar", + }, + }, + shouldAuth: false, + }, } - appConn, err := grpc.NewAppConn(context.Background(), conf.AppAddress, "", "", logger) - test.That(t, err, test.ShouldBeNil) - - svc := cloud.NewCloudConnectionService(conf, appConn, logger) - id, conn1, err := svc.AcquireConnection(context.Background()) - test.That(t, err, test.ShouldBeNil) - test.That(t, id, test.ShouldBeEmpty) - test.That(t, conn1, test.ShouldEqual, appConn) - - id2, conn2, err := svc.AcquireConnection(context.Background()) - test.That(t, err, test.ShouldBeNil) - test.That(t, id2, test.ShouldBeEmpty) - test.That(t, conn2, test.ShouldEqual, appConn) - - echoClient := echopb.NewEchoServiceClient(conn1) - resp, err := echoClient.Echo(context.Background(), &echopb.EchoRequest{ - Message: "hello", - }) - test.That(t, resp, test.ShouldBeNil) - test.That(t, err, test.ShouldNotBeNil) - test.That(t, status.Code(err), test.ShouldEqual, codes.Unauthenticated) - - test.That(t, svc.Close(context.Background()), test.ShouldBeNil) - test.That(t, appConn.Close(), test.ShouldBeNil) - - conf = &config.Cloud{ - AppAddress: fmt.Sprintf("http://%s", addr), - ID: "foo", - Secret: "bar", + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + testCloudConnectionAuth(t, logger, tc.config, tc.shouldAuth) + }) } +} - appConn, err = grpc.NewAppConn(context.Background(), conf.AppAddress, conf.Secret, conf.ID, logger) +func testCloudConnectionAuth(t *testing.T, logger logging.Logger, conf *config.Cloud, shouldAuth bool) { + cloudCreds := conf.GetCloudCredsDialOpt() + appConn, err := grpc.NewAppConn(context.Background(), conf.AppAddress, conf.ID, cloudCreds, logger) test.That(t, err, test.ShouldBeNil) - svc = cloud.NewCloudConnectionService(conf, appConn, logger) - id, conn1, err = svc.AcquireConnection(context.Background()) + svc := cloud.NewCloudConnectionService(conf, appConn, logger) + _, conn1, err := svc.AcquireConnection(context.Background()) test.That(t, err, test.ShouldBeNil) - test.That(t, id, test.ShouldEqual, "foo") - test.That(t, conn1, test.ShouldNotBeNil) + test.That(t, conn1, test.ShouldEqual, appConn) - id2, conn2, err = svc.AcquireConnection(context.Background()) + _, conn2, err := svc.AcquireConnection(context.Background()) test.That(t, err, test.ShouldBeNil) - test.That(t, id2, test.ShouldEqual, "foo") test.That(t, conn2, test.ShouldEqual, appConn) echoClient1 := echopb.NewEchoServiceClient(conn1) echoClient2 := echopb.NewEchoServiceClient(conn2) - resp, err = echoClient1.Echo(context.Background(), &echopb.EchoRequest{ - Message: "hello", - }) - test.That(t, err, test.ShouldBeNil) - test.That(t, resp.Message, test.ShouldEqual, "hello") - - resp, err = echoClient2.Echo(context.Background(), &echopb.EchoRequest{ - Message: "hello", - }) - test.That(t, err, test.ShouldBeNil) - test.That(t, resp.Message, test.ShouldEqual, "hello") - - test.That(t, appConn.Close(), test.ShouldBeNil) - - // now "both" connections are closed - resp, err = echoClient1.Echo(context.Background(), &echopb.EchoRequest{ - Message: "hello", - }) - test.That(t, resp, test.ShouldBeNil) - test.That(t, err, test.ShouldNotBeNil) - test.That(t, err, test.ShouldEqual, grpc.ErrNotConnected) - - resp, err = echoClient2.Echo(context.Background(), &echopb.EchoRequest{ + // Test first echo call + resp, err := echoClient1.Echo(context.Background(), &echopb.EchoRequest{ Message: "hello", }) - test.That(t, resp, test.ShouldBeNil) - test.That(t, err, test.ShouldNotBeNil) - test.That(t, err, test.ShouldEqual, grpc.ErrNotConnected) - - id3, conn3, err := svc.AcquireConnection(context.Background()) - test.That(t, err, test.ShouldBeNil) - test.That(t, id3, test.ShouldEqual, "foo") - test.That(t, conn3, test.ShouldNotBeNil) - test.That(t, conn3, test.ShouldEqual, conn2) - echoClient3 := echopb.NewEchoServiceClient(conn3) - resp, err = echoClient3.Echo(context.Background(), &echopb.EchoRequest{ - Message: "hello", - }) - test.That(t, resp, test.ShouldBeNil) - test.That(t, err, test.ShouldNotBeNil) - test.That(t, err, test.ShouldEqual, grpc.ErrNotConnected) + if shouldAuth { + test.That(t, err, test.ShouldBeNil) + test.That(t, resp.Message, test.ShouldEqual, "hello") + + // Test second echo call + resp, err = echoClient2.Echo(context.Background(), &echopb.EchoRequest{ + Message: "hello", + }) + test.That(t, err, test.ShouldBeNil) + test.That(t, resp.Message, test.ShouldEqual, "hello") + + test.That(t, appConn.Close(), test.ShouldBeNil) + + // Test connection behavior after close + resp, err = echoClient1.Echo(context.Background(), &echopb.EchoRequest{ + Message: "hello", + }) + test.That(t, resp, test.ShouldBeNil) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err, test.ShouldEqual, grpc.ErrNotConnected) + + resp, err = echoClient2.Echo(context.Background(), &echopb.EchoRequest{ + Message: "hello", + }) + test.That(t, resp, test.ShouldBeNil) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err, test.ShouldEqual, grpc.ErrNotConnected) + + _, conn3, err := svc.AcquireConnection(context.Background()) + test.That(t, err, test.ShouldBeNil) + test.That(t, conn3, test.ShouldNotBeNil) + test.That(t, conn3, test.ShouldEqual, conn2) + + echoClient3 := echopb.NewEchoServiceClient(conn3) + resp, err = echoClient3.Echo(context.Background(), &echopb.EchoRequest{ + Message: "hello", + }) + test.That(t, resp, test.ShouldBeNil) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err, test.ShouldEqual, grpc.ErrNotConnected) + } else { + test.That(t, resp, test.ShouldBeNil) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, status.Code(err), test.ShouldEqual, codes.Unauthenticated) + + test.That(t, appConn.Close(), test.ShouldBeNil) + } test.That(t, svc.Close(context.Background()), test.ShouldBeNil) } diff --git a/logging/net_appender.go b/logging/net_appender.go index 0d3cdbdb5b2..c3e5a77049d 100644 --- a/logging/net_appender.go +++ b/logging/net_appender.go @@ -34,7 +34,7 @@ var ( type CloudConfig struct { AppAddress string ID string - Secret string + CloudCred rpc.DialOption } // NewNetAppender creates a NetAppender to send log events to the app backend. NetAppenders ought to @@ -516,14 +516,10 @@ func CreateNewGRPCClient(ctx context.Context, cloudCfg *CloudConfig, logger Logg } dialOpts := make([]rpc.DialOption, 0, 2) - // Only add credentials when secret is set. - if cloudCfg.Secret != "" { - dialOpts = append(dialOpts, rpc.WithEntityCredentials(cloudCfg.ID, - rpc.Credentials{ - Type: "robot-secret", - Payload: cloudCfg.Secret, - }, - )) + + // Only add credentials when they are set. + if cloudCfg.CloudCred != nil { + dialOpts = append(dialOpts, cloudCfg.CloudCred) } if grpcURL.Scheme == "http" { diff --git a/robot/impl/utils.go b/robot/impl/utils.go index 0c5313ad38d..107af7fc8cb 100644 --- a/robot/impl/utils.go +++ b/robot/impl/utils.go @@ -28,7 +28,8 @@ func setupLocalRobot( var conn rpc.ClientConn var err error if cfg.Cloud != nil && cfg.Cloud.AppAddress != "" { - conn, err = grpc.NewAppConn(ctx, cfg.Cloud.AppAddress, cfg.Cloud.Secret, cfg.Cloud.ID, logger.Sublogger("appconn")) + cloudCreds := cfg.Cloud.GetCloudCredsDialOpt() + conn, err = grpc.NewAppConn(ctx, cfg.Cloud.AppAddress, cfg.Cloud.ID, cloudCreds, logger.Sublogger("appconn")) test.That(t, err, test.ShouldBeNil) } diff --git a/robot/packages/cloud_package_manager.go b/robot/packages/cloud_package_manager.go index e66d8368f7b..f780e06e975 100644 --- a/robot/packages/cloud_package_manager.go +++ b/robot/packages/cloud_package_manager.go @@ -199,7 +199,7 @@ func (m *cloudManager) Sync(ctx context.Context, packages []config.PackageConfig return "", "", err } - return m.downloadFileWithChecksum(ctx, url, dstPath, m.cloudConfig.ID, m.cloudConfig.Secret) + return m.downloadFileWithChecksum(ctx, url, dstPath) }, ) if err != nil { @@ -399,12 +399,19 @@ func (m *cloudManager) downloadFileWithChecksum( ctx context.Context, rawURL string, downloadPath string, - partID string, - partSecret string, ) (string, string, error) { getReq, err := http.NewRequestWithContext(ctx, http.MethodHead, rawURL, nil) - getReq.Header.Add("part_id", partID) - getReq.Header.Add("secret", partSecret) + + headers := make(http.Header) + if m.cloudConfig.APIKey.IsFullySet() { + headers.Add("key_id", m.cloudConfig.APIKey.ID) + headers.Add("key", m.cloudConfig.APIKey.Key) + } else { + headers.Add("part_id", m.cloudConfig.ID) + headers.Add("secret", m.cloudConfig.Secret) + } + getReq.Header = headers + if err != nil { return "", "", err } @@ -438,7 +445,7 @@ func (m *cloudManager) downloadFileWithChecksum( g := getter.HttpGetter{ MaxBytes: maxBytesForTesting, - Header: http.Header{"part_id": []string{partID}, "secret": []string{partSecret}}, + Header: headers, Client: &m.httpClient, } g.SetClient(&getter.Client{Ctx: ctx}) diff --git a/robot/packages/cloud_package_manager_test.go b/robot/packages/cloud_package_manager_test.go index f873dc86534..640134656ea 100644 --- a/robot/packages/cloud_package_manager_test.go +++ b/robot/packages/cloud_package_manager_test.go @@ -656,7 +656,7 @@ func TestDownloadFileWithChecksum(t *testing.T) { t.Run("complete", func(t *testing.T) { dest := filepath.Join(packagesDir, "download1") - _, _, err := pm.downloadFileWithChecksum(t.Context(), server.URL+"/download1", dest, "id", "secret") + _, _, err := pm.downloadFileWithChecksum(t.Context(), server.URL+"/download1", dest) test.That(t, err, test.ShouldBeNil) }) @@ -667,11 +667,11 @@ func TestDownloadFileWithChecksum(t *testing.T) { dest := filepath.Join(packagesDir, "download2") // first attempt fails midway because of maxBytesForTesting - _, _, err := pm.downloadFileWithChecksum(t.Context(), server.URL+"/download2", dest, "id", "secret") + _, _, err := pm.downloadFileWithChecksum(t.Context(), server.URL+"/download2", dest) test.That(t, err.Error(), test.ShouldContainSubstring, "short write") // second attempt finishes - _, _, err = pm.downloadFileWithChecksum(t.Context(), server.URL+"/download2", dest, "id", "secret") + _, _, err = pm.downloadFileWithChecksum(t.Context(), server.URL+"/download2", dest) test.That(t, err, test.ShouldBeNil) // check the length test.That(t, handler.lengths[len(handler.lengths)-1], test.ShouldEqual, diff --git a/robot/web/options/options.go b/robot/web/options/options.go index 59370ff6cb6..32b7e66c7d6 100644 --- a/robot/web/options/options.go +++ b/robot/web/options/options.go @@ -149,10 +149,8 @@ func FromConfig(cfg *config.Config) (Options, error) { }, }) - signalingDialOpts := []rpc.DialOption{rpc.WithEntityCredentials( - cfg.Cloud.ID, - rpc.Credentials{utils.CredentialsTypeRobotSecret, cfg.Cloud.Secret}, - )} + cloudCreds := cfg.Cloud.GetCloudCredsDialOpt() + signalingDialOpts := []rpc.DialOption{cloudCreds} if cfg.Cloud.SignalingInsecure { signalingDialOpts = append(signalingDialOpts, rpc.WithInsecure()) } diff --git a/web/server/entrypoint.go b/web/server/entrypoint.go index b33e6ee5278..5d6f0d3e3f4 100644 --- a/web/server/entrypoint.go +++ b/web/server/entrypoint.go @@ -230,8 +230,10 @@ func RunServer(ctx context.Context, args []string, _ logging.Logger) (err error) // serialized manner if cfgFromDisk.Cloud != nil { cloud := cfgFromDisk.Cloud + + cloudCreds := cfgFromDisk.Cloud.GetCloudCredsDialOpt() appConnLogger := networkingLogger.Sublogger("app_connection") - appConn, err = grpc.NewAppConn(ctx, cloud.AppAddress, cloud.Secret, cloud.ID, appConnLogger) + appConn, err = grpc.NewAppConn(ctx, cloud.AppAddress, cloud.ID, cloudCreds, appConnLogger) if err != nil { return err } @@ -240,7 +242,8 @@ func RunServer(ctx context.Context, args []string, _ logging.Logger) (err error) // if SignalingAddress is specified and different from AppAddress, create a new connection to it. Otherwise reuse appConn. if cloud.SignalingAddress != "" && cloud.SignalingAddress != cloud.AppAddress { signalingConnLogger := networkingLogger.Sublogger("signaling_connection") - signalingConn, err = grpc.NewAppConn(ctx, cloud.SignalingAddress, cloud.Secret, cloud.ID, signalingConnLogger) + signalingConn, err = grpc.NewAppConn( + ctx, cloud.SignalingAddress, cloud.ID, cloudCreds, signalingConnLogger) if err != nil { return err } @@ -248,25 +251,25 @@ func RunServer(ctx context.Context, args []string, _ logging.Logger) (err error) } else { signalingConn = appConn } - } - // Start remote logging with config from disk. - // This is to ensure we make our best effort to write logs for failures loading the remote config. - if cfgFromDisk.Cloud != nil && cfgFromDisk.Cloud.AppAddress != "" { - netAppender, err := logging.NewNetAppender( - &logging.CloudConfig{ - AppAddress: cfgFromDisk.Cloud.AppAddress, - ID: cfgFromDisk.Cloud.ID, - Secret: cfgFromDisk.Cloud.Secret, - }, - appConn, false, logging.NewLogger("NetAppender-loggerWithoutNet"), - ) - if err != nil { - return err - } - defer netAppender.Close() + // Start remote logging with config from disk. + // This is to ensure we make our best effort to write logs for failures loading the remote config. + if cloud.AppAddress != "" { + netAppender, err := logging.NewNetAppender( + &logging.CloudConfig{ + AppAddress: cloud.AppAddress, + ID: cloud.ID, + CloudCred: cloudCreds, + }, + appConn, false, logging.NewLogger("NetAppender-loggerWithoutNet"), + ) + if err != nil { + return err + } + defer netAppender.Close() - registry.AddAppenderToAll(netAppender) + registry.AddAppenderToAll(netAppender) + } } // log startup info and run network checks after netlogger is initialized so it's captured in cloud machine logs. logStartupInfo(rootLogger)