Skip to content

Commit a954873

Browse files
authored
Merge pull request #41 from axiom-data-science/cache_path
Added ability to input filepath for saving/reading interpolator file
2 parents c1e74c7 + e0b9ec9 commit a954873

File tree

7 files changed

+132
-58
lines changed

7 files changed

+132
-58
lines changed

docs/whats_new.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
# What's New
22

3-
## unreleased
3+
## v0.10.1 (January 30, 2025)
44

55
* Added built-in way to create plots for simulation using OpenDrift. Details available in {ref}`plots`.
6+
* User can now input a location to both save and read the interpolator, which avoids using the built-in cache location.
67

78
## v0.9.6 (November 15, 2024)
89

particle_tracking_manager/cli.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -135,20 +135,23 @@ def main():
135135
"output_file"
136136
] = f"output-results_{datetime.utcnow():%Y-%m-%dT%H%M:%SZ}.nc"
137137

138-
log_file = args.kwargs["output_file"].replace(".nc", ".log")
138+
# log_file = args.kwargs["output_file"].replace(".nc", ".log")
139139

140140
# Convert the string representation of the dictionary to an actual dictionary
141141
# not clear why I can't use `args.plots` in here but it isn't working
142-
plots = ast.literal_eval(parser.parse_args().plots)
142+
if parser.parse_args().plots is not None:
143+
plots = ast.literal_eval(parser.parse_args().plots)
144+
else:
145+
plots = None
143146

144-
# Create a file handler
145-
file_handler = logging.FileHandler(log_file)
147+
# # Create a file handler
148+
# file_handler = logging.FileHandler(log_file)
146149

147-
# Create a formatter and add it to the handler
148-
formatter = logging.Formatter(
149-
"%(asctime)s - %(name)s - %(levelname)s - %(message)s"
150-
)
151-
file_handler.setFormatter(formatter)
150+
# # Create a formatter and add it to the handler
151+
# formatter = logging.Formatter(
152+
# "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
153+
# )
154+
# file_handler.setFormatter(formatter)
152155

153156
m = ptm.OpenDriftModel(**args.kwargs, plots=plots)
154157

@@ -160,10 +163,10 @@ def main():
160163

161164
else:
162165

163-
# Add the handler to the logger
164-
m.logger.addHandler(file_handler)
166+
# # Add the handler to the logger
167+
# m.logger.addHandler(file_handler)
165168

166-
m.logger.info(f"filename: {args.kwargs['output_file']}")
169+
# m.logger.info(f"filename: {args.kwargs['output_file']}")
167170

168171
m.add_reader()
169172
print(m.drift_model_config())
@@ -173,6 +176,6 @@ def main():
173176

174177
print(m.outfile_name)
175178

176-
# Remove the handler at the end of the loop
177-
m.logger.removeHandler(file_handler)
178-
file_handler.close()
179+
# # Remove the handler at the end of the loop
180+
# m.logger.removeHandler(file_handler)
181+
# file_handler.close()

particle_tracking_manager/models/opendrift/opendrift.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def __init__(
223223
],
224224
biodegradation: bool = config_model["biodegradation"]["default"],
225225
log: str = config_model["log"]["default"],
226-
plots: dict = config_model["plots"]["default"],
226+
plots: Optional[dict] = config_model["plots"]["default"],
227227
**kw,
228228
) -> None:
229229
"""Inputs for OpenDrift model."""

particle_tracking_manager/the_manager.py

