Skip to content

Fix/merge idx #172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 40 additions & 43 deletions pyerrors/obs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import warnings
import hashlib
import pickle
from math import gcd
from functools import reduce
import numpy as np
import autograd.numpy as anp # Thinly-wrapped numpy
from autograd import jacobian
Expand Down Expand Up @@ -280,7 +278,7 @@ def _parse_kwarg(kwarg_name):

def _compute_drho(i):
tmp = (self.e_rho[e_name][i + 1:w_max]
+ np.concatenate([self.e_rho[e_name][i - 1:None if i - w_max // 2 < 0 else 2 * (i - w_max // 2):-1],
+ np.concatenate([self.e_rho[e_name][i - 1:None if i - w_max // 2 <= 0 else 2 * (i - w_max // 2):-1],
self.e_rho[e_name][1:max(1, w_max - 2 * i)]])
- 2 * self.e_rho[e_name][i] * self.e_rho[e_name][1:w_max - i])
self.e_drho[e_name][i] = np.sqrt(np.sum(tmp ** 2) / e_N)
Expand Down Expand Up @@ -1022,63 +1020,52 @@ def _expand_deltas(deltas, idx, shape, gapsize):


def _merge_idx(idl):
"""Returns the union of all lists in idl as sorted list
"""Returns the union of all lists in idl as range or sorted list

Parameters
----------
idl : list
List of lists or ranges.
"""

# Use groupby to efficiently check whether all elements of idl are identical
try:
g = groupby(idl)
if next(g, True) and not next(g, False):
return idl[0]
except Exception:
pass
if _check_lists_equal(idl):
return idl[0]

if np.all([type(idx) is range for idx in idl]):
if len(set([idx[0] for idx in idl])) == 1:
idstart = min([idx.start for idx in idl])
idstop = max([idx.stop for idx in idl])
idstep = min([idx.step for idx in idl])
return range(idstart, idstop, idstep)
idunion = sorted(set().union(*idl))

return sorted(set().union(*idl))
# Check whether idunion can be expressed as range
idrange = range(idunion[0], idunion[-1] + 1, idunion[1] - idunion[0])
idtest = [list(idrange), idunion]
if _check_lists_equal(idtest):
return idrange

return idunion


def _intersection_idx(idl):
"""Returns the intersection of all lists in idl as sorted list
"""Returns the intersection of all lists in idl as range or sorted list

Parameters
----------
idl : list
List of lists or ranges.
"""

def _lcm(*args):
"""Returns the lowest common multiple of args.
if _check_lists_equal(idl):
return idl[0]

From python 3.9 onwards the math library contains an lcm function."""
return reduce(lambda a, b: a * b // gcd(a, b), args)
idinter = sorted(set.intersection(*[set(o) for o in idl]))

# Use groupby to efficiently check whether all elements of idl are identical
# Check whether idinter can be expressed as range
try:
g = groupby(idl)
if next(g, True) and not next(g, False):
return idl[0]
except Exception:
idrange = range(idinter[0], idinter[-1] + 1, idinter[1] - idinter[0])
idtest = [list(idrange), idinter]
if _check_lists_equal(idtest):
return idrange
except IndexError:
pass

if np.all([type(idx) is range for idx in idl]):
if len(set([idx[0] for idx in idl])) == 1:
idstart = max([idx.start for idx in idl])
idstop = min([idx.stop for idx in idl])
idstep = _lcm(*[idx.step for idx in idl])
return range(idstart, idstop, idstep)

return sorted(set.intersection(*[set(o) for o in idl]))
return idinter


def _expand_deltas_for_merge(deltas, idx, shape, new_idx):
Expand Down Expand Up @@ -1299,13 +1286,8 @@ def _reduce_deltas(deltas, idx_old, idx_new):
if type(idx_old) is range and type(idx_new) is range:
if idx_old == idx_new:
return deltas
# Use groupby to efficiently check whether all elements of idx_old and idx_new are identical
try:
g = groupby([idx_old, idx_new])
if next(g, True) and not next(g, False):
return deltas
except Exception:
pass
if _check_lists_equal([idx_old, idx_new]):
return deltas
indices = np.intersect1d(idx_old, idx_new, assume_unique=True, return_indices=True)[1]
if len(indices) < len(idx_new):
raise Exception('Error in _reduce_deltas: Config of idx_new not in idx_old')
Expand Down Expand Up @@ -1650,3 +1632,18 @@ def _determine_gap(o, e_content, e_name):
raise Exception(f"Replica for ensemble {e_name} do not have a common spacing.", gaps)

return gap


def _check_lists_equal(idl):
'''
Use groupby to efficiently check whether all elements of idl are identical.
Returns True if all elements are equal, otherwise False.

Parameters
----------
idl : list of lists, ranges or np.ndarrays
'''
g = groupby([np.nditer(el) if isinstance(el, np.ndarray) else el for el in idl])
if next(g, True) and not next(g, False):
return True
return False
40 changes: 37 additions & 3 deletions tests/obs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,21 @@ def test_correlate():

def test_merge_idx():
assert pe.obs._merge_idx([range(10, 1010, 10), range(10, 1010, 50)]) == range(10, 1010, 10)
assert pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6250, 50)
assert isinstance(pe.obs._merge_idx([range(10, 1010, 10), range(10, 1010, 50)]), range)
assert pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6001, 50)
assert isinstance(pe.obs._merge_idx([range(500, 6050, 50), range(500, 6250, 250)]), range)
assert pe.obs._merge_idx([range(1, 1011, 2), range(1, 1010, 1)]) == range(1, 1010, 1)
assert isinstance(pe.obs._merge_idx([range(1, 1011, 2), range(1, 1010, 1)]), range)
assert pe.obs._merge_idx([range(1, 100, 2), range(2, 100, 2)]) == range(1, 100, 1)
assert isinstance(pe.obs._merge_idx([range(1, 100, 2), range(2, 100, 2)]), range)

for j in range(5):
idll = [range(1, int(round(np.random.uniform(300, 700))), int(round(np.random.uniform(1, 14)))) for i in range(10)]
assert pe.obs._merge_idx(idll) == sorted(set().union(*idll))

for j in range(5):
idll = [range(int(round(np.random.uniform(1, 28))), int(round(np.random.uniform(300, 700))), int(round(np.random.uniform(1, 14)))) for i in range(10)]
assert pe.obs._merge_idx(idll) == sorted(set().union(*idll))

idl = [list(np.arange(1, 14)) + list(range(16, 100, 4)), range(4, 604, 4), [2, 4, 5, 6, 8, 9, 12, 24], range(1, 20, 1), range(50, 789, 7)]
new_idx = pe.obs._merge_idx(idl)
Expand All @@ -550,10 +564,21 @@ def test_intersection_idx():
assert pe.obs._intersection_idx([range(1, 100), range(1, 100), range(1, 100)]) == range(1, 100)
assert pe.obs._intersection_idx([range(1, 100, 10), range(1, 100, 2)]) == range(1, 100, 10)
assert pe.obs._intersection_idx([range(10, 1010, 10), range(10, 1010, 50)]) == range(10, 1010, 50)
assert pe.obs._intersection_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6050, 250)
assert pe.obs._intersection_idx([range(500, 6050, 50), range(500, 6250, 250)]) == range(500, 6001, 250)
assert pe.obs._intersection_idx([range(1, 1011, 2), range(1, 1010, 1)]) == range(1, 1010, 2)
idll = [range(1, 100, 2), range(5, 105, 1)]
assert pe.obs._intersection_idx(idll) == range(5, 100, 2)
assert isinstance(pe.obs._intersection_idx(idll), range)
idll = [range(1, 100, 2), list(range(5, 105, 1))]
assert pe.obs._intersection_idx(idll) == range(5, 100, 2)
assert isinstance(pe.obs._intersection_idx(idll), range)

for ids in [[list(range(1, 80, 3)), list(range(1, 100, 2))], [range(1, 80, 3), range(1, 100, 2), range(1, 100, 7)]]:
assert list(pe.obs._intersection_idx(ids)) == pe.obs._intersection_idx([list(o) for o in ids])
interlist = pe.obs._intersection_idx([list(o) for o in ids])
listinter = list(pe.obs._intersection_idx(ids))
assert len(interlist) == len(listinter)
assert all([o in listinter for o in interlist])
assert all([o in interlist for o in listinter])


def test_merge_intersection():
Expand Down Expand Up @@ -733,6 +758,15 @@ def gen_autocorrelated_array(inarr, rho):
with pytest.raises(Exception):
my_obs.gm()

# check cases where tau is large compared to the chain length
N = 15
for i in range(10):
arr = np.random.normal(1, .2, size=N)
for rho in .1 * np.arange(20):
carr = gen_autocorrelated_array(arr, rho)
a = pe.Obs([carr], ['a'])
a.gm()


def test_irregular_gapped_dtauint():
my_idl = list(range(0, 5010, 10))
Expand Down