Skip to content

Commit ca38dda

Browse files
fnerdmanFrieder Paapefnerdman
authored
feat: auto handle attestation type on the validator side (#33)
* feat: auto handle attestation type on the validator side * fix: indentation * chore: readds but deprecates validator attestation type flag --------- Co-authored-by: Frieder Paape <[email protected]> Co-authored-by: fnerdman <[email protected]>
1 parent a2abde9 commit ca38dda

File tree

3 files changed

+50
-41
lines changed

3 files changed

+50
-41
lines changed

cmd/proxy-client/main.go

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ var flags []cli.Flag = []cli.Flag{
2727
},
2828
&cli.StringFlag{
2929
Name: "server-attestation-type",
30-
Value: string(proxy.AttestationAzureTDX),
31-
Usage: "type of attestation to expect and verify (" + proxy.AvailableAttestationTypes + ")",
30+
Usage: "Deprecated and not used. Server attestation types are set via the measurements file.",
3231
},
3332
&cli.StringFlag{
3433
Name: "server-measurements",
@@ -37,7 +36,7 @@ var flags []cli.Flag = []cli.Flag{
3736
&cli.BoolFlag{
3837
Name: "verify-tls",
3938
Value: false,
40-
Usage: "verify server's TLS certificate instead of server's attestation. Only valid for server-attestation-type=none.",
39+
Usage: "verify server's TLS certificate instead of server's attestation. Only valid when not specifying measurements.",
4140
},
4241
&cli.StringFlag{
4342
Name: "tls-ca-certificate",
@@ -96,20 +95,18 @@ func runClient(cCtx *cli.Context) error {
9695
Version: common.Version,
9796
})
9897

99-
if cCtx.String("server-attestation-type") != "none" && verifyTLS {
100-
log.Error("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)")
101-
return errors.New("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)")
98+
if cCtx.String("server-attestation-type") != "" {
99+
log.Warn("DEPRECATED: --server-attestation-type is deprecated and will be removed in a future version")
102100
}
103101

104-
clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type"))
105-
if err != nil {
106-
log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help")
107-
return err
102+
if serverMeasurements != "" && verifyTLS {
103+
log.Error("invalid combination of --verify-tls and --server-measurements passed (cannot add server measurements and verify default TLS at the same time)")
104+
return errors.New("invalid combination of --verify-tls and --server-measurements passed (cannot add server measurements and verify default TLS at the same time)")
108105
}
109106

110-
serverAttestationType, err := proxy.ParseAttestationType(cCtx.String("server-attestation-type"))
107+
clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type"))
111108
if err != nil {
112-
log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help")
109+
log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help")
113110
return err
114111
}
115112

@@ -119,9 +116,9 @@ func runClient(cCtx *cli.Context) error {
119116
return err
120117
}
121118

122-
validators, err := proxy.CreateAttestationValidators(log, serverAttestationType, serverMeasurements)
119+
validators, err := proxy.CreateAttestationValidatorsFromFile(log, serverMeasurements)
123120
if err != nil {
124-
log.Error("could not create attestation validators", "err", err)
121+
log.Error("could not create attestation validators from file", "err", err)
125122
return err
126123
}
127124

cmd/proxy-server/main.go

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ var flags []cli.Flag = []cli.Flag{
5555
},
5656
&cli.StringFlag{
5757
Name: "client-attestation-type",
58-
EnvVars: []string{"CLIENT_ATTESTATION_TYPE"},
59-
Value: string(proxy.AttestationNone),
60-
Usage: "type of attestation to expect and verify (" + proxy.AvailableAttestationTypes + ")",
58+
Usage: "Deprecated and not used. Client attestation types are set via the measurements file.",
6159
},
6260
&cli.StringFlag{
6361
Name: "client-measurements",
@@ -123,6 +121,10 @@ func runServer(cCtx *cli.Context) error {
123121
Version: common.Version,
124122
})
125123

124+
if cCtx.String("client-attestation-type") != "" {
125+
log.Warn("DEPRECATED: --client-attestation-type is deprecated and will be removed in a future version")
126+
}
127+
126128
useRegularTLS := certFile != "" || keyFile != ""
127129
if serverAttestationTypeFlag != "none" && useRegularTLS {
128130
return errors.New("invalid combination of --tls-certificate-path, --tls-private-key-path and --server-attestation-type flags passed (only 'none' is allowed)")
@@ -138,15 +140,9 @@ func runServer(cCtx *cli.Context) error {
138140
return err
139141
}
140142

141-
clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type"))
142-
if err != nil {
143-
log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help")
144-
return err
145-
}
146-
147-
validators, err := proxy.CreateAttestationValidators(log, clientAttestationType, clientMeasurements)
143+
validators, err := proxy.CreateAttestationValidatorsFromFile(log, clientMeasurements)
148144
if err != nil {
149-
log.Error("could not create attestation validators", "err", err)
145+
log.Error("could not create attestation validators from file", "err", err)
150146
return err
151147
}
152148

proxy/atls_config.go

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ func CreateAttestationIssuer(log *slog.Logger, attestationType AttestationType)
8383
}
8484
}
8585

86-
func CreateAttestationValidators(log *slog.Logger, attestationType AttestationType, jsonMeasurementsPath string) ([]atls.Validator, error) {
87-
if attestationType == AttestationNone {
86+
func CreateAttestationValidatorsFromFile(log *slog.Logger, jsonMeasurementsPath string) ([]atls.Validator, error) {
87+
if jsonMeasurementsPath == "" {
8888
return nil, nil
8989
}
9090

@@ -99,26 +99,42 @@ func CreateAttestationValidators(log *slog.Logger, attestationType AttestationTy
9999
return nil, err
100100
}
101101

102-
switch attestationType {
103-
case AttestationAzureTDX:
104-
validators := []atls.Validator{}
105-
for _, measurement := range parsedMeasurements {
102+
// Group validators by attestation type
103+
validatorsByType := make(map[AttestationType][]atls.Validator)
104+
105+
for _, measurement := range parsedMeasurements {
106+
attestationType, err := ParseAttestationType(measurement.AttestationType)
107+
if err != nil {
108+
return nil, fmt.Errorf("invalid attestation type %s in measurements file", measurement.AttestationType)
109+
}
110+
111+
switch attestationType {
112+
case AttestationAzureTDX:
106113
attConfig := config.DefaultForAzureTDX()
107114
attConfig.SetMeasurements(measurement.Measurements)
108-
validators = append(validators, azure_tdx.NewValidator(attConfig, AttestationLogger{Log: log}))
109-
}
110-
return []atls.Validator{NewMultiValidator(validators)}, nil
111-
case AttestationDCAPTDX:
112-
validators := []atls.Validator{}
113-
for _, measurement := range parsedMeasurements {
115+
validatorsByType[attestationType] = append(
116+
validatorsByType[attestationType],
117+
azure_tdx.NewValidator(attConfig, AttestationLogger{Log: log}),
118+
)
119+
case AttestationDCAPTDX:
114120
attConfig := &config.QEMUTDX{Measurements: measurements.DefaultsFor(cloudprovider.QEMU, variant.QEMUTDX{})}
115121
attConfig.SetMeasurements(measurement.Measurements)
116-
validators = append(validators, dcap_tdx.NewValidator(attConfig, AttestationLogger{Log: log}))
122+
validatorsByType[attestationType] = append(
123+
validatorsByType[attestationType],
124+
dcap_tdx.NewValidator(attConfig, AttestationLogger{Log: log}),
125+
)
126+
default:
127+
return nil, fmt.Errorf("unsupported attestation type %s in measurements file", measurement.AttestationType)
117128
}
118-
return []atls.Validator{NewMultiValidator(validators)}, nil
119-
default:
120-
return nil, errors.New("invalid attestation-type passed in")
121129
}
130+
131+
// Create a MultiValidator for each attestation type
132+
var validators []atls.Validator
133+
for _, typeValidators := range validatorsByType {
134+
validators = append(validators, NewMultiValidator(typeValidators))
135+
}
136+
137+
return validators, nil
122138
}
123139

124140
func ExtractMeasurementsFromExtension(ext *pkix.Extension, v variant.Variant) (map[uint32][]byte, error) {

0 commit comments

Comments
 (0)