@@ -185,34 +185,35 @@ public long readU8le() {
185
185
//region Unaligned bit values
186
186
187
187
public void alignToByte () {
188
- bits = 0 ;
189
188
bitsLeft = 0 ;
189
+ bits = 0 ;
190
190
}
191
191
192
192
public long readBitsIntBe (int n ) {
193
+ long res = 0 ;
194
+
193
195
int bitsNeeded = n - bitsLeft ;
196
+ bitsLeft = -bitsNeeded & 7 ; // `-bitsNeeded mod 8`
197
+
194
198
if (bitsNeeded > 0 ) {
195
199
// 1 bit => 1 byte
196
200
// 8 bits => 1 byte
197
201
// 9 bits => 2 bytes
198
- int bytesNeeded = ((bitsNeeded - 1 ) / 8 ) + 1 ;
202
+ int bytesNeeded = ((bitsNeeded - 1 ) / 8 ) + 1 ; // `ceil(bitsNeeded / 8)`
199
203
byte [] buf = readBytes (bytesNeeded );
200
204
for (byte b : buf ) {
201
- bits <<= 8 ;
202
- // b is signed byte, convert to unsigned using "& 0xff" trick
203
- bits |= (b & 0xff );
204
- bitsLeft += 8 ;
205
+ // `b` is signed byte, convert to unsigned using the "& 0xff" trick
206
+ res = res << 8 | (b & 0xff );
205
207
}
208
+
209
+ long newBits = res ;
210
+ res = res >>> bitsLeft | bits << bitsNeeded ;
211
+ bits = newBits ; // will be masked at the end of the function
212
+ } else {
213
+ res = bits >>> -bitsNeeded ; // shift unneeded bits out
206
214
}
207
215
208
- // raw mask with required number of 1s, starting from lowest bit
209
- long mask = getMaskOnes (n );
210
- // shift "bits" to align the highest bits with the mask & derive the result
211
- int shiftBits = bitsLeft - n ;
212
- long res = (bits >>> shiftBits ) & mask ;
213
- // clear top bits that we've just read => AND with 1s
214
- bitsLeft -= n ;
215
- mask = getMaskOnes (bitsLeft );
216
+ long mask = (1L << bitsLeft ) - 1 ; // `bitsLeft` is in range 0..7, so `(1L << 64)` does not have to be considered
216
217
bits &= mask ;
217
218
218
219
return res ;
@@ -229,36 +230,40 @@ public long readBitsInt(int n) {
229
230
}
230
231
231
232
public long readBitsIntLe (int n ) {
233
+ long res = 0 ;
232
234
int bitsNeeded = n - bitsLeft ;
235
+
233
236
if (bitsNeeded > 0 ) {
234
237
// 1 bit => 1 byte
235
238
// 8 bits => 1 byte
236
239
// 9 bits => 2 bytes
237
- int bytesNeeded = ((bitsNeeded - 1 ) / 8 ) + 1 ;
240
+ int bytesNeeded = ((bitsNeeded - 1 ) / 8 ) + 1 ; // `ceil(bitsNeeded / 8)`
238
241
byte [] buf = readBytes (bytesNeeded );
239
- for (byte b : buf ) {
240
- bits |= (( long ) ( b & 0xff ) << bitsLeft );
241
- bitsLeft += 8 ;
242
+ for (int i = 0 ; i < bytesNeeded ; i ++ ) {
243
+ // `buf[i]` is signed byte, convert to unsigned using the " & 0xff" trick
244
+ res |= (( long ) ( buf [ i ] & 0xff )) << ( i * 8 ) ;
242
245
}
246
+
247
+ // NB: in Java, bit shift operators on left-hand operand of type `long` work
248
+ // as if the right-hand operand were subjected to `& 63` (`& 0b11_1111`) (see
249
+ // https://docs.oracle.com/javase/specs/jls/se7/html/jls-15.html#jls-15.19),
250
+ // so `res >>> 64` is equivalent to `res >>> 0` (but we don't want that)
251
+ long newBits = bitsNeeded < 64 ? res >>> bitsNeeded : 0 ;
252
+ res = res << bitsLeft | bits ;
253
+ bits = newBits ;
254
+ } else {
255
+ res = bits ;
256
+ bits >>>= n ;
243
257
}
244
258
245
- // raw mask with required number of 1s, starting from lowest bit
246
- long mask = getMaskOnes (n );
247
- // derive reading result
248
- long res = bits & mask ;
249
- // remove bottom bits that we've just read by shifting
250
- bits >>>= n ;
251
- bitsLeft -= n ;
259
+ bitsLeft = -bitsNeeded & 7 ; // `-bitsNeeded mod 8`
252
260
253
- return res ;
254
- }
255
-
256
- private static long getMaskOnes (int n ) {
257
- if (n == 64 ) {
258
- return 0xffffffffffffffffL ;
259
- } else {
260
- return (1L << n ) - 1 ;
261
+ if (n < 64 ) {
262
+ long mask = (1L << n ) - 1 ;
263
+ res &= mask ;
261
264
}
265
+ // if `n == 64`, do nothing
266
+ return res ;
262
267
}
263
268
264
269
//endregion
0 commit comments