This repository contains the codebase for ContagionRL, a platform introduced in our paper for reinforcement learning in spatial epidemic environments which unifies compartmental epidemiology and agent-based modeling (ABM).
Create an environment with Conda:
conda create --name ContagionRL python=3.13.1Activate the environment:
conda activate ContagionRL
Install the dependencies:
pip install -r requirements.txtThe following libraries are used in ContagionRL. Their licenses and usage details are listed below. Our codebase is intended for release under BSD-3 license.
| Library | Version | Purpose | License |
|---|---|---|---|
| gymnasium | 1.0.0 | Toolkit for reinforcement learning environments | MIT License |
| imageio | 2.37.0 | Image and video I/O library | BSD 2-Clause |
| matplotlib | 3.10.3 | Comprehensive plotting library | Matplotlib License (BSD-based) |
| pandas | 2.2.3 | Data manipulation and analysis | BSD 3-Clause |
| scipy | 1.15.3 | Advanced scientific computations | BSD 3-Clause |
| seaborn | 0.13.2 | Statistical data visualization | BSD 3-Clause |
| stable_baselines3 | 2.5.0 | Reinforcement learning algorithms | MIT License |
| statsmodels | 0.14.4 | Statistical modeling and testing | BSD 3-Clause |
| torch | 2.6.0 | Deep learning framework | BSD 3-Clause |
| tqdm | 4.67.1 | Progress bar utility | MPL 2.0 (Mozilla Public License) |
| tueplots | 0.2.0 | Publication-style plotting presets | MIT License |
| wandb | 0.19.8 | Experiment tracking and visualization | MIT License |
The SIRS+D environment can be found as SIRSDEnvironment in environment.py file. SIRSDEnvironment can be used directly or using registered version with gym.make. The env_config corresponds to the environmental parameters in Table 3 of the paper.
The SIRSDEnvironment can be used in the default epidemic configuration as follows:
from environment import SIRSDEnvironment
env = SIRSDEnvironment()Critically, to support an extensive range of epidemics, it could be configured and parameterized:
from environment import SIRSDEnvironment
env_config = {
"simulation_time": 512,
"grid_size": 50,
"n_humans": 40,
"n_infected": 10,
"beta": 0.5,
"initial_agent_adherence": 0,
"distance_decay": 0.3,
"lethality": 0,
"immunity_loss_prob": 0.25,
"recovery_rate": 0.1,
"adherence_penalty_factor": 1,
"adherence_effectiveness": 0.2,
"movement_type": "continuous_random",
"movement_scale": 1,
"visibility_radius": -1,
"reinfection_count": 5,
"safe_distance": 10,
"init_agent_distance": 5,
"max_distance_for_beta_calculation": 10,
"reward_type": "potential_field",
"reward_ablation": "full",
"render_mode": None
}
env = SIRSDEnvironment(**env_config)Alternatively, it is also possible to use the registered form of environment directly from gymnasium:
import registerSIRSD # handles the registeration of SIRSDEnvironment
import gymnasium as gym
env_config = {
"simulation_time": 512,
"grid_size": 50,
"n_humans": 40,
"n_infected": 10,
"beta": 0.5,
"initial_agent_adherence": 0,
"distance_decay": 0.3,
"lethality": 0,
"immunity_loss_prob": 0.25,
"recovery_rate": 0.1,
"adherence_penalty_factor": 1,
"adherence_effectiveness": 0.2,
"movement_type": "continuous_random",
"movement_scale": 1,
"visibility_radius": -1,
"reinfection_count": 5,
"safe_distance": 10,
"init_agent_distance": 5,
"max_distance_for_beta_calculation": 10,
"reward_type": "potential_field",
"reward_ablation": "full",
"render_mode": None
}
env = gym.make('SIRSD-v0', **env_config)The registerSIRSD.py script registers our environment with gymnasium under the name SIRSD-v0.
To regenerate our data, figures and tables in the results or additional experiments follow the instructions here. The table below shows a grouping of the figures and tables in our paper. The figures are grouped by topic (and the script figureX.py where X is an integer) that produces them.
| Group Name | Supporting Figures and Tables | Figure Training Script | Figure Generation Script |
|---|---|---|---|
| Comparison of Reinforcement Learning Algorithms Performance | Figure 2, Figure 9, Table 1 | train_figure4_models.py |
figure4.py |
| Comparison of Reward Functions | Figure 3, Figure 10, Table 8 | train_figure2_models.py |
figure2.py |
| Potential Field Reward Function Ablation Study | Figure 4, Figure 11, Table 9 | train_figure5_models.py |
figure5.py |
| Impact of Visibility Radius Constraints | Figure 5 | train_figure9_models.py |
figure9.py |
| Performance Comparison Across Movement Patterns | Figure 6, Table 2 | train_figure10_models.py |
figure10.py |
| Static Render of Environment | Figure 8 | no training needed | figure_render.py |
| Environmental Parameter Variation: Infection Rate | Figure 12, Table 11 | train_figure3_models.py |
figure3.py |
| Environmental Parameter Variation: Population Density (Grid Size) | Figure 13, Table 12 | train_figure6_models.py |
figure6.py |
| Environmental Parameter Variation: Adherence Effectiveness | Figure 14, Table 13 | train_figure7_models.py |
figure7.py |
| Environmental Parameter Variation: Distance Decay | Figure 15, Table 14 | train_figure8_models.py |
figure8.py |
The results section is divided into two parts:
figureTraining: Handle the training of the necessary models to make the graphs. For example:train_figure2_models.pyhandles the training of the models needed to makefigure2.py. By default, all of the models are trained across 3 seeds.figureGeneration: Create graphs based on the train model made by its corresponding training file. EachfigureX.pyproduces multiple visualizations. The charts are saved to afiguresdirectory. The tables are printed (with pretty print) into the terminal that is running the script.
Note: train_figure1_models.py and figure1.py are legacy scripts that are not part of the current figure generation workflow, and no figures in the paper are derived from them.
To recreate the figures in our paper, follow these instructions. Make sure that you have activated the Conda environment, see requirements. First, train the models:
python train_figure[X]_models.pyIn training you may be asked to log into W&B. You can skip this using --no-wandb flag.
Then, recreate the figures/tables:
python figure[X].py --model-base Fig[X] --runs 100The above command would create a new folder called figures/, which will contain the figure(s). The tables (if applicable) will be printed to the terminal output.
- Be sure to replace
Xwith a valid integer based on the grouping table above. The value forXshould be the same for both of the scripts. --model-baseflag provides a reference to which trained models inlogsdirectory should be used to making the figure(s).--runs 100is used to do 100 inferences per seed per model (as done in the paper).