Skip to content

Commit eaa6aaf

Browse files
timhoffmmroeschkedatapythonista
authored
BUG: Copy attrs on pd.merge() (#60357)
* BUG: Copy attrs on pd.merge() This uses the same logic as `pd.concat()`: Copy `attrs` only if all input `attrs` are identical. I've refactored the handling in __finalize__ from special-casing based on th the method name (previously only "concat") to handling "other" parameters that have an `input_objs` attribute. This is a more scalable architecture compared to hard-coding method names in __finalize__. Tests added for `concat()` and `merge()`. Closes #60351. * Update docs * Add release note * Fix wrong link Co-authored-by: Matthew Roeschke <[email protected]> * Move release note to v3.0.0 --------- Co-authored-by: Matthew Roeschke <[email protected]> Co-authored-by: Marc Garcia <[email protected]>
1 parent 52e00a3 commit eaa6aaf

File tree

8 files changed

+54
-18
lines changed

8 files changed

+54
-18
lines changed

doc/source/whatsnew/v3.0.0.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ Enhancement2
2828

2929
Other enhancements
3030
^^^^^^^^^^^^^^^^^^
31+
- :func:`pandas.merge` propagates the ``attrs`` attribute to the result if all
32+
inputs have identical ``attrs``, as has so far already been the case for
33+
:func:`pandas.concat`.
3134
- :class:`pandas.api.typing.FrozenList` is available for typing the outputs of :attr:`MultiIndex.names`, :attr:`MultiIndex.codes` and :attr:`MultiIndex.levels` (:issue:`58237`)
3235
- :class:`pandas.api.typing.SASReader` is available for typing the output of :func:`read_sas` (:issue:`55689`)
3336
- Added :meth:`.Styler.to_typst` to write Styler objects to file, buffer or string in Typst format (:issue:`57617`)

pandas/core/generic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -330,8 +330,8 @@ def attrs(self) -> dict[Hashable, Any]:
330330
-----
331331
Many operations that create new datasets will copy ``attrs``. Copies
332332
are always deep so that changing ``attrs`` will only affect the
333-
present dataset. ``pandas.concat`` copies ``attrs`` only if all input
334-
datasets have the same ``attrs``.
333+
present dataset. :func:`pandas.concat` and :func:`pandas.merge` will
334+
only copy ``attrs`` if all input datasets have the same ``attrs``.
335335
336336
Examples
337337
--------
@@ -6090,11 +6090,11 @@ def __finalize__(self, other, method: str | None = None, **kwargs) -> Self:
60906090
assert isinstance(name, str)
60916091
object.__setattr__(self, name, getattr(other, name, None))
60926092

6093-
if method == "concat":
6094-
objs = other.objs
6095-
# propagate attrs only if all concat arguments have the same attrs
6093+
elif hasattr(other, "input_objs"):
6094+
objs = other.input_objs
6095+
# propagate attrs only if all inputs have the same attrs
60966096
if all(bool(obj.attrs) for obj in objs):
6097-
# all concatenate arguments have non-empty attrs
6097+
# all inputs have non-empty attrs
60986098
attrs = objs[0].attrs
60996099
have_same_attrs = all(obj.attrs == attrs for obj in objs[1:])
61006100
if have_same_attrs:

pandas/core/reshape/concat.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def _get_result(
550550
result = sample._constructor_from_mgr(mgr, axes=mgr.axes)
551551
result._name = name
552552
return result.__finalize__(
553-
types.SimpleNamespace(objs=objs), method="concat"
553+
types.SimpleNamespace(input_objs=objs), method="concat"
554554
)
555555

556556
# combine as columns in a frame
@@ -571,7 +571,9 @@ def _get_result(
571571
)
572572
df = cons(data, index=index, copy=False)
573573
df.columns = columns
574-
return df.__finalize__(types.SimpleNamespace(objs=objs), method="concat")
574+
return df.__finalize__(
575+
types.SimpleNamespace(input_objs=objs), method="concat"
576+
)
575577

576578
# combine block managers
577579
else:
@@ -610,7 +612,7 @@ def _get_result(
610612
)
611613

612614
out = sample._constructor_from_mgr(new_data, axes=new_data.axes)
613-
return out.__finalize__(types.SimpleNamespace(objs=objs), method="concat")
615+
return out.__finalize__(types.SimpleNamespace(input_objs=objs), method="concat")
614616

615617

616618
def new_axes(

pandas/core/reshape/merge.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
)
1111
import datetime
1212
from functools import partial
13+
import types
1314
from typing import (
1415
TYPE_CHECKING,
1516
Literal,
@@ -1134,7 +1135,10 @@ def get_result(self) -> DataFrame:
11341135
join_index, left_indexer, right_indexer = self._get_join_info()
11351136

11361137
result = self._reindex_and_concat(join_index, left_indexer, right_indexer)
1137-
result = result.__finalize__(self, method=self._merge_type)
1138+
result = result.__finalize__(
1139+
types.SimpleNamespace(input_objs=[self.left, self.right]),
1140+
method=self._merge_type,
1141+
)
11381142

11391143
if self.indicator:
11401144
result = self._indicator_post_merge(result)
@@ -1143,7 +1147,9 @@ def get_result(self) -> DataFrame:
11431147

11441148
self._maybe_restore_index_levels(result)
11451149

1146-
return result.__finalize__(self, method="merge")
1150+
return result.__finalize__(
1151+
types.SimpleNamespace(input_objs=[self.left, self.right]), method="merge"
1152+
)
11471153

11481154
@final
11491155
@cache_readonly

pandas/tests/frame/test_api.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ def test_attrs(self):
315315
result = df.rename(columns=str)
316316
assert result.attrs == {"version": 1}
317317

318-
def test_attrs_deepcopy(self):
318+
def test_attrs_is_deepcopy(self):
319319
df = DataFrame({"A": [2, 3]})
320320
assert df.attrs == {}
321321
df.attrs["tags"] = {"spam", "ham"}
@@ -324,6 +324,30 @@ def test_attrs_deepcopy(self):
324324
assert result.attrs == df.attrs
325325
assert result.attrs["tags"] is not df.attrs["tags"]
326326

327+
def test_attrs_concat(self):
328+
# concat propagates attrs if all input attrs are equal
329+
df1 = DataFrame({"A": [2, 3]})
330+
df1.attrs = {"a": 1, "b": 2}
331+
df2 = DataFrame({"A": [4, 5]})
332+
df2.attrs = df1.attrs.copy()
333+
df3 = DataFrame({"A": [6, 7]})
334+
df3.attrs = df1.attrs.copy()
335+
assert pd.concat([df1, df2, df3]).attrs == df1.attrs
336+
# concat does not propagate attrs if input attrs are different
337+
df2.attrs = {"c": 3}
338+
assert pd.concat([df1, df2, df3]).attrs == {}
339+
340+
def test_attrs_merge(self):
341+
# merge propagates attrs if all input attrs are equal
342+
df1 = DataFrame({"key": ["a", "b"], "val1": [1, 2]})
343+
df1.attrs = {"a": 1, "b": 2}
344+
df2 = DataFrame({"key": ["a", "b"], "val2": [3, 4]})
345+
df2.attrs = df1.attrs.copy()
346+
assert pd.merge(df1, df2).attrs == df1.attrs
347+
# merge does not propagate attrs if input attrs are different
348+
df2.attrs = {"c": 3}
349+
assert pd.merge(df1, df2).attrs == {}
350+
327351
@pytest.mark.parametrize("allows_duplicate_labels", [True, False, None])
328352
def test_set_flags(
329353
self,

pandas/tests/generic/test_duplicate_labels.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,6 @@ def test_concat(self, objs, kwargs):
164164
allows_duplicate_labels=False
165165
),
166166
False,
167-
marks=not_implemented,
168167
),
169168
# false true false
170169
pytest.param(
@@ -173,7 +172,6 @@ def test_concat(self, objs, kwargs):
173172
),
174173
pd.DataFrame({"B": [0, 1]}, index=["a", "d"]),
175174
False,
176-
marks=not_implemented,
177175
),
178176
# true true true
179177
(
@@ -296,7 +294,6 @@ def test_concat_raises(self):
296294
with pytest.raises(pd.errors.DuplicateLabelError, match=msg):
297295
pd.concat(objs, axis=1)
298296

299-
@not_implemented
300297
def test_merge_raises(self):
301298
a = pd.DataFrame({"A": [0, 1, 2]}, index=["a", "b", "c"]).set_flags(
302299
allows_duplicate_labels=False

pandas/tests/generic/test_frame.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,16 @@ def test_metadata_propagation_indiv(self, monkeypatch):
8080
def finalize(self, other, method=None, **kwargs):
8181
for name in self._metadata:
8282
if method == "merge":
83-
left, right = other.left, other.right
83+
left, right = other.input_objs
8484
value = getattr(left, name, "") + "|" + getattr(right, name, "")
8585
object.__setattr__(self, name, value)
8686
elif method == "concat":
8787
value = "+".join(
88-
[getattr(o, name) for o in other.objs if getattr(o, name, None)]
88+
[
89+
getattr(o, name)
90+
for o in other.input_objs
91+
if getattr(o, name, None)
92+
]
8993
)
9094
object.__setattr__(self, name, value)
9195
else:

pandas/tests/generic/test_series.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def finalize(self, other, method=None, **kwargs):
9797
value = "+".join(
9898
[
9999
getattr(obj, name)
100-
for obj in other.objs
100+
for obj in other.input_objs
101101
if getattr(obj, name, None)
102102
]
103103
)

0 commit comments

Comments
 (0)