Skip to content

Commit 1ff17c2

Browse files
committed
Make the seed random for HasSeed in python
1 parent bec938f commit 1ff17c2

File tree

5 files changed

+44
-24
lines changed

5 files changed

+44
-24
lines changed

python/pyspark/ml/feature.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
790790

791791
@keyword_only
792792
def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
793-
seed=42, inputCol=None, outputCol=None):
793+
seed=None, inputCol=None, outputCol=None):
794794
"""
795795
__init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
796796
seed=42, inputCol=None, outputCol=None)
@@ -810,9 +810,9 @@ def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025,
810810

811811
@keyword_only
812812
def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
813-
seed=42, inputCol=None, outputCol=None):
813+
seed=None, inputCol=None, outputCol=None):
814814
"""
815-
setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=42,
815+
setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None,
816816
inputCol=None, outputCol=None)
817817
Sets params for this Word2Vec.
818818
"""

python/pyspark/ml/param/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def isDefined(self, param):
123123
def getOrDefault(self, param):
124124
"""
125125
Gets the value of a param in the user-supplied param map or its
126-
default value. Raises an error if either is set.
126+
default value. Raises an error if neither is set.
127127
"""
128128
if isinstance(param, Param):
129129
if param in self.paramMap:
@@ -135,6 +135,7 @@ def getOrDefault(self, param):
135135
else:
136136
raise KeyError("Cannot recognize %r as a param." % param)
137137

138+
138139
def extractParamMap(self, extraParamMap={}):
139140
"""
140141
Extracts the embedded default param values and user-supplied

python/pyspark/ml/param/_shared_params_code_gen.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def __init__(self):
5757
super(Has$Name, self).__init__()
5858
#: param for $doc
5959
self.$name = Param(self, "$name", "$doc")
60-
if $defaultValueStr is not None:
61-
self._setDefault($name=$defaultValueStr)'''
60+
x = $defaultValueStr
61+
if x is not None:
62+
self._setDefault($name=x)'''
6263

6364
Name = name[0].upper() + name[1:]
6465
return template \
@@ -102,7 +103,9 @@ def get$Name(self):
102103
if __name__ == "__main__":
103104
print(header)
104105
print("\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n")
105-
print("from pyspark.ml.param import Param, Params\n\n")
106+
print("from pyspark.ml.param import Param, Params\n")
107+
print("import random\n")
108+
print("import sys\n\n")
106109
shared = [
107110
("maxIter", "max number of iterations (>= 0)", None),
108111
("regParam", "regularization parameter (>= 0)", None),
@@ -115,7 +118,7 @@ def get$Name(self):
115118
("outputCol", "output column name", None),
116119
("numFeatures", "number of features", None),
117120
("checkpointInterval", "checkpoint interval (>= 1)", None),
118-
("seed", "random seed", None),
121+
("seed", "random seed", "random.randint(0, sys.maxsize)"),
119122
("tol", "the convergence tolerance for iterative algorithms", None),
120123
("stepSize", "Step size to be used for each iteration of optimization.", None)]
121124
code = []

python/pyspark/ml/param/shared.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919

2020
from pyspark.ml.param import Param, Params
2121

22+
import random
23+
24+
import sys
25+
2226

2327
class HasMaxIter(Params):
2428
"""
@@ -351,8 +355,8 @@ def __init__(self):
351355
super(HasSeed, self).__init__()
352356
#: param for random seed
353357
self.seed = Param(self, "seed", "random seed")
354-
if None is not None:
355-
self._setDefault(seed=None)
358+
if random.randint(0, sys.maxsize) is not None:
359+
self._setDefault(seed=random.randint(0, sys.maxsize))
356360

357361
def setSeed(self, value):
358362
"""
@@ -438,22 +442,17 @@ class DecisionTreeParams(Params):
438442
minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.")
439443
maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
440444
cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")
445+
441446

