Creates a baseline task for image classification on CIFAR-100.
tff.simulation.baselines.cifar100.create_image_classification_task(
train_client_spec: tff.simulation.baselines.ClientSpec
,
eval_client_spec: Optional[tff.simulation.baselines.ClientSpec
] = None,
model_id: Union[str, tff.simulation.baselines.cifar100.ResnetModel
] = 'resnet18',
crop_height: int = DEFAULT_CROP_HEIGHT,
crop_width: int = DEFAULT_CROP_WIDTH,
distort_train_images: bool = False,
cache_dir: Optional[str] = None,
use_synthetic_data: bool = False
) -> tff.simulation.baselines.BaselineTask
The goal of the task is to minimize the sparse categorical crossentropy
between the output labels of the model and the true label of the image.
Args |
train_client_spec
|
A tff.simulation.baselines.ClientSpec specifying how to
preprocess train client data.
|
eval_client_spec
|
An optional tff.simulation.baselines.ClientSpec
specifying how to preprocess evaluation client data. If set to None , the
evaluation datasets will use a batch size of 64 with no extra
preprocessing.
|
model_id
|
A string identifier for a digit recognition model. Must be one of
resnet18 , resnet34 , resnet50 , resnet101 and resnet152. These
correspond to various ResNet architectures. Unlike standard ResNet
architectures though, the batch normalization layers are replaced with
group normalization.
</td>
</tr><tr>
<td> crop_height<a id="crop_height"></a>
</td>
<td>
An integer specifying the desired height for cropping images.
Must be between 1 and 32 (the height of uncropped CIFAR-100 images). By
default, this is set to
<a href="../../../../tff/simulation/baselines/cifar100#DEFAULT_CROP_HEIGHT"><code>tff.simulation.baselines.cifar100.DEFAULT_CROP_HEIGHT</code></a>.
</td>
</tr><tr>
<td> crop_width<a id="crop_width"></a>
</td>
<td>
An integer specifying the desired width for cropping images.
Must be between 1 and 32 (the width of uncropped CIFAR-100 images). By
default this is set to
<a href="../../../../tff/simulation/baselines/cifar100#DEFAULT_CROP_WIDTH"><code>tff.simulation.baselines.cifar100.DEFAULT_CROP_WIDTH</code></a>.
</td>
</tr><tr>
<td> distort_train_images<a id="distort_train_images"></a>
</td>
<td>
Whether to distort images in the train preprocessing
function.
</td>
</tr><tr>
<td> cache_dir<a id="cache_dir"></a>
</td>
<td>
An optional directory to cache the downloadeded datasets. If None, they will be cached to ~/.tff/.
</td>
</tr><tr>
<td> use_synthetic_data`
|
A boolean indicating whether to use synthetic CIFAR-100
data. This option should only be used for testing purposes, in order to
avoid downloading the entire CIFAR-100 dataset.
|