Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion ignite/engine/deterministic.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def _setup_engine(self) -> None:
# Below we define initial counter value for _run_once_on_dataset to measure a single epoch
if self.state.epoch_length is not None:
iteration %= self.state.epoch_length
self._init_iter.append(iteration)
self._init_iter = iteration

# restore rng state if in the middle
in_the_middle = self.state.iteration % self._dataloader_len > 0 if self._dataloader_len is not None else False
Expand Down
16 changes: 6 additions & 10 deletions ignite/engine/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __init__(self, process_function: Callable[["Engine", Any], Any]):
self._allowed_events = [] # type: List[EventEnum]

self._dataloader_iter = None # type: Optional[Iterator[Any]]
self._init_iter = [] # type: List[int]
self._init_iter = None # type: Optional[int]

self.register_events(*Events)

Expand Down Expand Up @@ -723,7 +723,7 @@ def switch_batch(engine):
if self.should_terminate:
# If engine was terminated and now is resuming from terminated state
# we need to initialize iter_counter as 0
self._init_iter.append(0)
self._init_iter = 0

self.state.dataloader = data
return self._internal_run()
Expand Down Expand Up @@ -756,12 +756,12 @@ def _setup_dataloader_iter(self) -> None:
def _setup_engine(self) -> None:
self._setup_dataloader_iter()

if len(self._init_iter) == 0:
if self._init_iter is None:
iteration = self.state.iteration
# Below we define initial counter value for _run_once_on_dataset to measure a single epoch
if self.state.epoch_length is not None:
iteration %= self.state.epoch_length
self._init_iter.append(iteration)
self._init_iter = iteration

def _internal_run(self) -> State:
self.should_terminate = self.should_terminate_single_epoch = False
Expand Down Expand Up @@ -832,12 +832,8 @@ def _run_once_on_dataset(self) -> float:
start_time = time.time()

# We need to setup iter_counter > 0 if we resume from an iteration
if len(self._init_iter) > 1:
raise RuntimeError(
"Internal error, len(self._init_iter) should 0 or 1, "
f"but got: {len(self._init_iter)}, {self._init_iter}"
)
iter_counter = self._init_iter.pop() if len(self._init_iter) > 0 else 0
iter_counter = 0 if self._init_iter is None else self._init_iter
self._init_iter = None
should_exit = False
try:
if self._dataloader_iter is None:
Expand Down
6 changes: 3 additions & 3 deletions ignite/engine/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,9 +280,9 @@ class CustomEvents(EventEnum):
"""Event attribute indicating epoch is ended."""

STARTED = "started"
"""triggered when engines run is started."""
"""triggered when engine's run is started."""
COMPLETED = "completed"
""""triggered when engines run is completed"""
"""triggered when engine's run is completed"""

ITERATION_STARTED = "iteration_started"
"""triggered when an iteration is started."""
Expand All @@ -297,7 +297,7 @@ class CustomEvents(EventEnum):
"""triggered after the batch is fetched."""

DATALOADER_STOP_ITERATION = "dataloader_stop_iteration"
""""engines specific event triggered when dataloader has no more data to provide"""
"""engine's specific event triggered when dataloader has no more data to provide"""
TERMINATE = "terminate"
"""triggered when the run is about to end completely, after receiving terminate() call."""
TERMINATE_SINGLE_EPOCH = "terminate_single_epoch"
Expand Down
2 changes: 1 addition & 1 deletion tests/ignite/engine/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ def test__setup_engine():
data = list(range(100))
engine.state.dataloader = data
engine._setup_engine()
assert len(engine._init_iter) == 1 and engine._init_iter[0] == 10
assert engine._init_iter == 10


def test_run_asserts():
Expand Down