Skip to content

Commit cb09fe2

Browse files
erusseilEtienne Russeil
andauthored
Superluminous supernovae classifier for ZTF (#552)
* first commit superluminous ztf * Documentation * pep8 formating * Add extra dependencies for the tests * Fixed remove None * Removed useless imports * added header + minor changes * Only compute lc>30 days * fixed indentation consistency * fix indentation fr * Minor consistency changes * Updated classifier with more SLSN-I * Fixed tests * Fixed doctests * Added is_transient filtering * Code optimization * pep8 fix * All invalid alerts=-1 --------- Co-authored-by: Etienne Russeil <[email protected]>
1 parent 8cb1352 commit cb09fe2

File tree

6 files changed

+570
-0
lines changed

6 files changed

+570
-0
lines changed
226 KB
Binary file not shown.

fink_science/ztf/superluminous/__init__.py

Whitespace-only changes.
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright 2025 AstroLab Software
2+
# Author: Etienne Russeil
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import os
17+
from fink_science import __file__
18+
19+
curdir = os.path.dirname(os.path.abspath(__file__))
20+
21+
classifier_path = curdir + "/data/models/superluminous_classifier.joblib"
22+
band_wave_aa = {1: 4770.0, 2: 6231.0, 3: 7625.0}
23+
temperature = "sigmoid"
24+
bolometric = "bazin"
25+
min_points_total = 7
26+
min_points_perband = 3
27+
min_duration = 30
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
# Copyright 2020-2023 AstroLab Software
2+
# Author: Etienne Russeil
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
from line_profiler import profile
17+
from fink_science import __file__
18+
from pyspark.sql.functions import pandas_udf, PandasUDFType
19+
from pyspark.sql.types import DoubleType
20+
from fink_science.tester import spark_unit_tests
21+
import numpy as np
22+
import pandas as pd
23+
import fink_science.ztf.superluminous.slsn_classifier as slsn
24+
from fink_science.ztf.superluminous.kernel import classifier_path
25+
import joblib
26+
import os
27+
28+
29+
@pandas_udf(DoubleType())
30+
@profile
31+
def superluminous_score(
32+
cjd: pd.Series,
33+
cfid: pd.Series,
34+
cmagpsf: pd.Series,
35+
csigmapsf: pd.Series,
36+
distnr: pd.Series,
37+
is_transient: pd.Series,
38+
) -> pd.Series:
39+
"""High level spark wrapper for the superluminous classifier on ztf data
40+
41+
Parameters
42+
----------
43+
cjd: Spark DataFrame Column
44+
JD times (vectors of floats)
45+
cfid: Spark DataFrame Column
46+
Filter IDs (vectors of str)
47+
cmagpsf, csigmapsf: Spark DataFrame Columns
48+
Magnitude and magnitude error from photometry (vectors of floats)
49+
distnr: Spark DataFrame Column
50+
The angular distance to the nearest reference source.
51+
is_transient: Spark DataFrame Column
52+
Is the source likely a transient.
53+
54+
Returns
55+
-------
56+
np.array
57+
Superluminous supernovae classification probability vector
58+
Return 0 if not enough points were available for feature extraction
59+
60+
Examples
61+
--------
62+
>>> from fink_utils.spark.utils import concat_col
63+
>>> from pyspark.sql import functions as F
64+
>>> from fink_filters.ztf.filter_transient_complete.filter import transient_complete_filter
65+
>>> from fink_science.ztf.transient_features.processor import extract_transient_features
66+
>>> sdf = spark.read.load(ztf_alert_sample)
67+
>>> sdf = extract_transient_features(sdf)
68+
>>> sdf = sdf.withColumn(
69+
... "is_transient",
70+
... transient_complete_filter(
71+
... "faint", "positivesubtraction", "real", "pointunderneath",
72+
... "brightstar", "variablesource", "stationary", "roid"))
73+
74+
# Required alert columns
75+
>>> what = ['jd', 'fid', 'magpsf', 'sigmapsf']
76+
77+
# Use for creating temp name
78+
>>> prefix = 'c'
79+
>>> what_prefix = [prefix + i for i in what]
80+
81+
# Append temp columns with historical + current measurements
82+
>>> for colname in what:
83+
... sdf = concat_col(sdf, colname, prefix=prefix)
84+
85+
# Perform the fit + classification (default model)
86+
>>> args = [F.col(i) for i in what_prefix]
87+
>>> args += ["candidate.distnr", "is_transient"]
88+
>>> sdf = sdf.withColumn('proba', superluminous_score(*args))
89+
>>> sdf.filter(sdf['proba']==-1.0).count()
90+
57
91+
"""
92+
pdf = pd.DataFrame(
93+
{
94+
"cjd": cjd,
95+
"cmagpsf": cmagpsf,
96+
"csigmapsf": csigmapsf,
97+
"cfid": cfid,
98+
"distnr": distnr,
99+
"is_transient": is_transient,
100+
}
101+
)
102+
103+
# If no alert pass the transient filter,
104+
# directly return invalid value for everyone.
105+
if sum(pdf["is_transient"]) == 0:
106+
return pd.Series([-1.0]*len(pdf))
107+
108+
else:
109+
110+
# Initialise all probas to -1
111+
probas_total = np.zeros(len(pdf), dtype=float) - 1
112+
mask_valid = pdf["is_transient"]
113+
114+
# select only trasnient alerts
115+
pdf_valid = pdf[mask_valid]
116+
117+
# Assign default -1 proba for every valid alert
118+
probas = np.zeros(len(pdf_valid), dtype=float) - 1
119+
120+
pdf_valid = slsn.compute_flux(pdf_valid)
121+
pdf_valid = slsn.remove_nan(pdf_valid)
122+
123+
# Perform feature extraction
124+
features = slsn.extract_features(pdf_valid)
125+
126+
# Load classifier
127+
clf = joblib.load(classifier_path)
128+
129+
# Modify proba for alerts that were feature extracted
130+
extracted = np.sum(features.isnull(), axis=1) == 0
131+
probas[extracted] = clf.predict_proba(
132+
features.loc[extracted, clf.feature_names_in_]
133+
)[:, 1]
134+
135+
probas_total[mask_valid] = probas
136+
137+
return pd.Series(probas_total)
138+
139+
140+
if __name__ == "__main__":
141+
globs = globals()
142+
path = os.path.dirname(__file__)
143+
144+
ztf_alert_sample = "file://{}/data/alerts/datatest/part-00003-bdab8e46-89c4-4ac1-8603-facd71833e8a-c000.snappy.parquet".format(
145+
path
146+
)
147+
globs["ztf_alert_sample"] = ztf_alert_sample
148+
149+
# Run the test suite
150+
spark_unit_tests(globs)

0 commit comments

Comments
 (0)