Skip to content

Commit 80df1e4

Browse files
author
fnerdman
committed
feat: fetches validator attestation type from measurments, removes flag
1 parent e36bf5b commit 80df1e4

File tree

4 files changed

+53
-67
lines changed

4 files changed

+53
-67
lines changed

cmd/attested-get/main.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func runClient(cCtx *cli.Context) (err error) {
109109
}
110110

111111
// Create validators based on the attestation type
112-
attestationType, err := proxy.ParseAttestationType(attestationTypeStr)
112+
attestationType, err := proxy.ParseAttestationType(log, attestationTypeStr)
113113
if err != nil {
114114
log.With("attestation-type", attestationType).Error("invalid attestation-type passed, see --help")
115115
return err

cmd/proxy-client/main.go

Lines changed: 8 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,6 @@ var flags []cli.Flag = []cli.Flag{
2525
Value: "https://localhost:80",
2626
Usage: "address to proxy requests to",
2727
},
28-
&cli.StringFlag{
29-
Name: "server-attestation-type",
30-
Value: string(proxy.AttestationAzureTDX),
31-
Usage: "type of attestation to expect and verify (" + proxy.AvailableAttestationTypes + ")",
32-
},
3328
&cli.StringFlag{
3429
Name: "server-measurements",
3530
Usage: "optional path to JSON measurements enforced on the server",
@@ -45,7 +40,7 @@ var flags []cli.Flag = []cli.Flag{
4540
},
4641
&cli.StringFlag{
4742
Name: "client-attestation-type",
48-
Value: "",
43+
Value: "auto",
4944
Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + "). If not set, automatically detected.",
5045
},
5146
&cli.BoolFlag{
@@ -96,27 +91,15 @@ func runClient(cCtx *cli.Context) error {
9691
Version: common.Version,
9792
})
9893

99-
if cCtx.String("server-attestation-type") != "none" && verifyTLS {
100-
log.Error("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)")
101-
return errors.New("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)")
94+
if serverMeasurements != "" && verifyTLS {
95+
log.Error("invalid combination of --verify-tls and --server-measurements passed (cannot add server measurements and verify default TLS at the same time)")
96+
return errors.New("invalid combination of --verify-tls and --server-measurements passed (cannot add server measurements and verify default TLS at the same time)")
10297
}
10398

10499
// Auto-detect client attestation type if not specified
105-
clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type"))
106-
if err != nil {
107-
// If parsing fails and no type was specified, use auto-detection
108-
if cCtx.String("client-attestation-type") == "" {
109-
clientAttestationType = proxy.DetectAttestationType()
110-
log.With("detected_attestation", clientAttestationType).Info("Auto-detected client attestation type")
111-
} else {
112-
log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help")
113-
return err
114-
}
115-
}
116-
117-
serverAttestationType, err := proxy.ParseAttestationType(cCtx.String("server-attestation-type"))
100+
clientAttestationType, err := proxy.ParseAttestationType(log, cCtx.String("client-attestation-type"))
118101
if err != nil {
119-
log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help")
102+
log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help")
120103
return err
121104
}
122105

@@ -126,9 +109,9 @@ func runClient(cCtx *cli.Context) error {
126109
return err
127110
}
128111

129-
validators, err := proxy.CreateAttestationValidators(log, serverAttestationType, serverMeasurements)
112+
validators, err := proxy.CreateAttestationValidatorsFromFile(log, serverMeasurements)
130113
if err != nil {
131-
log.Error("could not create attestation validators", "err", err)
114+
log.Error("could not create attestation validators from file", "err", err)
132115
return err
133116
}
134117

cmd/proxy-server/main.go

Lines changed: 5 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ var flags []cli.Flag = []cli.Flag{
4040
&cli.StringFlag{
4141
Name: "server-attestation-type",
4242
EnvVars: []string{"SERVER_ATTESTATION_TYPE"},
43-
Value: "",
43+
Value: "auto",
4444
Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + "). If not set, automatically detected.",
4545
},
4646
&cli.StringFlag{
@@ -53,12 +53,6 @@ var flags []cli.Flag = []cli.Flag{
5353
EnvVars: []string{"TLS_PRIVATE_KEY_PATH"},
5454
Usage: "Path to private key file for the certificate. Only valid with --tls-certificate-path",
5555
},
56-
&cli.StringFlag{
57-
Name: "client-attestation-type",
58-
EnvVars: []string{"CLIENT_ATTESTATION_TYPE"},
59-
Value: string(proxy.AttestationNone),
60-
Usage: "type of attestation to expect and verify (" + proxy.AvailableAttestationTypes + ")",
61-
},
6256
&cli.StringFlag{
6357
Name: "client-measurements",
6458
EnvVars: []string{"CLIENT_MEASUREMENTS"},
@@ -133,27 +127,15 @@ func runServer(cCtx *cli.Context) error {
133127
}
134128

135129
// Auto-detect server attestation type if not specified
136-
serverAttestationType, err := proxy.ParseAttestationType(cCtx.String("server-attestation-type"))
137-
if err != nil {
138-
// If parsing fails and no type was specified, use auto-detection
139-
if cCtx.String("server-attestation-type") == "" {
140-
serverAttestationType = proxy.DetectAttestationType()
141-
log.With("detected_attestation", serverAttestationType).Info("Auto-detected server attestation type")
142-
} else {
143-
log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help")
144-
return err
145-
}
146-
}
147-
148-
clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type"))
130+
serverAttestationType, err := proxy.ParseAttestationType(log, cCtx.String("server-attestation-type"))
149131
if err != nil {
150-
log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help")
132+
log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help")
151133
return err
152134
}
153135

154-
validators, err := proxy.CreateAttestationValidators(log, clientAttestationType, clientMeasurements)
136+
validators, err := proxy.CreateAttestationValidatorsFromFile(log, clientMeasurements)
155137
if err != nil {
156-
log.Error("could not create attestation validators", "err", err)
138+
log.Error("could not create attestation validators from file", "err", err)
157139
return err
158140
}
159141

proxy/atls_config.go

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,13 @@ import (
2323
type AttestationType string
2424

2525
const (
26+
AttestationAuto AttestationType = "auto"
2627
AttestationNone AttestationType = "none"
2728
AttestationAzureTDX AttestationType = "azure-tdx"
2829
AttestationDCAPTDX AttestationType = "dcap-tdx"
2930
)
3031

31-
const AvailableAttestationTypes string = "none, azure-tdx, dcap-tdx"
32+
const AvailableAttestationTypes string = "auto, none, azure-tdx, dcap-tdx"
3233

3334
// DetectAttestationType determines the attestation type based on environment
3435
func DetectAttestationType() AttestationType {
@@ -49,8 +50,12 @@ func DetectAttestationType() AttestationType {
4950
return AttestationNone
5051
}
5152

52-
func ParseAttestationType(attestationType string) (AttestationType, error) {
53+
func ParseAttestationType(log *slog.Logger, attestationType string) (AttestationType, error) {
5354
switch attestationType {
55+
case string(AttestationAuto):
56+
detectedType := DetectAttestationType()
57+
log.With("detected_attestation", detectedType).Info("Auto-detected attestation type")
58+
return detectedType, nil
5459
case string(AttestationNone):
5560
return AttestationNone, nil
5661
case string(AttestationAzureTDX):
@@ -75,8 +80,8 @@ func CreateAttestationIssuer(log *slog.Logger, attestationType AttestationType)
7580
}
7681
}
7782

78-
func CreateAttestationValidators(log *slog.Logger, attestationType AttestationType, jsonMeasurementsPath string) ([]atls.Validator, error) {
79-
if attestationType == AttestationNone {
83+
func CreateAttestationValidatorsFromFile(log *slog.Logger, jsonMeasurementsPath string) ([]atls.Validator, error) {
84+
if jsonMeasurementsPath == "" {
8085
return nil, nil
8186
}
8287

@@ -91,26 +96,42 @@ func CreateAttestationValidators(log *slog.Logger, attestationType AttestationTy
9196
return nil, err
9297
}
9398

94-
switch attestationType {
95-
case AttestationAzureTDX:
96-
validators := []atls.Validator{}
97-
for _, measurement := range parsedMeasurements {
99+
// Group validators by attestation type
100+
validatorsByType := make(map[AttestationType][]atls.Validator)
101+
102+
for _, measurement := range parsedMeasurements {
103+
attestationType, err := ParseAttestationType(log, measurement.AttestationType)
104+
if err != nil {
105+
return nil, fmt.Errorf("invalid attestation type %s in measurements file", measurement.AttestationType)
106+
}
107+
108+
switch attestationType {
109+
case AttestationAzureTDX:
98110
attConfig := config.DefaultForAzureTDX()
99111
attConfig.SetMeasurements(measurement.Measurements)
100-
validators = append(validators, azure_tdx.NewValidator(attConfig, AttestationLogger{Log: log}))
101-
}
102-
return []atls.Validator{NewMultiValidator(validators)}, nil
103-
case AttestationDCAPTDX:
104-
validators := []atls.Validator{}
105-
for _, measurement := range parsedMeasurements {
112+
validatorsByType[attestationType] = append(
113+
validatorsByType[attestationType],
114+
azure_tdx.NewValidator(attConfig, AttestationLogger{Log: log}),
115+
)
116+
case AttestationDCAPTDX:
106117
attConfig := &config.QEMUTDX{Measurements: measurements.DefaultsFor(cloudprovider.QEMU, variant.QEMUTDX{})}
107118
attConfig.SetMeasurements(measurement.Measurements)
108-
validators = append(validators, dcap_tdx.NewValidator(attConfig, AttestationLogger{Log: log}))
119+
validatorsByType[attestationType] = append(
120+
validatorsByType[attestationType],
121+
dcap_tdx.NewValidator(attConfig, AttestationLogger{Log: log}),
122+
)
123+
default:
124+
return nil, fmt.Errorf("unsupported attestation type %s in measurements file", measurement.AttestationType)
109125
}
110-
return []atls.Validator{NewMultiValidator(validators)}, nil
111-
default:
112-
return nil, errors.New("invalid attestation-type passed in")
113126
}
127+
128+
// Create a MultiValidator for each attestation type
129+
var validators []atls.Validator
130+
for _, typeValidators := range validatorsByType {
131+
validators = append(validators, NewMultiValidator(typeValidators))
132+
}
133+
134+
return validators, nil
114135
}
115136

116137
func ExtractMeasurementsFromExtension(ext *pkix.Extension, v variant.Variant) (map[uint32][]byte, error) {

0 commit comments

Comments
 (0)