Skip to content

Commit 63b2f4f

Browse files
authored
Merge pull request #135 from fjosw/fix/complex_Corr
Fix/complex corr
2 parents 313dec7 + 6343968 commit 63b2f4f

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

pyerrors/correlators.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,8 @@ def __repr__(self, print_range=None):
969969
content_string += "Description: " + self.tag + "\n"
970970
if self.N != 1:
971971
return content_string
972+
if isinstance(self[0], CObs):
973+
return content_string
972974

973975
if print_range[1]:
974976
print_range[1] += 1
@@ -1136,8 +1138,10 @@ def _apply_func_to_corr(self, func):
11361138
for t in range(self.T):
11371139
if _check_for_none(self, newcontent[t]):
11381140
continue
1139-
if np.isnan(np.sum(newcontent[t]).value):
1140-
newcontent[t] = None
1141+
tmp_sum = np.sum(newcontent[t])
1142+
if hasattr(tmp_sum, "value"):
1143+
if np.isnan(tmp_sum.value):
1144+
newcontent[t] = None
11411145
if all([item is None for item in newcontent]):
11421146
raise Exception('Operation returns undefined correlator')
11431147
return Corr(newcontent)
@@ -1194,8 +1198,8 @@ def __rtruediv__(self, y):
11941198
@property
11951199
def real(self):
11961200
def return_real(obs_OR_cobs):
1197-
if isinstance(obs_OR_cobs, CObs):
1198-
return obs_OR_cobs.real
1201+
if isinstance(obs_OR_cobs.flatten()[0], CObs):
1202+
return np.vectorize(lambda x: x.real)(obs_OR_cobs)
11991203
else:
12001204
return obs_OR_cobs
12011205

@@ -1204,8 +1208,8 @@ def return_real(obs_OR_cobs):
12041208
@property
12051209
def imag(self):
12061210
def return_imag(obs_OR_cobs):
1207-
if isinstance(obs_OR_cobs, CObs):
1208-
return obs_OR_cobs.imag
1211+
if isinstance(obs_OR_cobs.flatten()[0], CObs):
1212+
return np.vectorize(lambda x: x.imag)(obs_OR_cobs)
12091213
else:
12101214
return obs_OR_cobs * 0 # So it stays the right type
12111215

tests/correlators_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,3 +532,13 @@ def test_prune():
532532
with pytest.raises(Exception):
533533
corr_mat.prune(3)
534534
corr_mat.prune(4)
535+
536+
537+
def test_complex_Corr():
538+
o1 = pe.pseudo_Obs(1.0, 0.1, "test")
539+
cobs = pe.CObs(o1, -o1)
540+
ccorr = pe.Corr([cobs, cobs, cobs])
541+
assert np.all([ccorr.imag[i] == -ccorr.real[i] for i in range(ccorr.T)])
542+
print(ccorr)
543+
mcorr = pe.Corr(np.array([[ccorr, ccorr], [ccorr, ccorr]]))
544+
assert np.all([mcorr.imag[i] == -mcorr.real[i] for i in range(mcorr.T)])

0 commit comments

Comments
 (0)