diff --git a/chainerrl/replay_buffer.py b/chainerrl/replay_buffer.py index ffc9b68b5..3656e4cd3 100644 --- a/chainerrl/replay_buffer.py +++ b/chainerrl/replay_buffer.py @@ -10,6 +10,7 @@ from abc import ABCMeta from abc import abstractmethod from abc import abstractproperty +import collections import numpy as np import six.moves.cPickle as pickle @@ -153,6 +154,10 @@ def save(self, filename): def load(self, filename): with open(filename, 'rb') as f: self.memory = pickle.load(f) + if isinstance(self.memory, collections.deque): + # Load v0.2 + self.memory = RandomAccessQueue( + self.memory, maxlen=self.memory.maxlen) def stop_current_episode(self): pass @@ -281,7 +286,23 @@ def save(self, filename): def load(self, filename): with open(filename, 'rb') as f: - self.memory, self.episodic_memory = pickle.load(f) + memory = pickle.load(f) + if isinstance(memory, tuple): + self.memory, self.episodic_memory = memory + else: + # Load v0.2 + # FIXME: The code works with EpisodicReplayBuffer + # but not with PrioritizedEpisodicReplayBuffer + self.memory = RandomAccessQueue(memory) + self.episodic_memory = RandomAccessQueue() + + # Recover episodic_memory with best effort. + episode = [] + for item in self.memory: + episode.append(item) + if item['is_state_terminal']: + self.episodic_memory.append(episode) + episode = [] def stop_current_episode(self): if self.current_episode: