Skip to content

Commit e36bf5b

Browse files
author
fnerdman
committed
feat: auto detect tee env
1 parent 8e5c9a1 commit e36bf5b

File tree

3 files changed

+42
-9
lines changed

3 files changed

+42
-9
lines changed

cmd/proxy-client/main.go

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ var flags []cli.Flag = []cli.Flag{
4545
},
4646
&cli.StringFlag{
4747
Name: "client-attestation-type",
48-
Value: string(proxy.AttestationNone),
49-
Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + ")",
48+
Value: "",
49+
Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + "). If not set, automatically detected.",
5050
},
5151
&cli.BoolFlag{
5252
Name: "log-json",
@@ -101,10 +101,17 @@ func runClient(cCtx *cli.Context) error {
101101
return errors.New("invalid combination of --verify-tls and --server-attestation-type passed (only 'none' is allowed)")
102102
}
103103

104+
// Auto-detect client attestation type if not specified
104105
clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type"))
105106
if err != nil {
106-
log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help")
107-
return err
107+
// If parsing fails and no type was specified, use auto-detection
108+
if cCtx.String("client-attestation-type") == "" {
109+
clientAttestationType = proxy.DetectAttestationType()
110+
log.With("detected_attestation", clientAttestationType).Info("Auto-detected client attestation type")
111+
} else {
112+
log.With("attestation-type", cCtx.String("client-attestation-type")).Error("invalid client-attestation-type passed, see --help")
113+
return err
114+
}
108115
}
109116

110117
serverAttestationType, err := proxy.ParseAttestationType(cCtx.String("server-attestation-type"))

cmd/proxy-server/main.go

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ var flags []cli.Flag = []cli.Flag{
4040
&cli.StringFlag{
4141
Name: "server-attestation-type",
4242
EnvVars: []string{"SERVER_ATTESTATION_TYPE"},
43-
Value: string(proxy.AttestationAzureTDX),
44-
Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + ")",
43+
Value: "",
44+
Usage: "type of attestation to present (" + proxy.AvailableAttestationTypes + "). If not set, automatically detected.",
4545
},
4646
&cli.StringFlag{
4747
Name: "tls-certificate-path",
@@ -132,10 +132,17 @@ func runServer(cCtx *cli.Context) error {
132132
return errors.New("not all of --tls-certificate-path and --tls-private-key-path specified")
133133
}
134134

135-
serverAttestationType, err := proxy.ParseAttestationType(serverAttestationTypeFlag)
135+
// Auto-detect server attestation type if not specified
136+
serverAttestationType, err := proxy.ParseAttestationType(cCtx.String("server-attestation-type"))
136137
if err != nil {
137-
log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help")
138-
return err
138+
// If parsing fails and no type was specified, use auto-detection
139+
if cCtx.String("server-attestation-type") == "" {
140+
serverAttestationType = proxy.DetectAttestationType()
141+
log.With("detected_attestation", serverAttestationType).Info("Auto-detected server attestation type")
142+
} else {
143+
log.With("attestation-type", cCtx.String("server-attestation-type")).Error("invalid server-attestation-type passed, see --help")
144+
return err
145+
}
139146
}
140147

141148
clientAttestationType, err := proxy.ParseAttestationType(cCtx.String("client-attestation-type"))

proxy/atls_config.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,25 @@ const (
3030

3131
const AvailableAttestationTypes string = "none, azure-tdx, dcap-tdx"
3232

33+
// DetectAttestationType determines the attestation type based on environment
34+
func DetectAttestationType() AttestationType {
35+
// Check for TDX device files - these indicate DCAP TDX
36+
_, tdxErr1 := os.Stat("/dev/tdx-guest")
37+
_, tdxErr2 := os.Stat("/dev/tdx_guest")
38+
if tdxErr1 == nil || tdxErr2 == nil {
39+
return AttestationDCAPTDX
40+
}
41+
42+
// Try Azure TDX attestation - if it works, we're in Azure TDX
43+
issuer := azure_tdx.NewIssuer(nil) // nil logger for detection
44+
_, err := issuer.Issue(context.Background(), []byte("test"), []byte("test"))
45+
if err == nil {
46+
return AttestationAzureTDX
47+
}
48+
49+
return AttestationNone
50+
}
51+
3352
func ParseAttestationType(attestationType string) (AttestationType, error) {
3453
switch attestationType {
3554
case string(AttestationNone):

0 commit comments

Comments
 (0)