Skip to content

Commit de9c46c

Browse files
committed
Merge branch 'main' into feature/branch_file_selector
2 parents 288c12a + 5e75aa0 commit de9c46c

15 files changed

Lines changed: 1321 additions & 35 deletions

File tree

backend/internal/agent/agent.go

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"github.com/OrcaCD/orca-cd/internal/agent/docker"
1515
messages "github.com/OrcaCD/orca-cd/internal/proto"
1616
"github.com/OrcaCD/orca-cd/internal/version"
17+
"github.com/golang-jwt/jwt/v5"
1718
"github.com/gorilla/websocket"
1819
"github.com/rs/zerolog"
1920
"google.golang.org/protobuf/proto"
@@ -24,6 +25,7 @@ type Config struct {
2425
LogJSON bool
2526
HubUrl string
2627
AuthToken string
28+
AgentID string
2729
}
2830

2931
func DefaultConfig() (Config, error) {
@@ -41,6 +43,11 @@ func DefaultConfig() (Config, error) {
4143
return Config{}, errors.New("AUTH_TOKEN is required")
4244
}
4345

46+
agentID, err := parseAgentID(authToken)
47+
if err != nil {
48+
return Config{}, fmt.Errorf("AUTH_TOKEN: %w", err)
49+
}
50+
4451
logLevel, err := zerolog.ParseLevel(logLevelStr)
4552
if err != nil || logLevelStr == "" {
4653
logLevel = zerolog.InfoLevel
@@ -53,9 +60,23 @@ func DefaultConfig() (Config, error) {
5360
LogJSON: logJSON,
5461
HubUrl: hubUrl,
5562
AuthToken: authToken,
63+
AgentID: agentID,
5664
}, nil
5765
}
5866

67+
func parseAgentID(authToken string) (string, error) {
68+
p := jwt.NewParser()
69+
token, _, err := p.ParseUnverified(authToken, &jwt.RegisteredClaims{})
70+
if err != nil {
71+
return "", fmt.Errorf("could not parse token to extract agent ID: %w", err)
72+
}
73+
claims, ok := token.Claims.(*jwt.RegisteredClaims)
74+
if !ok || claims.Subject == "" {
75+
return "", errors.New("token is missing subject claim")
76+
}
77+
return claims.Subject, nil
78+
}
79+
5980
func newLogger(logJSON bool) zerolog.Logger {
6081
if logJSON {
6182
return zerolog.New(os.Stderr).With().Timestamp().Str("service", "agent").Logger()
@@ -120,6 +141,13 @@ func Run(cfg Config) error {
120141
return nil
121142
}
122143

144+
session, err := performHandshake(conn, cfg.AgentID)
145+
if err != nil {
146+
Log.Error().Err(err).Msg("handshake failed")
147+
_ = conn.Close()
148+
return fmt.Errorf("handshake: %w", err)
149+
}
150+
123151
for {
124152
_, data, readErr := conn.ReadMessage()
125153
if readErr != nil {
@@ -136,13 +164,19 @@ func Run(cfg Config) error {
136164
Log.Info().Msg("agent stopped")
137165
return nil
138166
}
167+
session, err = performHandshake(conn, cfg.AgentID)
168+
if err != nil {
169+
Log.Error().Err(err).Msg("handshake failed on reconnect")
170+
_ = conn.Close()
171+
return fmt.Errorf("handshake on reconnect: %w", err)
172+
}
139173
continue
140174
}
141175
msg := &messages.ServerMessage{}
142176
if err := proto.Unmarshal(data, msg); err != nil {
143177
Log.Error().Err(err).Msg("unmarshal error")
144178
continue
145179
}
146-
handleServerMessage(msg, conn)
180+
handleServerMessage(msg, conn, session)
147181
}
148182
}

backend/internal/agent/agent_config_test.go

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,27 @@
11
package agent
22

33
import (
4+
"context"
5+
"encoding/base64"
46
"testing"
7+
"time"
58

9+
"github.com/gorilla/websocket"
610
"github.com/rs/zerolog"
711
)
812

13+
// testAgentToken is a minimal JWT-format string with a known subject.
14+
// The signature is intentionally fake; DefaultConfig only calls jwt.ParseUnverified,
15+
// which does not check the signature.
16+
var testAgentToken = func() string {
17+
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"EdDSA","typ":"JWT"}`))
18+
payload := base64.RawURLEncoding.EncodeToString([]byte(`{"sub":"test-agent-id"}`))
19+
return header + "." + payload + ".fakesig"
20+
}()
21+
922
func TestDefaultConfig_Valid(t *testing.T) {
1023
t.Setenv("HUB_URL", "https://hub.example.com")
11-
t.Setenv("AUTH_TOKEN", "test-token")
24+
t.Setenv("AUTH_TOKEN", testAgentToken)
1225
t.Setenv("LOG_LEVEL", "debug")
1326
t.Setenv("LOG_JSON", "true")
1427

@@ -26,14 +39,17 @@ func TestDefaultConfig_Valid(t *testing.T) {
2639
if cfg.HubUrl != "wss://hub.example.com/api/v1/ws" {
2740
t.Errorf("HubUrl = %q, want %q", cfg.HubUrl, "wss://hub.example.com/api/v1/ws")
2841
}
29-
if cfg.AuthToken != "test-token" {
30-
t.Errorf("AuthToken = %q, want %q", cfg.AuthToken, "test-token")
42+
if cfg.AuthToken != testAgentToken {
43+
t.Errorf("AuthToken = %q, want %q", cfg.AuthToken, testAgentToken)
44+
}
45+
if cfg.AgentID != "test-agent-id" {
46+
t.Errorf("AgentID = %q, want %q", cfg.AgentID, "test-agent-id")
3147
}
3248
}
3349

3450
func TestDefaultConfig_Defaults(t *testing.T) {
3551
t.Setenv("HUB_URL", "https://hub.example.com")
36-
t.Setenv("AUTH_TOKEN", "test-token")
52+
t.Setenv("AUTH_TOKEN", testAgentToken)
3753
t.Setenv("LOG_LEVEL", "")
3854
t.Setenv("LOG_JSON", "")
3955

@@ -66,7 +82,7 @@ func TestDefaultConfig_LogLevels(t *testing.T) {
6682
for _, tt := range tests {
6783
t.Run(tt.input, func(t *testing.T) {
6884
t.Setenv("HUB_URL", "https://hub.example.com")
69-
t.Setenv("AUTH_TOKEN", "test-token")
85+
t.Setenv("AUTH_TOKEN", testAgentToken)
7086
t.Setenv("LOG_LEVEL", tt.input)
7187

7288
cfg, err := DefaultConfig()
@@ -95,7 +111,7 @@ func TestDefaultConfig_LogJSON(t *testing.T) {
95111
for _, tt := range tests {
96112
t.Run(tt.input, func(t *testing.T) {
97113
t.Setenv("HUB_URL", "https://hub.example.com")
98-
t.Setenv("AUTH_TOKEN", "test-token")
114+
t.Setenv("AUTH_TOKEN", testAgentToken)
99115
t.Setenv("LOG_JSON", tt.input)
100116

101117
cfg, err := DefaultConfig()
@@ -109,14 +125,50 @@ func TestDefaultConfig_LogJSON(t *testing.T) {
109125
}
110126
}
111127

128+
func TestParseAgentID_MissingSubject(t *testing.T) {
129+
// JWT with empty subject field.
130+
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"EdDSA","typ":"JWT"}`))
131+
payload := base64.RawURLEncoding.EncodeToString([]byte(`{"sub":""}`))
132+
token := header + "." + payload + ".fakesig"
133+
134+
_, err := parseAgentID(token)
135+
if err == nil {
136+
t.Fatal("expected error for token with empty subject")
137+
}
138+
}
139+
140+
func TestParseAgentID_InvalidToken(t *testing.T) {
141+
_, err := parseAgentID("not.a.jwt")
142+
if err == nil {
143+
t.Fatal("expected error for malformed JWT")
144+
}
145+
}
146+
147+
func TestConnTracker_SetAndCancelled_CancelledCtx(t *testing.T) {
148+
srv := newTestServer(t, func(serverConn *websocket.Conn) {
149+
serverConn.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) //nolint:errcheck,gosec
150+
_, _, _ = serverConn.ReadMessage()
151+
})
152+
153+
clientConn := dialServer(t, srv)
154+
155+
ctx, cancel := context.WithCancel(context.Background())
156+
cancel() // pre-cancel so ctx.Err() != nil
157+
158+
var tracker connTracker
159+
if cancelled := tracker.setAndCancelled(ctx, clientConn); !cancelled {
160+
t.Error("expected setAndCancelled to return true for a pre-cancelled context")
161+
}
162+
}
163+
112164
func TestDefaultConfig_Errors(t *testing.T) {
113165
tests := []struct {
114166
name string
115167
hubURL string
116168
authToken string
117169
}{
118170
{"missing auth token", "https://hub.example.com", ""},
119-
{"invalid hub url", "not-a-url", "test-token"},
171+
{"invalid hub url", "not-a-url", testAgentToken},
120172
}
121173

122174
for _, tt := range tests {

backend/internal/agent/websocket.go

Lines changed: 101 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,38 +2,129 @@ package agent
22

33
import (
44
"context"
5+
"fmt"
56
"net/http"
67
"strings"
78
"time"
89

910
messages "github.com/OrcaCD/orca-cd/internal/proto"
11+
"github.com/OrcaCD/orca-cd/internal/shared/wscrypto"
1012
"github.com/gorilla/websocket"
1113
"google.golang.org/protobuf/proto"
1214
)
1315

14-
func handleServerMessage(msg *messages.ServerMessage, conn *websocket.Conn) {
16+
const handshakeTimeout = 15 * time.Second
17+
18+
func performHandshake(conn *websocket.Conn, agentID string) (*wscrypto.Session, error) {
19+
if err := conn.SetReadDeadline(time.Now().Add(handshakeTimeout)); err != nil {
20+
return nil, fmt.Errorf("set read deadline: %w", err)
21+
}
22+
23+
_, data, err := conn.ReadMessage()
24+
if err != nil {
25+
return nil, fmt.Errorf("read KeyExchangeInit: %w", err)
26+
}
27+
28+
serverMsg := &messages.ServerMessage{}
29+
if err := proto.Unmarshal(data, serverMsg); err != nil {
30+
return nil, fmt.Errorf("unmarshal KeyExchangeInit: %w", err)
31+
}
32+
33+
init, ok := serverMsg.Payload.(*messages.ServerMessage_KeyExchangeInit)
34+
if !ok {
35+
return nil, fmt.Errorf("expected KeyExchangeInit, got %T", serverMsg.Payload)
36+
}
37+
38+
mlkemCiphertext, agentX25519Pub, sessionKey, err := wscrypto.AgentHandshake(
39+
init.KeyExchangeInit.MlkemEncapsulationKey,
40+
init.KeyExchangeInit.X25519PublicKey,
41+
agentID,
42+
)
43+
if err != nil {
44+
return nil, fmt.Errorf("agent handshake: %w", err)
45+
}
46+
47+
resp := &messages.ClientMessage{
48+
Payload: &messages.ClientMessage_KeyExchangeResponse{
49+
KeyExchangeResponse: &messages.KeyExchangeResponse{
50+
MlkemCiphertext: mlkemCiphertext,
51+
AgentX25519PublicKey: agentX25519Pub,
52+
},
53+
},
54+
}
55+
respData, err := proto.Marshal(resp)
56+
if err != nil {
57+
return nil, fmt.Errorf("marshal KeyExchangeResponse: %w", err)
58+
}
59+
if err := conn.WriteMessage(websocket.BinaryMessage, respData); err != nil {
60+
return nil, fmt.Errorf("send KeyExchangeResponse: %w", err)
61+
}
62+
63+
// Clear the handshake deadline — the hub's read loop will set its own deadline.
64+
if err := conn.SetReadDeadline(time.Time{}); err != nil {
65+
return nil, fmt.Errorf("clear read deadline: %w", err)
66+
}
67+
68+
Log.Debug().Str("agent_id", agentID).Msg("WS handshake complete")
69+
return wscrypto.NewSession(sessionKey)
70+
}
71+
72+
func sendMessage(conn *websocket.Conn, session *wscrypto.Session, msg *messages.ClientMessage) error {
73+
outMsg := msg
74+
if !wscrypto.AllowedUnencrypted(msg) {
75+
env, err := session.Encrypt(msg)
76+
if err != nil {
77+
return fmt.Errorf("encrypt: %w", err)
78+
}
79+
outMsg = &messages.ClientMessage{
80+
Payload: &messages.ClientMessage_EncryptedPayload{
81+
EncryptedPayload: env,
82+
},
83+
}
84+
}
85+
data, err := proto.Marshal(outMsg)
86+
if err != nil {
87+
return fmt.Errorf("marshal: %w", err)
88+
}
89+
return conn.WriteMessage(websocket.BinaryMessage, data)
90+
}
91+
92+
func handleServerMessage(msg *messages.ServerMessage, conn *websocket.Conn, session *wscrypto.Session) {
93+
_, isEncrypted := msg.Payload.(*messages.ServerMessage_EncryptedPayload)
94+
if !isEncrypted && !wscrypto.AllowedUnencrypted(msg) {
95+
Log.Warn().Msgf("dropping unencrypted message of type %T", msg.Payload)
96+
return
97+
}
98+
99+
// Unwrap at most one layer of encryption to prevent a malicious server from
100+
// crafting a chain of nested EncryptedPayloads that causes unbounded recursion.
101+
if p, ok := msg.Payload.(*messages.ServerMessage_EncryptedPayload); ok {
102+
inner := &messages.ServerMessage{}
103+
if err := session.Decrypt(p.EncryptedPayload, inner); err != nil {
104+
Log.Error().Err(err).Msg("failed to decrypt message")
105+
return
106+
}
107+
if _, stillEncrypted := inner.Payload.(*messages.ServerMessage_EncryptedPayload); stillEncrypted {
108+
Log.Warn().Msg("dropping doubly-encrypted message")
109+
return
110+
}
111+
msg = inner
112+
}
113+
15114
switch p := msg.Payload.(type) {
16115
case *messages.ServerMessage_Ping:
17116
latency := time.Now().UnixMilli() - p.Ping.Timestamp
18117
Log.Debug().Msgf("ping received, latency: %dms", latency)
19118

20-
// Respond with Pong
21119
pong := &messages.ClientMessage{
22120
Payload: &messages.ClientMessage_Pong{
23121
Pong: &messages.PongResponse{
24122
Timestamp: time.Now().UnixMilli(),
25123
},
26124
},
27125
}
28-
data, err := proto.Marshal(pong)
29-
30-
if err != nil {
31-
Log.Error().Err(err).Msg("failed to marshal Pong response")
32-
return
33-
}
34-
if err := conn.WriteMessage(websocket.BinaryMessage, data); err != nil {
126+
if err := sendMessage(conn, session, pong); err != nil {
35127
Log.Error().Err(err).Msg("failed to send Pong response")
36-
return
37128
}
38129
default:
39130
Log.Warn().Msg("unknown message type received")

0 commit comments

Comments
 (0)