संघीय शिक्षण अनुसंधान के लिए TFF: मॉडल और अद्यतन संपीड़न

TensorFlow.org पर देखें Google Colab में चलाएं GitHub पर स्रोत देखें नोटबुक डाउनलोड करें

इस ट्यूटोरियल में, हम का उपयोग EMNIST संघीय औसत का उपयोग करते हुए एल्गोरिथ्म में संचार लागत को कम करने हानिपूर्ण संपीड़न एल्गोरिदम सक्षम करने के लिए कैसे प्रदर्शित करने के लिए डाटासेट tff.learning.build_federated_averaging_process एपीआई और tensor_encoding एपीआई। संघीय औसत का कलन विधि के बारे में अधिक जानकारी के लिए, कागज को देखने के विकेन्द्रीकृत डाटा से दीप नेटवर्क के संचार कुशल सीखना

हमारे शुरू करने से पहले

शुरू करने से पहले, कृपया यह सुनिश्चित करने के लिए निम्नलिखित चलाएँ कि आपका परिवेश सही ढंग से सेटअप है। आप एक ग्रीटिंग दिखाई नहीं देता है, का संदर्भ लें स्थापना निर्देश के लिए गाइड।

!pip install --quiet --upgrade tensorflow-federated-nightly
!pip install --quiet --upgrade tensorflow-model-optimization
!pip install --quiet --upgrade nest-asyncio

import nest_asyncio
nest_asyncio.apply()
%load_ext tensorboard

import functools

import numpy as np
import tensorflow as tf
import tensorflow_federated as tff

from tensorflow_model_optimization.python.core.internal import tensor_encoding as te

सत्यापित करें कि क्या TFF काम कर रहा है।

@tff.federated_computation
def hello_world():
  return 'Hello, World!'

hello_world()
b'Hello, World!'

इनपुट डेटा तैयार करना

इस खंड में हम TFF में शामिल EMNIST डेटासेट को लोड और प्रीप्रोसेस करते हैं। कृपया पहले, के लिए छवि वर्गीकरण संघीय लर्निंग EMNIST डाटासेट के बारे में अधिक जानकारी के लिए ट्यूटोरियल।

# This value only applies to EMNIST dataset, consider choosing appropriate
# values if switching to other datasets.
MAX_CLIENT_DATASET_SIZE = 418

CLIENT_EPOCHS_PER_ROUND = 1
CLIENT_BATCH_SIZE = 20
TEST_BATCH_SIZE = 500

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data(
    only_digits=True)

def reshape_emnist_element(element):
  return (tf.expand_dims(element['pixels'], axis=-1), element['label'])

def preprocess_train_dataset(dataset):
  """Preprocessing function for the EMNIST training dataset."""
  return (dataset
          # Shuffle according to the largest client dataset
          .shuffle(buffer_size=MAX_CLIENT_DATASET_SIZE)
          # Repeat to do multiple local epochs
          .repeat(CLIENT_EPOCHS_PER_ROUND)
          # Batch to a fixed client batch size
          .batch(CLIENT_BATCH_SIZE, drop_remainder=False)
          # Preprocessing step
          .map(reshape_emnist_element))

emnist_train = emnist_train.preprocess(preprocess_train_dataset)

एक मॉडल को परिभाषित करना

यहाँ हम orginial FedAvg सीएनएन के आधार पर एक keras मॉडल को परिभाषित है, और फिर का एक उदाहरण में keras मॉडल लपेट tff.learning.Model इतना है कि यह TFF द्वारा उपयोग किया जा।

ध्यान दें कि हम एक समारोह जो एक मॉडल के बजाय सिर्फ एक मॉडल के सीधे पैदा करता है की आवश्यकता होगी। इसके अलावा, समारोह सिर्फ एक पूर्व निर्माण मॉडल पर कब्जा नहीं कर सकते हैं, यह संदर्भ में मॉडल है कि यह कहा जाता है बनाना होगा। इसका कारण यह है कि TFF को उपकरणों पर जाने के लिए डिज़ाइन किया गया है, और संसाधनों का निर्माण कब किया जाता है, इस पर नियंत्रण की आवश्यकता होती है ताकि उन्हें कैप्चर और पैक किया जा सके।

