Skip to content

Commit

Permalink
[util] add RWSpinLock and upgrade ObjectPool
Browse files Browse the repository at this point in the history
  • Loading branch information
wkgcass committed Aug 6, 2024
1 parent 5517931 commit bf6e791
Show file tree
Hide file tree
Showing 7 changed files with 315 additions and 64 deletions.
11 changes: 11 additions & 0 deletions base/src/main/java/io/vproxy/base/util/Utils.java
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,17 @@ public static boolean allZerosAfter(ByteArray bytes, int index) {
return true;
}

public static int minPow2GreaterThan(int n) {
n -= 1;
n |= n >>> 1;
n |= n >>> 2;
n |= n >>> 4;
n |= n >>> 8;
n |= n >>> 16;
n += 1;
return n;
}

public static boolean assertOn() {
return assertOn;
}
Expand Down
55 changes: 55 additions & 0 deletions base/src/main/java/io/vproxy/base/util/lock/ReadWriteSpinLock.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package io.vproxy.base.util.lock;

import java.util.concurrent.atomic.AtomicInteger;

public class ReadWriteSpinLock {
private static final int WRITE_LOCKED = 0x80_00_00_00;
// 32 31 ------ 0
// W RRRR...RRRR
private final AtomicInteger lock = new AtomicInteger(0);
private final AtomicInteger wLockPending = new AtomicInteger(0);
private final int spinTimes;

public ReadWriteSpinLock() {
this(20);
}

public ReadWriteSpinLock(int spinTimes) {
this.spinTimes = spinTimes;
}

public void readLock() {
while (true) {
if (wLockPending.get() != 0) {
spinWait();
continue;
}
if (lock.incrementAndGet() < 0) {
continue;
}
break;
}
}

public void readUnlock() {
lock.decrementAndGet();
}

public void writeLock() {
wLockPending.incrementAndGet();
while (!lock.compareAndSet(0, WRITE_LOCKED)) {
spinWait();
}
}

public void writeUnlock() {
lock.set(0);
wLockPending.decrementAndGet();
}

private void spinWait() {
for (int i = 0; i < spinTimes; ++i) {
Thread.onSpinWait();
}
}
}
Original file line number Diff line number Diff line change
@@ -1,62 +1,77 @@
package io.vproxy.base.util.objectpool;

import io.vproxy.base.util.Utils;
import io.vproxy.base.util.lock.ReadWriteSpinLock;
import io.vproxy.base.util.thread.VProxyThread;

import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.atomic.AtomicReferenceArray;

