Regularização de grafos para classificação de documentos usando grafos naturais

Ver no TensorFlow.org Executar no Google Colab Ver fonte no GitHub Baixar caderno

Visão geral

Gráfico regularização é uma técnica específica sob o paradigma mais amplo de Neural gráfico de aprendizagem ( Bui et al., 2018 ). A ideia central é treinar modelos de rede neural com um objetivo regularizado por gráfico, aproveitando dados rotulados e não rotulados.

Neste tutorial, exploraremos o uso de regularização de gráfico para classificar documentos que formam um gráfico natural (orgânico).

A receita geral para a criação de um modelo regularizado por gráfico usando o framework Neural Structured Learning (NSL) é a seguinte:

  1. Gere dados de treinamento a partir do gráfico de entrada e recursos de amostra. Os nós no gráfico correspondem a amostras e as arestas no gráfico correspondem à similaridade entre pares de amostras. Os dados de treinamento resultantes conterão recursos vizinhos, além dos recursos do nó original.
  2. Criar uma rede neural como um modelo básico usando o Keras sequencial, funcional ou subclasse API.
  3. Enrole o modelo base com o GraphRegularization classe wrapper, que é fornecido pela estrutura NSL, para criar um novo gráfico Keras modelo. Este novo modelo incluirá uma perda de regularização de gráfico como termo de regularização em seu objetivo de treinamento.
  4. Treinar e avaliar o gráfico Keras modelo.

Configurar

Instale o pacote Neural Structured Learning.

pip install --quiet neural-structured-learning

Dependências e importações

import neural_structured_learning as nsl

import tensorflow as tf

# Resets notebook state
tf.keras.backend.clear_session()

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
    "GPU is",
    "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version:  2.8.0-rc0
Eager mode:  True
GPU is NOT AVAILABLE
2022-01-05 12:39:27.704660: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Conjunto de dados Cora

O conjunto de dados Cora é um gráfico citação onde nós representam papéis aprendizado de máquina e arestas representam citações entre pares de papéis. A tarefa envolvida é a classificação de documentos, onde o objetivo é categorizar cada artigo em uma das 7 categorias. Em outras palavras, este é um problema de classificação multiclasse com 7 classes.

Gráfico

O gráfico original é direcionado. No entanto, para o propósito deste exemplo, consideramos a versão não direcionada deste gráfico. Portanto, se o artigo A cita o artigo B, também consideramos o artigo B como tendo citado A. Embora isso não seja necessariamente verdade, neste exemplo, consideramos as citações como um proxy de similaridade, que geralmente é uma propriedade comutativa.

Recursos

Cada artigo na entrada contém efetivamente 2 recursos:

  1. Palavras: Uma densa, multi-quente representação bag-of-palavras do texto no papel. O vocabulário para o conjunto de dados Cora contém 1433 palavras exclusivas. Portanto, o comprimento desse recurso é 1433, e o valor na posição 'i' é 0/1, indicando se a palavra 'i' no vocabulário existe no artigo fornecido ou não.

  2. Etiqueta: Um único número inteiro que representa a identificação de classe (categoria) do papel.

Baixe o conjunto de dados Cora

wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/
cora/README
cora/cora.cites
cora/cora.content

Converta os dados Cora para o formato NSL

A fim de pré-processar o conjunto de dados Cora e convertê-lo para o formato exigido pela aprendizagem estruturada Neural, vamos executar o script 'preprocess_cora_dataset.py', que está incluído no repositório de NSL github. Este script faz o seguinte:

  1. Gere recursos vizinhos usando os recursos do nó original e o gráfico.
  2. Gerar trem e dados de teste splits contendo tf.train.Example casos.
  3. Persistir os dados de teste de trem e resultando na TFRecord formato.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py

!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2022-01-05 12:39:28--  https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 11640 (11K) [text/plain]
Saving to: ‘preprocess_cora_dataset.py’

preprocess_cora_dat 100%[===================>]  11.37K  --.-KB/s    in 0s      

2022-01-05 12:39:28 (78.9 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640]

