잘라내기 종합 가이드

}TensorFlow.org에서 보기 Google Colab에서 실행하기 GitHub에서소스 보기 노트북 다운로드하기

Keras 가중치 잘라내기에 대한 종합 가이드를 시작합니다.

이 페이지는 다양한 사용 사례를 문서화하고 각각에 대해 API를 사용하는 방법을 보여줍니다. 필요한 API를 알고 나면, API 문서에서 매개변수와 하위 수준의 세부 정보를 찾아보세요.

  • 잘라내기의 이점과 지원되는 기능을 보려면 개요를 참조하세요.
  • 단일 엔드 투 엔드 예는 잘라내기 예를 참조하세요.

다음 사용 사례를 다룹니다.

  • 잘라낸 모델을 정의하고 훈련합니다.
    • 순차 및 함수형
    • Keras model.fit 및 사용자 정의 훈련 루프
  • 잘라낸 모델을 체크포인트 지정하고 역직렬화합니다.
  • 잘라낸 모델을 배포하고 압축 이점을 확인합니다.

잘라내기 알고리즘의 구성에 대해서는 tfmot.sparsity.keras.prune_low_magnitude API 문서를 참조하세요.

설정

필요한 API를 찾고 목적을 이해하기 위해 실행할 수 있지만, 이 섹션은 건너뛸 수 있습니다.

! pip install -q tensorflow-model-optimization

import tensorflow as tf
import numpy as np
import tensorflow_model_optimization as tfmot

%load_ext tensorboard

import tempfile

input_shape = [20]
x_train = np.random.randn(1, 20).astype(np.float32)
y_train = tf.keras.utils.to_categorical(np.random.randn(1), num_classes=20)

def setup_model():
  model = tf.keras.Sequential([
      tf.keras.layers.Dense(20, input_shape=input_shape),
      tf.keras.layers.Flatten()
  ])
  return model

def setup_pretrained_weights():
  model = setup_model()

  model.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
  )

  model.fit(x_train, y_train)

  _, pretrained_weights = tempfile.mkstemp('.tf')

  model.save_weights(pretrained_weights)

  return pretrained_weights

def get_gzipped_model_size(model):
  # Returns size of gzipped model, in bytes.
  import os
  import zipfile

  _, keras_file = tempfile.mkstemp('.h5')
  model.save(keras_file, include_optimizer=False)

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(keras_file)

  return os.path.getsize(zipped_file)

setup_model()
pretrained_weights = setup_pretrained_weights()

모델 정의하기

전체 모델 잘라내기(순차 및 함수형)

모델 정확성의 향상을 위한 팁:

  • 정확성을 가장 많이 떨어뜨리는 레이어 잘라내기를 건너뛰려면 "일부 레이어 잘라내기"를 시도합니다.
  • 일반적으로 처음부터 훈련하는 것보다 잘라내기로 미세 조정하는 것이 좋습니다.

잘라내기로 전체 모델을 훈련하려면, tfmot.sparsity.keras.prune_low_magnitude를 모델에 적용합니다.

base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended.

model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

