Skip to content

Commit 255b56f

Browse files
DB Tsaimengxr
authored andcommitted
[SPARK-2479][MLlib] Comparing floating-point numbers using relative error in UnitTests
Floating point math is not exact, and most floating-point numbers end up being slightly imprecise due to rounding errors. Simple values like 0.1 cannot be precisely represented using binary floating point numbers, and the limited precision of floating point numbers means that slight changes in the order of operations or the precision of intermediates can change the result. That means that comparing two floats to see if they are equal is usually not what we want. As long as this imprecision stays small, it can usually be ignored. Based on discussion in the community, we have implemented two different APIs for relative tolerance, and absolute tolerance. It makes sense that test writers should know which one they need depending on their circumstances. Developers also need to explicitly specify the eps, and there is no default value which will sometimes cause confusion. When comparing against zero using relative tolerance, a exception will be raised to warn users that it's meaningless. For relative tolerance, users can now write assert(23.1 ~== 23.52 relTol 0.02) assert(23.1 ~== 22.74 relTol 0.02) assert(23.1 ~= 23.52 relTol 0.02) assert(23.1 ~= 22.74 relTol 0.02) assert(!(23.1 !~= 23.52 relTol 0.02)) assert(!(23.1 !~= 22.74 relTol 0.02)) // This will throw exception with the following message. // "Did not expect 23.1 and 23.52 to be within 0.02 using relative tolerance." assert(23.1 !~== 23.52 relTol 0.02) // "Expected 23.1 and 22.34 to be within 0.02 using relative tolerance." assert(23.1 ~== 22.34 relTol 0.02) For absolute error, assert(17.8 ~== 17.99 absTol 0.2) assert(17.8 ~== 17.61 absTol 0.2) assert(17.8 ~= 17.99 absTol 0.2) assert(17.8 ~= 17.61 absTol 0.2) assert(!(17.8 !~= 17.99 absTol 0.2)) assert(!(17.8 !~= 17.61 absTol 0.2)) // This will throw exception with the following message. // "Did not expect 17.8 and 17.99 to be within 0.2 using absolute error." assert(17.8 !~== 17.99 absTol 0.2) // "Expected 17.8 and 17.59 to be within 0.2 using absolute error." assert(17.8 ~== 17.59 absTol 0.2) Authors: DB Tsai <dbtsaialpinenow.com> Marek Kolodziej <marekalpinenow.com> Author: DB Tsai <[email protected]> Closes #1425 from dbtsai/SPARK-2479_comparing_floating_point and squashes the following commits: 8c7cbcc [DB Tsai] Alpine Data Labs
1 parent 2b8d89e commit 255b56f

File tree

10 files changed

+438
-130
lines changed

10 files changed

+438
-130
lines changed

mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ import org.scalatest.Matchers
2626
import org.apache.spark.mllib.linalg.Vectors
2727
import org.apache.spark.mllib.regression._
2828
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
29+
import org.apache.spark.mllib.util.TestingUtils._
2930

