View source on GitHub |
A preprocessing layer which buckets continuous features by ranges.
Inherits From: PreprocessingLayer
, Layer
, Module
tf.keras.layers.Discretization(
bin_boundaries=None,
num_bins=None,
epsilon=0.01,
output_mode='int',
sparse=False,
**kwargs
)
This layer will place each element of its input data into one of several contiguous ranges and output an integer index indicating which range each element was placed in.
For an overview and full list of preprocessing layers, see the preprocessing guide.
Input shape | |
---|---|
Any tf.Tensor or tf.RaggedTensor of dimension 2 or higher.
|
Output shape | |
---|---|
Same as input shape. |
Examples:
Bucketize float values based on provided buckets.
>>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])
>>> layer = tf.keras.layers.Discretization(bin_boundaries=[0., 1., 2.])
>>> layer(input)
<tf.Tensor: shape=(2, 4), dtype=int64, numpy=
array([[0, 2, 3, 1],
[1, 3, 2, 1]])>
Bucketize float values based on a number of buckets to compute.
>>> input = np.array([[-1.5, 1.0, 3.4, .5], [0.0, 3.0, 1.3, 0.0]])
>>> layer = tf.keras.layers.Discretization(num_bins=4, epsilon=0.01)
>>> layer.adapt(input)
>>> layer(input)
<tf.Tensor: shape=(2, 4), dtype=int64, numpy=
array([[0, 2, 3, 2],
[1, 3, 3, 1]])>
Attributes | |
---|---|
is_adapted
|
Whether the layer has been fit to data already. |
Methods
adapt
adapt(
data, batch_size=None, steps=None
)
Computes bin boundaries from quantiles in a input dataset.
Calling adapt()
on a Discretization
layer is an alternative to
passing in a bin_boundaries
argument during construction. A
Discretization
layer should always be either adapted over a dataset or
passed bin_boundaries
.
During adapt()
, the layer will estimate the quantile boundaries of the
input dataset. The number of quantiles can be controlled via the
num_bins
argument, and the error tolerance for quantile boundaries can
be controlled via the epsilon
argument.
In order to make Discretization
efficient in any distribution context,
the computed boundaries are kept static with respect to any compiled
tf.Graph
s that call the layer. As a consequence, if the layer is
adapted a second time, any models using the layer should be re-compiled.
For more information see
tf.keras.layers.experimental.preprocessing.PreprocessingLayer.adapt
.
adapt()
is meant only as a single machine utility to compute layer
state. To analyze a dataset that cannot fit on a single machine, see
Tensorflow Transform for a
multi-machine, map-reduce solution.
Arguments | |
---|---|
data
|
The data to train on. It can be passed either as a
tf.data.Dataset , or as a numpy array.
|
batch_size
|
Integer or None .
Number of samples per state update.
If unspecified, batch_size will default to 32.
Do not specify the batch_size if your data is in the
form of datasets, generators, or keras.utils.Sequence instances
(since they generate batches).
|
steps
|
Integer or None .
Total number of steps (batches of samples)
When training with input tensors such as
TensorFlow data tensors, the default None is equal to
the number of samples in your dataset divided by
the batch size, or 1 if that cannot be determined. If x is a
tf.data dataset, and 'steps' is None, the epoch will run until
the input dataset is exhausted. When passing an infinitely
repeating dataset, you must specify the steps argument. This
argument is not supported with array inputs.
|
compile
compile(
run_eagerly=None, steps_per_execution=None
)
Configures the layer for adapt
.
Arguments | |
---|---|
run_eagerly
|
Bool. If True , this Model 's
logic will not be wrapped in a tf.function . Recommended to leave
this as None unless your Model cannot be run inside a
tf.function . Defaults to False .
|
steps_per_execution
|
Int. The number of batches to run
during each tf.function call. Running multiple batches inside a
single tf.function call can greatly improve performance on TPUs or
small models with a large Python overhead. Defaults to 1 .
|
reset_state
reset_state()
Resets the statistics of the preprocessing layer.
update_state
update_state(
data
)
Accumulates statistics for the preprocessing layer.
Arguments | |
---|---|
data
|
A mini-batch of inputs to the layer. |