Giới thiệu
Trong hầu hết các trường hợp, sử dụng MinDiffModel
trực tiếp như mô tả trong "Lồng ghép MinDiff với MinDiffModel" hướng dẫn là đủ. Tuy nhiên, có thể bạn sẽ cần hành vi tùy chỉnh. Hai lý do chính cho điều này là:
- Các
keras.Model
bạn đang sử dụng có hành vi tùy chỉnh mà bạn muốn giữ. - Bạn muốn
MinDiffModel
cư xử khác so với mặc định.
Trong cả hai trường hợp, bạn sẽ cần phải phân lớp MinDiffModel
để đạt được kết quả mong muốn.
Thành lập
pip install -q --upgrade tensorflow-model-remediation
import tensorflow as tf
tf.get_logger().setLevel('ERROR') # Avoid TF warnings.
from tensorflow_model_remediation import min_diff
from tensorflow_model_remediation.tools.tutorials_utils import uci as tutorials_utils
Đầu tiên, tải xuống dữ liệu. Đối với tính cô đọng, logic chuẩn bị đầu vào đã được yếu tố ra thành các hàm helper như mô tả trong hướng dẫn chuẩn bị đầu vào . Bạn có thể đọc hướng dẫn đầy đủ để biết chi tiết về quy trình này.
# Original Dataset for training, sampled at 0.3 for reduced runtimes.
train_df = tutorials_utils.get_uci_data(split='train', sample=0.3)
train_ds = tutorials_utils.df_to_dataset(train_df, batch_size=128)
# Dataset needed to train with MinDiff.
train_with_min_diff_ds = (
tutorials_utils.get_uci_with_min_diff_dataset(split='train', sample=0.3))
Bảo tồn các tùy chỉnh của mô hình gốc
tf.keras.Model
được thiết kế để dễ dàng tùy chỉnh thông qua subclassing như đã mô tả ở đây . Nếu mô hình của bạn đã tùy chỉnh triển khai mà bạn muốn giữ khi áp dụng MinDiff, bạn sẽ cần phải phân lớp MinDiffModel
.
Mô hình tùy chỉnh ban đầu
Để xem làm thế nào bạn có thể giữ gìn tùy chỉnh, tạo ra một mô hình tùy chỉnh mà bộ một thuộc tính để True
khi tùy chỉnh của nó train_step
được gọi. Đây không phải là một tùy chỉnh hữu ích nhưng sẽ phục vụ để minh họa hành vi.
class CustomModel(tf.keras.Model):
# Customized train_step
def train_step(self, *args, **kwargs):
self.used_custom_train_step = True # Marker that we can check for.
return super(CustomModel, self).train_step(*args, **kwargs)
Đào tạo một mô hình như vậy sẽ trông giống như một bình thường Sequential
mô hình.
model = tutorials_utils.get_uci_model(model_class=CustomModel) # Use CustomModel.
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(train_ds.take(1), epochs=1, verbose=0)
# Model has used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step')) # True
Model used the custom train_step: True
Phân lớp MinDiffModel
Nếu bạn đã thử và sử dụng MinDiffModel
trực tiếp, mô hình sẽ không sử dụng tùy chỉnh train_step
.
model = tutorials_utils.get_uci_model(model_class=CustomModel)
model = min_diff.keras.MinDiffModel(model, min_diff.losses.MMDLoss())
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(train_with_min_diff_ds.take(1), epochs=1, verbose=0)
# Model has not used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step')) # False
Model used the custom train_step: False
Để sử dụng đúng train_step
phương pháp, bạn cần một lớp tùy chỉnh mà lớp con cả MinDiffModel
và CustomModel
.
class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):
pass # No need for any further implementation.
Đào tạo mô hình này sẽ sử dụng train_step
từ CustomModel
.
model = tutorials_utils.get_uci_model(model_class=CustomModel)
model = CustomMinDiffModel(model, min_diff.losses.MMDLoss())
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(train_with_min_diff_ds.take(1), epochs=1, verbose=0)
# Model has used the custom train_step.
print('Model used the custom train_step:')
print(hasattr(model, 'used_custom_train_step')) # True
Model used the custom train_step: True
Tùy hành vi mặc định của MinDiffModel
Trong trường hợp khác, bạn có thể muốn thay đổi hành vi mặc định cụ thể của MinDiffModel
. Các trường hợp sử dụng phổ biến nhất của việc này đang thay đổi mặc định các hành vi giải nén để xử lý đúng đắn dữ liệu của bạn nếu bạn không sử dụng pack_min_diff_data
.
Khi đóng gói dữ liệu thành một định dạng tùy chỉnh, điều này có thể xuất hiện như sau.
def _reformat_input(inputs, original_labels):
min_diff_data = min_diff.keras.utils.unpack_min_diff_data(inputs)
original_inputs = min_diff.keras.utils.unpack_original_inputs(inputs)
return ({
'min_diff_data': min_diff_data,
'original_inputs': original_inputs}, original_labels)
customized_train_with_min_diff_ds = train_with_min_diff_ds.map(_reformat_input)
Các customized_train_with_min_diff_ds
bộ dữ liệu trở lại lô gồm tuples (x, y)
mà x
là một dict chứa min_diff_data
và original_inputs
và y
là original_labels
.
for x, _ in customized_train_with_min_diff_ds.take(1):
print('Type of x:', type(x)) # dict
print('Keys of x:', x.keys()) # 'min_diff_data', 'original_inputs'
Type of x: <class 'dict'> Keys of x: dict_keys(['min_diff_data', 'original_inputs'])
Định dạng dữ liệu này không phải là điều MinDiffModel
hy vọng theo mặc định và đi qua customized_train_with_min_diff_ds
để nó sẽ dẫn đến hành vi bất ngờ. Để khắc phục điều này, bạn sẽ cần tạo lớp con của riêng mình.
class CustomUnpackingMinDiffModel(min_diff.keras.MinDiffModel):
def unpack_min_diff_data(self, inputs):
return inputs['min_diff_data']
def unpack_original_inputs(self, inputs):
return inputs['original_inputs']
Với lớp con này, bạn có thể huấn luyện như với các ví dụ khác.
model = tutorials_utils.get_uci_model()
model = CustomUnpackingMinDiffModel(model, min_diff.losses.MMDLoss())
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(customized_train_with_min_diff_ds, epochs=1)
77/77 [==============================] - 4s 30ms/step - loss: 0.6690 - min_diff_loss: 0.0395
Hạn chế của một Customized MinDiffModel
Tạo một tùy chỉnh MinDiffModel
cung cấp một số lượng lớn linh hoạt đối với trường hợp sử dụng phức tạp hơn. Tuy nhiên, vẫn có một số trường hợp cạnh mà nó sẽ không hỗ trợ.
Tiền xử lý hoặc Xác nhận của đầu vào trước khi call
Hạn chế lớn nhất đối với một lớp con của MinDiffModel
là nó đòi hỏi sự x
thành phần của dữ liệu đầu vào (tức là phần tử đầu tiên hay duy nhất trong hàng loạt được trả về bởi các tf.data.Dataset
) để được đi qua mà không cần tiền xử lý hoặc xác nhận để call
.
Điều này đơn giản là vì min_diff_data
được đóng gói vào x
thành phần của dữ liệu đầu vào. Bất kỳ tiền xử lý hoặc xác nhận sẽ không mong đợi cơ cấu thêm chứa min_diff_data
và có khả năng sẽ phá vỡ.
Nếu quá trình tiền xử lý hoặc xác thực có thể dễ dàng tùy chỉnh (ví dụ: biến yếu tố thành phương thức riêng của nó) thì điều này có thể dễ dàng giải quyết bằng cách ghi đè nó để đảm bảo nó xử lý cấu trúc bổ sung một cách chính xác.
Một ví dụ với xác thực có thể trông như thế này:
class CustomMinDiffModel(min_diff.keras.MinDiffModel, CustomModel):
# Override so that it correctly handles additional `min_diff_data`.
def validate_inputs(self, inputs):
original_inputs = self.unpack_original_inputs(inputs)
... # Optionally also validate min_diff_data
# Call original validate method with correct inputs
return super(CustomMinDiffModel, self).validate(original_inputs)
Nếu tiền xử lý hoặc xác nhận không phải là dễ dàng tùy biến, sau đó sử dụng MinDiffModel
có thể không làm việc cho bạn và bạn sẽ cần phải tích hợp MinDiff mà không có nó như mô tả trong hướng dẫn này .
Xung đột tên phương pháp
Có thể là mô hình của bạn có phương pháp có tên xung đột với những người thực hiện trong MinDiffModel
(xem danh sách đầy đủ các phương pháp nào trong tài liệu API ).
Điều này chỉ có vấn đề nếu chúng sẽ được gọi trên một phiên bản của mô hình (thay vì nội bộ trong một số phương thức khác). Trong khi rất khó, nếu bạn đang ở trong tình huống này bạn sẽ phải hoặc là ghi đè và đổi tên một số phương pháp hay, nếu không thể, bạn có thể cần phải xem xét tích hợp MinDiff mà không MinDiffModel
như mô tả trong hướng dẫn này về đề tài này .
Tài nguyên bổ sung
- Đối với một trong cuộc thảo luận sâu về đánh giá công bằng xem các hướng dẫn chỉ số Công bằng
- Để biết thông tin chung về Xử lý ô nhiễm và MinDiff, xem tổng quan về khắc phục hậu quả .
- Để biết chi tiết về các yêu cầu xung quanh MinDiff thấy hướng dẫn này .
- Để xem một end-to-end hướng dẫn về sử dụng MinDiff trong Keras, xem hướng dẫn này .