3031
object LogisticRegressionSuite {
3132

@@ -81,9 +82,8 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
8182
val model = lr.run(testRDD)
8283

8384
// Test the weights
84-
val weight0 = model.weights(0)
85-
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
86-
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
85+
assert(model.weights(0) ~== -1.52 relTol 0.01)
86+
assert(model.intercept ~== 2.00 relTol 0.01)
8787

8888
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
8989
val validationRDD = sc.parallelize(validationData, 2)
@@ -113,9 +113,9 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Match
113113

114114
val model = lr.run(testRDD, initialWeights)
115115

116-
val weight0 = model.weights(0)
117-
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
118-
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
116+
// Test the weights
117+
assert(model.weights(0) ~== -1.50 relTol 0.01)
118+
assert(model.intercept ~== 1.97 relTol 0.01)
119119

120120
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
121121
val validationRDD = sc.parallelize(validationData, 2)

mllib/src/test/scala/org/apache/spark/mllib/clustering/KMeansSuite.scala

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@ import scala.util.Random
2121

2222
import org.scalatest.FunSuite
2323

24-
import org.apache.spark.mllib.linalg.Vectors
24+
import org.apache.spark.mllib.linalg.{Vector, Vectors}
2525
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
26+
import org.apache.spark.mllib.util.TestingUtils._
2627

2728
class KMeansSuite extends FunSuite with LocalSparkContext {
2829

@@ -41,26 +42,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
4142
// centered at the mean of the points
4243

4344
var model = KMeans.train(data, k = 1, maxIterations = 1)
44-
assert(model.clusterCenters.head === center)
45+
assert(model.clusterCenters.head ~== center absTol 1E-5)
4546

4647
model = KMeans.train(data, k = 1, maxIterations = 2)
47-
assert(model.clusterCenters.head === center)
48+
assert(model.clusterCenters.head ~== center absTol 1E-5)
4849

4950
model = KMeans.train(data, k = 1, maxIterations = 5)
50-
assert(model.clusterCenters.head === center)
51+
assert(model.clusterCenters.head ~== center absTol 1E-5)
5152

5253
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
53-
assert(model.clusterCenters.head === center)
54+
assert(model.clusterCenters.head ~== center absTol 1E-5)
5455

5556
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
56-
assert(model.clusterCenters.head === center)
57+
assert(model.clusterCenters.head ~== center absTol 1E-5)
5758

5859
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
59-
assert(model.clusterCenters.head === center)
60+
assert(model.clusterCenters.head ~== center absTol 1E-5)
6061

6162
model = KMeans.train(
6263
data, k = 1, maxIterations = 1, runs = 1, initializationMode = K_MEANS_PARALLEL)
63-
assert(model.clusterCenters.head === center)
64+
assert(model.clusterCenters.head ~== center absTol 1E-5)
6465
}
6566

6667
test("no distinct points") {
@@ -104,26 +105,26 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
104105

105106
var model = KMeans.train(data, k = 1, maxIterations = 1)
106107
assert(model.clusterCenters.size === 1)
107-
assert(model.clusterCenters.head === center)
108+
assert(model.clusterCenters.head ~== center absTol 1E-5)
108109

109110
model = KMeans.train(data, k = 1, maxIterations = 2)
110-
assert(model.clusterCenters.head === center)
111+
assert(model.clusterCenters.head ~== center absTol 1E-5)
111112

112113
model = KMeans.train(data, k = 1, maxIterations = 5)
113-
assert(model.clusterCenters.head === center)
114+
assert(model.clusterCenters.head ~== center absTol 1E-5)
114115

115116
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
116-
assert(model.clusterCenters.head === center)
117+
assert(model.clusterCenters.head ~== center absTol 1E-5)
117118

118119
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
119-
assert(model.clusterCenters.head === center)
120+
assert(model.clusterCenters.head ~== center absTol 1E-5)
120121

121122
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
122-
assert(model.clusterCenters.head === center)
123+
assert(model.clusterCenters.head ~== center absTol 1E-5)
123124

124125
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
125126
initializationMode = K_MEANS_PARALLEL)
126-
assert(model.clusterCenters.head === center)
127+
assert(model.clusterCenters.head ~== center absTol 1E-5)
127128
}
128129

129130
test("single cluster with sparse data") {
@@ -149,31 +150,39 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
149150
val center = Vectors.sparse(n, Seq((0, 1.0), (1, 3.0), (2, 4.0)))
150151

151152
var model = KMeans.train(data, k = 1, maxIterations = 1)
152-
assert(model.clusterCenters.head === center)
153+
assert(model.clusterCenters.head ~== center absTol 1E-5)
153154

154155
model = KMeans.train(data, k = 1, maxIterations = 2)
155-
assert(model.clusterCenters.head === center)
156+
assert(model.clusterCenters.head ~== center absTol 1E-5)
156157

157158
model = KMeans.train(data, k = 1, maxIterations = 5)
158-
assert(model.clusterCenters.head === center)
159+
assert(model.clusterCenters.head ~== center absTol 1E-5)
159160

160161
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
161-
assert(model.clusterCenters.head === center)
162+
assert(model.clusterCenters.head ~== center absTol 1E-5)
162163

163164
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 5)
164-
assert(model.clusterCenters.head === center)
165+
assert(model.clusterCenters.head ~== center absTol 1E-5)
165166

166167
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1, initializationMode = RANDOM)
167-
assert(model.clusterCenters.head === center)
168+
assert(model.clusterCenters.head ~== center absTol 1E-5)
168169

169170
model = KMeans.train(data, k = 1, maxIterations = 1, runs = 1,
170171
initializationMode = K_MEANS_PARALLEL)
171-
assert(model.clusterCenters.head === center)
172+
assert(model.clusterCenters.head ~== center absTol 1E-5)
172173

173174
data.unpersist()
174175
}
175176

