@@ -119,9 +119,7 @@ def select_action(self, batch: dict[str, Tensor]) -> Tensor:
119
119
batch = self .normalize_inputs (batch )
120
120
if self .config .image_features :
121
121
batch = dict (batch ) # shallow copy so that adding a key doesn't modify the original
122
- batch ["observation.images" ] = torch .stack (
123
- [batch [key ] for key in self .config .image_features ], dim = - 4
124
- )
122
+ batch ["observation.images" ] = [batch [key ] for key in self .config .image_features ]
125
123
126
124
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
127
125
# we are ensembling over.
@@ -149,9 +147,8 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
149
147
batch = self .normalize_inputs (batch )
150
148
if self .config .image_features :
151
149
batch = dict (batch ) # shallow copy so that adding a key doesn't modify the original
152
- batch ["observation.images" ] = torch .stack (
153
- [batch [key ] for key in self .config .image_features ], dim = - 4
154
- )
150
+ batch ["observation.images" ] = [batch [key ] for key in self .config .image_features ]
151
+
155
152
batch = self .normalize_targets (batch )
156
153
actions_hat , (mu_hat , log_sigma_x2_hat ) = self .model (batch )
157
154
@@ -413,11 +410,10 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
413
410
"actions must be provided when using the variational objective in training mode."
414
411
)
415
412
416
- batch_size = (
417
- batch ["observation.images" ]
418
- if "observation.images" in batch
419
- else batch ["observation.environment_state" ]
420
- ).shape [0 ]
413
+ if "observation.images" in batch :
414
+ batch_size = batch ["observation.images" ][0 ].shape [0 ]
415
+ else :
416
+ batch_size = batch ["observation.environment_state" ].shape [0 ]
421
417
422
418
# Prepare the latent for input to the transformer encoder.
423
419
if self .config .use_vae and "action" in batch :
@@ -490,20 +486,21 @@ def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, tuple[Tensor, Tenso
490
486
all_cam_features = []
491
487
all_cam_pos_embeds = []
492
488
493
- for cam_index in range (batch ["observation.images" ].shape [- 4 ]):
494
- cam_features = self .backbone (batch ["observation.images" ][:, cam_index ])["feature_map" ]
495
- # TODO(rcadene, alexander-soare): remove call to `.to` to speedup forward ; precompute and use
496
- # buffer
489
+ # For a list of images, the H and W may vary but H*W is constant.
490
+ for img in batch ["observation.images" ]:
491
+ cam_features = self .backbone (img )["feature_map" ]
497
492
cam_pos_embed = self .encoder_cam_feat_pos_embed (cam_features ).to (dtype = cam_features .dtype )
498
- cam_features = self .encoder_img_feat_input_proj (cam_features ) # (B, C, h, w)
493
+ cam_features = self .encoder_img_feat_input_proj (cam_features )
494
+
495
+ # Rearrange features to (sequence, batch, dim).
496
+ cam_features = einops .rearrange (cam_features , "b c h w -> (h w) b c" )
497
+ cam_pos_embed = einops .rearrange (cam_pos_embed , "b c h w -> (h w) b c" )
498
+
499
499
all_cam_features .append (cam_features )
500
500
all_cam_pos_embeds .append (cam_pos_embed )
501
- # Concatenate camera observation feature maps and positional embeddings along the width dimension,
502
- # and move to (sequence, batch, dim).
503
- all_cam_features = torch .cat (all_cam_features , axis = - 1 )
504
- encoder_in_tokens .extend (einops .rearrange (all_cam_features , "b c h w -> (h w) b c" ))
505
- all_cam_pos_embeds = torch .cat (all_cam_pos_embeds , axis = - 1 )
506
- encoder_in_pos_embed .extend (einops .rearrange (all_cam_pos_embeds , "b c h w -> (h w) b c" ))
501
+
502
+ encoder_in_tokens .extend (torch .cat (all_cam_features , axis = 0 ))
503
+ encoder_in_pos_embed .extend (torch .cat (all_cam_pos_embeds , axis = 0 ))
507
504
508
505
# Stack all tokens along the sequence dimension.
509
506
encoder_in_tokens = torch .stack (encoder_in_tokens , axis = 0 )
0 commit comments