Lines changed: 66 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,11 @@ class ParticleTrackingManager:
142142
Name of input/output module type to use for writing Lagrangian model output. Default is "netcdf".
143143
use_cache : bool
144144
Set to True to use cache for saving interpolators, by default True.
145+
interpolator_filename : Optional[Union[pathlib.Path,str]], optional
146+
Filename to save interpolators to, by default None. The full path should be given, but no suffix.
147+
Use this to either read from an existing file at a non-default location or to save to a
148+
non-default location. If None and use_cache==True, the filename is set to a built-in name to an
149+
`appdirs` cache directory.
145150
146151
Notes
147152
-----
@@ -198,6 +203,9 @@ def __init__(
198203
output_file: Optional[str] = config_ptm["output_file"]["default"],
199204
output_format: str = config_ptm["output_format"]["default"],
200205
use_cache: bool = config_ptm["use_cache"]["default"],
206+
interpolator_filename: Optional[Union[pathlib.Path, str]] = config_ptm[
207+
"interpolator_filename"
208+
]["default"],
201209
**kw,
202210
) -> None:
203211
"""Inputs necessary for any particle tracking."""
@@ -231,24 +239,16 @@ def __init__(
231239

232240
self.output_file_initial = None
233241

234-
# Set all attributes which will trigger some checks and changes in __setattr__
235-
# these will also update "value" in the config dict
236-
for key in sig.parameters.keys():
237-
# no need to run through for init if value is None (already set to None)
238-
if locals()[key] is not None:
239-
self.__setattr__(key, locals()[key])
240-
241-
self.kw = kw
242+
if output_file is None:
243+
output_file = f"output-results_{datetime.datetime.now():%Y-%m-%dT%H%M:%SZ}"
242244

243-
if self.__dict__["output_file"] is None:
244-
self.__dict__[
245-
"output_file"
246-
] = f"output-results_{datetime.datetime.now():%Y-%m-%dT%H%M:%SZ}"
245+
# want output_file to not include any suffix
246+
output_file = output_file.rstrip(".nc").rstrip(".parq")
247247

248248
## set up log for this simulation
249249
# Create a file handler
250-
assert self.__dict__["output_file"] is not None
251-
logfile_name = self.__dict__["output_file"] + ".log"
250+
assert output_file is not None
251+
logfile_name = output_file + ".log"
252252
self.file_handler = logging.FileHandler(logfile_name)
253253
self.logfile_name = logfile_name
254254

@@ -264,6 +264,56 @@ def __init__(
264264
self.logger.info(f"filename: {logfile_name}")
265265
##
266266

267+
if interpolator_filename is not None and not use_cache:
268+
raise ValueError(
269+
"If interpolator_filename is input, use_cache must be True."
270+
)
271+
272+
# deal with caching/interpolators
273+
# save interpolators to save time
274+
if use_cache:
275+
cache_dir = pathlib.Path(
276+
appdirs.user_cache_dir(
277+
appname="particle-tracking-manager",
278+
appauthor="axiom-data-science",
279+
)
280+
)
281+
cache_dir.mkdir(parents=True, exist_ok=True)
282+
cache_dir = cache_dir
283+
if interpolator_filename is None:
284+
interpolator_filename = cache_dir / pathlib.Path(
285+
f"{ocean_model}_interpolator"
286+
).with_suffix(".pickle")
287+
else:
288+
interpolator_filename = pathlib.Path(interpolator_filename).with_suffix(
289+
".pickle"
290+
)
291+
interpolator_filename = str(interpolator_filename)
292+
self.save_interpolator = True
293+
# if interpolator_filename already exists, load that
294+
if pathlib.Path(interpolator_filename).exists():
295+
self.logger.info(
296+
f"Loading the interpolator from {interpolator_filename}."
297+
)
298+
else:
299+
self.logger.info(
300+
f"A new interpolator will be saved to {interpolator_filename}."
301+
)
302+
else:
303+
self.save_interpolator = False
304+
# this is already None
305+
# self.interpolator_filename = None
306+
self.logger.info("Interpolators will not be saved.")
307+
308+
# Set all attributes which will trigger some checks and changes in __setattr__
309+
# these will also update "value" in the config dict
310+
for key in sig.parameters.keys():
311+
# no need to run through for init if value is None (already set to None)
312+
if locals()[key] is not None:
313+
self.__setattr__(key, locals()[key])
314+
315+
self.kw = kw
316+
267317
def __setattr_model__(self, name: str, value) -> None:
268318
"""Implement this in model class to add specific __setattr__ there too."""
269319
pass
@@ -383,7 +433,8 @@ def __setattr__(self, name: str, value) -> None:
383433
# by this point, output_file should already be a filename like what is
384434
# available here, from OpenDrift (if run from there)
385435
if self.output_file is not None:
386-
output_file = self.output_file.rstrip(".nc")
436+
output_file = self.output_file
437+
# output_file = self.output_file.rstrip(".nc")
387438
else:
388439
output_file = (
389440
f"output-results_{datetime.datetime.now():%Y-%m-%dT%H%M:%SZ}"
@@ -477,29 +528,6 @@ def __setattr__(self, name: str, value) -> None:
477528
)
478529
self.seed_seafloor = False
479530

480-
# save interpolators to save time
481-
if name == "use_cache":
482-
if value:
483-
cache_dir = pathlib.Path(
484-
appdirs.user_cache_dir(
485-
appname="particle-tracking-manager",
486-
appauthor="axiom-data-science",
487-
)
488-
)
489-
cache_dir.mkdir(parents=True, exist_ok=True)
490-
self.cache_dir = cache_dir
491-
self.interpolator_filename = cache_dir / pathlib.Path(
492-
f"{self.ocean_model}_interpolator"
493-
)
494-
self.save_interpolator = True
495-
self.logger.info(
496-
f"Interpolators will be saved to {self.interpolator_filename}."
497-
)
498-
else:
499-
self.save_interpolator = False
500-
self.interpolator_filename = None
501-
self.logger.info("Interpolators will not be saved.")
502-
503531
# if reader, lon, and lat set, check inputs
504532
if (
505533
name == "has_added_reader"

particle_tracking_manager/the_manager_config.json

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,5 +193,11 @@
193193
"default": true,
194194
"description": "Set to True to use cache for storing interpolators. This saves time on repeat simulations, may be used for other items in the future.",
195195
"ptm_level": 3
196+
},
197+
"interpolator_filename": {
198+
"type": "str",
199+
"default": "None",
200+
"description": "Filename to save interpolator to. The full path should be given, but no suffix. Use this to either read from an existing file at a non-default location or to save to a non-default location. If None and use_cache==True, default name is used. ",
201+
"ptm_level": 3
196202
}
197203
}

tests/test_manager.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,5 +387,24 @@ def test_start_time_none(self):
387387
self.m.seed()
388388

389389

390+
def test_interpolator_filename():
391+
with pytest.raises(ValueError):
392+
m = ptm.OpenDriftModel(interpolator_filename="test", use_cache=False)
393+
394+
m = ptm.OpenDriftModel(interpolator_filename="test")
395+
assert m.interpolator_filename == "test.pickle"
396+
397+
398+
def test_log_name():
399+
m = ptm.OpenDriftModel(output_file="newtest")
400+
assert m.logfile_name == "newtest.log"
401+
402+
m = ptm.OpenDriftModel(output_file="newtest.nc")
403+
assert m.logfile_name == "newtest.log"
404+
405+
m = ptm.OpenDriftModel(output_file="newtest.parq")
406+
assert m.logfile_name == "newtest.log"
407+
408+
390409
if __name__ == "__main__":
391410
unittest.main()

tests/test_realistic.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Test realistic scenarios, which are slower."""
22

3+
import pickle
4+
35
import pytest
46
import xarray as xr
57

@@ -42,16 +44,31 @@ def test_run_parquet():
4244
def test_run_netcdf():
4345
"""Set up and run."""
4446

47+
import tempfile
48+
4549
import xroms
4650

4751
seeding_kwargs = dict(lon=-90, lat=28.7, number=1)
48-
manager = ptm.OpenDriftModel(
49-
**seeding_kwargs, use_static_masks=True, steps=2, output_format="netcdf"
50-
)
52+
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
53+
manager = ptm.OpenDriftModel(
54+
**seeding_kwargs,
55+
use_static_masks=True,
56+
steps=2,
57+
output_format="netcdf",
58+
use_cache=True,
59+
interpolator_filename=temp_file.name
60+
)
5161
url = xroms.datasets.CLOVER.fetch("ROMS_example_full_grid.nc")
5262
ds = xr.open_dataset(url, decode_times=False)
5363
manager.add_reader(ds=ds, name="txla")
5464
manager.seed()
5565
manager.run()
5666

5767
assert "nc" in manager.o.outfile_name
68+
assert manager.interpolator_filename == temp_file.name + ".pickle"
69+
70+
# Replace 'path_to_pickle_file.pkl' with the actual path to your pickle file
71+
with open(manager.interpolator_filename, "rb") as file:
72+
data = pickle.load(file)
73+
assert "spl_x" in data
74+
assert "spl_y" in data

0 commit comments

Comments
 (0)