176177
test("k-means|| initialization") {
178+
179+
case class VectorWithCompare(x: Vector) extends Ordered[VectorWithCompare] {
180+
@Override def compare(that: VectorWithCompare): Int = {
181+
if(this.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x) >
182+
that.x.toArray.foldLeft[Double](0.0)((acc, x) => acc + x * x)) -1 else 1
183+
}
184+
}
185+
177186
val points = Seq(
178187
Vectors.dense(1.0, 2.0, 6.0),
179188
Vectors.dense(1.0, 3.0, 0.0),
@@ -188,15 +197,19 @@ class KMeansSuite extends FunSuite with LocalSparkContext {
188197
// unselected point as long as it hasn't yet selected all of them
189198

190199
var model = KMeans.train(rdd, k = 5, maxIterations = 1)
191-
assert(Set(model.clusterCenters: _*) === Set(points: _*))
200+
201+
assert(model.clusterCenters.sortBy(VectorWithCompare(_))
202+
.zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5))
192203

193204
// Iterations of Lloyd's should not change the answer either
194205
model = KMeans.train(rdd, k = 5, maxIterations = 10)
195-
assert(Set(model.clusterCenters: _*) === Set(points: _*))
206+
assert(model.clusterCenters.sortBy(VectorWithCompare(_))
207+
.zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5))
196208

197209
// Neither should more runs
198210
model = KMeans.train(rdd, k = 5, maxIterations = 10, runs = 5)
199-
assert(Set(model.clusterCenters: _*) === Set(points: _*))
211+
assert(model.clusterCenters.sortBy(VectorWithCompare(_))
212+
.zip(points.sortBy(VectorWithCompare(_))).forall(x => x._1 ~== (x._2) absTol 1E-5))
200213
}
201214

202215
test("two clusters") {

mllib/src/test/scala/org/apache/spark/mllib/evaluation/AreaUnderCurveSuite.scala

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,27 +20,28 @@ package org.apache.spark.mllib.evaluation
2020
import org.scalatest.FunSuite
2121

2222
import org.apache.spark.mllib.util.LocalSparkContext
23+
import org.apache.spark.mllib.util.TestingUtils._
2324

2425
class AreaUnderCurveSuite extends FunSuite with LocalSparkContext {
2526
test("auc computation") {
2627
val curve = Seq((0.0, 0.0), (1.0, 1.0), (2.0, 3.0), (3.0, 0.0))
2728
val auc = 4.0
28-
assert(AreaUnderCurve.of(curve) === auc)
29+
assert(AreaUnderCurve.of(curve) ~== auc absTol 1E-5)
2930
val rddCurve = sc.parallelize(curve, 2)
30-
assert(AreaUnderCurve.of(rddCurve) == auc)
31+
assert(AreaUnderCurve.of(rddCurve) ~== auc absTol 1E-5)
3132
}
3233

3334
test("auc of an empty curve") {
3435
val curve = Seq.empty[(Double, Double)]
35-
assert(AreaUnderCurve.of(curve) === 0.0)
36+
assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5)
3637
val rddCurve = sc.parallelize(curve, 2)
37-
assert(AreaUnderCurve.of(rddCurve) === 0.0)
38+
assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5)
3839
}
3940

4041
test("auc of a curve with a single point") {
4142
val curve = Seq((1.0, 1.0))
42-
assert(AreaUnderCurve.of(curve) === 0.0)
43+
assert(AreaUnderCurve.of(curve) ~== 0.0 absTol 1E-5)
4344
val rddCurve = sc.parallelize(curve, 2)
44-
assert(AreaUnderCurve.of(rddCurve) === 0.0)
45+
assert(AreaUnderCurve.of(rddCurve) ~== 0.0 absTol 1E-5)
4546
}
4647
}

mllib/src/test/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetricsSuite.scala

Lines changed: 15 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,14 @@ package org.apache.spark.mllib.evaluation
2020
import org.scalatest.FunSuite
2121

2222
import org.apache.spark.mllib.util.LocalSparkContext
23-
import org.apache.spark.mllib.util.TestingUtils.DoubleWithAlmostEquals
23+
import org.apache.spark.mllib.util.TestingUtils._
2424

2525
class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
2626

27-
// TODO: move utility functions to TestingUtils.
27+
def cond1(x: (Double, Double)): Boolean = x._1 ~= (x._2) absTol 1E-5
2828