442447
def __init__(self):
443448
super(DecisionTreeParams, self).__init__()
444-
#: param for Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
445449
self.maxDepth = Param(self, "maxDepth", "Maximum depth of the tree. (>= 0) E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.")
446-
#: param for Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.
447450
self.maxBins = Param(self, "maxBins", "Max number of bins for discretizing continuous features. Must be >=2 and >= number of categories for any categorical feature.")
448-
#: param for Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.
449451
self.minInstancesPerNode = Param(self, "minInstancesPerNode", "Minimum number of instances each child must have after split. If a split causes the left or right child to have fewer than minInstancesPerNode, the split will be discarded as invalid. Should be >= 1.")
450-
#: param for Minimum information gain for a split to be considered at a tree node.
451452
self.minInfoGain = Param(self, "minInfoGain", "Minimum information gain for a split to be considered at a tree node.")
452-
#: param for Maximum memory in MB allocated to histogram aggregation.
453453
self.maxMemoryInMB = Param(self, "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
454-
#: param for If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.
455454
self.cacheNodeIds = Param(self, "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")
456-
455+
457456
def setMaxDepth(self, value):
458457
"""
459458
Sets the value of :py:attr:`maxDepth`.

python/pyspark/ml/tests.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
3434
from pyspark.sql import DataFrame
3535
from pyspark.ml.param import Param
36-
from pyspark.ml.param.shared import HasMaxIter, HasInputCol
36+
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
3737
from pyspark.ml.pipeline import Estimator, Model, Pipeline, Transformer
3838

3939

@@ -112,15 +112,15 @@ def test_pipeline(self):
112112
self.assertEqual(6, dataset.index)
113113

114114

115-
class TestParams(HasMaxIter, HasInputCol):
115+
class TestParams(HasMaxIter, HasInputCol, HasSeed):
116116
"""
117-
A subclass of Params mixed with HasMaxIter and HasInputCol.
117+
A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
118118
"""
119119

120-
def __init__(self):
120+
def __init__(self, seed=None):
121121
super(TestParams, self).__init__()
122122
self._setDefault(maxIter=10)
123-
123+
self._set(seed=seed)
124124

125125
class ParamTests(PySparkTestCase):
126126

@@ -135,9 +135,10 @@ def test_params(self):
135135
testParams = TestParams()
136136
maxIter = testParams.maxIter
137137
inputCol = testParams.inputCol
138+
seed = testParams.seed
138139

139140
params = testParams.params
140-
self.assertEqual(params, [inputCol, maxIter])
141+
self.assertEqual(params, [inputCol, maxIter, seed])
141142

142143
self.assertTrue(testParams.hasDefault(maxIter))
143144
self.assertFalse(testParams.isSet(maxIter))
@@ -153,10 +154,26 @@ def test_params(self):
153154
with self.assertRaises(KeyError):
154155
testParams.getInputCol()
155156

157+
# Since the default is normally random, set it to a known number for debug str
158+
testParams._setDefault(seed=41)
159+
testParams.setSeed(43)
160+
156161
self.assertEquals(
157162
testParams.explainParams(),
158163
"\n".join(["inputCol: input column name (undefined)",
159-
"maxIter: max number of iterations (>= 0) (default: 10, current: 100)"]))
164+
"maxIter: max number of iterations (>= 0) (default: 10, current: 100)",
165+
"seed: random seed (default: 41, current: 43)"]))
166+
167+
def test_hasseed(self):
168+
noSeedSpecd = TestParams()
169+
withSeedSpecd = TestParams(seed = 42)
170+
# Check that we no longer use 42 as the magic number
171+
self.assertNotEqual(noSeedSpecd.getSeed(), 42)
172+
origSeed = noSeedSpecd.getSeed()
173+
# Check that we only compute the seed once
174+
self.assertEqual(noSeedSpecd.getSeed(), origSeed)
175+
# Check that a specified seed is honored
176+
self.assertEqual(withSeedSpecd.getSeed(), 42)
160177

161178

162179
if __name__ == "__main__":

0 commit comments

Comments
 (0)