Skip to content

Commit 2088b42

Browse files
authored
Match DataFrame.set_index with pandas (#6231)
1 parent c5aefae commit 2088b42

File tree

3 files changed

+271
-32
lines changed

3 files changed

+271
-32
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
- PR #6214 Small clean up to use more algorithms
6363
- PR #6209 Remove CXX11 ABI handling from CMake
6464
- PR #6223 Remove CXX11 ABI flag from JNI build
65+
- PR #6231 Adds `inplace`, `append`, `verify_integrity` fields to `DataFrame.set_index`
6566
- PR #6215 Add cmake command-line setting for spdlog logging level
6667
- PR #6242 Added cudf::detail::host_span and device_span
6768
- PR #6240 Don't shallow copy index in as_index() unless necessary

python/cudf/cudf/core/dataframe.py

Lines changed: 197 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from cudf.utils.dtypes import (
4141
cudf_dtype_from_pydata_dtype,
4242
is_categorical_dtype,
43+
is_column_like,
4344
is_list_dtype,
4445
is_list_like,
4546
is_scalar,
@@ -2627,42 +2628,209 @@ def reindex(
26272628

26282629
return DataFrame(cols, idx)
26292630

2630-
def set_index(self, index, drop=True):
2631+
def _set_index(
2632+
self, index, to_drop=None, inplace=False, verify_integrity=False,
2633+
):
2634+
"""Helper for `.set_index`
2635+
2636+
Parameters
2637+
----------
2638+
index : Index
2639+
The new index to set.
2640+
to_drop : list optional, default None
2641+
A list of labels indicating columns to drop.
2642+
inplace : boolean, default False
2643+
Modify the DataFrame in place (do not create a new object).
2644+
verify_integrity : boolean, default False
2645+
Check for duplicates in the new index.
2646+
"""
2647+
if not isinstance(index, Index):
2648+
raise ValueError("Parameter index should be type `Index`.")
2649+
2650+
df = self if inplace else self.copy(deep=True)
2651+
2652+
if verify_integrity and not index.is_unique:
2653+
raise ValueError(f"Values in Index are not unique: {index}")
2654+
2655+
if to_drop:
2656+
df.drop(columns=to_drop, inplace=True)
2657+
2658+
df.index = index
2659+
return df if not inplace else None
2660+
2661+
def set_index(
2662+
self,
2663+
index,
2664+
drop=True,
2665+
append=False,
2666+
inplace=False,
2667+
verify_integrity=False,
2668+
):
26312669
"""Return a new DataFrame with a new index
26322670
26332671
Parameters
26342672
----------
2635-
index : Index, Series-convertible, str, or list of str
2673+
index : Index, Series-convertible, label-like, or list
26362674
Index : the new index.
26372675
Series-convertible : values for the new index.
2638-
str : name of column to be used as series
2639-
list of str : name of columns to be converted to a MultiIndex
2640-
drop : boolean
2641-
whether to drop corresponding column for str index argument
2642-
"""
2643-
# When index is a list of column names
2644-
if isinstance(index, list):
2645-
if len(index) > 1:
2646-
df = self.copy(deep=False)
2647-
if drop:
2648-
df = df.drop(columns=index, axis=1)
2649-
return df.set_index(
2650-
cudf.MultiIndex.from_frame(self[index], names=index)
2651-
)
2652-
index = index[0] # List contains single item
2653-
2654-
# When index is a column name
2655-
if isinstance(index, str):
2656-
df = self.copy(deep=False)
2657-
if drop:
2658-
df._drop_column(index)
2659-
return df.set_index(self[index])
2660-
# Otherwise
2676+
Label-like : Label of column to be used as index.
2677+
List : List of items from above.
2678+
drop : boolean, default True
2679+
Whether to drop corresponding column for str index argument
2680+
append : boolean, default True
2681+
Whether to append columns to the existing index,
2682+
resulting in a MultiIndex.
2683+
inplace : boolean, default False
2684+
Modify the DataFrame in place (do not create a new object).
2685+
verify_integrity : boolean, default False
2686+
Check for duplicates in the new index.
2687+
2688+
Examples
2689+
--------
2690+
>>> df = cudf.DataFrame({"a": [1, 2, 3, 4, 5],
2691+
... "b": ["a", "b", "c", "d","e"],
2692+
... "c": [1.0, 2.0, 3.0, 4.0, 5.0]})
2693+
>>> df
2694+
a b c
2695+
0 1 a 1.0
2696+
1 2 b 2.0
2697+
2 3 c 3.0
2698+
3 4 d 4.0
2699+
4 5 e 5.0
2700+
2701+
Set the index to become the ‘b’ column:
2702+
2703+
>>> df.set_index('b')
2704+
a c
2705+
b
2706+
a 1 1.0
2707+
b 2 2.0
2708+
c 3 3.0
2709+
d 4 4.0
2710+
e 5 5.0
2711+
2712+
Create a MultiIndex using columns ‘a’ and ‘b’:
2713+
2714+
>>> df.set_index(["a", "b"])
2715+
c
2716+
a b
2717+
1 a 1.0
2718+
2 b 2.0
2719+
3 c 3.0
2720+
4 d 4.0
2721+
5 e 5.0
2722+
2723+
Set new Index instance as index:
2724+
2725+
>>> df.set_index(cudf.RangeIndex(10, 15))
2726+
a b c
2727+
10 1 a 1.0
2728+
11 2 b 2.0
2729+
12 3 c 3.0
2730+
13 4 d 4.0
2731+
14 5 e 5.0
2732+
2733+
Setting `append=True` will combine current index with column `a`:
2734+
2735+
>>> df.set_index("a", append=True)
2736+
b c
2737+
a
2738+
0 1 a 1.0
2739+
1 2 b 2.0
2740+
2 3 c 3.0
2741+
3 4 d 4.0
2742+
4 5 e 5.0
2743+
2744+
`set_index` supports `inplace` parameter too:
2745+
2746+
>>> df.set_index("a", inplace=True)
2747+
>>> df
2748+
b c
2749+
a
2750+
1 a 1.0
2751+
2 b 2.0
2752+
3 c 3.0
2753+
4 d 4.0
2754+
5 e 5.0
2755+
"""
2756+
2757+
if not isinstance(index, list):
2758+
index = [index]
2759+
2760+
# Preliminary type check
2761+
col_not_found = []
2762+
columns_to_add = []
2763+
names = []
2764+
to_drop = []
2765+
for i, col in enumerate(index):
2766+
# Is column label
2767+
if is_scalar(col) or isinstance(col, tuple):
2768+
if col in self.columns:
2769+
columns_to_add.append(self[col])
2770+
names.append(col)
2771+
if drop:
2772+
to_drop.append(col)
2773+
else:
2774+
col_not_found.append(col)
2775+
else:
2776+
# Try coerce into column
2777+
if not is_column_like(col):
2778+
try:
2779+
col = as_column(col)
2780+
except TypeError:
2781+
msg = f"{col} cannot be converted to column-like."
2782+
raise TypeError(msg)
2783+
if isinstance(col, (cudf.MultiIndex, pd.MultiIndex)):
2784+
col = (
2785+
cudf.from_pandas(col)
2786+
if isinstance(col, pd.MultiIndex)
2787+
else col
2788+
)
2789+
cols = [col._data[x] for x in col._data]
2790+
columns_to_add.extend(cols)
2791+
names.extend(col.names)
2792+
else:
2793+
if isinstance(col, (pd.RangeIndex, cudf.RangeIndex)):
2794+
# Corner case: RangeIndex does not need to instantiate
2795+
columns_to_add.append(col)
2796+
else:
2797+
# For pandas obj, convert to gpu obj
2798+
columns_to_add.append(as_column(col))
2799+
if isinstance(
2800+
col, (cudf.Series, cudf.Index, pd.Series, pd.Index)
2801+
):
2802+
names.append(col.name)
2803+
else:
2804+
names.append(None)
2805+
2806+
if col_not_found:
2807+
raise KeyError(f"None of {col_not_found} are in the columns")
2808+
2809+
if append:
2810+
idx_cols = [self.index._data[x] for x in self.index._data]
2811+
if isinstance(self.index, cudf.MultiIndex):
2812+
idx_names = self.index.names
2813+
else:
2814+
idx_names = [self.index.name]
2815+
columns_to_add = idx_cols + columns_to_add
2816+
names = idx_names + names
2817+
2818+
if len(columns_to_add) == 0:
2819+
raise ValueError("No valid columns to be added to index.")
2820+
elif len(columns_to_add) == 1:
2821+
idx = cudf.Index(columns_to_add[0], name=names[0])
26612822
else:
2662-
index = index if isinstance(index, Index) else as_index(index)
2663-
df = self.copy(deep=False)
2664-
df.index = index
2665-
return df
2823+
idf = cudf.DataFrame()
2824+
for i, col in enumerate(columns_to_add):
2825+
idf[i] = col
2826+
idx = cudf.MultiIndex.from_frame(idf, names=names)
2827+
2828+
return self._set_index(
2829+
index=idx,
2830+
to_drop=to_drop,
2831+
inplace=inplace,
2832+
verify_integrity=verify_integrity,
2833+
)
26662834

26672835
def reset_index(
26682836
self, level=None, drop=False, inplace=False, col_level=0, col_fill=""

python/cudf/cudf/tests/test_dataframe.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2290,10 +2290,80 @@ def test_reset_index_inplace(pdf, gdf, drop):
22902290
assert_eq(pdf, gdf)
22912291

22922292

2293+
@pytest.mark.parametrize(
2294+
"data",
2295+
[
2296+
{
2297+
"a": [1, 2, 3, 4, 5],
2298+
"b": ["a", "b", "c", "d", "e"],
2299+
"c": [1.0, 2.0, 3.0, 4.0, 5.0],
2300+
}
2301+
],
2302+
)
2303+
@pytest.mark.parametrize(
2304+
"index",
2305+
[
2306+
"a",
2307+
["a", "b"],
2308+
pd.CategoricalIndex(["I", "II", "III", "IV", "V"]),
2309+
pd.Series(["h", "i", "k", "l", "m"]),
2310+
["b", pd.Index(["I", "II", "III", "IV", "V"])],
2311+
["c", [11, 12, 13, 14, 15]],
2312+
pd.MultiIndex(
2313+
levels=[
2314+
["I", "II", "III", "IV", "V"],
2315+
["one", "two", "three", "four", "five"],
2316+
],
2317+
codes=[[0, 1, 2, 3, 4], [4, 3, 2, 1, 0]],
2318+
names=["col1", "col2"],
2319+
),
2320+
pd.RangeIndex(0, 5), # corner case
2321+
[pd.Series(["h", "i", "k", "l", "m"]), pd.RangeIndex(0, 5)],
2322+
[
2323+
pd.MultiIndex(
2324+
levels=[
2325+
["I", "II", "III", "IV", "V"],
2326+
["one", "two", "three", "four", "five"],
2327+
],
2328+
codes=[[0, 1, 2, 3, 4], [4, 3, 2, 1, 0]],
2329+
names=["col1", "col2"],
2330+
),
2331+
pd.RangeIndex(0, 5),
2332+
],
2333+
],
2334+
)
22932335
@pytest.mark.parametrize("drop", [True, False])
2294-
def test_set_index(pdf, gdf, drop):
2295-
for col in pdf.columns:
2296-
assert_eq(pdf.set_index(col, drop=drop), gdf.set_index(col, drop=drop))
2336+
@pytest.mark.parametrize("append", [True, False])
2337+
@pytest.mark.parametrize("inplace", [True, False])
2338+
def test_set_index(data, index, drop, append, inplace):
2339+
gdf = gd.DataFrame(data)
2340+
pdf = gdf.to_pandas()
2341+
2342+
expected = pdf.set_index(index, inplace=inplace, drop=drop, append=append)
2343+
actual = gdf.set_index(index, inplace=inplace, drop=drop, append=append)
2344+
2345+
if inplace:
2346+
expected = pdf
2347+
actual = gdf
2348+
assert_eq(expected, actual)
2349+
2350+
2351+
@pytest.mark.parametrize(
2352+
"data",
2353+
[
2354+
{
2355+
"a": [1, 1, 2, 2, 5],
2356+
"b": ["a", "b", "c", "d", "e"],
2357+
"c": [1.0, 2.0, 3.0, 4.0, 5.0],
2358+
}
2359+
],
2360+
)
2361+
@pytest.mark.parametrize("index", ["a", pd.Index([1, 1, 2, 2, 3])])
2362+
@pytest.mark.parametrize("verify_integrity", [True])
2363+
@pytest.mark.xfail
2364+
def test_set_index_verify_integrity(data, index, verify_integrity):
2365+
gdf = gd.DataFrame(data)
2366+
gdf.set_index(index, verify_integrity=verify_integrity)
22972367

22982368

22992369
@pytest.mark.parametrize("drop", [True, False])

0 commit comments

Comments
 (0)