model_for_pruning.summary()
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py:200: Layer.add_variable (from tensorflow.python.keras.engine.base_layer) is deprecated and will be removed in a future version.
Instructions for updating:
Please use `layer.add_weight` method instead.
Model: "sequential_2"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_2  (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________

일부 레이어 잘라내기(순차 및 함수형)

모델을 잘라내면 정확성에 부정적인 영향을 미칠 수 있습니다. 모델의 레이어를 선택적으로 잘라내어 정확성, 속도 및 모델 크기 간의 균형을 탐색할 수 있습니다.

모델 정확성의 향상을 위한 팁:

  • 일반적으로 처음부터 훈련하는 것보다 잘라내기로 미세 조정하는 것이 좋습니다.
  • 첫 번째 레이어 대신 이후 레이어를 잘라냅니다.
  • 중요 레이어(예: attention 메커니즘)을 잘라내지 마세요.

추가 자료:

  • tfmot.sparsity.keras.prune_low_magnitude API 문서는 레이어별로 잘라내기 구성을 변경하는 방법에 대한 세부 정보를 제공합니다.

아래 예에서는 Dense 레이어만 잘라냅니다.

# Create a base model
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy

# Helper function uses `prune_low_magnitude` to make only the 
# Dense layers train with pruning.
def apply_pruning_to_dense(layer):
  if isinstance(layer, tf.keras.layers.Dense):
    return tfmot.sparsity.keras.prune_low_magnitude(layer)
  return layer

# Use `tf.keras.models.clone_model` to apply `apply_pruning_to_dense` 
# to the layers of the model.
model_for_pruning = tf.keras.models.clone_model(
    base_model,
    clone_function=apply_pruning_to_dense,
)

model_for_pruning.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_3  (None, 20)                822       
_________________________________________________________________
flatten_3 (Flatten)          (None, 20)                0         
=================================================================
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________________

이 예에서는 레이어 유형을 사용하여 잘라낼 레이어를 결정했지만, 특정 레이어를 잘라내는 가장 쉬운 방법은 name 속성을 설정하고 clone_function에서 해당 내용을 찾는 것입니다.

print(base_model.layers[0].name)
dense_3

읽기 더 쉽지만 잠재적으로 모델 정확성이 낮음

잘라내기를 사용한 미세 조정과 호환되지 않으므로 미세 조정을 지원하는 위의 예보다 정확성이 떨어질 수 있습니다.

초기 모델을 정의하는 동안 prune_low_magnitude를 적용할 수 있지만, 이후에 가중치를 로드하면 아래 예에서 동작하지 않습니다.

함수형 예

# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
i = tf.keras.Input(shape=(20,))
x = tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(10))(i)
o = tf.keras.layers.Flatten()(x)
model_for_pruning = tf.keras.Model(inputs=i, outputs=o)

model_for_pruning.summary()
Model: "functional_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         [(None, 20)]              0         
_________________________________________________________________
prune_low_magnitude_dense_4  (None, 10)                412       
_________________________________________________________________
flatten_4 (Flatten)          (None, 10)                0         
=================================================================
Total params: 412
Trainable params: 210
Non-trainable params: 202
_________________________________________________________________

순차 예

# Use `prune_low_magnitude` to make the `Dense` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
  tfmot.sparsity.keras.prune_low_magnitude(tf.keras.layers.Dense(20, input_shape=input_shape)),
  tf.keras.layers.Flatten()
])

model_for_pruning.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_5  (None, 20)                822       
_________________________________________________________________
flatten_5 (Flatten)          (None, 20)                0         
=================================================================
Total params: 822
Trainable params: 420
Non-trainable params: 402
_________________________________________________________________

사용자 정의 Keras 레이어를 잘라내거나 잘라낼 레이어의 일부를 수정합니다.

일반적인 실수: 바이어스를 제거하면 일반적으로 모델 정확성이 너무 많이 손상됩니다.

tfmot.sparsity.keras.PrunableLayer는 두 가지 사용 사례를 제공합니다.

  1. 사용자 정의 Keras 레이어를 잘라냅니다.
  2. 내장 Keras 레이어의 일부를 수정하여 잘라냅니다.

예를 들어, API는 기본적으로 Dense 레이어의 커널만 잘라냅니다. 아래의 예는 바이어스도 제거합니다.

class MyDenseLayer(tf.keras.layers.Dense, tfmot.sparsity.keras.PrunableLayer):

  def get_prunable_weights(self):
    # Prune bias also, though that usually harms model accuracy too much.
    return [self.kernel, self.bias]

# Use `prune_low_magnitude` to make the `MyDenseLayer` layer train with pruning.
model_for_pruning = tf.keras.Sequential([
  tfmot.sparsity.keras.prune_low_magnitude(MyDenseLayer(20, input_shape=input_shape)),
  tf.keras.layers.Flatten()
])

model_for_pruning.summary()
Model: "sequential_5"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_my_dense (None, 20)                843       
_________________________________________________________________
flatten_6 (Flatten)          (None, 20)                0         
=================================================================
Total params: 843
Trainable params: 420
Non-trainable params: 423
_________________________________________________________________

