Skip to content

Commit 61d241d

Browse files
committed
ENH: Add filename option to most plot functions
1 parent fba6c8c commit 61d241d

21 files changed

+591
-184
lines changed

rocketpy/mathutils/function.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
import numpy as np
1717
from scipy import integrate, linalg, optimize
1818

19+
from ..plots.plot_helpers import show_or_save_plot
20+
1921
NUMERICAL_TYPES = (float, int, complex, np.ndarray, np.integer, np.floating)
2022
INTERPOLATION_TYPES = {
2123
"linear": 0,
@@ -1074,7 +1076,7 @@ def remove_outliers_iqr(self, threshold=1.5):
10741076
)
10751077

10761078
# Define all presentation methods
1077-
def __call__(self, *args):
1079+
def __call__(self, *args, filename=None):
10781080
"""Plot the Function if no argument is given. If an
10791081
argument is given, return the value of the function at the desired
10801082
point.
@@ -1088,13 +1090,15 @@ def __call__(self, *args):
10881090
evaluated at all points in the list and a list of floats will be
10891091
returned. If the function is N-D, N arguments must be given, each
10901092
one being an scalar or list.
1093+
filename : str | None, optional
1094+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff and webp.
10911095
10921096
Returns
10931097
-------
10941098
ans : None, scalar, list
10951099
"""
1096-
if len(args) == 0:
1097-
return self.plot()
1100+
if len(args) == 0 or (len(args) == 1 and filename != None):
1101+
return self.plot(filename=filename)
10981102
else:
10991103
return self.get_value(*args)
11001104

