-
Notifications
You must be signed in to change notification settings - Fork 28.7k
[SPARK-7045] [MLlib] Avoid intermediate representation when creating model #5748
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -403,17 +403,8 @@ class Word2Vec extends Serializable with Logging { | |
} | ||
newSentences.unpersist() | ||
|
||
val word2VecMap = mutable.HashMap.empty[String, Array[Float]] | ||
var i = 0 | ||
while (i < vocabSize) { | ||
val word = bcVocab.value(i).word | ||
val vector = new Array[Float](vectorSize) | ||
Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize) | ||
word2VecMap += word -> vector | ||
i += 1 | ||
} | ||
|
||
new Word2VecModel(word2VecMap.toMap) | ||
val wordArray = vocab.map(_.word) | ||
new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) | ||
} | ||
|
||
/** | ||
|
@@ -429,38 +420,42 @@ class Word2Vec extends Serializable with Logging { | |
/** | ||
* :: Experimental :: | ||
* Word2Vec model | ||
* @param wordIndex maps each word to an index, which can retrieve the corresponding | ||
* vector from wordVectors | ||
* @param wordVectors array of length numWords * vectorSize, vector corresponding | ||
* to the word mapped with index i can be retrieved by the slice | ||
* (i * vectorSize, i * vectorSize + vectorSize) | ||
*/ | ||
@Experimental | ||
class Word2VecModel private[spark] ( | ||
model: Map[String, Array[Float]]) extends Serializable with Saveable { | ||
|
||
// wordList: Ordered list of words obtained from model. | ||
private val wordList: Array[String] = model.keys.toArray | ||
|
||
// wordIndex: Maps each word to an index, which can retrieve the corresponding | ||
// vector from wordVectors (see below). | ||
private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap | ||
class Word2VecModel private[mllib] ( | ||
private val wordIndex: Map[String, Int], | ||
private val wordVectors: Array[Float]) extends Serializable with Saveable { | ||
|
||
// vectorSize: Dimension of each word's vector. | ||
private val vectorSize = model.head._2.size | ||
private val numWords = wordIndex.size | ||
// vectorSize: Dimension of each word's vector. | ||
private val vectorSize = wordVectors.length / numWords | ||
|
||
// wordList: Ordered list of words obtained from wordIndex. | ||
private val wordList: Array[String] = { | ||
val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip | ||
wl.toArray | ||
} | ||
|
||
// wordVectors: Array of length numWords * vectorSize, vector corresponding to the word | ||
// mapped with index i can be retrieved by the slice | ||
// (ind * vectorSize, ind * vectorSize + vectorSize) | ||
// wordVecNorms: Array of length numWords, each value being the Euclidean norm | ||
// of the wordVector. | ||
private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = { | ||
val wordVectors = new Array[Float](vectorSize * numWords) | ||
private val wordVecNorms: Array[Double] = { | ||
val wordVecNorms = new Array[Double](numWords) | ||
var i = 0 | ||
while (i < numWords) { | ||
val vec = model.get(wordList(i)).get | ||
Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize) | ||
val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize) | ||
wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1) | ||
i += 1 | ||
} | ||
(wordVectors, wordVecNorms) | ||
wordVecNorms | ||
} | ||
|
||
def this(model: Map[String, Array[Float]]) = { | ||
this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model)) | ||
} | ||
|
||
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { | ||
|
@@ -484,8 +479,9 @@ class Word2VecModel private[spark] ( | |
* @return vector representation of word | ||
*/ | ||
def transform(word: String): Vector = { | ||
model.get(word) match { | ||
case Some(vec) => | ||
wordIndex.get(word) match { | ||
case Some(ind) => | ||
val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize) | ||
Vectors.dense(vec.map(_.toDouble)) | ||
case None => | ||
throw new IllegalStateException(s"$word not in vocabulary") | ||
|
@@ -511,7 +507,7 @@ class Word2VecModel private[spark] ( | |
*/ | ||
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { | ||
require(num > 0, "Number of similar words should > 0") | ||
|
||
// TODO: optimize top-k | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a JIRA for this? If so, can you please note the JIRA number here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I see. Can you please make a JIRA and add its number to the comment here? |
||
val fVector = vector.toArray.map(_.toFloat) | ||
val cosineVec = Array.fill[Float](numWords)(0) | ||
val alpha: Float = 1 | ||
|
@@ -521,13 +517,13 @@ class Word2VecModel private[spark] ( | |
"T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) | ||
|
||
// Need not divide with the norm of the given vector since it is constant. | ||
val updatedCosines = new Array[Double](numWords) | ||
val cosVec = cosineVec.map(_.toDouble) | ||
var ind = 0 | ||
while (ind < numWords) { | ||
updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind) | ||
cosVec(ind) /= wordVecNorms(ind) | ||
ind += 1 | ||
} | ||
wordList.zip(updatedCosines) | ||
wordList.zip(cosVec) | ||
.toSeq | ||
.sortBy(- _._2) | ||
.take(num + 1) | ||
|
@@ -548,6 +544,23 @@ class Word2VecModel private[spark] ( | |
@Experimental | ||
object Word2VecModel extends Loader[Word2VecModel] { | ||
|
||
private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = { | ||
model.keys.zipWithIndex.toMap | ||
} | ||
|
||
private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = { | ||
require(model.nonEmpty, "Word2VecMap should be non-empty") | ||
val (vectorSize, numWords) = (model.head._2.size, model.size) | ||
val wordList = model.keys.toArray | ||
val wordVectors = new Array[Float](vectorSize * numWords) | ||
var i = 0 | ||
while (i < numWords) { | ||
Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize) | ||
i += 1 | ||
} | ||
wordVectors | ||
} | ||
|
||
private object SaveLoadV1_0 { | ||
|
||
val formatVersionV1_0 = "1.0" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work if you call
wordVectors.view.slice(...)
instead? I think "view" will tell Scala not to physically create a copy of the slice.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure? I think a copy of the slice will be produced anyway. It seems if it is a collection.view then it does not produce a copy of collection.
Ref: (http://stackoverflow.com/a/6799739/1170730)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It gives me a compilation error (because Vectors.dense accepts just an Array), so that also works in favor of not changing it :p
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You're right, for this one, we have to make a copy anyways.