def create_original_fedavg_cnn_model(only_digits=True):
  """The CNN model used in https://arxiv.org/abs/1602.05629."""
  data_format = 'channels_last'

  max_pool = functools.partial(
      tf.keras.layers.MaxPooling2D,
      pool_size=(2, 2),
      padding='same',
      data_format=data_format)
  conv2d = functools.partial(
      tf.keras.layers.Conv2D,
      kernel_size=5,
      padding='same',
      data_format=data_format,
      activation=tf.nn.relu)

  model = tf.keras.models.Sequential([
      tf.keras.layers.InputLayer(input_shape=(28, 28, 1)),
      conv2d(filters=32),
      max_pool(),
      conv2d(filters=64),
      max_pool(),
      tf.keras.layers.Flatten(),
      tf.keras.layers.Dense(512, activation=tf.nn.relu),
      tf.keras.layers.Dense(10 if only_digits else 62),
      tf.keras.layers.Softmax(),
  ])

  return model

# Gets the type information of the input data. TFF is a strongly typed
# functional programming framework, and needs type information about inputs to 
# the model.
input_spec = emnist_train.create_tf_dataset_for_client(
    emnist_train.client_ids[0]).element_spec

def tff_model_fn():
  keras_model = create_original_fedavg_cnn_model()
  return tff.learning.from_keras_model(
      keras_model=keras_model,
      input_spec=input_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

मॉडल का प्रशिक्षण और प्रशिक्षण मेट्रिक्स का उत्पादन

अब हम एक फ़ेडरेटेड एवरेजिंग एल्गोरिथम बनाने और EMNIST डेटासेट पर परिभाषित मॉडल को प्रशिक्षित करने के लिए तैयार हैं।

पहले हम का उपयोग कर एक संघीय औसत का कलन विधि का निर्माण करने की जरूरत है tff.learning.build_federated_averaging_process एपीआई।

federated_averaging = tff.learning.build_federated_averaging_process(
    model_fn=tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))

अब फेडरेटेड एवरेजिंग एल्गोरिथम चलाते हैं। TFF के दृष्टिकोण से एक फ़ेडरेटेड लर्निंग एल्गोरिथम का निष्पादन इस तरह दिखता है:

  1. एल्गोरिथम को इनिशियलाइज़ करें और इनबिल्ट सर्वर स्टेट प्राप्त करें। सर्वर स्थिति में एल्गोरिथम को निष्पादित करने के लिए आवश्यक जानकारी होती है। याद रखें, चूंकि टीएफएफ कार्यात्मक है, इस राज्य में एल्गोरिदम का उपयोग करने वाले किसी भी अनुकूलक राज्य (जैसे गति शर्तों) के साथ-साथ मॉडल पैरामीटर भी शामिल हैं - इन्हें तर्क के रूप में पारित किया जाएगा और टीएफएफ गणनाओं से परिणाम के रूप में लौटाया जाएगा।
  2. एल्गोरिथम को गोल-गोल निष्पादित करें। प्रत्येक दौर में, एक नया सर्वर स्थिति लौटा दी जाएगी क्योंकि प्रत्येक क्लाइंट अपने डेटा पर मॉडल को प्रशिक्षण देता है। आमतौर पर एक दौर में:
    1. सर्वर ने सभी भाग लेने वाले ग्राहकों को मॉडल प्रसारित किया।
    2. प्रत्येक ग्राहक मॉडल और अपने स्वयं के डेटा के आधार पर कार्य करता है।
    3. सर्वर सभी मॉडल को एक अलग स्थिति उत्पन्न करने के लिए एकत्रित करता है जिसमें एक नया मॉडल होता है।

अधिक जानकारी के लिए, कृपया देखें लागू संघीय औसत का: कस्टम संघीय एल्गोरिदम, भाग 2 ट्यूटोरियल।

प्रशिक्षण के बाद प्रदर्शित करने के लिए प्रशिक्षण मेट्रिक्स को Tensorboard निर्देशिका में लिखा जाता है।

लोड उपयोगिता कार्य