29-
def elementsAlmostEqual(actual: Seq[Double], expected: Seq[Double]): Boolean = {
30-
actual.zip(expected).forall { case (x1, x2) =>
31-
x1.almostEquals(x2)
32-
}
33-
}
34-
35-
def elementsAlmostEqual(
36-
actual: Seq[(Double, Double)],
37-
expected: Seq[(Double, Double)])(implicit dummy: DummyImplicit): Boolean = {
38-
actual.zip(expected).forall { case ((x1, y1), (x2, y2)) =>
39-
x1.almostEquals(x2) && y1.almostEquals(y2)
40-
}
41-
}
29+
def cond2(x: ((Double, Double), (Double, Double))): Boolean =
30+
(x._1._1 ~= x._2._1 absTol 1E-5) && (x._1._2 ~= x._2._2 absTol 1E-5)
4231

4332
test("binary evaluation metrics") {
4433
val scoreAndLabels = sc.parallelize(
@@ -57,16 +46,17 @@ class BinaryClassificationMetricsSuite extends FunSuite with LocalSparkContext {
5746
val rocCurve = Seq((0.0, 0.0)) ++ fpr.zip(recall) ++ Seq((1.0, 1.0))
5847
val pr = recall.zip(precision)
5948
val prCurve = Seq((0.0, 1.0)) ++ pr
60-
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r) }
49+
val f1 = pr.map { case (r, p) => 2.0 * (p * r) / (p + r)}
6150
val f2 = pr.map { case (r, p) => 5.0 * (p * r) / (4.0 * p + r)}
62-
assert(elementsAlmostEqual(metrics.thresholds().collect(), threshold))
63-
assert(elementsAlmostEqual(metrics.roc().collect(), rocCurve))
64-
assert(metrics.areaUnderROC().almostEquals(AreaUnderCurve.of(rocCurve)))
65-
assert(elementsAlmostEqual(metrics.pr().collect(), prCurve))
66-
assert(metrics.areaUnderPR().almostEquals(AreaUnderCurve.of(prCurve)))
67-
assert(elementsAlmostEqual(metrics.fMeasureByThreshold().collect(), threshold.zip(f1)))
68-
assert(elementsAlmostEqual(metrics.fMeasureByThreshold(2.0).collect(), threshold.zip(f2)))
69-
assert(elementsAlmostEqual(metrics.precisionByThreshold().collect(), threshold.zip(precision)))
70-
assert(elementsAlmostEqual(metrics.recallByThreshold().collect(), threshold.zip(recall)))
51+
52+
assert(metrics.thresholds().collect().zip(threshold).forall(cond1))
53+
assert(metrics.roc().collect().zip(rocCurve).forall(cond2))
54+
assert(metrics.areaUnderROC() ~== AreaUnderCurve.of(rocCurve) absTol 1E-5)
55+
assert(metrics.pr().collect().zip(prCurve).forall(cond2))
56+
assert(metrics.areaUnderPR() ~== AreaUnderCurve.of(prCurve) absTol 1E-5)
57+
assert(metrics.fMeasureByThreshold().collect().zip(threshold.zip(f1)).forall(cond2))
58+
assert(metrics.fMeasureByThreshold(2.0).collect().zip(threshold.zip(f2)).forall(cond2))
59+
assert(metrics.precisionByThreshold().collect().zip(threshold.zip(precision)).forall(cond2))
60+
assert(metrics.recallByThreshold().collect().zip(threshold.zip(recall)).forall(cond2))
7161
}
7262
}

mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.scalatest.{FunSuite, Matchers}
2525
import org.apache.spark.mllib.linalg.Vectors
2626
import org.apache.spark.mllib.regression._
2727
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
28+
import org.apache.spark.mllib.util.TestingUtils._
2829

2930
object GradientDescentSuite {
3031

@@ -126,19 +127,14 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with Matchers
126127
val (newWeights1, loss1) = GradientDescent.runMiniBatchSGD(
127128
dataRDD, gradient, updater, 1, 1, regParam1, 1.0, initialWeightsWithIntercept)
128129

129-
def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = {
130-
math.abs(x - y) / (math.abs(y) + 1e-15) < tol
131-
}
132-
133-
assert(compareDouble(
134-
loss1(0),
135-
loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) +
136-
math.pow(initialWeightsWithIntercept(1), 2)) / 2),
130+
assert(
131+
loss1(0) ~= (loss0(0) + (math.pow(initialWeightsWithIntercept(0), 2) +
132+
math.pow(initialWeightsWithIntercept(1), 2)) / 2) absTol 1E-5,
137133
"""For non-zero weights, the regVal should be \frac{1}{2}\sum_i w_i^2.""")
138134

