Skip to content

Commit 3a2b84f

Browse files
committed
ENH: Add filename option to most plot functions
1 parent dc99084 commit 3a2b84f

21 files changed

+417
-160
lines changed

rocketpy/mathutils/function.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,13 @@
1616
import numpy as np
1717
from scipy import integrate, linalg, optimize
1818

19+
try:
20+
from functools import cached_property
21+
except ImportError:
22+
from ..tools import cached_property
23+
24+
from ..plots.plot_helpers import show_or_save_plot
25+
1926
NUMERICAL_TYPES = (float, int, complex, np.ndarray, np.integer, np.floating)
2027
INTERPOLATION_TYPES = {
2128
"linear": 0,
@@ -1074,7 +1081,7 @@ def remove_outliers_iqr(self, threshold=1.5):
10741081
)
10751082

10761083
# Define all presentation methods
1077-
def __call__(self, *args):
1084+
def __call__(self, *args, filename=None):
10781085
"""Plot the Function if no argument is given. If an
10791086
argument is given, return the value of the function at the desired
10801087
point.
@@ -1093,8 +1100,8 @@ def __call__(self, *args):
10931100
-------
10941101
ans : None, scalar, list
10951102
"""
1096-
if len(args) == 0:
1097-
return self.plot()
1103+
if len(args) == 0 or (len(args) == 1 and filename != None):
1104+
return self.plot(filename=filename)
10981105
else:
10991106
return self.get_value(*args)
11001107