모델 훈련하기

Model.fit

훈련 중에 tfmot.sparsity.keras.UpdatePruningStep 콜백을 호출합니다.

훈련 디버깅에 tfmot.sparsity.keras.PruningSummaries 콜백을 사용합니다.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

log_dir = tempfile.mkdtemp()
callbacks = [
    tfmot.sparsity.keras.UpdatePruningStep(),
    # Log sparsity and other metrics in Tensorboard.
    tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir)
]

model_for_pruning.compile(
      loss=tf.keras.losses.categorical_crossentropy,
      optimizer='adam',
      metrics=['accuracy']
)

model_for_pruning.fit(
    x_train,
    y_train,
    callbacks=callbacks,
    epochs=2,
)

#docs_infra: no_execute
%tensorboard --logdir={log_dir}

Colab이 아닌 사용자의 경우, TensorBoard.dev에서 이 코드 블록의 이전 실행의 결과를 볼 수 있습니다.

사용자 정의 훈련 루프

훈련 중에 tfmot.sparsity.keras.UpdatePruningStep 콜백을 호출합니다.

To help debug training, use the tfmot.sparsity.keras.PruningSummaries callback.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

# Boilerplate
loss = tf.keras.losses.categorical_crossentropy
optimizer = tf.keras.optimizers.Adam()
log_dir = tempfile.mkdtemp()
unused_arg = -1
epochs = 2
batches = 1 # example is hardcoded so that the number of batches cannot change.

# Non-boilerplate.
model_for_pruning.optimizer = optimizer
step_callback = tfmot.sparsity.keras.UpdatePruningStep()
step_callback.set_model(model_for_pruning)
log_callback = tfmot.sparsity.keras.PruningSummaries(log_dir=log_dir) # Log sparsity and other metrics in Tensorboard.
log_callback.set_model(model_for_pruning)

step_callback.on_train_begin() # run pruning callback
for _ in range(epochs):
  log_callback.on_epoch_begin(epoch=unused_arg) # run pruning callback
  for _ in range(batches):
    step_callback.on_train_batch_begin(batch=unused_arg) # run pruning callback

    with tf.GradientTape() as tape:
      logits = model_for_pruning(x_train, training=True)
      loss_value = loss(y_train, logits)
      grads = tape.gradient(loss_value, model_for_pruning.trainable_variables)
      optimizer.apply_gradients(zip(grads, model_for_pruning.trainable_variables))

  step_callback.on_epoch_end(batch=unused_arg) # run pruning callback

#docs_infra: no_execute
%tensorboard --logdir={log_dir}

Colab이 아닌 사용자의 경우, TensorBoard.dev에서 이 코드 블록의 이전 실행의 결과를 볼 수 있습니다.

잘라낸 모델의 정확성 향상하기

먼저, tfmot.sparsity.keras.prune_low_magnitude API 문서를 보고 잘라내기 일정이 무엇인지, 그리고 각 잘라내기 일정 유형의 수학을 이해합니다.

:

  • 모델이 잘라내기를 수행할 때 학습률이 너무 높거나 낮지 않습니다. 잘라내기 일정을 하이퍼 매개변수로 간주합니다.

  • 빠른 테스트로, tfmot.sparsity.keras.ConstantSparsity 일정으로 begin_step을 0으로 설정하여 훈련 시작 시 모델을 최종 희소성까지 잘라내는 실험을 시도해 보세요. 운이 좋으면 우수한 결과를 얻을 수도 있습니다.

  • 모델이 복구할 시간을 주기 위해 자주 잘라내기를 수행하지 마세요. 잘라내기 일정에서 적절한 기본 빈도를 제공합니다.

  • 모델 정확성을 개선하기 위한 일반적인 아이디어는 '모델 정의하기'에서 사용 사례에 대한 팁을 찾아보세요.

체크포인트 및 역직렬화

