diff --git a/impyute/util/preprocess.py b/impyute/util/preprocess.py index c87ed91..b2bd1e0 100644 --- a/impyute/util/preprocess.py +++ b/impyute/util/preprocess.py @@ -2,6 +2,16 @@ from functools import wraps # pylint:disable=invalid-name +# TODO:Some hacky ass code to handle python2 not having `ModuleNotFoundError` +try: + raise ModuleNotFoundError +except NameError: + class ModuleNotFoundError(Exception): + pass +except ModuleNotFoundError: + pass + + def preprocess(fn): """ Base preprocess function for commonly used preprocessing @@ -19,14 +29,29 @@ def preprocess(fn): @wraps(fn) def wrapper(*args, **kwargs): """ Run input checks""" + # convert tuple to list so args can be modified + args = list(args) + + # Either make a copy or use a pointer to the original if "inplace" in kwargs and kwargs['inplace']: - data = args[0] + args[0] = args[0] else: - data = args[0].copy() + args[0] = args[0].copy() - if len(args) == 1: - return fn(data, **kwargs) - return fn(data, *args[1:], **kwargs) - return wrapper + # Check if Pandas exists + try: + import pandas as pd + pd_DataFrame = pd.DataFrame + except (ModuleNotFoundError, ImportError): + pd_DataFrame = None + # If Pandas exists, and the input data is a dataframe + # then cast the input to an np.array and cast the output + # back to a DataFrame. + if pd_DataFrame and isinstance(args[0], pd_DataFrame): + args[0] = args[0].as_matrix() + return pd_DataFrame(fn(*args, **kwargs)) + else: + return fn(*args, **kwargs) + return wrapper diff --git a/test/util/test_preprocess.py b/test/util/test_preprocess.py new file mode 100644 index 0000000..a0af11b --- /dev/null +++ b/test/util/test_preprocess.py @@ -0,0 +1,58 @@ +"""test_preprocess.py""" +import unittest +import numpy as np +from impyute.util import preprocess +from impyute.imputation.cs import mean + +# TODO:Some hacky ass code to handle python2 not having `ModuleNotFoundError` +try: + raise ModuleNotFoundError +except NameError: + class ModuleNotFoundError(Exception): + pass +except ModuleNotFoundError: + pass + +class TestPreprocess(unittest.TestCase): + """ Tests for checks""" + def setUp(self): + @preprocess + def mul(arr, **kwargs): + arr = arr * 25 + return arr + self.mul = mul + + def test_inplace_false(self): + A = np.ones((5, 5)) + A_copy = A.copy() + self.mul(A, inplace=False) + assert all(map(all, A == A_copy)) + + @unittest.skip("Implementation of this is still buggy, kind of \ + works only, depending on input") + def test_inplace_true(self): + A = np.ones((5, 5)) + A_copy = A.copy() + self.mul(A, inplace=False) + assert all(map(all, A != A_copy)) + + def test_pandas_input(self): + """ Input: DataFrame, Output: DataFrame """ + # Skip this test if you don't have pandas + try: + import pandas as pd + except (ModuleNotFoundError, ImportError): + return True + + # Create a DataFrame with a NaN + A = np.arange(25).reshape((5,5)).astype(np.float) + A[0][0] = np.nan + A = pd.DataFrame(A) + + # Assert that the output is a DataFrame + assert isinstance(mean(A), pd.DataFrame) + + +if __name__ == "__main__": + unittest.main() +