Skip to content

Commit f6a588b

Browse files
authored
Refactor replay based scripts (vwxyzjn#173)
* Fix the seed issue: see vwxyzjn#171 * Quick fix * log `episodic_length` * Fix vwxyzjn#172 * Fix vwxyzjn#148 and vwxyzjn#172-style problem for SAC * Add benchmark scripts * add sac script * Removes gradient clipping reference * use the latest reproduction script * Remove past reproducibility script * update documentation
1 parent a0aa8ed commit f6a588b

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

cleanrl/sac_continuous_action.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ def parse_args():
4747
help="the discount factor gamma")
4848
parser.add_argument("--tau", type=float, default=0.005,
4949
help="target smoothing coefficient (default: 0.005)")
50-
parser.add_argument("--max-grad-norm", type=float, default=0.5,
51-
help="the maximum norm for the gradient clipping")
5250
parser.add_argument("--batch-size", type=int, default=256,
5351
help="the batch size of sample from the reply memory")
5452
parser.add_argument("--exploration-noise", type=float, default=0.1,
@@ -180,7 +178,7 @@ def to(self, device):
180178
device = torch.device("cuda" if torch.cuda.is_available() and args.cuda else "cpu")
181179

182180
# env setup
183-
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, 0, 0, args.capture_video, run_name)])
181+
envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)])
184182
assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported"
185183

186184
max_action = float(envs.single_action_space.high[0])
@@ -232,6 +230,7 @@ def to(self, device):
232230
if "episode" in info.keys():
233231
print(f"global_step={global_step}, episodic_return={info['episode']['r']}")
234232
writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step)
233+
writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step)
235234
break
236235

237236
# TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
@@ -262,7 +261,6 @@ def to(self, device):
262261

263262
q_optimizer.zero_grad()
264263
qf_loss.backward()
265-
nn.utils.clip_grad_norm_(list(qf1.parameters()) + list(qf2.parameters()), args.max_grad_norm)
266264
q_optimizer.step()
267265

268266
if global_step % args.policy_frequency == 0: # TD 3 Delayed update support
@@ -277,7 +275,6 @@ def to(self, device):
277275

278276
actor_optimizer.zero_grad()
279277
actor_loss.backward()
280-
nn.utils.clip_grad_norm_(list(actor.parameters()), args.max_grad_norm)
281278
actor_optimizer.step()
282279

283280
if args.autotune:
@@ -298,12 +295,13 @@ def to(self, device):
298295
target_param.data.copy_(args.tau * param.data + (1 - args.tau) * target_param.data)
299296

300297
if global_step % 100 == 0:
298+
writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
299+
writer.add_scalar("losses/qf2_values", qf2_a_values.mean().item(), global_step)
301300
writer.add_scalar("losses/qf1_loss", qf1_loss.item(), global_step)
302301
writer.add_scalar("losses/qf2_loss", qf2_loss.item(), global_step)
303302
writer.add_scalar("losses/qf_loss", qf_loss.item() / 2.0, global_step)
304303
writer.add_scalar("losses/actor_loss", actor_loss.item(), global_step)
305304
writer.add_scalar("losses/alpha", alpha, global_step)
306-
writer.add_scalar("losses/qf1_values", qf1_a_values.mean().item(), global_step)
307305
print("SPS:", int(global_step / (time.time() - start_time)))
308306
writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step)
309307
if args.autotune:

0 commit comments

Comments
 (0)