Skip to content
156 changes: 101 additions & 55 deletions chainerrl/q_functions/state_action_q_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,94 +34,112 @@ def __call__(self, x, a):
return h


class FCSAQFunction(chainer.ChainList, StateActionQFunction):
"""Fully-connected (s,a)-input continuous Q-function.
class FCSAQFunction(MLP, StateActionQFunction):
"""Fully-connected (s,a)-input Q-function.

Args:
n_dim_obs: number of dimensions of observation space
n_dim_action: number of dimensions of action space
n_hidden_channels: number of hidden channels
n_hidden_layers: number of hidden layers
n_dim_obs (int): Number of dimensions of observation space.
n_dim_action (int): Number of dimensions of action space.
n_hidden_channels (int): Number of hidden channels.
n_hidden_layers (int): Number of hidden layers.
nonlinearity (callable): Nonlinearity between layers. It must accept a
Variable as an argument and return a Variable with the same shape.
Nonlinearities with learnable parameters such as PReLU are not
supported. It is not used if n_hidden_layers is zero.
last_wscale (float): Scale of weight initialization of the last layer.
"""

def __init__(self, n_dim_obs, n_dim_action, n_hidden_channels,
n_hidden_layers, nonlinearity=F.relu,
last_wscale=1):
last_wscale=1.):
self.n_input_channels = n_dim_obs + n_dim_action
self.n_hidden_layers = n_hidden_layers
self.n_hidden_channels = n_hidden_channels
self.nonlinearity = nonlinearity

layers = []
assert self.n_hidden_layers >= 1
layers.append(
L.Linear(self.n_input_channels, self.n_hidden_channels))
for i in range(self.n_hidden_layers - 1):
layers.append(
L.Linear(self.n_hidden_channels, self.n_hidden_channels))
layers.append(L.Linear(self.n_hidden_channels, 1,
initialW=LeCunNormal(last_wscale)))
super().__init__(*layers)
self.output = layers[-1]
super().__init__(
in_size=self.n_input_channels,
out_size=1,
hidden_sizes=[self.n_hidden_channels] * self.n_hidden_layers,
nonlinearity=nonlinearity,
last_wscale=last_wscale,
)

def __call__(self, state, action):
h = F.concat((state, action), axis=1)
for layer in self[:-1]:
h = self.nonlinearity(layer(h))
h = self[-1](h)
return h
return super().__call__(h)


class FCLSTMSAQFunction(chainer.Chain, StateActionQFunction,
RecurrentChainMixin):
"""Fully-connected (s,a)-input continuous Q-function.
"""Fully-connected + LSTM (s,a)-input Q-function.

Args:
n_dim_obs: number of dimensions of observation space
n_dim_action: number of dimensions of action space
n_hidden_channels: number of hidden channels
n_hidden_layers: number of hidden layers
n_dim_obs (int): Number of dimensions of observation space.
n_dim_action (int): Number of dimensions of action space.
n_hidden_channels (int): Number of hidden channels.
n_hidden_layers (int): Number of hidden layers.
nonlinearity (callable): Nonlinearity between layers. It must accept a
Variable as an argument and return a Variable with the same shape.
Nonlinearities with learnable parameters such as PReLU are not
supported.
last_wscale (float): Scale of weight initialization of the last layer.
"""

def __init__(self, n_dim_obs, n_dim_action, n_hidden_channels,
n_hidden_layers):
n_hidden_layers, nonlinearity=F.relu, last_wscale=1.):
self.n_input_channels = n_dim_obs + n_dim_action
self.n_hidden_layers = n_hidden_layers
self.n_hidden_channels = n_hidden_channels
self.nonlinearity = nonlinearity
super().__init__()
with self.init_scope():
self.fc = MLP(self.n_input_channels, n_hidden_channels,
[self.n_hidden_channels] * self.n_hidden_layers)
[self.n_hidden_channels] * self.n_hidden_layers,
nonlinearity=nonlinearity,
)
self.lstm = L.LSTM(n_hidden_channels, n_hidden_channels)
self.out = L.Linear(n_hidden_channels, 1)
self.out = L.Linear(n_hidden_channels, 1,
initialW=LeCunNormal(last_wscale))

def __call__(self, x, a):
h = F.concat((x, a), axis=1)
h = F.relu(self.fc(h))
h = self.nonlinearity(self.fc(h))
h = self.lstm(h)
return self.out(h)


class FCBNSAQFunction(MLPBN, StateActionQFunction):
"""Fully-connected (s,a)-input continuous Q-function.
"""Fully-connected + BN (s,a)-input Q-function.

