-
Notifications
You must be signed in to change notification settings - Fork 4.3k
[refactor] Refactor normalizers and encoders #4275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,28 +1,19 @@ | ||
import torch | ||
from torch import nn | ||
from typing import Tuple, Optional | ||
|
||
from mlagents.trainers.exception import UnityTrainerException | ||
|
||
class VectorEncoder(nn.Module): | ||
def __init__(self, input_size, hidden_size, num_layers, **kwargs): | ||
super().__init__(**kwargs) | ||
self.layers = [nn.Linear(input_size, hidden_size)] | ||
for _ in range(num_layers - 1): | ||
self.layers.append(nn.Linear(hidden_size, hidden_size)) | ||
self.layers.append(nn.ReLU()) | ||
self.seq_layers = nn.Sequential(*self.layers) | ||
|
||
def forward(self, inputs): | ||
return self.seq_layers(inputs) | ||
import torch | ||
from torch import nn | ||
|
||
|
||
class Normalizer(nn.Module): | ||
def __init__(self, vec_obs_size, **kwargs): | ||
super().__init__(**kwargs) | ||
def __init__(self, vec_obs_size: int): | ||
super().__init__() | ||
self.normalization_steps = torch.tensor(1) | ||
self.running_mean = torch.zeros(vec_obs_size) | ||
self.running_variance = torch.ones(vec_obs_size) | ||
|
||
def forward(self, inputs): | ||
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | ||
normalized_state = torch.clamp( | ||
(inputs - self.running_mean) | ||
/ torch.sqrt(self.running_variance / self.normalization_steps), | ||
|
@@ -31,7 +22,7 @@ def forward(self, inputs): | |
) | ||
return normalized_state | ||
|
||
def update(self, vector_input): | ||
def update(self, vector_input: torch.Tensor) -> None: | ||
steps_increment = vector_input.size()[0] | ||
total_new_steps = self.normalization_steps + steps_increment | ||
|
||
|
@@ -66,14 +57,96 @@ def conv_output_shape(h_w, kernel_size=1, stride=1, pad=0, dilation=1): | |
return h, w | ||
|
||
|
||
def pool_out_shape(h_w, kernel_size): | ||
def pool_out_shape(h_w: Tuple[int, int], kernel_size: int) -> Tuple[int, int]: | ||
height = (h_w[0] - kernel_size) // 2 + 1 | ||
width = (h_w[1] - kernel_size) // 2 + 1 | ||
return height, width | ||
|
||
|
||
class VectorEncoder(nn.Module): | ||
def __init__( | ||
self, | ||
input_size: int, | ||
hidden_size: int, | ||
num_layers: int, | ||
normalize: bool = False, | ||
): | ||
self.normalizer: Optional[Normalizer] = None | ||
super().__init__() | ||
self.layers = [nn.Linear(input_size, hidden_size)] | ||
if normalize: | ||
self.normalizer = Normalizer(input_size) | ||
|
||
for _ in range(num_layers - 1): | ||
self.layers.append(nn.Linear(hidden_size, hidden_size)) | ||
self.layers.append(nn.ReLU()) | ||
self.seq_layers = nn.Sequential(*self.layers) | ||
|
||
def forward(self, inputs: torch.Tensor) -> None: | ||
if self.normalizer is not None: | ||
inputs = self.normalizer(inputs) | ||
return self.seq_layers(inputs) | ||
|
||
def copy_normalization(self, other_encoder: "VectorEncoder") -> None: | ||
if self.normalizer is not None and other_encoder.normalizer is not None: | ||
self.normalizer.copy_from(other_encoder.normalizer) | ||
|
||
def update_normalization(self, inputs: torch.Tensor) -> None: | ||
if self.normalizer is not None: | ||
self.normalizer.update(inputs) | ||
|
||
|
||
class VectorAndUnnormalizedInputEncoder(VectorEncoder): | ||
""" | ||
Encoder for concatenated vector input (can be normalized) and unnormalized vector input. | ||
This is used for passing inputs to the network that should not be normalized, such as | ||
actions in the case of a Q function or task parameterizations. It will result in an encoder with | ||
this structure: | ||
____________ ____________ ____________ | ||
| Vector | | Normalize | | Fully | | ||
| | --> | | --> | Connected | ___________ | ||
|____________| |____________| | | | Output | | ||
____________ | | --> | | | ||
|Unnormalized| | | |___________| | ||
| Input | ---------------------> | | | ||
|____________| |____________| | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_size: int, | ||
hidden_size: int, | ||
unnormalized_input_size: int, | ||
num_layers: int, | ||
normalize: bool = False, | ||
): | ||
super().__init__( | ||
input_size + unnormalized_input_size, | ||
hidden_size, | ||
num_layers, | ||
normalize=False, | ||
) | ||
if normalize: | ||
self.normalizer = Normalizer(input_size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is not really an Action Vector Encoder, it also encodes "inputs" and concatenates it with the actions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Renamed to |
||
else: | ||
self.normalizer = None | ||
|
||
def forward( # pylint: disable=W0221 | ||
self, inputs: torch.Tensor, unnormalized_inputs: Optional[torch.Tensor] = None | ||
) -> None: | ||
if unnormalized_inputs is None: | ||
raise UnityTrainerException( | ||
"Attempted to call an VectorAndUnnormalizedInputEncoder without an unnormalized input." | ||
) # Fix mypy errors about method parameters. | ||
if self.normalizer is not None: | ||
inputs = self.normalizer(inputs) | ||
return self.seq_layers(torch.cat([inputs, unnormalized_inputs], dim=-1)) | ||
|
||
|
||
class SimpleVisualEncoder(nn.Module): | ||
def __init__(self, height, width, initial_channels, output_size): | ||
def __init__( | ||
self, height: int, width: int, initial_channels: int, output_size: int | ||
): | ||
super().__init__() | ||
self.h_size = output_size | ||
conv_1_hw = conv_output_shape((height, width), 8, 4) | ||
|
@@ -84,7 +157,7 @@ def __init__(self, height, width, initial_channels, output_size): | |
self.conv2 = nn.Conv2d(16, 32, [4, 4], [2, 2]) | ||
self.dense = nn.Linear(self.final_flat, self.h_size) | ||
|
||
def forward(self, visual_obs): | ||
def forward(self, visual_obs: torch.Tensor) -> None: | ||
conv_1 = torch.relu(self.conv1(visual_obs)) | ||
conv_2 = torch.relu(self.conv2(conv_1)) | ||
# hidden = torch.relu(self.dense(conv_2.view([-1, self.final_flat]))) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we add
pool_out_shape
andconv_output_shape
to the torch utils file?