Visualizza su TensorFlow.org | Esegui in Google Colab | Visualizza la fonte su GitHub | Scarica taccuino |
Panoramica
Grafico regolarizzazione è una tecnica specifica sotto il paradigma più ampio di Neural Graph Learning ( Bui et al., 2018 ). L'idea centrale è quella di addestrare modelli di rete neurale con un obiettivo regolato dal grafico, sfruttando sia i dati etichettati che quelli non etichettati.
In questo tutorial, esploreremo l'uso della regolarizzazione del grafico per classificare i documenti che formano un grafico naturale (organico).
La ricetta generale per la creazione di un modello grafico-regolato utilizzando il framework Neural Structured Learning (NSL) è la seguente:
- Genera dati di addestramento dal grafico di input e dalle funzionalità di esempio. I nodi nel grafico corrispondono ai campioni e gli archi nel grafico corrispondono alla somiglianza tra coppie di campioni. I dati di training risultanti conterranno feature adiacenti oltre alle feature del nodo originale.
- Creare una rete neurale come modello base utilizzando la
Keras
sequenziale, funzionale, o API sottoclasse. - Avvolgere il modello base con il
GraphRegularization
classe wrapper, che è fornito dal framework NSL, per creare un nuovo graficoKeras
modello. Questo nuovo modello includerà un grafico della perdita di regolarizzazione come termine di regolarizzazione nel suo obiettivo di formazione. - Addestrare e valutare il grafico
Keras
modello.
Impostare
Installa il pacchetto Neural Structured Learning.
pip install --quiet neural-structured-learning
Dipendenze e importazioni
import neural_structured_learning as nsl
import tensorflow as tf
# Resets notebook state
tf.keras.backend.clear_session()
print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
"GPU is",
"available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version: 2.8.0-rc0 Eager mode: True GPU is NOT AVAILABLE 2022-01-05 12:39:27.704660: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Set di dati Cora
L' insieme di dati Cora è un grafico citazione in cui i nodi rappresentano carte di apprendimento automatico e bordi rappresentano citazioni tra coppie di carte. L'attività coinvolta è la classificazione dei documenti in cui l'obiettivo è classificare ogni documento in una delle 7 categorie. In altre parole, questo è un problema di classificazione multiclasse con 7 classi.
Grafico
Il grafico originale è diretto. Tuttavia, ai fini di questo esempio, consideriamo la versione non orientata di questo grafico. Quindi, se l'articolo A cita l'articolo B, consideriamo che anche l'articolo B abbia citato A. Sebbene ciò non sia necessariamente vero, in questo esempio, consideriamo le citazioni come proxy per la somiglianza, che di solito è una proprietà commutativa.
Caratteristiche
Ogni carta nell'input contiene effettivamente 2 caratteristiche:
Le parole: Una fitta, più calda la rappresentazione bag-of-parole del testo nel documento. Il vocabolario per il set di dati Cora contiene 1433 parole uniche. Quindi, la lunghezza di questa caratteristica è 1433, e il valore alla posizione 'i' è 0/1 che indica se la parola 'i' nel vocabolario esiste o meno nel documento dato.
Etichetta: Un unico numero intero che rappresenta l'ID di classe (categoria) della carta.
Scarica il set di dati Cora
wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/ cora/README cora/cora.cites cora/cora.content
Converti i dati Cora nel formato NSL
Al fine di pre-elaborare i set di dati Cora e convertirlo nel formato richiesto dalla Neural strutturato di apprendimento, ci verrà eseguito lo script 'preprocess_cora_dataset.py', che è incluso nel repository github NSL. Questo script esegue le seguenti operazioni:
- Genera feature adiacenti utilizzando le feature del nodo originale e il grafico.
- Generare spaccature dei treni e dei dati di test che contengono
tf.train.Example
istanze. - Persistere il treno risultante e dati di test nel
TFRecord
formato.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2022-01-05 12:39:28-- https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 11640 (11K) [text/plain] Saving to: ‘preprocess_cora_dataset.py’ preprocess_cora_dat 100%[===================>] 11.37K --.-KB/s in 0s 2022-01-05 12:39:28 (78.9 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640] 2022-01-05 12:39:31.378912: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Reading graph file: /tmp/cora/cora.cites... Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds). Making all edges bi-directional... Done (0.01 seconds). Total graph nodes: 2708 Joining seed and neighbor tf.train.Examples with graph edges... Done creating and writing 2155 merged tf.train.Examples (1.36 seconds). Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)] Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr. Output test data written to TFRecord file: /tmp/cora/test_examples.tfr. Total running time: 0.04 minutes.
Variabili globali
I percorsi dei file ai dati dei treni e dei test sono basati sui valori di flag della riga di comando usato per richiamare lo script 'preprocess_cora_dataset.py' sopra.
### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'
### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
Iperparametri
Useremo un'istanza di HParams
per includere vari iperparametri e costanti utilizzate per la formazione e la valutazione. Descriviamo brevemente ciascuno di essi di seguito:
num_classes: Ci sono un totale di 7 classi diverse
max_seq_length: Questa è la dimensione del vocabolario e tutte le istanze in ingresso hanno una densa multi-caldo, il sacchetto-di-parole di rappresentanza. In altre parole, un valore 1 per una parola indica che la parola è presente nell'ingresso e un valore 0 indica che non lo è.
distance_type: Questa è la distanza metrica utilizzata per regolarizzare il campione con i suoi vicini.
graph_regularization_multiplier: Questo controlla il peso relativo del termine grafico regolarizzazione nella funzione generale di perdita.
num_neighbors: il numero di vicini utilizzati per la regolarizzazione del grafico. Tale valore deve essere minore o uguale ai
max_nbrs
comando di riga argomento usato sopra durante l'esecuzionepreprocess_cora_dataset.py
.num_fc_units: Il numero di strati completamente collegati in una rete neurale.
train_epochs: il numero di epoche di formazione.
dimensione del lotto utilizzato per la formazione e la valutazione: batch_size.
dropout_rate: controlla la velocità della dispersione dopo ciascuna strato completamente connesso
eval_steps: il numero di lotti di processo prima ritenendo di valutazione è completo. Se è impostato su
None
, tutte le istanze nel set di prova vengono valutati.
class HParams(object):
"""Hyperparameters used for training."""
def __init__(self):
### dataset parameters
self.num_classes = 7
self.max_seq_length = 1433
### neural graph learning parameters
self.distance_type = nsl.configs.DistanceType.L2
self.graph_regularization_multiplier = 0.1
self.num_neighbors = 1
### model architecture
self.num_fc_units = [50, 50]
### training parameters
self.train_epochs = 100
self.batch_size = 128
self.dropout_rate = 0.5
### eval parameters
self.eval_steps = None # All instances in the test set are evaluated.
HPARAMS = HParams()
Caricare i dati del treno e dei test
Come descritto in precedenza in questo notebook, i dati di allenamento e di test di ingresso sono stati creati dal 'preprocess_cora_dataset.py'. Noi caricarli in due tf.data.Dataset
oggetti - uno per il treno e uno per il test.
Nello strato di input del nostro modello, avremo estrarre non solo le 'parole' e 'etichetta' dotato di ciascun campione, ma anche vicina corrispondenti funzioni basate sul hparams.num_neighbors
valore. Le istanze con un minor numero di vicini di casa che hparams.num_neighbors
verranno assegnati manichino valori per quelle caratteristiche vicini inesistenti.
def make_dataset(file_path, training=False):
"""Creates a `tf.data.TFRecordDataset`.
Args:
file_path: Name of the file in the `.tfrecord` format containing
`tf.train.Example` objects.
training: Boolean indicating if we are in training mode.
Returns:
An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
objects.
"""
def parse_example(example_proto):
"""Extracts relevant fields from the `example_proto`.
Args:
example_proto: An instance of `tf.train.Example`.
Returns:
A pair whose first value is a dictionary containing relevant features
and whose second value contains the ground truth label.
"""
# The 'words' feature is a multi-hot, bag-of-words representation of the
# original raw text. A default value is required for examples that don't
# have the feature.
feature_spec = {
'words':
tf.io.FixedLenFeature([HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0,
dtype=tf.int64,
shape=[HPARAMS.max_seq_length])),
'label':
tf.io.FixedLenFeature((), tf.int64, default_value=-1),
}
# We also extract corresponding neighbor features in a similar manner to
# the features above during training.
if training:
for i in range(HPARAMS.num_neighbors):
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
NBR_WEIGHT_SUFFIX)
feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
[HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))
# We assign a default value of 0.0 for the neighbor weight so that
# graph regularization is done on samples based on their exact number
# of neighbors. In other words, non-existent neighbors are discounted.
feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
[1], tf.float32, default_value=tf.constant([0.0]))
features = tf.io.parse_single_example(example_proto, feature_spec)
label = features.pop('label')
return features, label
dataset = tf.data.TFRecordDataset([file_path])
if training:
dataset = dataset.shuffle(10000)
dataset = dataset.map(parse_example)
dataset = dataset.batch(HPARAMS.batch_size)
return dataset
train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)
Diamo un'occhiata al set di dati del treno per esaminarne il contenuto.
for feature_batch, label_batch in train_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
print('Batch of neighbor weights:',
tf.reshape(feature_batch[nbr_weight_key], [-1]))
print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 1 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor weights: tf.Tensor( [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32) Batch of labels: tf.Tensor( [2 2 6 2 0 6 1 3 5 0 1 2 3 6 1 1 0 3 5 2 3 1 4 1 6 1 3 2 2 2 0 3 2 1 3 3 2 3 3 2 3 2 2 0 2 2 6 0 2 1 1 0 5 2 1 4 2 1 2 4 0 2 5 4 3 6 3 2 1 6 2 4 2 2 6 4 6 4 3 5 2 2 2 4 2 2 2 1 2 2 2 4 2 3 6 2 0 6 6 0 2 6 2 1 2 0 1 1 3 2 0 2 0 2 1 1 3 5 2 1 2 5 1 6 2 4 6 4], shape=(128,), dtype=int64)
Diamo un'occhiata al set di dati di test per esaminarne il contenuto.
for feature_batch, label_batch in test_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
print('Batch of labels:', label_batch)
Feature list: ['words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of labels: tf.Tensor( [5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)
Definizione del modello
Per dimostrare l'uso della regolarizzazione dei grafi, prima costruiamo un modello di base per questo problema. Useremo una semplice rete neurale feed-forward con 2 strati nascosti e dropout in mezzo. Illustriamo la creazione del modello di base con tutti i tipi di modello supportati dal tf.Keras
quadro - sequenziale, funzionale e sottoclasse.
Modello base sequenziale
def make_mlp_sequential_model(hparams):
"""Creates a sequential multi-layer perceptron model."""
model = tf.keras.Sequential()
model.add(
tf.keras.layers.InputLayer(
input_shape=(hparams.max_seq_length,), name='words'))
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
model.add(
tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
for num_units in hparams.num_fc_units:
model.add(tf.keras.layers.Dense(num_units, activation='relu'))
# For sequential models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
model.add(tf.keras.layers.Dense(hparams.num_classes))
return model
Modello base funzionale
def make_mlp_functional_model(hparams):
"""Creates a functional API-based multi-layer perceptron model."""
inputs = tf.keras.Input(
shape=(hparams.max_seq_length,), dtype='int64', name='words')
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
cur_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))(
inputs)
for num_units in hparams.num_fc_units:
cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
# For functional models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)
outputs = tf.keras.layers.Dense(hparams.num_classes)(cur_layer)
model = tf.keras.Model(inputs, outputs=outputs)
return model
Modello base della sottoclasse
def make_mlp_subclass_model(hparams):
"""Creates a multi-layer perceptron subclass model in Keras."""
class MLP(tf.keras.Model):
"""Subclass model defining a multi-layer perceptron."""
def __init__(self):
super(MLP, self).__init__()
# Input is already one-hot encoded in the integer format. We create a
# layer to cast it to floating point format here.
self.cast_to_float_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))
self.dense_layers = [
tf.keras.layers.Dense(num_units, activation='relu')
for num_units in hparams.num_fc_units
]
self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
self.output_layer = tf.keras.layers.Dense(hparams.num_classes)
def call(self, inputs, training=False):
cur_layer = self.cast_to_float_layer(inputs['words'])
for dense_layer in self.dense_layers:
cur_layer = dense_layer(cur_layer)
cur_layer = self.dropout_layer(cur_layer, training=training)
outputs = self.output_layer(cur_layer)
return outputs
return MLP()
Crea modelli base
# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= words (InputLayer) [(None, 1433)] 0 lambda (Lambda) (None, 1433) 0 dense (Dense) (None, 50) 71700 dropout (Dropout) (None, 50) 0 dense_1 (Dense) (None, 50) 2550 dropout_1 (Dropout) (None, 50) 0 dense_2 (Dense) (None, 7) 357 ================================================================= Total params: 74,607 Trainable params: 74,607 Non-trainable params: 0 _________________________________________________________________
Modello MLP di base del treno
# Compile and train the base MLP model
base_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/functional.py:559: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model. inputs = self._flatten_to_reference_inputs(inputs) 17/17 [==============================] - 1s 18ms/step - loss: 1.9521 - accuracy: 0.1838 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8590 - accuracy: 0.3044 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7770 - accuracy: 0.3601 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6655 - accuracy: 0.3898 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5386 - accuracy: 0.4543 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3856 - accuracy: 0.5077 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2736 - accuracy: 0.5531 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1636 - accuracy: 0.5889 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0654 - accuracy: 0.6385 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9703 - accuracy: 0.6761 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8689 - accuracy: 0.7104 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7704 - accuracy: 0.7494 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7157 - accuracy: 0.7810 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 0.6296 - accuracy: 0.8186 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5932 - accuracy: 0.8167 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5526 - accuracy: 0.8464 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5112 - accuracy: 0.8445 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4624 - accuracy: 0.8613 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4163 - accuracy: 0.8696 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3808 - accuracy: 0.8849 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3564 - accuracy: 0.8933 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3453 - accuracy: 0.9002 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3226 - accuracy: 0.9114 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3058 - accuracy: 0.9151 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2798 - accuracy: 0.9146 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2638 - accuracy: 0.9248 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2538 - accuracy: 0.9290 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2356 - accuracy: 0.9411 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2080 - accuracy: 0.9425 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2172 - accuracy: 0.9364 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2259 - accuracy: 0.9225 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1944 - accuracy: 0.9480 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1892 - accuracy: 0.9434 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1718 - accuracy: 0.9592 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1826 - accuracy: 0.9508 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1585 - accuracy: 0.9559 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1605 - accuracy: 0.9545 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1529 - accuracy: 0.9550 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1411 - accuracy: 0.9615 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1366 - accuracy: 0.9624 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1431 - accuracy: 0.9578 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1241 - accuracy: 0.9619 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1310 - accuracy: 0.9661 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1284 - accuracy: 0.9652 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1215 - accuracy: 0.9633 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1130 - accuracy: 0.9722 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1074 - accuracy: 0.9722 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1143 - accuracy: 0.9694 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1015 - accuracy: 0.9740 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1077 - accuracy: 0.9698 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1035 - accuracy: 0.9684 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1076 - accuracy: 0.9694 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1000 - accuracy: 0.9689 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0967 - accuracy: 0.9749 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0994 - accuracy: 0.9703 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0943 - accuracy: 0.9740 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0923 - accuracy: 0.9735 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0848 - accuracy: 0.9800 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0836 - accuracy: 0.9782 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0913 - accuracy: 0.9735 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0823 - accuracy: 0.9773 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0753 - accuracy: 0.9810 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0746 - accuracy: 0.9777 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0861 - accuracy: 0.9731 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0765 - accuracy: 0.9787 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0750 - accuracy: 0.9791 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0725 - accuracy: 0.9814 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0762 - accuracy: 0.9791 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0645 - accuracy: 0.9842 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0606 - accuracy: 0.9861 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0775 - accuracy: 0.9805 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0655 - accuracy: 0.9800 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0629 - accuracy: 0.9833 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0625 - accuracy: 0.9824 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0607 - accuracy: 0.9838 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0578 - accuracy: 0.9824 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0568 - accuracy: 0.9842 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0595 - accuracy: 0.9833 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0615 - accuracy: 0.9842 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0555 - accuracy: 0.9852 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0517 - accuracy: 0.9870 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0541 - accuracy: 0.9856 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0533 - accuracy: 0.9884 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0509 - accuracy: 0.9838 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0600 - accuracy: 0.9828 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0617 - accuracy: 0.9800 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0599 - accuracy: 0.9800 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0502 - accuracy: 0.9870 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0416 - accuracy: 0.9907 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0542 - accuracy: 0.9842 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0490 - accuracy: 0.9847 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0374 - accuracy: 0.9916 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0467 - accuracy: 0.9893 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0426 - accuracy: 0.9879 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0543 - accuracy: 0.9861 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0420 - accuracy: 0.9870 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0461 - accuracy: 0.9861 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0425 - accuracy: 0.9898 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0406 - accuracy: 0.9907 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0486 - accuracy: 0.9847 <keras.callbacks.History at 0x7f6f9d5eacd0>
Valuta il modello MLP di base
# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
"""Prints evaluation metrics.
Args:
model_desc: A description of the model.
eval_metrics: A dictionary mapping metric names to corresponding values. It
must contain the loss and accuracy metrics.
"""
print('\n')
print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
if 'graph_loss' in eval_metrics:
print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
zip(base_model.metrics_names,
base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.4192 - accuracy: 0.7939 Eval accuracy for Base MLP model : 0.7938517332077026 Eval loss for Base MLP model : 1.4192423820495605
Addestra il modello MLP con la regolarizzazione del grafico
Incorporando grafico regolarizzazione nel termine perdita di un esistente tf.Keras.Model
richiede solo poche righe di codice. Il modello di base è avvolto per creare un nuovo tf.Keras
sottoclassi modello, la cui perdita include grafico regolarizzazione.
Per valutare il vantaggio incrementale della regolarizzazione del grafico, creeremo una nuova istanza del modello di base. Questo perché base_model
è già stato allenato per un paio di iterazioni, e il riutilizzo di questo modello addestrato per creare un modello grafico-regolarizzato non sarà un confronto equo per base_model
.
# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
max_neighbors=HPARAMS.num_neighbors,
multiplier=HPARAMS.graph_regularization_multiplier,
distance_type=HPARAMS.distance_type,
sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
graph_reg_config)
graph_reg_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape:0", shape=(None, 7), dtype=float32), dense_shape=Tensor("gradient_tape/GraphRegularization/graph_loss/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) 17/17 [==============================] - 2s 4ms/step - loss: 1.9798 - accuracy: 0.1601 - scaled_graph_loss: 0.0373 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.9024 - accuracy: 0.2979 - scaled_graph_loss: 0.0254 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8623 - accuracy: 0.3160 - scaled_graph_loss: 0.0317 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8042 - accuracy: 0.3443 - scaled_graph_loss: 0.0498 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7552 - accuracy: 0.3582 - scaled_graph_loss: 0.0696 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7012 - accuracy: 0.4084 - scaled_graph_loss: 0.0866 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6578 - accuracy: 0.4515 - scaled_graph_loss: 0.1114 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6058 - accuracy: 0.5039 - scaled_graph_loss: 0.1300 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5498 - accuracy: 0.5434 - scaled_graph_loss: 0.1508 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5098 - accuracy: 0.6019 - scaled_graph_loss: 0.1651 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4746 - accuracy: 0.6302 - scaled_graph_loss: 0.1844 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4315 - accuracy: 0.6520 - scaled_graph_loss: 0.1917 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3932 - accuracy: 0.6770 - scaled_graph_loss: 0.2024 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3645 - accuracy: 0.7183 - scaled_graph_loss: 0.2145 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3265 - accuracy: 0.7369 - scaled_graph_loss: 0.2324 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3045 - accuracy: 0.7555 - scaled_graph_loss: 0.2358 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2836 - accuracy: 0.7652 - scaled_graph_loss: 0.2404 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2456 - accuracy: 0.7898 - scaled_graph_loss: 0.2469 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2348 - accuracy: 0.8074 - scaled_graph_loss: 0.2615 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2000 - accuracy: 0.8074 - scaled_graph_loss: 0.2542 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1994 - accuracy: 0.8260 - scaled_graph_loss: 0.2729 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1825 - accuracy: 0.8269 - scaled_graph_loss: 0.2676 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1598 - accuracy: 0.8455 - scaled_graph_loss: 0.2742 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1543 - accuracy: 0.8534 - scaled_graph_loss: 0.2797 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1456 - accuracy: 0.8552 - scaled_graph_loss: 0.2714 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8566 - scaled_graph_loss: 0.2796 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1150 - accuracy: 0.8687 - scaled_graph_loss: 0.2850 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8626 - scaled_graph_loss: 0.2772 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0806 - accuracy: 0.8733 - scaled_graph_loss: 0.2756 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0828 - accuracy: 0.8626 - scaled_graph_loss: 0.2907 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0724 - accuracy: 0.8886 - scaled_graph_loss: 0.2834 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0589 - accuracy: 0.8826 - scaled_graph_loss: 0.2881 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0490 - accuracy: 0.8872 - scaled_graph_loss: 0.2972 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0550 - accuracy: 0.8923 - scaled_graph_loss: 0.2935 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0397 - accuracy: 0.8840 - scaled_graph_loss: 0.2795 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0360 - accuracy: 0.8891 - scaled_graph_loss: 0.2966 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0235 - accuracy: 0.8961 - scaled_graph_loss: 0.2890 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0219 - accuracy: 0.8984 - scaled_graph_loss: 0.2965 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0168 - accuracy: 0.9044 - scaled_graph_loss: 0.3023 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0148 - accuracy: 0.9035 - scaled_graph_loss: 0.2984 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9118 - scaled_graph_loss: 0.2888 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0019 - accuracy: 0.9021 - scaled_graph_loss: 0.2877 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9049 - scaled_graph_loss: 0.2912 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9986 - accuracy: 0.9026 - scaled_graph_loss: 0.3040 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9939 - accuracy: 0.9067 - scaled_graph_loss: 0.3016 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9828 - accuracy: 0.9058 - scaled_graph_loss: 0.2877 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9629 - accuracy: 0.9137 - scaled_graph_loss: 0.2844 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9645 - accuracy: 0.9146 - scaled_graph_loss: 0.2933 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9752 - accuracy: 0.9165 - scaled_graph_loss: 0.3013 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9552 - accuracy: 0.9179 - scaled_graph_loss: 0.2865 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9539 - accuracy: 0.9193 - scaled_graph_loss: 0.3044 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9443 - accuracy: 0.9183 - scaled_graph_loss: 0.3010 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9559 - accuracy: 0.9244 - scaled_graph_loss: 0.2987 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9497 - accuracy: 0.9225 - scaled_graph_loss: 0.2979 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9674 - accuracy: 0.9183 - scaled_graph_loss: 0.3034 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9537 - accuracy: 0.9174 - scaled_graph_loss: 0.2834 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9341 - accuracy: 0.9188 - scaled_graph_loss: 0.2939 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9392 - accuracy: 0.9225 - scaled_graph_loss: 0.2998 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9240 - accuracy: 0.9313 - scaled_graph_loss: 0.3022 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9368 - accuracy: 0.9267 - scaled_graph_loss: 0.2979 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9306 - accuracy: 0.9234 - scaled_graph_loss: 0.2952 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9197 - accuracy: 0.9230 - scaled_graph_loss: 0.2916 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9360 - accuracy: 0.9206 - scaled_graph_loss: 0.2947 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9181 - accuracy: 0.9299 - scaled_graph_loss: 0.2996 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9105 - accuracy: 0.9341 - scaled_graph_loss: 0.2981 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9014 - accuracy: 0.9323 - scaled_graph_loss: 0.2897 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9059 - accuracy: 0.9364 - scaled_graph_loss: 0.3083 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9053 - accuracy: 0.9309 - scaled_graph_loss: 0.2976 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9099 - accuracy: 0.9258 - scaled_graph_loss: 0.3069 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9025 - accuracy: 0.9355 - scaled_graph_loss: 0.2890 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8849 - accuracy: 0.9281 - scaled_graph_loss: 0.2933 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8959 - accuracy: 0.9323 - scaled_graph_loss: 0.2918 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9074 - accuracy: 0.9248 - scaled_graph_loss: 0.3065 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8845 - accuracy: 0.9369 - scaled_graph_loss: 0.2874 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8873 - accuracy: 0.9401 - scaled_graph_loss: 0.2996 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8942 - accuracy: 0.9327 - scaled_graph_loss: 0.3086 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9052 - accuracy: 0.9253 - scaled_graph_loss: 0.2986 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8811 - accuracy: 0.9336 - scaled_graph_loss: 0.2948 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8896 - accuracy: 0.9276 - scaled_graph_loss: 0.2919 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8853 - accuracy: 0.9313 - scaled_graph_loss: 0.2944 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8875 - accuracy: 0.9323 - scaled_graph_loss: 0.2925 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8639 - accuracy: 0.9323 - scaled_graph_loss: 0.2967 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8820 - accuracy: 0.9332 - scaled_graph_loss: 0.3047 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8752 - accuracy: 0.9346 - scaled_graph_loss: 0.2942 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9374 - scaled_graph_loss: 0.3066 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8765 - accuracy: 0.9332 - scaled_graph_loss: 0.2881 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8691 - accuracy: 0.9420 - scaled_graph_loss: 0.3030 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8631 - accuracy: 0.9374 - scaled_graph_loss: 0.2916 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9392 - scaled_graph_loss: 0.3032 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8632 - accuracy: 0.9420 - scaled_graph_loss: 0.3019 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8600 - accuracy: 0.9425 - scaled_graph_loss: 0.2965 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8569 - accuracy: 0.9346 - scaled_graph_loss: 0.2977 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8704 - accuracy: 0.9374 - scaled_graph_loss: 0.3083 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8562 - accuracy: 0.9406 - scaled_graph_loss: 0.2883 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8545 - accuracy: 0.9415 - scaled_graph_loss: 0.3030 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8592 - accuracy: 0.9332 - scaled_graph_loss: 0.2927 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8503 - accuracy: 0.9397 - scaled_graph_loss: 0.2927 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8434 - accuracy: 0.9462 - scaled_graph_loss: 0.2937 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8578 - accuracy: 0.9374 - scaled_graph_loss: 0.3064 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8504 - accuracy: 0.9411 - scaled_graph_loss: 0.3043 <keras.callbacks.History at 0x7f70041be650>
Valuta il modello MLP con la regolarizzazione del grafico
eval_results = dict(
zip(graph_reg_model.metrics_names,
graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 0.8884 - accuracy: 0.7957 Eval accuracy for MLP + graph regularization : 0.7956600189208984 Eval loss for MLP + graph regularization : 0.8883611559867859
Accuratezza del modello grafico-regolarizzata è circa 2-3% superiore a quella del modello di base ( base_model
).
Conclusione
Abbiamo dimostrato l'uso della regolarizzazione del grafico per la classificazione dei documenti su un grafico delle citazioni naturali (Cora) utilizzando il framework Neural Structured Learning (NSL). Il nostro tutorial avanzato coinvolge la sintesi grafici basati su incastri di esempio prima di allenare una rete neurale con il grafico di regolarizzazione. Questo approccio è utile se l'input non contiene un grafico esplicito.
Incoraggiamo gli utenti a sperimentare ulteriormente variando la quantità di supervisione e provando diverse architetture neurali per la regolarizzazione dei grafi.