{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "headers" }, "source": [ "Project: /responsible_ai/_project.yaml\n", "Book: /responsible_ai/_book.yaml\n", "\n", "\n", "\n", "\n", "\n", "\n", "{% comment %}\n", "The source of truth file can be found [here]: http://google3/third_party/py/tensorflow_model_remediation/docs\n", "{% endcomment %}" ] }, { "cell_type": "markdown", "metadata": { "id": "SdAOaOmwHEjq" }, "source": [ "# Customizing MinDiffModel\n", "\n", "
\n", " View on TensorFlow.org\n", " | \n", "\n", " \n", " Run in Google Colab\n", " | \n", "\n", " \n", " View source on GitHub\n", " | \n", "\n", " Download notebook\n", " | \n", "
keras.Model
you are using has custom behavior that you want to preserve.\n",
"* You want the `MinDiffModel` to behave differently from the default.\n",
"\n",
"In either case, you will need to subclass `MinDiffModel` to achieve the desired results."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "X0KEthjiwvwg"
},
"source": [
"##Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kmAHyZt9TErX"
},
"outputs": [],
"source": [
"!pip install --upgrade tensorflow-model-remediation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "XRa49AkYS6n1"
},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"tf.get_logger().setLevel('ERROR') # Avoid TF warnings.\n",
"from tensorflow_model_remediation import min_diff\n",
"from tensorflow_model_remediation.tools.tutorials_utils import uci as tutorials_utils"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "SIKBbSGZQvBt"
},
"source": [
"First, download the data. For succinctness, the input preparation logic has been factored out into helper functions as described in the [input preparation guide](./min_diff_data_preparation#utility_functions_for_other_guides). You can read the full guide for details on this process."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Yhbj5eqLQpzC"
},
"outputs": [],
"source": [
"# Original Dataset for training, sampled at 0.3 for reduced runtimes.\n",
"train_df = tutorials_utils.get_uci_data(split='train', sample=0.3)\n",
"train_ds = tutorials_utils.df_to_dataset(train_df, batch_size=128)\n",
"\n",
"# Dataset needed to train with MinDiff.\n",
"train_with_min_diff_ds = (\n",
" tutorials_utils.get_uci_with_min_diff_dataset(split='train', sample=0.3))"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "M5UY0aDskbsO"
},
"source": [
"## Preserving Original Model Customizations\n",
"\n",
"tf.keras.Model
is designed to be easily customized via subclassing as described [here](https://www.tensorflow.org/guide/keras/custom_layers_and_models#the_model_class). If your model has customized implementations that you wish to preserve when applying MinDiff, you will need to subclass `MinDiffModel`."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "XPUDWc_ro-sX"
},
"source": [
"### Original Custom Model\n",
"\n",
"To see how you can preserve customizations, create a custom model that sets an attribute to `True` when its custom `train_step` is called. This is not a useful customization but will serve to illustrate behavior."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "EowVwf79pE7j"
},
"outputs": [],
"source": [
"class CustomModel(tf.keras.Model):\n",
"\n",
" # Customized train_step\n",
" def train_step(self, *args, **kwargs):\n",
" self.used_custom_train_step = True # Marker that we can check for.\n",
" return super(CustomModel, self).train_step(*args, **kwargs)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "jdeDa67XbNiv"
},
"source": [
"Training such a model would look the same as a normal `Sequential` model."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9reVfMpvbJGW"
},
"outputs": [],
"source": [
"model = tutorials_utils.get_uci_model(model_class=CustomModel) # Use CustomModel.\n",
"\n",
"model.compile(optimizer='adam', loss='binary_crossentropy')\n",
"\n",
"_ = model.fit(train_ds.take(1), epochs=1, verbose=0)\n",
"\n",
"# Model has used the custom train_step.\n",
"print('Model used the custom train_step:')\n",
"print(hasattr(model, 'used_custom_train_step')) # True"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "6S19mOvVpDU2"
},
"source": [
"### Subclassing MinDiffModel\n",
"\n",
"If you were to try and use `MinDiffModel` directly, the model would not use the custom `train_step`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "motiNz_NbGB3"
},
"outputs": [],
"source": [
"model = tutorials_utils.get_uci_model(model_class=CustomModel)\n",
"model = min_diff.keras.MinDiffModel(model, min_diff.losses.MMDLoss())\n",
"\n",
"model.compile(optimizer='adam', loss='binary_crossentropy')\n",
"\n",
"_ = model.fit(train_with_min_diff_ds.take(1), epochs=1, verbose=0)\n",
"\n",
"# Model has not used the custom train_step.\n",
"print('Model used the custom train_step:')\n",
"print(hasattr(model, 'used_custom_train_step')) # False"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QGZszcNnfWs3"
},
"source": [
"In order to use the correct `train_step` method, you need a custom class that subclasses both `MinDiffModel` and `CustomModel`.\n",
"\n",
"Note: Make sure to inherit from `MinDiffModel` first. This is important since you need to make sure that certain functions such as `__init__` and `call` still come from `MinDiffModel`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "l5ozMo_gplLP"
},
"outputs": [],
"source": [
"class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):\n",
" pass # No need for any further implementation."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ZYy40uGKgAB_"
},
"source": [
"Training this model will use the `train_step` from `CustomModel`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "AieekYU4f_D_"
},
"outputs": [],
"source": [
"model = tutorials_utils.get_uci_model(model_class=CustomModel)\n",
"\n",
"model = CustomMinDiffModel(model, min_diff.losses.MMDLoss())\n",
"\n",
"model.compile(optimizer='adam', loss='binary_crossentropy')\n",
"\n",
"_ = model.fit(train_with_min_diff_ds.take(1), epochs=1, verbose=0)\n",
"\n",
"# Model has used the custom train_step.\n",
"print('Model used the custom train_step:')\n",
"print(hasattr(model, 'used_custom_train_step')) # True"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "x8M1EqzEmIq4"
},
"source": [
"##Customizing default behaviors of `MinDiffModel`\n",
"\n",
"In other cases, you may want to change specific default behaviors of `MinDiffModel`. The most common use case of this is changing the default unpacking behavior to properly handle your data if you don't use `pack_min_diff_data`.\n",
"\n",
"When packing the data into a custom format, this might appear as follows."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "-37_74R4jTRN"
},
"outputs": [],
"source": [
"def _reformat_input(inputs, original_labels):\n",
" min_diff_data = min_diff.keras.utils.unpack_min_diff_data(inputs)\n",
" original_inputs = min_diff.keras.utils.unpack_original_inputs(inputs)\n",
"\n",
" return ({\n",
" 'min_diff_data': min_diff_data,\n",
" 'original_inputs': original_inputs}, original_labels)\n",
"\n",
"customized_train_with_min_diff_ds = train_with_min_diff_ds.map(_reformat_input)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wDHdQXp0v27r"
},
"source": [
"The `customized_train_with_min_diff_ds` dataset returns batches composed of tuples `(x, y)` where `x` is a dict containing `min_diff_data` and `original_inputs` and `y` is the `original_labels`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "ch4tJRP1KqwP"
},
"outputs": [],
"source": [
"for x, _ in customized_train_with_min_diff_ds.take(1):\n",
" print('Type of x:', type(x)) # dict\n",
" print('Keys of x:', x.keys()) # 'min_diff_data', 'original_inputs'"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-s8HuF8jKqGK"
},
"source": [
"This data format is not what `MinDiffModel` expects by default and passing `customized_train_with_min_diff_ds` to it would result in unexpected behavior. To fix this you will need to create your own subclass."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "Nh0v7HA6ipOl"
},
"outputs": [],
"source": [
"class CustomUnpackingMinDiffModel(min_diff.keras.MinDiffModel):\n",
"\n",
" def unpack_min_diff_data(self, inputs):\n",
" return inputs['min_diff_data']\n",
"\n",
" def unpack_original_inputs(self, inputs):\n",
" return inputs['original_inputs']\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Mqj2QsLwHic"
},
"source": [
"With this subclass, you can train as with the other examples."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "yLAfa0HguFWH"
},
"outputs": [],
"source": [
"model = tutorials_utils.get_uci_model()\n",
"model = CustomUnpackingMinDiffModel(model, min_diff.losses.MMDLoss())\n",
"\n",
"model.compile(optimizer='adam', loss='binary_crossentropy')\n",
"\n",
"_ = model.fit(customized_train_with_min_diff_ds, epochs=1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-fNyaok_L5Q4"
},
"source": [
"##Limitations of a Customized `MinDiffModel`\n",
"\n",
"Creating a custom `MinDiffModel` provides a huge amount of flexibility for more complex use cases. However, there are still some edge cases that it will not support."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "QYO2Uxzxe8km"
},
"source": [
"### Preprocessing or Validation of inputs before `call`\n",
"\n",
"The biggest limitation for a subclass of `MinDiffModel` is that it requires the `x` component of the input data (i.e. the first or only element in the batch returned by the tf.data.Dataset
) to be passed through without preprocessing or validation to `call`.\n",
"\n",
"This is simply because the `min_diff_data` is packed into the `x` component of the input data. Any preprocessing or validation will not expect the additional structure containing `min_diff_data` and will likely break.\n",
"\n",
"If the preprocessing or validation is easily customizable (e.g. factored into its own method) then this is easily addressed by overriding it to ensure it handles the additional structure correctly.\n",
"\n",
"An example with validation might look like this:\n",
"```\n",
"class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):\n",
"\n",
" # Override so that it correctly handles additional `min_diff_data`.\n",
" def validate_inputs(self, inputs):\n",
" original_inputs = self.unpack_original_inputs(inputs)\n",
" ... # Optionally also validate min_diff_data\n",
" # Call original validate method with correct inputs\n",
" return super(CustomMinDiffModel, self).validate(original_inputs)\n",
"```\n",
"If the preprocessing or validation isn't easily customizable, then using `MinDiffModel` may not work for you and you will need to integrate MinDiff without it as described in [this guide](./integrating_min_diff_without_min_diff_model)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-w4b9AOc3j-N"
},
"source": [
"### Method name collisions\n",
"\n",
"It is possible that your model has methods whose names clash with those implemented in `MinDiffModel` (see full list of public methods in the [API documentation](https://www.tensorflow.org/responsible_ai/model_remediation/api_docs/python/model_remediation/min_diff/keras/MinDiffModel#methods)).\n",
"\n",
"Note: Both `compile` and `save` will use any customized implementations provided so are not a cause for concern. Similarly, `call` will invoke the original model's `call` method after calculating the `min_diff_loss`.\n",
"\n",
"\n",
"This is only problematic if these will be called on an instance of the model (rather than internally in some other method). While highly unlikely, if you are in this situation you will have to either override and rename some methods or, if not possible, you may need to consider integrating MinDiff without `MinDiffModel` as described in [this guide on the subject](./integrating_min_diff_without_min_diff_model)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1sGofxqcHYD5"
},
"source": [
"## Additional Resources\n",
"\n",
"* For an in depth discussion on fairness evaluation see the [Fairness Indicators guidance](https://www.tensorflow.org/responsible_ai/fairness_indicators/guide/guidance)\n",
"* For general information on Remediation and MinDiff, see the [remediation overview](https://www.tensorflow.org/responsible_ai/model_remediation).\n",
"* For details on requirements surrounding MinDiff see [this guide](https://www.tensorflow.org/responsible_ai/model_remediation/min_diff/guide/requirements).\n",
"* To see an end-to-end tutorial on using MinDiff in Keras, see [this tutorial](https://www.tensorflow.org/responsible_ai/model_remediation/min_diff/tutorials/min_diff_keras)."
]
}
],
"metadata": {
"colab": {
"collapsed_sections": [],
"name": "customizing_min_diff_model.ipynb",
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}