diff --git a/cmd/attested-get/main.go b/cmd/attested-get/main.go index 4521587..2897c71 100644 --- a/cmd/attested-get/main.go +++ b/cmd/attested-get/main.go @@ -109,7 +109,7 @@ func runClient(cCtx *cli.Context) (err error) { } // Create validators based on the attestation type - attestationType, err := proxy.ParseAttestationType(attestationTypeStr) + attestationType, err := proxy.ParseAttestationType(log, attestationTypeStr) if err != nil { log.With("attestation-type", attestationType).Error("invalid attestation-type passed, see --help") return err diff --git a/cmd/proxy-client/main.go b/cmd/proxy-client/main.go index d6a34ed..fcefa28 100644 --- a/cmd/proxy-client/main.go +++ b/cmd/proxy-client/main.go @@ -25,11 +25,6 @@ var flags []cli.Flag = []cli.Flag{ Value: "https://localhost:80", Usage: "address to proxy requests to", }, - &cli.StringFlag{ - Name: "server-attestation-type", - Value: string(proxy.AttestationAzureTDX), - Usage: "type of attestation to expect and verify (" + proxy.AvailableAttestationTypes + ")", - }, &cli.StringFlag{ Name: "server-measurements", Usage: "optional path to JSON measurements enforced on the server", @@ -45,8 +40,8 @@ var flags []cli.Flag = []cli.Flag{ }, &cli.StringFlag{ Name: "client-attestation-type", - Value: string(proxy.AttestationNone), - Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + ")", + Value: "auto", + Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + "). If not set, automatically detected.", }, &cli.BoolFlag{ Name: "log-json", @@ -96,32 +91,27 @@ func runClient(cCtx *cli.Context) error { Version: common.Version, }) - if cCtx.String("server-attestation-type") != "none" && verifyTLS { - log.Error("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)") - return errors.New("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)") + if serverMeasurements != "" && verifyTLS { + log.Error("invalid combination of --verify-tls and --server-measurements passed (cannot add server measurements and verify default TLS at the same time)") + return errors.New("invalid combination of --verify-tls and --server-measurements passed (cannot add server measurements and verify default TLS at the same time)") } - clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type")) + // Auto-detect client attestation type if not specified + clientAttestationType, err := proxy.ParseAttestationType(log, cCtx.String("client-attestation-type")) if err != nil { log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help") return err } - serverAttestationType, err := proxy.ParseAttestationType(cCtx.String("server-attestation-type")) - if err != nil { - log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help") - return err - } - issuer, err := proxy.CreateAttestationIssuer(log, clientAttestationType) if err != nil { log.Error("could not create attestation issuer", "err", err) return err } - validators, err := proxy.CreateAttestationValidators(log, serverAttestationType, serverMeasurements) + validators, err := proxy.CreateAttestationValidatorsFromFile(log, serverMeasurements) if err != nil { - log.Error("could not create attestation validators", "err", err) + log.Error("could not create attestation validators from file", "err", err) return err } diff --git a/cmd/proxy-server/main.go b/cmd/proxy-server/main.go index 618ba88..dca97cf 100644 --- a/cmd/proxy-server/main.go +++ b/cmd/proxy-server/main.go @@ -40,8 +40,8 @@ var flags []cli.Flag = []cli.Flag{ &cli.StringFlag{ Name: "server-attestation-type", EnvVars: []string{"SERVER_ATTESTATION_TYPE"}, - Value: string(proxy.AttestationAzureTDX), - Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + ")", + Value: "auto", + Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + "). If not set, automatically detected.", }, &cli.StringFlag{ Name: "tls-certificate-path", @@ -53,12 +53,6 @@ var flags []cli.Flag = []cli.Flag{ EnvVars: []string{"TLS_PRIVATE_KEY_PATH"}, Usage: "Path to private key file for the certificate. Only valid with --tls-certificate-path", }, - &cli.StringFlag{ - Name: "client-attestation-type", - EnvVars: []string{"CLIENT_ATTESTATION_TYPE"}, - Value: string(proxy.AttestationNone), - Usage: "type of attestation to expect and verify (" + proxy.AvailableAttestationTypes + ")", - }, &cli.StringFlag{ Name: "client-measurements", EnvVars: []string{"CLIENT_MEASUREMENTS"}, @@ -132,21 +126,16 @@ func runServer(cCtx *cli.Context) error { return errors.New("not all of --tls-certificate-path and --tls-private-key-path specified") } - serverAttestationType, err := proxy.ParseAttestationType(serverAttestationTypeFlag) + // Auto-detect server attestation type if not specified + serverAttestationType, err := proxy.ParseAttestationType(log, cCtx.String("server-attestation-type")) if err != nil { log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help") return err } - clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type")) - if err != nil { - log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help") - return err - } - - validators, err := proxy.CreateAttestationValidators(log, clientAttestationType, clientMeasurements) + validators, err := proxy.CreateAttestationValidatorsFromFile(log, clientMeasurements) if err != nil { - log.Error("could not create attestation validators", "err", err) + log.Error("could not create attestation validators from file", "err", err) return err } diff --git a/proxy/atls_config.go b/proxy/atls_config.go index eee5fb2..67e8b40 100644 --- a/proxy/atls_config.go +++ b/proxy/atls_config.go @@ -23,15 +23,39 @@ import ( type AttestationType string const ( + AttestationAuto AttestationType = "auto" AttestationNone AttestationType = "none" AttestationAzureTDX AttestationType = "azure-tdx" AttestationDCAPTDX AttestationType = "dcap-tdx" ) -const AvailableAttestationTypes string = "none, azure-tdx, dcap-tdx" +const AvailableAttestationTypes string = "auto, none, azure-tdx, dcap-tdx" + +// DetectAttestationType determines the attestation type based on environment +func DetectAttestationType() AttestationType { + // Check for TDX device files - these indicate DCAP TDX + _, tdxErr1 := os.Stat("/dev/tdx-guest") + _, tdxErr2 := os.Stat("/dev/tdx_guest") + if tdxErr1 == nil || tdxErr2 == nil { + return AttestationDCAPTDX + } + + // Try Azure TDX attestation - if it works, we're in Azure TDX + issuer := azure_tdx.NewIssuer(nil) // nil logger for detection + _, err := issuer.Issue(context.Background(), []byte("test"), []byte("test")) + if err == nil { + return AttestationAzureTDX + } + + return AttestationNone +} -func ParseAttestationType(attestationType string) (AttestationType, error) { +func ParseAttestationType(log *slog.Logger, attestationType string) (AttestationType, error) { switch attestationType { + case string(AttestationAuto): + detectedType := DetectAttestationType() + log.With("detected_attestation", detectedType).Info("Auto-detected attestation type") + return detectedType, nil case string(AttestationNone): return AttestationNone, nil case string(AttestationAzureTDX): @@ -56,8 +80,8 @@ func CreateAttestationIssuer(log *slog.Logger, attestationType AttestationType) } } -func CreateAttestationValidators(log *slog.Logger, attestationType AttestationType, jsonMeasurementsPath string) ([]atls.Validator, error) { - if attestationType == AttestationNone { +func CreateAttestationValidatorsFromFile(log *slog.Logger, jsonMeasurementsPath string) ([]atls.Validator, error) { + if jsonMeasurementsPath == "" { return nil, nil } @@ -72,26 +96,42 @@ func CreateAttestationValidators(log *slog.Logger, attestationType AttestationTy return nil, err } - switch attestationType { - case AttestationAzureTDX: - validators := []atls.Validator{} - for _, measurement := range parsedMeasurements { + // Group validators by attestation type + validatorsByType := make(map[AttestationType][]atls.Validator) + + for _, measurement := range parsedMeasurements { + attestationType, err := ParseAttestationType(log, measurement.AttestationType) + if err != nil { + return nil, fmt.Errorf("invalid attestation type %s in measurements file", measurement.AttestationType) + } + + switch attestationType { + case AttestationAzureTDX: attConfig := config.DefaultForAzureTDX() attConfig.SetMeasurements(measurement.Measurements) - validators = append(validators, azure_tdx.NewValidator(attConfig, AttestationLogger{Log: log})) - } - return []atls.Validator{NewMultiValidator(validators)}, nil - case AttestationDCAPTDX: - validators := []atls.Validator{} - for _, measurement := range parsedMeasurements { + validatorsByType[attestationType] = append( + validatorsByType[attestationType], + azure_tdx.NewValidator(attConfig, AttestationLogger{Log: log}), + ) + case AttestationDCAPTDX: attConfig := &config.QEMUTDX{Measurements: measurements.DefaultsFor(cloudprovider.QEMU, variant.QEMUTDX{})} attConfig.SetMeasurements(measurement.Measurements) - validators = append(validators, dcap_tdx.NewValidator(attConfig, AttestationLogger{Log: log})) + validatorsByType[attestationType] = append( + validatorsByType[attestationType], + dcap_tdx.NewValidator(attConfig, AttestationLogger{Log: log}), + ) + default: + return nil, fmt.Errorf("unsupported attestation type %s in measurements file", measurement.AttestationType) } - return []atls.Validator{NewMultiValidator(validators)}, nil - default: - return nil, errors.New("invalid attestation-type passed in") } + + // Create a MultiValidator for each attestation type + var validators []atls.Validator + for _, typeValidators := range validatorsByType { + validators = append(validators, NewMultiValidator(typeValidators)) + } + + return validators, nil } func ExtractMeasurementsFromExtension(ext *pkix.Extension, v variant.Variant) (map[uint32][]byte, error) {