Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ignite/contrib/handlers/tqdm_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ def attach( # type: ignore[override]
raise ValueError(f"Logging event {event_name.name} is not in allowed events for this engine")

if isinstance(closing_event_name, CallableEventWithFilter):
if closing_event_name.filter != CallableEventWithFilter.default_event_filter:
if closing_event_name.filter is not None:
raise ValueError("Closing Event should not be a filtered event")

if not self._compare_lt(event_name, closing_event_name):
Expand Down
15 changes: 1 addition & 14 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,7 @@ def execute_something():
for e in event_name:
self.add_event_handler(e, handler, *args, **kwargs)
return RemovableEventHandle(event_name, handler, self)
if (
isinstance(event_name, CallableEventWithFilter)
and event_name.filter != CallableEventWithFilter.default_event_filter
):
if isinstance(event_name, CallableEventWithFilter) and event_name.filter is not None:
event_filter = event_name.filter
handler = self._handler_wrapper(handler, event_name, event_filter)

Expand All @@ -312,16 +309,6 @@ def execute_something():

return RemovableEventHandle(event_name, handler, self)

@staticmethod
def _assert_non_filtered_event(event_name: Any) -> None:
if (
isinstance(event_name, CallableEventWithFilter)
and event_name.filter != CallableEventWithFilter.default_event_filter
):
raise TypeError(
"Argument event_name should not be a filtered event, " "please use event without any event filtering"
)

def has_event_handler(self, handler: Callable, event_name: Optional[Any] = None) -> bool:
"""Check if the specified event has the specified handler.

Expand Down
13 changes: 8 additions & 5 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numbers
import warnings
import weakref
from enum import Enum
from types import DynamicClassAttribute
Expand Down Expand Up @@ -27,8 +28,6 @@ class CallableEventWithFilter:
"""

def __init__(self, value: str, event_filter: Optional[Callable] = None, name: Optional[str] = None) -> None:
if event_filter is None:
event_filter = CallableEventWithFilter.default_event_filter
self.filter = event_filter

if not hasattr(self, "_value_"):
Expand Down Expand Up @@ -117,11 +116,15 @@ def wrapper(engine: "Engine", event: int) -> bool:

@staticmethod
def default_event_filter(engine: "Engine", event: int) -> bool:
"""Default event filter."""
"""Default event filter. This method is is deprecated and will be removed. Please, use None instead"""
warnings.warn("Events.default_event_filter is deprecated and will be removed. Please, use None instead")
return True

def __str__(self) -> str:
return "<event=%s, filter=%r>" % (self.name, self.filter)
def __repr__(self) -> str:
out = f"Events.{self.name}"
if self.filter is not None:
out += f"(filter={self.filter})"
return out

def __eq__(self, other: Any) -> bool:
if isinstance(other, CallableEventWithFilter):
Expand Down
64 changes: 35 additions & 29 deletions tests/ignite/engine/test_custom_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,46 +148,52 @@ def test_callable_events_with_wrong_inputs():
with pytest.raises(ValueError, match=r"but will be called with"):
Events.ITERATION_STARTED(event_filter=lambda x: x)


def test_callable_events():

assert isinstance(Events.ITERATION_STARTED.value, str)

def foo(engine, event):
with pytest.warns(UserWarning, match=r"default_event_filter is deprecated and will be removed"):
Events.default_event_filter(None, None)


@pytest.mark.parametrize(
"event",
[
Events.ITERATION_STARTED,
Events.ITERATION_COMPLETED,
Events.EPOCH_STARTED,
Events.EPOCH_COMPLETED,
Events.GET_BATCH_STARTED,
Events.GET_BATCH_COMPLETED,
Events.STARTED,
Events.COMPLETED,
],
)
def test_callable_events(event):

assert isinstance(event.value, str)

def foo(engine, _):
return True

ret = Events.ITERATION_STARTED(event_filter=foo)
ret = event(event_filter=foo)
assert isinstance(ret, CallableEventWithFilter)
assert ret == Events.ITERATION_STARTED
assert ret == event
assert ret.filter == foo
assert isinstance(Events.ITERATION_STARTED.value, str)

# assert ret in Events
assert Events.ITERATION_STARTED.name in f"{ret}"
# assert ret in State.event_to_attr
assert event.name in f"{ret}"

ret = Events.ITERATION_STARTED(every=10)
ret = event(every=10)
assert isinstance(ret, CallableEventWithFilter)
assert ret == Events.ITERATION_STARTED
assert ret == event
assert ret.filter is not None
assert event.name in f"{ret}"

# assert ret in Events
assert Events.ITERATION_STARTED.name in f"{ret}"
# assert ret in State.event_to_attr

ret = Events.ITERATION_STARTED(once=10)
ret = event(once=10)
assert isinstance(ret, CallableEventWithFilter)
assert ret == Events.ITERATION_STARTED
assert ret == event
assert ret.filter is not None
assert event.name in f"{ret}"

# assert ret in Events
assert Events.ITERATION_STARTED.name in f"{ret}"
# assert ret in State.event_to_attr

def _attach(e1, e2):
assert id(e1) != id(e2)

_attach(Events.ITERATION_STARTED(every=10), Events.ITERATION_COMPLETED(every=10))
ret = event
assert isinstance(ret, CallableEventWithFilter)
assert ret.filter is None
assert event.name in f"{ret}"


def test_callable_events_every_eq_one():
Expand Down