2022-01-05 12:39:31.378912: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Reading graph file: /tmp/cora/cora.cites...
Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds).
Making all edges bi-directional...
Done (0.01 seconds). Total graph nodes: 2708
Joining seed and neighbor tf.train.Examples with graph edges...
Done creating and writing 2155 merged tf.train.Examples (1.36 seconds).
Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)]
Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr.
Output test data written to TFRecord file: /tmp/cora/test_examples.tfr.
Total running time: 0.04 minutes.

Variáveis ​​globais

Os caminhos de arquivo para os dados de comboio e de teste são baseados nos valores de sinalizadores de linha de comando usados para invocar o script 'preprocess_cora_dataset.py' acima.

### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'

### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'

Hiperparâmetros

Vamos usar um exemplo de HParams para incluir vários hiperparâmetros e constantes utilizadas para treinamento e avaliação. Descrevemos resumidamente cada um deles abaixo:

  • num_classes: Há um total de 7 classes diferentes

  • max_seq_length: Este é o tamanho do vocabulário e todos os exemplos na entrada tem um multi-quente, representação densa saco-de-palavras. Em outras palavras, um valor de 1 para uma palavra indica que a palavra está presente na entrada e um valor de 0 indica que não está.

  • distance_type: Esta é a distância métrica utilizada para regularizar a amostra com seus vizinhos.

  • graph_regularization_multiplier: Este controla o peso relativo do termo gráfico regularização na função global de perda.

  • num_neighbors: O número de vizinhos utilizados para regularização gráfico. Este valor tem de ser inferior ou igual aos max_nbrs comando da linha de argumento usado acima durante a execução preprocess_cora_dataset.py .

  • num_fc_units: O número de camadas totalmente conectados em nossa rede neural.

  • train_epochs: O número de épocas de treinamento.

  • Tamanho do lote usado para treinamento e avaliação: batch_size.

  • dropout_rate: controla a taxa de abandono a seguir a cada camada totalmente ligado

  • eval_steps: O número de lotes para processo antes julgando avaliação for concluída. Se definido para None , todas as instâncias do conjunto de teste são avaliados.

class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 7
    self.max_seq_length = 1433
    ### neural graph learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    self.num_neighbors = 1
    ### model architecture
    self.num_fc_units = [50, 50]
    ### training parameters
    self.train_epochs = 100
    self.batch_size = 128
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()

Carregar dados de trem e teste

Conforme descrito anteriormente neste caderno, os dados de treinamento de entrada e de teste foram criadas pela 'preprocess_cora_dataset.py'. Vamos carregá-los em dois tf.data.Dataset objetos - um para trem e um para teste.

Na camada de entrada do nosso modelo, vamos extrair não apenas as 'palavras' e o 'rótulo' apresenta de cada amostra, mas vizinho também correspondente recursos com base na hparams.num_neighbors valor. Instâncias com poucos vizinhos que hparams.num_neighbors será atribuído manequim valores para as características vizinho inexistentes.

def make_dataset(file_path, training=False):
  """Creates a `tf.data.TFRecordDataset`.

  Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

  Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
  """

  def parse_example(example_proto):
    """Extracts relevant fields from the `example_proto`.

    Args:
      example_proto: An instance of `tf.train.Example`.

    Returns:
      A pair whose first value is a dictionary containing relevant features
      and whose second value contains the ground truth label.
    """
    # The 'words' feature is a multi-hot, bag-of-words representation of the
    # original raw text. A default value is required for examples that don't
    # have the feature.
    feature_spec = {
        'words':
            tf.io.FixedLenFeature([HPARAMS.max_seq_length],
                                  tf.int64,
                                  default_value=tf.constant(
                                      0,
                                      dtype=tf.int64,
                                      shape=[HPARAMS.max_seq_length])),
        'label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1),
    }
    # We also extract corresponding neighbor features in a similar manner to
    # the features above during training.
    if training:
      for i in range(HPARAMS.num_neighbors):
        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
                                         NBR_WEIGHT_SUFFIX)
        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
            [HPARAMS.max_seq_length],
            tf.int64,
            default_value=tf.constant(
                0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))

        # We assign a default value of 0.0 for the neighbor weight so that
        # graph regularization is done on samples based on their exact number
        # of neighbors. In other words, non-existent neighbors are discounted.
        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
            [1], tf.float32, default_value=tf.constant([0.0]))

    features = tf.io.parse_single_example(example_proto, feature_spec)

    label = features.pop('label')
    return features, label

  dataset = tf.data.TFRecordDataset([file_path])
  if training:
    dataset = dataset.shuffle(10000)
  dataset = dataset.map(parse_example)
  dataset = dataset.batch(HPARAMS.batch_size)
  return dataset


