TensorFlow.org에서 보기 | Run in Google Colab | View source on GitHub |
개요
그래프 정규화는 Neural Graph Learning의 더 넓은 패러다임에서 사용되는 특정 기술입니다(Bui et al., 2018). 핵심 아이디어는 레이블이 지정된 데이터와 레이블이 없는 데이터를 모두 활용하여 그래프 정규화 목표를 갖고 신경망 모델을 훈련하는 것입니다.
이 튜토리얼에서는 그래프 정규화를 사용하여 자연(유기적) 그래프를 형성하는 문서를 분류하는 방법을 살펴봅니다.
Neural Structured Learning(NSL) 프레임워크를 사용하여 그래프 정규화 모델을 생성하는 일반적인 방법은 다음과 같습니다.
- 입력 그래프 및 샘플 특성에서 훈련 데이터를 생성합니다. 그래프의 노드는 샘플에 해당하고, 그래프의 간선은 샘플 쌍 간의 유사성에 해당합니다. 결과 훈련 데이터에는 원래 노드 특성 외에도 이웃 특성이 포함됩니다.
Keras
순차, 함수형 또는 서브 클래스 API를 사용하여 신경망을 기본 모델로 만듭니다.- NSL 프레임워크에서 제공하는
GraphRegularization
래퍼 클래스로 기본 모델을 래핑하여 새 그래프Keras
모델을 만듭니다. 이 새로운 모델은 훈련 목표에서 그래프 정규화 손실을 정규화 항으로 포함합니다. - 그래프
Keras
모델을 훈련하고 평가합니다.
설정
Neural Structured Learning 패키지를 설치합니다.
pip install --quiet neural-structured-learning
종속성 및 가져오기
import neural_structured_learning as nsl
import tensorflow as tf
# Resets notebook state
tf.keras.backend.clear_session()
print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
"GPU is",
"available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version: 2.4.0 Eager mode: True GPU is available
Cora 데이터세트
Cora 데이터세트는 노드가 머신러닝 논문을 나타내고 간선이 논문 쌍 간의 인용을 나타내는 인용 그래프입니다. 관련된 작업은 각 논문을 7가지 범주 중 하나로 분류하는 것을 목표로 하는 문서 분류입니다. 즉, 7개의 클래스가 있는 다중 클래스 분류 문제입니다.
그래프
원래 그래프에는 방향이 있습니다. 그러나 이 예에서는 이 그래프의 방향 없는 버전을 고려합니다. 따라서 A 논문이 B 논문을 인용하면 B 논문도 A를 인용한 것으로 간주합니다. 이것이 반드시 사실은 아니지만, 이 예에서는 인용을 유사성에 대한 프록시로 간주하며, 일반적으로 교환 속성입니다.
특성
입력의 각 논문에는 효과적으로 두 가지 특성이 포함되어 있습니다.
Words: 종이에 있는 텍스트를 표현한 밀집 멀티-핫 단어 주머니(bag-of-words)입니다. Cora 데이터세트의 어휘에는 1433개의 고유한 단어가 포함되어 있습니다. 따라서 이 특성의 길이는 1433이고, 위치 'i'의 값은 주어진 논문에서 해당 어휘의 단어 'i'가 존재하는지 여부를 나타내는 0/1입니다.
Label: 논문의 클래스 ID(카테고리)를 나타내는 단일 정수입니다.
Cora 데이터세트 다운로드하기
wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/ cora/README cora/cora.cites cora/cora.content
Cora 데이터를 NSL 형식으로 변환하기
Cora 데이터세트를 전처리하고 Neural Structured Learning에 필요한 형식으로 변환하기 위해 NSL github 리포지토리에 포함된 'preprocess_cora_dataset.py' 스크립트를 실행합니다. 이 스크립트는 다음을 수행합니다.
- 원래 노드 특성과 그래프를 사용하여 이웃 특성을 생성합니다.
tf.train.Example
인스턴스를 포함하는 훈련 및 테스트 데이터 분할을 생성합니다.- 결과 훈련 및 테스트 데이터를
TFRecord
형식으로 유지합니다.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2021-01-15 02:26:25-- https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.192.133, 151.101.128.133, 151.101.64.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.192.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 11640 (11K) [text/plain] Saving to: ‘preprocess_cora_dataset.py’ preprocess_cora_dat 100%[===================>] 11.37K --.-KB/s in 0.001s 2021-01-15 02:26:26 (20.7 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640] 2021-01-15 02:26:26.653369: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0 Reading graph file: /tmp/cora/cora.cites... Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds). Making all edges bi-directional... Done (0.01 seconds). Total graph nodes: 2708 Joining seed and neighbor tf.train.Examples with graph edges... Done creating and writing 2155 merged tf.train.Examples (1.51 seconds). Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)] Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr. Output test data written to TFRecord file: /tmp/cora/test_examples.tfr. Total running time: 0.05 minutes.
전역 변수
훈련 및 테스트 데이터에 대한 파일 경로는 위의 'preprocess_cora_dataset.py' 스크립트를 호출하는 데 사용된 명령 줄 플래그 값을 기반으로 합니다.
### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'
### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
하이퍼 매개변수
HParams
의 인스턴스를 사용하여 훈련 및 평가에 사용되는 다양한 하이퍼 매개변수 및 상수를 포함합니다. 아래에서 각각에 대해 간략하게 설명합니다.
num_classes: 총 7개의 클래스가 있습니다.
max_seq_length: 어휘의 크기이며, 입력의 모든 인스턴스는 밀집 멀티-핫, 단어 주머니(bag-of-words)의 표현을 갖습니다. 즉, 단어의 값이 1이면 해당 단어가 입력에 있음을 나타내고, 값이 0이면 그렇지 않음을 나타냅니다.
distance_type: 샘플을 이웃으로 정규화하는 데 사용되는 거리 메트릭입니다.
graph_regularization_multiplier: 전체 손실 함수에서 그래프 정규화 항의 상대적 가중치를 제어합니다.
num_neighbors: 그래프 정규화에 사용되는 이웃의 수입니다. 이 값은
preprocess_cora_dataset.py
를 실행할 때 위에 사용된max_nbrs
명령 줄 인수보다 작거나 같아야 합니다.num_fc_units: 신경망에서 완전 연결된 레이어의 수입니다.
train_epochs: 훈련 epoch의 수입니다.
batch_size: 훈련 및 평가에 사용되는 배치 크기입니다.
dropout_rate: 각 완전 연결 레이어의 드롭아웃 비율을 제어합니다.
eval_steps: 평가가 완료된 것으로 간주하기 전에 처리할 배치의 수입니다.
None
으로 설정하면, 테스트세트의 모든 인스턴스가 평가됩니다.
class HParams(object):
"""Hyperparameters used for training."""
def __init__(self):
### dataset parameters
self.num_classes = 7
self.max_seq_length = 1433
### neural graph learning parameters
self.distance_type = nsl.configs.DistanceType.L2
self.graph_regularization_multiplier = 0.1
self.num_neighbors = 1
### model architecture
self.num_fc_units = [50, 50]
### training parameters
self.train_epochs = 100
self.batch_size = 128
self.dropout_rate = 0.5
### eval parameters
self.eval_steps = None # All instances in the test set are evaluated.
HPARAMS = HParams()
훈련 및 테스트 데이터 로드하기
이 노트북의 앞부분에서 설명한 것처럼 입력 훈련 및 테스트 데이터는 'preprocess_cora_dataset.py'에 의해 생성되었습니다. 데이터를 두 개의 tf.data.Dataset
객체로 로드합니다. 하나는 훈련용이고 다른 하나는 테스트용입니다.
모델의 입력 레이어에서 각 샘플의 'words' 및 'label' 특성뿐만 아니라 hparams.num_neighbors
값을 기반으로 해당 이웃 특성도 추출합니다. 이웃이 hparams.num_neighbors
보다 적은 인스턴스에는 존재하지 않는 이웃 특성에 대해 더미 값이 할당됩니다.
def make_dataset(file_path, training=False):
"""Creates a `tf.data.TFRecordDataset`.
Args:
file_path: Name of the file in the `.tfrecord` format containing
`tf.train.Example` objects.
training: Boolean indicating if we are in training mode.
Returns:
An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
objects.
"""
def parse_example(example_proto):
"""Extracts relevant fields from the `example_proto`.
Args:
example_proto: An instance of `tf.train.Example`.
Returns:
A pair whose first value is a dictionary containing relevant features
and whose second value contains the ground truth label.
"""
# The 'words' feature is a multi-hot, bag-of-words representation of the
# original raw text. A default value is required for examples that don't
# have the feature.
feature_spec = {
'words':
tf.io.FixedLenFeature([HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0,
dtype=tf.int64,
shape=[HPARAMS.max_seq_length])),
'label':
tf.io.FixedLenFeature((), tf.int64, default_value=-1),
}
# We also extract corresponding neighbor features in a similar manner to
# the features above during training.
if training:
for i in range(HPARAMS.num_neighbors):
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
NBR_WEIGHT_SUFFIX)
feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
[HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))
# We assign a default value of 0.0 for the neighbor weight so that
# graph regularization is done on samples based on their exact number
# of neighbors. In other words, non-existent neighbors are discounted.
feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
[1], tf.float32, default_value=tf.constant([0.0]))
features = tf.io.parse_single_example(example_proto, feature_spec)
label = features.pop('label')
return features, label
dataset = tf.data.TFRecordDataset([file_path])
if training:
dataset = dataset.shuffle(10000)
dataset = dataset.map(parse_example)
dataset = dataset.batch(HPARAMS.batch_size)
return dataset
train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)
내용을 보기 위해 훈련 데이터세트를 살펴보겠습니다.
for feature_batch, label_batch in train_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
print('Batch of neighbor weights:',
tf.reshape(feature_batch[nbr_weight_key], [-1]))
print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words'] Batch of inputs: tf.Tensor( [[0 0 1 ... 0 0 0] [1 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor weights: tf.Tensor( [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32) Batch of labels: tf.Tensor( [2 2 2 0 0 6 6 1 2 6 2 0 5 6 1 2 5 3 3 0 3 3 2 6 6 2 1 2 1 3 6 4 6 3 0 2 2 1 3 3 6 6 3 2 2 1 2 2 6 6 5 0 6 2 0 2 6 6 2 2 2 5 2 3 3 0 3 3 6 3 6 3 1 2 2 3 3 3 2 3 0 1 2 2 0 2 3 3 2 6 3 2 3 1 2 4 2 1 2 2 3 6 1 2 3 2 5 2 2 3 2 2 1 1 3 2 1 4 0 2 3 5 2 1 2 2 0 1], shape=(128,), dtype=int64)
내용을 보기 위해 테스트 데이터세트를 살펴보겠습니다.
for feature_batch, label_batch in test_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
print('Batch of labels:', label_batch)
Feature list: ['words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of labels: tf.Tensor( [5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)
모델 정의
그래프 정규화를 사용하는 방법을 보여주기 위해 먼저 이 문제에 대한 기본 모델을 빌드합니다. 2개의 숨겨진 레이어와 그 사이에 드롭아웃이 있는 간단한 피드 포워드 신경망을 사용합니다. tf.Keras
프레임워크에서 지원하는 모든 모델 유형(순차, 함수형 및 서브 클래스)을 사용하여 기본 모델을 생성하는 방법을 설명합니다.
순차 기본 모델
def make_mlp_sequential_model(hparams):
"""Creates a sequential multi-layer perceptron model."""
model = tf.keras.Sequential()
model.add(
tf.keras.layers.InputLayer(
input_shape=(hparams.max_seq_length,), name='words'))
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
model.add(
tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
for num_units in hparams.num_fc_units:
model.add(tf.keras.layers.Dense(num_units, activation='relu'))
# For sequential models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
model.add(tf.keras.layers.Dense(hparams.num_classes, activation='softmax'))
return model
함수형 기본 모델
def make_mlp_functional_model(hparams):
"""Creates a functional API-based multi-layer perceptron model."""
inputs = tf.keras.Input(
shape=(hparams.max_seq_length,), dtype='int64', name='words')
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
cur_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))(
inputs)
for num_units in hparams.num_fc_units:
cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
# For functional models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)
outputs = tf.keras.layers.Dense(
hparams.num_classes, activation='softmax')(
cur_layer)
model = tf.keras.Model(inputs, outputs=outputs)
return model
서브 클래스 기본 모델
def make_mlp_subclass_model(hparams):
"""Creates a multi-layer perceptron subclass model in Keras."""
class MLP(tf.keras.Model):
"""Subclass model defining a multi-layer perceptron."""
def __init__(self):
super(MLP, self).__init__()
# Input is already one-hot encoded in the integer format. We create a
# layer to cast it to floating point format here.
self.cast_to_float_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))
self.dense_layers = [
tf.keras.layers.Dense(num_units, activation='relu')
for num_units in hparams.num_fc_units
]
self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
self.output_layer = tf.keras.layers.Dense(
hparams.num_classes, activation='softmax')
def call(self, inputs, training=False):
cur_layer = self.cast_to_float_layer(inputs['words'])
for dense_layer in self.dense_layers:
cur_layer = dense_layer(cur_layer)
cur_layer = self.dropout_layer(cur_layer, training=training)
outputs = self.output_layer(cur_layer)
return outputs
return MLP()
기본 모델 생성하기
# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= words (InputLayer) [(None, 1433)] 0 _________________________________________________________________ lambda (Lambda) (None, 1433) 0 _________________________________________________________________ dense (Dense) (None, 50) 71700 _________________________________________________________________ dropout (Dropout) (None, 50) 0 _________________________________________________________________ dense_1 (Dense) (None, 50) 2550 _________________________________________________________________ dropout_1 (Dropout) (None, 50) 0 _________________________________________________________________ dense_2 (Dense) (None, 7) 357 ================================================================= Total params: 74,607 Trainable params: 74,607 Non-trainable params: 0 _________________________________________________________________
기본 MLP 모델 훈련하기
# Compile and train the base MLP model
base_model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py:595: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model. [n for n in tensors.keys() if n not in ref_input_names]) 17/17 [==============================] - 1s 12ms/step - loss: 1.9717 - accuracy: 0.1436 Epoch 2/100 17/17 [==============================] - 0s 12ms/step - loss: 1.8710 - accuracy: 0.2844 Epoch 3/100 17/17 [==============================] - 0s 11ms/step - loss: 1.7676 - accuracy: 0.3435 Epoch 4/100 17/17 [==============================] - 0s 12ms/step - loss: 1.6748 - accuracy: 0.3555 Epoch 5/100 17/17 [==============================] - 0s 11ms/step - loss: 1.5762 - accuracy: 0.3930 Epoch 6/100 17/17 [==============================] - 0s 12ms/step - loss: 1.4101 - accuracy: 0.5144 Epoch 7/100 17/17 [==============================] - 0s 12ms/step - loss: 1.2371 - accuracy: 0.5929 Epoch 8/100 17/17 [==============================] - 0s 11ms/step - loss: 1.1377 - accuracy: 0.6022 Epoch 9/100 17/17 [==============================] - 0s 12ms/step - loss: 0.9814 - accuracy: 0.6762 Epoch 10/100 17/17 [==============================] - 0s 11ms/step - loss: 0.8619 - accuracy: 0.7089 Epoch 11/100 17/17 [==============================] - 0s 12ms/step - loss: 0.8184 - accuracy: 0.7439 Epoch 12/100 17/17 [==============================] - 0s 12ms/step - loss: 0.7374 - accuracy: 0.7569 Epoch 13/100 17/17 [==============================] - 0s 12ms/step - loss: 0.6403 - accuracy: 0.7911 Epoch 14/100 17/17 [==============================] - 0s 12ms/step - loss: 0.6167 - accuracy: 0.8038 Epoch 15/100 17/17 [==============================] - 0s 11ms/step - loss: 0.5324 - accuracy: 0.8423 Epoch 16/100 17/17 [==============================] - 0s 11ms/step - loss: 0.4927 - accuracy: 0.8447 Epoch 17/100 17/17 [==============================] - 0s 12ms/step - loss: 0.4550 - accuracy: 0.8589 Epoch 18/100 17/17 [==============================] - 0s 12ms/step - loss: 0.4416 - accuracy: 0.8682 Epoch 19/100 17/17 [==============================] - 0s 12ms/step - loss: 0.3883 - accuracy: 0.8835 Epoch 20/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3845 - accuracy: 0.8679 Epoch 21/100 17/17 [==============================] - 0s 12ms/step - loss: 0.3481 - accuracy: 0.8922 Epoch 22/100 17/17 [==============================] - 0s 12ms/step - loss: 0.3229 - accuracy: 0.8996 Epoch 23/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2873 - accuracy: 0.9198 Epoch 24/100 17/17 [==============================] - 0s 12ms/step - loss: 0.2848 - accuracy: 0.9158 Epoch 25/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2820 - accuracy: 0.9076 Epoch 26/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2745 - accuracy: 0.9187 Epoch 27/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2591 - accuracy: 0.9247 Epoch 28/100 17/17 [==============================] - 0s 12ms/step - loss: 0.2313 - accuracy: 0.9377 Epoch 29/100 17/17 [==============================] - 0s 12ms/step - loss: 0.2242 - accuracy: 0.9371 Epoch 30/100 17/17 [==============================] - 0s 12ms/step - loss: 0.2171 - accuracy: 0.9358 Epoch 31/100 17/17 [==============================] - 0s 12ms/step - loss: 0.2285 - accuracy: 0.9365 Epoch 32/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2079 - accuracy: 0.9358 Epoch 33/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1881 - accuracy: 0.9430 Epoch 34/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1703 - accuracy: 0.9556 Epoch 35/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1751 - accuracy: 0.9464 Epoch 36/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1843 - accuracy: 0.9495 Epoch 37/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1580 - accuracy: 0.9588 Epoch 38/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1557 - accuracy: 0.9548 Epoch 39/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1647 - accuracy: 0.9548 Epoch 40/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1494 - accuracy: 0.9584 Epoch 41/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1299 - accuracy: 0.9665 Epoch 42/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1432 - accuracy: 0.9657 Epoch 43/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1293 - accuracy: 0.9613 Epoch 44/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1050 - accuracy: 0.9759 Epoch 45/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1292 - accuracy: 0.9569 Epoch 46/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1182 - accuracy: 0.9670 Epoch 47/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1220 - accuracy: 0.9626 Epoch 48/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1210 - accuracy: 0.9598 Epoch 49/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1015 - accuracy: 0.9733 Epoch 50/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1042 - accuracy: 0.9714 Epoch 51/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1079 - accuracy: 0.9707 Epoch 52/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1176 - accuracy: 0.9615 Epoch 53/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0952 - accuracy: 0.9727 Epoch 54/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1062 - accuracy: 0.9697 Epoch 55/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0891 - accuracy: 0.9743 Epoch 56/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0906 - accuracy: 0.9764 Epoch 57/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0931 - accuracy: 0.9707 Epoch 58/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0901 - accuracy: 0.9762 Epoch 59/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0759 - accuracy: 0.9794 Epoch 60/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0954 - accuracy: 0.9700 Epoch 61/100 17/17 [==============================] - 0s 12ms/step - loss: 0.0933 - accuracy: 0.9769 Epoch 62/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0798 - accuracy: 0.9783 Epoch 63/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0695 - accuracy: 0.9845 Epoch 64/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0670 - accuracy: 0.9822 Epoch 65/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0751 - accuracy: 0.9807 Epoch 66/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0743 - accuracy: 0.9781 Epoch 67/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0619 - accuracy: 0.9855 Epoch 68/100 17/17 [==============================] - 0s 12ms/step - loss: 0.0683 - accuracy: 0.9820 Epoch 69/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0669 - accuracy: 0.9822 Epoch 70/100 17/17 [==============================] - 0s 12ms/step - loss: 0.0658 - accuracy: 0.9830 Epoch 71/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0681 - accuracy: 0.9841 Epoch 72/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0857 - accuracy: 0.9760 Epoch 73/100 17/17 [==============================] - 0s 12ms/step - loss: 0.0701 - accuracy: 0.9767 Epoch 74/100 17/17 [==============================] - 0s 12ms/step - loss: 0.0820 - accuracy: 0.9799 Epoch 75/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0612 - accuracy: 0.9854 Epoch 76/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0662 - accuracy: 0.9804 Epoch 77/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0678 - accuracy: 0.9800 Epoch 78/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0590 - accuracy: 0.9838 Epoch 79/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0640 - accuracy: 0.9807 Epoch 80/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0595 - accuracy: 0.9846 Epoch 81/100 17/17 [==============================] - 0s 12ms/step - loss: 0.0556 - accuracy: 0.9824 Epoch 82/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0632 - accuracy: 0.9816 Epoch 83/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0637 - accuracy: 0.9822 Epoch 84/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0524 - accuracy: 0.9858 Epoch 85/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0665 - accuracy: 0.9780 Epoch 86/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0586 - accuracy: 0.9807 Epoch 87/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0542 - accuracy: 0.9844 Epoch 88/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0506 - accuracy: 0.9847 Epoch 89/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0447 - accuracy: 0.9866 Epoch 90/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0465 - accuracy: 0.9880 Epoch 91/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0651 - accuracy: 0.9754 Epoch 92/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0584 - accuracy: 0.9825 Epoch 93/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0479 - accuracy: 0.9889 Epoch 94/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0517 - accuracy: 0.9843 Epoch 95/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0566 - accuracy: 0.9835 Epoch 96/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0478 - accuracy: 0.9844 Epoch 97/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0511 - accuracy: 0.9809 Epoch 98/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0516 - accuracy: 0.9866 Epoch 99/100 17/17 [==============================] - 0s 12ms/step - loss: 0.0454 - accuracy: 0.9891 Epoch 100/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0530 - accuracy: 0.9822 <tensorflow.python.keras.callbacks.History at 0x7fc82c426c88>
기본 MLP 모델 평가하기
# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
"""Prints evaluation metrics.
Args:
model_desc: A description of the model.
eval_metrics: A dictionary mapping metric names to corresponding values. It
must contain the loss and accuracy metrics.
"""
print('\n')
print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
if 'graph_loss' in eval_metrics:
print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
zip(base_model.metrics_names,
base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 8ms/step - loss: 1.4110 - accuracy: 0.7866 Eval accuracy for Base MLP model : 0.7866184711456299 Eval loss for Base MLP model : 1.4110491275787354
그래프 정규화로 MLP 모델 훈련하기
그래프 정규화를 기존 tf.Keras.Model
의 손실 항에 통합하려면 몇 줄의 코드만 있으면 됩니다. 기본 모델은 래핑되어 새로운 tf.Keras
서브 클래스 모델을 생성하며, 손실에는 그래프 정규화가 포함됩니다.
그래프 정규화의 점진적 이점을 평가하기 위해 새 기본 모델 인스턴스를 생성합니다. 이는 base_model
이 이미 몇 번의 반복 동안 훈련되었으며, 이 훈련된 모델을 재사용하여 그래프 정규화 모델을 만드는 것은 base_model
에 대한 공정한 비교가 되지 않기 때문입니다.
# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
max_neighbors=HPARAMS.num_neighbors,
multiplier=HPARAMS.graph_regularization_multiplier,
distance_type=HPARAMS.distance_type,
sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
graph_reg_config)
graph_reg_model.compile(
optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/indexed_slices.py:437: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape:0", shape=(None, 7), dtype=float32), dense_shape=Tensor("gradient_tape/GraphRegularization/graph_loss/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) 17/17 [==============================] - 2s 11ms/step - loss: 1.9225 - accuracy: 0.1801 - scaled_graph_loss: 9.2100e-04 Epoch 2/100 17/17 [==============================] - 0s 11ms/step - loss: 1.8425 - accuracy: 0.2831 - scaled_graph_loss: 0.0014 Epoch 3/100 17/17 [==============================] - 0s 11ms/step - loss: 1.7455 - accuracy: 0.3194 - scaled_graph_loss: 0.0025 Epoch 4/100 17/17 [==============================] - 0s 11ms/step - loss: 1.6523 - accuracy: 0.3600 - scaled_graph_loss: 0.0042 Epoch 5/100 17/17 [==============================] - 0s 11ms/step - loss: 1.5537 - accuracy: 0.3918 - scaled_graph_loss: 0.0061 Epoch 6/100 17/17 [==============================] - 0s 12ms/step - loss: 1.3923 - accuracy: 0.4934 - scaled_graph_loss: 0.0091 Epoch 7/100 17/17 [==============================] - 0s 11ms/step - loss: 1.2615 - accuracy: 0.5619 - scaled_graph_loss: 0.0131 Epoch 8/100 17/17 [==============================] - 0s 12ms/step - loss: 1.1398 - accuracy: 0.6262 - scaled_graph_loss: 0.0167 Epoch 9/100 17/17 [==============================] - 0s 11ms/step - loss: 1.0197 - accuracy: 0.6717 - scaled_graph_loss: 0.0212 Epoch 10/100 17/17 [==============================] - 0s 11ms/step - loss: 0.9155 - accuracy: 0.7002 - scaled_graph_loss: 0.0247 Epoch 11/100 17/17 [==============================] - 0s 11ms/step - loss: 0.7946 - accuracy: 0.7688 - scaled_graph_loss: 0.0258 Epoch 12/100 17/17 [==============================] - 0s 11ms/step - loss: 0.7516 - accuracy: 0.7877 - scaled_graph_loss: 0.0272 Epoch 13/100 17/17 [==============================] - 0s 12ms/step - loss: 0.6715 - accuracy: 0.8285 - scaled_graph_loss: 0.0276 Epoch 14/100 17/17 [==============================] - 0s 11ms/step - loss: 0.6117 - accuracy: 0.8109 - scaled_graph_loss: 0.0317 Epoch 15/100 17/17 [==============================] - 0s 11ms/step - loss: 0.5729 - accuracy: 0.8365 - scaled_graph_loss: 0.0313 Epoch 16/100 17/17 [==============================] - 0s 12ms/step - loss: 0.5208 - accuracy: 0.8520 - scaled_graph_loss: 0.0327 Epoch 17/100 17/17 [==============================] - 0s 12ms/step - loss: 0.4611 - accuracy: 0.8802 - scaled_graph_loss: 0.0307 Epoch 18/100 17/17 [==============================] - 0s 11ms/step - loss: 0.4573 - accuracy: 0.8776 - scaled_graph_loss: 0.0324 Epoch 19/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3964 - accuracy: 0.9074 - scaled_graph_loss: 0.0320 Epoch 20/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3887 - accuracy: 0.9051 - scaled_graph_loss: 0.0337 Epoch 21/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3882 - accuracy: 0.8998 - scaled_graph_loss: 0.0350 Epoch 22/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3457 - accuracy: 0.9086 - scaled_graph_loss: 0.0337 Epoch 23/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3666 - accuracy: 0.9020 - scaled_graph_loss: 0.0332 Epoch 24/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3249 - accuracy: 0.9178 - scaled_graph_loss: 0.0336 Epoch 25/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3070 - accuracy: 0.9131 - scaled_graph_loss: 0.0348 Epoch 26/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2703 - accuracy: 0.9342 - scaled_graph_loss: 0.0323 Epoch 27/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2743 - accuracy: 0.9369 - scaled_graph_loss: 0.0346 Epoch 28/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2609 - accuracy: 0.9313 - scaled_graph_loss: 0.0334 Epoch 29/100 17/17 [==============================] - 0s 12ms/step - loss: 0.2561 - accuracy: 0.9355 - scaled_graph_loss: 0.0334 Epoch 30/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2222 - accuracy: 0.9489 - scaled_graph_loss: 0.0318 Epoch 31/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2030 - accuracy: 0.9545 - scaled_graph_loss: 0.0324 Epoch 32/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2269 - accuracy: 0.9437 - scaled_graph_loss: 0.0329 Epoch 33/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2138 - accuracy: 0.9498 - scaled_graph_loss: 0.0351 Epoch 34/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2171 - accuracy: 0.9490 - scaled_graph_loss: 0.0347 Epoch 35/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2096 - accuracy: 0.9519 - scaled_graph_loss: 0.0344 Epoch 36/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2035 - accuracy: 0.9517 - scaled_graph_loss: 0.0350 Epoch 37/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1795 - accuracy: 0.9619 - scaled_graph_loss: 0.0330 Epoch 38/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1818 - accuracy: 0.9603 - scaled_graph_loss: 0.0346 Epoch 39/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1778 - accuracy: 0.9596 - scaled_graph_loss: 0.0340 Epoch 40/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1788 - accuracy: 0.9597 - scaled_graph_loss: 0.0348 Epoch 41/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1604 - accuracy: 0.9699 - scaled_graph_loss: 0.0332 Epoch 42/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1692 - accuracy: 0.9682 - scaled_graph_loss: 0.0357 Epoch 43/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1506 - accuracy: 0.9744 - scaled_graph_loss: 0.0342 Epoch 44/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1763 - accuracy: 0.9628 - scaled_graph_loss: 0.0352 Epoch 45/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1721 - accuracy: 0.9657 - scaled_graph_loss: 0.0354 Epoch 46/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1386 - accuracy: 0.9726 - scaled_graph_loss: 0.0325 Epoch 47/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1458 - accuracy: 0.9669 - scaled_graph_loss: 0.0332 Epoch 48/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1211 - accuracy: 0.9815 - scaled_graph_loss: 0.0334 Epoch 49/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1281 - accuracy: 0.9786 - scaled_graph_loss: 0.0326 Epoch 50/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1285 - accuracy: 0.9814 - scaled_graph_loss: 0.0343 Epoch 51/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1317 - accuracy: 0.9748 - scaled_graph_loss: 0.0355 Epoch 52/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1420 - accuracy: 0.9706 - scaled_graph_loss: 0.0343 Epoch 53/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1395 - accuracy: 0.9715 - scaled_graph_loss: 0.0338 Epoch 54/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1260 - accuracy: 0.9758 - scaled_graph_loss: 0.0350 Epoch 55/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1261 - accuracy: 0.9778 - scaled_graph_loss: 0.0321 Epoch 56/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1205 - accuracy: 0.9793 - scaled_graph_loss: 0.0341 Epoch 57/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1170 - accuracy: 0.9814 - scaled_graph_loss: 0.0337 Epoch 58/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1223 - accuracy: 0.9715 - scaled_graph_loss: 0.0338 Epoch 59/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1181 - accuracy: 0.9737 - scaled_graph_loss: 0.0332 Epoch 60/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1104 - accuracy: 0.9827 - scaled_graph_loss: 0.0341 Epoch 61/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0919 - accuracy: 0.9840 - scaled_graph_loss: 0.0339 Epoch 62/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0999 - accuracy: 0.9838 - scaled_graph_loss: 0.0331 Epoch 63/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1078 - accuracy: 0.9833 - scaled_graph_loss: 0.0339 Epoch 64/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0994 - accuracy: 0.9854 - scaled_graph_loss: 0.0324 Epoch 65/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1016 - accuracy: 0.9820 - scaled_graph_loss: 0.0355 Epoch 66/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0962 - accuracy: 0.9859 - scaled_graph_loss: 0.0327 Epoch 67/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0999 - accuracy: 0.9846 - scaled_graph_loss: 0.0345 Epoch 68/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1032 - accuracy: 0.9823 - scaled_graph_loss: 0.0333 Epoch 69/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1035 - accuracy: 0.9828 - scaled_graph_loss: 0.0349 Epoch 70/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1052 - accuracy: 0.9828 - scaled_graph_loss: 0.0344 Epoch 71/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0921 - accuracy: 0.9874 - scaled_graph_loss: 0.0329 Epoch 72/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0907 - accuracy: 0.9867 - scaled_graph_loss: 0.0344 Epoch 73/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0982 - accuracy: 0.9851 - scaled_graph_loss: 0.0344 Epoch 74/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0892 - accuracy: 0.9817 - scaled_graph_loss: 0.0319 Epoch 75/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0966 - accuracy: 0.9835 - scaled_graph_loss: 0.0345 Epoch 76/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0888 - accuracy: 0.9888 - scaled_graph_loss: 0.0339 Epoch 77/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1019 - accuracy: 0.9800 - scaled_graph_loss: 0.0330 Epoch 78/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0931 - accuracy: 0.9806 - scaled_graph_loss: 0.0334 Epoch 79/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0862 - accuracy: 0.9848 - scaled_graph_loss: 0.0351 Epoch 80/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0852 - accuracy: 0.9870 - scaled_graph_loss: 0.0321 Epoch 81/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0927 - accuracy: 0.9837 - scaled_graph_loss: 0.0344 Epoch 82/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0954 - accuracy: 0.9826 - scaled_graph_loss: 0.0370 Epoch 83/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0810 - accuracy: 0.9904 - scaled_graph_loss: 0.0333 Epoch 84/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0827 - accuracy: 0.9874 - scaled_graph_loss: 0.0304 Epoch 85/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0898 - accuracy: 0.9854 - scaled_graph_loss: 0.0330 Epoch 86/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0829 - accuracy: 0.9864 - scaled_graph_loss: 0.0332 Epoch 87/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0784 - accuracy: 0.9893 - scaled_graph_loss: 0.0336 Epoch 88/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0811 - accuracy: 0.9876 - scaled_graph_loss: 0.0321 Epoch 89/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0811 - accuracy: 0.9887 - scaled_graph_loss: 0.0327 Epoch 90/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0813 - accuracy: 0.9856 - scaled_graph_loss: 0.0342 Epoch 91/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0765 - accuracy: 0.9896 - scaled_graph_loss: 0.0333 Epoch 92/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0888 - accuracy: 0.9814 - scaled_graph_loss: 0.0342 Epoch 93/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0796 - accuracy: 0.9843 - scaled_graph_loss: 0.0329 Epoch 94/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0784 - accuracy: 0.9859 - scaled_graph_loss: 0.0333 Epoch 95/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0842 - accuracy: 0.9882 - scaled_graph_loss: 0.0332 Epoch 96/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0739 - accuracy: 0.9920 - scaled_graph_loss: 0.0337 Epoch 97/100 17/17 [==============================] - 0s 12ms/step - loss: 0.0810 - accuracy: 0.9857 - scaled_graph_loss: 0.0347 Epoch 98/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0856 - accuracy: 0.9871 - scaled_graph_loss: 0.0356 Epoch 99/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0810 - accuracy: 0.9879 - scaled_graph_loss: 0.0305 Epoch 100/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0735 - accuracy: 0.9882 - scaled_graph_loss: 0.0339 <tensorflow.python.keras.callbacks.History at 0x7fc8206484a8>
그래프 정규화로 MLP 모델 평가하기
eval_results = dict(
zip(graph_reg_model.metrics_names,
graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 8ms/step - loss: 1.3513 - accuracy: 0.7946 Eval accuracy for MLP + graph regularization : 0.8010849952697754 Eval loss for MLP + graph regularization : 1.349246859550476
그래프 정규화 모델의 정확성은 기본 모델(base_model
)보다 약 2~3% 높습니다.
결론
Neural Structured Learning(NSL) 프레임워크를 사용하여 자연 인용 그래프(Cora)에서 문서 분류를 위해 그래프 정규화를 사용하는 방법을 시연했습니다. 고급 튜토리얼에는 그래프 정규화로 신경망을 훈련하기 전에 샘플 임베딩을 기반으로 그래프를 합성하는 것이 포함됩니다. 이 접근 방식은 입력에 명시적 그래프가 포함되지 않은 경우 유용합니다.
사용자가 감독의 양을 변경하고 그래프 정규화를 위해 다양한 신경 아키텍처를 시도하여 추가 실험을 할 것을 권장합니다.