View source on GitHub |
Performs spectral normalization on the weights of a target layer.
Inherits From: Wrapper
, Layer
, Module
tf.keras.layers.SpectralNormalization(
layer, power_iterations=1, **kwargs
)
This wrapper controls the Lipschitz constant of the weights of a layer by constraining their spectral norm, which can stabilize the training of GANs.
Args | |
---|---|
layer
|
A keras.layers.Layer instance that
has either a kernel (e.g. Conv2D , Dense ...)
or an embeddings attribute (Embedding layer).
|
power_iterations
|
int, the number of iterations during normalization. |
Examples:
Wrap keras.layers.Conv2D
:
>>> x = np.random.rand(1, 10, 10, 1)
>>> conv2d = SpectralNormalization(tf.keras.layers.Conv2D(2, 2))
>>> y = conv2d(x)
>>> y.shape
TensorShape([1, 9, 9, 2])
Wrap keras.layers.Dense
:
>>> x = np.random.rand(1, 10, 10, 1)
>>> dense = SpectralNormalization(tf.keras.layers.Dense(10))
>>> y = dense(x)
>>> y.shape
TensorShape([1, 10, 10, 10])
Reference:
Methods
normalize_weights
normalize_weights()
Generate spectral normalized weights.
This method will update the value of self.kernel
with the
spectral normalized value, so that the layer is ready for call()
.