train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)

Vamos dar uma olhada no conjunto de dados do trem para ver seu conteúdo.

for feature_batch, label_batch in train_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 1 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of neighbor weights: tf.Tensor(
[1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.
 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32)
Batch of labels: tf.Tensor(
[2 2 6 2 0 6 1 3 5 0 1 2 3 6 1 1 0 3 5 2 3 1 4 1 6 1 3 2 2 2 0 3 2 1 3 3 2
 3 3 2 3 2 2 0 2 2 6 0 2 1 1 0 5 2 1 4 2 1 2 4 0 2 5 4 3 6 3 2 1 6 2 4 2 2
 6 4 6 4 3 5 2 2 2 4 2 2 2 1 2 2 2 4 2 3 6 2 0 6 6 0 2 6 2 1 2 0 1 1 3 2 0
 2 0 2 1 1 3 5 2 1 2 5 1 6 2 4 6 4], shape=(128,), dtype=int64)

Vamos dar uma olhada no conjunto de dados de teste para ver seu conteúdo.

for feature_batch, label_batch in test_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  print('Batch of labels:', label_batch)
Feature list: ['words']
Batch of inputs: tf.Tensor(
[[0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 ...
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]
 [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64)
Batch of labels: tf.Tensor(
[5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2
 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5
 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6
 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)

Definição de modelo

Para demonstrar o uso da regularização de grafos, construímos primeiro um modelo básico para esse problema. Usaremos uma rede neural feed-forward simples com 2 camadas ocultas e dropout entre elas. Nós ilustrar a criação do modelo de base usando todos os tipos de modelo apoiados pelo tf.Keras quadro - sequencial, funcional e subclasse.

Modelo de base sequencial

def make_mlp_sequential_model(hparams):
  """Creates a sequential multi-layer perceptron model."""
  model = tf.keras.Sequential()
  model.add(
      tf.keras.layers.InputLayer(
          input_shape=(hparams.max_seq_length,), name='words'))
  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  model.add(
      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
  for num_units in hparams.num_fc_units:
    model.add(tf.keras.layers.Dense(num_units, activation='relu'))
    # For sequential models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
  model.add(tf.keras.layers.Dense(hparams.num_classes))
  return model

Modelo de base funcional

def make_mlp_functional_model(hparams):
  """Creates a functional API-based multi-layer perceptron model."""
  inputs = tf.keras.Input(
      shape=(hparams.max_seq_length,), dtype='int64', name='words')

  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  cur_layer = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.cast(x, tf.float32))(
          inputs)

  for num_units in hparams.num_fc_units:
    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
    # For functional models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)

  outputs = tf.keras.layers.Dense(hparams.num_classes)(cur_layer)

  model = tf.keras.Model(inputs, outputs=outputs)
  return model

Modelo básico de subclasse

def make_mlp_subclass_model(hparams):
  """Creates a multi-layer perceptron subclass model in Keras."""

  class MLP(tf.keras.Model):
    """Subclass model defining a multi-layer perceptron."""

    def __init__(self):
      super(MLP, self).__init__()
      # Input is already one-hot encoded in the integer format. We create a
      # layer to cast it to floating point format here.
      self.cast_to_float_layer = tf.keras.layers.Lambda(
          lambda x: tf.keras.backend.cast(x, tf.float32))
      self.dense_layers = [
          tf.keras.layers.Dense(num_units, activation='relu')
          for num_units in hparams.num_fc_units
      ]
      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
      self.output_layer = tf.keras.layers.Dense(hparams.num_classes)

    def call(self, inputs, training=False):
      cur_layer = self.cast_to_float_layer(inputs['words'])
      for dense_layer in self.dense_layers:
        cur_layer = dense_layer(cur_layer)
        cur_layer = self.dropout_layer(cur_layer, training=training)

      outputs = self.output_layer(cur_layer)

      return outputs

  return MLP()

Criar modelo (s) de base

# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 words (InputLayer)          [(None, 1433)]            0         
                                                                 
 lambda (Lambda)             (None, 1433)              0         
                                                                 
 dense (Dense)               (None, 50)                71700     
                                                                 
 dropout (Dropout)           (None, 50)                0         
                                                                 
 dense_1 (Dense)             (None, 50)                2550      
                                                                 
 dropout_1 (Dropout)         (None, 50)                0         
                                                                 
 dense_2 (Dense)             (None, 7)                 357       
                                                                 
=================================================================
Total params: 74,607
Trainable params: 74,607
Non-trainable params: 0
_________________________________________________________________

Modelo MLP de base de treinamento

# Compile and train the base MLP model
base_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/functional.py:559: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model.
  inputs = self._flatten_to_reference_inputs(inputs)
17/17 [==============================] - 1s 18ms/step - loss: 1.9521 - accuracy: 0.1838
Epoch 2/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8590 - accuracy: 0.3044
Epoch 3/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7770 - accuracy: 0.3601
Epoch 4/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6655 - accuracy: 0.3898
Epoch 5/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5386 - accuracy: 0.4543
Epoch 6/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3856 - accuracy: 0.5077
Epoch 7/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2736 - accuracy: 0.5531
Epoch 8/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1636 - accuracy: 0.5889
Epoch 9/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0654 - accuracy: 0.6385
Epoch 10/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9703 - accuracy: 0.6761
Epoch 11/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8689 - accuracy: 0.7104
Epoch 12/100
17/17 [==============================] - 0s 3ms/step - loss: 0.7704 - accuracy: 0.7494
Epoch 13/100
17/17 [==============================] - 0s 3ms/step - loss: 0.7157 - accuracy: 0.7810
Epoch 14/100
17/17 [==============================] - 0s 3ms/step - loss: 0.6296 - accuracy: 0.8186
Epoch 15/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5932 - accuracy: 0.8167
Epoch 16/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5526 - accuracy: 0.8464
Epoch 17/100
17/17 [==============================] - 0s 3ms/step - loss: 0.5112 - accuracy: 0.8445
Epoch 18/100
17/17 [==============================] - 0s 3ms/step - loss: 0.4624 - accuracy: 0.8613
Epoch 19/100
17/17 [==============================] - 0s 3ms/step - loss: 0.4163 - accuracy: 0.8696
Epoch 20/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3808 - accuracy: 0.8849
Epoch 21/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3564 - accuracy: 0.8933
Epoch 22/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3453 - accuracy: 0.9002
Epoch 23/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3226 - accuracy: 0.9114
Epoch 24/100
17/17 [==============================] - 0s 3ms/step - loss: 0.3058 - accuracy: 0.9151
Epoch 25/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2798 - accuracy: 0.9146
Epoch 26/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2638 - accuracy: 0.9248
Epoch 27/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2538 - accuracy: 0.9290
Epoch 28/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2356 - accuracy: 0.9411
Epoch 29/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2080 - accuracy: 0.9425
Epoch 30/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2172 - accuracy: 0.9364
Epoch 31/100
17/17 [==============================] - 0s 3ms/step - loss: 0.2259 - accuracy: 0.9225
Epoch 32/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1944 - accuracy: 0.9480
Epoch 33/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1892 - accuracy: 0.9434
Epoch 34/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1718 - accuracy: 0.9592
Epoch 35/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1826 - accuracy: 0.9508
Epoch 36/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1585 - accuracy: 0.9559
Epoch 37/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1605 - accuracy: 0.9545
Epoch 38/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1529 - accuracy: 0.9550
Epoch 39/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1411 - accuracy: 0.9615
Epoch 40/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1366 - accuracy: 0.9624
Epoch 41/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1431 - accuracy: 0.9578
Epoch 42/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1241 - accuracy: 0.9619
Epoch 43/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1310 - accuracy: 0.9661
Epoch 44/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1284 - accuracy: 0.9652
Epoch 45/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1215 - accuracy: 0.9633
Epoch 46/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1130 - accuracy: 0.9722
Epoch 47/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1074 - accuracy: 0.9722
Epoch 48/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1143 - accuracy: 0.9694
Epoch 49/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1015 - accuracy: 0.9740
Epoch 50/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1077 - accuracy: 0.9698
Epoch 51/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1035 - accuracy: 0.9684
Epoch 52/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1076 - accuracy: 0.9694
Epoch 53/100
17/17 [==============================] - 0s 3ms/step - loss: 0.1000 - accuracy: 0.9689
Epoch 54/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0967 - accuracy: 0.9749
Epoch 55/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0994 - accuracy: 0.9703
Epoch 56/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0943 - accuracy: 0.9740
Epoch 57/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0923 - accuracy: 0.9735
Epoch 58/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0848 - accuracy: 0.9800
Epoch 59/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0836 - accuracy: 0.9782
Epoch 60/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0913 - accuracy: 0.9735
Epoch 61/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0823 - accuracy: 0.9773
Epoch 62/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0753 - accuracy: 0.9810
Epoch 63/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0746 - accuracy: 0.9777
Epoch 64/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0861 - accuracy: 0.9731
Epoch 65/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0765 - accuracy: 0.9787
Epoch 66/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0750 - accuracy: 0.9791
Epoch 67/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0725 - accuracy: 0.9814
Epoch 68/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0762 - accuracy: 0.9791
Epoch 69/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0645 - accuracy: 0.9842
Epoch 70/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0606 - accuracy: 0.9861
Epoch 71/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0775 - accuracy: 0.9805
Epoch 72/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0655 - accuracy: 0.9800
Epoch 73/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0629 - accuracy: 0.9833
Epoch 74/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0625 - accuracy: 0.9824
Epoch 75/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0607 - accuracy: 0.9838
Epoch 76/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0578 - accuracy: 0.9824
Epoch 77/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0568 - accuracy: 0.9842
Epoch 78/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0595 - accuracy: 0.9833
Epoch 79/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0615 - accuracy: 0.9842
Epoch 80/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0555 - accuracy: 0.9852
Epoch 81/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0517 - accuracy: 0.9870
Epoch 82/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0541 - accuracy: 0.9856
Epoch 83/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0533 - accuracy: 0.9884
Epoch 84/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0509 - accuracy: 0.9838
Epoch 85/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0600 - accuracy: 0.9828
Epoch 86/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0617 - accuracy: 0.9800
Epoch 87/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0599 - accuracy: 0.9800
Epoch 88/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0502 - accuracy: 0.9870
Epoch 89/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0416 - accuracy: 0.9907
Epoch 90/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0542 - accuracy: 0.9842
Epoch 91/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0490 - accuracy: 0.9847
Epoch 92/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0374 - accuracy: 0.9916
Epoch 93/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0467 - accuracy: 0.9893
Epoch 94/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0426 - accuracy: 0.9879
Epoch 95/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0543 - accuracy: 0.9861
Epoch 96/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0420 - accuracy: 0.9870
Epoch 97/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0461 - accuracy: 0.9861
Epoch 98/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0425 - accuracy: 0.9898
Epoch 99/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0406 - accuracy: 0.9907
Epoch 100/100
17/17 [==============================] - 0s 3ms/step - loss: 0.0486 - accuracy: 0.9847
<keras.callbacks.History at 0x7f6f9d5eacd0>

Avalie o modelo básico de MLP

# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
  """Prints evaluation metrics.

  Args:
    model_desc: A description of the model.
    eval_metrics: A dictionary mapping metric names to corresponding values. It
      must contain the loss and accuracy metrics.
  """
  print('\n')
  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
  if 'graph_loss' in eval_metrics:
    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
    zip(base_model.metrics_names,
        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.4192 - accuracy: 0.7939


Eval accuracy for  Base MLP model :  0.7938517332077026
Eval loss for  Base MLP model :  1.4192423820495605

Treine o modelo MLP com regularização de gráfico

Incorporando gráfico regularização no prazo perda de um já existente tf.Keras.Model requer apenas algumas linhas de código. O modelo de base é ajustada para criar uma nova tf.Keras subclasse modelo, cuja perda inclui gráfico regularização.

Para avaliar o benefício incremental da regularização do gráfico, criaremos uma nova instância do modelo base. Isso ocorre porque base_model já foi treinado por algumas iterações, e reutilizar este modelo treinados para criar um modelo regularizado-gráfico não será uma comparação justa para base_model .

# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
    HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape:0", shape=(None, 7), dtype=float32), dense_shape=Tensor("gradient_tape/GraphRegularization/graph_loss/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory.
  "shape. This may consume a large amount of memory." % value)
17/17 [==============================] - 2s 4ms/step - loss: 1.9798 - accuracy: 0.1601 - scaled_graph_loss: 0.0373
Epoch 2/100
17/17 [==============================] - 0s 3ms/step - loss: 1.9024 - accuracy: 0.2979 - scaled_graph_loss: 0.0254
Epoch 3/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8623 - accuracy: 0.3160 - scaled_graph_loss: 0.0317
Epoch 4/100
17/17 [==============================] - 0s 3ms/step - loss: 1.8042 - accuracy: 0.3443 - scaled_graph_loss: 0.0498
Epoch 5/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7552 - accuracy: 0.3582 - scaled_graph_loss: 0.0696
Epoch 6/100
17/17 [==============================] - 0s 3ms/step - loss: 1.7012 - accuracy: 0.4084 - scaled_graph_loss: 0.0866
Epoch 7/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6578 - accuracy: 0.4515 - scaled_graph_loss: 0.1114
Epoch 8/100
17/17 [==============================] - 0s 3ms/step - loss: 1.6058 - accuracy: 0.5039 - scaled_graph_loss: 0.1300
Epoch 9/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5498 - accuracy: 0.5434 - scaled_graph_loss: 0.1508
Epoch 10/100
17/17 [==============================] - 0s 3ms/step - loss: 1.5098 - accuracy: 0.6019 - scaled_graph_loss: 0.1651
Epoch 11/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4746 - accuracy: 0.6302 - scaled_graph_loss: 0.1844
Epoch 12/100
17/17 [==============================] - 0s 3ms/step - loss: 1.4315 - accuracy: 0.6520 - scaled_graph_loss: 0.1917
Epoch 13/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3932 - accuracy: 0.6770 - scaled_graph_loss: 0.2024
Epoch 14/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3645 - accuracy: 0.7183 - scaled_graph_loss: 0.2145
Epoch 15/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3265 - accuracy: 0.7369 - scaled_graph_loss: 0.2324
Epoch 16/100
17/17 [==============================] - 0s 3ms/step - loss: 1.3045 - accuracy: 0.7555 - scaled_graph_loss: 0.2358
Epoch 17/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2836 - accuracy: 0.7652 - scaled_graph_loss: 0.2404
Epoch 18/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2456 - accuracy: 0.7898 - scaled_graph_loss: 0.2469
Epoch 19/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2348 - accuracy: 0.8074 - scaled_graph_loss: 0.2615
Epoch 20/100
17/17 [==============================] - 0s 3ms/step - loss: 1.2000 - accuracy: 0.8074 - scaled_graph_loss: 0.2542
Epoch 21/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1994 - accuracy: 0.8260 - scaled_graph_loss: 0.2729
Epoch 22/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1825 - accuracy: 0.8269 - scaled_graph_loss: 0.2676
Epoch 23/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1598 - accuracy: 0.8455 - scaled_graph_loss: 0.2742
Epoch 24/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1543 - accuracy: 0.8534 - scaled_graph_loss: 0.2797
Epoch 25/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1456 - accuracy: 0.8552 - scaled_graph_loss: 0.2714
Epoch 26/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8566 - scaled_graph_loss: 0.2796
Epoch 27/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1150 - accuracy: 0.8687 - scaled_graph_loss: 0.2850
Epoch 28/100
17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8626 - scaled_graph_loss: 0.2772
Epoch 29/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0806 - accuracy: 0.8733 - scaled_graph_loss: 0.2756
Epoch 30/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0828 - accuracy: 0.8626 - scaled_graph_loss: 0.2907
Epoch 31/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0724 - accuracy: 0.8886 - scaled_graph_loss: 0.2834
Epoch 32/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0589 - accuracy: 0.8826 - scaled_graph_loss: 0.2881
Epoch 33/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0490 - accuracy: 0.8872 - scaled_graph_loss: 0.2972
Epoch 34/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0550 - accuracy: 0.8923 - scaled_graph_loss: 0.2935
Epoch 35/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0397 - accuracy: 0.8840 - scaled_graph_loss: 0.2795
Epoch 36/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0360 - accuracy: 0.8891 - scaled_graph_loss: 0.2966
Epoch 37/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0235 - accuracy: 0.8961 - scaled_graph_loss: 0.2890
Epoch 38/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0219 - accuracy: 0.8984 - scaled_graph_loss: 0.2965
Epoch 39/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0168 - accuracy: 0.9044 - scaled_graph_loss: 0.3023
Epoch 40/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0148 - accuracy: 0.9035 - scaled_graph_loss: 0.2984
Epoch 41/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9118 - scaled_graph_loss: 0.2888
Epoch 42/100
17/17 [==============================] - 0s 3ms/step - loss: 1.0019 - accuracy: 0.9021 - scaled_graph_loss: 0.2877
Epoch 43/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9049 - scaled_graph_loss: 0.2912
Epoch 44/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9986 - accuracy: 0.9026 - scaled_graph_loss: 0.3040
Epoch 45/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9939 - accuracy: 0.9067 - scaled_graph_loss: 0.3016
Epoch 46/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9828 - accuracy: 0.9058 - scaled_graph_loss: 0.2877
Epoch 47/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9629 - accuracy: 0.9137 - scaled_graph_loss: 0.2844
Epoch 48/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9645 - accuracy: 0.9146 - scaled_graph_loss: 0.2933
Epoch 49/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9752 - accuracy: 0.9165 - scaled_graph_loss: 0.3013
Epoch 50/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9552 - accuracy: 0.9179 - scaled_graph_loss: 0.2865
Epoch 51/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9539 - accuracy: 0.9193 - scaled_graph_loss: 0.3044
Epoch 52/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9443 - accuracy: 0.9183 - scaled_graph_loss: 0.3010
Epoch 53/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9559 - accuracy: 0.9244 - scaled_graph_loss: 0.2987
Epoch 54/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9497 - accuracy: 0.9225 - scaled_graph_loss: 0.2979
Epoch 55/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9674 - accuracy: 0.9183 - scaled_graph_loss: 0.3034
Epoch 56/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9537 - accuracy: 0.9174 - scaled_graph_loss: 0.2834
Epoch 57/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9341 - accuracy: 0.9188 - scaled_graph_loss: 0.2939
Epoch 58/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9392 - accuracy: 0.9225 - scaled_graph_loss: 0.2998
Epoch 59/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9240 - accuracy: 0.9313 - scaled_graph_loss: 0.3022
Epoch 60/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9368 - accuracy: 0.9267 - scaled_graph_loss: 0.2979
Epoch 61/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9306 - accuracy: 0.9234 - scaled_graph_loss: 0.2952
Epoch 62/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9197 - accuracy: 0.9230 - scaled_graph_loss: 0.2916
Epoch 63/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9360 - accuracy: 0.9206 - scaled_graph_loss: 0.2947
Epoch 64/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9181 - accuracy: 0.9299 - scaled_graph_loss: 0.2996
Epoch 65/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9105 - accuracy: 0.9341 - scaled_graph_loss: 0.2981
Epoch 66/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9014 - accuracy: 0.9323 - scaled_graph_loss: 0.2897
Epoch 67/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9059 - accuracy: 0.9364 - scaled_graph_loss: 0.3083
Epoch 68/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9053 - accuracy: 0.9309 - scaled_graph_loss: 0.2976
Epoch 69/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9099 - accuracy: 0.9258 - scaled_graph_loss: 0.3069
Epoch 70/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9025 - accuracy: 0.9355 - scaled_graph_loss: 0.2890
Epoch 71/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8849 - accuracy: 0.9281 - scaled_graph_loss: 0.2933
Epoch 72/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8959 - accuracy: 0.9323 - scaled_graph_loss: 0.2918
Epoch 73/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9074 - accuracy: 0.9248 - scaled_graph_loss: 0.3065
Epoch 74/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8845 - accuracy: 0.9369 - scaled_graph_loss: 0.2874
Epoch 75/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8873 - accuracy: 0.9401 - scaled_graph_loss: 0.2996
Epoch 76/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8942 - accuracy: 0.9327 - scaled_graph_loss: 0.3086
Epoch 77/100
17/17 [==============================] - 0s 3ms/step - loss: 0.9052 - accuracy: 0.9253 - scaled_graph_loss: 0.2986
Epoch 78/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8811 - accuracy: 0.9336 - scaled_graph_loss: 0.2948
Epoch 79/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8896 - accuracy: 0.9276 - scaled_graph_loss: 0.2919
Epoch 80/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8853 - accuracy: 0.9313 - scaled_graph_loss: 0.2944
Epoch 81/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8875 - accuracy: 0.9323 - scaled_graph_loss: 0.2925
Epoch 82/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8639 - accuracy: 0.9323 - scaled_graph_loss: 0.2967
Epoch 83/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8820 - accuracy: 0.9332 - scaled_graph_loss: 0.3047
Epoch 84/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8752 - accuracy: 0.9346 - scaled_graph_loss: 0.2942
Epoch 85/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9374 - scaled_graph_loss: 0.3066
Epoch 86/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8765 - accuracy: 0.9332 - scaled_graph_loss: 0.2881
Epoch 87/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8691 - accuracy: 0.9420 - scaled_graph_loss: 0.3030
Epoch 88/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8631 - accuracy: 0.9374 - scaled_graph_loss: 0.2916
Epoch 89/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9392 - scaled_graph_loss: 0.3032
Epoch 90/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8632 - accuracy: 0.9420 - scaled_graph_loss: 0.3019
Epoch 91/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8600 - accuracy: 0.9425 - scaled_graph_loss: 0.2965
Epoch 92/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8569 - accuracy: 0.9346 - scaled_graph_loss: 0.2977
Epoch 93/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8704 - accuracy: 0.9374 - scaled_graph_loss: 0.3083
Epoch 94/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8562 - accuracy: 0.9406 - scaled_graph_loss: 0.2883
Epoch 95/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8545 - accuracy: 0.9415 - scaled_graph_loss: 0.3030
Epoch 96/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8592 - accuracy: 0.9332 - scaled_graph_loss: 0.2927
Epoch 97/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8503 - accuracy: 0.9397 - scaled_graph_loss: 0.2927
Epoch 98/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8434 - accuracy: 0.9462 - scaled_graph_loss: 0.2937
Epoch 99/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8578 - accuracy: 0.9374 - scaled_graph_loss: 0.3064
Epoch 100/100
17/17 [==============================] - 0s 3ms/step - loss: 0.8504 - accuracy: 0.9411 - scaled_graph_loss: 0.3043
<keras.callbacks.History at 0x7f70041be650>

Avalie o modelo MLP com regularização de gráfico

eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 0.8884 - accuracy: 0.7957


Eval accuracy for  MLP + graph regularization :  0.7956600189208984
Eval loss for  MLP + graph regularization :  0.8883611559867859

A precisão do modelo regularizado-gráfico é de cerca de 2-3% mais elevada do que a do modelo de base ( base_model ).

Conclusão

Demonstramos o uso de regularização de gráfico para classificação de documentos em um gráfico de citação natural (Cora) usando o framework Neural Structured Learning (NSL). Nosso tutorial avançado envolve sintetizar gráficos baseados em embeddings amostra antes de treinar uma rede neural com regularização gráfico. Essa abordagem é útil se a entrada não contém um gráfico explícito.

Nós encorajamos os usuários a experimentar mais, variando a quantidade de supervisão, bem como tentando diferentes arquiteturas neurais para regularização de gráfico.