Skip to content

Commit 233501c

Browse files
wkerzendorfCopilot
andauthored
Restructure/mc subpackage packets (#3266)
* Refactor imports in base.py and test_interaction.py; add LineInteractionType class in interaction.py * Refactor opacity state initialization to use numba implementation; update related tests * Refactor OpacityStateNumba class to use Numba types directly; update initialization and documentation * Refactor packet and tracker classes for Monte Carlo simulation; implement RPacket and VPacket with associated methods * Refactor Monte Carlo transport module to reorganize packet structure - Moved packet-related classes and functions into a new 'packets' submodule for better organization. - Updated import statements throughout the codebase to reflect the new structure. - Enhanced type annotations and docstrings for clarity and improved code readability. - Refactored the PacketCollection, LastInteractionTracker, RPacket, VPacket, and associated tracker classes to utilize Numba's JIT compilation more effectively. - Improved the handling of packet properties and interactions, ensuring consistency across the Monte Carlo transport processes. - Updated tests to accommodate the new structure and ensure all functionalities remain intact. * Refactor packet source and testing structure - Updated import paths in `packet_source.py` to reflect new module structure. * Update tardis/opacities/opacity_state_numba.py Co-authored-by: Copilot <[email protected]> * Update energy parameter description in RPacket class * empty commit to get CI going --------- Co-authored-by: Copilot <[email protected]>
1 parent 995879b commit 233501c

27 files changed

+478
-351
lines changed

benchmarks/benchmark_base.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@
1212
from tardis.model.geometry.radial1d import NumbaRadial1DGeometry
1313
from tardis.simulation import Simulation
1414
from tardis.tests.fixtures.atom_data import DEFAULT_ATOM_DATA_MD5
15-
from tardis.transport.montecarlo import RPacket, packet_trackers
15+
from tardis.transport.montecarlo import RPacket
1616
from tardis.transport.montecarlo.configuration.base import (
1717
MonteCarloConfiguration,
1818
)
1919
from tardis.transport.montecarlo.estimators import radfield_mc_estimators
2020
from tardis.opacities.opacity_state_numba import opacity_state_numba_initialize
21-
from tardis.transport.montecarlo.packet_collections import VPacketCollection
21+
from tardis.transport.montecarlo.packets.packet_collections import VPacketCollection
22+
from tardis.transport.montecarlo.packets import packet_trackers
2223

2324

2425
class BenchmarkBase:

benchmarks/transport_montecarlo_packet_trackers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from asv_runner.benchmarks.mark import parameterize
77

88
from benchmarks.benchmark_base import BenchmarkBase
9-
from tardis.transport.montecarlo import packet_trackers
9+
from tardis.transport.montecarlo.packets import packet_trackers
1010

1111

1212
class BenchmarkTransportMontecarloPacketTrackers(BenchmarkBase):

benchmarks/transport_montecarlo_vpacket.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66

77
import numpy as np
88

9-
import tardis.transport.montecarlo.vpacket as vpacket
9+
import tardis.transport.montecarlo.packets.virtual_packet as virtual_packet
1010
from benchmarks.benchmark_base import BenchmarkBase
1111
from tardis.transport.frame_transformations import get_doppler_factor
12-
from tardis.transport.montecarlo.r_packet import RPacket
12+
from tardis.transport.montecarlo.packets.radiative_packet import RPacket
1313

1414

1515
class BenchmarkMontecarloMontecarloNumbaVpacket(BenchmarkBase):
@@ -30,7 +30,7 @@ def setup(self):
3030

3131
@functools.cached_property
3232
def v_packet(self):
33-
return vpacket.VPacket(
33+
return virtual_packet.VPacket(
3434
r=7.5e14,
3535
nu=4e15,
3636
mu=0.3,
@@ -74,7 +74,7 @@ def time_trace_vpacket_within_shell(self):
7474
self.enable_full_relativity,
7575
)
7676

77-
vpacket.trace_vpacket_within_shell(
77+
virtual_packet.trace_vpacket_within_shell(
7878
self.vpacket,
7979
self.numba_radial_1d_geometry,
8080
self.time_explosion,
@@ -94,7 +94,7 @@ def time_trace_vpacket(self):
9494
self.enable_full_relativity,
9595
)
9696

97-
vpacket.trace_vpacket(
97+
virtual_packet.trace_vpacket(
9898
self.vpacket,
9999
self.numba_radial_1d_geometry,
100100
self.time_explosion,
@@ -106,7 +106,7 @@ def time_trace_vpacket(self):
106106

107107
@functools.cached_property
108108
def broken_packet(self):
109-
return vpacket.VPacket(
109+
return virtual_packet.VPacket(
110110
r=1286064000000000.0,
111111
nu=1660428912896553.2,
112112
mu=0.4916053094346575,
@@ -119,7 +119,7 @@ def broken_packet(self):
119119
def time_trace_bad_vpacket(self):
120120
broken_packet = self.broken_packet
121121

122-
vpacket.trace_vpacket(
122+
virtual_packet.trace_vpacket(
123123
broken_packet,
124124
self.numba_radial_1d_geometry,
125125
self.time_explosion,
@@ -130,7 +130,7 @@ def time_trace_bad_vpacket(self):
130130
)
131131

132132
def time_trace_vpacket_volley(self):
133-
vpacket.trace_vpacket_volley(
133+
virtual_packet.trace_vpacket_volley(
134134
self.r_packet,
135135
self.verysimple_3vpacket_collection,
136136
self.numba_radial_1d_geometry,

docs/io/optional/how_to_custom_source.ipynb

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,16 @@
3838
"outputs": [],
3939
"source": [
4040
"# Import necessary packages\n",
41+
"import matplotlib.pyplot as plt\n",
4142
"import numpy as np\n",
42-
"from tardis import constants as const\n",
4343
"from astropy import units as u\n",
44+
"\n",
45+
"from tardis import run_tardis\n",
46+
"from tardis.io.atom_data import download_atom_data\n",
4447
"from tardis.transport.montecarlo.packet_source import BlackBodySimpleSource\n",
45-
"from tardis.transport.montecarlo.packet_collections import (\n",
48+
"from tardis.transport.montecarlo.packets.packet_collections import (\n",
4649
" PacketCollection,\n",
47-
")\n",
48-
"from tardis import run_tardis\n",
49-
"import matplotlib.pyplot as plt\n",
50-
"from tardis.io.atom_data import download_atom_data"
50+
")"
5151
]
5252
},
5353
{
@@ -90,7 +90,14 @@
9090
" self.truncation_wavelength = truncation_wavelength\n",
9191
" super().__init__(**kwargs)\n",
9292
"\n",
93-
" def create_packets(self, no_of_packets, drawing_sample_size=None, seed_offset=0, *args, **kwargs):\n",
93+
" def create_packets(\n",
94+
" self,\n",
95+
" no_of_packets,\n",
96+
" drawing_sample_size=None,\n",
97+
" seed_offset=0,\n",
98+
" *args,\n",
99+
" **kwargs,\n",
100+
" ):\n",
94101
" \"\"\"\n",
95102
" Packet source that generates a truncated Blackbody source.\n",
96103
"\n",
@@ -110,7 +117,6 @@
110117
" array\n",
111118
" Packet energies\n",
112119
" \"\"\"\n",
113-
"\n",
114120
" self._reseed(self.base_seed + seed_offset)\n",
115121
" packet_seeds = self.rng.choice(\n",
116122
" self.MAX_SEED_VAL, no_of_packets, replace=True\n",
@@ -128,10 +134,9 @@
128134
" drawing_sample_size = 2 * no_of_packets\n",
129135
"\n",
130136
" # Blackbody will be truncated below truncation_wavelength / above truncation_frequency.\n",
131-
" truncation_frequency = (\n",
132-
" u.Quantity(self.truncation_wavelength, u.Angstrom)\n",
133-
" .to(u.Hz, equivalencies=u.spectral())\n",
134-
" )\n",
137+
" truncation_frequency = u.Quantity(\n",
138+
" self.truncation_wavelength, u.Angstrom\n",
139+
" ).to(u.Hz, equivalencies=u.spectral())\n",
135140
"\n",
136141
" # Draw nus from blackbody distribution and reject based on truncation_frequency.\n",
137142
" # If more nus.shape[0] > no_of_packets use only the first no_of_packets.\n",
@@ -141,7 +146,9 @@
141146
" # Only required if the truncation wavelength is too big compared to the maximum\n",
142147
" # of the blackbody distribution. Keep sampling until nus.shape[0] > no_of_packets.\n",
143148
" while nus.shape[0] < no_of_packets:\n",
144-
" additional_nus = self.create_packet_nus(drawing_sample_size, *args, **kwargs)\n",
149+
" additional_nus = self.create_packet_nus(\n",
150+
" drawing_sample_size, *args, **kwargs\n",
151+
" )\n",
145152
" mask = additional_nus < truncation_frequency\n",
146153
" additional_nus = additional_nus[mask][:no_of_packets]\n",
147154
" nus = np.hstack([nus, additional_nus])[:no_of_packets]\n",
@@ -150,7 +157,9 @@
150157
" self.calculate_radfield_luminosity().to(u.erg / u.s).value\n",
151158
" )\n",
152159
"\n",
153-
" return PacketCollection(radii, nus, mus, energies, packet_seeds, radiation_field_luminosity)"
160+
" return PacketCollection(\n",
161+
" radii, nus, mus, energies, packet_seeds, radiation_field_luminosity\n",
162+
" )"
154163
]
155164
},
156165
{
@@ -190,14 +199,20 @@
190199
"outputs": [],
191200
"source": [
192201
"%matplotlib inline\n",
193-
"plt.plot(mdl.spectrum_solver.spectrum_virtual_packets.wavelength,\n",
194-
" mdl.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n",
195-
" color='red', label='truncated blackbody (custom packet source)')\n",
196-
"plt.plot(mdl_norm.spectrum_solver.spectrum_virtual_packets.wavelength,\n",
197-
" mdl_norm.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n",
198-
" color='blue', label='normal blackbody (default packet source)')\n",
199-
"plt.xlabel(r'$\\lambda [\\AA]$')\n",
200-
"plt.ylabel(r'$L_\\lambda$ [erg/s/$\\AA$]')\n",
202+
"plt.plot(\n",
203+
" mdl.spectrum_solver.spectrum_virtual_packets.wavelength,\n",
204+
" mdl.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n",
205+
" color=\"red\",\n",
206+
" label=\"truncated blackbody (custom packet source)\",\n",
207+
")\n",
208+
"plt.plot(\n",
209+
" mdl_norm.spectrum_solver.spectrum_virtual_packets.wavelength,\n",
210+
" mdl_norm.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n",
211+
" color=\"blue\",\n",
212+
" label=\"normal blackbody (default packet source)\",\n",
213+
")\n",
214+
"plt.xlabel(r\"$\\lambda [\\AA]$\")\n",
215+
"plt.ylabel(r\"$L_\\lambda$ [erg/s/$\\AA$]\")\n",
201216
"plt.xlim(500, 10000)\n",
202217
"plt.legend()"
203218
]

docs/physics_walkthrough/montecarlo/initialization.ipynb

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,12 @@
7777
"metadata": {},
7878
"outputs": [],
7979
"source": [
80+
"import matplotlib.pyplot as plt\n",
8081
"import numpy as np\n",
81-
"from tardis.transport.montecarlo.packet_source import BlackBodySimpleSource\n",
82-
"from tardis.transport.montecarlo.packet_collections import (\n",
83-
" PacketCollection,\n",
84-
")\n",
8582
"from astropy import units as u\n",
83+
"\n",
8684
"from tardis import constants as const\n",
87-
"import matplotlib.pyplot as plt"
85+
"from tardis.transport.montecarlo.packet_source import BlackBodySimpleSource"
8886
]
8987
},
9088
{
@@ -135,11 +133,7 @@
135133
"temperature_inner = 10000 * u.K\n",
136134
"\n",
137135
"luminosity_inner = (\n",
138-
" 4\n",
139-
" * np.pi\n",
140-
" * (r_boundary_inner**2)\n",
141-
" * const.sigma_sb\n",
142-
" * (temperature_inner**4)\n",
136+
" 4 * np.pi * (r_boundary_inner**2) * const.sigma_sb * (temperature_inner**4)\n",
143137
")\n",
144138
"\n",
145139
"# Makes sure the luminosity is given in erg/s\n",
@@ -263,12 +257,18 @@
263257
"source": [
264258
"# We set important quantities for making our histogram\n",
265259
"bins = 200\n",
266-
"nus_planck = np.linspace(min(packet_collection.initial_nus), max(packet_collection.initial_nus), bins).value\n",
260+
"nus_planck = np.linspace(\n",
261+
" min(packet_collection.initial_nus), max(packet_collection.initial_nus), bins\n",
262+
").value\n",
267263
"bin_width = nus_planck[1] - nus_planck[0]\n",
268264
"\n",
269265
"# In the histogram plot below, the weights argument is used\n",
270266
"# to make sure our plotted spectrum has the correct y-axis scale\n",
271-
"plt.hist(packet_collection.initial_nus.value, bins=bins, weights=lumin_per_packet / bin_width)\n",
267+
"plt.hist(\n",
268+
" packet_collection.initial_nus.value,\n",
269+
" bins=bins,\n",
270+
" weights=lumin_per_packet / bin_width,\n",
271+
")\n",
272272
"\n",
273273
"# We plot the planck function for comparison\n",
274274
"plt.plot(nus_planck * u.Hz, planck_function(nus_planck * u.Hz))\n",

tardis/transport/montecarlo/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
"parallel": False,
2222
}
2323

24-
from tardis.transport.montecarlo.packet_collections import (
24+
from tardis.transport.montecarlo.packets.packet_collections import (
2525
PacketCollection,
2626
)
27-
from tardis.transport.montecarlo.r_packet import RPacket
27+
from tardis.transport.montecarlo.packets.radiative_packet import RPacket

tardis/transport/montecarlo/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from tardis.transport.montecarlo.montecarlo_transport_state import (
2828
MonteCarloTransportState,
2929
)
30-
from tardis.transport.montecarlo.packet_trackers import (
30+
from tardis.transport.montecarlo.packets.packet_trackers import (
3131
generate_rpacket_last_interaction_tracker_list,
3232
generate_rpacket_tracker_list,
3333
rpacket_trackers_to_dataframe,

tardis/transport/montecarlo/interaction.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
MacroAtomTransitionType,
1616
macro_atom_interaction,
1717
)
18-
from tardis.transport.montecarlo.r_packet import (
18+
from tardis.transport.montecarlo.packets.radiative_packet import (
1919
PacketStatus,
2020
)
2121
from tardis.transport.montecarlo.utils import get_random_mu

tardis/transport/montecarlo/montecarlo_main_loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@
55

66
from tardis.transport.montecarlo import njit_dict
77
from tardis.transport.montecarlo.configuration import montecarlo_globals
8-
from tardis.transport.montecarlo.packet_collections import (
8+
from tardis.transport.montecarlo.packets.packet_collections import (
99
VPacketCollection,
1010
consolidate_vpacket_tracker,
1111
initialize_last_interaction_tracker,
1212
)
13-
from tardis.transport.montecarlo.r_packet import (
13+
from tardis.transport.montecarlo.packets.radiative_packet import (
1414
PacketStatus,
1515
RPacket,
1616
)

tardis/transport/montecarlo/packet_source.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from tardis import constants as const
88
from tardis.io.hdf_writer_mixin import HDFWriterMixin
9-
from tardis.transport.montecarlo.packet_collections import (
9+
from tardis.transport.montecarlo.packets.packet_collections import (
1010
PacketCollection,
1111
)
1212

0 commit comments

Comments
 (0)