From e36bf5ba5390a452837904af6101467da3b07fc9 Mon Sep 17 00:00:00 2001 From: fnerdman Date: Tue, 28 Jan 2025 09:22:04 +0100 Subject: [PATCH 1/2] feat: auto detect tee env --- cmd/proxy-client/main.go | 15 +++++++++++---- cmd/proxy-server/main.go | 17 ++++++++++++----- proxy/atls_config.go | 19 +++++++++++++++++++ 3 files changed, 42 insertions(+), 9 deletions(-) diff --git a/cmd/proxy-client/main.go b/cmd/proxy-client/main.go index d6a34ed..dfe4383 100644 --- a/cmd/proxy-client/main.go +++ b/cmd/proxy-client/main.go @@ -45,8 +45,8 @@ var flags []cli.Flag = []cli.Flag{ }, &cli.StringFlag{ Name: "client-attestation-type", - Value: string(proxy.AttestationNone), - Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + ")", + Value: "", + Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + "). If not set, automatically detected.", }, &cli.BoolFlag{ Name: "log-json", @@ -101,10 +101,17 @@ func runClient(cCtx *cli.Context) error { return errors.New("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)") } + // Auto-detect client attestation type if not specified clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type")) if err != nil { - log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help") - return err + // If parsing fails and no type was specified, use auto-detection + if cCtx.String("client-attestation-type") == "" { + clientAttestationType = proxy.DetectAttestationType() + log.With("detected_attestation", clientAttestationType).Info("Auto-detected client attestation type") + } else { + log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help") + return err + } } serverAttestationType, err := proxy.ParseAttestationType(cCtx.String("server-attestation-type")) diff --git a/cmd/proxy-server/main.go b/cmd/proxy-server/main.go index 618ba88..20f0ca0 100644 --- a/cmd/proxy-server/main.go +++ b/cmd/proxy-server/main.go @@ -40,8 +40,8 @@ var flags []cli.Flag = []cli.Flag{ &cli.StringFlag{ Name: "server-attestation-type", EnvVars: []string{"SERVER_ATTESTATION_TYPE"}, - Value: string(proxy.AttestationAzureTDX), - Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + ")", + Value: "", + Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + "). If not set, automatically detected.", }, &cli.StringFlag{ Name: "tls-certificate-path", @@ -132,10 +132,17 @@ func runServer(cCtx *cli.Context) error { return errors.New("not all of --tls-certificate-path and --tls-private-key-path specified") } - serverAttestationType, err := proxy.ParseAttestationType(serverAttestationTypeFlag) + // Auto-detect server attestation type if not specified + serverAttestationType, err := proxy.ParseAttestationType(cCtx.String("server-attestation-type")) if err != nil { - log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help") - return err + // If parsing fails and no type was specified, use auto-detection + if cCtx.String("server-attestation-type") == "" { + serverAttestationType = proxy.DetectAttestationType() + log.With("detected_attestation", serverAttestationType).Info("Auto-detected server attestation type") + } else { + log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help") + return err + } } clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type")) diff --git a/proxy/atls_config.go b/proxy/atls_config.go index eee5fb2..d44ca37 100644 --- a/proxy/atls_config.go +++ b/proxy/atls_config.go @@ -30,6 +30,25 @@ const ( const AvailableAttestationTypes string = "none, azure-tdx, dcap-tdx" +// DetectAttestationType determines the attestation type based on environment +func DetectAttestationType() AttestationType { + // Check for TDX device files - these indicate DCAP TDX + _, tdxErr1 := os.Stat("/dev/tdx-guest") + _, tdxErr2 := os.Stat("/dev/tdx_guest") + if tdxErr1 == nil || tdxErr2 == nil { + return AttestationDCAPTDX + } + + // Try Azure TDX attestation - if it works, we're in Azure TDX + issuer := azure_tdx.NewIssuer(nil) // nil logger for detection + _, err := issuer.Issue(context.Background(), []byte("test"), []byte("test")) + if err == nil { + return AttestationAzureTDX + } + + return AttestationNone +} + func ParseAttestationType(attestationType string) (AttestationType, error) { switch attestationType { case string(AttestationNone): From 80df1e47e61f0eb4cdb7bec143e8f93350f2c553 Mon Sep 17 00:00:00 2001 From: fnerdman Date: Tue, 28 Jan 2025 15:18:57 +0100 Subject: [PATCH 2/2] feat: fetches validator attestation type from measurments, removes flag --- cmd/attested-get/main.go | 2 +- cmd/proxy-client/main.go | 33 ++++++----------------- cmd/proxy-server/main.go | 28 ++++---------------- proxy/atls_config.go | 57 +++++++++++++++++++++++++++------------- 4 files changed, 53 insertions(+), 67 deletions(-) diff --git a/cmd/attested-get/main.go b/cmd/attested-get/main.go index 4521587..2897c71 100644 --- a/cmd/attested-get/main.go +++ b/cmd/attested-get/main.go @@ -109,7 +109,7 @@ func runClient(cCtx *cli.Context) (err error) { } // Create validators based on the attestation type - attestationType, err := proxy.ParseAttestationType(attestationTypeStr) + attestationType, err := proxy.ParseAttestationType(log, attestationTypeStr) if err != nil { log.With("attestation-type", attestationType).Error("invalid attestation-type passed, see --help") return err diff --git a/cmd/proxy-client/main.go b/cmd/proxy-client/main.go index dfe4383..fcefa28 100644 --- a/cmd/proxy-client/main.go +++ b/cmd/proxy-client/main.go @@ -25,11 +25,6 @@ var flags []cli.Flag = []cli.Flag{ Value: "https://localhost:80", Usage: "address to proxy requests to", }, - &cli.StringFlag{ - Name: "server-attestation-type", - Value: string(proxy.AttestationAzureTDX), - Usage: "type of attestation to expect and verify (" + proxy.AvailableAttestationTypes + ")", - }, &cli.StringFlag{ Name: "server-measurements", Usage: "optional path to JSON measurements enforced on the server", @@ -45,7 +40,7 @@ var flags []cli.Flag = []cli.Flag{ }, &cli.StringFlag{ Name: "client-attestation-type", - Value: "", + Value: "auto", Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + "). If not set, automatically detected.", }, &cli.BoolFlag{ @@ -96,27 +91,15 @@ func runClient(cCtx *cli.Context) error { Version: common.Version, }) - if cCtx.String("server-attestation-type") != "none" && verifyTLS { - log.Error("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)") - return errors.New("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)") + if serverMeasurements != "" && verifyTLS { + log.Error("invalid combination of --verify-tls and --server-measurements passed (cannot add server measurements and verify default TLS at the same time)") + return errors.New("invalid combination of --verify-tls and --server-measurements passed (cannot add server measurements and verify default TLS at the same time)") } // Auto-detect client attestation type if not specified - clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type")) - if err != nil { - // If parsing fails and no type was specified, use auto-detection - if cCtx.String("client-attestation-type") == "" { - clientAttestationType = proxy.DetectAttestationType() - log.With("detected_attestation", clientAttestationType).Info("Auto-detected client attestation type") - } else { - log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help") - return err - } - } - - serverAttestationType, err := proxy.ParseAttestationType(cCtx.String("server-attestation-type")) + clientAttestationType, err := proxy.ParseAttestationType(log, cCtx.String("client-attestation-type")) if err != nil { - log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help") + log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help") return err } @@ -126,9 +109,9 @@ func runClient(cCtx *cli.Context) error { return err } - validators, err := proxy.CreateAttestationValidators(log, serverAttestationType, serverMeasurements) + validators, err := proxy.CreateAttestationValidatorsFromFile(log, serverMeasurements) if err != nil { - log.Error("could not create attestation validators", "err", err) + log.Error("could not create attestation validators from file", "err", err) return err } diff --git a/cmd/proxy-server/main.go b/cmd/proxy-server/main.go index 20f0ca0..dca97cf 100644 --- a/cmd/proxy-server/main.go +++ b/cmd/proxy-server/main.go @@ -40,7 +40,7 @@ var flags []cli.Flag = []cli.Flag{ &cli.StringFlag{ Name: "server-attestation-type", EnvVars: []string{"SERVER_ATTESTATION_TYPE"}, - Value: "", + Value: "auto", Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + "). If not set, automatically detected.", }, &cli.StringFlag{ @@ -53,12 +53,6 @@ var flags []cli.Flag = []cli.Flag{ EnvVars: []string{"TLS_PRIVATE_KEY_PATH"}, Usage: "Path to private key file for the certificate. Only valid with --tls-certificate-path", }, - &cli.StringFlag{ - Name: "client-attestation-type", - EnvVars: []string{"CLIENT_ATTESTATION_TYPE"}, - Value: string(proxy.AttestationNone), - Usage: "type of attestation to expect and verify (" + proxy.AvailableAttestationTypes + ")", - }, &cli.StringFlag{ Name: "client-measurements", EnvVars: []string{"CLIENT_MEASUREMENTS"}, @@ -133,27 +127,15 @@ func runServer(cCtx *cli.Context) error { } // Auto-detect server attestation type if not specified - serverAttestationType, err := proxy.ParseAttestationType(cCtx.String("server-attestation-type")) - if err != nil { - // If parsing fails and no type was specified, use auto-detection - if cCtx.String("server-attestation-type") == "" { - serverAttestationType = proxy.DetectAttestationType() - log.With("detected_attestation", serverAttestationType).Info("Auto-detected server attestation type") - } else { - log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help") - return err - } - } - - clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type")) + serverAttestationType, err := proxy.ParseAttestationType(log, cCtx.String("server-attestation-type")) if err != nil { - log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help") + log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help") return err } - validators, err := proxy.CreateAttestationValidators(log, clientAttestationType, clientMeasurements) + validators, err := proxy.CreateAttestationValidatorsFromFile(log, clientMeasurements) if err != nil { - log.Error("could not create attestation validators", "err", err) + log.Error("could not create attestation validators from file", "err", err) return err } diff --git a/proxy/atls_config.go b/proxy/atls_config.go index d44ca37..67e8b40 100644 --- a/proxy/atls_config.go +++ b/proxy/atls_config.go @@ -23,12 +23,13 @@ import ( type AttestationType string const ( + AttestationAuto AttestationType = "auto" AttestationNone AttestationType = "none" AttestationAzureTDX AttestationType = "azure-tdx" AttestationDCAPTDX AttestationType = "dcap-tdx" ) -const AvailableAttestationTypes string = "none, azure-tdx, dcap-tdx" +const AvailableAttestationTypes string = "auto, none, azure-tdx, dcap-tdx" // DetectAttestationType determines the attestation type based on environment func DetectAttestationType() AttestationType { @@ -49,8 +50,12 @@ func DetectAttestationType() AttestationType { return AttestationNone } -func ParseAttestationType(attestationType string) (AttestationType, error) { +func ParseAttestationType(log *slog.Logger, attestationType string) (AttestationType, error) { switch attestationType { + case string(AttestationAuto): + detectedType := DetectAttestationType() + log.With("detected_attestation", detectedType).Info("Auto-detected attestation type") + return detectedType, nil case string(AttestationNone): return AttestationNone, nil case string(AttestationAzureTDX): @@ -75,8 +80,8 @@ func CreateAttestationIssuer(log *slog.Logger, attestationType AttestationType) } } -func CreateAttestationValidators(log *slog.Logger, attestationType AttestationType, jsonMeasurementsPath string) ([]atls.Validator, error) { - if attestationType == AttestationNone { +func CreateAttestationValidatorsFromFile(log *slog.Logger, jsonMeasurementsPath string) ([]atls.Validator, error) { + if jsonMeasurementsPath == "" { return nil, nil } @@ -91,26 +96,42 @@ func CreateAttestationValidators(log *slog.Logger, attestationType AttestationTy return nil, err } - switch attestationType { - case AttestationAzureTDX: - validators := []atls.Validator{} - for _, measurement := range parsedMeasurements { + // Group validators by attestation type + validatorsByType := make(map[AttestationType][]atls.Validator) + + for _, measurement := range parsedMeasurements { + attestationType, err := ParseAttestationType(log, measurement.AttestationType) + if err != nil { + return nil, fmt.Errorf("invalid attestation type %s in measurements file", measurement.AttestationType) + } + + switch attestationType { + case AttestationAzureTDX: attConfig := config.DefaultForAzureTDX() attConfig.SetMeasurements(measurement.Measurements) - validators = append(validators, azure_tdx.NewValidator(attConfig, AttestationLogger{Log: log})) - } - return []atls.Validator{NewMultiValidator(validators)}, nil - case AttestationDCAPTDX: - validators := []atls.Validator{} - for _, measurement := range parsedMeasurements { + validatorsByType[attestationType] = append( + validatorsByType[attestationType], + azure_tdx.NewValidator(attConfig, AttestationLogger{Log: log}), + ) + case AttestationDCAPTDX: attConfig := &config.QEMUTDX{Measurements: measurements.DefaultsFor(cloudprovider.QEMU, variant.QEMUTDX{})} attConfig.SetMeasurements(measurement.Measurements) - validators = append(validators, dcap_tdx.NewValidator(attConfig, AttestationLogger{Log: log})) + validatorsByType[attestationType] = append( + validatorsByType[attestationType], + dcap_tdx.NewValidator(attConfig, AttestationLogger{Log: log}), + ) + default: + return nil, fmt.Errorf("unsupported attestation type %s in measurements file", measurement.AttestationType) } - return []atls.Validator{NewMultiValidator(validators)}, nil - default: - return nil, errors.New("invalid attestation-type passed in") } + + // Create a MultiValidator for each attestation type + var validators []atls.Validator + for _, typeValidators := range validatorsByType { + validators = append(validators, NewMultiValidator(typeValidators)) + } + + return validators, nil } func ExtractMeasurementsFromExtension(ext *pkix.Extension, v variant.Variant) (map[uint32][]byte, error) {