Skip to content

Commit c11231d

Browse files
author
Ervin T
authored
Get step from policy (#3223)
1 parent 0a7977c commit c11231d

File tree

2 files changed

+2
-0
lines changed

2 files changed

+2
-0
lines changed

ml-agents/mlagents/trainers/ppo/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None:
258258
if not isinstance(policy, PPOPolicy):
259259
raise RuntimeError("Non-PPOPolicy passed to PPOTrainer.add_policy()")
260260
self.policy = policy
261+
self.step = policy.get_current_step()
261262

262263
def get_policy(self, name_behavior_id: str) -> TFPolicy:
263264
"""

ml-agents/mlagents/trainers/sac/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ def add_policy(self, name_behavior_id: str, policy: TFPolicy) -> None:
340340
if not isinstance(policy, SACPolicy):
341341
raise RuntimeError("Non-SACPolicy passed to SACTrainer.add_policy()")
342342
self.policy = policy
343+
self.step = policy.get_current_step()
343344

344345
def get_policy(self, name_behavior_id: str) -> TFPolicy:
345346
"""

0 commit comments

Comments
 (0)