Skip to content

add keep_matlab_shapes to oct2py #334

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions oct2py/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class Oct2Py:
If true, convert integer types to float when passing to Octave.
backend: string, optional
The graphics_toolkit to use for plotting.
keep_matlab_shapes: bool, optional
If true, matlab shapes will be preserved (scalars as (1,1) etc)
"""

def __init__( # noqa
Expand All @@ -68,6 +70,7 @@ def __init__( # noqa
temp_dir=None,
convert_to_float=True,
backend=None,
keep_matlab_shapes=False,
):
"""Start Octave and set up the session."""
self._oned_as = oned_as
Expand All @@ -76,6 +79,7 @@ def __init__( # noqa
self.logger = logger
self.timeout = timeout
self.backend = backend or "default"
self.keep_matlab_shapes = keep_matlab_shapes
if temp_dir is None:
temp_dir_obj = tempfile.mkdtemp()
self.temp_dir = temp_dir_obj
Expand Down
48 changes: 24 additions & 24 deletions oct2py/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,17 @@ class DataFrame: # type:ignore[no-redef]
_WRITE_LOCK = threading.Lock()


def read_file(path, session=None):
def read_file(path, session=None, keep_matlab_shapes=False):
"""Read the data from the given file path."""
if session:
keep_matlab_shapes = keep_matlab_shapes or session.keep_matlab_shapes
try:
data = loadmat(path, struct_as_record=True)
except UnicodeDecodeError as e:
raise Oct2PyError(str(e)) from None
out = {}
for key, value in data.items():
out[key] = _extract(value, session)
out[key] = _extract(value, session, keep_matlab_shapes)
return out


Expand Down Expand Up @@ -151,22 +153,19 @@ class StructArray(np.recarray): # type:ignore[type-arg]
4.0
"""

def __new__(cls, value, session=None):
def __new__(cls, value, session=None, keep_matlab_shapes=False):
"""Create a struct array from a value and optional Octave session."""
value = np.asarray(value)
# Squeeze the last element if it is 1
if value.shape[value.ndim - 1] == 1:
if value.shape[value.ndim - 1] == 1 and not keep_matlab_shapes:
value = value.squeeze(axis=value.ndim - 1)
value = np.atleast_1d(value)

if not session:
return value.view(cls)

# Extract the values.
obj = np.empty(value.size, dtype=value.dtype).view(cls)
for i, item in enumerate(value.ravel()):
for name in value.dtype.names:
obj[i][name] = _extract(item[name], session)
obj[i][name] = _extract(item[name], session, keep_matlab_shapes)
return obj.reshape(value.shape)

