Skip to content

Commit e94cd6e

Browse files
erusseilEtienne Russeil
andauthored
SLSN v2 [ISSUE #595] (#596)
* Now extract peak mag * also extract ra/dec * SLSN module update * SLSN module update * PEP * More rigourous definition of dust extinction * Fix doctest * Fix doctest * Try to fix doctest * Attempt to fix slsn doctest crash * Rerun PR --------- Co-authored-by: Etienne Russeil <[email protected]>
1 parent aeeaa30 commit e94cd6e

File tree

6 files changed

+365
-68
lines changed

6 files changed

+365
-68
lines changed
3.26 MB
Binary file not shown.
-134 KB
Binary file not shown.

fink_science/ztf/superluminous/kernel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@
2525
min_points_total = 7
2626
min_points_perband = 3
2727
min_duration = 20
28+
not_sl_threshold = -19.75

fink_science/ztf/superluminous/processor.py

Lines changed: 81 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222
import pandas as pd
2323
import fink_science.ztf.superluminous.slsn_classifier as slsn
24-
from fink_science.ztf.superluminous.kernel import classifier_path
24+
import fink_science.ztf.superluminous.kernel as kern
2525
import joblib
2626
import os
2727
import requests
@@ -92,12 +92,10 @@ def superluminous_score(
9292
>>> for colname in what:
9393
... sdf = concat_col(sdf, colname, prefix=prefix)
9494
95-
# Perform the fit + classification (default model)
9695
>>> args = ['is_transient', 'objectId', 'candidate.jdstarthist']
97-
98-
# Perform the fit + classification (default model)
9996
>>> args += [F.col(i) for i in what_prefix]
10097
98+
# Perform the fit + classification
10199
>>> sdf = sdf.withColumn('proba', superluminous_score(*args))
102100
>>> pdf = sdf.toPandas()
103101
>>> sum(pdf['proba']==-1)
@@ -179,31 +177,87 @@ def superluminous_score(
179177
else:
180178
lcs[field] = np.array(combined_values)
181179

182-
# FIXME: why lcs would be None here?
183-
if lcs is not None:
184-
# Assign default -1 proba for every valid alert
185-
probas = np.zeros(len(pdf_valid), dtype=float) - 1
180+
# Assign default -1 proba for every valid alert
181+
probas = np.zeros(len(pdf_valid), dtype=float) - 1
186182

187-
lcs = slsn.compute_flux(lcs)
188-
lcs = slsn.remove_nan(lcs)
183+
lcs = slsn.compute_flux(lcs)
184+
lcs = slsn.remove_nan(lcs)
189185

190-
# Perform feature extraction
191-
features = slsn.extract_features(lcs)
186+
# Perform feature extraction
187+
features = slsn.extract_features(lcs)
192188

193-
# Load classifier
194-
clf = joblib.load(classifier_path)
189+
# Load classifier
190+
clf = joblib.load(kern.classifier_path)
195191

196-
# Modify proba for alerts that were feature extracted
197-
extracted = np.sum(features.isna(), axis=1) == 0
198-
probas[extracted] = clf.predict_proba(
199-
features.loc[extracted, clf.feature_names_in_]
200-
)[:, 1]
192+
# Compute proba for alerts that were feature extracted
193+
extracted = np.sum(features.isna(), axis=1) == 0
194+
probas[extracted] = clf.predict_proba(
195+
features.loc[extracted, clf.feature_names_in_]
196+
)[:, 1]
201197

202-
probas_total[mask_valid] = probas
203-
return pd.Series(probas_total)
198+
# Mask only alerts classified as SLSN
199+
mask_is_SLSN = probas > clf.optimal_threshold
204200

205-
else:
206-
return pd.Series([-1.0] * len(objectId))
201+
# Check the SDSS photo-z for these alerts
202+
SLSN_features = features[mask_is_SLSN].copy()
203+
204+
if len(SLSN_features) > 0:
205+
SLSN_features["objectId"] = lcs.loc[mask_is_SLSN, "objectId"]
206+
SLSN_features = slsn.add_all_photoz(SLSN_features)
207+
208+
# Compute upper bound for abs magnitude
209+
upper_M = np.array(
210+
SLSN_features.apply(
211+
lambda x: slsn.abs_peak(
212+
x["peak_mag"], x["photoz"], x["photozerr"], x["ebv"]
213+
)[2],
214+
axis=1,
215+
)
216+
)
217+
218+
# Sources clearly not SL are masked
219+
mask_not_SL = upper_M > kern.not_sl_threshold
220+
zero_proba_idx = SLSN_features[mask_not_SL].index
221+
222+
# And have their probabilities put to 0.
223+
probas[zero_proba_idx] = 0
224+
225+
# Apply the proba computed for valid sources
226+
probas_total[mask_valid] = probas
227+
228+
return pd.Series(probas_total)
229+
230+
231+
def protected_mean(arr):
232+
"""Returns the mean protected in case of only Nans.
233+
234+
Parameters
235+
----------
236+
arr: np.array
237+
238+
Returns
239+
-------
240+
float
241+
Mean of the list. Or 0 if the list is made
242+
of Nans/Nones only.
243+
244+
Example
245+
-------
246+
>>> protected_mean(np.array([10., 20]))
247+
15.0
248+
>>> protected_mean(np.array([10, 20., None]))
249+
15.0
250+
>>> protected_mean(np.array([None, None]))
251+
0.0
252+
"""
253+
# Keep only numerical values
254+
mask = [type(element) is not type(None) for element in arr]
255+
new_arr = arr[mask]
256+
257+
if len(new_arr) > 0:
258+
return np.nanmean(new_arr)
259+
260+
return 0.0
207261

208262

209263
def get_and_format(ZTF_name):
@@ -233,7 +287,7 @@ def get_and_format(ZTF_name):
233287
>>> data = get_and_format(["ZTF21abfmbix", "ZTF21abfmbix"])
234288
>>> (data["distnr"].iloc[0] > 3.0) & (data["distnr"].iloc[0] < 3.5)
235289
True
236-
>>> list(data.columns) == ['objectId', 'cjd', 'cmagpsf', 'csigmapsf', 'cfid', 'distnr']
290+
>>> list(data.columns) == ['objectId', 'ra', 'dec', 'cjd', 'cmagpsf', 'csigmapsf', 'cfid', 'distnr']
237291
True
238292
>>> len(data['cjd'].iloc[0]) >= 14
239293
True
@@ -252,7 +306,7 @@ def get_and_format(ZTF_name):
252306
"https://api.fink-portal.org/api/v1/objects",
253307
json={
254308
"objectId": name,
255-
"columns": "i:objectId,i:jd,i:magpsf,i:sigmapsf,i:fid,i:jd,i:distnr,d:tag",
309+
"columns": "i:objectId,i:jd,i:magpsf,i:sigmapsf,i:fid,i:jd,i:distnr,d:tag,i:ra,i:dec",
256310
"output-format": "json",
257311
"withupperlim": "True",
258312
},
@@ -281,6 +335,8 @@ def get_and_format(ZTF_name):
281335
lcs = pd.DataFrame(
282336
data={
283337
"objectId": [lc["i:objectId"].iloc[0] for lc in pdfs],
338+
"ra": [protected_mean(lc["i:ra"]) for lc in pdfs],
339+
"dec": [protected_mean(lc["i:dec"]) for lc in pdfs],
284340
"cjd": [np.array(lc["i:jd"].values, dtype=float) for lc in pdfs],
285341
"cmagpsf": [
286342
np.array(lc["i:magpsf"].values, dtype=float) for lc in pdfs

0 commit comments

Comments
 (0)