Skip to content

Commit ef3f186

Browse files
committed
Merge commit for internal changes
2 parents 784eca5 + 68a23a4 commit ef3f186

File tree

8 files changed

+441
-301
lines changed

8 files changed

+441
-301
lines changed
87.7 KB
Loading

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
# This version number should always be that of the *next* (unreleased) version.
2828
# Immediately after uploading a package to PyPI, you should increment the
2929
# version number and push to gitHub.
30-
__version__ = "2.0.9"
30+
__version__ = "2.0.10"
3131

3232
if "--release" in sys.argv:
3333
sys.argv.remove("--release")

tensorflow_lattice/python/BUILD

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ py_library(
5959
srcs_version = "PY2AND3",
6060
deps = [
6161
":internal_utils",
62-
# enum dep,
6362
# tensorflow:tensorflow_no_contrib dep,
6463
],
6564
)
@@ -319,7 +318,6 @@ py_library(
319318
name = "model_info",
320319
srcs = ["model_info.py"],
321320
srcs_version = "PY2AND3",
322-
deps = [],
323321
)
324322

325323
py_library(
@@ -386,7 +384,6 @@ py_library(
386384
":rtl_layer",
387385
":utils",
388386
# absl/logging dep,
389-
# enum dep,
390387
# numpy dep,
391388
# six dep,
392389
# tensorflow dep,
@@ -429,7 +426,6 @@ py_library(
429426
srcs_version = "PY2AND3",
430427
deps = [
431428
":utils",
432-
# enum dep,
433429
# tensorflow:tensorflow_no_contrib dep,
434430
],
435431
)

tensorflow_lattice/python/estimators.py

Lines changed: 107 additions & 118 deletions
Large diffs are not rendered by default.

tensorflow_lattice/python/estimators_test.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -392,14 +392,23 @@ def testCalibratedLatticeClassifier(self, feature_names, output_calibration,
392392
self.assertGreater(results['auc'], auc)
393393

394394
@parameterized.parameters(
395-
(['age', 'sex', 'fbs', 'restecg', 'ca', 'thal'], False, False, 0.7),
395+
(['age', 'sex', 'fbs', 'restecg', 'ca', 'thal'
396+
], False, False, None, None, 'mean', 0.7),
396397
([
397398
'age', 'sex', 'cp', 'trestbps', 'chol', 'fbs', 'restecg', 'thalach',
398399
'exang', 'oldpeak', 'slope', 'ca', 'thal'
399-
], True, True, 0.8),
400+
], True, True, None, None, 'mean', 0.8),
401+
(['age', 'sex', 'fbs', 'restecg', 'ca', 'thal'
402+
], False, False, 'thalach', None, 'mean', 0.7),
403+
(['age', 'sex', 'fbs', 'restecg', 'ca', 'thal'
404+
], False, False, 'thalach', 'thalach', 'mean', 0.7),
405+
(['age', 'sex', 'fbs', 'restecg', 'ca', 'thal'
406+
], False, False, 'thalach', 'thalach', 'sum', 0.7),
400407
)
401408
def testCalibratedLinearClassifier(self, feature_names, output_calibration,
402-
use_bias, auc):
409+
use_bias, weight_column,
410+
feature_analysis_weight_column,
411+
feature_analysis_weight_reduction, auc):
403412
self._ResetAllBackends()
404413
feature_columns = [
405414
feature_column for feature_column in self.heart_feature_columns
@@ -420,6 +429,9 @@ def testCalibratedLinearClassifier(self, feature_names, output_calibration,
420429
feature_columns=feature_columns,
421430
model_config=model_config,
422431
feature_analysis_input_fn=self._GetHeartTrainInputFn(num_epochs=1),
432+
weight_column=weight_column,
433+
feature_analysis_weight_column=feature_analysis_weight_column,
434+
feature_analysis_weight_reduction=feature_analysis_weight_reduction,
423435
optimizer=tf.keras.optimizers.Adam(0.01))
424436
estimator.train(input_fn=self._GetHeartTrainInputFn(num_epochs=200))
425437
results = estimator.evaluate(input_fn=self._GetHeartTestInputFn())

tensorflow_lattice/python/lattice_layer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class Lattice(keras.layers.Layer):
5858
There are several types of constraints on the shape of the learned function
5959
that are either 1 or 2 dimensional:
6060
61+
![Shape constraint visual example images](https://www.tensorflow.org/lattice/images/2d_shape_constraints_picture_color.png)
62+
6163
* **Monotonicity:** constrains the function to be increasing in the
6264
corresponding dimension. To achieve decreasing monotonicity, either pass the
6365
inputs through a `tfl.layers.PWLCalibration` with `decreasing` monotonicity,

tensorflow_lattice/python/premade_lib.py

Lines changed: 240 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -460,6 +460,17 @@ def _dominance_constraints_from_feature_configs(feature_configs):
460460
return monotonic_dominances
461461

462462

463+
def _canonical_feature_names(model_config, feature_names=None):
464+
if feature_names is not None:
465+
return feature_names
466+
if model_config.feature_configs is None:
467+
raise ValueError(
468+
'Feature configs must be specified if feature names are not provided.')
469+
return [
470+
feature_config.name for feature_config in model_config.feature_configs
471+
]
472+
473+
463474
def build_linear_layer(linear_input, feature_configs, model_config,
464475
weighted_average, submodel_index, dtype):
465476
"""Creates a `tfl.layers.Linear` layer initialized to be an average.
@@ -937,15 +948,7 @@ def set_random_lattice_ensemble(model_config, feature_names=None):
937948
.format(type(model_config)))
938949
if model_config.lattices != 'random':
939950
raise ValueError('model_config.lattices must be set to \'random\'.')
940-
# Extract feature names
941-
if feature_names is None:
942-
if model_config.feature_configs is None:
943-
raise ValueError(
944-
'Feature configs must be specified if feature names are not provided.'
945-
)
946-
feature_names = [
947-
feature_config.name for feature_config in model_config.feature_configs
948-
]
951+
feature_names = _canonical_feature_names(model_config, feature_names)
949952
# Start by using each feature once.
950953
np.random.seed(model_config.random_seed)
951954
model_config.lattices = [[] for _ in range(model_config.num_lattices)]
@@ -1032,15 +1035,7 @@ def construct_prefitting_model_config(model_config, feature_names=None):
10321035
.format(type(model_config)))
10331036
if model_config.lattices != 'crystals':
10341037
raise ValueError('model_config.lattices must be set to \'crystals\'.')
1035-
# Extract feature names from model_config if not provided.
1036-
if feature_names is None:
1037-
if model_config.feature_configs is None:
1038-
raise ValueError(
1039-
'Feature configs must be specified if feature names are not provided.'
1040-
)
1041-
feature_names = [
1042-
feature_config.name for feature_config in model_config.feature_configs
1043-
]
1038+
feature_names = _canonical_feature_names(model_config, feature_names)
10441039

10451040
# Make a copy of the model config provided and set all pairs covered.
10461041
prefitting_model_config = copy.deepcopy(model_config)
@@ -1330,14 +1325,7 @@ def set_crystals_lattice_ensemble(model_config,
13301325
# the proper type will have undefined behavior.
13311326
# To perform this check, we must first extract feature names if they are not
13321327
# provided, which we need for later steps anyway.
1333-
if feature_names is None:
1334-
if model_config.feature_configs is None:
1335-
raise ValueError(
1336-
'Feature configs must be specified if feature names are not provided.'
1337-
)
1338-
feature_names = [
1339-
feature_config.name for feature_config in model_config.feature_configs
1340-
]
1328+
feature_names = _canonical_feature_names(model_config, feature_names)
13411329
_verify_prefitting_model(prefitting_model, feature_names)
13421330

13431331
# Now we can extract the crystals and finalize model_config.
@@ -1351,6 +1339,232 @@ def set_crystals_lattice_ensemble(model_config,
13511339
] for lattice in lattices]
13521340

13531341

1342+
def _weighted_quantile(sorted_values, quantiles, weights):
1343+
"""Calculates weighted quantiles of the given sorted and unique values."""
1344+
if len(sorted_values) < len(quantiles):
1345+
raise ValueError(
1346+
'Not enough unique values ({}) to calculate {} quantiles.'.format(
1347+
len(sorted_values), len(quantiles)))
1348+
# Weighted quantiles of the observed (sorted) values.
1349+
# Weights are spread equaly before and after the observed values.
1350+
weighted_quantiles = (np.cumsum(weights) - 0.5 * weights) / np.sum(weights)
1351+
1352+
# Use linear interpolation to find index of the quantile values.
1353+
index_values = np.arange(len(sorted_values))
1354+
quantiles_idx = np.interp(x=quantiles, xp=weighted_quantiles, fp=index_values)
1355+
quantiles_idx = np.rint(quantiles_idx).astype(int)
1356+
1357+
# Replace repeated quantile values with neighbouring values.
1358+
unique_idx, first_use = np.unique(quantiles_idx, return_index=True)
1359+
used_idx = set(unique_idx)
1360+
num_values = len(sorted_values)
1361+
for i in range(len(quantiles_idx)):
1362+
if i not in first_use:
1363+
# Since this is not the first use of a (repeated) quantile value, we will
1364+
# need to find an unused neighbouring value.
1365+
for delta, direction in itertools.product(range(1, num_values), [-1, 1]):
1366+
candidate_idx = quantiles_idx[i] + direction * delta
1367+
if (candidate_idx >= 0 and candidate_idx < num_values and
1368+
candidate_idx not in used_idx):
1369+
used_idx.add(candidate_idx)
1370+
quantiles_idx[i] = candidate_idx
1371+
break
1372+
quantiles_idx = np.sort(quantiles_idx)
1373+
1374+
return sorted_values[quantiles_idx]
1375+
1376+
1377+
def compute_keypoints(values,
1378+
num_keypoints,
1379+
keypoints='quantiles',
1380+
clip_min=None,
1381+
clip_max=None,
1382+
default_value=None,
1383+
weights=None,
1384+
weight_reduction='mean',
1385+
feature_name=''):
1386+
"""Calculates keypoints for the given set of values.
1387+
1388+
Args:
1389+
values: Values to use for quantile calculation.
1390+
num_keypoints: Number of keypoints to compute.
1391+
keypoints: String `'quantiles'` or `'uniform'`.
1392+
clip_min: Input values are lower clipped by this value.
1393+
clip_max: Input values are upper clipped by this value.
1394+
default_value: If provided, occurances will be removed from values.
1395+
weights: Weights to be used for quantile calculation.
1396+
weight_reduction: Reduction applied to weights for repeated values. Must be
1397+
either 'mean' or 'sum'.
1398+
feature_name: Name to use for error logs.
1399+
1400+
Returns:
1401+
A list of keypoints of `num_keypoints` length.
1402+
"""
1403+
# Remove default values before calculating stats.
1404+
non_default_idx = values != default_value
1405+
values = values[non_default_idx]
1406+
if weights is not None:
1407+
weights = weights[non_default_idx]
1408+
1409+
# Clip min and max if requested. Note that we add clip bounds to the values
1410+
# so that the first and last keypoints are set to those values.
1411+
if clip_min is not None:
1412+
values = np.maximum(values, clip_min)
1413+
values = np.append(values, clip_min)
1414+
if weights is not None:
1415+
weights = np.append(weights, 0)
1416+
if clip_max is not None:
1417+
values = np.minimum(values, clip_max)
1418+
values = np.append(values, clip_max)
1419+
if weights is not None:
1420+
weights = np.append(weights, 0)
1421+
1422+
# We do not allow nans in the data, even as default_value.
1423+
if np.isnan(values).any():
1424+
raise ValueError(
1425+
'NaN values were observed for numeric feature `{}`. '
1426+
'Consider replacing the values in transform or input_fn.'.format(
1427+
feature_name))
1428+
1429+
# Remove duplicates and sort value before calculating stats.
1430+
# This is emperically useful as we use of keypoints more efficiently.
1431+
if weights is None:
1432+
sorted_values = np.unique(values)
1433+
else:
1434+
# First sort the values and reorder weights.
1435+
idx = np.argsort(values)
1436+
values = values[idx]
1437+
weights = weights[idx]
1438+
1439+
# Set the weight of each unique element to be the sum or average of the
1440+
# weights of repeated instances. Using 'mean' reduction results in parity
1441+
# between unweighted calculation and having equal weights for all values.
1442+
sorted_values, idx, counts = np.unique(
1443+
values, return_index=True, return_counts=True)
1444+
weights = np.add.reduceat(weights, idx)
1445+
if weight_reduction == 'mean':
1446+
weights = weights / counts
1447+
elif weight_reduction != 'sum':
1448+
raise ValueError('Invalid weight reduction: {}'.format(weight_reduction))
1449+
1450+
if keypoints == 'quantiles':
1451+
if sorted_values.size < num_keypoints:
1452+
logging.info(
1453+
'Not enough unique values observed for feature `%s` to '
1454+
'construct %d keypoints for pwl calibration. Using %d unique '
1455+
'values as keypoints.', feature_name, num_keypoints,
1456+
sorted_values.size)
1457+
return sorted_values.astype(float)
1458+
1459+
quantiles = np.linspace(0., 1., num_keypoints)
1460+
if weights is not None:
1461+
return _weighted_quantile(
1462+
sorted_values=sorted_values, quantiles=quantiles,
1463+
weights=weights).astype(float)
1464+
else:
1465+
return np.quantile(
1466+
sorted_values, quantiles, interpolation='nearest').astype(float)
1467+
1468+
elif keypoints == 'uniform':
1469+
return np.linspace(sorted_values[0], sorted_values[-1], num_keypoints)
1470+
else:
1471+
raise ValueError('Invalid keypoint generation mode: {}'.format(keypoints))
1472+
1473+
1474+
def _feature_config_by_name(feature_configs, feature_name, add_if_missing):
1475+
"""Returns feature_config with the given name."""
1476+
for feature_config in feature_configs:
1477+
if feature_config.name == feature_name:
1478+
return feature_config
1479+
# Use the default FeatureConfig if not present.
1480+
feature_config = configs.FeatureConfig(feature_name)
1481+
if add_if_missing:
1482+
feature_configs.append(feature_config)
1483+
return feature_config
1484+
1485+
1486+
def compute_feature_keypoints(feature_configs,
1487+
features,
1488+
weights=None,
1489+
weight_reduction='mean'):
1490+
"""Computes feature keypoints with the data provide in `features` dict."""
1491+
# Calculate feature keypoitns.
1492+
feature_keypoints = {}
1493+
for feature_name, values in six.iteritems(features):
1494+
feature_config = _feature_config_by_name(
1495+
feature_configs=feature_configs,
1496+
feature_name=feature_name,
1497+
add_if_missing=False)
1498+
1499+
if feature_config.num_buckets:
1500+
# Skip categorical features.
1501+
continue
1502+
if isinstance(feature_config.pwl_calibration_input_keypoints, str):
1503+
feature_keypoints[feature_name] = compute_keypoints(
1504+
values,
1505+
num_keypoints=feature_config.pwl_calibration_num_keypoints,
1506+
keypoints=feature_config.pwl_calibration_input_keypoints,
1507+
clip_min=feature_config.pwl_calibration_clip_min,
1508+
clip_max=feature_config.pwl_calibration_clip_max,
1509+
weights=weights,
1510+
weight_reduction=weight_reduction,
1511+
feature_name=feature_name,
1512+
)
1513+
else:
1514+
# User-specified keypoint values.
1515+
feature_keypoints[
1516+
feature_name] = feature_config.pwl_calibration_input_keypoints
1517+
return feature_keypoints
1518+
1519+
1520+
def set_feature_keypoints(feature_configs, feature_keypoints,
1521+
add_missing_feature_configs):
1522+
"""Updates the feature configs with provided keypoints."""
1523+
for feature_name, keypoints in six.iteritems(feature_keypoints):
1524+
feature_config = _feature_config_by_name(
1525+
feature_configs=feature_configs,
1526+
feature_name=feature_name,
1527+
add_if_missing=add_missing_feature_configs)
1528+
feature_config.pwl_calibration_input_keypoints = keypoints
1529+
1530+
1531+
def compute_label_keypoints(model_config,
1532+
labels,
1533+
logits_output,
1534+
weights=None,
1535+
weight_reduction='mean'):
1536+
"""Computes label keypoints with the data provide in `lables` array."""
1537+
if not np.issubdtype(labels[0], np.number):
1538+
# Default feature_values to [0, ... n_class-1] for string labels.
1539+
labels = np.arange(len(set(labels)))
1540+
weights = None
1541+
1542+
if isinstance(model_config.output_initialization, str):
1543+
# If model is expected to produce logits, initialize linearly in the
1544+
# range [-2, 2], ignoring the label distribution.
1545+
if logits_output:
1546+
return np.linspace(-2, 2, model_config.output_calibration_num_keypoints)
1547+
1548+
return compute_keypoints(
1549+
labels,
1550+
num_keypoints=model_config.output_calibration_num_keypoints,
1551+
keypoints=model_config.output_initialization,
1552+
clip_min=model_config.output_min,
1553+
clip_max=model_config.output_max,
1554+
weights=weights,
1555+
weight_reduction=weight_reduction,
1556+
feature_name='label',
1557+
)
1558+
else:
1559+
# User-specified keypoint values.
1560+
return model_config.output_initialization
1561+
1562+
1563+
def set_label_keypoints(model_config, label_keypoints):
1564+
"""Updates the label keypoints in the `model_config`."""
1565+
model_config.output_initialization = label_keypoints
1566+
1567+
13541568
def _verify_ensemble_config(model_config):
13551569
"""Verifies that an ensemble model and feature configs are properly specified.
13561570

0 commit comments

Comments
 (0)