Introduction
In most cases, using MinDiffModel
directly as described in the "Integrating MinDiff with MinDiffModel" guide is sufficient. However, it is possible that you will need customized behavior. The two primary reasons for this are:
- The
keras.Model
you are using has custom behavior that you want to preserve. - You want the
MinDiffModel
to behave differently from the default.
In either case, you will need to subclass MinDiffModel
to achieve the desired results.
Setup
pip install --upgrade tensorflow-model-remediation
import tensorflow as tf
tf.get_logger().setLevel('ERROR') # Avoid TF warnings.
from tensorflow_model_remediation import min_diff
from tensorflow_model_remediation.tools.tutorials_utils import uci as tutorials_utils
First, download the data. For succinctness, the input preparation logic has been factored out into helper functions as described in the input preparation guide. You can read the full guide for details on this process.
# Original Dataset for training, sampled at 0.3 for reduced runtimes.
train_df = tutorials_utils.get_uci_data(split='train', sample=0.3)
train_ds = tutorials_utils.df_to_dataset(train_df, batch_size=128)
# Dataset needed to train with MinDiff.
train_with_min_diff_ds = (
tutorials_utils.get_uci_with_min_diff_dataset(split='train', sample=0.3))
Preserving Original Model Customizations
tf.keras.Model
is designed to be easily customized via subclassing as described here. If your model has customized implementations that you wish to preserve when applying MinDiff, you will need to subclass MinDiffModel
.
Original Custom Model
To see how you can preserve customizations, create a custom model that sets an attribute to True
when its custom train_step
is called. This is not a useful customization but will serve to illustrate behavior.
class CustomModel(tf.keras.Model):
# Customized train_step
def train_step(self, *args, **kwargs):
self.used_custom_train_step = True # Marker that we can check for.
return super(CustomModel, self).train_step(*args, **kwargs)
Training such a model would look the same as a normal Sequential
model.
model = tutorials_utils.get_uci_model(model_class=CustomModel) # Use CustomModel.
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(train_ds.take(1), epochs=1, verbose=0)
# Model has used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step')) # True
Subclassing MinDiffModel
If you were to try and use MinDiffModel
directly, the model would not use the custom train_step
.
model = tutorials_utils.get_uci_model(model_class=CustomModel)
model = min_diff.keras.MinDiffModel(model, min_diff.losses.MMDLoss())
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(train_with_min_diff_ds.take(1), epochs=1, verbose=0)
# Model has not used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step')) # False
In order to use the correct train_step
method, you need a custom class that subclasses both MinDiffModel
and CustomModel
.
class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):
pass # No need for any further implementation.
Training this model will use the train_step
from CustomModel
.
model = tutorials_utils.get_uci_model(model_class=CustomModel)
model = CustomMinDiffModel(model, min_diff.losses.MMDLoss())
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(train_with_min_diff_ds.take(1), epochs=1, verbose=0)
# Model has used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step')) # True
Customizing default behaviors of MinDiffModel
In other cases, you may want to change specific default behaviors of MinDiffModel
. The most common use case of this is changing the default unpacking behavior to properly handle your data if you don't use pack_min_diff_data
.
When packing the data into a custom format, this might appear as follows.
def _reformat_input(inputs, original_labels):
min_diff_data = min_diff.keras.utils.unpack_min_diff_data(inputs)
original_inputs = min_diff.keras.utils.unpack_original_inputs(inputs)
return ({
'min_diff_data': min_diff_data,
'original_inputs': original_inputs}, original_labels)
customized_train_with_min_diff_ds = train_with_min_diff_ds.map(_reformat_input)
The customized_train_with_min_diff_ds
dataset returns batches composed of tuples (x, y)
where x
is a dict containing min_diff_data
and original_inputs
and y
is the original_labels
.
for x, _ in customized_train_with_min_diff_ds.take(1):
print('Type of x:', type(x)) # dict
print('Keys of x:', x.keys()) # 'min_diff_data', 'original_inputs'
This data format is not what MinDiffModel
expects by default and passing customized_train_with_min_diff_ds
to it would result in unexpected behavior. To fix this you will need to create your own subclass.
class CustomUnpackingMinDiffModel(min_diff.keras.MinDiffModel):
def unpack_min_diff_data(self, inputs):
return inputs['min_diff_data']
def unpack_original_inputs(self, inputs):
return inputs['original_inputs']
With this subclass, you can train as with the other examples.
model = tutorials_utils.get_uci_model()
model = CustomUnpackingMinDiffModel(model, min_diff.losses.MMDLoss())
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(customized_train_with_min_diff_ds, epochs=1)
Limitations of a Customized MinDiffModel
Creating a custom MinDiffModel
provides a huge amount of flexibility for more complex use cases. However, there are still some edge cases that it will not support.
Preprocessing or Validation of inputs before call
The biggest limitation for a subclass of MinDiffModel
is that it requires the x
component of the input data (i.e. the first or only element in the batch returned by the tf.data.Dataset
) to be passed through without preprocessing or validation to call
.
This is simply because the min_diff_data
is packed into the x
component of the input data. Any preprocessing or validation will not expect the additional structure containing min_diff_data
and will likely break.
If the preprocessing or validation is easily customizable (e.g. factored into its own method) then this is easily addressed by overriding it to ensure it handles the additional structure correctly.
An example with validation might look like this:
class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):
# Override so that it correctly handles additional `min_diff_data`.
def validate_inputs(self, inputs):
original_inputs = self.unpack_original_inputs(inputs)
... # Optionally also validate min_diff_data
# Call original validate method with correct inputs
return super(CustomMinDiffModel, self).validate(original_inputs)
If the preprocessing or validation isn't easily customizable, then using MinDiffModel
may not work for you and you will need to integrate MinDiff without it as described in this guide.
Method name collisions
It is possible that your model has methods whose names clash with those implemented in MinDiffModel
(see full list of public methods in the API documentation).
This is only problematic if these will be called on an instance of the model (rather than internally in some other method). While highly unlikely, if you are in this situation you will have to either override and rename some methods or, if not possible, you may need to consider integrating MinDiff without MinDiffModel
as described in this guide on the subject.
Additional Resources
- For an in depth discussion on fairness evaluation see the Fairness Indicators guidance
- For general information on Remediation and MinDiff, see the remediation overview.
- For details on requirements surrounding MinDiff see this guide.
- To see an end-to-end tutorial on using MinDiff in Keras, see this tutorial.