We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 0431025 commit 98da4b1Copy full SHA for 98da4b1
ml-agents/mlagents/trainers/torch/distributions.py
@@ -124,7 +124,7 @@ def sample(self):
124
return torch.multinomial(self.probs, 1)
125
126
def deterministic_sample(self):
127
- return torch.argmax(self.probs).reshape((1, 1))
+ return torch.argmax(self.probs, dim=1, keepdim=True)
128
129
def pdf(self, value):
130
# This function is equivalent to torch.diag(self.probs.T[value.flatten().long()]),
0 commit comments