Skip to content

Commit e586d56

Browse files
authored
SPARKNLP-835: ProtectedParam and ProtectedFeature (#13797)
* SPARKNLP-835: Finalize protected Features * SPARKNLP-835: Remove redundant checks for protected Features * SPARKNLP-835: Introduce ProtectedParam * SPARKNLP-835: Resolve encoding/decoding issue for HasProtectedParams * SPARKNLP-835: Make caseSensitive settable * SPARKNLP-835: Make maxSentenceLength, batchSize settable * SPARKNLP-835: Enable protected Params for Annotators
1 parent 8ed57e4 commit e586d56

File tree

61 files changed

+263
-295
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+263
-295
lines changed
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package com.johnsnowlabs.nlp
2+
3+
import org.apache.spark.ml.param.{Param, Params}
4+
5+
/** Enables a class to protect a parameter, which means that it can only be set once.
6+
*
7+
* This trait will enable a implicit conversion from Param to ProtectedParam. In addition, the
8+
* new set for ProtectedParam will then check, whether or not the value was already set. If so,
9+
* then a warning will be output and the value will not be set again.
10+
*/
11+
trait HasProtectedParams {
12+
this: Params =>
13+
implicit class ProtectedParam[T](baseParam: Param[T])
14+
extends Param[T](baseParam.parent, baseParam.name, baseParam.doc, baseParam.isValid) {
15+
16+
var isProtected = false
17+
18+
/** Sets this parameter to be protected, which means that it can only be set once.
19+
*
20+
* Default values do not count as a set value and can be overridden.
21+
*
22+
* @return
23+
* This object
24+
*/
25+
def setProtected(): this.type = {
26+
isProtected = true
27+
this
28+
}
29+
30+
def toParam: Param[T] = this.asInstanceOf[Param[T]]
31+
32+
// Overrides needed for individual Param implementation
33+
override def jsonEncode(value: T): String = baseParam.jsonEncode(value)
34+
override def jsonDecode(json: String): T = baseParam.jsonDecode(json)
35+
}
36+
37+
/** Sets the value for a protected Param.
38+
*
39+
* If the parameter was already set, it will not be set again. Default values do not count as a
40+
* set value and can be overridden.
41+
*
42+
* @param param
43+
* Protected parameter to set
44+
* @param value
45+
* Value for the parameter
46+
* @tparam T
47+
* Type of the parameter
48+
* @return
49+
* This object
50+
*/
51+
def set[T](param: ProtectedParam[T], value: T): this.type = {
52+
if (param.isProtected && get(param).isDefined)
53+
println(
54+
s"Warning: The parameter ${param.name} is protected and can only be set once." +
55+
" For a pretrained model, this was done during the initialization process." +
56+
" If you are trying to train your own model, please check the documentation." +
57+
" If this is intentional, set the parameter directly with set(annotator.param, value).")
58+
else
59+
set(param.toParam, value)
60+
this
61+
}
62+
}

src/main/scala/com/johnsnowlabs/nlp/annotators/audio/Wav2Vec2ForCTC.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -189,8 +189,7 @@ class Wav2Vec2ForCTC(override val uid: String)
189189

190190
/** @group setParam */
191191
def setSignatures(value: Map[String, String]): this.type = {
192-
if (get(signatures).isEmpty)
193-
set(signatures, value)
192+
set(signatures, value)
194193
this
195194
}
196195

src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForQuestionAnswering.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,7 @@ class AlbertForQuestionAnswering(override val uid: String)
184184

185185
/** @group setParam */
186186
def setSignatures(value: Map[String, String]): this.type = {
187-
if (get(signatures).isEmpty)
188-
set(signatures, value)
187+
set(signatures, value)
189188
this
190189
}
191190

src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForSequenceClassification.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,7 @@ class AlbertForSequenceClassification(override val uid: String)
227227

228228
/** @group setParam */
229229
def setSignatures(value: Map[String, String]): this.type = {
230-
if (get(signatures).isEmpty)
231-
set(signatures, value)
230+
set(signatures, value)
232231
this
233232
}
234233

@@ -265,9 +264,7 @@ class AlbertForSequenceClassification(override val uid: String)
265264
* @group setParam
266265
*/
267266
override def setCaseSensitive(value: Boolean): this.type = {
268-
if (get(caseSensitive).isEmpty)
269-
set(this.caseSensitive, value)
270-
this
267+
set(this.caseSensitive, value)
271268
}
272269

273270
setDefault(

src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/AlbertForTokenClassification.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,7 @@ class AlbertForTokenClassification(override val uid: String)
205205

206206
/** @group setParam */
207207
def setSignatures(value: Map[String, String]): this.type = {
208-
if (get(signatures).isEmpty)
209-
set(signatures, value)
208+
set(signatures, value)
210209
this
211210
}
212211

@@ -242,9 +241,7 @@ class AlbertForTokenClassification(override val uid: String)
242241
* @group setParam
243242
*/
244243
override def setCaseSensitive(value: Boolean): this.type = {
245-
if (get(caseSensitive).isEmpty)
246-
set(this.caseSensitive, value)
247-
this
244+
set(this.caseSensitive, value)
248245
}
249246

250247
setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> false)

src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForQuestionAnswering.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,7 @@ class BertForQuestionAnswering(override val uid: String)
198198

199199
/** @group setParam */
200200
def setSignatures(value: Map[String, String]): this.type = {
201-
if (get(signatures).isEmpty)
202-
set(signatures, value)
201+
set(signatures, value)
203202
this
204203
}
205204

src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForSequenceClassification.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,7 @@ class BertForSequenceClassification(override val uid: String)
243243

244244
/** @group setParam */
245245
def setSignatures(value: Map[String, String]): this.type = {
246-
if (get(signatures).isEmpty)
247-
set(signatures, value)
246+
set(signatures, value)
248247
this
249248
}
250249

@@ -282,9 +281,7 @@ class BertForSequenceClassification(override val uid: String)
282281
* @group setParam
283282
*/
284283
override def setCaseSensitive(value: Boolean): this.type = {
285-
if (get(caseSensitive).isEmpty)
286-
set(this.caseSensitive, value)
287-
this
284+
set(this.caseSensitive, value)
288285
}
289286

290287
setDefault(

src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForTokenClassification.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,7 @@ class BertForTokenClassification(override val uid: String)
217217

218218
/** @group setParam */
219219
def setSignatures(value: Map[String, String]): this.type = {
220-
if (get(signatures).isEmpty)
221-
set(signatures, value)
220+
set(signatures, value)
222221
this
223222
}
224223

@@ -255,9 +254,7 @@ class BertForTokenClassification(override val uid: String)
255254
* @group setParam
256255
*/
257256
override def setCaseSensitive(value: Boolean): this.type = {
258-
if (get(caseSensitive).isEmpty)
259-
set(this.caseSensitive, value)
260-
this
257+
set(this.caseSensitive, value)
261258
}
262259

263260
setDefault(batchSize -> 8, maxSentenceLength -> 128, caseSensitive -> true)

src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/BertForZeroShotClassification.scala

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,7 @@ class BertForZeroShotClassification(override val uid: String)
167167

168168
/** @group setParam */
169169
def setVocabulary(value: Map[String, Int]): this.type = {
170-
if (get(vocabulary).isEmpty)
171-
set(vocabulary, value)
170+
set(vocabulary, value)
172171
this
173172
}
174173

@@ -256,8 +255,7 @@ class BertForZeroShotClassification(override val uid: String)
256255

257256
/** @group setParam */
258257
def setSignatures(value: Map[String, String]): this.type = {
259-
if (get(signatures).isEmpty)
260-
set(signatures, value)
258+
set(signatures, value)
261259
this
262260
}
263261

@@ -295,9 +293,7 @@ class BertForZeroShotClassification(override val uid: String)
295293
* @group setParam
296294
*/
297295
override def setCaseSensitive(value: Boolean): this.type = {
298-
if (get(caseSensitive).isEmpty)
299-
set(this.caseSensitive, value)
300-
this
296+
set(this.caseSensitive, value)
301297
}
302298

303299
setDefault(

src/main/scala/com/johnsnowlabs/nlp/annotators/classifier/dl/CamemBertForQuestionAnswering.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,7 @@ class CamemBertForQuestionAnswering(override val uid: String)
184184

185185
/** @group setParam */
186186
def setSignatures(value: Map[String, String]): this.type = {
187-
if (get(signatures).isEmpty)
188-
set(signatures, value)
187+
set(signatures, value)
189188
this
190189
}
191190

0 commit comments

Comments
 (0)