@@ -1155,8 +1159,11 @@ def plot(self, *args, **kwargs):
11551159
Function.plot_2d if Function is 2-Dimensional and forward arguments
11561160
and key-word arguments."""
11571161
if isinstance(self, list):
1162+
# Extract filename from kwargs
1163+
filename = kwargs.get("filename", None)
1164+
11581165
# Compare multiple plots
1159-
Function.compare_plots(self)
1166+
Function.compare_plots(self, filename)
11601167
else:
11611168
if self.__dom_dim__ == 1:
11621169
self.plot_1d(*args, **kwargs)
@@ -1184,6 +1191,7 @@ def plot_1d(
11841191
force_points=False,
11851192
return_object=False,
11861193
equal_axis=False,
1194+
filename=None,
11871195
):
11881196
"""Plot 1-Dimensional Function, from a lower limit to an upper limit,
11891197
by sampling the Function several times in the interval. The title of
@@ -1214,6 +1222,8 @@ def plot_1d(
12141222
Setting force_points to True will plot all points, as a scatter, in
12151223
which the Function was evaluated in the dataset. Default value is
12161224
False.
1225+
filename : str | None, optional
1226+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff and webp.
12171227
12181228
Returns
12191229
-------
@@ -1254,7 +1264,7 @@ def plot_1d(
12541264
plt.title(self.title)
12551265
plt.xlabel(self.__inputs__[0].title())
12561266
plt.ylabel(self.__outputs__[0].title())
1257-
plt.show()
1267+
show_or_save_plot(filename)
12581268
if return_object:
12591269
return fig, ax
12601270

@@ -1277,6 +1287,7 @@ def plot_2d(
12771287
disp_type="surface",
12781288
alpha=0.6,
12791289
cmap="viridis",
1290+
filename=None,
12801291
):
12811292
"""Plot 2-Dimensional Function, from a lower limit to an upper limit,
12821293
by sampling the Function several times in the interval. The title of
@@ -1316,6 +1327,8 @@ def plot_2d(
13161327
cmap : string, optional
13171328
Colormap of plotted graph, which can be any of the color maps
13181329
available in matplotlib. Default value is viridis.
1330+
filename : str | None, optional
1331+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff and webp.
13191332
13201333
Returns
13211334
-------
@@ -1389,7 +1402,7 @@ def plot_2d(
13891402
axes.set_xlabel(self.__inputs__[0].title())
13901403
axes.set_ylabel(self.__inputs__[1].title())
13911404
axes.set_zlabel(self.__outputs__[0].title())
1392-
plt.show()
1405+
show_or_save_plot(filename)
13931406

13941407
@staticmethod
13951408
def compare_plots(
@@ -1404,6 +1417,7 @@ def compare_plots(
14041417
force_points=False,
14051418
return_object=False,
14061419
show=True,
1420+
filename=None,
14071421
):
14081422
"""Plots N 1-Dimensional Functions in the same plot, from a lower
14091423
limit to an upper limit, by sampling the Functions several times in
@@ -1448,6 +1462,8 @@ def compare_plots(
14481462
False.
14491463
show : bool, optional
14501464
If True, shows the plot. Default value is True.
1465+
filename : str | None, optional
1466+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff and webp.
14511467
14521468
Returns
14531469
-------
@@ -1522,7 +1538,7 @@ def compare_plots(
15221538
plt.ylabel(ylabel)
15231539

15241540
if show:
1525-
plt.show()
1541+
show_or_save_plot(filename)
15261542

15271543
if return_object:
15281544
return fig, ax

rocketpy/motors/hybrid_motor.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -600,14 +600,34 @@ def add_tank(self, tank, position):
600600
)
601601
reset_funcified_methods(self)
602602

603-
def draw(self):
604-
"""Draws a representation of the HybridMotor."""
605-
self.plots.draw()
603+
def draw(self, filename=None):
604+
"""Draws a representation of the HybridMotor.
606605
607-
def info(self):
608-
"""Prints out basic data about the Motor."""
606+
Parameters
607+
----------
608+
filename : str | None, optional
609+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff and webp.
610+
611+
Returns
612+
-------
613+
None
614+
"""
615+
self.plots.draw(filename)
616+
617+
def info(self, filename=None):
618+
"""Prints out basic data about the Motor.
619+
620+
Parameters
621+
----------
622+
filename : str | None, optional
623+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff and webp.
624+
625+
Returns
626+
-------
627+
None
628+
"""
609629
self.prints.all()
610-
self.plots.thrust()
630+
self.plots.thrust(filename=filename)
611631
return None
612632

613633
def all_info(self):

rocketpy/motors/liquid_motor.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -466,9 +466,19 @@ def add_tank(self, tank, position):
466466
self.positioned_tanks.append({"tank": tank, "position": position})
467467
reset_funcified_methods(self)
468468

469-
def draw(self):
470-
"""Draw a representation of the LiquidMotor."""
471-
self.plots.draw()
469+
def draw(self, filename=None):
470+
"""Draw a representation of the LiquidMotor.
471+
472+
Parameters
473+
----------
474+
filename : str | None, optional
475+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff and webp.
476+
477+
Returns
478+
-------
479+
None
480+
"""
481+
self.plots.draw(filename)
472482

473483
def info(self):
474484
"""Prints out basic data about the Motor."""

rocketpy/motors/motor.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1035,13 +1035,22 @@ def get_attr_value(obj, attr_name, multiplier=1):
10351035

10361036
return None
10371037

1038-
def info(self):
1038+
def info(self, filename=None):
10391039
"""Prints out a summary of the data and graphs available about the
10401040
Motor.
1041+
1042+
Parameters
1043+
----------
1044+
filename : str | None, optional
1045+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff and webp.
1046+
1047+
Returns
1048+
-------
1049+
None
10411050
"""
10421051
# Print motor details
10431052
self.prints.all()
1044-
self.plots.thrust()
1053+
self.plots.thrust(filename=filename)
10451054
return None
10461055

10471056
@abstractmethod

rocketpy/motors/solid_motor.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -699,9 +699,19 @@ def propellant_I_13(self):
699699
def propellant_I_23(self):
700700
return 0
701701

702-
def draw(self):
703-
"""Draw a representation of the SolidMotor."""
704-
self.plots.draw()
702+
def draw(self, filename=None):
703+
"""Draw a representation of the SolidMotor.
704+
705+
Parameters
706+
----------
707+
filename : str | None, optional
708+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff and webp.
709+
710+
Returns
711+
-------
712+
None
713+
"""
714+
self.plots.draw(filename)
705715

706716
def info(self):
707717
"""Prints out basic data about the SolidMotor."""

rocketpy/motors/tank.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,19 @@ def underfill_height_exception(param_name, param):
485485
elif (height < bottom_tolerance).any():
486486
underfill_height_exception(name, height)
487487

488-
def draw(self):
489-
"""Draws the tank geometry."""
490-
self.plots.draw()
488+
def draw(self, filename=None):
489+
"""Draws the tank geometry.
490+
491+
Parameters
492+
----------
493+
filename : str | None, optional
494+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff and webp.
495+
496+
Returns
497+
-------
498+
None
499+
"""
500+
self.plots.draw(filename)
491501

492502

493503
class MassFlowRateBasedTank(Tank):

rocketpy/plots/aero_surface_plots.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import numpy as np
55
from matplotlib.patches import Ellipse
66

7+
from .plot_helpers import show_or_save_plot
8+
79

810
class _AeroSurfacePlots(ABC):
911
"""Abstract class that contains all aero surface plots."""
@@ -24,7 +26,7 @@ def __init__(self, aero_surface):
2426
return None
2527

2628
@abstractmethod
27-
def draw(self):
29+
def draw(self, filename=None):
2830
pass
2931

3032
def lift(self):
@@ -70,10 +72,15 @@ def __init__(self, nosecone):
7072
super().__init__(nosecone)
7173
return None
7274

73-
def draw(self):
75+
def draw(self, filename=None):
7476
"""Draw the nosecone shape along with some important information,
7577
including the center line and the center of pressure position.
7678
79+
Parameters
80+
----------
81+
filename : str | None, optional
82+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff and webp.
83+
7784
Returns
7885
-------
7986
None
@@ -141,7 +148,7 @@ def draw(self):
141148
ax.set_title(self.aero_surface.kind + " Nose Cone")
142149
ax.legend(bbox_to_anchor=(1, -0.2))
143150
# Show Plot
144-
plt.show()
151+
show_or_save_plot(filename)
145152
return None
146153

147154

@@ -165,7 +172,7 @@ def __init__(self, fin_set):
165172
return None
166173

167174
@abstractmethod
168-
def draw(self):
175+
def draw(self, filename=None):
169176
pass
170177

171178
def airfoil(self):
@@ -233,10 +240,15 @@ def __init__(self, fin_set):
233240
super().__init__(fin_set)
234241
return None
235242

236-
def draw(self):
243+
def draw(self, filename=None):
237244
"""Draw the fin shape along with some important information, including
238245
the center line, the quarter line and the center of pressure position.
239246
247+
Parameters
248+
----------
249+
filename : str | None, optional
250+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff and webp.
251+
240252
Returns
241253
-------
242254
None
@@ -347,7 +359,7 @@ def draw(self):
347359
ax.legend(bbox_to_anchor=(1.05, 1.0), loc="upper left")
348360

349361
plt.tight_layout()
350-
plt.show()
362+
show_or_save_plot(filename)
351363
return None
352364

353365

@@ -358,10 +370,15 @@ def __init__(self, fin_set):
358370
super().__init__(fin_set)
359371
return None
360372

361-
def draw(self):
373+
def draw(self, filename=None):
362374
"""Draw the fin shape along with some important information.
363375
These being: the center line and the center of pressure position.
364376
377+
Parameters
378+
----------
379+
filename : str | None, optional
380+
The path the plot should be saved to. By default None, in which case the plot will be shown instead of saved. Supported file endings are: eps, jpg, jpeg, pdf, pgf, png, ps, raw, rgba, svg, svgz, tif, tiff and webp.
381+
365382
Returns
366383
-------
367384
None
@@ -422,7 +439,7 @@ def draw(self):
422439
ax.legend(bbox_to_anchor=(1.05, 1.0), loc="upper left")
423440

424441
plt.tight_layout()
425-
plt.show()
442+
show_or_save_plot(filename)
426443

427444
return None
428445

@@ -445,7 +462,7 @@ def __init__(self, tail):
445462
super().__init__(tail)
446463
return None
447464

448-
def draw(self):
465+
def draw(self, filename=None):
449466
# This will de done in the future
450467
return None
451468

@@ -474,7 +491,7 @@ def drag_coefficient_curve(self):
474491
else:
475492
return self.aero_surface.drag_coefficient.plot()
476493

477-
def draw(self):
494+
def draw(self, filename=None):
478495
raise NotImplementedError
479496

480497
def all(self):

0 commit comments

Comments
 (0)