diff --git a/chainerrl/agents/reinforce.py b/chainerrl/agents/reinforce.py index 5c6f07467..cfc7f7ccc 100644 --- a/chainerrl/agents/reinforce.py +++ b/chainerrl/agents/reinforce.py @@ -140,7 +140,7 @@ def stop_episode_and_train(self, obs, reward, done=False): def accumulate_grad(self): if self.n_backward == 0: - self.model.zerograds() + self.model.cleargrads() # Compute losses losses = [] for r_seq, log_prob_seq, ent_seq in zip(self.reward_sequences, @@ -168,7 +168,7 @@ def batch_update(self): assert len(self.log_prob_sequences) == self.batchsize assert len(self.entropy_sequences) == self.batchsize # Update the model - self.model.zerograds() + assert self.n_backward == 0 self.accumulate_grad() self.optimizer.update() self.n_backward = 0