Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions src/ethproto/w3wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
from environs import Env
from eth_account.account import Account, LocalAccount
from eth_account.signers.base import BaseAccount
from eth_utils.abi import event_abi_to_log_topic
from eth_utils import add_0x_prefix, event_abi_to_log_topic
from hexbytes import HexBytes
from web3.contract import Contract
from web3.contract.contract import ContractEvent
from web3.exceptions import ContractLogicError, ExtraDataLengthError
from web3.middleware import ExtraDataToPOAMiddleware

Expand Down Expand Up @@ -448,7 +449,9 @@ def deploy(self, eth_contract, init_params, from_, **kwargs):
kwargs["from"] = from_
return self.construct(factory, init_params, kwargs)

def get_events(self, eth_wrapper, event_name, filter_kwargs={}):
def get_events(
self, eth_wrapper, event_names: Union[list[Union[str, ContractEvent]], str], filter_kwargs=None
):
"""Returns a list of events given a filter, like this:

>>> provider.get_events(currencywrapper, "Transfer", dict(from_block=0))
Expand All @@ -468,12 +471,36 @@ def get_events(self, eth_wrapper, event_name, filter_kwargs={}):
'blockNumber': 23
})]
"""
contract = eth_wrapper.contract
event = getattr(contract.events, event_name)
if "from_block" not in filter_kwargs:
filter_kwargs["from_block"] = self.get_first_block(eth_wrapper)
event_filter = event.create_filter(**filter_kwargs)
return event_filter.get_all_entries()
if filter_kwargs is None:
filter_kwargs = {}

if isinstance(event_names, (str, ContractEvent)):
# Backwards compatibility, if we don't get a list we're getting a single event name/ref
event_names = [event_names]

topics = {}

for name in event_names:
if isinstance(name, str):
# We got a plain event name, let's get the event from the contract
event: ContractEvent = getattr(eth_wrapper.contract.events, name)
else:
# Assume we already got an event reference
event: ContractEvent = name

topics[event.topic] = event

filter_params = {
"fromBlock": filter_kwargs.get("from_block", self.get_first_block(eth_wrapper)),
"toBlock": filter_kwargs.get("to_block", "latest"),
"address": eth_wrapper.contract.address,
"topics": [list(topics.keys())],
}

logs = self.w3.eth.get_logs(filter_params)

parsed_events = [topics[add_0x_prefix(log["topics"][0].hex())].process_log(log) for log in logs]
return parsed_events

def init_eth_wrapper(self, eth_wrapper, owner, init_params, kwargs):
eth_wrapper.owner = self.address_book.get_account(owner)
Expand Down
30 changes: 30 additions & 0 deletions tests/test_w3.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,36 @@ def test_get_events():
assert event2[0].args.value == 2


def test_get_multiple_events():
provider = wrappers.get_provider("w3")
contract_def = provider.get_contract_def("EventLauncher")
wrapper = wrappers.ETHWrapper.build_from_def(contract_def)

launcher = wrapper(owner="owner")

launcher.launchEvent1(1)
launcher.launchEvent2(2)
launcher.launchEvent1(3)

all_events = provider.get_events(launcher, ["Event1", "Event2"], dict(from_block=0))
assert len(all_events) == 3

all_events_by_reference = provider.get_events(
launcher, [launcher.contract.events.Event1, launcher.contract.events.Event2], dict(from_block=0)
)
assert len(all_events_by_reference) == 3

all_events_mixed = provider.get_events(
launcher, ["Event1", launcher.contract.events.Event2], dict(from_block=0)
)
assert len(all_events_mixed) == 3

single_event_by_reference = provider.get_events(
launcher, launcher.contract.events.Event1, dict(from_block=0)
)
assert len(single_event_by_reference) == 2


@pytest.fixture
def sign_and_send(mocker, hardhat_node):
"""Sets up sign-and-send transact mode with a well-known address, returns the address"""
Expand Down