def train(federated_averaging_process, num_rounds, num_clients_per_round, summary_writer):
  """Trains the federated averaging process and output metrics."""
  # Create a environment to get communication cost.
  environment = set_sizing_environment()

  # Initialize the Federated Averaging algorithm to get the initial server state.
  state = federated_averaging_process.initialize()

  with summary_writer.as_default():
    for round_num in range(num_rounds):
      # Sample the clients parcitipated in this round.
      sampled_clients = np.random.choice(
          emnist_train.client_ids,
          size=num_clients_per_round,
          replace=False)
      # Create a list of `tf.Dataset` instances from the data of sampled clients.
      sampled_train_data = [
          emnist_train.create_tf_dataset_for_client(client)
          for client in sampled_clients
      ]
      # Round one round of the algorithm based on the server state and client data
      # and output the new state and metrics.
      state, metrics = federated_averaging_process.next(state, sampled_train_data)

      # For more about size_info, please see https://www.tensorflow.org/federated/api_docs/python/tff/framework/SizeInfo
      size_info = environment.get_size_info()
      broadcasted_bits = size_info.broadcast_bits[-1]
      aggregated_bits = size_info.aggregate_bits[-1]

      print('round {:2d}, metrics={}, broadcasted_bits={}, aggregated_bits={}'.format(round_num, metrics, format_size(broadcasted_bits), format_size(aggregated_bits)))

      # Add metrics to Tensorboard.
      for name, value in metrics['train'].items():
          tf.summary.scalar(name, value, step=round_num)

      # Add broadcasted and aggregated data size to Tensorboard.
      tf.summary.scalar('cumulative_broadcasted_bits', broadcasted_bits, step=round_num)
      tf.summary.scalar('cumulative_aggregated_bits', aggregated_bits, step=round_num)
      summary_writer.flush()
# Clean the log directory to avoid conflicts.
try:
  tf.io.gfile.rmtree('/tmp/logs/scalars')
except tf.errors.OpError as e:
  pass  # Path doesn't exist

# Set up the log directory and writer for Tensorboard.
logdir = "/tmp/logs/scalars/original/"
summary_writer = tf.summary.create_file_writer(logdir)

train(federated_averaging_process=federated_averaging, num_rounds=10,
      num_clients_per_round=10, summary_writer=summary_writer)
round  0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07383774), ('loss', 2.3276227)])), ('stat', OrderedDict([('num_examples', 1097)]))]), broadcasted_bits=507.62Mibit, aggregated_bits=507.62Mibit
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.099585064), ('loss', 2.3152695)])), ('stat', OrderedDict([('num_examples', 964)]))]), broadcasted_bits=1015.24Mibit, aggregated_bits=1015.24Mibit
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09760766), ('loss', 2.3077576)])), ('stat', OrderedDict([('num_examples', 1045)]))]), broadcasted_bits=1.49Gibit, aggregated_bits=1.49Gibit
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.0963035), ('loss', 2.3066626)])), ('stat', OrderedDict([('num_examples', 1028)]))]), broadcasted_bits=1.98Gibit, aggregated_bits=1.98Gibit
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10694184), ('loss', 2.3033001)])), ('stat', OrderedDict([('num_examples', 1066)]))]), broadcasted_bits=2.48Gibit, aggregated_bits=2.48Gibit
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.1185567), ('loss', 2.2999184)])), ('stat', OrderedDict([('num_examples', 970)]))]), broadcasted_bits=2.97Gibit, aggregated_bits=2.97Gibit
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.11751663), ('loss', 2.296883)])), ('stat', OrderedDict([('num_examples', 902)]))]), broadcasted_bits=3.47Gibit, aggregated_bits=3.47Gibit
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13063477), ('loss', 2.2990246)])), ('stat', OrderedDict([('num_examples', 1087)]))]), broadcasted_bits=3.97Gibit, aggregated_bits=3.97Gibit
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12742382), ('loss', 2.2971866)])), ('stat', OrderedDict([('num_examples', 1083)]))]), broadcasted_bits=4.46Gibit, aggregated_bits=4.46Gibit
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13555992), ('loss', 2.2934425)])), ('stat', OrderedDict([('num_examples', 1018)]))]), broadcasted_bits=4.96Gibit, aggregated_bits=4.96Gibit

प्रशिक्षण मेट्रिक्स प्रदर्शित करने के लिए ऊपर निर्दिष्ट रूट लॉग निर्देशिका के साथ TensorBoard प्रारंभ करें। डेटा लोड होने में कुछ सेकंड लग सकते हैं। हानि और सटीकता को छोड़कर, हम प्रसारित और एकत्रित डेटा की मात्रा का भी उत्पादन करते हैं। ब्रॉडकास्टेड डेटा टेंसर को संदर्भित करता है जिसे सर्वर प्रत्येक क्लाइंट को पुश करता है जबकि एग्रीगेट डेटा टेनर्स को संदर्भित करता है जो प्रत्येक क्लाइंट सर्वर पर लौटता है।

