Skip to content

Commit 2e66f03

Browse files
authored
Merge pull request #147 from fjosw/fix/non_overlapping_cnfgs
Fix non overlapping configurations
2 parents 3236ba5 + 37c59a1 commit 2e66f03

File tree

4 files changed

+127
-30
lines changed

4 files changed

+127
-30
lines changed

pyerrors/input/dobs.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -397,7 +397,7 @@ def read_pobs(fname, full_output=False, gz=True, separator_insertion=None):
397397

398398

399399
# this is based on Mattia Bruno's implementation at https://github.com/mbruno46/pyobs/blob/master/pyobs/IO/xml.py
400-
def import_dobs_string(content, noempty=False, full_output=False, separator_insertion=True):
400+
def import_dobs_string(content, full_output=False, separator_insertion=True):
401401
"""Import a list of Obs from a string in the Zeuthen dobs format.
402402
403403
Tags are not written or recovered automatically.
@@ -406,9 +406,6 @@ def import_dobs_string(content, noempty=False, full_output=False, separator_inse
406406
----------
407407
content : str
408408
XML string containing the data
409-
noemtpy : bool
410-
If True, ensembles with no contribution to the Obs are not included.
411-
If False, ensembles are included as written in the file, possibly with vanishing entries.
412409
full_output : bool
413410
If True, a dict containing auxiliary information and the data is returned.
414411
If False, only the data is returned as list.
@@ -457,7 +454,6 @@ def import_dobs_string(content, noempty=False, full_output=False, separator_inse
457454
_check(dobs[4].tag == "ne")
458455
ne = int(dobs[4].text.strip())
459456
_check(dobs[5].tag == "nc")
460-
nc = int(dobs[5].text.strip())
461457

462458
idld = {}
463459
deltad = {}
@@ -507,7 +503,11 @@ def import_dobs_string(content, noempty=False, full_output=False, separator_inse
507503

508504
for name in names:
509505
for i in range(len(deltad[name])):
510-
deltad[name][i] = np.array(deltad[name][i]) + mean[i]
506+
tmp = np.zeros_like(deltad[name][i])
507+
for j in range(len(deltad[name][i])):
508+
if deltad[name][i][j] != 0.:
509+
tmp[j] = deltad[name][i][j] + mean[i]
510+
deltad[name][i] = tmp
511511

512512
res = []
513513
for i in range(len(mean)):
@@ -516,25 +516,30 @@ def import_dobs_string(content, noempty=False, full_output=False, separator_inse
516516
obs_names = []
517517
for name in names:
518518
h = np.unique(deltad[name][i])
519-
if len(h) == 1 and np.all(h == mean[i]) and noempty:
519+
if len(h) == 1 and np.all(h == mean[i]):
520520
continue
521-
deltas.append(deltad[name][i])
522-
obs_names.append(name)
523-
idl.append(idld[name])
521+
repdeltas = []
522+
repidl = []
523+
for j in range(len(deltad[name][i])):
524+
if deltad[name][i][j] != 0.:
525+
repdeltas.append(deltad[name][i][j])
526+
repidl.append(idld[name][j])
527+
if len(repdeltas) > 0:
528+
obs_names.append(name)
529+
deltas.append(repdeltas)
530+
idl.append(repidl)
531+
524532
res.append(Obs(deltas, obs_names, idl=idl))
525533
res[-1]._value = mean[i]
526534
_check(len(e_names) == ne)
527535

528536
cnames = list(covd.keys())
529537
for i in range(len(res)):
530538
new_covobs = {name: Covobs(0, covd[name], name, grad=gradd[name][i]) for name in cnames}
531-
if noempty:
532-
for name in cnames:
533-
if np.all(new_covobs[name].grad == 0):
534-
del new_covobs[name]
535-
cnames_loc = list(new_covobs.keys())
536-
else:
537-
cnames_loc = cnames
539+
for name in cnames:
540+
if np.all(new_covobs[name].grad == 0):
541+
del new_covobs[name]
542+
cnames_loc = list(new_covobs.keys())
538543
for name in cnames_loc:
539544
res[i].names.append(name)
540545
res[i].shape[name] = 1
@@ -546,8 +551,6 @@ def import_dobs_string(content, noempty=False, full_output=False, separator_inse
546551
res[i].tag = symbol[i]
547552
if res[i].tag == 'None':
548553
res[i].tag = None
549-
if not noempty:
550-
_check(len(res[0].covobs.keys()) == nc)
551554
if full_output:
552555
retd = {}
553556
tool = file_origin.get('tool', None)
@@ -568,7 +571,7 @@ def import_dobs_string(content, noempty=False, full_output=False, separator_inse
568571
return res
569572

570573

571-
def read_dobs(fname, noempty=False, full_output=False, gz=True, separator_insertion=True):
574+
def read_dobs(fname, full_output=False, gz=True, separator_insertion=True):
572575
"""Import a list of Obs from an xml.gz file in the Zeuthen dobs format.
573576
574577
Tags are not written or recovered automatically.
@@ -577,9 +580,6 @@ def read_dobs(fname, noempty=False, full_output=False, gz=True, separator_insert
577580
----------
578581
fname : str
579582
Filename of the input file.
580-
noemtpy : bool
581-
If True, ensembles with no contribution to the Obs are not included.
582-
If False, ensembles are included as written in the file.
583583
full_output : bool
584584
If True, a dict containing auxiliary information and the data is returned.
585585
If False, only the data is returned as list.
@@ -615,7 +615,7 @@ def read_dobs(fname, noempty=False, full_output=False, gz=True, separator_insert
615615
with open(fname, 'r') as fin:
616616
content = fin.read()
617617

618-
return import_dobs_string(content, noempty, full_output, separator_insertion=separator_insertion)
618+
return import_dobs_string(content, full_output, separator_insertion=separator_insertion)
619619

620620

621621
def _dobsdict_to_xmlstring(d):
@@ -782,7 +782,7 @@ def create_dobs_string(obsl, name, spec='dobs v1.0', origin='', symbol=[], who=N
782782
o = obsl[oi]
783783
if repname in o.idl:
784784
if counters[oi] < 0:
785-
num = offsets[oi]
785+
num = 0
786786
if num == 0:
787787
data += '0 '
788788
else:
@@ -798,7 +798,7 @@ def create_dobs_string(obsl, name, spec='dobs v1.0', origin='', symbol=[], who=N
798798
if counters[oi] >= len(o.idl[repname]):
799799
counters[oi] = -1
800800
else:
801-
num = offsets[oi]
801+
num = 0
802802
if num == 0:
803803
data += '0 '
804804
else:

pyerrors/obs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1097,7 +1097,7 @@ def _expand_deltas_for_merge(deltas, idx, shape, new_idx):
10971097
ret = np.zeros(new_idx[-1] - new_idx[0] + 1)
10981098
for i in range(shape):
10991099
ret[idx[i] - new_idx[0]] = deltas[i]
1100-
return np.array([ret[new_idx[i] - new_idx[0]] for i in range(len(new_idx))])
1100+
return np.array([ret[new_idx[i] - new_idx[0]] for i in range(len(new_idx))]) * len(new_idx) / len(idx)
11011101

11021102

11031103
def derived_observable(func, data, array_mode=False, **kwargs):

tests/json_io_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,7 @@ def test_dobsio():
339339

340340
dobsio.write_dobs(ol, fname, 'TEST')
341341

342-
rl = dobsio.read_dobs(fname, noempty=True)
342+
rl = dobsio.read_dobs(fname)
343343
os.remove(fname + '.xml.gz')
344344
[o.gamma_method() for o in rl]
345345

tests/obs_test.py

Lines changed: 99 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def test_intersection_reduce():
566566
intersection = pe.obs._intersection_idx([o.idl["ens"] for o in [obs1, obs_merge]])
567567
coll = pe.obs._reduce_deltas(obs_merge.deltas["ens"], obs_merge.idl["ens"], range1)
568568

569-
assert np.all(coll == obs1.deltas["ens"])
569+
assert np.allclose(coll, obs1.deltas["ens"] * (len(obs_merge.idl["ens"]) / len(range1)))
570570

571571

572572
def test_irregular_error_propagation():
@@ -878,7 +878,7 @@ def test_correlation_intersection_of_idls():
878878
cov1 = pe.covariance([obs1, obs2_a])
879879
corr1 = pe.covariance([obs1, obs2_a], correlation=True)
880880

881-
obs2_b = obs2_a + pe.Obs([np.random.normal(1.0, 0.1, len(range2))], ["ens"], idl=[range2])
881+
obs2_b = (obs2_a + pe.Obs([np.random.normal(1.0, 0.1, len(range2))], ["ens"], idl=[range2])) / 2
882882
obs2_b.gamma_method()
883883

884884
cov2 = pe.covariance([obs1, obs2_b])
@@ -1038,6 +1038,7 @@ def test_hash():
10381038
assert hash(obs) != hash(o1)
10391039
assert hash(o1) != hash(o2)
10401040

1041+
10411042
def test_gm_alias():
10421043
samples = np.random.rand(500)
10431044

@@ -1049,3 +1050,99 @@ def test_gm_alias():
10491050

10501051
assert np.isclose(tt1.dvalue, tt2.dvalue)
10511052

1053+
1054+
def test_overlapping_missing_cnfgs():
1055+
length = 200000
1056+
1057+
l_samp = np.random.normal(2.87, 0.5, length)
1058+
s_samp = np.random.normal(7.87, 0.7, length // 2)
1059+
1060+
o1 = pe.Obs([l_samp], ["test"])
1061+
o2 = pe.Obs([s_samp], ["test"], idl=[range(1, length, 2)])
1062+
1063+
a2 = pe.Obs([s_samp], ["alt"])
1064+
t1 = o1 + o2
1065+
t1.gm(S=0)
1066+
1067+
t2 = o1 + a2
1068+
t2.gm(S=0)
1069+
assert np.isclose(t1.value, t2.value)
1070+
assert np.isclose(t1.dvalue, t2.dvalue, rtol=0.01)
1071+
1072+
1073+
def test_non_overlapping_missing_cnfgs():
1074+
length = 100000
1075+
1076+
xsamp = np.random.normal(1.0, 1.0, length)
1077+
1078+
1079+
full = pe.Obs([xsamp], ["ensemble"], idl=[range(0, length)])
1080+
full.gm()
1081+
1082+
even = pe.Obs([xsamp[0:length:2]], ["ensemble"], idl=[range(0, length, 2)])
1083+
odd = pe.Obs([xsamp[1:length:2]], ["ensemble"], idl=[range(1, length, 2)])
1084+
1085+
average = (even + odd) / 2
1086+
average.gm(S=0)
1087+
assert np.isclose(full.value, average.value)
1088+
assert np.isclose(full.dvalue, average.dvalue, rtol=0.01)
1089+
1090+
1091+
def test_non_overlapping_operations():
1092+
length = 100000
1093+
1094+
samples = np.random.normal(0.93, 0.5, length)
1095+
1096+
e = pe.Obs([samples[0:length:2]], ["ensemble"], idl=[range(0, length, 2)])
1097+
o = pe.Obs([samples[1:length:2]], ["ensemble"], idl=[range(1, length, 2)])
1098+
1099+
1100+
e2 = pe.Obs([samples[0:length:2]], ["even"])
1101+
o2 = pe.Obs([samples[1:length:2]], ["odd"])
1102+
1103+
for func in [lambda a, b: a + b,
1104+
lambda a, b: a - b,
1105+
lambda a, b: a * b,
1106+
lambda a, b: a / b,
1107+
lambda a, b: a ** b]:
1108+
1109+
res1 = func(e, o)
1110+
res1.gm(S=0)
1111+
res2 = func(e2, o2)
1112+
res2.gm(S=0)
1113+
1114+
print(res1, res2)
1115+
print((res1.dvalue - res2.dvalue) / res1.dvalue)
1116+
1117+
assert np.isclose(res1.value, res2.value)
1118+
assert np.isclose(res1.dvalue, res2.dvalue, rtol=0.01)
1119+
1120+
1121+
def test_non_overlapping_operations_different_lengths():
1122+
length = 100000
1123+
1124+
samples = np.random.normal(0.93, 0.5, length)
1125+
first = samples[:length // 5]
1126+
second = samples[length // 5:]
1127+
1128+
f1 = pe.Obs([first], ["ensemble"], idl=[range(1, length // 5 + 1)])
1129+
s1 = pe.Obs([second], ["ensemble"], idl=[range(length // 5, length)])
1130+
1131+
1132+
f2 = pe.Obs([first], ["first"])
1133+
s2 = pe.Obs([second], ["second"])
1134+
1135+
for func in [lambda a, b: a + b,
1136+
lambda a, b: a - b,
1137+
lambda a, b: a * b,
1138+
lambda a, b: a / b,
1139+
lambda a, b: a ** b,
1140+
lambda a, b: a ** 2 + b ** 2 / a]:
1141+
1142+
res1 = func(f1, f1)
1143+
res1.gm(S=0)
1144+
res2 = func(f2, f2)
1145+
res2.gm(S=0)
1146+
1147+
assert np.isclose(res1.value, res2.value)
1148+
assert np.isclose(res1.dvalue, res2.dvalue, rtol=0.01)

0 commit comments

Comments
 (0)