Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions impyute/util/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
from functools import wraps
# pylint:disable=invalid-name

# TODO:Some hacky ass code to handle python2 not having `ModuleNotFoundError`
try:
raise ModuleNotFoundError
except NameError:
class ModuleNotFoundError(Exception):
pass
except ModuleNotFoundError:
pass


def preprocess(fn):
""" Base preprocess function for commonly used preprocessing

Expand All @@ -19,14 +29,29 @@ def preprocess(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
""" Run input checks"""
# convert tuple to list so args can be modified
args = list(args)

# Either make a copy or use a pointer to the original
if "inplace" in kwargs and kwargs['inplace']:
data = args[0]
args[0] = args[0]
else:
data = args[0].copy()
args[0] = args[0].copy()

if len(args) == 1:
return fn(data, **kwargs)
return fn(data, *args[1:], **kwargs)
return wrapper
# Check if Pandas exists
try:
import pandas as pd
pd_DataFrame = pd.DataFrame
except (ModuleNotFoundError, ImportError):
pd_DataFrame = None

# If Pandas exists, and the input data is a dataframe
# then cast the input to an np.array and cast the output
# back to a DataFrame.
if pd_DataFrame and isinstance(args[0], pd_DataFrame):
args[0] = args[0].as_matrix()
return pd_DataFrame(fn(*args, **kwargs))
else:
return fn(*args, **kwargs)

return wrapper
58 changes: 58 additions & 0 deletions test/util/test_preprocess.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
"""test_preprocess.py"""
import unittest
import numpy as np
from impyute.util import preprocess
from impyute.imputation.cs import mean

# TODO:Some hacky ass code to handle python2 not having `ModuleNotFoundError`
try:
raise ModuleNotFoundError
except NameError:
class ModuleNotFoundError(Exception):
pass
except ModuleNotFoundError:
pass

class TestPreprocess(unittest.TestCase):
""" Tests for checks"""
def setUp(self):
@preprocess
def mul(arr, **kwargs):
arr = arr * 25
return arr
self.mul = mul

def test_inplace_false(self):
A = np.ones((5, 5))
A_copy = A.copy()
self.mul(A, inplace=False)
assert all(map(all, A == A_copy))

@unittest.skip("Implementation of this is still buggy, kind of \
works only, depending on input")
def test_inplace_true(self):
A = np.ones((5, 5))
A_copy = A.copy()
self.mul(A, inplace=False)
assert all(map(all, A != A_copy))

def test_pandas_input(self):
""" Input: DataFrame, Output: DataFrame """
# Skip this test if you don't have pandas
try:
import pandas as pd
except (ModuleNotFoundError, ImportError):
return True

# Create a DataFrame with a NaN
A = np.arange(25).reshape((5,5)).astype(np.float)
A[0][0] = np.nan
A = pd.DataFrame(A)

# Assert that the output is a DataFrame
assert isinstance(mean(A), pd.DataFrame)


if __name__ == "__main__":
unittest.main()