2121import numpy as np
2222import pandas as pd
2323import fink_science .ztf .superluminous .slsn_classifier as slsn
24- from fink_science .ztf .superluminous .kernel import classifier_path
24+ import fink_science .ztf .superluminous .kernel as kern
2525import joblib
2626import os
2727import requests
@@ -92,12 +92,10 @@ def superluminous_score(
9292 >>> for colname in what:
9393 ... sdf = concat_col(sdf, colname, prefix=prefix)
9494
95- # Perform the fit + classification (default model)
9695 >>> args = ['is_transient', 'objectId', 'candidate.jdstarthist']
97-
98- # Perform the fit + classification (default model)
9996 >>> args += [F.col(i) for i in what_prefix]
10097
98+ # Perform the fit + classification
10199 >>> sdf = sdf.withColumn('proba', superluminous_score(*args))
102100 >>> pdf = sdf.toPandas()
103101 >>> sum(pdf['proba']==-1)
@@ -179,31 +177,87 @@ def superluminous_score(
179177 else :
180178 lcs [field ] = np .array (combined_values )
181179
182- # FIXME: why lcs would be None here?
183- if lcs is not None :
184- # Assign default -1 proba for every valid alert
185- probas = np .zeros (len (pdf_valid ), dtype = float ) - 1
180+ # Assign default -1 proba for every valid alert
181+ probas = np .zeros (len (pdf_valid ), dtype = float ) - 1
186182
187- lcs = slsn .compute_flux (lcs )
188- lcs = slsn .remove_nan (lcs )
183+ lcs = slsn .compute_flux (lcs )
184+ lcs = slsn .remove_nan (lcs )
189185
190- # Perform feature extraction
191- features = slsn .extract_features (lcs )
186+ # Perform feature extraction
187+ features = slsn .extract_features (lcs )
192188
193- # Load classifier
194- clf = joblib .load (classifier_path )
189+ # Load classifier
190+ clf = joblib .load (kern . classifier_path )
195191
196- # Modify proba for alerts that were feature extracted
197- extracted = np .sum (features .isna (), axis = 1 ) == 0
198- probas [extracted ] = clf .predict_proba (
199- features .loc [extracted , clf .feature_names_in_ ]
200- )[:, 1 ]
192+ # Compute proba for alerts that were feature extracted
193+ extracted = np .sum (features .isna (), axis = 1 ) == 0
194+ probas [extracted ] = clf .predict_proba (
195+ features .loc [extracted , clf .feature_names_in_ ]
196+ )[:, 1 ]
201197
202- probas_total [ mask_valid ] = probas
203- return pd . Series ( probas_total )
198+ # Mask only alerts classified as SLSN
199+ mask_is_SLSN = probas > clf . optimal_threshold
204200
205- else :
206- return pd .Series ([- 1.0 ] * len (objectId ))
201+ # Check the SDSS photo-z for these alerts
202+ SLSN_features = features [mask_is_SLSN ].copy ()
203+
204+ if len (SLSN_features ) > 0 :
205+ SLSN_features ["objectId" ] = lcs .loc [mask_is_SLSN , "objectId" ]
206+ SLSN_features = slsn .add_all_photoz (SLSN_features )
207+
208+ # Compute upper bound for abs magnitude
209+ upper_M = np .array (
210+ SLSN_features .apply (
211+ lambda x : slsn .abs_peak (
212+ x ["peak_mag" ], x ["photoz" ], x ["photozerr" ], x ["ebv" ]
213+ )[2 ],
214+ axis = 1 ,
215+ )
216+ )
217+
218+ # Sources clearly not SL are masked
219+ mask_not_SL = upper_M > kern .not_sl_threshold
220+ zero_proba_idx = SLSN_features [mask_not_SL ].index
221+
222+ # And have their probabilities put to 0.
223+ probas [zero_proba_idx ] = 0
224+
225+ # Apply the proba computed for valid sources
226+ probas_total [mask_valid ] = probas
227+
228+ return pd .Series (probas_total )
229+
230+
231+ def protected_mean (arr ):
232+ """Returns the mean protected in case of only Nans.
233+
234+ Parameters
235+ ----------
236+ arr: np.array
237+
238+ Returns
239+ -------
240+ float
241+ Mean of the list. Or 0 if the list is made
242+ of Nans/Nones only.
243+
244+ Example
245+ -------
246+ >>> protected_mean(np.array([10., 20]))
247+ 15.0
248+ >>> protected_mean(np.array([10, 20., None]))
249+ 15.0
250+ >>> protected_mean(np.array([None, None]))
251+ 0.0
252+ """
253+ # Keep only numerical values
254+ mask = [type (element ) is not type (None ) for element in arr ]
255+ new_arr = arr [mask ]
256+
257+ if len (new_arr ) > 0 :
258+ return np .nanmean (new_arr )
259+
260+ return 0.0
207261
208262
209263def get_and_format (ZTF_name ):
@@ -233,7 +287,7 @@ def get_and_format(ZTF_name):
233287 >>> data = get_and_format(["ZTF21abfmbix", "ZTF21abfmbix"])
234288 >>> (data["distnr"].iloc[0] > 3.0) & (data["distnr"].iloc[0] < 3.5)
235289 True
236- >>> list(data.columns) == ['objectId', 'cjd', 'cmagpsf', 'csigmapsf', 'cfid', 'distnr']
290+ >>> list(data.columns) == ['objectId', 'ra', 'dec', ' cjd', 'cmagpsf', 'csigmapsf', 'cfid', 'distnr']
237291 True
238292 >>> len(data['cjd'].iloc[0]) >= 14
239293 True
@@ -252,7 +306,7 @@ def get_and_format(ZTF_name):
252306 "https://api.fink-portal.org/api/v1/objects" ,
253307 json = {
254308 "objectId" : name ,
255- "columns" : "i:objectId,i:jd,i:magpsf,i:sigmapsf,i:fid,i:jd,i:distnr,d:tag" ,
309+ "columns" : "i:objectId,i:jd,i:magpsf,i:sigmapsf,i:fid,i:jd,i:distnr,d:tag,i:ra,i:dec " ,
256310 "output-format" : "json" ,
257311 "withupperlim" : "True" ,
258312 },
@@ -281,6 +335,8 @@ def get_and_format(ZTF_name):
281335 lcs = pd .DataFrame (
282336 data = {
283337 "objectId" : [lc ["i:objectId" ].iloc [0 ] for lc in pdfs ],
338+ "ra" : [protected_mean (lc ["i:ra" ]) for lc in pdfs ],
339+ "dec" : [protected_mean (lc ["i:dec" ]) for lc in pdfs ],
284340 "cjd" : [np .array (lc ["i:jd" ].values , dtype = float ) for lc in pdfs ],
285341 "cmagpsf" : [
286342 np .array (lc ["i:magpsf" ].values , dtype = float ) for lc in pdfs
0 commit comments