|
1 | 1 | from typing import Dict, cast, List, Tuple, Optional
|
| 2 | +from collections import defaultdict |
2 | 3 | from mlagents.trainers.torch.components.reward_providers.extrinsic_reward_provider import (
|
3 | 4 | ExtrinsicRewardProvider,
|
4 | 5 | )
|
5 | 6 | import numpy as np
|
6 |
| -import math |
7 | 7 | from mlagents.torch_utils import torch, default_device
|
8 | 8 |
|
9 | 9 | from mlagents.trainers.buffer import (
|
@@ -381,116 +381,109 @@ def _evaluate_by_sequence_team(
|
381 | 381 | num_experiences = self_obs[0].shape[0]
|
382 | 382 | all_next_value_mem = AgentBufferField()
|
383 | 383 | all_next_baseline_mem = AgentBufferField()
|
384 |
| - # In the buffer, the 1st sequence are the ones that are padded. So if seq_len = 3 and |
385 |
| - # trajectory is of length 10, the 1st sequence is [pad,pad,obs]. |
386 |
| - # Compute the number of elements in this padded seq. |
387 |
| - leftover = num_experiences % self.policy.sequence_length |
388 |
| - |
389 |
| - # Compute values for the potentially truncated initial sequence |
390 | 384 |
|
391 |
| - first_seq_len = leftover if leftover > 0 else self.policy.sequence_length |
392 |
| - |
393 |
| - self_seq_obs = [] |
394 |
| - groupmate_seq_obs = [] |
395 |
| - groupmate_seq_act = [] |
396 |
| - seq_obs = [] |
397 |
| - for _self_obs in self_obs: |
398 |
| - first_seq_obs = _self_obs[0:first_seq_len] |
399 |
| - seq_obs.append(first_seq_obs) |
400 |
| - self_seq_obs.append(seq_obs) |
401 |
| - |
402 |
| - for groupmate_obs, groupmate_action in zip(obs, actions): |
403 |
| - seq_obs = [] |
404 |
| - for _obs in groupmate_obs: |
405 |
| - first_seq_obs = _obs[0:first_seq_len] |
406 |
| - seq_obs.append(first_seq_obs) |
407 |
| - groupmate_seq_obs.append(seq_obs) |
408 |
| - _act = groupmate_action.slice(0, first_seq_len) |
409 |
| - groupmate_seq_act.append(_act) |
410 |
| - |
411 |
| - # For the first sequence, the initial memory should be the one at the |
412 |
| - # beginning of this trajectory. |
413 |
| - for _ in range(first_seq_len): |
414 |
| - all_next_value_mem.append(ModelUtils.to_numpy(init_value_mem.squeeze())) |
415 |
| - all_next_baseline_mem.append( |
416 |
| - ModelUtils.to_numpy(init_baseline_mem.squeeze()) |
417 |
| - ) |
418 |
| - |
419 |
| - all_seq_obs = self_seq_obs + groupmate_seq_obs |
420 |
| - init_values, _value_mem = self.critic.critic_pass( |
421 |
| - all_seq_obs, init_value_mem, sequence_length=first_seq_len |
422 |
| - ) |
423 |
| - all_values = { |
424 |
| - signal_name: [init_values[signal_name]] |
425 |
| - for signal_name in init_values.keys() |
426 |
| - } |
| 385 | + # When using LSTM, we need to divide the trajectory into sequences of equal length. Sometimes, |
| 386 | + # that division isn't even, and we must pad the leftover sequence. |
| 387 | + # In the buffer, the last sequence are the ones that are padded. So if seq_len = 3 and |
| 388 | + # trajectory is of length 10, the last sequence is [obs,pad,pad]. |
| 389 | + # Compute the number of elements in this padded seq. |
| 390 | + leftover_seq_len = num_experiences % self.policy.sequence_length |
427 | 391 |
|
428 |
| - groupmate_obs_and_actions = (groupmate_seq_obs, groupmate_seq_act) |
429 |
| - init_baseline, _baseline_mem = self.critic.baseline( |
430 |
| - self_seq_obs[0], |
431 |
| - groupmate_obs_and_actions, |
432 |
| - init_baseline_mem, |
433 |
| - sequence_length=first_seq_len, |
434 |
| - ) |
435 |
| - all_baseline = { |
436 |
| - signal_name: [init_baseline[signal_name]] |
437 |
| - for signal_name in init_baseline.keys() |
438 |
| - } |
| 392 | + all_values: Dict[str, List[np.ndarray]] = defaultdict(list) |
| 393 | + all_baseline: Dict[str, List[np.ndarray]] = defaultdict(list) |
| 394 | + _baseline_mem = init_baseline_mem |
| 395 | + _value_mem = init_value_mem |
439 | 396 |
|
440 | 397 | # Evaluate other trajectories, carrying over _mem after each
|
441 | 398 | # trajectory
|
442 |
| - for seq_num in range( |
443 |
| - 1, math.ceil((num_experiences) / (self.policy.sequence_length)) |
444 |
| - ): |
| 399 | + for seq_num in range(num_experiences // self.policy.sequence_length): |
445 | 400 | for _ in range(self.policy.sequence_length):
|
446 | 401 | all_next_value_mem.append(ModelUtils.to_numpy(_value_mem.squeeze()))
|
447 | 402 | all_next_baseline_mem.append(
|
448 | 403 | ModelUtils.to_numpy(_baseline_mem.squeeze())
|
449 | 404 | )
|
450 | 405 |
|
451 |
| - start = seq_num * self.policy.sequence_length - ( |
452 |
| - self.policy.sequence_length - leftover |
453 |
| - ) |
454 |
| - end = (seq_num + 1) * self.policy.sequence_length - ( |
455 |
| - self.policy.sequence_length - leftover |
456 |
| - ) |
| 406 | + start = seq_num * self.policy.sequence_length |
| 407 | + end = (seq_num + 1) * self.policy.sequence_length |
457 | 408 |
|
458 | 409 | self_seq_obs = []
|
459 | 410 | groupmate_seq_obs = []
|
460 | 411 | groupmate_seq_act = []
|
461 | 412 | seq_obs = []
|
462 | 413 | for _self_obs in self_obs:
|
463 |
| - seq_obs.append(_obs[start:end]) |
| 414 | + seq_obs.append(_self_obs[start:end]) |
464 | 415 | self_seq_obs.append(seq_obs)
|
465 | 416 |
|
466 |
| - for groupmate_obs, team_action in zip(obs, actions): |
| 417 | + for groupmate_obs, groupmate_action in zip(obs, actions): |
467 | 418 | seq_obs = []
|
468 |
| - for (_obs,) in groupmate_obs: |
469 |
| - first_seq_obs = _obs[start:end] |
470 |
| - seq_obs.append(first_seq_obs) |
| 419 | + for _obs in groupmate_obs: |
| 420 | + sliced_seq_obs = _obs[start:end] |
| 421 | + seq_obs.append(sliced_seq_obs) |
471 | 422 | groupmate_seq_obs.append(seq_obs)
|
472 |
| - _act = team_action.slice(start, end) |
| 423 | + _act = groupmate_action.slice(start, end) |
473 | 424 | groupmate_seq_act.append(_act)
|
474 | 425 |
|
475 | 426 | all_seq_obs = self_seq_obs + groupmate_seq_obs
|
476 | 427 | values, _value_mem = self.critic.critic_pass(
|
477 | 428 | all_seq_obs, _value_mem, sequence_length=self.policy.sequence_length
|
478 | 429 | )
|
479 |
| - all_values = { |
480 |
| - signal_name: [init_values[signal_name]] for signal_name in values.keys() |
481 |
| - } |
| 430 | + for signal_name, _val in values.items(): |
| 431 | + all_values[signal_name].append(_val) |
482 | 432 |
|
483 | 433 | groupmate_obs_and_actions = (groupmate_seq_obs, groupmate_seq_act)
|
484 | 434 | baselines, _baseline_mem = self.critic.baseline(
|
485 | 435 | self_seq_obs[0],
|
486 | 436 | groupmate_obs_and_actions,
|
487 | 437 | _baseline_mem,
|
488 |
| - sequence_length=first_seq_len, |
| 438 | + sequence_length=self.policy.sequence_length, |
| 439 | + ) |
| 440 | + for signal_name, _val in baselines.items(): |
| 441 | + all_baseline[signal_name].append(_val) |
| 442 | + |
| 443 | + # Compute values for the potentially truncated initial sequence |
| 444 | + if leftover_seq_len > 0: |
| 445 | + self_seq_obs = [] |
| 446 | + groupmate_seq_obs = [] |
| 447 | + groupmate_seq_act = [] |
| 448 | + seq_obs = [] |
| 449 | + for _self_obs in self_obs: |
| 450 | + last_seq_obs = _self_obs[-leftover_seq_len:] |
| 451 | + seq_obs.append(last_seq_obs) |
| 452 | + self_seq_obs.append(seq_obs) |
| 453 | + |
| 454 | + for groupmate_obs, groupmate_action in zip(obs, actions): |
| 455 | + seq_obs = [] |
| 456 | + for _obs in groupmate_obs: |
| 457 | + last_seq_obs = _obs[-leftover_seq_len:] |
| 458 | + seq_obs.append(last_seq_obs) |
| 459 | + groupmate_seq_obs.append(seq_obs) |
| 460 | + _act = groupmate_action.slice(len(_obs) - leftover_seq_len, len(_obs)) |
| 461 | + groupmate_seq_act.append(_act) |
| 462 | + |
| 463 | + # For the last sequence, the initial memory should be the one at the |
| 464 | + # beginning of this trajectory. |
| 465 | + seq_obs = [] |
| 466 | + for _ in range(leftover_seq_len): |
| 467 | + all_next_value_mem.append(ModelUtils.to_numpy(_value_mem.squeeze())) |
| 468 | + all_next_baseline_mem.append( |
| 469 | + ModelUtils.to_numpy(_baseline_mem.squeeze()) |
| 470 | + ) |
| 471 | + |
| 472 | + all_seq_obs = self_seq_obs + groupmate_seq_obs |
| 473 | + last_values, _value_mem = self.critic.critic_pass( |
| 474 | + all_seq_obs, _value_mem, sequence_length=leftover_seq_len |
| 475 | + ) |
| 476 | + for signal_name, _val in last_values.items(): |
| 477 | + all_values[signal_name].append(_val) |
| 478 | + groupmate_obs_and_actions = (groupmate_seq_obs, groupmate_seq_act) |
| 479 | + last_baseline, _baseline_mem = self.critic.baseline( |
| 480 | + self_seq_obs[0], |
| 481 | + groupmate_obs_and_actions, |
| 482 | + _baseline_mem, |
| 483 | + sequence_length=leftover_seq_len, |
489 | 484 | )
|
490 |
| - all_baseline = { |
491 |
| - signal_name: [baselines[signal_name]] |
492 |
| - for signal_name in baselines.keys() |
493 |
| - } |
| 485 | + for signal_name, _val in last_baseline.items(): |
| 486 | + all_baseline[signal_name].append(_val) |
494 | 487 | # Create one tensor per reward signal
|
495 | 488 | all_value_tensors = {
|
496 | 489 | signal_name: torch.cat(value_list, dim=0)
|
|
0 commit comments