Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 162 additions & 0 deletions hidrpc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
package kvm

import (
"fmt"
"time"

"github.com/jetkvm/kvm/internal/hidrpc"
"github.com/jetkvm/kvm/internal/usbgadget"
)

func handleHidRPCMessage(message hidrpc.Message, session *Session) {
var rpcErr error

switch message.Type() {
case hidrpc.TypeHandshake:
message, err := hidrpc.NewHandshakeMessage().Marshal()
if err != nil {
logger.Warn().Err(err).Msg("failed to marshal handshake message")
return
}
if err := session.HidChannel.Send(message); err != nil {
logger.Warn().Err(err).Msg("failed to send handshake message")
return
}
session.hidRPCAvailable = true
case hidrpc.TypeKeypressReport, hidrpc.TypeKeyboardReport:
keysDownState, err := handleHidRPCKeyboardInput(message)
if keysDownState != nil {
session.reportHidRPCKeysDownState(*keysDownState)
}
rpcErr = err
case hidrpc.TypePointerReport:
pointerReport, err := message.PointerReport()
if err != nil {
logger.Warn().Err(err).Msg("failed to get pointer report")
return
}
rpcErr = rpcAbsMouseReport(pointerReport.X, pointerReport.Y, pointerReport.Button)
case hidrpc.TypeMouseReport:
mouseReport, err := message.MouseReport()
if err != nil {
logger.Warn().Err(err).Msg("failed to get mouse report")
return
}
rpcErr = rpcRelMouseReport(mouseReport.DX, mouseReport.DY, mouseReport.Button)
default:
logger.Warn().Uint8("type", uint8(message.Type())).Msg("unknown HID RPC message type")
}

if rpcErr != nil {
logger.Warn().Err(rpcErr).Msg("failed to handle HID RPC message")
}
}

func onHidMessage(data []byte, session *Session) {
scopedLogger := hidRPCLogger.With().Bytes("data", data).Logger()
scopedLogger.Debug().Msg("HID RPC message received")

if len(data) < 1 {
scopedLogger.Warn().Int("length", len(data)).Msg("received empty data in HID RPC message handler")
return
}

var message hidrpc.Message

if err := hidrpc.Unmarshal(data, &message); err != nil {
scopedLogger.Warn().Err(err).Msg("failed to unmarshal HID RPC message")
return
}

scopedLogger = scopedLogger.With().Str("descr", message.String()).Logger()

t := time.Now()

r := make(chan interface{})
go func() {
handleHidRPCMessage(message, session)
r <- nil
}()
select {
case <-time.After(1 * time.Second):
scopedLogger.Warn().Msg("HID RPC message timed out")
case <-r:
scopedLogger.Debug().Dur("duration", time.Since(t)).Msg("HID RPC message handled")
}
}

func handleHidRPCKeyboardInput(message hidrpc.Message) (*usbgadget.KeysDownState, error) {
switch message.Type() {
case hidrpc.TypeKeypressReport:
keypressReport, err := message.KeypressReport()
if err != nil {
logger.Warn().Err(err).Msg("failed to get keypress report")
return nil, err
}
keysDownState, rpcError := rpcKeypressReport(keypressReport.Key, keypressReport.Press)
return &keysDownState, rpcError
case hidrpc.TypeKeyboardReport:
keyboardReport, err := message.KeyboardReport()
if err != nil {
logger.Warn().Err(err).Msg("failed to get keyboard report")
return nil, err
}
keysDownState, rpcError := rpcKeyboardReport(keyboardReport.Modifier, keyboardReport.Keys)
return &keysDownState, rpcError
}

return nil, fmt.Errorf("unknown HID RPC message type: %d", message.Type())
}

