Skip to content

Commit a7cfba0

Browse files
authored
Merge pull request #2113 from kleros/refactor/sortition-trees-library
Sortition trees extracted to a library
2 parents 953bd95 + 323ba18 commit a7cfba0

File tree

10 files changed

+1116
-248
lines changed

10 files changed

+1116
-248
lines changed

contracts/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ The format is based on [Common Changelog](https://common-changelog.org/).
2222
- Make the primary VRF-based RNG fall back to `BlockhashRNG` if the VRF request is not fulfilled within a timeout ([#2054](https://github.com/kleros/kleros-v2/issues/2054))
2323
- Authenticate the calls to the RNGs to prevent 3rd parties from depleting the Chainlink VRF subscription funds ([#2054](https://github.com/kleros/kleros-v2/issues/2054))
2424
- Use `block.timestamp` rather than `block.number` for `BlockhashRNG` for better reliability on Arbitrum as block production is sporadic depending on network conditions. ([#2054](https://github.com/kleros/kleros-v2/issues/2054))
25+
- Replace the `bytes32 _key` parameter in `SortitionTrees.createTree()` and `SortitionTrees.draw()` by `uint96 courtID` ([#2113](https://github.com/kleros/kleros-v2/issues/2113))
26+
- Extract the sortition sum trees logic into a library `SortitionTrees` ([#2113](https://github.com/kleros/kleros-v2/issues/2113))
2527
- Set the Hardhat Solidity version to v0.8.30 and enable the IR pipeline ([#2069](https://github.com/kleros/kleros-v2/issues/2069))
2628
- Set the Foundry Solidity version to v0.8.30 and enable the IR pipeline ([#2073](https://github.com/kleros/kleros-v2/issues/2073))
2729
- Widen the allowed solc version to any v0.8.x for the interfaces only ([#2083](https://github.com/kleros/kleros-v2/issues/2083))

contracts/src/arbitration/KlerosCoreBase.sol

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ abstract contract KlerosCoreBase is IArbitratorV2, Initializable, UUPSProxiable
223223
// FORKING_COURT
224224
// TODO: Fill the properties for the Forking court, emit CourtCreated.
225225
courts.push();
226-
sortitionModule.createTree(bytes32(uint256(FORKING_COURT)), _sortitionExtraData);
226+
sortitionModule.createTree(FORKING_COURT, _sortitionExtraData);
227227

228228
// GENERAL_COURT
229229
Court storage court = courts.push();
@@ -236,7 +236,7 @@ abstract contract KlerosCoreBase is IArbitratorV2, Initializable, UUPSProxiable
236236
court.jurorsForCourtJump = _courtParameters[3];
237237
court.timesPerPeriod = _timesPerPeriod;
238238

239-
sortitionModule.createTree(bytes32(uint256(GENERAL_COURT)), _sortitionExtraData);
239+
sortitionModule.createTree(GENERAL_COURT, _sortitionExtraData);
240240

241241
uint256[] memory supportedDisputeKits = new uint256[](1);
242242
supportedDisputeKits[0] = DISPUTE_KIT_CLASSIC;
@@ -343,7 +343,7 @@ abstract contract KlerosCoreBase is IArbitratorV2, Initializable, UUPSProxiable
343343
if (_supportedDisputeKits.length == 0) revert UnsupportedDisputeKit();
344344
if (_parent == FORKING_COURT) revert InvalidForkingCourtAsParent();
345345

346-
uint256 courtID = courts.length;
346+
uint96 courtID = uint96(courts.length);
347347
Court storage court = courts.push();
348348

349349
for (uint256 i = 0; i < _supportedDisputeKits.length; i++) {
@@ -364,7 +364,7 @@ abstract contract KlerosCoreBase is IArbitratorV2, Initializable, UUPSProxiable
364364
court.jurorsForCourtJump = _jurorsForCourtJump;
365365
court.timesPerPeriod = _timesPerPeriod;
366366

367-
sortitionModule.createTree(bytes32(courtID), _sortitionExtraData);
367+
sortitionModule.createTree(courtID, _sortitionExtraData);
368368

369369
// Update the parent.
370370
courts[_parent].children.push(courtID);

contracts/src/arbitration/SortitionModuleBase.sol

Lines changed: 19 additions & 232 deletions
Original file line numberDiff line numberDiff line change
@@ -7,25 +7,20 @@ import {ISortitionModule} from "./interfaces/ISortitionModule.sol";
77
import {IDisputeKit} from "./interfaces/IDisputeKit.sol";
88
import {Initializable} from "../proxy/Initializable.sol";
99
import {UUPSProxiable} from "../proxy/UUPSProxiable.sol";
10+
import {SortitionTrees, TreeKey, CourtID} from "../libraries/SortitionTrees.sol";
1011
import {IRNG} from "../rng/IRNG.sol";
1112
import "../libraries/Constants.sol";
1213

1314
/// @title SortitionModuleBase
1415
/// @dev A factory of trees that keeps track of staked values for sortition.
1516
abstract contract SortitionModuleBase is ISortitionModule, Initializable, UUPSProxiable {
17+
using SortitionTrees for SortitionTrees.Tree;
18+
using SortitionTrees for mapping(TreeKey key => SortitionTrees.Tree);
19+
1620
// ************************************* //
1721
// * Enums / Structs * //
1822
// ************************************* //
1923

20-
struct SortitionSumTree {
21-
uint256 K; // The maximum number of children per node.
22-
uint256[] stack; // We use this to keep track of vacant positions in the tree after removing a leaf. This is for keeping the tree as balanced as possible without spending gas on moving nodes around.
23-
uint256[] nodes; // The tree nodes.
24-
// Two-way mapping of IDs to node indexes. Note that node index 0 is reserved for the root node, and means the ID does not have a node.
25-
mapping(bytes32 stakePathID => uint256 nodeIndex) IDsToNodeIndexes;
26-
mapping(uint256 nodeIndex => bytes32 stakePathID) nodeIndexesToIDs;
27-
}
28-
2924
struct DelayedStake {
3025
address account; // The address of the juror.
3126
uint96 courtID; // The ID of the court.
@@ -56,7 +51,7 @@ abstract contract SortitionModuleBase is ISortitionModule, Initializable, UUPSPr
5651
uint256 public rngLookahead; // DEPRECATED: to be removed in the next redeploy
5752
uint256 public delayedStakeWriteIndex; // The index of the last `delayedStake` item that was written to the array. 0 index is skipped.
5853
uint256 public delayedStakeReadIndex; // The index of the next `delayedStake` item that should be processed. Starts at 1 because 0 index is skipped.
59-
mapping(bytes32 treeHash => SortitionSumTree) sortitionSumTrees; // The mapping trees by keys.
54+
mapping(TreeKey key => SortitionTrees.Tree) sortitionSumTrees; // The mapping of sortition trees by keys.
6055
mapping(address account => Juror) public jurors; // The jurors.
6156
mapping(uint256 => DelayedStake) public delayedStakes; // Stores the stakes that were changed during Drawing phase, to update them when the phase is switched to Staking.
6257
mapping(address jurorAccount => mapping(uint96 courtId => uint256)) public latestDelayedStakeIndex; // DEPRECATED. Maps the juror to its latest delayed stake. If there is already a delayed stake for this juror then it'll be replaced. latestDelayedStakeIndex[juror][courtID].
@@ -185,15 +180,12 @@ abstract contract SortitionModuleBase is ISortitionModule, Initializable, UUPSPr
185180
}
186181

187182
/// @dev Create a sortition sum tree at the specified key.
188-
/// @param _key The key of the new tree.
183+
/// @param _courtID The ID of the court.
189184
/// @param _extraData Extra data that contains the number of children each node in the tree should have.
190-
function createTree(bytes32 _key, bytes memory _extraData) external override onlyByCore {
191-
SortitionSumTree storage tree = sortitionSumTrees[_key];
185+
function createTree(uint96 _courtID, bytes memory _extraData) external override onlyByCore {
186+
TreeKey key = CourtID.wrap(_courtID).toTreeKey();
192187
uint256 K = _extraDataToTreeK(_extraData);
193-
if (tree.K != 0) revert TreeAlreadyExists();
194-
if (K <= 1) revert KMustBeGreaterThanOne();
195-
tree.K = K;
196-
tree.nodes.push(0);
188+
sortitionSumTrees.createTree(key, K);
197189
}
198190

199191
/// @dev Executes the next delayed stakes.
@@ -398,12 +390,13 @@ abstract contract SortitionModuleBase is ISortitionModule, Initializable, UUPSPr
398390
}
399391

400392
// Update the sortition sum tree.
401-
bytes32 stakePathID = _accountAndCourtIDToStakePathID(_account, _courtID);
393+
bytes32 stakePathID = SortitionTrees.toStakePathID(_account, _courtID);
402394
bool finished = false;
403395
uint96 currentCourtID = _courtID;
404396
while (!finished) {
405397
// Tokens are also implicitly staked in parent courts through sortition module to increase the chance of being drawn.
406-
_set(bytes32(uint256(currentCourtID)), _newStake, stakePathID);
398+
TreeKey key = CourtID.wrap(currentCourtID).toTreeKey();
399+
sortitionSumTrees[key].set(_newStake, stakePathID);
407400
if (currentCourtID == GENERAL_COURT) {
408401
finished = true;
409402
} else {
@@ -477,71 +470,32 @@ abstract contract SortitionModuleBase is ISortitionModule, Initializable, UUPSPr
477470

478471
/// @dev Draw an ID from a tree using a number.
479472
/// Note that this function reverts if the sum of all values in the tree is 0.
480-
/// @param _key The key of the tree.
473+
/// @param _courtID The ID of the court.
481474
/// @param _coreDisputeID Index of the dispute in Kleros Core.
482475
/// @param _nonce Nonce to hash with random number.
483476
/// @return drawnAddress The drawn address.
484477
/// `O(k * log_k(n))` where
485478
/// `k` is the maximum number of children per node in the tree,
486479
/// and `n` is the maximum number of nodes ever appended.
487480
function draw(
488-
bytes32 _key,
481+
uint96 _courtID,
489482
uint256 _coreDisputeID,
490483
uint256 _nonce
491484
) public view override returns (address drawnAddress, uint96 fromSubcourtID) {
492485
if (phase != Phase.drawing) revert NotDrawingPhase();
493-
SortitionSumTree storage tree = sortitionSumTrees[_key];
494-
495-
if (tree.nodes[0] == 0) {
496-
return (address(0), 0); // No jurors staked.
497-
}
498-
499-
uint256 currentDrawnNumber = uint256(keccak256(abi.encodePacked(randomNumber, _coreDisputeID, _nonce))) %
500-
tree.nodes[0];
501-
502-
// While it still has children
503-
uint256 treeIndex = 0;
504-
while ((tree.K * treeIndex) + 1 < tree.nodes.length) {
505-
for (uint256 i = 1; i <= tree.K; i++) {
506-
// Loop over children.
507-
uint256 nodeIndex = (tree.K * treeIndex) + i;
508-
uint256 nodeValue = tree.nodes[nodeIndex];
509-
510-
if (currentDrawnNumber >= nodeValue) {
511-
// Go to the next child.
512-
currentDrawnNumber -= nodeValue;
513-
} else {
514-
// Pick this child.
515-
treeIndex = nodeIndex;
516-
break;
517-
}
518-
}
519-
}
520486

521-
bytes32 stakePathID = tree.nodeIndexesToIDs[treeIndex];
522-
(drawnAddress, fromSubcourtID) = _stakePathIDToAccountAndCourtID(stakePathID);
487+
TreeKey key = CourtID.wrap(_courtID).toTreeKey();
488+
(drawnAddress, fromSubcourtID) = sortitionSumTrees[key].draw(_coreDisputeID, _nonce, randomNumber);
523489
}
524490

525491
/// @dev Get the stake of a juror in a court.
526492
/// @param _juror The address of the juror.
527493
/// @param _courtID The ID of the court.
528494
/// @return value The stake of the juror in the court.
529495
function stakeOf(address _juror, uint96 _courtID) public view returns (uint256) {
530-
bytes32 stakePathID = _accountAndCourtIDToStakePathID(_juror, _courtID);
531-
return stakeOf(bytes32(uint256(_courtID)), stakePathID);
532-
}
533-
534-
/// @dev Get the stake of a juror in a court.
535-
/// @param _key The key of the tree, corresponding to a court.
536-
/// @param _stakePathID The stake path ID, corresponding to a juror.
537-
/// @return The stake of the juror in the court.
538-
function stakeOf(bytes32 _key, bytes32 _stakePathID) public view returns (uint256) {
539-
SortitionSumTree storage tree = sortitionSumTrees[_key];
540-
uint treeIndex = tree.IDsToNodeIndexes[_stakePathID];
541-
if (treeIndex == 0) {
542-
return 0;
543-
}
544-
return tree.nodes[treeIndex];
496+
bytes32 stakePathID = SortitionTrees.toStakePathID(_juror, _courtID);
497+
TreeKey key = CourtID.wrap(_courtID).toTreeKey();
498+
return sortitionSumTrees[key].stakeOf(stakePathID);
545499
}
546500

547501
/// @dev Gets the balance of a juror in a court.
@@ -590,26 +544,6 @@ abstract contract SortitionModuleBase is ISortitionModule, Initializable, UUPSPr
590544
// * Internal * //
591545
// ************************************* //
592546

593-
/// @dev Update all the parents of a node.
594-
/// @param _key The key of the tree to update.
595-
/// @param _treeIndex The index of the node to start from.
596-
/// @param _plusOrMinus Whether to add (true) or substract (false).
597-
/// @param _value The value to add or substract.
598-
/// `O(log_k(n))` where
599-
/// `k` is the maximum number of children per node in the tree,
600-
/// and `n` is the maximum number of nodes ever appended.
601-
function _updateParents(bytes32 _key, uint256 _treeIndex, bool _plusOrMinus, uint256 _value) private {
602-
SortitionSumTree storage tree = sortitionSumTrees[_key];
603-
604-
uint256 parentIndex = _treeIndex;
605-
while (parentIndex != 0) {
606-
parentIndex = (parentIndex - 1) / tree.K;
607-
tree.nodes[parentIndex] = _plusOrMinus
608-
? tree.nodes[parentIndex] + _value
609-
: tree.nodes[parentIndex] - _value;
610-
}
611-
}
612-
613547
function _extraDataToTreeK(bytes memory _extraData) internal pure returns (uint256 K) {
614548
if (_extraData.length >= 32) {
615549
assembly {
@@ -621,151 +555,6 @@ abstract contract SortitionModuleBase is ISortitionModule, Initializable, UUPSPr
621555
}
622556
}
623557

624-
/// @dev Set a value in a tree.
625-
/// @param _key The key of the tree.
626-
/// @param _value The new value.
627-
/// @param _stakePathID The ID of the value.
628-
/// `O(log_k(n))` where
629-
/// `k` is the maximum number of children per node in the tree,
630-
/// and `n` is the maximum number of nodes ever appended.
631-
function _set(bytes32 _key, uint256 _value, bytes32 _stakePathID) internal {
632-
SortitionSumTree storage tree = sortitionSumTrees[_key];
633-
uint256 treeIndex = tree.IDsToNodeIndexes[_stakePathID];
634-
635-
if (treeIndex == 0) {
636-
// No existing node.
637-
if (_value != 0) {
638-
// Non zero value.
639-
// Append.
640-
// Add node.
641-
if (tree.stack.length == 0) {
642-
// No vacant spots.
643-
// Get the index and append the value.
644-
treeIndex = tree.nodes.length;
645-
tree.nodes.push(_value);
646-
647-
// Potentially append a new node and make the parent a sum node.
648-
if (treeIndex != 1 && (treeIndex - 1) % tree.K == 0) {
649-
// Is first child.
650-
uint256 parentIndex = treeIndex / tree.K;
651-
bytes32 parentID = tree.nodeIndexesToIDs[parentIndex];
652-
uint256 newIndex = treeIndex + 1;
653-
tree.nodes.push(tree.nodes[parentIndex]);
654-
delete tree.nodeIndexesToIDs[parentIndex];
655-
tree.IDsToNodeIndexes[parentID] = newIndex;
656-
tree.nodeIndexesToIDs[newIndex] = parentID;
657-
}
658-
} else {
659-
// Some vacant spot.
660-
// Pop the stack and append the value.
661-
treeIndex = tree.stack[tree.stack.length - 1];
662-
tree.stack.pop();
663-
tree.nodes[treeIndex] = _value;
664-
}
665-
666-
// Add label.
667-
tree.IDsToNodeIndexes[_stakePathID] = treeIndex;
668-
tree.nodeIndexesToIDs[treeIndex] = _stakePathID;
669-
670-
_updateParents(_key, treeIndex, true, _value);
671-
}
672-
} else {
673-
// Existing node.
674-
if (_value == 0) {
675-
// Zero value.
676-
// Remove.
677-
// Remember value and set to 0.
678-
uint256 value = tree.nodes[treeIndex];
679-
tree.nodes[treeIndex] = 0;
680-
681-
// Push to stack.
682-
tree.stack.push(treeIndex);
683-
684-
// Clear label.
685-
delete tree.IDsToNodeIndexes[_stakePathID];
686-
delete tree.nodeIndexesToIDs[treeIndex];
687-
688-
_updateParents(_key, treeIndex, false, value);
689-
} else if (_value != tree.nodes[treeIndex]) {
690-
// New, non zero value.
691-
// Set.
692-
bool plusOrMinus = tree.nodes[treeIndex] <= _value;
693-
uint256 plusOrMinusValue = plusOrMinus
694-
? _value - tree.nodes[treeIndex]
695-
: tree.nodes[treeIndex] - _value;
696-
tree.nodes[treeIndex] = _value;
697-
698-
_updateParents(_key, treeIndex, plusOrMinus, plusOrMinusValue);
699-
}
700-
}
701-
}
702-
703-
/// @dev Packs an account and a court ID into a stake path ID: [20 bytes of address][12 bytes of courtID] = 32 bytes total.
704-
/// @param _account The address of the juror to pack.
705-
/// @param _courtID The court ID to pack.
706-
/// @return stakePathID The stake path ID.
707-
function _accountAndCourtIDToStakePathID(
708-
address _account,
709-
uint96 _courtID
710-
) internal pure returns (bytes32 stakePathID) {
711-
assembly {
712-
// solium-disable-line security/no-inline-assembly
713-
let ptr := mload(0x40)
714-
715-
// Write account address (first 20 bytes)
716-
for {
717-
let i := 0x00
718-
} lt(i, 0x14) {
719-
i := add(i, 0x01)
720-
} {
721-
mstore8(add(ptr, i), byte(add(0x0c, i), _account))
722-
}
723-
724-
// Write court ID (last 12 bytes)
725-
for {
726-
let i := 0x14
727-
} lt(i, 0x20) {
728-
i := add(i, 0x01)
729-
} {
730-
mstore8(add(ptr, i), byte(i, _courtID))
731-
}
732-
stakePathID := mload(ptr)
733-
}
734-
}
735-
736-
/// @dev Retrieves both juror's address and court ID from the stake path ID.
737-
/// @param _stakePathID The stake path ID to unpack.
738-
/// @return account The account.
739-
/// @return courtID The court ID.
740-
function _stakePathIDToAccountAndCourtID(
741-
bytes32 _stakePathID
742-
) internal pure returns (address account, uint96 courtID) {
743-
assembly {
744-
// solium-disable-line security/no-inline-assembly
745-
let ptr := mload(0x40)
746-
747-
// Read account address (first 20 bytes)
748-
for {
749-
let i := 0x00
750-
} lt(i, 0x14) {
751-
i := add(i, 0x01)
752-
} {
753-
mstore8(add(add(ptr, 0x0c), i), byte(i, _stakePathID))
754-
}
755-
account := mload(ptr)
756-
757-
// Read court ID (last 12 bytes)
758-
for {
759-
let i := 0x00
760-
} lt(i, 0x0c) {
761-
i := add(i, 0x01)
762-
} {
763-
mstore8(add(add(ptr, 0x14), i), byte(add(i, 0x14), _stakePathID))
764-
}
765-
courtID := mload(ptr)
766-
}
767-
}
768-
769558
// ************************************* //
770559
// * Errors * //
771560
// ************************************* //
@@ -776,8 +565,6 @@ abstract contract SortitionModuleBase is ISortitionModule, Initializable, UUPSPr
776565
error NoDisputesThatNeedJurors();
777566
error RandomNumberNotReady();
778567
error DisputesWithoutJurorsAndMaxDrawingTimeNotPassed();
779-
error TreeAlreadyExists();
780-
error KMustBeGreaterThanOne();
781568
error NotStakingPhase();
782569
error NoDelayedStakeToExecute();
783570
error NotEligibleForWithdrawal();

contracts/src/arbitration/dispute-kits/DisputeKitClassicBase.sol

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,9 +232,7 @@ abstract contract DisputeKitClassicBase is IDisputeKit, Initializable, UUPSProxi
232232

233233
ISortitionModule sortitionModule = core.sortitionModule();
234234
(uint96 courtID, , , , ) = core.disputes(_coreDisputeID);
235-
bytes32 key = bytes32(uint256(courtID)); // Get the ID of the tree.
236-
237-
(drawnAddress, fromSubcourtID) = sortitionModule.draw(key, _coreDisputeID, _nonce);
235+
(drawnAddress, fromSubcourtID) = sortitionModule.draw(courtID, _coreDisputeID, _nonce);
238236
if (drawnAddress == address(0)) {
239237
// Sortition can return 0 address if no one has staked yet.
240238
return (drawnAddress, fromSubcourtID);

0 commit comments

Comments
 (0)