Skip to content

Commit

Permalink
Optimize votes lookups for recent checkpoints (OpenZeppelin#3673)
Browse files Browse the repository at this point in the history
  • Loading branch information
frangio authored and JulissaDantes committed Nov 3, 2022
1 parent a4f6d97 commit b1bc18a
Show file tree
Hide file tree
Showing 9 changed files with 94 additions and 102 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
* `Address`: optimize `functionCall` functions by checking contract size only if there is no returned data. ([#3469](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3469))
* `GovernorCompatibilityBravo`: remove unused `using` statements. ([#3506](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3506))
* `ERC20`: optimize `_transfer`, `_mint` and `_burn` by using `unchecked` arithmetic when possible. ([#3513](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3513))
* `ERC20Votes`, `ERC721Votes`: optimize `getPastVotes` for looking up recent checkpoints. ([#3673](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3673))
* `ERC20FlashMint`: add an internal `_flashFee` function for overriding. ([#3551](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3551))
* `ERC4626`: use the same `decimals()` as the underlying asset by default (if available). ([#3639](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3639))
* `ERC4626`: add internal `_initialConvertToShares` and `_initialConvertToAssets` functions to customize empty vaults behavior. ([#3639](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3639))
Expand Down
4 changes: 2 additions & 2 deletions contracts/governance/utils/Votes.sol
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ abstract contract Votes is IVotes, Context, EIP712 {
* - `blockNumber` must have been already mined
*/
function getPastVotes(address account, uint256 blockNumber) public view virtual override returns (uint256) {
return _delegateCheckpoints[account].getAtBlock(blockNumber);
return _delegateCheckpoints[account].getAtProbablyRecentBlock(blockNumber);
}

/**
Expand All @@ -72,7 +72,7 @@ abstract contract Votes is IVotes, Context, EIP712 {
*/
function getPastTotalSupply(uint256 blockNumber) public view virtual override returns (uint256) {
require(blockNumber < block.number, "Votes: block not yet mined");
return _totalCheckpoints.getAtBlock(blockNumber);
return _totalCheckpoints.getAtProbablyRecentBlock(blockNumber);
}

/**
Expand Down
12 changes: 2 additions & 10 deletions contracts/mocks/CheckpointsMock.sol
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ contract CheckpointsMock {
return _totalCheckpoints.getAtBlock(blockNumber);
}

function getAtRecentBlock(uint256 blockNumber) public view returns (uint256) {
return _totalCheckpoints.getAtRecentBlock(blockNumber);
function getAtProbablyRecentBlock(uint256 blockNumber) public view returns (uint256) {
return _totalCheckpoints.getAtProbablyRecentBlock(blockNumber);
}

function length() public view returns (uint256) {
Expand Down Expand Up @@ -52,10 +52,6 @@ contract Checkpoints224Mock {
return _totalCheckpoints.upperLookup(key);
}

function upperLookupRecent(uint32 key) public view returns (uint224) {
return _totalCheckpoints.upperLookupRecent(key);
}

function length() public view returns (uint256) {
return _totalCheckpoints._checkpoints.length;
}
Expand All @@ -82,10 +78,6 @@ contract Checkpoints160Mock {
return _totalCheckpoints.upperLookup(key);
}

function upperLookupRecent(uint96 key) public view returns (uint224) {
return _totalCheckpoints.upperLookupRecent(key);
}

function length() public view returns (uint256) {
return _totalCheckpoints._checkpoints.length;
}
Expand Down
35 changes: 29 additions & 6 deletions contracts/token/ERC20/extensions/ERC20Votes.sol
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ abstract contract ERC20Votes is IVotes, ERC20Permit {
function _checkpointsLookup(Checkpoint[] storage ckpts, uint256 blockNumber) private view returns (uint256) {
// We run a binary search to look for the earliest checkpoint taken after `blockNumber`.
//
// Initially we check if the block is recent to narrow the search range.
// During the loop, the index of the wanted checkpoint remains in the range [low-1, high).
// With each iteration, either `low` or `high` is moved towards the middle of the range to maintain the invariant.
// - If the middle checkpoint is after `blockNumber`, we look in [low, mid)
Expand All @@ -106,18 +107,30 @@ abstract contract ERC20Votes is IVotes, ERC20Permit {
// Note that if the latest checkpoint available is exactly for `blockNumber`, we end up with an index that is
// past the end of the array, so we technically don't find a checkpoint after `blockNumber`, but it works out
// the same.
uint256 high = ckpts.length;
uint256 length = ckpts.length;

uint256 low = 0;
uint256 high = length;

if (length > 5) {
uint256 mid = length - Math.sqrt(length);
if (_unsafeAccess(ckpts, mid).fromBlock > blockNumber) {
high = mid;
} else {
low = mid + 1;
}
}

while (low < high) {
uint256 mid = Math.average(low, high);
if (ckpts[mid].fromBlock > blockNumber) {
if (_unsafeAccess(ckpts, mid).fromBlock > blockNumber) {
high = mid;
} else {
low = mid + 1;
}
}

return high == 0 ? 0 : ckpts[high - 1].votes;
return high == 0 ? 0 : _unsafeAccess(ckpts, high - 1).votes;
}

/**
Expand Down Expand Up @@ -229,11 +242,14 @@ abstract contract ERC20Votes is IVotes, ERC20Permit {
uint256 delta
) private returns (uint256 oldWeight, uint256 newWeight) {
uint256 pos = ckpts.length;
oldWeight = pos == 0 ? 0 : ckpts[pos - 1].votes;

Checkpoint memory oldCkpt = pos == 0 ? Checkpoint(0, 0) : _unsafeAccess(ckpts, pos - 1);

oldWeight = oldCkpt.votes;
newWeight = op(oldWeight, delta);

if (pos > 0 && ckpts[pos - 1].fromBlock == block.number) {
ckpts[pos - 1].votes = SafeCast.toUint224(newWeight);
if (pos > 0 && oldCkpt.fromBlock == block.number) {
_unsafeAccess(ckpts, pos - 1).votes = SafeCast.toUint224(newWeight);
} else {
ckpts.push(Checkpoint({fromBlock: SafeCast.toUint32(block.number), votes: SafeCast.toUint224(newWeight)}));
}
Expand All @@ -246,4 +262,11 @@ abstract contract ERC20Votes is IVotes, ERC20Permit {
function _subtract(uint256 a, uint256 b) private pure returns (uint256) {
return a - b;
}

function _unsafeAccess(Checkpoint[] storage ckpts, uint256 pos) private view returns (Checkpoint storage result) {
assembly {
mstore(0, ckpts.slot)
result.slot := add(keccak256(0, 0x20), pos)
}
}
}
60 changes: 14 additions & 46 deletions contracts/utils/Checkpoints.sol
Original file line number Diff line number Diff line change
Expand Up @@ -49,22 +49,28 @@ library Checkpoints {

/**
* @dev Returns the value at a given block number. If a checkpoint is not available at that block, the closest one
* before it is returned, or zero otherwise. Similarly to {upperLookup} but optimized for the case when the search
* key is known to be recent.
* before it is returned, or zero otherwise. Similar to {upperLookup} but optimized for the case when the searched
* checkpoint is probably "recent", defined as being among the last sqrt(N) checkpoints where N is the number of
* checkpoints.
*/
function getAtRecentBlock(History storage self, uint256 blockNumber) internal view returns (uint256) {
function getAtProbablyRecentBlock(History storage self, uint256 blockNumber) internal view returns (uint256) {
require(blockNumber < block.number, "Checkpoints: block not yet mined");
uint32 key = SafeCast.toUint32(blockNumber);

uint256 length = self._checkpoints.length;
uint256 offset = 1;

while (offset <= length && _unsafeAccess(self._checkpoints, length - offset)._blockNumber > key) {
offset <<= 1;
uint256 low = 0;
uint256 high = length;

if (length > 5) {
uint256 mid = length - Math.sqrt(length);
if (key < _unsafeAccess(self._checkpoints, mid)._blockNumber) {
high = mid;
} else {
low = mid + 1;
}
}

uint256 low = offset < length ? length - offset : 0;
uint256 high = length - (offset >> 1);
uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high);

return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
Expand Down Expand Up @@ -225,25 +231,6 @@ library Checkpoints {
return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
}

/**
* @dev Returns the value in the most recent checkpoint with key lower or equal than the search key (similarly to
* {upperLookup}), optimized for the case when the search key is known to be recent.
*/
function upperLookupRecent(Trace224 storage self, uint32 key) internal view returns (uint224) {
uint256 length = self._checkpoints.length;
uint256 offset = 1;

while (offset <= length && _unsafeAccess(self._checkpoints, length - offset)._key > key) {
offset <<= 1;
}

uint256 low = 0 < offset && offset < length ? length - offset : 0;
uint256 high = length - (offset >> 1);
uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high);

return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
}

/**
* @dev Pushes a (`key`, `value`) pair into an ordered list of checkpoints, either by inserting a new checkpoint,
* or by updating the last one.
Expand Down Expand Up @@ -380,25 +367,6 @@ library Checkpoints {
return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
}

/**
* @dev Returns the value in the most recent checkpoint with key lower or equal than the search key (similarly to
* {upperLookup}), optimized for the case when the search key is known to be recent.
*/
function upperLookupRecent(Trace160 storage self, uint96 key) internal view returns (uint160) {
uint256 length = self._checkpoints.length;
uint256 offset = 1;

while (offset <= length && _unsafeAccess(self._checkpoints, length - offset)._key > key) {
offset <<= 1;
}

uint256 low = 0 < offset && offset < length ? length - offset : 0;
uint256 high = length - (offset >> 1);
uint256 pos = _upperBinaryLookup(self._checkpoints, key, low, high);

return pos == 0 ? 0 : _unsafeAccess(self._checkpoints, pos - 1)._value;
}

/**
* @dev Pushes a (`key`, `value`) pair into an ordered list of checkpoints, either by inserting a new checkpoint,
* or by updating the last one.
Expand Down
41 changes: 14 additions & 27 deletions scripts/generate/templates/Checkpoints.js
Original file line number Diff line number Diff line change
Expand Up @@ -70,25 +70,6 @@ function upperLookup(${opts.historyTypeName} storage self, ${opts.keyTypeName} k
uint256 pos = _upperBinaryLookup(self.${opts.checkpointFieldName}, key, 0, length);
return pos == 0 ? 0 : _unsafeAccess(self.${opts.checkpointFieldName}, pos - 1).${opts.valueFieldName};
}
/**
* @dev Returns the value in the most recent checkpoint with key lower or equal than the search key (similarly to
* {upperLookup}), optimized for the case when the search key is known to be recent.
*/
function upperLookupRecent(${opts.historyTypeName} storage self, ${opts.keyTypeName} key) internal view returns (${opts.valueTypeName}) {
uint256 length = self.${opts.checkpointFieldName}.length;
uint256 offset = 1;
while (offset <= length && _unsafeAccess(self.${opts.checkpointFieldName}, length - offset).${opts.keyFieldName} > key) {
offset <<= 1;
}
uint256 low = 0 < offset && offset < length ? length - offset : 0;
uint256 high = length - (offset >> 1);
uint256 pos = _upperBinaryLookup(self.${opts.checkpointFieldName}, key, low, high);
return pos == 0 ? 0 : _unsafeAccess(self.${opts.checkpointFieldName}, pos - 1).${opts.valueFieldName};
}
`;

const legacyOperations = opts => `\
Expand All @@ -115,22 +96,28 @@ function getAtBlock(${opts.historyTypeName} storage self, uint256 blockNumber) i
/**
* @dev Returns the value at a given block number. If a checkpoint is not available at that block, the closest one
* before it is returned, or zero otherwise. Similarly to {upperLookup} but optimized for the case when the search
* key is known to be recent.
* before it is returned, or zero otherwise. Similar to {upperLookup} but optimized for the case when the searched
* checkpoint is probably "recent", defined as being among the last sqrt(N) checkpoints where N is the number of
* checkpoints.
*/
function getAtRecentBlock(${opts.historyTypeName} storage self, uint256 blockNumber) internal view returns (uint256) {
function getAtProbablyRecentBlock(${opts.historyTypeName} storage self, uint256 blockNumber) internal view returns (uint256) {
require(blockNumber < block.number, "Checkpoints: block not yet mined");
uint32 key = SafeCast.toUint32(blockNumber);
uint256 length = self.${opts.checkpointFieldName}.length;
uint256 offset = 1;
while (offset <= length && _unsafeAccess(self.${opts.checkpointFieldName}, length - offset).${opts.keyFieldName} > key) {
offset <<= 1;
uint256 low = 0;
uint256 high = length;
if (length > 5) {
uint256 mid = length - Math.sqrt(length);
if (key < _unsafeAccess(self.${opts.checkpointFieldName}, mid)._blockNumber) {
high = mid;
} else {
low = mid + 1;
}
}
uint256 low = offset < length ? length - offset : 0;
uint256 high = length - (offset >> 1);
uint256 pos = _upperBinaryLookup(self.${opts.checkpointFieldName}, key, low, high);
return pos == 0 ? 0 : _unsafeAccess(self.${opts.checkpointFieldName}, pos - 1).${opts.valueFieldName};
Expand Down
8 changes: 2 additions & 6 deletions scripts/generate/templates/CheckpointsMock.js
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ contract CheckpointsMock {
return _totalCheckpoints.getAtBlock(blockNumber);
}
function getAtRecentBlock(uint256 blockNumber) public view returns (uint256) {
return _totalCheckpoints.getAtRecentBlock(blockNumber);
function getAtProbablyRecentBlock(uint256 blockNumber) public view returns (uint256) {
return _totalCheckpoints.getAtProbablyRecentBlock(blockNumber);
}
function length() public view returns (uint256) {
Expand Down Expand Up @@ -58,10 +58,6 @@ contract Checkpoints${length}Mock {
return _totalCheckpoints.upperLookup(key);
}
function upperLookupRecent(uint${256 - length} key) public view returns (uint224) {
return _totalCheckpoints.upperLookupRecent(key);
}
function length() public view returns (uint256) {
return _totalCheckpoints._checkpoints.length;
}
Expand Down
13 changes: 13 additions & 0 deletions test/token/ERC20/extensions/ERC20Votes.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,19 @@ contract('ERC20Votes', function (accounts) {
);
});

it('recent checkpoints', async function () {
await this.token.delegate(holder, { from: holder });
for (let i = 0; i < 6; i++) {
await this.token.mint(holder, 1);
}
const block = await web3.eth.getBlockNumber();
expect(await this.token.numCheckpoints(holder)).to.be.bignumber.equal('6');
// recent
expect(await this.token.getPastVotes(holder, block - 1)).to.be.bignumber.equal('5');
// non-recent
expect(await this.token.getPastVotes(holder, block - 6)).to.be.bignumber.equal('0');
});

describe('set delegation', function () {
describe('call', function () {
it('delegation with balance', async function () {
Expand Down
22 changes: 17 additions & 5 deletions test/utils/Checkpoints.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,10 @@ contract('Checkpoints', function (accounts) {

it('returns zero as past value', async function () {
await time.advanceBlock();
expect(await this.checkpoint.getAtBlock(await web3.eth.getBlockNumber() - 1)).to.be.bignumber.equal('0');
expect(await this.checkpoint.getAtRecentBlock(await web3.eth.getBlockNumber() - 1)).to.be.bignumber.equal('0');
expect(await this.checkpoint.getAtBlock(await web3.eth.getBlockNumber() - 1))
.to.be.bignumber.equal('0');
expect(await this.checkpoint.getAtProbablyRecentBlock(await web3.eth.getBlockNumber() - 1))
.to.be.bignumber.equal('0');
});
});

Expand All @@ -41,7 +43,7 @@ contract('Checkpoints', function (accounts) {
expect(await this.checkpoint.latest()).to.be.bignumber.equal('3');
});

for (const fn of [ 'getAtBlock(uint256)', 'getAtRecentBlock(uint256)' ]) {
for (const fn of [ 'getAtBlock(uint256)', 'getAtProbablyRecentBlock(uint256)' ]) {
describe(`lookup: ${fn}`, function () {
it('returns past values', async function () {
expect(await this.checkpoint.methods[fn](this.tx1.receipt.blockNumber - 1)).to.be.bignumber.equal('0');
Expand Down Expand Up @@ -78,6 +80,18 @@ contract('Checkpoints', function (accounts) {
expect(await this.checkpoint.length()).to.be.bignumber.equal(lengthBefore.addn(1));
expect(await this.checkpoint.latest()).to.be.bignumber.equal('10');
});

it('more than 5 checkpoints', async function () {
for (let i = 4; i <= 6; i++) {
await this.checkpoint.push(i);
}
expect(await this.checkpoint.length()).to.be.bignumber.equal('6');
const block = await web3.eth.getBlockNumber();
// recent
expect(await this.checkpoint.getAtProbablyRecentBlock(block - 1)).to.be.bignumber.equal('5');
// non-recent
expect(await this.checkpoint.getAtProbablyRecentBlock(block - 9)).to.be.bignumber.equal('0');
});
});
});

Expand All @@ -95,7 +109,6 @@ contract('Checkpoints', function (accounts) {
it('lookup returns 0', async function () {
expect(await this.contract.lowerLookup(0)).to.be.bignumber.equal('0');
expect(await this.contract.upperLookup(0)).to.be.bignumber.equal('0');
expect(await this.contract.upperLookupRecent(0)).to.be.bignumber.equal('0');
});
});

Expand Down Expand Up @@ -149,7 +162,6 @@ contract('Checkpoints', function (accounts) {
const value = last(this.checkpoints.filter(x => i >= x.key))?.value || '0';

expect(await this.contract.upperLookup(i)).to.be.bignumber.equal(value);
expect(await this.contract.upperLookupRecent(i)).to.be.bignumber.equal(value);
}
});
});
Expand Down

0 comments on commit b1bc18a

Please sign in to comment.