Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions benchmarks/benchmark_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@
from tardis.model.geometry.radial1d import NumbaRadial1DGeometry
from tardis.simulation import Simulation
from tardis.tests.fixtures.atom_data import DEFAULT_ATOM_DATA_MD5
from tardis.transport.montecarlo import RPacket, packet_trackers
from tardis.transport.montecarlo import RPacket
from tardis.transport.montecarlo.configuration.base import (
MonteCarloConfiguration,
)
from tardis.transport.montecarlo.estimators import radfield_mc_estimators
from tardis.opacities.opacity_state_numba import opacity_state_numba_initialize
from tardis.transport.montecarlo.packet_collections import VPacketCollection
from tardis.transport.montecarlo.packets.packet_collections import VPacketCollection
from tardis.transport.montecarlo.packets import packet_trackers


class BenchmarkBase:
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/transport_montecarlo_packet_trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from asv_runner.benchmarks.mark import parameterize

from benchmarks.benchmark_base import BenchmarkBase
from tardis.transport.montecarlo import packet_trackers
from tardis.transport.montecarlo.packets import packet_trackers


class BenchmarkTransportMontecarloPacketTrackers(BenchmarkBase):
Expand Down
16 changes: 8 additions & 8 deletions benchmarks/transport_montecarlo_vpacket.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

import numpy as np

import tardis.transport.montecarlo.vpacket as vpacket
import tardis.transport.montecarlo.packets.virtual_packet as virtual_packet
from benchmarks.benchmark_base import BenchmarkBase
from tardis.transport.frame_transformations import get_doppler_factor
from tardis.transport.montecarlo.r_packet import RPacket
from tardis.transport.montecarlo.packets.radiative_packet import RPacket


class BenchmarkMontecarloMontecarloNumbaVpacket(BenchmarkBase):
Expand All @@ -30,7 +30,7 @@ def setup(self):