/**
* The pool is split into a few partitions, each partition has a read array and a write array.
* When adding, elements will be added into the write array.
* When polling, elements will be polled from the read array.
* If read array is empty and write array is full, and when running polling, the two arrays will be swapped
* (they will not be swapped when adding).
* If read array is empty and write array is full, and when running polling or adding, the two arrays will be swapped
* The arrays will not be operated when they are being swapped.
* When concurrency occurs, the operations will retry for maximum 10 times.
*
* @param <E> element type
*/
public class ConcurrentObjectPool<E> {
private final int partitionCount;
private final int partitionCountMinusOne;
private final Partition<E>[] partitions;
private final int maxTraversal;

public ConcurrentObjectPool(int capacityHint) {
this(capacityHint, 16, 4);
this(capacityHint, 16, 0);
}

public ConcurrentObjectPool(int capacityHint, int partitionCountHint, int minPartitionCapHint) {
capacityHint -= 1;
capacityHint |= capacityHint >>> 1;
capacityHint |= capacityHint >>> 2;
capacityHint |= capacityHint >>> 4;
capacityHint |= capacityHint >>> 8;
capacityHint |= capacityHint >>> 16;
capacityHint += 1;
public ConcurrentObjectPool(int capacityHint, int partitionCountHint, int maxTraversalHint) {
capacityHint = Utils.minPow2GreaterThan(capacityHint) / 2;
partitionCountHint = Utils.minPow2GreaterThan(partitionCountHint);

if (capacityHint / minPartitionCapHint == 0) {
if (capacityHint / partitionCountHint == 0) {
partitionCount = 1;
} else {
partitionCount = Math.min(capacityHint / minPartitionCapHint, partitionCountHint);
partitionCount = partitionCountHint;
}
partitionCountMinusOne = partitionCount - 1;

//noinspection unchecked
this.partitions = new Partition[partitionCount];
for (int i = 0; i < partitionCount; ++i) {
partitions[i] = new Partition<>(capacityHint / partitionCount);
}

if (maxTraversalHint <= 0 || maxTraversalHint >= partitionCount) {
maxTraversal = partitionCount;
} else {
maxTraversal = maxTraversalHint;
}
}

private int hashForPartition() {
var tid = VProxyThread.current().threadId;
return (int) (tid & partitionCountMinusOne);
}

public boolean add(E e) {
for (int i = 0; i < partitionCount; ++i) {
if (partitions[i].add(e)) {
int m = maxTraversal;
int hash = hashForPartition();
for (int i = hash; m > 0; ++i, --m) {
if (partitions[i & partitionCountMinusOne].add(e)) {
return true;
}
}
return false;
}

public E poll() {
for (int i = 0; i < partitionCount; ++i) {
E e = partitions[i].poll();
int m = maxTraversal;
int hash = hashForPartition();
for (int i = hash; m > 0; ++i, --m) {
E e = partitions[i & partitionCountMinusOne].poll();
if (e != null) {
return e;
}
Expand All @@ -73,27 +88,37 @@ public int size() {
}

private static class Partition<E> {
private final AtomicReference<StorageArray<E>> read;
private volatile StorageArray<E> write;
private final StorageArray<E> _1;
private final StorageArray<E> _2;
private final ReadWriteSpinLock lock = new ReadWriteSpinLock();
private volatile ArrayQueue<E> read;
private volatile ArrayQueue<E> write;
private final ArrayQueue<E> _1;
private final ArrayQueue<E> _2;

public Partition(int capacity) {
_1 = new StorageArray<>(capacity);
_2 = new StorageArray<>(capacity);
read = new AtomicReference<>(_1);
_1 = new ArrayQueue<>(capacity, lock);
_2 = new ArrayQueue<>(capacity, lock);
read = _1;
write = _2;
}

public boolean add(E e) {
StorageArray<E> write = this.write;
return add(1, e);
}

private boolean add(int retry, E e) {
if (retry > 10) { // max retry for 10 times
return false; // too many retries
}

// adding is always safe
//noinspection RedundantIfStatement
var write = this.write;
if (write.add(e)) {
return true;
}
// $write is full, storing fails

// the $write is full now
if (swap(read, write, false)) {
return add(retry + 1, e);
}
return false;
}

Expand All @@ -106,103 +131,134 @@ private E poll(int retry) {
return null; // too many retries
}

StorageArray<E> read = this.read.get();
StorageArray<E> write = this.write;
var read = this.read;
var write = this.write;

// polling is always safe
E ret = read.poll();
var ret = read.poll();
if (ret != null) {
return ret;
}

// no elements in the $read now
// check whether we can swap (whether $write is full)
if (swap(read, write, true)) {
return poll(retry + 1);
}
return null;
}

int writeEnd = write.end.get();
if (writeEnd < write.capacity) {
return null; // capacity not reached, do not swap and return nothing
// no retry here because the write array will not change until something written into it
// return true -> need retry
// return false -> failed and should not retry
private boolean swap(ArrayQueue<E> read, ArrayQueue<E> write, boolean isPolling) {
// check whether we can swap
if (read == write) {
// is being swapped by another thread
return true;
}
// also we should check whether there are no elements being stored
if (write.storing.get() != 0) { // element is being stored into the array
return poll(retry + 1); // try again

if (isPolling) { // $read is empty
int writeEnd = write.end.get();
if (writeEnd < write.capacity) {
return false; // capacity not reached, do not swap and return nothing
// no retry here because the write array will not change until something written into it
}
} else { // $write is full
int readStart = read.start.get();
if (readStart < read.end.get()) {
return false; // still have objects to fetch, do not swap
// no retry here because the read array will not change until something polling from it
}
}
// now we can know that writing operations will not happen in this partition

// we can safely swap the two arrays now
if (!this.read.compareAndSet(read, write)) {
return poll(retry + 1); // concurrency detected: another thread is swapping
lock.writeLock();
if (this.read != read) {
// already swapped by another thread
lock.writeUnlock();
return true;
}
// we can safely swap the two arrays now
this.read = write;
// the $read is expected to be empty
assert read.size() == 0;
read.reset(); // reset the cursors, so further operations can store data into this array
this.write = read; // swapping is done
return poll(retry + 1); // poll again
lock.writeUnlock();

return true;
}

public int size() {
return _1.size() + _2.size();
}
}

private static class StorageArray<E> {
private static class ArrayQueue<E> {
private final int capacity;
private final ReadWriteSpinLock lock;
private final AtomicReferenceArray<E> array;
private final AtomicInteger start = new AtomicInteger(-1);
private final AtomicInteger end = new AtomicInteger(-1);
private final AtomicInteger storing = new AtomicInteger(0);
private final AtomicInteger start = new AtomicInteger(0);
private final AtomicInteger end = new AtomicInteger(0);

private StorageArray(int capacity) {
private ArrayQueue(int capacity, ReadWriteSpinLock lock) {
this.capacity = capacity;
this.lock = lock;
this.array = new AtomicReferenceArray<>(capacity);
}

boolean add(E e) {
storing.incrementAndGet();
lock.readLock();

if (end.get() >= capacity) {
storing.decrementAndGet();
lock.readUnlock();
return false; // exceeds capacity
}
int index = end.incrementAndGet();
int index = end.getAndIncrement();
if (index < capacity) {
// storing should succeed
array.set(index, e);
storing.decrementAndGet();
lock.readUnlock();
return true;
} else {
// storing failed
storing.decrementAndGet();
lock.readUnlock();
return false;
}
}

E poll() {
if (start.get() + 1 >= end.get() || start.get() + 1 >= capacity) {
lock.readLock();

if (start.get() >= end.get() || start.get() >= capacity) {
lock.readUnlock();
return null;
}
int idx = start.incrementAndGet();
int idx = start.getAndIncrement();
if (idx >= end.get() || idx >= capacity) {
lock.readUnlock();
return null; // concurrent polling
}
return array.get(idx);
var e = array.get(idx);
lock.readUnlock();
return e;
}

int size() {
int start = this.start.get() + 1;
int start = this.start.get();
if (start >= capacity) {
return 0;
}
int cap = end.get() + 1;
int cap = end.get();
if (cap > capacity) {
cap = capacity;
}
if (start > cap) {
return 0;
}
return cap - start;
}

void reset() {
end.set(-1);
start.set(-1);
end.set(0);
start.set(0);
}
}
}
Loading

0 comments on commit bf6e791

Please sign in to comment.