Skip to content

Commit e5b786c

Browse files
authored
Merge pull request #123 from NickeZ/nickez/write-all-u2fhid-packets
u2fhid: Write as many bytes as possible
2 parents ae18bcb + 8e05da0 commit e5b786c

File tree

2 files changed

+228
-24
lines changed

2 files changed

+228
-24
lines changed

communication/u2fhid/u2fhid.go

+37-24
Original file line numberDiff line numberDiff line change
@@ -95,44 +95,57 @@ func (communication *Communication) sendFrame(msg string) error {
9595
if dataLen == 0 {
9696
return nil
9797
}
98-
send := func(header []byte, readFrom *bytes.Buffer) error {
99-
buf := newBuffer()
100-
buf.Write(header)
101-
buf.Write(readFrom.Next(writeReportSize - buf.Len()))
102-
for buf.Len() < writeReportSize {
103-
buf.WriteByte(0xee)
104-
}
105-
x := buf.Bytes() // needs to be in a var: https://github.com/golang/go/issues/14210#issuecomment-346402945
106-
_, err := communication.device.Write(x)
107-
return errp.WithMessage(errp.WithStack(err), "Failed to send message")
108-
}
98+
10999
readBuffer := bytes.NewBufferString(msg)
100+
out := newBuffer()
101+
102+
// Calculate how large the `out` buffer should be. Round up to an equal
103+
// number of writeReportSize sized bytes
104+
outLen := writeReportSize
105+
initPayloadSize := writeReportSize - 7
106+
contPayloadSize := writeReportSize - 5
107+
if dataLen > initPayloadSize {
108+
contLen := dataLen - initPayloadSize
109+
outLen += ((contLen + contPayloadSize - 1) / contPayloadSize) * writeReportSize
110+
}
111+
out.Grow(outLen)
112+
110113
// init frame
111-
header := newBuffer()
112-
if err := binary.Write(header, binary.BigEndian, cid); err != nil {
114+
if err := binary.Write(out, binary.BigEndian, cid); err != nil {
113115
return errp.WithStack(err)
114116
}
115-
if err := binary.Write(header, binary.BigEndian, communication.cmd); err != nil {
117+
if err := binary.Write(out, binary.BigEndian, communication.cmd); err != nil {
116118
return errp.WithStack(err)
117119
}
118-
if err := binary.Write(header, binary.BigEndian, uint16(dataLen&0xFFFF)); err != nil {
120+
if err := binary.Write(out, binary.BigEndian, uint16(dataLen&0xFFFF)); err != nil {
119121
return errp.WithStack(err)
120122
}
121-
if err := send(header.Bytes(), readBuffer); err != nil {
122-
return err
123-
}
123+
out.Write(readBuffer.Next(initPayloadSize))
124+
125+
// cont frames
124126
for seq := 0; readBuffer.Len() > 0; seq++ {
125-
// cont frame
126-
header = newBuffer()
127-
if err := binary.Write(header, binary.BigEndian, cid); err != nil {
127+
if err := binary.Write(out, binary.BigEndian, cid); err != nil {
128128
return errp.WithStack(err)
129129
}
130-
if err := binary.Write(header, binary.BigEndian, uint8(seq)); err != nil {
130+
if err := binary.Write(out, binary.BigEndian, uint8(seq)); err != nil {
131131
return errp.WithStack(err)
132132
}
133-
if err := send(header.Bytes(), readBuffer); err != nil {
134-
return err
133+
out.Write(readBuffer.Next(contPayloadSize))
134+
}
135+
136+
// Pad to multiple of writeReportSize
137+
for range (writeReportSize - (out.Len() % writeReportSize)) % writeReportSize {
138+
out.WriteByte(0xEE)
139+
}
140+
141+
// Write out packets, write as many bytes as possible in each iteration
142+
for out.Len() > 0 {
143+
x := out.Bytes() // needs to be in a var: https://github.com/golang/go/issues/14210#issuecomment-346402945
144+
n, err := communication.device.Write(x)
145+
if err != nil {
146+
return errp.WithMessage(errp.WithStack(err), "Failed to send message")
135147
}
148+
out.Next(n)
136149
}
137150
return nil
138151
}

communication/u2fhid/u2fhid_test.go

+191
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
// Copyright 2025 Shift Cryptosecurity AG
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package u2fhid
16+
17+
import (
18+
"bytes"
19+
"encoding/binary"
20+
"strings"
21+
"testing"
22+
)
23+
24+
// testRW mocks io.ReadWriteCloser with variable write chunk handling
25+
type testRW struct {
26+
writeBuffer bytes.Buffer
27+
writeChunkSize int // max bytes to write per call
28+
}
29+
30+
func (t *testRW) Write(p []byte) (n int, err error) {
31+
if t.writeChunkSize == 0 || t.writeChunkSize > len(p) {
32+
t.writeBuffer.Write(p)
33+
return len(p), nil
34+
}
35+
36+
written := t.writeChunkSize
37+
t.writeBuffer.Write(p[:written])
38+
return written, nil
39+
}
40+
41+
func (t *testRW) Read(p []byte) (n int, err error) { return 0, nil }
42+
func (t *testRW) Close() error { return nil }
43+
44+
func generateExpectedFrame(cmd byte, message string) []byte {
45+
buf := new(bytes.Buffer)
46+
47+
// Handle empty message case
48+
if len(message) == 0 {
49+
return buf.Bytes()
50+
}
51+
52+
// Init packet
53+
binary.Write(buf, binary.BigEndian, cid)
54+
buf.WriteByte(cmd)
55+
binary.Write(buf, binary.BigEndian, uint16(len(message)))
56+
57+
// Split the message into init part and remaining
58+
remaining := message
59+
initData := remaining[:min(len(remaining), 57)]
60+
remaining = remaining[len(initData):]
61+
buf.WriteString(initData)
62+
if len(initData) < 57 {
63+
buf.Write(bytes.Repeat([]byte{0xee}, 57-len(initData)))
64+
}
65+
66+
// Continue with remaining data for continuation frames
67+
for seq := 0; len(remaining) > 0; seq++ {
68+
cont := new(bytes.Buffer)
69+
binary.Write(cont, binary.BigEndian, cid)
70+
cont.WriteByte(uint8(seq))
71+
72+
contData := remaining[:min(len(remaining), 59)]
73+
cont.WriteString(contData)
74+
remaining = remaining[len(contData):]
75+
if cont.Len() < 64 {
76+
cont.Write(bytes.Repeat([]byte{0xee}, 64-cont.Len()))
77+
}
78+
buf.Write(cont.Bytes())
79+
}
80+
81+
return buf.Bytes()
82+
}
83+
84+
func TestSendFrame(t *testing.T) {
85+
const testCMD = 0xab
86+
87+
tests := []struct {
88+
name string
89+
input string
90+
chunkSize int
91+
expectedFrames string
92+
}{
93+
{
94+
name: "empty message",
95+
input: "",
96+
chunkSize: 64,
97+
expectedFrames: "",
98+
},
99+
{
100+
name: "small message (exact init frame)",
101+
input: strings.Repeat("a", 57),
102+
chunkSize: 64,
103+
expectedFrames: strings.Repeat("a", 57),
104+
},
105+
{
106+
name: "message requiring one continuation frame",
107+
input: strings.Repeat("b", 58),
108+
chunkSize: 32,
109+
expectedFrames: strings.Repeat("b", 58),
110+
},
111+
{
112+
name: "multi-frame message with uneven writes",
113+
input: strings.Repeat("c", 500),
114+
chunkSize: 7,
115+
expectedFrames: strings.Repeat("c", 500),
116+
},
117+
{
118+
name: "boundary case with minimal writes",
119+
input: strings.Repeat("d", 64*3),
120+
chunkSize: 1,
121+
expectedFrames: strings.Repeat("d", 64*3),
122+
},
123+
}
124+
125+
for _, tt := range tests {
126+
t.Run(tt.name, func(t *testing.T) {
127+
rw := &testRW{writeChunkSize: tt.chunkSize}
128+
comm := NewCommunication(rw, testCMD)
129+
130+
err := comm.SendFrame(tt.input)
131+
if err != nil {
132+
t.Fatalf("SendFrame failed: %v", err)
133+
}
134+
135+
expected := generateExpectedFrame(testCMD, tt.input)
136+
actualCount := rw.writeBuffer.Len() / 64 * 64
137+
actual := rw.writeBuffer.Bytes()[:actualCount]
138+
139+
if !bytes.Equal(expected, actual) {
140+
t.Errorf("Frame mismatch\nExpected:\n% x\n\nGot:\n% x", expected, actual)
141+
}
142+
143+
// Verify all padding bytes
144+
totalFrames := len(expected) / 64
145+
for frameNum := range totalFrames {
146+
frameStart := frameNum * 64
147+
frameEnd := (frameNum + 1) * 64
148+
frame := expected[frameStart:frameEnd]
149+
150+
// Check CID in every frame
151+
if binary.BigEndian.Uint32(frame[0:4]) != cid {
152+
t.Error("Invalid CID in frame")
153+
}
154+
155+
if frameNum == 0 {
156+
// Init frame checks
157+
if frame[4] != testCMD {
158+
t.Error("Invalid command byte in init frame")
159+
}
160+
161+
dataLength := binary.BigEndian.Uint16(frame[5:7])
162+
if dataLength != uint16(len(tt.input)) {
163+
t.Error("Invalid data length in init frame")
164+
}
165+
} else if frame[4] != byte(frameNum-1) {
166+
// Continuation frame checks
167+
t.Error("Invalid sequence number in continuation frame")
168+
}
169+
170+
// Verify padding bytes (last byte of filled data to end)
171+
dataEnd := len(tt.input) - (57 + (frameNum-1)*59)
172+
if frameNum > 0 && dataEnd < 0 {
173+
dataEnd = 0
174+
}
175+
paddingStart := 7 + dataEnd
176+
if frameNum == 0 {
177+
paddingStart = 7 + min(len(tt.input), 57)
178+
}
179+
180+
if paddingStart < 64 {
181+
paddingBytes := frame[paddingStart:]
182+
for _, b := range paddingBytes {
183+
if b != 0xee {
184+
t.Errorf("Invalid padding byte: %02x", b)
185+
}
186+
}
187+
}
188+
}
189+
})
190+
}
191+
}

0 commit comments

Comments
 (0)