%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard...
Reusing TensorBoard on port 34445 (pid 579503), started 1:53:14 ago. (Use '!kill 579503' to kill it.)
<IPython.core.display.Javascript at 0x7f9135ef1630>

एक कस्टम प्रसारण और समग्र कार्य बनाएँ

अब हम प्रसारित डेटा और समेकित डेटा का उपयोग करने पर हानिपूर्ण संपीड़न एल्गोरिदम का उपयोग करने के लिए समारोह को लागू करते हैं tensor_encoding एपीआई।

सबसे पहले, हम दो कार्यों को परिभाषित करते हैं:

  • broadcast_encoder_fn जो का एक उदाहरण बनाता te.core.SimpleEncoder ग्राहक संचार (ब्रॉडकास्ट डेटा) के लिए सर्वर में एनकोड tensors या चर करने के लिए।
  • mean_encoder_fn जो का एक उदाहरण बनाता te.core.GatherEncoder सर्वर communicaiton (एकत्रीकरण डेटा) के लिए ग्राहक में एनकोड tensors या चर करने के लिए।

यह ध्यान रखना महत्वपूर्ण है कि हम एक बार में पूरे मॉडल पर एक संपीड़न विधि लागू नहीं करते हैं। इसके बजाय, हम तय करते हैं कि कैसे (और क्या) मॉडल के प्रत्येक चर को स्वतंत्र रूप से संपीड़ित किया जाए। इसका कारण यह है कि आम तौर पर, पूर्वाग्रह जैसे छोटे चर अशुद्धि के प्रति अधिक संवेदनशील होते हैं, और अपेक्षाकृत छोटा होने के कारण, संभावित संचार बचत भी अपेक्षाकृत कम होती है। इसलिए हम डिफ़ॉल्ट रूप से छोटे चर को संपीड़ित नहीं करते हैं। इस उदाहरण में, हम 10000 से अधिक तत्वों के साथ प्रत्येक चर के लिए 8 बिट्स (256 बाल्टी) पर एक समान परिमाणीकरण लागू करते हैं, और केवल अन्य चर के लिए पहचान लागू करते हैं।

def broadcast_encoder_fn(value):
  """Function for building encoded broadcast."""
  spec = tf.TensorSpec(value.shape, value.dtype)
  if value.shape.num_elements() > 10000:
    return te.encoders.as_simple_encoder(
        te.encoders.uniform_quantization(bits=8), spec)
  else:
    return te.encoders.as_simple_encoder(te.encoders.identity(), spec)


def mean_encoder_fn(tensor_spec):
  """Function for building a GatherEncoder."""
  spec = tf.TensorSpec(tensor_spec.shape, tensor_spec.dtype)
  if tensor_spec.shape.num_elements() > 10000:
    return te.encoders.as_gather_encoder(
        te.encoders.uniform_quantization(bits=8), spec)
  else:
    return te.encoders.as_gather_encoder(te.encoders.identity(), spec)

TFF एक प्रारूप में एनकोडर समारोह कन्वर्ट करने के लिए है कि API प्रदान करता है tff.learning.build_federated_averaging_process एपीआई उपभोग कर सकते हैं। का उपयोग करके tff.learning.framework.build_encoded_broadcast_from_model और tff.aggregators.MeanFactory , हम दो वस्तुओं है कि में पारित किया जा सकता बना सकते हैं broadcast_process और model_update_aggregation_factory की agruments tff.learning.build_federated_averaging_process एक हानिपूर्ण संपीड़न एल्गोरिथ्म के साथ एक संघीय के औसत एल्गोरिदम बनाने के लिए।

encoded_broadcast_process = (
    tff.learning.framework.build_encoded_broadcast_process_from_model(
        tff_model_fn, broadcast_encoder_fn))

mean_factory = tff.aggregators.MeanFactory(
    tff.aggregators.EncodedSumFactory(mean_encoder_fn), # numerator
    tff.aggregators.EncodedSumFactory(mean_encoder_fn), # denominator
)

federated_averaging_with_compression = tff.learning.build_federated_averaging_process(
    tff_model_fn,
    client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
    server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0),
    broadcast_process=encoded_broadcast_process,
    model_update_aggregation_factory=mean_factory)

मॉडल को फिर से प्रशिक्षण

अब चलिए नया फेडरेटेड एवरेजिंग एल्गोरिथम चलाते हैं।

logdir_for_compression = "/tmp/logs/scalars/compression/"
summary_writer_for_compression = tf.summary.create_file_writer(
    logdir_for_compression)