func reportHidRPC(params any, session *Session) {
if session == nil {
logger.Warn().Msg("session is nil, skipping reportHidRPC")
return
}

if !session.hidRPCAvailable || session.HidChannel == nil {
logger.Warn().Msg("HID RPC is not available, skipping reportHidRPC")
return
}

var (
message []byte
err error
)
switch params := params.(type) {
case usbgadget.KeyboardState:
message, err = hidrpc.NewKeyboardLedMessage(params).Marshal()
case usbgadget.KeysDownState:
message, err = hidrpc.NewKeydownStateMessage(params).Marshal()
default:
err = fmt.Errorf("unknown HID RPC message type: %T", params)
}

if err != nil {
logger.Warn().Err(err).Msg("failed to marshal HID RPC message")
return
}

if message == nil {
logger.Warn().Msg("failed to marshal HID RPC message")
return
}

if err := session.HidChannel.Send(message); err != nil {
logger.Warn().Err(err).Msg("failed to send HID RPC message")
}
}

func (s *Session) reportHidRPCKeyboardLedState(state usbgadget.KeyboardState) {
if !s.hidRPCAvailable {
writeJSONRPCEvent("keyboardLedState", state, s)
}
reportHidRPC(state, s)
}

func (s *Session) reportHidRPCKeysDownState(state usbgadget.KeysDownState) {
if !s.hidRPCAvailable {
writeJSONRPCEvent("keysDownState", state, s)
}
reportHidRPC(state, s)
}
100 changes: 100 additions & 0 deletions internal/hidrpc/hidrpc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package hidrpc

import (
"fmt"

"github.com/jetkvm/kvm/internal/usbgadget"
)

// MessageType is the type of the HID RPC message
type MessageType byte

const (
TypeHandshake MessageType = 0x01
TypeKeyboardReport MessageType = 0x02
TypePointerReport MessageType = 0x03
TypeWheelReport MessageType = 0x04
TypeKeypressReport MessageType = 0x05
TypeMouseReport MessageType = 0x06
TypeKeyboardLedState MessageType = 0x32
TypeKeydownState MessageType = 0x33
)

const (
Version byte = 0x01 // Version of the HID RPC protocol
)

// GetQueueIndex returns the index of the queue to which the message should be enqueued.
func GetQueueIndex(messageType MessageType) int {
switch messageType {
case TypeHandshake:
return 0
case TypeKeyboardReport, TypeKeypressReport, TypeKeyboardLedState, TypeKeydownState:
return 1
case TypePointerReport, TypeMouseReport, TypeWheelReport:
return 2
default:
return 3
}
}

// Unmarshal unmarshals the HID RPC message from the data.
func Unmarshal(data []byte, message *Message) error {
l := len(data)
if l < 1 {
return fmt.Errorf("invalid data length: %d", l)
}

message.t = MessageType(data[0])
message.d = data[1:]
return nil
}

// Marshal marshals the HID RPC message to the data.
func Marshal(message *Message) ([]byte, error) {
if message.t == 0 {
return nil, fmt.Errorf("invalid message type: %d", message.t)
}

data := make([]byte, len(message.d)+1)
data[0] = byte(message.t)
copy(data[1:], message.d)

return data, nil
}

// NewHandshakeMessage creates a new handshake message.
func NewHandshakeMessage() *Message {
return &Message{
t: TypeHandshake,
d: []byte{Version},
}
}

// NewKeyboardReportMessage creates a new keyboard report message.
func NewKeyboardReportMessage(keys []byte, modifier uint8) *Message {
return &Message{
t: TypeKeyboardReport,
d: append([]byte{modifier}, keys...),
}
}

// NewKeyboardLedMessage creates a new keyboard LED message.
func NewKeyboardLedMessage(state usbgadget.KeyboardState) *Message {
return &Message{
t: TypeKeyboardLedState,
d: []byte{state.Byte()},
}
}

// NewKeydownStateMessage creates a new keydown state message.
func NewKeydownStateMessage(state usbgadget.KeysDownState) *Message {
data := make([]byte, len(state.Keys)+1)
data[0] = state.Modifier
copy(data[1:], state.Keys)

return &Message{
t: TypeKeydownState,
d: data,
}
}
133 changes: 133 additions & 0 deletions internal/hidrpc/message.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
package hidrpc

