@@ -460,6 +460,17 @@ def _dominance_constraints_from_feature_configs(feature_configs):
460
460
return monotonic_dominances
461
461
462
462
463
+ def _canonical_feature_names (model_config , feature_names = None ):
464
+ if feature_names is not None :
465
+ return feature_names
466
+ if model_config .feature_configs is None :
467
+ raise ValueError (
468
+ 'Feature configs must be specified if feature names are not provided.' )
469
+ return [
470
+ feature_config .name for feature_config in model_config .feature_configs
471
+ ]
472
+
473
+
463
474
def build_linear_layer (linear_input , feature_configs , model_config ,
464
475
weighted_average , submodel_index , dtype ):
465
476
"""Creates a `tfl.layers.Linear` layer initialized to be an average.
@@ -937,15 +948,7 @@ def set_random_lattice_ensemble(model_config, feature_names=None):
937
948
.format (type (model_config )))
938
949
if model_config .lattices != 'random' :
939
950
raise ValueError ('model_config.lattices must be set to \' random\' .' )
940
- # Extract feature names
941
- if feature_names is None :
942
- if model_config .feature_configs is None :
943
- raise ValueError (
944
- 'Feature configs must be specified if feature names are not provided.'
945
- )
946
- feature_names = [
947
- feature_config .name for feature_config in model_config .feature_configs
948
- ]
951
+ feature_names = _canonical_feature_names (model_config , feature_names )
949
952
# Start by using each feature once.
950
953
np .random .seed (model_config .random_seed )
951
954
model_config .lattices = [[] for _ in range (model_config .num_lattices )]
@@ -1032,15 +1035,7 @@ def construct_prefitting_model_config(model_config, feature_names=None):
1032
1035
.format (type (model_config )))
1033
1036
if model_config .lattices != 'crystals' :
1034
1037
raise ValueError ('model_config.lattices must be set to \' crystals\' .' )
1035
- # Extract feature names from model_config if not provided.
1036
- if feature_names is None :
1037
- if model_config .feature_configs is None :
1038
- raise ValueError (
1039
- 'Feature configs must be specified if feature names are not provided.'
1040
- )
1041
- feature_names = [
1042
- feature_config .name for feature_config in model_config .feature_configs
1043
- ]
1038
+ feature_names = _canonical_feature_names (model_config , feature_names )
1044
1039
1045
1040
# Make a copy of the model config provided and set all pairs covered.
1046
1041
prefitting_model_config = copy .deepcopy (model_config )
@@ -1330,14 +1325,7 @@ def set_crystals_lattice_ensemble(model_config,
1330
1325
# the proper type will have undefined behavior.
1331
1326
# To perform this check, we must first extract feature names if they are not
1332
1327
# provided, which we need for later steps anyway.
1333
- if feature_names is None :
1334
- if model_config .feature_configs is None :
1335
- raise ValueError (
1336
- 'Feature configs must be specified if feature names are not provided.'
1337
- )
1338
- feature_names = [
1339
- feature_config .name for feature_config in model_config .feature_configs
1340
- ]
1328
+ feature_names = _canonical_feature_names (model_config , feature_names )
1341
1329
_verify_prefitting_model (prefitting_model , feature_names )
1342
1330
1343
1331
# Now we can extract the crystals and finalize model_config.
@@ -1351,6 +1339,232 @@ def set_crystals_lattice_ensemble(model_config,
1351
1339
] for lattice in lattices ]
1352
1340
1353
1341
1342
+ def _weighted_quantile (sorted_values , quantiles , weights ):
1343
+ """Calculates weighted quantiles of the given sorted and unique values."""
1344
+ if len (sorted_values ) < len (quantiles ):
1345
+ raise ValueError (
1346
+ 'Not enough unique values ({}) to calculate {} quantiles.' .format (
1347
+ len (sorted_values ), len (quantiles )))
1348
+ # Weighted quantiles of the observed (sorted) values.
1349
+ # Weights are spread equaly before and after the observed values.
1350
+ weighted_quantiles = (np .cumsum (weights ) - 0.5 * weights ) / np .sum (weights )
1351
+
1352
+ # Use linear interpolation to find index of the quantile values.
1353
+ index_values = np .arange (len (sorted_values ))
1354
+ quantiles_idx = np .interp (x = quantiles , xp = weighted_quantiles , fp = index_values )
1355
+ quantiles_idx = np .rint (quantiles_idx ).astype (int )
1356
+
1357
+ # Replace repeated quantile values with neighbouring values.
1358
+ unique_idx , first_use = np .unique (quantiles_idx , return_index = True )
1359
+ used_idx = set (unique_idx )
1360
+ num_values = len (sorted_values )
1361
+ for i in range (len (quantiles_idx )):
1362
+ if i not in first_use :
1363
+ # Since this is not the first use of a (repeated) quantile value, we will
1364
+ # need to find an unused neighbouring value.
1365
+ for delta , direction in itertools .product (range (1 , num_values ), [- 1 , 1 ]):
1366
+ candidate_idx = quantiles_idx [i ] + direction * delta
1367
+ if (candidate_idx >= 0 and candidate_idx < num_values and
1368
+ candidate_idx not in used_idx ):
1369
+ used_idx .add (candidate_idx )
1370
+ quantiles_idx [i ] = candidate_idx
1371
+ break
1372
+ quantiles_idx = np .sort (quantiles_idx )
1373
+
1374
+ return sorted_values [quantiles_idx ]
1375
+
1376
+
1377
+ def compute_keypoints (values ,
1378
+ num_keypoints ,
1379
+ keypoints = 'quantiles' ,
1380
+ clip_min = None ,
1381
+ clip_max = None ,
1382
+ default_value = None ,
1383
+ weights = None ,
1384
+ weight_reduction = 'mean' ,
1385
+ feature_name = '' ):
1386
+ """Calculates keypoints for the given set of values.
1387
+
1388
+ Args:
1389
+ values: Values to use for quantile calculation.
1390
+ num_keypoints: Number of keypoints to compute.
1391
+ keypoints: String `'quantiles'` or `'uniform'`.
1392
+ clip_min: Input values are lower clipped by this value.
1393
+ clip_max: Input values are upper clipped by this value.
1394
+ default_value: If provided, occurances will be removed from values.
1395
+ weights: Weights to be used for quantile calculation.
1396
+ weight_reduction: Reduction applied to weights for repeated values. Must be
1397
+ either 'mean' or 'sum'.
1398
+ feature_name: Name to use for error logs.
1399
+
1400
+ Returns:
1401
+ A list of keypoints of `num_keypoints` length.
1402
+ """
1403
+ # Remove default values before calculating stats.
1404
+ non_default_idx = values != default_value
1405
+ values = values [non_default_idx ]
1406
+ if weights is not None :
1407
+ weights = weights [non_default_idx ]
1408
+
1409
+ # Clip min and max if requested. Note that we add clip bounds to the values
1410
+ # so that the first and last keypoints are set to those values.
1411
+ if clip_min is not None :
1412
+ values = np .maximum (values , clip_min )
1413
+ values = np .append (values , clip_min )
1414
+ if weights is not None :
1415
+ weights = np .append (weights , 0 )
1416
+ if clip_max is not None :
1417
+ values = np .minimum (values , clip_max )
1418
+ values = np .append (values , clip_max )
1419
+ if weights is not None :
1420
+ weights = np .append (weights , 0 )
1421
+
1422
+ # We do not allow nans in the data, even as default_value.
1423
+ if np .isnan (values ).any ():
1424
+ raise ValueError (
1425
+ 'NaN values were observed for numeric feature `{}`. '
1426
+ 'Consider replacing the values in transform or input_fn.' .format (
1427
+ feature_name ))
1428
+
1429
+ # Remove duplicates and sort value before calculating stats.
1430
+ # This is emperically useful as we use of keypoints more efficiently.
1431
+ if weights is None :
1432
+ sorted_values = np .unique (values )
1433
+ else :
1434
+ # First sort the values and reorder weights.
1435
+ idx = np .argsort (values )
1436
+ values = values [idx ]
1437
+ weights = weights [idx ]
1438
+
1439
+ # Set the weight of each unique element to be the sum or average of the
1440
+ # weights of repeated instances. Using 'mean' reduction results in parity
1441
+ # between unweighted calculation and having equal weights for all values.
1442
+ sorted_values , idx , counts = np .unique (
1443
+ values , return_index = True , return_counts = True )
1444
+ weights = np .add .reduceat (weights , idx )
1445
+ if weight_reduction == 'mean' :
1446
+ weights = weights / counts
1447
+ elif weight_reduction != 'sum' :
1448
+ raise ValueError ('Invalid weight reduction: {}' .format (weight_reduction ))
1449
+
1450
+ if keypoints == 'quantiles' :
1451
+ if sorted_values .size < num_keypoints :
1452
+ logging .info (
1453
+ 'Not enough unique values observed for feature `%s` to '
1454
+ 'construct %d keypoints for pwl calibration. Using %d unique '
1455
+ 'values as keypoints.' , feature_name , num_keypoints ,
1456
+ sorted_values .size )
1457
+ return sorted_values .astype (float )
1458
+
1459
+ quantiles = np .linspace (0. , 1. , num_keypoints )
1460
+ if weights is not None :
1461
+ return _weighted_quantile (
1462
+ sorted_values = sorted_values , quantiles = quantiles ,
1463
+ weights = weights ).astype (float )
1464
+ else :
1465
+ return np .quantile (
1466
+ sorted_values , quantiles , interpolation = 'nearest' ).astype (float )
1467
+
1468
+ elif keypoints == 'uniform' :
1469
+ return np .linspace (sorted_values [0 ], sorted_values [- 1 ], num_keypoints )
1470
+ else :
1471
+ raise ValueError ('Invalid keypoint generation mode: {}' .format (keypoints ))
1472
+
1473
+
1474
+ def _feature_config_by_name (feature_configs , feature_name , add_if_missing ):
1475
+ """Returns feature_config with the given name."""
1476
+ for feature_config in feature_configs :
1477
+ if feature_config .name == feature_name :
1478
+ return feature_config
1479
+ # Use the default FeatureConfig if not present.
1480
+ feature_config = configs .FeatureConfig (feature_name )
1481
+ if add_if_missing :
1482
+ feature_configs .append (feature_config )
1483
+ return feature_config
1484
+
1485
+
1486
+ def compute_feature_keypoints (feature_configs ,
1487
+ features ,
1488
+ weights = None ,
1489
+ weight_reduction = 'mean' ):
1490
+ """Computes feature keypoints with the data provide in `features` dict."""
1491
+ # Calculate feature keypoitns.
1492
+ feature_keypoints = {}
1493
+ for feature_name , values in six .iteritems (features ):
1494
+ feature_config = _feature_config_by_name (
1495
+ feature_configs = feature_configs ,
1496
+ feature_name = feature_name ,
1497
+ add_if_missing = False )
1498
+
1499
+ if feature_config .num_buckets :
1500
+ # Skip categorical features.
1501
+ continue
1502
+ if isinstance (feature_config .pwl_calibration_input_keypoints , str ):
1503
+ feature_keypoints [feature_name ] = compute_keypoints (
1504
+ values ,
1505
+ num_keypoints = feature_config .pwl_calibration_num_keypoints ,
1506
+ keypoints = feature_config .pwl_calibration_input_keypoints ,
1507
+ clip_min = feature_config .pwl_calibration_clip_min ,
1508
+ clip_max = feature_config .pwl_calibration_clip_max ,
1509
+ weights = weights ,
1510
+ weight_reduction = weight_reduction ,
1511
+ feature_name = feature_name ,
1512
+ )
1513
+ else :
1514
+ # User-specified keypoint values.
1515
+ feature_keypoints [
1516
+ feature_name ] = feature_config .pwl_calibration_input_keypoints
1517
+ return feature_keypoints
1518
+
1519
+
1520
+ def set_feature_keypoints (feature_configs , feature_keypoints ,
1521
+ add_missing_feature_configs ):
1522
+ """Updates the feature configs with provided keypoints."""
1523
+ for feature_name , keypoints in six .iteritems (feature_keypoints ):
1524
+ feature_config = _feature_config_by_name (
1525
+ feature_configs = feature_configs ,
1526
+ feature_name = feature_name ,
1527
+ add_if_missing = add_missing_feature_configs )
1528
+ feature_config .pwl_calibration_input_keypoints = keypoints
1529
+
1530
+
1531
+ def compute_label_keypoints (model_config ,
1532
+ labels ,
1533
+ logits_output ,
1534
+ weights = None ,
1535
+ weight_reduction = 'mean' ):
1536
+ """Computes label keypoints with the data provide in `lables` array."""
1537
+ if not np .issubdtype (labels [0 ], np .number ):
1538
+ # Default feature_values to [0, ... n_class-1] for string labels.
1539
+ labels = np .arange (len (set (labels )))
1540
+ weights = None
1541
+
1542
+ if isinstance (model_config .output_initialization , str ):
1543
+ # If model is expected to produce logits, initialize linearly in the
1544
+ # range [-2, 2], ignoring the label distribution.
1545
+ if logits_output :
1546
+ return np .linspace (- 2 , 2 , model_config .output_calibration_num_keypoints )
1547
+
1548
+ return compute_keypoints (
1549
+ labels ,
1550
+ num_keypoints = model_config .output_calibration_num_keypoints ,
1551
+ keypoints = model_config .output_initialization ,
1552
+ clip_min = model_config .output_min ,
1553
+ clip_max = model_config .output_max ,
1554
+ weights = weights ,
1555
+ weight_reduction = weight_reduction ,
1556
+ feature_name = 'label' ,
1557
+ )
1558
+ else :
1559
+ # User-specified keypoint values.
1560
+ return model_config .output_initialization
1561
+
1562
+
1563
+ def set_label_keypoints (model_config , label_keypoints ):
1564
+ """Updates the label keypoints in the `model_config`."""
1565
+ model_config .output_initialization = label_keypoints
1566
+
1567
+
1354
1568
def _verify_ensemble_config (model_config ):
1355
1569
"""Verifies that an ensemble model and feature configs are properly specified.
1356
1570
0 commit comments