@@ -267,6 +267,8 @@ def __init__(self,
267
267
num_lattices = None ,
268
268
lattice_rank = None ,
269
269
interpolation = 'hypercube' ,
270
+ parameterization = 'all_vertices' ,
271
+ num_terms = 2 ,
270
272
separate_calibrators = True ,
271
273
use_linear_combination = False ,
272
274
use_bias = False ,
@@ -305,6 +307,34 @@ def __init__(self,
305
307
'simplex' uses d+1 parameters and thus scales better. For details see
306
308
`tfl.lattice_lib.evaluate_with_simplex_interpolation` and
307
309
`tfl.lattice_lib.evaluate_with_hypercube_interpolation`.
310
+ parameterization: The parameterization of the lattice function class to
311
+ use. A lattice function is uniquely determined by specifying its value
312
+ on every lattice vertex. A parameterization scheme is a mapping from a
313
+ vector of parameters to a multidimensional array of lattice vertex
314
+ values. It can be one of:
315
+ - String `'all_vertices'`: This is the "traditional" parameterization
316
+ that keeps one scalar parameter per lattice vertex where the mapping
317
+ is essentially the identity map. With this scheme, the number of
318
+ parameters scales exponentially with the number of inputs to the
319
+ lattice. The underlying lattices used will be `tfl.layers.Lattice`
320
+ layers.
321
+ - String `'kronecker_factored'`: With this parameterization, for each
322
+ lattice input i we keep a collection of `num_terms` vectors each
323
+ having `feature_configs[0].lattice_size` entries (note that all
324
+ features must have the same lattice size). To obtain the tensor of
325
+ lattice vertex values, for `t=1,2,...,num_terms` we compute the
326
+ outer product of the `t'th` vector in each collection, multiply by a
327
+ per-term scale, and sum the resulting tensors. Finally, we add a
328
+ single shared bias parameter to each entry in the sum. With this
329
+ scheme, the number of parameters grows linearly with `lattice_rank`
330
+ (assuming lattice sizes and `num_terms` are held constant).
331
+ Currently, only monotonicity shape constraint and bound constraint
332
+ are supported for this scheme. Regularization is not currently
333
+ supported. The underlying lattices used will be
334
+ `tfl.layers.KroneckerFactoredLattice` layers.
335
+ num_terms: The number of terms in a lattice using `'kronecker_factored'`
336
+ parameterization. Ignored if parameterization is set to
337
+ `'all_vertices'`.
308
338
separate_calibrators: If features should be separately calibrated for each
309
339
lattice in the ensemble.
310
340
use_linear_combination: If set to true, a linear combination layer will be
@@ -375,12 +405,15 @@ class CalibratedLatticeConfig(_Config, _HasFeatureConfigs,
375
405
def __init__ (self ,
376
406
feature_configs = None ,
377
407
interpolation = 'hypercube' ,
408
+ parameterization = 'all_vertices' ,
409
+ num_terms = 2 ,
378
410
regularizer_configs = None ,
379
411
output_min = None ,
380
412
output_max = None ,
381
413
output_calibration = False ,
382
414
output_calibration_num_keypoints = 10 ,
383
- output_initialization = 'quantiles' ):
415
+ output_initialization = 'quantiles' ,
416
+ random_seed = 0 ):
384
417
"""Initializes a `CalibratedLatticeConfig` instance.
385
418
386
419
Args:
@@ -392,6 +425,34 @@ def __init__(self,
392
425
'simplex' uses d+1 parameters and thus scales better. For details see
393
426
`tfl.lattice_lib.evaluate_with_simplex_interpolation` and
394
427
`tfl.lattice_lib.evaluate_with_hypercube_interpolation`.
428
+ parameterization: The parameterization of the lattice function class to
429
+ use. A lattice function is uniquely determined by specifying its value
430
+ on every lattice vertex. A parameterization scheme is a mapping from a
431
+ vector of parameters to a multidimensional array of lattice vertex
432
+ values. It can be one of:
433
+ - String `'all_vertices'`: This is the "traditional" parameterization
434
+ that keeps one scalar parameter per lattice vertex where the mapping
435
+ is essentially the identity map. With this scheme, the number of
436
+ parameters scales exponentially with the number of inputs to the
437
+ lattice. The underlying lattice used will be a `tfl.layers.Lattice`
438
+ layer.
439
+ - String `'kronecker_factored'`: With this parameterization, for each
440
+ lattice input i we keep a collection of `num_terms` vectors each
441
+ having `feature_configs[0].lattice_size` entries (note that all
442
+ features must have the same lattice size). To obtain the tensor of
443
+ lattice vertex values, for `t=1,2,...,num_terms` we compute the
444
+ outer product of the `t'th` vector in each collection, multiply by a
445
+ per-term scale, and sum the resulting tensors. Finally, we add a
446
+ single shared bias parameter to each entry in the sum. With this
447
+ scheme, the number of parameters grows linearly with
448
+ `len(feature_configs)` (assuming lattice sizes and `num_terms` are
449
+ held constant). Currently, only monotonicity shape constraint and
450
+ bound constraint are supported for this scheme. Regularization is
451
+ not currently supported. The underlying lattice used will be a
452
+ `tfl.layers.KroneckerFactoredLattice` layer.
453
+ num_terms: The number of terms in a lattice using `'kronecker_factored'`
454
+ parameterization. Ignored if parameterization is set to
455
+ `'all_vertices'`.
395
456
regularizer_configs: A list of `tfl.configs.RegularizerConfig` instances
396
457
that apply global regularization.
397
458
output_min: Lower bound constraint on the output of the model.
@@ -410,6 +471,9 @@ def __init__(self,
410
471
- String `'uniform'`: Output is initliazed uniformly in label range.
411
472
- A list of numbers: To be used for initialization of the output
412
473
lattice or output calibrator.
474
+ random_seed: Random seed to use for initialization of a lattice with
475
+ `'kronecker_factored'` parameterization. Ignored if parameterization is
476
+ set to `'all_vertices'`.
413
477
"""
414
478
super (CalibratedLatticeConfig , self ).__init__ (locals ())
415
479
0 commit comments