import (
"fmt"
)

// Message ..
type Message struct {
t MessageType
d []byte
}

// Marshal marshals the message to a byte array.
func (m *Message) Marshal() ([]byte, error) {
return Marshal(m)
}

func (m *Message) Type() MessageType {
return m.t
}

func (m *Message) String() string {
switch m.t {
case TypeHandshake:
return "Handshake"
case TypeKeypressReport:
if len(m.d) < 2 {
return fmt.Sprintf("KeypressReport{Malformed: %v}", m.d)
}
return fmt.Sprintf("KeypressReport{Key: %d, Press: %v}", m.d[0], m.d[1] == uint8(1))
case TypeKeyboardReport:
if len(m.d) < 2 {
return fmt.Sprintf("KeyboardReport{Malformed: %v}", m.d)
}
return fmt.Sprintf("KeyboardReport{Modifier: %d, Keys: %v}", m.d[0], m.d[1:])
case TypePointerReport:
if len(m.d) < 9 {
return fmt.Sprintf("PointerReport{Malformed: %v}", m.d)
}
return fmt.Sprintf("PointerReport{X: %d, Y: %d, Button: %d}", m.d[0:4], m.d[4:8], m.d[8])
case TypeMouseReport:
if len(m.d) < 3 {
return fmt.Sprintf("MouseReport{Malformed: %v}", m.d)
}
return fmt.Sprintf("MouseReport{DX: %d, DY: %d, Button: %d}", m.d[0], m.d[1], m.d[2])
default:
return fmt.Sprintf("Unknown{Type: %d, Data: %v}", m.t, m.d)
}
}

// KeypressReport ..
type KeypressReport struct {
Key byte
Press bool
}

// KeypressReport returns the keypress report from the message.
func (m *Message) KeypressReport() (KeypressReport, error) {
if m.t != TypeKeypressReport {
return KeypressReport{}, fmt.Errorf("invalid message type: %d", m.t)
}

return KeypressReport{
Key: m.d[0],
Press: m.d[1] == uint8(1),
}, nil
}

// KeyboardReport ..
type KeyboardReport struct {
Modifier byte
Keys []byte
}

// KeyboardReport returns the keyboard report from the message.
func (m *Message) KeyboardReport() (KeyboardReport, error) {
if m.t != TypeKeyboardReport {
return KeyboardReport{}, fmt.Errorf("invalid message type: %d", m.t)
}

return KeyboardReport{
Modifier: m.d[0],
Keys: m.d[1:],
}, nil
}

// PointerReport ..
type PointerReport struct {
X int
Y int
Button uint8
}

func toInt(b []byte) int {
return int(b[0])<<24 + int(b[1])<<16 + int(b[2])<<8 + int(b[3])<<0
}

// PointerReport returns the point report from the message.
func (m *Message) PointerReport() (PointerReport, error) {
if m.t != TypePointerReport {
return PointerReport{}, fmt.Errorf("invalid message type: %d", m.t)
}

if len(m.d) != 9 {
return PointerReport{}, fmt.Errorf("invalid message length: %d", len(m.d))
}

return PointerReport{
X: toInt(m.d[0:4]),
Y: toInt(m.d[4:8]),
Button: uint8(m.d[8]),
}, nil
}

// MouseReport ..
type MouseReport struct {
DX int8
DY int8
Button uint8
}

// MouseReport returns the mouse report from the message.
func (m *Message) MouseReport() (MouseReport, error) {
if m.t != TypeMouseReport {
return MouseReport{}, fmt.Errorf("invalid message type: %d", m.t)
}

return MouseReport{
DX: int8(m.d[0]),
DY: int8(m.d[1]),
Button: uint8(m.d[2]),
}, nil
}
Loading