View source on GitHub |
Applies Mixup and/or Cutmix to a batch of images.
tfm.vision.augment.MixupAndCutmix(
num_classes: int,
mixup_alpha: float = 0.8,
cutmix_alpha: float = 1.0,
prob: float = 1.0,
switch_prob: float = 0.5,
label_smoothing: float = 0.1
)
- Mixup: https://arxiv.org/abs/1710.09412
- Cutmix: https://arxiv.org/abs/1905.04899
Implementaion is inspired by https://github.com/rwightman/pytorch-image-models
Methods
distort
distort(
images: tf.Tensor, labels: tf.Tensor
) -> Tuple[tf.Tensor, tf.Tensor]
Applies Mixup and/or Cutmix to batch of images and transforms labels.
Args | |
---|---|
images
|
tf.Tensor
Of shape [batch_size, height, width, 3] representing a batch of image, or [batch_size, time, height, width, 3] representing a batch of video. |
labels
|
tf.Tensor
Of shape [batch_size, ] representing the class id for each image of the batch. |
Returns | |
---|---|
Tuple[tf.Tensor, tf.Tensor]: The augmented version of image and
labels .
|
__call__
__call__(
images: tf.Tensor, labels: tf.Tensor
) -> Tuple[tf.Tensor, tf.Tensor]
Call self as a function.