Skip to content

Commit 0d62366

Browse files
wkerzendorfjvshieldsRodot-DeerWhale
authored
Restructure/transport solver v1 (tardis-sn#3218)
* Co-authored-by: Joshua Shields <[email protected]> Co-authored-by: Jack O'Brien <[email protected]> Co-authored-by: Jing Lu <[email protected]> Refactors transport initialization to use geometry state directly Simplifies the transport state initialization by passing the geometry state object directly instead of the full simulation state. This change reduces coupling between transport and simulation components and clarifies the specific data dependencies needed for transport initialization. Updates method signatures and variable names throughout the transport system to reflect this more focused interface. * Refactor transport state initialization in Simulation and MonteCarloTransportSolver classes Co-authored-by: Joshua Shields <[email protected]> Co-authored-by: Jack O'Brien <[email protected]> Co-authored-by: Jing Lu <[email protected]> --------- Co-authored-by: Joshua Shields <[email protected]> Co-authored-by: Jack O'Brien <[email protected]> Co-authored-by: Jing Lu <[email protected]>
1 parent 2b64dc6 commit 0d62366

File tree

3 files changed

+44
-16
lines changed

3 files changed

+44
-16
lines changed

tardis/simulation/base.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -437,23 +437,21 @@ def iterate(self, no_of_packets, no_of_virtual_packets=0):
437437
self.opacity_state.beta_sobolev,
438438
)
439439

440-
transport_state = self.transport.initialize_transport_state(
441-
self.simulation_state,
440+
v_packets_energy_hist = self.transport.run(
441+
self.simulation_state.geometry,
442442
self.opacity_state,
443443
macro_atom_state,
444444
self.plasma,
445445
no_of_packets,
446446
no_of_virtual_packets=no_of_virtual_packets,
447447
iteration=self.iterations_executed,
448-
)
449-
450-
v_packets_energy_hist = self.transport.run(
451-
transport_state,
452-
iteration=self.iterations_executed,
453448
total_iterations=self.iterations,
454449
show_progress_bars=self.show_progress_bars,
455450
)
456451

452+
# Get transport state for further processing
453+
transport_state = self.transport.transport_state
454+
457455
output_energy = self.transport.transport_state.packet_collection.output_energies
458456
if np.sum(output_energy < 0) == len(output_energy):
459457
logger.critical("No r-packet escaped through the outer boundary.")

tardis/transport/montecarlo/base.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(
9797

9898
def initialize_transport_state(
9999
self,
100-
simulation_state,
100+
geometry_state,
101101
opacity_state,
102102
macro_atom_state,
103103
plasma,
@@ -114,13 +114,13 @@ def initialize_transport_state(
114114
no_of_packets, seed_offset=iteration
115115
)
116116

117-
geometry_state = simulation_state.geometry.to_numba()
117+
geometry_state_numba = geometry_state.to_numba()
118118
opacity_state_numba = opacity_state.to_numba(
119119
macro_atom_state,
120120
self.line_interaction_type,
121121
)
122122
opacity_state_numba = opacity_state_numba[
123-
simulation_state.geometry.v_inner_boundary_index : simulation_state.geometry.v_outer_boundary_index
123+
geometry_state.v_inner_boundary_index : geometry_state.v_outer_boundary_index
124124
]
125125

126126
estimators = initialize_estimator_statistics(
@@ -130,9 +130,9 @@ def initialize_transport_state(
130130
transport_state = MonteCarloTransportState(
131131
packet_collection,
132132
estimators,
133-
geometry_state=geometry_state,
133+
geometry_state=geometry_state_numba,
134134
opacity_state=opacity_state_numba,
135-
time_explosion=simulation_state.time_explosion,
135+
time_explosion=geometry_state.time_explosion,
136136
)
137137

138138
transport_state.enable_full_relativity = (
@@ -147,7 +147,12 @@ def initialize_transport_state(
147147

148148
def run(
149149
self,
150-
transport_state,
150+
geometry_state,
151+
opacity_state,
152+
macro_atom_state,
153+
plasma,
154+
no_of_packets,
155+
no_of_virtual_packets=0,
151156
iteration=0,
152157
total_iterations=0,
153158
show_progress_bars=True,
@@ -157,18 +162,43 @@ def run(
157162
158163
Parameters
159164
----------
160-
model : tardis.model.SimulationState
165+
geometry_state : tardis.model.geometry.Geometry
166+
The geometry state of the simulation
167+
opacity_state : tardis.opacities.opacity_state.OpacityState
168+
The opacity state
169+
macro_atom_state : tardis.opacities.macro_atom.macroatom_state.LegacyMacroAtomState
170+
The macro atom state
161171
plasma : tardis.plasma.BasePlasma
172+
The plasma state
162173
no_of_packets : int
174+
Number of packets to run
163175
no_of_virtual_packets : int
176+
Number of virtual packets
177+
iteration : int
178+
Current iteration number
164179
total_iterations : int
165180
The total number of iterations in the simulation.
181+
show_progress_bars : bool
182+
Whether to show progress bars
166183
167184
Returns
168185
-------
169-
None
186+
v_packets_energy_hist : numpy.ndarray
187+
Virtual packet energy histogram
170188
"""
171189
set_num_threads(self.nthreads)
190+
191+
# Initialize transport state
192+
transport_state = self.initialize_transport_state(
193+
geometry_state,
194+
opacity_state,
195+
macro_atom_state,
196+
plasma,
197+
no_of_packets,
198+
no_of_virtual_packets=no_of_virtual_packets,
199+
iteration=iteration,
200+
)
201+
172202
self.transport_state = transport_state
173203

174204
number_of_vpackets = self.montecarlo_configuration.NUMBER_OF_VPACKETS

tardis/workflows/simple_tardis_workflow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ def solve_montecarlo(
403403
macro_atom_state = opacity_states["macro_atom_state"]
404404

405405
self.transport_state = self.transport_solver.initialize_transport_state(
406-
self.simulation_state,
406+
self.simulation_state.geometry,
407407
opacity_state,
408408
macro_atom_state,
409409
self.plasma_solver,

0 commit comments

Comments
 (0)