diff --git a/block_device.go b/block_device.go index 2099da8fa..4b663ea15 100644 --- a/block_device.go +++ b/block_device.go @@ -7,6 +7,7 @@ import ( "os" "time" + "github.com/jetkvm/kvm/internal/logging" "github.com/pojntfx/go-nbd/pkg/server" "github.com/rs/zerolog" ) @@ -16,6 +17,8 @@ type remoteImageBackend struct { func (r remoteImageBackend) ReadAt(p []byte, off int64) (n int, err error) { virtualMediaStateMutex.RLock() + + logger := logging.GetSubsystemLogger("nbd") logger.Debug().Interface("currentVirtualMediaState", currentVirtualMediaState).Msg("currentVirtualMediaState") logger.Debug().Int64("read size", int64(len(p))).Int64("off", off).Msg("read size and off") if currentVirtualMediaState == nil { @@ -60,14 +63,21 @@ type NBDDevice struct { serverConn net.Conn clientConn net.Conn dev *os.File - - l *zerolog.Logger } func NewNBDDevice() *NBDDevice { return &NBDDevice{} } +func (d *NBDDevice) getLogger() *zerolog.Logger { + logger := logging.GetSubsystemLogger("nbd"). + With(). + Str("socket_path", nbdSocketPath). + Str("device_path", nbdDevicePath). + Logger() + return &logger +} + func (d *NBDDevice) Start() error { var err error @@ -80,18 +90,10 @@ func (d *NBDDevice) Start() error { return err } - if d.l == nil { - scopedLogger := nbdLogger.With(). - Str("socket_path", nbdSocketPath). - Str("device_path", nbdDevicePath). - Logger() - d.l = &scopedLogger - } - // Remove the socket file if it already exists if _, err := os.Stat(nbdSocketPath); err == nil { if err := os.Remove(nbdSocketPath); err != nil { - d.l.Error().Err(err).Msg("failed to remove existing socket file") + d.getLogger().Error().Err(err).Msg("failed to remove existing socket file") os.Exit(1) } } @@ -133,5 +135,5 @@ func (d *NBDDevice) runServerConn() { SupportsMultiConn: false, }) - d.l.Info().Err(err).Msg("nbd server exited") + d.getLogger().Info().Err(err).Msg("nbd server exited") } diff --git a/block_device_linux.go b/block_device_linux.go index 8ca93722a..3d1b6f67e 100644 --- a/block_device_linux.go +++ b/block_device_linux.go @@ -11,14 +11,14 @@ func (d *NBDDevice) runClientConn() { ExportName: "jetkvm", BlockSize: uint32(4 * 1024), }) - d.l.Info().Err(err).Msg("nbd client exited") + d.getLogger().Info().Err(err).Msg("nbd client exited") } func (d *NBDDevice) Close() { if d.dev != nil { err := client.Disconnect(d.dev) if err != nil { - d.l.Warn().Err(err).Msg("error disconnecting nbd client") + d.getLogger().Warn().Err(err).Msg("error disconnecting nbd client") } _ = d.dev.Close() } diff --git a/block_device_notlinux.go b/block_device_notlinux.go index b6a9abaa5..25cec5ae2 100644 --- a/block_device_notlinux.go +++ b/block_device_notlinux.go @@ -7,11 +7,11 @@ import ( ) func (d *NBDDevice) runClientConn() { - d.l.Error().Msg("platform not supported") + d.getLogger().Error().Msg("platform not supported") os.Exit(1) } func (d *NBDDevice) Close() { - d.l.Error().Msg("platform not supported") + d.getLogger().Error().Msg("platform not supported") os.Exit(1) } diff --git a/cloud.go b/cloud.go index dbbd3bbcc..446fcfa74 100644 --- a/cloud.go +++ b/cloud.go @@ -13,6 +13,8 @@ import ( "github.com/coder/websocket/wsjson" "github.com/google/uuid" + "github.com/jetkvm/kvm/internal/logging" + "github.com/jetkvm/kvm/internal/utils" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -20,7 +22,6 @@ import ( "github.com/coder/websocket" "github.com/gin-gonic/gin" - "github.com/rs/zerolog" ) type CloudRegisterRequest struct { @@ -284,6 +285,7 @@ func disconnectCloud(reason error) { cloudDisconnectLock.Lock() defer cloudDisconnectLock.Unlock() + cloudLogger := logging.GetSubsystemLogger("cloud") if cloudDisconnectChan == nil { cloudLogger.Trace().Msg("cloud disconnect channel is not set, no need to disconnect") return @@ -323,18 +325,13 @@ func runWebsocketClient() error { header.Set("Authorization", "Bearer "+config.CloudToken) dialCtx, cancelDial := context.WithTimeout(context.Background(), CloudWebSocketConnectTimeout) - l := websocketLogger.With(). - Str("source", wsURL.Host). - Str("sourceType", "cloud"). - Logger() - - scopedLogger := &l + logger := logging.GetSubsystemLogger("cloud").With().Str("subcomponent", "websocket").Str("source", wsURL.Host).Str("sourceType", "cloud").Logger() defer cancelDial() c, resp, err := websocket.Dial(dialCtx, wsURL.String(), &websocket.DialOptions{ HTTPHeader: header, OnPingReceived: func(ctx context.Context, payload []byte) bool { - scopedLogger.Debug().Bytes("payload", payload).Int("length", len(payload)).Msg("ping frame received") + logger.Debug().Object("data", utils.ByteSlice(payload)).Int("length", len(payload)).Msg("ping frame received") metricConnectionTotalPingReceivedCount.WithLabelValues("cloud", wsURL.Host).Inc() metricConnectionLastPingReceivedTimestamp.WithLabelValues("cloud", wsURL.Host).SetToCurrentTime() @@ -356,15 +353,13 @@ func runWebsocketClient() error { if connectionId == "" { connectionId = uuid.New().String() - scopedLogger.Warn(). + logger.Warn(). Str("connectionId", connectionId). Msg("no connection id received from the server, generating a new one") } - lWithConnectionId := scopedLogger.With(). - Str("connectionID", connectionId). - Logger() - scopedLogger = &lWithConnectionId + logger = logger.With().Str("connectionID", connectionId).Logger() + cloudLogger := logging.GetSubsystemLogger("cloud") // if the context is canceled, we don't want to return an error if err != nil { @@ -386,7 +381,7 @@ func runWebsocketClient() error { wsResetMetrics(true, "cloud", wsURL.Host) // we don't have a source for the cloud connection - return handleWebRTCSignalWsMessages(c, true, wsURL.Host, connectionId, scopedLogger) + return handleWebRTCSignalWsMessages(c, true, wsURL.Host, connectionId) } func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessionRequest) error { @@ -397,7 +392,7 @@ func authenticateSession(ctx context.Context, c *websocket.Conn, req WebRTCSessi _ = wsjson.Write(context.Background(), c, gin.H{ "error": fmt.Sprintf("failed to initialize OIDC provider: %v", err), }) - cloudLogger.Warn().Err(err).Msg("failed to initialize OIDC provider") + logging.GetSubsystemLogger("cloud").Warn().Err(err).Msg("failed to initialize OIDC provider") return err } @@ -426,7 +421,6 @@ func handleSessionRequest( req WebRTCSessionRequest, isCloudConnection bool, source string, - scopedLogger *zerolog.Logger, ) error { var sourceType string if isCloudConnection { @@ -453,7 +447,6 @@ func handleSessionRequest( IsCloud: isCloudConnection, LocalIP: req.IP, ICEServers: req.ICEServers, - Logger: scopedLogger, }) if err != nil { _ = wsjson.Write(context.Background(), c, gin.H{"error": err}) @@ -474,11 +467,10 @@ func handleSessionRequest( }() } - cloudLogger.Info().Interface("session", session).Msg("new session accepted") - cloudLogger.Trace().Interface("session", session).Msg("new session accepted") + logging.GetSubsystemLogger("cloud").Info().Interface("session", session).Msg("new session accepted") // Cancel any ongoing keyboard macro when session changes - cancelKeyboardMacro() + _ = cancelKeyboardMacro() currentSession = session _ = wsjson.Write(context.Background(), c, gin.H{"type": "answer", "data": sd}) @@ -495,21 +487,21 @@ func RunWebsocketClient() { // If the network is not up, well, we can't connect to the cloud. if !networkManager.IsOnline() { - cloudLogger.Warn().Msg("waiting for network to be online, will retry in 3 seconds") + logging.GetSubsystemLogger("cloud").Warn().Msg("waiting for network to be online, will retry in 3 seconds") time.Sleep(3 * time.Second) continue } // If the system time is not synchronized, the API request will fail anyway because the TLS handshake will fail. if isTimeSyncNeeded() && !timeSync.IsSyncSuccess() { - cloudLogger.Warn().Msg("system time is not synced, will retry in 3 seconds") + logging.GetSubsystemLogger("cloud").Warn().Msg("system time is not synced, will retry in 3 seconds") time.Sleep(3 * time.Second) continue } err := runWebsocketClient() if err != nil { - cloudLogger.Warn().Err(err).Msg("websocket client error") + logging.GetSubsystemLogger("cloud").Warn().Err(err).Msg("websocket client error") metricCloudConnectionStatus.Set(0) metricCloudConnectionFailureCount.Inc() time.Sleep(5 * time.Second) @@ -561,7 +553,7 @@ func rpcDeregisterDevice() error { return fmt.Errorf("failed to save configuration after deregistering: %w", err) } - cloudLogger.Info().Msg("device deregistered, disconnecting from cloud") + logging.GetSubsystemLogger("cloud").Info().Msg("device deregistered, disconnecting from cloud") disconnectCloud(fmt.Errorf("device deregistered")) setCloudConnectionState(CloudConnectionStateNotConfigured) diff --git a/cmd/main.go b/cmd/main.go index fcf2cdfee..a7c3160e0 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -11,10 +11,11 @@ import ( "syscall" "time" - "github.com/erikdubbelboer/gspt" "github.com/jetkvm/kvm" "github.com/jetkvm/kvm/internal/native" "github.com/jetkvm/kvm/internal/supervisor" + + "github.com/erikdubbelboer/gspt" ) var ( diff --git a/config.go b/config.go index a06febd56..ab64fad82 100644 --- a/config.go +++ b/config.go @@ -128,7 +128,7 @@ func (c *Config) GetUpdateAPIURL() string { func (c *Config) GetDisplayRotation() uint16 { rotationInt, err := strconv.ParseUint(c.DisplayRotation, 10, 16) if err != nil { - logger.Warn().Err(err).Msg("invalid display rotation, using default") + logging.GetSubsystemLogger("config").Warn().Err(err).Msg("invalid display rotation, using default") return 270 } return uint16(rotationInt) @@ -138,7 +138,7 @@ func (c *Config) GetDisplayRotation() uint16 { func (c *Config) SetDisplayRotation(rotation string) error { _, err := strconv.ParseUint(rotation, 10, 16) if err != nil { - logger.Warn().Err(err).Msg("invalid display rotation, using default") + logging.GetSubsystemLogger("config").Warn().Err(err).Msg("invalid display rotation, using default") return err } c.DisplayRotation = rotation @@ -224,6 +224,8 @@ func LoadConfig() { configLock.Lock() defer configLock.Unlock() + logger := logging.GetSubsystemLogger("config") + if config != nil { logger.Debug().Msg("config already loaded, skipping") return @@ -272,10 +274,9 @@ func LoadConfig() { loadedConfig.KeyboardLayout = "en-US" } + logging.UpdateConfigLogLevel(loadedConfig.DefaultLogLevel) config = &loadedConfig - logging.GetRootLogger().UpdateLogLevel(config.DefaultLogLevel) - configSuccess.Set(1.0) configSuccessTime.SetToCurrentTime() @@ -294,6 +295,7 @@ func saveConfig(path string) error { configLock.Lock() defer configLock.Unlock() + logger := logging.GetSubsystemLogger("config") logger.Trace().Str("path", path).Msg("Saving config") // fixup old keyboard layout value diff --git a/display.go b/display.go index 9b12ad433..6c70af852 100644 --- a/display.go +++ b/display.go @@ -10,6 +10,7 @@ import ( "sync" "time" + "github.com/jetkvm/kvm/internal/logging" "github.com/prometheus/common/version" ) @@ -75,7 +76,7 @@ func updateDisplay() { nativeInstance.UpdateLabelIfChanged("hdmi_status_label", "Disconnected") _, _ = nativeInstance.UIObjClearState("hdmi_status_label", "LV_STATE_CHECKED") } - nativeInstance.UpdateLabelIfChanged("cloud_status_label", fmt.Sprintf("%d active", actionSessions)) + nativeInstance.UpdateLabelIfChanged("cloud_status_label", fmt.Sprintf("%d active", int(getActiveSessions()))) if networkManager != nil && networkManager.IsUp() { nativeInstance.UISetVar("main_screen", "home_screen") @@ -185,14 +186,13 @@ func requestDisplayUpdate(shouldWakeDisplay bool, reason string) { defer displayUpdateLock.Unlock() if !displayInited { - displayLogger.Info().Msg("display not inited, skipping updates") + logging.GetSubsystemLogger("display").Info().Msg("display not inited, skipping updates") return } go func() { if shouldWakeDisplay { wakeDisplay(false, reason) } - displayLogger.Debug().Msg("display updating") // TODO: only run once regardless how many pending updates updateDisplay() }() @@ -240,7 +240,7 @@ func updateStaticContents() { // configureDisplayOnNativeRestart is called when the native process restarts // it ensures the display is configured correctly after the restart func configureDisplayOnNativeRestart() { - displayLogger.Info().Msg("native restarted, configuring display") + logging.GetSubsystemLogger("display").Info().Msg("native restarted, configuring display") updateStaticContents() requestDisplayUpdate(true, "native_restart") } @@ -266,7 +266,7 @@ func setDisplayBrightness(brightness int, reason string) error { return err } - displayLogger.Info().Int("brightness", brightness).Str("reason", reason).Msg("set brightness") + logging.GetSubsystemLogger("display").Info().Int("brightness", brightness).Str("reason", reason).Msg("set brightness") return nil } @@ -275,7 +275,7 @@ func setDisplayBrightness(brightness int, reason string) error { func tick_displayDim() { err := setDisplayBrightness(config.DisplayMaxBrightness/2, "tick_display_dim") if err != nil { - displayLogger.Warn().Err(err).Msg("failed to dim display") + logging.GetSubsystemLogger("display").Warn().Err(err).Msg("failed to dim display") } dimTicker.Stop() @@ -288,7 +288,7 @@ func tick_displayDim() { func tick_displayOff() { err := setDisplayBrightness(0, "tick_display_off") if err != nil { - displayLogger.Warn().Err(err).Msg("failed to turn off display") + logging.GetSubsystemLogger("display").Warn().Err(err).Msg("failed to turn off display") } offTicker.Stop() @@ -315,7 +315,7 @@ func wakeDisplay(force bool, reason string) { err := setDisplayBrightness(config.DisplayMaxBrightness, reason) if err != nil { - displayLogger.Warn().Err(err).Msg("failed to wake display") + logging.GetSubsystemLogger("display").Warn().Err(err).Msg("failed to wake display") } if config.DisplayDimAfterSec != 0 && dimTicker != nil { @@ -348,6 +348,8 @@ func startBacklightTickers() { offTicker.Stop() } + displayLogger := logging.GetSubsystemLogger("display") + if config.DisplayDimAfterSec != 0 { displayLogger.Info().Msg("dim_ticker has started") dimTicker = time.NewTicker(time.Duration(config.DisplayDimAfterSec) * time.Second) @@ -373,6 +375,7 @@ func startBacklightTickers() { func initDisplay() { go func() { + displayLogger := logging.GetSubsystemLogger("display") displayLogger.Info().Msg("setting initial display contents") time.Sleep(500 * time.Millisecond) updateStaticContents() diff --git a/failsafe.go b/failsafe.go index d7de1c816..a9b01b4ae 100644 --- a/failsafe.go +++ b/failsafe.go @@ -1,11 +1,14 @@ package kvm import ( + "errors" "fmt" + "io/fs" "os" "strings" "sync" + "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/supervisor" ) @@ -53,10 +56,10 @@ func checkFailsafeReason() { } // check if the last crash log file exists - l := failsafeLogger.With().Str("path", lastCrashPath).Logger() + l := logging.GetSubsystemLogger("failsafe").With().Str("path", lastCrashPath).Logger() fi, err := os.Lstat(lastCrashPath) if err != nil { - if !os.IsNotExist(err) { + if !errors.Is(err, fs.ErrNotExist) { l.Warn().Err(err).Msg("failed to stat last crash log") } return @@ -98,9 +101,9 @@ func notifyFailsafeMode(session *Session) { return } - jsonRpcLogger.Info().Str("reason", failsafeModeReason).Msg("sending failsafe mode notification") + logging.GetSubsystemLogger("failsafe").Info().Str("reason", failsafeModeReason).Msg("sending failsafe mode notification") - writeJSONRPCEvent("failsafeMode", FailsafeModeNotification{ + go writeJSONRPCEvent("failsafeMode", FailsafeModeNotification{ Active: true, Reason: failsafeModeReason, }, session) diff --git a/go.mod b/go.mod index aac587993..37199d064 100644 --- a/go.mod +++ b/go.mod @@ -5,6 +5,7 @@ go 1.24.4 require ( github.com/Masterminds/semver/v3 v3.4.0 github.com/beevik/ntp v1.5.0 + github.com/caarlos0/env/v11 v11.3.1 github.com/coder/websocket v1.8.14 github.com/coreos/go-oidc/v3 v3.16.0 github.com/creack/pty v1.1.24 @@ -36,6 +37,8 @@ require ( golang.org/x/crypto v0.43.0 golang.org/x/net v0.46.0 golang.org/x/sys v0.37.0 + google.golang.org/grpc v1.76.0 + google.golang.org/protobuf v1.36.10 ) replace github.com/pojntfx/go-nbd v0.3.2 => github.com/chemhack/go-nbd v0.0.0-20241006125820-59e45f5b1e7b @@ -44,7 +47,6 @@ require ( github.com/beorn7/perks v1.0.1 // indirect github.com/bytedance/sonic v1.14.0 // indirect github.com/bytedance/sonic/loader v0.3.0 // indirect - github.com/caarlos0/env/v11 v11.3.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/creack/goselect v0.1.2 // indirect @@ -88,7 +90,6 @@ require ( github.com/prometheus/client_model v0.6.2 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect github.com/rogpeppe/go-internal v1.14.1 // indirect - github.com/spf13/pflag v1.0.10 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect github.com/u-root/uio v0.0.0-20230220225925-ffce2a382923 // indirect github.com/ugorji/go/codec v1.3.0 // indirect @@ -100,8 +101,5 @@ require ( golang.org/x/sync v0.17.0 // indirect golang.org/x/text v0.30.0 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250804133106-a7a43d27e69b // indirect - google.golang.org/grpc v1.76.0 // indirect - google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 // indirect - google.golang.org/protobuf v1.36.10 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 8d33b159a..465787e70 100644 --- a/go.sum +++ b/go.sum @@ -48,6 +48,10 @@ github.com/go-co-op/gocron/v2 v2.17.0 h1:e/oj6fcAM8vOOKZxv2Cgfmjo+s8AXC46po5ZPta github.com/go-co-op/gocron/v2 v2.17.0/go.mod h1:Zii6he+Zfgy5W9B+JKk/KwejFOW0kZTFvHtwIpR4aBI= github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -59,6 +63,8 @@ github.com/go-playground/validator/v10 v10.27.0/go.mod h1:I5QpIEbmr8On7W0TktmJAu github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= @@ -175,8 +181,6 @@ github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/sourcegraph/tf-dag v0.2.2-0.20250131204052-3e8ff1477b4f h1:VgoRCP1efSCEZIcF2THLQ46+pIBzzgNiaUBe9wEDwYU= github.com/sourcegraph/tf-dag v0.2.2-0.20250131204052-3e8ff1477b4f/go.mod h1:pzro7BGorij2WgrjEammtrkbo3+xldxo+KaGLGUiD+Q= -github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= -github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -204,6 +208,18 @@ github.com/wlynxg/anet v0.0.5 h1:J3VJGi1gvo0JwZ/P1/Yc/8p63SoW98B5dHkYDmpgvvU= github.com/wlynxg/anet v0.0.5/go.mod h1:eay5PRQr7fIVAMbTbchTnO9gG65Hg/uYGdc7mguHxoA= go.bug.st/serial v1.6.4 h1:7FmqNPgVp3pu2Jz5PoPtbZ9jJO5gnEnZIvnI1lzve8A= go.bug.st/serial v1.6.4/go.mod h1:nofMJxTeNVny/m6+KaafC6vJGj3miwQZ6vW4BZUGJPI= +go.opentelemetry.io/auto/sdk v1.1.0 h1:cH53jehLUN6UFLY71z+NDOiNJqDdPRaXzTel0sJySYA= +go.opentelemetry.io/auto/sdk v1.1.0/go.mod h1:3wSPjt5PWp2RhlCcmmOial7AvC4DQqZb7a7wCow3W8A= +go.opentelemetry.io/otel v1.37.0 h1:9zhNfelUvx0KBfu/gb+ZgeAfAgtWrfHJZcAqFC228wQ= +go.opentelemetry.io/otel v1.37.0/go.mod h1:ehE/umFRLnuLa/vSccNq9oS1ErUlkkK71gMcN34UG8I= +go.opentelemetry.io/otel/metric v1.37.0 h1:mvwbQS5m0tbmqML4NqK+e3aDiO02vsf/WgbsdpcPoZE= +go.opentelemetry.io/otel/metric v1.37.0/go.mod h1:04wGrZurHYKOc+RKeye86GwKiTb9FKm1WHtO+4EVr2E= +go.opentelemetry.io/otel/sdk v1.37.0 h1:ItB0QUqnjesGRvNcmAcU0LyvkVyGJ2xftD29bWdDvKI= +go.opentelemetry.io/otel/sdk v1.37.0/go.mod h1:VredYzxUvuo2q3WRcDnKDjbdvmO0sCzOvVAiY+yUkAg= +go.opentelemetry.io/otel/sdk/metric v1.37.0 h1:90lI228XrB9jCMuSdA0673aubgRobVZFhbjxHHspCPc= +go.opentelemetry.io/otel/sdk/metric v1.37.0/go.mod h1:cNen4ZWfiD37l5NhS+Keb5RXVWZWpRE+9WyVCpbo5ps= +go.opentelemetry.io/otel/trace v1.37.0 h1:HLdcFNbRQBE2imdSEgm/kwqmQj1Or1l/7bW6mxVK7z4= +go.opentelemetry.io/otel/trace v1.37.0/go.mod h1:TlgrlQ+PtQO5XFerSPUYG0JSgGyryXewPGyayAWSBS0= go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= @@ -230,12 +246,12 @@ golang.org/x/term v0.36.0 h1:zMPR+aF8gfksFprF/Nc/rd1wRS1EI6nDBGyWAvDzx2Q= golang.org/x/term v0.36.0/go.mod h1:Qu394IJq6V6dCBRgwqshf3mPF85AqzYEzofzRdZkWss= golang.org/x/text v0.30.0 h1:yznKA/E9zq54KzlzBEAWn1NXSQ8DIp/NYMy88xJjl4k= golang.org/x/text v0.30.0/go.mod h1:yDdHFIX9t+tORqspjENWgzaCVXgk0yYnYuSZ8UzzBVM= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= google.golang.org/genproto/googleapis/rpc v0.0.0-20250804133106-a7a43d27e69b h1:zPKJod4w6F1+nRGDI9ubnXYhU9NSWoFAijkHkUXeTK8= google.golang.org/genproto/googleapis/rpc v0.0.0-20250804133106-a7a43d27e69b/go.mod h1:qQ0YXyHHx3XkvlzUtpXDkS29lDSafHMZBAZDc03LQ3A= google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A= google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c= -google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1 h1:F29+wU6Ee6qgu9TddPgooOdaqsxTMunOoj8KA5yuS5A= -google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.5.1/go.mod h1:5KF+wpkbTSbGcR9zteSqZV6fqFOWBl4Yde8En8MryZA= google.golang.org/protobuf v1.36.10 h1:AYd7cD/uASjIL6Q9LiTjz8JLcrh/88q5UObnmY3aOOE= google.golang.org/protobuf v1.36.10/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/hidrpc.go b/hidrpc.go index ebe03daab..e15fb8869 100644 --- a/hidrpc.go +++ b/hidrpc.go @@ -7,99 +7,110 @@ import ( "time" "github.com/jetkvm/kvm/internal/hidrpc" + "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/usbgadget" - "github.com/rs/zerolog" + "github.com/jetkvm/kvm/internal/utils" ) -func handleHidRPCMessage(message hidrpc.Message, session *Session) { - var rpcErr error - +func handleHidRPCMessage(message hidrpc.Message, session *Session) error { switch message.Type() { case hidrpc.TypeHandshake: - message, err := hidrpc.NewHandshakeMessage().Marshal() - if err != nil { - logger.Warn().Err(err).Msg("failed to marshal handshake message") - return - } - if err := session.HidChannel.Send(message); err != nil { - logger.Warn().Err(err).Msg("failed to send handshake message") - return - } - session.hidRPCAvailable = true - case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport: - rpcErr = handleHidRPCKeyboardInput(message) + return handleHidRPCHandshake(session) case hidrpc.TypeKeyboardMacroReport: - keyboardMacroReport, err := message.KeyboardMacroReport() - if err != nil { - logger.Warn().Err(err).Msg("failed to get keyboard macro report") - return - } - rpcErr = rpcExecuteKeyboardMacro(keyboardMacroReport.Steps) + return handleKeyboardMacro(message) + case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport: + return handleHidRPCKeyboardInput(message) case hidrpc.TypeCancelKeyboardMacroReport: - rpcCancelKeyboardMacro() - return + return rpcCancelKeyboardMacro() case hidrpc.TypeKeypressKeepAliveReport: - rpcErr = handleHidRPCKeypressKeepAlive(session) + return handleHidRPCKeypressKeepAlive(session) case hidrpc.TypePointerReport: - pointerReport, err := message.PointerReport() - if err != nil { - logger.Warn().Err(err).Msg("failed to get pointer report") - return - } - rpcErr = rpcAbsMouseReport(pointerReport.X, pointerReport.Y, pointerReport.Button) + return handlePointerReport(message) case hidrpc.TypeMouseReport: - mouseReport, err := message.MouseReport() - if err != nil { - logger.Warn().Err(err).Msg("failed to get mouse report") - return - } - rpcErr = rpcRelMouseReport(mouseReport.DX, mouseReport.DY, mouseReport.Button) - default: - logger.Warn().Uint8("type", uint8(message.Type())).Msg("unknown HID RPC message type") + return handleMouseReport(message) + } + + return fmt.Errorf("unknown HID RPC message type %d", message.Type()) +} + +func handleHidRPCHandshake(session *Session) error { + hidrpc.GetHidRpcLogger().Debug().Msg("handling handshake") + message, err := hidrpc.NewHandshakeMessage().Marshal() + if err != nil { + return err + } + if err = session.HidChannel.Send(message); err != nil { + return err + } + session.hidRPCAvailable = true + return nil +} + +func handleKeyboardMacro(message hidrpc.Message) error { + keyboardMacroReport, err := message.KeyboardMacroReport() + if err != nil { + return err + } + + hidrpc.GetHidRpcLogger().Debug().Interface("keyboardMacroReport", keyboardMacroReport).Msg("handling keyboard macro") + return rpcExecuteKeyboardMacro(keyboardMacroReport.Steps) +} + +func handleMouseReport(message hidrpc.Message) error { + mouseReport, err := message.MouseReport() + if err != nil { + return err } + hidrpc.GetHidRpcLogger().Debug().Interface("mouseReport", mouseReport).Msg("handling relative mouse") + return rpcRelMouseReport(mouseReport.DX, mouseReport.DY, mouseReport.Button) +} - if rpcErr != nil { - logger.Warn().Err(rpcErr).Msg("failed to handle HID RPC message") +func handlePointerReport(message hidrpc.Message) error { + pointerReport, err := message.PointerReport() + if err != nil { + return err } + hidrpc.GetHidRpcLogger().Debug().Interface("pointerReport", pointerReport).Msg("handling absolute pointer") + return rpcAbsMouseReport(pointerReport.X, pointerReport.Y, pointerReport.Button) } -func onHidMessage(msg hidQueueMessage, session *Session) { +func onHidMessage(msg hidQueueMessage, session *Session, index int) { + logger := hidrpc.GetHidRpcLogger().With().Int("queueIndex", index).Str("channel", msg.channel).Logger() data := msg.Data - scopedLogger := hidRPCLogger.With(). - Str("channel", msg.channel). - Bytes("data", data). - Logger() - scopedLogger.Debug().Msg("HID RPC message received") + if logger.GetLevel() <= zerolog.TraceLevel { + logger.Trace().Object("data", utils.ByteSlice(data)).Msg("HID RPC message received") + } if len(data) < 1 { - scopedLogger.Warn().Int("length", len(data)).Msg("received empty data in HID RPC message handler") + logger.Warn().Int("length", len(data)).Msg("received empty data in HID RPC message handler") return } var message hidrpc.Message if err := hidrpc.Unmarshal(data, &message); err != nil { - scopedLogger.Warn().Err(err).Msg("failed to unmarshal HID RPC message") + logger.Warn().Err(err).Msg("failed to unmarshal HID RPC message") return } - if scopedLogger.GetLevel() <= zerolog.DebugLevel { - scopedLogger = scopedLogger.With().Str("descr", message.String()).Logger() + if logger.GetLevel() <= zerolog.DebugLevel { + logger = scopedLogger.With().Str("descr", message.String()).Logger() } t := time.Now() - r := make(chan interface{}) go func() { - handleHidRPCMessage(message, session) - r <- nil + r <- handleHidRPCMessage(message, session) }() select { case <-time.After(1 * time.Second): - scopedLogger.Warn().Msg("HID RPC message timed out") - case <-r: - scopedLogger.Debug().Dur("duration", time.Since(t)).Msg("HID RPC message handled") + logger.Warn().Msg("HID RPC message took too long") + case err := <-r: + logger.Debug().Dur("duration", time.Since(t)).Msg("HID RPC message handled") + if err != nil { + logger.Warn().Err(err.(error)).Msg("failed to handle HID RPC message") + } } } @@ -108,12 +119,10 @@ func onHidMessage(msg hidQueueMessage, session *Session) { // macOS default: 15 * 15 = 225ms https://discussions.apple.com/thread/1316947?sortBy=rank // Linux default: 250ms https://man.archlinux.org/man/kbdrate.8.en // Windows default: 1s `HKEY_CURRENT_USER\Control Panel\Accessibility\Keyboard Response\AutoRepeatDelay` - const expectedRate = 50 * time.Millisecond // expected keepalive interval const maxLateness = 50 * time.Millisecond // max jitter we'll tolerate OR jitter budget const baseExtension = expectedRate + maxLateness // 100ms extension on perfect tick - -const maxStaleness = 225 * time.Millisecond // discard ancient packets outright +const maxStaleness = 225 * time.Millisecond // discard ancient packets outright func handleHidRPCKeypressKeepAlive(session *Session) error { session.keepAliveJitterLock.Lock() @@ -128,7 +137,6 @@ func handleHidRPCKeypressKeepAlive(session *Session) error { return nil } - validTick := true timerExtension := baseExtension if !session.lastKeepAliveArrivalTime.IsZero() { @@ -147,14 +155,11 @@ func handleHidRPCKeypressKeepAlive(session *Session) error { // This is likely a retransmit stall or ordering delay. // We reject the tick entirely and DO NOT extend, // so the auto-release still fires on time. - validTick = false + return nil } } } - if !validTick { - return nil - } // Only valid ticks update our state and extend the timer. session.lastKeepAliveArrivalTime = now session.lastTimerResetTime = now @@ -167,6 +172,8 @@ func handleHidRPCKeypressKeepAlive(session *Session) error { } func handleHidRPCKeyboardInput(message hidrpc.Message) error { + logger := hidrpc.GetHidRpcLogger().With().Interface("message", message).Logger() + switch message.Type() { case hidrpc.TypeKeypressReport: keypressReport, err := message.KeypressReport() @@ -174,6 +181,7 @@ func handleHidRPCKeyboardInput(message hidrpc.Message) error { logger.Warn().Err(err).Msg("failed to get keypress report") return err } + logger.Debug().Interface("keypressReport", keypressReport).Msg("handling key press") return rpcKeypressReport(keypressReport.Key, keypressReport.Press) case hidrpc.TypeKeyboardReport: keyboardReport, err := message.KeyboardReport() @@ -181,6 +189,7 @@ func handleHidRPCKeyboardInput(message hidrpc.Message) error { logger.Warn().Err(err).Msg("failed to get keyboard report") return err } + logger.Debug().Interface("keyboardReport", keyboardReport).Msg("handling keyboard") return rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys) } @@ -188,6 +197,8 @@ func handleHidRPCKeyboardInput(message hidrpc.Message) error { } func reportHidRPC(params any, session *Session) { + logger := hidrpc.GetHidRpcLogger().With().Interface("params", params).Logger() + if session == nil { logger.Warn().Msg("session is nil, skipping reportHidRPC") return @@ -205,6 +216,7 @@ func reportHidRPC(params any, session *Session) { message []byte err error ) + switch params := params.(type) { case usbgadget.KeyboardState: message, err = hidrpc.NewKeyboardLedMessage(params).Marshal() @@ -216,23 +228,35 @@ func reportHidRPC(params any, session *Session) { err = fmt.Errorf("unknown HID RPC message type: %T", params) } - if err != nil { - logger.Warn().Err(err).Msg("failed to marshal HID RPC message") - return - } + logger = logger.With().Type("type", params).Logger() - if message == nil { - logger.Warn().Msg("failed to marshal HID RPC message") + if err != nil || message == nil { + logger.Warn().Err(err).Msg("failed to marshal HID RPC message") return } - if err := session.HidChannel.Send(message); err != nil { - if errors.Is(err, io.ErrClosedPipe) { - logger.Debug().Err(err).Msg("HID RPC channel closed, skipping reportHidRPC") - return + // fire and forget... + go func() { + t := time.Now() + r := make(chan interface{}) + go func() { + logger.Trace().Msg("sending HID RPC report") + r <- session.HidChannel.Send(message) + }() + select { + case <-time.After(1 * time.Second): + logger.Warn().Msg("HID RPC report took too long") + case err := <-r: + logger.Debug().Dur("duration", time.Since(t)).Msg("HID RPC report sent") + if err != nil { + if errors.Is(err.(error), io.ErrClosedPipe) { + logger.Warn().Err(err.(error)).Msg("HID RPC channel closed, skipping reportHidRPC") + return + } + logger.Warn().Err(err.(error)).Msg("failed to send HID RPC report") + } } - logger.Warn().Err(err).Msg("failed to send HID RPC message") - } + }() } func (s *Session) reportHidRPCKeyboardLedState(state usbgadget.KeyboardState) { @@ -244,10 +268,8 @@ func (s *Session) reportHidRPCKeyboardLedState(state usbgadget.KeyboardState) { func (s *Session) reportHidRPCKeysDownState(state usbgadget.KeysDownState) { if !s.hidRPCAvailable { - usbLogger.Debug().Interface("state", state).Msg("reporting keys down state") writeJSONRPCEvent("keysDownState", state, s) } - usbLogger.Debug().Interface("state", state).Msg("reporting keys down state, calling reportHidRPC") reportHidRPC(state, s) } diff --git a/hw.go b/hw.go index f6ec57f5c..902c315f7 100644 --- a/hw.go +++ b/hw.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/ota" ) @@ -32,6 +33,7 @@ func extractSerialNumber() (string, error) { } func hwReboot(force bool, postRebootAction *ota.PostRebootAction, delay time.Duration) error { + logger := logging.GetSubsystemLogger("hw") logger.Info().Dur("delayMs", delay).Msg("reboot requested") writeJSONRPCEvent("willReboot", postRebootAction, currentSession) @@ -71,7 +73,7 @@ func GetDeviceID() string { deviceIDOnce.Do(func() { serial, err := extractSerialNumber() if err != nil { - logger.Warn().Msg("unknown serial number, the program likely not running on RV1106") + logging.GetSubsystemLogger("hw").Warn().Msg("unknown serial number, the program likely not running on RV1106") deviceID = "unknown_device_id" } else { deviceID = serial @@ -90,6 +92,8 @@ func GetDefaultHostname() string { } func runWatchdog() { + watchdogLogger := logging.GetSubsystemLogger("hw").With().Str("subcomponent", "watchdog").Logger() + file, err := os.OpenFile("/dev/watchdog", os.O_WRONLY, 0) if err != nil { watchdogLogger.Warn().Err(err).Msg("unable to open /dev/watchdog, skipping watchdog reset") diff --git a/internal/hidrpc/hidrpc.go b/internal/hidrpc/hidrpc.go index 7313e3b53..b27315835 100644 --- a/internal/hidrpc/hidrpc.go +++ b/internal/hidrpc/hidrpc.go @@ -6,26 +6,9 @@ import ( "github.com/jetkvm/kvm/internal/usbgadget" ) -// MessageType is the type of the HID RPC message -type MessageType byte - const ( - TypeHandshake MessageType = 0x01 - TypeKeyboardReport MessageType = 0x02 - TypePointerReport MessageType = 0x03 - TypeWheelReport MessageType = 0x04 - TypeKeypressReport MessageType = 0x05 - TypeKeypressKeepAliveReport MessageType = 0x09 - TypeMouseReport MessageType = 0x06 - TypeKeyboardMacroReport MessageType = 0x07 - TypeCancelKeyboardMacroReport MessageType = 0x08 - TypeKeyboardLedState MessageType = 0x32 - TypeKeydownState MessageType = 0x33 - TypeKeyboardMacroState MessageType = 0x34 -) - -const ( - Version byte = 0x01 // Version of the HID RPC protocol + Version byte = 0x01 // Version of the HID RPC protocol + MaximumQueues int = 4 // Maximum number of HID RPC queues ) // GetQueueIndex returns the index of the queue to which the message should be enqueued. @@ -57,19 +40,6 @@ func Unmarshal(data []byte, message *Message) error { return nil } -// Marshal marshals the HID RPC message to the data. -func Marshal(message *Message) ([]byte, error) { - if message.t == 0 { - return nil, fmt.Errorf("invalid message type: %d", message.t) - } - - data := make([]byte, len(message.d)+1) - data[0] = byte(message.t) - copy(data[1:], message.d) - - return data, nil -} - // NewHandshakeMessage creates a new handshake message. func NewHandshakeMessage() *Message { return &Message{ diff --git a/internal/hidrpc/log.go b/internal/hidrpc/log.go new file mode 100644 index 000000000..6869b0b0b --- /dev/null +++ b/internal/hidrpc/log.go @@ -0,0 +1,10 @@ +package hidrpc + +import ( + "github.com/jetkvm/kvm/internal/logging" + "github.com/rs/zerolog" +) + +func GetHidRpcLogger() *zerolog.Logger { + return logging.GetSubsystemLogger("hidrpc") +} diff --git a/internal/hidrpc/message.go b/internal/hidrpc/message.go index 3f3506f7f..ec9247acd 100644 --- a/internal/hidrpc/message.go +++ b/internal/hidrpc/message.go @@ -3,6 +3,26 @@ package hidrpc import ( "encoding/binary" "fmt" + + "github.com/rs/zerolog" +) + +// MessageType is the type of the HID RPC message +type MessageType byte + +const ( + TypeHandshake MessageType = 0x01 + TypeKeyboardReport MessageType = 0x02 + TypePointerReport MessageType = 0x03 + TypeWheelReport MessageType = 0x04 + TypeKeypressReport MessageType = 0x05 + TypeKeypressKeepAliveReport MessageType = 0x09 + TypeMouseReport MessageType = 0x06 + TypeKeyboardMacroReport MessageType = 0x07 + TypeCancelKeyboardMacroReport MessageType = 0x08 + TypeKeyboardLedState MessageType = 0x32 + TypeKeydownState MessageType = 0x33 + TypeKeyboardMacroState MessageType = 0x34 ) // Message .. @@ -11,9 +31,22 @@ type Message struct { d []byte } +func (m Message) MarshalZerologObject(e *zerolog.Event) { + e.Uint8("Type", uint8(m.t)) + e.Bytes("Payload", m.d) +} + // Marshal marshals the message to a byte array. func (m *Message) Marshal() ([]byte, error) { - return Marshal(m) + if m.t == 0 { + return nil, fmt.Errorf("invalid message type: %d", m.t) + } + + data := make([]byte, len(m.d)+1) + data[0] = byte(m.t) + copy(data[1:], m.d) + + return data, nil } func (m *Message) Type() MessageType { @@ -56,12 +89,20 @@ func (m *Message) String() string { } } +// HidKeyBufferSize is the size of the keys buffer in the keyboard report. +const HidKeyBufferSize = 6 + // KeypressReport .. type KeypressReport struct { Key byte Press bool } +func (k KeypressReport) MarshalZerologObject(e *zerolog.Event) { + e.Hex("Key", []byte{k.Key}) + e.Bool("Press", k.Press) +} + // KeypressReport returns the keypress report from the message. func (m *Message) KeypressReport() (KeypressReport, error) { if m.t != TypeKeypressReport { @@ -77,7 +118,12 @@ func (m *Message) KeypressReport() (KeypressReport, error) { // KeyboardReport .. type KeyboardReport struct { Modifier byte - Keys []byte + Keys []byte // 6 bytes: HidKeyBufferSize +} + +func (k KeyboardReport) MarshalZerologObject(e *zerolog.Event) { + e.Hex("Modifier", []byte{k.Modifier}) + e.Hex("Keys", k.Keys) } // KeyboardReport returns the keyboard report from the message. @@ -95,17 +141,26 @@ func (m *Message) KeyboardReport() (KeyboardReport, error) { // Macro .. type KeyboardMacroStep struct { Modifier byte // 1 byte - Keys []byte // 6 bytes: hidKeyBufferSize + Keys []byte // 6 bytes: HidKeyBufferSize Delay uint16 // 2 bytes } + +func (s KeyboardMacroStep) MarshalZerologObject(e *zerolog.Event) { + e.Hex("Modifier", []byte{s.Modifier}) + e.Hex("Keys", s.Keys) + e.Uint16("Delay", s.Delay) +} + type KeyboardMacroReport struct { IsPaste bool StepCount uint32 Steps []KeyboardMacroStep } -// HidKeyBufferSize is the size of the keys buffer in the keyboard report. -const HidKeyBufferSize = 6 +func (m KeyboardMacroReport) MarshalZerologObject(e *zerolog.Event) { + e.Bool("IsPaste", m.IsPaste) + e.Uint32("StepCount", m.StepCount) +} // KeyboardMacroReport returns the keyboard macro report from the message. func (m *Message) KeyboardMacroReport() (KeyboardMacroReport, error) { @@ -117,7 +172,9 @@ func (m *Message) KeyboardMacroReport() (KeyboardMacroReport, error) { stepCount := binary.BigEndian.Uint32(m.d[1:5]) // check total length - expectedLength := int(stepCount)*9 + 5 + const StepSize = 1 + HidKeyBufferSize + 2 + + expectedLength := int(stepCount)*StepSize + 5 if len(m.d) != expectedLength { return KeyboardMacroReport{}, fmt.Errorf("invalid length: %d, expected: %d", len(m.d), expectedLength) } @@ -131,7 +188,7 @@ func (m *Message) KeyboardMacroReport() (KeyboardMacroReport, error) { Delay: binary.BigEndian.Uint16(m.d[offset+7 : offset+9]), }) - offset += 1 + HidKeyBufferSize + 2 + offset += StepSize } return KeyboardMacroReport{ @@ -148,6 +205,12 @@ type PointerReport struct { Button uint8 } +func (p PointerReport) MarshalZerologObject(e *zerolog.Event) { + e.Int("X", p.X) + e.Int("Y", p.Y) + e.Uint8("Button", p.Button) +} + func toInt(b []byte) int { return int(b[0])<<24 + int(b[1])<<16 + int(b[2])<<8 + int(b[3])<<0 } @@ -176,6 +239,12 @@ type MouseReport struct { Button uint8 } +func (m MouseReport) MarshalZerologObject(e *zerolog.Event) { + e.Int8("DX", m.DX) + e.Int8("DY", m.DY) + e.Uint8("Button", m.Button) +} + // MouseReport returns the mouse report from the message. func (m *Message) MouseReport() (MouseReport, error) { if m.t != TypeMouseReport { @@ -194,6 +263,11 @@ type KeyboardMacroState struct { IsPaste bool } +func (k KeyboardMacroState) MarshalZerologObject(e *zerolog.Event) { + e.Bool("State", k.State) + e.Bool("IsPaste", k.IsPaste) +} + // KeyboardMacroState returns the keyboard macro state report from the message. func (m *Message) KeyboardMacroState() (KeyboardMacroState, error) { if m.t != TypeKeyboardMacroState { diff --git a/internal/logging/logger.go b/internal/logging/logger.go index 3a8274c51..48fe65511 100644 --- a/internal/logging/logger.go +++ b/internal/logging/logger.go @@ -4,6 +4,7 @@ import ( "fmt" "io" "os" + "slices" "strings" "sync" "time" @@ -12,10 +13,10 @@ import ( ) type Logger struct { - l *zerolog.Logger - scopeLoggers map[string]*zerolog.Logger - scopeLevels map[string]zerolog.Level - scopeLevelMutex sync.Mutex + baseLogger *zerolog.Logger + scopeLoggers map[string]*zerolog.Logger + scopeLevels map[string]zerolog.Level + loggerMutex sync.Mutex defaultLogLevelFromEnv zerolog.Level defaultLogLevelFromConfig zerolog.Level @@ -24,6 +25,7 @@ type Logger struct { const ( defaultLogLevel = zerolog.ErrorLevel + unset = -2 ) type logOutput struct { @@ -45,14 +47,16 @@ func (w *logOutput) Write(p []byte) (n int, err error) { } var ( + excludedFields = []string{"scope", "component", "subcomponent"} + partsOrder = []string{"time", "level", "scope", "component", "subcomponent", "message"} consoleLogOutput io.Writer = zerolog.ConsoleWriter{ Out: os.Stdout, TimeFormat: time.RFC3339, - PartsOrder: []string{"time", "level", "scope", "component", "message"}, - FieldsExclude: []string{"scope", "component"}, + PartsOrder: partsOrder, + FieldsExclude: excludedFields, FormatPartValueByName: func(value any, name string) string { val := fmt.Sprintf("%s", value) - if name == "component" { + if slices.Contains(excludedFields, name) { if value == nil { return "-" } @@ -73,125 +77,147 @@ var ( "INFO": zerolog.InfoLevel, "DEBUG": zerolog.DebugLevel, "TRACE": zerolog.TraceLevel, + "UNSET": unset, } ) func NewLogger(zerologLogger zerolog.Logger) *Logger { return &Logger{ - l: &zerologLogger, + baseLogger: &zerologLogger, scopeLoggers: make(map[string]*zerolog.Logger), scopeLevels: make(map[string]zerolog.Level), - scopeLevelMutex: sync.Mutex{}, - defaultLogLevelFromEnv: -2, - defaultLogLevelFromConfig: -2, + loggerMutex: sync.Mutex{}, + defaultLogLevelFromEnv: unset, + defaultLogLevelFromConfig: unset, defaultLogLevel: defaultLogLevel, } } -func (l *Logger) updateLogLevel() { - l.scopeLevelMutex.Lock() - defer l.scopeLevelMutex.Unlock() +func (l *Logger) updateLogLevels(newConfigLevel zerolog.Level) { + logger := l.baseLogger.Level(zerolog.InfoLevel) + logger.Info().Msgf("updating log levels with new config level: %v", newConfigLevel) + l.defaultLogLevelFromConfig = newConfigLevel l.scopeLevels = make(map[string]zerolog.Level) - finalDefaultLogLevel := l.defaultLogLevel - - for name, level := range zerologLevels { + for name, envLevel := range zerologLevels { env := os.Getenv(fmt.Sprintf("JETKVM_LOG_%s", name)) - if env == "" { - env = os.Getenv(fmt.Sprintf("PION_LOG_%s", name)) - } - - if env == "" { - env = os.Getenv(fmt.Sprintf("PIONS_LOG_%s", name)) - } - - if env == "" { - continue + if env != "" { + env = strings.ToLower(env) + loopLogger := logger.With(). + Str("name", name). + Str("env", env). + Stringer("envLevel", envLevel). + Logger() + + if env == "all" { + loopLogger.Info().Msg("setting log level for ALL scopes from environment") + l.defaultLogLevelFromEnv = envLevel + } else { + // if not "all", parse as comma-separated list of scopes + scopes := strings.SplitSeq(env, ",") + for scope := range scopes { + loopLogger.Info().Msgf("setting log level for scope %s from environment", scope) + + if envLevel == unset { + delete(l.scopeLevels, scope) + } else { + l.scopeLevels[scope] = envLevel + } + } + } } + } - if strings.ToLower(env) == "all" { - l.defaultLogLevelFromEnv = level - - if finalDefaultLogLevel > level { - finalDefaultLogLevel = level - } + l.defaultLogLevel = defaultLogLevel + logger.Info().Msgf("default log level starts at %v", l.defaultLogLevel) - continue - } + if l.defaultLogLevel > l.defaultLogLevelFromEnv { + logger.Info().Msgf("default log level from env %v", l.defaultLogLevelFromEnv) + l.defaultLogLevel = l.defaultLogLevelFromEnv + } - scopes := strings.SplitSeq(strings.ToLower(env), ",") - for scope := range scopes { - l.scopeLevels[scope] = level - } + if l.defaultLogLevel > l.defaultLogLevelFromConfig { + logger.Info().Msgf("default log level from config %v", l.defaultLogLevelFromConfig) + l.defaultLogLevel = l.defaultLogLevelFromConfig } - l.defaultLogLevel = finalDefaultLogLevel + logger.Info().Msgf("default log level %v", l.defaultLogLevel) } func (l *Logger) getScopeLoggerLevel(scope string) zerolog.Level { - if l.scopeLevels == nil { - l.updateLogLevel() + if level, ok := l.scopeLevels[scope]; ok { + return level } - var scopeLevel zerolog.Level - if l.defaultLogLevelFromConfig != -2 { - scopeLevel = l.defaultLogLevelFromConfig - } - if l.defaultLogLevelFromEnv != -2 { - scopeLevel = l.defaultLogLevelFromEnv + // if the scope is not in the map, use the default level from the root logger + if l.defaultLogLevelFromConfig != unset { + return l.defaultLogLevelFromConfig } - // if the scope is not in the map, use the default level from the root logger - if level, ok := l.scopeLevels[scope]; ok { - scopeLevel = level + if l.defaultLogLevelFromEnv != unset { + return l.defaultLogLevelFromEnv } - return scopeLevel + return l.defaultLogLevel } -func (l *Logger) newScopeLogger(scope string) zerolog.Logger { - scopeLevel := l.getScopeLoggerLevel(scope) - logger := l.l.Level(scopeLevel).With().Str("component", scope).Logger() +func (l *Logger) getLogger(scope string) *zerolog.Logger { + if logger, ok := l.scopeLoggers[scope]; ok && logger != nil { + return logger + } - return logger -} + l.loggerMutex.Lock() + defer l.loggerMutex.Unlock() -func (l *Logger) getLogger(scope string) *zerolog.Logger { - logger, ok := l.scopeLoggers[scope] - if !ok || logger == nil { - scopeLogger := l.newScopeLogger(scope) - l.scopeLoggers[scope] = &scopeLogger + // double-check after acquiring the lock + if logger, ok := l.scopeLoggers[scope]; ok && logger != nil { + return logger } - return l.scopeLoggers[scope] + scopeLevel := l.getScopeLoggerLevel(scope) + logger := l.baseLogger.Level(scopeLevel).With().Str("component", scope).Logger() + l.scopeLoggers[scope] = &logger + return &logger } -func (l *Logger) UpdateLogLevel(configDefaultLogLevel string) { - needUpdate := false +func (l *Logger) UpdateConfigLogLevel(configDefaultLogLevel string) { + var newConfigLevel zerolog.Level + + configDefaultLogLevel = strings.ToUpper(configDefaultLogLevel) + loggingContext := l.baseLogger.With().Str("configDefaultLogLevel", configDefaultLogLevel) + logger := loggingContext.Logger() if configDefaultLogLevel != "" { if logLevel, ok := zerologLevels[configDefaultLogLevel]; ok { - l.defaultLogLevelFromConfig = logLevel + logger.Debug().Msgf("setting config log level to %v", logLevel) + newConfigLevel = logLevel } else { - l.l.Warn().Str("logLevel", configDefaultLogLevel).Msg("invalid defaultLogLevel from config, using ERROR") - } - - if l.defaultLogLevelFromConfig != l.defaultLogLevel { - needUpdate = true + logger.Warn().Msg("invalid log level from config") + return } + } else { + newConfigLevel = unset } - l.updateLogLevel() + l.loggerMutex.Lock() + defer l.loggerMutex.Unlock() - if needUpdate { - for scope, logger := range l.scopeLoggers { - currentLevel := logger.GetLevel() - targetLevel := l.getScopeLoggerLevel(scope) - if currentLevel != targetLevel { - *logger = l.newScopeLogger(scope) - } + l.updateLogLevels(newConfigLevel) + + for scope, logger := range l.scopeLoggers { + currentLevel := logger.GetLevel() + targetLevel := l.getScopeLoggerLevel(scope) + + if currentLevel != targetLevel { + debugLogger := loggingContext.Stringer("currentLevel", currentLevel).Stringer("targetLevel", targetLevel).Logger() + debugLogger.Info().Msgf("updating log level for scope %s", scope) + + // update the chosen level by replacing the logger + // with a new one at the target level + newLogger := logger.Level(targetLevel) + l.scopeLoggers[scope] = &newLogger } } } diff --git a/internal/logging/pion.go b/internal/logging/pion.go index 2676caf25..02e0344c2 100644 --- a/internal/logging/pion.go +++ b/internal/logging/pion.go @@ -6,58 +6,58 @@ import ( ) type pionLogger struct { - logger *zerolog.Logger + logger func() *zerolog.Logger } -// Print all messages except trace. func (c pionLogger) Trace(msg string) { - c.logger.Trace().Msg(msg) + c.logger().Trace().Msg(msg) } func (c pionLogger) Tracef(format string, args ...any) { - c.logger.Trace().Msgf(format, args...) + c.logger().Trace().Msgf(format, args...) } func (c pionLogger) Debug(msg string) { - c.logger.Debug().Msg(msg) + c.logger().Debug().Msg(msg) } func (c pionLogger) Debugf(format string, args ...any) { - c.logger.Debug().Msgf(format, args...) + c.logger().Debug().Msgf(format, args...) } + func (c pionLogger) Info(msg string) { - c.logger.Info().Msg(msg) + c.logger().Info().Msg(msg) } func (c pionLogger) Infof(format string, args ...any) { - c.logger.Info().Msgf(format, args...) + c.logger().Info().Msgf(format, args...) } + func (c pionLogger) Warn(msg string) { - c.logger.Warn().Msg(msg) + c.logger().Warn().Msg(msg) } func (c pionLogger) Warnf(format string, args ...any) { - c.logger.Warn().Msgf(format, args...) + c.logger().Warn().Msgf(format, args...) } + func (c pionLogger) Error(msg string) { - c.logger.Error().Msg(msg) + c.logger().Error().Msg(msg) } func (c pionLogger) Errorf(format string, args ...any) { - c.logger.Error().Msgf(format, args...) + c.logger().Error().Msgf(format, args...) } // customLoggerFactory satisfies the interface logging.LoggerFactory // This allows us to create different loggers per subsystem. So we can // add custom behavior. -type pionLoggerFactory struct{} - -func (c pionLoggerFactory) NewLogger(subsystem string) logging.LeveledLogger { - logger := rootLogger.getLogger(subsystem).With(). - Str("scope", "pion"). - Str("component", subsystem). - Logger() - - return pionLogger{logger: &logger} +type pionLoggerFactory struct { + subcomponent string } -var defaultLoggerFactory = &pionLoggerFactory{} +func (c pionLoggerFactory) NewLogger(subcomponent string) logging.LeveledLogger { + return pionLogger{logger: func() *zerolog.Logger { + logger := GetSubsystemLogger("pion").With().Str("subcomponent", subcomponent).Logger() + return &logger + }} +} -func GetPionDefaultLoggerFactory() logging.LoggerFactory { - return defaultLoggerFactory +func GetPionLoggerFactory(subcomponent string) logging.LoggerFactory { + return &pionLoggerFactory{subcomponent: subcomponent} } diff --git a/internal/logging/root.go b/internal/logging/root.go index 397ca6488..333ea7d89 100644 --- a/internal/logging/root.go +++ b/internal/logging/root.go @@ -3,16 +3,16 @@ package logging import "github.com/rs/zerolog" var ( - rootZerologLogger = zerolog.New(defaultLogOutput).With(). - Str("scope", "jetkvm"). - Timestamp(). - Stack(). - Logger() - rootLogger = NewLogger(rootZerologLogger) + rootLogger = NewLogger( + zerolog.New(defaultLogOutput). + With(). + Str("scope", "jetkvm"). + Timestamp(). + Logger()) ) -func GetRootLogger() *Logger { - return rootLogger +func UpdateConfigLogLevel(logLevel string) { + rootLogger.UpdateConfigLogLevel(logLevel) } func GetSubsystemLogger(subsystem string) *zerolog.Logger { diff --git a/internal/logging/sse.go b/internal/logging/sse.go index 05e6e9e25..29ec564af 100644 --- a/internal/logging/sse.go +++ b/internal/logging/sse.go @@ -6,7 +6,6 @@ import ( "net/http" "github.com/gin-gonic/gin" - "github.com/rs/zerolog" ) //go:embed sse.html @@ -24,15 +23,13 @@ type sseClientChan chan string var ( sseServer *sseEvent - sseLogger *zerolog.Logger ) func init() { sseServer = newSseServer() - sseLogger = GetSubsystemLogger("sse") } -// Initialize event and Start procnteessing requests +// Initialize event and Start proceessing requests func newSseServer() (event *sseEvent) { event = &sseEvent{ Message: make(chan string), @@ -54,7 +51,8 @@ func (stream *sseEvent) listen() { // Add new available client case client := <-stream.NewClients: stream.TotalClients[client] = true - sseLogger.Info(). + GetSubsystemLogger("sse"). + Info(). Int("total_clients", len(stream.TotalClients)). Msg("new client connected") @@ -62,7 +60,10 @@ func (stream *sseEvent) listen() { case client := <-stream.ClosedClients: delete(stream.TotalClients, client) close(client) - sseLogger.Info().Int("total_clients", len(stream.TotalClients)).Msg("client disconnected") + GetSubsystemLogger("sse"). + Info(). + Int("total_clients", len(stream.TotalClients)). + Msg("client disconnected") // Broadcast message to client case eventMsg := <-stream.Message: diff --git a/internal/logging/utils.go b/internal/logging/utils.go index 73ae37a84..fce76cd1e 100644 --- a/internal/logging/utils.go +++ b/internal/logging/utils.go @@ -2,31 +2,33 @@ package logging import ( "fmt" - "os" + "sync" "github.com/rs/zerolog" ) -var defaultLogger = zerolog.New(os.Stdout).Level(zerolog.InfoLevel) - -func GetDefaultLogger() *zerolog.Logger { - return &defaultLogger +func UnlockWithTraceLog(logger *zerolog.Logger, lock *sync.Mutex, msg string, args ...any) { + defer lock.Unlock() + logger.Trace().Msgf(msg, args...) } -func ErrorfL(l *zerolog.Logger, format string, err error, args ...any) error { - // TODO: move rootLogger to logging package - if l == nil { - l = &defaultLogger - } - - l.Error().Err(err).Msgf(format, args...) +func ErrorfL(logger *zerolog.Logger, format string, err error, args ...any) error { + logger.Error().Err(err).Msgf(format, args...) if err == nil { return fmt.Errorf(format, args...) } - err_msg := err.Error() + ": %v" + err_msg := err.Error() + ": %w" err_args := append(args, err) return fmt.Errorf(err_msg, err_args...) } + +func IsDebugLevel(logger *zerolog.Logger) bool { + return logger.GetLevel() <= zerolog.DebugLevel +} + +func IsTraceLevel(logger *zerolog.Logger) bool { + return logger.GetLevel() <= zerolog.TraceLevel +} diff --git a/internal/mdns/log.go b/internal/mdns/log.go new file mode 100644 index 000000000..114373f47 --- /dev/null +++ b/internal/mdns/log.go @@ -0,0 +1,16 @@ +package mdns + +import ( + "github.com/jetkvm/kvm/internal/logging" + "github.com/rs/zerolog" +) + +func (m *MDNS) getMdnsLogger() *zerolog.Logger { + logger := logging.GetSubsystemLogger("mdns"). + With(). + Strs("local_names", m.localNames). + Bool("ipv4", m.listenOptions.IPv4). + Bool("ipv6", m.listenOptions.IPv6). + Logger() + return &logger +} diff --git a/internal/mdns/mdns.go b/internal/mdns/mdns.go index 2b954d45d..d7fd81d3f 100644 --- a/internal/mdns/mdns.go +++ b/internal/mdns/mdns.go @@ -9,7 +9,6 @@ import ( "github.com/jetkvm/kvm/internal/logging" pion_mdns "github.com/pion/mdns/v2" - "github.com/rs/zerolog" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) @@ -17,7 +16,6 @@ import ( type MDNS struct { conn *pion_mdns.Conn lock sync.Mutex - l *zerolog.Logger localNames []string listenOptions *MDNSListenOptions @@ -29,7 +27,6 @@ type MDNSListenOptions struct { } type MDNSOptions struct { - Logger *zerolog.Logger LocalNames []string ListenOptions *MDNSListenOptions } @@ -40,10 +37,6 @@ const ( ) func NewMDNS(opts *MDNSOptions) (*MDNS, error) { - if opts.Logger == nil { - opts.Logger = logging.GetDefaultLogger() - } - if opts.ListenOptions == nil { opts.ListenOptions = &MDNSListenOptions{ IPv4: true, @@ -52,7 +45,6 @@ func NewMDNS(opts *MDNSOptions) (*MDNS, error) { } return &MDNS{ - l: opts.Logger, lock: sync.Mutex{}, localNames: opts.LocalNames, listenOptions: opts.ListenOptions, @@ -75,8 +67,10 @@ func (m *MDNS) start(allowRestart bool) error { return fmt.Errorf("listen options not set") } + logger := *m.getMdnsLogger() + if !m.listenOptions.IPv4 && !m.listenOptions.IPv6 { - m.l.Info().Msg("mDNS server disabled") + logger.Info().Msg("mDNS server disabled") return nil } @@ -116,12 +110,6 @@ func (m *MDNS) start(allowRestart bool) error { p6 = ipv6.NewPacketConn(l6) } - scopeLogger := m.l.With(). - Interface("local_names", m.localNames). - Bool("ipv4", m.listenOptions.IPv4). - Bool("ipv6", m.listenOptions.IPv6). - Logger() - newLocalNames := make([]string, len(m.localNames)) for i, name := range m.localNames { newLocalNames[i] = strings.TrimRight(strings.ToLower(name), ".") @@ -132,16 +120,18 @@ func (m *MDNS) start(allowRestart bool) error { mDNSConn, err := pion_mdns.Server(p4, p6, &pion_mdns.Config{ LocalNames: newLocalNames, - LoggerFactory: logging.GetPionDefaultLoggerFactory(), + LoggerFactory: logging.GetPionLoggerFactory("mdns"), }) + logger = logger.With().Interface("mDNSConn", mDNSConn).Logger() + if err != nil { - scopeLogger.Warn().Err(err).Msg("failed to start mDNS server") + logger.Error().Err(err).Msg("failed to start mDNS server") return err } m.conn = mDNSConn - scopeLogger.Info().Msg("mDNS server started") + logger.Info().Msg("mDNS server started") return nil } diff --git a/internal/mdns/utils.go b/internal/mdns/utils.go deleted file mode 100644 index 7565eee2c..000000000 --- a/internal/mdns/utils.go +++ /dev/null @@ -1 +0,0 @@ -package mdns diff --git a/internal/native/cgo/edid.c b/internal/native/cgo/edid.c index 95dfe95e5..00ebce4f0 100644 --- a/internal/native/cgo/edid.c +++ b/internal/native/cgo/edid.c @@ -81,24 +81,20 @@ int set_edid(uint8_t *edid, size_t size) if (size != 128 && size != 256) { errno = EINVAL; - return -1; + return -2; } - int fd; - struct v4l2_edid v4l2_edid; - - fd = open(V4L_SUBDEV, O_RDWR); + int fd = open(V4L_SUBDEV, O_RDWR); if (fd < 0) { log_error("Failed to open device"); - return -1; + return -3; } fix_edid_checksum(edid, size); + struct v4l2_edid v4l2_edid; memset(&v4l2_edid, 0, sizeof(v4l2_edid)); - v4l2_edid.pad = 0; - v4l2_edid.start_block = 0; v4l2_edid.blocks = size / 128; v4l2_edid.edid = edid; @@ -106,7 +102,7 @@ int set_edid(uint8_t *edid, size_t size) { log_error("Failed to set EDID"); close(fd); - return -1; + return -4; } close(fd); diff --git a/internal/native/cgo_linux.go b/internal/native/cgo_linux.go index dcd25e42a..2820980d8 100644 --- a/internal/native/cgo_linux.go +++ b/internal/native/cgo_linux.go @@ -75,7 +75,6 @@ func jetkvm_go_log_handler(level C.int, filename *C.cchar_t, funcname *C.cchar_t FuncName: C.GoString(funcname), Line: int(line), } - logChan <- logMessage } @@ -143,7 +142,7 @@ func videoInit(factor float64) error { ret := C.jetkvm_video_init(factorC) if ret != 0 { - return fmt.Errorf("failed to initialize video: %d", ret) + return fmt.Errorf("failed to initialize video with factor %f: %d", factor, ret) } return nil } @@ -194,7 +193,6 @@ func uiSetVar(name string, value string) { nameCStr := C.CString(name) defer C.free(unsafe.Pointer(nameCStr)) - valueCStr := C.CString(value) defer C.free(unsafe.Pointer(valueCStr)) @@ -217,6 +215,7 @@ func uiSwitchToScreen(screen string) { screenCStr := C.CString(screen) defer C.free(unsafe.Pointer(screenCStr)) + C.jetkvm_ui_load_screen(screenCStr) } @@ -236,6 +235,7 @@ func uiObjAddState(objName string, state string) (bool, error) { defer C.free(unsafe.Pointer(objNameCStr)) stateCStr := C.CString(state) defer C.free(unsafe.Pointer(stateCStr)) + C.jetkvm_ui_add_state(objNameCStr, stateCStr) return true, nil } @@ -248,6 +248,7 @@ func uiObjClearState(objName string, state string) (bool, error) { defer C.free(unsafe.Pointer(objNameCStr)) stateCStr := C.CString(state) defer C.free(unsafe.Pointer(stateCStr)) + C.jetkvm_ui_clear_state(objNameCStr, stateCStr) return true, nil } @@ -268,8 +269,12 @@ func uiObjAddFlag(objName string, flag string) (bool, error) { defer C.free(unsafe.Pointer(objNameCStr)) flagCStr := C.CString(flag) defer C.free(unsafe.Pointer(flagCStr)) - C.jetkvm_ui_add_flag(objNameCStr, flagCStr) - return true, nil + + ret := C.jetkvm_ui_add_flag(objNameCStr, flagCStr) + if ret < 0 { + return false, fmt.Errorf("failed to add flag %s on %s: %d", flag, objName, ret) + } + return ret == 0, nil } func uiObjClearFlag(objName string, flag string) (bool, error) { @@ -280,8 +285,12 @@ func uiObjClearFlag(objName string, flag string) (bool, error) { defer C.free(unsafe.Pointer(objNameCStr)) flagCStr := C.CString(flag) defer C.free(unsafe.Pointer(flagCStr)) - C.jetkvm_ui_clear_flag(objNameCStr, flagCStr) - return true, nil + + ret := C.jetkvm_ui_clear_flag(objNameCStr, flagCStr) + if ret < 0 { + return false, fmt.Errorf("failed to clear flag %s on %s: %d", flag, objName, ret) + } + return ret == 0, nil } func uiObjHide(objName string) (bool, error) { @@ -311,7 +320,6 @@ func uiObjFadeIn(objName string, duration uint32) (bool, error) { defer C.free(unsafe.Pointer(objNameCStr)) C.jetkvm_ui_fade_in(objNameCStr, C.u_int32_t(duration)) - return true, nil } @@ -323,7 +331,6 @@ func uiObjFadeOut(objName string, duration uint32) (bool, error) { defer C.free(unsafe.Pointer(objNameCStr)) C.jetkvm_ui_fade_out(objNameCStr, C.u_int32_t(duration)) - return true, nil } @@ -333,13 +340,12 @@ func uiLabelSetText(objName string, text string) (bool, error) { objNameCStr := C.CString(objName) defer C.free(unsafe.Pointer(objNameCStr)) - textCStr := C.CString(text) defer C.free(unsafe.Pointer(textCStr)) ret := C.jetkvm_ui_set_text(objNameCStr, textCStr) if ret < 0 { - return false, fmt.Errorf("failed to set text: %d", ret) + return false, fmt.Errorf("failed to set %s text to %s: %d", objName, text, ret) } return ret == 0, nil } @@ -350,12 +356,10 @@ func uiImgSetSrc(objName string, src string) (bool, error) { objNameCStr := C.CString(objName) defer C.free(unsafe.Pointer(objNameCStr)) - srcCStr := C.CString(src) defer C.free(unsafe.Pointer(srcCStr)) C.jetkvm_ui_set_image(objNameCStr, srcCStr) - return true, nil } @@ -363,8 +367,6 @@ func uiDispSetRotation(rotation uint16) (bool, error) { cgoLock.Lock() defer cgoLock.Unlock() - nativeLogger.Info().Uint16("rotation", rotation).Msg("setting rotation") - cRotation := C.u_int16_t(rotation) C.jetkvm_ui_set_rotation(cRotation) @@ -383,7 +385,10 @@ func videoSetStreamQualityFactor(factor float64) error { cgoLock.Lock() defer cgoLock.Unlock() - C.jetkvm_video_set_quality_factor(C.float(factor)) + ret := C.jetkvm_video_set_quality_factor(C.float(factor)) + if ret < 0 { + return fmt.Errorf("failed to set video quality factor to %f: %d", factor, ret) + } return nil } @@ -401,7 +406,11 @@ func videoSetEDID(edid string) error { edidCStr := C.CString(edid) defer C.free(unsafe.Pointer(edidCStr)) - C.jetkvm_video_set_edid(edidCStr) + + ret := C.jetkvm_video_set_edid(edidCStr) + if ret < 0 { + return fmt.Errorf("failed to set EDID to %s: %d", edid, ret) + } return nil } diff --git a/internal/native/chan.go b/internal/native/chan.go index cd6d07af1..80fb62e82 100644 --- a/internal/native/chan.go +++ b/internal/native/chan.go @@ -28,7 +28,6 @@ func (n *Native) handleVideoFrameChan() { func (n *Native) handleVideoStateChan() { for { state := <-videoStateChan - n.onVideoStateChange(state) } } @@ -36,7 +35,8 @@ func (n *Native) handleVideoStateChan() { func (n *Native) handleLogChan() { for { entry := <-logChan - l := n.l.With(). + l := GetNativeLogger(). + With(). Str("file", entry.File). Str("func", entry.FuncName). Int("line", entry.Line). diff --git a/internal/native/display.go b/internal/native/display.go index 9c92378d1..56430ad7b 100644 --- a/internal/native/display.go +++ b/internal/native/display.go @@ -99,18 +99,22 @@ func (n *Native) DisplaySetRotation(rotation uint16) (bool, error) { // UpdateLabelIfChanged updates the label if the text has changed func (n *Native) UpdateLabelIfChanged(objName string, newText string) { - l := n.lD.Trace().Str("obj", objName).Str("text", newText) + logger := GetDisplayLogger(). + With(). + Str("obj", objName). + Str("text", newText). + Logger() changed, err := n.UIObjSetLabelText(objName, newText) if err != nil { - n.lD.Warn().Str("obj", objName).Str("text", newText).Err(err).Msg("failed to update label") + logger.Warn().Err(err).Msg("failed to update label") return } if changed { - l.Msg("label changed") + logger.Trace().Msg("label changed") } else { - l.Msg("label not changed") + logger.Trace().Msg("label not changed") } } @@ -134,11 +138,19 @@ func (n *Native) SwitchToScreenIf(screenName string, shouldSwitch []string) { if currentScreen == screenName { return } + + logger := GetDisplayLogger(). + With(). + Str("from", currentScreen). + Str("to", screenName). + Strs("from_screens", shouldSwitch). + Logger() + if len(shouldSwitch) > 0 && !slices.Contains(shouldSwitch, currentScreen) { - n.lD.Trace().Str("from", currentScreen).Str("to", screenName).Msg("skipping screen switch") + logger.Trace().Msg("skipping screen switch") return } - n.lD.Info().Str("from", currentScreen).Str("to", screenName).Msg("switching screen") + logger.Info().Msg("switching screen") uiSwitchToScreen(screenName) } diff --git a/internal/native/grpc_client.go b/internal/native/grpc_client.go index 85a3201bd..beeaff93e 100644 --- a/internal/native/grpc_client.go +++ b/internal/native/grpc_client.go @@ -8,7 +8,9 @@ import ( "sync" "time" + "github.com/jetkvm/kvm/internal/logging" "github.com/rs/zerolog" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" @@ -25,7 +27,6 @@ type GRPCClient struct { conn *grpc.ClientConn client pb.NativeServiceClient - logger *zerolog.Logger eventStream pb.NativeService_StreamEventsClient eventM sync.RWMutex @@ -42,12 +43,20 @@ type GRPCClient struct { type grpcClientOptions struct { SocketPath string - Logger *zerolog.Logger OnVideoStateChange func(state VideoState) OnIndevEvent func(event string) OnRpcEvent func(event string) } +func (client *GRPCClient) getLogger() *zerolog.Logger { + logger := getClientLogger(). + With(). + Str("target", client.conn.Target()). + Bool("closed", client.closed). + Logger() + return &logger +} + // NewGRPCClient creates a new gRPC client connected to the native service func NewGRPCClient(opts grpcClientOptions) (*GRPCClient, error) { // Connect to the Unix domain socket @@ -68,7 +77,6 @@ func NewGRPCClient(opts grpcClientOptions) (*GRPCClient, error) { cancel: cancel, conn: conn, client: client, - logger: opts.Logger, eventCh: make(chan *pb.Event, 100), eventDone: make(chan struct{}), onVideoStateChange: opts.OnVideoStateChange, @@ -103,7 +111,7 @@ func (c *GRPCClient) handleEventStream(stream pb.NativeService_StreamEventsClien }() for { - logger := c.logger.With().Interface("stream", stream).Logger() + logger := c.getLogger().With().Interface("stream", stream).Logger() if stream == nil { logger.Error().Msg("event stream is nil") break @@ -121,7 +129,7 @@ func (c *GRPCClient) handleEventStream(stream pb.NativeService_StreamEventsClien } // enrich the logger with the event type and data, if debug mode is enabled - if c.logger.GetLevel() <= zerolog.DebugLevel { + if logging.IsDebugLevel(&logger) { logger = logger.With(). Str("type", event.Type). Interface("data", event.Data). @@ -150,14 +158,14 @@ func (c *GRPCClient) startEventStream() { // check if the context is done select { case <-c.ctx.Done(): - c.logger.Info().Msg("event stream context done, closing") + c.getLogger().Info().Msg("event stream context done, closing") return default: } stream, err := c.client.StreamEvents(c.ctx, &pb.Empty{}) if err != nil { - c.logger.Warn().Err(err).Msg("failed to start event stream, retrying ...") + c.getLogger().Warn().Err(err).Msg("failed to start event stream, retrying ...") time.Sleep(5 * time.Second) continue } @@ -170,7 +178,7 @@ func (c *GRPCClient) startEventStream() { } func (c *GRPCClient) checkIsReady(ctx context.Context) error { - c.logger.Trace().Msg("connection is idle, connecting ...") + c.getLogger().Trace().Msg("connection is idle, connecting ...") resp, err := c.client.IsReady(ctx, &pb.IsReadyRequest{}) if err != nil { @@ -193,7 +201,7 @@ func (c *GRPCClient) WaitReady() error { prevState := connectivity.Idle for { state := c.conn.GetState() - c.logger. + logger := c.getLogger(). With(). Str("state", state.String()). Int("prev_state", int(prevState)). @@ -207,7 +215,7 @@ func (c *GRPCClient) WaitReady() error { } } - c.logger.Info().Msg("waiting for connection to be ready") + logger.Info().Msg("waiting for connection to be ready") if state == connectivity.Ready { return nil @@ -223,11 +231,13 @@ func (c *GRPCClient) WaitReady() error { } func (c *GRPCClient) handleEvent(event *pb.Event) { + logger := c.getLogger().With().Str("event_type", event.Type).Logger() + switch event.Type { case "video_state_change": state := event.GetVideoState() if state == nil { - c.logger.Warn().Msg("video state event is nil") + logger.Warn().Msg("video state event is nil") return } c.onVideoStateChange(VideoState{ @@ -242,7 +252,7 @@ func (c *GRPCClient) handleEvent(event *pb.Event) { case "rpc_event": c.onRpcEvent(event.GetRpcEvent()) default: - c.logger.Warn().Str("type", event.Type).Msg("unknown event type") + logger.Warn().Str("type", event.Type).Msg("unknown event type") } } @@ -263,7 +273,7 @@ func (c *GRPCClient) Close() error { c.eventM.Lock() if c.eventStream != nil { if err := c.eventStream.CloseSend(); err != nil { - c.logger.Warn().Err(err).Msg("failed to close event stream") + c.getLogger().Warn().Err(err).Msg("failed to close event stream") } } c.eventM.Unlock() diff --git a/internal/native/grpc_server.go b/internal/native/grpc_server.go index 9b54fb5b7..7b7d5bb7d 100644 --- a/internal/native/grpc_server.go +++ b/internal/native/grpc_server.go @@ -6,17 +6,17 @@ import ( "net" "sync" - "github.com/rs/zerolog" "google.golang.org/grpc" pb "github.com/jetkvm/kvm/internal/native/proto" + "github.com/rs/zerolog" ) // grpcServer wraps the Native instance and implements the gRPC service type grpcServer struct { pb.UnimplementedNativeServiceServer native *Native - logger *zerolog.Logger + socketPath string eventStreamChan chan *pb.Event eventStreamMu sync.Mutex eventStreamCtx context.Context @@ -24,10 +24,10 @@ type grpcServer struct { } // NewGRPCServer creates a new gRPC server for the native service -func NewGRPCServer(n *Native, logger *zerolog.Logger) *grpcServer { +func NewGRPCServer(n *Native, socketPath string) *grpcServer { s := &grpcServer{ native: n, - logger: logger, + socketPath: socketPath, eventStreamChan: make(chan *pb.Event, 100), } @@ -83,6 +83,14 @@ func NewGRPCServer(n *Native, logger *zerolog.Logger) *grpcServer { return s } +func (s *grpcServer) getLogger() *zerolog.Logger { + logger := getServerLogger(). + With(). + Str("socketPath", s.socketPath). + Logger() + return &logger +} + func (s *grpcServer) broadcastEvent(event *pb.Event) { s.eventStreamChan <- event } @@ -99,7 +107,7 @@ func (s *grpcServer) StreamEvents(req *pb.Empty, stream pb.NativeService_StreamE // Cancel previous stream if exists s.eventStreamMu.Lock() if s.eventStreamCancel != nil { - s.logger.Debug().Msg("cancelling previous StreamEvents call") + s.getLogger().Debug().Msg("cancelling previous StreamEvents call") s.eventStreamCancel() } @@ -130,7 +138,7 @@ func (s *grpcServer) StreamEvents(req *pb.Empty, stream pb.NativeService_StreamE s.eventStreamMu.Unlock() if !isActive { - s.logger.Debug().Msg("stream replaced by new call, exiting") + s.getLogger().Debug().Msg("stream replaced by new call, exiting") return context.Canceled } @@ -144,8 +152,8 @@ func (s *grpcServer) StreamEvents(req *pb.Empty, stream pb.NativeService_StreamE } // StartGRPCServer starts the gRPC server on a Unix domain socket -func StartGRPCServer(server *grpcServer, socketPath string, logger *zerolog.Logger) (*grpc.Server, net.Listener, error) { - lis, err := net.Listen("unix", socketPath) +func (server *grpcServer) Start() (*grpc.Server, net.Listener, error) { + lis, err := net.Listen("unix", server.socketPath) if err != nil { return nil, nil, fmt.Errorf("failed to listen on socket: %w", err) } @@ -155,10 +163,10 @@ func StartGRPCServer(server *grpcServer, socketPath string, logger *zerolog.Logg go func() { if err := s.Serve(lis); err != nil { - logger.Error().Err(err).Msg("gRPC server error") + server.getLogger().Error().Err(err).Msg("gRPC server error") } }() - logger.Info().Str("socket", socketPath).Msg("gRPC server started") + server.getLogger().Info().Msg("gRPC server started") return s, lis, nil } diff --git a/internal/native/log.go b/internal/native/log.go index 41ae4df9f..0765b74a8 100644 --- a/internal/native/log.go +++ b/internal/native/log.go @@ -1,12 +1,29 @@ package native import ( + "os" + "github.com/jetkvm/kvm/internal/logging" "github.com/rs/zerolog" ) -var nativeLogger = logging.GetSubsystemLogger("native") -var displayLogger = logging.GetSubsystemLogger("display") +var pid = os.Getpid() + +func GetNativeLogger() *zerolog.Logger { + logger := logging.GetSubsystemLogger("native"). + With(). + Int("pid", pid). + Logger() + return &logger +} + +func GetDisplayLogger() *zerolog.Logger { + logging := logging.GetSubsystemLogger("display"). + With(). + Int("pid", pid). + Logger() + return &logging +} type nativeLogMessage struct { Level zerolog.Level diff --git a/internal/native/native.go b/internal/native/native.go index 87eebf185..5d513a483 100644 --- a/internal/native/native.go +++ b/internal/native/native.go @@ -1,18 +1,14 @@ package native import ( - "os" "sync" "time" "github.com/Masterminds/semver/v3" - "github.com/rs/zerolog" ) type Native struct { ready chan struct{} - l *zerolog.Logger - lD *zerolog.Logger systemVersion *semver.Version appVersion *semver.Version displayRotation uint16 @@ -23,7 +19,6 @@ type Native struct { onRpcEvent func(event string) sleepModeSupported bool videoLock sync.Mutex - screenLock sync.Mutex extraLock sync.Mutex } @@ -61,61 +56,49 @@ func (s VideoStreamingStatus) String() string { } func NewNative(opts NativeOptions) *Native { - pid := os.Getpid() - nativeSubLogger := nativeLogger.With().Int("pid", pid).Str("scope", "native").Logger() - displaySubLogger := displayLogger.With().Int("pid", pid).Str("scope", "native").Logger() - - onVideoStateChange := opts.OnVideoStateChange - if onVideoStateChange == nil { - onVideoStateChange = func(state VideoState) { - nativeLogger.Info().Interface("state", state).Msg("video state changed") - } + n := &Native{ + ready: make(chan struct{}), + systemVersion: opts.SystemVersion, + appVersion: opts.AppVersion, + displayRotation: opts.DisplayRotation, + defaultQualityFactor: opts.DefaultQualityFactor, + onVideoStateChange: opts.OnVideoStateChange, + onVideoFrameReceived: opts.OnVideoFrameReceived, + onIndevEvent: opts.OnIndevEvent, + onRpcEvent: opts.OnRpcEvent, + sleepModeSupported: isSleepModeSupported(), + videoLock: sync.Mutex{}, } - onVideoFrameReceived := opts.OnVideoFrameReceived - if onVideoFrameReceived == nil { - onVideoFrameReceived = func(frame []byte, duration time.Duration) { - nativeLogger.Trace().Interface("frame", frame).Dur("duration", duration).Msg("video frame received") + if n.onVideoStateChange == nil { + n.onVideoStateChange = func(state VideoState) { + GetDisplayLogger().Info().Interface("state", state).Msg("video state changed") } } - onIndevEvent := opts.OnIndevEvent - if onIndevEvent == nil { - onIndevEvent = func(event string) { - nativeLogger.Info().Str("event", event).Msg("indev event") + if n.onVideoFrameReceived == nil { + n.onVideoFrameReceived = func(frame []byte, duration time.Duration) { + GetDisplayLogger().Trace().Int("frame_size", len(frame)).Dur("duration", duration).Msg("video frame received") } } - onRpcEvent := opts.OnRpcEvent - if onRpcEvent == nil { - onRpcEvent = func(event string) { - nativeLogger.Info().Str("event", event).Msg("rpc event") + if n.onIndevEvent == nil { + n.onIndevEvent = func(event string) { + GetDisplayLogger().Info().Str("event", event).Msg("indev event") } } - sleepModeSupported := isSleepModeSupported() - - defaultQualityFactor := opts.DefaultQualityFactor - if defaultQualityFactor <= 0 || defaultQualityFactor > 1 { - defaultQualityFactor = 1.0 + if n.onRpcEvent == nil { + n.onRpcEvent = func(event string) { + GetNativeLogger().Info().Str("event", event).Msg("rpc event") + } } - return &Native{ - ready: make(chan struct{}), - l: &nativeSubLogger, - lD: &displaySubLogger, - systemVersion: opts.SystemVersion, - appVersion: opts.AppVersion, - displayRotation: opts.DisplayRotation, - defaultQualityFactor: defaultQualityFactor, - onVideoStateChange: onVideoStateChange, - onVideoFrameReceived: onVideoFrameReceived, - onIndevEvent: onIndevEvent, - onRpcEvent: onRpcEvent, - sleepModeSupported: sleepModeSupported, - videoLock: sync.Mutex{}, - screenLock: sync.Mutex{}, + if n.defaultQualityFactor <= 0 || n.defaultQualityFactor > 1 { + n.defaultQualityFactor = 1.0 } + + return n } func (n *Native) Start() error { @@ -134,7 +117,7 @@ func (n *Native) Start() error { go n.tickUI() if err := videoInit(n.defaultQualityFactor); err != nil { - n.l.Error().Err(err).Msg("failed to initialize video") + GetDisplayLogger().Error().Err(err).Msg("failed to initialize video") return err } @@ -147,7 +130,7 @@ func (n *Native) Start() error { func (n *Native) DoNotUseThisIsForCrashTestingOnly() { defer func() { if r := recover(); r != nil { - n.l.Error().Msg("recovered from crash") + GetNativeLogger().Error().Interface("recovered", r).Msg("recovered from crash") } }() diff --git a/internal/native/proxy.go b/internal/native/proxy.go index 62a913ae4..649059a7e 100644 --- a/internal/native/proxy.go +++ b/internal/native/proxy.go @@ -47,11 +47,20 @@ type nativeProxyOptions struct { OnNativeRestart func() } +func getClientLogger() *zerolog.Logger { + logger := GetNativeLogger(). + With(). + Str("subcomponent", "grpc-client"). + Int("pid", pid). + Logger() + return &logger +} + func randomId(binaryLength int) string { s := make([]byte, binaryLength) _, err := rand.Read(s) if err != nil { - nativeLogger.Error().Err(err).Msg("failed to generate random ID") + getClientLogger().Error().Err(err).Msg("failed to generate random ID") return strings.Repeat("0", binaryLength*2) // return all zeros if error } return hex.EncodeToString(s) @@ -133,7 +142,6 @@ type NativeProxy struct { cmd *cmdWrapper cmdMu sync.Mutex // mutex for the cmd - logger *zerolog.Logger options *nativeProxyOptions restarts uint stopped bool @@ -154,7 +162,6 @@ func NewNativeProxy(opts NativeOptions) (*NativeProxy, error) { nativeUnixSocket: proxyOptions.CtrlUnixSocket, videoStreamUnixSocket: proxyOptions.VideoStreamUnixSocket, binaryPath: exePath, - logger: nativeLogger, options: proxyOptions, restarts: 0, } @@ -167,7 +174,7 @@ func (p *NativeProxy) startVideoStreamListener() error { return nil } - logger := p.logger.With().Str("socketPath", p.videoStreamUnixSocket).Logger() + logger := GetDisplayLogger().With().Str("socketPath", p.videoStreamUnixSocket).Logger() listener, err := net.Listen("unix", p.videoStreamUnixSocket) if err != nil { logger.Warn().Err(err).Msg("failed to start video stream listener") @@ -185,7 +192,7 @@ func (p *NativeProxy) startVideoStreamListener() error { } logger.Info().Msg("video stream socket accepted") - go p.handleVideoFrame(conn) + go p.handleVideoFrames(conn) } }() @@ -252,7 +259,7 @@ func (p *NativeProxy) toProcessCommand() (*cmdWrapper, error) { return cmd, nil } -func (p *NativeProxy) handleVideoFrame(conn net.Conn) { +func (p *NativeProxy) handleVideoFrames(conn net.Conn) { defer conn.Close() inboundPacket := make([]byte, maxFrameSize) @@ -264,14 +271,14 @@ func (p *NativeProxy) handleVideoFrame(conn net.Conn) { _, err := io.ReadFull(conn, frameSizeBuffer[:]) if err != nil { if err != io.EOF { - p.logger.Warn().Err(err).Msg("failed to read frame size from socket") + GetDisplayLogger().Warn().Err(err).Msg("failed to read frame size from socket") } break } frameSize := binary.LittleEndian.Uint32(frameSizeBuffer[:]) if frameSize == 0 || frameSize > maxFrameSize { - p.logger.Error().Uint32("frameSize", frameSize).Uint32("maxFrameSize", maxFrameSize). + GetDisplayLogger().Error().Uint32("frameSize", frameSize).Uint32("maxFrameSize", maxFrameSize). Msg("received invalid frame size") break } @@ -279,7 +286,7 @@ func (p *NativeProxy) handleVideoFrame(conn net.Conn) { // Read the actual frame data _, err = io.ReadFull(conn, inboundPacket[:frameSize]) if err != nil { - p.logger.Warn().Err(err).Msg("failed to read video frame from socket") + GetDisplayLogger().Warn().Err(err).Msg("failed to read video frame from socket") break } @@ -295,24 +302,22 @@ func (p *NativeProxy) setUpGRPCClient() error { // wait until handshake completed select { case <-p.cmd.stdoutHandler.handshakeCh: - p.logger.Info().Msg("handshake completed") + getClientLogger().Info().Msg("grpc handshake completed") case <-time.After(10 * time.Second): return fmt.Errorf("handshake not completed within 10 seconds") } - logger := p.logger.With().Str("socketPath", "@"+p.nativeUnixSocket).Logger() client, err := NewGRPCClient(grpcClientOptions{ SocketPath: p.nativeUnixSocket, - Logger: &logger, OnIndevEvent: p.options.OnIndevEvent, OnRpcEvent: p.options.OnRpcEvent, OnVideoStateChange: p.options.OnVideoStateChange, }) - logger.Info().Msg("created gRPC client") if err != nil { return fmt.Errorf("failed to create gRPC client: %w", err) } + getClientLogger().Info().Str("socketPath", "@"+p.nativeUnixSocket).Msg("created gRPC client") p.client = client // Wait for ready signal from the native process @@ -353,12 +358,7 @@ func (p *NativeProxy) doStart() error { return fmt.Errorf("failed to start native process: %w", err) } - // here we'll replace the logger with a new one that includes the process ID - // there's no need to lock the mutex here as the side effect is acceptable - newLogger := p.logger.With().Int("pid", p.cmd.Process.Pid).Logger() - p.logger = &newLogger - - p.logger.Info().Msg("native process started") + getClientLogger().Info().Int("pid", p.cmd.Process.Pid).Msg("native process started") if err := p.setUpGRPCClient(); err != nil { return fmt.Errorf("failed to set up gRPC client: %w", err) @@ -400,7 +400,7 @@ func (p *NativeProxy) monitorProcess() { select { case <-p.ctx.Done(): - p.logger.Trace().Msg("context done, stopping monitor process [before wait]") + getClientLogger().Trace().Msg("context done, stopping monitor process [before wait]") return default: } @@ -418,19 +418,20 @@ func (p *NativeProxy) monitorProcess() { select { case <-p.ctx.Done(): - p.logger.Trace().Msg("context done, stopping monitor process [after wait]") + getClientLogger().Trace().Msg("context done, stopping monitor process [after wait]") return default: } - p.logger.Warn().Err(err).Msg("native process exited, restarting ...") + logger := getClientLogger() + logger.Warn().Err(err).Msg("native process exited, restarting ...") // Wait a bit before restarting time.Sleep(1 * time.Second) // Restart the process if err := p.restartProcess(); err != nil { - p.logger.Error().Err(err).Msg("failed to restart native process") + logger.Error().Err(err).Msg("failed to restart native process") // Wait longer before retrying time.Sleep(5 * time.Second) continue @@ -441,7 +442,7 @@ func (p *NativeProxy) monitorProcess() { // restartProcess restarts the native process func (p *NativeProxy) restartProcess() error { p.restarts++ - logger := p.logger.With().Uint("attempt", p.restarts).Uint("maxAttempts", p.options.MaxRestartAttempts).Logger() + logger := getClientLogger().With().Uint("attempt", p.restarts).Uint("maxAttempts", p.options.MaxRestartAttempts).Logger() if p.restarts >= p.options.MaxRestartAttempts { logger.Fatal().Msgf("max restart attempts reached, exiting: %s", supervisor.FailsafeReasonVideoMaxRestartAttemptsReached) diff --git a/internal/native/server.go b/internal/native/server.go index 5d50e6e66..9e34d48e0 100644 --- a/internal/native/server.go +++ b/internal/native/server.go @@ -37,19 +37,27 @@ func setProcTitle(status string) { gspt.SetProcTitle(title) } -func monitorCrashSignal(ctx context.Context, logger *zerolog.Logger, nativeInstance NativeInterface) { - logger.Info().Msg("DEBUG mode: will crash the process on SIGHUP signal") +func getServerLogger() *zerolog.Logger { + logger := GetNativeLogger(). + With(). + Str("subcomponent", "grpc-server"). + Int("pid", pid). + Logger() + return &logger +} +func monitorCrashSignal(ctx context.Context, nativeInstance NativeInterface) { + getServerLogger().Info().Msg("DEBUG mode: will crash the process on SIGHUP signal") sigChan := make(chan os.Signal, 1) signal.Notify(sigChan, syscall.SIGHUP) for { select { case sig := <-sigChan: - logger.Info().Str("signal", sig.String()).Msg("received termination signal") + getServerLogger().Warn().Str("signal", sig.String()).Msg("received termination signal, crashing the process for testing purposes") nativeInstance.DoNotUseThisIsForCrashTestingOnly() case <-ctx.Done(): - logger.Info().Msg("context done, stopping monitor process") + getServerLogger().Info().Msg("context done, stopping monitor process") return } } @@ -79,88 +87,114 @@ func RunNativeProcess(binaryName string) { appCtx, appCtxCancel := context.WithCancel(context.Background()) defer appCtxCancel() - logger := nativeLogger.With().Int("pid", os.Getpid()).Logger() - setProcTitle("starting") - - // Parse native options - var proxyOptions nativeProxyOptions - if err := env.Parse(&proxyOptions); err != nil { - logger.Fatal().Err(err).Msg("failed to parse native proxy options") - } - - // Connect to video stream socket - conn, err := net.Dial("unix", proxyOptions.VideoStreamUnixSocket) - if err != nil { - logger.Fatal().Err(err).Msg("failed to connect to video stream socket") - } - logger.Info().Str("videoStreamSocketPath", proxyOptions.VideoStreamUnixSocket).Msg("connected to video stream socket") + // for defer clean-up scoping... this is NOT a goroutine + func() { + setProcTitle("starting") + logger := getServerLogger().With().Str("binaryName", binaryName).Logger() + logger.Info().Msg("native process starting") - nativeOptions := proxyOptions.toNativeOptions() - nativeOptions.OnVideoFrameReceived = func(frame []byte, duration time.Duration) { - // Write 4-byte frame length prefix, then frame data - var frameSizeBuffer [4]byte - binary.LittleEndian.PutUint32(frameSizeBuffer[:], uint32(len(frame))) + // Parse native options + var proxyOptions nativeProxyOptions + if err := env.Parse(&proxyOptions); err != nil { + logger.Fatal().Err(err).Msg("failed to parse native proxy options") + return + } + socketPath := fmt.Sprintf("@%v", proxyOptions.CtrlUnixSocket) + logger = logger.With().Interface("proxyOptions", proxyOptions).Str("socketPath", socketPath).Logger() - if _, err := conn.Write(frameSizeBuffer[:]); err != nil { - logger.Fatal().Err(err).Msg("failed to write frame size to video stream socket") + // Connect to video stream socket + conn, err := net.Dial("unix", proxyOptions.VideoStreamUnixSocket) + if err != nil { + logger.Fatal().Err(err).Msg("failed to connect to video stream socket") + return } - if _, err := conn.Write(frame); err != nil { - logger.Fatal().Err(err).Msg("failed to write frame to video stream socket") + defer conn.Close() + logger = logger.With().Interface("local", conn.LocalAddr()).Interface("remote", conn.RemoteAddr()).Logger() + logger.Info().Msg("connected to video stream socket") + + nativeOptions := proxyOptions.toNativeOptions() + nativeOptions.OnVideoFrameReceived = func(frame []byte, duration time.Duration) { + // Write 4-byte frame length prefix, then frame data + var frameSizeBuffer [4]byte + binary.LittleEndian.PutUint32(frameSizeBuffer[:], uint32(len(frame))) + + if _, err := conn.Write(frameSizeBuffer[:]); err != nil { + logger.Fatal().Err(err).Msg("failed to write frame size to video stream socket") + return + } + if _, err := conn.Write(frame); err != nil { + logger.Fatal().Err(err).Msg("failed to write frame to video stream socket") + return + } + } + nativeOptions.OnVideoStateChange = func(state VideoState) { + updateProcessTitle(&state) } - } - nativeOptions.OnVideoStateChange = func(state VideoState) { - updateProcessTitle(&state) - } - // Create native instance - nativeInstance := NewNative(*nativeOptions) - gspt.SetProcTitle("jetkvm: [native] initializing") + setProcTitle("initializing") + logger.Info().Msg("starting native instance") - // Start native instance - if err := nativeInstance.Start(); err != nil { - logger.Fatal().Err(err).Msg("failed to start native instance") - } + // Create and start native instance + nativeInstance := NewNative(*nativeOptions) + if err := nativeInstance.Start(); err != nil { + logger.Fatal().Err(err).Msg("failed to start native instance") + return + } - grpcLogger := logger.With().Str("socketPath", fmt.Sprintf("@%v", proxyOptions.CtrlUnixSocket)).Logger() - setProcTitle("starting gRPC server") - // Create gRPC server - grpcServer := NewGRPCServer(nativeInstance, &grpcLogger) + setProcTitle("starting gRPC server") + logger.Info().Msg("starting gRPC server") - logger.Info().Msg("starting gRPC server") - // Start gRPC server - server, lis, err := StartGRPCServer(grpcServer, fmt.Sprintf("@%v", proxyOptions.CtrlUnixSocket), &logger) - if err != nil { - logger.Fatal().Err(err).Msg("failed to start gRPC server") - } - setProcTitle("ready") + // Create and Start gRPC server + grpcServer := NewGRPCServer(nativeInstance, socketPath) + server, lis, err := grpcServer.Start() - if _, err := os.Stat(DebugModeFile); err == nil { - logger.Info().Msg("DEBUG mode: enabled") - go monitorCrashSignal(appCtx, &logger, nativeInstance) - } + if err != nil { + logger.Fatal().Err(err).Msg("failed to start gRPC server") + return + } - // Signal that we're ready by writing handshake message to stdout (for parent to read) - // Stdout.Write is used to avoid buffering the message - _, err = os.Stdout.Write([]byte(proxyOptions.HandshakeMessage + "\n")) - if err != nil { - logger.Fatal().Err(err).Msg("failed to write handshake message to stdout") - } + defer lis.Close() // close listener after + defer server.Stop() // forceful server stop - // Set up signal handling - sigChan := make(chan os.Signal, 1) - signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) + logger = logger.With().Interface("listener", lis).Logger() - // Wait for signal - sig := <-sigChan - logger.Info(). - Str("signal", sig.String()). - Msg("received termination signal") + setProcTitle("ready") - // Graceful shutdown might stuck forever, - // we will use Stop() instead to force quit the gRPC server, - // we can implement a graceful shutdown with a timeout in the future if needed - server.Stop() - lis.Close() + if _, err := os.Stat(DebugModeFile); err == nil { + logger.Info().Msg("DEBUG mode: enabled") + go monitorCrashSignal(appCtx, nativeInstance) + } - logger.Info().Msg("native process exiting") + // Signal that we're ready by writing handshake message to stdout (for parent to read) + // Stdout.Write is used to avoid buffering the message + _, err = os.Stdout.Write([]byte(proxyOptions.HandshakeMessage + "\n")) + if err != nil { + logger.Fatal().Err(err).Msg("failed to write handshake message to stdout") + return + } + logger.Debug().Msg("wrote handshake message for supervisor") + + // Set up signal handling + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT) + + // Wait for signal + sig := <-sigChan + logger.Info().Str("signal", sig.String()).Msg("received termination signal") + + // Graceful shutdown might take a long time so use a 2 second timeout + shutdownCtx, shutdownCancel := context.WithTimeout(appCtx, 2*time.Second) + defer shutdownCancel() + + // Stop gRPC server + setProcTitle("shutting down gRPC server") + logger.Info().Msg("shutting down gRPC server") + go func() { + server.GracefulStop() + }() + + // Wait for shutdown or timeout (and then the defer will force stop) + <-shutdownCtx.Done() + logger.Info().Msg("native process exiting") + }() } diff --git a/internal/native/video.go b/internal/native/video.go index 176511c69..f5bac68b3 100644 --- a/internal/native/video.go +++ b/internal/native/video.go @@ -53,6 +53,8 @@ func (n *Native) setSleepMode(enabled bool) error { return nil } + logger := GetDisplayLogger().With().Bool("enabled", enabled).Logger() + bEnabled := "0" shouldWait := false if enabled { @@ -60,11 +62,11 @@ func (n *Native) setSleepMode(enabled bool) error { switch videoGetStreamingStatus() { case VideoStreamingStatusActive: - n.l.Info().Msg("stopping video stream to enable sleep mode") + logger.Info().Msg("stopping video stream to enable sleep mode") videoStop() shouldWait = true case VideoStreamingStatusStopping: - n.l.Info().Msg("video stream is stopping, will enable sleep mode in a few seconds") + logger.Info().Msg("video stream is stopping, will enable sleep mode in a few seconds") shouldWait = true } } diff --git a/internal/network/types/config.go b/internal/network/types/config.go index 33afbcc7d..ec0fc0785 100644 --- a/internal/network/types/config.go +++ b/internal/network/types/config.go @@ -89,6 +89,6 @@ func (c *NetworkConfig) GetTransportProxyFunc() func(*http.Request) (*url.URL, e // NetworkConfig interface for backward compatibility type NetworkConfigInterface interface { InterfaceName() string - IPv4Addresses() []IPAddress - IPv6Addresses() []IPAddress + IPv4Addresses() IPAddresses + IPv6Addresses() IPAddresses } diff --git a/internal/network/types/dhcp.go b/internal/network/types/dhcp.go index ff34e2f29..f6abf427f 100644 --- a/internal/network/types/dhcp.go +++ b/internal/network/types/dhcp.go @@ -3,6 +3,8 @@ package types import ( "net" "time" + + "github.com/rs/zerolog" ) // DHCPClient is the interface for a DHCP client. @@ -34,15 +36,15 @@ type DHCPLease struct { BootPServerName string `env:"sname" json:"bootp_server_name,omitempty"` // The bootp server name option BootPFile string `env:"boot_file" json:"bootp_file,omitempty"` // The bootp boot file option Timezone string `env:"timezone" json:"timezone,omitempty"` // Offset in seconds from UTC - Routers []net.IP `env:"router" json:"routers,omitempty"` // A list of routers - DNS []net.IP `env:"dns" json:"dns_servers,omitempty"` // A list of DNS servers - NTPServers []net.IP `env:"ntpsrv" json:"ntp_servers,omitempty"` // A list of NTP servers - LPRServers []net.IP `env:"lprsvr" json:"lpr_servers,omitempty"` // A list of LPR servers - TimeServers []net.IP `env:"timesvr" json:"_time_servers,omitempty"` // A list of time servers (obsolete) - IEN116NameServers []net.IP `env:"namesvr" json:"_name_servers,omitempty"` // A list of IEN 116 name servers (obsolete) - LogServers []net.IP `env:"logsvr" json:"_log_servers,omitempty"` // A list of MIT-LCS UDP log servers (obsolete) - CookieServers []net.IP `env:"cookiesvr" json:"_cookie_servers,omitempty"` // A list of RFC 865 cookie servers (obsolete) - WINSServers []net.IP `env:"wins" json:"_wins_servers,omitempty"` // A list of WINS servers + Routers IPs `env:"router" json:"routers,omitempty"` // A list of routers + DNS IPs `env:"dns" json:"dns_servers,omitempty"` // A list of DNS servers + NTPServers IPs `env:"ntpsrv" json:"ntp_servers,omitempty"` // A list of NTP servers + LPRServers IPs `env:"lprsvr" json:"lpr_servers,omitempty"` // A list of LPR servers + TimeServers IPs `env:"timesvr" json:"_time_servers,omitempty"` // A list of time servers (obsolete) + IEN116NameServers IPs `env:"namesvr" json:"_name_servers,omitempty"` // A list of IEN 116 name servers (obsolete) + LogServers IPs `env:"logsvr" json:"_log_servers,omitempty"` // A list of MIT-LCS UDP log servers (obsolete) + CookieServers IPs `env:"cookiesvr" json:"_cookie_servers,omitempty"` // A list of RFC 865 cookie servers (obsolete) + WINSServers IPs `env:"wins" json:"_wins_servers,omitempty"` // A list of WINS servers SwapServer net.IP `env:"swapsvr" json:"_swap_server,omitempty"` // The IP address of the client's swap server BootSize int `env:"bootsize" json:"bootsize,omitempty"` // The length in 512 octect blocks of the bootfile RootPath string `env:"rootpath" json:"root_path,omitempty"` // The path name of the client's root disk @@ -62,6 +64,49 @@ type DHCPLease struct { DHCPClient string `json:"dhcp_client,omitempty"` // The DHCP client that obtained the lease } +func (d DHCPLease) MarshalZerologObject(e *zerolog.Event) { + e.IPAddr("ip_address", d.IPAddress) + e.IPAddr("Netmask", d.Netmask) + e.IPAddr("broadcast", d.Broadcast) + e.Int("ttl", d.TTL) + e.Int("mtu", d.MTU) + e.Str("hostname", d.HostName) + e.Str("domain", d.Domain) + e.Strs("search_list", d.SearchList) + e.IPAddr("bootp_next_server", d.BootPNextServer) + e.Str("bootp_server_name", d.BootPServerName) + e.Str("bootp_file", d.BootPFile) + e.Str("timezone", d.Timezone) + //TODO IPAddrs + e.Array("routers", d.Routers) + e.Array("dns_servers", d.DNS) + e.Array("ntp_servers", d.NTPServers) + e.Array("lpr_servers", d.LPRServers) + e.Array("time_servers", d.TimeServers) + e.Array("name_servers", d.IEN116NameServers) + e.Array("log_servers", d.LogServers) + e.Array("cookie_servers", d.CookieServers) + e.Array("wins_servers", d.WINSServers) + e.IPAddr("swap_servers", d.SwapServer) + e.Int("bootsize", d.BootSize) + e.Str("root_path", d.RootPath) + e.Dur("lease", d.LeaseTime) + e.Dur("renewal", d.RenewalTime) + e.Dur("rebinding", d.RebindingTime) + e.Str("dhcp_type", d.DHCPType) + e.Str("server_id", d.ServerID) + e.Str("reason", d.Message) + e.Str("tftp", d.TFTPServerName) + e.Str("bootfile", d.BootFileName) + e.Dur("uptime", d.Uptime) + e.Str("class_identifier", d.ClassIdentifier) + if d.LeaseExpiry != nil { + e.Time("lease_expiry", *d.LeaseExpiry) + } + e.Str("interface_name", d.InterfaceName) + e.Str("dhcp_client", d.DHCPClient) +} + // IsIPv6 returns true if the DHCP lease is for an IPv6 address func (d *DHCPLease) IsIPv6() bool { return d.IPAddress.To4() == nil @@ -69,19 +114,21 @@ func (d *DHCPLease) IsIPv6() bool { // IPMask returns the IP mask for the DHCP lease func (d *DHCPLease) IPMask() net.IPMask { - if d.IsIPv6() { - // TODO: not implemented + if d.Netmask == nil { return nil } + if d.IsIPv6() { + return net.IPMask(d.Netmask.To16()) + } + mask := net.ParseIP(d.Netmask.String()) return net.IPv4Mask(mask[12], mask[13], mask[14], mask[15]) } // IPNet returns the IP net for the DHCP lease func (d *DHCPLease) IPNet() *net.IPNet { - if d.IsIPv6() { - // TODO: not implemented + if d.IPAddress == nil || d.Netmask == nil { return nil } diff --git a/internal/network/types/interface.go b/internal/network/types/interface.go index cea9620ba..e30a18b01 100644 --- a/internal/network/types/interface.go +++ b/internal/network/types/interface.go @@ -1,9 +1,9 @@ package types import ( - "net" "time" + "github.com/rs/zerolog" "golang.org/x/sys/unix" ) @@ -21,17 +21,47 @@ type InterfaceState struct { IPv6LinkLocal string `json:"ipv6_link_local,omitempty"` IPv6Gateway string `json:"ipv6_gateway,omitempty"` IPv4Addresses []string `json:"ipv4_addresses,omitempty"` - IPv6Addresses []IPv6Address `json:"ipv6_addresses,omitempty"` - NTPServers []net.IP `json:"ntp_servers,omitempty"` + IPv6Addresses IPv6Addresses `json:"ipv6_addresses,omitempty"` + NTPServers IPs `json:"ntp_servers,omitempty"` DHCPLease4 *DHCPLease `json:"dhcp_lease,omitempty"` DHCPLease6 *DHCPLease `json:"dhcp_lease6,omitempty"` LastUpdated time.Time `json:"last_updated"` } +func (a InterfaceState) MarshalZerologObject(e *zerolog.Event) { + e.Str("interface_name", a.InterfaceName) + e.Str("hostname", a.Hostname) + e.Str("mac", a.MACAddress) + e.Bool("up", a.Up) + e.Bool("online", a.Online) + e.Bool("ipv4_ready", a.IPv4Ready) + e.Bool("ipv6_ready", a.IPv6Ready) + e.Str("ipv4_address", a.IPv4Address) + e.Str("ipv6_address", a.IPv6Address) + e.Str("ipv6_link_local", a.IPv6LinkLocal) + e.Str("ipv6_gateway", a.IPv6Gateway) + e.Strs("ipv4_addresses", a.IPv4Addresses) + e.Array("ipv6_addresses", a.IPv6Addresses) + e.Array("ntp_servers", a.NTPServers) + // DHCP leases can be nil + if a.DHCPLease4 != nil { + e.Object("dhcp_lease", a.DHCPLease4) + } + if a.DHCPLease6 != nil { + e.Object("dhcp_lease6", a.DHCPLease6) + } + e.Time("last_updated", a.LastUpdated) +} + // RpcInterfaceState is the RPC representation of an interface state type RpcInterfaceState struct { InterfaceState - IPv6Addresses []RpcIPv6Address `json:"ipv6_addresses"` + IPv6Addresses RpcIPv6Addresses `json:"ipv6_addresses"` +} + +func (s RpcInterfaceState) MarshalZerologObject(e *zerolog.Event) { + e.Object("interface_state", s.InterfaceState) + e.Array("ipv6_addresses", s.IPv6Addresses) } // ToRpcInterfaceState converts an InterfaceState to a RpcInterfaceState diff --git a/internal/network/types/ip.go b/internal/network/types/ip.go index b0a07bb36..98aa557fd 100644 --- a/internal/network/types/ip.go +++ b/internal/network/types/ip.go @@ -5,6 +5,7 @@ import ( "slices" "time" + "github.com/rs/zerolog" "github.com/vishvananda/netlink" ) @@ -18,6 +19,23 @@ type IPAddress struct { Permanent bool } +func (ip IPAddress) MarshalZerologObject(e *zerolog.Event) { + e.Int("family", ip.Family) + e.IPPrefix("address", ip.Address) + e.IPAddr("gateway", ip.Gateway) + e.Int("mtu", int(ip.MTU)) + e.Bool("secondary", ip.Secondary) + e.Bool("permanent", ip.Permanent) +} + +type IPAddresses []IPAddress + +func (addrs IPAddresses) MarshalZerologArray(e *zerolog.Array) { + for _, addr := range addrs { + e.Object(&addr) + } +} + func (a *IPAddress) String() string { return a.Address.String() } @@ -46,16 +64,33 @@ func (a *IPAddress) DefaultRoute(linkIndex int) netlink.Route { } } +type IPs []net.IP + +func (addrs IPs) MarshalZerologArray(e *zerolog.Array) { + for _, addr := range addrs { + e.IPAddr(addr) + } +} + // ParsedIPConfig represents the parsed IP configuration type ParsedIPConfig struct { - Addresses []IPAddress - Nameservers []net.IP + Addresses IPAddresses + Nameservers IPs SearchList []string Domain string MTU int Interface string } +func (a ParsedIPConfig) MarshalZerologObject(e *zerolog.Event) { + e.Array("addresses", a.Addresses) + e.Array("nameservers", a.Nameservers) + e.Strs("search_list", a.SearchList) + e.Str("domain", a.Domain) + e.Int("mtu", a.MTU) + e.Str("interface", a.Interface) +} + // IPv6Address represents an IPv6 address with lifetime information type IPv6Address struct { Address net.IP `json:"address"` @@ -66,6 +101,27 @@ type IPv6Address struct { Scope int `json:"scope"` } +func (a IPv6Address) MarshalZerologObject(e *zerolog.Event) { + e.IPAddr("address", a.Address) + e.IPPrefix("prefix", a.Prefix) + if a.ValidLifetime != nil { + e.Time("valid_lifetime", *a.ValidLifetime) + } + if a.PreferredLifetime != nil { + e.Time("preferred_lifetime", *a.PreferredLifetime) + } + e.Int("flags", a.Flags) + e.Int("scope", a.Scope) +} + +type IPv6Addresses []IPv6Address + +func (addrs IPv6Addresses) MarshalZerologArray(e *zerolog.Array) { + for _, addr := range addrs { + e.Object(addr) + } +} + // RpcIPv6Address is the RPC representation of an IPv6 address type RpcIPv6Address struct { Address string `json:"address"` @@ -83,3 +139,32 @@ type RpcIPv6Address struct { FlagDADFailed bool `json:"flag_dad_failed"` FlagTentative bool `json:"flag_tentative"` } + +func (a RpcIPv6Address) MarshalZerologObject(e *zerolog.Event) { + e.Str("address", a.Address) + e.Str("prefix", a.Prefix) + if a.ValidLifetime != nil { + e.Time("valid_lifetime", *a.ValidLifetime) + } + if a.PreferredLifetime != nil { + e.Time("preferred_lifetime", *a.PreferredLifetime) + } + e.Int("scope", a.Scope) + e.Int("flags", a.Flags) + e.Bool("flag_secondary", a.FlagSecondary) + e.Bool("flag_permanent", a.FlagPermanent) + e.Bool("flag_temporary", a.FlagTemporary) + e.Bool("flag_stable_privacy", a.FlagStablePrivacy) + e.Bool("flag_deprecated", a.FlagDeprecated) + e.Bool("flag_optimistic", a.FlagOptimistic) + e.Bool("flag_dad_failed", a.FlagDADFailed) + e.Bool("flag_tentative", a.FlagTentative) +} + +type RpcIPv6Addresses []RpcIPv6Address + +func (addrs RpcIPv6Addresses) MarshalZerologArray(e *zerolog.Array) { + for _, addr := range addrs { + e.Object(addr) + } +} diff --git a/internal/ota/app.go b/internal/ota/app.go index 55caa8e8a..37c6b2b30 100644 --- a/internal/ota/app.go +++ b/internal/ota/app.go @@ -12,10 +12,10 @@ const ( // DO NOT call it directly, it's not thread safe // Mutex is currently held by the caller, e.g. doUpdate func (s *State) updateApp(ctx context.Context, appUpdate *componentUpdateStatus) error { - l := s.l.With().Str("path", appUpdatePath).Logger() + logger := GetOtaLogger().With().Str("path", appUpdatePath).Logger() - if err := s.downloadFile(ctx, appUpdatePath, appUpdate.url, "app"); err != nil { - return s.componentUpdateError("Error downloading app update", err, &l) + if err := s.downloadFile(ctx, appUpdatePath, appUpdate.url, "app", &logger); err != nil { + return s.componentUpdateError("Error downloading app update", err, &logger) } downloadFinished := time.Now() @@ -27,8 +27,9 @@ func (s *State) updateApp(ctx context.Context, appUpdate *componentUpdateStatus) appUpdatePath, appUpdate.hash, &appUpdate.verificationProgress, + &logger, ); err != nil { - return s.componentUpdateError("Error verifying app update hash", err, &l) + return s.componentUpdateError("Error verifying app update hash", err, &logger) } verifyFinished := time.Now() appUpdate.verifiedAt = verifyFinished @@ -37,9 +38,8 @@ func (s *State) updateApp(ctx context.Context, appUpdate *componentUpdateStatus) appUpdate.updateProgress = 1 s.triggerComponentUpdateState("app", appUpdate) - l.Info().Msg("App update downloaded") + logger.Info().Msg("App update downloaded") s.rebootNeeded = true - return nil } diff --git a/internal/ota/errors.go b/internal/ota/errors.go index a1d0b4c5d..7d69b7d04 100644 --- a/internal/ota/errors.go +++ b/internal/ota/errors.go @@ -12,11 +12,8 @@ var ( ErrVersionNotFound = errors.New("specified version not found") ) -func (s *State) componentUpdateError(prefix string, err error, l *zerolog.Logger) error { - if l == nil { - l = s.l - } - l.Error().Err(err).Msg(prefix) +func (s *State) componentUpdateError(prefix string, err error, logger *zerolog.Logger) error { + logger.Error().Err(err).Msg(prefix) s.error = fmt.Sprintf("%s: %v", prefix, err) s.updating = false s.triggerStateUpdate() diff --git a/internal/ota/log.go b/internal/ota/log.go new file mode 100644 index 000000000..b19550752 --- /dev/null +++ b/internal/ota/log.go @@ -0,0 +1,10 @@ +package ota + +import ( + "github.com/jetkvm/kvm/internal/logging" + "github.com/rs/zerolog" +) + +func GetOtaLogger() *zerolog.Logger { + return logging.GetSubsystemLogger("ota") +} diff --git a/internal/ota/ota.go b/internal/ota/ota.go index 52b38c7d0..221938653 100644 --- a/internal/ota/ota.go +++ b/internal/ota/ota.go @@ -12,6 +12,7 @@ import ( "net/url" "time" + "github.com/jetkvm/kvm/internal/logging" "github.com/rs/zerolog" ) @@ -60,31 +61,43 @@ func (s *State) getUpdateURL(params UpdateParams) (string, error, bool) { // newHTTPRequestWithTrace creates a new HTTP request with a trace logger // TODO: use OTEL instead of doing this manually -func (s *State) newHTTPRequestWithTrace(ctx context.Context, method, url string, body io.Reader, logger func() *zerolog.Event) (*http.Request, error) { +func (s *State) newHTTPRequestWithTrace(ctx context.Context, method, url string, body io.Reader, logger *zerolog.Logger) (*http.Request, error) { localCtx := ctx - if s.l.GetLevel() <= zerolog.TraceLevel { - if logger == nil { - logger = func() *zerolog.Event { return s.l.Trace() } - } - - l := func() *zerolog.Event { return logger().Str("url", url).Str("method", method) } + if logging.IsTraceLevel(logger) { + trace := func() *zerolog.Event { return logger.Trace().Str("url", url).Str("method", method) } localCtx = httptrace.WithClientTrace(localCtx, &httptrace.ClientTrace{ - GetConn: func(hostPort string) { l().Str("hostPort", hostPort).Msg("[conn] starting to create conn") }, - GotConn: func(info httptrace.GotConnInfo) { l().Interface("info", info).Msg("[conn] connection established") }, - PutIdleConn: func(err error) { l().Err(err).Msg("[conn] connection returned to idle pool") }, - GotFirstResponseByte: func() { l().Msg("[resp] first response byte received") }, - Got100Continue: func() { l().Msg("[resp] 100 continue received") }, - DNSStart: func(info httptrace.DNSStartInfo) { l().Interface("info", info).Msg("[dns] starting to look up dns") }, - DNSDone: func(info httptrace.DNSDoneInfo) { l().Interface("info", info).Msg("[dns] done looking up dns") }, + GetConn: func(hostPort string) { + trace().Str("hostPort", hostPort).Msg("[conn] starting to create conn") + }, + GotConn: func(info httptrace.GotConnInfo) { + trace().Interface("info", info).Msg("[conn] connection established") + }, + PutIdleConn: func(err error) { + trace().Err(err).Msg("[conn] connection returned to idle pool") + }, + GotFirstResponseByte: func() { + trace().Msg("[resp] first response byte received") + }, + Got100Continue: func() { + trace().Msg("[resp] 100 continue received") + }, + DNSStart: func(info httptrace.DNSStartInfo) { + trace().Interface("info", info).Msg("[dns] starting to look up dns") + }, + DNSDone: func(info httptrace.DNSDoneInfo) { + trace().Interface("info", info).Msg("[dns] done looking up dns") + }, ConnectStart: func(network, addr string) { - l().Str("network", network).Str("addr", addr).Msg("[tcp] starting tcp connection") + trace().Str("network", network).Str("addr", addr).Msg("[tcp] starting tcp connection") }, ConnectDone: func(network, addr string, err error) { - l().Str("network", network).Str("addr", addr).Err(err).Msg("[tcp] tcp connection created") + trace().Str("network", network).Str("addr", addr).Err(err).Msg("[tcp] tcp connection created") + }, + TLSHandshakeStart: func() { + trace().Msg("[tls] handshake started") }, - TLSHandshakeStart: func() { l().Msg("[tls] handshake started") }, TLSHandshakeDone: func(state tls.ConnectionState, err error) { - l(). + trace(). Str("tlsVersion", tls.VersionName(state.Version)). Str("cipherSuite", tls.CipherSuiteName(state.CipherSuite)). Str("negotiatedProtocol", state.NegotiatedProtocol). @@ -100,27 +113,20 @@ func (s *State) newHTTPRequestWithTrace(ctx context.Context, method, url string, func (s *State) fetchUpdateMetadata(ctx context.Context, params UpdateParams) (*UpdateMetadata, error) { metadata := &UpdateMetadata{} - logger := s.l.With().Logger() - if params.RequestID != "" { - logger = logger.With().Str("requestID", params.RequestID).Logger() - } + logger := GetOtaLogger().With().Interface("params", params).Logger() t := time.Now() traceLogger := func() *zerolog.Event { return logger.Trace().Dur("duration", time.Since(t)) } url, err, isCustomVersion := s.getUpdateURL(params) - traceLogger().Err(err). - Msg("fetchUpdateMetadata: getUpdateURL") if err != nil { + traceLogger().Err(err).Msg("fetchUpdateMetadata: getUpdateURL") return nil, fmt.Errorf("error getting update URL: %w", err) } - traceLogger(). - Str("url", url). - Msg("fetching update metadata") - - req, err := s.newHTTPRequestWithTrace(ctx, "GET", url, nil, traceLogger) + traceLogger().Str("url", url).Bool("custom_version", isCustomVersion).Msg("fetching update metadata") + req, err := s.newHTTPRequestWithTrace(ctx, "GET", url, nil, &logger) if err != nil { return nil, fmt.Errorf("error creating request: %w", err) } @@ -133,9 +139,7 @@ func (s *State) fetchUpdateMetadata(ctx context.Context, params UpdateParams) (* } defer resp.Body.Close() - traceLogger(). - Int("status", resp.StatusCode). - Msg("fetchUpdateMetadata: response") + traceLogger().Int("status", resp.StatusCode).Msg("fetchUpdateMetadata: response") if isCustomVersion && resp.StatusCode == http.StatusNotFound { return nil, ErrVersionNotFound @@ -150,9 +154,7 @@ func (s *State) fetchUpdateMetadata(ctx context.Context, params UpdateParams) (* return nil, fmt.Errorf("error decoding response: %w", err) } - traceLogger(). - Msg("fetchUpdateMetadata: completed") - + traceLogger().Msg("fetchUpdateMetadata: completed") return metadata, nil } @@ -173,19 +175,16 @@ func (s *State) TryUpdate(ctx context.Context, params UpdateParams) error { return fmt.Errorf("update already in progress") } + defer s.mu.Unlock() return s.doUpdate(ctx, params) } // before calling doUpdate, the caller must have locked the mutex // otherwise a runtime error will occur func (s *State) doUpdate(ctx context.Context, params UpdateParams) error { - defer s.mu.Unlock() - - scopedLogger := s.l.With(). - Interface("params", params). - Logger() + logger := GetOtaLogger().With().Interface("params", params).Logger() + logger.Info().Msg("checking for updates") - scopedLogger.Info().Msg("checking for updates") if s.updating { return fmt.Errorf("update already in progress") } @@ -204,13 +203,13 @@ func (s *State) doUpdate(ctx context.Context, params UpdateParams) error { return s.componentUpdateError( "Update aborted: no components were specified to update. Requested components: ", fmt.Errorf("%v", params.Components), - &scopedLogger, + &logger, ) } appUpdate, systemUpdate, err := s.getUpdateStatus(ctx, params) if err != nil { - return s.componentUpdateError("Error checking for updates", err, &scopedLogger) + return s.componentUpdateError("Error checking for updates", err, &logger) } s.metadataFetchedAt = time.Now() @@ -229,54 +228,66 @@ func (s *State) doUpdate(ctx context.Context, params UpdateParams) error { } if !appUpdate.pending && !systemUpdate.pending { - scopedLogger.Info().Msg("No updates available") + logger.Info().Msg("No updates available") s.updating = false s.triggerStateUpdate() return nil } - scopedLogger.Trace().Bool("pending", appUpdate.pending).Msg("Checking for app update") + logger.Trace().Bool("pending", appUpdate.pending).Msg("Checking for app update") if appUpdate.pending { - scopedLogger.Info(). + appLogger := logger. + With(). + Str("subcomponent", "app"). Str("url", appUpdate.url). Str("hash", appUpdate.hash). - Msg("App update available") + Logger() + + appLogger.Info().Msg("App update available") if err := s.updateApp(ctx, appUpdate); err != nil { - return s.componentUpdateError("Error updating app", err, &scopedLogger) + return s.componentUpdateError("Error updating app", err, &appLogger) } } else { - scopedLogger.Info().Msg("App is up to date") + logger.Info().Msg("App is up to date") } - scopedLogger.Trace().Bool("pending", systemUpdate.pending).Msg("Checking for system update") + logger.Trace().Bool("pending", systemUpdate.pending).Msg("Checking for system update") if systemUpdate.pending { + systemLogger := logger. + With(). + Str("subcomponent", "system"). + Str("url", systemUpdate.url). + Str("hash", systemUpdate.hash). + Logger() + + systemLogger.Info().Msg("System update available") + if err := s.updateSystem(ctx, systemUpdate); err != nil { - return s.componentUpdateError("Error updating system", err, &scopedLogger) + return s.componentUpdateError("Error updating system", err, &systemLogger) } } else { - scopedLogger.Info().Msg("System is up to date") + logger.Info().Msg("System is up to date") } if s.rebootNeeded { if appUpdate.customVersionUpdate || systemUpdate.customVersionUpdate { - scopedLogger.Info().Msg("disabling auto-update due to custom version update") + logger.Info().Msg("disabling auto-update due to custom version update") // If they are explicitly updating a custom version, we assume they want to disable auto-update if _, err := s.setAutoUpdate(false); err != nil { - scopedLogger.Warn().Err(err).Msg("Failed to disable auto-update") + logger.Warn().Err(err).Msg("Failed to disable auto-update") } } - scopedLogger.Info().Msg("System Rebooting due to OTA update") - + logger.Info().Msg("System Rebooting due to OTA update") redirectUrl := "/settings/general/update" if params.ResetConfig { - scopedLogger.Info().Msg("Resetting config") + logger.Warn().Msg("Resetting config") if err := s.resetConfig(); err != nil { - return s.componentUpdateError("Error resetting config", err, &scopedLogger) + return s.componentUpdateError("Error resetting config", err, &logger) } redirectUrl = "/welcome" } @@ -290,7 +301,7 @@ func (s *State) doUpdate(ctx context.Context, params UpdateParams) error { // it means that healthCheckUrl will be called after 7 seconds that we send willReboot JSONRPC event // so we need to reboot it within 7 seconds to avoid it being called before the device is rebooted if err := s.reboot(true, postRebootAction, 5*time.Second); err != nil { - return s.componentUpdateError("Error requesting reboot", err, &scopedLogger) + return s.componentUpdateError("Error requesting reboot", err, &logger) } } @@ -351,6 +362,7 @@ func (s *State) checkUpdateStatus( systemUpdateStatus *componentUpdateStatus, ) error { // get the local versions + t := time.Now() systemVersionLocal, appVersionLocal, err := s.getLocalVersion() if err != nil { return fmt.Errorf("error getting local version: %w", err) @@ -358,22 +370,19 @@ func (s *State) checkUpdateStatus( appUpdateStatus.localVersion = appVersionLocal.String() systemUpdateStatus.localVersion = systemVersionLocal.String() - logger := s.l.With().Logger() - if params.RequestID != "" { - logger = logger.With().Str("requestID", params.RequestID).Logger() - } - t := time.Now() + logger := GetOtaLogger(). + With(). + Interface("params", params). + Stringer("appVersionLocal", appVersionLocal). + Stringer("systemVersionLocal", systemVersionLocal). + Logger() - logger.Trace(). - Str("appVersionLocal", appVersionLocal.String()). - Str("systemVersionLocal", systemVersionLocal.String()). - Dur("duration", time.Since(t)). - Msg("checkUpdateStatus: getLocalVersion") + logger.Trace().Dur("duration", time.Since(t)).Msg("checkUpdateStatus: getLocalVersion") // fetch the remote metadata remoteMetadata, err := s.fetchUpdateMetadata(ctx, params) if err != nil { - if err == ErrVersionNotFound || errors.Unwrap(err) == ErrVersionNotFound { + if errors.Is(err, ErrVersionNotFound) { err = ErrVersionNotFound } else { err = fmt.Errorf("error checking for updates: %w", err) @@ -395,6 +404,9 @@ func (s *State) checkUpdateStatus( ); err != nil { return fmt.Errorf("error parsing remote app version: %w", err) } + logger.Trace().Str("subcomponent", "app"). + Interface("componentUpdateStatus", appUpdateStatus). + Msg("checkUpdateStatus: remoteMetadataToComponentStatus") if err := remoteMetadataToComponentStatus( remoteMetadata, @@ -404,11 +416,10 @@ func (s *State) checkUpdateStatus( ); err != nil { return fmt.Errorf("error parsing remote system version: %w", err) } - - if s.l.GetLevel() <= zerolog.TraceLevel { - appUpdateStatus.getZerologLogger(&logger).Trace().Msg("checkUpdateStatus: remoteMetadataToComponentStatus [app]") - systemUpdateStatus.getZerologLogger(&logger).Trace().Msg("checkUpdateStatus: remoteMetadataToComponentStatus [system]") - } + logger.Trace(). + Str("subcomponent", "system"). + Interface("componentUpdateStatus", systemUpdateStatus). + Msg("checkUpdateStatus: remoteMetadataToComponentStatus") logger.Trace(). Dur("duration", time.Since(t)). diff --git a/internal/ota/ota_test.go b/internal/ota/ota_test.go index 2c8ce661d..5f9004355 100644 --- a/internal/ota/ota_test.go +++ b/internal/ota/ota_test.go @@ -11,14 +11,12 @@ import ( "io" "net/http" "net/url" - "os" "path/filepath" "testing" "time" "github.com/Masterminds/semver/v3" "github.com/gwatts/rootcerts" - "github.com/rs/zerolog" "github.com/stretchr/testify/assert" ) @@ -173,15 +171,8 @@ func newOtaState(d *testData, t *testing.T) *State { return systemVersion, appVersion, nil } - traceLevel := zerolog.InfoLevel - - if os.Getenv("TEST_LOG_TRACE") == "1" { - traceLevel = zerolog.TraceLevel - } - logger := zerolog.New(os.Stdout).Level(traceLevel) otaState := NewState(Options{ SkipConfirmSystem: true, - Logger: &logger, ReleaseAPIEndpoint: releaseAPIEndpoint, GetHTTPClient: func() HttpClient { if d.RemoteMetadata != nil { diff --git a/internal/ota/state.go b/internal/ota/state.go index 2bb7055ee..915445e26 100644 --- a/internal/ota/state.go +++ b/internal/ota/state.go @@ -5,7 +5,6 @@ import ( "time" "github.com/Masterminds/semver/v3" - "github.com/rs/zerolog" ) var ( @@ -71,27 +70,6 @@ type componentUpdateStatus struct { verifiedAt time.Time updateProgress float32 updatedAt time.Time - dependsOn []string -} - -func (c *componentUpdateStatus) getZerologLogger(l *zerolog.Logger) *zerolog.Logger { - logger := l.With(). - Bool("pending", c.pending). - Bool("available", c.available). - Str("availableReason", c.availableReason). - Str("version", c.version). - Str("localVersion", c.localVersion). - Str("url", c.url). - Str("hash", c.hash). - Float32("downloadProgress", c.downloadProgress). - Time("downloadFinishedAt", c.downloadFinishedAt). - Float32("verificationProgress", c.verificationProgress). - Time("verifiedAt", c.verifiedAt). - Float32("updateProgress", c.updateProgress). - Time("updatedAt", c.updatedAt). - Strs("dependsOn", c.dependsOn). - Logger() - return &logger } // HwRebootFunc is a function that reboots the hardware @@ -118,7 +96,6 @@ type GetLocalVersionFunc func() (systemVersion *semver.Version, appVersion *semv // State represents the current OTA state for the UI type State struct { releaseAPIEndpoint string - l *zerolog.Logger mu sync.Mutex updating bool error string @@ -178,7 +155,6 @@ func (s *State) IsUpdatePending() bool { // Options represents the options for the OTA state type Options struct { - Logger *zerolog.Logger GetHTTPClient GetHTTPClientFunc GetLocalVersion GetLocalVersionFunc OnStateUpdate OnStateUpdateFunc @@ -198,7 +174,6 @@ func NewState(opts Options) *State { } s := &State{ - l: opts.Logger, client: opts.GetHTTPClient, reboot: opts.HwReboot, onStateUpdate: opts.OnStateUpdate, diff --git a/internal/ota/sys.go b/internal/ota/sys.go index 6a5002f6f..472ad45e6 100644 --- a/internal/ota/sys.go +++ b/internal/ota/sys.go @@ -14,42 +14,52 @@ const ( // DO NOT call it directly, it's not thread safe // Mutex is currently held by the caller, e.g. doUpdate func (s *State) updateSystem(ctx context.Context, systemUpdate *componentUpdateStatus) error { - l := s.l.With().Str("path", systemUpdatePath).Logger() + logger := GetOtaLogger(). + With(). + Str("subcomponent", "system").Str("path", systemUpdatePath). + Logger() - if err := s.downloadFile(ctx, systemUpdatePath, systemUpdate.url, "system"); err != nil { - return s.componentUpdateError("Error downloading system update", err, &l) + downloadStarted := time.Now() + if err := s.downloadFile(ctx, systemUpdatePath, systemUpdate.url, "system", &logger); err != nil { + return s.componentUpdateError("Error downloading system update", err, &logger) } - downloadFinished := time.Now() + logger.Info().Dur("download_time", downloadFinished.Sub(downloadStarted)).Msg("update downloaded") + systemUpdate.downloadFinishedAt = downloadFinished systemUpdate.downloadProgress = 1 + systemUpdate.updateProgress = 0.25 s.triggerComponentUpdateState("system", systemUpdate) + verifyStarted := time.Now() if err := s.verifyFile( systemUpdatePath, systemUpdate.hash, &systemUpdate.verificationProgress, + &logger, ); err != nil { - return s.componentUpdateError("Error verifying system update hash", err, &l) + return s.componentUpdateError("Error verifying system update hash", err, &logger) } verifyFinished := time.Now() + logger.Info().Dur("verification_time", verifyFinished.Sub(verifyStarted)).Msg("update verified") + systemUpdate.verifiedAt = verifyFinished systemUpdate.verificationProgress = 1 - systemUpdate.updatedAt = verifyFinished - systemUpdate.updateProgress = 1 + systemUpdate.updatedAt = verifyFinished // TODO, this seems wrong here + systemUpdate.updateProgress = 0.5 s.triggerComponentUpdateState("system", systemUpdate) - l.Info().Msg("System update downloaded") - - l.Info().Msg("Starting rk_ota command") + logger.Info().Msg("Starting rk_ota command") + upgradeStarted := time.Now() cmd := exec.Command("rk_ota", "--misc=update", "--tar_path=/userdata/jetkvm/update_system.tar", "--save_dir=/userdata/jetkvm/ota_save", "--partition=all") var b bytes.Buffer cmd.Stdout = &b cmd.Stderr = &b if err := cmd.Start(); err != nil { - return s.componentUpdateError("Error starting rk_ota command", err, &l) + return s.componentUpdateError("Error starting rk_ota command", err, &logger) } + ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -76,13 +86,19 @@ func (s *State) updateSystem(ctx context.Context, systemUpdate *componentUpdateS err := cmd.Wait() cancel() - rkLogger := s.l.With(). + + upgradeFinished := time.Now() + logger.Info().Dur("upgrade_time", upgradeFinished.Sub(upgradeStarted)).Msg("upgrade completed") + + logger = logger. + With(). Str("output", b.String()). - Int("exitCode", cmd.ProcessState.ExitCode()).Logger() + Int("exitCode", cmd.ProcessState.ExitCode()). + Logger() if err != nil { - return s.componentUpdateError("Error executing rk_ota command", err, &rkLogger) + return s.componentUpdateError("Error executing rk_ota command", err, &logger) } - rkLogger.Info().Msg("rk_ota success") + logger.Info().Msg("rk_ota success") s.rebootNeeded = true systemUpdate.updateProgress = 1 @@ -93,9 +109,10 @@ func (s *State) updateSystem(ctx context.Context, systemUpdate *componentUpdateS } func (s *State) confirmCurrentSystem() { + logger := GetOtaLogger().With().Str("action", "confirmCurrentSystem").Logger() output, err := exec.Command("rk_ota", "--misc=now").CombinedOutput() if err != nil { - s.l.Warn().Str("output", string(output)).Msg("failed to set current partition in A/B setup") + logger.Warn().Str("output", string(output)).Err(err).Msg("failed to set current partition in A/B setup") } - s.l.Trace().Str("output", string(output)).Msg("current partition in A/B setup set") + logger.Trace().Str("output", string(output)).Msg("current partition in A/B setup set") } diff --git a/internal/ota/utils.go b/internal/ota/utils.go index b03db3420..82f32a834 100644 --- a/internal/ota/utils.go +++ b/internal/ota/utils.go @@ -28,12 +28,14 @@ func syncFilesystem() error { return nil } -func (s *State) downloadFile(ctx context.Context, path string, url string, component string) error { - logger := s.l.With(). +func (s *State) downloadFile(ctx context.Context, path string, url string, component string, l *zerolog.Logger) error { + logger := l. + With(). Str("path", path). Str("url", url). Str("downloadComponent", component). Logger() + t := time.Now() traceLogger := func() *zerolog.Event { return logger.Trace().Dur("duration", time.Since(t)) @@ -70,7 +72,7 @@ func (s *State) downloadFile(ctx context.Context, path string, url string, compo defer file.Close() traceLogger().Msg("creating request") - req, err := s.newHTTPRequestWithTrace(ctx, "GET", url, nil, traceLogger) + req, err := s.newHTTPRequestWithTrace(ctx, "GET", url, nil, &logger) if err != nil { return fmt.Errorf("error creating request: %w", err) } @@ -129,9 +131,8 @@ func (s *State) downloadFile(ctx context.Context, path string, url string, compo return nil } -func (s *State) verifyFile(path string, expectedHash string, verifyProgress *float32) error { - l := s.l.With().Str("path", path).Logger() +func (s *State) verifyFile(path string, expectedHash string, verifyProgress *float32, logger *zerolog.Logger) error { unverifiedPath := path + ".unverified" fileToHash, err := os.Open(unverifiedPath) if err != nil { @@ -175,7 +176,7 @@ func (s *State) verifyFile(path string, expectedHash string, verifyProgress *flo } hashSum := hash.Sum(nil) - l.Info().Str("hash", hex.EncodeToString(hashSum)).Msg("SHA256 hash of") + logger.Info().Str("path", path).Str("hash", hex.EncodeToString(hashSum)).Msg("SHA256 hash of") if hex.EncodeToString(hashSum) != expectedHash { return fmt.Errorf("hash mismatch: %x != %s", hashSum, expectedHash) diff --git a/internal/sync/log.go b/internal/sync/log.go index 36d0b29c8..7a2faab50 100644 --- a/internal/sync/log.go +++ b/internal/sync/log.go @@ -13,46 +13,37 @@ import ( "github.com/rs/zerolog" ) -var defaultLogger = logging.GetSubsystemLogger("synctrace") - -func logTrace(msg string) { - if defaultLogger.GetLevel() > zerolog.TraceLevel { - return - } - - logTrack(3).Trace().Msg(msg) +func getLogger() *zerolog.Logger { + return logging.GetSubsystemLogger("synctrace") } -func logTrack(callerSkip int) *zerolog.Logger { - l := *defaultLogger - if l.GetLevel() > zerolog.TraceLevel { - return &l +func logTrack(callerSkip int) *zerolog.Event { + logger := getLogger() + if !logging.IsTraceLevel(logger) { + return logger.Trace() } + traceEvent := logger.Trace() + pc, file, no, ok := runtime.Caller(callerSkip) if ok { - l = l.With(). - Str("file", file). - Int("line", no). - Logger() + traceEvent = traceEvent.Str("file", file).Int("line", no) details := runtime.FuncForPC(pc) if details != nil { - l = l.With(). - Str("func", details.Name()). - Logger() + traceEvent = traceEvent.Str("func", details.Name()) } } - return &l + return traceEvent +} + +func logTrace(msg string) { + logTrack(3).Msg(msg) } -func logLockTrack(i string) *zerolog.Logger { - l := logTrack(4). - With(). - Str("index", i). - Logger() - return &l +func logLockTrace(i string) *zerolog.Event { + return logTrack(4).Str("index", i) } var ( @@ -99,18 +90,18 @@ func increaseUnlockCount(i string) { func logLock(t trackable) { i := getIndex(t) increaseLockCount(i) - logLockTrack(i).Trace().Msg("locking mutex") + logLockTrace(i).Msg("locking mutex") } func logUnlock(t trackable) { i := getIndex(t) increaseUnlockCount(i) - logLockTrack(i).Trace().Msg("unlocking mutex") + logLockTrace(i).Msg("unlocking mutex") } func logTryLock(t trackable) { i := getIndex(t) - logLockTrack(i).Trace().Msg("trying to lock mutex") + logLockTrace(i).Msg("trying to lock mutex") } func logTryLockResult(t trackable, l bool) { @@ -119,24 +110,24 @@ func logTryLockResult(t trackable, l bool) { } i := getIndex(t) increaseLockCount(i) - logLockTrack(i).Trace().Msg("locked mutex") + logLockTrace(i).Msg("locked mutex") } func logRLock(t trackable) { i := getIndex(t) increaseLockCount(i) - logLockTrack(i).Trace().Msg("locking mutex for reading") + logLockTrace(i).Msg("locking mutex for reading") } func logRUnlock(t trackable) { i := getIndex(t) increaseUnlockCount(i) - logLockTrack(i).Trace().Msg("unlocking mutex for reading") + logLockTrace(i).Msg("unlocking mutex for reading") } func logTryRLock(t trackable) { i := getIndex(t) - logLockTrack(i).Trace().Msg("trying to lock mutex for reading") + logLockTrace(i).Msg("trying to lock mutex for reading") } func logTryRLockResult(t trackable, l bool) { @@ -145,5 +136,5 @@ func logTryRLockResult(t trackable, l bool) { } i := getIndex(t) increaseLockCount(i) - logLockTrack(i).Trace().Msg("locked mutex for reading") + logLockTrace(i).Msg("locked mutex for reading") } diff --git a/internal/sync/once.go b/internal/sync/once.go index e2d90affd..e059df77a 100644 --- a/internal/sync/once.go +++ b/internal/sync/once.go @@ -13,6 +13,10 @@ type Once struct { // Do calls the function f if and only if Do has not been called before for this instance of Once. func (o *Once) Do(f func()) { - logTrace("Doing once") - o.mu.Do(f) + g := func() { + logTrace("Doing once") + f() + logTrace("Once done") + } + o.mu.Do(g) } diff --git a/internal/timesync/http.go b/internal/timesync/http.go index 4375e2a4b..dfc5d6804 100644 --- a/internal/timesync/http.go +++ b/internal/timesync/http.go @@ -22,7 +22,8 @@ var defaultHTTPUrls = []string{ func (t *TimeSync) queryAllHttpTime(httpUrls []string) (now *time.Time) { chunkSize := int(t.networkConfig.TimeSyncParallel.ValueOr(4)) - t.l.Info().Strs("httpUrls", httpUrls).Int("chunkSize", chunkSize).Msg("querying HTTP URLs") + logger := GetTimesyncLogger() + logger.Info().Strs("httpUrls", httpUrls).Int("chunkSize", chunkSize).Msg("querying HTTP URLs") // shuffle the http urls to avoid always querying the same servers rand.Shuffle(len(httpUrls), func(i, j int) { httpUrls[i], httpUrls[j] = httpUrls[j], httpUrls[i] }) @@ -43,13 +44,12 @@ func (t *TimeSync) queryMultipleHttp(urls []string, timeout time.Duration) (now ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() + logger := GetTimesyncLogger().With().Int("urls_count", len(urls)).Dur("timeout", timeout).Logger() for _, url := range urls { - go func(url string) { - scopedLogger := t.l.With(). - Str("http_url", url). - Logger() + loopLogger := logger.With().Str("http_url", url).Logger() + go func(url string) { metricHttpRequestCount.WithLabelValues(url).Inc() metricHttpTotalRequestCount.Inc() @@ -85,22 +85,23 @@ func (t *TimeSync) queryMultipleHttp(urls []string, timeout time.Duration) (now if requestId == "" { requestId = response.Header.Get("Cf-Ray") } - scopedLogger.Info(). - Str("time", now.Format(time.RFC3339)). + loopLogger.Info(). + Time("time", *now). Int("status", status). Str("request_id", requestId). - Str("time_taken", duration.String()). + Dur("time_taken", duration). Msg("HTTP server returned time") cancel() + results <- now } else if errors.Is(err, context.Canceled) { metricHttpCancelCount.WithLabelValues(url).Inc() metricHttpTotalCancelCount.Inc() results <- nil } else { - scopedLogger.Warn(). - Str("error", err.Error()). + loopLogger.Warn(). + Err(err). Int("status", status). Msg("failed to query HTTP server") results <- nil diff --git a/internal/timesync/log.go b/internal/timesync/log.go new file mode 100644 index 000000000..fa1cd073e --- /dev/null +++ b/internal/timesync/log.go @@ -0,0 +1,10 @@ +package timesync + +import ( + "github.com/jetkvm/kvm/internal/logging" + "github.com/rs/zerolog" +) + +func GetTimesyncLogger() *zerolog.Logger { + return logging.GetSubsystemLogger("timesync") +} diff --git a/internal/timesync/ntp.go b/internal/timesync/ntp.go index 7ff410b01..260221dd8 100644 --- a/internal/timesync/ntp.go +++ b/internal/timesync/ntp.go @@ -43,25 +43,27 @@ func (t *TimeSync) filterNTPServers(ntpServers []string) ([]string, error) { return nil, nil } + logger := GetTimesyncLogger() hasIPv4, err := t.preCheckIPv4() if err != nil { - t.l.Error().Err(err).Msg("failed to check IPv4") + logger.Error().Err(err).Msg("failed to check IPv4") return nil, err } hasIPv6, err := t.preCheckIPv6() if err != nil { - t.l.Error().Err(err).Msg("failed to check IPv6") + logger.Error().Err(err).Msg("failed to check IPv6") return nil, err } filteredServers := []string{} for _, server := range ntpServers { ip := net.ParseIP(server) - t.l.Trace().Str("server", server).Interface("ip", ip).Msg("checking NTP server") if ip == nil { + logger.Trace().Str("server", server).Msg("server didn't parse as IP, skipping") continue } + logger.Trace().Str("server", server).IPAddr("ip", ip).Msg("going to check NTP server") if hasIPv4 && ip.To4() != nil { filteredServers = append(filteredServers, server) @@ -74,14 +76,15 @@ func (t *TimeSync) filterNTPServers(ntpServers []string) ([]string, error) { } func (t *TimeSync) queryNetworkTime(ntpServers []string) (now *time.Time, offset *time.Duration) { + logger := GetTimesyncLogger() ntpServers, err := t.filterNTPServers(ntpServers) if err != nil { - t.l.Error().Err(err).Msg("failed to filter NTP servers") + logger.Error().Err(err).Msg("failed to filter NTP servers") return nil, nil } chunkSize := int(t.networkConfig.TimeSyncParallel.ValueOr(4)) - t.l.Info().Strs("servers", ntpServers).Int("chunkSize", chunkSize).Msg("querying NTP servers") + logger.Info().Strs("servers", ntpServers).Int("chunkSize", chunkSize).Msg("querying NTP servers") // shuffle the ntp servers to avoid always querying the same servers rand.Shuffle(len(ntpServers), func(i, j int) { ntpServers[i], ntpServers[j] = ntpServers[j], ntpServers[i] }) @@ -108,11 +111,12 @@ func (t *TimeSync) queryMultipleNTP(servers []string, timeout time.Duration) (no _, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() + logger := GetTimesyncLogger() + logger.Info().Strs("servers", servers).Dur("timeout", timeout).Msg("querying chunk") + for _, server := range servers { go func(server string) { - scopedLogger := t.l.With(). - Str("server", server). - Logger() + loopLogger := logger.With().Str("server", server).Logger() // increase request count metricNtpTotalRequestCount.Inc() @@ -121,17 +125,14 @@ func (t *TimeSync) queryMultipleNTP(servers []string, timeout time.Duration) (no // query the server now, response, err := queryNtpServer(server, timeout) if err != nil { - scopedLogger.Warn(). - Str("error", err.Error()). - Msg("failed to query NTP server") + loopLogger.Warn().Err(err).Msg("failed to query NTP server") results <- nil return } if response.IsKissOfDeath() { - scopedLogger.Warn(). - Str("kiss_code", response.KissCode). - Msg("ignoring NTP server kiss of death") + // TODO back-signal to avoid querying this server if DENY or RSTR, deprioritize if RATE + loopLogger.Warn().Str("kiss_code", response.KissCode).Msg("ignoring NTP server kiss of death") results <- nil return } @@ -160,11 +161,11 @@ func (t *TimeSync) queryMultipleNTP(servers []string, timeout time.Duration) (no metricNtpTotalSuccessCount.Inc() metricNtpSuccessCount.WithLabelValues(server).Inc() - scopedLogger.Info(). - Str("time", now.Format(time.RFC3339)). + loopLogger.Info(). + Time("now", *now). Str("reference", response.ReferenceString()). Float64("rtt", rtt). - Str("clockOffset", response.ClockOffset.String()). + Dur("clockOffset", response.ClockOffset). Uint8("stratum", response.Stratum). Msg("NTP server returned time") diff --git a/internal/timesync/rtc_linux.go b/internal/timesync/rtc_linux.go index 27e4ec79c..aca8dfdd4 100644 --- a/internal/timesync/rtc_linux.go +++ b/internal/timesync/rtc_linux.go @@ -90,9 +90,9 @@ func (t *TimeSync) setRtcTime(tu time.Time) error { return fmt.Errorf("failed to read RTC time: %w", err) } - t.l.Info(). - Interface("rtc_time", tu). - Str("offset", tu.Sub(currentRtcTime).String()). + GetTimesyncLogger().Info(). + Time("rtc_time", tu). + Dur("offset", tu.Sub(currentRtcTime)). Msg("set rtc time") if err := unix.IoctlSetRTCTime(fd, &rt); err != nil { diff --git a/internal/timesync/timesync.go b/internal/timesync/timesync.go index 97cee97d8..198ca114d 100644 --- a/internal/timesync/timesync.go +++ b/internal/timesync/timesync.go @@ -8,7 +8,6 @@ import ( "time" "github.com/jetkvm/kvm/internal/network/types" - "github.com/rs/zerolog" ) const ( @@ -28,7 +27,6 @@ type PreCheckFunc func() (bool, error) type TimeSync struct { syncLock *sync.Mutex - l *zerolog.Logger networkConfig *types.NetworkConfig dhcpNtpAddresses []string @@ -49,7 +47,6 @@ type TimeSyncOptions struct { PreCheckFunc PreCheckFunc PreCheckIPv4 PreCheckFunc PreCheckIPv6 PreCheckFunc - Logger *zerolog.Logger NetworkConfig *types.NetworkConfig } @@ -62,16 +59,16 @@ type SyncMode struct { } func NewTimeSync(opts *TimeSyncOptions) *TimeSync { + logger := GetTimesyncLogger() rtcDevice, err := getRtcDevicePath() if err != nil { - opts.Logger.Error().Err(err).Msg("failed to get RTC device path") + logger.Error().Err(err).Msg("failed to get RTC device path") } else { - opts.Logger.Info().Str("path", rtcDevice).Msg("RTC device found") + logger.Info().Str("path", rtcDevice).Msg("RTC device found") } t := &TimeSync{ syncLock: &sync.Mutex{}, - l: opts.Logger, dhcpNtpAddresses: []string{}, rtcDevicePath: rtcDevice, rtcLock: &sync.Mutex{}, @@ -84,7 +81,7 @@ func NewTimeSync(opts *TimeSyncOptions) *TimeSync { if t.rtcDevicePath != "" { rtcTime, _ := t.readRtcTime() - t.l.Info().Interface("rtc_time", rtcTime).Msg("read RTC time") + logger.Info().Interface("rtc_time", rtcTime).Msg("read RTC time") } return t @@ -122,35 +119,27 @@ func (t *TimeSync) getSyncMode() SyncMode { } } - t.l.Debug(). - Strs("Ordering", syncMode.Ordering). - Bool("Ntp", syncMode.Ntp). - Bool("Http", syncMode.Http). - Bool("NtpUseFallback", syncMode.NtpUseFallback). - Bool("HttpUseFallback", syncMode.HttpUseFallback). - Msg("sync mode") - return syncMode } + func (t *TimeSync) timeSyncLoop() { metricTimeSyncStatus.Set(0) - // use a timer here instead of sleep - for range t.timer.C { + logger := GetTimesyncLogger() if ok, err := t.preCheckFunc(); !ok { if err != nil { - t.l.Error().Err(err).Msg("pre-check failed") + logger.Error().Err(err).Msg("pre-check failed") } t.timer.Reset(timeSyncWaitNetChkInt) continue } - t.l.Info().Msg("syncing system time") + logger.Info().Msg("syncing system time") start := time.Now() err := t.sync() if err != nil { - t.l.Error().Str("error", err.Error()).Msg("failed to sync system time") + logger.Error().Err(err).Msg("failed to sync system time") // retry after a delay timeSyncRetryInterval += timeSyncRetryStep @@ -165,8 +154,9 @@ func (t *TimeSync) timeSyncLoop() { isInitialSync := !t.syncSuccess t.syncSuccess = true - t.l.Info().Str("now", time.Now().Format(time.RFC3339)). - Str("time_taken", time.Since(start).String()). + logger.Info(). + Time("now", time.Now()). + Dur("time_taken", time.Since(start)). Bool("is_initial_sync", isInitialSync). Msg("time sync successful") @@ -183,20 +173,20 @@ func (t *TimeSync) sync() error { var ( now *time.Time offset *time.Duration - log zerolog.Logger ) metricTimeSyncCount.Inc() syncMode := t.getSyncMode() + logger := GetTimesyncLogger().With().Interface("sync_mode", syncMode).Logger() Orders: for _, mode := range syncMode.Ordering { - log = t.l.With().Str("mode", mode).Logger() + loopLogger := logger.With().Str("mode", mode).Logger() switch mode { case "ntp_user_provided": if syncMode.Ntp { - log.Info().Msg("using NTP custom servers") + loopLogger.Info().Msg("using NTP custom servers") now, offset = t.queryNetworkTime(t.networkConfig.TimeSyncNTPServers) if now != nil { break Orders @@ -204,7 +194,7 @@ Orders: } case "ntp_dhcp": if syncMode.Ntp { - log.Info().Msg("using NTP servers from DHCP") + loopLogger.Info().Msg("using NTP servers from DHCP") now, offset = t.queryNetworkTime(t.dhcpNtpAddresses) if now != nil { break Orders @@ -212,10 +202,10 @@ Orders: } case "ntp": if syncMode.Ntp && syncMode.NtpUseFallback { - log.Info().Msg("using NTP fallback IPs") + loopLogger.Info().Msg("using NTP fallback IPs") now, offset = t.queryNetworkTime(DefaultNTPServerIPs) if now == nil { - log.Info().Msg("using NTP fallback hostnames") + loopLogger.Info().Msg("using NTP fallback hostnames") now, offset = t.queryNetworkTime(DefaultNTPServerHostnames) } if now != nil { @@ -224,7 +214,7 @@ Orders: } case "http_user_provided": if syncMode.Http { - log.Info().Msg("using HTTP custom URLs") + loopLogger.Info().Msg("using HTTP custom URLs") now = t.queryAllHttpTime(t.networkConfig.TimeSyncHTTPUrls) if now != nil { break Orders @@ -232,14 +222,14 @@ Orders: } case "http": if syncMode.Http && syncMode.HttpUseFallback { - log.Info().Msg("using HTTP fallback") + loopLogger.Info().Msg("using HTTP fallback") now = t.queryAllHttpTime(defaultHTTPUrls) if now != nil { break Orders } } default: - log.Warn().Msg("unknown time sync mode, skipping") + loopLogger.Warn().Msg("unknown time sync mode, skipping") } } @@ -248,11 +238,12 @@ Orders: } if offset != nil { + logger = logger.With().Dur("offset", *offset).Logger() newNow := time.Now().Add(*offset) now = &newNow } - log.Info().Time("now", *now).Msg("time obtained") + logger.Info().Time("now", *now).Msg("time obtained") err := t.setSystemTime(*now) if err != nil { @@ -260,14 +251,13 @@ Orders: } metricTimeSyncSuccessCount.Inc() - return nil } // Sync triggers a manual time sync func (t *TimeSync) Sync() error { if !t.syncLock.TryLock() { - t.l.Warn().Msg("sync already in progress, skipping") + GetTimesyncLogger().Warn().Msg("sync already in progress, skipping") return nil } t.syncLock.Unlock() @@ -289,7 +279,7 @@ func (t *TimeSync) setSystemTime(now time.Time) error { nowStr := now.Format("2006-01-02 15:04:05") output, err := exec.Command("date", "-s", nowStr).CombinedOutput() if err != nil { - return fmt.Errorf("failed to run date -s: %w, %s", err, string(output)) + return fmt.Errorf("failed to run date -s %s: %w, %s", nowStr, err, string(output)) } if t.rtcDevicePath != "" { diff --git a/internal/tzdata/gen.go b/internal/tzdata/gen.go index 7c168f14c..8f769d0ec 100644 --- a/internal/tzdata/gen.go +++ b/internal/tzdata/gen.go @@ -12,8 +12,10 @@ import ( ) var tmpl = `// Code generated by "go run gen.go". DO NOT EDIT. +// //go:generate env ZONEINFO=$GOROOT/lib/time/zoneinfo.zip go run gen.go -output tzdata.go package tzdata + var TimeZones = []string{ {{- range . }} "{{.}}", diff --git a/internal/tzdata/tzdata.go b/internal/tzdata/tzdata.go index 368c7205f..1d58ae769 100644 --- a/internal/tzdata/tzdata.go +++ b/internal/tzdata/tzdata.go @@ -1,6 +1,8 @@ // Code generated by "go run gen.go". DO NOT EDIT. +// //go:generate env ZONEINFO=$GOROOT/lib/time/zoneinfo.zip go run gen.go -output tzdata.go package tzdata + var TimeZones = []string{ "Africa/Abidjan", "Africa/Accra", diff --git a/internal/usbgadget/changeset.go b/internal/usbgadget/changeset.go index 57f5d7de6..ddb9b9a35 100644 --- a/internal/usbgadget/changeset.go +++ b/internal/usbgadget/changeset.go @@ -9,6 +9,7 @@ import ( "time" "github.com/prometheus/procfs" + "github.com/rs/zerolog" "github.com/sourcegraph/tf-dag/dag" ) @@ -194,15 +195,15 @@ func (fc *FileChange) checkIfDirIsMountPoint() error { } // GetActualState returns the actual state of the file at the given path. -func (fc *FileChange) getActualState() error { - l := defaultLogger.With().Str("path", fc.Path).Logger() +func (fc *FileChange) getActualState(l *zerolog.Logger) error { + logger := l.With().Str("path", fc.Path).Logger() fi, err := os.Lstat(fc.Path) if err != nil { if os.IsNotExist(err) { fc.ActualState = FileStateAbsent } else { - l.Warn().Err(err).Msg("failed to stat file") + logger.Warn().Err(err).Msg("failed to stat file") fc.ActualState = FileStateUnknown } return nil @@ -214,7 +215,7 @@ func (fc *FileChange) getActualState() error { // get the target of the symlink target, err := os.Readlink(fc.Path) if err != nil { - l.Warn().Err(err).Msg("failed to read symlink") + logger.Warn().Err(err).Msg("failed to read symlink") return fmt.Errorf("failed to read symlink") } // check if the target is a relative path @@ -222,7 +223,7 @@ func (fc *FileChange) getActualState() error { // make it absolute target, err = filepath.Abs(filepath.Join(filepath.Dir(fc.Path), target)) if err != nil { - l.Warn().Err(err).Msg("failed to make symlink target absolute") + logger.Warn().Err(err).Msg("failed to make symlink target absolute") return fmt.Errorf("failed to make symlink target absolute") } } @@ -237,13 +238,13 @@ func (fc *FileChange) getActualState() error { case FileStateMountedConfigFS: err := fc.checkIfDirIsMountPoint() if err != nil { - l.Warn().Err(err).Msg("failed to check if dir is mount point") + logger.Warn().Err(err).Msg("failed to check if dir is mount point") return err } case FileStateSymlinkInOrderConfigFS: - state, err := checkIfSymlinksInOrder(fc, &l) + state, err := checkIfSymlinksInOrder(fc, &logger) if err != nil { - l.Warn().Err(err).Msg("failed to check if symlinks are in order") + logger.Warn().Err(err).Msg("failed to check if symlinks are in order") return err } fc.ActualState = state @@ -252,7 +253,7 @@ func (fc *FileChange) getActualState() error { } if fi.Mode()&os.ModeDevice == os.ModeDevice { - l.Info().Msg("file is a device") + logger.Info().Msg("file is a device") return nil } @@ -262,15 +263,14 @@ func (fc *FileChange) getActualState() error { // get the content of the file content, err := os.ReadFile(fc.Path) if err != nil { - l.Warn().Err(err).Msg("failed to read file") + logger.Warn().Err(err).Msg("failed to read file") return fmt.Errorf("failed to read file") } fc.ActualContent = content return nil } - l.Warn().Interface("file_info", fi.Mode()).Bool("is_dir", fi.IsDir()).Msg("unknown file type") - + logger.Warn().Interface("file_info", fi.Mode()).Bool("is_dir", fi.IsDir()).Msg("unknown file type") return fmt.Errorf("unknown file type") } @@ -280,18 +280,16 @@ func (fc *FileChange) ResetActionResolution() { fc.changed = ChangeStateUnknown } -func (fc *FileChange) Action() FileChangeResolvedAction { +func (fc *FileChange) Action(logger *zerolog.Logger) FileChangeResolvedAction { if !fc.checked { - fc.action = fc.getFileChangeResolvedAction() + fc.action = fc.getFileChangeResolvedAction(logger) fc.checked = true } return fc.action } -func (fc *FileChange) getFileChangeResolvedAction() FileChangeResolvedAction { - l := defaultLogger.With().Str("path", fc.Path).Logger() - +func (fc *FileChange) getFileChangeResolvedAction(l *zerolog.Logger) FileChangeResolvedAction { // some actions are not needed to be checked switch fc.ExpectedState { case FileStateFileWrite: @@ -300,8 +298,10 @@ func (fc *FileChange) getFileChangeResolvedAction() FileChangeResolvedAction { return FileChangeResolvedActionTouch } + logger := l.With().Interface("expected_state", FileStateString[fc.ExpectedState]).Logger() + // get the actual state of the file - err := fc.getActualState() + err := fc.getActualState(&logger) if err != nil { return FileChangeResolvedActionDoNothing } @@ -348,7 +348,7 @@ func (fc *FileChange) getFileChangeResolvedAction() FileChangeResolvedAction { } return FileChangeResolvedActionCreateSymlink case FileStateSymlinkInOrderConfigFS: - // if the file is already a symlink, check if the target is the same + // if the file is already a symlink to configfs, check if the target is the same if fc.ActualState == FileStateSymlinkInOrderConfigFS { return FileChangeResolvedActionDoNothing } @@ -364,7 +364,7 @@ func (fc *FileChange) getFileChangeResolvedAction() FileChangeResolvedAction { } return FileChangeResolvedActionMountConfigFS default: - l.Warn().Interface("file_change", FileStateString[fc.ExpectedState]).Msg("unknown expected state") + logger.Warn().Interface("file_change", FileStateString[fc.ExpectedState]).Msg("unknown expected state") return FileChangeResolvedActionDoNothing } } @@ -387,18 +387,19 @@ func (c *ChangeSet) AddFileChange(component string, path string, expectedState F }) } -func (c *ChangeSet) ApplyChanges() error { +func (c *ChangeSet) ApplyChanges(logger *zerolog.Logger) error { r := ChangeSetResolver{ changeset: c, g: &dag.AcyclicGraph{}, - l: defaultLogger, } - return r.Apply() + return r.Apply(logger) } -func (c *ChangeSet) applyChange(change *FileChange) error { - switch change.Action() { +func (c *ChangeSet) applyChange(change *FileChange, logger *zerolog.Logger) error { + action := change.Action(logger) + + switch action { case FileChangeResolvedActionWriteFile: return os.WriteFile(change.Path, change.ExpectedContent, 0644) case FileChangeResolvedActionUpdateFile: @@ -413,7 +414,7 @@ func (c *ChangeSet) applyChange(change *FileChange) error { } return os.Symlink(string(change.ExpectedContent), change.Path) case FileChangeResolvedActionReorderSymlinks: - return recreateSymlinks(change, nil) + return recreateSymlinks(change, logger) case FileChangeResolvedActionCreateDirectory: return os.MkdirAll(change.Path, 0755) case FileChangeResolvedActionRemove: @@ -427,10 +428,10 @@ func (c *ChangeSet) applyChange(change *FileChange) error { case FileChangeResolvedActionDoNothing: return nil default: - return fmt.Errorf("unknown action: %d", change.Action()) + return fmt.Errorf("unknown action: %d", action) } } -func (c *ChangeSet) Apply() error { - return c.ApplyChanges() +func (c *ChangeSet) Apply(logger *zerolog.Logger) error { + return c.ApplyChanges(logger) } diff --git a/internal/usbgadget/changeset_arm_test.go b/internal/usbgadget/changeset_arm_test.go index 8c0abd54f..73a6e2984 100644 --- a/internal/usbgadget/changeset_arm_test.go +++ b/internal/usbgadget/changeset_arm_test.go @@ -75,20 +75,20 @@ var oldAbsoluteMouseCombinedReportDesc = []byte{ func TestUsbGadgetInit(t *testing.T) { assert := assert.New(t) - usbGadget = NewUsbGadget(usbGadgetName, usbDevices, usbConfig, nil) + usbGadget = NewUsbGadget(usbGadgetName, usbDevices, usbConfig) assert.NotNil(usbGadget) } func TestUsbGadgetStrictModeInitFail(t *testing.T) { usbConfig.strictMode = true - u := NewUsbGadget("test", usbDevices, usbConfig, nil) + u := NewUsbGadget("test", usbDevices, usbConfig) assert.Nil(t, u, "should be nil") } func TestUsbGadgetUDCNotBoundAfterReportDescrChanged(t *testing.T) { assert := assert.New(t) - usbGadget = NewUsbGadget(usbGadgetName, usbDevices, usbConfig, nil) + usbGadget = NewUsbGadget(usbGadgetName, usbDevices, usbConfig) assert.NotNil(usbGadget) // release the usb gadget and create a new one @@ -100,7 +100,7 @@ func TestUsbGadgetUDCNotBoundAfterReportDescrChanged(t *testing.T) { oldAbsoluteMouseConfig.reportDesc = oldAbsoluteMouseCombinedReportDesc altGadgetConfig["absolute_mouse"] = oldAbsoluteMouseConfig - usbGadget = newUsbGadget(usbGadgetName, altGadgetConfig, usbDevices, usbConfig, nil) + usbGadget = newUsbGadget(usbGadgetName, altGadgetConfig, usbDevices, usbConfig) assert.NotNil(usbGadget) udcs := getUdcs() diff --git a/internal/usbgadget/changeset_resolver.go b/internal/usbgadget/changeset_resolver.go index 67812e0d6..d36e63e3b 100644 --- a/internal/usbgadget/changeset_resolver.go +++ b/internal/usbgadget/changeset_resolver.go @@ -3,15 +3,14 @@ package usbgadget import ( "fmt" + "github.com/jetkvm/kvm/internal/logging" "github.com/rs/zerolog" "github.com/sourcegraph/tf-dag/dag" ) type ChangeSetResolver struct { changeset *ChangeSet - - l *zerolog.Logger - g *dag.AcyclicGraph + g *dag.AcyclicGraph changesMap map[string]*FileChange conditionalChangesMap map[string]*FileChange @@ -43,13 +42,13 @@ func (c *ChangeSetResolver) toOrderedChanges() error { return nil } -func (c *ChangeSetResolver) doResolveChanges(initial bool) error { +func (c *ChangeSetResolver) doResolveChanges(initial bool, logger *zerolog.Logger) error { resolvedChanges := make([]*FileChange, 0) for _, key := range c.orderedChanges { change := c.changesMap[key.(string)] if change == nil { - c.l.Error().Str("key", key.(string)).Msg("fileChange not found") + logger.Error().Str("key", key.(string)).Msg("fileChange not found") continue } @@ -57,7 +56,7 @@ func (c *ChangeSetResolver) doResolveChanges(initial bool) error { change.ResetActionResolution() } - resolvedAction := change.Action() + resolvedAction := change.Action(logger) resolvedChanges = append(resolvedChanges, change) // no need to check the triggers if there's no change @@ -89,7 +88,7 @@ func (c *ChangeSetResolver) doResolveChanges(initial bool) error { return nil } -func (c *ChangeSetResolver) resolveChanges(initial bool) error { +func (c *ChangeSetResolver) resolveChanges(initial bool, logger *zerolog.Logger) error { // get the ordered changes err := c.toOrderedChanges() if err != nil { @@ -97,39 +96,45 @@ func (c *ChangeSetResolver) resolveChanges(initial bool) error { } // resolve the changes - err = c.doResolveChanges(initial) + err = c.doResolveChanges(initial, logger) if err != nil { return err } - for _, change := range c.resolvedChanges { - c.l.Trace().Str("change", change.String()).Msg("resolved change") + if logging.IsTraceLevel(logger) { + for _, change := range c.resolvedChanges { + logger.Trace().Stringer("change", change).Msg("resolved change") + } } if !c.additionalResolveRequired || !initial { return nil } - return c.resolveChanges(false) + return c.resolveChanges(false, logger) } -func (c *ChangeSetResolver) applyChanges() error { +func (c *ChangeSetResolver) applyChanges(l *zerolog.Logger) error { for _, change := range c.resolvedChanges { change.ResetActionResolution() - action := change.Action() + + action := change.Action(l) actionStr := FileChangeResolvedActionString[action] - l := c.l.Info() + logger := l.With().Str("action", actionStr).Stringer("change", change).Logger() + var event *zerolog.Event if action == FileChangeResolvedActionDoNothing { - l = c.l.Trace() + event = logger.Trace() + } else { + event = logger.Debug() } - l.Str("action", actionStr).Str("change", change.String()).Msg("applying change") + event.Msg("applying change") - err := c.changeset.applyChange(change) + err := c.changeset.applyChange(change, &logger) if err != nil { if change.IgnoreErrors { - c.l.Warn().Str("change", change.String()).Err(err).Msg("ignoring error") + logger.Warn().Err(err).Msg("ignoring error") } else { return err } @@ -139,7 +144,7 @@ func (c *ChangeSetResolver) applyChanges() error { return nil } -func (c *ChangeSetResolver) GetChanges() ([]*FileChange, error) { +func (c *ChangeSetResolver) GetChanges(logger *zerolog.Logger) ([]*FileChange, error) { localChanges := c.changeset.Changes changesMap := make(map[string]*FileChange) conditionalChangesMap := make(map[string]*FileChange) @@ -175,7 +180,7 @@ func (c *ChangeSetResolver) GetChanges() ([]*FileChange, error) { c.changesMap = changesMap c.conditionalChangesMap = conditionalChangesMap - err := c.resolveChanges(true) + err := c.resolveChanges(true, logger) if err != nil { return nil, err } @@ -183,10 +188,10 @@ func (c *ChangeSetResolver) GetChanges() ([]*FileChange, error) { return c.resolvedChanges, nil } -func (c *ChangeSetResolver) Apply() error { - if _, err := c.GetChanges(); err != nil { +func (c *ChangeSetResolver) Apply(logger *zerolog.Logger) error { + if _, err := c.GetChanges(logger); err != nil { return err } - return c.applyChanges() + return c.applyChanges(logger) } diff --git a/internal/usbgadget/changeset_symlink.go b/internal/usbgadget/changeset_symlink.go index d94c75944..2f91d4d1e 100644 --- a/internal/usbgadget/changeset_symlink.go +++ b/internal/usbgadget/changeset_symlink.go @@ -23,11 +23,8 @@ func compareSymlinks(expected []symlink, actual []symlink) bool { return reflect.DeepEqual(expected, actual) } -func checkIfSymlinksInOrder(fc *FileChange, logger *zerolog.Logger) (FileState, error) { - if logger == nil { - logger = defaultLogger - } - l := logger.With().Str("path", fc.Path).Logger() +func checkIfSymlinksInOrder(fc *FileChange, l *zerolog.Logger) (FileState, error) { + logger := l.With().Str("path", fc.Path).Logger() if len(fc.ParamSymlinks) == 0 { return FileStateUnknown, fmt.Errorf("no symlinks to check") @@ -39,7 +36,7 @@ func checkIfSymlinksInOrder(fc *FileChange, logger *zerolog.Logger) (FileState, if os.IsNotExist(err) { return FileStateAbsent, nil } else { - l.Warn().Err(err).Msg("failed to stat file") + logger.Warn().Err(err).Msg("failed to stat file") return FileStateUnknown, fmt.Errorf("failed to stat file") } } @@ -85,40 +82,40 @@ func checkIfSymlinksInOrder(fc *FileChange, logger *zerolog.Logger) (FileState, return FileStateSymlinkInOrderConfigFS, nil } - l.Trace().Interface("expected", fc.ParamSymlinks).Interface("actual", symlinks).Msg("symlinks are not in order") + logger.Trace().Interface("expected", fc.ParamSymlinks).Interface("actual", symlinks).Msg("symlinks are not in order") return FileStateSymlinkNotInOrderConfigFS, nil } -func recreateSymlinks(fc *FileChange, logger *zerolog.Logger) error { - if logger == nil { - logger = defaultLogger - } +func recreateSymlinks(fc *FileChange, l *zerolog.Logger) error { // remove all symlinks files, err := os.ReadDir(fc.Path) if err != nil { return fmt.Errorf("failed to read directory") } - l := logger.With().Str("path", fc.Path).Logger() - l.Info().Msg("recreate symlinks") + logger := l.With().Str("path", fc.Path).Logger() + logger.Info().Msg("recreate symlinks") for _, file := range files { if file.Type()&os.ModeSymlink != os.ModeSymlink { continue } - l.Info().Str("name", file.Name()).Msg("remove symlink") + + logger.Info().Str("name", file.Name()).Msg("remove symlink") err := os.Remove(path.Join(fc.Path, file.Name())) if err != nil { return fmt.Errorf("failed to remove symlink") } } - l.Info().Interface("param-symlinks", fc.ParamSymlinks).Msg("create symlinks") + logger = logger.With().Interface("param-symlinks", fc.ParamSymlinks).Logger() + logger.Debug().Msg("create symlinks") // create the symlinks for _, symlink := range fc.ParamSymlinks { - l.Info().Str("name", symlink.Path).Str("target", symlink.Target).Msg("create symlink") + eachLogger := logger.With().Str("name", symlink.Path).Str("target", symlink.Target).Logger() + eachLogger.Debug().Msg("create symlink") path := symlink.Path if !filepath.IsAbs(path) { @@ -127,7 +124,7 @@ func recreateSymlinks(fc *FileChange, logger *zerolog.Logger) error { err := os.Symlink(symlink.Target, path) if err != nil { - l.Warn().Err(err).Msg("failed to create symlink") + eachLogger.Warn().Err(err).Msg("failed to create symlink") return fmt.Errorf("failed to create symlink") } } diff --git a/internal/usbgadget/config.go b/internal/usbgadget/config.go index 6d1bd391b..3320ee527 100644 --- a/internal/usbgadget/config.go +++ b/internal/usbgadget/config.go @@ -61,18 +61,18 @@ var defaultGadgetConfig = map[string]gadgetConfigItem{ "mass_storage_lun0": massStorageLun0Config, } -func (u *UsbGadget) isGadgetConfigItemEnabled(itemKey string) bool { +func (enabledDevices *Devices) isGadgetConfigItemEnabled(itemKey string) bool { switch itemKey { case "absolute_mouse": - return u.enabledDevices.AbsoluteMouse + return enabledDevices.AbsoluteMouse case "relative_mouse": - return u.enabledDevices.RelativeMouse + return enabledDevices.RelativeMouse case "keyboard": - return u.enabledDevices.Keyboard + return enabledDevices.Keyboard case "mass_storage_base": - return u.enabledDevices.MassStorage + return enabledDevices.MassStorage case "mass_storage_lun0": - return u.enabledDevices.MassStorage + return enabledDevices.MassStorage default: return true } @@ -80,7 +80,7 @@ func (u *UsbGadget) isGadgetConfigItemEnabled(itemKey string) bool { func (u *UsbGadget) loadGadgetConfig() { if u.customConfig.isEmpty { - u.log.Trace().Msg("using default gadget config") + u.getUsbGadgetLogger().Trace().Msg("using default gadget config") return } @@ -115,15 +115,6 @@ func (u *UsbGadget) SetGadgetDevices(devices *Devices) { u.enabledDevices = *devices } -// GetConfigPath returns the path to the config item. -func (u *UsbGadget) GetConfigPath(itemKey string) (string, error) { - item, ok := u.configMap[itemKey] - if !ok { - return "", fmt.Errorf("config item %s not found", itemKey) - } - return joinPath(u.kvmGadgetPath, item.configPath), nil -} - // GetPath returns the path to the item. func (u *UsbGadget) GetPath(itemKey string) (string, error) { item, ok := u.configMap[itemKey] @@ -136,24 +127,34 @@ func (u *UsbGadget) GetPath(itemKey string) (string, error) { // OverrideGadgetConfig overrides the gadget config for the given item and attribute. // It returns an error if the item is not found or the attribute is not found. // It returns true if the attribute is overridden, false otherwise. -func (u *UsbGadget) OverrideGadgetConfig(itemKey string, itemAttr string, value string) (error, bool) { +func (u *UsbGadget) OverrideGadgetConfig(itemKey string, itemAttr string, value string) (bool, error) { u.configLock.Lock() defer u.configLock.Unlock() + logger := u.getUsbGadgetLogger(). + With(). + Str("itemKey", itemKey). + Str("itemAttr", itemAttr). + Str("value", value). + Logger() + // get it as a pointer _, ok := u.configMap[itemKey] if !ok { - return fmt.Errorf("config item %s not found", itemKey), false + err := fmt.Errorf("config item %s not found", itemKey) + logger.Error().Err(err).Msg("overriding gadget config") + return false, err } if u.configMap[itemKey].attrs[itemAttr] == value { - return nil, false + logger.Trace().Msg("unchanged gadget config") + return false, nil } u.configMap[itemKey].attrs[itemAttr] = value - u.log.Info().Str("itemKey", itemKey).Str("itemAttr", itemAttr).Str("value", value).Msg("overriding gadget config") + logger.Info().Msg("overriding gadget config") - return nil, true + return true, nil } func mountConfigFS(path string) error { @@ -200,12 +201,12 @@ func (u *UsbGadget) UpdateGadgetConfig() error { } func (u *UsbGadget) configureUsbGadget(resetUsb bool) error { - return u.WithTransaction(func() error { - u.tx.MountConfigFS() - u.tx.CreateConfigPath() - u.tx.WriteGadgetConfig() + return u.WithTransaction(func(u *UsbGadget, tx *UsbGadgetTransaction) error { + tx.MountConfigFS() + tx.CreateConfigPath(u.configC1Path) + tx.WriteGadgetConfig(u.kvmGadgetPath, u.configC1Path, u.udc, u.getOrderedConfigItems(), &u.enabledDevices) if resetUsb { - u.tx.RebindUsb(true) + tx.RebindUsb(u.udc, true) } return nil }) diff --git a/internal/usbgadget/config_tx.go b/internal/usbgadget/config_tx.go index df8a3d1b9..42fc28c4c 100644 --- a/internal/usbgadget/config_tx.go +++ b/internal/usbgadget/config_tx.go @@ -6,74 +6,70 @@ import ( "path/filepath" "sort" + "github.com/jetkvm/kvm/internal/logging" "github.com/rs/zerolog" ) // no os package should occur in this file type UsbGadgetTransaction struct { - c *ChangeSet - - // below are the fields that are needed to be set by the caller - log *zerolog.Logger - udc string - dwc3Path string - kvmGadgetPath string - configC1Path string - orderedConfigItems orderedGadgetConfigItems - isGadgetConfigItemEnabled func(key string) bool - + c *ChangeSet reorderSymlinkChanges *RequestedFileChange } -func (u *UsbGadget) newUsbGadgetTransaction(lock bool) error { - if lock { - u.txLock.Lock() - defer u.txLock.Unlock() - } - - if u.tx != nil { - return fmt.Errorf("transaction already exists") - } - +func (u *UsbGadget) newUsbGadgetTransaction() *UsbGadgetTransaction { tx := &UsbGadgetTransaction{ - c: &ChangeSet{}, - log: u.log, - udc: u.udc, - dwc3Path: dwc3Path, - kvmGadgetPath: u.kvmGadgetPath, - configC1Path: u.configC1Path, - orderedConfigItems: u.getOrderedConfigItems(), - isGadgetConfigItemEnabled: u.isGadgetConfigItemEnabled, + c: &ChangeSet{}, } - u.tx = tx - - return nil + return tx } -func (u *UsbGadget) WithTransaction(fn func() error) error { +func (u *UsbGadget) WithTransaction(fn func(u2 *UsbGadget, tx *UsbGadgetTransaction) error) error { u.txLock.Lock() defer u.txLock.Unlock() - err := u.newUsbGadgetTransaction(false) - if err != nil { - u.log.Error().Err(err).Msg("failed to create transaction") + logger := u.getUsbGadgetLogger().With().Str("udc", u.udc).Logger() + logger.Info().Msg("starting USB gadget transaction") + + tx := u.newUsbGadgetTransaction() + if err := fn(u, tx); err != nil { + logger.Error().Err(err).Msg("transaction failed") return err } - if err := fn(); err != nil { - u.log.Error().Err(err).Msg("transaction failed") - return err + + err := tx.Commit() + logger.Trace().Err(err).Msg("committed transaction") + return err +} + +func (u *UsbGadget) getOrderedConfigItems() orderedGadgetConfigItems { + items := make([]gadgetConfigItemWithKey, 0) + for key, item := range u.configMap { + items = append(items, gadgetConfigItemWithKey{key, item}) } - result := u.tx.Commit() - u.tx = nil - return result + sort.SliceStable(items, func(i, j int) bool { + return items[i].item.order < items[j].item.order + }) + + return items +} + +func (tx *UsbGadgetTransaction) getLogger() *zerolog.Logger { + logger := logging.GetSubsystemLogger("usbgadget"). + With(). + Str("subcomponent", "transaction"). + Logger() + return &logger } func (tx *UsbGadgetTransaction) addFileChange(component string, change RequestedFileChange) string { change.Component = component tx.c.AddFileChangeStruct(change) + logger := tx.getLogger() + logger.Trace().Interface("change", change).Msg("add change") + key := change.Key if key == "" { key = change.Path @@ -101,28 +97,16 @@ func (tx *UsbGadgetTransaction) removeFile(component string, path string, descri func (tx *UsbGadgetTransaction) Commit() error { tx.addFileChange("gadget-finalize", *tx.reorderSymlinkChanges) - err := tx.c.Apply() + logger := tx.getLogger() + err := tx.c.Apply(logger) if err != nil { - tx.log.Error().Err(err).Msg("failed to update usbgadget configuration") + logger.Error().Err(err).Msg("failed to update usbgadget configuration") return err } - tx.log.Info().Msg("usbgadget configuration updated") + logger.Info().Msg("usbgadget configuration updated") return nil } -func (u *UsbGadget) getOrderedConfigItems() orderedGadgetConfigItems { - items := make([]gadgetConfigItemWithKey, 0) - for key, item := range u.configMap { - items = append(items, gadgetConfigItemWithKey{key, item}) - } - - sort.Slice(items, func(i, j int) bool { - return items[i].item.order < items[j].item.order - }) - - return items -} - func (tx *UsbGadgetTransaction) MountConfigFS() { tx.addFileChange("gadget", RequestedFileChange{ Path: configFSPath, @@ -131,84 +115,71 @@ func (tx *UsbGadgetTransaction) MountConfigFS() { }) } -func (tx *UsbGadgetTransaction) CreateConfigPath() { +func (tx *UsbGadgetTransaction) CreateConfigPath(configC1Path string) { tx.mkdirAll( "gadget", - tx.configC1Path, + configC1Path, "create config path", []string{configFSPath}, ) } -func (tx *UsbGadgetTransaction) WriteGadgetConfig() { +func (tx *UsbGadgetTransaction) WriteGadgetConfig(kvmGadgetPath string, configC1Path string, udc string, orderedConfigItems orderedGadgetConfigItems, enabledDevices *Devices) { // create kvm gadget path tx.mkdirAll( "gadget", - tx.kvmGadgetPath, + kvmGadgetPath, "create kvm gadget path", - []string{tx.configC1Path}, + []string{configC1Path}, ) deps := make([]string, 0) - deps = append(deps, tx.kvmGadgetPath) + deps = append(deps, kvmGadgetPath) - for _, val := range tx.orderedConfigItems { + for _, val := range orderedConfigItems { key := val.key item := val.item // check if the item is enabled in the config - if !tx.isGadgetConfigItemEnabled(key) { - tx.DisableGadgetItemConfig(item) + if !enabledDevices.isGadgetConfigItemEnabled(key) { + tx.DisableGadgetItemConfig(configC1Path, item) continue } - deps = tx.writeGadgetItemConfig(item, deps) - } - - tx.WriteUDC() -} -func (tx *UsbGadgetTransaction) getDisableKeys() []string { - disableKeys := make([]string, 0) - for _, item := range tx.orderedConfigItems { - if !tx.isGadgetConfigItemEnabled(item.key) { - continue - } - if item.item.configPath == nil || item.item.configAttrs != nil { - continue - } - - disableKeys = append(disableKeys, fmt.Sprintf("disable-%s", item.item.device)) + deps = tx.writeGadgetItemConfig(kvmGadgetPath, configC1Path, item, deps) } - return disableKeys + + tx.WriteUDC(kvmGadgetPath, udc) } -func (tx *UsbGadgetTransaction) DisableGadgetItemConfig(item gadgetConfigItem) { +func (tx *UsbGadgetTransaction) DisableGadgetItemConfig(configC1Path string, item gadgetConfigItem) { // remove symlink if exists if item.configPath == nil { return } - configPath := joinPath(tx.configC1Path, item.configPath) + configPath := joinPath(configC1Path, item.configPath) _ = tx.removeFile("gadget", configPath, "remove symlink: disable gadget config") } -func (tx *UsbGadgetTransaction) writeGadgetItemConfig(item gadgetConfigItem, deps []string) []string { +func (tx *UsbGadgetTransaction) writeGadgetItemConfig(kvmGadgetPath string, configC1Path string, item gadgetConfigItem, deps []string) []string { component := item.device // create directory for the item files := make([]string, 0) files = append(files, deps...) - gadgetItemPath := joinPath(tx.kvmGadgetPath, item.path) - if gadgetItemPath != tx.kvmGadgetPath { + gadgetItemPath := joinPath(kvmGadgetPath, item.path) + if gadgetItemPath != kvmGadgetPath { gadgetItemDir := tx.mkdirAll(component, gadgetItemPath, "create gadget item directory", files) files = append(files, gadgetItemDir) } beforeChange := make([]string, 0) disableGadgetItemKey := fmt.Sprintf("disable-%s", item.device) + if item.configPath != nil && item.configAttrs == nil { - beforeChange = append(beforeChange, tx.getDisableKeys()...) + beforeChange = append(beforeChange, disableGadgetItemKey) } if len(item.attrs) > 0 { @@ -245,8 +216,8 @@ func (tx *UsbGadgetTransaction) writeGadgetItemConfig(item gadgetConfigItem, dep // create config directory if configAttrs are set if len(item.configAttrs) > 0 { - configItemPath := joinPath(tx.configC1Path, item.configPath) - if configItemPath != tx.configC1Path { + configItemPath := joinPath(configC1Path, item.configPath) + if configItemPath != configC1Path { configItemDir := tx.mkdirAll(component, configItemPath, "create config item directory", files) files = append(files, configItemDir) } @@ -260,7 +231,7 @@ func (tx *UsbGadgetTransaction) writeGadgetItemConfig(item gadgetConfigItem, dep // create symlink if configPath is set if item.configPath != nil && item.configAttrs == nil { - configPath := joinPath(tx.configC1Path, item.configPath) + configPath := joinPath(configC1Path, item.configPath) // the change will be only applied by `beforeChange` tx.addFileChange(component, RequestedFileChange{ @@ -271,7 +242,7 @@ func (tx *UsbGadgetTransaction) writeGadgetItemConfig(item gadgetConfigItem, dep Description: "remove symlink", }) - tx.addReorderSymlinkChange(configPath, gadgetItemPath, files) + tx.addReorderSymlinkChange(configC1Path, configPath, gadgetItemPath, files) } return files @@ -294,17 +265,19 @@ func (tx *UsbGadgetTransaction) writeGadgetAttrs(basePath string, attrs gadgetAt return files } -func (tx *UsbGadgetTransaction) addReorderSymlinkChange(path string, target string, deps []string) { - tx.log.Trace().Str("path", path).Str("target", target).Msg("add reorder symlink change") +func (tx *UsbGadgetTransaction) addReorderSymlinkChange(configC1Path string, path string, target string, deps []string) { + logger := tx.getLogger() + logger.Trace().Str("configC1Path", configC1Path).Str("path", path).Str("target", target).Msg("add reorder symlink change") if tx.reorderSymlinkChanges == nil { tx.reorderSymlinkChanges = &RequestedFileChange{ - Component: "gadget-finalize", - Key: "reorder-symlinks", - Path: tx.configC1Path, - ExpectedState: FileStateSymlinkInOrderConfigFS, - Description: "order symlinks", - ParamSymlinks: []symlink{}, + Component: "gadget-finalize", + Key: "reorder-symlinks", + Path: configC1Path, + ExpectedState: FileStateSymlinkInOrderConfigFS, + ExpectedContent: []byte(target), + Description: "order symlinks", + ParamSymlinks: []symlink{}, } } @@ -315,35 +288,35 @@ func (tx *UsbGadgetTransaction) addReorderSymlinkChange(path string, target stri }) } -func (tx *UsbGadgetTransaction) WriteUDC() { +func (tx *UsbGadgetTransaction) WriteUDC(kvmGadgetPath string, udc string) { // bound the gadget to a UDC (USB Device Controller) - path := path.Join(tx.kvmGadgetPath, "UDC") + path := path.Join(kvmGadgetPath, "UDC") tx.addFileChange("udc", RequestedFileChange{ Key: "udc", Path: path, ExpectedState: FileStateFileContentMatch, - ExpectedContent: []byte(tx.udc), + ExpectedContent: []byte(udc), DependsOn: []string{"reorder-symlinks"}, Description: "write UDC", }) } -func (tx *UsbGadgetTransaction) RebindUsb(ignoreUnbindError bool) { +func (tx *UsbGadgetTransaction) RebindUsb(udc string, ignoreUnbindError bool) { // remove the gadget from the UDC tx.addFileChange("udc", RequestedFileChange{ - Path: path.Join(tx.dwc3Path, "unbind"), + Path: path.Join(dwc3Path, "unbind"), ExpectedState: FileStateFileWrite, - ExpectedContent: []byte(tx.udc), + ExpectedContent: []byte(udc), Description: "unbind UDC", DependsOn: []string{"udc"}, IgnoreErrors: ignoreUnbindError, }) // bind the gadget to the UDC tx.addFileChange("udc", RequestedFileChange{ - Path: path.Join(tx.dwc3Path, "bind"), + Path: path.Join(dwc3Path, "bind"), ExpectedState: FileStateFileWrite, - ExpectedContent: []byte(tx.udc), + ExpectedContent: []byte(udc), Description: "bind UDC", - DependsOn: []string{path.Join(tx.dwc3Path, "unbind")}, + DependsOn: []string{path.Join(dwc3Path, "unbind")}, }) } diff --git a/internal/usbgadget/hid_keyboard.go b/internal/usbgadget/hid_keyboard.go index 274f0b6aa..9bb273f22 100644 --- a/internal/usbgadget/hid_keyboard.go +++ b/internal/usbgadget/hid_keyboard.go @@ -3,11 +3,14 @@ package usbgadget import ( "bytes" "context" + "errors" "fmt" "os" + "slices" "sync" "time" + "github.com/jetkvm/kvm/internal/logging" "github.com/rs/xid" "github.com/rs/zerolog" ) @@ -95,6 +98,10 @@ type KeyboardState struct { raw byte } +func (s KeyboardState) MarshalZerologObject(e *zerolog.Event) { + e.Hex("State", []byte{s.raw}) +} + // Byte returns the raw byte representation of the keyboard state. func (k *KeyboardState) Byte() byte { return k.raw @@ -117,19 +124,25 @@ func (u *UsbGadget) updateKeyboardState(state byte) { u.keyboardStateLock.Lock() defer u.keyboardStateLock.Unlock() + logger := u.getHidKeyboardLogger().With().Hex("state", []byte{state}).Logger() + if state&^ValidKeyboardLedMasks != 0 { - u.log.Warn().Uint8("state", state).Msg("ignoring invalid bits") - return + logger.Warn().Msg("ignoring invalid bits") + state &= ValidKeyboardLedMasks } + logger = logger.With().Hex("old_state", []byte{u.keyboardState}).Logger() + if u.keyboardState == state { + logger.Trace().Msg("unchanged keyboardState") return } - u.log.Trace().Uint8("old", u.keyboardState).Uint8("new", state).Msg("keyboardState updated") + u.keyboardState = state + logger.Trace().Msg("keyboardState updated") - if u.onKeyboardStateChange != nil { - (*u.onKeyboardStateChange)(getKeyboardState(state)) + if cb := u.onKeyboardStateChange; cb != nil { + go (*cb)(getKeyboardState(state)) // this enqueues to the outgoing hidrpc queue via usb.go → currentSession.reportHidRPCKeyboardLedState(...) } } @@ -151,6 +164,17 @@ func (u *UsbGadget) GetKeysDownState() KeysDownState { return u.keysDownState } +func (u *UsbGadget) ResetRollover() { + u.keyboardStateLock.Lock() + defer u.keyboardStateLock.Unlock() + + if u.keysDownState.Keys[0] == hidErrorRollOver { + for i := range u.keysDownState.Keys { + u.keysDownState.Keys[i] = 0 + } + } +} + func (u *UsbGadget) SetOnKeysDownChange(f func(state KeysDownState)) { u.onKeysDownChange = &f } @@ -162,183 +186,169 @@ func (u *UsbGadget) SetOnKeepAliveReset(f func()) { // DefaultAutoReleaseDuration is the default duration for auto-release of a key. const DefaultAutoReleaseDuration = 100 * time.Millisecond -func (u *UsbGadget) scheduleAutoRelease(key byte) { +func (u *UsbGadget) DelayAutoReleaseWithDuration(resetDuration time.Duration) { + logger := u.getHidKeyboardAutoReleaseLogger().With().Dur("reset_duration", resetDuration).Logger() + u.kbdAutoReleaseLock.Lock() - defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease scheduled") + defer logging.UnlockWithTraceLog(&u.kbdAutoReleaseLock, &logger, "auto-release delayed") - if u.kbdAutoReleaseTimers[key] != nil { - u.kbdAutoReleaseTimers[key].Stop() + for _, timer := range u.kbdAutoReleaseTimers { + if timer != nil { + timer.Reset(resetDuration) + } } - - // TODO: make this configurable - // We currently hardcode the duration to 100ms - // However, it should be the same as the duration of the keep-alive reset called baseExtension. - u.kbdAutoReleaseTimers[key] = time.AfterFunc(100*time.Millisecond, func() { - u.performAutoRelease(key) - }) } -func (u *UsbGadget) cancelAutoRelease(key byte) { +// note: lock must be freed by caller +func (u *UsbGadget) popAutoReleaseTimer(key byte) bool { u.kbdAutoReleaseLock.Lock() - defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease cancelled") + timer, ok := u.kbdAutoReleaseTimers[key] + if ok { + if timer != nil { + timer.Stop() + } - if timer := u.kbdAutoReleaseTimers[key]; timer != nil { - timer.Stop() - u.kbdAutoReleaseTimers[key] = nil delete(u.kbdAutoReleaseTimers, key) - - // Reset keep-alive timing when key is released - if u.onKeepAliveReset != nil { - (*u.onKeepAliveReset)() - } } + + return ok } -func (u *UsbGadget) DelayAutoReleaseWithDuration(resetDuration time.Duration) { - u.kbdAutoReleaseLock.Lock() - defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease delayed") +func (u *UsbGadget) scheduleAutoRelease(key byte) { + logger := u.getHidKeyboardAutoReleaseLogger().With().Hex("key", []byte{key}).Logger() - u.log.Debug().Dur("reset_duration", resetDuration).Msg("delaying auto-release with dynamic duration") + _ = u.popAutoReleaseTimer(key) // discard previous timer if any + defer logging.UnlockWithTraceLog(&u.kbdAutoReleaseLock, &logger, "auto-release scheduled") - for _, timer := range u.kbdAutoReleaseTimers { - if timer != nil { - timer.Reset(resetDuration) - } - } + start := time.Now() + u.kbdAutoReleaseTimers[key] = time.AfterFunc(DefaultAutoReleaseDuration, func() { + logger.Info().Dur("elapsed", time.Since(start)).Msg("fired after") + u.performAutoRelease(key) + }) } -func (u *UsbGadget) performAutoRelease(key byte) { - u.kbdAutoReleaseLock.Lock() +func (u *UsbGadget) cancelAutoRelease(key byte) { + logger := u.getHidKeyboardAutoReleaseLogger().With().Hex("key", []byte{key}).Logger() - if u.kbdAutoReleaseTimers[key] == nil { - u.log.Warn().Uint8("key", key).Msg("autoRelease timer not found") - u.kbdAutoReleaseLock.Unlock() - return + ok := u.popAutoReleaseTimer(key) + defer logging.UnlockWithTraceLog(&u.kbdAutoReleaseLock, &logger, "auto-release cancelled") + + if ok { + // Reset keep-alive timing when key is actually released + if cb := u.onKeepAliveReset; cb != nil { + go (*cb)() + } } +} - u.kbdAutoReleaseTimers[key].Stop() - u.kbdAutoReleaseTimers[key] = nil - delete(u.kbdAutoReleaseTimers, key) - u.kbdAutoReleaseLock.Unlock() +func (u *UsbGadget) performAutoRelease(key byte) { + logger := u.getHidKeyboardAutoReleaseLogger().With().Hex("key", []byte{key}).Logger() - // Skip if already released - state := u.GetKeysDownState() - alreadyReleased := true + ok := u.popAutoReleaseTimer(key) + defer logging.UnlockWithTraceLog(&u.kbdAutoReleaseLock, &logger, "auto-released") - for i := range state.Keys { - if state.Keys[i] == key { - alreadyReleased = false - break + if ok { + // Skip if already released + state := u.GetKeysDownState() + if !slices.Contains(state.Keys, key) { + logger.Trace().Msg("already released") + return } - } - if alreadyReleased { - return - } - - _, err := u.keypressReport(key, false) - if err != nil { - u.log.Warn().Uint8("key", key).Msg("failed to release key") + _, err := u.keypressReport(&logger, key, false) + if err != nil { + logger.Warn().Msg("failed to release key") + } } } func (u *UsbGadget) listenKeyboardEvents() { - var path string - if u.keyboardHidFile != nil { - path = u.keyboardHidFile.Name() - } - l := u.log.With().Str("listener", "keyboardEvents").Str("path", path).Logger() - l.Trace().Msg("starting") - - go func() { - buf := make([]byte, hidReadBufferSize) - for { - select { - case <-u.keyboardStateCtx.Done(): - l.Info().Msg("context done") + buf := make([]byte, hidReadBufferSize) + for { + select { + case <-u.keyboardStateCtx.Done(): + u.getHidKeyboardLogger().Info().Msg("context done") + return + default: + if u.keyboardHidFile == nil { + u.getHidKeyboardLogger().Warn().Msg("keyboardHidFile is nil, stopping keyboard event listener") return - default: - l.Trace().Msg("reading from keyboard for LED state changes") - if u.keyboardHidFile == nil { - u.logWithSuppression("keyboardHidFileNil", 100, &l, nil, "keyboardHidFile is nil") - // show the error every 100 times to avoid spamming the logs - time.Sleep(time.Second) - continue - } - // reset the counter - u.resetLogSuppressionCounter("keyboardHidFileNil") + } - n, err := u.keyboardHidFile.Read(buf) - if err != nil { - u.logWithSuppression("keyboardHidFileRead", 100, &l, err, "failed to read") - continue + logger := u.getHidKeyboardLogger().With().Str("path", u.keyboardHidFile.Name()).Str("listener", "keyboardEvents").Logger() + logger.Trace().Msg("reading from keyboard for LED state changes") + n, err := u.keyboardHidFile.Read(buf) + if err != nil { + if errors.Is(err, os.ErrClosed) { + logger.Warn().Msg("keyboard file is closed, stopping keyboard event listener") + return + } else if exceeded := u.logWithSuppression(&logger, "keyboardHidFileRead", 10, err, "failed to read"); exceeded { + logger.Error().Msg("too many errors reading the keyboard file, stopping keyboard event listener") + return } + } else { u.resetLogSuppressionCounter("keyboardHidFileRead") + } - l.Trace().Int("n", n).Uints8("buf", buf).Msg("got data from keyboard") - if n != 1 { - l.Trace().Int("n", n).Msg("expected 1 byte, got") - continue - } - u.updateKeyboardState(buf[0]) + logger.Trace().Int("n", n).Hex("buf", buf).Msg("got data from keyboard") + if n != 1 { + logger.Warn().Int("n", n).Msg("expected 1 byte") + continue } + u.updateKeyboardState(buf[0]) } - }() + } } -func (u *UsbGadget) openKeyboardHidFile() error { +func (u *UsbGadget) openKeyboardHidFileUnderMutex() error { if u.keyboardHidFile != nil { return nil } - var err error - u.keyboardHidFile, err = os.OpenFile("/dev/hidg0", os.O_RDWR, 0666) - if err != nil { - return fmt.Errorf("failed to open hidg0: %w", err) - } - if u.keyboardStateCancel != nil { u.keyboardStateCancel() + u.keyboardStateCancel = nil + } + + if keyboardFile, err := os.OpenFile("/dev/hidg0", os.O_RDWR, 0666); err == nil { + u.keyboardHidFile = keyboardFile + } else { + return fmt.Errorf("failed to open keyboard on hidg0: %w", err) } u.keyboardStateCtx, u.keyboardStateCancel = context.WithCancel(context.Background()) - u.listenKeyboardEvents() + go u.listenKeyboardEvents() return nil } +var keyboardHidFileLock sync.Mutex + func (u *UsbGadget) OpenKeyboardHidFile() error { - return u.openKeyboardHidFile() -} + keyboardHidFileLock.Lock() + defer keyboardHidFileLock.Unlock() -var keyboardWriteHidFileLock sync.Mutex + return u.openKeyboardHidFileUnderMutex() +} func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error { - keyboardWriteHidFileLock.Lock() - defer keyboardWriteHidFileLock.Unlock() - if err := u.openKeyboardHidFile(); err != nil { + keyboardHidFileLock.Lock() + defer keyboardHidFileLock.Unlock() + + if err := u.openKeyboardHidFileUnderMutex(); err != nil { return err } _, err := u.writeWithTimeout(u.keyboardHidFile, append([]byte{modifier, 0x00}, keys[:hidKeyBufferSize]...)) if err != nil { - u.logWithSuppression("keyboardWriteHidFile", 100, u.log, err, "failed to write to hidg0") u.keyboardHidFile.Close() u.keyboardHidFile = nil return err } - u.resetLogSuppressionCounter("keyboardWriteHidFile") return nil } func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) KeysDownState { - // if we just reported an error roll over, we should clear the keys - if keys[0] == hidErrorRollOver { - for i := range keys { - keys[i] = 0 - } - } - state := KeysDownState{ Modifier: modifier, Keys: []byte(keys[:]), @@ -355,8 +365,8 @@ func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) KeysDownState { u.keysDownState = state u.keyboardStateLock.Unlock() - if u.onKeysDownChange != nil { - (*u.onKeysDownChange)(state) // this enques to the outgoing hidrpc queue via usb.go → currentSession.enqueueKeysDownState(...) + if cb := u.onKeysDownChange; cb != nil { + go (*cb)(state) // this enqueues to the outgoing hidrpc queue via usb.go → currentSession.enqueueKeysDownState(...) } return state } @@ -373,10 +383,11 @@ func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) error { err := u.keyboardWriteHidFile(modifier, keys) if err != nil { - u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keyboard report to hidg0") + u.getHidKeyboardLogger().Warn().Err(err).Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keyboard report to hidg0") } u.UpdateKeysDown(modifier, keys) + defer u.ResetRollover() return err } @@ -417,13 +428,13 @@ var KeyCodeToMaskMap = map[byte]byte{ RightSuper: ModifierMaskRightSuper, } -func (u *UsbGadget) keypressReport(key byte, press bool) (KeysDownState, error) { +func (u *UsbGadget) keypressReport(l *zerolog.Logger, key byte, press bool) (KeysDownState, error) { defer u.resetUserInputTime() + logger := l.With().Hex("key", []byte{key}).Bool("press", press).Logger() - l := u.log.With().Uint8("key", key).Bool("press", press).Logger() - if l.GetLevel() <= zerolog.DebugLevel { + if logger.GetLevel() <= zerolog.DebugLevel { requestID := xid.New() - l = l.With().Str("requestID", requestID.String()).Logger() + logger = logger.With().Stringer("requestID", requestID).Logger() } // IMPORTANT: This code parallels the logic in the kernel's hid-gadget driver @@ -432,7 +443,7 @@ func (u *UsbGadget) keypressReport(key byte, press bool) (KeysDownState, error) // in the client/browser-side code in useKeyboard.ts so make sure to keep // them in sync. var state = u.GetKeysDownState() - l.Trace().Interface("state", state).Msg("got keys down state") + logger.Trace().Object("state", state).Msg("got keys down state") modifier := state.Modifier keys := append([]byte(nil), state.Keys...) @@ -473,14 +484,15 @@ func (u *UsbGadget) keypressReport(key byte, press bool) (KeysDownState, error) // If we reach here it means we didn't find an empty slot or the key in the buffer if overrun { if press { - l.Error().Msg("keyboard buffer overflow, key not added") + logger.Warn().Msg("keyboard buffer overflow, key not added") // Fill all key slots with ErrorRollOver (0x01) to indicate overflow for i := range keys { keys[i] = hidErrorRollOver } + defer u.ResetRollover() // after reporting rollover, we reset the buffer } else { // If we are releasing a key, and we didn't find it in a slot, who cares? - l.Warn().Msg("key not found in buffer, nothing to release") + logger.Debug().Msg("key not found in buffer, nothing to release") } } } @@ -490,13 +502,15 @@ func (u *UsbGadget) keypressReport(key byte, press bool) (KeysDownState, error) } func (u *UsbGadget) KeypressReport(key byte, press bool) error { - state, err := u.keypressReport(key, press) + logger := u.getHidKeyboardLogger().With().Hex("key", []byte{key}).Bool("press", press).Logger() + + state, err := u.keypressReport(&logger, key, press) if err != nil { - u.log.Warn().Uint8("key", key).Bool("press", press).Msg("failed to report key") + logger.Warn().Err(err).Msg("failed to report key") } - isRolledOver := state.Keys[0] == hidErrorRollOver + wasRolledOver := state.Keys[0] == hidErrorRollOver - if isRolledOver { + if wasRolledOver { u.cancelAutoRelease(key) } else if press { u.scheduleAutoRelease(key) diff --git a/internal/usbgadget/hid_mouse_absolute.go b/internal/usbgadget/hid_mouse_absolute.go index 374844f10..4059dc86f 100644 --- a/internal/usbgadget/hid_mouse_absolute.go +++ b/internal/usbgadget/hid_mouse_absolute.go @@ -68,20 +68,17 @@ var absoluteMouseCombinedReportDesc = []byte{ func (u *UsbGadget) absMouseWriteHidFile(data []byte) error { if u.absMouseHidFile == nil { var err error - u.absMouseHidFile, err = os.OpenFile("/dev/hidg1", os.O_RDWR, 0666) - if err != nil { + if u.absMouseHidFile, err = os.OpenFile("/dev/hidg1", os.O_RDWR, 0666); err != nil { return fmt.Errorf("failed to open hidg1: %w", err) } } _, err := u.writeWithTimeout(u.absMouseHidFile, data) if err != nil { - u.logWithSuppression("absMouseWriteHidFile", 100, u.log, err, "failed to write to hidg1") u.absMouseHidFile.Close() u.absMouseHidFile = nil return err } - u.resetLogSuppressionCounter("absMouseWriteHidFile") return nil } @@ -97,12 +94,9 @@ func (u *UsbGadget) AbsMouseReport(x int, y int, buttons uint8) error { byte(y), // Y Low Byte byte(y >> 8), // Y High Byte }) - if err != nil { - return err - } u.resetUserInputTime() - return nil + return err } func (u *UsbGadget) AbsMouseWheelReport(wheelY int8) error { diff --git a/internal/usbgadget/hid_mouse_relative.go b/internal/usbgadget/hid_mouse_relative.go index 070db6e89..8359bc959 100644 --- a/internal/usbgadget/hid_mouse_relative.go +++ b/internal/usbgadget/hid_mouse_relative.go @@ -58,20 +58,17 @@ var relativeMouseCombinedReportDesc = []byte{ func (u *UsbGadget) relMouseWriteHidFile(data []byte) error { if u.relMouseHidFile == nil { var err error - u.relMouseHidFile, err = os.OpenFile("/dev/hidg2", os.O_RDWR, 0666) - if err != nil { - return fmt.Errorf("failed to open hidg1: %w", err) + if u.relMouseHidFile, err = os.OpenFile("/dev/hidg2", os.O_RDWR, 0666); err != nil { + return fmt.Errorf("failed to open hidg2: %w", err) } } _, err := u.writeWithTimeout(u.relMouseHidFile, data) if err != nil { - u.logWithSuppression("relMouseWriteHidFile", 100, u.log, err, "failed to write to hidg2") u.relMouseHidFile.Close() u.relMouseHidFile = nil return err } - u.resetLogSuppressionCounter("relMouseWriteHidFile") return nil } @@ -85,10 +82,7 @@ func (u *UsbGadget) RelMouseReport(mx int8, my int8, buttons uint8) error { byte(my), // Y 0, // Wheel }) - if err != nil { - return err - } u.resetUserInputTime() - return nil + return err } diff --git a/internal/usbgadget/log.go b/internal/usbgadget/log.go index f979f6c1f..03cea0c9c 100644 --- a/internal/usbgadget/log.go +++ b/internal/usbgadget/log.go @@ -2,16 +2,67 @@ package usbgadget import ( "errors" + + "github.com/jetkvm/kvm/internal/logging" + "github.com/rs/zerolog" ) +func (u *UsbGadget) getUsbGadgetLogger() *zerolog.Logger { + logger := logging.GetSubsystemLogger("usbgadget"). + With(). + Str("gadget", u.name). + Logger() + return &logger +} + +func (u *UsbGadget) getHidKeyboardLogger() *zerolog.Logger { + logger := logging.GetSubsystemLogger("hid-keyboard"). + With(). + Str("gadget", u.name). + Logger() + return &logger +} + +func (u *UsbGadget) getHidKeyboardAutoReleaseLogger() *zerolog.Logger { + logger := logging.GetSubsystemLogger("hid-keyboard-auto-release"). + With(). + Str("gadget", u.name). + Logger() + return &logger +} + +func (u *UsbGadget) logWithSuppression(counterName string, every int, logger *zerolog.Logger, err error, msg string, args ...interface{}) bool { + u.logSuppressionLock.Lock() + counter, ok := u.logSuppressionCounter[counterName] // returns 0, false if not found + counter++ + u.logSuppressionCounter[counterName] = counter + u.logSuppressionLock.Unlock() + + // log if it's the first time, and then every N times thereafter + if !ok || counter%every == 0 { + logger.Error().Str("counterName", counterName).Int("counter", counter).Err(err).Msgf(msg, args...) + return ok // return whether we've just exceeded the every interval + } + return false +} + +func (u *UsbGadget) resetLogSuppressionCounter(counterName string) { + u.logSuppressionLock.Lock() + delete(u.logSuppressionCounter, counterName) + u.logSuppressionLock.Unlock() +} + func (u *UsbGadget) logWarn(msg string, err error) error { if err == nil { err = errors.New(msg) } + + u.getUsbGadgetLogger().Warn().Err(err).Msg(msg) + if u.strictMode { return err } - u.log.Warn().Err(err).Msg(msg) + return nil } @@ -19,9 +70,12 @@ func (u *UsbGadget) logError(msg string, err error) error { if err == nil { err = errors.New(msg) } + + u.getUsbGadgetLogger().Error().Err(err).Msg(msg) + if u.strictMode { return err } - u.log.Error().Err(err).Msg(msg) + return nil } diff --git a/internal/usbgadget/udc.go b/internal/usbgadget/udc.go index 4b7fbe361..abc16680d 100644 --- a/internal/usbgadget/udc.go +++ b/internal/usbgadget/udc.go @@ -38,7 +38,7 @@ func rebindUsb(udc string, ignoreUnbindError bool) error { } func (u *UsbGadget) rebindUsb(ignoreUnbindError bool) error { - u.log.Info().Str("udc", u.udc).Msg("rebinding USB gadget to UDC") + u.getUsbGadgetLogger().Info().Str("udc", u.udc).Msg("rebinding USB gadget to UDC") return rebindUsb(u.udc, ignoreUnbindError) } @@ -58,7 +58,7 @@ func (u *UsbGadget) GetUsbState() (state string) { if os.IsNotExist(err) { return "not attached" } else { - u.log.Trace().Err(err).Msg("failed to read usb state") + u.getUsbGadgetLogger().Warn().Err(err).Msg("failed to read usb state") } return "unknown" } diff --git a/internal/usbgadget/usbgadget.go b/internal/usbgadget/usbgadget.go index f01ae09d4..2c11057bd 100644 --- a/internal/usbgadget/usbgadget.go +++ b/internal/usbgadget/usbgadget.go @@ -9,7 +9,8 @@ import ( "sync" "time" - "github.com/jetkvm/kvm/internal/logging" + "github.com/jetkvm/kvm/internal/utils" + "github.com/rs/zerolog" ) @@ -29,9 +30,8 @@ type Config struct { SerialNumber string `json:"serial_number"` Manufacturer string `json:"manufacturer"` Product string `json:"product"` - - strictMode bool // when it's enabled, all warnings will be converted to errors - isEmpty bool + isEmpty bool + strictMode bool // when it's enabled, all warnings will be converted to errors } var defaultUsbGadgetDevices = Devices{ @@ -42,8 +42,13 @@ var defaultUsbGadgetDevices = Devices{ } type KeysDownState struct { - Modifier byte `json:"modifier"` - Keys ByteSlice `json:"keys"` + Modifier byte `json:"modifier"` + Keys utils.ByteSlice `json:"keys"` +} + +func (k KeysDownState) MarshalZerologObject(e *zerolog.Event) { + e.Uint8("modifier", k.Modifier) + e.Object("keys", k.Keys) } // UsbGadget is a struct that represents a USB gadget. @@ -77,21 +82,18 @@ type UsbGadget struct { enabledDevices Devices - strictMode bool // only intended for testing for now + strictMode bool // only intended for testing absMouseAccumulatedWheelY float64 lastUserInput time.Time - tx *UsbGadgetTransaction txLock sync.Mutex onKeyboardStateChange *func(state KeyboardState) onKeysDownChange *func(state KeysDownState) onKeepAliveReset *func() - log *zerolog.Logger - logSuppressionCounter map[string]int logSuppressionLock sync.Mutex } @@ -99,18 +101,12 @@ type UsbGadget struct { const configFSPath = "/sys/kernel/config" const gadgetPath = "/sys/kernel/config/usb_gadget" -var defaultLogger = logging.GetSubsystemLogger("usbgadget") - // NewUsbGadget creates a new UsbGadget. -func NewUsbGadget(name string, enabledDevices *Devices, config *Config, logger *zerolog.Logger) *UsbGadget { - return newUsbGadget(name, defaultGadgetConfig, enabledDevices, config, logger) +func NewUsbGadget(name string, enabledDevices *Devices, config *Config) *UsbGadget { + return newUsbGadget(name, defaultGadgetConfig, enabledDevices, config) } -func newUsbGadget(name string, configMap map[string]gadgetConfigItem, enabledDevices *Devices, config *Config, logger *zerolog.Logger) *UsbGadget { - if logger == nil { - logger = defaultLogger - } - +func newUsbGadget(name string, configMap map[string]gadgetConfigItem, enabledDevices *Devices, config *Config) *UsbGadget { if enabledDevices == nil { enabledDevices = &defaultUsbGadgetDevices } @@ -139,17 +135,15 @@ func newUsbGadget(name string, configMap map[string]gadgetConfigItem, enabledDev kbdAutoReleaseTimers: make(map[byte]*time.Timer), enabledDevices: *enabledDevices, lastUserInput: time.Now(), - log: logger, - - strictMode: config.strictMode, + strictMode: config.strictMode, logSuppressionCounter: make(map[string]int), absMouseAccumulatedWheelY: 0, } + if err := g.Init(); err != nil { - logger.Error().Err(err).Msg("failed to init USB gadget") - return nil + g.getUsbGadgetLogger().Error().Err(err).Msg("failed to init USB gadget") } return g @@ -160,6 +154,7 @@ func (u *UsbGadget) Close() error { // Cancel keyboard state context if u.keyboardStateCancel != nil { u.keyboardStateCancel() + u.keyboardStateCancel = nil } // Stop auto-release timer diff --git a/internal/usbgadget/utils.go b/internal/usbgadget/utils.go index 85bf1579d..cf7125ae8 100644 --- a/internal/usbgadget/utils.go +++ b/internal/usbgadget/utils.go @@ -2,44 +2,15 @@ package usbgadget import ( "bytes" - "encoding/json" "errors" "fmt" "os" "path/filepath" "strconv" "strings" - "sync" "time" - - "github.com/rs/zerolog" ) -type ByteSlice []byte - -func (s ByteSlice) MarshalJSON() ([]byte, error) { - vals := make([]int, len(s)) - for i, v := range s { - vals[i] = int(v) - } - return json.Marshal(vals) -} - -func (s *ByteSlice) UnmarshalJSON(data []byte) error { - var vals []int - if err := json.Unmarshal(data, &vals); err != nil { - return err - } - *s = make([]byte, len(vals)) - for i, v := range vals { - if v < 0 || v > 255 { - return fmt.Errorf("value %d out of byte range", v) - } - (*s)[i] = byte(v) - } - return nil -} - func joinPath(basePath string, paths []string) string { pathArr := append([]string{basePath}, paths...) return filepath.Join(pathArr...) @@ -71,14 +42,21 @@ func hexToOctal(hex string) (string, error) { return octal, nil } +func compareIgnoreTrailingLF(shorter []byte, longer []byte) bool { + shorterLen := len(shorter) + longerLen := len(longer) + return shorterLen+1 == longerLen && + bytes.Equal(longer[:shorterLen], shorter) && + longer[shorterLen] == 0x0a +} + func compareFileContent(oldContent []byte, newContent []byte, looserMatch bool) bool { - if bytes.Equal(oldContent, newContent) { + if len(oldContent) == len(newContent) && bytes.Equal(oldContent, newContent) { return true } - if len(oldContent) == len(newContent)+1 && - bytes.Equal(oldContent[:len(newContent)], newContent) && - oldContent[len(newContent)] == 10 { + // allow for a trailing newline difference if the one did have one and the other does NOT + if compareIgnoreTrailingLF(oldContent, newContent) || compareIgnoreTrailingLF(newContent, oldContent) { return true } @@ -112,67 +90,31 @@ func compareFileContent(oldContent []byte, newContent []byte, looserMatch bool) } func (u *UsbGadget) writeWithTimeout(file *os.File, data []byte) (n int, err error) { + fileName := file.Name() + if err := file.SetWriteDeadline(time.Now().Add(hidWriteTimeout)); err != nil { return -1, err } n, err = file.Write(data) if err == nil { - return + u.resetLogSuppressionCounter("writeWithTimeout_" + fileName) + return n, nil } - u.log.Trace(). - Str("file", file.Name()). - Bytes("data", data). - Err(err). - Msg("write failed") - - if errors.Is(err, os.ErrDeadlineExceeded) { - u.logWithSuppression( - fmt.Sprintf("writeWithTimeout_%s", file.Name()), - 1000, - u.log, - err, - "write timed out: %s", - file.Name(), - ) - err = nil - } + logger := u.getHidKeyboardLogger().With().Str("file", fileName).Bytes("data", data).Logger() + logger.Trace().Err(err).Msg("write failed") - return -} - -func (u *UsbGadget) logWithSuppression(counterName string, every int, logger *zerolog.Logger, err error, msg string, args ...any) { - u.logSuppressionLock.Lock() - defer u.logSuppressionLock.Unlock() - - if _, ok := u.logSuppressionCounter[counterName]; !ok { - u.logSuppressionCounter[counterName] = 0 - } else { - u.logSuppressionCounter[counterName]++ - } - - l := logger.With().Int("counter", u.logSuppressionCounter[counterName]).Logger() - - if u.logSuppressionCounter[counterName]%every == 0 { - if err != nil { - l.Error().Err(err).Msgf(msg, args...) - } else { - l.Error().Msgf(msg, args...) + if errors.Is(err, os.ErrClosed) { + logger.Warn().Msg("keyboard file is closed, stopping writes") + return 0, err + } else if errors.Is(err, os.ErrDeadlineExceeded) { + if exceeded := u.logWithSuppression("writeWithTimeout_"+fileName, 10, &logger, err, "write timed out"); exceeded { + logger.Error().Msg("too many errors writing to the keyboard file, stopping writes") + return 0, err } + return 0, nil } -} - -func (u *UsbGadget) resetLogSuppressionCounter(counterName string) { - u.logSuppressionLock.Lock() - defer u.logSuppressionLock.Unlock() - - if _, ok := u.logSuppressionCounter[counterName]; !ok { - u.logSuppressionCounter[counterName] = 0 - } -} -func unlockWithLog(lock *sync.Mutex, logger *zerolog.Logger, msg string, args ...any) { - logger.Trace().Msgf(msg, args...) - lock.Unlock() + return n, err } diff --git a/internal/utils/byte_slice.go b/internal/utils/byte_slice.go new file mode 100644 index 000000000..406d3916a --- /dev/null +++ b/internal/utils/byte_slice.go @@ -0,0 +1,37 @@ +package utils + +import ( + "encoding/json" + "fmt" + + "github.com/rs/zerolog" +) + +type ByteSlice []byte + +func (s ByteSlice) MarshalJSON() ([]byte, error) { + vals := make([]int, len(s)) + for i, v := range s { + vals[i] = int(v) + } + return json.Marshal(vals) +} + +func (s *ByteSlice) UnmarshalJSON(data []byte) error { + var vals []int + if err := json.Unmarshal(data, &vals); err != nil { + return err + } + *s = make([]byte, len(vals)) + for i, v := range vals { + if v < 0 || v > 255 { + return fmt.Errorf("value %d out of byte range", v) + } + (*s)[i] = byte(v) + } + return nil +} + +func (k ByteSlice) MarshalZerologObject(e *zerolog.Event) { + e.Hex("bytes", k) +} diff --git a/internal/websecure/ed25519_test.go b/internal/websecure/ed25519_test.go index 0753be0d4..5de74feaf 100644 --- a/internal/websecure/ed25519_test.go +++ b/internal/websecure/ed25519_test.go @@ -3,6 +3,8 @@ package websecure import ( "os" "testing" + + "github.com/rs/zerolog" ) var ( @@ -25,17 +27,17 @@ MC4CAQAwBQYDK2VwBCIEIKV08xUsLRHBfMXqZwxVRzIbViOp8G7aQGjPvoRFjujB ) func TestMain(m *testing.M) { + logger := zerolog.New(os.Stdout).Level(zerolog.InfoLevel) tlsStorePath, err := os.MkdirTemp("", "jktls.*") if err != nil { - defaultLogger.Fatal().Err(err).Msg("failed to create temp directory") + logger.Fatal().Err(err).Msg("failed to create temp directory") } - certStore = NewCertStore(tlsStorePath, nil) - certStore.LoadCertificates() + certStore = NewCertStore(tlsStorePath) + certStore.LoadCertificates(&logger) certSigner = NewSelfSigner( certStore, - nil, "ci.jetkvm.com", "JetKVM", "JetKVM", @@ -48,7 +50,8 @@ func TestMain(m *testing.M) { } func TestSaveEd25519Certificate(t *testing.T) { - err, _ := certStore.ValidateAndSaveCertificate("ed25519-test.jetkvm.com", fixtureEd25519Certificate, fixtureEd25519PrivateKey, true) + logger := zerolog.New(os.Stdout).Level(zerolog.InfoLevel) + err, _ := certStore.ValidateAndSaveCertificate("ed25519-test.jetkvm.com", fixtureEd25519Certificate, fixtureEd25519PrivateKey, true, &logger) if err != nil { t.Fatalf("failed to save certificate: %v", err) } diff --git a/internal/websecure/log.go b/internal/websecure/log.go deleted file mode 100644 index f45767ede..000000000 --- a/internal/websecure/log.go +++ /dev/null @@ -1,9 +0,0 @@ -package websecure - -import ( - "os" - - "github.com/rs/zerolog" -) - -var defaultLogger = zerolog.New(os.Stdout).With().Str("component", "websecure").Logger() diff --git a/internal/websecure/selfsign.go b/internal/websecure/selfsign.go index 77efa3716..c29540d39 100644 --- a/internal/websecure/selfsign.go +++ b/internal/websecure/selfsign.go @@ -11,18 +11,16 @@ import ( "strings" "time" - "github.com/rs/zerolog" + "github.com/jetkvm/kvm/internal/logging" + "golang.org/x/net/idna" ) const selfSignerCAMagicName = "__ca__" type SelfSigner struct { - store *CertStore - log *zerolog.Logger - - caInfo pkix.Name - + store *CertStore + caInfo pkix.Name DefaultDomain string DefaultOrg string DefaultOU string @@ -30,7 +28,6 @@ type SelfSigner struct { func NewSelfSigner( store *CertStore, - log *zerolog.Logger, defaultDomain, defaultOrg, defaultOU, @@ -38,7 +35,6 @@ func NewSelfSigner( ) *SelfSigner { return &SelfSigner{ store: store, - log: log, DefaultDomain: defaultDomain, DefaultOrg: defaultOrg, DefaultOU: defaultOU, @@ -59,17 +55,19 @@ func (s *SelfSigner) createSelfSignedCert(hostname string) *tls.Certificate { return tlsCert } + logger := logging.GetSubsystemLogger("web-selfsign").With().Str("hostname", hostname).Logger() + // check if hostname is the CA magic name var ca *tls.Certificate if hostname != selfSignerCAMagicName { ca = s.getCA() if ca == nil { - s.log.Error().Msg("Failed to get CA certificate") + logger.Error().Msg("Failed to get CA certificate") return nil } } - s.log.Info().Str("hostname", hostname).Msg("Creating self-signed certificate") + logger.Info().Msg("Creating self-signed certificate") // lock the store while creating the certificate (do not move upwards) s.store.certLock.Lock() @@ -77,16 +75,16 @@ func (s *SelfSigner) createSelfSignedCert(hostname string) *tls.Certificate { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { - s.log.Error().Err(err).Msg("Failed to generate private key") + logger.Error().Err(err).Msg("Failed to generate private key") return nil } - notBefore := time.Now() + notBefore := time.Now().AddDate(0, 0, -1) // ensure we don't have issues with clock skew notAfter := notBefore.AddDate(1, 0, 0) serialNumber, err := generateSerialNumber() if err != nil { - s.log.Error().Err(err).Msg("Failed to generate serial number") + logger.Error().Err(err).Msg("Failed to generate serial number") return nil } @@ -139,7 +137,7 @@ func (s *SelfSigner) createSelfSignedCert(hostname string) *tls.Certificate { if ca != nil { parent, err = x509.ParseCertificate(ca.Certificate[0]) if err != nil { - s.log.Error().Err(err).Msg("Failed to parse parent certificate") + logger.Error().Err(err).Msg("Failed to parse parent certificate") return nil } parentPriv = ca.PrivateKey.(*ecdsa.PrivateKey) @@ -147,7 +145,7 @@ func (s *SelfSigner) createSelfSignedCert(hostname string) *tls.Certificate { certBytes, err := x509.CreateCertificate(rand.Reader, &cert, parent, &priv.PublicKey, parentPriv) if err != nil { - s.log.Error().Err(err).Msg("Failed to create certificate") + logger.Error().Err(err).Msg("Failed to create certificate") return nil } @@ -160,7 +158,7 @@ func (s *SelfSigner) createSelfSignedCert(hostname string) *tls.Certificate { } s.store.certificates[hostname] = tlsCert - s.store.saveCertificate(hostname) + s.store.saveCertificate(hostname, &logger) return tlsCert } @@ -175,12 +173,13 @@ func (s *SelfSigner) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate hostname = strings.Split(info.Conn.LocalAddr().String(), ":")[0] } - s.log.Info().Str("hostname", hostname).Strs("supported_protos", info.SupportedProtos).Msg("TLS handshake") + logger := logging.GetSubsystemLogger("web-selfsign").With().Str("hostname", hostname).Strs("supported_protos", info.SupportedProtos).Logger() + logger.Info().Msg("TLS handshake") // convert hostname to punycode h, err := idna.Lookup.ToASCII(hostname) if err != nil { - s.log.Warn().Str("hostname", hostname).Err(err).Str("remote_addr", info.Conn.RemoteAddr().String()).Msg("Hostname is not valid") + logger.Warn().Err(err).Stringer("remote_addr", info.Conn.RemoteAddr()).Msg("Hostname is not valid") hostname = s.DefaultDomain } else { hostname = h diff --git a/internal/websecure/store.go b/internal/websecure/store.go index ea7911c48..0f06a984a 100644 --- a/internal/websecure/store.go +++ b/internal/websecure/store.go @@ -14,27 +14,18 @@ import ( type CertStore struct { certificates map[string]*tls.Certificate certLock *sync.Mutex - - storePath string - - log *zerolog.Logger + storePath string } -func NewCertStore(storePath string, log *zerolog.Logger) *CertStore { - if log == nil { - log = &defaultLogger - } - +func NewCertStore(storePath string) *CertStore { return &CertStore{ certificates: make(map[string]*tls.Certificate), certLock: &sync.Mutex{}, - - storePath: storePath, - log: log, + storePath: storePath, } } -func (s *CertStore) ensureStorePath() error { +func (s *CertStore) ensureStorePath(logger *zerolog.Logger) error { // check if directory exists stat, err := os.Stat(s.storePath) if err == nil { @@ -46,7 +37,7 @@ func (s *CertStore) ensureStorePath() error { } if os.IsNotExist(err) { - s.log.Trace().Str("path", s.storePath).Msg("TLS store directory does not exist, creating directory") + logger.Trace().Str("path", s.storePath).Msg("TLS store directory does not exist, creating directory") err = os.MkdirAll(s.storePath, 0755) if err != nil { return fmt.Errorf("failed to create TLS store path: %w", err) @@ -57,16 +48,18 @@ func (s *CertStore) ensureStorePath() error { return fmt.Errorf("failed to check TLS store path: %w", err) } -func (s *CertStore) LoadCertificates() { - err := s.ensureStorePath() +func (s *CertStore) LoadCertificates(l *zerolog.Logger) { + logger := l.With().Str("storePath", s.storePath).Logger() + + err := s.ensureStorePath(&logger) if err != nil { - s.log.Error().Err(err).Msg("Failed to ensure store path") + logger.Error().Err(err).Msg("Failed to ensure store path") return } files, err := os.ReadDir(s.storePath) if err != nil { - s.log.Error().Err(err).Msg("Failed to read TLS directory") + logger.Error().Err(err).Msg("Failed to read TLS directory") return } @@ -76,30 +69,33 @@ func (s *CertStore) LoadCertificates() { } if strings.HasSuffix(file.Name(), ".crt") { - s.loadCertificate(strings.TrimSuffix(file.Name(), ".crt")) + hostname := strings.TrimSuffix(file.Name(), ".crt") + s.loadCertificate(hostname, &logger) } } } -func (s *CertStore) loadCertificate(hostname string) { +func (s *CertStore) loadCertificate(hostname string, l *zerolog.Logger) { s.certLock.Lock() defer s.certLock.Unlock() + logger := l.With().Str("hostname", hostname).Logger() + keyFile := path.Join(s.storePath, hostname+".key") crtFile := path.Join(s.storePath, hostname+".crt") cert, err := tls.LoadX509KeyPair(crtFile, keyFile) if err != nil { - s.log.Error().Err(err).Str("hostname", hostname).Msg("Failed to load certificate") + logger.Error().Err(err).Msg("Failed to load certificate") return } s.certificates[hostname] = &cert if hostname == selfSignerCAMagicName { - s.log.Info().Msg("loaded CA certificate") + logger.Info().Msg("loaded CA certificate") } else { - s.log.Info().Str("hostname", hostname).Msg("loaded certificate") + logger.Info().Msg("loaded certificate") } } @@ -116,7 +112,9 @@ func (s *CertStore) GetCertificate(hostname string) *tls.Certificate { // returns are: // - error: if the certificate is invalid or if there's any error during saving the certificate // - error: if there's any warning or error during saving the certificate -func (s *CertStore) ValidateAndSaveCertificate(hostname string, cert string, key string, ignoreWarning bool) (error, error) { +func (s *CertStore) ValidateAndSaveCertificate(hostname string, cert string, key string, ignoreWarning bool, l *zerolog.Logger) (error, error) { + logger := l.With().Str("hostname", hostname).Str("cert", cert).Logger() // don't log the key for security reasons + tlsCert, err := tls.X509KeyPair([]byte(cert), []byte(key)) if err != nil { return fmt.Errorf("failed to parse certificate: %w", err), nil @@ -127,7 +125,7 @@ func (s *CertStore) ValidateAndSaveCertificate(hostname string, cert string, key // add recover to avoid panic defer func() { if r := recover(); r != nil { - s.log.Error().Interface("recovered", r).Msg("Failed to verify hostname") + logger.Error().Interface("recovered", r).Msg("Failed to verify hostname") } }() @@ -135,7 +133,7 @@ func (s *CertStore) ValidateAndSaveCertificate(hostname string, cert string, key if !ignoreWarning { return nil, fmt.Errorf("certificate does not match hostname: %w", err) } - s.log.Warn().Err(err).Msg("certificate does not match hostname") + logger.Warn().Err(err).Msg("certificate does not match hostname") } } @@ -143,22 +141,22 @@ func (s *CertStore) ValidateAndSaveCertificate(hostname string, cert string, key s.certificates[hostname] = &tlsCert s.certLock.Unlock() - s.saveCertificate(hostname) + s.saveCertificate(hostname, &logger) return nil, nil } -func (s *CertStore) saveCertificate(hostname string) { +func (s *CertStore) saveCertificate(hostname string, logger *zerolog.Logger) { // check if certificate already exists tlsCert := s.certificates[hostname] if tlsCert == nil { - s.log.Error().Str("hostname", hostname).Msg("Certificate for hostname does not exist, skipping saving certificate") + logger.Error().Msg("Certificate for hostname does not exist, skipping saving certificate") return } - err := s.ensureStorePath() + err := s.ensureStorePath(logger) if err != nil { - s.log.Error().Err(err).Msg("Failed to ensure store path") + logger.Error().Err(err).Msg("Failed to ensure store path") return } @@ -166,14 +164,14 @@ func (s *CertStore) saveCertificate(hostname string) { crtFile := path.Join(s.storePath, hostname+".crt") if err := keyToFile(tlsCert, keyFile); err != nil { - s.log.Error().Err(err).Msg("Failed to save key file") + logger.Error().Err(err).Msg("Failed to save key file") return } if err := certToFile(tlsCert, crtFile); err != nil { - s.log.Error().Err(err).Msg("Failed to save certificate") + logger.Error().Err(err).Msg("Failed to save certificate") return } - s.log.Info().Str("hostname", hostname).Msg("Saved certificate") + logger.Info().Msg("Saved certificate") } diff --git a/jiggler.go b/jiggler.go index b2463e0ab..6fb49d484 100644 --- a/jiggler.go +++ b/jiggler.go @@ -7,6 +7,7 @@ import ( _ "time/tzdata" "github.com/go-co-op/gocron/v2" + "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/tzdata" ) @@ -42,7 +43,7 @@ func rpcGetJigglerConfig() (JigglerConfig, error) { } func rpcSetJigglerConfig(jigglerConfig JigglerConfig) error { - logger.Info().Msgf("jigglerConfig: %v, %v, %v, %v", jigglerConfig.InactivityLimitSeconds, jigglerConfig.JitterPercentage, jigglerConfig.ScheduleCronTab, jigglerConfig.Timezone) + logging.GetSubsystemLogger("jiggler").Info().Msgf("jigglerConfig: %v, %v, %v, %v", jigglerConfig.InactivityLimitSeconds, jigglerConfig.JitterPercentage, jigglerConfig.ScheduleCronTab, jigglerConfig.Timezone) config.JigglerConfig = &jigglerConfig err := removeExistingCrobJobs(scheduler) if err != nil { @@ -73,7 +74,7 @@ func initJiggler() { ensureConfigLoaded() err := runJigglerCronTab() if err != nil { - logger.Error().Msgf("Error scheduling jiggler crontab: %v", err) + logging.GetSubsystemLogger("jiggler").Error().Msgf("Error scheduling jiggler crontab: %v", err) return } } @@ -85,7 +86,7 @@ func runJigglerCronTab() error { if config.JigglerConfig.Timezone != "" && config.JigglerConfig.Timezone != "UTC" { // Validate timezone before applying if _, err := time.LoadLocation(config.JigglerConfig.Timezone); err != nil { - logger.Warn().Msgf("Invalid timezone '%s', falling back to UTC: %v", config.JigglerConfig.Timezone, err) + logging.GetSubsystemLogger("jiggler").Warn().Msgf("Invalid timezone '%s', falling back to UTC: %v", config.JigglerConfig.Timezone, err) // Don't add TZ prefix, let it run in UTC } else { cronTab = fmt.Sprintf("TZ=%s %s", config.JigglerConfig.Timezone, cronTab) @@ -114,7 +115,7 @@ func runJigglerCronTab() error { s.Start() delta, err := calculateJobDelta(s) jobDelta = delta - logger.Info().Msgf("Time between jiggler runs: %v", jobDelta) + logging.GetSubsystemLogger("jiggler").Info().Msgf("Time between jiggler runs: %v", jobDelta) if err != nil { return err } @@ -129,17 +130,17 @@ func runJiggler() { } inactivitySeconds := config.JigglerConfig.InactivityLimitSeconds timeSinceLastInput := time.Since(gadget.GetLastUserInputTime()) - logger.Debug().Msgf("Time since last user input %v", timeSinceLastInput) + logging.GetSubsystemLogger("jiggler").Debug().Msgf("Time since last user input %v", timeSinceLastInput) if timeSinceLastInput > time.Duration(inactivitySeconds)*time.Second { - logger.Debug().Msg("Jiggling mouse...") + logging.GetSubsystemLogger("jiggler").Debug().Msg("Jiggling mouse...") //TODO: change to rel mouse err := rpcAbsMouseReport(1, 1, 0) if err != nil { - logger.Warn().Msgf("Failed to jiggle mouse: %v", err) + logging.GetSubsystemLogger("jiggler").Warn().Msgf("Failed to jiggle mouse: %v", err) } err = rpcAbsMouseReport(0, 0, 0) if err != nil { - logger.Warn().Msgf("Failed to reset mouse position: %v", err) + logging.GetSubsystemLogger("jiggler").Warn().Msgf("Failed to reset mouse position: %v", err) } } } diff --git a/jsonrpc.go b/jsonrpc.go index b401ac593..d84d000c5 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -19,6 +19,7 @@ import ( "go.bug.st/serial" "github.com/jetkvm/kvm/internal/hidrpc" + "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/usbgadget" "github.com/jetkvm/kvm/internal/utils" ) @@ -56,17 +57,19 @@ type BacklightSettings struct { func writeJSONRPCResponse(response JSONRPCResponse, session *Session) { responseBytes, err := json.Marshal(response) if err != nil { - jsonRpcLogger.Warn().Err(err).Msg("Error marshalling JSONRPC response") + getJsonRPCLogger().Error().Err(err).Msg("Error marshalling JSONRPC response") return } err = session.RPCChannel.SendText(string(responseBytes)) if err != nil { - jsonRpcLogger.Warn().Err(err).Msg("Error sending JSONRPC response") + getJsonRPCLogger().Warn().Err(err).Msg("Error sending JSONRPC response") return } } func writeJSONRPCEvent(event string, params any, session *Session) { + logger := getJsonRPCLogger().With().Str("event", event).Interface("params", params).Logger() + request := JSONRPCEvent{ JSONRPC: "2.0", Method: event, @@ -74,33 +77,39 @@ func writeJSONRPCEvent(event string, params any, session *Session) { } requestBytes, err := json.Marshal(request) if err != nil { - jsonRpcLogger.Warn().Err(err).Msg("Error marshalling JSONRPC event") + logger.Warn().Err(err).Msg("Error marshalling JSONRPC event") return } if session == nil || session.RPCChannel == nil { - jsonRpcLogger.Info().Msg("RPC channel not available") + logger.Info().Msg("RPC channel not available") return } - requestString := string(requestBytes) - scopedLogger := jsonRpcLogger.With(). - Str("data", requestString). - Logger() + if logging.IsTraceLevel(&logger) { + logger = logger.With().Object("requestBytes", utils.ByteSlice(requestBytes)).Logger() + } - scopedLogger.Trace().Msg("sending JSONRPC event") + logger.Trace().Msg("sending JSONRPC event") - err = session.RPCChannel.SendText(requestString) + err = session.RPCChannel.SendText(string(requestBytes)) if err != nil { - scopedLogger.Warn().Err(err).Msg("error sending JSONRPC event") + logger.Warn().Err(err).Msg("error sending JSONRPC event") return } } +func getJsonRPCLogger() *zerolog.Logger { + return logging.GetSubsystemLogger("jsonrpc") +} + func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { + logger := getJsonRPCLogger().With().Int("length", len(message.Data)).Logger() + var request JSONRPCRequest err := json.Unmarshal(message.Data, &request) if err != nil { - jsonRpcLogger.Warn(). + logger. + Warn(). Str("data", string(message.Data)). Err(err). Msg("Error unmarshalling JSONRPC request") @@ -117,12 +126,14 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } - scopedLogger := jsonRpcLogger.With(). + logger = logger. + With(). Str("method", request.Method). Interface("params", request.Params). - Interface("id", request.ID).Logger() + Interface("id", request.ID). + Logger() - scopedLogger.Trace().Msg("Received RPC request") + logger.Trace().Msg("Received RPC request") t := time.Now() handler, ok := rpcHandlers[request.Method] @@ -139,9 +150,9 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } - result, err := callRPCHandler(scopedLogger, handler, request.Params) + result, err := callRPCHandler(&logger, handler, request.Params) if err != nil { - scopedLogger.Error().Err(err).Msg("Error calling RPC handler") + logger.Error().Err(err).Msg("Error calling RPC handler") errorResponse := JSONRPCResponse{ JSONRPC: "2.0", Error: map[string]any{ @@ -155,7 +166,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } - scopedLogger.Trace().Dur("duration", time.Since(t)).Interface("result", result).Msg("RPC handler returned") + logger.Trace().Dur("duration", time.Since(t)).Interface("result", result).Msg("RPC handler returned") response := JSONRPCResponse{ JSONRPC: "2.0", @@ -174,7 +185,7 @@ func rpcGetDeviceID() (string, error) { } func rpcReboot(force bool) error { - logger.Info().Msg("Got reboot request via RPC") + getJsonRPCLogger().Debug().Bool("force", force).Msg("Got reboot request via RPC") return hwReboot(force, nil, 0) } @@ -183,7 +194,7 @@ func rpcGetStreamQualityFactor() (float64, error) { } func rpcSetStreamQualityFactor(factor float64) error { - logger.Info().Float64("factor", factor).Msg("Setting stream quality factor") + getJsonRPCLogger().Debug().Float64("factor", factor).Msg("Setting stream quality factor") err := nativeInstance.VideoSetQualityFactor(factor) if err != nil { return err @@ -201,6 +212,7 @@ func rpcGetAutoUpdateState() (bool, error) { } func rpcSetAutoUpdateState(enabled bool) (bool, error) { + getJsonRPCLogger().Debug().Bool("enabled", enabled).Msg("setting auto-update state") config.AutoUpdateEnabled = enabled if err := SaveConfig(); err != nil { return config.AutoUpdateEnabled, fmt.Errorf("failed to save config: %w", err) @@ -218,9 +230,9 @@ func rpcGetEDID() (string, error) { func rpcSetEDID(edid string) error { if edid == "" { - logger.Info().Msg("Restoring EDID to default") + getJsonRPCLogger().Debug().Msg("Restoring EDID to default") } else { - logger.Info().Str("edid", edid).Msg("Setting EDID") + getJsonRPCLogger().Debug().Str("edid", edid).Msg("Setting EDID") } err := nativeInstance.VideoSetEDID(edid) if err != nil { @@ -238,18 +250,18 @@ func rpcGetVideoLogStatus() (string, error) { } func rpcSetDisplayRotation(params DisplayRotationSettings) error { + getJsonRPCLogger().Debug().Interface("params", params).Msg("setting display rotation") + currentRotation := config.DisplayRotation if currentRotation == params.Rotation { return nil } - err := config.SetDisplayRotation(params.Rotation) - if err != nil { + if err := config.SetDisplayRotation(params.Rotation); err != nil { return err } - _, err = nativeInstance.DisplaySetRotation(config.GetDisplayRotation()) - if err != nil { + if _, err := nativeInstance.DisplaySetRotation(config.GetDisplayRotation()); err != nil { return err } @@ -257,7 +269,7 @@ func rpcSetDisplayRotation(params DisplayRotationSettings) error { return fmt.Errorf("failed to save config: %w", err) } - return err + return nil } func rpcGetDisplayRotation() (*DisplayRotationSettings, error) { @@ -267,24 +279,25 @@ func rpcGetDisplayRotation() (*DisplayRotationSettings, error) { } func rpcSetBacklightSettings(params BacklightSettings) error { - blConfig := params + logger := getJsonRPCLogger().With().Interface("params", params).Logger() + logger.Debug().Msg("setting backlight settings") // NOTE: by default, the frontend limits the brightness to 64, as that's what the device originally shipped with. - if blConfig.MaxBrightness > 255 || blConfig.MaxBrightness < 0 { + if params.MaxBrightness > 255 || params.MaxBrightness < 0 { return fmt.Errorf("maxBrightness must be between 0 and 255") } - if blConfig.DimAfter < 0 { + if params.DimAfter < 0 { return fmt.Errorf("dimAfter must be a positive integer") } - if blConfig.OffAfter < 0 { + if params.OffAfter < 0 { return fmt.Errorf("offAfter must be a positive integer") } - config.DisplayMaxBrightness = blConfig.MaxBrightness - config.DisplayDimAfterSec = blConfig.DimAfter - config.DisplayOffAfterSec = blConfig.OffAfter + config.DisplayMaxBrightness = params.MaxBrightness + config.DisplayDimAfterSec = params.DimAfter + config.DisplayOffAfterSec = params.OffAfter if err := SaveConfig(); err != nil { return fmt.Errorf("failed to save config: %w", err) @@ -327,23 +340,21 @@ type SSHKeyState struct { } func rpcGetDevModeState() (DevModeState, error) { - devModeEnabled := false - if _, err := os.Stat(devModeFile); err != nil { - if !os.IsNotExist(err) { - return DevModeState{}, fmt.Errorf("error checking dev mode file: %w", err) - } + if _, err := os.Stat(devModeFile); err != nil && !os.IsNotExist(err) { + return DevModeState{}, fmt.Errorf("error checking dev mode file: %w", err) } else { - devModeEnabled = true + return DevModeState{ + Enabled: err == nil, + }, nil } - - return DevModeState{ - Enabled: devModeEnabled, - }, nil } func rpcSetDevModeState(enabled bool) error { + logger := getJsonRPCLogger().With().Bool("enabled", enabled).Logger() + logger.Debug().Msg("setting dev mode state") + if enabled { - if _, err := os.Stat(devModeFile); os.IsNotExist(err) { + if _, err := os.Stat(devModeFile); err != nil && os.IsNotExist(err) { if err := os.MkdirAll(filepath.Dir(devModeFile), 0755); err != nil { return fmt.Errorf("failed to create directory for devmode file: %w", err) } @@ -367,10 +378,11 @@ func rpcSetDevModeState(enabled bool) error { } } + logger.Info().Msg("restarting dropbear service to apply dev mode changes") cmd := exec.Command("dropbear.sh") output, err := cmd.CombinedOutput() if err != nil { - logger.Warn().Err(err).Bytes("output", output).Msg("Failed to start/stop SSH") + logger.Warn().Err(err).Object("output", utils.ByteSlice(output)).Msg("Failed to start/stop SSH") return fmt.Errorf("failed to start/stop SSH, you may need to reboot for changes to take effect") } @@ -379,15 +391,17 @@ func rpcSetDevModeState(enabled bool) error { func rpcGetSSHKeyState() (string, error) { keyData, err := os.ReadFile(sshKeyFile) - if err != nil { - if !os.IsNotExist(err) { - return "", fmt.Errorf("error reading SSH key file: %w", err) - } + if err != nil && !os.IsNotExist(err) { + return "", fmt.Errorf("error reading SSH key file: %w", err) } return string(keyData), nil } func rpcSetSSHKeyState(sshKey string) error { + // Log the action but avoid logging the actual key for security reasons + logger := logging.GetSubsystemLogger("jsonrpc") + logger.Debug().Msg("setting SSH key") + if sshKey == "" { // Remove SSH key file if empty string is provided if err := os.Remove(sshKeyFile); err != nil && !os.IsNotExist(err) { @@ -419,6 +433,8 @@ func rpcGetTLSState() TLSState { } func rpcSetTLSState(state TLSState) error { + getJsonRPCLogger().Debug().Interface("state", state).Msg("setting TLS state") + err := setTLSState(state) if err != nil { return fmt.Errorf("failed to set TLS state: %w", err) @@ -437,7 +453,7 @@ type RPCHandler struct { } // call the handler but recover from a panic to ensure our RPC thread doesn't collapse on malformed calls -func callRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[string]any) (result any, err error) { +func callRPCHandler(logger *zerolog.Logger, handler RPCHandler, params map[string]any) (result any, err error) { // Use defer to recover from a panic defer func() { if r := recover(); r != nil { @@ -447,6 +463,7 @@ func callRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[string } else { err = fmt.Errorf("panic occurred: %v", r) } + logger.Error().Err(err).Msg("Recovered from panic in RPC handler") } }() @@ -455,7 +472,7 @@ func callRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[string return result, err // do not combine these two lines into one, as it breaks the above defer function's setting of err } -func riskyCallRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[string]any) (any, error) { +func riskyCallRPCHandler(logger *zerolog.Logger, handler RPCHandler, params map[string]any) (any, error) { handlerValue := reflect.ValueOf(handler.Func) handlerType := handlerValue.Type() @@ -467,9 +484,7 @@ func riskyCallRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[s paramNames := handler.Params // Get the parameter names from the RPCHandler if len(paramNames) != numParams { - err := fmt.Errorf("mismatch between handler parameters (%d) and defined parameter names (%d)", numParams, len(paramNames)) - logger.Error().Strs("paramNames", paramNames).Err(err).Msg("Cannot call RPC handler") - return nil, err + return nil, fmt.Errorf("mismatch between handler parameters (%d) and defined parameter names (%d)", numParams, len(paramNames)) } args := make([]reflect.Value, numParams) @@ -479,9 +494,7 @@ func riskyCallRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[s paramName := paramNames[i] paramValue, ok := params[paramName] if !ok { - err := fmt.Errorf("missing parameter: %s", paramName) - logger.Error().Err(err).Msg("Cannot marshal arguments for RPC handler") - return nil, err + return nil, fmt.Errorf("missing parameter: %s", paramName) } convertedValue := reflect.ValueOf(paramValue) @@ -530,7 +543,7 @@ func riskyCallRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[s } } - logger.Trace().Msg("Calling RPC handler") + logger.Trace().Interface("args", args).Msg("Calling RPC handler") results := handlerValue.Call(args) if len(results) == 0 { @@ -567,7 +580,8 @@ func asError(value reflect.Value) (bool, error) { } func rpcSetMassStorageMode(mode string) (string, error) { - logger.Info().Str("mode", mode).Msg("Setting mass storage mode") + logger := getJsonRPCLogger().With().Str("mode", mode).Logger() + logger.Debug().Msg("Setting mass storage mode") var cdrom bool switch mode { case "cdrom": @@ -575,18 +589,18 @@ func rpcSetMassStorageMode(mode string) (string, error) { case "file": cdrom = false default: - logger.Info().Str("mode", mode).Msg("Invalid mode provided") + logger.Info().Msg("Invalid mode provided") return "", fmt.Errorf("invalid mode: %s", mode) } - logger.Info().Str("mode", mode).Msg("Setting mass storage mode") + logger.Info().Msg("Setting mass storage mode") err := setMassStorageMode(cdrom) if err != nil { return "", fmt.Errorf("failed to set mass storage mode: %w", err) } - logger.Info().Str("mode", mode).Msg("Mass storage mode set") + logger.Info().Msg("Mass storage mode set") // Get the updated mode after setting return rpcGetMassStorageMode() @@ -614,6 +628,7 @@ func rpcGetUsbEmulationState() (bool, error) { } func rpcSetUsbEmulationState(enabled bool) error { + getJsonRPCLogger().Debug().Bool("enabled", enabled).Msg("setting USB emulation state") if enabled { return gadget.BindUDC() } else { @@ -627,6 +642,7 @@ func rpcGetUsbConfig() (usbgadget.Config, error) { } func rpcSetUsbConfig(usbConfig usbgadget.Config) error { + getJsonRPCLogger().Debug().Interface("usbConfig", usbConfig).Msg("setting USB emulation state") LoadConfig() config.UsbConfig = &usbConfig gadget.SetGadgetConfig(config.UsbConfig) @@ -650,16 +666,28 @@ func rpcSetWakeOnLanDevices(params SetWakeOnLanDevicesParams) error { } func rpcResetConfig() error { + getJsonRPCLogger().Debug().Msg("resetting configuration to default") defaultConfig := getDefaultConfig() config = &defaultConfig if err := SaveConfig(); err != nil { return fmt.Errorf("failed to reset config: %w", err) } - logger.Info().Msg("Configuration reset to default") + getJsonRPCLogger().Info().Msg("Configuration reset to default") return nil } +func rpcGetLogLevel() string { + return config.DefaultLogLevel +} + +func rpcSetLogLevel(level string) error { + getJsonRPCLogger().Debug().Str("level", level).Msg("setting log level") + + config.DefaultLogLevel = level + return SaveConfig() +} + type DCPowerState struct { IsOn bool `json:"isOn"` Voltage float64 `json:"voltage"` @@ -673,7 +701,7 @@ func rpcGetDCPowerState() (DCPowerState, error) { } func rpcSetDCPowerState(enabled bool) error { - logger.Info().Bool("enabled", enabled).Msg("Setting DC power state") + getJsonRPCLogger().Debug().Bool("enabled", enabled).Msg("setting DC power state") err := setDCPowerState(enabled) if err != nil { return fmt.Errorf("failed to set DC power state: %w", err) @@ -682,7 +710,7 @@ func rpcSetDCPowerState(enabled bool) error { } func rpcSetDCRestoreState(state int) error { - logger.Info().Int("state", state).Msg("Setting DC restore state") + getJsonRPCLogger().Debug().Int("state", state).Msg("setting DC restore state") err := setDCRestoreState(state) if err != nil { return fmt.Errorf("failed to set DC restore state: %w", err) @@ -695,6 +723,7 @@ func rpcGetActiveExtension() (string, error) { } func rpcSetActiveExtension(extensionId string) error { + getJsonRPCLogger().Debug().Str("extensionId", extensionId).Msg("setting active extension") if config.ActiveExtension == extensionId { return nil } @@ -718,7 +747,8 @@ func rpcSetActiveExtension(extensionId string) error { } func rpcSetATXPowerAction(action string) error { - logger.Debug().Str("action", action).Msg("Executing ATX power action") + logger := getJsonRPCLogger().With().Str("action", action).Logger() + logger.Debug().Msg("Executing ATX power action") switch action { case "power-short": logger.Debug().Msg("Simulating short power button press") @@ -790,6 +820,9 @@ func rpcGetSerialSettings() (SerialSettings, error) { var serialPortMode = defaultMode func rpcSetSerialSettings(settings SerialSettings) error { + logger := getJsonRPCLogger().With().Interface("settings", settings).Logger() + logger.Debug().Msg("setting serial settings") + baudRate, err := strconv.Atoi(settings.BaudRate) if err != nil { return fmt.Errorf("invalid baud rate: %v", err) @@ -833,7 +866,9 @@ func rpcSetSerialSettings(settings SerialSettings) error { Parity: parity, } - _ = port.SetMode(serialPortMode) + if err := port.SetMode(serialPortMode); err != nil { + logger.Error().Err(err).Interface("serialPortMode", serialPortMode).Msg("setting port mode") + } return nil } @@ -853,12 +888,14 @@ func updateUsbRelatedConfig() error { } func rpcSetUsbDevices(usbDevices usbgadget.Devices) error { + getJsonRPCLogger().Debug().Interface("usbDevices", usbDevices).Msg("setting USB devices") config.UsbDevices = &usbDevices gadget.SetGadgetDevices(config.UsbDevices) return updateUsbRelatedConfig() } func rpcSetUsbDeviceState(device string, enabled bool) error { + getJsonRPCLogger().Debug().Str("device", device).Bool("enabled", enabled).Msg("setting USB device state") switch device { case "absoluteMouse": config.UsbDevices.AbsoluteMouse = enabled @@ -876,6 +913,8 @@ func rpcSetUsbDeviceState(device string, enabled bool) error { } func rpcSetCloudUrl(apiUrl string, appUrl string) error { + getJsonRPCLogger().Debug().Str("apiUrl", apiUrl).Str("appUrl", appUrl).Msg("setting cloud urls") + currentCloudURL := config.CloudURL config.CloudURL = apiUrl config.CloudAppURL = appUrl @@ -900,6 +939,8 @@ func rpcGetKeyboardLayout() (string, error) { } func rpcSetKeyboardLayout(layout string) error { + getJsonRPCLogger().Debug().Str("layout", layout).Msg("setting keyboard layout") + config.KeyboardLayout = layout if err := SaveConfig(); err != nil { return fmt.Errorf("failed to save config: %w", err) @@ -919,6 +960,7 @@ type KeyboardMacrosParams struct { } func setKeyboardMacros(params KeyboardMacrosParams) (any, error) { + getJsonRPCLogger().Debug().Interface("params", params).Msg("setting keyboard macros") if params.Macros == nil { return nil, fmt.Errorf("missing or invalid macros parameter") } @@ -1005,6 +1047,7 @@ func rpcGetLocalLoopbackOnly() (bool, error) { } func rpcSetLocalLoopbackOnly(enabled bool) error { + getJsonRPCLogger().Debug().Bool("enabled", enabled).Msg("setting local loopback only mode") // Check if the setting is actually changing if config.LocalLoopbackOnly == enabled { return nil @@ -1025,15 +1068,16 @@ var ( ) // cancelKeyboardMacro cancels any ongoing keyboard macro execution -func cancelKeyboardMacro() { +func cancelKeyboardMacro() error { keyboardMacroLock.Lock() defer keyboardMacroLock.Unlock() if keyboardMacroCancel != nil { keyboardMacroCancel() - logger.Info().Msg("canceled keyboard macro") keyboardMacroCancel = nil + getJsonRPCLogger().Info().Msg("cancelled keyboard macro") } + return nil } func setKeyboardMacroCancel(cancel context.CancelFunc) { @@ -1044,7 +1088,8 @@ func setKeyboardMacroCancel(cancel context.CancelFunc) { } func rpcExecuteKeyboardMacro(macro []hidrpc.KeyboardMacroStep) error { - cancelKeyboardMacro() + getJsonRPCLogger().Debug().Int("steps", len(macro)).Msg("executing keyboard macro") + _ = cancelKeyboardMacro() ctx, cancel := context.WithCancel(context.Background()) setKeyboardMacroCancel(cancel) @@ -1055,23 +1100,23 @@ func rpcExecuteKeyboardMacro(macro []hidrpc.KeyboardMacroStep) error { } if currentSession != nil { - currentSession.reportHidRPCKeyboardMacroState(s) + go currentSession.reportHidRPCKeyboardMacroState(s) } err := rpcDoExecuteKeyboardMacro(ctx, macro) - setKeyboardMacroCancel(nil) s.State = false if currentSession != nil { - currentSession.reportHidRPCKeyboardMacroState(s) + go currentSession.reportHidRPCKeyboardMacroState(s) } return err } -func rpcCancelKeyboardMacro() { - cancelKeyboardMacro() +func rpcCancelKeyboardMacro() error { + getJsonRPCLogger().Debug().Msg("cancelling keyboard macro") + return cancelKeyboardMacro() } var keyboardClearStateKeys = make([]byte, hidrpc.HidKeyBufferSize) @@ -1081,7 +1126,8 @@ func isClearKeyStep(step hidrpc.KeyboardMacroStep) bool { } func rpcDoExecuteKeyboardMacro(ctx context.Context, macro []hidrpc.KeyboardMacroStep) error { - logger.Debug().Interface("macro", macro).Msg("Executing keyboard macro") + logger := getJsonRPCLogger().With().Interface("macro", macro).Logger() + logger.Debug().Msg("Executing keyboard macro") for i, step := range macro { delay := time.Duration(step.Delay) * time.Millisecond @@ -1108,7 +1154,7 @@ func rpcDoExecuteKeyboardMacro(ctx context.Context, macro []hidrpc.KeyboardMacro logger.Warn().Err(err).Msg("failed to reset keyboard state") } - logger.Debug().Int("step", i).Msg("Keyboard macro cancelled during sleep") + logger.Debug().Int("step", i).Msg("Keyboard macro cancelled during step") return ctx.Err() } } @@ -1210,4 +1256,6 @@ var rpcHandlers = map[string]RPCHandler{ "getFailSafeLogs": {Func: rpcGetFailsafeLogs}, "getPublicIPAddresses": {Func: rpcGetPublicIPAddresses, Params: []string{"refresh"}}, "checkPublicIPAddresses": {Func: rpcCheckPublicIPAddresses}, + "getLogLevel": {Func: rpcGetLogLevel}, + "setLogLevel": {Func: rpcSetLogLevel, Params: []string{"level"}}, } diff --git a/log.go b/log.go deleted file mode 100644 index 9cd9188e6..000000000 --- a/log.go +++ /dev/null @@ -1,34 +0,0 @@ -package kvm - -import ( - "github.com/jetkvm/kvm/internal/logging" - "github.com/rs/zerolog" -) - -func ErrorfL(l *zerolog.Logger, format string, err error, args ...any) error { - return logging.ErrorfL(l, format, err, args...) -} - -var ( - logger = logging.GetSubsystemLogger("jetkvm") - failsafeLogger = logging.GetSubsystemLogger("failsafe") - networkLogger = logging.GetSubsystemLogger("network") - cloudLogger = logging.GetSubsystemLogger("cloud") - websocketLogger = logging.GetSubsystemLogger("websocket") - webrtcLogger = logging.GetSubsystemLogger("webrtc") - nativeLogger = logging.GetSubsystemLogger("native") - nbdLogger = logging.GetSubsystemLogger("nbd") - timesyncLogger = logging.GetSubsystemLogger("timesync") - jsonRpcLogger = logging.GetSubsystemLogger("jsonrpc") - hidRPCLogger = logging.GetSubsystemLogger("hidrpc") - watchdogLogger = logging.GetSubsystemLogger("watchdog") - websecureLogger = logging.GetSubsystemLogger("websecure") - otaLogger = logging.GetSubsystemLogger("ota") - serialLogger = logging.GetSubsystemLogger("serial") - terminalLogger = logging.GetSubsystemLogger("terminal") - displayLogger = logging.GetSubsystemLogger("display") - wolLogger = logging.GetSubsystemLogger("wol") - usbLogger = logging.GetSubsystemLogger("usb") - // external components - ginLogger = logging.GetSubsystemLogger("gin") -) diff --git a/main.go b/main.go index 83d337d7c..e84857c59 100644 --- a/main.go +++ b/main.go @@ -9,9 +9,10 @@ import ( "syscall" "time" + "github.com/jetkvm/kvm/internal/logging" + "github.com/erikdubbelboer/gspt" "github.com/gwatts/rootcerts" - "github.com/jetkvm/kvm/internal/ota" ) var appCtx context.Context @@ -27,13 +28,13 @@ func setProcTitle(status string) { func Main() { setProcTitle("starting") - + logger := logging.GetSubsystemLogger("jetkvm-main") logger.Log().Msg("JetKVM Starting Up") checkFailsafeReason() if failsafeModeActive { procPrefix = "jetkvm: [app+failsafe]" - logger.Warn().Str("reason", failsafeModeReason).Msg("failsafe mode activated") + logger.Warn().Str("subcomponent", "failsafe").Str("reason", failsafeModeReason).Msg("failsafe mode activated") } LoadConfig() @@ -104,50 +105,12 @@ func Main() { if err := initImagesFolder(); err != nil { logger.Warn().Err(err).Msg("failed to init images folder") } + initJiggler() // start video sleep mode timer startVideoSleepModeTicker() - go func() { - // wait for 15 minutes before starting auto-update checks - // this is to avoid interfering with initial setup processes - // and to ensure the system is stable before checking for updates - time.Sleep(15 * time.Minute) - - for { - logger.Info().Bool("auto_update_enabled", config.AutoUpdateEnabled).Msg("auto-update check") - if !config.AutoUpdateEnabled { - logger.Debug().Msg("auto-update disabled") - time.Sleep(5 * time.Minute) // we'll check if auto-updates are enabled in five minutes - continue - } - - if currentSession != nil { - logger.Debug().Msg("skipping update since a session is active") - time.Sleep(1 * time.Minute) - continue - } - - if isTimeSyncNeeded() || !timeSync.IsSyncSuccess() { - logger.Debug().Msg("system time is not synced, will retry in 30 seconds") - time.Sleep(30 * time.Second) - continue - } - - includePreRelease := config.IncludePreRelease - err = otaState.TryUpdate(context.Background(), ota.UpdateParams{ - DeviceID: GetDeviceID(), - IncludePreRelease: includePreRelease, - }) - if err != nil { - logger.Warn().Err(err).Msg("failed to auto update") - } - - time.Sleep(1 * time.Hour) - } - }() - //go RunFuseServer() go RunWebServer() @@ -159,26 +122,25 @@ func Main() { // As websocket client already checks if the cloud token is set, we can start it here. go RunWebsocketClient() - initPublicIPState() + initPublicIPState() initSerialPort() setProcTitle("ready") + logger.Log().Msg("JetKVM ready") + + go RunAutoUpdateCheck() sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - logger.Log().Msg("JetKVM Shutting Down") - //if fuseServer != nil { - // err := setMassStorageImage(" ") - // if err != nil { - // logger.Infof("Failed to unmount mass storage image: %v", err) - // } - // err = fuseServer.Unmount() - // if err != nil { - // logger.Infof("Failed to unmount fuse: %v", err) - // } + if err := rpcUnmountImage(); err != nil { + logger.Warn().Err(err).Msg("failed to eject virtual media on shutdown") + } + gadget.Close() + + logger.Log().Msg("JetKVM Shutting Down") // os.Exit(0) } diff --git a/mdns.go b/mdns.go index c197d16fa..788a8c74b 100644 --- a/mdns.go +++ b/mdns.go @@ -15,7 +15,6 @@ func initMdns() error { } m, err := mdns.NewMDNS(&mdns.MDNSOptions{ - Logger: logger, LocalNames: options.LocalNames, ListenOptions: options.ListenOptions, }) @@ -25,6 +24,5 @@ func initMdns() error { // do not start the server yet, as we need to wait for the network state to be set mDNS = m - return nil } diff --git a/native.go b/native.go index 5f09ef83a..0ab34c5f8 100644 --- a/native.go +++ b/native.go @@ -5,8 +5,9 @@ import ( "sync" "time" - "github.com/Masterminds/semver/v3" "github.com/jetkvm/kvm/internal/native" + + "github.com/Masterminds/semver/v3" "github.com/pion/webrtc/v4/pkg/media" ) @@ -16,6 +17,8 @@ var ( ) func initNative(systemVersion *semver.Version, appVersion *semver.Version) { + nativeLogger := native.GetNativeLogger() + if failsafeModeActive { nativeInstance = &native.EmptyNativeInterface{} nativeLogger.Warn().Msg("failsafe mode active, using empty native interface") diff --git a/network.go b/network.go index eb14c70d6..fac4de970 100644 --- a/network.go +++ b/network.go @@ -9,9 +9,11 @@ import ( "time" "github.com/jetkvm/kvm/internal/confparser" + "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/mdns" "github.com/jetkvm/kvm/internal/network/types" "github.com/jetkvm/kvm/internal/ota" + "github.com/jetkvm/kvm/pkg/myip" "github.com/jetkvm/kvm/pkg/nmlite" "github.com/jetkvm/kvm/pkg/nmlite/link" @@ -84,7 +86,7 @@ func restartMdns() { } if err := mDNS.SetOptions(options); err != nil { - networkLogger.Error().Err(err).Msg("failed to restart mDNS") + logging.GetSubsystemLogger("network").Error().Err(err).Msg("failed to restart mDNS") } } @@ -92,6 +94,7 @@ func triggerTimeSyncOnNetworkStateChange() { if timeSync == nil { return } + logger := logging.GetSubsystemLogger("network").With().Str("subcomponent", "timesync").Logger() // set the NTP servers from the network manager if networkManager != nil { @@ -99,14 +102,15 @@ func triggerTimeSyncOnNetworkStateChange() { for i, server := range networkManager.NTPServers() { ntpServers[i] = server.String() } - networkLogger.Info().Strs("ntpServers", ntpServers).Msg("setting NTP servers from network manager") + logger = logger.With().Interface("ntpServers", networkManager.NTPServers()).Logger() //TODO IPAddrs + logger.Info().Msg("setting NTP servers from network manager") timeSync.SetDhcpNtpAddresses(ntpServers) } // sync time go func() { if err := timeSync.Sync(); err != nil { - networkLogger.Error().Err(err).Msg("failed to sync time after network state change") + logger.Error().Err(err).Msg("failed to sync time after network state change") } }() } @@ -118,16 +122,18 @@ func setPublicIPReadyState(ipv4Ready, ipv6Ready bool) { publicIPState.SetIPv4AndIPv6(ipv4Ready, ipv6Ready) } -func networkStateChanged(_ string, state types.InterfaceState) { +func networkStateChanged(iface string, state types.InterfaceState) { + logger := logging.GetSubsystemLogger("network").With().Str("interface", iface).Logger() + // do not block the main thread go waitCtrlAndRequestDisplayUpdate(true, "network_state_changed") if currentSession != nil { - writeJSONRPCEvent("networkState", state.ToRpcInterfaceState(), currentSession) + go writeJSONRPCEvent("networkState", state.ToRpcInterfaceState(), currentSession) } if state.Online { - networkLogger.Info().Msg("network state changed to online, triggering time sync") + logger.Info().Msg("network state changed to online, triggering time sync") triggerTimeSyncOnNetworkStateChange() } @@ -145,16 +151,18 @@ func validateNetworkConfig() { return } - networkLogger.Error().Err(err).Msg("failed to validate config, reverting to default config") + logger := logging.GetSubsystemLogger("network") + logger.Error().Err(err).Msg("failed to validate config, reverting to default config") + if err := SaveBackupConfig(); err != nil { - networkLogger.Error().Err(err).Msg("failed to save backup config") + logger.Error().Err(err).Msg("failed to save backup config") } // do not use a pointer to the default config // it has been already changed during LoadConfig config.NetworkConfig = &(types.NetworkConfig{}) if err := SaveConfig(); err != nil { - networkLogger.Error().Err(err).Msg("failed to save config") + logger.Error().Err(err).Msg("failed to save config") } } @@ -166,9 +174,17 @@ func initNetwork() error { nc := config.NetworkConfig - nm := nmlite.NewNetworkManager(context.Background(), networkLogger) - networkLogger.Info().Interface("networkConfig", nc).Str("hostname", nc.Hostname.String).Str("domain", nc.Domain.String).Msg("initializing network manager") - _ = setHostname(nm, nc.Hostname.String, nc.Domain.String) + logger := logging.GetSubsystemLogger("network"). + With(). + Str("hostname", nc.Hostname.String). + Str("domain", nc.Domain.String). + Logger() + logger.Info().Interface("networkConfig", nc).Msg("initializing network manager") + nm := nmlite.NewNetworkManager(context.Background()) + + if err := setHostname(nm, nc.Hostname.String, nc.Domain.String); err != nil { + logger.Error().Err(err).Msg("failed to set hostname") + } nm.SetOnInterfaceStateChange(networkStateChanged) if err := nm.AddInterface(NetIfName, nc); err != nil { return fmt.Errorf("failed to add interface: %w", err) @@ -176,7 +192,6 @@ func initNetwork() error { _ = nm.CleanUpLegacyDHCPClients() networkManager = nm - return nil } @@ -186,7 +201,6 @@ func initPublicIPState() { // but it will be initialized anyway to avoid nil pointer dereferences ps := myip.NewPublicIPState(&myip.PublicIPStateConfig{ - Logger: networkLogger, CloudflareEndpoint: config.CloudURL, APIEndpoint: "", IPv4: false, @@ -229,7 +243,8 @@ func setHostname(nm *nmlite.NetworkManager, hostname, domain string) error { func shouldRebootForNetworkChange(oldConfig, newConfig *types.NetworkConfig) (rebootRequired bool, postRebootAction *ota.PostRebootAction) { oldDhcpClient := oldConfig.DHCPClient.String - l := networkLogger.With(). + logger := logging.GetSubsystemLogger("network"). + With(). Interface("old", oldConfig). Interface("new", newConfig). Logger() @@ -237,7 +252,7 @@ func shouldRebootForNetworkChange(oldConfig, newConfig *types.NetworkConfig) (re // DHCP client change always requires reboot if newConfig.DHCPClient.String != oldDhcpClient { rebootRequired = true - l.Info().Msg("DHCP client changed, reboot required") + logger.Info().Msg("DHCP client changed, reboot required") return rebootRequired, postRebootAction } @@ -247,14 +262,14 @@ func shouldRebootForNetworkChange(oldConfig, newConfig *types.NetworkConfig) (re // IPv4 mode change requires reboot if newIPv4Mode != oldIPv4Mode { rebootRequired = true - l.Info().Msg("IPv4 mode changed with udhcpc, reboot required") + logger.Info().Msg("IPv4 mode changed with udhcpc, reboot required") if newIPv4Mode == "static" && oldIPv4Mode != "static" { postRebootAction = &ota.PostRebootAction{ HealthCheck: fmt.Sprintf("//%s/device/status", newConfig.IPv4Static.Address.String), RedirectTo: fmt.Sprintf("//%s", newConfig.IPv4Static.Address.String), } - l.Info().Interface("postRebootAction", postRebootAction).Msg("IPv4 mode changed to static, reboot required") + logger.Info().Interface("postRebootAction", postRebootAction).Msg("IPv4 mode changed to static, reboot required") } return rebootRequired, postRebootAction @@ -272,7 +287,7 @@ func shouldRebootForNetworkChange(oldConfig, newConfig *types.NetworkConfig) (re RedirectTo: fmt.Sprintf("//%s", newConfig.IPv4Static.Address.String), } - l.Info().Interface("postRebootAction", postRebootAction).Msg("IPv4 static config changed, reboot required") + logger.Info().Interface("postRebootAction", postRebootAction).Msg("IPv4 static config changed, reboot required") } return rebootRequired, postRebootAction @@ -281,12 +296,12 @@ func shouldRebootForNetworkChange(oldConfig, newConfig *types.NetworkConfig) (re // IPv6 mode change requires reboot when using udhcpc if newConfig.IPv6Mode.String != oldConfig.IPv6Mode.String && oldDhcpClient == "udhcpc" { rebootRequired = true - l.Info().Msg("IPv6 mode changed with udhcpc, reboot required") + logger.Info().Msg("IPv6 mode changed with udhcpc, reboot required") } if newConfig.Hostname.String != oldConfig.Hostname.String { rebootRequired = true - l.Info().Msg("Hostname changed, reboot required") + logger.Info().Msg("Hostname changed, reboot required") } return rebootRequired, postRebootAction @@ -304,19 +319,20 @@ func rpcGetNetworkSettings() *RpcNetworkSettings { func rpcSetNetworkSettings(settings RpcNetworkSettings) (*RpcNetworkSettings, error) { netConfig := settings.ToNetworkConfig() - l := networkLogger.With(). + logger := logging.GetSubsystemLogger("network"). + With(). Str("interface", NetIfName). Interface("newConfig", netConfig). Logger() - l.Debug().Msg("setting new config") + logger.Debug().Msg("setting new config") // Check if reboot is needed rebootRequired, postRebootAction := shouldRebootForNetworkChange(config.NetworkConfig, netConfig) // If reboot required, send willReboot event before applying network config if rebootRequired { - l.Info().Msg("Sending willReboot event before applying network config") + logger.Info().Msg("Sending willReboot event before applying network config") writeJSONRPCEvent("willReboot", postRebootAction, currentSession) } @@ -326,7 +342,7 @@ func rpcSetNetworkSettings(settings RpcNetworkSettings) (*RpcNetworkSettings, er if s != nil { return nil, s } - l.Debug().Msg("new config applied") + logger.Debug().Msg("new config applied") newConfig, err := networkManager.GetInterfaceConfig(NetIfName) if err != nil { @@ -334,13 +350,13 @@ func rpcSetNetworkSettings(settings RpcNetworkSettings) (*RpcNetworkSettings, er } config.NetworkConfig = newConfig - l.Debug().Msg("saving new config") + logger.Debug().Msg("saving new config") if err := SaveConfig(); err != nil { return nil, err } if rebootRequired { - l.Info().Msg("Rebooting due to network changes") + logger.Info().Msg("Rebooting due to network changes") if err := hwReboot(true, postRebootAction, 0); err != nil { return nil, err } diff --git a/ota.go b/ota.go index ef7f9c21a..826cdb373 100644 --- a/ota.go +++ b/ota.go @@ -6,9 +6,11 @@ import ( "net/http" "os" "strings" + "time" "github.com/Masterminds/semver/v3" "github.com/google/uuid" + "github.com/jetkvm/kvm/internal/ota" ) @@ -17,29 +19,30 @@ var builtAppVersion = "0.1.0+dev" var otaState *ota.State func initOta() { - otaState = ota.NewState(ota.Options{ - Logger: otaLogger, - ReleaseAPIEndpoint: config.GetUpdateAPIURL(), - GetHTTPClient: func() ota.HttpClient { - transport := http.DefaultTransport.(*http.Transport).Clone() - transport.Proxy = config.NetworkConfig.GetTransportProxyFunc() - - client := &http.Client{ - Transport: transport, - } - return client - }, - GetLocalVersion: GetLocalVersion, - HwReboot: hwReboot, - ResetConfig: rpcResetConfig, - SetAutoUpdate: rpcSetAutoUpdateState, - OnStateUpdate: func(state *ota.RPCState) { - triggerOTAStateUpdate(state) + otaState = ota.NewState( + ota.Options{ + ReleaseAPIEndpoint: config.GetUpdateAPIURL(), + GetHTTPClient: func() ota.HttpClient { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Proxy = config.NetworkConfig.GetTransportProxyFunc() + + client := &http.Client{ + Transport: transport, + } + return client + }, + GetLocalVersion: GetLocalVersion, + HwReboot: hwReboot, + ResetConfig: rpcResetConfig, + SetAutoUpdate: rpcSetAutoUpdateState, + OnStateUpdate: func(state *ota.RPCState) { + triggerOTAStateUpdate(state) + }, + OnProgressUpdate: func(progress float32) { + writeJSONRPCEvent("otaProgress", progress, currentSession) + }, }, - OnProgressUpdate: func(progress float32) { - writeJSONRPCEvent("otaProgress", progress, currentSession) - }, - }) + ) } func triggerOTAStateUpdate(state *ota.RPCState) { @@ -50,6 +53,8 @@ func triggerOTAStateUpdate(state *ota.RPCState) { if state == nil { state = otaState.ToRPCState() } + ota.GetOtaLogger().Trace().Interface("state", state).Msg("Reporting OTA state") + writeJSONRPCEvent("otaState", state, currentSession) }() } @@ -80,11 +85,13 @@ func GetLocalVersion() (systemVersion *semver.Version, appVersion *semver.Versio } func getUpdateStatus(includePreRelease bool) (*ota.UpdateStatus, error) { - updateStatus, err := otaState.GetUpdateStatus(context.Background(), ota.UpdateParams{ - DeviceID: GetDeviceID(), - IncludePreRelease: includePreRelease, - RequestID: uuid.New().String(), - }) + updateStatus, err := otaState.GetUpdateStatus( + context.Background(), + ota.UpdateParams{ + DeviceID: GetDeviceID(), + IncludePreRelease: includePreRelease, + RequestID: uuid.New().String(), + }) // to ensure backwards compatibility, // if there's an error, we won't return an error, but we will set the error field @@ -100,7 +107,7 @@ func getUpdateStatus(includePreRelease bool) (*ota.UpdateStatus, error) { updateStatus.WillDisableAutoUpdate = config.AutoUpdateEnabled } - otaLogger.Info().Interface("updateStatus", updateStatus).Msg("Update status") + ota.GetOtaLogger().Info().Interface("updateStatus", updateStatus).Msg("Update status") return updateStatus, nil } @@ -160,6 +167,7 @@ func rpcCheckUpdateComponents(params updateParams, includePreRelease bool) (*ota IncludePreRelease: includePreRelease, Components: params.Components, } + info, err := otaState.GetUpdateStatus(context.Background(), updateParams) if err != nil { return nil, fmt.Errorf("failed to check update: %w", err) @@ -178,8 +186,53 @@ func rpcTryUpdateComponents(params updateParams, includePreRelease bool, resetCo go func() { err := otaState.TryUpdate(context.Background(), updateParams) if err != nil { - otaLogger.Warn().Err(err).Msg("failed to try update") + ota.GetOtaLogger().Warn().Err(err).Msg("failed to try update") } }() return nil } + +func RunAutoUpdateCheck() { + // initially wait for 15 minutes before starting auto-update checks + // to avoid interfering with initial setup processes and to ensure + // the system is stable before checking for updates + ticker := time.NewTicker(15 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-appCtx.Done(): + return + case <-ticker.C: + logger := ota.GetOtaLogger() + logger.Info().Bool("auto_update_enabled", config.AutoUpdateEnabled).Msg("auto-update check") + + if !config.AutoUpdateEnabled { + logger.Info().Msg("auto-update disabled, waiting 5 minutes") + ticker.Reset(5 * time.Minute) // we'll check if auto-updates are enabled in five minutes + continue + } + + if currentSession != nil { + logger.Info().Msg("skipping update since a session is active for one minute") + ticker.Reset(1 * time.Minute) + continue + } + + if isTimeSyncNeeded() || !timeSync.IsSyncSuccess() { + logger.Info().Msg("system time is not synced, will retry in 30 seconds") + ticker.Reset(30 * time.Second) + continue + } + + if err := otaState.TryUpdate(context.Background(), ota.UpdateParams{ + DeviceID: GetDeviceID(), + IncludePreRelease: config.IncludePreRelease, + }); err != nil { + logger.Warn().Err(err).Msg("failed to auto update") + } + + ticker.Reset(1 * time.Hour) + } + } +} diff --git a/pkg/myip/check.go b/pkg/myip/check.go index 86d3ba50a..2a1b07131 100644 --- a/pkg/myip/check.go +++ b/pkg/myip/check.go @@ -2,6 +2,7 @@ package myip import ( "context" + "errors" "fmt" "io" "net" @@ -92,11 +93,11 @@ func (ps *PublicIPState) checkAPI(_ context.Context, _ int) (*PublicIP, error) { // checkIPs checks both IPv4 and IPv6 public IP addresses in parallel // and updates the IPAddresses slice with the results -func (ps *PublicIPState) checkIPs(ctx context.Context, checkIPv4, checkIPv6 bool) error { +func (ps *PublicIPState) checkIPs(ctx context.Context) error { var wg sync.WaitGroup var mu sync.Mutex var ips []PublicIP - var errors []error + var errs []error checkFamily := func(family int, familyName string) { wg.Add(1) @@ -104,23 +105,22 @@ func (ps *PublicIPState) checkIPs(ctx context.Context, checkIPv4, checkIPv6 bool defer wg.Done() ip, err := ps.checkIPForFamily(ctx, f) + mu.Lock() defer mu.Unlock() if err != nil { - errors = append(errors, fmt.Errorf("%s check failed: %w", name, err)) - return - } - if ip != nil { + errs = append(errs, fmt.Errorf("%s check failed: %w", name, err)) + } else if ip != nil { ips = append(ips, *ip) } }(family, familyName) } - if checkIPv4 { + if ps.ipv4 { checkFamily(link.AfInet, "IPv4") } - if checkIPv6 { + if ps.ipv6 { checkFamily(link.AfInet6, "IPv6") } @@ -134,11 +134,7 @@ func (ps *PublicIPState) checkIPs(ctx context.Context, checkIPv4, checkIPv6 bool ps.lastUpdated = time.Now() } - if len(errors) > 0 && len(ips) == 0 { - return errors[0] - } - - return nil + return errors.Join(errs...) // returns nil if errs is empty otherwise combines errors } func (ps *PublicIPState) checkIPForFamily(ctx context.Context, family int) (*PublicIP, error) { diff --git a/pkg/myip/ip.go b/pkg/myip/ip.go index 15afc24e7..5261298b6 100644 --- a/pkg/myip/ip.go +++ b/pkg/myip/ip.go @@ -18,10 +18,23 @@ type PublicIP struct { LastUpdated time.Time `json:"last_updated"` } +func (pi PublicIP) MarshalZerologObject(e *zerolog.Event) { + e.IPAddr("address", pi.IPAddress) + e.Time("updated", pi.LastUpdated) +} + +type PublicIPs []PublicIP + +func (ips PublicIPs) MarshalZerologArray(a *zerolog.Array) { + for _, ip := range ips { + a.Object(ip) + } +} + type HttpClientGetter func(family int) *http.Client type PublicIPState struct { - addresses []PublicIP + addresses PublicIPs lastUpdated time.Time cloudflareEndpoint string // cdn-cgi/trace domain @@ -29,7 +42,6 @@ type PublicIPState struct { ipv4 bool ipv6 bool httpClient HttpClientGetter - logger *zerolog.Logger timer *time.Timer ctx context.Context @@ -37,13 +49,19 @@ type PublicIPState struct { mu sync.Mutex } +func (ps *PublicIPState) MarshalZerologObject(e *zerolog.Event) { + e.Array("addresses", ps.addresses) + e.Time("lastUpdated", ps.lastUpdated) + e.Bool("ipv4", ps.ipv4) + e.Bool("ipv6", ps.ipv6) +} + type PublicIPStateConfig struct { CloudflareEndpoint string APIEndpoint string IPv4 bool IPv6 bool HttpClientGetter HttpClientGetter - Logger *zerolog.Logger } func stripURLPath(s string) string { @@ -61,10 +79,6 @@ func stripURLPath(s string) string { // NewPublicIPState creates a new PublicIPState func NewPublicIPState(config *PublicIPStateConfig) *PublicIPState { - if config.Logger == nil { - config.Logger = logging.GetSubsystemLogger("publicip") - } - ctx, cancel := context.WithCancel(context.Background()) ps := &PublicIPState{ addresses: make([]PublicIP, 0), @@ -76,7 +90,6 @@ func NewPublicIPState(config *PublicIPStateConfig) *PublicIPState { httpClient: config.HttpClientGetter, ctx: ctx, cancel: cancel, - logger: config.Logger, } // Start the timer automatically ps.Start() @@ -155,7 +168,15 @@ func (ps *PublicIPState) Stop() { // ForceUpdate forces an update of the public IP addresses func (ps *PublicIPState) ForceUpdate() error { - return ps.checkIPs(context.Background(), true, true) + return ps.checkIPs(context.Background()) +} + +func (ps *PublicIPState) getLogger() *zerolog.Logger { + logger := logging.GetSubsystemLogger("myip"). + With(). + Object("ps", ps). + Logger() + return &logger } // timerLoop runs the periodic IP check loop @@ -166,14 +187,12 @@ func (ps *PublicIPState) timerLoop(ctx context.Context) { // Store timer reference for Stop() to access ps.mu.Lock() ps.timer = timer - checkIPv4 := ps.ipv4 - checkIPv6 := ps.ipv6 ps.mu.Unlock() // Perform initial check immediately checkIPs := func() { - if err := ps.checkIPs(ctx, checkIPv4, checkIPv6); err != nil { - ps.logger.Error().Err(err).Msg("failed to check public IP addresses") + if err := ps.checkIPs(ctx); err != nil { + ps.getLogger().Error().Err(err).Msg("failed to check public IP addresses") } } diff --git a/pkg/nmlite/dhcp.go b/pkg/nmlite/dhcp.go index 2a0c47b5a..b7fa75987 100644 --- a/pkg/nmlite/dhcp.go +++ b/pkg/nmlite/dhcp.go @@ -5,6 +5,7 @@ import ( "context" "fmt" + "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/network/types" "github.com/jetkvm/kvm/pkg/nmlite/jetdhcpc" "github.com/jetkvm/kvm/pkg/nmlite/udhcpc" @@ -15,7 +16,6 @@ import ( type DHCPClient struct { ctx context.Context ifaceName string - logger *zerolog.Logger client types.DHCPClient clientType string @@ -28,19 +28,14 @@ type DHCPClient struct { } // NewDHCPClient creates a new DHCP client -func NewDHCPClient(ctx context.Context, ifaceName string, logger *zerolog.Logger, clientType string) (*DHCPClient, error) { +func NewDHCPClient(ctx context.Context, ifaceName string, clientType string) (*DHCPClient, error) { if ifaceName == "" { return nil, fmt.Errorf("interface name cannot be empty") } - if logger == nil { - return nil, fmt.Errorf("logger cannot be nil") - } - return &DHCPClient{ ctx: ctx, ifaceName: ifaceName, - logger: logger, clientType: clientType, }, nil } @@ -77,6 +72,17 @@ func (dc *DHCPClient) initClient() (types.DHCPClient, error) { } } +func (dc *DHCPClient) getLogger() *zerolog.Logger { + logger := logging.GetSubsystemLogger("dhcp"). + With(). + Str("interface", dc.ifaceName). + Str("clientType", dc.clientType). + Bool("ipv4Enabled", dc.ipv4Enabled). + Bool("ipv6Enabled", dc.ipv6Enabled). + Logger() + return &logger +} + func (dc *DHCPClient) initJetDHCPC() (types.DHCPClient, error) { return jetdhcpc.NewClient(dc.ctx, []string{dc.ifaceName}, &jetdhcpc.Config{ IPv4: dc.ipv4Enabled, @@ -90,19 +96,18 @@ func (dc *DHCPClient) initJetDHCPC() (types.DHCPClient, error) { }, UpdateResolvConf: func(nameservers []string) error { // This will be handled by the resolv.conf manager - dc.logger.Debug(). + dc.getLogger().Debug(). Interface("nameservers", nameservers). Msg("DHCP client requested resolv.conf update") return nil }, - }, dc.logger) + }) } func (dc *DHCPClient) initUDHCPC() (types.DHCPClient, error) { c := udhcpc.NewDHCPClient(&udhcpc.DHCPClientOptions{ InterfaceName: dc.ifaceName, PidFile: "", - Logger: dc.logger, OnLeaseChange: func(lease *types.DHCPLease) { dc.handleLeaseChange(lease, false) }, @@ -112,12 +117,13 @@ func (dc *DHCPClient) initUDHCPC() (types.DHCPClient, error) { // Start starts the DHCP client func (dc *DHCPClient) Start() error { + logger := dc.getLogger() if dc.client != nil { - dc.logger.Warn().Msg("DHCP client already started") + logger.Warn().Msg("DHCP client already started") return nil } - dc.logger.Info().Msg("starting DHCP client") + logger.Info().Msg("starting DHCP client") // Create the underlying DHCP client client, err := dc.initClient() @@ -134,7 +140,7 @@ func (dc *DHCPClient) Start() error { return fmt.Errorf("failed to start DHCP client: %w", err) } - dc.logger.Info().Msg("DHCP client started") + logger.Info().Msg("DHCP client started") return nil } @@ -165,10 +171,11 @@ func (dc *DHCPClient) Stop() error { return nil } - dc.logger.Info().Msg("stopping DHCP client") + logger := dc.getLogger() + logger.Info().Msg("stopping DHCP client") dc.client = nil - dc.logger.Info().Msg("DHCP client stopped") + logger.Info().Msg("DHCP client stopped") return nil } @@ -178,7 +185,7 @@ func (dc *DHCPClient) Renew() error { return fmt.Errorf("DHCP client not started") } - dc.logger.Info().Msg("renewing DHCP lease") + dc.getLogger().Info().Msg("renewing DHCP lease") if err := dc.client.Renew(); err != nil { return fmt.Errorf("failed to renew DHCP lease: %w", err) } @@ -191,7 +198,7 @@ func (dc *DHCPClient) Release() error { return fmt.Errorf("DHCP client not started") } - dc.logger.Info().Msg("releasing DHCP lease") + dc.getLogger().Info().Msg("releasing DHCP lease") if err := dc.client.Release(); err != nil { return fmt.Errorf("failed to release DHCP lease: %w", err) } @@ -204,7 +211,7 @@ func (dc *DHCPClient) handleLeaseChange(lease *types.DHCPLease, isIPv6 bool) { return } - dc.logger.Info(). + dc.getLogger().Info(). Bool("ipv6", isIPv6). Str("ip", lease.IPAddress.String()). Msg("DHCP lease changed") diff --git a/pkg/nmlite/hostname.go b/pkg/nmlite/hostname.go index 88f35330c..0d654eebd 100644 --- a/pkg/nmlite/hostname.go +++ b/pkg/nmlite/hostname.go @@ -100,7 +100,7 @@ func (hm *ResolvConfManager) reconcileHostname() error { fqdn = fmt.Sprintf("%s.%s", hostname, domain) } - hm.logger.Info(). + hm.getLogger().Info(). Str("hostname", hostname). Str("fqdn", fqdn). Msg("setting hostname") @@ -120,7 +120,7 @@ func (hm *ResolvConfManager) reconcileHostname() error { return fmt.Errorf("failed to set system hostname: %w", err) } - hm.logger.Info(). + hm.getLogger().Info(). Str("hostname", hostname). Str("fqdn", fqdn). Msg("hostname set successfully") @@ -150,7 +150,7 @@ func (hm *ResolvConfManager) updateEtcHostname(hostname string) error { return fmt.Errorf("failed to write %s: %w", hostnamePath, err) } - hm.logger.Debug().Str("file", hostnamePath).Str("hostname", hostname).Msg("updated /etc/hostname") + hm.getLogger().Debug().Str("file", hostnamePath).Str("hostname", hostname).Msg("updated /etc/hostname") return nil } @@ -204,7 +204,7 @@ func (hm *ResolvConfManager) updateEtcHosts(hostname, fqdn string) error { return fmt.Errorf("failed to write %s: %w", hostsPath, err) } - hm.logger.Debug(). + hm.getLogger().Debug(). Str("file", hostsPath). Str("hostname", hostname). Str("fqdn", fqdn). @@ -220,7 +220,7 @@ func (hm *ResolvConfManager) setSystemHostname(hostname string) error { return fmt.Errorf("failed to run hostname command: %w", err) } - hm.logger.Debug().Str("hostname", hostname).Msg("set system hostname") + hm.getLogger().Debug().Str("hostname", hostname).Msg("set system hostname") return nil } diff --git a/pkg/nmlite/interface.go b/pkg/nmlite/interface.go index a9111a645..956d6886b 100644 --- a/pkg/nmlite/interface.go +++ b/pkg/nmlite/interface.go @@ -8,12 +8,14 @@ import ( "time" + "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/sync" "github.com/jetkvm/kvm/internal/confparser" - "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/network/types" + "github.com/jetkvm/kvm/pkg/nmlite/link" + "github.com/mdlayher/ndp" "github.com/rs/zerolog" "github.com/vishvananda/netlink" @@ -26,7 +28,6 @@ type InterfaceManager struct { ctx context.Context ifaceName string config *types.NetworkConfig - logger *zerolog.Logger state *types.InterfaceState linkState *link.Link stateMu sync.RWMutex @@ -47,17 +48,11 @@ type InterfaceManager struct { } // NewInterfaceManager creates a new interface manager -func NewInterfaceManager(ctx context.Context, ifaceName string, config *types.NetworkConfig, logger *zerolog.Logger) (*InterfaceManager, error) { +func NewInterfaceManager(ctx context.Context, ifaceName string, config *types.NetworkConfig) (*InterfaceManager, error) { if config == nil { return nil, fmt.Errorf("config cannot be nil") } - if logger == nil { - logger = logging.GetSubsystemLogger("interface") - } - - scopedLogger := logger.With().Str("interface", ifaceName).Logger() - // Validate and set defaults if err := confparser.SetDefaultsAndValidate(config); err != nil { return nil, fmt.Errorf("invalid config: %w", err) @@ -67,7 +62,6 @@ func NewInterfaceManager(ctx context.Context, ifaceName string, config *types.Ne ctx: ctx, ifaceName: ifaceName, config: config, - logger: &scopedLogger, state: &types.InterfaceState{ InterfaceName: ifaceName, // LastUpdated: time.Now(), @@ -77,26 +71,27 @@ func NewInterfaceManager(ctx context.Context, ifaceName string, config *types.Ne // Initialize components var err error - im.staticConfig, err = NewStaticConfigManager(ifaceName, &scopedLogger) + im.staticConfig, err = NewStaticConfigManager(ifaceName) if err != nil { return nil, fmt.Errorf("failed to create static config manager: %w", err) } // create the dhcp client - im.dhcpClient, err = NewDHCPClient(ctx, ifaceName, &scopedLogger, config.DHCPClient.String) + im.dhcpClient, err = NewDHCPClient(ctx, ifaceName, config.DHCPClient.String) if err != nil { return nil, fmt.Errorf("failed to create DHCP client: %w", err) } // Set up DHCP client callbacks im.dhcpClient.SetOnLeaseChange(func(lease *types.DHCPLease) { + logger := im.getLogger() if im.config.IPv4Mode.String != "dhcp" { - im.logger.Warn().Str("mode", im.config.IPv4Mode.String).Msg("ignoring DHCP lease, current mode is not DHCP") + logger.Warn().Str("mode", im.config.IPv4Mode.String).Msg("ignoring DHCP lease, current mode is not DHCP") return } if err := im.applyDHCPLease(lease); err != nil { - im.logger.Error().Err(err).Msg("failed to apply DHCP lease") + logger.Error().Err(err).Msg("failed to apply DHCP lease") } im.updateStateFromDHCPLease(lease) if im.onDHCPLeaseChange != nil { @@ -107,12 +102,26 @@ func NewInterfaceManager(ctx context.Context, ifaceName string, config *types.Ne return im, nil } +func (im *InterfaceManager) getLogger() *zerolog.Logger { + logger := logging.GetSubsystemLogger("nmlite").With().Str("interface", im.ifaceName).Logger() + + if logging.IsTraceLevel(&logger) { + logger = logger. + With(). + Object("state", im.state). + Logger() + } + + return &logger +} + // Start starts managing the interface func (im *InterfaceManager) Start() error { im.stateMu.Lock() defer im.stateMu.Unlock() - im.logger.Info().Msg("starting interface manager") + logger := im.getLogger() + logger.Info().Msg("starting interface manager") // Start monitoring interface state im.wg.Add(1) @@ -143,22 +152,23 @@ func (im *InterfaceManager) Start() error { }) if linkUpErr != nil { - im.logger.Error().Err(linkUpErr).Msg("failed to bring interface up, continuing anyway") + logger.Error().Err(linkUpErr).Msg("failed to bring interface up, continuing anyway") } else { // Apply initial configuration if err := im.applyConfiguration(); err != nil { - im.logger.Error().Err(err).Msg("failed to apply initial configuration") + logger.Error().Err(err).Msg("failed to apply initial configuration") return err } } - im.logger.Info().Msg("interface manager started") + logger.Info().Msg("interface manager started") return nil } // Stop stops managing the interface func (im *InterfaceManager) Stop() error { - im.logger.Info().Msg("stopping interface manager") + logger := im.getLogger() + logger.Info().Msg("stopping interface manager") close(im.stopCh) im.wg.Wait() @@ -170,7 +180,7 @@ func (im *InterfaceManager) Stop() error { } } - im.logger.Info().Msg("interface manager stopped") + logger.Info().Msg("interface manager stopped") return nil } @@ -301,9 +311,9 @@ func (im *InterfaceManager) GetState() *types.InterfaceState { im.stateMu.RLock() defer im.stateMu.RUnlock() - // Return a copy to avoid race conditions - im.logger.Debug().Interface("state", im.state).Msg("getting interface state") + im.getLogger().Debug().Msg("getting interface state") + // Return a copy to avoid race conditions state := *im.state return &state } @@ -366,7 +376,7 @@ func (im *InterfaceManager) SetConfig(config *types.NetworkConfig) error { // Apply the new configuration if err := im.applyConfiguration(); err != nil { - im.logger.Error().Err(err).Msg("failed to apply new configuration") + im.getLogger().Error().Err(err).Msg("failed to apply new configuration") return err } @@ -375,7 +385,7 @@ func (im *InterfaceManager) SetConfig(config *types.NetworkConfig) error { im.onConfigChange(config) } - im.logger.Info().Msg("configuration updated") + im.getLogger().Info().Msg("configuration updated") return nil } @@ -410,8 +420,6 @@ func (im *InterfaceManager) SetOnResolvConfChange(callback ResolvConfChangeCallb // applyConfiguration applies the current configuration to the interface func (im *InterfaceManager) applyConfiguration() error { - im.logger.Info().Msg("applying configuration") - // Apply IPv4 configuration if err := im.applyIPv4Config(); err != nil { return fmt.Errorf("failed to apply IPv4 config: %w", err) @@ -428,7 +436,7 @@ func (im *InterfaceManager) applyConfiguration() error { // applyIPv4Config applies IPv4 configuration func (im *InterfaceManager) applyIPv4Config() error { mode := im.config.IPv4Mode.String - im.logger.Info().Str("mode", mode).Msg("applying IPv4 configuration") + im.getLogger().Info().Str("mode", mode).Msg("applying IPv4 configuration") switch mode { case "static": @@ -445,7 +453,7 @@ func (im *InterfaceManager) applyIPv4Config() error { // applyIPv6Config applies IPv6 configuration func (im *InterfaceManager) applyIPv6Config() error { mode := im.config.IPv6Mode.String - im.logger.Info().Str("mode", mode).Msg("applying IPv6 configuration") + im.getLogger().Info().Str("mode", mode).Msg("applying IPv6 configuration") switch mode { case "static": @@ -471,27 +479,27 @@ func (im *InterfaceManager) applyIPv4Static() error { return fmt.Errorf("IPv4 static configuration is nil") } - im.logger.Info().Msg("stopping DHCP") + im.getLogger().Info().Msg("stopping DHCP") // Disable DHCP if im.dhcpClient != nil { im.dhcpClient.SetIPv4(false) } - im.logger.Info().Interface("config", im.config.IPv4Static).Msg("applying IPv4 static configuration") + im.getLogger().Info().Interface("config", im.config.IPv4Static).Msg("applying IPv4 static configuration") config, err := im.staticConfig.ToIPv4Static(im.config.IPv4Static) if err != nil { return fmt.Errorf("failed to convert IPv4 static configuration: %w", err) } - im.logger.Info().Interface("config", config).Msg("converted IPv4 static configuration") + im.getLogger().Info().Interface("config", config).Msg("converted IPv4 static configuration") if err := im.onResolvConfChange(link.AfInet, &types.InterfaceResolvConf{ NameServers: config.Nameservers, Source: "static", }); err != nil { - im.logger.Warn().Err(err).Msg("failed to update resolv.conf") + im.getLogger().Warn().Err(err).Msg("failed to update resolv.conf") } return im.ReconcileLinkAddrs(config.Addresses, link.AfInet) @@ -525,7 +533,7 @@ func (im *InterfaceManager) applyIPv6Static() error { return fmt.Errorf("IPv6 static configuration is nil") } - im.logger.Info().Msg("stopping DHCPv6") + im.getLogger().Info().Msg("stopping DHCPv6") // Disable DHCPv6 if im.dhcpClient != nil { im.dhcpClient.SetIPv6(false) @@ -536,13 +544,13 @@ func (im *InterfaceManager) applyIPv6Static() error { if err != nil { return fmt.Errorf("failed to convert IPv6 static configuration: %w", err) } - im.logger.Info().Interface("config", config).Msg("converted IPv6 static configuration") + im.getLogger().Info().Interface("config", config).Msg("converted IPv6 static configuration") if err := im.onResolvConfChange(link.AfInet6, &types.InterfaceResolvConf{ NameServers: config.Nameservers, Source: "static", }); err != nil { - im.logger.Warn().Err(err).Msg("failed to update resolv.conf") + im.getLogger().Warn().Err(err).Msg("failed to update resolv.conf") } return im.ReconcileLinkAddrs(config.Addresses, link.AfInet6) @@ -584,7 +592,7 @@ func (im *InterfaceManager) applyIPv6SLAAC() error { } if err := im.SendRouterSolicitation(); err != nil { - im.logger.Error().Err(err).Msg("failed to send router solicitation, continuing anyway") + im.getLogger().Error().Err(err).Msg("failed to send router solicitation, continuing anyway") } // Enable SLAAC @@ -638,7 +646,7 @@ func (im *InterfaceManager) handleLinkStateChange(link *link.Link) { im.linkState = link } - im.logger.Info().Interface("link", link).Msg("link state changed") + im.getLogger().Info().Interface("link", link).Msg("link state changed") operState := link.Attrs().OperState if operState == netlink.OperUp { @@ -650,7 +658,7 @@ func (im *InterfaceManager) handleLinkStateChange(link *link.Link) { // SendRouterSolicitation sends a router solicitation func (im *InterfaceManager) SendRouterSolicitation() error { - im.logger.Info().Msg("sending router solicitation") + im.getLogger().Info().Msg("sending router solicitation") m := &ndp.RouterSolicitation{} l, err := im.link() @@ -689,51 +697,53 @@ func (im *InterfaceManager) SendRouterSolicitation() error { return fmt.Errorf("failed to write to %s: %w", targetAddr.String(), err) } - im.logger.Info().Msg("router solicitation sent") + im.getLogger().Info().Msg("router solicitation sent") c.Close() return nil } func (im *InterfaceManager) handleLinkUp() { - im.logger.Info().Msg("link up") + logger := im.getLogger() + logger.Info().Msg("link up") if err := im.applyConfiguration(); err != nil { - im.logger.Error().Err(err).Msg("failed to apply configuration") + logger.Error().Err(err).Msg("failed to apply configuration") } if im.config.IPv4Mode.String == "dhcp" { if err := im.dhcpClient.Renew(); err != nil { - im.logger.Error().Err(err).Msg("failed to renew DHCP lease") + logger.Error().Err(err).Msg("failed to renew DHCP lease") } } if im.config.IPv6Mode.String == "slaac" { if err := im.staticConfig.EnableIPv6SLAAC(); err != nil { - im.logger.Error().Err(err).Msg("failed to enable IPv6 SLAAC") + logger.Error().Err(err).Msg("failed to enable IPv6 SLAAC") } if err := im.SendRouterSolicitation(); err != nil { - im.logger.Error().Err(err).Msg("failed to send router solicitation") + logger.Error().Err(err).Msg("failed to send router solicitation") } } } func (im *InterfaceManager) handleLinkDown() { - im.logger.Info().Msg("link down") + logger := im.getLogger() + logger.Info().Msg("link down") if im.config.IPv4Mode.String == "dhcp" { if err := im.dhcpClient.Stop(); err != nil { - im.logger.Error().Err(err).Msg("failed to stop DHCP client") + logger.Error().Err(err).Msg("failed to stop DHCP client") } } netlinkMgr := getNetlinkManager() if err := netlinkMgr.RemoveAllAddresses(im.linkState, link.AfInet); err != nil { - im.logger.Error().Err(err).Msg("failed to remove all IPv4 addresses") + logger.Error().Err(err).Msg("failed to remove all IPv4 addresses") } if err := netlinkMgr.RemoveNonLinkLocalIPv6Addresses(im.linkState); err != nil { - im.logger.Error().Err(err).Msg("failed to remove non-link-local IPv6 addresses") + logger.Error().Err(err).Msg("failed to remove non-link-local IPv6 addresses") } } @@ -741,7 +751,8 @@ func (im *InterfaceManager) handleLinkDown() { func (im *InterfaceManager) monitorInterfaceState() { defer im.wg.Done() - im.logger.Debug().Msg("monitoring interface state") + logger := im.getLogger() + logger.Debug().Msg("monitoring interface state") // TODO: use netlink subscription instead of polling ticker := time.NewTicker(5 * time.Second) defer ticker.Stop() @@ -754,7 +765,7 @@ func (im *InterfaceManager) monitorInterfaceState() { return case <-ticker.C: if err := im.updateInterfaceState(); err != nil { - im.logger.Error().Err(err).Msg("failed to update interface state") + logger.Error().Err(err).Msg("failed to update interface state") } } } @@ -778,8 +789,10 @@ func (im *InterfaceManager) updateStateFromDHCPLease(lease *types.DHCPLease) { return } + logger := im.getLogger() + if im.ifaceName == "" { - im.logger.Warn().Msg("interface name is empty, skipping resolv.conf update") + logger.Warn().Msg("interface name is empty, skipping resolv.conf update") return } @@ -789,7 +802,7 @@ func (im *InterfaceManager) updateStateFromDHCPLease(lease *types.DHCPLease) { Source: "dhcp", Domain: lease.Domain, }); err != nil { - im.logger.Warn().Err(err).Msg("failed to update resolv.conf") + logger.Warn().Err(err).Msg("failed to update resolv.conf") } } @@ -812,13 +825,15 @@ func (im *InterfaceManager) applyDHCPLease(lease *types.DHCPLease) error { return fmt.Errorf("DHCP lease is nil") } + logger := im.getLogger() + if lease.DHCPClient != "jetdhcpc" { - im.logger.Warn().Str("dhcp_client", lease.DHCPClient).Msg("ignoring DHCP lease, not implemented yet") + logger.Warn().Str("dhcp_client", lease.DHCPClient).Msg("ignoring DHCP lease, not implemented yet") return nil } if lease.IsIPv6() { - im.logger.Warn().Msg("ignoring IPv6 DHCP lease, not implemented yet") + logger.Warn().Msg("ignoring IPv6 DHCP lease, not implemented yet") return nil } @@ -844,7 +859,7 @@ func (im *InterfaceManager) convertDHCPLeaseToIPv4Config(lease *types.DHCPLease) Permanent: false, } - im.logger.Trace(). + im.getLogger().Trace(). Interface("ipv4Addr", ipv4Addr). Interface("lease", lease). Msg("converted DHCP lease to IPv4Config") diff --git a/pkg/nmlite/interface_state.go b/pkg/nmlite/interface_state.go index efa5f087b..aa6540393 100644 --- a/pkg/nmlite/interface_state.go +++ b/pkg/nmlite/interface_state.go @@ -87,7 +87,7 @@ func (im *InterfaceManager) updateInterfaceState() error { // Update IP addresses if ipChanged, err := im.updateInterfaceStateAddresses(nl); err != nil { - im.logger.Error().Err(err).Msg("failed to update IP addresses") + im.getLogger().Error().Err(err).Msg("failed to update IP addresses") } else if ipChanged { stateChanged = true changeReasons = append(changeReasons, IfStateIPAddressesChanged) @@ -98,9 +98,8 @@ func (im *InterfaceManager) updateInterfaceState() error { // Notify callback if state changed if stateChanged && im.onStateChange != nil { - im.logger.Debug(). + im.getLogger().Debug(). Stringer("changeReasons", changeReasons). - Interface("state", im.state). Msg("notifying state change") im.onStateChange(*im.state) } @@ -214,9 +213,8 @@ func (im *InterfaceManager) updateInterfaceStateAddresses(nl *link.Link) (bool, } if stateChanged { - im.logger.Trace(). + im.getLogger().Trace(). Str("changeReason", stateChangeReason). - Interface("state", im.state). Msg("interface state changed") } diff --git a/pkg/nmlite/jetdhcpc/client.go b/pkg/nmlite/jetdhcpc/client.go index 102d3bee0..68a327e7c 100644 --- a/pkg/nmlite/jetdhcpc/client.go +++ b/pkg/nmlite/jetdhcpc/client.go @@ -8,12 +8,14 @@ import ( "time" + "github.com/jetkvm/kvm/internal/logging" + "github.com/jetkvm/kvm/internal/network/types" "github.com/jetkvm/kvm/internal/sync" + "github.com/jetkvm/kvm/pkg/nmlite/link" "github.com/insomniacslk/dhcp/dhcpv4" "github.com/insomniacslk/dhcp/dhcpv6" - "github.com/jetkvm/kvm/internal/network/types" "github.com/rs/zerolog" ) @@ -89,7 +91,6 @@ type Client struct { ifaces []string cfg Config - l *zerolog.Logger ctx context.Context @@ -116,7 +117,7 @@ var ( ) // NewClient creates a new DHCP client for the given interface. -func NewClient(ctx context.Context, ifaces []string, c *Config, l *zerolog.Logger) (*Client, error) { +func NewClient(ctx context.Context, ifaces []string, c *Config) (*Client, error) { timer4 := time.NewTimer(defaultTimerDuration) timer6 := time.NewTimer(defaultTimerDuration) @@ -137,7 +138,6 @@ func NewClient(ctx context.Context, ifaces []string, c *Config, l *zerolog.Logge ctx: ctx, ifaces: ifaces, cfg: cfg, - l: l, stateDir: "/run/jetkvm-dhcp", currentLease4: nil, @@ -154,14 +154,14 @@ func NewClient(ctx context.Context, ifaces []string, c *Config, l *zerolog.Logge }, nil } -func resetTimer(t *time.Timer, attempt int, l *zerolog.Logger) { +func (c *Client) resetTimer(t *time.Timer, attempt int) { // Exponential backoff: 1s, 2s, 4s, 8s, max 8s backoffAttempt := attempt if backoffAttempt > 3 { backoffAttempt = 3 } delay := time.Duration(1< 0 { - l.Info().Int("attempt", attempt-1).Msg("interface is up") + logger.Info().Int("attempt", attempt-1).Msg("interface is up") } return link, nil } - l.Info().Msg("bringing up interface") + logger.Info().Msg("bringing up interface") // bring up the interface if err = nm.LinkSetUp(link); err != nil { - l.Error().Err(err).Msg("interface can't make it up") + logger.Error().Err(err).Msg("interface can't make it up") } // refresh the link attributes if err = link.Refresh(); err != nil { - l.Error().Err(err).Msg("failed to refresh link attributes") + logger.Error().Err(err).Msg("failed to refresh link attributes") } // check the state again state = link.Attrs().OperState - l = l.With().Str("new_state", state.String()).Logger() + logger = logger.With().Stringer("new_state", state).Logger() + if state == netlink.OperUp { - l.Info().Msg("interface is up") + logger.Info().Msg("interface is up") return link, nil } - l.Warn().Msg("interface is still down, retrying") + logger.Warn().Msg("interface is still down, retrying") select { case <-time.After(500 * time.Millisecond): @@ -210,8 +217,7 @@ func (nm *NetlinkManager) EnsureInterfaceUpWithTimeout(ctx context.Context, ifac return nil, ErrInterfaceUpCanceled case <-linkUpTimeout: attempt++ - l.Error(). - Int("attempt", attempt).Msg("interface is still down after timeout") + logger.Error().Int("attempt", attempt).Msg("interface is still down after timeout") if err != nil { return nil, err } @@ -252,7 +258,7 @@ func (nm *NetlinkManager) RemoveAllAddresses(link *Link, family int) error { for _, addr := range addrs { if err := nm.AddrDel(link, &addr); err != nil { - nm.logger.Warn().Err(err).Str("address", addr.IP.String()).Msg("failed to remove address") + nm.getLogger().Warn().Err(err).IPAddr("address", addr.IP).Msg("failed to remove address") } } @@ -269,7 +275,7 @@ func (nm *NetlinkManager) RemoveNonLinkLocalIPv6Addresses(link *Link) error { for _, addr := range addrs { if !addr.IP.IsLinkLocalUnicast() { if err := nm.AddrDel(link, &addr); err != nil { - nm.logger.Warn().Err(err).Str("address", addr.IP.String()).Msg("failed to remove IPv6 address") + nm.getLogger().Warn().Err(err).IPAddr("address", addr.IP).Msg("failed to remove IPv6 address") } } } @@ -313,7 +319,7 @@ func (nm *NetlinkManager) ListDefaultRoutes(family int) ([]netlink.Route, error) netlink.RT_FILTER_DST|netlink.RT_FILTER_TABLE, ) if err != nil { - nm.logger.Error().Err(err).Int("family", family).Msg("failed to list default routes") + nm.getLogger().Error().Err(err).Int("family", family).Msg("failed to list default routes") return nil, err } @@ -359,14 +365,16 @@ func (nm *NetlinkManager) RemoveDefaultRoute(family int) error { for _, route := range routes { if route.Dst != nil { + logger := nm.getLogger() + if family == AfInet && route.Dst.IP.Equal(net.IPv4zero) && route.Dst.Mask.String() == "0.0.0.0/0" { if err := nm.RouteDel(&route); err != nil { - nm.logger.Warn().Err(err).Msg("failed to remove IPv4 default route") + logger.Warn().Err(err).Msg("failed to remove IPv4 default route") } } if family == AfInet6 && route.Dst.IP.Equal(net.IPv6zero) && route.Dst.Mask.String() == "::/0" { if err := nm.RouteDel(&route); err != nil { - nm.logger.Warn().Err(err).Msg("failed to remove IPv6 default route") + logger.Warn().Err(err).Msg("failed to remove IPv6 default route") } } } @@ -386,6 +394,8 @@ func (nm *NetlinkManager) reconcileDefaultRoute(link *Link, expected map[string] return fmt.Errorf("failed to get default routes: %w", err) } + logger := nm.getLogger() + // check existing default routes for _, defaultRoute := range defaultRoutes { // only check the default routes for the current link @@ -400,21 +410,22 @@ func (nm *NetlinkManager) reconcileDefaultRoute(link *Link, expected map[string] continue } - nm.logger.Warn().Str("gateway", key).Msg("keeping default route") + logger.Warn().Str("gateway", key).Msg("keeping default route") delete(expected, key) } // remove remaining default routes for _, defaultRoute := range toRemove { - nm.logger.Warn().Str("gateway", defaultRoute.Gw.String()).Msg("removing default route") + logger.Info().Str("gateway", defaultRoute.Gw.String()).Msg("removing default route") + if err := nm.RouteDel(defaultRoute); err != nil { - nm.logger.Warn().Err(err).Msg("failed to remove default route") + logger.Warn().Err(err).Msg("failed to remove default route") } } // add remaining expected default routes for _, gateway := range expected { - nm.logger.Warn().Str("gateway", gateway.String()).Msg("adding default route") + logger.Info().Str("gateway", gateway.String()).Msg("adding default route") route := &netlink.Route{ Dst: &ipv4DefaultRoute, @@ -425,16 +436,12 @@ func (nm *NetlinkManager) reconcileDefaultRoute(link *Link, expected map[string] route.Dst = &ipv6DefaultRoute } if err := nm.RouteAdd(route); err != nil { - nm.logger.Warn().Err(err).Interface("route", route).Msg("failed to add default route") + logger.Warn().Err(err).Interface("route", route).Msg("failed to add default route") } added++ } - nm.logger.Info(). - Int("added", added). - Int("removed", len(toRemove)). - Msg("default routes reconciled") - + logger.Info().Int("added", added).Int("removed", len(toRemove)).Msg("default routes reconciled") return nil } @@ -444,11 +451,15 @@ func (nm *NetlinkManager) ReconcileLink(link *Link, expected []types.IPAddress, toRemove := make([]*netlink.Addr, 0) toUpdate := make([]*types.IPAddress, 0) expectedAddrs := make(map[string]*types.IPAddress) - expectedGateways := make(map[string]net.IP) + ifname := link.Attrs().Name + linkIndex := link.Attrs().Index + logger := nm.getLogger().With().Str("interface", ifname).Int("index", linkIndex).Logger() + mtu := link.Attrs().MTU expectedMTU := mtu + // add all expected addresses to the map for _, addr := range expected { expectedAddrs[addr.String()] = &addr @@ -461,7 +472,7 @@ func (nm *NetlinkManager) ReconcileLink(link *Link, expected []types.IPAddress, } if expectedMTU != mtu { if err := link.SetMTU(expectedMTU); err != nil { - nm.logger.Warn().Err(err).Int("expected_mtu", expectedMTU).Int("mtu", mtu).Msg("failed to set MTU") + logger.Warn().Err(err).Int("expected_mtu", expectedMTU).Int("mtu", mtu).Msg("failed to set MTU") } } @@ -501,35 +512,28 @@ func (nm *NetlinkManager) ReconcileLink(link *Link, expected []types.IPAddress, for _, addr := range toUpdate { netlinkAddr := addr.NetlinkAddr() if err := nm.AddrDel(link, &netlinkAddr); err != nil { - nm.logger.Warn().Err(err).Str("address", addr.Address.String()).Msg("failed to update address") + logger.Warn().Err(err).IPPrefix("address", addr.Address).Msg("failed to update address") } // we'll add it again later toAdd = append(toAdd, addr) } - for _, addr := range toAdd { - netlinkAddr := addr.NetlinkAddr() - if err := nm.AddrAdd(link, &netlinkAddr); err != nil { - nm.logger.Warn().Err(err).Str("address", addr.Address.String()).Msg("failed to add address") - } - } - for _, netlinkAddr := range toRemove { if err := nm.AddrDel(link, netlinkAddr); err != nil { - nm.logger.Warn().Err(err).Str("address", netlinkAddr.IP.String()).Msg("failed to remove address") + logger.Warn().Err(err).IPAddr("address", netlinkAddr.IP).Msg("failed to remove address") } } for _, addr := range toAdd { netlinkAddr := addr.NetlinkAddr() if err := nm.AddrAdd(link, &netlinkAddr); err != nil { - nm.logger.Warn().Err(err).Str("address", addr.Address.String()).Msg("failed to add address") + logger.Warn().Err(err).IPPrefix("address", addr.Address).Msg("failed to add address") } } actualToAdd := len(toAdd) - len(toUpdate) if len(toAdd) > 0 || len(toUpdate) > 0 || len(toRemove) > 0 { - nm.logger.Info(). + logger.Info(). Int("added", actualToAdd). Int("updated", len(toUpdate)). Int("removed", len(toRemove)). @@ -537,7 +541,7 @@ func (nm *NetlinkManager) ReconcileLink(link *Link, expected []types.IPAddress, } if err := nm.reconcileDefaultRoute(link, expectedGateways, family); err != nil { - nm.logger.Warn().Err(err).Msg("failed to reconcile default route") + logger.Warn().Err(err).Msg("failed to reconcile default route") } return nil diff --git a/pkg/nmlite/manager.go b/pkg/nmlite/manager.go index 03496d9ea..145f18853 100644 --- a/pkg/nmlite/manager.go +++ b/pkg/nmlite/manager.go @@ -7,20 +7,19 @@ import ( "context" "fmt" - "github.com/jetkvm/kvm/internal/sync" - "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/network/types" + "github.com/jetkvm/kvm/internal/sync" + "github.com/rs/zerolog" + "github.com/jetkvm/kvm/pkg/nmlite/jetdhcpc" "github.com/jetkvm/kvm/pkg/nmlite/link" - "github.com/rs/zerolog" ) // NetworkManager manages multiple network interfaces type NetworkManager struct { interfaces map[string]*InterfaceManager mu sync.RWMutex - logger *zerolog.Logger ctx context.Context cancel context.CancelFunc @@ -33,25 +32,25 @@ type NetworkManager struct { } // NewNetworkManager creates a new network manager -func NewNetworkManager(ctx context.Context, logger *zerolog.Logger) *NetworkManager { - if logger == nil { - logger = logging.GetSubsystemLogger("nm") - } - +func NewNetworkManager(ctx context.Context) *NetworkManager { // Initialize the NetlinkManager singleton - link.InitializeNetlinkManager(logger) + link.InitializeNetlinkManager() ctx, cancel := context.WithCancel(ctx) return &NetworkManager{ interfaces: make(map[string]*InterfaceManager), - logger: logger, ctx: ctx, cancel: cancel, - resolvConf: NewResolvConfManager(logger), + resolvConf: NewResolvConfManager(), } } +func (nm *NetworkManager) getLogger() *zerolog.Logger { + logging := logging.GetSubsystemLogger("nmlite").With().Int("interface_count", len(nm.interfaces)).Logger() + return &logging +} + // SetHostname sets the hostname and domain for the network manager func (nm *NetworkManager) SetHostname(hostname string, domain string) error { return nm.resolvConf.SetHostname(hostname, domain) @@ -71,7 +70,7 @@ func (nm *NetworkManager) AddInterface(iface string, config *types.NetworkConfig return fmt.Errorf("interface %s already managed", iface) } - im, err := NewInterfaceManager(nm.ctx, iface, config, nm.logger) + im, err := NewInterfaceManager(nm.ctx, iface, config) if err != nil { return fmt.Errorf("failed to create interface manager for %s: %w", iface, err) } @@ -108,7 +107,7 @@ func (nm *NetworkManager) AddInterface(iface string, config *types.NetworkConfig return fmt.Errorf("failed to start interface manager for %s: %w", iface, err) } - nm.logger.Info().Str("interface", iface).Msg("added interface to network manager") + nm.getLogger().Info().Str("interface", iface).Msg("added interface to network manager") return nil } @@ -123,11 +122,11 @@ func (nm *NetworkManager) RemoveInterface(iface string) error { } if err := im.Stop(); err != nil { - nm.logger.Error().Err(err).Str("interface", iface).Msg("failed to stop interface manager") + nm.getLogger().Error().Err(err).Str("interface", iface).Msg("failed to stop interface manager") } delete(nm.interfaces, iface) - nm.logger.Info().Str("interface", iface).Msg("removed interface from network manager") + nm.getLogger().Info().Str("interface", iface).Msg("removed interface from network manager") return nil } @@ -236,7 +235,7 @@ func (nm *NetworkManager) shouldKillLegacyDHCPClients() bool { func (nm *NetworkManager) CleanUpLegacyDHCPClients() error { shouldKill := nm.shouldKillLegacyDHCPClients() if shouldKill { - return jetdhcpc.KillUdhcpC(nm.logger) + return jetdhcpc.KillUdhcpC(nm.getLogger()) } return nil } @@ -249,12 +248,12 @@ func (nm *NetworkManager) Stop() error { var lastErr error for iface, im := range nm.interfaces { if err := im.Stop(); err != nil { - nm.logger.Error().Err(err).Str("interface", iface).Msg("failed to stop interface manager") + nm.getLogger().Error().Err(err).Str("interface", iface).Msg("failed to stop interface manager") lastErr = err } } nm.cancel() - nm.logger.Info().Msg("network manager stopped") + nm.getLogger().Info().Msg("network manager stopped") return lastErr } diff --git a/pkg/nmlite/resolvconf.go b/pkg/nmlite/resolvconf.go index b94f5c3db..f101b3656 100644 --- a/pkg/nmlite/resolvconf.go +++ b/pkg/nmlite/resolvconf.go @@ -7,6 +7,7 @@ import ( "os" "strings" + "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/network/types" "github.com/jetkvm/kvm/internal/sync" "github.com/jetkvm/kvm/pkg/nmlite/link" @@ -39,24 +40,17 @@ var ( // ResolvConfManager manages the resolv.conf file type ResolvConfManager struct { - logger *zerolog.Logger - mu sync.Mutex - conf *types.ResolvConf + mu sync.Mutex + conf *types.ResolvConf hostname string domain string } // NewResolvConfManager creates a new resolv.conf manager -func NewResolvConfManager(logger *zerolog.Logger) *ResolvConfManager { - if logger == nil { - // Create a no-op logger if none provided - logger = &zerolog.Logger{} - } - +func NewResolvConfManager() *ResolvConfManager { return &ResolvConfManager{ - logger: logger, - mu: sync.Mutex{}, + mu: sync.Mutex{}, conf: &types.ResolvConf{ ConfigIPv4: make(map[string]types.InterfaceResolvConf), ConfigIPv6: make(map[string]types.InterfaceResolvConf), @@ -107,12 +101,21 @@ func (rcm *ResolvConfManager) Reconcile() error { return rcm.update() } +func (rcm *ResolvConfManager) getLogger() *zerolog.Logger { + logging := logging.GetSubsystemLogger("nmlite"). + With(). + Str("subcomponent", "resolvconf"). + Logger() + return &logging +} + // Update updates the resolv.conf file func (rcm *ResolvConfManager) update() error { rcm.mu.Lock() defer rcm.mu.Unlock() - rcm.logger.Debug().Msg("updating resolv.conf") + logger := rcm.getLogger() + logger.Debug().Msg("updating resolv.conf") // Generate resolv.conf content content, err := rcm.generateResolvConf(rcm.conf) @@ -124,11 +127,11 @@ func (rcm *ResolvConfManager) update() error { if _, err := os.Stat(resolvConfPath); err == nil { existingContent, err := os.ReadFile(resolvConfPath) if err != nil { - rcm.logger.Warn().Err(err).Msg("failed to read existing resolv.conf") + logger.Warn().Err(err).Msg("failed to read existing resolv.conf") } if bytes.Equal(existingContent, content) { - rcm.logger.Debug().Msg("resolv.conf is the same, skipping write") + logger.Debug().Msg("resolv.conf is the same, skipping write") return nil } } @@ -138,10 +141,7 @@ func (rcm *ResolvConfManager) update() error { return fmt.Errorf("failed to write resolv.conf: %w", err) } - rcm.logger.Info(). - Interface("config", rcm.conf). - Msg("resolv.conf updated successfully") - + logger.Info().Interface("config", rcm.conf).Msg("resolv.conf updated successfully") return nil } @@ -192,7 +192,7 @@ func (rcm *ResolvConfManager) generateResolvConf(conf *types.ResolvConf) ([]byte mergeConfig(&nameservers, &searchList, &conf.ConfigIPv4) mergeConfig(&nameservers, &searchList, &conf.ConfigIPv6) - rcm.logger.Info(). + rcm.getLogger().Info(). Interface("nameservers", nameservers). Interface("searchList", searchList). Msg("merged config") diff --git a/pkg/nmlite/static.go b/pkg/nmlite/static.go index 9500556b6..584cd74a0 100644 --- a/pkg/nmlite/static.go +++ b/pkg/nmlite/static.go @@ -4,6 +4,7 @@ import ( "fmt" "net" + "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/network/types" "github.com/jetkvm/kvm/pkg/nmlite/link" "github.com/rs/zerolog" @@ -12,25 +13,27 @@ import ( // StaticConfigManager manages static network configuration type StaticConfigManager struct { ifaceName string - logger *zerolog.Logger } // NewStaticConfigManager creates a new static configuration manager -func NewStaticConfigManager(ifaceName string, logger *zerolog.Logger) (*StaticConfigManager, error) { +func NewStaticConfigManager(ifaceName string) (*StaticConfigManager, error) { if ifaceName == "" { return nil, fmt.Errorf("interface name cannot be empty") } - if logger == nil { - return nil, fmt.Errorf("logger cannot be nil") - } - return &StaticConfigManager{ ifaceName: ifaceName, - logger: logger, }, nil } +func (scm *StaticConfigManager) getLogger() *zerolog.Logger { + logging := logging.GetSubsystemLogger("nmlite"). + With(). + Str("interface", scm.ifaceName). + Logger() + return &logging +} + // ToIPv4Static applies static IPv4 configuration func (scm *StaticConfigManager) ToIPv4Static(config *types.IPv4StaticConfig) (*types.ParsedIPConfig, error) { if config == nil { @@ -42,7 +45,6 @@ func (scm *StaticConfigManager) ToIPv4Static(config *types.IPv4StaticConfig) (*t if err != nil { return nil, err } - scm.logger.Info().Str("ipNet", ipNet.String()).Interface("ipc", config).Msg("parsed IPv4 address and netmask") // Parse gateway gateway := net.ParseIP(config.Gateway.String) @@ -119,7 +121,7 @@ func (scm *StaticConfigManager) ToIPv6Static(config *types.IPv6StaticConfig) (*t // DisableIPv4 disables IPv4 on the interface func (scm *StaticConfigManager) DisableIPv4() error { - scm.logger.Info().Msg("disabling IPv4") + scm.getLogger().Info().Msg("disabling IPv4") netlinkMgr := getNetlinkManager() iface, err := netlinkMgr.GetLinkByName(scm.ifaceName) @@ -134,30 +136,30 @@ func (scm *StaticConfigManager) DisableIPv4() error { // Remove default route if err := scm.removeIPv4DefaultRoute(); err != nil { - scm.logger.Warn().Err(err).Msg("failed to remove IPv4 default route") + scm.getLogger().Warn().Err(err).Msg("failed to remove IPv4 default route") } - scm.logger.Info().Msg("IPv4 disabled") + scm.getLogger().Info().Msg("IPv4 disabled") return nil } // DisableIPv6 disables IPv6 on the interface func (scm *StaticConfigManager) DisableIPv6() error { - scm.logger.Info().Msg("disabling IPv6") + scm.getLogger().Info().Msg("disabling IPv6") netlinkMgr := getNetlinkManager() return netlinkMgr.DisableIPv6(scm.ifaceName) } // EnableIPv6SLAAC enables IPv6 SLAAC func (scm *StaticConfigManager) EnableIPv6SLAAC() error { - scm.logger.Info().Msg("enabling IPv6 SLAAC") + scm.getLogger().Info().Msg("enabling IPv6 SLAAC") netlinkMgr := getNetlinkManager() return netlinkMgr.EnableIPv6SLAAC(scm.ifaceName) } // EnableIPv6LinkLocal enables IPv6 link-local only func (scm *StaticConfigManager) EnableIPv6LinkLocal() error { - scm.logger.Info().Msg("enabling IPv6 link-local only") + scm.getLogger().Info().Msg("enabling IPv6 link-local only") netlinkMgr := getNetlinkManager() if err := netlinkMgr.EnableIPv6LinkLocal(scm.ifaceName); err != nil { diff --git a/pkg/nmlite/udhcpc/proc.go b/pkg/nmlite/udhcpc/proc.go index 69c2ab99e..70a9cce72 100644 --- a/pkg/nmlite/udhcpc/proc.go +++ b/pkg/nmlite/udhcpc/proc.go @@ -76,7 +76,7 @@ func (p *DHCPClient) findUdhcpcProcess() (int, error) { // check if it's a udhcpc process if strings.Contains(cmdlineText, fmt.Sprintf("-i %s", p.InterfaceName)) { - p.logger.Debug(). + p.getLogger().Debug(). Str("pid", d.Name()). Interface("cmdline", cmdline). Msg("found udhcpc process") @@ -93,7 +93,7 @@ func (c *DHCPClient) getProcessPid() (int, error) { // try to read the pid file pidHandle, err := os.ReadFile(c.pidFile) if err != nil { - c.logger.Warn().Err(err). + c.getLogger().Warn().Err(err). Str("pidFile", c.pidFile).Msg("failed to read udhcpc pid file") } @@ -101,7 +101,7 @@ func (c *DHCPClient) getProcessPid() (int, error) { if pidHandle != nil { pidFromFile, err := strconv.Atoi(string(pidHandle)) if err != nil { - c.logger.Warn().Err(err). + c.getLogger().Warn().Err(err). Str("pidFile", c.pidFile).Msg("failed to convert pid file to int") } pid = pidFromFile @@ -128,7 +128,7 @@ func (c *DHCPClient) getProcess() *os.Process { process, err := os.FindProcess(pid) if err != nil { - c.logger.Warn().Err(err). + c.getLogger().Warn().Err(err). Int("pid", pid).Msg("failed to find process") return nil } @@ -152,15 +152,15 @@ func (c *DHCPClient) GetProcess() *os.Process { c.process = nil c.process = c.getProcess() if c.process == nil { - c.logger.Error().Msg("failed to find new udhcpc process") + c.getLogger().Error().Msg("failed to find new udhcpc process") return nil } - c.logger.Warn(). + c.getLogger().Warn(). Int("oldPid", oldPid). Int("newPid", c.process.Pid). Msg("udhcpc process pid changed") } else if err != nil { - c.logger.Warn().Err(err). + c.getLogger().Warn().Err(err). Int("pid", c.process.Pid).Msg("udhcpc process is not running") } @@ -193,7 +193,7 @@ func (c *DHCPClient) signalProcess(sig syscall.Signal) error { s := process.Signal(sig) if s != nil { - c.logger.Warn().Err(s). + c.getLogger().Warn().Err(s). Int("pid", process.Pid). Str("signal", sig.String()). Msg("failed to signal udhcpc process") diff --git a/pkg/nmlite/udhcpc/udhcpc.go b/pkg/nmlite/udhcpc/udhcpc.go index 19ce2cb51..d434b6b3f 100644 --- a/pkg/nmlite/udhcpc/udhcpc.go +++ b/pkg/nmlite/udhcpc/udhcpc.go @@ -9,11 +9,12 @@ import ( "time" + "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/sync" + "github.com/rs/zerolog" "github.com/fsnotify/fsnotify" "github.com/jetkvm/kvm/internal/network/types" - "github.com/rs/zerolog" ) const ( @@ -27,7 +28,6 @@ type DHCPClient struct { leaseFile string pidFile string lease *Lease - logger *zerolog.Logger process *os.Process runOnce sync.Once onLeaseChange func(lease *types.DHCPLease) @@ -36,27 +36,29 @@ type DHCPClient struct { type DHCPClientOptions struct { InterfaceName string PidFile string - Logger *zerolog.Logger OnLeaseChange func(lease *types.DHCPLease) } -var defaultLogger = zerolog.New(os.Stdout).Level(zerolog.InfoLevel) - func NewDHCPClient(options *DHCPClientOptions) *DHCPClient { - if options.Logger == nil { - options.Logger = &defaultLogger - } - - l := options.Logger.With().Str("interface", options.InterfaceName).Logger() return &DHCPClient{ InterfaceName: options.InterfaceName, - logger: &l, leaseFile: fmt.Sprintf(DHCPLeaseFile, options.InterfaceName), pidFile: options.PidFile, onLeaseChange: options.OnLeaseChange, } } +func (c *DHCPClient) getLogger() *zerolog.Logger { + logger := logging.GetSubsystemLogger("nmlite"). + With(). + Str("subcomponent", "udhcpc"). + Str("interface", c.InterfaceName). + Str("pidFile", c.pidFile). + Str("leaseFile", c.leaseFile). + Logger() + return &logger +} + func (c *DHCPClient) getWatchPaths() []string { watchPaths := make(map[string]any) watchPaths[filepath.Dir(c.leaseFile)] = nil @@ -98,7 +100,7 @@ func (c *DHCPClient) run() error { } if event.Name == c.leaseFile { - c.logger.Debug(). + c.getLogger().Debug(). Str("event", event.Op.String()). Str("path", event.Name). Msg("udhcpc lease file updated, reloading lease") @@ -108,7 +110,7 @@ func (c *DHCPClient) run() error { if !ok { return } - c.logger.Error().Err(err).Msg("error watching lease file") + c.getLogger().Error().Err(err).Msg("error watching lease file") } } }() @@ -116,7 +118,7 @@ func (c *DHCPClient) run() error { for _, path := range c.getWatchPaths() { err = watcher.Add(path) if err != nil { - c.logger.Error(). + c.getLogger().Error(). Err(err). Str("path", path). Msg("failed to watch directory") @@ -145,7 +147,7 @@ func (c *DHCPClient) loadLeaseFile() error { data := string(file) if data == "" { - c.logger.Debug().Msg("udhcpc lease file is empty") + c.getLogger().Debug().Msg("udhcpc lease file is empty") return nil } @@ -165,7 +167,7 @@ func (c *DHCPClient) loadLeaseFile() error { c.lease = lease if lease.IPAddress == nil { - c.logger.Info(). + c.getLogger().Info(). Interface("lease", lease). Str("data", string(file)). Msg("udhcpc lease cleared") @@ -179,20 +181,20 @@ func (c *DHCPClient) loadLeaseFile() error { leaseExpiry, err := lease.SetLeaseExpiry() if err != nil { - c.logger.Error().Err(err).Msg("failed to get dhcp lease expiry") + c.getLogger().Error().Err(err).Msg("failed to get dhcp lease expiry") } else { expiresIn := time.Until(leaseExpiry) - c.logger.Info(). + c.getLogger().Info(). Interface("expiry", leaseExpiry). - Str("expiresIn", expiresIn.String()). + Dur("expiresIn", expiresIn). Msg("current dhcp lease expiry time calculated") } c.onLeaseChange(lease.ToDHCPLease()) - c.logger.Info(). - Str("ip", lease.IPAddress.String()). - Str("leaseTime", lease.LeaseTime.String()). + c.getLogger().Info(). + IPAddr("ip", lease.IPAddress). + Dur("leaseTime", lease.LeaseTime). Interface("data", lease). Msg(msg) @@ -236,7 +238,7 @@ func (c *DHCPClient) Start() error { go func() { err := c.run() if err != nil { - c.logger.Error().Err(err).Msg("failed to run udhcpc") + c.getLogger().Error().Err(err).Msg("failed to run udhcpc") } }() }) diff --git a/scripts/dev_deploy.sh b/scripts/dev_deploy.sh index 96e7cf60d..dd21ed77c 100755 --- a/scripts/dev_deploy.sh +++ b/scripts/dev_deploy.sh @@ -261,14 +261,16 @@ if [ "$INSTALL_APP" = true ] then msg_info "▶ Building release binary" do_make build_release \ - SKIP_NATIVE_IF_EXISTS=${SKIP_NATIVE_BUILD} \ - SKIP_UI_BUILD=${SKIP_UI_BUILD_RELEASE} \ - ENABLE_SYNC_TRACE=${ENABLE_SYNC_TRACE} + SKIP_NATIVE_IF_EXISTS=${SKIP_NATIVE_BUILD} \ + SKIP_UI_BUILD=${SKIP_UI_BUILD_RELEASE} \ + ENABLE_SYNC_TRACE=${ENABLE_SYNC_TRACE} # Copy the binary to the remote host as if we were the OTA updater. + msg_info "▶ Copying the application update to the remote host" sshdev "cat > /userdata/jetkvm/jetkvm_app.update" < bin/jetkvm_app # Reboot the device, the new app will be deployed by the startup process. + msg_info "▶ Rebooting the remote host" sshdev "reboot" else msg_info "▶ Building development binary" @@ -277,27 +279,9 @@ else SKIP_UI_BUILD=${SKIP_UI_BUILD_RELEASE} \ ENABLE_SYNC_TRACE=${ENABLE_SYNC_TRACE} - # Kill any existing instances of the application - sshdev "killall jetkvm_app_debug || true" - - # Copy the binary to the remote host - sshdev "cat > ${REMOTE_PATH}/jetkvm_app_debug" < bin/jetkvm_app - - if [ "$RESET_USB_HID_DEVICE" = true ]; then - msg_info "▶ Resetting USB HID device" - msg_warn "The option has been deprecated and will be removed in a future version, as JetKVM will now reset USB gadget configuration when needed" - # Remove the old USB gadget configuration - sshdev "rm -rf /sys/kernel/config/usb_gadget/jetkvm/configs/c.1/hid.usb*" - sshdev "ls /sys/class/udc > /sys/kernel/config/usb_gadget/jetkvm/UDC" - fi - - # Deploy and run the application on the remote host + # Kill any existing instances of the application on the remote host + msg_info "▶ Killing any running instances of the application on the remote host" sshdev ash << EOF -set -e - -# Set the library path to include the directory where librockit.so is located -export LD_LIBRARY_PATH=/oem/usr/lib:\$LD_LIBRARY_PATH - # Kill any existing instances of the application killall jetkvm_app || true killall jetkvm_app_debug || true @@ -312,6 +296,18 @@ while [ \$i -le 10 ]; do sleep 1 i=\$((i + 1)) done +EOF + + # Copy the binary to the remote host + msg_info "▶ Copying the application to the remote host" + sshdev "cat > ${REMOTE_PATH}/jetkvm_app_debug" < bin/jetkvm_app + + # Deploy and run the application on the remote host + msg_info "▶ Starting the application on the remote host" + logs=$(printenv | grep '^JETKVM_LOG' | sed "s/=/='/; s/$/'/") + sshdev ${logs} ash << EOF +# Set the library path to include the directory where librockit.so is located +export LD_LIBRARY_PATH=/oem/usr/lib:\$LD_LIBRARY_PATH # Navigate to the directory where the binary will be stored cd "${REMOTE_PATH}" @@ -320,7 +316,7 @@ cd "${REMOTE_PATH}" chmod +x jetkvm_app_debug # Run the application in the background -PION_LOG_TRACE=${LOG_TRACE_SCOPES} ./jetkvm_app_debug | tee -a /tmp/jetkvm_app_debug.log +./jetkvm_app_debug | tee -a /tmp/jetkvm_app_debug.log EOF fi diff --git a/serial.go b/serial.go index 5439d135a..cf6cfc110 100644 --- a/serial.go +++ b/serial.go @@ -7,7 +7,10 @@ import ( "strings" "time" + "github.com/jetkvm/kvm/internal/logging" + "github.com/jetkvm/kvm/internal/utils" "github.com/pion/webrtc/v4" + "github.com/rs/zerolog" "go.bug.st/serial" ) @@ -35,19 +38,19 @@ var ( ) func runATXControl() { - scopedLogger := serialLogger.With().Str("service", "atx_control").Logger() + logger := logging.GetSubsystemLogger("serial").With().Str("service", "atx_control").Logger() reader := bufio.NewReader(port) for { line, err := reader.ReadString('\n') if err != nil { - scopedLogger.Warn().Err(err).Msg("Error reading from serial port") + logger.Warn().Err(err).Msg("Error reading from serial port") return } // Each line should be 4 binary digits + newline if len(line) != 5 { - scopedLogger.Warn().Int("length", len(line)).Msg("Invalid line length") + logger.Warn().Int("length", len(line)).Msg("Invalid line length") continue } @@ -68,7 +71,7 @@ func runATXControl() { newLedPWRState != ledPWRState || newBtnRSTState != btnRSTState || newBtnPWRState != btnPWRState { - scopedLogger.Debug(). + logger.Debug(). Bool("hdd", newLedHDDState). Bool("pwr", newLedPWRState). Bool("rst", newBtnRSTState). @@ -141,40 +144,40 @@ func unmountDCControl() error { var dcState DCPowerState func runDCControl() { - scopedLogger := serialLogger.With().Str("service", "dc_control").Logger() reader := bufio.NewReader(port) hasRestoreFeature := false for { + logger := logging.GetSubsystemLogger("serial").With().Str("service", "dc_control").Logger() line, err := reader.ReadString('\n') if err != nil { - scopedLogger.Warn().Err(err).Msg("Error reading from serial port") + logger.Warn().Err(err).Msg("Error reading from serial port") return } // Split the line by semicolon parts := strings.Split(strings.TrimSpace(line), ";") if len(parts) == 5 { - scopedLogger.Debug().Str("line", line).Msg("Detected DC extension with restore feature") + logger.Debug().Str("line", line).Msg("Detected DC extension with restore feature") hasRestoreFeature = true } else if len(parts) == 4 { - scopedLogger.Debug().Str("line", line).Msg("Detected DC extension without restore feature") + logger.Debug().Str("line", line).Msg("Detected DC extension without restore feature") hasRestoreFeature = false } else { - scopedLogger.Warn().Str("line", line).Msg("Invalid line") + logger.Warn().Str("line", line).Msg("Invalid line") continue } // Parse new states powerState, err := strconv.Atoi(parts[0]) if err != nil { - scopedLogger.Warn().Err(err).Msg("Invalid power state") + logger.Warn().Err(err).Msg("Invalid power state") continue } dcState.IsOn = powerState == 1 if hasRestoreFeature { restoreState, err := strconv.Atoi(parts[4]) if err != nil { - scopedLogger.Warn().Err(err).Msg("Invalid restore state") + logger.Warn().Err(err).Msg("Invalid restore state") continue } dcState.RestoreState = restoreState @@ -184,21 +187,21 @@ func runDCControl() { } milliVolts, err := strconv.ParseFloat(parts[1], 64) if err != nil { - scopedLogger.Warn().Err(err).Msg("Invalid voltage") + logger.Warn().Err(err).Msg("Invalid voltage") continue } volts := milliVolts / 1000 // Convert mV to V milliAmps, err := strconv.ParseFloat(parts[2], 64) if err != nil { - scopedLogger.Warn().Err(err).Msg("Invalid current") + logger.Warn().Err(err).Msg("Invalid current") continue } amps := milliAmps / 1000 // Convert mA to A milliWatts, err := strconv.ParseFloat(parts[3], 64) if err != nil { - scopedLogger.Warn().Err(err).Msg("Invalid power") + logger.Warn().Err(err).Msg("Invalid power") continue } watts := milliWatts / 1000 // Convert mW to W @@ -275,7 +278,8 @@ func reopenSerialPort() error { var err error port, err = serial.Open(serialPortPath, defaultMode) if err != nil { - serialLogger.Error(). + logging.GetSubsystemLogger("serial"). + Error(). Err(err). Str("path", serialPortPath). Interface("mode", defaultMode). @@ -284,10 +288,12 @@ func reopenSerialPort() error { return nil } -func handleSerialChannel(d *webrtc.DataChannel) { - scopedLogger := serialLogger.With(). - Uint16("data_channel_id", *d.ID()).Logger() +func getLogger(d *webrtc.DataChannel) *zerolog.Logger { + logger := logging.GetSubsystemLogger("serial").With().Uint16("data_channel_id", *d.ID()).Logger() + return &logger +} +func handleSerialChannel(d *webrtc.DataChannel) { d.OnOpen(func() { go func() { buf := make([]byte, 1024) @@ -295,13 +301,13 @@ func handleSerialChannel(d *webrtc.DataChannel) { n, err := port.Read(buf) if err != nil { if err != io.EOF { - scopedLogger.Warn().Err(err).Msg("Failed to read from serial port") + getLogger(d).Warn().Err(err).Msg("Failed to read from serial port") } break } err = d.Send(buf[:n]) if err != nil { - scopedLogger.Warn().Err(err).Msg("Failed to send serial output") + getLogger(d).Warn().Err(err).Msg("Failed to send serial output") break } } @@ -312,17 +318,18 @@ func handleSerialChannel(d *webrtc.DataChannel) { if port == nil { return } + getLogger(d).Debug().Object("data", utils.ByteSlice(msg.Data)).Msg("Writing to serial port") _, err := port.Write(msg.Data) if err != nil { - scopedLogger.Warn().Err(err).Msg("Failed to write to serial") + getLogger(d).Warn().Err(err).Object("data", utils.ByteSlice(msg.Data)).Msg("Failed to write to serial") } }) d.OnError(func(err error) { - scopedLogger.Warn().Err(err).Msg("Serial channel error") + getLogger(d).Warn().Err(err).Msg("Serial channel error") }) d.OnClose(func() { - scopedLogger.Info().Msg("Serial channel closed") + getLogger(d).Info().Msg("Serial channel closed") }) } diff --git a/terminal.go b/terminal.go index e06e5cdc1..0d163edfb 100644 --- a/terminal.go +++ b/terminal.go @@ -8,7 +8,9 @@ import ( "os/exec" "github.com/creack/pty" + "github.com/jetkvm/kvm/internal/logging" "github.com/pion/webrtc/v4" + "github.com/rs/zerolog" ) type TerminalSize struct { @@ -16,10 +18,12 @@ type TerminalSize struct { Cols int `json:"cols"` } -func handleTerminalChannel(d *webrtc.DataChannel) { - scopedLogger := terminalLogger.With(). - Uint16("data_channel_id", *d.ID()).Logger() +func getTerminalLogger(d *webrtc.DataChannel) *zerolog.Logger { + logger := logging.GetSubsystemLogger("terminal").With().Uint16("data_channel_id", *d.ID()).Logger() + return &logger +} +func handleTerminalChannel(d *webrtc.DataChannel) { var ptmx *os.File var cmd *exec.Cmd d.OnOpen(func() { @@ -27,7 +31,7 @@ func handleTerminalChannel(d *webrtc.DataChannel) { var err error ptmx, err = pty.Start(cmd) if err != nil { - scopedLogger.Warn().Err(err).Msg("Failed to start pty") + getTerminalLogger(d).Warn().Err(err).Msg("Failed to start pty") d.Close() return } @@ -38,13 +42,13 @@ func handleTerminalChannel(d *webrtc.DataChannel) { n, err := ptmx.Read(buf) if err != nil { if err != io.EOF { - scopedLogger.Warn().Err(err).Msg("Failed to read from pty") + getTerminalLogger(d).Warn().Err(err).Msg("Failed to read from pty") } break } err = d.Send(buf[:n]) if err != nil { - scopedLogger.Warn().Err(err).Msg("Failed to send pty output") + getTerminalLogger(d).Warn().Err(err).Msg("Failed to send pty output") break } } @@ -67,16 +71,16 @@ func handleTerminalChannel(d *webrtc.DataChannel) { Cols: uint16(size.Cols), }) if err == nil { - scopedLogger.Info().Int("rows", size.Rows).Int("cols", size.Cols).Msg("Set terminal size") + getTerminalLogger(d).Info().Int("rows", size.Rows).Int("cols", size.Cols).Msg("Set terminal size") return } } - scopedLogger.Warn().Err(err).Msg("Failed to parse terminal size") + getTerminalLogger(d).Warn().Err(err).Msg("Failed to parse terminal size") } } _, err := ptmx.Write(msg.Data) if err != nil { - scopedLogger.Warn().Err(err).Msg("Failed to write to pty") + getTerminalLogger(d).Warn().Err(err).Msg("Failed to write to pty") } }) @@ -87,10 +91,10 @@ func handleTerminalChannel(d *webrtc.DataChannel) { if cmd != nil && cmd.Process != nil { _ = cmd.Process.Kill() } - scopedLogger.Info().Msg("Terminal channel closed") + getTerminalLogger(d).Info().Msg("Terminal channel closed") }) d.OnError(func(err error) { - scopedLogger.Warn().Err(err).Msg("Terminal channel error") + getTerminalLogger(d).Warn().Err(err).Msg("Terminal channel error") }) } diff --git a/timesync.go b/timesync.go index 956011b3c..4103b25de 100644 --- a/timesync.go +++ b/timesync.go @@ -13,14 +13,16 @@ var ( ) func isTimeSyncNeeded() bool { + logger := timesync.GetTimesyncLogger() + if builtTimestamp == "" { - timesyncLogger.Warn().Msg("built timestamp is not set, time sync is needed") + logger.Warn().Msg("built timestamp is not set, time sync is needed") return true } ts, err := strconv.Atoi(builtTimestamp) if err != nil { - timesyncLogger.Warn().Str("error", err.Error()).Msg("failed to parse built timestamp") + logger.Warn().Str("error", err.Error()).Msg("failed to parse built timestamp") return true } @@ -29,7 +31,7 @@ func isTimeSyncNeeded() bool { now := time.Now() if now.Sub(builtTime) < 0 { - timesyncLogger.Warn(). + logger.Warn(). Str("built_time", builtTime.Format(time.RFC3339)). Str("now", now.Format(time.RFC3339)). Msg("system time is behind the built time, time sync is needed") @@ -41,7 +43,6 @@ func isTimeSyncNeeded() bool { func initTimeSync() { timeSync = timesync.NewTimeSync(×ync.TimeSyncOptions{ - Logger: timesyncLogger, NetworkConfig: config.NetworkConfig, PreCheckIPv4: func() (bool, error) { if !networkManager.IPv4Ready() { diff --git a/ui/localization/messages/da.json b/ui/localization/messages/da.json index d2064532f..4aa213f74 100644 --- a/ui/localization/messages/da.json +++ b/ui/localization/messages/da.json @@ -71,10 +71,22 @@ "advanced_error_reset_config": "Konfigurationen kunne ikke nulstilles: {error}", "advanced_error_set_dev_channel": "Kunne ikke indstille udviklerkanaltilstand: {error}", "advanced_error_set_dev_mode": "Kunne ikke indstille udviklertilstand: {error}", + "advanced_error_set_log_level": "Kunne ikke indstille logniveau: {error}", "advanced_error_update_ssh_key": "Kunne ikke opdatere SSH-nøglen: {error}", "advanced_error_usb_emulation_disable": "Kunne ikke deaktivere USB-emulering: {error}", "advanced_error_usb_emulation_enable": "Kunne ikke aktivere USB-emulering: {error}", "advanced_error_version_update": "Kunne ikke starte versionsopdatering: {error}", + "advanced_log_level_debug": "Fejlfinding", + "advanced_log_level_default": "Misligholdelse", + "advanced_log_level_description": "Indstil logføringsniveauet til fejlfindingsformål", + "advanced_log_level_disabled": "Handicappet", + "advanced_log_level_error": "Fejl", + "advanced_log_level_fatal": "Fatal", + "advanced_log_level_info": "Info", + "advanced_log_level_panic": "Panik", + "advanced_log_level_title": "Logniveau", + "advanced_log_level_trace": "Spor", + "advanced_log_level_warn": "Advare", "advanced_loopback_only_description": "Begræns webgrænsefladeadgang kun til localhost (127.0.0.1)", "advanced_loopback_only_title": "Kun loopback-tilstand", "advanced_loopback_warning_before": "Før du aktiverer denne funktion, skal du sikre dig, at du har enten:", diff --git a/ui/localization/messages/de.json b/ui/localization/messages/de.json index 326baa165..c2dcefc8c 100644 --- a/ui/localization/messages/de.json +++ b/ui/localization/messages/de.json @@ -71,10 +71,22 @@ "advanced_error_reset_config": "Konfiguration konnte nicht zurückgesetzt werden: {error}", "advanced_error_set_dev_channel": "Der Dev-Kanalstatus konnte nicht festgelegt werden: {error}", "advanced_error_set_dev_mode": "Fehler beim Festlegen des Entwicklungsmodus: {error}", + "advanced_error_set_log_level": "Fehler beim Festlegen des Protokollierungslevels: {error}", "advanced_error_update_ssh_key": "SSH-Schlüssel konnte nicht aktualisiert werden: {error}", "advanced_error_usb_emulation_disable": "USB-Emulation konnte nicht deaktiviert werden: {error}", "advanced_error_usb_emulation_enable": "USB-Emulation konnte nicht aktiviert werden: {error}", "advanced_error_version_update": "Versionsaktualisierung konnte nicht initiiert werden: {error}", + "advanced_log_level_debug": "Debuggen", + "advanced_log_level_default": "Standard", + "advanced_log_level_description": "Legen Sie den Ausführlichkeitsgrad der Protokollierung für Fehlerbehebungszwecke fest.", + "advanced_log_level_disabled": "Deaktiviert", + "advanced_log_level_error": "Fehler", + "advanced_log_level_fatal": "Tödlich", + "advanced_log_level_info": "Info", + "advanced_log_level_panic": "Panik", + "advanced_log_level_title": "Protokollierungsstufe", + "advanced_log_level_trace": "Verfolgen", + "advanced_log_level_warn": "Warnen", "advanced_loopback_only_description": "Beschränken Sie den Zugriff auf die Weboberfläche nur auf den lokalen Host (127.0.0.1).", "advanced_loopback_only_title": "Nur-Loopback-Modus", "advanced_loopback_warning_before": "Bevor Sie diese Funktion aktivieren, stellen Sie sicher, dass Sie über Folgendes verfügen:", diff --git a/ui/localization/messages/en.json b/ui/localization/messages/en.json index 7341b930e..c07413310 100644 --- a/ui/localization/messages/en.json +++ b/ui/localization/messages/en.json @@ -71,10 +71,22 @@ "advanced_error_reset_config": "Failed to reset configuration: {error}", "advanced_error_set_dev_channel": "Failed to set dev channel state: {error}", "advanced_error_set_dev_mode": "Failed to set dev mode: {error}", + "advanced_error_set_log_level": "Failed to set log level: {error}", "advanced_error_update_ssh_key": "Failed to update SSH key: {error}", "advanced_error_usb_emulation_disable": "Failed to disable USB emulation: {error}", "advanced_error_usb_emulation_enable": "Failed to enable USB emulation: {error}", "advanced_error_version_update": "Failed to initiate version update: {error}", + "advanced_log_level_debug": "Debug", + "advanced_log_level_default": "Default", + "advanced_log_level_description": "Set the logging verbosity level for troubleshooting purposes", + "advanced_log_level_disabled": "Disabled", + "advanced_log_level_error": "Error", + "advanced_log_level_fatal": "Fatal", + "advanced_log_level_info": "Info", + "advanced_log_level_panic": "Panic", + "advanced_log_level_title": "Log Level", + "advanced_log_level_trace": "Trace", + "advanced_log_level_warn": "Warn", "advanced_loopback_only_description": "Restrict web interface access to localhost only (127.0.0.1)", "advanced_loopback_only_title": "Loopback-Only Mode", "advanced_loopback_warning_before": "Before enabling this feature, make sure you have either:", diff --git a/ui/localization/messages/es.json b/ui/localization/messages/es.json index cec167b1f..5924be19a 100644 --- a/ui/localization/messages/es.json +++ b/ui/localization/messages/es.json @@ -71,10 +71,22 @@ "advanced_error_reset_config": "No se pudo restablecer la configuración: {error}", "advanced_error_set_dev_channel": "No se pudo establecer el estado del canal de desarrollo: {error}", "advanced_error_set_dev_mode": "No se pudo establecer el modo de desarrollo: {error}", + "advanced_error_set_log_level": "No se pudo establecer el nivel de registro: {error}", "advanced_error_update_ssh_key": "No se pudo actualizar la clave SSH: {error}", "advanced_error_usb_emulation_disable": "No se pudo deshabilitar la emulación USB: {error}", "advanced_error_usb_emulation_enable": "No se pudo habilitar la emulación USB: {error}", "advanced_error_version_update": "Error al iniciar la actualización de versión: {error}", + "advanced_log_level_debug": "Depurar", + "advanced_log_level_default": "Por defecto", + "advanced_log_level_description": "Establecer el nivel de verbosidad del registro para fines de resolución de problemas", + "advanced_log_level_disabled": "Desactivado", + "advanced_log_level_error": "Error", + "advanced_log_level_fatal": "Fatal", + "advanced_log_level_info": "Información", + "advanced_log_level_panic": "Pánico", + "advanced_log_level_title": "Nivel de registro", + "advanced_log_level_trace": "Rastro", + "advanced_log_level_warn": "Advertir", "advanced_loopback_only_description": "Restringir el acceso a la interfaz web solo al host local (127.0.0.1)", "advanced_loopback_only_title": "Modo de solo bucle invertido", "advanced_loopback_warning_before": "Antes de habilitar esta función, asegúrese de tener:", diff --git a/ui/localization/messages/fr.json b/ui/localization/messages/fr.json index b86bc160b..d358f11b6 100644 --- a/ui/localization/messages/fr.json +++ b/ui/localization/messages/fr.json @@ -71,10 +71,22 @@ "advanced_error_reset_config": "Échec de la réinitialisation de la configuration : {error}", "advanced_error_set_dev_channel": "Échec de la définition de l'état du canal de développement : {error}", "advanced_error_set_dev_mode": "Échec de la définition du mode de développement : {error}", + "advanced_error_set_log_level": "Échec de la définition du niveau de journalisation : {error}", "advanced_error_update_ssh_key": "Échec de la mise à jour de la clé SSH : {error}", "advanced_error_usb_emulation_disable": "Échec de la désactivation de l'émulation USB : {error}", "advanced_error_usb_emulation_enable": "Échec de l'activation de l'émulation USB : {error}", "advanced_error_version_update": "Échec de la mise à jour de version : {error}", + "advanced_log_level_debug": "Déboguer", + "advanced_log_level_default": "Défaut", + "advanced_log_level_description": "Définissez le niveau de verbosité des journaux à des fins de dépannage.", + "advanced_log_level_disabled": "Désactivé", + "advanced_log_level_error": "Erreur", + "advanced_log_level_fatal": "Fatal", + "advanced_log_level_info": "Info", + "advanced_log_level_panic": "Panique", + "advanced_log_level_title": "Niveau de journalisation", + "advanced_log_level_trace": "Tracer", + "advanced_log_level_warn": "Avertir", "advanced_loopback_only_description": "Restreindre l'accès à l'interface Web à l'hôte local uniquement (127.0.0.1)", "advanced_loopback_only_title": "Mode de bouclage uniquement", "advanced_loopback_warning_before": "Avant d'activer cette fonctionnalité, assurez-vous d'avoir :", diff --git a/ui/localization/messages/it.json b/ui/localization/messages/it.json index 3fe77fedf..8b91227e8 100644 --- a/ui/localization/messages/it.json +++ b/ui/localization/messages/it.json @@ -71,10 +71,22 @@ "advanced_error_reset_config": "Impossibile reimpostare la configurazione: {error}", "advanced_error_set_dev_channel": "Impossibile impostare lo stato del canale di sviluppo: {error}", "advanced_error_set_dev_mode": "Impossibile impostare la modalità di sviluppo: {error}", + "advanced_error_set_log_level": "Impossibile impostare il livello di registro: {error}", "advanced_error_update_ssh_key": "Impossibile aggiornare la chiave SSH: {error}", "advanced_error_usb_emulation_disable": "Impossibile disabilitare l'emulazione USB: {error}", "advanced_error_usb_emulation_enable": "Impossibile abilitare l'emulazione USB: {error}", "advanced_error_version_update": "Impossibile avviare l'aggiornamento della versione: {error}", + "advanced_log_level_debug": "Debug", + "advanced_log_level_default": "Predefinito", + "advanced_log_level_description": "Imposta il livello di verbosità della registrazione per scopi di risoluzione dei problemi", + "advanced_log_level_disabled": "Disabili", + "advanced_log_level_error": "Errore", + "advanced_log_level_fatal": "Fatale", + "advanced_log_level_info": "Informazioni", + "advanced_log_level_panic": "Panico", + "advanced_log_level_title": "Livello di registro", + "advanced_log_level_trace": "Traccia", + "advanced_log_level_warn": "Avvisare", "advanced_loopback_only_description": "Limita l'accesso all'interfaccia web solo a localhost (127.0.0.1)", "advanced_loopback_only_title": "Modalità solo loopback", "advanced_loopback_warning_before": "Prima di abilitare questa funzione, assicurati di avere:", diff --git a/ui/localization/messages/nb.json b/ui/localization/messages/nb.json index 59c8ea64c..7345462f4 100644 --- a/ui/localization/messages/nb.json +++ b/ui/localization/messages/nb.json @@ -71,10 +71,22 @@ "advanced_error_reset_config": "Kunne ikke tilbakestille konfigurasjonen: {error}", "advanced_error_set_dev_channel": "Klarte ikke å angi tilstanden til utviklerkanalen: {error}", "advanced_error_set_dev_mode": "Kunne ikke angi utviklermodus: {error}", + "advanced_error_set_log_level": "Klarte ikke å angi loggnivå: {error}", "advanced_error_update_ssh_key": "Kunne ikke oppdatere SSH-nøkkelen: {error}", "advanced_error_usb_emulation_disable": "Kunne ikke deaktivere USB-emulering: {error}", "advanced_error_usb_emulation_enable": "Kunne ikke aktivere USB-emulering: {error}", "advanced_error_version_update": "Kunne ikke starte versjonsoppdatering: {error}", + "advanced_log_level_debug": "Feilsøking", + "advanced_log_level_default": "Misligholde", + "advanced_log_level_description": "Angi detaljnivået for loggføring for feilsøkingsformål", + "advanced_log_level_disabled": "Funksjonshemmet", + "advanced_log_level_error": "Feil", + "advanced_log_level_fatal": "Dødelig", + "advanced_log_level_info": "Info", + "advanced_log_level_panic": "Panikk", + "advanced_log_level_title": "Loggnivå", + "advanced_log_level_trace": "Spor", + "advanced_log_level_warn": "Varsle", "advanced_loopback_only_description": "Begrens tilgang til webgrensesnittet kun til lokal vert (127.0.0.1)", "advanced_loopback_only_title": "Kun lokal tilgang", "advanced_loopback_warning_before": "Før du aktiverer denne funksjonen, må du sørge for at du har enten:", diff --git a/ui/localization/messages/sv.json b/ui/localization/messages/sv.json index 17c794f30..edf896175 100644 --- a/ui/localization/messages/sv.json +++ b/ui/localization/messages/sv.json @@ -71,10 +71,22 @@ "advanced_error_reset_config": "Misslyckades med att återställa konfigurationen: {error}", "advanced_error_set_dev_channel": "Misslyckades med att ställa in status för utvecklarkanalen: {error}", "advanced_error_set_dev_mode": "Misslyckades med att ställa in utvecklarläge: {error}", + "advanced_error_set_log_level": "Misslyckades med att ställa in loggnivå: {error}", "advanced_error_update_ssh_key": "Misslyckades med att uppdatera SSH-nyckeln: {error}", "advanced_error_usb_emulation_disable": "Misslyckades med att inaktivera USB-emulering: {error}", "advanced_error_usb_emulation_enable": "Misslyckades med att aktivera USB-emulering: {error}", "advanced_error_version_update": "Misslyckades med att initiera versionsuppdatering: {error}", + "advanced_log_level_debug": "Felsök", + "advanced_log_level_default": "Standard", + "advanced_log_level_description": "Ställ in loggningsnivån för felsökning", + "advanced_log_level_disabled": "Funktionshindrad", + "advanced_log_level_error": "Fel", + "advanced_log_level_fatal": "Dödlig", + "advanced_log_level_info": "Info", + "advanced_log_level_panic": "Panik", + "advanced_log_level_title": "Loggnivå", + "advanced_log_level_trace": "Spåra", + "advanced_log_level_warn": "Varna", "advanced_loopback_only_description": "Begränsa åtkomst till webbgränssnittet endast till lokal värd (127.0.0.1)", "advanced_loopback_only_title": "Loopback-läge", "advanced_loopback_warning_before": "Innan du aktiverar den här funktionen, se till att du har antingen:", diff --git a/ui/localization/messages/zh.json b/ui/localization/messages/zh.json index 9510b518a..f1d07c68c 100644 --- a/ui/localization/messages/zh.json +++ b/ui/localization/messages/zh.json @@ -75,6 +75,17 @@ "advanced_error_usb_emulation_disable": "禁用 USB 模拟失败:{error}", "advanced_error_usb_emulation_enable": "启用 USB 模拟失败:{error}", "advanced_error_version_update": "启动版本更新失败:{error}", + "advanced_log_level_debug": "调试", + "advanced_log_level_default": "默认", + "advanced_log_level_description": "设置日志详细级别以进行故障排除", + "advanced_log_level_disabled": "已禁用", + "advanced_log_level_error": "错误", + "advanced_log_level_fatal": "致命的", + "advanced_log_level_info": "信息", + "advanced_log_level_panic": "恐慌", + "advanced_log_level_title": "日志级别", + "advanced_log_level_trace": "痕迹", + "advanced_log_level_warn": "警告", "advanced_loopback_only_description": "将 Web 访问限制为仅本地主机 (127.0.0.1)。", "advanced_loopback_only_title": "环回模式", "advanced_loopback_warning_before": "在启用此功能之前,请确保您已具备以下任一条件:", diff --git a/ui/src/routes/devices.$id.settings.advanced.tsx b/ui/src/routes/devices.$id.settings.advanced.tsx index dd39f9b6d..6b4918d5d 100644 --- a/ui/src/routes/devices.$id.settings.advanced.tsx +++ b/ui/src/routes/devices.$id.settings.advanced.tsx @@ -38,6 +38,7 @@ export default function SettingsAdvancedRoute() { const [resetConfig, setResetConfig] = useState(false); const [versionChangeAcknowledged, setVersionChangeAcknowledged] = useState(false); const [customVersionUpdateLoading, setCustomVersionUpdateLoading] = useState(false); + const [logLevel, setLogLevel] = useState("WARN"); const settings = useSettingsStore(); useEffect(() => { @@ -66,6 +67,12 @@ export default function SettingsAdvancedRoute() { if ("error" in resp) return; setLocalLoopbackOnly(resp.result as boolean); }); + + send("getLogLevel", {}, (resp: JsonRpcResponse) => { + if ("error" in resp) return; + const result = resp.result as { level: string }; + setLogLevel(result.level); + }); }, [send, setDeveloperMode]); const getUsbEmulationState = useCallback(() => { @@ -81,8 +88,12 @@ export default function SettingsAdvancedRoute() { if ("error" in resp) { notifications.error( enabled - ? m.advanced_error_usb_emulation_enable({ error: resp.error.data || m.unknown_error() }) - : m.advanced_error_usb_emulation_disable({ error: resp.error.data || m.unknown_error() }) + ? m.advanced_error_usb_emulation_enable({ + error: resp.error.data || m.unknown_error(), + }) + : m.advanced_error_usb_emulation_disable({ + error: resp.error.data || m.unknown_error(), + }), ); return; } @@ -97,7 +108,7 @@ export default function SettingsAdvancedRoute() { send("resetConfig", {}, (resp: JsonRpcResponse) => { if ("error" in resp) { notifications.error( - m.advanced_error_reset_config({ error: resp.error.data || m.unknown_error() }) + m.advanced_error_reset_config({ error: resp.error.data || m.unknown_error() }), ); return; } @@ -109,7 +120,9 @@ export default function SettingsAdvancedRoute() { send("setSSHKeyState", { sshKey }, (resp: JsonRpcResponse) => { if ("error" in resp) { notifications.error( - m.advanced_error_update_ssh_key({ error: resp.error.data || m.unknown_error() }) + m.advanced_error_update_ssh_key({ + error: resp.error.data || m.unknown_error(), + }), ); return; } @@ -122,7 +135,9 @@ export default function SettingsAdvancedRoute() { send("setDevModeState", { enabled: developerMode }, (resp: JsonRpcResponse) => { if ("error" in resp) { notifications.error( - m.advanced_error_set_dev_mode({ error: resp.error.data || m.unknown_error() }) + m.advanced_error_set_dev_mode({ + error: resp.error.data || m.unknown_error(), + }), ); return; } @@ -137,7 +152,9 @@ export default function SettingsAdvancedRoute() { send("setDevChannelState", { enabled }, (resp: JsonRpcResponse) => { if ("error" in resp) { notifications.error( - m.advanced_error_set_dev_channel({ error: resp.error.data || m.unknown_error() }) + m.advanced_error_set_dev_channel({ + error: resp.error.data || m.unknown_error(), + }), ); return; } @@ -153,8 +170,12 @@ export default function SettingsAdvancedRoute() { if ("error" in resp) { notifications.error( enabled - ? m.advanced_error_loopback_enable({ error: resp.error.data || m.unknown_error() }) - : m.advanced_error_loopback_disable({ error: resp.error.data || m.unknown_error() }) + ? m.advanced_error_loopback_enable({ + error: resp.error.data || m.unknown_error(), + }) + : m.advanced_error_loopback_disable({ + error: resp.error.data || m.unknown_error(), + }), ); return; } @@ -182,6 +203,23 @@ export default function SettingsAdvancedRoute() { [applyLoopbackOnlyMode, setShowLoopbackWarning], ); + const handleLogLevelChange = useCallback( + (level: string) => { + send("setLogLevel", { level: level }, (resp: JsonRpcResponse) => { + if ("error" in resp) { + notifications.error( + m.advanced_error_set_log_level({ + error: resp.error.data || m.unknown_error(), + }), + ); + return; + } + setLogLevel(level); + }); + }, + [send, setLogLevel], + ); + const confirmLoopbackModeEnable = useCallback(() => { applyLoopbackOnlyMode(true); setShowLoopbackWarning(false); @@ -190,9 +228,12 @@ export default function SettingsAdvancedRoute() { const handleVersionUpdateError = useCallback((error?: JsonRpcError | string) => { notifications.error( m.advanced_error_version_update({ - error: typeof error === "string" ? error : (error?.data ?? error?.message ?? m.unknown_error()) + error: + typeof error === "string" + ? error + : (error?.data ?? error?.message ?? m.unknown_error()), }), - { duration: 1000 * 15 } // 15 seconds + { duration: 1000 * 15 }, // 15 seconds ); setCustomVersionUpdateLoading(false); }, []); @@ -200,7 +241,8 @@ export default function SettingsAdvancedRoute() { const handleCustomVersionUpdate = useCallback(async () => { const components: UpdateComponents = {}; if (["app", "both"].includes(updateTarget) && appVersion) components.app = appVersion; - if (["system", "both"].includes(updateTarget) && systemVersion) components.system = systemVersion; + if (["system", "both"].includes(updateTarget) && systemVersion) + components.system = systemVersion; let versionInfo: SystemVersionInfo | undefined; try { @@ -209,7 +251,9 @@ export default function SettingsAdvancedRoute() { setCustomVersionUpdateLoading(true); versionInfo = await checkUpdateComponents({ components, - }, devChannel); + }, + devChannel, + ); } catch (error: unknown) { const jsonRpcError = error as JsonRpcError; handleVersionUpdateError(jsonRpcError); @@ -219,11 +263,19 @@ export default function SettingsAdvancedRoute() { let hasUpdate = false; const pageParams = new URLSearchParams(); - if (components.app && versionInfo?.remote?.appVersion && versionInfo?.appUpdateAvailable) { + if ( + components.app && + versionInfo?.remote?.appVersion && + versionInfo?.appUpdateAvailable + ) { hasUpdate = true; pageParams.set("custom_app_version", versionInfo.remote?.appVersion); } - if (components.system && versionInfo?.remote?.systemVersion && versionInfo?.systemUpdateAvailable) { + if ( + components.system && + versionInfo?.remote?.systemVersion && + versionInfo?.systemUpdateAvailable + ) { hasUpdate = true; pageParams.set("custom_system_version", versionInfo.remote?.systemVersion); } @@ -237,9 +289,14 @@ export default function SettingsAdvancedRoute() { // Navigate to update page navigateTo(`/settings/general/update?${pageParams.toString()}`); }, [ - updateTarget, appVersion, systemVersion, devChannel, - navigateTo, resetConfig, handleVersionUpdateError, - setCustomVersionUpdateLoading + updateTarget, + appVersion, + systemVersion, + devChannel, + navigateTo, + resetConfig, + handleVersionUpdateError, + setCustomVersionUpdateLoading, ]); return ( @@ -319,7 +376,8 @@ export default function SettingsAdvancedRoute() { placeholder={m.advanced_ssh_public_key_placeholder()} />

- {m.advanced_ssh_default_user()}root. + {m.advanced_ssh_default_user()} + root.

@@ -484,12 +564,8 @@ export default function SettingsAdvancedRoute() { title={m.advanced_loopback_warning_title()} description={ <> -

- {m.advanced_loopback_warning_description()} -

-

- {m.advanced_loopback_warning_before()} -

+

{m.advanced_loopback_warning_description()}

+

{m.advanced_loopback_warning_before()}