Skip to content

Commit e7067d7

Browse files
committed
Rewrite readBitsInt{Be,Le}()
See kaitai-io/kaitai_struct#949
1 parent b83000e commit e7067d7

File tree

1 file changed

+38
-33
lines changed

1 file changed

+38
-33
lines changed

src/main/java/io/kaitai/struct/KaitaiStream.java

Lines changed: 38 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -185,34 +185,35 @@ public long readU8le() {
185185
//region Unaligned bit values
186186

187187
public void alignToByte() {
188-
bits = 0;
189188
bitsLeft = 0;
189+
bits = 0;
190190
}
191191

192192
public long readBitsIntBe(int n) {
193+
long res = 0;
194+
193195
int bitsNeeded = n - bitsLeft;
196+
bitsLeft = -bitsNeeded & 7; // `-bitsNeeded mod 8`
197+
194198
if (bitsNeeded > 0) {
195199
// 1 bit => 1 byte
196200
// 8 bits => 1 byte
197201
// 9 bits => 2 bytes
198-
int bytesNeeded = ((bitsNeeded - 1) / 8) + 1;
202+
int bytesNeeded = ((bitsNeeded - 1) / 8) + 1; // `ceil(bitsNeeded / 8)`
199203
byte[] buf = readBytes(bytesNeeded);
200204
for (byte b : buf) {
201-
bits <<= 8;
202-
// b is signed byte, convert to unsigned using "& 0xff" trick
203-
bits |= (b & 0xff);
204-
bitsLeft += 8;
205+
// `b` is signed byte, convert to unsigned using the "& 0xff" trick
206+
res = res << 8 | (b & 0xff);
205207
}
208+
209+
long newBits = res;
210+
res = res >>> bitsLeft | bits << bitsNeeded;
211+
bits = newBits; // will be masked at the end of the function
212+
} else {
213+
res = bits >>> -bitsNeeded; // shift unneeded bits out
206214
}
207215

208-
// raw mask with required number of 1s, starting from lowest bit
209-
long mask = getMaskOnes(n);
210-
// shift "bits" to align the highest bits with the mask & derive the result
211-
int shiftBits = bitsLeft - n;
212-
long res = (bits >>> shiftBits) & mask;
213-
// clear top bits that we've just read => AND with 1s
214-
bitsLeft -= n;
215-
mask = getMaskOnes(bitsLeft);
216+
long mask = (1L << bitsLeft) - 1; // `bitsLeft` is in range 0..7, so `(1L << 64)` does not have to be considered
216217
bits &= mask;
217218

218219
return res;
@@ -229,36 +230,40 @@ public long readBitsInt(int n) {
229230
}
230231

231232
public long readBitsIntLe(int n) {
233+
long res = 0;
232234
int bitsNeeded = n - bitsLeft;
235+
233236
if (bitsNeeded > 0) {
234237
// 1 bit => 1 byte
235238
// 8 bits => 1 byte
236239
// 9 bits => 2 bytes
237-
int bytesNeeded = ((bitsNeeded - 1) / 8) + 1;
240+
int bytesNeeded = ((bitsNeeded - 1) / 8) + 1; // `ceil(bitsNeeded / 8)`
238241
byte[] buf = readBytes(bytesNeeded);
239-
for (byte b : buf) {
240-
bits |= ((long) (b & 0xff) << bitsLeft);
241-
bitsLeft += 8;
242+
for (int i = 0; i < bytesNeeded; i++) {
243+
// `buf[i]` is signed byte, convert to unsigned using the "& 0xff" trick
244+
res |= ((long) (buf[i] & 0xff)) << (i * 8);
242245
}
246+
247+
// NB: in Java, bit shift operators on left-hand operand of type `long` work
248+
// as if the right-hand operand were subjected to `& 63` (`& 0b11_1111`) (see
249+
// https://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.19),
250+
// so `res >>> 64` is equivalent to `res >>> 0` (but we don't want that)
251+
long newBits = bitsNeeded < 64 ? res >>> bitsNeeded : 0;
252+
res = res << bitsLeft | bits;
253+
bits = newBits;
254+
} else {
255+
res = bits;
256+
bits >>>= n;
243257
}
244258

245-
// raw mask with required number of 1s, starting from lowest bit
246-
long mask = getMaskOnes(n);
247-
// derive reading result
248-
long res = bits & mask;
249-
// remove bottom bits that we've just read by shifting
250-
bits >>>= n;
251-
bitsLeft -= n;
259+
bitsLeft = -bitsNeeded & 7; // `-bitsNeeded mod 8`
252260

253-
return res;
254-
}
255-
256-
private static long getMaskOnes(int n) {
257-
if (n == 64) {
258-
return 0xffffffffffffffffL;
259-
} else {
260-
return (1L << n) - 1;
261+
if (n < 64) {
262+
long mask = (1L << n) - 1;
263+
res &= mask;
261264
}
265+
// if `n == 64`, do nothing
266+
return res;
262267
}
263268

264269
//endregion

0 commit comments

Comments
 (0)