-
Notifications
You must be signed in to change notification settings - Fork 28.7k
[SPARK-4588] ML Attributes #4925
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
2a21d6d
7c944da
e7ab467
b1aceef
393ffdc
617be40
71d1bd0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,234 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.attribute | ||
|
||
import scala.collection.mutable.ArrayBuffer | ||
|
||
import org.apache.spark.mllib.linalg.VectorUDT | ||
import org.apache.spark.sql.types.{Metadata, MetadataBuilder, StructField} | ||
|
||
/** | ||
* Attributes that describe a vector ML column. | ||
* | ||
* @param name name of the attribute group (the ML column name) | ||
* @param numAttributes optional number of attributes. At most one of `numAttributes` and `attrs` | ||
* can be defined. | ||
* @param attrs optional array of attributes. Attribute will be copied with their corresponding | ||
* indices in the array. | ||
*/ | ||
class AttributeGroup private ( | ||
val name: String, | ||
val numAttributes: Option[Int], | ||
attrs: Option[Array[Attribute]]) extends Serializable { | ||
|
||
require(name.nonEmpty, "Cannot have an empty string for name.") | ||
require(!(numAttributes.isDefined && attrs.isDefined), | ||
"Cannot have both numAttributes and attrs defined.") | ||
|
||
/** | ||
* Creates an attribute group without attribute info. | ||
* @param name name of the attribute group | ||
*/ | ||
def this(name: String) = this(name, None, None) | ||
|
||
/** | ||
* Creates an attribute group knowing only the number of attributes. | ||
* @param name name of the attribute group | ||
* @param numAttributes number of attributes | ||
*/ | ||
def this(name: String, numAttributes: Int) = this(name, Some(numAttributes), None) | ||
|
||
/** | ||
* Creates an attribute group with attributes. | ||
* @param name name of the attribute group | ||
* @param attrs array of attributes. Attributes will be copied with their corresponding indices in | ||
* the array. | ||
*/ | ||
def this(name: String, attrs: Array[Attribute]) = this(name, None, Some(attrs)) | ||
|
||
/** | ||
* Optional array of attributes. At most one of `numAttributes` and `attributes` can be defined. | ||
*/ | ||
val attributes: Option[Array[Attribute]] = attrs.map(_.view.zipWithIndex.map { case (attr, i) => | ||
attr.withIndex(i) | ||
}.toArray) | ||
|
||
private lazy val nameToIndex: Map[String, Int] = { | ||
attributes.map(_.view.flatMap { attr => | ||
attr.name.map(_ -> attr.index.get) | ||
}.toMap).getOrElse(Map.empty) | ||
} | ||
|
||
/** Size of the attribute group. Returns -1 if the size is unknown. */ | ||
def size: Int = { | ||
if (numAttributes.isDefined) { | ||
numAttributes.get | ||
} else if (attributes.isDefined) { | ||
attributes.get.length | ||
} else { | ||
-1 | ||
} | ||
} | ||
|
||
/** Test whether this attribute group contains a specific attribute. */ | ||
def hasAttr(attrName: String): Boolean = nameToIndex.contains(attrName) | ||
|
||
/** Index of an attribute specified by name. */ | ||
def indexOf(attrName: String): Int = nameToIndex(attrName) | ||
|
||
/** Gets an attribute by its name. */ | ||
def apply(attrName: String): Attribute = { | ||
attributes.get(indexOf(attrName)) | ||
} | ||
|
||
/** Gets an attribute by its name. */ | ||
def getAttr(attrName: String): Attribute = this(attrName) | ||
|
||
/** Gets an attribute by its index. */ | ||
def apply(attrIndex: Int): Attribute = attributes.get(attrIndex) | ||
|
||
/** Gets an attribute by its index. */ | ||
def getAttr(attrIndex: Int): Attribute = this(attrIndex) | ||
|
||
/** Converts to metadata without name. */ | ||
private[attribute] def toMetadata: Metadata = { | ||
import AttributeKeys._ | ||
val bldr = new MetadataBuilder() | ||
if (attributes.isDefined) { | ||
val numericMetadata = ArrayBuffer.empty[Metadata] | ||
val nominalMetadata = ArrayBuffer.empty[Metadata] | ||
val binaryMetadata = ArrayBuffer.empty[Metadata] | ||
attributes.get.foreach { | ||
case numeric: NumericAttribute => | ||
// Skip default numeric attributes. | ||
if (numeric.withoutIndex != NumericAttribute.defaultAttr) { | ||
numericMetadata += numeric.toMetadata(withType = false) | ||
} | ||
case nominal: NominalAttribute => | ||
nominalMetadata += nominal.toMetadata(withType = false) | ||
case binary: BinaryAttribute => | ||
binaryMetadata += binary.toMetadata(withType = false) | ||
} | ||
val attrBldr = new MetadataBuilder | ||
if (numericMetadata.nonEmpty) { | ||
attrBldr.putMetadataArray(AttributeType.Numeric.name, numericMetadata.toArray) | ||
} | ||
if (nominalMetadata.nonEmpty) { | ||
attrBldr.putMetadataArray(AttributeType.Nominal.name, nominalMetadata.toArray) | ||
} | ||
if (binaryMetadata.nonEmpty) { | ||
attrBldr.putMetadataArray(AttributeType.Binary.name, binaryMetadata.toArray) | ||
} | ||
bldr.putMetadata(ATTRIBUTES, attrBldr.build()) | ||
bldr.putLong(NUM_ATTRIBUTES, attributes.get.length) | ||
} else if (numAttributes.isDefined) { | ||
bldr.putLong(NUM_ATTRIBUTES, numAttributes.get) | ||
} | ||
bldr.build() | ||
} | ||
|
||
/** Converts to a StructField with some existing metadata. */ | ||
def toStructField(existingMetadata: Metadata): StructField = { | ||
val newMetadata = new MetadataBuilder() | ||
.withMetadata(existingMetadata) | ||
.putMetadata(AttributeKeys.ML_ATTR, toMetadata) | ||
.build() | ||
StructField(name, new VectorUDT, nullable = false, newMetadata) | ||
} | ||
|
||
/** Converts to a StructField. */ | ||
def toStructField(): StructField = toStructField(Metadata.empty) | ||
|
||
override def equals(other: Any): Boolean = { | ||
other match { | ||
case o: AttributeGroup => | ||
(name == o.name) && | ||
(numAttributes == o.numAttributes) && | ||
(attributes.map(_.toSeq) == o.attributes.map(_.toSeq)) | ||
case _ => | ||
false | ||
} | ||
} | ||
|
||
override def hashCode: Int = { | ||
var sum = 17 | ||
sum = 37 * sum + name.hashCode | ||
sum = 37 * sum + numAttributes.hashCode | ||
sum = 37 * sum + attributes.map(_.toSeq).hashCode | ||
sum | ||
} | ||
} | ||
|
||
/** Factory methods to create attribute groups. */ | ||
object AttributeGroup { | ||
|
||
import AttributeKeys._ | ||
|
||
/** Creates an attribute group from a [[Metadata]] instance with name. */ | ||
private[attribute] def fromMetadata(metadata: Metadata, name: String): AttributeGroup = { | ||
import org.apache.spark.ml.attribute.AttributeType._ | ||
if (metadata.contains(ATTRIBUTES)) { | ||
val numAttrs = metadata.getLong(NUM_ATTRIBUTES).toInt | ||
val attributes = new Array[Attribute](numAttrs) | ||
val attrMetadata = metadata.getMetadata(ATTRIBUTES) | ||
if (attrMetadata.contains(Numeric.name)) { | ||
attrMetadata.getMetadataArray(Numeric.name) | ||
.map(NumericAttribute.fromMetadata) | ||
.foreach { attr => | ||
attributes(attr.index.get) = attr | ||
} | ||
} | ||
if (attrMetadata.contains(Nominal.name)) { | ||
attrMetadata.getMetadataArray(Nominal.name) | ||
.map(NominalAttribute.fromMetadata) | ||
.foreach { attr => | ||
attributes(attr.index.get) = attr | ||
} | ||
} | ||
if (attrMetadata.contains(Binary.name)) { | ||
attrMetadata.getMetadataArray(Binary.name) | ||
.map(BinaryAttribute.fromMetadata) | ||
.foreach { attr => | ||
attributes(attr.index.get) = attr | ||
} | ||
} | ||
var i = 0 | ||
while (i < numAttrs) { | ||
if (attributes(i) == null) { | ||
attributes(i) = NumericAttribute.defaultAttr | ||
} | ||
i += 1 | ||
} | ||
new AttributeGroup(name, attributes) | ||
} else if (metadata.contains(NUM_ATTRIBUTES)) { | ||
new AttributeGroup(name, metadata.getLong(NUM_ATTRIBUTES).toInt) | ||
} else { | ||
new AttributeGroup(name) | ||
} | ||
} | ||
|
||
/** Creates an attribute group from a [[StructField]] instance. */ | ||
def fromStructField(field: StructField): AttributeGroup = { | ||
require(field.dataType == new VectorUDT) | ||
if (field.metadata.contains(ML_ATTR)) { | ||
fromMetadata(field.metadata.getMetadata(ML_ATTR), field.name) | ||
} else { | ||
new AttributeGroup(field.name) | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.attribute | ||
|
||
/** | ||
* Keys used to store attributes. | ||
*/ | ||
private[attribute] object AttributeKeys { | ||
val ML_ATTR: String = "ml_attr" | ||
val TYPE: String = "type" | ||
val NAME: String = "name" | ||
val INDEX: String = "idx" | ||
val MIN: String = "min" | ||
val MAX: String = "max" | ||
val STD: String = "std" | ||
val SPARSITY: String = "sparsity" | ||
val ORDINAL: String = "ord" | ||
val VALUES: String = "vals" | ||
val NUM_VALUES: String = "num_vals" | ||
val ATTRIBUTES: String = "attrs" | ||
val NUM_ATTRIBUTES: String = "num_attrs" | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.ml.attribute | ||
|
||
/** | ||
* An enum-like type for attribute types: [[AttributeType$#Numeric]], [[AttributeType$#Nominal]], | ||
* and [[AttributeType$#Binary]]. | ||
*/ | ||
sealed abstract class AttributeType(val name: String) | ||
|
||
object AttributeType { | ||
|
||
/** Numeric type. */ | ||
val Numeric: AttributeType = { | ||
case object Numeric extends AttributeType("numeric") | ||
Numeric | ||
} | ||
|
||
/** Nominal type. */ | ||
val Nominal: AttributeType = { | ||
case object Nominal extends AttributeType("nominal") | ||
Nominal | ||
} | ||
|
||
/** Binary type. */ | ||
val Binary: AttributeType = { | ||
case object Binary extends AttributeType("binary") | ||
Binary | ||
} | ||
|
||
/** | ||
* Gets the [[AttributeType]] object from its name. | ||
* @param name attribute type name: "numeric", "nominal", or "binary" | ||
*/ | ||
def fromName(name: String): AttributeType = { | ||
if (name == Numeric.name) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can this be done with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Verified that |
||
Numeric | ||
} else if (name == Nominal.name) { | ||
Nominal | ||
} else if (name == Binary.name) { | ||
Binary | ||
} else { | ||
throw new IllegalArgumentException(s"Cannot recognize type $name.") | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm probably overlooking this but how do I put a
AttributeGroup
inside anotherAttributeGroup
? they are intended to be nest-able?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are no nested groups. An ML column is either a Double column or a Vector column.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK if this is the abstraction for a vector-valued column, what is the abstraction for the overall set of features? Let's say I have two numeric columns and one vector-valued column -- where is that represented?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You need a
VectorAssembler
to merge those three columns into a single vector column. TheVectorAssembler
will merge the ML attributes as well. In this way, the algorithms only need to handle vector columns.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does that effectively 'flatten' the schema, you mean? so if I have two numeric columns, and one vector-valued columns with 3 features inside, then the assembler makes a representation of 5 features? If that's right, then yes now I get what you mean by this. So, there is nothing here that represents the original schema with its nesting, since that would never be how it is presented to an algorithm, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that is correct. We will use both the group name and the attribute name to name to flattened attribute. For example, if
user
is a vector input column, in the flattened column, you see feature names likeuser:age
anduser:gender
.