View source on GitHub |
Base class for Preprocessing Layers.
tf.keras.layers.experimental.preprocessing.PreprocessingLayer(
**kwargs
)
Don't use this class directly: it's an abstract base class! You may be looking for one of the many built-in preprocessing layers instead.
Preprocessing layers are layers whose state gets computed before model
training starts. They do not get updated during training. Most
preprocessing layers implement an adapt()
method for state computation.
The PreprocessingLayer
class is the base class you would subclass to
implement your own preprocessing layers.
Attributes | |
---|---|
is_adapted
|
Whether the layer has been fit to data already. |
Methods
adapt
adapt(
data, batch_size=None, steps=None
)
Fits the state of the preprocessing layer to the data being passed.
After calling adapt
on a layer, a preprocessing layer's state will not
update during training. In order to make preprocessing layers efficient
in any distribution context, they are kept constant with respect to any
compiled tf.Graph
s that call the layer. This does not affect the layer
use when adapting each layer only once, but if you adapt a layer
multiple times you will need to take care to re-compile any compiled
functions as follows:
- If you are adding a preprocessing layer to a
keras.Model
, you need to callmodel.compile
after each subsequent call toadapt
. - If you are calling a preprocessing layer inside
tf.data.Dataset.map
, you should callmap
again on the inputtf.data.Dataset
after eachadapt
. - If you are using a
tf.function
directly which calls a preprocessing layer, you need to calltf.function
again on your callable after each subsequent call toadapt
.
tf.keras.Model
example with multiple adapts:
layer = tf.keras.layers.Normalization(
axis=None)
layer.adapt([0, 2])
model = tf.keras.Sequential(layer)
model.predict([0, 1, 2])
array([-1., 0., 1.], dtype=float32)
layer.adapt([-1, 1])
model.compile() # This is needed to re-compile model.predict!
model.predict([0, 1, 2])
array([0., 1., 2.], dtype=float32)
tf.data.Dataset
example with multiple adapts:
layer = tf.keras.layers.Normalization(
axis=None)
layer.adapt([0, 2])
input_ds = tf.data.Dataset.range(3)
normalized_ds = input_ds.map(layer)
list(normalized_ds.as_numpy_iterator())
[array([-1.], dtype=float32),
array([0.], dtype=float32),
array([1.], dtype=float32)]
layer.adapt([-1, 1])
normalized_ds = input_ds.map(layer) # Re-map over the input dataset.
list(normalized_ds.as_numpy_iterator())
[array([0.], dtype=float32),
array([1.], dtype=float32),
array([2.], dtype=float32)]
adapt()
is meant only as a single machine utility to compute layer
state. To analyze a dataset that cannot fit on a single machine, see
Tensorflow Transform
for a multi-machine, map-reduce solution.
Arguments | |
---|---|
data
|
The data to train on. It can be passed either as a tf.data Dataset, or as a numpy array. |
batch_size
|
Integer or None .
Number of samples per state update. If unspecified,
batch_size will default to 32. Do not specify the
batch_size if your data is in the form of datasets,
generators, or keras.utils.Sequence instances (since they
generate batches).
|
steps
|
Integer or None .
Total number of steps (batches of samples)
When training with input tensors such as
TensorFlow data tensors, the default None is equal to
the number of samples in your dataset divided by
the batch size, or 1 if that cannot be determined. If x is a
tf.data dataset, and 'steps' is None, the epoch will run until
the input dataset is exhausted. When passing an infinitely
repeating dataset, you must specify the steps argument. This
argument is not supported with array inputs.
|
compile
compile(
run_eagerly=None, steps_per_execution=None
)
Configures the layer for adapt
.
Arguments | |
---|---|
run_eagerly
|
Bool. If True , this Model 's
logic will not be wrapped in a tf.function . Recommended to leave
this as None unless your Model cannot be run inside a
tf.function . Defaults to False .
|
steps_per_execution
|
Int. The number of batches to run
during each tf.function call. Running multiple batches inside a
single tf.function call can greatly improve performance on TPUs or
small models with a large Python overhead. Defaults to 1 .
|