@functools.cached_property
def v_packet(self):
return vpacket.VPacket(
return virtual_packet.VPacket(
r=7.5e14,
nu=4e15,
mu=0.3,
Expand Down Expand Up @@ -74,7 +74,7 @@ def time_trace_vpacket_within_shell(self):
self.enable_full_relativity,
)

vpacket.trace_vpacket_within_shell(
virtual_packet.trace_vpacket_within_shell(
self.vpacket,
self.numba_radial_1d_geometry,
self.time_explosion,
Expand All @@ -94,7 +94,7 @@ def time_trace_vpacket(self):
self.enable_full_relativity,
)

vpacket.trace_vpacket(
virtual_packet.trace_vpacket(
self.vpacket,
self.numba_radial_1d_geometry,
self.time_explosion,
Expand All @@ -106,7 +106,7 @@ def time_trace_vpacket(self):

@functools.cached_property
def broken_packet(self):
return vpacket.VPacket(
return virtual_packet.VPacket(
r=1286064000000000.0,
nu=1660428912896553.2,
mu=0.4916053094346575,
Expand All @@ -119,7 +119,7 @@ def broken_packet(self):
def time_trace_bad_vpacket(self):
broken_packet = self.broken_packet

vpacket.trace_vpacket(
virtual_packet.trace_vpacket(
broken_packet,
self.numba_radial_1d_geometry,
self.time_explosion,
Expand All @@ -130,7 +130,7 @@ def time_trace_bad_vpacket(self):
)

def time_trace_vpacket_volley(self):
vpacket.trace_vpacket_volley(
virtual_packet.trace_vpacket_volley(
self.r_packet,
self.verysimple_3vpacket_collection,
self.numba_radial_1d_geometry,
Expand Down
59 changes: 37 additions & 22 deletions docs/io/optional/how_to_custom_source.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,16 @@
"outputs": [],
"source": [
"# Import necessary packages\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from tardis import constants as const\n",
"from astropy import units as u\n",
"\n",
"from tardis import run_tardis\n",
"from tardis.io.atom_data import download_atom_data\n",
"from tardis.transport.montecarlo.packet_source import BlackBodySimpleSource\n",
"from tardis.transport.montecarlo.packet_collections import (\n",
"from tardis.transport.montecarlo.packets.packet_collections import (\n",
" PacketCollection,\n",
")\n",
"from tardis import run_tardis\n",
"import matplotlib.pyplot as plt\n",
"from tardis.io.atom_data import download_atom_data"
")"
]
},
{
Expand Down Expand Up @@ -90,7 +90,14 @@
" self.truncation_wavelength = truncation_wavelength\n",
" super().__init__(**kwargs)\n",
"\n",
" def create_packets(self, no_of_packets, drawing_sample_size=None, seed_offset=0, *args, **kwargs):\n",
" def create_packets(\n",
" self,\n",
" no_of_packets,\n",
" drawing_sample_size=None,\n",
" seed_offset=0,\n",
" *args,\n",
" **kwargs,\n",
" ):\n",
" \"\"\"\n",
" Packet source that generates a truncated Blackbody source.\n",
"\n",
Expand All @@ -110,7 +117,6 @@
" array\n",
" Packet energies\n",
" \"\"\"\n",
"\n",
" self._reseed(self.base_seed + seed_offset)\n",
" packet_seeds = self.rng.choice(\n",
" self.MAX_SEED_VAL, no_of_packets, replace=True\n",
Expand All @@ -128,10 +134,9 @@
" drawing_sample_size = 2 * no_of_packets\n",
"\n",
" # Blackbody will be truncated below truncation_wavelength / above truncation_frequency.\n",
" truncation_frequency = (\n",
" u.Quantity(self.truncation_wavelength, u.Angstrom)\n",
" .to(u.Hz, equivalencies=u.spectral())\n",
" )\n",
" truncation_frequency = u.Quantity(\n",
" self.truncation_wavelength, u.Angstrom\n",
" ).to(u.Hz, equivalencies=u.spectral())\n",
"\n",
" # Draw nus from blackbody distribution and reject based on truncation_frequency.\n",
" # If more nus.shape[0] > no_of_packets use only the first no_of_packets.\n",
Expand All @@ -141,7 +146,9 @@
" # Only required if the truncation wavelength is too big compared to the maximum\n",
" # of the blackbody distribution. Keep sampling until nus.shape[0] > no_of_packets.\n",
" while nus.shape[0] < no_of_packets:\n",
" additional_nus = self.create_packet_nus(drawing_sample_size, *args, **kwargs)\n",
" additional_nus = self.create_packet_nus(\n",
" drawing_sample_size, *args, **kwargs\n",
" )\n",
" mask = additional_nus < truncation_frequency\n",
" additional_nus = additional_nus[mask][:no_of_packets]\n",
" nus = np.hstack([nus, additional_nus])[:no_of_packets]\n",
Expand All @@ -150,7 +157,9 @@
" self.calculate_radfield_luminosity().to(u.erg / u.s).value\n",
" )\n",
"\n",
" return PacketCollection(radii, nus, mus, energies, packet_seeds, radiation_field_luminosity)"
" return PacketCollection(\n",
" radii, nus, mus, energies, packet_seeds, radiation_field_luminosity\n",
" )"
]
},
{
Expand Down Expand Up @@ -190,14 +199,20 @@
"outputs": [],
"source": [
"%matplotlib inline\n",
"plt.plot(mdl.spectrum_solver.spectrum_virtual_packets.wavelength,\n",
" mdl.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n",
" color='red', label='truncated blackbody (custom packet source)')\n",
"plt.plot(mdl_norm.spectrum_solver.spectrum_virtual_packets.wavelength,\n",
" mdl_norm.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n",
" color='blue', label='normal blackbody (default packet source)')\n",
"plt.xlabel(r'$\\lambda [\\AA]$')\n",
"plt.ylabel(r'$L_\\lambda$ [erg/s/$\\AA$]')\n",
"plt.plot(\n",
" mdl.spectrum_solver.spectrum_virtual_packets.wavelength,\n",
" mdl.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n",
" color=\"red\",\n",
" label=\"truncated blackbody (custom packet source)\",\n",
")\n",
"plt.plot(\n",
" mdl_norm.spectrum_solver.spectrum_virtual_packets.wavelength,\n",
" mdl_norm.spectrum_solver.spectrum_virtual_packets.luminosity_density_lambda,\n",
" color=\"blue\",\n",
" label=\"normal blackbody (default packet source)\",\n",
")\n",
"plt.xlabel(r\"$\\lambda [\\AA]$\")\n",
"plt.ylabel(r\"$L_\\lambda$ [erg/s/$\\AA$]\")\n",
"plt.xlim(500, 10000)\n",
"plt.legend()"
]
Expand Down
24 changes: 12 additions & 12 deletions docs/physics_walkthrough/montecarlo/initialization.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,12 @@
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from tardis.transport.montecarlo.packet_source import BlackBodySimpleSource\n",
"from tardis.transport.montecarlo.packet_collections import (\n",
" PacketCollection,\n",
")\n",
"from astropy import units as u\n",
"\n",
"from tardis import constants as const\n",
"import matplotlib.pyplot as plt"
"from tardis.transport.montecarlo.packet_source import BlackBodySimpleSource"
]
},
{
Expand Down Expand Up @@ -135,11 +133,7 @@
"temperature_inner = 10000 * u.K\n",
"\n",
"luminosity_inner = (\n",
" 4\n",
" * np.pi\n",
" * (r_boundary_inner**2)\n",
" * const.sigma_sb\n",
" * (temperature_inner**4)\n",
" 4 * np.pi * (r_boundary_inner**2) * const.sigma_sb * (temperature_inner**4)\n",
")\n",
"\n",
"# Makes sure the luminosity is given in erg/s\n",
Expand Down Expand Up @@ -263,12 +257,18 @@
"source": [
"# We set important quantities for making our histogram\n",
"bins = 200\n",
"nus_planck = np.linspace(min(packet_collection.initial_nus), max(packet_collection.initial_nus), bins).value\n",
"nus_planck = np.linspace(\n",
" min(packet_collection.initial_nus), max(packet_collection.initial_nus), bins\n",
").value\n",
"bin_width = nus_planck[1] - nus_planck[0]\n",
"\n",
"# In the histogram plot below, the weights argument is used\n",
"# to make sure our plotted spectrum has the correct y-axis scale\n",
"plt.hist(packet_collection.initial_nus.value, bins=bins, weights=lumin_per_packet / bin_width)\n",
"plt.hist(\n",
" packet_collection.initial_nus.value,\n",
" bins=bins,\n",
" weights=lumin_per_packet / bin_width,\n",
")\n",
"\n",
"# We plot the planck function for comparison\n",
"plt.plot(nus_planck * u.Hz, planck_function(nus_planck * u.Hz))\n",
Expand Down
4 changes: 2 additions & 2 deletions tardis/transport/montecarlo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"parallel": False,
}

from tardis.transport.montecarlo.packet_collections import (
from tardis.transport.montecarlo.packets.packet_collections import (
PacketCollection,
)
from tardis.transport.montecarlo.r_packet import RPacket
from tardis.transport.montecarlo.packets.radiative_packet import RPacket
2 changes: 1 addition & 1 deletion tardis/transport/montecarlo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from tardis.transport.montecarlo.montecarlo_transport_state import (
MonteCarloTransportState,
)
from tardis.transport.montecarlo.packet_trackers import (
from tardis.transport.montecarlo.packets.packet_trackers import (
generate_rpacket_last_interaction_tracker_list,
generate_rpacket_tracker_list,
rpacket_trackers_to_dataframe,
Expand Down
2 changes: 1 addition & 1 deletion tardis/transport/montecarlo/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
MacroAtomTransitionType,
macro_atom_interaction,
)
from tardis.transport.montecarlo.r_packet import (
from tardis.transport.montecarlo.packets.radiative_packet import (
PacketStatus,
)
from tardis.transport.montecarlo.utils import get_random_mu
Expand Down
4 changes: 2 additions & 2 deletions tardis/transport/montecarlo/montecarlo_main_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

from tardis.transport.montecarlo import njit_dict
from tardis.transport.montecarlo.configuration import montecarlo_globals
from tardis.transport.montecarlo.packet_collections import (
from tardis.transport.montecarlo.packets.packet_collections import (
VPacketCollection,
consolidate_vpacket_tracker,
initialize_last_interaction_tracker,
)
from tardis.transport.montecarlo.r_packet import (
from tardis.transport.montecarlo.packets.radiative_packet import (
PacketStatus,
RPacket,
)
Expand Down
2 changes: 1 addition & 1 deletion tardis/transport/montecarlo/packet_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from tardis import constants as const
from tardis.io.hdf_writer_mixin import HDFWriterMixin
from tardis.transport.montecarlo.packet_collections import (
from tardis.transport.montecarlo.packets.packet_collections import (
PacketCollection,
)

Expand Down
Empty file.
Loading
Loading