From fc011b893fb26cf4dfa17088b59555ff83936914 Mon Sep 17 00:00:00 2001 From: GDM Neurolab Date: Thu, 20 Nov 2025 16:28:32 -0800 Subject: [PATCH] internal PiperOrigin-RevId: 834956434 --- .../library/two_armed_bandits.py | 105 ++++++++++-------- 1 file changed, 57 insertions(+), 48 deletions(-) diff --git a/disentangled_rnns/library/two_armed_bandits.py b/disentangled_rnns/library/two_armed_bandits.py index a7f3474..825da28 100644 --- a/disentangled_rnns/library/two_armed_bandits.py +++ b/disentangled_rnns/library/two_armed_bandits.py @@ -58,7 +58,7 @@ def new_session(self): """ @abstractmethod - def step(self, attempted_choice: int) -> tuple[int, float, int]: + def step(self, attempted_choice: int) -> tuple[int, float | int, int]: """Executes a single step in the environment. Args: @@ -91,17 +91,18 @@ class EnvironmentBanditsDrift(BaseEnvironment): n_arms: The number of arms in the environment. """ - def __init__(self, - sigma: float, - p_instructed: float = 0.0, - seed: Optional[int] = None, - n_arms: int = 2, - ): + def __init__( + self, + sigma: float, + p_instructed: float = 0.0, + seed: Optional[int] = None, + n_arms: int = 2, + ): super().__init__(seed=seed, n_arms=n_arms) # Check inputs if sigma < 0: - msg = ('sigma was {}, but must be greater than 0') + msg = 'sigma was {}, but must be greater than 0' raise ValueError(msg.format(sigma)) # Initialize persistent properties @@ -116,8 +117,7 @@ def new_session(self): # Sample randomly between 0 and 1 self._reward_probs = self._random_state.rand(self.n_arms) - def step(self, - attempted_choice: int) -> tuple[int, float, int]: + def step(self, attempted_choice: int) -> tuple[int, float, int]: """Run a single trial of the task. Args: @@ -129,7 +129,6 @@ def step(self, that trial. reward: The reward to be given to the agent. 0 or 1. instructed: 1 if the choice was instructed, 0 otherwise - """ if attempted_choice == -1: choice = -1 @@ -139,8 +138,10 @@ def step(self, # Check inputs if attempted_choice not in list(range(self.n_arms)): - msg = (f'choice given was {attempted_choice}, but must be one of ' - f'{list(range(self.n_arms))}.') + msg = ( + f'choice given was {attempted_choice}, but must be one of ' + f'{list(range(self.n_arms))}.' + ) raise ValueError(msg) # If choice was instructed, overrule it and decide randomly @@ -154,7 +155,8 @@ def step(self, reward = self._random_state.rand() < self._reward_probs[choice] # Add gaussian noise to reward probabilities drift = self._random_state.normal( - loc=0, scale=self._sigma, size=self.n_arms) + loc=0, scale=self._sigma, size=self.n_arms + ) self._reward_probs += drift # Fix reward probs that've drifted below 0 or above 1 @@ -186,7 +188,7 @@ def __init__( """Initialize the environment. Args: - payout_matrix: A numpy array of shape (n_sessions, n_actions, n_trials) + payout_matrix: A numpy array of shape (n_sessions, n_trials, n_actions) giving the reward for each session, action, and trial. These are deterministic, i.e. for the same trial_num, session_num, and action, the reward will always be the same. (If you'd like stochastic rewards you @@ -206,7 +208,9 @@ def __init__( if instructed_matrix is not None: self._instructed_matrix = instructed_matrix else: - self._instructed_matrix = np.nan * np.zeros_like(payout_matrix) + self._instructed_matrix = np.full( + (self._n_sessions, self._n_trials), np.nan + ) self._current_session = 0 self._current_trial = 0 @@ -221,7 +225,7 @@ def new_session(self): ) self._current_trial = 0 - def step(self, attempted_choice: int) -> tuple[int, float, int]: + def step(self, attempted_choice: int) -> tuple[int, float | int, int]: # If agent choice is default empty value -1, return -1 for all outputs. if attempted_choice == -1: choice = -1 @@ -231,8 +235,10 @@ def step(self, attempted_choice: int) -> tuple[int, float, int]: # Check inputted choice is valid. if attempted_choice not in list(range(self.n_arms)): - msg = (f'choice given was {attempted_choice}, but must be one of ' - f'{list(range(self.n_arms))}.') + msg = ( + f'choice given was {attempted_choice}, but must be one of ' + f'{list(range(self.n_arms))}.' + ) raise ValueError(msg) if self._current_trial >= self._n_trials: @@ -256,13 +262,15 @@ def step(self, attempted_choice: int) -> tuple[int, float, int]: self._current_session, self._current_trial, choice ] self._current_trial += 1 - return choice, float(reward), int(instructed) + return choice, reward, int(instructed) @property def payout(self) -> np.ndarray: """Get possible payouts for current session, trial across actions.""" return self._payout_matrix[ - self._current_session, self._current_trial, :].copy() + self._current_session, self._current_trial, : + ].copy() + ########## # AGENTS # @@ -274,7 +282,6 @@ class AgentQ: Attributes: q: The agent's current estimate of the reward probability on each arm - """ def __init__( @@ -298,7 +305,8 @@ def new_session(self): def get_choice_probs(self) -> np.ndarray: choice_probs = np.exp(self._beta * self.q) / np.sum( - np.exp(self._beta * self.q)) + np.exp(self._beta * self.q) + ) return choice_probs def get_choice(self) -> int: @@ -308,9 +316,7 @@ def get_choice(self) -> int: choice = np.random.choice(2, p=choice_probs) return choice - def update(self, - choice: int, - reward: float): + def update(self, choice: int, reward: float): """Update the agent after one step of the task. Args: @@ -350,12 +356,11 @@ def __init__( def new_session(self): """Reset the agent for the beginning of a new session.""" - self.theta = 0. * np.ones(2) + self.theta = 0.0 * np.ones(2) self.v = 0.5 def get_choice_probs(self) -> np.ndarray: - choice_probs = np.exp(self.theta) / np.sum( - np.exp(self.theta)) + choice_probs = np.exp(self.theta) / np.sum(np.exp(self.theta)) return choice_probs def get_choice(self) -> int: @@ -379,9 +384,11 @@ def update(self, choice: int, reward: float): choice_probs = self.get_choice_probs() rpe = reward - self.v self.theta[choice] = (1 - self._alpha_actor_forget) * self.theta[ - choice] + self._alpha_actor_learn * rpe * (1 - choice_probs[choice]) + choice + ] + self._alpha_actor_learn * rpe * (1 - choice_probs[choice]) self.theta[unchosen] = (1 - self._alpha_actor_forget) * self.theta[ - unchosen] - self._alpha_actor_learn * rpe * (choice_probs[unchosen]) + unchosen + ] - self._alpha_actor_learn * rpe * (choice_probs[unchosen]) # Critic learing: V moves towards reward self.v = (1 - self._alpha_critic) * self.v + self._alpha_critic * reward @@ -395,9 +402,7 @@ class AgentNetwork: params: A set of Haiku parameters suitable for that architecture """ - def __init__(self, - make_network: Callable[[], hk.RNNCore], - params: hk.Params): + def __init__(self, make_network: Callable[[], hk.RNNCore], params: hk.Params): def step_network( xs: np.ndarray, state: hk.State @@ -449,9 +454,9 @@ class SessData(NamedTuple): n_trials: int -def run_experiment(agent: Agent, - environment: EnvironmentBanditsDrift, - n_steps: int) -> SessData: +def run_experiment( + agent: Agent, environment: EnvironmentBanditsDrift, n_steps: int +) -> SessData: """Runs a behavioral session from a given agent and environment. Args: @@ -479,25 +484,29 @@ def run_experiment(agent: Agent, choices[step] = choice rewards[step] = reward - experiment = SessData(choices=choices, - rewards=rewards, - n_trials=n_steps, - reward_probs=reward_probs) + experiment = SessData( + choices=choices, + rewards=rewards, + n_trials=n_steps, + reward_probs=reward_probs, + ) return experiment -def create_dataset(agent: Agent, - environment: EnvironmentBanditsDrift, - n_steps_per_session: int, - n_sessions: int, - batch_size: int) -> rnn_utils.DatasetRNN: +def create_dataset( + agent: Agent, + environment: EnvironmentBanditsDrift, + n_steps_per_session: int, + n_sessions: int, + batch_size: int, +) -> rnn_utils.DatasetRNN: """Generates a behavioral dataset from a given agent and environment. Args: agent: An agent object to generate choices environment: An environment object to generate rewards - n_steps_per_session: The number of trials in each behavioral session to - be generated + n_steps_per_session: The number of trials in each behavioral session to be + generated n_sessions: The number of sessions to generate batch_size: The size of the batches to serve from the dataset