139135
assert(
140-
compareDouble(newWeights1(0) , newWeights0(0) - initialWeightsWithIntercept(0)) &&
141-
compareDouble(newWeights1(1) , newWeights0(1) - initialWeightsWithIntercept(1)),
136+
(newWeights1(0) ~= (newWeights0(0) - initialWeightsWithIntercept(0)) absTol 1E-5) &&
137+
(newWeights1(1) ~= (newWeights0(1) - initialWeightsWithIntercept(1)) absTol 1E-5),
142138
"The different between newWeights with/without regularization " +
143139
"should be initialWeightsWithIntercept.")
144140
}

mllib/src/test/scala/org/apache/spark/mllib/optimization/LBFGSSuite.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import org.scalatest.{FunSuite, Matchers}
2424
import org.apache.spark.mllib.linalg.Vectors
2525
import org.apache.spark.mllib.regression.LabeledPoint
2626
import org.apache.spark.mllib.util.{LocalClusterSparkContext, LocalSparkContext}
27+
import org.apache.spark.mllib.util.TestingUtils._
2728

2829
class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
2930

@@ -49,10 +50,6 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
4950

5051
lazy val dataRDD = sc.parallelize(data, 2).cache()
5152

52-
def compareDouble(x: Double, y: Double, tol: Double = 1E-3): Boolean = {
53-
math.abs(x - y) / (math.abs(y) + 1e-15) < tol
54-
}
55-
5653
test("LBFGS loss should be decreasing and match the result of Gradient Descent.") {
5754
val regParam = 0
5855

@@ -126,15 +123,15 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
126123
miniBatchFrac,
127124
initialWeightsWithIntercept)
128125

129-
assert(compareDouble(lossGD(0), lossLBFGS(0)),
126+
assert(lossGD(0) ~= lossLBFGS(0) absTol 1E-5,
130127
"The first losses of LBFGS and GD should be the same.")
131128

132129
// The 2% difference here is based on observation, but is not theoretically guaranteed.
133-
assert(compareDouble(lossGD.last, lossLBFGS.last, 0.02),
130+
assert(lossGD.last ~= lossLBFGS.last relTol 0.02,
134131
"The last losses of LBFGS and GD should be within 2% difference.")
135132

136-
assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) &&
137-
compareDouble(weightLBFGS(1), weightGD(1), 0.02),
133+
assert(
134+
(weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02),
138135
"The weight differences between LBFGS and GD should be within 2%.")
139136
}
140137

@@ -226,8 +223,8 @@ class LBFGSSuite extends FunSuite with LocalSparkContext with Matchers {
226223
initialWeightsWithIntercept)
227224

228225
// for class LBFGS and the optimize method, we only look at the weights
229-
assert(compareDouble(weightLBFGS(0), weightGD(0), 0.02) &&
230-
compareDouble(weightLBFGS(1), weightGD(1), 0.02),
226+
assert(
227+
(weightLBFGS(0) ~= weightGD(0) relTol 0.02) && (weightLBFGS(1) ~= weightGD(1) relTol 0.02),
231228
"The weight differences between LBFGS and GD should be within 2%.")
232229
}
233230
}

mllib/src/test/scala/org/apache/spark/mllib/optimization/NNLSSuite.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ import scala.util.Random
2121

2222
import org.scalatest.FunSuite
2323

24-
import org.jblas.{DoubleMatrix, SimpleBlas, NativeBlas}
24+
import org.jblas.{DoubleMatrix, SimpleBlas}
25+
26+
import org.apache.spark.mllib.util.TestingUtils._
2527

2628
class NNLSSuite extends FunSuite {
2729
/** Generate an NNLS problem whose optimal solution is the all-ones vector. */
@@ -73,7 +75,7 @@ class NNLSSuite extends FunSuite {
7375
val ws = NNLS.createWorkspace(n)
7476
val x = NNLS.solve(ata, atb, ws)
7577
for (i <- 0 until n) {
76-
assert(Math.abs(x(i) - goodx(i)) < 1e-3)
78+
assert(x(i) ~== goodx(i) absTol 1E-3)
7779
assert(x(i) >= 0)
7880
}
7981
}

0 commit comments

Comments
 (0)