عرض على TensorFlow.org | تشغيل في Google Colab | عرض المصدر على جيثب | تحميل دفتر |
ملخص
الرسم البياني تنظيم هو أسلوب محدد في إطار نموذج أوسع من العصبية الرسم البياني التعلم ( بوي وآخرون، 2018 ). الفكرة الأساسية هي تدريب نماذج الشبكة العصبية بهدف منظم للرسم البياني ، وتسخير البيانات المصنفة وغير المصنفة.
في هذا البرنامج التعليمي ، سوف نستكشف استخدام تنظيم الرسم البياني لتصنيف المستندات التي تشكل رسمًا بيانيًا طبيعيًا (عضويًا).
الوصفة العامة لإنشاء نموذج منظم للرسم البياني باستخدام إطار عمل التعلم المهيكل العصبي (NSL) هو كما يلي:
- قم بتوليد بيانات التدريب من الرسم البياني للإدخال وميزات العينة. تتوافق العقد في الرسم البياني مع العينات وتتوافق الحواف في الرسم البياني مع التشابه بين أزواج العينات. ستحتوي بيانات التدريب الناتجة على ميزات الجوار بالإضافة إلى ميزات العقدة الأصلية.
- إنشاء الشبكة العصبية نموذجا قاعدة باستخدام
Keras
متتابعة، وظيفية، أو API فئة فرعية. - التفاف قاعدة نموذجية مع
GraphRegularization
فئة المجمع، والتي يتم توفيرها من قبل إطار قانون الأمن القومي، لإنشاء رسم بياني جديدKeras
نموذج. سيتضمن هذا النموذج الجديد خسارة تنظيم الرسم البياني كمصطلح التنظيم في هدف التدريب الخاص به. - تدريب وتقييم الرسم البياني
Keras
نموذج.
يثبت
قم بتثبيت حزمة التعلم المهيكل العصبي.
pip install --quiet neural-structured-learning
التبعيات والواردات
import neural_structured_learning as nsl
import tensorflow as tf
# Resets notebook state
tf.keras.backend.clear_session()
print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
"GPU is",
"available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")
Version: 2.8.0-rc0 Eager mode: True GPU is NOT AVAILABLE 2022-01-05 12:39:27.704660: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
مجموعة بيانات كورا
و رقة العمل كورا هو الرسم البياني الاقتباس حيث تمثل العقد أوراق تعلم الآلة وحواف تمثل الاستشهادات بين أزواج من الأوراق. المهمة المتضمنة هي تصنيف الوثيقة حيث الهدف هو تصنيف كل ورقة في واحدة من سبع فئات. بمعنى آخر ، هذه مشكلة تصنيف متعددة الفئات تتكون من 7 فئات.
رسم بياني
الرسم البياني الأصلي موجه. ومع ذلك ، لغرض هذا المثال ، فإننا نعتبر النسخة غير الموجهة من هذا الرسم البياني. لذلك ، إذا استشهدت الورقة (أ) بالورقة (ب) ، فإننا نعتبر أيضًا أن الورقة (ب) قد استشهدت بـ (أ) على الرغم من أن هذا ليس صحيحًا بالضرورة ، في هذا المثال ، فإننا نعتبر الاستشهادات وكيلًا للتشابه ، والتي تكون عادةً خاصية تبادلية.
سمات
تحتوي كل ورقة في الإدخال بشكل فعال على ميزتين:
كلمات: A كثيفة ومتعددة الساخن حقيبة من بين الكلمات تمثيل النص في ورقة. تحتوي مفردات مجموعة بيانات Cora على 1433 كلمة فريدة. لذا ، فإن طول هذه الميزة هو 1433 ، والقيمة في الموضع "i" تساوي 0/1 مما يشير إلى ما إذا كانت الكلمة "i" في المفردات موجودة في ورقة معينة أم لا.
التسمية: عدد صحيح واحد يمثل معرف فئة (فئة) من ورقة.
قم بتنزيل مجموعة بيانات Cora
wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz
tar -C /tmp -xvzf /tmp/cora.tgz
cora/ cora/README cora/cora.cites cora/cora.content
قم بتحويل بيانات Cora إلى تنسيق NSL
من أجل المعالجة المسبقة لمجموعة البيانات كورا وتحويله إلى الشكل المطلوب عن طريق التعلم العصبية الهيكلية، ونحن سوف تشغيل "preprocess_cora_dataset.py 'النصي، والتي يتم تضمينها في مستودع الأمن القومي جيثب. يقوم هذا البرنامج النصي بما يلي:
- قم بإنشاء ميزات الجوار باستخدام ميزات العقدة الأصلية والرسم البياني.
- توليد القطار وبيانات الاختبار الانقسامات التي تحتوي على
tf.train.Example
الحالات. - تستمر القطار الناتجة عنها، وبيانات الاختبار في
TFRecord
الشكل.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py
!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr
--2022-01-05 12:39:28-- https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 11640 (11K) [text/plain] Saving to: ‘preprocess_cora_dataset.py’ preprocess_cora_dat 100%[===================>] 11.37K --.-KB/s in 0s 2022-01-05 12:39:28 (78.9 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640] 2022-01-05 12:39:31.378912: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Reading graph file: /tmp/cora/cora.cites... Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds). Making all edges bi-directional... Done (0.01 seconds). Total graph nodes: 2708 Joining seed and neighbor tf.train.Examples with graph edges... Done creating and writing 2155 merged tf.train.Examples (1.36 seconds). Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)] Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr. Output test data written to TFRecord file: /tmp/cora/test_examples.tfr. Total running time: 0.04 minutes.
المتغيرات العالمية
وتستند مسارات الملفات لبيانات القطار والاختبار على قيم علامة سطر الأوامر المستخدمة لاستدعاء "preprocess_cora_dataset.py" السيناريو أعلاه.
### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'
### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'
Hyperparameters
سوف نستخدم مثيل HParams
لتشمل مختلف hyperparameters والثوابت المستخدمة للتدريب والتقييم. نصف بإيجاز كل منهم أدناه:
num_classes: هناك ما مجموعه 7 فئات مختلفة
max_seq_length: هذا هو حجم المفردات وجميع الحالات في مدخلات لها كثيف متعدد الساخن، وحقيبة من بين الكلمات التمثيل. بمعنى آخر ، تشير القيمة 1 لكلمة إلى أن الكلمة موجودة في الإدخال وتشير القيمة 0 إلى أنها ليست كذلك.
distance_type: هذه هي المسافة متري تستخدم لتنظيم العينة مع جيرانها.
graph_regularization_multiplier: يتحكم هذا الوزن النسبي لمصطلح الرسم البياني تسوية في فقدان وظيفة الكلية.
num_neighbors: عدد من الدول المجاورة تستخدم لتنظيم الرسم البياني. هذه القيمة يجب أن تكون أقل من أو يساوي
max_nbrs
قيادة خط الحجة المستخدمة أعلاه عند تشغيلpreprocess_cora_dataset.py
.num_fc_units: عدد طبقات مرتبطة ارتباطا كاملا في الشبكة العصبية لدينا.
train_epochs: عدد العهود التدريب.
حجم دفعة تستخدم للتدريب والتقييم: batch_size.
dropout_rate: التحكم في معدل التسرب بعد كل طبقة متصلة بشكل كامل
eval_steps: عدد دفعات لعملية قبل اعتبار تقييم كامل. إذا تم تعيين إلى
None
، يتم تقييم جميع الحالات في مجموعة الاختبار.
class HParams(object):
"""Hyperparameters used for training."""
def __init__(self):
### dataset parameters
self.num_classes = 7
self.max_seq_length = 1433
### neural graph learning parameters
self.distance_type = nsl.configs.DistanceType.L2
self.graph_regularization_multiplier = 0.1
self.num_neighbors = 1
### model architecture
self.num_fc_units = [50, 50]
### training parameters
self.train_epochs = 100
self.batch_size = 128
self.dropout_rate = 0.5
### eval parameters
self.eval_steps = None # All instances in the test set are evaluated.
HPARAMS = HParams()
تحميل بيانات القطار والاختبار
كما هو موضح سابقا في هذه المفكرة، وقد تم إنشاء بيانات التدريب المدخلات واختبار من قبل 'preprocess_cora_dataset.py. وسوف تحميلها على اثنين tf.data.Dataset
الكائنات - واحدة للقطار واحد للاختبار.
في طبقة مدخلات النموذج، فإننا سوف استخراج وليس فقط "كلمات" و "التسمية" ملامح من كل عينة، ولكن ملامح جار أيضا المقابلة على أساس hparams.num_neighbors
القيمة. الحالات مع الدول المجاورة أقل من hparams.num_neighbors
سيتم تعيين وهمية القيم لتلك الميزات الجار غير موجودة.
def make_dataset(file_path, training=False):
"""Creates a `tf.data.TFRecordDataset`.
Args:
file_path: Name of the file in the `.tfrecord` format containing
`tf.train.Example` objects.
training: Boolean indicating if we are in training mode.
Returns:
An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
objects.
"""
def parse_example(example_proto):
"""Extracts relevant fields from the `example_proto`.
Args:
example_proto: An instance of `tf.train.Example`.
Returns:
A pair whose first value is a dictionary containing relevant features
and whose second value contains the ground truth label.
"""
# The 'words' feature is a multi-hot, bag-of-words representation of the
# original raw text. A default value is required for examples that don't
# have the feature.
feature_spec = {
'words':
tf.io.FixedLenFeature([HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0,
dtype=tf.int64,
shape=[HPARAMS.max_seq_length])),
'label':
tf.io.FixedLenFeature((), tf.int64, default_value=-1),
}
# We also extract corresponding neighbor features in a similar manner to
# the features above during training.
if training:
for i in range(HPARAMS.num_neighbors):
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
NBR_WEIGHT_SUFFIX)
feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
[HPARAMS.max_seq_length],
tf.int64,
default_value=tf.constant(
0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))
# We assign a default value of 0.0 for the neighbor weight so that
# graph regularization is done on samples based on their exact number
# of neighbors. In other words, non-existent neighbors are discounted.
feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
[1], tf.float32, default_value=tf.constant([0.0]))
features = tf.io.parse_single_example(example_proto, feature_spec)
label = features.pop('label')
return features, label
dataset = tf.data.TFRecordDataset([file_path])
if training:
dataset = dataset.shuffle(10000)
dataset = dataset.map(parse_example)
dataset = dataset.batch(HPARAMS.batch_size)
return dataset
train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)
دعنا نلقي نظرة خاطفة على مجموعة بيانات القطار لنلقي نظرة على محتوياتها.
for feature_batch, label_batch in train_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
print('Batch of neighbor weights:',
tf.reshape(feature_batch[nbr_weight_key], [-1]))
print('Batch of labels:', label_batch)
Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 1 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor weights: tf.Tensor( [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.], shape=(128,), dtype=float32) Batch of labels: tf.Tensor( [2 2 6 2 0 6 1 3 5 0 1 2 3 6 1 1 0 3 5 2 3 1 4 1 6 1 3 2 2 2 0 3 2 1 3 3 2 3 3 2 3 2 2 0 2 2 6 0 2 1 1 0 5 2 1 4 2 1 2 4 0 2 5 4 3 6 3 2 1 6 2 4 2 2 6 4 6 4 3 5 2 2 2 4 2 2 2 1 2 2 2 4 2 3 6 2 0 6 6 0 2 6 2 1 2 0 1 1 3 2 0 2 0 2 1 1 3 5 2 1 2 5 1 6 2 4 6 4], shape=(128,), dtype=int64)
دعنا نلقي نظرة خاطفة على مجموعة بيانات الاختبار لنلقي نظرة على محتوياتها.
for feature_batch, label_batch in test_dataset.take(1):
print('Feature list:', list(feature_batch.keys()))
print('Batch of inputs:', feature_batch['words'])
print('Batch of labels:', label_batch)
Feature list: ['words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of labels: tf.Tensor( [5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)
تعريف النموذج
لتوضيح استخدام تنظيم الرسم البياني ، قمنا ببناء نموذج أساسي لهذه المشكلة أولاً. سنستخدم شبكة عصبونية بسيطة للتغذية الأمامية مع طبقتين مخفيتين ومنفصل بينهما. نحن لتوضيح إنشاء قاعدة نموذجية باستخدام جميع أنواع نموذج معتمد من قبل tf.Keras
الإطار - متتابعة، وظيفية، وفئة فرعية.
نموذج القاعدة المتسلسل
def make_mlp_sequential_model(hparams):
"""Creates a sequential multi-layer perceptron model."""
model = tf.keras.Sequential()
model.add(
tf.keras.layers.InputLayer(
input_shape=(hparams.max_seq_length,), name='words'))
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
model.add(
tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
for num_units in hparams.num_fc_units:
model.add(tf.keras.layers.Dense(num_units, activation='relu'))
# For sequential models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
model.add(tf.keras.layers.Dense(hparams.num_classes))
return model
نموذج أساسي وظيفي
def make_mlp_functional_model(hparams):
"""Creates a functional API-based multi-layer perceptron model."""
inputs = tf.keras.Input(
shape=(hparams.max_seq_length,), dtype='int64', name='words')
# Input is already one-hot encoded in the integer format. We cast it to
# floating point format here.
cur_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))(
inputs)
for num_units in hparams.num_fc_units:
cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
# For functional models, by default, Keras ensures that the 'dropout' layer
# is invoked only during training.
cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)
outputs = tf.keras.layers.Dense(hparams.num_classes)(cur_layer)
model = tf.keras.Model(inputs, outputs=outputs)
return model
نموذج قاعدة الفئة الفرعية
def make_mlp_subclass_model(hparams):
"""Creates a multi-layer perceptron subclass model in Keras."""
class MLP(tf.keras.Model):
"""Subclass model defining a multi-layer perceptron."""
def __init__(self):
super(MLP, self).__init__()
# Input is already one-hot encoded in the integer format. We create a
# layer to cast it to floating point format here.
self.cast_to_float_layer = tf.keras.layers.Lambda(
lambda x: tf.keras.backend.cast(x, tf.float32))
self.dense_layers = [
tf.keras.layers.Dense(num_units, activation='relu')
for num_units in hparams.num_fc_units
]
self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
self.output_layer = tf.keras.layers.Dense(hparams.num_classes)
def call(self, inputs, training=False):
cur_layer = self.cast_to_float_layer(inputs['words'])
for dense_layer in self.dense_layers:
cur_layer = dense_layer(cur_layer)
cur_layer = self.dropout_layer(cur_layer, training=training)
outputs = self.output_layer(cur_layer)
return outputs
return MLP()
إنشاء نموذج (نماذج) أساسي
# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= words (InputLayer) [(None, 1433)] 0 lambda (Lambda) (None, 1433) 0 dense (Dense) (None, 50) 71700 dropout (Dropout) (None, 50) 0 dense_1 (Dense) (None, 50) 2550 dropout_1 (Dropout) (None, 50) 0 dense_2 (Dense) (None, 7) 357 ================================================================= Total params: 74,607 Trainable params: 74,607 Non-trainable params: 0 _________________________________________________________________
نموذج تدريب قاعدة MLP
# Compile and train the base MLP model
base_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/keras/engine/functional.py:559: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model. inputs = self._flatten_to_reference_inputs(inputs) 17/17 [==============================] - 1s 18ms/step - loss: 1.9521 - accuracy: 0.1838 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8590 - accuracy: 0.3044 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7770 - accuracy: 0.3601 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6655 - accuracy: 0.3898 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5386 - accuracy: 0.4543 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3856 - accuracy: 0.5077 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2736 - accuracy: 0.5531 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1636 - accuracy: 0.5889 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0654 - accuracy: 0.6385 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9703 - accuracy: 0.6761 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8689 - accuracy: 0.7104 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7704 - accuracy: 0.7494 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7157 - accuracy: 0.7810 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 0.6296 - accuracy: 0.8186 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5932 - accuracy: 0.8167 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5526 - accuracy: 0.8464 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5112 - accuracy: 0.8445 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4624 - accuracy: 0.8613 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4163 - accuracy: 0.8696 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3808 - accuracy: 0.8849 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3564 - accuracy: 0.8933 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3453 - accuracy: 0.9002 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3226 - accuracy: 0.9114 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3058 - accuracy: 0.9151 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2798 - accuracy: 0.9146 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2638 - accuracy: 0.9248 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2538 - accuracy: 0.9290 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2356 - accuracy: 0.9411 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2080 - accuracy: 0.9425 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2172 - accuracy: 0.9364 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2259 - accuracy: 0.9225 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1944 - accuracy: 0.9480 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1892 - accuracy: 0.9434 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1718 - accuracy: 0.9592 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1826 - accuracy: 0.9508 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1585 - accuracy: 0.9559 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1605 - accuracy: 0.9545 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1529 - accuracy: 0.9550 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1411 - accuracy: 0.9615 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1366 - accuracy: 0.9624 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1431 - accuracy: 0.9578 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1241 - accuracy: 0.9619 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1310 - accuracy: 0.9661 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1284 - accuracy: 0.9652 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1215 - accuracy: 0.9633 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1130 - accuracy: 0.9722 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1074 - accuracy: 0.9722 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1143 - accuracy: 0.9694 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1015 - accuracy: 0.9740 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1077 - accuracy: 0.9698 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1035 - accuracy: 0.9684 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1076 - accuracy: 0.9694 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1000 - accuracy: 0.9689 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0967 - accuracy: 0.9749 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0994 - accuracy: 0.9703 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0943 - accuracy: 0.9740 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0923 - accuracy: 0.9735 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0848 - accuracy: 0.9800 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0836 - accuracy: 0.9782 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0913 - accuracy: 0.9735 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0823 - accuracy: 0.9773 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0753 - accuracy: 0.9810 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0746 - accuracy: 0.9777 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0861 - accuracy: 0.9731 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0765 - accuracy: 0.9787 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0750 - accuracy: 0.9791 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0725 - accuracy: 0.9814 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0762 - accuracy: 0.9791 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0645 - accuracy: 0.9842 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0606 - accuracy: 0.9861 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0775 - accuracy: 0.9805 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0655 - accuracy: 0.9800 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0629 - accuracy: 0.9833 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0625 - accuracy: 0.9824 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0607 - accuracy: 0.9838 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0578 - accuracy: 0.9824 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0568 - accuracy: 0.9842 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0595 - accuracy: 0.9833 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0615 - accuracy: 0.9842 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0555 - accuracy: 0.9852 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0517 - accuracy: 0.9870 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0541 - accuracy: 0.9856 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0533 - accuracy: 0.9884 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0509 - accuracy: 0.9838 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0600 - accuracy: 0.9828 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0617 - accuracy: 0.9800 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0599 - accuracy: 0.9800 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0502 - accuracy: 0.9870 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0416 - accuracy: 0.9907 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0542 - accuracy: 0.9842 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0490 - accuracy: 0.9847 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0374 - accuracy: 0.9916 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0467 - accuracy: 0.9893 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0426 - accuracy: 0.9879 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0543 - accuracy: 0.9861 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0420 - accuracy: 0.9870 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0461 - accuracy: 0.9861 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0425 - accuracy: 0.9898 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0406 - accuracy: 0.9907 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0486 - accuracy: 0.9847 <keras.callbacks.History at 0x7f6f9d5eacd0>
تقييم نموذج MLP الأساسي
# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
"""Prints evaluation metrics.
Args:
model_desc: A description of the model.
eval_metrics: A dictionary mapping metric names to corresponding values. It
must contain the loss and accuracy metrics.
"""
print('\n')
print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
if 'graph_loss' in eval_metrics:
print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
zip(base_model.metrics_names,
base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 1.4192 - accuracy: 0.7939 Eval accuracy for Base MLP model : 0.7938517332077026 Eval loss for Base MLP model : 1.4192423820495605
تدريب نموذج MLP مع تنظيم الرسم البياني
دمج الرسم البياني تسوية في المدى فقدان موجود tf.Keras.Model
يتطلب فقط بضعة أسطر من التعليمات البرمجية. هو التفاف قاعدة نموذجية لخلق جديد tf.Keras
نموذج فئة فرعية، التي تشمل الرسم البياني تسوية الخسائر.
لتقييم الفائدة المتزايدة لتسوية الرسم البياني ، سننشئ مثيل نموذج أساسي جديد. وذلك لأن base_model
تم بالفعل تدريب لبضعة التكرار، وسوف إعادة استخدام هذا النموذج مدربة على إنشاء نموذج تنظيما الرسم البياني لا تكون المقارنة عادلة ل base_model
.
# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
max_neighbors=HPARAMS.num_neighbors,
multiplier=HPARAMS.graph_regularization_multiplier,
distance_type=HPARAMS.distance_type,
sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
graph_reg_config)
graph_reg_model.compile(
optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)
Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/framework/indexed_slices.py:446: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape:0", shape=(None, 7), dtype=float32), dense_shape=Tensor("gradient_tape/GraphRegularization/graph_loss/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) 17/17 [==============================] - 2s 4ms/step - loss: 1.9798 - accuracy: 0.1601 - scaled_graph_loss: 0.0373 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.9024 - accuracy: 0.2979 - scaled_graph_loss: 0.0254 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8623 - accuracy: 0.3160 - scaled_graph_loss: 0.0317 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8042 - accuracy: 0.3443 - scaled_graph_loss: 0.0498 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7552 - accuracy: 0.3582 - scaled_graph_loss: 0.0696 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7012 - accuracy: 0.4084 - scaled_graph_loss: 0.0866 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6578 - accuracy: 0.4515 - scaled_graph_loss: 0.1114 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6058 - accuracy: 0.5039 - scaled_graph_loss: 0.1300 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5498 - accuracy: 0.5434 - scaled_graph_loss: 0.1508 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5098 - accuracy: 0.6019 - scaled_graph_loss: 0.1651 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4746 - accuracy: 0.6302 - scaled_graph_loss: 0.1844 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4315 - accuracy: 0.6520 - scaled_graph_loss: 0.1917 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3932 - accuracy: 0.6770 - scaled_graph_loss: 0.2024 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3645 - accuracy: 0.7183 - scaled_graph_loss: 0.2145 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3265 - accuracy: 0.7369 - scaled_graph_loss: 0.2324 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3045 - accuracy: 0.7555 - scaled_graph_loss: 0.2358 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2836 - accuracy: 0.7652 - scaled_graph_loss: 0.2404 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2456 - accuracy: 0.7898 - scaled_graph_loss: 0.2469 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2348 - accuracy: 0.8074 - scaled_graph_loss: 0.2615 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2000 - accuracy: 0.8074 - scaled_graph_loss: 0.2542 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1994 - accuracy: 0.8260 - scaled_graph_loss: 0.2729 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1825 - accuracy: 0.8269 - scaled_graph_loss: 0.2676 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1598 - accuracy: 0.8455 - scaled_graph_loss: 0.2742 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1543 - accuracy: 0.8534 - scaled_graph_loss: 0.2797 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1456 - accuracy: 0.8552 - scaled_graph_loss: 0.2714 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8566 - scaled_graph_loss: 0.2796 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1150 - accuracy: 0.8687 - scaled_graph_loss: 0.2850 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1154 - accuracy: 0.8626 - scaled_graph_loss: 0.2772 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0806 - accuracy: 0.8733 - scaled_graph_loss: 0.2756 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0828 - accuracy: 0.8626 - scaled_graph_loss: 0.2907 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0724 - accuracy: 0.8886 - scaled_graph_loss: 0.2834 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0589 - accuracy: 0.8826 - scaled_graph_loss: 0.2881 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0490 - accuracy: 0.8872 - scaled_graph_loss: 0.2972 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0550 - accuracy: 0.8923 - scaled_graph_loss: 0.2935 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0397 - accuracy: 0.8840 - scaled_graph_loss: 0.2795 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0360 - accuracy: 0.8891 - scaled_graph_loss: 0.2966 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0235 - accuracy: 0.8961 - scaled_graph_loss: 0.2890 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0219 - accuracy: 0.8984 - scaled_graph_loss: 0.2965 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0168 - accuracy: 0.9044 - scaled_graph_loss: 0.3023 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0148 - accuracy: 0.9035 - scaled_graph_loss: 0.2984 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9118 - scaled_graph_loss: 0.2888 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0019 - accuracy: 0.9021 - scaled_graph_loss: 0.2877 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9956 - accuracy: 0.9049 - scaled_graph_loss: 0.2912 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9986 - accuracy: 0.9026 - scaled_graph_loss: 0.3040 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9939 - accuracy: 0.9067 - scaled_graph_loss: 0.3016 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9828 - accuracy: 0.9058 - scaled_graph_loss: 0.2877 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9629 - accuracy: 0.9137 - scaled_graph_loss: 0.2844 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9645 - accuracy: 0.9146 - scaled_graph_loss: 0.2933 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9752 - accuracy: 0.9165 - scaled_graph_loss: 0.3013 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9552 - accuracy: 0.9179 - scaled_graph_loss: 0.2865 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9539 - accuracy: 0.9193 - scaled_graph_loss: 0.3044 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9443 - accuracy: 0.9183 - scaled_graph_loss: 0.3010 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9559 - accuracy: 0.9244 - scaled_graph_loss: 0.2987 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9497 - accuracy: 0.9225 - scaled_graph_loss: 0.2979 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9674 - accuracy: 0.9183 - scaled_graph_loss: 0.3034 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9537 - accuracy: 0.9174 - scaled_graph_loss: 0.2834 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9341 - accuracy: 0.9188 - scaled_graph_loss: 0.2939 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9392 - accuracy: 0.9225 - scaled_graph_loss: 0.2998 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9240 - accuracy: 0.9313 - scaled_graph_loss: 0.3022 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9368 - accuracy: 0.9267 - scaled_graph_loss: 0.2979 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9306 - accuracy: 0.9234 - scaled_graph_loss: 0.2952 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9197 - accuracy: 0.9230 - scaled_graph_loss: 0.2916 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9360 - accuracy: 0.9206 - scaled_graph_loss: 0.2947 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9181 - accuracy: 0.9299 - scaled_graph_loss: 0.2996 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9105 - accuracy: 0.9341 - scaled_graph_loss: 0.2981 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9014 - accuracy: 0.9323 - scaled_graph_loss: 0.2897 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9059 - accuracy: 0.9364 - scaled_graph_loss: 0.3083 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9053 - accuracy: 0.9309 - scaled_graph_loss: 0.2976 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9099 - accuracy: 0.9258 - scaled_graph_loss: 0.3069 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9025 - accuracy: 0.9355 - scaled_graph_loss: 0.2890 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8849 - accuracy: 0.9281 - scaled_graph_loss: 0.2933 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8959 - accuracy: 0.9323 - scaled_graph_loss: 0.2918 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9074 - accuracy: 0.9248 - scaled_graph_loss: 0.3065 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8845 - accuracy: 0.9369 - scaled_graph_loss: 0.2874 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8873 - accuracy: 0.9401 - scaled_graph_loss: 0.2996 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8942 - accuracy: 0.9327 - scaled_graph_loss: 0.3086 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9052 - accuracy: 0.9253 - scaled_graph_loss: 0.2986 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8811 - accuracy: 0.9336 - scaled_graph_loss: 0.2948 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8896 - accuracy: 0.9276 - scaled_graph_loss: 0.2919 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8853 - accuracy: 0.9313 - scaled_graph_loss: 0.2944 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8875 - accuracy: 0.9323 - scaled_graph_loss: 0.2925 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8639 - accuracy: 0.9323 - scaled_graph_loss: 0.2967 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8820 - accuracy: 0.9332 - scaled_graph_loss: 0.3047 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8752 - accuracy: 0.9346 - scaled_graph_loss: 0.2942 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9374 - scaled_graph_loss: 0.3066 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8765 - accuracy: 0.9332 - scaled_graph_loss: 0.2881 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8691 - accuracy: 0.9420 - scaled_graph_loss: 0.3030 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8631 - accuracy: 0.9374 - scaled_graph_loss: 0.2916 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8651 - accuracy: 0.9392 - scaled_graph_loss: 0.3032 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8632 - accuracy: 0.9420 - scaled_graph_loss: 0.3019 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8600 - accuracy: 0.9425 - scaled_graph_loss: 0.2965 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8569 - accuracy: 0.9346 - scaled_graph_loss: 0.2977 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8704 - accuracy: 0.9374 - scaled_graph_loss: 0.3083 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8562 - accuracy: 0.9406 - scaled_graph_loss: 0.2883 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8545 - accuracy: 0.9415 - scaled_graph_loss: 0.3030 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8592 - accuracy: 0.9332 - scaled_graph_loss: 0.2927 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8503 - accuracy: 0.9397 - scaled_graph_loss: 0.2927 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8434 - accuracy: 0.9462 - scaled_graph_loss: 0.2937 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8578 - accuracy: 0.9374 - scaled_graph_loss: 0.3064 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8504 - accuracy: 0.9411 - scaled_graph_loss: 0.3043 <keras.callbacks.History at 0x7f70041be650>
تقييم نموذج MLP مع تنظيم الرسم البياني
eval_results = dict(
zip(graph_reg_model.metrics_names,
graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)
5/5 [==============================] - 0s 5ms/step - loss: 0.8884 - accuracy: 0.7957 Eval accuracy for MLP + graph regularization : 0.7956600189208984 Eval loss for MLP + graph regularization : 0.8883611559867859
دقة طراز قننت الرسم البياني هو حوالي 2-3٪ أعلى من نموذج قاعدة ( base_model
).
استنتاج
لقد أظهرنا استخدام تنظيم الرسم البياني لتصنيف المستندات على رسم بياني اقتباس طبيعي (كورا) باستخدام إطار عمل التعلم الهيكلية العصبية (NSL). لدينا تعليمي متطور يشمل تجميع الرسوم البيانية استنادا التضمينات عينة قبل تدريب الشبكة العصبية مع تنظيم الرسم البياني. هذا الأسلوب مفيد إذا كان الإدخال لا يحتوي على رسم بياني صريح.
نحن نشجع المستخدمين على إجراء المزيد من التجارب من خلال تغيير مقدار الإشراف وكذلك تجربة البنى العصبية المختلفة لتنظيم الرسم البياني.