@@ -23,12 +23,13 @@ import (
23
23
type AttestationType string
24
24
25
25
const (
26
+ AttestationAuto AttestationType = "auto"
26
27
AttestationNone AttestationType = "none"
27
28
AttestationAzureTDX AttestationType = "azure-tdx"
28
29
AttestationDCAPTDX AttestationType = "dcap-tdx"
29
30
)
30
31
31
- const AvailableAttestationTypes string = "none, azure-tdx, dcap-tdx"
32
+ const AvailableAttestationTypes string = "auto, none, azure-tdx, dcap-tdx"
32
33
33
34
// DetectAttestationType determines the attestation type based on environment
34
35
func DetectAttestationType () AttestationType {
@@ -49,8 +50,12 @@ func DetectAttestationType() AttestationType {
49
50
return AttestationNone
50
51
}
51
52
52
- func ParseAttestationType (attestationType string ) (AttestationType , error ) {
53
+ func ParseAttestationType (log * slog. Logger , attestationType string ) (AttestationType , error ) {
53
54
switch attestationType {
55
+ case string (AttestationAuto ):
56
+ detectedType := DetectAttestationType ()
57
+ log .With ("detected_attestation" , detectedType ).Info ("Auto-detected attestation type" )
58
+ return detectedType , nil
54
59
case string (AttestationNone ):
55
60
return AttestationNone , nil
56
61
case string (AttestationAzureTDX ):
@@ -75,8 +80,8 @@ func CreateAttestationIssuer(log *slog.Logger, attestationType AttestationType)
75
80
}
76
81
}
77
82
78
- func CreateAttestationValidators (log * slog.Logger , attestationType AttestationType , jsonMeasurementsPath string ) ([]atls.Validator , error ) {
79
- if attestationType == AttestationNone {
83
+ func CreateAttestationValidatorsFromFile (log * slog.Logger , jsonMeasurementsPath string ) ([]atls.Validator , error ) {
84
+ if jsonMeasurementsPath == "" {
80
85
return nil , nil
81
86
}
82
87
@@ -91,26 +96,42 @@ func CreateAttestationValidators(log *slog.Logger, attestationType AttestationTy
91
96
return nil , err
92
97
}
93
98
94
- switch attestationType {
95
- case AttestationAzureTDX :
96
- validators := []atls.Validator {}
97
- for _ , measurement := range parsedMeasurements {
99
+ // Group validators by attestation type
100
+ validatorsByType := make (map [AttestationType ][]atls.Validator )
101
+
102
+ for _ , measurement := range parsedMeasurements {
103
+ attestationType , err := ParseAttestationType (log , measurement .AttestationType )
104
+ if err != nil {
105
+ return nil , fmt .Errorf ("invalid attestation type %s in measurements file" , measurement .AttestationType )
106
+ }
107
+
108
+ switch attestationType {
109
+ case AttestationAzureTDX :
98
110
attConfig := config .DefaultForAzureTDX ()
99
111
attConfig .SetMeasurements (measurement .Measurements )
100
- validators = append (validators , azure_tdx .NewValidator (attConfig , AttestationLogger {Log : log }))
101
- }
102
- return []atls.Validator {NewMultiValidator (validators )}, nil
103
- case AttestationDCAPTDX :
104
- validators := []atls.Validator {}
105
- for _ , measurement := range parsedMeasurements {
112
+ validatorsByType [attestationType ] = append (
113
+ validatorsByType [attestationType ],
114
+ azure_tdx .NewValidator (attConfig , AttestationLogger {Log : log }),
115
+ )
116
+ case AttestationDCAPTDX :
106
117
attConfig := & config.QEMUTDX {Measurements : measurements .DefaultsFor (cloudprovider .QEMU , variant.QEMUTDX {})}
107
118
attConfig .SetMeasurements (measurement .Measurements )
108
- validators = append (validators , dcap_tdx .NewValidator (attConfig , AttestationLogger {Log : log }))
119
+ validatorsByType [attestationType ] = append (
120
+ validatorsByType [attestationType ],
121
+ dcap_tdx .NewValidator (attConfig , AttestationLogger {Log : log }),
122
+ )
123
+ default :
124
+ return nil , fmt .Errorf ("unsupported attestation type %s in measurements file" , measurement .AttestationType )
109
125
}
110
- return []atls.Validator {NewMultiValidator (validators )}, nil
111
- default :
112
- return nil , errors .New ("invalid attestation-type passed in" )
113
126
}
127
+
128
+ // Create a MultiValidator for each attestation type
129
+ var validators []atls.Validator
130
+ for _ , typeValidators := range validatorsByType {
131
+ validators = append (validators , NewMultiValidator (typeValidators ))
132
+ }
133
+
134
+ return validators , nil
114
135
}
115
136
116
137
func ExtractMeasurementsFromExtension (ext * pkix.Extension , v variant.Variant ) (map [uint32 ][]byte , error ) {
0 commit comments