diff --git a/cloud.go b/cloud.go index dbbd3bbcc..7a0e34fc7 100644 --- a/cloud.go +++ b/cloud.go @@ -475,10 +475,9 @@ func handleSessionRequest( } cloudLogger.Info().Interface("session", session).Msg("new session accepted") - cloudLogger.Trace().Interface("session", session).Msg("new session accepted") // Cancel any ongoing keyboard macro when session changes - cancelKeyboardMacro() + _ = cancelKeyboardMacro() currentSession = session _ = wsjson.Write(context.Background(), c, gin.H{"type": "answer", "data": sd}) diff --git a/display.go b/display.go index 9b12ad433..a0e7780ad 100644 --- a/display.go +++ b/display.go @@ -75,7 +75,7 @@ func updateDisplay() { nativeInstance.UpdateLabelIfChanged("hdmi_status_label", "Disconnected") _, _ = nativeInstance.UIObjClearState("hdmi_status_label", "LV_STATE_CHECKED") } - nativeInstance.UpdateLabelIfChanged("cloud_status_label", fmt.Sprintf("%d active", actionSessions)) + nativeInstance.UpdateLabelIfChanged("cloud_status_label", fmt.Sprintf("%d active", int(getActiveSessions()))) if networkManager != nil && networkManager.IsUp() { nativeInstance.UISetVar("main_screen", "home_screen") diff --git a/hidrpc.go b/hidrpc.go index ebe03daab..0d1c7e439 100644 --- a/hidrpc.go +++ b/hidrpc.go @@ -8,98 +8,108 @@ import ( "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/usbgadget" + "github.com/jetkvm/kvm/internal/utils" "github.com/rs/zerolog" ) -func handleHidRPCMessage(message hidrpc.Message, session *Session) { - var rpcErr error - +func handleHidRPCMessage(message hidrpc.Message, session *Session) error { switch message.Type() { case hidrpc.TypeHandshake: - message, err := hidrpc.NewHandshakeMessage().Marshal() - if err != nil { - logger.Warn().Err(err).Msg("failed to marshal handshake message") - return - } - if err := session.HidChannel.Send(message); err != nil { - logger.Warn().Err(err).Msg("failed to send handshake message") - return - } - session.hidRPCAvailable = true - case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport: - rpcErr = handleHidRPCKeyboardInput(message) + return handleHidRPCHandshake(session) case hidrpc.TypeKeyboardMacroReport: - keyboardMacroReport, err := message.KeyboardMacroReport() - if err != nil { - logger.Warn().Err(err).Msg("failed to get keyboard macro report") - return - } - rpcErr = rpcExecuteKeyboardMacro(keyboardMacroReport.Steps) + return handleKeyboardMacro(message) + case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport: + return handleHidRPCKeyboardInput(message) case hidrpc.TypeCancelKeyboardMacroReport: - rpcCancelKeyboardMacro() - return + return rpcCancelKeyboardMacro() case hidrpc.TypeKeypressKeepAliveReport: - rpcErr = handleHidRPCKeypressKeepAlive(session) + return handleHidRPCKeypressKeepAlive(session) case hidrpc.TypePointerReport: - pointerReport, err := message.PointerReport() - if err != nil { - logger.Warn().Err(err).Msg("failed to get pointer report") - return - } - rpcErr = rpcAbsMouseReport(pointerReport.X, pointerReport.Y, pointerReport.Button) + return handlePointerReport(message) case hidrpc.TypeMouseReport: - mouseReport, err := message.MouseReport() - if err != nil { - logger.Warn().Err(err).Msg("failed to get mouse report") - return - } - rpcErr = rpcRelMouseReport(mouseReport.DX, mouseReport.DY, mouseReport.Button) - default: - logger.Warn().Uint8("type", uint8(message.Type())).Msg("unknown HID RPC message type") + return handleMouseReport(message) } - if rpcErr != nil { - logger.Warn().Err(rpcErr).Msg("failed to handle HID RPC message") + return fmt.Errorf("unknown HID RPC message type %d", message.Type()) +} + +func handleHidRPCHandshake(session *Session) error { + hidRPCLogger.Debug().Msg("handling handshake") + message, err := hidrpc.NewHandshakeMessage().Marshal() + if err != nil { + return err } + if err = session.HidChannel.Send(message); err != nil { + return err + } + session.hidRPCAvailable = true + return nil } -func onHidMessage(msg hidQueueMessage, session *Session) { +func handleKeyboardMacro(message hidrpc.Message) error { + keyboardMacroReport, err := message.KeyboardMacroReport() + if err != nil { + return err + } + hidRPCLogger.Debug().Interface("keyboardMacroReport", keyboardMacroReport).Msg("handling keyboard macro") + return rpcExecuteKeyboardMacro(keyboardMacroReport.Steps) +} + +func handleMouseReport(message hidrpc.Message) error { + mouseReport, err := message.MouseReport() + if err != nil { + return err + } + hidRPCLogger.Debug().Interface("mouseReport", mouseReport).Msg("handling relative mouse") + return rpcRelMouseReport(mouseReport.DX, mouseReport.DY, mouseReport.Button) +} + +func handlePointerReport(message hidrpc.Message) error { + pointerReport, err := message.PointerReport() + if err != nil { + return err + } + hidRPCLogger.Debug().Interface("pointerReport", pointerReport).Msg("handling absolute pointer") + return rpcAbsMouseReport(pointerReport.X, pointerReport.Y, pointerReport.Button) +} + +func onHidMessage(msg hidQueueMessage, session *Session, index int) { + logger := hidRPCLogger.With().Int("queueIndex", index).Str("channel", msg.channel).Logger() data := msg.Data - scopedLogger := hidRPCLogger.With(). - Str("channel", msg.channel). - Bytes("data", data). - Logger() - scopedLogger.Debug().Msg("HID RPC message received") + if logger.GetLevel() <= zerolog.TraceLevel { + logger.Trace().Object("data", utils.ByteSlice(data)).Msg("HID RPC message received") + } if len(data) < 1 { - scopedLogger.Warn().Int("length", len(data)).Msg("received empty data in HID RPC message handler") + logger.Warn().Int("length", len(data)).Msg("received empty data in HID RPC message handler") return } var message hidrpc.Message if err := hidrpc.Unmarshal(data, &message); err != nil { - scopedLogger.Warn().Err(err).Msg("failed to unmarshal HID RPC message") + logger.Warn().Err(err).Msg("failed to unmarshal HID RPC message") return } - if scopedLogger.GetLevel() <= zerolog.DebugLevel { - scopedLogger = scopedLogger.With().Str("descr", message.String()).Logger() + if logger.GetLevel() <= zerolog.DebugLevel { + logger = logger.With().Str("descr", message.String()).Logger() } t := time.Now() - r := make(chan interface{}) go func() { - handleHidRPCMessage(message, session) - r <- nil + r <- handleHidRPCMessage(message, session) }() select { case <-time.After(1 * time.Second): - scopedLogger.Warn().Msg("HID RPC message timed out") - case <-r: - scopedLogger.Debug().Dur("duration", time.Since(t)).Msg("HID RPC message handled") + logger.Warn().Msg("HID RPC message took too long") + case err := <-r: + logger.Debug().Dur("duration", time.Since(t)).Msg("HID RPC message handled") + if err != nil { + logger.Warn().Err(err.(error)).Msg("failed to handle HID RPC message") + } } } @@ -108,12 +118,10 @@ func onHidMessage(msg hidQueueMessage, session *Session) { // macOS default: 15 * 15 = 225ms https://discussions.apple.com/thread/1316947?sortBy=rank // Linux default: 250ms https://man.archlinux.org/man/kbdrate.8.en // Windows default: 1s `HKEY_CURRENT_USER\Control Panel\Accessibility\Keyboard Response\AutoRepeatDelay` - const expectedRate = 50 * time.Millisecond // expected keepalive interval const maxLateness = 50 * time.Millisecond // max jitter we'll tolerate OR jitter budget const baseExtension = expectedRate + maxLateness // 100ms extension on perfect tick - -const maxStaleness = 225 * time.Millisecond // discard ancient packets outright +const maxStaleness = 225 * time.Millisecond // discard ancient packets outright func handleHidRPCKeypressKeepAlive(session *Session) error { session.keepAliveJitterLock.Lock() @@ -128,7 +136,6 @@ func handleHidRPCKeypressKeepAlive(session *Session) error { return nil } - validTick := true timerExtension := baseExtension if !session.lastKeepAliveArrivalTime.IsZero() { @@ -147,14 +154,11 @@ func handleHidRPCKeypressKeepAlive(session *Session) error { // This is likely a retransmit stall or ordering delay. // We reject the tick entirely and DO NOT extend, // so the auto-release still fires on time. - validTick = false + return nil } } } - if !validTick { - return nil - } // Only valid ticks update our state and extend the timer. session.lastKeepAliveArrivalTime = now session.lastTimerResetTime = now @@ -167,6 +171,8 @@ func handleHidRPCKeypressKeepAlive(session *Session) error { } func handleHidRPCKeyboardInput(message hidrpc.Message) error { + logger := hidRPCLogger.With().Interface("message", message).Logger() + switch message.Type() { case hidrpc.TypeKeypressReport: keypressReport, err := message.KeypressReport() @@ -174,6 +180,7 @@ func handleHidRPCKeyboardInput(message hidrpc.Message) error { logger.Warn().Err(err).Msg("failed to get keypress report") return err } + logger.Debug().Interface("keypressReport", keypressReport).Msg("handling key press") return rpcKeypressReport(keypressReport.Key, keypressReport.Press) case hidrpc.TypeKeyboardReport: keyboardReport, err := message.KeyboardReport() @@ -181,6 +188,7 @@ func handleHidRPCKeyboardInput(message hidrpc.Message) error { logger.Warn().Err(err).Msg("failed to get keyboard report") return err } + logger.Debug().Interface("keyboardReport", keyboardReport).Msg("handling keyboard") return rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys) } @@ -188,6 +196,8 @@ func handleHidRPCKeyboardInput(message hidrpc.Message) error { } func reportHidRPC(params any, session *Session) { + logger := hidRPCLogger.With().Interface("params", params).Logger() + if session == nil { logger.Warn().Msg("session is nil, skipping reportHidRPC") return @@ -205,6 +215,7 @@ func reportHidRPC(params any, session *Session) { message []byte err error ) + switch params := params.(type) { case usbgadget.KeyboardState: message, err = hidrpc.NewKeyboardLedMessage(params).Marshal() @@ -216,23 +227,35 @@ func reportHidRPC(params any, session *Session) { err = fmt.Errorf("unknown HID RPC message type: %T", params) } - if err != nil { - logger.Warn().Err(err).Msg("failed to marshal HID RPC message") - return - } + logger = logger.With().Type("type", params).Logger() - if message == nil { - logger.Warn().Msg("failed to marshal HID RPC message") + if err != nil || message == nil { + logger.Warn().Err(err).Msg("failed to marshal HID RPC message") return } - if err := session.HidChannel.Send(message); err != nil { - if errors.Is(err, io.ErrClosedPipe) { - logger.Debug().Err(err).Msg("HID RPC channel closed, skipping reportHidRPC") - return + // fire and forget... + go func() { + t := time.Now() + r := make(chan interface{}) + go func() { + logger.Trace().Msg("sending HID RPC report") + r <- session.HidChannel.Send(message) + }() + select { + case <-time.After(1 * time.Second): + logger.Warn().Msg("HID RPC report took too long") + case err := <-r: + logger.Debug().Dur("duration", time.Since(t)).Msg("HID RPC report sent") + if err != nil { + if errors.Is(err.(error), io.ErrClosedPipe) { + logger.Warn().Err(err.(error)).Msg("HID RPC channel closed, skipping reportHidRPC") + return + } + logger.Warn().Err(err.(error)).Msg("failed to send HID RPC report") + } } - logger.Warn().Err(err).Msg("failed to send HID RPC message") - } + }() } func (s *Session) reportHidRPCKeyboardLedState(state usbgadget.KeyboardState) { diff --git a/internal/hidrpc/hidrpc.go b/internal/hidrpc/hidrpc.go index 7313e3b53..b27315835 100644 --- a/internal/hidrpc/hidrpc.go +++ b/internal/hidrpc/hidrpc.go @@ -6,26 +6,9 @@ import ( "github.com/jetkvm/kvm/internal/usbgadget" ) -// MessageType is the type of the HID RPC message -type MessageType byte - const ( - TypeHandshake MessageType = 0x01 - TypeKeyboardReport MessageType = 0x02 - TypePointerReport MessageType = 0x03 - TypeWheelReport MessageType = 0x04 - TypeKeypressReport MessageType = 0x05 - TypeKeypressKeepAliveReport MessageType = 0x09 - TypeMouseReport MessageType = 0x06 - TypeKeyboardMacroReport MessageType = 0x07 - TypeCancelKeyboardMacroReport MessageType = 0x08 - TypeKeyboardLedState MessageType = 0x32 - TypeKeydownState MessageType = 0x33 - TypeKeyboardMacroState MessageType = 0x34 -) - -const ( - Version byte = 0x01 // Version of the HID RPC protocol + Version byte = 0x01 // Version of the HID RPC protocol + MaximumQueues int = 4 // Maximum number of HID RPC queues ) // GetQueueIndex returns the index of the queue to which the message should be enqueued. @@ -57,19 +40,6 @@ func Unmarshal(data []byte, message *Message) error { return nil } -// Marshal marshals the HID RPC message to the data. -func Marshal(message *Message) ([]byte, error) { - if message.t == 0 { - return nil, fmt.Errorf("invalid message type: %d", message.t) - } - - data := make([]byte, len(message.d)+1) - data[0] = byte(message.t) - copy(data[1:], message.d) - - return data, nil -} - // NewHandshakeMessage creates a new handshake message. func NewHandshakeMessage() *Message { return &Message{ diff --git a/internal/hidrpc/message.go b/internal/hidrpc/message.go index 3f3506f7f..ec9247acd 100644 --- a/internal/hidrpc/message.go +++ b/internal/hidrpc/message.go @@ -3,6 +3,26 @@ package hidrpc import ( "encoding/binary" "fmt" + + "github.com/rs/zerolog" +) + +// MessageType is the type of the HID RPC message +type MessageType byte + +const ( + TypeHandshake MessageType = 0x01 + TypeKeyboardReport MessageType = 0x02 + TypePointerReport MessageType = 0x03 + TypeWheelReport MessageType = 0x04 + TypeKeypressReport MessageType = 0x05 + TypeKeypressKeepAliveReport MessageType = 0x09 + TypeMouseReport MessageType = 0x06 + TypeKeyboardMacroReport MessageType = 0x07 + TypeCancelKeyboardMacroReport MessageType = 0x08 + TypeKeyboardLedState MessageType = 0x32 + TypeKeydownState MessageType = 0x33 + TypeKeyboardMacroState MessageType = 0x34 ) // Message .. @@ -11,9 +31,22 @@ type Message struct { d []byte } +func (m Message) MarshalZerologObject(e *zerolog.Event) { + e.Uint8("Type", uint8(m.t)) + e.Bytes("Payload", m.d) +} + // Marshal marshals the message to a byte array. func (m *Message) Marshal() ([]byte, error) { - return Marshal(m) + if m.t == 0 { + return nil, fmt.Errorf("invalid message type: %d", m.t) + } + + data := make([]byte, len(m.d)+1) + data[0] = byte(m.t) + copy(data[1:], m.d) + + return data, nil } func (m *Message) Type() MessageType { @@ -56,12 +89,20 @@ func (m *Message) String() string { } } +// HidKeyBufferSize is the size of the keys buffer in the keyboard report. +const HidKeyBufferSize = 6 + // KeypressReport .. type KeypressReport struct { Key byte Press bool } +func (k KeypressReport) MarshalZerologObject(e *zerolog.Event) { + e.Hex("Key", []byte{k.Key}) + e.Bool("Press", k.Press) +} + // KeypressReport returns the keypress report from the message. func (m *Message) KeypressReport() (KeypressReport, error) { if m.t != TypeKeypressReport { @@ -77,7 +118,12 @@ func (m *Message) KeypressReport() (KeypressReport, error) { // KeyboardReport .. type KeyboardReport struct { Modifier byte - Keys []byte + Keys []byte // 6 bytes: HidKeyBufferSize +} + +func (k KeyboardReport) MarshalZerologObject(e *zerolog.Event) { + e.Hex("Modifier", []byte{k.Modifier}) + e.Hex("Keys", k.Keys) } // KeyboardReport returns the keyboard report from the message. @@ -95,17 +141,26 @@ func (m *Message) KeyboardReport() (KeyboardReport, error) { // Macro .. type KeyboardMacroStep struct { Modifier byte // 1 byte - Keys []byte // 6 bytes: hidKeyBufferSize + Keys []byte // 6 bytes: HidKeyBufferSize Delay uint16 // 2 bytes } + +func (s KeyboardMacroStep) MarshalZerologObject(e *zerolog.Event) { + e.Hex("Modifier", []byte{s.Modifier}) + e.Hex("Keys", s.Keys) + e.Uint16("Delay", s.Delay) +} + type KeyboardMacroReport struct { IsPaste bool StepCount uint32 Steps []KeyboardMacroStep } -// HidKeyBufferSize is the size of the keys buffer in the keyboard report. -const HidKeyBufferSize = 6 +func (m KeyboardMacroReport) MarshalZerologObject(e *zerolog.Event) { + e.Bool("IsPaste", m.IsPaste) + e.Uint32("StepCount", m.StepCount) +} // KeyboardMacroReport returns the keyboard macro report from the message. func (m *Message) KeyboardMacroReport() (KeyboardMacroReport, error) { @@ -117,7 +172,9 @@ func (m *Message) KeyboardMacroReport() (KeyboardMacroReport, error) { stepCount := binary.BigEndian.Uint32(m.d[1:5]) // check total length - expectedLength := int(stepCount)*9 + 5 + const StepSize = 1 + HidKeyBufferSize + 2 + + expectedLength := int(stepCount)*StepSize + 5 if len(m.d) != expectedLength { return KeyboardMacroReport{}, fmt.Errorf("invalid length: %d, expected: %d", len(m.d), expectedLength) } @@ -131,7 +188,7 @@ func (m *Message) KeyboardMacroReport() (KeyboardMacroReport, error) { Delay: binary.BigEndian.Uint16(m.d[offset+7 : offset+9]), }) - offset += 1 + HidKeyBufferSize + 2 + offset += StepSize } return KeyboardMacroReport{ @@ -148,6 +205,12 @@ type PointerReport struct { Button uint8 } +func (p PointerReport) MarshalZerologObject(e *zerolog.Event) { + e.Int("X", p.X) + e.Int("Y", p.Y) + e.Uint8("Button", p.Button) +} + func toInt(b []byte) int { return int(b[0])<<24 + int(b[1])<<16 + int(b[2])<<8 + int(b[3])<<0 } @@ -176,6 +239,12 @@ type MouseReport struct { Button uint8 } +func (m MouseReport) MarshalZerologObject(e *zerolog.Event) { + e.Int8("DX", m.DX) + e.Int8("DY", m.DY) + e.Uint8("Button", m.Button) +} + // MouseReport returns the mouse report from the message. func (m *Message) MouseReport() (MouseReport, error) { if m.t != TypeMouseReport { @@ -194,6 +263,11 @@ type KeyboardMacroState struct { IsPaste bool } +func (k KeyboardMacroState) MarshalZerologObject(e *zerolog.Event) { + e.Bool("State", k.State) + e.Bool("IsPaste", k.IsPaste) +} + // KeyboardMacroState returns the keyboard macro state report from the message. func (m *Message) KeyboardMacroState() (KeyboardMacroState, error) { if m.t != TypeKeyboardMacroState { diff --git a/internal/logging/root.go b/internal/logging/root.go index 397ca6488..1a05aa523 100644 --- a/internal/logging/root.go +++ b/internal/logging/root.go @@ -3,12 +3,12 @@ package logging import "github.com/rs/zerolog" var ( - rootZerologLogger = zerolog.New(defaultLogOutput).With(). - Str("scope", "jetkvm"). - Timestamp(). - Stack(). - Logger() - rootLogger = NewLogger(rootZerologLogger) + rootLogger = NewLogger( + zerolog.New(defaultLogOutput). + With(). + Str("scope", "jetkvm"). + Timestamp(). + Logger()) ) func GetRootLogger() *Logger { diff --git a/internal/logging/utils.go b/internal/logging/utils.go index 73ae37a84..9eea63db0 100644 --- a/internal/logging/utils.go +++ b/internal/logging/utils.go @@ -2,30 +2,24 @@ package logging import ( "fmt" - "os" + "sync" "github.com/rs/zerolog" ) -var defaultLogger = zerolog.New(os.Stdout).Level(zerolog.InfoLevel) - -func GetDefaultLogger() *zerolog.Logger { - return &defaultLogger +func UnlockWithTraceLog(lock *sync.Mutex, logger *zerolog.Logger, msg string, args ...any) { + defer lock.Unlock() + logger.Trace().Msgf(msg, args...) } -func ErrorfL(l *zerolog.Logger, format string, err error, args ...any) error { - // TODO: move rootLogger to logging package - if l == nil { - l = &defaultLogger - } - - l.Error().Err(err).Msgf(format, args...) +func ErrorfL(logger *zerolog.Logger, format string, err error, args ...any) error { + logger.Error().Err(err).Msgf(format, args...) if err == nil { return fmt.Errorf(format, args...) } - err_msg := err.Error() + ": %v" + err_msg := err.Error() + ": %w" err_args := append(args, err) return fmt.Errorf(err_msg, err_args...) diff --git a/internal/mdns/mdns.go b/internal/mdns/mdns.go index 2b954d45d..77004da87 100644 --- a/internal/mdns/mdns.go +++ b/internal/mdns/mdns.go @@ -40,10 +40,6 @@ const ( ) func NewMDNS(opts *MDNSOptions) (*MDNS, error) { - if opts.Logger == nil { - opts.Logger = logging.GetDefaultLogger() - } - if opts.ListenOptions == nil { opts.ListenOptions = &MDNSListenOptions{ IPv4: true, diff --git a/internal/ota/app.go b/internal/ota/app.go index 55caa8e8a..e3ee89ad5 100644 --- a/internal/ota/app.go +++ b/internal/ota/app.go @@ -12,10 +12,10 @@ const ( // DO NOT call it directly, it's not thread safe // Mutex is currently held by the caller, e.g. doUpdate func (s *State) updateApp(ctx context.Context, appUpdate *componentUpdateStatus) error { - l := s.l.With().Str("path", appUpdatePath).Logger() + logger := s.l.With().Str("path", appUpdatePath).Logger() if err := s.downloadFile(ctx, appUpdatePath, appUpdate.url, "app"); err != nil { - return s.componentUpdateError("Error downloading app update", err, &l) + return s.componentUpdateError("Error downloading app update", err, &logger) } downloadFinished := time.Now() @@ -28,7 +28,7 @@ func (s *State) updateApp(ctx context.Context, appUpdate *componentUpdateStatus) appUpdate.hash, &appUpdate.verificationProgress, ); err != nil { - return s.componentUpdateError("Error verifying app update hash", err, &l) + return s.componentUpdateError("Error verifying app update hash", err, &logger) } verifyFinished := time.Now() appUpdate.verifiedAt = verifyFinished @@ -37,7 +37,7 @@ func (s *State) updateApp(ctx context.Context, appUpdate *componentUpdateStatus) appUpdate.updateProgress = 1 s.triggerComponentUpdateState("app", appUpdate) - l.Info().Msg("App update downloaded") + logger.Info().Msg("App update downloaded") s.rebootNeeded = true diff --git a/internal/usbgadget/changeset.go b/internal/usbgadget/changeset.go index 57f5d7de6..ddb9b9a35 100644 --- a/internal/usbgadget/changeset.go +++ b/internal/usbgadget/changeset.go @@ -9,6 +9,7 @@ import ( "time" "github.com/prometheus/procfs" + "github.com/rs/zerolog" "github.com/sourcegraph/tf-dag/dag" ) @@ -194,15 +195,15 @@ func (fc *FileChange) checkIfDirIsMountPoint() error { } // GetActualState returns the actual state of the file at the given path. -func (fc *FileChange) getActualState() error { - l := defaultLogger.With().Str("path", fc.Path).Logger() +func (fc *FileChange) getActualState(l *zerolog.Logger) error { + logger := l.With().Str("path", fc.Path).Logger() fi, err := os.Lstat(fc.Path) if err != nil { if os.IsNotExist(err) { fc.ActualState = FileStateAbsent } else { - l.Warn().Err(err).Msg("failed to stat file") + logger.Warn().Err(err).Msg("failed to stat file") fc.ActualState = FileStateUnknown } return nil @@ -214,7 +215,7 @@ func (fc *FileChange) getActualState() error { // get the target of the symlink target, err := os.Readlink(fc.Path) if err != nil { - l.Warn().Err(err).Msg("failed to read symlink") + logger.Warn().Err(err).Msg("failed to read symlink") return fmt.Errorf("failed to read symlink") } // check if the target is a relative path @@ -222,7 +223,7 @@ func (fc *FileChange) getActualState() error { // make it absolute target, err = filepath.Abs(filepath.Join(filepath.Dir(fc.Path), target)) if err != nil { - l.Warn().Err(err).Msg("failed to make symlink target absolute") + logger.Warn().Err(err).Msg("failed to make symlink target absolute") return fmt.Errorf("failed to make symlink target absolute") } } @@ -237,13 +238,13 @@ func (fc *FileChange) getActualState() error { case FileStateMountedConfigFS: err := fc.checkIfDirIsMountPoint() if err != nil { - l.Warn().Err(err).Msg("failed to check if dir is mount point") + logger.Warn().Err(err).Msg("failed to check if dir is mount point") return err } case FileStateSymlinkInOrderConfigFS: - state, err := checkIfSymlinksInOrder(fc, &l) + state, err := checkIfSymlinksInOrder(fc, &logger) if err != nil { - l.Warn().Err(err).Msg("failed to check if symlinks are in order") + logger.Warn().Err(err).Msg("failed to check if symlinks are in order") return err } fc.ActualState = state @@ -252,7 +253,7 @@ func (fc *FileChange) getActualState() error { } if fi.Mode()&os.ModeDevice == os.ModeDevice { - l.Info().Msg("file is a device") + logger.Info().Msg("file is a device") return nil } @@ -262,15 +263,14 @@ func (fc *FileChange) getActualState() error { // get the content of the file content, err := os.ReadFile(fc.Path) if err != nil { - l.Warn().Err(err).Msg("failed to read file") + logger.Warn().Err(err).Msg("failed to read file") return fmt.Errorf("failed to read file") } fc.ActualContent = content return nil } - l.Warn().Interface("file_info", fi.Mode()).Bool("is_dir", fi.IsDir()).Msg("unknown file type") - + logger.Warn().Interface("file_info", fi.Mode()).Bool("is_dir", fi.IsDir()).Msg("unknown file type") return fmt.Errorf("unknown file type") } @@ -280,18 +280,16 @@ func (fc *FileChange) ResetActionResolution() { fc.changed = ChangeStateUnknown } -func (fc *FileChange) Action() FileChangeResolvedAction { +func (fc *FileChange) Action(logger *zerolog.Logger) FileChangeResolvedAction { if !fc.checked { - fc.action = fc.getFileChangeResolvedAction() + fc.action = fc.getFileChangeResolvedAction(logger) fc.checked = true } return fc.action } -func (fc *FileChange) getFileChangeResolvedAction() FileChangeResolvedAction { - l := defaultLogger.With().Str("path", fc.Path).Logger() - +func (fc *FileChange) getFileChangeResolvedAction(l *zerolog.Logger) FileChangeResolvedAction { // some actions are not needed to be checked switch fc.ExpectedState { case FileStateFileWrite: @@ -300,8 +298,10 @@ func (fc *FileChange) getFileChangeResolvedAction() FileChangeResolvedAction { return FileChangeResolvedActionTouch } + logger := l.With().Interface("expected_state", FileStateString[fc.ExpectedState]).Logger() + // get the actual state of the file - err := fc.getActualState() + err := fc.getActualState(&logger) if err != nil { return FileChangeResolvedActionDoNothing } @@ -348,7 +348,7 @@ func (fc *FileChange) getFileChangeResolvedAction() FileChangeResolvedAction { } return FileChangeResolvedActionCreateSymlink case FileStateSymlinkInOrderConfigFS: - // if the file is already a symlink, check if the target is the same + // if the file is already a symlink to configfs, check if the target is the same if fc.ActualState == FileStateSymlinkInOrderConfigFS { return FileChangeResolvedActionDoNothing } @@ -364,7 +364,7 @@ func (fc *FileChange) getFileChangeResolvedAction() FileChangeResolvedAction { } return FileChangeResolvedActionMountConfigFS default: - l.Warn().Interface("file_change", FileStateString[fc.ExpectedState]).Msg("unknown expected state") + logger.Warn().Interface("file_change", FileStateString[fc.ExpectedState]).Msg("unknown expected state") return FileChangeResolvedActionDoNothing } } @@ -387,18 +387,19 @@ func (c *ChangeSet) AddFileChange(component string, path string, expectedState F }) } -func (c *ChangeSet) ApplyChanges() error { +func (c *ChangeSet) ApplyChanges(logger *zerolog.Logger) error { r := ChangeSetResolver{ changeset: c, g: &dag.AcyclicGraph{}, - l: defaultLogger, } - return r.Apply() + return r.Apply(logger) } -func (c *ChangeSet) applyChange(change *FileChange) error { - switch change.Action() { +func (c *ChangeSet) applyChange(change *FileChange, logger *zerolog.Logger) error { + action := change.Action(logger) + + switch action { case FileChangeResolvedActionWriteFile: return os.WriteFile(change.Path, change.ExpectedContent, 0644) case FileChangeResolvedActionUpdateFile: @@ -413,7 +414,7 @@ func (c *ChangeSet) applyChange(change *FileChange) error { } return os.Symlink(string(change.ExpectedContent), change.Path) case FileChangeResolvedActionReorderSymlinks: - return recreateSymlinks(change, nil) + return recreateSymlinks(change, logger) case FileChangeResolvedActionCreateDirectory: return os.MkdirAll(change.Path, 0755) case FileChangeResolvedActionRemove: @@ -427,10 +428,10 @@ func (c *ChangeSet) applyChange(change *FileChange) error { case FileChangeResolvedActionDoNothing: return nil default: - return fmt.Errorf("unknown action: %d", change.Action()) + return fmt.Errorf("unknown action: %d", action) } } -func (c *ChangeSet) Apply() error { - return c.ApplyChanges() +func (c *ChangeSet) Apply(logger *zerolog.Logger) error { + return c.ApplyChanges(logger) } diff --git a/internal/usbgadget/changeset_resolver.go b/internal/usbgadget/changeset_resolver.go index 67812e0d6..0b419376a 100644 --- a/internal/usbgadget/changeset_resolver.go +++ b/internal/usbgadget/changeset_resolver.go @@ -9,9 +9,7 @@ import ( type ChangeSetResolver struct { changeset *ChangeSet - - l *zerolog.Logger - g *dag.AcyclicGraph + g *dag.AcyclicGraph changesMap map[string]*FileChange conditionalChangesMap map[string]*FileChange @@ -43,13 +41,13 @@ func (c *ChangeSetResolver) toOrderedChanges() error { return nil } -func (c *ChangeSetResolver) doResolveChanges(initial bool) error { +func (c *ChangeSetResolver) doResolveChanges(initial bool, logger *zerolog.Logger) error { resolvedChanges := make([]*FileChange, 0) for _, key := range c.orderedChanges { change := c.changesMap[key.(string)] if change == nil { - c.l.Error().Str("key", key.(string)).Msg("fileChange not found") + logger.Error().Str("key", key.(string)).Msg("fileChange not found") continue } @@ -57,7 +55,7 @@ func (c *ChangeSetResolver) doResolveChanges(initial bool) error { change.ResetActionResolution() } - resolvedAction := change.Action() + resolvedAction := change.Action(logger) resolvedChanges = append(resolvedChanges, change) // no need to check the triggers if there's no change @@ -89,7 +87,7 @@ func (c *ChangeSetResolver) doResolveChanges(initial bool) error { return nil } -func (c *ChangeSetResolver) resolveChanges(initial bool) error { +func (c *ChangeSetResolver) resolveChanges(initial bool, logger *zerolog.Logger) error { // get the ordered changes err := c.toOrderedChanges() if err != nil { @@ -97,39 +95,43 @@ func (c *ChangeSetResolver) resolveChanges(initial bool) error { } // resolve the changes - err = c.doResolveChanges(initial) + err = c.doResolveChanges(initial, logger) if err != nil { return err } for _, change := range c.resolvedChanges { - c.l.Trace().Str("change", change.String()).Msg("resolved change") + logger.Trace().Stringer("change", change).Msg("resolved change") } if !c.additionalResolveRequired || !initial { return nil } - return c.resolveChanges(false) + return c.resolveChanges(false, logger) } -func (c *ChangeSetResolver) applyChanges() error { +func (c *ChangeSetResolver) applyChanges(l *zerolog.Logger) error { for _, change := range c.resolvedChanges { change.ResetActionResolution() - action := change.Action() + + action := change.Action(l) actionStr := FileChangeResolvedActionString[action] - l := c.l.Info() + logger := l.With().Str("action", actionStr).Stringer("change", change).Logger() + var event *zerolog.Event if action == FileChangeResolvedActionDoNothing { - l = c.l.Trace() + event = logger.Trace() + } else { + event = logger.Debug() } - l.Str("action", actionStr).Str("change", change.String()).Msg("applying change") + event.Msg("applying change") - err := c.changeset.applyChange(change) + err := c.changeset.applyChange(change, &logger) if err != nil { if change.IgnoreErrors { - c.l.Warn().Str("change", change.String()).Err(err).Msg("ignoring error") + logger.Warn().Err(err).Msg("ignoring error") } else { return err } @@ -139,7 +141,7 @@ func (c *ChangeSetResolver) applyChanges() error { return nil } -func (c *ChangeSetResolver) GetChanges() ([]*FileChange, error) { +func (c *ChangeSetResolver) GetChanges(logger *zerolog.Logger) ([]*FileChange, error) { localChanges := c.changeset.Changes changesMap := make(map[string]*FileChange) conditionalChangesMap := make(map[string]*FileChange) @@ -175,7 +177,7 @@ func (c *ChangeSetResolver) GetChanges() ([]*FileChange, error) { c.changesMap = changesMap c.conditionalChangesMap = conditionalChangesMap - err := c.resolveChanges(true) + err := c.resolveChanges(true, logger) if err != nil { return nil, err } @@ -183,10 +185,10 @@ func (c *ChangeSetResolver) GetChanges() ([]*FileChange, error) { return c.resolvedChanges, nil } -func (c *ChangeSetResolver) Apply() error { - if _, err := c.GetChanges(); err != nil { +func (c *ChangeSetResolver) Apply(logger *zerolog.Logger) error { + if _, err := c.GetChanges(logger); err != nil { return err } - return c.applyChanges() + return c.applyChanges(logger) } diff --git a/internal/usbgadget/changeset_symlink.go b/internal/usbgadget/changeset_symlink.go index d94c75944..2f91d4d1e 100644 --- a/internal/usbgadget/changeset_symlink.go +++ b/internal/usbgadget/changeset_symlink.go @@ -23,11 +23,8 @@ func compareSymlinks(expected []symlink, actual []symlink) bool { return reflect.DeepEqual(expected, actual) } -func checkIfSymlinksInOrder(fc *FileChange, logger *zerolog.Logger) (FileState, error) { - if logger == nil { - logger = defaultLogger - } - l := logger.With().Str("path", fc.Path).Logger() +func checkIfSymlinksInOrder(fc *FileChange, l *zerolog.Logger) (FileState, error) { + logger := l.With().Str("path", fc.Path).Logger() if len(fc.ParamSymlinks) == 0 { return FileStateUnknown, fmt.Errorf("no symlinks to check") @@ -39,7 +36,7 @@ func checkIfSymlinksInOrder(fc *FileChange, logger *zerolog.Logger) (FileState, if os.IsNotExist(err) { return FileStateAbsent, nil } else { - l.Warn().Err(err).Msg("failed to stat file") + logger.Warn().Err(err).Msg("failed to stat file") return FileStateUnknown, fmt.Errorf("failed to stat file") } } @@ -85,40 +82,40 @@ func checkIfSymlinksInOrder(fc *FileChange, logger *zerolog.Logger) (FileState, return FileStateSymlinkInOrderConfigFS, nil } - l.Trace().Interface("expected", fc.ParamSymlinks).Interface("actual", symlinks).Msg("symlinks are not in order") + logger.Trace().Interface("expected", fc.ParamSymlinks).Interface("actual", symlinks).Msg("symlinks are not in order") return FileStateSymlinkNotInOrderConfigFS, nil } -func recreateSymlinks(fc *FileChange, logger *zerolog.Logger) error { - if logger == nil { - logger = defaultLogger - } +func recreateSymlinks(fc *FileChange, l *zerolog.Logger) error { // remove all symlinks files, err := os.ReadDir(fc.Path) if err != nil { return fmt.Errorf("failed to read directory") } - l := logger.With().Str("path", fc.Path).Logger() - l.Info().Msg("recreate symlinks") + logger := l.With().Str("path", fc.Path).Logger() + logger.Info().Msg("recreate symlinks") for _, file := range files { if file.Type()&os.ModeSymlink != os.ModeSymlink { continue } - l.Info().Str("name", file.Name()).Msg("remove symlink") + + logger.Info().Str("name", file.Name()).Msg("remove symlink") err := os.Remove(path.Join(fc.Path, file.Name())) if err != nil { return fmt.Errorf("failed to remove symlink") } } - l.Info().Interface("param-symlinks", fc.ParamSymlinks).Msg("create symlinks") + logger = logger.With().Interface("param-symlinks", fc.ParamSymlinks).Logger() + logger.Debug().Msg("create symlinks") // create the symlinks for _, symlink := range fc.ParamSymlinks { - l.Info().Str("name", symlink.Path).Str("target", symlink.Target).Msg("create symlink") + eachLogger := logger.With().Str("name", symlink.Path).Str("target", symlink.Target).Logger() + eachLogger.Debug().Msg("create symlink") path := symlink.Path if !filepath.IsAbs(path) { @@ -127,7 +124,7 @@ func recreateSymlinks(fc *FileChange, logger *zerolog.Logger) error { err := os.Symlink(symlink.Target, path) if err != nil { - l.Warn().Err(err).Msg("failed to create symlink") + eachLogger.Warn().Err(err).Msg("failed to create symlink") return fmt.Errorf("failed to create symlink") } } diff --git a/internal/usbgadget/config.go b/internal/usbgadget/config.go index 6d1bd391b..18dcbb379 100644 --- a/internal/usbgadget/config.go +++ b/internal/usbgadget/config.go @@ -61,18 +61,18 @@ var defaultGadgetConfig = map[string]gadgetConfigItem{ "mass_storage_lun0": massStorageLun0Config, } -func (u *UsbGadget) isGadgetConfigItemEnabled(itemKey string) bool { +func (enabledDevices *Devices) isGadgetConfigItemEnabled(itemKey string) bool { switch itemKey { case "absolute_mouse": - return u.enabledDevices.AbsoluteMouse + return enabledDevices.AbsoluteMouse case "relative_mouse": - return u.enabledDevices.RelativeMouse + return enabledDevices.RelativeMouse case "keyboard": - return u.enabledDevices.Keyboard + return enabledDevices.Keyboard case "mass_storage_base": - return u.enabledDevices.MassStorage + return enabledDevices.MassStorage case "mass_storage_lun0": - return u.enabledDevices.MassStorage + return enabledDevices.MassStorage default: return true } @@ -115,15 +115,6 @@ func (u *UsbGadget) SetGadgetDevices(devices *Devices) { u.enabledDevices = *devices } -// GetConfigPath returns the path to the config item. -func (u *UsbGadget) GetConfigPath(itemKey string) (string, error) { - item, ok := u.configMap[itemKey] - if !ok { - return "", fmt.Errorf("config item %s not found", itemKey) - } - return joinPath(u.kvmGadgetPath, item.configPath), nil -} - // GetPath returns the path to the item. func (u *UsbGadget) GetPath(itemKey string) (string, error) { item, ok := u.configMap[itemKey] @@ -136,24 +127,34 @@ func (u *UsbGadget) GetPath(itemKey string) (string, error) { // OverrideGadgetConfig overrides the gadget config for the given item and attribute. // It returns an error if the item is not found or the attribute is not found. // It returns true if the attribute is overridden, false otherwise. -func (u *UsbGadget) OverrideGadgetConfig(itemKey string, itemAttr string, value string) (error, bool) { +func (u *UsbGadget) OverrideGadgetConfig(itemKey string, itemAttr string, value string) (bool, error) { u.configLock.Lock() defer u.configLock.Unlock() + logger := u.log. + With(). + Str("itemKey", itemKey). + Str("itemAttr", itemAttr). + Str("value", value). + Logger() + // get it as a pointer _, ok := u.configMap[itemKey] if !ok { - return fmt.Errorf("config item %s not found", itemKey), false + err := fmt.Errorf("config item %s not found", itemKey) + logger.Error().Err(err).Msg("overriding gadget config") + return false, err } if u.configMap[itemKey].attrs[itemAttr] == value { - return nil, false + logger.Trace().Msg("unchanged gadget config") + return false, nil } u.configMap[itemKey].attrs[itemAttr] = value - u.log.Info().Str("itemKey", itemKey).Str("itemAttr", itemAttr).Str("value", value).Msg("overriding gadget config") + logger.Info().Msg("overriding gadget config") - return nil, true + return true, nil } func mountConfigFS(path string) error { @@ -200,12 +201,12 @@ func (u *UsbGadget) UpdateGadgetConfig() error { } func (u *UsbGadget) configureUsbGadget(resetUsb bool) error { - return u.WithTransaction(func() error { - u.tx.MountConfigFS() - u.tx.CreateConfigPath() - u.tx.WriteGadgetConfig() + return u.WithTransaction(func(u *UsbGadget, tx *UsbGadgetTransaction) error { + tx.MountConfigFS() + tx.CreateConfigPath(u.configC1Path) + tx.WriteGadgetConfig(u.kvmGadgetPath, u.configC1Path, u.udc, u.getOrderedConfigItems(), &u.enabledDevices) if resetUsb { - u.tx.RebindUsb(true) + tx.RebindUsb(u.udc, true) } return nil }) diff --git a/internal/usbgadget/config_tx.go b/internal/usbgadget/config_tx.go index df8a3d1b9..c08bc28e4 100644 --- a/internal/usbgadget/config_tx.go +++ b/internal/usbgadget/config_tx.go @@ -12,68 +12,57 @@ import ( // no os package should occur in this file type UsbGadgetTransaction struct { - c *ChangeSet - - // below are the fields that are needed to be set by the caller - log *zerolog.Logger - udc string - dwc3Path string - kvmGadgetPath string - configC1Path string - orderedConfigItems orderedGadgetConfigItems - isGadgetConfigItemEnabled func(key string) bool - + c *ChangeSet + log *zerolog.Logger reorderSymlinkChanges *RequestedFileChange } -func (u *UsbGadget) newUsbGadgetTransaction(lock bool) error { - if lock { - u.txLock.Lock() - defer u.txLock.Unlock() - } - - if u.tx != nil { - return fmt.Errorf("transaction already exists") - } - +func (u *UsbGadget) newUsbGadgetTransaction() *UsbGadgetTransaction { tx := &UsbGadgetTransaction{ - c: &ChangeSet{}, - log: u.log, - udc: u.udc, - dwc3Path: dwc3Path, - kvmGadgetPath: u.kvmGadgetPath, - configC1Path: u.configC1Path, - orderedConfigItems: u.getOrderedConfigItems(), - isGadgetConfigItemEnabled: u.isGadgetConfigItemEnabled, + c: &ChangeSet{}, + log: u.log, } - u.tx = tx - - return nil + return tx } -func (u *UsbGadget) WithTransaction(fn func() error) error { +func (u *UsbGadget) WithTransaction(fn func(u2 *UsbGadget, tx *UsbGadgetTransaction) error) error { u.txLock.Lock() defer u.txLock.Unlock() - err := u.newUsbGadgetTransaction(false) - if err != nil { - u.log.Error().Err(err).Msg("failed to create transaction") + logger := u.log.With().Str("udc", u.udc).Logger() + logger.Info().Msg("starting USB gadget transaction") + + tx := u.newUsbGadgetTransaction() + if err := fn(u, tx); err != nil { + logger.Error().Err(err).Msg("transaction failed") return err } - if err := fn(); err != nil { - u.log.Error().Err(err).Msg("transaction failed") - return err + + err := tx.Commit() + logger.Trace().Err(err).Msg("committed transaction") + return err +} + +func (u *UsbGadget) getOrderedConfigItems() orderedGadgetConfigItems { + items := make([]gadgetConfigItemWithKey, 0) + for key, item := range u.configMap { + items = append(items, gadgetConfigItemWithKey{key, item}) } - result := u.tx.Commit() - u.tx = nil - return result + sort.SliceStable(items, func(i, j int) bool { + return items[i].item.order < items[j].item.order + }) + + return items } func (tx *UsbGadgetTransaction) addFileChange(component string, change RequestedFileChange) string { change.Component = component tx.c.AddFileChangeStruct(change) + logger := tx.log + logger.Trace().Interface("change", change).Msg("add change") + key := change.Key if key == "" { key = change.Path @@ -101,28 +90,16 @@ func (tx *UsbGadgetTransaction) removeFile(component string, path string, descri func (tx *UsbGadgetTransaction) Commit() error { tx.addFileChange("gadget-finalize", *tx.reorderSymlinkChanges) - err := tx.c.Apply() + logger := tx.log + err := tx.c.Apply(logger) if err != nil { - tx.log.Error().Err(err).Msg("failed to update usbgadget configuration") + logger.Error().Err(err).Msg("failed to update usbgadget configuration") return err } - tx.log.Info().Msg("usbgadget configuration updated") + logger.Info().Msg("usbgadget configuration updated") return nil } -func (u *UsbGadget) getOrderedConfigItems() orderedGadgetConfigItems { - items := make([]gadgetConfigItemWithKey, 0) - for key, item := range u.configMap { - items = append(items, gadgetConfigItemWithKey{key, item}) - } - - sort.Slice(items, func(i, j int) bool { - return items[i].item.order < items[j].item.order - }) - - return items -} - func (tx *UsbGadgetTransaction) MountConfigFS() { tx.addFileChange("gadget", RequestedFileChange{ Path: configFSPath, @@ -131,84 +108,71 @@ func (tx *UsbGadgetTransaction) MountConfigFS() { }) } -func (tx *UsbGadgetTransaction) CreateConfigPath() { +func (tx *UsbGadgetTransaction) CreateConfigPath(configC1Path string) { tx.mkdirAll( "gadget", - tx.configC1Path, + configC1Path, "create config path", []string{configFSPath}, ) } -func (tx *UsbGadgetTransaction) WriteGadgetConfig() { +func (tx *UsbGadgetTransaction) WriteGadgetConfig(kvmGadgetPath string, configC1Path string, udc string, orderedConfigItems orderedGadgetConfigItems, enabledDevices *Devices) { // create kvm gadget path tx.mkdirAll( "gadget", - tx.kvmGadgetPath, + kvmGadgetPath, "create kvm gadget path", - []string{tx.configC1Path}, + []string{configC1Path}, ) deps := make([]string, 0) - deps = append(deps, tx.kvmGadgetPath) + deps = append(deps, kvmGadgetPath) - for _, val := range tx.orderedConfigItems { + for _, val := range orderedConfigItems { key := val.key item := val.item // check if the item is enabled in the config - if !tx.isGadgetConfigItemEnabled(key) { - tx.DisableGadgetItemConfig(item) + if !enabledDevices.isGadgetConfigItemEnabled(key) { + tx.DisableGadgetItemConfig(configC1Path, item) continue } - deps = tx.writeGadgetItemConfig(item, deps) - } - - tx.WriteUDC() -} -func (tx *UsbGadgetTransaction) getDisableKeys() []string { - disableKeys := make([]string, 0) - for _, item := range tx.orderedConfigItems { - if !tx.isGadgetConfigItemEnabled(item.key) { - continue - } - if item.item.configPath == nil || item.item.configAttrs != nil { - continue - } - - disableKeys = append(disableKeys, fmt.Sprintf("disable-%s", item.item.device)) + deps = tx.writeGadgetItemConfig(kvmGadgetPath, configC1Path, item, deps) } - return disableKeys + + tx.WriteUDC(kvmGadgetPath, udc) } -func (tx *UsbGadgetTransaction) DisableGadgetItemConfig(item gadgetConfigItem) { +func (tx *UsbGadgetTransaction) DisableGadgetItemConfig(configC1Path string, item gadgetConfigItem) { // remove symlink if exists if item.configPath == nil { return } - configPath := joinPath(tx.configC1Path, item.configPath) + configPath := joinPath(configC1Path, item.configPath) _ = tx.removeFile("gadget", configPath, "remove symlink: disable gadget config") } -func (tx *UsbGadgetTransaction) writeGadgetItemConfig(item gadgetConfigItem, deps []string) []string { +func (tx *UsbGadgetTransaction) writeGadgetItemConfig(kvmGadgetPath string, configC1Path string, item gadgetConfigItem, deps []string) []string { component := item.device // create directory for the item files := make([]string, 0) files = append(files, deps...) - gadgetItemPath := joinPath(tx.kvmGadgetPath, item.path) - if gadgetItemPath != tx.kvmGadgetPath { + gadgetItemPath := joinPath(kvmGadgetPath, item.path) + if gadgetItemPath != kvmGadgetPath { gadgetItemDir := tx.mkdirAll(component, gadgetItemPath, "create gadget item directory", files) files = append(files, gadgetItemDir) } beforeChange := make([]string, 0) disableGadgetItemKey := fmt.Sprintf("disable-%s", item.device) + if item.configPath != nil && item.configAttrs == nil { - beforeChange = append(beforeChange, tx.getDisableKeys()...) + beforeChange = append(beforeChange, disableGadgetItemKey) } if len(item.attrs) > 0 { @@ -245,8 +209,8 @@ func (tx *UsbGadgetTransaction) writeGadgetItemConfig(item gadgetConfigItem, dep // create config directory if configAttrs are set if len(item.configAttrs) > 0 { - configItemPath := joinPath(tx.configC1Path, item.configPath) - if configItemPath != tx.configC1Path { + configItemPath := joinPath(configC1Path, item.configPath) + if configItemPath != configC1Path { configItemDir := tx.mkdirAll(component, configItemPath, "create config item directory", files) files = append(files, configItemDir) } @@ -260,7 +224,7 @@ func (tx *UsbGadgetTransaction) writeGadgetItemConfig(item gadgetConfigItem, dep // create symlink if configPath is set if item.configPath != nil && item.configAttrs == nil { - configPath := joinPath(tx.configC1Path, item.configPath) + configPath := joinPath(configC1Path, item.configPath) // the change will be only applied by `beforeChange` tx.addFileChange(component, RequestedFileChange{ @@ -271,7 +235,7 @@ func (tx *UsbGadgetTransaction) writeGadgetItemConfig(item gadgetConfigItem, dep Description: "remove symlink", }) - tx.addReorderSymlinkChange(configPath, gadgetItemPath, files) + tx.addReorderSymlinkChange(configC1Path, configPath, gadgetItemPath, files) } return files @@ -294,17 +258,19 @@ func (tx *UsbGadgetTransaction) writeGadgetAttrs(basePath string, attrs gadgetAt return files } -func (tx *UsbGadgetTransaction) addReorderSymlinkChange(path string, target string, deps []string) { - tx.log.Trace().Str("path", path).Str("target", target).Msg("add reorder symlink change") +func (tx *UsbGadgetTransaction) addReorderSymlinkChange(configC1Path string, path string, target string, deps []string) { + logger := tx.log + logger.Trace().Str("configC1Path", configC1Path).Str("path", path).Str("target", target).Msg("add reorder symlink change") if tx.reorderSymlinkChanges == nil { tx.reorderSymlinkChanges = &RequestedFileChange{ - Component: "gadget-finalize", - Key: "reorder-symlinks", - Path: tx.configC1Path, - ExpectedState: FileStateSymlinkInOrderConfigFS, - Description: "order symlinks", - ParamSymlinks: []symlink{}, + Component: "gadget-finalize", + Key: "reorder-symlinks", + Path: configC1Path, + ExpectedState: FileStateSymlinkInOrderConfigFS, + ExpectedContent: []byte(target), + Description: "order symlinks", + ParamSymlinks: []symlink{}, } } @@ -315,35 +281,35 @@ func (tx *UsbGadgetTransaction) addReorderSymlinkChange(path string, target stri }) } -func (tx *UsbGadgetTransaction) WriteUDC() { +func (tx *UsbGadgetTransaction) WriteUDC(kvmGadgetPath string, udc string) { // bound the gadget to a UDC (USB Device Controller) - path := path.Join(tx.kvmGadgetPath, "UDC") + path := path.Join(kvmGadgetPath, "UDC") tx.addFileChange("udc", RequestedFileChange{ Key: "udc", Path: path, ExpectedState: FileStateFileContentMatch, - ExpectedContent: []byte(tx.udc), + ExpectedContent: []byte(udc), DependsOn: []string{"reorder-symlinks"}, Description: "write UDC", }) } -func (tx *UsbGadgetTransaction) RebindUsb(ignoreUnbindError bool) { +func (tx *UsbGadgetTransaction) RebindUsb(udc string, ignoreUnbindError bool) { // remove the gadget from the UDC tx.addFileChange("udc", RequestedFileChange{ - Path: path.Join(tx.dwc3Path, "unbind"), + Path: path.Join(dwc3Path, "unbind"), ExpectedState: FileStateFileWrite, - ExpectedContent: []byte(tx.udc), + ExpectedContent: []byte(udc), Description: "unbind UDC", DependsOn: []string{"udc"}, IgnoreErrors: ignoreUnbindError, }) // bind the gadget to the UDC tx.addFileChange("udc", RequestedFileChange{ - Path: path.Join(tx.dwc3Path, "bind"), + Path: path.Join(dwc3Path, "bind"), ExpectedState: FileStateFileWrite, - ExpectedContent: []byte(tx.udc), + ExpectedContent: []byte(udc), Description: "bind UDC", - DependsOn: []string{path.Join(tx.dwc3Path, "unbind")}, + DependsOn: []string{path.Join(dwc3Path, "unbind")}, }) } diff --git a/internal/usbgadget/hid_keyboard.go b/internal/usbgadget/hid_keyboard.go index 274f0b6aa..f43d5c950 100644 --- a/internal/usbgadget/hid_keyboard.go +++ b/internal/usbgadget/hid_keyboard.go @@ -3,11 +3,14 @@ package usbgadget import ( "bytes" "context" + "errors" "fmt" "os" + "slices" "sync" "time" + "github.com/jetkvm/kvm/internal/logging" "github.com/rs/xid" "github.com/rs/zerolog" ) @@ -95,6 +98,10 @@ type KeyboardState struct { raw byte } +func (s KeyboardState) MarshalZerologObject(e *zerolog.Event) { + e.Hex("State", []byte{s.raw}) +} + // Byte returns the raw byte representation of the keyboard state. func (k *KeyboardState) Byte() byte { return k.raw @@ -117,19 +124,25 @@ func (u *UsbGadget) updateKeyboardState(state byte) { u.keyboardStateLock.Lock() defer u.keyboardStateLock.Unlock() + logger := u.log.With().Hex("state", []byte{state}).Logger() + if state&^ValidKeyboardLedMasks != 0 { - u.log.Warn().Uint8("state", state).Msg("ignoring invalid bits") - return + logger.Warn().Msg("ignoring invalid bits") + state &= ValidKeyboardLedMasks } + logger = logger.With().Hex("old_state", []byte{u.keyboardState}).Logger() + if u.keyboardState == state { + logger.Trace().Msg("unchanged keyboardState") return } - u.log.Trace().Uint8("old", u.keyboardState).Uint8("new", state).Msg("keyboardState updated") + u.keyboardState = state + logger.Trace().Msg("keyboardState updated") - if u.onKeyboardStateChange != nil { - (*u.onKeyboardStateChange)(getKeyboardState(state)) + if cb := u.onKeyboardStateChange; cb != nil { + go (*cb)(getKeyboardState(state)) // this enqueues to the outgoing hidrpc queue via usb.go → currentSession.reportHidRPCKeyboardLedState(...) } } @@ -151,6 +164,17 @@ func (u *UsbGadget) GetKeysDownState() KeysDownState { return u.keysDownState } +func (u *UsbGadget) ResetRollover() { + u.keyboardStateLock.Lock() + defer u.keyboardStateLock.Unlock() + + if u.keysDownState.Keys[0] == hidErrorRollOver { + for i := range u.keysDownState.Keys { + u.keysDownState.Keys[i] = 0 + } + } +} + func (u *UsbGadget) SetOnKeysDownChange(f func(state KeysDownState)) { u.onKeysDownChange = &f } @@ -162,183 +186,171 @@ func (u *UsbGadget) SetOnKeepAliveReset(f func()) { // DefaultAutoReleaseDuration is the default duration for auto-release of a key. const DefaultAutoReleaseDuration = 100 * time.Millisecond -func (u *UsbGadget) scheduleAutoRelease(key byte) { +func (u *UsbGadget) DelayAutoReleaseWithDuration(resetDuration time.Duration) { + logger := u.log.With().Dur("reset_duration", resetDuration).Logger() + u.kbdAutoReleaseLock.Lock() - defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease scheduled") + defer logging.UnlockWithTraceLog(&u.kbdAutoReleaseLock, &logger, "auto-release delayed") - if u.kbdAutoReleaseTimers[key] != nil { - u.kbdAutoReleaseTimers[key].Stop() + for _, timer := range u.kbdAutoReleaseTimers { + if timer != nil { + timer.Reset(resetDuration) + } } - - // TODO: make this configurable - // We currently hardcode the duration to 100ms - // However, it should be the same as the duration of the keep-alive reset called baseExtension. - u.kbdAutoReleaseTimers[key] = time.AfterFunc(100*time.Millisecond, func() { - u.performAutoRelease(key) - }) } -func (u *UsbGadget) cancelAutoRelease(key byte) { +// note: lock must be freed by caller +func (u *UsbGadget) popAutoReleaseTimer(key byte) bool { u.kbdAutoReleaseLock.Lock() - defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease cancelled") + timer, ok := u.kbdAutoReleaseTimers[key] + if ok { + if timer != nil { + timer.Stop() + } - if timer := u.kbdAutoReleaseTimers[key]; timer != nil { - timer.Stop() - u.kbdAutoReleaseTimers[key] = nil delete(u.kbdAutoReleaseTimers, key) - - // Reset keep-alive timing when key is released - if u.onKeepAliveReset != nil { - (*u.onKeepAliveReset)() - } } + + return ok } -func (u *UsbGadget) DelayAutoReleaseWithDuration(resetDuration time.Duration) { - u.kbdAutoReleaseLock.Lock() - defer unlockWithLog(&u.kbdAutoReleaseLock, u.log, "autoRelease delayed") +func (u *UsbGadget) scheduleAutoRelease(key byte) { + logger := u.log.With().Hex("key", []byte{key}).Logger() - u.log.Debug().Dur("reset_duration", resetDuration).Msg("delaying auto-release with dynamic duration") + _ = u.popAutoReleaseTimer(key) // discard previous timer if any + defer logging.UnlockWithTraceLog(&u.kbdAutoReleaseLock, &logger, "auto-release scheduled") - for _, timer := range u.kbdAutoReleaseTimers { - if timer != nil { - timer.Reset(resetDuration) - } - } + start := time.Now() + u.kbdAutoReleaseTimers[key] = time.AfterFunc(DefaultAutoReleaseDuration, func() { + logger.Info().Dur("elapsed", time.Since(start)).Msg("fired after") + u.performAutoRelease(key) + }) } -func (u *UsbGadget) performAutoRelease(key byte) { - u.kbdAutoReleaseLock.Lock() +func (u *UsbGadget) cancelAutoRelease(key byte) { + logger := u.log.With().Hex("key", []byte{key}).Logger() - if u.kbdAutoReleaseTimers[key] == nil { - u.log.Warn().Uint8("key", key).Msg("autoRelease timer not found") - u.kbdAutoReleaseLock.Unlock() - return + ok := u.popAutoReleaseTimer(key) + defer logging.UnlockWithTraceLog(&u.kbdAutoReleaseLock, &logger, "auto-release cancelled") + + if ok { + // Reset keep-alive timing when key is actually released + if cb := u.onKeepAliveReset; cb != nil { + go (*cb)() + } } +} - u.kbdAutoReleaseTimers[key].Stop() - u.kbdAutoReleaseTimers[key] = nil - delete(u.kbdAutoReleaseTimers, key) - u.kbdAutoReleaseLock.Unlock() +func (u *UsbGadget) performAutoRelease(key byte) { + logger := u.log.With().Hex("key", []byte{key}).Logger() - // Skip if already released - state := u.GetKeysDownState() - alreadyReleased := true + ok := u.popAutoReleaseTimer(key) + defer logging.UnlockWithTraceLog(&u.kbdAutoReleaseLock, &logger, "auto-released") - for i := range state.Keys { - if state.Keys[i] == key { - alreadyReleased = false - break + if ok { + // Skip if already released + state := u.GetKeysDownState() + if !slices.Contains(state.Keys, key) { + logger.Trace().Msg("already released") + return } - } - if alreadyReleased { - return - } - - _, err := u.keypressReport(key, false) - if err != nil { - u.log.Warn().Uint8("key", key).Msg("failed to release key") + _, err := u.keypressReport(&logger, key, false) + if err != nil { + logger.Warn().Msg("failed to release key") + } } } func (u *UsbGadget) listenKeyboardEvents() { - var path string - if u.keyboardHidFile != nil { - path = u.keyboardHidFile.Name() - } - l := u.log.With().Str("listener", "keyboardEvents").Str("path", path).Logger() - l.Trace().Msg("starting") - - go func() { - buf := make([]byte, hidReadBufferSize) - for { - select { - case <-u.keyboardStateCtx.Done(): - l.Info().Msg("context done") + buf := make([]byte, hidReadBufferSize) + for { + select { + case <-u.keyboardStateCtx.Done(): + u.log.Info().Msg("context done") + return + default: + if u.keyboardHidFile == nil { + u.log.Warn().Msg("keyboardHidFile is nil, stopping keyboard event listener") return - default: - l.Trace().Msg("reading from keyboard for LED state changes") - if u.keyboardHidFile == nil { - u.logWithSuppression("keyboardHidFileNil", 100, &l, nil, "keyboardHidFile is nil") - // show the error every 100 times to avoid spamming the logs - time.Sleep(time.Second) - continue - } - // reset the counter - u.resetLogSuppressionCounter("keyboardHidFileNil") + } - n, err := u.keyboardHidFile.Read(buf) - if err != nil { - u.logWithSuppression("keyboardHidFileRead", 100, &l, err, "failed to read") - continue + logger := u.log.With().Str("path", u.keyboardHidFile.Name()).Str("listener", "keyboardEvents").Logger() + logger.Trace().Msg("reading from keyboard for LED state changes") + n, err := u.keyboardHidFile.Read(buf) + if err != nil { + if errors.Is(err, os.ErrClosed) { + logger.Warn().Msg("keyboard file is closed, stopping keyboard event listener") + return + } else if exceeded := u.logWithSuppression("keyboardHidFileRead", 10, &logger, err, "failed to read"); exceeded { + logger.Error().Msg("too many errors reading the keyboard file, stopping keyboard event listener") + return } + } else { u.resetLogSuppressionCounter("keyboardHidFileRead") + } - l.Trace().Int("n", n).Uints8("buf", buf).Msg("got data from keyboard") - if n != 1 { - l.Trace().Int("n", n).Msg("expected 1 byte, got") - continue - } - u.updateKeyboardState(buf[0]) + logger.Trace().Int("n", n).Hex("buf", buf).Msg("got data from keyboard") + if n != 1 { + logger.Warn().Int("n", n).Msg("expected 1 byte") + continue } + u.updateKeyboardState(buf[0]) } - }() + } } -func (u *UsbGadget) openKeyboardHidFile() error { +func (u *UsbGadget) openKeyboardHidFileUnderMutex() error { if u.keyboardHidFile != nil { return nil } - var err error - u.keyboardHidFile, err = os.OpenFile("/dev/hidg0", os.O_RDWR, 0666) - if err != nil { - return fmt.Errorf("failed to open hidg0: %w", err) - } - if u.keyboardStateCancel != nil { u.keyboardStateCancel() + u.keyboardStateCancel = nil + } + + if keyboardFile, err := os.OpenFile("/dev/hidg0", os.O_RDWR, 0666); err == nil { + u.keyboardHidFile = keyboardFile + } else { + return fmt.Errorf("failed to open keyboard on hidg0: %w", err) } u.keyboardStateCtx, u.keyboardStateCancel = context.WithCancel(context.Background()) - u.listenKeyboardEvents() + go u.listenKeyboardEvents() return nil } +var keyboardHidFileLock sync.Mutex + func (u *UsbGadget) OpenKeyboardHidFile() error { - return u.openKeyboardHidFile() -} + keyboardHidFileLock.Lock() + defer keyboardHidFileLock.Unlock() -var keyboardWriteHidFileLock sync.Mutex + return u.openKeyboardHidFileUnderMutex() +} func (u *UsbGadget) keyboardWriteHidFile(modifier byte, keys []byte) error { - keyboardWriteHidFileLock.Lock() - defer keyboardWriteHidFileLock.Unlock() - if err := u.openKeyboardHidFile(); err != nil { + keyboardHidFileLock.Lock() + defer keyboardHidFileLock.Unlock() + + if err := u.openKeyboardHidFileUnderMutex(); err != nil { return err } _, err := u.writeWithTimeout(u.keyboardHidFile, append([]byte{modifier, 0x00}, keys[:hidKeyBufferSize]...)) if err != nil { - u.logWithSuppression("keyboardWriteHidFile", 100, u.log, err, "failed to write to hidg0") - u.keyboardHidFile.Close() + if cerr := u.keyboardHidFile.Close(); cerr != nil { + u.log.Error().Err(cerr).Msg("failed to close keyboard HID file after write error") + } u.keyboardHidFile = nil return err } - u.resetLogSuppressionCounter("keyboardWriteHidFile") return nil } func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) KeysDownState { - // if we just reported an error roll over, we should clear the keys - if keys[0] == hidErrorRollOver { - for i := range keys { - keys[i] = 0 - } - } - state := KeysDownState{ Modifier: modifier, Keys: []byte(keys[:]), @@ -355,8 +367,8 @@ func (u *UsbGadget) UpdateKeysDown(modifier byte, keys []byte) KeysDownState { u.keysDownState = state u.keyboardStateLock.Unlock() - if u.onKeysDownChange != nil { - (*u.onKeysDownChange)(state) // this enques to the outgoing hidrpc queue via usb.go → currentSession.enqueueKeysDownState(...) + if cb := u.onKeysDownChange; cb != nil { + go (*cb)(state) // this enqueues to the outgoing hidrpc queue via usb.go → currentSession.enqueueKeysDownState(...) } return state } @@ -373,10 +385,11 @@ func (u *UsbGadget) KeyboardReport(modifier byte, keys []byte) error { err := u.keyboardWriteHidFile(modifier, keys) if err != nil { - u.log.Warn().Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keyboard report to hidg0") + u.log.Warn().Err(err).Uint8("modifier", modifier).Uints8("keys", keys).Msg("Could not write keyboard report to hidg0") } u.UpdateKeysDown(modifier, keys) + defer u.ResetRollover() return err } @@ -417,13 +430,13 @@ var KeyCodeToMaskMap = map[byte]byte{ RightSuper: ModifierMaskRightSuper, } -func (u *UsbGadget) keypressReport(key byte, press bool) (KeysDownState, error) { +func (u *UsbGadget) keypressReport(l *zerolog.Logger, key byte, press bool) (KeysDownState, error) { defer u.resetUserInputTime() + logger := l.With().Hex("key", []byte{key}).Bool("press", press).Logger() - l := u.log.With().Uint8("key", key).Bool("press", press).Logger() - if l.GetLevel() <= zerolog.DebugLevel { + if logger.GetLevel() <= zerolog.DebugLevel { requestID := xid.New() - l = l.With().Str("requestID", requestID.String()).Logger() + logger = logger.With().Stringer("requestID", requestID).Logger() } // IMPORTANT: This code parallels the logic in the kernel's hid-gadget driver @@ -432,7 +445,7 @@ func (u *UsbGadget) keypressReport(key byte, press bool) (KeysDownState, error) // in the client/browser-side code in useKeyboard.ts so make sure to keep // them in sync. var state = u.GetKeysDownState() - l.Trace().Interface("state", state).Msg("got keys down state") + logger.Trace().Object("state", state).Msg("got keys down state") modifier := state.Modifier keys := append([]byte(nil), state.Keys...) @@ -473,14 +486,15 @@ func (u *UsbGadget) keypressReport(key byte, press bool) (KeysDownState, error) // If we reach here it means we didn't find an empty slot or the key in the buffer if overrun { if press { - l.Error().Msg("keyboard buffer overflow, key not added") + logger.Warn().Msg("keyboard buffer overflow, key not added") // Fill all key slots with ErrorRollOver (0x01) to indicate overflow for i := range keys { keys[i] = hidErrorRollOver } + defer u.ResetRollover() // after reporting rollover, we reset the buffer } else { // If we are releasing a key, and we didn't find it in a slot, who cares? - l.Warn().Msg("key not found in buffer, nothing to release") + logger.Debug().Msg("key not found in buffer, nothing to release") } } } @@ -490,13 +504,15 @@ func (u *UsbGadget) keypressReport(key byte, press bool) (KeysDownState, error) } func (u *UsbGadget) KeypressReport(key byte, press bool) error { - state, err := u.keypressReport(key, press) + logger := u.log.With().Hex("key", []byte{key}).Bool("press", press).Logger() + + state, err := u.keypressReport(&logger, key, press) if err != nil { - u.log.Warn().Uint8("key", key).Bool("press", press).Msg("failed to report key") + logger.Warn().Err(err).Msg("failed to report key") } - isRolledOver := state.Keys[0] == hidErrorRollOver + wasRolledOver := state.Keys[0] == hidErrorRollOver - if isRolledOver { + if wasRolledOver { u.cancelAutoRelease(key) } else if press { u.scheduleAutoRelease(key) diff --git a/internal/usbgadget/hid_mouse_absolute.go b/internal/usbgadget/hid_mouse_absolute.go index 374844f10..bc6734814 100644 --- a/internal/usbgadget/hid_mouse_absolute.go +++ b/internal/usbgadget/hid_mouse_absolute.go @@ -68,20 +68,19 @@ var absoluteMouseCombinedReportDesc = []byte{ func (u *UsbGadget) absMouseWriteHidFile(data []byte) error { if u.absMouseHidFile == nil { var err error - u.absMouseHidFile, err = os.OpenFile("/dev/hidg1", os.O_RDWR, 0666) - if err != nil { + if u.absMouseHidFile, err = os.OpenFile("/dev/hidg1", os.O_RDWR, 0666); err != nil { return fmt.Errorf("failed to open hidg1: %w", err) } } _, err := u.writeWithTimeout(u.absMouseHidFile, data) if err != nil { - u.logWithSuppression("absMouseWriteHidFile", 100, u.log, err, "failed to write to hidg1") - u.absMouseHidFile.Close() + if cerr := u.absMouseHidFile.Close(); cerr != nil { + u.log.Error().Err(cerr).Msg("failed to close absolute mouse HID file after write error") + } u.absMouseHidFile = nil return err } - u.resetLogSuppressionCounter("absMouseWriteHidFile") return nil } @@ -97,12 +96,9 @@ func (u *UsbGadget) AbsMouseReport(x int, y int, buttons uint8) error { byte(y), // Y Low Byte byte(y >> 8), // Y High Byte }) - if err != nil { - return err - } u.resetUserInputTime() - return nil + return err } func (u *UsbGadget) AbsMouseWheelReport(wheelY int8) error { diff --git a/internal/usbgadget/hid_mouse_relative.go b/internal/usbgadget/hid_mouse_relative.go index 070db6e89..fe654441c 100644 --- a/internal/usbgadget/hid_mouse_relative.go +++ b/internal/usbgadget/hid_mouse_relative.go @@ -58,20 +58,19 @@ var relativeMouseCombinedReportDesc = []byte{ func (u *UsbGadget) relMouseWriteHidFile(data []byte) error { if u.relMouseHidFile == nil { var err error - u.relMouseHidFile, err = os.OpenFile("/dev/hidg2", os.O_RDWR, 0666) - if err != nil { - return fmt.Errorf("failed to open hidg1: %w", err) + if u.relMouseHidFile, err = os.OpenFile("/dev/hidg2", os.O_RDWR, 0666); err != nil { + return fmt.Errorf("failed to open hidg2: %w", err) } } _, err := u.writeWithTimeout(u.relMouseHidFile, data) if err != nil { - u.logWithSuppression("relMouseWriteHidFile", 100, u.log, err, "failed to write to hidg2") - u.relMouseHidFile.Close() + if cerr := u.relMouseHidFile.Close(); cerr != nil { + u.log.Error().Err(cerr).Msg("failed to close relative mouse HID file after write error") + } u.relMouseHidFile = nil return err } - u.resetLogSuppressionCounter("relMouseWriteHidFile") return nil } @@ -85,10 +84,7 @@ func (u *UsbGadget) RelMouseReport(mx int8, my int8, buttons uint8) error { byte(my), // Y 0, // Wheel }) - if err != nil { - return err - } u.resetUserInputTime() - return nil + return err } diff --git a/internal/usbgadget/log.go b/internal/usbgadget/log.go index f979f6c1f..ee9461584 100644 --- a/internal/usbgadget/log.go +++ b/internal/usbgadget/log.go @@ -2,16 +2,42 @@ package usbgadget import ( "errors" + + "github.com/rs/zerolog" ) +func (u *UsbGadget) logWithSuppression(counterName string, every int, logger *zerolog.Logger, err error, msg string, args ...interface{}) bool { + u.logSuppressionLock.Lock() + counter, ok := u.logSuppressionCounter[counterName] // returns 0, false if not found + counter++ + u.logSuppressionCounter[counterName] = counter + u.logSuppressionLock.Unlock() + + // log if it's the first time, and then every N times thereafter + if !ok || counter%every == 0 { + logger.Error().Str("counterName", counterName).Int("counter", counter).Err(err).Msgf(msg, args...) + return ok // return whether we've just exceeded the every interval + } + return false +} + +func (u *UsbGadget) resetLogSuppressionCounter(counterName string) { + u.logSuppressionLock.Lock() + delete(u.logSuppressionCounter, counterName) + u.logSuppressionLock.Unlock() +} + func (u *UsbGadget) logWarn(msg string, err error) error { if err == nil { err = errors.New(msg) } + + u.log.Warn().Err(err).Msg(msg) + if u.strictMode { return err } - u.log.Warn().Err(err).Msg(msg) + return nil } @@ -19,9 +45,12 @@ func (u *UsbGadget) logError(msg string, err error) error { if err == nil { err = errors.New(msg) } + + u.log.Error().Err(err).Msg(msg) + if u.strictMode { return err } - u.log.Error().Err(err).Msg(msg) + return nil } diff --git a/internal/usbgadget/udc.go b/internal/usbgadget/udc.go index 4b7fbe361..f8aed097d 100644 --- a/internal/usbgadget/udc.go +++ b/internal/usbgadget/udc.go @@ -58,7 +58,7 @@ func (u *UsbGadget) GetUsbState() (state string) { if os.IsNotExist(err) { return "not attached" } else { - u.log.Trace().Err(err).Msg("failed to read usb state") + u.log.Warn().Err(err).Msg("failed to read usb state") } return "unknown" } diff --git a/internal/usbgadget/usbgadget.go b/internal/usbgadget/usbgadget.go index f01ae09d4..dcfd0059d 100644 --- a/internal/usbgadget/usbgadget.go +++ b/internal/usbgadget/usbgadget.go @@ -9,7 +9,8 @@ import ( "sync" "time" - "github.com/jetkvm/kvm/internal/logging" + "github.com/jetkvm/kvm/internal/utils" + "github.com/rs/zerolog" ) @@ -29,9 +30,8 @@ type Config struct { SerialNumber string `json:"serial_number"` Manufacturer string `json:"manufacturer"` Product string `json:"product"` - - strictMode bool // when it's enabled, all warnings will be converted to errors - isEmpty bool + isEmpty bool + strictMode bool // when it's enabled, all warnings will be converted to errors } var defaultUsbGadgetDevices = Devices{ @@ -42,8 +42,13 @@ var defaultUsbGadgetDevices = Devices{ } type KeysDownState struct { - Modifier byte `json:"modifier"` - Keys ByteSlice `json:"keys"` + Modifier byte `json:"modifier"` + Keys utils.ByteSlice `json:"keys"` +} + +func (k KeysDownState) MarshalZerologObject(e *zerolog.Event) { + e.Uint8("modifier", k.Modifier) + e.Object("keys", k.Keys) } // UsbGadget is a struct that represents a USB gadget. @@ -77,13 +82,12 @@ type UsbGadget struct { enabledDevices Devices - strictMode bool // only intended for testing for now + strictMode bool // only intended for testing absMouseAccumulatedWheelY float64 lastUserInput time.Time - tx *UsbGadgetTransaction txLock sync.Mutex onKeyboardStateChange *func(state KeyboardState) @@ -99,18 +103,12 @@ type UsbGadget struct { const configFSPath = "/sys/kernel/config" const gadgetPath = "/sys/kernel/config/usb_gadget" -var defaultLogger = logging.GetSubsystemLogger("usbgadget") - // NewUsbGadget creates a new UsbGadget. func NewUsbGadget(name string, enabledDevices *Devices, config *Config, logger *zerolog.Logger) *UsbGadget { return newUsbGadget(name, defaultGadgetConfig, enabledDevices, config, logger) } func newUsbGadget(name string, configMap map[string]gadgetConfigItem, enabledDevices *Devices, config *Config, logger *zerolog.Logger) *UsbGadget { - if logger == nil { - logger = defaultLogger - } - if enabledDevices == nil { enabledDevices = &defaultUsbGadgetDevices } @@ -141,15 +139,14 @@ func newUsbGadget(name string, configMap map[string]gadgetConfigItem, enabledDev lastUserInput: time.Now(), log: logger, - strictMode: config.strictMode, - + strictMode: config.strictMode, logSuppressionCounter: make(map[string]int), absMouseAccumulatedWheelY: 0, } + if err := g.Init(); err != nil { logger.Error().Err(err).Msg("failed to init USB gadget") - return nil } return g @@ -160,6 +157,7 @@ func (u *UsbGadget) Close() error { // Cancel keyboard state context if u.keyboardStateCancel != nil { u.keyboardStateCancel() + u.keyboardStateCancel = nil } // Stop auto-release timer diff --git a/internal/usbgadget/utils.go b/internal/usbgadget/utils.go index 85bf1579d..20d853db0 100644 --- a/internal/usbgadget/utils.go +++ b/internal/usbgadget/utils.go @@ -2,44 +2,15 @@ package usbgadget import ( "bytes" - "encoding/json" "errors" "fmt" "os" "path/filepath" "strconv" "strings" - "sync" "time" - - "github.com/rs/zerolog" ) -type ByteSlice []byte - -func (s ByteSlice) MarshalJSON() ([]byte, error) { - vals := make([]int, len(s)) - for i, v := range s { - vals[i] = int(v) - } - return json.Marshal(vals) -} - -func (s *ByteSlice) UnmarshalJSON(data []byte) error { - var vals []int - if err := json.Unmarshal(data, &vals); err != nil { - return err - } - *s = make([]byte, len(vals)) - for i, v := range vals { - if v < 0 || v > 255 { - return fmt.Errorf("value %d out of byte range", v) - } - (*s)[i] = byte(v) - } - return nil -} - func joinPath(basePath string, paths []string) string { pathArr := append([]string{basePath}, paths...) return filepath.Join(pathArr...) @@ -71,14 +42,21 @@ func hexToOctal(hex string) (string, error) { return octal, nil } +func compareIgnoreTrailingLF(shorter []byte, longer []byte) bool { + shorterLen := len(shorter) + longerLen := len(longer) + return shorterLen+1 == longerLen && + bytes.Equal(longer[:shorterLen], shorter) && + longer[shorterLen] == 0x0a +} + func compareFileContent(oldContent []byte, newContent []byte, looserMatch bool) bool { - if bytes.Equal(oldContent, newContent) { + if len(oldContent) == len(newContent) && bytes.Equal(oldContent, newContent) { return true } - if len(oldContent) == len(newContent)+1 && - bytes.Equal(oldContent[:len(newContent)], newContent) && - oldContent[len(newContent)] == 10 { + // allow for a trailing newline difference if the one did have one and the other does NOT + if compareIgnoreTrailingLF(oldContent, newContent) || compareIgnoreTrailingLF(newContent, oldContent) { return true } @@ -112,67 +90,31 @@ func compareFileContent(oldContent []byte, newContent []byte, looserMatch bool) } func (u *UsbGadget) writeWithTimeout(file *os.File, data []byte) (n int, err error) { + fileName := file.Name() + if err := file.SetWriteDeadline(time.Now().Add(hidWriteTimeout)); err != nil { return -1, err } n, err = file.Write(data) if err == nil { - return + u.resetLogSuppressionCounter("writeWithTimeout_" + fileName) + return n, nil } - u.log.Trace(). - Str("file", file.Name()). - Bytes("data", data). - Err(err). - Msg("write failed") - - if errors.Is(err, os.ErrDeadlineExceeded) { - u.logWithSuppression( - fmt.Sprintf("writeWithTimeout_%s", file.Name()), - 1000, - u.log, - err, - "write timed out: %s", - file.Name(), - ) - err = nil - } + logger := u.log.With().Str("file", fileName).Bytes("data", data).Logger() + logger.Trace().Err(err).Msg("write failed") - return -} - -func (u *UsbGadget) logWithSuppression(counterName string, every int, logger *zerolog.Logger, err error, msg string, args ...any) { - u.logSuppressionLock.Lock() - defer u.logSuppressionLock.Unlock() - - if _, ok := u.logSuppressionCounter[counterName]; !ok { - u.logSuppressionCounter[counterName] = 0 - } else { - u.logSuppressionCounter[counterName]++ - } - - l := logger.With().Int("counter", u.logSuppressionCounter[counterName]).Logger() - - if u.logSuppressionCounter[counterName]%every == 0 { - if err != nil { - l.Error().Err(err).Msgf(msg, args...) - } else { - l.Error().Msgf(msg, args...) + if errors.Is(err, os.ErrClosed) { + logger.Warn().Msg("file is closed, stopping writes") + return 0, err + } else if errors.Is(err, os.ErrDeadlineExceeded) { + if exceeded := u.logWithSuppression("writeWithTimeout_"+fileName, 10, &logger, err, "write timed out"); exceeded { + logger.Error().Msg("too many errors writing to the file, stopping writes") + return 0, err } + return 0, nil } -} - -func (u *UsbGadget) resetLogSuppressionCounter(counterName string) { - u.logSuppressionLock.Lock() - defer u.logSuppressionLock.Unlock() - - if _, ok := u.logSuppressionCounter[counterName]; !ok { - u.logSuppressionCounter[counterName] = 0 - } -} -func unlockWithLog(lock *sync.Mutex, logger *zerolog.Logger, msg string, args ...any) { - logger.Trace().Msgf(msg, args...) - lock.Unlock() + return n, err } diff --git a/internal/utils/byte_slice.go b/internal/utils/byte_slice.go new file mode 100644 index 000000000..406d3916a --- /dev/null +++ b/internal/utils/byte_slice.go @@ -0,0 +1,37 @@ +package utils + +import ( + "encoding/json" + "fmt" + + "github.com/rs/zerolog" +) + +type ByteSlice []byte + +func (s ByteSlice) MarshalJSON() ([]byte, error) { + vals := make([]int, len(s)) + for i, v := range s { + vals[i] = int(v) + } + return json.Marshal(vals) +} + +func (s *ByteSlice) UnmarshalJSON(data []byte) error { + var vals []int + if err := json.Unmarshal(data, &vals); err != nil { + return err + } + *s = make([]byte, len(vals)) + for i, v := range vals { + if v < 0 || v > 255 { + return fmt.Errorf("value %d out of byte range", v) + } + (*s)[i] = byte(v) + } + return nil +} + +func (k ByteSlice) MarshalZerologObject(e *zerolog.Event) { + e.Hex("bytes", k) +} diff --git a/internal/websecure/ed25519_test.go b/internal/websecure/ed25519_test.go index 0753be0d4..b2124c4b8 100644 --- a/internal/websecure/ed25519_test.go +++ b/internal/websecure/ed25519_test.go @@ -35,7 +35,7 @@ func TestMain(m *testing.M) { certSigner = NewSelfSigner( certStore, - nil, + &defaultLogger, "ci.jetkvm.com", "JetKVM", "JetKVM", diff --git a/internal/websecure/store.go b/internal/websecure/store.go index ea7911c48..c7cd84873 100644 --- a/internal/websecure/store.go +++ b/internal/websecure/store.go @@ -21,10 +21,6 @@ type CertStore struct { } func NewCertStore(storePath string, log *zerolog.Logger) *CertStore { - if log == nil { - log = &defaultLogger - } - return &CertStore{ certificates: make(map[string]*tls.Certificate), certLock: &sync.Mutex{}, diff --git a/jsonrpc.go b/jsonrpc.go index b401ac593..0be862dc7 100644 --- a/jsonrpc.go +++ b/jsonrpc.go @@ -53,54 +53,75 @@ type BacklightSettings struct { OffAfter int `json:"off_after"` } +func getOpenChannel(session *Session) (*webrtc.DataChannel, error) { + if session == nil { + return nil, errors.New("session is nil") + } + + rpcChannel := session.RPCChannel + + if rpcChannel == nil || rpcChannel.ReadyState() != webrtc.DataChannelStateOpen { + return nil, errors.New("RPC channel is nil or not open") + } + + return rpcChannel, nil +} + func writeJSONRPCResponse(response JSONRPCResponse, session *Session) { + logger := jsonRpcLogger.With().Interface("response", response).Logger() + responseBytes, err := json.Marshal(response) - if err != nil { - jsonRpcLogger.Warn().Err(err).Msg("Error marshalling JSONRPC response") - return - } - err = session.RPCChannel.SendText(string(responseBytes)) - if err != nil { - jsonRpcLogger.Warn().Err(err).Msg("Error sending JSONRPC response") - return + if err == nil { + if logger.GetLevel() <= zerolog.TraceLevel { + logger = logger.With().Object("responseBytes", utils.ByteSlice(responseBytes)).Logger() + } + + var rpcChannel *webrtc.DataChannel + if rpcChannel, err = getOpenChannel(session); err == nil { + logger.Trace().Msg("sending JSONRPC response") + if err = rpcChannel.SendText(string(responseBytes)); err == nil { + return + } + } } + + logger.Warn().Err(err).Msg("Error sending JSONRPC response") } func writeJSONRPCEvent(event string, params any, session *Session) { + logger := jsonRpcLogger.With().Str("event", event).Interface("params", params).Logger() + request := JSONRPCEvent{ JSONRPC: "2.0", Method: event, Params: params, } - requestBytes, err := json.Marshal(request) - if err != nil { - jsonRpcLogger.Warn().Err(err).Msg("Error marshalling JSONRPC event") - return - } - if session == nil || session.RPCChannel == nil { - jsonRpcLogger.Info().Msg("RPC channel not available") - return - } - requestString := string(requestBytes) - scopedLogger := jsonRpcLogger.With(). - Str("data", requestString). - Logger() - - scopedLogger.Trace().Msg("sending JSONRPC event") + requestBytes, err := json.Marshal(request) + if err == nil { + if logger.GetLevel() <= zerolog.TraceLevel { + logger = logger.With().Object("requestBytes", utils.ByteSlice(requestBytes)).Logger() + } - err = session.RPCChannel.SendText(requestString) - if err != nil { - scopedLogger.Warn().Err(err).Msg("error sending JSONRPC event") - return + var rpcChannel *webrtc.DataChannel + if rpcChannel, err = getOpenChannel(session); err == nil { + logger.Trace().Msg("sending JSONRPC event") + if err = rpcChannel.SendText(string(requestBytes)); err == nil { + return + } + } } + + logger.Warn().Err(err).Msg("Error sending JSONRPC event") } func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { + logger := jsonRpcLogger.With().Int("length", len(message.Data)).Logger() + var request JSONRPCRequest err := json.Unmarshal(message.Data, &request) if err != nil { - jsonRpcLogger.Warn(). + logger.Warn(). Str("data", string(message.Data)). Err(err). Msg("Error unmarshalling JSONRPC request") @@ -117,12 +138,14 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } - scopedLogger := jsonRpcLogger.With(). + logger = logger. + With(). Str("method", request.Method). Interface("params", request.Params). - Interface("id", request.ID).Logger() + Interface("id", request.ID). + Logger() - scopedLogger.Trace().Msg("Received RPC request") + logger.Trace().Msg("Received RPC request") t := time.Now() handler, ok := rpcHandlers[request.Method] @@ -139,9 +162,9 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } - result, err := callRPCHandler(scopedLogger, handler, request.Params) + result, err := callRPCHandler(&logger, handler, request.Params) if err != nil { - scopedLogger.Error().Err(err).Msg("Error calling RPC handler") + logger.Error().Err(err).Msg("Error calling RPC handler") errorResponse := JSONRPCResponse{ JSONRPC: "2.0", Error: map[string]any{ @@ -155,7 +178,7 @@ func onRPCMessage(message webrtc.DataChannelMessage, session *Session) { return } - scopedLogger.Trace().Dur("duration", time.Since(t)).Interface("result", result).Msg("RPC handler returned") + logger.Trace().Dur("duration", time.Since(t)).Interface("result", result).Msg("RPC handler returned") response := JSONRPCResponse{ JSONRPC: "2.0", @@ -174,7 +197,7 @@ func rpcGetDeviceID() (string, error) { } func rpcReboot(force bool) error { - logger.Info().Msg("Got reboot request via RPC") + jsonRpcLogger.Debug().Bool("force", force).Msg("Got reboot request via RPC") return hwReboot(force, nil, 0) } @@ -183,7 +206,7 @@ func rpcGetStreamQualityFactor() (float64, error) { } func rpcSetStreamQualityFactor(factor float64) error { - logger.Info().Float64("factor", factor).Msg("Setting stream quality factor") + jsonRpcLogger.Debug().Float64("factor", factor).Msg("Setting stream quality factor") err := nativeInstance.VideoSetQualityFactor(factor) if err != nil { return err @@ -201,6 +224,7 @@ func rpcGetAutoUpdateState() (bool, error) { } func rpcSetAutoUpdateState(enabled bool) (bool, error) { + jsonRpcLogger.Debug().Bool("enabled", enabled).Msg("setting auto-update state") config.AutoUpdateEnabled = enabled if err := SaveConfig(); err != nil { return config.AutoUpdateEnabled, fmt.Errorf("failed to save config: %w", err) @@ -217,10 +241,11 @@ func rpcGetEDID() (string, error) { } func rpcSetEDID(edid string) error { + logger := jsonRpcLogger.With().Str("edid", edid).Logger() if edid == "" { - logger.Info().Msg("Restoring EDID to default") + logger.Debug().Msg("Restoring EDID to default") } else { - logger.Info().Str("edid", edid).Msg("Setting EDID") + logger.Debug().Str("edid", edid).Msg("Setting EDID") } err := nativeInstance.VideoSetEDID(edid) if err != nil { @@ -238,18 +263,18 @@ func rpcGetVideoLogStatus() (string, error) { } func rpcSetDisplayRotation(params DisplayRotationSettings) error { + jsonRpcLogger.Debug().Interface("params", params).Msg("setting display rotation") + currentRotation := config.DisplayRotation if currentRotation == params.Rotation { return nil } - err := config.SetDisplayRotation(params.Rotation) - if err != nil { + if err := config.SetDisplayRotation(params.Rotation); err != nil { return err } - _, err = nativeInstance.DisplaySetRotation(config.GetDisplayRotation()) - if err != nil { + if _, err := nativeInstance.DisplaySetRotation(config.GetDisplayRotation()); err != nil { return err } @@ -257,7 +282,7 @@ func rpcSetDisplayRotation(params DisplayRotationSettings) error { return fmt.Errorf("failed to save config: %w", err) } - return err + return nil } func rpcGetDisplayRotation() (*DisplayRotationSettings, error) { @@ -267,24 +292,25 @@ func rpcGetDisplayRotation() (*DisplayRotationSettings, error) { } func rpcSetBacklightSettings(params BacklightSettings) error { - blConfig := params + logger := jsonRpcLogger.With().Interface("params", params).Logger() + logger.Debug().Msg("setting backlight settings") // NOTE: by default, the frontend limits the brightness to 64, as that's what the device originally shipped with. - if blConfig.MaxBrightness > 255 || blConfig.MaxBrightness < 0 { + if params.MaxBrightness > 255 || params.MaxBrightness < 0 { return fmt.Errorf("maxBrightness must be between 0 and 255") } - if blConfig.DimAfter < 0 { + if params.DimAfter < 0 { return fmt.Errorf("dimAfter must be a positive integer") } - if blConfig.OffAfter < 0 { + if params.OffAfter < 0 { return fmt.Errorf("offAfter must be a positive integer") } - config.DisplayMaxBrightness = blConfig.MaxBrightness - config.DisplayDimAfterSec = blConfig.DimAfter - config.DisplayOffAfterSec = blConfig.OffAfter + config.DisplayMaxBrightness = params.MaxBrightness + config.DisplayDimAfterSec = params.DimAfter + config.DisplayOffAfterSec = params.OffAfter if err := SaveConfig(); err != nil { return fmt.Errorf("failed to save config: %w", err) @@ -327,23 +353,21 @@ type SSHKeyState struct { } func rpcGetDevModeState() (DevModeState, error) { - devModeEnabled := false - if _, err := os.Stat(devModeFile); err != nil { - if !os.IsNotExist(err) { - return DevModeState{}, fmt.Errorf("error checking dev mode file: %w", err) - } + if _, err := os.Stat(devModeFile); err != nil && !os.IsNotExist(err) { + return DevModeState{}, fmt.Errorf("error checking dev mode file: %w", err) } else { - devModeEnabled = true + return DevModeState{ + Enabled: err == nil, + }, nil } - - return DevModeState{ - Enabled: devModeEnabled, - }, nil } func rpcSetDevModeState(enabled bool) error { + logger := jsonRpcLogger.With().Bool("enabled", enabled).Logger() + logger.Debug().Msg("setting dev mode state") + if enabled { - if _, err := os.Stat(devModeFile); os.IsNotExist(err) { + if _, err := os.Stat(devModeFile); err != nil && os.IsNotExist(err) { if err := os.MkdirAll(filepath.Dir(devModeFile), 0755); err != nil { return fmt.Errorf("failed to create directory for devmode file: %w", err) } @@ -367,10 +391,11 @@ func rpcSetDevModeState(enabled bool) error { } } + logger.Info().Msg("restarting dropbear service to apply dev mode changes") cmd := exec.Command("dropbear.sh") output, err := cmd.CombinedOutput() if err != nil { - logger.Warn().Err(err).Bytes("output", output).Msg("Failed to start/stop SSH") + logger.Warn().Err(err).Object("output", utils.ByteSlice(output)).Msg("Failed to start/stop SSH") return fmt.Errorf("failed to start/stop SSH, you may need to reboot for changes to take effect") } @@ -379,15 +404,16 @@ func rpcSetDevModeState(enabled bool) error { func rpcGetSSHKeyState() (string, error) { keyData, err := os.ReadFile(sshKeyFile) - if err != nil { - if !os.IsNotExist(err) { - return "", fmt.Errorf("error reading SSH key file: %w", err) - } + if err != nil && !os.IsNotExist(err) { + return "", fmt.Errorf("error reading SSH key file: %w", err) } return string(keyData), nil } func rpcSetSSHKeyState(sshKey string) error { + // Log the action but avoid logging the actual key for security reasons + jsonRpcLogger.Debug().Msg("setting SSH key") + if sshKey == "" { // Remove SSH key file if empty string is provided if err := os.Remove(sshKeyFile); err != nil && !os.IsNotExist(err) { @@ -419,6 +445,8 @@ func rpcGetTLSState() TLSState { } func rpcSetTLSState(state TLSState) error { + jsonRpcLogger.Debug().Interface("state", state).Msg("setting TLS state") + err := setTLSState(state) if err != nil { return fmt.Errorf("failed to set TLS state: %w", err) @@ -437,7 +465,7 @@ type RPCHandler struct { } // call the handler but recover from a panic to ensure our RPC thread doesn't collapse on malformed calls -func callRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[string]any) (result any, err error) { +func callRPCHandler(logger *zerolog.Logger, handler RPCHandler, params map[string]any) (result any, err error) { // Use defer to recover from a panic defer func() { if r := recover(); r != nil { @@ -447,6 +475,7 @@ func callRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[string } else { err = fmt.Errorf("panic occurred: %v", r) } + logger.Error().Err(err).Msg("Recovered from panic in RPC handler") } }() @@ -455,7 +484,7 @@ func callRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[string return result, err // do not combine these two lines into one, as it breaks the above defer function's setting of err } -func riskyCallRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[string]any) (any, error) { +func riskyCallRPCHandler(logger *zerolog.Logger, handler RPCHandler, params map[string]any) (any, error) { handlerValue := reflect.ValueOf(handler.Func) handlerType := handlerValue.Type() @@ -467,9 +496,7 @@ func riskyCallRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[s paramNames := handler.Params // Get the parameter names from the RPCHandler if len(paramNames) != numParams { - err := fmt.Errorf("mismatch between handler parameters (%d) and defined parameter names (%d)", numParams, len(paramNames)) - logger.Error().Strs("paramNames", paramNames).Err(err).Msg("Cannot call RPC handler") - return nil, err + return nil, fmt.Errorf("mismatch between handler parameters (%d) and defined parameter names (%d)", numParams, len(paramNames)) } args := make([]reflect.Value, numParams) @@ -479,9 +506,7 @@ func riskyCallRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[s paramName := paramNames[i] paramValue, ok := params[paramName] if !ok { - err := fmt.Errorf("missing parameter: %s", paramName) - logger.Error().Err(err).Msg("Cannot marshal arguments for RPC handler") - return nil, err + return nil, fmt.Errorf("missing parameter: %s", paramName) } convertedValue := reflect.ValueOf(paramValue) @@ -530,7 +555,7 @@ func riskyCallRPCHandler(logger zerolog.Logger, handler RPCHandler, params map[s } } - logger.Trace().Msg("Calling RPC handler") + logger.Trace().Interface("args", args).Msg("Calling RPC handler") results := handlerValue.Call(args) if len(results) == 0 { @@ -567,7 +592,8 @@ func asError(value reflect.Value) (bool, error) { } func rpcSetMassStorageMode(mode string) (string, error) { - logger.Info().Str("mode", mode).Msg("Setting mass storage mode") + logger := jsonRpcLogger.With().Str("mode", mode).Logger() + logger.Debug().Msg("Setting mass storage mode") var cdrom bool switch mode { case "cdrom": @@ -575,18 +601,18 @@ func rpcSetMassStorageMode(mode string) (string, error) { case "file": cdrom = false default: - logger.Info().Str("mode", mode).Msg("Invalid mode provided") + logger.Info().Msg("Invalid mode provided") return "", fmt.Errorf("invalid mode: %s", mode) } - logger.Info().Str("mode", mode).Msg("Setting mass storage mode") + logger.Info().Msg("Setting mass storage mode") err := setMassStorageMode(cdrom) if err != nil { return "", fmt.Errorf("failed to set mass storage mode: %w", err) } - logger.Info().Str("mode", mode).Msg("Mass storage mode set") + logger.Info().Msg("Mass storage mode set") // Get the updated mode after setting return rpcGetMassStorageMode() @@ -614,6 +640,7 @@ func rpcGetUsbEmulationState() (bool, error) { } func rpcSetUsbEmulationState(enabled bool) error { + jsonRpcLogger.Debug().Bool("enabled", enabled).Msg("setting USB emulation state") if enabled { return gadget.BindUDC() } else { @@ -627,6 +654,7 @@ func rpcGetUsbConfig() (usbgadget.Config, error) { } func rpcSetUsbConfig(usbConfig usbgadget.Config) error { + jsonRpcLogger.Debug().Interface("usbConfig", usbConfig).Msg("setting USB emulation state") LoadConfig() config.UsbConfig = &usbConfig gadget.SetGadgetConfig(config.UsbConfig) @@ -650,13 +678,14 @@ func rpcSetWakeOnLanDevices(params SetWakeOnLanDevicesParams) error { } func rpcResetConfig() error { + jsonRpcLogger.Debug().Msg("resetting configuration to default") defaultConfig := getDefaultConfig() config = &defaultConfig if err := SaveConfig(); err != nil { return fmt.Errorf("failed to reset config: %w", err) } - logger.Info().Msg("Configuration reset to default") + jsonRpcLogger.Info().Msg("Configuration reset to default") return nil } @@ -673,7 +702,7 @@ func rpcGetDCPowerState() (DCPowerState, error) { } func rpcSetDCPowerState(enabled bool) error { - logger.Info().Bool("enabled", enabled).Msg("Setting DC power state") + powerLogger.Info().Bool("enabled", enabled).Msg("Setting DC power state") err := setDCPowerState(enabled) if err != nil { return fmt.Errorf("failed to set DC power state: %w", err) @@ -682,7 +711,7 @@ func rpcSetDCPowerState(enabled bool) error { } func rpcSetDCRestoreState(state int) error { - logger.Info().Int("state", state).Msg("Setting DC restore state") + powerLogger.Info().Int("state", state).Msg("Setting DC restore state") err := setDCRestoreState(state) if err != nil { return fmt.Errorf("failed to set DC restore state: %w", err) @@ -718,7 +747,8 @@ func rpcSetActiveExtension(extensionId string) error { } func rpcSetATXPowerAction(action string) error { - logger.Debug().Str("action", action).Msg("Executing ATX power action") + logger := powerLogger.With().Str("action", action).Logger() + logger.Debug().Msg("Executing ATX power action") switch action { case "power-short": logger.Debug().Msg("Simulating short power button press") @@ -790,6 +820,9 @@ func rpcGetSerialSettings() (SerialSettings, error) { var serialPortMode = defaultMode func rpcSetSerialSettings(settings SerialSettings) error { + logger := jsonRpcLogger.With().Interface("settings", settings).Logger() + logger.Debug().Msg("setting serial settings") + baudRate, err := strconv.Atoi(settings.BaudRate) if err != nil { return fmt.Errorf("invalid baud rate: %v", err) @@ -833,7 +866,9 @@ func rpcSetSerialSettings(settings SerialSettings) error { Parity: parity, } - _ = port.SetMode(serialPortMode) + if err := port.SetMode(serialPortMode); err != nil { + logger.Error().Err(err).Interface("serialPortMode", serialPortMode).Msg("setting port mode") + } return nil } @@ -853,12 +888,14 @@ func updateUsbRelatedConfig() error { } func rpcSetUsbDevices(usbDevices usbgadget.Devices) error { + jsonRpcLogger.Debug().Interface("usbDevices", usbDevices).Msg("setting USB devices") config.UsbDevices = &usbDevices gadget.SetGadgetDevices(config.UsbDevices) return updateUsbRelatedConfig() } func rpcSetUsbDeviceState(device string, enabled bool) error { + jsonRpcLogger.Debug().Str("device", device).Bool("enabled", enabled).Msg("setting USB device state") switch device { case "absoluteMouse": config.UsbDevices.AbsoluteMouse = enabled @@ -876,6 +913,8 @@ func rpcSetUsbDeviceState(device string, enabled bool) error { } func rpcSetCloudUrl(apiUrl string, appUrl string) error { + jsonRpcLogger.Debug().Str("apiUrl", apiUrl).Str("appUrl", appUrl).Msg("setting cloud urls") + currentCloudURL := config.CloudURL config.CloudURL = apiUrl config.CloudAppURL = appUrl @@ -900,6 +939,8 @@ func rpcGetKeyboardLayout() (string, error) { } func rpcSetKeyboardLayout(layout string) error { + jsonRpcLogger.Debug().Str("layout", layout).Msg("setting keyboard layout") + config.KeyboardLayout = layout if err := SaveConfig(); err != nil { return fmt.Errorf("failed to save config: %w", err) @@ -919,6 +960,7 @@ type KeyboardMacrosParams struct { } func setKeyboardMacros(params KeyboardMacrosParams) (any, error) { + jsonRpcLogger.Debug().Interface("params", params).Msg("setting keyboard macros") if params.Macros == nil { return nil, fmt.Errorf("missing or invalid macros parameter") } @@ -1005,6 +1047,7 @@ func rpcGetLocalLoopbackOnly() (bool, error) { } func rpcSetLocalLoopbackOnly(enabled bool) error { + jsonRpcLogger.Debug().Bool("enabled", enabled).Msg("setting local loopback only mode") // Check if the setting is actually changing if config.LocalLoopbackOnly == enabled { return nil @@ -1025,15 +1068,16 @@ var ( ) // cancelKeyboardMacro cancels any ongoing keyboard macro execution -func cancelKeyboardMacro() { +func cancelKeyboardMacro() error { keyboardMacroLock.Lock() defer keyboardMacroLock.Unlock() if keyboardMacroCancel != nil { keyboardMacroCancel() - logger.Info().Msg("canceled keyboard macro") keyboardMacroCancel = nil + jsonRpcLogger.Info().Msg("canceled keyboard macro") } + return nil } func setKeyboardMacroCancel(cancel context.CancelFunc) { @@ -1044,7 +1088,8 @@ func setKeyboardMacroCancel(cancel context.CancelFunc) { } func rpcExecuteKeyboardMacro(macro []hidrpc.KeyboardMacroStep) error { - cancelKeyboardMacro() + jsonRpcLogger.Debug().Int("steps", len(macro)).Msg("executing keyboard macro") + _ = cancelKeyboardMacro() ctx, cancel := context.WithCancel(context.Background()) setKeyboardMacroCancel(cancel) @@ -1055,7 +1100,7 @@ func rpcExecuteKeyboardMacro(macro []hidrpc.KeyboardMacroStep) error { } if currentSession != nil { - currentSession.reportHidRPCKeyboardMacroState(s) + go currentSession.reportHidRPCKeyboardMacroState(s) } err := rpcDoExecuteKeyboardMacro(ctx, macro) @@ -1064,14 +1109,15 @@ func rpcExecuteKeyboardMacro(macro []hidrpc.KeyboardMacroStep) error { s.State = false if currentSession != nil { - currentSession.reportHidRPCKeyboardMacroState(s) + go currentSession.reportHidRPCKeyboardMacroState(s) } return err } -func rpcCancelKeyboardMacro() { - cancelKeyboardMacro() +func rpcCancelKeyboardMacro() error { + jsonRpcLogger.Debug().Msg("cancelling keyboard macro") + return cancelKeyboardMacro() } var keyboardClearStateKeys = make([]byte, hidrpc.HidKeyBufferSize) @@ -1081,7 +1127,8 @@ func isClearKeyStep(step hidrpc.KeyboardMacroStep) bool { } func rpcDoExecuteKeyboardMacro(ctx context.Context, macro []hidrpc.KeyboardMacroStep) error { - logger.Debug().Interface("macro", macro).Msg("Executing keyboard macro") + logger := jsonRpcLogger.With().Interface("macro", macro).Logger() + logger.Debug().Msg("Executing keyboard macro") for i, step := range macro { delay := time.Duration(step.Delay) * time.Millisecond @@ -1108,7 +1155,7 @@ func rpcDoExecuteKeyboardMacro(ctx context.Context, macro []hidrpc.KeyboardMacro logger.Warn().Err(err).Msg("failed to reset keyboard state") } - logger.Debug().Int("step", i).Msg("Keyboard macro cancelled during sleep") + logger.Debug().Int("step", i).Msg("Keyboard macro cancelled during step") return ctx.Err() } } diff --git a/log.go b/log.go index 9cd9188e6..00aa1e02a 100644 --- a/log.go +++ b/log.go @@ -2,13 +2,8 @@ package kvm import ( "github.com/jetkvm/kvm/internal/logging" - "github.com/rs/zerolog" ) -func ErrorfL(l *zerolog.Logger, format string, err error, args ...any) error { - return logging.ErrorfL(l, format, err, args...) -} - var ( logger = logging.GetSubsystemLogger("jetkvm") failsafeLogger = logging.GetSubsystemLogger("failsafe") @@ -29,6 +24,7 @@ var ( displayLogger = logging.GetSubsystemLogger("display") wolLogger = logging.GetSubsystemLogger("wol") usbLogger = logging.GetSubsystemLogger("usb") + powerLogger = logging.GetSubsystemLogger("dcpower") // external components ginLogger = logging.GetSubsystemLogger("gin") ) diff --git a/main.go b/main.go index 83d337d7c..f79a419cf 100644 --- a/main.go +++ b/main.go @@ -11,7 +11,6 @@ import ( "github.com/erikdubbelboer/gspt" "github.com/gwatts/rootcerts" - "github.com/jetkvm/kvm/internal/ota" ) var appCtx context.Context @@ -33,7 +32,7 @@ func Main() { checkFailsafeReason() if failsafeModeActive { procPrefix = "jetkvm: [app+failsafe]" - logger.Warn().Str("reason", failsafeModeReason).Msg("failsafe mode activated") + logger.Warn().Str("subcomponent", "failsafe").Str("reason", failsafeModeReason).Msg("failsafe mode activated") } LoadConfig() @@ -104,50 +103,12 @@ func Main() { if err := initImagesFolder(); err != nil { logger.Warn().Err(err).Msg("failed to init images folder") } + initJiggler() // start video sleep mode timer startVideoSleepModeTicker() - go func() { - // wait for 15 minutes before starting auto-update checks - // this is to avoid interfering with initial setup processes - // and to ensure the system is stable before checking for updates - time.Sleep(15 * time.Minute) - - for { - logger.Info().Bool("auto_update_enabled", config.AutoUpdateEnabled).Msg("auto-update check") - if !config.AutoUpdateEnabled { - logger.Debug().Msg("auto-update disabled") - time.Sleep(5 * time.Minute) // we'll check if auto-updates are enabled in five minutes - continue - } - - if currentSession != nil { - logger.Debug().Msg("skipping update since a session is active") - time.Sleep(1 * time.Minute) - continue - } - - if isTimeSyncNeeded() || !timeSync.IsSyncSuccess() { - logger.Debug().Msg("system time is not synced, will retry in 30 seconds") - time.Sleep(30 * time.Second) - continue - } - - includePreRelease := config.IncludePreRelease - err = otaState.TryUpdate(context.Background(), ota.UpdateParams{ - DeviceID: GetDeviceID(), - IncludePreRelease: includePreRelease, - }) - if err != nil { - logger.Warn().Err(err).Msg("failed to auto update") - } - - time.Sleep(1 * time.Hour) - } - }() - //go RunFuseServer() go RunWebServer() @@ -159,26 +120,25 @@ func Main() { // As websocket client already checks if the cloud token is set, we can start it here. go RunWebsocketClient() - initPublicIPState() + initPublicIPState() initSerialPort() setProcTitle("ready") + logger.Log().Msg("JetKVM ready") + + go RunAutoUpdateCheck() sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) <-sigs - logger.Log().Msg("JetKVM Shutting Down") - //if fuseServer != nil { - // err := setMassStorageImage(" ") - // if err != nil { - // logger.Infof("Failed to unmount mass storage image: %v", err) - // } - // err = fuseServer.Unmount() - // if err != nil { - // logger.Infof("Failed to unmount fuse: %v", err) - // } + if err := rpcUnmountImage(); err != nil { + logger.Warn().Err(err).Msg("failed to eject virtual media on shutdown") + } + gadget.Close() + + logger.Log().Msg("JetKVM Shutting Down") // os.Exit(0) } diff --git a/ota.go b/ota.go index ef7f9c21a..260316d71 100644 --- a/ota.go +++ b/ota.go @@ -6,6 +6,7 @@ import ( "net/http" "os" "strings" + "time" "github.com/Masterminds/semver/v3" "github.com/google/uuid" @@ -17,29 +18,31 @@ var builtAppVersion = "0.1.0+dev" var otaState *ota.State func initOta() { - otaState = ota.NewState(ota.Options{ - Logger: otaLogger, - ReleaseAPIEndpoint: config.GetUpdateAPIURL(), - GetHTTPClient: func() ota.HttpClient { - transport := http.DefaultTransport.(*http.Transport).Clone() - transport.Proxy = config.NetworkConfig.GetTransportProxyFunc() - - client := &http.Client{ - Transport: transport, - } - return client - }, - GetLocalVersion: GetLocalVersion, - HwReboot: hwReboot, - ResetConfig: rpcResetConfig, - SetAutoUpdate: rpcSetAutoUpdateState, - OnStateUpdate: func(state *ota.RPCState) { - triggerOTAStateUpdate(state) - }, - OnProgressUpdate: func(progress float32) { - writeJSONRPCEvent("otaProgress", progress, currentSession) + otaState = ota.NewState( + ota.Options{ + Logger: otaLogger, + ReleaseAPIEndpoint: config.GetUpdateAPIURL(), + GetHTTPClient: func() ota.HttpClient { + transport := http.DefaultTransport.(*http.Transport).Clone() + transport.Proxy = config.NetworkConfig.GetTransportProxyFunc() + + client := &http.Client{ + Transport: transport, + } + return client + }, + GetLocalVersion: GetLocalVersion, + HwReboot: hwReboot, + ResetConfig: rpcResetConfig, + SetAutoUpdate: rpcSetAutoUpdateState, + OnStateUpdate: func(state *ota.RPCState) { + triggerOTAStateUpdate(state) + }, + OnProgressUpdate: func(progress float32) { + writeJSONRPCEvent("otaProgress", progress, currentSession) + }, }, - }) + ) } func triggerOTAStateUpdate(state *ota.RPCState) { @@ -50,6 +53,8 @@ func triggerOTAStateUpdate(state *ota.RPCState) { if state == nil { state = otaState.ToRPCState() } + otaLogger.Trace().Interface("state", state).Msg("Reporting OTA state") + writeJSONRPCEvent("otaState", state, currentSession) }() } @@ -80,11 +85,13 @@ func GetLocalVersion() (systemVersion *semver.Version, appVersion *semver.Versio } func getUpdateStatus(includePreRelease bool) (*ota.UpdateStatus, error) { - updateStatus, err := otaState.GetUpdateStatus(context.Background(), ota.UpdateParams{ - DeviceID: GetDeviceID(), - IncludePreRelease: includePreRelease, - RequestID: uuid.New().String(), - }) + updateStatus, err := otaState.GetUpdateStatus( + context.Background(), + ota.UpdateParams{ + DeviceID: GetDeviceID(), + IncludePreRelease: includePreRelease, + RequestID: uuid.New().String(), + }) // to ensure backwards compatibility, // if there's an error, we won't return an error, but we will set the error field @@ -160,6 +167,7 @@ func rpcCheckUpdateComponents(params updateParams, includePreRelease bool) (*ota IncludePreRelease: includePreRelease, Components: params.Components, } + info, err := otaState.GetUpdateStatus(context.Background(), updateParams) if err != nil { return nil, fmt.Errorf("failed to check update: %w", err) @@ -183,3 +191,48 @@ func rpcTryUpdateComponents(params updateParams, includePreRelease bool, resetCo }() return nil } + +func RunAutoUpdateCheck() { + // initially wait for 15 minutes before starting auto-update checks + // to avoid interfering with initial setup processes and to ensure + // the system is stable before checking for updates + ticker := time.NewTicker(15 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-appCtx.Done(): + return + case <-ticker.C: + logger := otaLogger + logger.Info().Bool("auto_update_enabled", config.AutoUpdateEnabled).Msg("auto-update check") + + if !config.AutoUpdateEnabled { + logger.Info().Msg("auto-update disabled, waiting 5 minutes") + ticker.Reset(5 * time.Minute) // we'll check if auto-updates are enabled in five minutes + continue + } + + if currentSession != nil { + logger.Info().Msg("skipping update since a session is active for one minute") + ticker.Reset(1 * time.Minute) + continue + } + + if isTimeSyncNeeded() || !timeSync.IsSyncSuccess() { + logger.Info().Msg("system time is not synced, will retry in 30 seconds") + ticker.Reset(30 * time.Second) + continue + } + + if err := otaState.TryUpdate(context.Background(), ota.UpdateParams{ + DeviceID: GetDeviceID(), + IncludePreRelease: config.IncludePreRelease, + }); err != nil { + logger.Warn().Err(err).Msg("failed to auto update") + } + + ticker.Reset(1 * time.Hour) + } + } +} diff --git a/pkg/nmlite/interface.go b/pkg/nmlite/interface.go index a9111a645..13c6d9402 100644 --- a/pkg/nmlite/interface.go +++ b/pkg/nmlite/interface.go @@ -118,7 +118,7 @@ func (im *InterfaceManager) Start() error { im.wg.Add(1) go im.monitorInterfaceState() - nl := getNetlinkManager() + nl := getNetlinkManager(im.logger) // Set the link state linkState, err := nl.GetLinkByName(im.ifaceName) @@ -175,7 +175,7 @@ func (im *InterfaceManager) Stop() error { } func (im *InterfaceManager) link() (*link.Link, error) { - nl := getNetlinkManager() + nl := getNetlinkManager(im.logger) if nl == nil { return nil, fmt.Errorf("netlink manager not initialized") } @@ -572,7 +572,7 @@ func (im *InterfaceManager) applyIPv6SLAAC() error { return fmt.Errorf("failed to get interface: %w", err) } - netlinkMgr := getNetlinkManager() + netlinkMgr := getNetlinkManager(im.logger) // Ensure interface is up if err := netlinkMgr.EnsureInterfaceUp(l); err != nil { @@ -727,7 +727,7 @@ func (im *InterfaceManager) handleLinkDown() { } } - netlinkMgr := getNetlinkManager() + netlinkMgr := getNetlinkManager(im.logger) if err := netlinkMgr.RemoveAllAddresses(im.linkState, link.AfInet); err != nil { im.logger.Error().Err(err).Msg("failed to remove all IPv4 addresses") } @@ -795,7 +795,7 @@ func (im *InterfaceManager) updateStateFromDHCPLease(lease *types.DHCPLease) { // ReconcileLinkAddrs reconciles the link addresses func (im *InterfaceManager) ReconcileLinkAddrs(addrs []types.IPAddress, family int) error { - nl := getNetlinkManager() + nl := getNetlinkManager(im.logger) link, err := im.link() if err != nil { return fmt.Errorf("failed to get interface: %w", err) diff --git a/pkg/nmlite/interface_state.go b/pkg/nmlite/interface_state.go index efa5f087b..53648d894 100644 --- a/pkg/nmlite/interface_state.go +++ b/pkg/nmlite/interface_state.go @@ -110,7 +110,7 @@ func (im *InterfaceManager) updateInterfaceState() error { // updateIPAddresses updates the IP addresses in the state func (im *InterfaceManager) updateInterfaceStateAddresses(nl *link.Link) (bool, error) { - mgr := getNetlinkManager() + mgr := getNetlinkManager(im.logger) addrs, err := nl.AddrList(link.AfUnspec) if err != nil { diff --git a/pkg/nmlite/jetdhcpc/client.go b/pkg/nmlite/jetdhcpc/client.go index 102d3bee0..6583d119f 100644 --- a/pkg/nmlite/jetdhcpc/client.go +++ b/pkg/nmlite/jetdhcpc/client.go @@ -220,7 +220,7 @@ func (c *Client) requestLoop(t *time.Timer, family int, ifname string) { } func (c *Client) ensureInterfaceUp(ifname string) (*link.Link, error) { - nlm := link.GetNetlinkManager() + nlm := link.GetNetlinkManager(c.l) iface, err := nlm.GetLinkByName(ifname) if err != nil { return nil, err diff --git a/pkg/nmlite/link/manager.go b/pkg/nmlite/link/manager.go index c9b9410c5..05dbfdc23 100644 --- a/pkg/nmlite/link/manager.go +++ b/pkg/nmlite/link/manager.go @@ -42,9 +42,9 @@ func newNetlinkManager(logger *zerolog.Logger) *NetlinkManager { } // GetNetlinkManager returns the singleton NetlinkManager instance -func GetNetlinkManager() *NetlinkManager { +func GetNetlinkManager(logger *zerolog.Logger) *NetlinkManager { netlinkManagerOnce.Do(func() { - netlinkManagerInstance = newNetlinkManager(nil) + netlinkManagerInstance = newNetlinkManager(logger) }) return netlinkManagerInstance } diff --git a/pkg/nmlite/netlink.go b/pkg/nmlite/netlink.go index cca2fc09e..c3b33280e 100644 --- a/pkg/nmlite/netlink.go +++ b/pkg/nmlite/netlink.go @@ -1,7 +1,10 @@ package nmlite -import "github.com/jetkvm/kvm/pkg/nmlite/link" +import ( + "github.com/jetkvm/kvm/pkg/nmlite/link" + "github.com/rs/zerolog" +) -func getNetlinkManager() *link.NetlinkManager { - return link.GetNetlinkManager() +func getNetlinkManager(logger *zerolog.Logger) *link.NetlinkManager { + return link.GetNetlinkManager(logger) } diff --git a/pkg/nmlite/static.go b/pkg/nmlite/static.go index 9500556b6..0a99ed245 100644 --- a/pkg/nmlite/static.go +++ b/pkg/nmlite/static.go @@ -121,7 +121,7 @@ func (scm *StaticConfigManager) ToIPv6Static(config *types.IPv6StaticConfig) (*t func (scm *StaticConfigManager) DisableIPv4() error { scm.logger.Info().Msg("disabling IPv4") - netlinkMgr := getNetlinkManager() + netlinkMgr := getNetlinkManager(scm.logger) iface, err := netlinkMgr.GetLinkByName(scm.ifaceName) if err != nil { return fmt.Errorf("failed to get interface: %w", err) @@ -144,14 +144,14 @@ func (scm *StaticConfigManager) DisableIPv4() error { // DisableIPv6 disables IPv6 on the interface func (scm *StaticConfigManager) DisableIPv6() error { scm.logger.Info().Msg("disabling IPv6") - netlinkMgr := getNetlinkManager() + netlinkMgr := getNetlinkManager(scm.logger) return netlinkMgr.DisableIPv6(scm.ifaceName) } // EnableIPv6SLAAC enables IPv6 SLAAC func (scm *StaticConfigManager) EnableIPv6SLAAC() error { scm.logger.Info().Msg("enabling IPv6 SLAAC") - netlinkMgr := getNetlinkManager() + netlinkMgr := getNetlinkManager(scm.logger) return netlinkMgr.EnableIPv6SLAAC(scm.ifaceName) } @@ -159,7 +159,7 @@ func (scm *StaticConfigManager) EnableIPv6SLAAC() error { func (scm *StaticConfigManager) EnableIPv6LinkLocal() error { scm.logger.Info().Msg("enabling IPv6 link-local only") - netlinkMgr := getNetlinkManager() + netlinkMgr := getNetlinkManager(scm.logger) if err := netlinkMgr.EnableIPv6LinkLocal(scm.ifaceName); err != nil { return err } @@ -179,6 +179,6 @@ func (scm *StaticConfigManager) EnableIPv6LinkLocal() error { // removeIPv4DefaultRoute removes IPv4 default route func (scm *StaticConfigManager) removeIPv4DefaultRoute() error { - netlinkMgr := getNetlinkManager() + netlinkMgr := getNetlinkManager(scm.logger) return netlinkMgr.RemoveDefaultRoute(link.AfInet) } diff --git a/serial.go b/serial.go index 5439d135a..a14fe4e8b 100644 --- a/serial.go +++ b/serial.go @@ -271,6 +271,7 @@ func initSerialPort() { func reopenSerialPort() error { if port != nil { port.Close() + port = nil } var err error port, err = serial.Open(serialPortPath, defaultMode) diff --git a/terminal.go b/terminal.go index e06e5cdc1..af60f95a7 100644 --- a/terminal.go +++ b/terminal.go @@ -83,6 +83,7 @@ func handleTerminalChannel(d *webrtc.DataChannel) { d.OnClose(func() { if ptmx != nil { ptmx.Close() + ptmx = nil } if cmd != nil && cmd.Process != nil { _ = cmd.Process.Kill() diff --git a/timesync.go b/timesync.go index 956011b3c..7b10d5ef5 100644 --- a/timesync.go +++ b/timesync.go @@ -13,14 +13,16 @@ var ( ) func isTimeSyncNeeded() bool { + logger := timesyncLogger + if builtTimestamp == "" { - timesyncLogger.Warn().Msg("built timestamp is not set, time sync is needed") + logger.Warn().Msg("built timestamp is not set, time sync is needed") return true } ts, err := strconv.Atoi(builtTimestamp) if err != nil { - timesyncLogger.Warn().Str("error", err.Error()).Msg("failed to parse built timestamp") + logger.Warn().Str("error", err.Error()).Msg("failed to parse built timestamp") return true } @@ -29,7 +31,7 @@ func isTimeSyncNeeded() bool { now := time.Now() if now.Sub(builtTime) < 0 { - timesyncLogger.Warn(). + logger.Warn(). Str("built_time", builtTime.Format(time.RFC3339)). Str("now", now.Format(time.RFC3339)). Msg("system time is behind the built time, time sync is needed") diff --git a/usb.go b/usb.go index af57692f6..0ecb43f8a 100644 --- a/usb.go +++ b/usb.go @@ -11,7 +11,7 @@ var gadget *usbgadget.UsbGadget // initUsbGadget initializes the USB gadget. // call it only after the config is loaded. -func initUsbGadget() { +func initUsbGadget() *usbgadget.UsbGadget { gadget = usbgadget.NewUsbGadget( "jetkvm", config.UsbDevices, @@ -19,13 +19,6 @@ func initUsbGadget() { usbLogger, ) - go func() { - for { - checkUSBState() - time.Sleep(500 * time.Millisecond) - } - }() - gadget.SetOnKeyboardStateChange(func(state usbgadget.KeyboardState) { if currentSession != nil { currentSession.reportHidRPCKeyboardLedState(state) @@ -44,10 +37,21 @@ func initUsbGadget() { } }) - // open the keyboard hid file to listen for keyboard events - if err := gadget.OpenKeyboardHidFile(); err != nil { - usbLogger.Error().Err(err).Msg("failed to open keyboard hid file") - } + go func() { + for { + // is the USB configured? + if checkUSBState() { + // ensure we have opened the keyboard hid file to listen for keyboard events + if err := gadget.OpenKeyboardHidFile(); err != nil { + usbLogger.Error().Err(err).Msg("failed to open keyboard hid file") + // but keep trying... + } + } + time.Sleep(500 * time.Millisecond) + } + }() + + return gadget } func rpcKeyboardReport(modifier byte, keys []byte) error { @@ -90,28 +94,26 @@ func rpcGetUSBState() (state string) { func triggerUSBStateUpdate() { go func() { if currentSession == nil { - usbLogger.Info().Msg("No active RPC session, skipping USB state update") + usbLogger.Debug().Msg("No active RPC session, skipping USB state update") return } writeJSONRPCEvent("usbState", usbState, currentSession) }() } -func checkUSBState() { +func checkUSBState() bool { usbStateLock.Lock() defer usbStateLock.Unlock() newState := gadget.GetUsbState() - usbLogger.Trace().Str("old", usbState).Str("new", newState).Msg("Checking USB state") + if newState != usbState { + usbLogger.Debug().Str("from", usbState).Str("to", newState).Msg("USB state changed") + usbState = newState - if newState == usbState { - return + requestDisplayUpdate(true, "usb_state_changed") + triggerUSBStateUpdate() } - usbState = newState - usbLogger.Info().Str("from", usbState).Str("to", newState).Msg("USB state changed") - - requestDisplayUpdate(true, "usb_state_changed") - triggerUSBStateUpdate() + return newState == "configured" } diff --git a/usb_mass_storage.go b/usb_mass_storage.go index 0f1f4b934..bb149134c 100644 --- a/usb_mass_storage.go +++ b/usb_mass_storage.go @@ -41,13 +41,15 @@ func getMassStorageImage() (string, error) { func setMassStorageImage(imagePath string) error { massStorageFunctionPath, err := gadget.GetPath("mass_storage_lun0") + if err != nil { - return fmt.Errorf("failed to get mass storage path: %w", err) + return fmt.Errorf("failed to get mass storage path error: %w", err) } if err := writeFile(path.Join(massStorageFunctionPath, "file"), imagePath); err != nil { - return fmt.Errorf("failed to set image path: %w", err) + return fmt.Errorf("failed to set image path %s error: %w", imagePath, err) } + return nil } @@ -57,7 +59,7 @@ func setMassStorageMode(cdrom bool) error { mode = "1" } - err, changed := gadget.OverrideGadgetConfig("mass_storage_lun0", "cdrom", mode) + changed, err := gadget.OverrideGadgetConfig("mass_storage_lun0", "cdrom", mode) if err != nil { return fmt.Errorf("failed to set cdrom mode: %w", err) } @@ -70,18 +72,14 @@ func setMassStorageMode(cdrom bool) error { } func mountImage(imagePath string) error { - err := setMassStorageImage("") - if err != nil { + if err := setMassStorageImage(""); err != nil { return fmt.Errorf("remove mass storage image error: %w", err) } - err = setMassStorageImage(imagePath) - if err != nil { - return fmt.Errorf("set mass storage image error: %w", err) - } - err = setMassStorageImage(imagePath) - if err != nil { - return fmt.Errorf("set Mass Storage Image Error: %w", err) + + if err := setMassStorageImage(imagePath); err != nil { + return fmt.Errorf("set mass storage image path %s error: %w", imagePath, err) } + return nil } @@ -192,16 +190,22 @@ func rpcGetVirtualMediaState() (*VirtualMediaState, error) { func rpcUnmountImage() error { virtualMediaStateMutex.Lock() defer virtualMediaStateMutex.Unlock() - err := setMassStorageImage("\n") - if err != nil { - logger.Warn().Err(err).Msg("Remove Mass Storage Image Error") - } - //TODO: check if we still need it - time.Sleep(500 * time.Millisecond) + + logger.Info().Msg("Unmounting virtual media image") + if nbdDevice != nil { + logger.Trace().Msg("Stopping nbd device") nbdDevice.Close() nbdDevice = nil } + + time.Sleep(500 * time.Millisecond) + + err := setMassStorageImage("") + if err != nil { + logger.Warn().Err(err).Msg("Remove Mass Storage Image Error") + } + currentVirtualMediaState = nil return nil } @@ -263,18 +267,27 @@ func setInitialVirtualMediaState() error { } func rpcMountWithHTTP(url string, mode VirtualMediaMode) error { + if err := configureHttpStorageDevice(url, mode); err != nil { + return fmt.Errorf("failed to mount with http: %w", err) + } + return initializeNBDDevice() +} + +func configureHttpStorageDevice(url string, mode VirtualMediaMode) error { virtualMediaStateMutex.Lock() + defer virtualMediaStateMutex.Unlock() + if currentVirtualMediaState != nil { - virtualMediaStateMutex.Unlock() return fmt.Errorf("another virtual media is already mounted") } + httpRangeReader = httpreadat.New(url) n, err := httpRangeReader.Size() if err != nil { - virtualMediaStateMutex.Unlock() return fmt.Errorf("failed to use http url: %w", err) } - logger.Info().Str("url", url).Int64("size", n).Msg("using remote url") + + logger.Info().Str("url", url).Bool("cdrom", mode == CDROM).Int64("size", n).Msg("using remote url") if err := setMassStorageMode(mode == CDROM); err != nil { return fmt.Errorf("failed to set mass storage mode: %w", err) @@ -286,22 +299,28 @@ func rpcMountWithHTTP(url string, mode VirtualMediaMode) error { URL: url, Size: n, } - virtualMediaStateMutex.Unlock() + return nil +} + +func initializeNBDDevice() error { logger.Debug().Msg("Starting nbd device") + nbdDevice = NewNBDDevice() - err = nbdDevice.Start() - if err != nil { + if err := nbdDevice.Start(); err != nil { logger.Warn().Err(err).Msg("failed to start nbd device") - return err + return fmt.Errorf("failed to set ndb device: %w", err) } + logger.Debug().Msg("nbd device started") + //TODO: replace by polling on block device having right size time.Sleep(1 * time.Second) - err = setMassStorageImage("/dev/nbd0") - if err != nil { - return err + + if err := setMassStorageImage("/dev/nbd0"); err != nil { + return fmt.Errorf("failed to set mass storage image to /dev/nbd0: %w", err) } + logger.Info().Msg("usb mass storage mounted") return nil } @@ -309,11 +328,12 @@ func rpcMountWithHTTP(url string, mode VirtualMediaMode) error { func rpcMountWithStorage(filename string, mode VirtualMediaMode) error { filename, err := sanitizeFilename(filename) if err != nil { - return err + return fmt.Errorf("failed to sanitize filename %s: %w", filename, err) } virtualMediaStateMutex.Lock() defer virtualMediaStateMutex.Unlock() + if currentVirtualMediaState != nil { return fmt.Errorf("another virtual media is already mounted") } @@ -321,16 +341,15 @@ func rpcMountWithStorage(filename string, mode VirtualMediaMode) error { fullPath := filepath.Join(imagesFolder, filename) fileInfo, err := os.Stat(fullPath) if err != nil { - return fmt.Errorf("failed to get file info: %w", err) + return fmt.Errorf("failed to get file info for %s: %w", fullPath, err) } if err := setMassStorageMode(mode == CDROM); err != nil { - return fmt.Errorf("failed to set mass storage mode: %w", err) + return fmt.Errorf("failed to set mass storage mode %s: %w", mode, err) } - err = setMassStorageImage(fullPath) - if err != nil { - return fmt.Errorf("failed to set mass storage image: %w", err) + if err := setMassStorageImage(fullPath); err != nil { + return fmt.Errorf("failed to set mass storage image to %s: %w", fullPath, err) } currentVirtualMediaState = &VirtualMediaState{ Source: Storage, @@ -350,7 +369,7 @@ func rpcGetStorageSpace() (*StorageSpace, error) { var stat syscall.Statfs_t err := syscall.Statfs(imagesFolder, &stat) if err != nil { - return nil, fmt.Errorf("failed to get storage stats: %v", err) + return nil, fmt.Errorf("failed to get storage stats: %w", err) } totalSpace := stat.Blocks * uint64(stat.Bsize) @@ -376,7 +395,7 @@ type StorageFiles struct { func rpcListStorageFiles() (*StorageFiles, error) { files, err := os.ReadDir(imagesFolder) if err != nil { - return nil, fmt.Errorf("failed to read directory: %v", err) + return nil, fmt.Errorf("failed to read directory: %w", err) } storageFiles := make([]StorageFile, 0) @@ -387,7 +406,7 @@ func rpcListStorageFiles() (*StorageFiles, error) { info, err := file.Info() if err != nil { - return nil, fmt.Errorf("failed to get file info: %v", err) + return nil, fmt.Errorf("failed to get file info: %w", err) } storageFiles = append(storageFiles, StorageFile{ @@ -415,18 +434,18 @@ func sanitizeFilename(filename string) (string, error) { func rpcDeleteStorageFile(filename string) error { sanitizedFilename, err := sanitizeFilename(filename) if err != nil { - return err + return fmt.Errorf("failed to sanitize filename %s: %w", filename, err) } fullPath := filepath.Join(imagesFolder, sanitizedFilename) if _, err := os.Stat(fullPath); os.IsNotExist(err) { - return fmt.Errorf("file does not exist: %s", filename) + return fmt.Errorf("file %s does not exist: %w", fullPath, err) } err = os.Remove(fullPath) if err != nil { - return fmt.Errorf("failed to delete file: %v", err) + return fmt.Errorf("failed to delete file %s: %w", fullPath, err) } return nil @@ -442,7 +461,7 @@ const uploadIdPrefix = "upload_" func rpcStartStorageFileUpload(filename string, size int64) (*StorageFileUpload, error) { sanitizedFilename, err := sanitizeFilename(filename) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to sanitize filename %s: %w", filename, err) } filePath := path.Join(imagesFolder, sanitizedFilename) @@ -460,7 +479,7 @@ func rpcStartStorageFileUpload(filename string, size int64) (*StorageFileUpload, uploadId := uploadIdPrefix + uuid.New().String() file, err := os.OpenFile(uploadPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) if err != nil { - return nil, fmt.Errorf("failed to open file for upload: %v", err) + return nil, fmt.Errorf("failed to open file %s for upload: %w", uploadPath, err) } pendingUploadsMutex.Lock() pendingUploads[uploadId] = pendingUpload{ diff --git a/web.go b/web.go index 667dee1c3..0640300fa 100644 --- a/web.go +++ b/web.go @@ -22,6 +22,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/jetkvm/kvm/internal/logging" + "github.com/jetkvm/kvm/internal/utils" "github.com/pion/webrtc/v4" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" @@ -233,7 +234,7 @@ func handleWebRTCSession(c *gin.Context) { } // Cancel any ongoing keyboard macro when session changes - cancelKeyboardMacro() + _ = cancelKeyboardMacro() currentSession = session c.JSON(http.StatusOK, gin.H{"sd": sd}) @@ -247,21 +248,22 @@ var ( func handleLocalWebRTCSignal(c *gin.Context) { // get the source from the request source := c.ClientIP() + sourceType := "local" connectionID := uuid.New().String() - scopedLogger := websocketLogger.With(). + logger := websocketLogger.With(). Str("component", "websocket"). Str("source", source). - Str("sourceType", "local"). + Str("sourceType", sourceType). Logger() - scopedLogger.Info().Msg("new websocket connection established") + logger.Info().Msg("new websocket connection established") // Create WebSocket options with InsecureSkipVerify to bypass origin check wsOptions := &websocket.AcceptOptions{ InsecureSkipVerify: true, // Allow connections from any origin OnPingReceived: func(ctx context.Context, payload []byte) bool { - scopedLogger.Debug().Bytes("payload", payload).Msg("ping frame received") + logger.Debug().Object("payload", utils.ByteSlice(payload)).Msg("ping frame received") metricConnectionTotalPingReceivedCount.WithLabelValues("local", source).Inc() metricConnectionLastPingReceivedTimestamp.WithLabelValues("local", source).SetToCurrentTime() @@ -275,17 +277,16 @@ func handleLocalWebRTCSignal(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - - // Now use conn for websocket operations defer wsCon.Close(websocket.StatusNormalClosure, "") + // Now use conn for websocket operations err = wsjson.Write(context.Background(), wsCon, gin.H{"type": "device-metadata", "data": gin.H{"deviceVersion": builtAppVersion}}) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return } - err = handleWebRTCSignalWsMessages(wsCon, false, source, connectionID, &scopedLogger) + err = handleWebRTCSignalWsMessages(wsCon, false, source, connectionID, &logger) if err != nil { c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()}) return @@ -297,7 +298,7 @@ func handleWebRTCSignalWsMessages( isCloudConnection bool, source string, connectionID string, - scopedLogger *zerolog.Logger, + l *zerolog.Logger, ) error { runCtx, cancelRun := context.WithCancel(context.Background()) defer func() { @@ -315,13 +316,13 @@ func handleWebRTCSignalWsMessages( sourceType = "local" } - l := scopedLogger.With(). + logger := l.With(). Str("source", source). Str("sourceType", sourceType). Str("connectionID", connectionID). Logger() - l.Info().Msg("new websocket connection established") + logger.Info().Msg("new websocket connection established") go func() { for { @@ -329,9 +330,9 @@ func handleWebRTCSignalWsMessages( if ctxErr := runCtx.Err(); ctxErr != nil { if !errors.Is(ctxErr, context.Canceled) { - l.Warn().Str("error", ctxErr.Error()).Msg("websocket connection closed") + logger.Warn().Err(ctxErr).Msg("websocket connection closed") } else { - l.Trace().Str("error", ctxErr.Error()).Msg("websocket connection closed as the context was canceled") + logger.Trace().Err(ctxErr).Msg("websocket connection closed as the context was canceled") } return } @@ -342,10 +343,10 @@ func handleWebRTCSignalWsMessages( metricConnectionPingDuration.WithLabelValues(sourceType, source).Observe(v) })) - l.Trace().Msg("sending ping frame") + logger.Trace().Msg("sending ping frame") err := wsCon.Ping(runCtx) if err != nil { - l.Warn().Str("error", err.Error()).Msg("websocket ping error") + logger.Warn().Err(err).Msg("websocket ping error") cancelRun() return } @@ -356,7 +357,7 @@ func handleWebRTCSignalWsMessages( metricConnectionTotalPingSentCount.WithLabelValues(sourceType, source).Inc() metricConnectionLastPingTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime() - l.Trace().Str("duration", duration.String()).Msg("received pong frame") + logger.Trace().Dur("duration", duration).Msg("received pong frame") } }() @@ -381,7 +382,7 @@ func handleWebRTCSignalWsMessages( for { typ, msg, err := wsCon.Read(runCtx) if err != nil { - l.Warn().Str("error", err.Error()).Msg("websocket read error") + logger.Warn().Err(err).Msg("websocket read error") return err } if typ != websocket.MessageText { @@ -395,70 +396,69 @@ func handleWebRTCSignalWsMessages( } if bytes.Equal(msg, pingMessage) { - l.Info().Str("message", string(msg)).Msg("ping message received") + logger.Info().Str("message", string(msg)).Msg("ping message received") err = wsCon.Write(context.Background(), websocket.MessageText, pongMessage) if err != nil { - l.Warn().Str("error", err.Error()).Msg("unable to write pong message") + logger.Warn().Err(err).Msg("unable to write pong message") return err } metricConnectionTotalPingReceivedCount.WithLabelValues(sourceType, source).Inc() metricConnectionLastPingReceivedTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime() - continue } err = json.Unmarshal(msg, &message) if err != nil { - l.Warn().Str("error", err.Error()).Msg("unable to parse ws message") + logger.Warn().Err(err).Msg("unable to parse ws message") continue } if message.Type == "offer" { - l.Info().Msg("new session request received") + logger.Info().Msg("new session request received") var req WebRTCSessionRequest err = json.Unmarshal(message.Data, &req) if err != nil { - l.Warn().Str("error", err.Error()).Msg("unable to parse session request data") + logger.Warn().Err(err).Msg("unable to parse session request data") continue } if req.OidcGoogle != "" { - l.Info().Str("oidcGoogle", req.OidcGoogle).Msg("new session request with OIDC Google") + logger.Info().Str("oidcGoogle", req.OidcGoogle).Msg("new session request with OIDC Google") } metricConnectionSessionRequestCount.WithLabelValues(sourceType, source).Inc() metricConnectionLastSessionRequestTimestamp.WithLabelValues(sourceType, source).SetToCurrentTime() - err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection, source, &l) + err = handleSessionRequest(runCtx, wsCon, req, isCloudConnection, source, &logger) if err != nil { - l.Warn().Str("error", err.Error()).Msg("error starting new session") + logger.Warn().Err(err).Msg("error starting new session") continue } } else if message.Type == "new-ice-candidate" { - l.Info().Str("data", string(message.Data)).Msg("The client sent us a new ICE candidate") + logger.Info().Str("data", string(message.Data)).Msg("The client sent us a new ICE candidate") var candidate webrtc.ICECandidateInit // Attempt to unmarshal as a ICECandidateInit if err := json.Unmarshal(message.Data, &candidate); err != nil { - l.Warn().Str("error", err.Error()).Msg("unable to parse incoming ICE candidate data") + logger.Warn().Err(err).Msg("unable to parse incoming ICE candidate data") continue } if candidate.Candidate == "" { - l.Warn().Msg("empty incoming ICE candidate, skipping") + logger.Warn().Msg("empty incoming ICE candidate, skipping") continue } - l.Info().Str("data", fmt.Sprintf("%v", candidate)).Msg("unmarshalled incoming ICE candidate") + logger.Info().Str("data", fmt.Sprintf("%v", candidate)).Msg("unmarshalled incoming ICE candidate") if currentSession == nil { - l.Warn().Msg("no current session, skipping incoming ICE candidate") + logger.Warn().Msg("no current session, skipping incoming ICE candidate") continue } - l.Info().Str("data", fmt.Sprintf("%v", candidate)).Msg("adding incoming ICE candidate to current session") + logger.Info().Str("data", fmt.Sprintf("%v", candidate)).Msg("adding incoming ICE candidate to current session") if err = currentSession.peerConnection.AddICECandidate(candidate); err != nil { - l.Warn().Str("error", err.Error()).Msg("failed to add incoming ICE candidate to our peer connection") + logger.Warn().Err(err).Msg("failed to add incoming ICE candidate to our peer connection") } } } @@ -822,7 +822,7 @@ func handleSendWOLMagicPacket(c *gin.Context) { macAddrString := macAddr.String() err = rpcSendWOLMagicPacket(macAddrString) if err != nil { - logger.Warn().Err(err).Str("macAddrString", macAddrString).Msg("Failed to send WOL magic packet") + logger.Warn().Err(err).MACAddr("macAddr", macAddr).Msg("Failed to send WOL magic packet") c.String(http.StatusInternalServerError, "Failed to send WOL to %s: %v", macAddrString, err) return } diff --git a/web_tls.go b/web_tls.go index 41f532ea9..6d019ea8a 100644 --- a/web_tls.go +++ b/web_tls.go @@ -10,6 +10,7 @@ import ( "sync" "github.com/jetkvm/kvm/internal/websecure" + "github.com/rs/zerolog" ) const ( @@ -33,17 +34,17 @@ type TLSState struct { PrivateKey string `json:"privateKey"` } -func initCertStore() { +func initCertStore(logger *zerolog.Logger) { if certStore != nil { websecureLogger.Warn().Msg("TLS store already initialized, it should not be initialized again") return } - certStore = websecure.NewCertStore(tlsStorePath, websecureLogger) + certStore = websecure.NewCertStore(tlsStorePath, logger) certStore.LoadCertificates() certSigner = websecure.NewSelfSigner( certStore, - websecureLogger, + logger, webSecureSelfSignedDefaultDomain, webSecureSelfSignedOrganization, webSecureSelfSignedOU, @@ -109,7 +110,7 @@ func setTLSState(s TLSState) error { } // parse pem to cert and key if certStore == nil { - initCertStore() + initCertStore(logger) } err, _ := certStore.ValidateAndSaveCertificate(webSecureCustomCertificateName, s.Certificate, s.PrivateKey, true) // warn doesn't matter as ... we don't know the hostname yet @@ -170,7 +171,6 @@ func runWebSecureServer() { GetCertificate: getCertificate, }, } - websecureLogger.Info().Str("listen", webSecureListen).Msg("Starting websecure server") go func() { for range stopTLS { @@ -182,6 +182,7 @@ func runWebSecureServer() { } }() + websecureLogger.Info().Str("listen", webSecureListen).Msg("Starting websecure server") err := server.ListenAndServeTLS("", "") if !errors.Is(err, http.ErrServerClosed) { panic(err) @@ -207,8 +208,9 @@ func startWebSecureServer() { func RunWebSecureServer() { for range startTLS { websecureLogger.Info().Msg("Starting websecure server, as we have received a start signal") + if certStore == nil { - initCertStore() + initCertStore(websecureLogger) } go runWebSecureServer() } diff --git a/webrtc.go b/webrtc.go index 10c43ddf1..907c037c0 100644 --- a/webrtc.go +++ b/webrtc.go @@ -7,16 +7,19 @@ import ( "net" "strings" "sync" + "sync/atomic" "time" - "github.com/coder/websocket" - "github.com/coder/websocket/wsjson" - "github.com/gin-gonic/gin" "github.com/jetkvm/kvm/internal/hidrpc" "github.com/jetkvm/kvm/internal/logging" "github.com/jetkvm/kvm/internal/usbgadget" - "github.com/pion/webrtc/v4" + "github.com/jetkvm/kvm/internal/utils" "github.com/rs/zerolog" + + "github.com/coder/websocket" + "github.com/coder/websocket/wsjson" + "github.com/gin-gonic/gin" + "github.com/pion/webrtc/v4" ) type Session struct { @@ -39,39 +42,25 @@ type Session struct { keysDownStateQueue chan usbgadget.KeysDownState } -var ( - actionSessions int = 0 - activeSessionsMutex = &sync.Mutex{} -) - -func incrActiveSessions() int { - activeSessionsMutex.Lock() - defer activeSessionsMutex.Unlock() +var activeSessions atomic.Int32 - actionSessions++ - return actionSessions +func incrActiveSessions() int32 { + return activeSessions.Add(1) } -func decrActiveSessions() int { - activeSessionsMutex.Lock() - defer activeSessionsMutex.Unlock() - - actionSessions-- - return actionSessions +func decrActiveSessions() int32 { + return activeSessions.Add(-1) } -func getActiveSessions() int { - activeSessionsMutex.Lock() - defer activeSessionsMutex.Unlock() - - return actionSessions +func getActiveSessions() int32 { + return activeSessions.Load() } func (s *Session) resetKeepAliveTime() { s.keepAliveJitterLock.Lock() - defer s.keepAliveJitterLock.Unlock() s.lastKeepAliveArrivalTime = time.Time{} // Reset keep-alive timing tracking s.lastTimerResetTime = time.Time{} // Reset auto-release timer tracking + s.keepAliveJitterLock.Unlock() } type hidQueueMessage struct { @@ -121,20 +110,98 @@ func (s *Session) ExchangeOffer(offerStr string) (string, error) { return base64.StdEncoding.EncodeToString(localDescription), nil } +func (s *Session) startupSession() { + s.rpcQueue = make(chan webrtc.DataChannelMessage, 256) + s.initQueues() + s.initKeysDownStateQueue() + + go func() { + for msg := range s.rpcQueue { + // TODO: only use goroutine if the task is asynchronous + go onRPCMessage(msg, s) + } + }() + + for i := 0; i < len(s.hidQueue); i++ { + go s.handleQueue(i) + } +} + func (s *Session) initQueues() { s.hidQueueLock.Lock() defer s.hidQueueLock.Unlock() - s.hidQueue = make([]chan hidQueueMessage, 0) - for i := 0; i < 4; i++ { + s.hidQueue = make([]chan hidQueueMessage, hidrpc.MaximumQueues) + for i := 0; i < hidrpc.MaximumQueues; i++ { q := make(chan hidQueueMessage, 256) - s.hidQueue = append(s.hidQueue, q) + s.hidQueue[i] = q } } -func (s *Session) handleQueues(index int) { +func (s *Session) shutdownSession() { + // Stop RPC processor + if s.rpcQueue != nil { + close(s.rpcQueue) + s.rpcQueue = nil + } + + // Stop HID RPC processors + if s.hidQueue != nil { + for i := 0; i < len(s.hidQueue); i++ { + if s.hidQueue[i] != nil { + close(s.hidQueue[i]) + s.hidQueue[i] = nil + } + } + s.hidQueue = nil + } + + if s.keysDownStateQueue != nil { + close(s.keysDownStateQueue) + s.keysDownStateQueue = nil + } + + if s.shouldUmountVirtualMedia { + go func() { + if err := rpcUnmountImage(); err != nil { + logger.Warn().Err(err).Msg("unmount image failed on connection close") + } + }() + } + + if s.ControlChannel != nil { + go s.ControlChannel.GracefulClose() + s.ControlChannel = nil + } + + if s.RPCChannel != nil { + go s.RPCChannel.GracefulClose() + s.RPCChannel = nil + } + + if s.HidChannel != nil { + go s.HidChannel.GracefulClose() + s.HidChannel = nil + } + + s.hidRPCAvailable = false + + // TODO what about the other channels? + + if s.VideoTrack != nil { + // there's no Close() on this, just set to nil + s.VideoTrack = nil + } + + if s.peerConnection != nil { + go s.peerConnection.GracefulClose() + s.peerConnection = nil + } +} + +func (s *Session) handleQueue(index int) { for msg := range s.hidQueue[index] { - onHidMessage(msg, s) + onHidMessage(msg, s, index) } } @@ -160,37 +227,38 @@ func (s *Session) enqueueKeysDownState(state usbgadget.KeysDownState) { select { case s.keysDownStateQueue <- state: default: - hidRPCLogger.Warn().Msg("dropping keys down state update; queue full") + hidRPCLogger.Error().Msg("dropping keys down state update; queue full") } } -func getOnHidMessageHandler(session *Session, scopedLogger *zerolog.Logger, channel string) func(msg webrtc.DataChannelMessage) { +func getOnHidMessageHandler(session *Session, l *zerolog.Logger, channel string) func(msg webrtc.DataChannelMessage) { return func(msg webrtc.DataChannelMessage) { - l := scopedLogger.With(). - Str("channel", channel). - Int("length", len(msg.Data)). - Logger() - // only log data if the log level is debug or lower - if scopedLogger.GetLevel() > zerolog.DebugLevel { - l = l.With().Str("data", string(msg.Data)).Logger() - } + logger := l.With().Str("channel", channel).Interface("msg", msg).Logger() if msg.IsString { - l.Warn().Msg("received string data in HID RPC message handler") + logger.Warn().Msg("received string data in HID RPC message handler") return } - if len(msg.Data) < 1 { - l.Warn().Msg("received empty data in HID RPC message handler") - return + dataLength := len(msg.Data) + logger = logger.With().Int("length", dataLength).Logger() + + // only log data if the log level is debug or lower + if logger.GetLevel() <= zerolog.DebugLevel { + logger = logger.With().Object("data", utils.ByteSlice(msg.Data)).Logger() } - l.Trace().Msg("received data in HID RPC message handler") + if dataLength < 1 { + logger.Warn().Msg("received empty data in HID RPC message handler") + return + } // Enqueue to ensure ordered processing queueIndex := hidrpc.GetQueueIndex(hidrpc.MessageType(msg.Data[0])) + logger = logger.With().Int("queueIndex", queueIndex).Logger() + if queueIndex >= len(session.hidQueue) || queueIndex < 0 { - l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue index not found") + logger.Warn().Msg("received data in HID RPC message handler, but queue index not found") queueIndex = 3 } @@ -200,8 +268,9 @@ func getOnHidMessageHandler(session *Session, scopedLogger *zerolog.Logger, chan DataChannelMessage: msg, channel: channel, } + logger.Trace().Msg("queued HID RPC message") } else { - l.Warn().Int("queueIndex", queueIndex).Msg("received data in HID RPC message handler, but queue is nil") + logger.Warn().Msg("received data in HID RPC message handler, but queue is nil") return } } @@ -213,27 +282,25 @@ func newSession(config SessionConfig) (*Session, error) { } iceServer := webrtc.ICEServer{} - var scopedLogger *zerolog.Logger + var logger = webrtcLogger if config.Logger != nil { l := config.Logger.With().Str("component", "webrtc").Logger() - scopedLogger = &l - } else { - scopedLogger = webrtcLogger + logger = &l } if config.IsCloud { if config.ICEServers == nil { - scopedLogger.Info().Msg("ICE Servers not provided by cloud") + logger.Info().Msg("ICE Servers not provided by cloud") } else { iceServer.URLs = config.ICEServers - scopedLogger.Info().Interface("iceServers", iceServer.URLs).Msg("Using ICE Servers provided by cloud") + logger.Info().Strs("iceServers", iceServer.URLs).Msg("Using ICE Servers provided by cloud") } if config.LocalIP == "" || net.ParseIP(config.LocalIP) == nil { - scopedLogger.Info().Str("localIP", config.LocalIP).Msg("Local IP address not provided or invalid, won't set NAT1To1IPs") + logger.Info().Str("localIP", config.LocalIP).Msg("Local IP address not provided or invalid, won't set NAT1To1IPs") } else { webrtcSettingEngine.SetNAT1To1IPs([]string{config.LocalIP}, webrtc.ICECandidateTypeSrflx) - scopedLogger.Info().Str("localIP", config.LocalIP).Msg("Setting NAT1To1IPs") + logger.Info().Str("localIP", config.LocalIP).Msg("Setting NAT1To1IPs") } } @@ -242,44 +309,32 @@ func newSession(config SessionConfig) (*Session, error) { ICEServers: []webrtc.ICEServer{iceServer}, }) if err != nil { - scopedLogger.Warn().Err(err).Msg("Failed to create PeerConnection") + logger.Warn().Err(err).Msg("Failed to create PeerConnection") return nil, err } session := &Session{peerConnection: peerConnection} - session.rpcQueue = make(chan webrtc.DataChannelMessage, 256) - session.initQueues() - session.initKeysDownStateQueue() - - go func() { - for msg := range session.rpcQueue { - // TODO: only use goroutine if the task is asynchronous - go onRPCMessage(msg, session) - } - }() - - for i := 0; i < len(session.hidQueue); i++ { - go session.handleQueues(i) - } + session.startupSession() peerConnection.OnDataChannel(func(d *webrtc.DataChannel) { defer func() { if r := recover(); r != nil { - scopedLogger.Error().Interface("error", r).Msg("Recovered from panic in DataChannel handler") + logger.Error().Interface("recovered", r).Msg("Recovered from panic in DataChannel handler") } }() - scopedLogger.Info().Str("label", d.Label()).Uint16("id", *d.ID()).Msg("New DataChannel") + logger := logger.With().Str("label", d.Label()).Uint16("id", *d.ID()).Logger() + logger.Info().Msg("New DataChannel") switch d.Label() { case "hidrpc": session.HidChannel = d - d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc")) + d.OnMessage(getOnHidMessageHandler(session, &logger, "hidrpc")) // we won't send anything over the unreliable channels case "hidrpc-unreliable-ordered": - d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc-unreliable-ordered")) + d.OnMessage(getOnHidMessageHandler(session, &logger, "hidrpc-unreliable-ordered")) case "hidrpc-unreliable-nonordered": - d.OnMessage(getOnHidMessageHandler(session, scopedLogger, "hidrpc-unreliable-nonordered")) + d.OnMessage(getOnHidMessageHandler(session, &logger, "hidrpc-unreliable-nonordered")) case "rpc": session.RPCChannel = d d.OnMessage(func(msg webrtc.DataChannelMessage) { @@ -306,13 +361,13 @@ func newSession(config SessionConfig) (*Session, error) { session.VideoTrack, err = webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{MimeType: webrtc.MimeTypeH264}, "video", "kvm") if err != nil { - scopedLogger.Warn().Err(err).Msg("Failed to create VideoTrack") + logger.Warn().Err(err).Msg("Failed to create VideoTrack") return nil, err } rtpSender, err := peerConnection.AddTrack(session.VideoTrack) if err != nil { - scopedLogger.Warn().Err(err).Msg("Failed to add VideoTrack to PeerConnection") + logger.Warn().Err(err).Msg("Failed to add VideoTrack to PeerConnection") return nil, err } @@ -330,63 +385,49 @@ func newSession(config SessionConfig) (*Session, error) { var isConnected bool peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) { - scopedLogger.Info().Interface("candidate", candidate).Msg("WebRTC peerConnection has a new ICE candidate") + logger := logger.With().Interface("candidate", candidate).Logger() + logger.Info().Msg("WebRTC peerConnection has a new ICE candidate") if candidate != nil { err := wsjson.Write(context.Background(), config.ws, gin.H{"type": "new-ice-candidate", "data": candidate.ToJSON()}) if err != nil { - scopedLogger.Warn().Err(err).Msg("failed to write new-ice-candidate to WebRTC signaling channel") + logger.Warn().Err(err).Msg("failed to write new-ice-candidate to WebRTC signaling channel") } } }) peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { - scopedLogger.Info().Str("connectionState", connectionState.String()).Msg("ICE Connection State has changed") + logger := logger.With().Stringer("connectionState", connectionState).Logger() + logger.Info().Msg("ICE Connection State has changed") if connectionState == webrtc.ICEConnectionStateConnected { if !isConnected { isConnected = true onActiveSessionsChanged() if incrActiveSessions() == 1 { + logger.Info().Msg("first session connected, starting video stream") onFirstSessionConnected() } } } //state changes on closing browser tab disconnected->failed, we need to manually close it if connectionState == webrtc.ICEConnectionStateFailed { - scopedLogger.Debug().Msg("ICE Connection State is failed, closing peerConnection") + logger.Debug().Msg("ICE Connection State is failed, closing peerConnection") _ = peerConnection.Close() } if connectionState == webrtc.ICEConnectionStateClosed { - scopedLogger.Debug().Msg("ICE Connection State is closed, unmounting virtual media") + logger.Debug().Msg("ICE Connection State is closed, shutting down session") if session == currentSession { - // Cancel any ongoing keyboard report multi when session closes - cancelKeyboardMacro() + // Cancel any ongoing keyboard macro when session closes + _ = cancelKeyboardMacro() currentSession = nil } - // Stop RPC processor - if session.rpcQueue != nil { - close(session.rpcQueue) - session.rpcQueue = nil - } - // Stop HID RPC processor - for i := 0; i < len(session.hidQueue); i++ { - close(session.hidQueue[i]) - session.hidQueue[i] = nil - } - - close(session.keysDownStateQueue) - session.keysDownStateQueue = nil + session.shutdownSession() - if session.shouldUmountVirtualMedia { - if err := rpcUnmountImage(); err != nil { - scopedLogger.Warn().Err(err).Msg("unmount image failed on connection close") - } - } if isConnected { isConnected = false onActiveSessionsChanged() if decrActiveSessions() == 0 { - scopedLogger.Info().Msg("last session disconnected, stopping video stream") + logger.Info().Msg("last session disconnected, stopping video stream") onLastSessionDisconnected() } } diff --git a/wol.go b/wol.go index c3d0de2d0..1343be78c 100644 --- a/wol.go +++ b/wol.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "net" + "github.com/jetkvm/kvm/internal/logging" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) @@ -30,7 +31,7 @@ func rpcSendWOLMagicPacket(macAddress string) error { mac, err := net.ParseMAC(macAddress) if err != nil { wolErrors.Inc() - return ErrorfL(wolLogger, "invalid MAC address", err) + return logging.ErrorfL(wolLogger, "invalid MAC address", err) } // Create the magic packet @@ -40,7 +41,7 @@ func rpcSendWOLMagicPacket(macAddress string) error { conn, err := net.Dial("udp", "255.255.255.255:9") if err != nil { wolErrors.Inc() - return ErrorfL(wolLogger, "failed to establish UDP connection", err) + return logging.ErrorfL(wolLogger, "failed to establish UDP connection", err) } defer conn.Close() @@ -48,7 +49,7 @@ func rpcSendWOLMagicPacket(macAddress string) error { _, err = conn.Write(packet) if err != nil { wolErrors.Inc() - return ErrorfL(wolLogger, "failed to send WOL packet", err) + return logging.ErrorfL(wolLogger, "failed to send WOL packet", err) } wolLogger.Info().Str("mac", macAddress).Msg("WOL packet sent")