From 743f225bb853d4aedd9c7889a0b5682548f8b2de Mon Sep 17 00:00:00 2001 From: Wolfgang Kerzendorf Date: Tue, 26 Aug 2025 17:03:50 -0400 Subject: [PATCH 1/9] Refactor imports in base.py and test_interaction.py; add LineInteractionType class in interaction.py --- .../opacity_state_numba.py} | 5 ----- tardis/transport/montecarlo/configuration/base.py | 2 +- tardis/transport/montecarlo/interaction.py | 10 +++++++--- tardis/transport/montecarlo/tests/test_interaction.py | 2 +- 4 files changed, 9 insertions(+), 10 deletions(-) rename tardis/{transport/montecarlo/numba_interface.py => opacities/opacity_state_numba.py} (98%) diff --git a/tardis/transport/montecarlo/numba_interface.py b/tardis/opacities/opacity_state_numba.py similarity index 98% rename from tardis/transport/montecarlo/numba_interface.py rename to tardis/opacities/opacity_state_numba.py index d6f54d8c8d9..d6cfc56d38c 100644 --- a/tardis/transport/montecarlo/numba_interface.py +++ b/tardis/opacities/opacity_state_numba.py @@ -1,4 +1,3 @@ -from enum import IntEnum import numpy as np from numba import float64, int64 @@ -283,7 +282,3 @@ def opacity_state_initialize( ) -class LineInteractionType(IntEnum): - SCATTER = 0 - DOWNBRANCH = 1 - MACROATOM = 2 diff --git a/tardis/transport/montecarlo/configuration/base.py b/tardis/transport/montecarlo/configuration/base.py index 0d5d7c409b3..f1d14841220 100644 --- a/tardis/transport/montecarlo/configuration/base.py +++ b/tardis/transport/montecarlo/configuration/base.py @@ -4,7 +4,7 @@ from numba.experimental import jitclass from tardis.transport.montecarlo.configuration import montecarlo_globals -from tardis.transport.montecarlo.numba_interface import ( +from tardis.transport.montecarlo.interaction import ( LineInteractionType, ) diff --git a/tardis/transport/montecarlo/interaction.py b/tardis/transport/montecarlo/interaction.py index 6bbdd54ccbf..8fd1fe6e5c0 100644 --- a/tardis/transport/montecarlo/interaction.py +++ b/tardis/transport/montecarlo/interaction.py @@ -1,3 +1,4 @@ +from enum import IntEnum import numpy as np from numba import njit @@ -13,9 +14,6 @@ MacroAtomTransitionType, macro_atom_interaction, ) -from tardis.transport.montecarlo.numba_interface import ( - LineInteractionType, -) from tardis.transport.montecarlo.r_packet import ( PacketStatus, ) @@ -419,6 +417,12 @@ def thomson_scatter(r_packet, time_explosion, enable_full_relativity): ) +class LineInteractionType(IntEnum): + SCATTER = 0 + DOWNBRANCH = 1 + MACROATOM = 2 + + @njit(**njit_dict_no_parallel) def line_scatter( r_packet, diff --git a/tardis/transport/montecarlo/tests/test_interaction.py b/tardis/transport/montecarlo/tests/test_interaction.py index 152a0663b2a..727a28c7843 100644 --- a/tardis/transport/montecarlo/tests/test_interaction.py +++ b/tardis/transport/montecarlo/tests/test_interaction.py @@ -3,7 +3,7 @@ import pytest import tardis.transport.montecarlo.interaction as interaction -from tardis.transport.montecarlo.numba_interface import ( +from tardis.transport.montecarlo.interaction import ( LineInteractionType, ) From 47e279a6ecd5d6d6f10f4c6922dae89810b74e30 Mon Sep 17 00:00:00 2001 From: Wolfgang Kerzendorf Date: Tue, 26 Aug 2025 17:07:46 -0400 Subject: [PATCH 2/9] Refactor opacity state initialization to use numba implementation; update related tests --- benchmarks/benchmark_base.py | 4 ++-- tardis/opacities/opacity_state_numba.py | 12 ++++++------ tardis/transport/montecarlo/interaction.py | 1 + tardis/transport/montecarlo/tests/conftest.py | 6 +++--- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/benchmarks/benchmark_base.py b/benchmarks/benchmark_base.py index e93a86a9068..a6bc9a6764c 100644 --- a/benchmarks/benchmark_base.py +++ b/benchmarks/benchmark_base.py @@ -17,7 +17,7 @@ MonteCarloConfiguration, ) from tardis.transport.montecarlo.estimators import radfield_mc_estimators -from tardis.transport.montecarlo.numba_interface import opacity_state_initialize +from tardis.opacities.opacity_state_numba import opacity_state_numba_initialize from tardis.transport.montecarlo.packet_collections import VPacketCollection @@ -138,7 +138,7 @@ def verysimple_time_explosion(self): @functools.cached_property def verysimple_opacity_state(self): - return opacity_state_initialize( + return opacity_state_numba_initialize( self.nb_simulation_verysimple.plasma, line_interaction_type="macroatom", disable_line_scattering=self.nb_simulation_verysimple.transport.montecarlo_configuration.DISABLE_LINE_SCATTERING, diff --git a/tardis/opacities/opacity_state_numba.py b/tardis/opacities/opacity_state_numba.py index d6cfc56d38c..af872e0c5f8 100644 --- a/tardis/opacities/opacity_state_numba.py +++ b/tardis/opacities/opacity_state_numba.py @@ -9,7 +9,7 @@ C_SPEED_OF_LIGHT = const.c.to("cm/s").value -opacity_state_spec = [ +opacity_state_numba_spec = [ ("electron_density", float64[:]), ("t_electrons", float64[:]), ("line_list_nu", float64[:]), @@ -35,8 +35,8 @@ ] -@jitclass(opacity_state_spec) -class OpacityState: +@jitclass(opacity_state_numba_spec) +class OpacityStateNumba: def __init__( self, electron_density, @@ -121,7 +121,7 @@ def __getitem__(self, i: slice): OpacityState : a shallow copy of the current instance """ # NOTE: This currently will not work with continuum processes since it does not slice those arrays - return OpacityState( + return OpacityStateNumba( self.electron_density[i], self.t_electrons[i], self.line_list_nu, @@ -147,7 +147,7 @@ def __getitem__(self, i: slice): ) -def opacity_state_initialize( +def opacity_state_numba_initialize( plasma, line_interaction_type, disable_line_scattering, @@ -256,7 +256,7 @@ def opacity_state_initialize( photo_ion_activation_idx = np.zeros(0, dtype=np.int64) k_packet_idx = np.int64(-1) - return OpacityState( + return OpacityStateNumba( electron_densities, t_electrons, line_list_nu, diff --git a/tardis/transport/montecarlo/interaction.py b/tardis/transport/montecarlo/interaction.py index 8fd1fe6e5c0..298fc206f70 100644 --- a/tardis/transport/montecarlo/interaction.py +++ b/tardis/transport/montecarlo/interaction.py @@ -1,4 +1,5 @@ from enum import IntEnum + import numpy as np from numba import njit diff --git a/tardis/transport/montecarlo/tests/conftest.py b/tardis/transport/montecarlo/tests/conftest.py index 9a28726265c..17c1ab6f942 100644 --- a/tardis/transport/montecarlo/tests/conftest.py +++ b/tardis/transport/montecarlo/tests/conftest.py @@ -10,8 +10,8 @@ from tardis.transport.montecarlo.estimators.radfield_mc_estimators import ( RadiationFieldMCEstimators, ) -from tardis.transport.montecarlo.numba_interface import ( - opacity_state_initialize, +from tardis.opacities.opacity_state_numba import ( + opacity_state_numba_initialize, ) from tardis.transport.montecarlo.packet_collections import ( VPacketCollection, @@ -53,7 +53,7 @@ def simple_weighted_packet_source(): @pytest.fixture(scope="package") def verysimple_opacity_state(nb_simulation_verysimple): - return opacity_state_initialize( + return opacity_state_numba_initialize( nb_simulation_verysimple.plasma, line_interaction_type="macroatom", disable_line_scattering=False, From cc2f2de0d8132cfedf4f72fc28e4a73b35cb16d0 Mon Sep 17 00:00:00 2001 From: Wolfgang Kerzendorf Date: Tue, 26 Aug 2025 17:18:47 -0400 Subject: [PATCH 3/9] Refactor OpacityStateNumba class to use Numba types directly; update initialization and documentation --- tardis/opacities/opacity_state_numba.py | 170 +++++++++++++++--------- 1 file changed, 106 insertions(+), 64 deletions(-) diff --git a/tardis/opacities/opacity_state_numba.py b/tardis/opacities/opacity_state_numba.py index af872e0c5f8..600f7579fc3 100644 --- a/tardis/opacities/opacity_state_numba.py +++ b/tardis/opacities/opacity_state_numba.py @@ -1,6 +1,6 @@ +import numba as nb import numpy as np -from numba import float64, int64 from numba.experimental import jitclass from tardis import constants as const @@ -9,75 +9,105 @@ C_SPEED_OF_LIGHT = const.c.to("cm/s").value -opacity_state_numba_spec = [ - ("electron_density", float64[:]), - ("t_electrons", float64[:]), - ("line_list_nu", float64[:]), - ("tau_sobolev", float64[:, :]), - ("transition_probabilities", float64[:, :]), - ("line2macro_level_upper", int64[:]), - ("macro_block_references", int64[:]), - ("transition_type", int64[:]), - ("destination_level_id", int64[:]), - ("transition_line_id", int64[:]), - ("bf_threshold_list_nu", float64[:]), - ("p_fb_deactivation", float64[:, :]), - ("photo_ion_nu_threshold_mins", float64[:]), - ("photo_ion_nu_threshold_maxs", float64[:]), - ("photo_ion_block_references", int64[:]), - ("chi_bf", float64[:, :]), - ("x_sect", float64[:]), - ("phot_nus", float64[:]), - ("ff_opacity_factor", float64[:]), - ("emissivities", float64[:, :]), - ("photo_ion_activation_idx", int64[:]), - ("k_packet_idx", int64), -] - - -@jitclass(opacity_state_numba_spec) +@jitclass class OpacityStateNumba: + electron_density: nb.float64[:] # type: ignore[misc] + t_electrons: nb.float64[:] # type: ignore[misc] + line_list_nu: nb.float64[:] # type: ignore[misc] + tau_sobolev: nb.float64[:, :] # type: ignore[misc] + transition_probabilities: nb.float64[:, :] # type: ignore[misc] + line2macro_level_upper: nb.int64[:] # type: ignore[misc] + macro_block_references: nb.int64[:] # type: ignore[misc] + transition_type: nb.int64[:] # type: ignore[misc] + destination_level_id: nb.int64[:] # type: ignore[misc] + transition_line_id: nb.int64[:] # type: ignore[misc] + bf_threshold_list_nu: nb.float64[:] # type: ignore[misc] + p_fb_deactivation: nb.float64[:, :] # type: ignore[misc] + photo_ion_nu_threshold_mins: nb.float64[:] # type: ignore[misc] + photo_ion_nu_threshold_maxs: nb.float64[:] # type: ignore[misc] + photo_ion_block_references: nb.int64[:] # type: ignore[misc] + chi_bf: nb.float64[:, :] # type: ignore[misc] + x_sect: nb.float64[:] # type: ignore[misc] + phot_nus: nb.float64[:] # type: ignore[misc] + ff_opacity_factor: nb.float64[:] # type: ignore[misc] + emissivities: nb.float64[:, :] # type: ignore[misc] + photo_ion_activation_idx: nb.int64[:] # type: ignore[misc] + k_packet_idx: nb.int64 # type: ignore[misc] + def __init__( self, - electron_density, - t_electrons, - line_list_nu, - tau_sobolev, - transition_probabilities, - line2macro_level_upper, - macro_block_references, - transition_type, - destination_level_id, - transition_line_id, - bf_threshold_list_nu, - p_fb_deactivation, - photo_ion_nu_threshold_mins, - photo_ion_nu_threshold_maxs, - photo_ion_block_references, - chi_bf, - x_sect, - phot_nus, - ff_opacity_factor, - emissivities, - photo_ion_activation_idx, - k_packet_idx, - ): + electron_density: np.ndarray, + t_electrons: np.ndarray, + line_list_nu: np.ndarray, + tau_sobolev: np.ndarray, + transition_probabilities: np.ndarray, + line2macro_level_upper: np.ndarray, + macro_block_references: np.ndarray, + transition_type: np.ndarray, + destination_level_id: np.ndarray, + transition_line_id: np.ndarray, + bf_threshold_list_nu: np.ndarray, + p_fb_deactivation: np.ndarray, + photo_ion_nu_threshold_mins: np.ndarray, + photo_ion_nu_threshold_maxs: np.ndarray, + photo_ion_block_references: np.ndarray, + chi_bf: np.ndarray, + x_sect: np.ndarray, + phot_nus: np.ndarray, + ff_opacity_factor: np.ndarray, + emissivities: np.ndarray, + photo_ion_activation_idx: np.ndarray, + k_packet_idx: int, + ) -> None: """ - Plasma for the Numba code + Initialize Numba-compatible opacity state for Monte Carlo transport. Parameters ---------- electron_density : numpy.ndarray + Electron density in each shell [cm^-3]. t_electrons : numpy.ndarray + Electron temperature in each shell [K]. line_list_nu : numpy.ndarray + Frequencies of spectral lines [Hz]. tau_sobolev : numpy.ndarray + Sobolev optical depths for line transitions. transition_probabilities : numpy.ndarray + Probabilities for macro atom transitions. line2macro_level_upper : numpy.ndarray + Mapping from line indices to macro atom upper levels. macro_block_references : numpy.ndarray + Block references for macro atom data. transition_type : numpy.ndarray + Type identifiers for transitions. destination_level_id : numpy.ndarray + Destination level indices for transitions. transition_line_id : numpy.ndarray + Line indices for transitions. bf_threshold_list_nu : numpy.ndarray + Bound-free threshold frequencies [Hz]. + p_fb_deactivation : numpy.ndarray + Free-bound deactivation probabilities. + photo_ion_nu_threshold_mins : numpy.ndarray + Minimum photoionization threshold frequencies [Hz]. + photo_ion_nu_threshold_maxs : numpy.ndarray + Maximum photoionization threshold frequencies [Hz]. + photo_ion_block_references : numpy.ndarray + Block references for photoionization data. + chi_bf : numpy.ndarray + Bound-free absorption coefficients. + x_sect : numpy.ndarray + Photoionization cross sections [cm^2]. + phot_nus : numpy.ndarray + Photoionization frequencies [Hz]. + ff_opacity_factor : numpy.ndarray + Free-free opacity factors. + emissivities : numpy.ndarray + Emission coefficients for bound-free transitions. + photo_ion_activation_idx : numpy.ndarray + Indices for photoionization activation. + k_packet_idx : int + Index for k-packet handling. """ self.electron_density = electron_density self.t_electrons = t_electrons @@ -110,15 +140,18 @@ def __init__( self.photo_ion_activation_idx = photo_ion_activation_idx self.k_packet_idx = k_packet_idx - def __getitem__(self, i: slice): - """Get a shell or slice of shells of the attributes of the opacity state + def __getitem__(self, i: slice) -> "OpacityStateNumba": + """Get a shell or slice of shells of the attributes of the opacity state. - Args: - i (slice): shell slice. Will fail if slice is int since class only supports array types + Parameters + ---------- + i : slice + Shell slice. Will fail if slice is int since class only supports array types. Returns ------- - OpacityState : a shallow copy of the current instance + OpacityStateNumba + A shallow copy of the current instance with sliced data. """ # NOTE: This currently will not work with continuum processes since it does not slice those arrays return OpacityStateNumba( @@ -149,16 +182,25 @@ def __getitem__(self, i: slice): def opacity_state_numba_initialize( plasma, - line_interaction_type, - disable_line_scattering, -): + line_interaction_type: str, + disable_line_scattering: bool, +) -> OpacityStateNumba: """ - Initialize the OpacityState object and copy over the data over from TARDIS Plasma + Initialize the OpacityStateNumba object and copy data from TARDIS Plasma. Parameters ---------- plasma : tardis.plasma.BasePlasma - line_interaction_type : enum + The plasma object containing atomic and opacity data. + line_interaction_type : str + Type of line interaction ("scatter" or "macroatom"). + disable_line_scattering : bool + Whether to disable line scattering by setting tau_sobolev to zero. + + Returns + ------- + OpacityStateNumba + Initialized opacity state for Monte Carlo transport. """ electron_densities = plasma.electron_densities.values t_electrons = plasma.t_electrons @@ -241,7 +283,7 @@ def opacity_state_numba_initialize( photo_ion_activation_idx = plasma.photo_ion_idx.loc[ plasma.level2continuum_idx.index, "destination_level_idx" ].values - k_packet_idx = np.int64(plasma.k_packet_idx) + k_packet_idx = int(plasma.k_packet_idx) else: bf_threshold_list_nu = np.zeros(0, dtype=np.float64) p_fb_deactivation = np.zeros((0, 0), dtype=np.float64) @@ -254,7 +296,7 @@ def opacity_state_numba_initialize( ff_opacity_factor = np.zeros(0, dtype=np.float64) emissivities = np.zeros((0, 0), dtype=np.float64) photo_ion_activation_idx = np.zeros(0, dtype=np.int64) - k_packet_idx = np.int64(-1) + k_packet_idx = -1 return OpacityStateNumba( electron_densities, From af7f3b0ab454e9fcad0ec95e192e023c5e1fa394 Mon Sep 17 00:00:00 2001 From: Wolfgang Kerzendorf Date: Tue, 26 Aug 2025 17:50:10 -0400 Subject: [PATCH 4/9] Refactor packet and tracker classes for Monte Carlo simulation; implement RPacket and VPacket with associated methods --- tardis/transport/montecarlo/packets/__init__.py | 0 tardis/transport/montecarlo/{ => packets}/packet_collections.py | 0 tardis/transport/montecarlo/{ => packets}/packet_trackers.py | 0 .../montecarlo/{r_packet.py => packets/radiative_packet.py} | 0 .../montecarlo/{vpacket.py => packets/virtual_packet.py} | 2 +- 5 files changed, 1 insertion(+), 1 deletion(-) create mode 100644 tardis/transport/montecarlo/packets/__init__.py rename tardis/transport/montecarlo/{ => packets}/packet_collections.py (100%) rename tardis/transport/montecarlo/{ => packets}/packet_trackers.py (100%) rename tardis/transport/montecarlo/{r_packet.py => packets/radiative_packet.py} (100%) rename tardis/transport/montecarlo/{vpacket.py => packets/virtual_packet.py} (99%) diff --git a/tardis/transport/montecarlo/packets/__init__.py b/tardis/transport/montecarlo/packets/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tardis/transport/montecarlo/packet_collections.py b/tardis/transport/montecarlo/packets/packet_collections.py similarity index 100% rename from tardis/transport/montecarlo/packet_collections.py rename to tardis/transport/montecarlo/packets/packet_collections.py diff --git a/tardis/transport/montecarlo/packet_trackers.py b/tardis/transport/montecarlo/packets/packet_trackers.py similarity index 100% rename from tardis/transport/montecarlo/packet_trackers.py rename to tardis/transport/montecarlo/packets/packet_trackers.py diff --git a/tardis/transport/montecarlo/r_packet.py b/tardis/transport/montecarlo/packets/radiative_packet.py similarity index 100% rename from tardis/transport/montecarlo/r_packet.py rename to tardis/transport/montecarlo/packets/radiative_packet.py diff --git a/tardis/transport/montecarlo/vpacket.py b/tardis/transport/montecarlo/packets/virtual_packet.py similarity index 99% rename from tardis/transport/montecarlo/vpacket.py rename to tardis/transport/montecarlo/packets/virtual_packet.py index 657e1c016d0..27a6141d355 100644 --- a/tardis/transport/montecarlo/vpacket.py +++ b/tardis/transport/montecarlo/packets/virtual_packet.py @@ -22,7 +22,7 @@ C_SPEED_OF_LIGHT, SIGMA_THOMSON, ) -from tardis.transport.montecarlo.r_packet import ( +from tardis.transport.montecarlo.packets.radiative_packet import ( PacketStatus, ) from tardis.transport.montecarlo.r_packet_transport import ( From 726de1a05eda37b57a30710d685406f561cda2a0 Mon Sep 17 00:00:00 2001 From: Wolfgang Kerzendorf Date: Tue, 26 Aug 2025 18:30:10 -0400 Subject: [PATCH 5/9] Refactor Monte Carlo transport module to reorganize packet structure - Moved packet-related classes and functions into a new 'packets' submodule for better organization. - Updated import statements throughout the codebase to reflect the new structure. - Enhanced type annotations and docstrings for clarity and improved code readability. - Refactored the PacketCollection, LastInteractionTracker, RPacket, VPacket, and associated tracker classes to utilize Numba's JIT compilation more effectively. - Improved the handling of packet properties and interactions, ensuring consistency across the Monte Carlo transport processes. - Updated tests to accommodate the new structure and ensure all functionalities remain intact. --- benchmarks/benchmark_base.py | 5 +- .../transport_montecarlo_packet_trackers.py | 2 +- benchmarks/transport_montecarlo_vpacket.py | 16 +- tardis/transport/montecarlo/__init__.py | 4 +- tardis/transport/montecarlo/base.py | 2 +- tardis/transport/montecarlo/interaction.py | 2 +- .../montecarlo/montecarlo_main_loop.py | 4 +- .../montecarlo/packets/packet_collections.py | 241 +++++++++++------- .../montecarlo/packets/packet_trackers.py | 68 +++-- .../montecarlo/packets/radiative_packet.py | 102 +++++--- .../montecarlo/packets/virtual_packet.py | 70 +++-- .../montecarlo/r_packet_transport.py | 2 +- .../montecarlo/single_packet_loop.py | 4 +- tardis/transport/montecarlo/tests/conftest.py | 2 +- .../montecarlo/tests/test_montecarlo.py | 56 ++-- .../transport/montecarlo/tests/test_packet.py | 6 +- .../test_rpacket_last_interaction_tracker.py | 4 +- .../montecarlo/tests/test_rpacket_tracker.py | 4 +- .../montecarlo/tests/test_tracker_utils.py | 2 +- .../montecarlo/tests/test_vpacket.py | 14 +- 20 files changed, 358 insertions(+), 252 deletions(-) diff --git a/benchmarks/benchmark_base.py b/benchmarks/benchmark_base.py index a6bc9a6764c..1a11fff0f71 100644 --- a/benchmarks/benchmark_base.py +++ b/benchmarks/benchmark_base.py @@ -12,13 +12,14 @@ from tardis.model.geometry.radial1d import NumbaRadial1DGeometry from tardis.simulation import Simulation from tardis.tests.fixtures.atom_data import DEFAULT_ATOM_DATA_MD5 -from tardis.transport.montecarlo import RPacket, packet_trackers +from tardis.transport.montecarlo import RPacket from tardis.transport.montecarlo.configuration.base import ( MonteCarloConfiguration, ) from tardis.transport.montecarlo.estimators import radfield_mc_estimators from tardis.opacities.opacity_state_numba import opacity_state_numba_initialize -from tardis.transport.montecarlo.packet_collections import VPacketCollection +from tardis.transport.montecarlo.packets.packet_collections import VPacketCollection +from tardis.transport.montecarlo.packets import packet_trackers class BenchmarkBase: diff --git a/benchmarks/transport_montecarlo_packet_trackers.py b/benchmarks/transport_montecarlo_packet_trackers.py index 26739847bc6..9fd40ea384d 100644 --- a/benchmarks/transport_montecarlo_packet_trackers.py +++ b/benchmarks/transport_montecarlo_packet_trackers.py @@ -6,7 +6,7 @@ from asv_runner.benchmarks.mark import parameterize from benchmarks.benchmark_base import BenchmarkBase -from tardis.transport.montecarlo import packet_trackers +from tardis.transport.montecarlo.packets import packet_trackers class BenchmarkTransportMontecarloPacketTrackers(BenchmarkBase): diff --git a/benchmarks/transport_montecarlo_vpacket.py b/benchmarks/transport_montecarlo_vpacket.py index b237577074d..aafb2673075 100644 --- a/benchmarks/transport_montecarlo_vpacket.py +++ b/benchmarks/transport_montecarlo_vpacket.py @@ -6,10 +6,10 @@ import numpy as np -import tardis.transport.montecarlo.vpacket as vpacket +import tardis.transport.montecarlo.packets.virtual_packet as virtual_packet from benchmarks.benchmark_base import BenchmarkBase from tardis.transport.frame_transformations import get_doppler_factor -from tardis.transport.montecarlo.r_packet import RPacket +from tardis.transport.montecarlo.packets.radiative_packet import RPacket class BenchmarkMontecarloMontecarloNumbaVpacket(BenchmarkBase): @@ -30,7 +30,7 @@ def setup(self): @functools.cached_property def v_packet(self): - return vpacket.VPacket( + return virtual_packet.VPacket( r=7.5e14, nu=4e15, mu=0.3, @@ -74,7 +74,7 @@ def time_trace_vpacket_within_shell(self): self.enable_full_relativity, ) - vpacket.trace_vpacket_within_shell( + virtual_packet.trace_vpacket_within_shell( self.vpacket, self.numba_radial_1d_geometry, self.time_explosion, @@ -94,7 +94,7 @@ def time_trace_vpacket(self): self.enable_full_relativity, ) - vpacket.trace_vpacket( + virtual_packet.trace_vpacket( self.vpacket, self.numba_radial_1d_geometry, self.time_explosion, @@ -106,7 +106,7 @@ def time_trace_vpacket(self): @functools.cached_property def broken_packet(self): - return vpacket.VPacket( + return virtual_packet.VPacket( r=1286064000000000.0, nu=1660428912896553.2, mu=0.4916053094346575, @@ -119,7 +119,7 @@ def broken_packet(self): def time_trace_bad_vpacket(self): broken_packet = self.broken_packet - vpacket.trace_vpacket( + virtual_packet.trace_vpacket( broken_packet, self.numba_radial_1d_geometry, self.time_explosion, @@ -130,7 +130,7 @@ def time_trace_bad_vpacket(self): ) def time_trace_vpacket_volley(self): - vpacket.trace_vpacket_volley( + virtual_packet.trace_vpacket_volley( self.r_packet, self.verysimple_3vpacket_collection, self.numba_radial_1d_geometry, diff --git a/tardis/transport/montecarlo/__init__.py b/tardis/transport/montecarlo/__init__.py index 8cb07e87401..4dbbffc06a6 100644 --- a/tardis/transport/montecarlo/__init__.py +++ b/tardis/transport/montecarlo/__init__.py @@ -21,7 +21,7 @@ "parallel": False, } -from tardis.transport.montecarlo.packet_collections import ( +from tardis.transport.montecarlo.packets.packet_collections import ( PacketCollection, ) -from tardis.transport.montecarlo.r_packet import RPacket +from tardis.transport.montecarlo.packets.radiative_packet import RPacket diff --git a/tardis/transport/montecarlo/base.py b/tardis/transport/montecarlo/base.py index 0c2ab31a777..7bd404eddba 100644 --- a/tardis/transport/montecarlo/base.py +++ b/tardis/transport/montecarlo/base.py @@ -27,7 +27,7 @@ from tardis.transport.montecarlo.montecarlo_transport_state import ( MonteCarloTransportState, ) -from tardis.transport.montecarlo.packet_trackers import ( +from tardis.transport.montecarlo.packets.packet_trackers import ( generate_rpacket_last_interaction_tracker_list, generate_rpacket_tracker_list, rpacket_trackers_to_dataframe, diff --git a/tardis/transport/montecarlo/interaction.py b/tardis/transport/montecarlo/interaction.py index 298fc206f70..e55d9df06e8 100644 --- a/tardis/transport/montecarlo/interaction.py +++ b/tardis/transport/montecarlo/interaction.py @@ -15,7 +15,7 @@ MacroAtomTransitionType, macro_atom_interaction, ) -from tardis.transport.montecarlo.r_packet import ( +from tardis.transport.montecarlo.packets.radiative_packet import ( PacketStatus, ) from tardis.transport.montecarlo.utils import get_random_mu diff --git a/tardis/transport/montecarlo/montecarlo_main_loop.py b/tardis/transport/montecarlo/montecarlo_main_loop.py index e5daa72c89d..2c4837782a6 100644 --- a/tardis/transport/montecarlo/montecarlo_main_loop.py +++ b/tardis/transport/montecarlo/montecarlo_main_loop.py @@ -5,12 +5,12 @@ from tardis.transport.montecarlo import njit_dict from tardis.transport.montecarlo.configuration import montecarlo_globals -from tardis.transport.montecarlo.packet_collections import ( +from tardis.transport.montecarlo.packets.packet_collections import ( VPacketCollection, consolidate_vpacket_tracker, initialize_last_interaction_tracker, ) -from tardis.transport.montecarlo.r_packet import ( +from tardis.transport.montecarlo.packets.radiative_packet import ( PacketStatus, RPacket, ) diff --git a/tardis/transport/montecarlo/packets/packet_collections.py b/tardis/transport/montecarlo/packets/packet_collections.py index 2f73e083c0d..5f9c65eb81e 100644 --- a/tardis/transport/montecarlo/packets/packet_collections.py +++ b/tardis/transport/montecarlo/packets/packet_collections.py @@ -1,35 +1,49 @@ +import numba as nb import numpy as np -from numba import float64, int64, njit +from numba import njit from numba.experimental import jitclass -from tardis.transport.montecarlo import ( - njit_dict_no_parallel, -) +from tardis.transport.montecarlo import njit_dict_no_parallel -packet_collection_spec = [ - ("initial_radii", float64[:]), - ("initial_nus", float64[:]), - ("initial_mus", float64[:]), - ("initial_energies", float64[:]), - ("packet_seeds", int64[:]), - ("time_of_simulation", float64), - ("radiation_field_luminosity", float64), - ("output_nus", float64[:]), - ("output_energies", float64[:]), -] - - -@jitclass(packet_collection_spec) +@jitclass class PacketCollection: + initial_radii: nb.float64[:] # type: ignore[misc] + initial_nus: nb.float64[:] # type: ignore[misc] + initial_mus: nb.float64[:] # type: ignore[misc] + initial_energies: nb.float64[:] # type: ignore[misc] + packet_seeds: nb.int64[:] # type: ignore[misc] + time_of_simulation: nb.float64 # type: ignore[misc] + radiation_field_luminosity: nb.float64 # type: ignore[misc] + output_nus: nb.float64[:] # type: ignore[misc] + output_energies: nb.float64[:] # type: ignore[misc] + def __init__( self, - initial_radii, - initial_nus, - initial_mus, - initial_energies, - packet_seeds, - radiation_field_luminosity, - ): + initial_radii: np.ndarray, + initial_nus: np.ndarray, + initial_mus: np.ndarray, + initial_energies: np.ndarray, + packet_seeds: np.ndarray, + radiation_field_luminosity: float, + ) -> None: + """ + Initialize Numba-compatible packet collection for Monte Carlo transport. + + Parameters + ---------- + initial_radii : numpy.ndarray + Initial radii of packets [cm]. + initial_nus : numpy.ndarray + Initial frequencies of packets [Hz]. + initial_mus : numpy.ndarray + Initial directional cosines of packets. + initial_energies : numpy.ndarray + Initial energies of packets [erg]. + packet_seeds : numpy.ndarray + Random number seeds for packets. + radiation_field_luminosity : float + Luminosity of the radiation field [erg/s]. + """ self.initial_radii = initial_radii self.initial_nus = initial_nus self.initial_mus = initial_mus @@ -45,7 +59,15 @@ def __init__( ) @property - def number_of_packets(self): + def number_of_packets(self) -> int: + """ + Get the number of packets in the collection. + + Returns + ------- + int + Number of packets. + """ return len(self.initial_radii) @@ -70,27 +92,42 @@ def initialize_last_interaction_tracker(no_of_packets): ) -last_interaction_tracker_spec = [ - ("types", int64[:]), - ("in_nus", float64[:]), - ("in_rs", float64[:]), - ("in_ids", int64[:]), - ("out_ids", int64[:]), - ("shell_ids", int64[:]), -] - - -@jitclass(last_interaction_tracker_spec) +@jitclass class LastInteractionTracker: + types: nb.int64[:] # type: ignore[misc] + in_nus: nb.float64[:] # type: ignore[misc] + in_rs: nb.float64[:] # type: ignore[misc] + in_ids: nb.int64[:] # type: ignore[misc] + out_ids: nb.int64[:] # type: ignore[misc] + shell_ids: nb.int64[:] # type: ignore[misc] + def __init__( self, - types, - in_nus, - in_rs, - in_ids, - out_ids, - shell_ids, - ): + types: np.ndarray, + in_nus: np.ndarray, + in_rs: np.ndarray, + in_ids: np.ndarray, + out_ids: np.ndarray, + shell_ids: np.ndarray, + ) -> None: + """ + Initialize last interaction tracker for Monte Carlo packets. + + Parameters + ---------- + types : numpy.ndarray + Types of last interactions. + in_nus : numpy.ndarray + Incoming frequencies of last interactions [Hz]. + in_rs : numpy.ndarray + Radii of last interactions [cm]. + in_ids : numpy.ndarray + Input line IDs for last interactions. + out_ids : numpy.ndarray + Output line IDs for last interactions. + shell_ids : numpy.ndarray + Shell IDs where last interactions occurred. + """ self.types = types self.in_nus = in_nus self.in_rs = in_rs @@ -98,7 +135,17 @@ def __init__( self.out_ids = out_ids self.shell_ids = shell_ids - def update_last_interaction(self, r_packet, i): + def update_last_interaction(self, r_packet, i: int) -> None: + """ + Update the last interaction information for a packet. + + Parameters + ---------- + r_packet : RPacket + The R-packet with interaction information. + i : int + Index of the packet to update. + """ self.types[i] = r_packet.last_interaction_type self.in_nus[i] = r_packet.last_interaction_in_nu self.in_rs[i] = r_packet.last_interaction_in_r @@ -107,38 +154,53 @@ def update_last_interaction(self, r_packet, i): self.shell_ids[i] = r_packet.last_line_interaction_shell_id -vpacket_collection_spec = [ - ("source_rpacket_index", int64), - ("spectrum_frequency_grid", float64[:]), - ("v_packet_spawn_start_frequency", float64), - ("v_packet_spawn_end_frequency", float64), - ("nus", float64[:]), - ("energies", float64[:]), - ("initial_mus", float64[:]), - ("initial_rs", float64[:]), - ("idx", int64), - ("number_of_vpackets", int64), - ("length", int64), - ("last_interaction_in_nu", float64[:]), - ("last_interaction_in_r", float64[:]), - ("last_interaction_type", int64[:]), - ("last_interaction_in_id", int64[:]), - ("last_interaction_out_id", int64[:]), - ("last_interaction_shell_id", int64[:]), -] - - -@jitclass(vpacket_collection_spec) +@jitclass class VPacketCollection: + source_rpacket_index: nb.int64 # type: ignore[misc] + spectrum_frequency_grid: nb.float64[:] # type: ignore[misc] + v_packet_spawn_start_frequency: nb.float64 # type: ignore[misc] + v_packet_spawn_end_frequency: nb.float64 # type: ignore[misc] + nus: nb.float64[:] # type: ignore[misc] + energies: nb.float64[:] # type: ignore[misc] + initial_mus: nb.float64[:] # type: ignore[misc] + initial_rs: nb.float64[:] # type: ignore[misc] + idx: nb.int64 # type: ignore[misc] + number_of_vpackets: nb.int64 # type: ignore[misc] + length: nb.int64 # type: ignore[misc] + last_interaction_in_nu: nb.float64[:] # type: ignore[misc] + last_interaction_in_r: nb.float64[:] # type: ignore[misc] + last_interaction_type: nb.int64[:] # type: ignore[misc] + last_interaction_in_id: nb.int64[:] # type: ignore[misc] + last_interaction_out_id: nb.int64[:] # type: ignore[misc] + last_interaction_shell_id: nb.int64[:] # type: ignore[misc] + def __init__( self, - source_rpacket_index, - spectrum_frequency_grid, - v_packet_spawn_start_frequency, - v_packet_spawn_end_frequency, - number_of_vpackets, - temporary_v_packet_bins, - ): + source_rpacket_index: int, + spectrum_frequency_grid: np.ndarray, + v_packet_spawn_start_frequency: float, + v_packet_spawn_end_frequency: float, + number_of_vpackets: int, + temporary_v_packet_bins: int, + ) -> None: + """ + Initialize virtual packet collection for Monte Carlo transport. + + Parameters + ---------- + source_rpacket_index : int + Index of the source R-packet. + spectrum_frequency_grid : numpy.ndarray + Frequency grid for spectrum calculation [Hz]. + v_packet_spawn_start_frequency : float + Start frequency for virtual packet spawning [Hz]. + v_packet_spawn_end_frequency : float + End frequency for virtual packet spawning [Hz]. + number_of_vpackets : int + Number of virtual packets to generate. + temporary_v_packet_bins : int + Initial size of temporary storage arrays. + """ self.spectrum_frequency_grid = spectrum_frequency_grid self.v_packet_spawn_start_frequency = v_packet_spawn_start_frequency self.v_packet_spawn_end_frequency = v_packet_spawn_end_frequency @@ -171,17 +233,17 @@ def __init__( def add_packet( self, - nu, - energy, - initial_mu, - initial_r, - last_interaction_in_nu, - last_interaction_in_r, - last_interaction_type, - last_interaction_in_id, - last_interaction_out_id, - last_interaction_shell_id, - ): + nu: float, + energy: float, + initial_mu: float, + initial_r: float, + last_interaction_in_nu: float, + last_interaction_in_r: float, + last_interaction_type: int, + last_interaction_in_id: int, + last_interaction_out_id: int, + last_interaction_shell_id: int, + ) -> None: """ Add a packet to the vpacket collection and potentially resizing the arrays. @@ -277,7 +339,7 @@ def add_packet( self.last_interaction_shell_id[self.idx] = last_interaction_shell_id self.idx += 1 - def finalize_arrays(self): + def finalize_arrays(self) -> None: """ Finalize the arrays by truncating them based on the current index. @@ -302,8 +364,11 @@ def finalize_arrays(self): @njit(**njit_dict_no_parallel) def consolidate_vpacket_tracker( - vpacket_collections, spectrum_frequency_grid, start_frequency, end_frequency -): + vpacket_collections, + spectrum_frequency_grid: np.ndarray, + start_frequency: float, + end_frequency: float, +) -> "VPacketCollection": """ Consolidate the vpacket trackers from multiple collections into a single vpacket tracker. diff --git a/tardis/transport/montecarlo/packets/packet_trackers.py b/tardis/transport/montecarlo/packets/packet_trackers.py index baf0167e605..80ccfad9a54 100644 --- a/tardis/transport/montecarlo/packets/packet_trackers.py +++ b/tardis/transport/montecarlo/packets/packet_trackers.py @@ -1,6 +1,7 @@ +import numba as nb import numpy as np import pandas as pd -from numba import float64, from_dtype, int64, njit +from numba import from_dtype, njit from numba.experimental import jitclass from numba.typed import List @@ -13,26 +14,22 @@ ) -rpacket_tracker_spec = [ - ("seed", int64), - ("index", int64), - ("status", int64[:]), - ("r", float64[:]), - ("nu", float64[:]), - ("mu", float64[:]), - ("energy", float64[:]), - ("shell_id", int64[:]), - ("interaction_type", int64[:]), - ("boundary_interaction", from_dtype(boundary_interaction_dtype)[:]), - ("num_interactions", int64), - ("boundary_interactions_index", int64), - ("event_id", int64), - ("extend_factor", int64), -] - - -@jitclass(rpacket_tracker_spec) +@jitclass class RPacketTracker: + seed: nb.int64 # type: ignore[misc] + index: nb.int64 # type: ignore[misc] + status: nb.int64[:] # type: ignore[misc] + r: nb.float64[:] # type: ignore[misc] + nu: nb.float64[:] # type: ignore[misc] + mu: nb.float64[:] # type: ignore[misc] + energy: nb.float64[:] # type: ignore[misc] + shell_id: nb.int64[:] # type: ignore[misc] + interaction_type: nb.int64[:] # type: ignore[misc] + boundary_interaction: from_dtype(boundary_interaction_dtype)[:] # type: ignore[misc] + num_interactions: nb.int64 # type: ignore[misc] + boundary_interactions_index: nb.int64 # type: ignore[misc] + event_id: nb.int64 # type: ignore[misc] + extend_factor: nb.int64 # type: ignore[misc] """ Numba JITCLASS for storing the information for each interaction a RPacket instance undergoes. @@ -64,7 +61,7 @@ class RPacketTracker: The factor by which to extend the properties array when the size limit is reached """ - def __init__(self, length): + def __init__(self, length: int) -> None: """ Initialize the variables with default value """ @@ -209,18 +206,14 @@ def rpacket_trackers_to_dataframe(rpacket_trackers): ) -rpacket_last_interaction_tracker_spec = [ - ("index", int64), - ("r", float64), - ("nu", float64), - ("energy", float64), - ("shell_id", int64), - ("interaction_type", int64), -] - - -@jitclass(rpacket_last_interaction_tracker_spec) +@jitclass class RPacketLastInteractionTracker: + index: nb.int64 # type: ignore[misc] + r: nb.float64 # type: ignore[misc] + nu: nb.float64 # type: ignore[misc] + energy: nb.float64 # type: ignore[misc] + shell_id: nb.int64 # type: ignore[misc] + interaction_type: nb.int64 # type: ignore[misc] """ Numba JITCLASS for storing the last interaction the RPacket undergoes. @@ -240,7 +233,7 @@ class RPacketLastInteractionTracker: Type of interaction the rpacket undergoes """ - def __init__(self): + def __init__(self) -> None: """ Initialize properties with default values """ @@ -251,9 +244,14 @@ def __init__(self): self.shell_id = -1 self.interaction_type = -1 - def track(self, r_packet): + def track(self, r_packet) -> None: """ - Track properties of RPacket and override the previous values + Track properties of RPacket and override the previous values. + + Parameters + ---------- + r_packet : RPacket + The R-packet to track. """ self.index = r_packet.index self.r = r_packet.r diff --git a/tardis/transport/montecarlo/packets/radiative_packet.py b/tardis/transport/montecarlo/packets/radiative_packet.py index 55c28ae6504..3144bd2a9a5 100644 --- a/tardis/transport/montecarlo/packets/radiative_packet.py +++ b/tardis/transport/montecarlo/packets/radiative_packet.py @@ -1,15 +1,11 @@ -from enum import IntEnum - +import numba as nb import numpy as np -from numba import float64, int64, njit, objmode +from enum import IntEnum +from numba import njit, objmode from numba.experimental import jitclass -from tardis.transport.frame_transformations import ( - get_doppler_factor, -) -from tardis.transport.montecarlo import ( - njit_dict_no_parallel, -) +from tardis.transport.frame_transformations import get_doppler_factor +from tardis.transport.montecarlo import njit_dict_no_parallel class InteractionType(IntEnum): @@ -26,28 +22,51 @@ class PacketStatus(IntEnum): ADIABATIC_COOLING = 4 -rpacket_spec = [ - ("r", float64), - ("mu", float64), - ("nu", float64), - ("energy", float64), - ("next_line_id", int64), - ("current_shell_id", int64), - ("status", int64), - ("seed", int64), - ("index", int64), - ("last_interaction_type", int64), - ("last_interaction_in_nu", float64), - ("last_interaction_in_r", float64), - ("last_line_interaction_in_id", int64), - ("last_line_interaction_out_id", int64), - ("last_line_interaction_shell_id", int64), -] +@jitclass +class RPacket: + r: nb.float64 # type: ignore[misc] + mu: nb.float64 # type: ignore[misc] + nu: nb.float64 # type: ignore[misc] + energy: nb.float64 # type: ignore[misc] + next_line_id: nb.int64 # type: ignore[misc] + current_shell_id: nb.int64 # type: ignore[misc] + status: nb.int64 # type: ignore[misc] + seed: nb.int64 # type: ignore[misc] + index: nb.int64 # type: ignore[misc] + last_interaction_type: nb.int64 # type: ignore[misc] + last_interaction_in_nu: nb.float64 # type: ignore[misc] + last_interaction_in_r: nb.float64 # type: ignore[misc] + last_line_interaction_in_id: nb.int64 # type: ignore[misc] + last_line_interaction_out_id: nb.int64 # type: ignore[misc] + last_line_interaction_shell_id: nb.int64 # type: ignore[misc] + def __init__( + self, + r: float, + mu: float, + nu: float, + energy: float, + seed: int, + index: int = 0, + ) -> None: + """ + Initialize radiative packet for Monte Carlo transport. -@jitclass(rpacket_spec) -class RPacket: - def __init__(self, r, mu, nu, energy, seed, index=0): + Parameters + ---------- + r : float + Initial radius [cm]. + mu : float + Initial directional cosine. + nu : float + Initial frequency [Hz]. + energy : float + Initial energy [erg]. + seed : int + Random number seed. + index : int, optional + Packet index, by default 0. + """ self.r = r self.mu = mu self.nu = nu @@ -82,20 +101,29 @@ def initialize_line_id( @njit(**njit_dict_no_parallel) def print_r_packet_properties(r_packet): """ - Print all packet information + Print all packet information. Parameters ---------- r_packet : RPacket - RPacket object + RPacket object. """ print("-" * 80) print("R-Packet information:") with objmode: - for r_packet_attribute_name, _ in rpacket_spec: - print( - r_packet_attribute_name, - "=", - str(getattr(r_packet, r_packet_attribute_name)), - ) + print("r =", str(r_packet.r)) + print("mu =", str(r_packet.mu)) + print("nu =", str(r_packet.nu)) + print("energy =", str(r_packet.energy)) + print("next_line_id =", str(r_packet.next_line_id)) + print("current_shell_id =", str(r_packet.current_shell_id)) + print("status =", str(r_packet.status)) + print("seed =", str(r_packet.seed)) + print("index =", str(r_packet.index)) + print("last_interaction_type =", str(r_packet.last_interaction_type)) + print("last_interaction_in_nu =", str(r_packet.last_interaction_in_nu)) + print("last_interaction_in_r =", str(r_packet.last_interaction_in_r)) + print("last_line_interaction_in_id =", str(r_packet.last_line_interaction_in_id)) + print("last_line_interaction_out_id =", str(r_packet.last_line_interaction_out_id)) + print("last_line_interaction_shell_id =", str(r_packet.last_line_interaction_shell_id)) print("-" * 80) diff --git a/tardis/transport/montecarlo/packets/virtual_packet.py b/tardis/transport/montecarlo/packets/virtual_packet.py index 27a6141d355..c7603e9b3ac 100644 --- a/tardis/transport/montecarlo/packets/virtual_packet.py +++ b/tardis/transport/montecarlo/packets/virtual_packet.py @@ -1,13 +1,12 @@ import math +import numba as nb import numpy as np -from numba import float64, int64, njit +from numba import njit from numba.experimental import jitclass import tardis.transport.montecarlo.configuration.montecarlo_globals as montecarlo_globals -from tardis.opacities.opacities import ( - chi_continuum_calculator, -) +from tardis.opacities.opacities import chi_continuum_calculator from tardis.transport.frame_transformations import ( angle_aberration_CMF_to_LF, angle_aberration_LF_to_CMF, @@ -22,37 +21,52 @@ C_SPEED_OF_LIGHT, SIGMA_THOMSON, ) -from tardis.transport.montecarlo.packets.radiative_packet import ( - PacketStatus, -) +from tardis.transport.montecarlo.packets.radiative_packet import PacketStatus from tardis.transport.montecarlo.r_packet_transport import ( move_packet_across_shell_boundary, ) -vpacket_spec = [ - ("r", float64), - ("mu", float64), - ("nu", float64), - ("energy", float64), - ("next_line_id", int64), - ("current_shell_id", int64), - ("status", int64), - ("index", int64), -] - - -@jitclass(vpacket_spec) +@jitclass class VPacket: + r: nb.float64 # type: ignore[misc] + mu: nb.float64 # type: ignore[misc] + nu: nb.float64 # type: ignore[misc] + energy: nb.float64 # type: ignore[misc] + next_line_id: nb.int64 # type: ignore[misc] + current_shell_id: nb.int64 # type: ignore[misc] + status: nb.int64 # type: ignore[misc] + index: nb.int64 # type: ignore[misc] + def __init__( self, - r, - mu, - nu, - energy, - current_shell_id, - next_line_id, - index=0, - ): + r: float, + mu: float, + nu: float, + energy: float, + current_shell_id: int, + next_line_id: int, + index: int = 0, + ) -> None: + """ + Initialize virtual packet for Monte Carlo transport. + + Parameters + ---------- + r : float + Initial radius [cm]. + mu : float + Initial directional cosine. + nu : float + Initial frequency [Hz]. + energy : float + Initial energy [erg]. + current_shell_id : int + Current shell index. + next_line_id : int + Next line interaction index. + index : int, optional + Packet index, by default 0. + """ self.r = r self.mu = mu self.nu = nu diff --git a/tardis/transport/montecarlo/r_packet_transport.py b/tardis/transport/montecarlo/r_packet_transport.py index 3df8939feff..aecfb8e749d 100644 --- a/tardis/transport/montecarlo/r_packet_transport.py +++ b/tardis/transport/montecarlo/r_packet_transport.py @@ -14,7 +14,7 @@ update_base_estimators, update_line_estimators, ) -from tardis.transport.montecarlo.r_packet import ( +from tardis.transport.montecarlo.packets.radiative_packet import ( InteractionType, PacketStatus, ) diff --git a/tardis/transport/montecarlo/single_packet_loop.py b/tardis/transport/montecarlo/single_packet_loop.py index 73ab86d3599..0c48ed4d1d6 100644 --- a/tardis/transport/montecarlo/single_packet_loop.py +++ b/tardis/transport/montecarlo/single_packet_loop.py @@ -18,7 +18,7 @@ line_scatter, thomson_scatter, ) -from tardis.transport.montecarlo.r_packet import ( +from tardis.transport.montecarlo.packets.radiative_packet import ( InteractionType, PacketStatus, RPacket, @@ -28,7 +28,7 @@ move_r_packet, trace_packet, ) -from tardis.transport.montecarlo.vpacket import trace_vpacket_volley +from tardis.transport.montecarlo.packets.virtual_packet import trace_vpacket_volley C_SPEED_OF_LIGHT = const.c.to("cm/s").value diff --git a/tardis/transport/montecarlo/tests/conftest.py b/tardis/transport/montecarlo/tests/conftest.py index 17c1ab6f942..d29dcfe8b7c 100644 --- a/tardis/transport/montecarlo/tests/conftest.py +++ b/tardis/transport/montecarlo/tests/conftest.py @@ -13,7 +13,7 @@ from tardis.opacities.opacity_state_numba import ( opacity_state_numba_initialize, ) -from tardis.transport.montecarlo.packet_collections import ( +from tardis.transport.montecarlo.packets.packet_collections import ( VPacketCollection, ) from tardis.transport.montecarlo.weighted_packet_source import ( diff --git a/tardis/transport/montecarlo/tests/test_montecarlo.py b/tardis/transport/montecarlo/tests/test_montecarlo.py index 663e3dbb990..6814a8b578e 100644 --- a/tardis/transport/montecarlo/tests/test_montecarlo.py +++ b/tardis/transport/montecarlo/tests/test_montecarlo.py @@ -4,7 +4,7 @@ import pandas as pd import pytest -import tardis.transport.montecarlo.r_packet as r_packet +import tardis.transport.montecarlo.packets.radiative_packet as radiative_packet import tardis.transport.montecarlo.r_packet_transport as r_packet_transport import tardis.transport.montecarlo.utils as utils from tardis import constants as const @@ -16,7 +16,7 @@ from tardis.transport.montecarlo.estimators.radfield_mc_estimators import ( RadiationFieldMCEstimators, ) -from tardis.transport.montecarlo.packet_trackers import RPacketTracker +from tardis.transport.montecarlo.packets.packet_trackers import RPacketTracker C_SPEED_OF_LIGHT = const.c.to("cm/s").value @@ -96,8 +96,8 @@ def test_get_random_mu_different_output(): """ Ensure that different calls results """ - output1 = r_packet.get_random_mu() - output2 = r_packet.get_random_mu() + output1 = radiative_packet.get_random_mu() + output2 = radiative_packet.get_random_mu() assert output1 != output2 @@ -105,8 +105,8 @@ def test_get_random_mu_different_output(): """ Ensure that different calls results """ - output1 = r_packet.get_random_mu() - output2 = r_packet.get_random_mu() + output1 = radiative_packet.get_random_mu() + output2 = radiative_packet.get_random_mu() assert output1 != output2 @@ -119,9 +119,9 @@ def test_angle_ab_LF_to_CMF_diverge(mu, r, time_explosion): """ nu = 0.4 energy = 0.9 - packet = r_packet.RPacket(r, mu, nu, energy) + packet = radiative_packet.RPacket(r, mu, nu, energy) with pytest.raises(ZeroDivisionError): - obtained = r_packet.angle_aberration_LF_to_CMF( + obtained = radiative_packet.angle_aberration_LF_to_CMF( packet, time_explosion, mu ) @@ -136,7 +136,7 @@ def test_both_angle_aberrations(mu, r, time_explosion): """ nu = 0.4 energy = 0.9 - packet = r_packet.RPacket(r, mu, nu, energy) + packet = radiative_packet.RPacket(r, mu, nu, energy) packet.r = r obtained_mu = angle_aberration_LF_to_CMF(packet, time_explosion, mu) inverse_obtained_mu = angle_aberration_CMF_to_LF( @@ -155,7 +155,7 @@ def test_both_angle_aberrations_inverse(mu, r, time_explosion): """ nu = 0.4 energy = 0.9 - packet = r_packet.RPacket(r, mu, nu, energy) + packet = radiative_packet.RPacket(r, mu, nu, energy) packet.r = r obtained_mu = angle_aberration_CMF_to_LF(packet, time_explosion, mu) inverse_obtained_mu = angle_aberration_LF_to_CMF( @@ -175,12 +175,12 @@ def test_move_packet_across_shell_boundary_emitted( mu = 0.3 nu = 0.4 energy = 0.9 - packet = r_packet.RPacket(r, mu, nu, energy) + packet = radiative_packet.RPacket(r, mu, nu, energy) packet.current_shell_id = current_shell_id r_packet_transport.move_packet_across_shell_boundary( packet, delta_shell, no_of_shells ) - assert packet.status == r_packet.PacketStatus.EMITTED + assert packet.status == radiative_packet.PacketStatus.EMITTED @pytest.mark.parametrize( @@ -194,12 +194,12 @@ def test_move_packet_across_shell_boundary_reabsorbed( mu = 0.3 nu = 0.4 energy = 0.9 - packet = r_packet.RPacket(r, mu, nu, energy) + packet = radiative_packet.RPacket(r, mu, nu, energy) packet.current_shell_id = current_shell_id r_packet_transport.move_packet_across_shell_boundary( packet, delta_shell, no_of_shells ) - assert packet.status == r_packet.PacketStatus.REABSORBED + assert packet.status == radiative_packet.PacketStatus.REABSORBED @pytest.mark.parametrize( @@ -213,7 +213,7 @@ def test_move_packet_across_shell_boundary_increment( mu = 0.3 nu = 0.4 energy = 0.9 - packet = r_packet.RPacket(r, mu, nu, energy) + packet = radiative_packet.RPacket(r, mu, nu, energy) packet.current_shell_id = current_shell_id r_packet_transport.move_packet_across_shell_boundary( packet, delta_shell, no_of_shells @@ -228,8 +228,8 @@ def test_move_packet_across_shell_boundary_increment( def test_packet_energy_limit_one(distance_trace, time_explosion, mu, r): initial_energy = 0.9 nu = 0.4 - packet = r_packet.RPacket(r, mu, nu, initial_energy) - new_energy = r_packet.calc_packet_energy( + packet = radiative_packet.RPacket(r, mu, nu, initial_energy) + new_energy = radiative_packet.calc_packet_energy( packet, distance_trace, time_explosion ) assert new_energy == initial_energy @@ -249,7 +249,7 @@ def test_compute_distance2boundary(packet_params, expected_params): r_inner = np.array([6.912e14, 8.64e14], dtype=np.float64) r_outer = np.array([8.64e14, 1.0368e15], dtype=np.float64) - d_boundary = r_packet.calculate_distance_boundary( + d_boundary = radiative_packet.calculate_distance_boundary( r, mu, r_inner[0], r_outer[0] ) @@ -285,7 +285,7 @@ def test_compute_distance2line(packet_params, expected_params): mu = 0.3 nu = 0.4 energy = 0.9 - packet = r_packet.RPacket(r, mu, nu, energy) + packet = radiative_packet.RPacket(r, mu, nu, energy) nu_line = packet_params["nu_line"] # packet.next_line_id = packet_params['next_line_id'] # packet.last_line = packet_params['last_line'] @@ -300,7 +300,7 @@ def test_compute_distance2line(packet_params, expected_params): d_line = 0 obtained_tardis_error = None try: - d_line = r_packet.calculate_distance_line( + d_line = radiative_packet.calculate_distance_line( packet, comov_nu, nu_line, time_explosion ) except utils.MonteCarloException: @@ -355,7 +355,7 @@ def test_move_packet(packet_params, expected_params, full_relativity): distance = 1e13 r, mu, nu, energy = 7.5e14, 0.3, 0.4, 0.9 time_explosion = 5.2e7 - packet = r_packet.RPacket(r, mu, nu, energy) + packet = radiative_packet.RPacket(r, mu, nu, energy) packet.nu = packet_params["nu"] packet.mu = packet_params["mu"] packet.energy = packet_params["energy"] @@ -466,14 +466,14 @@ def test_move_packet(packet_params, expected_params, full_relativity): ], ) def test_frame_transformations(mu, r, inv_t_exp, full_relativity): - packet = r_packet.RPacket(r=r, mu=mu, energy=0.9, nu=0.4) + packet = radiative_packet.RPacket(r=r, mu=mu, energy=0.9, nu=0.4) mc.ENABLE_FULL_RELATIVITY = bool(full_relativity) mc.ENABLE_FULL_RELATIVITY = full_relativity - inverse_doppler_factor = r_packet.get_inverse_doppler_factor( + inverse_doppler_factor = radiative_packet.get_inverse_doppler_factor( r, mu, 1 / inv_t_exp ) - r_packet.angle_aberration_CMF_to_LF(packet, 1 / inv_t_exp, packet.mu) + radiative_packet.angle_aberration_CMF_to_LF(packet, 1 / inv_t_exp, packet.mu) doppler_factor = get_doppler_factor(r, mu, 1 / inv_t_exp) mc.ENABLE_FULL_RELATIVITY = False @@ -492,7 +492,7 @@ def test_frame_transformations(mu, r, inv_t_exp, full_relativity): ], ) def test_angle_transformation_invariance(mu, r, inv_t_exp): - packet = r_packet.RPacket(r, mu, 0.4, 0.9) + packet = radiative_packet.RPacket(r, mu, 0.4, 0.9) mc.ENABLE_FULL_RELATIVITY = True mu1 = angle_aberration_CMF_to_LF(packet, 1 / inv_t_exp, mu) @@ -518,7 +518,7 @@ def test_angle_transformation_invariance(mu, r, inv_t_exp): def test_compute_distance2line_relativistic( mu, r, t_exp, nu, nu_line, full_relativity ): - packet = r_packet.RPacket(r=r, nu=nu, mu=mu, energy=0.9) + packet = radiative_packet.RPacket(r=r, nu=nu, mu=mu, energy=0.9) # packet.nu_line = nu_line numba_estimator = RadiationFieldMCEstimators( transport.j_estimator, @@ -530,7 +530,7 @@ def test_compute_distance2line_relativistic( doppler_factor = get_doppler_factor(r, mu, t_exp) comov_nu = packet.nu * doppler_factor - distance = r_packet.calculate_distance_line( + distance = radiative_packet.calculate_distance_line( packet, comov_nu, nu_line, t_exp ) r_packet_transport.move_r_packet( @@ -566,7 +566,7 @@ def test_rpacket_tracking(index, seed, r, nu, mu, energy): mc.INITIAL_TRACKING_ARRAY_LENGTH = 10 tracked_rpacket_properties = RPacketTracker() - test_rpacket = r_packet.RPacket( + test_rpacket = radiative_packet.RPacket( index=index, seed=seed, r=r, diff --git a/tardis/transport/montecarlo/tests/test_packet.py b/tardis/transport/montecarlo/tests/test_packet.py index c473e09d2a1..a2a17f1911d 100644 --- a/tardis/transport/montecarlo/tests/test_packet.py +++ b/tardis/transport/montecarlo/tests/test_packet.py @@ -6,7 +6,7 @@ import tardis.transport.geometry.calculate_distances as calculate_distances import tardis.transport.montecarlo.configuration.montecarlo_globals as montecarlo_globals import tardis.transport.montecarlo.estimators.radfield_mc_estimators -import tardis.transport.montecarlo.r_packet as r_packet +import tardis.transport.montecarlo.packets.radiative_packet as radiative_packet import tardis.transport.montecarlo.r_packet_transport as r_packet_transport import tardis.transport.montecarlo.utils as utils from tardis import constants as const @@ -349,7 +349,7 @@ def test_move_packet_across_shell_boundary_emitted( r_packet_transport.move_packet_across_shell_boundary( packet, delta_shell, no_of_shells ) - assert packet.status == r_packet.PacketStatus.EMITTED + assert packet.status == radiative_packet.PacketStatus.EMITTED @pytest.mark.parametrize( @@ -363,7 +363,7 @@ def test_move_packet_across_shell_boundary_reabsorbed( r_packet_transport.move_packet_across_shell_boundary( packet, delta_shell, no_of_shells ) - assert packet.status == r_packet.PacketStatus.REABSORBED + assert packet.status == radiative_packet.PacketStatus.REABSORBED @pytest.mark.parametrize( diff --git a/tardis/transport/montecarlo/tests/test_rpacket_last_interaction_tracker.py b/tardis/transport/montecarlo/tests/test_rpacket_last_interaction_tracker.py index 54c84df5579..25539673e33 100644 --- a/tardis/transport/montecarlo/tests/test_rpacket_last_interaction_tracker.py +++ b/tardis/transport/montecarlo/tests/test_rpacket_last_interaction_tracker.py @@ -2,10 +2,10 @@ import numpy.testing as npt import pytest -from tardis.transport.montecarlo.packet_trackers import ( +from tardis.transport.montecarlo.packets.packet_trackers import ( RPacketLastInteractionTracker, ) -from tardis.transport.montecarlo.r_packet import InteractionType +from tardis.transport.montecarlo.packets.radiative_packet import InteractionType @pytest.fixture(scope="module") diff --git a/tardis/transport/montecarlo/tests/test_rpacket_tracker.py b/tardis/transport/montecarlo/tests/test_rpacket_tracker.py index 67f779091d1..31b526c5499 100644 --- a/tardis/transport/montecarlo/tests/test_rpacket_tracker.py +++ b/tardis/transport/montecarlo/tests/test_rpacket_tracker.py @@ -2,11 +2,11 @@ import numpy.testing as npt import pytest -from tardis.transport.montecarlo.packet_trackers import ( +from tardis.transport.montecarlo.packets.packet_trackers import ( RPacketTracker, rpacket_trackers_to_dataframe, ) -from tardis.transport.montecarlo.r_packet import InteractionType +from tardis.transport.montecarlo.packets.radiative_packet import InteractionType @pytest.fixture diff --git a/tardis/transport/montecarlo/tests/test_tracker_utils.py b/tardis/transport/montecarlo/tests/test_tracker_utils.py index 92ddc5b3cf7..ddc06d7b10b 100644 --- a/tardis/transport/montecarlo/tests/test_tracker_utils.py +++ b/tardis/transport/montecarlo/tests/test_tracker_utils.py @@ -1,7 +1,7 @@ import numpy as np from numba import typeof -from tardis.transport.montecarlo.packet_trackers import ( +from tardis.transport.montecarlo.packets.packet_trackers import ( RPacketLastInteractionTracker, RPacketTracker, generate_rpacket_last_interaction_tracker_list, diff --git a/tardis/transport/montecarlo/tests/test_vpacket.py b/tardis/transport/montecarlo/tests/test_vpacket.py index 51a8e31dd78..56f73d0aac8 100644 --- a/tardis/transport/montecarlo/tests/test_vpacket.py +++ b/tardis/transport/montecarlo/tests/test_vpacket.py @@ -1,7 +1,7 @@ import numpy as np import pytest -import tardis.transport.montecarlo.vpacket as vpacket +import tardis.transport.montecarlo.packets.virtual_packet as virtual_packet from tardis import constants as const from tardis.transport.frame_transformations import ( get_doppler_factor, @@ -14,7 +14,7 @@ @pytest.fixture(scope="function") def v_packet(): - return vpacket.VPacket( + return virtual_packet.VPacket( r=7.5e14, nu=4e15, mu=0.3, @@ -52,7 +52,7 @@ def test_trace_vpacket_within_shell( tau_trace_combined, distance_boundary, delta_shell, - ) = vpacket.trace_vpacket_within_shell( + ) = virtual_packet.trace_vpacket_within_shell( v_packet, verysimple_numba_radial_1d_geometry, verysimple_time_explosion, @@ -80,7 +80,7 @@ def test_trace_vpacket( v_packet, verysimple_opacity_state, verysimple_time_explosion ) - tau_trace_combined = vpacket.trace_vpacket( + tau_trace_combined = virtual_packet.trace_vpacket( v_packet, verysimple_numba_radial_1d_geometry, verysimple_time_explosion, @@ -117,7 +117,7 @@ def test_trace_vpacket_volley( verysimple_opacity_state, verysimple_time_explosion ) - vpacket.trace_vpacket_volley( + virtual_packet.trace_vpacket_volley( packet, verysimple_3vpacket_collection, verysimple_numba_radial_1d_geometry, @@ -131,7 +131,7 @@ def test_trace_vpacket_volley( @pytest.fixture(scope="function") def broken_packet(): - return vpacket.VPacket( + return virtual_packet.VPacket( r=1286064000000000.0, nu=1660428912896553.2, mu=0.4916053094346575, @@ -148,7 +148,7 @@ def test_trace_bad_vpacket( verysimple_time_explosion, verysimple_opacity_state, ): - vpacket.trace_vpacket( + virtual_packet.trace_vpacket( broken_packet, verysimple_numba_radial_1d_geometry, verysimple_time_explosion, From 209883ff302ca420b23ae668c86b10851565d839 Mon Sep 17 00:00:00 2001 From: Wolfgang Kerzendorf Date: Wed, 27 Aug 2025 07:43:39 -0400 Subject: [PATCH 6/9] Refactor packet source and testing structure - Updated import paths in `packet_source.py` to reflect new module structure. --- docs/io/optional/how_to_custom_source.ipynb | 59 ++++++++++++------- .../montecarlo/initialization.ipynb | 24 ++++---- tardis/transport/montecarlo/packet_source.py | 2 +- .../montecarlo/packets/tests/__init__.py | 0 .../montecarlo/packets/tests/conftest.py | 5 ++ .../{ => packets}/tests/test_packet.py | 0 .../test_rpacket_last_interaction_tracker.py | 0 .../tests/test_rpacket_tracker.py | 0 .../{ => packets}/tests/test_vpacket.py | 0 9 files changed, 55 insertions(+), 35 deletions(-) create mode 100644 tardis/transport/montecarlo/packets/tests/__init__.py create mode 100644 tardis/transport/montecarlo/packets/tests/conftest.py rename tardis/transport/montecarlo/{ => packets}/tests/test_packet.py (100%) rename tardis/transport/montecarlo/{ => packets}/tests/test_rpacket_last_interaction_tracker.py (100%) rename tardis/transport/montecarlo/{ => packets}/tests/test_rpacket_tracker.py (100%) rename tardis/transport/montecarlo/{ => packets}/tests/test_vpacket.py (100%) diff --git a/docs/io/optional/how_to_custom_source.ipynb b/docs/io/optional/how_to_custom_source.ipynb index c7b02edcabe..8f6a21fe777 100644 --- a/docs/io/optional/how_to_custom_source.ipynb +++ b/docs/io/optional/how_to_custom_source.ipynb @@ -38,16 +38,16 @@ "outputs": [], "source": [ "# Import necessary packages\n", + "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from tardis import constants as const\n", "from astropy import units as u\n", + "\n", + "from tardis import run_tardis\n", + "from tardis.io.atom_data import download_atom_data\n", "from tardis.transport.montecarlo.packet_source import BlackBodySimpleSource\n", - "from tardis.transport.montecarlo.packet_collections import (\n", + "from tardis.transport.montecarlo.packets.packet_collections import (\n", " PacketCollection,\n", - ")\n", - "from tardis import run_tardis\n", - "import matplotlib.pyplot as plt\n", - "from tardis.io.atom_data import download_atom_data" + ")" ] }, { @@ -90,7 +90,14 @@ " self.truncation_wavelength = truncation_wavelength\n", " super().__init__(**kwargs)\n", "\n", - " def create_packets(self, no_of_packets, drawing_sample_size=None, seed_offset=0, *args, **kwargs):\n", + " def create_packets(\n", + " self,\n", + " no_of_packets,\n", + " drawing_sample_size=None,\n", + " seed_offset=0,\n", + " *args,\n", + " **kwargs,\n", + " ):\n", " \"\"\"\n", " Packet source that generates a truncated Blackbody source.\n", "\n", @@ -110,7 +117,6 @@ " array\n", " Packet energies\n", " \"\"\"\n", - "\n", " self._reseed(self.base_seed + seed_offset)\n", " packet_seeds = self.rng.choice(\n", " self.MAX_SEED_VAL, no_of_packets, replace=True\n", @@ -128,10 +134,9 @@ " drawing_sample_size = 2 * no_of_packets\n", "\n", " # Blackbody will be truncated below truncation_wavelength / above truncation_frequency.\n", - " truncation_frequency = (\n", - " u.Quantity(self.truncation_wavelength, u.Angstrom)\n", - " .to(u.Hz, equivalencies=u.spectral())\n", - " )\n", + " truncation_frequency = u.Quantity(\n", + " self.truncation_wavelength, u.Angstrom\n", + " ).to(u.Hz, equivalencies=u.spectral())\n", "\n", " # Draw nus from blackbody distribution and reject based on truncation_frequency.\n", " # If more nus.shape[0] > no_of_packets use only the first no_of_packets.\n", @@ -141,7 +146,9 @@ " # Only required if the truncation wavelength is too big compared to the maximum\n", " # of the blackbody distribution. Keep sampling until nus.shape[0] > no_of_packets.\n", " while nus.shape[0] < no_of_packets:\n", - " additional_nus = self.create_packet_nus(drawing_sample_size, *args, **kwargs)\n", + " additional_nus = self.create_packet_nus(\n", + " drawing_sample_size, *args, **kwargs\n", + " )\n", " mask = additional_nus < truncation_frequency\n", " additional_nus = additional_nus[mask][:no_of_packets]\n", " nus = np.hstack([nus, additional_nus])[:no_of_packets]\n", @@ -150,7 +157,9 @@ " self.calculate_radfield_luminosity().to(u.erg / u.s).value\n", " )\n", "\n", - " return PacketCollection(radii, nus, mus, energies, packet_seeds, radiation_field_luminosity)" + " return PacketCollection(\n", + " radii, nus, mus, energies, packet_seeds, radiation_field_luminosity\n", + " )" ] }, { @@ -190,14 +199,20 @@ "outputs": [], "source": [ "%matplotlib inline\n", - "plt.plot(mdl.spectrum_solver.spectrum_virtual_packets.wavelength,\n", - " mdl.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n", - " color='red', label='truncated blackbody (custom packet source)')\n", - "plt.plot(mdl_norm.spectrum_solver.spectrum_virtual_packets.wavelength,\n", - " mdl_norm.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n", - " color='blue', label='normal blackbody (default packet source)')\n", - "plt.xlabel(r'$\\lambda [\\AA]$')\n", - "plt.ylabel(r'$L_\\lambda$ [erg/s/$\\AA$]')\n", + "plt.plot(\n", + " mdl.spectrum_solver.spectrum_virtual_packets.wavelength,\n", + " mdl.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n", + " color=\"red\",\n", + " label=\"truncated blackbody (custom packet source)\",\n", + ")\n", + "plt.plot(\n", + " mdl_norm.spectrum_solver.spectrum_virtual_packets.wavelength,\n", + " mdl_norm.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n", + " color=\"blue\",\n", + " label=\"normal blackbody (default packet source)\",\n", + ")\n", + "plt.xlabel(r\"$\\lambda [\\AA]$\")\n", + "plt.ylabel(r\"$L_\\lambda$ [erg/s/$\\AA$]\")\n", "plt.xlim(500, 10000)\n", "plt.legend()" ] diff --git a/docs/physics_walkthrough/montecarlo/initialization.ipynb b/docs/physics_walkthrough/montecarlo/initialization.ipynb index b20324d7e1b..2b0758ac385 100644 --- a/docs/physics_walkthrough/montecarlo/initialization.ipynb +++ b/docs/physics_walkthrough/montecarlo/initialization.ipynb @@ -77,14 +77,12 @@ "metadata": {}, "outputs": [], "source": [ + "import matplotlib.pyplot as plt\n", "import numpy as np\n", - "from tardis.transport.montecarlo.packet_source import BlackBodySimpleSource\n", - "from tardis.transport.montecarlo.packet_collections import (\n", - " PacketCollection,\n", - ")\n", "from astropy import units as u\n", + "\n", "from tardis import constants as const\n", - "import matplotlib.pyplot as plt" + "from tardis.transport.montecarlo.packet_source import BlackBodySimpleSource" ] }, { @@ -135,11 +133,7 @@ "temperature_inner = 10000 * u.K\n", "\n", "luminosity_inner = (\n", - " 4\n", - " * np.pi\n", - " * (r_boundary_inner**2)\n", - " * const.sigma_sb\n", - " * (temperature_inner**4)\n", + " 4 * np.pi * (r_boundary_inner**2) * const.sigma_sb * (temperature_inner**4)\n", ")\n", "\n", "# Makes sure the luminosity is given in erg/s\n", @@ -263,12 +257,18 @@ "source": [ "# We set important quantities for making our histogram\n", "bins = 200\n", - "nus_planck = np.linspace(min(packet_collection.initial_nus), max(packet_collection.initial_nus), bins).value\n", + "nus_planck = np.linspace(\n", + " min(packet_collection.initial_nus), max(packet_collection.initial_nus), bins\n", + ").value\n", "bin_width = nus_planck[1] - nus_planck[0]\n", "\n", "# In the histogram plot below, the weights argument is used\n", "# to make sure our plotted spectrum has the correct y-axis scale\n", - "plt.hist(packet_collection.initial_nus.value, bins=bins, weights=lumin_per_packet / bin_width)\n", + "plt.hist(\n", + " packet_collection.initial_nus.value,\n", + " bins=bins,\n", + " weights=lumin_per_packet / bin_width,\n", + ")\n", "\n", "# We plot the planck function for comparison\n", "plt.plot(nus_planck * u.Hz, planck_function(nus_planck * u.Hz))\n", diff --git a/tardis/transport/montecarlo/packet_source.py b/tardis/transport/montecarlo/packet_source.py index d9a9cd6bc27..1012b7a49ad 100644 --- a/tardis/transport/montecarlo/packet_source.py +++ b/tardis/transport/montecarlo/packet_source.py @@ -6,7 +6,7 @@ from tardis import constants as const from tardis.io.hdf_writer_mixin import HDFWriterMixin -from tardis.transport.montecarlo.packet_collections import ( +from tardis.transport.montecarlo.packets.packet_collections import ( PacketCollection, ) diff --git a/tardis/transport/montecarlo/packets/tests/__init__.py b/tardis/transport/montecarlo/packets/tests/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tardis/transport/montecarlo/packets/tests/conftest.py b/tardis/transport/montecarlo/packets/tests/conftest.py new file mode 100644 index 00000000000..df06d6d7940 --- /dev/null +++ b/tardis/transport/montecarlo/packets/tests/conftest.py @@ -0,0 +1,5 @@ +"""Test fixtures for the packets package.""" + +# Import all fixtures from the parent montecarlo tests conftest +# This ensures the packet tests have access to all the necessary fixtures +from tardis.transport.montecarlo.tests.conftest import * # noqa: F403 diff --git a/tardis/transport/montecarlo/tests/test_packet.py b/tardis/transport/montecarlo/packets/tests/test_packet.py similarity index 100% rename from tardis/transport/montecarlo/tests/test_packet.py rename to tardis/transport/montecarlo/packets/tests/test_packet.py diff --git a/tardis/transport/montecarlo/tests/test_rpacket_last_interaction_tracker.py b/tardis/transport/montecarlo/packets/tests/test_rpacket_last_interaction_tracker.py similarity index 100% rename from tardis/transport/montecarlo/tests/test_rpacket_last_interaction_tracker.py rename to tardis/transport/montecarlo/packets/tests/test_rpacket_last_interaction_tracker.py diff --git a/tardis/transport/montecarlo/tests/test_rpacket_tracker.py b/tardis/transport/montecarlo/packets/tests/test_rpacket_tracker.py similarity index 100% rename from tardis/transport/montecarlo/tests/test_rpacket_tracker.py rename to tardis/transport/montecarlo/packets/tests/test_rpacket_tracker.py diff --git a/tardis/transport/montecarlo/tests/test_vpacket.py b/tardis/transport/montecarlo/packets/tests/test_vpacket.py similarity index 100% rename from tardis/transport/montecarlo/tests/test_vpacket.py rename to tardis/transport/montecarlo/packets/tests/test_vpacket.py From f9a12fa2fabd3d6df04c4ace30bce3aa29fd533b Mon Sep 17 00:00:00 2001 From: Wolfgang Kerzendorf Date: Wed, 27 Aug 2025 10:22:18 -0400 Subject: [PATCH 7/9] Update tardis/opacities/opacity_state_numba.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- tardis/opacities/opacity_state_numba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tardis/opacities/opacity_state_numba.py b/tardis/opacities/opacity_state_numba.py index 600f7579fc3..52565eb8c0e 100644 --- a/tardis/opacities/opacity_state_numba.py +++ b/tardis/opacities/opacity_state_numba.py @@ -283,7 +283,7 @@ def opacity_state_numba_initialize( photo_ion_activation_idx = plasma.photo_ion_idx.loc[ plasma.level2continuum_idx.index, "destination_level_idx" ].values - k_packet_idx = int(plasma.k_packet_idx) + k_packet_idx = plasma.k_packet_idx else: bf_threshold_list_nu = np.zeros(0, dtype=np.float64) p_fb_deactivation = np.zeros((0, 0), dtype=np.float64) From acaa20f6c9dba5938f9ef4cdc70957cd7439700b Mon Sep 17 00:00:00 2001 From: Wolfgang Kerzendorf Date: Wed, 27 Aug 2025 14:06:30 -0400 Subject: [PATCH 8/9] Update energy parameter description in RPacket class --- tardis/transport/montecarlo/packets/radiative_packet.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tardis/transport/montecarlo/packets/radiative_packet.py b/tardis/transport/montecarlo/packets/radiative_packet.py index 3144bd2a9a5..52501097cbc 100644 --- a/tardis/transport/montecarlo/packets/radiative_packet.py +++ b/tardis/transport/montecarlo/packets/radiative_packet.py @@ -61,7 +61,8 @@ def __init__( nu : float Initial frequency [Hz]. energy : float - Initial energy [erg]. + Initial energy. Energy units are scaled with time_simulation. + Adds all up to 1 in a single run. seed : int Random number seed. index : int, optional From ede4621dda5ad1fbc7e22ca393ce2f27274db029 Mon Sep 17 00:00:00 2001 From: Wolfgang Kerzendorf Date: Wed, 27 Aug 2025 15:15:57 -0400 Subject: [PATCH 9/9] empty commit to get CI going