@@ -1155,8 +1162,11 @@ def plot(self, *args, **kwargs):
11551162
Function.plot_2d if Function is 2-Dimensional and forward arguments
11561163
and key-word arguments."""
11571164
if isinstance(self, list):
1165+
# Extract filename from kwargs
1166+
filename = kwargs.get("filename", None)
1167+
11581168
# Compare multiple plots
1159-
Function.compare_plots(self)
1169+
Function.compare_plots(self, filename)
11601170
else:
11611171
if self.__dom_dim__ == 1:
11621172
self.plot_1d(*args, **kwargs)
@@ -1184,6 +1194,7 @@ def plot_1d(
11841194
force_points=False,
11851195
return_object=False,
11861196
equal_axis=False,
1197+
filename=None,
11871198
):
11881199
"""Plot 1-Dimensional Function, from a lower limit to an upper limit,
11891200
by sampling the Function several times in the interval. The title of
@@ -1214,6 +1225,8 @@ def plot_1d(
12141225
Setting force_points to True will plot all points, as a scatter, in
12151226
which the Function was evaluated in the dataset. Default value is
12161227
False.
1228+
filename : str | None, optional
1229+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: png, pdf, ps, eps and svg.
12171230
12181231
Returns
12191232
-------
@@ -1254,7 +1267,7 @@ def plot_1d(
12541267
plt.title(self.title)
12551268
plt.xlabel(self.__inputs__[0].title())
12561269
plt.ylabel(self.__outputs__[0].title())
1257-
plt.show()
1270+
show_or_save_plot(filename)
12581271
if return_object:
12591272
return fig, ax
12601273

@@ -1277,6 +1290,7 @@ def plot_2d(
12771290
disp_type="surface",
12781291
alpha=0.6,
12791292
cmap="viridis",
1293+
filename=None,
12801294
):
12811295
"""Plot 2-Dimensional Function, from a lower limit to an upper limit,
12821296
by sampling the Function several times in the interval. The title of
@@ -1316,6 +1330,8 @@ def plot_2d(
13161330
cmap : string, optional
13171331
Colormap of plotted graph, which can be any of the color maps
13181332
available in matplotlib. Default value is viridis.
1333+
filename : str | None, optional
1334+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: png, pdf, ps, eps and svg.
13191335
13201336
Returns
13211337
-------
@@ -1389,7 +1405,7 @@ def plot_2d(
13891405
axes.set_xlabel(self.__inputs__[0].title())
13901406
axes.set_ylabel(self.__inputs__[1].title())
13911407
axes.set_zlabel(self.__outputs__[0].title())
1392-
plt.show()
1408+
show_or_save_plot(filename)
13931409

13941410
@staticmethod
13951411
def compare_plots(
@@ -1404,6 +1420,7 @@ def compare_plots(
14041420
force_points=False,
14051421
return_object=False,
14061422
show=True,
1423+
filename=None,
14071424
):
14081425
"""Plots N 1-Dimensional Functions in the same plot, from a lower
14091426
limit to an upper limit, by sampling the Functions several times in
@@ -1448,6 +1465,8 @@ def compare_plots(
14481465
False.
14491466
show : bool, optional
14501467
If True, shows the plot. Default value is True.
1468+
filename : str | None, optional
1469+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: png, pdf, ps, eps and svg.
14511470
14521471
Returns
14531472
-------
@@ -1522,7 +1541,7 @@ def compare_plots(
15221541
plt.ylabel(ylabel)
15231542

15241543
if show:
1525-
plt.show()
1544+
show_or_save_plot(filename)
15261545

15271546
if return_object:
15281547
return fig, ax

rocketpy/motors/hybrid_motor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -600,14 +600,14 @@ def add_tank(self, tank, position):
600600
)
601601
reset_funcified_methods(self)
602602

603-
def draw(self):
603+
def draw(self, filename=None):
604604
"""Draws a representation of the HybridMotor."""
605-
self.plots.draw()
605+
self.plots.draw(filename)
606606

607-
def info(self):
607+
def info(self, filename=None):
608608
"""Prints out basic data about the Motor."""
609609
self.prints.all()
610-
self.plots.thrust()
610+
self.plots.thrust(filename=filename)
611611
return None
612612

613613
def all_info(self):

rocketpy/motors/liquid_motor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -466,9 +466,9 @@ def add_tank(self, tank, position):
466466
self.positioned_tanks.append({"tank": tank, "position": position})
467467
reset_funcified_methods(self)
468468

469-
def draw(self):
469+
def draw(self, filename=None):
470470
"""Draw a representation of the LiquidMotor."""
471-
self.plots.draw()
471+
self.plots.draw(filename)
472472

473473
def info(self):
474474
"""Prints out basic data about the Motor."""

rocketpy/motors/motor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,13 +1035,13 @@ def get_attr_value(obj, attr_name, multiplier=1):
10351035

10361036
return None
10371037

1038-
def info(self):
1038+
def info(self, filename=None):
10391039
"""Prints out a summary of the data and graphs available about the
10401040
Motor.
10411041
"""
10421042
# Print motor details
10431043
self.prints.all()
1044-
self.plots.thrust()
1044+
self.plots.thrust(filename=filename)
10451045
return None
10461046

10471047
@abstractmethod

rocketpy/motors/solid_motor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -699,9 +699,9 @@ def propellant_I_13(self):
699699
def propellant_I_23(self):
700700
return 0
701701

702-
def draw(self):
702+
def draw(self, filename=None):
703703
"""Draw a representation of the SolidMotor."""
704-
self.plots.draw()
704+
self.plots.draw(filename)
705705

706706
def info(self):
707707
"""Prints out basic data about the SolidMotor."""

rocketpy/motors/tank.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,9 @@ def underfill_height_exception(param_name, param):
485485
elif (height < bottom_tolerance).any():
486486
underfill_height_exception(name, height)
487487

488-
def draw(self):
488+
def draw(self, filename=None):
489489
"""Draws the tank geometry."""
490-
self.plots.draw()
490+
self.plots.draw(filename)
491491

492492

493493
class MassFlowRateBasedTank(Tank):

rocketpy/plots/aero_surface_plots.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import numpy as np
55
from matplotlib.patches import Ellipse
66

7+
from .plot_helpers import show_or_save_plot
8+
79

810
class _AeroSurfacePlots(ABC):
911
"""Abstract class that contains all aero surface plots."""
@@ -24,7 +26,7 @@ def __init__(self, aero_surface):
2426
return None
2527

2628
@abstractmethod
27-
def draw(self):
29+
def draw(self, filename=None):
2830
pass
2931

3032
def lift(self):
@@ -70,10 +72,15 @@ def __init__(self, nosecone):
7072
super().__init__(nosecone)
7173
return None
7274

73-
def draw(self):
75+
def draw(self, filename=None):
7476
"""Draw the nosecone shape along with some important information,
7577
including the center line and the center of pressure position.
7678
79+
Parameters
80+
----------
81+
filename : str | None, optional
82+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: png, pdf, ps, eps and svg.
83+
7784
Returns
7885
-------
7986
None
@@ -141,7 +148,7 @@ def draw(self):
141148
ax.set_title(self.aero_surface.kind + " Nose Cone")
142149
ax.legend(bbox_to_anchor=(1, -0.2))
143150
# Show Plot
144-
plt.show()
151+
show_or_save_plot(filename)
145152
return None
146153

147154

@@ -165,7 +172,7 @@ def __init__(self, fin_set):
165172
return None
166173

167174
@abstractmethod
168-
def draw(self):
175+
def draw(self, filename=None):
169176
pass
170177

171178
def airfoil(self):
@@ -233,10 +240,15 @@ def __init__(self, fin_set):
233240
super().__init__(fin_set)
234241
return None
235242

236-
def draw(self):
243+
def draw(self, filename=None):
237244
"""Draw the fin shape along with some important information, including
238245
the center line, the quarter line and the center of pressure position.
239246
247+
Parameters
248+
----------
249+
filename : str | None, optional
250+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: png, pdf, ps, eps and svg.
251+
240252
Returns
241253
-------
242254
None
@@ -347,7 +359,7 @@ def draw(self):
347359
ax.legend(bbox_to_anchor=(1.05, 1.0), loc="upper left")
348360

349361
plt.tight_layout()
350-
plt.show()
362+
show_or_save_plot(filename)
351363
return None
352364

353365

@@ -358,10 +370,15 @@ def __init__(self, fin_set):
358370
super().__init__(fin_set)
359371
return None
360372

361-
def draw(self):
373+
def draw(self, filename=None):
362374
"""Draw the fin shape along with some important information.
363375
These being: the center line and the center of pressure position.
364376
377+
Parameters
378+
----------
379+
filename : str | None, optional
380+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: png, pdf, ps, eps and svg.
381+
365382
Returns
366383
-------
367384
None
@@ -422,7 +439,7 @@ def draw(self):
422439
ax.legend(bbox_to_anchor=(1.05, 1.0), loc="upper left")
423440

424441
plt.tight_layout()
425-
plt.show()
442+
show_or_save_plot(filename)
426443

427444
return None
428445

@@ -445,7 +462,7 @@ def __init__(self, tail):
445462
super().__init__(tail)
446463
return None
447464

448-
def draw(self):
465+
def draw(self, filename=None):
449466
# This will de done in the future
450467
return None
451468

@@ -474,7 +491,7 @@ def drag_coefficient_curve(self):
474491
else:
475492
return self.aero_surface.drag_coefficient.plot()
476493

477-
def draw(self):
494+
def draw(self, filename=None):
478495
raise NotImplementedError
479496

480497
def all(self):

rocketpy/plots/compare/compare_flights.py

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
from .compare import Compare
55

6+
from ..plot_helpers import show_or_save_plot, show_or_save_fig
7+
68

79
class CompareFlights(Compare):
810
"""A class to compare the results of multiple flights.
@@ -90,11 +92,9 @@ def __process_savefig(self, filename, fig):
9092
-------
9193
None
9294
"""
95+
show_or_save_fig(fig, filename)
9396
if filename:
94-
fig.savefig(filename)
9597
print("Plot saved to file: " + filename)
96-
else:
97-
plt.show()
9898
return None
9999

100100
def __process_legend(self, legend, fig):
@@ -1276,10 +1276,7 @@ def compare_trajectories_3d(
12761276
fig1.tight_layout()
12771277

12781278
# Save figure
1279-
if filename:
1280-
plt.savefig(filename)
1281-
else:
1282-
plt.show()
1279+
show_or_save_plot(filename)
12831280

12841281
return None
12851282

@@ -1474,6 +1471,8 @@ def __plot_xz(
14741471
.svg, .pgf, .eps
14751472
figsize : tuple, optional
14761473
Tuple with the size of the figure. The default is (7, 7).
1474+
filename : str | None, optional
1475+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: png, pdf, ps, eps and svg.
14771476
14781477
Returns
14791478
-------
@@ -1517,10 +1516,7 @@ def __plot_xz(
15171516
fig.tight_layout()
15181517

15191518
# Save figure
1520-
if filename:
1521-
plt.savefig(filename)
1522-
else:
1523-
plt.show()
1519+
show_or_save_plot(filename)
15241520

15251521
return None
15261522

0 commit comments

Comments
 (0)