@@ -235,6 +235,11 @@ class IQN(dqn.DQN):
235235 to sample from the return distribution at the next state.
236236 quantile_thresholds_K (int): Number of quantile thresholds used to
237237 compute greedy actions.
238+ act_deterministically (bool): IQN's action selection is by default
239+ stochastic as it samples quantile thresholds every time it acts,
240+ even for evaluation. If this option is set to True, it uses
241+ equally spaced quantile thresholds instead of randomly sampled ones
242+ for evaluation, making its action selection deterministic.
238243
239244 For other arguments, see chainerrl.agents.DQN.
240245 """
@@ -246,6 +251,7 @@ def __init__(self, *args, **kwargs):
246251 self .quantile_thresholds_N_prime = kwargs .pop (
247252 'quantile_thresholds_N_prime' , 64 )
248253 self .quantile_thresholds_K = kwargs .pop ('quantile_thresholds_K' , 32 )
254+ self .act_deterministically = kwargs .pop ('act_deterministically' , False )
249255 super ().__init__ (* args , ** kwargs )
250256
251257 def _compute_target_values (self , exp_batch ):
@@ -357,7 +363,16 @@ def _evaluate_model_and_update_recurrent_states(self, batch_obs, test):
357363 batch_xs , self .train_recurrent_states )
358364 else :
359365 tau2av = self .model (batch_xs )
360- taus_tilde = self .xp .random .uniform (
361- 0 , 1 ,
362- size = (len (batch_obs ), self .quantile_thresholds_K )).astype ('f' )
366+ if test and self .act_deterministically :
367+ # Instead of uniform sampling, use a deterministic sequence of
368+ # equally spaced numbers from 0 to 1 as quantile thresholds.
369+ taus_tilde = self .xp .broadcast_to (
370+ self .xp .linspace (
371+ 0 , 1 , num = self .quantile_thresholds_K , dtype = np .float32 ),
372+ (len (batch_obs ), self .quantile_thresholds_K ),
373+ )
374+ else :
375+ taus_tilde = self .xp .random .uniform (
376+ 0 , 1 ,
377+ size = (len (batch_obs ), self .quantile_thresholds_K )).astype ('f' )
363378 return tau2av (taus_tilde )
0 commit comments