train(federated_averaging_process=federated_averaging_with_compression, 
      num_rounds=10,
      num_clients_per_round=10,
      summary_writer=summary_writer_for_compression)
round  0, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.093), ('loss', 2.3194966)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=146.46Mibit, aggregated_bits=146.46Mibit
round  1, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.10432034), ('loss', 2.3079953)])), ('stat', OrderedDict([('num_examples', 949)]))]), broadcasted_bits=292.92Mibit, aggregated_bits=292.93Mibit
round  2, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.07886754), ('loss', 2.3101337)])), ('stat', OrderedDict([('num_examples', 989)]))]), broadcasted_bits=439.38Mibit, aggregated_bits=439.39Mibit
round  3, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09774436), ('loss', 2.305069)])), ('stat', OrderedDict([('num_examples', 1064)]))]), broadcasted_bits=585.84Mibit, aggregated_bits=585.85Mibit
round  4, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09404097), ('loss', 2.302943)])), ('stat', OrderedDict([('num_examples', 1074)]))]), broadcasted_bits=732.30Mibit, aggregated_bits=732.32Mibit
round  5, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.09), ('loss', 2.304385)])), ('stat', OrderedDict([('num_examples', 1000)]))]), broadcasted_bits=878.77Mibit, aggregated_bits=878.78Mibit
round  6, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.14368932), ('loss', 2.2973824)])), ('stat', OrderedDict([('num_examples', 1030)]))]), broadcasted_bits=1.00Gibit, aggregated_bits=1.00Gibit
round  7, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.12140871), ('loss', 2.2993405)])), ('stat', OrderedDict([('num_examples', 1079)]))]), broadcasted_bits=1.14Gibit, aggregated_bits=1.14Gibit
round  8, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13600783), ('loss', 2.2953267)])), ('stat', OrderedDict([('num_examples', 1022)]))]), broadcasted_bits=1.29Gibit, aggregated_bits=1.29Gibit
round  9, metrics=OrderedDict([('broadcast', ()), ('aggregation', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('train', OrderedDict([('sparse_categorical_accuracy', 0.13844621), ('loss', 2.295768)])), ('stat', OrderedDict([('num_examples', 1004)]))]), broadcasted_bits=1.43Gibit, aggregated_bits=1.43Gibit

दो रन के बीच प्रशिक्षण मेट्रिक्स की तुलना करने के लिए फिर से TensorBoard प्रारंभ करें।

आप Tensorboard में देख सकते हैं, वहाँ के बीच एक महत्वपूर्ण कमी है orginial और compression में घटता broadcasted_bits और aggregated_bits भूखंडों समय में loss और sparse_categorical_accuracy भूखंड दो घटता बहुत समान हैं।

अंत में, हमने एक कम्प्रेशन एल्गोरिथम लागू किया जो कि ओरिजिनल फेडरेटेड एवरेजिंग एल्गोरिथम के समान प्रदर्शन प्राप्त कर सकता है, जबकि कम्युनिकेशन लागत काफी कम हो जाती है।

%tensorboard --logdir /tmp/logs/scalars/ --port=0
Launching TensorBoard...
Reusing TensorBoard on port 34445 (pid 579503), started 1:54:12 ago. (Use '!kill 579503' to kill it.)
<IPython.core.display.Javascript at 0x7f9140eb5ef0>

अभ्यास

एक कस्टम संपीड़न एल्गोरिथ्म को लागू करने और इसे प्रशिक्षण लूप पर लागू करने के लिए, आप यह कर सकते हैं:

  1. का एक उपवर्ग के रूप में एक नया संपीड़न एल्गोरिथ्म लागू EncodingStageInterface या उसके अधिक सामान्य संस्करण है, AdaptiveEncodingStageInterface निम्नलिखित इस उदाहरण
  2. अपने नए निर्माण Encoder और इसके लिए विशेषज्ञ मॉडल प्रसारण या मॉडल अद्यतन औसत
  3. पूरे निर्माण करने के लिए उन वस्तुओं का प्रयोग करें प्रशिक्षण गणना

संभावित रूप से मूल्यवान खुले शोध प्रश्नों में शामिल हैं: गैर-समान परिमाणीकरण, हफ़मैन कोडिंग जैसे दोषरहित संपीड़न, और पिछले प्रशिक्षण दौर की जानकारी के आधार पर संपीड़न को अपनाने के लिए तंत्र।

अनुशंसित पठन सामग्री: