Skip to content

feat(hint): post state key hint [IDEA] #807

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 39 additions & 10 deletions src/ethereum_test_base_types/composite_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Base composite types for Ethereum test cases.
"""
from dataclasses import dataclass
from typing import Any, ClassVar, Dict, SupportsBytes, Type, TypeAlias
from typing import Any, ClassVar, Dict, Optional, SupportsBytes, Type, TypeAlias

from pydantic import Field, PrivateAttr, RootModel, TypeAdapter

Expand Down Expand Up @@ -92,22 +92,37 @@ class KeyValueMismatch(Exception):
key: int
want: int
got: int

def __init__(self, address: Address, key: int, want: int, got: int, *args):
hint: Optional[Dict[int, str]] = None

def __init__(
self,
address: Address,
key: int,
want: int,
got: int,
hint: Optional[Dict[int, str]] = None,
*args,
):
super().__init__(args)
self.address = address
self.key = key
self.want = want
self.got = got
self.hint = hint

def __str__(self):
"""Print exception string"""
label_str = ""
if self.address.label is not None:
label_str = f" ({self.address.label})"

key = Hash(self.key)
if self.hint is not None:
key = self.hint[self.key]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to revert to default value of Hash(self.key) if value is not found


return (
f"incorrect value in address {self.address}{label_str} for "
+ f"key {Hash(self.key)}:"
+ f"key {key}:"
+ f" want {HexNumber(self.want)} (dec:{int(self.want)}),"
+ f" got {HexNumber(self.got)} (dec:{int(self.got)})"
)
Expand Down Expand Up @@ -233,7 +248,12 @@ def must_contain(self, address: Address, other: "Storage"):
address=address, key=key, want=self[key], got=other[key]
)

def must_be_equal(self, address: Address, other: "Storage | None"):
def must_be_equal(
self,
address: Address,
other: "Storage | None",
post_hint: Optional[Dict[int, str]] = None,
):
"""
Succeeds only if "self" is equal to "other" storage.
"""
Expand All @@ -243,17 +263,21 @@ def must_be_equal(self, address: Address, other: "Storage | None"):
for key in self.keys() & other.keys():
if self[key] != other[key]:
raise Storage.KeyValueMismatch(
address=address, key=key, want=self[key], got=other[key]
address=address, key=key, want=self[key], got=other[key], hint=post_hint
)

# Test keys contained in either one of the storage objects
for key in self.keys() ^ other.keys():
if key in self:
if self[key] != 0:
raise Storage.KeyValueMismatch(address=address, key=key, want=self[key], got=0)
raise Storage.KeyValueMismatch(
address=address, key=key, want=self[key], got=0, hint=post_hint
)

elif other[key] != 0:
raise Storage.KeyValueMismatch(address=address, key=key, want=0, got=other[key])
raise Storage.KeyValueMismatch(
address=address, key=key, want=0, got=other[key], hint=post_hint
)

def canary(self) -> "Storage":
"""
Expand Down Expand Up @@ -374,7 +398,12 @@ def __str__(self):
+ f"want {self.want}, got {self.got}"
)

def check_alloc(self: "Account", address: Address, account: "Account"):
def check_alloc(
self: "Account",
address: Address,
account: "Account",
post_hint: Optional[Dict[int, str]] = None,
):
"""
Checks the returned alloc against an expected account in post state.
Raises exception on failure.
Expand Down Expand Up @@ -404,7 +433,7 @@ def check_alloc(self: "Account", address: Address, account: "Account"):
)

if "storage" in self.model_fields_set:
self.storage.must_be_equal(address=address, other=account.storage)
self.storage.must_be_equal(address=address, other=account.storage, post_hint=post_hint)

def __bool__(self: "Account") -> bool:
"""
Expand Down
6 changes: 4 additions & 2 deletions src/ethereum_test_specs/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class StateTest(BaseTest):
pre: Alloc
post: Alloc
tx: Transaction
post_hint: Optional[Dict[int, str]] = None
engine_api_error_code: Optional[EngineAPIError] = None
blockchain_test_header_verify: Optional[Header] = None
blockchain_test_rlp_modifier: Optional[Header] = None
Expand Down Expand Up @@ -117,6 +118,7 @@ def make_state_test_fixture(
t8n: TransitionTool,
fork: Fork,
eips: Optional[List[int]] = None,
post_hint: Optional[Dict[int, str]] = None,
) -> Fixture:
"""
Create a fixture from the state test definition.
Expand Down Expand Up @@ -146,7 +148,7 @@ def make_state_test_fixture(
)

try:
self.post.verify_post_alloc(transition_tool_output.alloc)
self.post.verify_post_alloc(transition_tool_output.alloc, post_hint)
except Exception as e:
print_traces(t8n.get_traces())
raise e
Expand Down Expand Up @@ -183,7 +185,7 @@ def generate(
request=request, t8n=t8n, fork=fork, fixture_format=fixture_format, eips=eips
)
elif fixture_format == StateFixture:
return self.make_state_test_fixture(t8n, fork, eips)
return self.make_state_test_fixture(t8n, fork, eips, self.post_hint)

raise Exception(f"Unknown fixture format: {fixture_format}")

Expand Down
6 changes: 3 additions & 3 deletions src/ethereum_test_types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from dataclasses import dataclass
from functools import cached_property
from typing import Any, ClassVar, Dict, Generic, List, Literal, Sequence, Tuple
from typing import Any, ClassVar, Dict, Generic, List, Literal, Optional, Sequence, Tuple

from coincurve.keys import PrivateKey, PublicKey
from ethereum import rlp as eth_rlp
Expand Down Expand Up @@ -268,7 +268,7 @@ def state_root(self) -> bytes:
)
return state_root(state)

def verify_post_alloc(self, got_alloc: "Alloc"):
def verify_post_alloc(self, got_alloc: "Alloc", post_hint: Optional[Dict[int, str]] = None):
"""
Verify that the allocation matches the expected post in the test.
Raises exception on unexpected values.
Expand All @@ -284,7 +284,7 @@ def verify_post_alloc(self, got_alloc: "Alloc"):
got_account = got_alloc.root[address]
assert isinstance(got_account, Account)
assert isinstance(account, Account)
account.check_alloc(address, got_account)
account.check_alloc(address, got_account, post_hint)
else:
raise Alloc.MissingAccount(address)

Expand Down