@@ -26,6 +26,13 @@ import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, Struct
26
26
*/
27
27
sealed abstract class Attribute extends Serializable {
28
28
29
+ name.foreach { n =>
30
+ require(n.nonEmpty, " Cannot have an empty string for name." )
31
+ }
32
+ index.foreach { i =>
33
+ require(i >= 0 , s " Index cannot be negative but got $i" )
34
+ }
35
+
29
36
/** Attribute type. */
30
37
def attrType : AttributeType
31
38
@@ -155,6 +162,13 @@ class NumericAttribute private[ml] (
155
162
val std : Option [Double ] = None ,
156
163
val sparsity : Option [Double ] = None ) extends Attribute {
157
164
165
+ std.foreach { s =>
166
+ require(s >= 0.0 , s " Standard deviation cannot be negative but got $s. " )
167
+ }
168
+ sparsity.foreach { s =>
169
+ require(s >= 0.0 && s <= 1.0 , s " Sparsity must be in [0, 1] but got $s. " )
170
+ }
171
+
158
172
override def attrType : AttributeType = AttributeType .Numeric
159
173
160
174
override def withName (name : String ): NumericAttribute = copy(name = Some (name))
@@ -271,16 +285,23 @@ object NumericAttribute extends AttributeFactory {
271
285
* @param name optional name
272
286
* @param index optional index
273
287
* @param isOrdinal whether this attribute is ordinal (optional)
274
- * @param cardinality optional number of values
275
- * @param values optional values
288
+ * @param numValues optional number of values. At most one of `numValues` and `values` can be
289
+ * defined.
290
+ * @param values optional values. At most one of `numValues` and `values` can be defined.
276
291
*/
277
292
class NominalAttribute private [ml] (
278
293
override val name : Option [String ] = None ,
279
294
override val index : Option [Int ] = None ,
280
295
val isOrdinal : Option [Boolean ] = None ,
281
- val cardinality : Option [Int ] = None ,
296
+ val numValues : Option [Int ] = None ,
282
297
val values : Option [Array [String ]] = None ) extends Attribute {
283
298
299
+ numValues.foreach { n =>
300
+ require(n >= 0 , s " numValues cannot be negative but got $n. " )
301
+ }
302
+ require(! (numValues.isDefined && values.isDefined),
303
+ " Cannot have both numValues and values defined." )
304
+
284
305
override def attrType : AttributeType = AttributeType .Nominal
285
306
286
307
override def isNumeric : Boolean = false
@@ -299,58 +320,57 @@ class NominalAttribute private[ml] (
299
320
/** Tests whether this attribute contains a specific value. */
300
321
def hasValue (value : String ): Boolean = valueToIndex.contains(value)
301
322
302
- /** Copy with new values. */
323
+ /** Gets a value given its index. */
324
+ def getValue (index : Int ): String = values.get(index)
325
+
326
+ override def withName (name : String ): NominalAttribute = copy(name = Some (name))
327
+ override def withoutName : NominalAttribute = copy(name = None )
328
+
329
+ override def withIndex (index : Int ): NominalAttribute = copy(index = Some (index))
330
+ override def withoutIndex : NominalAttribute = copy(index = None )
331
+
332
+ /** Copy with new values and empty `numValues`. */
303
333
def withValues (values : Array [String ]): NominalAttribute = {
304
- copy(cardinality = None , values = Some (values))
334
+ copy(numValues = None , values = Some (values))
305
335
}
306
336
307
- /** Copy with new vaues . */
337
+ /** Copy with new values and empty `numValues` . */
308
338
@ varargs
309
339
def withValues (first : String , others : String * ): NominalAttribute = {
310
- copy(cardinality = None , values = Some ((first +: others).toArray))
340
+ copy(numValues = None , values = Some ((first +: others).toArray))
311
341
}
312
342
313
343
/** Copy without the values. */
314
344
def withoutValues : NominalAttribute = {
315
345
copy(values = None )
316
346
}
317
347
318
- /** Copy with a new cardinality. */
319
- def withCardinality (cardinality : Int ): NominalAttribute = {
320
- if (values.isDefined) {
321
- throw new IllegalArgumentException (" Cannot copy with cardinality if values are defined." )
322
- } else {
323
- copy(cardinality = Some (cardinality))
324
- }
348
+ /** Copy with a new `numValues` and empty `values`. */
349
+ def withNumValues (numValues : Int ): NominalAttribute = {
350
+ copy(numValues = Some (numValues), values = None )
325
351
}
326
352
327
- /** Copy without the cardinality . */
328
- def withoutCardinality : NominalAttribute = copy(cardinality = None )
353
+ /** Copy without the `numValues` . */
354
+ def withoutNumValues : NominalAttribute = copy(numValues = None )
329
355
330
356
/** Creates a copy of this attribute with optional changes. */
331
357
private def copy (
332
358
name : Option [String ] = name,
333
359
index : Option [Int ] = index,
334
360
isOrdinal : Option [Boolean ] = isOrdinal,
335
- cardinality : Option [Int ] = cardinality ,
361
+ numValues : Option [Int ] = numValues ,
336
362
values : Option [Array [String ]] = values): NominalAttribute = {
337
- new NominalAttribute (name, index, isOrdinal, cardinality , values)
363
+ new NominalAttribute (name, index, isOrdinal, numValues , values)
338
364
}
339
365
340
- override def withName (name : String ): NominalAttribute = copy(name = Some (name))
341
- override def withoutName : NominalAttribute = copy(name = None )
342
-
343
- override def withIndex (index : Int ): NominalAttribute = copy(index = Some (index))
344
- override def withoutIndex : NominalAttribute = copy(index = None )
345
-
346
366
private [attribute] override def toMetadata (withType : Boolean ): Metadata = {
347
367
import org .apache .spark .ml .attribute .AttributeKeys ._
348
368
val bldr = new MetadataBuilder ()
349
369
if (withType) bldr.putString(TYPE , attrType.name)
350
370
name.foreach(bldr.putString(NAME , _))
351
371
index.foreach(bldr.putLong(INDEX , _))
352
372
isOrdinal.foreach(bldr.putBoolean(ORDINAL , _))
353
- cardinality .foreach(bldr.putLong(CARDINALITY , _))
373
+ numValues .foreach(bldr.putLong(CARDINALITY , _))
354
374
values.foreach(v => bldr.putStringArray(VALUES , v))
355
375
bldr.build()
356
376
}
@@ -361,7 +381,7 @@ class NominalAttribute private[ml] (
361
381
(name == o.name) &&
362
382
(index == o.index) &&
363
383
(isOrdinal == o.isOrdinal) &&
364
- (cardinality == o.cardinality ) &&
384
+ (numValues == o.numValues ) &&
365
385
(values.map(_.toSeq) == o.values.map(_.toSeq))
366
386
case _ =>
367
387
false
@@ -373,7 +393,7 @@ class NominalAttribute private[ml] (
373
393
sum = 37 * sum + name.hashCode
374
394
sum = 37 * sum + index.hashCode
375
395
sum = 37 * sum + isOrdinal.hashCode
376
- sum = 37 * sum + cardinality .hashCode
396
+ sum = 37 * sum + numValues .hashCode
377
397
sum = 37 * sum + values.map(_.toSeq).hashCode
378
398
sum
379
399
}
@@ -410,6 +430,10 @@ class BinaryAttribute private[ml] (
410
430
val values : Option [Array [String ]] = None )
411
431
extends Attribute {
412
432
433
+ values.foreach { v =>
434
+ require(v.length == 2 , s " Number of values must be 2 for a binary attribute but got ${v.toSeq}. " )
435
+ }
436
+
413
437
override def attrType : AttributeType = AttributeType .Binary
414
438
415
439
override def isNumeric : Boolean = true
0 commit comments