Skip to content

[SPARK-5015] [mllib] Random seed for GMM + make test suite deterministic #3981

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.mllib.linalg.{Matrices, Vector, Vectors, DenseVector, DenseMatrix, BLAS}
import org.apache.spark.mllib.stat.impl.MultivariateGaussian
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.util.Utils

/**
* This class performs expectation maximization for multivariate Gaussian
Expand All @@ -45,10 +46,11 @@ import org.apache.spark.mllib.util.MLUtils
class GaussianMixtureEM private (
private var k: Int,
private var convergenceTol: Double,
private var maxIterations: Int) extends Serializable {
private var maxIterations: Int,
private var seed: Long) extends Serializable {

/** A default instance, 2 Gaussians, 100 iterations, 0.01 log-likelihood threshold */
def this() = this(2, 0.01, 100)
def this() = this(2, 0.01, 100, Utils.random.nextLong())

// number of samples per cluster to use when initializing Gaussians
private val nSamples = 5
Expand Down Expand Up @@ -100,11 +102,21 @@ class GaussianMixtureEM private (
this
}

/** Return the largest change in log-likelihood at which convergence is
* considered to have occurred.
/**
* Return the largest change in log-likelihood at which convergence is
* considered to have occurred.
*/
def getConvergenceTol: Double = convergenceTol


/** Set the random seed */
def setSeed(seed: Long): this.type = {
this.seed = seed
this
}

/** Return the random seed */
def getSeed: Long = seed

/** Perform expectation maximization */
def run(data: RDD[Vector]): GaussianMixtureModel = {
val sc = data.sparkContext
Expand All @@ -113,7 +125,7 @@ class GaussianMixtureEM private (
val breezeData = data.map(u => u.toBreeze.toDenseVector).cache()

// Get length of the input vectors
val d = breezeData.first.length
val d = breezeData.first().length

// Determine initial weights and corresponding Gaussians.
// If the user supplied an initial GMM, we use those values, otherwise
Expand All @@ -126,7 +138,7 @@ class GaussianMixtureEM private (
})

case None => {
val samples = breezeData.takeSample(true, k * nSamples, scala.util.Random.nextInt)
val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed)
(Array.fill(k)(1.0 / k), Array.tabulate(k) { i =>
val slice = samples.view(i * nSamples, (i + 1) * nSamples)
new MultivariateGaussian(vectorMean(slice), initCovariance(slice))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,14 @@ class GMMExpectationMaximizationSuite extends FunSuite with MLlibTestSparkContex
val Ew = 1.0
val Emu = Vectors.dense(5.0, 10.0)
val Esigma = Matrices.dense(2, 2, Array(2.0 / 3.0, -2.0 / 3.0, -2.0 / 3.0, 2.0 / 3.0))

val gmm = new GaussianMixtureEM().setK(1).run(data)

assert(gmm.weight(0) ~== Ew absTol 1E-5)
assert(gmm.mu(0) ~== Emu absTol 1E-5)
assert(gmm.sigma(0) ~== Esigma absTol 1E-5)

val seeds = Array(314589, 29032897, 50181, 494821, 4660)
seeds.foreach { seed =>
val gmm = new GaussianMixtureEM().setK(1).setSeed(seed).run(data)
assert(gmm.weight(0) ~== Ew absTol 1E-5)
assert(gmm.mu(0) ~== Emu absTol 1E-5)
assert(gmm.sigma(0) ~== Esigma absTol 1E-5)
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the single cluster case, I don't think the seed can make a difference. With only one cluster, this just makes sure the right mean and covariance are computed correctly.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, you're correct. The issue is that, with some probability, the 5 data points chosen for initializing the cluster will be identical, causing initialization of the covariance matrix to fail.

I'd be fine with having this test use only 1 random seed instead of 5, but it's fast anyways.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it!

}

test("two clusters") {
Expand Down