체크포인트 중에 옵티마이저 단계를 보존해야 합니다. 즉, 체크포인트 지정을 위해 Keras HDF5 모델을 사용할 수 있지만, Keras HDF5 가중치는 사용할 수 없습니다.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

_, keras_model_file = tempfile.mkstemp('.h5')

# Checkpoint: saving the optimizer is necessary (include_optimizer=True is the default).
model_for_pruning.save(keras_model_file, include_optimizer=True)
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.

위의 코드가 일반적으로 적용됩니다. 아래 코드는 HDF5 모델 형식(HDF5 가중치 및 기타 형식이 아님)에만 필요합니다.

# Deserialize model.
with tfmot.sparsity.keras.prune_scope():
  loaded_model = tf.keras.models.load_model(keras_model_file)

loaded_model.summary()
WARNING:tensorflow:No training configuration found in the save file, so the model was *not* compiled. Compile it manually.
Model: "sequential_6"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_6  (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________

잘라낸 모델 배포하기

크기 압축으로 모델 내보내기

일반적인 실수: 잘라내기의 압축 이점을 확인하려면, strip_pruning과 표준 압축 알고리즘(예: gzip을 통해)을 적용하는 것이 모두 필요합니다.

# Define the model.
base_model = setup_model()
base_model.load_weights(pretrained_weights) # optional but recommended for model accuracy
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model)

# Typically you train the model here.

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

print("final model")
model_for_export.summary()

print("\n")
print("Size of gzipped pruned model without stripping: %.2f bytes" % (get_gzipped_model_size(model_for_pruning)))
print("Size of gzipped pruned model with stripping: %.2f bytes" % (get_gzipped_model_size(model_for_export)))
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
final model
Model: "sequential_7"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_7 (Dense)              (None, 20)                420       
_________________________________________________________________
flatten_8 (Flatten)          (None, 20)                0         
=================================================================
Total params: 420
Trainable params: 420
Non-trainable params: 0
_________________________________________________________________


Size of gzipped pruned model without stripping: 3291.00 bytes
Size of gzipped pruned model with stripping: 2866.00 bytes

하드웨어별 최적화

여러 백엔드에서 잘라내기를 사용하여 지연 시간을 개선하면, 블록 희소성을 사용하여 특정 하드웨어의 지연 시간을 개선할 수 있습니다.

블록 크기를 늘리면 대상 모델의 정확성에 대해 달성할 수 있는 최대 희소성이 감소합니다. 그럼에도 불구하고, 지연 시간은 여전히 개선될 수 있습니다.

블록 희소성에 지원되는 항목에 대한 자세한 내용은 tfmot.sparsity.keras.prune_low_magnitude API 문서를 참조하세요.

base_model = setup_model()

# For using intrinsics on a CPU with 128-bit registers, together with 8-bit
# quantized weights, a 1x16 block size is nice because the block perfectly
# fits into the register.
pruning_params = {'block_size': [1, 16]}
model_for_pruning = tfmot.sparsity.keras.prune_low_magnitude(base_model, **pruning_params)

model_for_pruning.summary()
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.iter
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_1
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.beta_2
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.decay
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer.learning_rate
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'm' for (root).layer_with_weights-0.bias
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.kernel
WARNING:tensorflow:Unresolved object in checkpoint: (root).optimizer's state 'v' for (root).layer_with_weights-0.bias
WARNING:tensorflow:A checkpoint was restored (e.g. tf.train.Checkpoint.restore or tf.keras.Model.load_weights) but not all checkpointed values were used. See above for specific issues. Use expect_partial() on the load status object, e.g. tf.train.Checkpoint.restore(...).expect_partial(), to silence these warnings, or use assert_consumed() to make the check explicit. See https://www.tensorflow.org/guide/checkpoint#loading_mechanics for details.
Model: "sequential_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
prune_low_magnitude_dense_8  (None, 20)                822       
_________________________________________________________________
prune_low_magnitude_flatten_ (None, 20)                1         
=================================================================
Total params: 823
Trainable params: 420
Non-trainable params: 403
_________________________________________________________________