Creating a custom Counterfactual Logit Pairing Dataset

Applying Counterfactual Logit Pairing (CLP) to evaluate and improve the fairness of your model requires a counterfactual dataset. You create a counterfactual dataset by duplicating your existing dataset and changing the new dataset to add, remove, or modify identity terminology. This tutorial explains the approach and techniques for creating a counterfactual dataset for your existing text dataset.

You use your counterfactual dataset with the CLP technique by creating a new data object, CounterfactualPackedInputs, that contains the original_input and counterfactual_data, and looks like the following:

CounterfactualPackedInputs looks like the following:

CounterfactualPackedInputs(
  original_input=(x, y, sample_weight),
  counterfactual_data=(original_x, counterfactual_x,
                       counterfactual_sample_weight)
)

The original_input should be the original dataset that is used to train your Keras model. counterfactual_data should be a tf.data.Dataset with the original x value, the corresponding counterfactual_x value, and the counterfactual_sample_weight. The counterfactual_x value is nearly identical to the original value but with one or more of the attributes removed or replaced. This dataset is used to pair the loss function between the original value and the counterfactual value with the goal of assuring that the model’s prediction doesn’t change when the sensitive attribute is different. original_input and counterfactual_data need to be the same shape. You can duplicate values from counterfactual_data so that it’s the same number of elements as original_input.

Properties of counterfactual_data:

  • All original_x values need to have references to an identity group
  • Each counterfactual_x value is identical to the original value, but with one or more of the attributes removed or replaced
  • Have the same shape as original input (you can duplicate values so that they’re the same shape)

counterfactual_data does not need to:

  • Have overlap with data within original input
  • Have ground truth labels

Here’s an example of what a counterfactual_data would look like if you remove the term "gay".

original_x: I am a gay man
counterfactual_x: I am a man 
counterfactual_sample_weight: 1

If you have a text classifier, you can use build_counterfactual_data to help create a counterfactual dataset. For all other data types, you need to provide a counterfactual dataset directly.

Setup

You'll begin by installing TensorFlow Model Remediation.

pip install --upgrade tensorflow-model-remediation
import tensorflow as tf
from tensorflow_model_remediation import counterfactual
2024-07-19 09:53:41.340953: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-19 09:53:41.361880: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-19 09:53:41.368395: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered

Create a simple Dataset

For demonstrative purposes, we’ll create counterfactual data from the original input using build_counterfactual_dataset. Note that you can also construct counterfactual data from unlabeled data (as opposed to constructing it from original input). You will create a simple dataset with one sentence: “i am a gay man” which will serve as the original_input.

Build a Counterfactual Dataset

As this is a text classifier, you can create the counterfactual dataset with build_counterfactual_data in two ways:

  1. Remove terms: Use build_counterfactual_data to pass a list of words that will be removed from the dataset via tf.strings.regex_replace.
  2. Replace terms: Create a custom function to pass to build_counterfactual_data. This might include using more specific regex functions to replace words within your original dataset or to support non-text features

build_counterfactual_dataset takes in original_input and either removes or replaces terms depending on what optional parameters you pass. In most cases removing terms (option 1) should be sufficient to run CLP, however passing a custom function (option 2) is available for more precise control on the counterfactual values.

Option 1: List of Words to Remove

Pass in a list of gender-related terms to remove withbuild_counterfactual_data.

When using simple regex to create the counterfactual dataset, keep in mind that this may augment words that shouldn’t be changed. It is good practice to check that the changes made to the counterfactual_x value make sense in the context of the orginal_x value. Additionally, build_counterfactual_dataset will return only the values including a counterfactual instance. This could result in a different shape dataset from orginal_input, but it will be resized when passed to pack_counterfactual_data.

simple_dataset_x = tf.constant(
    ["I am a gay man" + str(i) for i in range(10)] +
    ["I am a man" + str(i) for i in range(10)])
print("Length of starting values: " + str(len(simple_dataset_x)))

simple_dataset = tf.data.Dataset.from_tensor_slices(
            (simple_dataset_x, None, None))

counterfactual_data = counterfactual.keras.utils.build_counterfactual_data(
    original_input=simple_dataset,
    sensitive_terms_to_remove=['gay'])

# Inspect the content of the TF Counterfactual Dataset
for original_value, counterfactual_value, _ in counterfactual_data.take(1):
  print("original: ", original_value)
  print("counterfactual: ", counterfactual_value)
print("Length of dataset after build_counterfactual_data: " +
      str(len(list(counterfactual_data))))
Length of starting values: 20
original:  tf.Tensor(b'I am a gay man0', shape=(), dtype=string)
counterfactual:  tf.Tensor(b'I am a  man0', shape=(), dtype=string)
Length of dataset after build_counterfactual_data: 10
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1721382824.212840   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382824.216631   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382824.220305   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382824.225695   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382824.237536   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382824.240949   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382824.244491   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382824.247862   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382824.251348   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382824.254836   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382824.258211   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382824.261555   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.510650   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.512802   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.514940   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.517172   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.519339   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.521281   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.523290   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.525275   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.527328   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.529283   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.531324   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.533303   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.572780   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.574793   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.576887   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.578906   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.581059   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.582998   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.584991   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.586968   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.588999   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.591422   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.593838   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721382825.596178   23039 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355

Option 2: Custom Function

For more flexibility around ways of modifying your original dataset, you can instead pass a custom function to build_counterfactual_data.

In the example, you can consider replacing identity terms that reference men with those that reference women. This can be done by writing a function to replace a dictionary of words.

Note that the only limitation on the custom function is that it must be a callable to accept and return a tuple in the format used in Model.fit and should remove values that do not include any changes, which can be done by passing the terms to sensitive_terms_to_remove.

words_to_replace = {"man": "woman"}
print("Length of starting values: " + str(len(simple_dataset_x)))

def replace_words(original_batch):
  original_x, _, original_sample_weight = (
      tf.keras.utils.unpack_x_y_sample_weight(original_batch))
  for word in words_to_replace:
    counterfactual_x = tf.strings.regex_replace(
        original_x, f'{word}', words_to_replace[word])
  return tf.keras.utils.pack_x_y_sample_weight(
      original_x, counterfactual_x, sample_weight=original_sample_weight)

counterfactual_data = counterfactual.keras.utils.build_counterfactual_data(
    original_input=simple_dataset,
    sensitive_terms_to_remove=['gay'],
    custom_counterfactual_function=replace_words)

# Inspect the content of the TF Counterfactual Dataset
for original_value, counterfactual_value in counterfactual_data.take(1):
  print("original: ", original_value)
  print("counterfactual: ", counterfactual_value)
print("Length of dataset after build_counterfactual_data: " +
      str(len(list(counterfactual_data))))
Length of starting values: 20
original:  tf.Tensor(b'I am a gay man0', shape=(), dtype=string)
counterfactual:  tf.Tensor(b'I am a gay man0', shape=(), dtype=string)
Length of dataset after build_counterfactual_data: 10

To learn more, please see the API documents for build_counterfactual_data.