@property
Expand Down Expand Up @@ -224,20 +223,16 @@ class Cell(np.ndarray): # type:ignore[type-arg]
[1.0, 1.0]
"""

def __new__(cls, value, session=None):
def __new__(cls, value, session=None, keep_matlab_shapes=False):
"""Create a cell array from a value and optional Octave session."""
# Use atleast_2d to preserve Octave size()
value = np.atleast_2d(np.asarray(value, dtype=object))

if not session:
return value.view(cls)

# Extract the values.
obj = np.empty(value.size, dtype=object).view(cls)
for i, item in enumerate(value.ravel()):
obj[i] = _extract(item, session)
obj = obj.reshape(value.shape) # type:ignore[assignment]

obj[i] = _extract(item, session, keep_matlab_shapes)
obj = obj.reshape(value.shape)
return obj

def __repr__(self):
Expand All @@ -258,11 +253,11 @@ def __getitem__(self, key):
return super().__getitem__(key)


def _extract(data, session=None): # noqa
def _extract(data, session=None, keep_matlab_shapes=False): # noqa
"""Convert the Octave values to values suitable for Python."""
# Extract each item of a list.
if isinstance(data, list):
return [_extract(d, session) for d in data]
return [_extract(d, session, keep_matlab_shapes) for d in data]

# Ignore leaf objects.
if not isinstance(data, np.ndarray):
Expand All @@ -277,20 +272,25 @@ def _extract(data, session=None): # noqa
if data.dtype.names:
# Singular struct
if data.size == 1:
return _create_struct(data, session)
return _create_struct(data, session, keep_matlab_shapes)
# Struct array
return StructArray(data, session)
return StructArray(data, session, keep_matlab_shapes)

# Extract cells.
if data.dtype.kind == "O":
return Cell(data, session)
return Cell(data, session, keep_matlab_shapes)

# Compress singleton values.
if data.size == 1:
return data.item()
if not keep_matlab_shapes:
return data.item()
else:
if data.dtype.kind in "US":
return data.item()
return data

# Compress empty values.
if data.size == 0:
if data.shape in ((0,), (0, 0)):
if data.dtype.kind in "US":
return ""
return []
Expand All @@ -299,15 +299,15 @@ def _extract(data, session=None): # noqa
return data


def _create_struct(data, session):
def _create_struct(data, session, keep_matlab_shapes=False):
"""Create a struct from session data."""
out = Struct()
for name in data.dtype.names:
item = data[name]
# Extract values that are cells (they are doubly wrapped).
if isinstance(item, np.ndarray) and item.dtype.kind == "O":
item = item.squeeze().tolist()
out[name] = _extract(item, session)
out[name] = _extract(item, session, keep_matlab_shapes)
return out


Expand Down
146 changes: 146 additions & 0 deletions tests/test_keep_matlab_shapes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os
import warnings

import numpy as np

from oct2py import Oct2Py


class TestNumpy:
"""Check value and type preservation of Numpy arrays"""

oc: Oct2Py
codes = np.typecodes["All"]

@classmethod
def setup_class(cls):
cls.oc = Oct2Py(keep_matlab_shapes=True)
cls.oc.addpath(os.path.dirname(__file__))

def teardown_class(cls): # noqa
cls.oc.exit()

def test_scalars(self):
"""Send scalar numpy types and make sure we get the same number back."""
for typecode in self.codes:
if typecode == "V":
continue
outgoing = np.random.randint(-255, 255) + np.random.rand(1)
if typecode in "US":
outgoing = np.array("spam").astype(typecode)
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning)
try:
outgoing = outgoing.astype(typecode)
except TypeError:
continue
incoming = self.oc.roundtrip(outgoing)
try:
assert np.allclose(incoming, outgoing)
except (ValueError, TypeError, NotImplementedError, AssertionError):
assert np.all(np.array(incoming).astype(typecode) == outgoing)

def test_ndarrays(self):
"""Send ndarrays and make sure we get the same array back"""
for typecode in self.codes:
if typecode == "V":
continue
for ndims in [2, 3, 4]:
size = [np.random.randint(1, 10) for i in range(ndims)]
outgoing = np.random.randint(-255, 255, tuple(size))
try:
outgoing += np.random.rand(*size).astype(outgoing.dtype, casting="unsafe")
except TypeError: # pragma: no cover
outgoing += np.random.rand(*size).astype(outgoing.dtype)
if typecode in ["U", "S"]:
outgoing = [ # type:ignore
[["spam", "eggs", "hash"], ["spam", "eggs", "hash"]],
[["spam", "eggs", "hash"], ["spam", "eggs", "hash"]],
]
outgoing = np.array(outgoing).astype(typecode)
else:
try:
outgoing = outgoing.astype(typecode)
except TypeError:
continue
with warnings.catch_warnings():
warnings.simplefilter("ignore", FutureWarning)
incoming = self.oc.roundtrip(outgoing)
incoming = np.array(incoming)
if outgoing.size == 1:
outgoing = outgoing.squeeze()
if len(outgoing.shape) > 2 and 1 in outgoing.shape:
incoming = incoming.squeeze()
outgoing = outgoing.squeeze()
elif incoming.size == 1:
incoming = incoming.squeeze()
if typecode == "O":
incoming = incoming.squeeze()
outgoing = outgoing.squeeze()
assert incoming.shape == outgoing.shape
try:
assert np.allclose(incoming, outgoing)
except (AssertionError, ValueError, TypeError, NotImplementedError):
if "c" in incoming.dtype.str:
incoming = np.abs(incoming)
outgoing = np.abs(outgoing)
assert np.all(np.array(incoming).astype(typecode) == outgoing)

def test_sparse(self):
"""Test roundtrip sparse matrices"""
from scipy.sparse import csr_matrix, identity # type:ignore

rand = np.random.rand(100, 100)
rand = csr_matrix(rand)
iden = identity(1000)
for item in [rand, iden]:
incoming, type_ = self.oc.roundtrip(item, nout=2)
assert item.shape == incoming.shape
assert item.nnz == incoming.nnz
assert np.allclose(item.todense(), incoming.todense())
assert item.dtype == incoming.dtype
assert type_ in ("double", "cell")

def test_empty(self):
"""Test roundtrip empty matrices"""
empty = np.empty((100, 100))
incoming, type_ = self.oc.roundtrip(empty, nout=2)
assert empty.squeeze().shape == incoming.squeeze().shape
assert np.allclose(empty[np.isfinite(empty)], incoming[np.isfinite(incoming)])
assert type_ == "double"

def test_masked(self):
"""Test support for masked arrays"""
test = np.random.rand(100)
test = np.ma.array(test)
incoming, type_ = self.oc.roundtrip(test, nout=2)
assert np.allclose(test, incoming)
assert test.dtype == incoming.dtype
assert type_ == "double"

def test_shaped_but_zero_sized(self):
"""Test support for shaped but zero-sized arrays"""
test = np.zeros((0, 1, 2))
incoming, type_ = self.oc.roundtrip(test, nout=2)
assert test.shape == incoming.shape
assert test.dtype == incoming.dtype
assert type_ == "double"

def test_keep_matlab_shape(self):
"""Test support for keep_matlab_shape"""
tests = [
((1,), (1, 1)),
((2,), (1, 2)),
((1, 1), (1, 1)),
((1, 2), (1, 2)),
((2, 1), (2, 1)),
((2, 2), (2, 2)),
((1, 2, 3), (1, 2, 3)),
((2, 1, 1), (2, 1)),
]
for test_out_shape, test_in_shape in tests:
outgoing = np.zeros(test_out_shape)
incoming, type_ = self.oc.roundtrip(outgoing, nout=2)
assert incoming.shape == test_in_shape
assert outgoing.dtype == incoming.dtype
assert type_ == "double"
8 changes: 8 additions & 0 deletions tests/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,11 @@ def test_masked(self):
assert np.allclose(test, incoming)
assert test.dtype == incoming.dtype
assert type_ == "double"

def test_shaped_but_zero_sized(self):
"""Test support for shaped but zero-sized arrays"""
test = np.zeros((0, 1, 2))
incoming, type_ = self.oc.roundtrip(test, nout=2)
assert test.shape == incoming.shape
assert test.dtype == incoming.dtype
assert type_ == "double"
Loading