Skip to content

Commit b1aceef

Browse files
committed
more tests
1 parent e7ab467 commit b1aceef

File tree

4 files changed

+103
-29
lines changed

4 files changed

+103
-29
lines changed

mllib/src/main/scala/org/apache/spark/ml/attribute/AttributeGroup.scala

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ class AttributeGroup private (
3636
val numAttributes: Option[Int],
3737
attrs: Option[Array[Attribute]]) extends Serializable {
3838

39+
require(name.nonEmpty, "Cannot have an empty string for name.")
40+
require(!(numAttributes.isDefined && attrs.isDefined),
41+
"Cannot have both numAttributes and attrs defined.")
42+
3943
/**
4044
* Creates an attribute group without attribute info.
4145
* @param name name of the attribute group
@@ -87,13 +91,20 @@ class AttributeGroup private (
8791
/** Index of an attribute specified by name. */
8892
def indexOf(attrName: String): Int = nameToIndex(attrName)
8993

90-
/** Gets an attribute by name. */
94+
/** Gets an attribute by its name. */
9195
def apply(attrName: String): Attribute = {
9296
attributes.get(indexOf(attrName))
9397
}
9498

99+
/** Gets an attribute by its name. */
100+
def getAttr(attrName: String): Attribute = this(attrName)
101+
102+
/** Gets an attribute by its index. */
95103
def apply(attrIndex: Int): Attribute = attributes.get(attrIndex)
96104

105+
/** Gets an attribute by its index. */
106+
def getAttr(attrIndex: Int): Attribute = this(attrIndex)
107+
97108
/** Converts to metadata without name. */
98109
private[attribute] def toMetadata: Metadata = {
99110
import AttributeKeys._

mllib/src/main/scala/org/apache/spark/ml/attribute/attributes.scala

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,13 @@ import org.apache.spark.sql.types.{DoubleType, Metadata, MetadataBuilder, Struct
2626
*/
2727
sealed abstract class Attribute extends Serializable {
2828

29+
name.foreach { n =>
30+
require(n.nonEmpty, "Cannot have an empty string for name.")
31+
}
32+
index.foreach { i =>
33+
require(i >= 0, s"Index cannot be negative but got $i")
34+
}
35+
2936
/** Attribute type. */
3037
def attrType: AttributeType
3138

@@ -155,6 +162,13 @@ class NumericAttribute private[ml] (
155162
val std: Option[Double] = None,
156163
val sparsity: Option[Double] = None) extends Attribute {
157164

165+
std.foreach { s =>
166+
require(s >= 0.0, s"Standard deviation cannot be negative but got $s.")
167+
}
168+
sparsity.foreach { s =>
169+
require(s >= 0.0 && s <= 1.0, s"Sparsity must be in [0, 1] but got $s.")
170+
}
171+
158172
override def attrType: AttributeType = AttributeType.Numeric
159173

160174
override def withName(name: String): NumericAttribute = copy(name = Some(name))
@@ -271,16 +285,23 @@ object NumericAttribute extends AttributeFactory {
271285
* @param name optional name
272286
* @param index optional index
273287
* @param isOrdinal whether this attribute is ordinal (optional)
274-
* @param cardinality optional number of values
275-
* @param values optional values
288+
* @param numValues optional number of values. At most one of `numValues` and `values` can be
289+
* defined.
290+
* @param values optional values. At most one of `numValues` and `values` can be defined.
276291
*/
277292
class NominalAttribute private[ml] (
278293
override val name: Option[String] = None,
279294
override val index: Option[Int] = None,
280295
val isOrdinal: Option[Boolean] = None,
281-
val cardinality: Option[Int] = None,
296+
val numValues: Option[Int] = None,
282297
val values: Option[Array[String]] = None) extends Attribute {
283298

299+
numValues.foreach { n =>
300+
require(n >= 0, s"numValues cannot be negative but got $n.")
301+
}
302+
require(!(numValues.isDefined && values.isDefined),
303+
"Cannot have both numValues and values defined.")
304+
284305
override def attrType: AttributeType = AttributeType.Nominal
285306

286307
override def isNumeric: Boolean = false
@@ -299,58 +320,57 @@ class NominalAttribute private[ml] (
299320
/** Tests whether this attribute contains a specific value. */
300321
def hasValue(value: String): Boolean = valueToIndex.contains(value)
301322

302-
/** Copy with new values. */
323+
/** Gets a value given its index. */
324+
def getValue(index: Int): String = values.get(index)
325+
326+
override def withName(name: String): NominalAttribute = copy(name = Some(name))
327+
override def withoutName: NominalAttribute = copy(name = None)
328+
329+
override def withIndex(index: Int): NominalAttribute = copy(index = Some(index))
330+
override def withoutIndex: NominalAttribute = copy(index = None)
331+
332+
/** Copy with new values and empty `numValues`. */
303333
def withValues(values: Array[String]): NominalAttribute = {
304-
copy(cardinality = None, values = Some(values))
334+
copy(numValues = None, values = Some(values))
305335
}
306336

307-
/** Copy with new vaues. */
337+
/** Copy with new values and empty `numValues`. */
308338
@varargs
309339
def withValues(first: String, others: String*): NominalAttribute = {
310-
copy(cardinality = None, values = Some((first +: others).toArray))
340+
copy(numValues = None, values = Some((first +: others).toArray))
311341
}
312342

313343
/** Copy without the values. */
314344
def withoutValues: NominalAttribute = {
315345
copy(values = None)
316346
}
317347

318-
/** Copy with a new cardinality. */
319-
def withCardinality(cardinality: Int): NominalAttribute = {
320-
if (values.isDefined) {
321-
throw new IllegalArgumentException("Cannot copy with cardinality if values are defined.")
322-
} else {
323-
copy(cardinality = Some(cardinality))
324-
}
348+
/** Copy with a new `numValues` and empty `values`. */
349+
def withNumValues(numValues: Int): NominalAttribute = {
350+
copy(numValues = Some(numValues), values = None)
325351
}
326352

327-
/** Copy without the cardinality. */
328-
def withoutCardinality: NominalAttribute = copy(cardinality = None)
353+
/** Copy without the `numValues`. */
354+
def withoutNumValues: NominalAttribute = copy(numValues = None)
329355

330356
/** Creates a copy of this attribute with optional changes. */
331357
private def copy(
332358
name: Option[String] = name,
333359
index: Option[Int] = index,
334360
isOrdinal: Option[Boolean] = isOrdinal,
335-
cardinality: Option[Int] = cardinality,
361+
numValues: Option[Int] = numValues,
336362
values: Option[Array[String]] = values): NominalAttribute = {
337-
new NominalAttribute(name, index, isOrdinal, cardinality, values)
363+
new NominalAttribute(name, index, isOrdinal, numValues, values)
338364
}
339365

340-
override def withName(name: String): NominalAttribute = copy(name = Some(name))
341-
override def withoutName: NominalAttribute = copy(name = None)
342-
343-
override def withIndex(index: Int): NominalAttribute = copy(index = Some(index))
344-
override def withoutIndex: NominalAttribute = copy(index = None)
345-
346366
private[attribute] override def toMetadata(withType: Boolean): Metadata = {
347367
import org.apache.spark.ml.attribute.AttributeKeys._
348368
val bldr = new MetadataBuilder()
349369
if (withType) bldr.putString(TYPE, attrType.name)
350370
name.foreach(bldr.putString(NAME, _))
351371
index.foreach(bldr.putLong(INDEX, _))
352372
isOrdinal.foreach(bldr.putBoolean(ORDINAL, _))
353-
cardinality.foreach(bldr.putLong(CARDINALITY, _))
373+
numValues.foreach(bldr.putLong(CARDINALITY, _))
354374
values.foreach(v => bldr.putStringArray(VALUES, v))
355375
bldr.build()
356376
}
@@ -361,7 +381,7 @@ class NominalAttribute private[ml] (
361381
(name == o.name) &&
362382
(index == o.index) &&
363383
(isOrdinal == o.isOrdinal) &&
364-
(cardinality == o.cardinality) &&
384+
(numValues == o.numValues) &&
365385
(values.map(_.toSeq) == o.values.map(_.toSeq))
366386
case _ =>
367387
false
@@ -373,7 +393,7 @@ class NominalAttribute private[ml] (
373393
sum = 37 * sum + name.hashCode
374394
sum = 37 * sum + index.hashCode
375395
sum = 37 * sum + isOrdinal.hashCode
376-
sum = 37 * sum + cardinality.hashCode
396+
sum = 37 * sum + numValues.hashCode
377397
sum = 37 * sum + values.map(_.toSeq).hashCode
378398
sum
379399
}
@@ -410,6 +430,10 @@ class BinaryAttribute private[ml] (
410430
val values: Option[Array[String]] = None)
411431
extends Attribute {
412432

433+
values.foreach { v =>
434+
require(v.length == 2, s"Number of values must be 2 for a binary attribute but got ${v.toSeq}.")
435+
}
436+
413437
override def attrType: AttributeType = AttributeType.Binary
414438

415439
override def isNumeric: Boolean = true

mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeGroupSuite.scala

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,20 @@ class AttributeGroupSuite extends FunSuite {
4646
assert(group === AttributeGroup.fromMetadata(group.toMetadata, group.name))
4747
assert(group === AttributeGroup.fromStructField(group.toStructField()))
4848
}
49+
50+
test("attribute group without attributes") {
51+
val group0 = new AttributeGroup("user", 10)
52+
assert(group0.name === "user")
53+
assert(group0.numAttributes === Some(10))
54+
assert(group0.size === 10)
55+
assert(group0.attributes.isEmpty)
56+
assert(group0 === AttributeGroup.fromMetadata(group0.toMetadata, group0.name))
57+
assert(group0 === AttributeGroup.fromStructField(group0.toStructField()))
58+
59+
val group1 = new AttributeGroup("item")
60+
assert(group1.name === "item")
61+
assert(group1.numAttributes.isEmpty)
62+
assert(group1.attributes.isEmpty)
63+
assert(group1.size === -1)
64+
}
4965
}

mllib/src/test/scala/org/apache/spark/ml/attribute/AttributeSuite.scala

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,15 @@ class AttributeSuite extends FunSuite {
8484
assert(attr2 === Attribute.fromMetadata(attr2.toMetadata()))
8585
}
8686

87+
test("bad numeric attributes") {
88+
val attr = NumericAttribute.defaultAttr
89+
intercept[IllegalArgumentException](attr.withName(""))
90+
intercept[IllegalArgumentException](attr.withIndex(-1))
91+
intercept[IllegalArgumentException](attr.withStd(-0.1))
92+
intercept[IllegalArgumentException](attr.withSparsity(-0.5))
93+
intercept[IllegalArgumentException](attr.withSparsity(1.5))
94+
}
95+
8796
test("default nominal attribute") {
8897
val attr: NominalAttribute = NominalAttribute.defaultAttr
8998
val metadata = Metadata.fromJson("""{"type":"nominal"}""")
@@ -94,7 +103,7 @@ class AttributeSuite extends FunSuite {
94103
assert(attr.name.isEmpty)
95104
assert(attr.index.isEmpty)
96105
assert(attr.values.isEmpty)
97-
assert(attr.cardinality.isEmpty)
106+
assert(attr.numValues.isEmpty)
98107
assert(attr.isOrdinal.isEmpty)
99108
assert(attr.toMetadata() === metadata)
100109
assert(attr.toMetadata(withType = true) === metadata)
@@ -125,6 +134,7 @@ class AttributeSuite extends FunSuite {
125134
assert(attr.index === Some(index))
126135
assert(attr.values === Some(values))
127136
assert(attr.indexOf("medium") === 1)
137+
assert(attr.getValue(1) === "medium")
128138
assert(attr.toMetadata() === metadata)
129139
assert(attr.toMetadata(withType = true) === metadata)
130140
assert(attr.toMetadata(withType = false) === metadataWithoutType)
@@ -141,6 +151,13 @@ class AttributeSuite extends FunSuite {
141151
assert(attr2 === NominalAttribute.fromMetadata(attr2.toMetadata(withType = false)))
142152
}
143153

154+
test("bad nominal attributes") {
155+
val attr = NominalAttribute.defaultAttr
156+
intercept[IllegalArgumentException](attr.withName(""))
157+
intercept[IllegalArgumentException](attr.withIndex(-1))
158+
intercept[IllegalArgumentException](attr.withNumValues(-1))
159+
}
160+
144161
test("default binary attribute") {
145162
val attr = BinaryAttribute.defaultAttr
146163
val metadata = Metadata.fromJson("""{"type":"binary"}""")
@@ -186,4 +203,10 @@ class AttributeSuite extends FunSuite {
186203
assert(attr === BinaryAttribute.fromMetadata(metadataWithoutType))
187204
assert(attr.withoutIndex === Attribute.fromStructField(attr.toStructField()))
188205
}
206+
207+
test("bad binary attributes") {
208+
val attr = BinaryAttribute.defaultAttr
209+
intercept[IllegalArgumentException](attr.withName(""))
210+
intercept[IllegalArgumentException](attr.withIndex(-1))
211+
}
189212
}

0 commit comments

Comments
 (0)