Args:
n_dim_obs: number of dimensions of observation space
n_dim_action: number of dimensions of action space
n_hidden_channels: number of hidden channels
n_hidden_layers: number of hidden layers
n_dim_obs (int): Number of dimensions of observation space.
n_dim_action (int): Number of dimensions of action space.
n_hidden_channels (int): Number of hidden channels.
n_hidden_layers (int): Number of hidden layers.
normalize_input (bool): If set to True, Batch Normalization is applied
to both observations and actions.
nonlinearity (callable): Nonlinearity between layers. It must accept a
Variable as an argument and return a Variable with the same shape.
Nonlinearities with learnable parameters such as PReLU are not
supported. It is not used if n_hidden_layers is zero.
last_wscale (float): Scale of weight initialization of the last layer.
"""

def __init__(self, n_dim_obs, n_dim_action, n_hidden_channels,
n_hidden_layers, normalize_input=True):
n_hidden_layers, normalize_input=True,
nonlinearity=F.relu, last_wscale=1.):
self.n_input_channels = n_dim_obs + n_dim_action
self.n_hidden_layers = n_hidden_layers
self.n_hidden_channels = n_hidden_channels
self.normalize_input = normalize_input
self.nonlinearity = nonlinearity
super().__init__(
in_size=self.n_input_channels, out_size=1,
hidden_sizes=[self.n_hidden_channels] * self.n_hidden_layers,
normalize_input=self.normalize_input)
normalize_input=self.normalize_input,
nonlinearity=nonlinearity,
last_wscale=last_wscale,
)

def __call__(self, state, action):
h = F.concat((state, action), axis=1)
Expand All @@ -130,78 +148,106 @@ def __call__(self, state, action):

class FCBNLateActionSAQFunction(chainer.Chain, StateActionQFunction,
RecurrentChainMixin):
"""Fully-connected (s,a)-input continuous Q-function.
"""Fully-connected + BN (s,a)-input Q-function with late action input.

Actions are not included until the second hidden layer and not normalized.
This architecture is used in the DDPG paper:
http://arxiv.org/abs/1509.02971

Args:
n_dim_obs: number of dimensions of observation space
n_dim_action: number of dimensions of action space
n_hidden_channels: number of hidden channels
n_hidden_layers: number of hidden layers
n_dim_obs (int): Number of dimensions of observation space.
n_dim_action (int): Number of dimensions of action space.
n_hidden_channels (int): Number of hidden channels.
n_hidden_layers (int): Number of hidden layers. It must be greater than
or equal to 1.
normalize_input (bool): If set to True, Batch Normalization is applied
nonlinearity (callable): Nonlinearity between layers. It must accept a
Variable as an argument and return a Variable with the same shape.
Nonlinearities with learnable parameters such as PReLU are not
supported.
last_wscale (float): Scale of weight initialization of the last layer.
"""

def __init__(self, n_dim_obs, n_dim_action, n_hidden_channels,
n_hidden_layers, normalize_input=True):
n_hidden_layers, normalize_input=True,
nonlinearity=F.relu, last_wscale=1.):
assert n_hidden_layers >= 1
self.n_input_channels = n_dim_obs + n_dim_action
self.n_hidden_layers = n_hidden_layers
self.n_hidden_channels = n_hidden_channels
self.normalize_input = normalize_input
self.nonlinearity = nonlinearity

super().__init__()
with self.init_scope():
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment that nonlinearity does not need to be passed to MLPBN because hidden_sizes is empty?

# No need to pass nonlinearity to obs_mlp because it has no
# hidden layers
self.obs_mlp = MLPBN(in_size=n_dim_obs, out_size=n_hidden_channels,
hidden_sizes=[],
normalize_input=normalize_input,
normalize_output=True)
self.mlp = MLP(in_size=n_hidden_channels + n_dim_action,
out_size=1,
hidden_sizes=([self.n_hidden_channels] *
(self.n_hidden_layers - 1)))
(self.n_hidden_layers - 1)),
nonlinearity=nonlinearity,
last_wscale=last_wscale,
)

self.output = self.mlp.output

def __call__(self, state, action):
h = F.relu(self.obs_mlp(state))
h = self.nonlinearity(self.obs_mlp(state))
h = F.concat((h, action), axis=1)
return self.mlp(h)


class FCLateActionSAQFunction(chainer.Chain, StateActionQFunction,
RecurrentChainMixin):
"""Fully-connected (s,a)-input continuous Q-function.
"""Fully-connected (s,a)-input Q-function with late action input.

Actions are not included until the second hidden layer and not normalized.
This architecture is used in the DDPG paper:
http://arxiv.org/abs/1509.02971

Args:
n_dim_obs: number of dimensions of observation space
n_dim_action: number of dimensions of action space
n_hidden_channels: number of hidden channels
n_hidden_layers: number of hidden layers
n_dim_obs (int): Number of dimensions of observation space.
n_dim_action (int): Number of dimensions of action space.
n_hidden_channels (int): Number of hidden channels.
n_hidden_layers (int): Number of hidden layers. It must be greater than
or equal to 1.
nonlinearity (callable): Nonlinearity between layers. It must accept a
Variable as an argument and return a Variable with the same shape.
Nonlinearities with learnable parameters such as PReLU are not
supported.
last_wscale (float): Scale of weight initialization of the last layer.
"""

def __init__(self, n_dim_obs, n_dim_action, n_hidden_channels,
n_hidden_layers):
n_hidden_layers, nonlinearity=F.relu, last_wscale=1.):
assert n_hidden_layers >= 1
self.n_input_channels = n_dim_obs + n_dim_action
self.n_hidden_layers = n_hidden_layers
self.n_hidden_channels = n_hidden_channels
self.nonlinearity = nonlinearity

super().__init__()
with self.init_scope():
# No need to pass nonlinearity to obs_mlp because it has no
# hidden layers
self.obs_mlp = MLP(in_size=n_dim_obs, out_size=n_hidden_channels,
hidden_sizes=[])
self.mlp = MLP(in_size=n_hidden_channels + n_dim_action,
out_size=1,
hidden_sizes=([self.n_hidden_channels] *
(self.n_hidden_layers - 1)))
(self.n_hidden_layers - 1)),
nonlinearity=nonlinearity,
last_wscale=last_wscale,
)

self.output = self.mlp.output

def __call__(self, state, action):
h = F.relu(self.obs_mlp(state))
h = self.nonlinearity(self.obs_mlp(state))
h = F.concat((h, action), axis=1)
return self.mlp(h)
Loading