-
-
Notifications
You must be signed in to change notification settings - Fork 1.2k
ENH: Plotting for groupby_bins #2152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 18 commits
78c077c
7b400fa
4175bbf
2d11c10
e43f0b0
0a15f07
a63d68a
347740b
73f790a
ecb0935
b4d05e7
e77e996
6d9416d
ce407cd
389f63b
0dcbf50
3898394
0217b29
447aea3
b87d0f6
98bc369
87ef1cc
826df44
ea6f6df
1c2d6d6
a255857
e60728e
448d6b7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from __future__ import absolute_import, division, print_function | ||
|
||
import functools | ||
import itertools | ||
import warnings | ||
from datetime import datetime | ||
|
||
|
@@ -48,14 +49,69 @@ def _ensure_plottable(*args): | |
axis. | ||
""" | ||
numpy_types = [np.floating, np.integer, np.timedelta64, np.datetime64] | ||
other_types = [datetime] | ||
other_types = [datetime, pd.Interval] | ||
|
||
for x in args: | ||
if not (_valid_numpy_subdtype(np.array(x), numpy_types) or | ||
_valid_other_type(np.array(x), other_types)): | ||
raise TypeError('Plotting requires coordinates to be numeric ' | ||
'or dates of type np.datetime64 or ' | ||
'datetime.datetime.') | ||
'datetime.datetime or pd.Interval.') | ||
|
||
|
||
def _interval_to_mid_points(array): | ||
""" | ||
Helper function which returns an array | ||
with the Intervals' mid points. | ||
""" | ||
|
||
return np.array([x.mid for x in array]) | ||
|
||
|
||
def _interval_to_bound_points(array): | ||
""" | ||
Helper function which returns an array | ||
with the Intervals' boundaries. | ||
""" | ||
|
||
array_boundaries = np.array([x.left for x in array]) | ||
array_boundaries = np.concatenate( | ||
(array_boundaries, np.array([array[-1].right]))) | ||
|
||
return array_boundaries | ||
|
||
|
||
def _interval_to_double_bound_points(xarray, yarray): | ||
""" | ||
Helper function to deal with a xarray consisting of pd.Intervals. Each | ||
interval is replaced with both boundaries. I.e. the length of xarray | ||
doubles. yarray is modified so it matches the new shape of xarray. | ||
""" | ||
|
||
xarray1 = np.array([x.left for x in xarray]) | ||
xarray2 = np.array([x.right for x in xarray]) | ||
|
||
xarray = list(itertools.chain.from_iterable(zip(xarray1, xarray2))) | ||
yarray = list(itertools.chain.from_iterable(zip(yarray, yarray))) | ||
|
||
return xarray, yarray | ||
|
||
|
||
def _resolve_intervals_2dplot(val, func_name): | ||
""" | ||
Helper function to replace the values of a coordinate array containing | ||
pd.Interval with their mid-points or - for pcolormesh - boundaries which | ||
increases length by 1. | ||
""" | ||
label_extra = '' | ||
if _valid_other_type(val, [pd.Interval]): | ||
if func_name == 'pcolormesh': | ||
val = _interval_to_bound_points(val) | ||
else: | ||
val = _interval_to_mid_points(val) | ||
label_extra = '_center' | ||
|
||
return val, label_extra | ||
|
||
|
||
def _easy_facetgrid(darray, plotfunc, x, y, row=None, col=None, | ||
|
@@ -317,7 +373,28 @@ def line(darray, *args, **kwargs): | |
|
||
_ensure_plottable(xplt) | ||
|
||
primitive = ax.plot(xplt, yplt, *args, **kwargs) | ||
# Remove pd.Intervals if contained in xplt.values. | ||
if _valid_other_type(xplt.values, [pd.Interval]): | ||
# Is it a step plot? | ||
if kwargs.get('linestyle', '').startswith('steps-'): | ||
xplt_val, yplt_val = _interval_to_double_bound_points(xplt.values, | ||
yplt.values) | ||
# just to be sure that matplotlib is not confused | ||
kwargs['linestyle'] = kwargs['linestyle'].replace( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know this is quite ugly, but does it make sense to import re only for this one line? |
||
'steps-pre', '').replace( | ||
'steps-post', '').replace( | ||
'steps-mid', '') | ||
if kwargs['linestyle'] == '': | ||
kwargs.pop('linestyle') | ||
else: | ||
xplt_val = _interval_to_mid_points(xplt.values) | ||
yplt_val = yplt.values | ||
xlabel += '_center' | ||
else: | ||
xplt_val = xplt.values | ||
yplt_val = yplt.values | ||
|
||
primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) | ||
|
||
if _labels: | ||
if xlabel is not None: | ||
|
@@ -347,6 +424,46 @@ def line(darray, *args, **kwargs): | |
return primitive | ||
|
||
|
||
def step(darray, *args, **kwargs): | ||
""" | ||
Step plot of DataArray index against values | ||
|
||
Similar to :func:`matplotlib:matplotlib.pyplot.step` | ||
|
||
Parameters | ||
---------- | ||
where : {'pre', 'post', 'mid'}, optional, default 'pre' | ||
Define where the steps should be placed: | ||
- 'pre': The y value is continued constantly to the left from | ||
every *x* position, i.e. the interval ``(x[i-1], x[i]]`` has the | ||
value ``y[i]``. | ||
- 'post': The y value is continued constantly to the right from | ||
every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the | ||
value ``y[i]``. | ||
- 'mid': Steps occur half-way between the *x* positions. | ||
Note that this parameter is ignored if the x coordinate consists of | ||
:py:func:`pandas.Interval` values, e.g. as a result of | ||
:py:func:`xarray.Dataset.groupby_bins`. In this case, the actual | ||
boundaries of the interval are used. | ||
|
||
*args, **kwargs : optional | ||
Additional arguments following :py:func:`xarray.plot.line` | ||
|
||
""" | ||
if ('ls' in kwargs.keys()) and ('linestyle' not in kwargs.keys()): | ||
kwargs['linestyle'] = kwargs.pop('ls') | ||
|
||
where = kwargs.pop('where', 'pre') | ||
|
||
if where not in ('pre', 'post', 'mid'): | ||
raise ValueError("'where' argument to step must be " | ||
"'pre', 'post' or 'mid'") | ||
|
||
kwargs['linestyle'] = 'steps-' + where + kwargs.get('linestyle', '') | ||
|
||
return line(darray, *args, **kwargs) | ||
|
||
|
||
def hist(darray, figsize=None, size=None, aspect=None, ax=None, **kwargs): | ||
""" | ||
Histogram of DataArray | ||
|
@@ -432,6 +549,10 @@ def hist(self, ax=None, **kwargs): | |
def line(self, *args, **kwargs): | ||
return line(self._da, *args, **kwargs) | ||
|
||
@functools.wraps(step) | ||
def step(self, *args, **kwargs): | ||
return step(self._da, *args, **kwargs) | ||
|
||
|
||
def _rescale_imshow_rgb(darray, vmin, vmax, robust): | ||
assert robust or vmin is not None or vmax is not None | ||
|
@@ -661,6 +782,10 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, | |
|
||
_ensure_plottable(xval, yval) | ||
|
||
# Replace pd.Intervals if contained in xval or yval. | ||
xplt, xlab_extra = _resolve_intervals_2dplot(xval, plotfunc.__name__) | ||
yplt, ylab_extra = _resolve_intervals_2dplot(yval, plotfunc.__name__) | ||
|
||
if 'contour' in plotfunc.__name__ and levels is None: | ||
levels = 7 # this is the matplotlib default | ||
|
||
|
@@ -696,15 +821,15 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, | |
"in xarray") | ||
|
||
ax = get_axis(figsize, size, aspect, ax) | ||
primitive = plotfunc(xval, yval, zval, ax=ax, cmap=cmap_params['cmap'], | ||
primitive = plotfunc(xplt, yplt, zval, ax=ax, cmap=cmap_params['cmap'], | ||
vmin=cmap_params['vmin'], | ||
vmax=cmap_params['vmax'], | ||
**kwargs) | ||
|
||
# Label the plot with metadata | ||
if add_labels: | ||
ax.set_xlabel(label_from_attrs(darray[xlab])) | ||
ax.set_ylabel(label_from_attrs(darray[ylab])) | ||
ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra)) | ||
ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra)) | ||
ax.set_title(darray._title_for_slice()) | ||
|
||
if add_colorbar: | ||
|
@@ -725,7 +850,7 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, | |
_update_axes_limits(ax, xincrease, yincrease) | ||
|
||
# Rotate dates on xlabels | ||
if np.issubdtype(xval.dtype, np.datetime64): | ||
if np.issubdtype(xplt.dtype, np.datetime64): | ||
ax.get_figure().autofmt_xdate() | ||
|
||
return primitive | ||
|
@@ -919,14 +1044,22 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): | |
else: | ||
infer_intervals = True | ||
|
||
if infer_intervals: | ||
if (infer_intervals and | ||
((np.shape(x)[0] == np.shape(z)[1]) or | ||
((x.ndim > 1) and (np.shape(x)[1] == np.shape(z)[1])))): | ||
if len(x.shape) == 1: | ||
x = _infer_interval_breaks(x) | ||
y = _infer_interval_breaks(y) | ||
else: | ||
# we have to infer the intervals on both axes | ||
x = _infer_interval_breaks(x, axis=1) | ||
x = _infer_interval_breaks(x, axis=0) | ||
|
||
if (infer_intervals and | ||
(np.shape(y)[0] == np.shape(z)[0])): | ||
if len(y.shape) == 1: | ||
y = _infer_interval_breaks(y) | ||
else: | ||
# we have to infer the intervals on both axes | ||
y = _infer_interval_breaks(y, axis=1) | ||
y = _infer_interval_breaks(y, axis=0) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -319,6 +319,10 @@ def test_convenient_facetgrid_4d(self): | |
with raises_regex(ValueError, '[Ff]acet'): | ||
d.plot(x='x', y='y', col='columns', ax=plt.gca()) | ||
|
||
def test_coord_with_interval(self): | ||
bins = [-1, 0, 1, 2] | ||
self.darray.groupby_bins('dim_0', bins).mean().plot() | ||
|
||
|
||
class TestPlot1D(PlotTestCase): | ||
def setUp(self): | ||
|
@@ -392,6 +396,19 @@ def test_slice_in_title(self): | |
assert 'd = 10' == title | ||
|
||
|
||
class TestPlotStep(PlotTestCase): | ||
def setUp(self): | ||
self.darray = DataArray(easy_array((2, 3, 4))) | ||
|
||
def test_step(self): | ||
self.darray[0, 0].plot.step() | ||
|
||
def test_coord_with_interval_step(self): | ||
bins = [-1, 0, 1, 2] | ||
self.darray.groupby_bins('dim_0', bins).mean().plot.step() | ||
assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2) | ||
|
||
|
||
class TestPlotHistogram(PlotTestCase): | ||
def setUp(self): | ||
self.darray = DataArray(easy_array((2, 3, 4))) | ||
|
@@ -430,6 +447,10 @@ def test_plot_nans(self): | |
self.darray[0, 0, 0] = np.nan | ||
self.darray.plot.hist() | ||
|
||
def test_hist_coord_with_interval(self): | ||
self.darray.groupby_bins('dim_0', [-1, 0, 1, 2]).mean().plot.hist( | ||
range=(-1, 2)) | ||
|
||
|
||
@requires_matplotlib | ||
class TestDetermineCmapParams(TestCase): | ||
|
@@ -1007,6 +1028,12 @@ def test_cmap_and_color_both(self): | |
with pytest.raises(ValueError): | ||
self.plotmethod(colors='k', cmap='RdBu') | ||
|
||
def test_2d_coord_with_interval(self): | ||
for dim in self.darray.dims: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I left the loop here, because for the 2d plots, x and y axis are treated separately. |
||
gp = self.darray.groupby_bins(dim, range(15)).mean(dim) | ||
for kind in ['imshow', 'pcolormesh', 'contourf', 'contour']: | ||
getattr(gp.plot, kind)() | ||
|
||
|
||
@pytest.mark.slow | ||
class TestContourf(Common2dMixin, PlotTestCase): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"step lot" should be step plot?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
of course, thanks!