Graph regularization for document classification using natural graphs (original) (raw)

Overview

Graph regularization is a specific technique under the broader paradigm of Neural Graph Learning (Bui et al., 2018). The core idea is to train neural network models with a graph-regularized objective, harnessing both labeled and unlabeled data.

In this tutorial, we will explore the use of graph regularization to classify documents that form a natural (organic) graph.

The general recipe for creating a graph-regularized model using the Neural Structured Learning (NSL) framework is as follows:

  1. Generate training data from the input graph and sample features. Nodes in the graph correspond to samples and edges in the graph correspond to similarity between pairs of samples. The resulting training data will contain neighbor features in addition to the original node features.
  2. Create a neural network as a base model using the Keras sequential, functional, or subclass API.
  3. Wrap the base model with the GraphRegularization wrapper class, which is provided by the NSL framework, to create a new graph Keras model. This new model will include a graph regularization loss as the regularization term in its training objective.
  4. Train and evaluate the graph Keras model.

Setup

Install the Neural Structured Learning package.

pip install --quiet neural-structured-learning

Dependencies and imports

import neural_structured_learning as nsl

import tensorflow as tf

# Resets notebook state
tf.keras.backend.clear_session()

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
    "GPU is",
    "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")

2023-11-16 12:04:49.460421: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-16 12:04:49.460472: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-16 12:04:49.461916: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered Version: 2.15.0 Eager mode: True GPU is NOT AVAILABLE 2023-11-16 12:04:51.768240: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected

Cora dataset

The Cora dataset is a citation graph where nodes represent machine learning papers and edges represent citations between pairs of papers. The task involved is document classification where the goal is to categorize each paper into one of 7 categories. In other words, this is a multi-class classification problem with 7 classes.

Graph

The original graph is directed. However, for the purpose of this example, we consider the undirected version of this graph. So, if paper A cites paper B, we also consider paper B to have cited A. Although this is not necessarily true, in this example, we consider citations as a proxy for similarity, which is usually a commutative property.

Features

Each paper in the input effectively contains 2 features:

  1. Words: A dense, multi-hot bag-of-words representation of the text in the paper. The vocabulary for the Cora dataset contains 1433 unique words. So, the length of this feature is 1433, and the value at position 'i' is 0/1 indicating whether word 'i' in the vocabulary exists in the given paper or not.
  2. Label: A single integer representing the class ID (category) of the paper.

Download the Cora dataset

wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz tar -C /tmp -xvzf /tmp/cora.tgz

cora/ cora/README cora/cora.cites cora/cora.content

Convert the Cora data to the NSL format

In order to preprocess the Cora dataset and convert it to the format required by Neural Structured Learning, we will run the **'preprocess_cora_dataset.py'**script, which is included in the NSL github repository. This script does the following:

  1. Generate neighbor features using the original node features and the graph.
  2. Generate train and test data splits containing tf.train.Example instances.
  3. Persist the resulting train and test data in the TFRecord format.
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py

!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr

--2023-11-16 12:04:52-- https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 11640 (11K) [text/plain] Saving to: ‘preprocess_cora_dataset.py’

preprocess_cora_dat 100%[===================>] 11.37K --.-KB/s in 0s

2023-11-16 12:04:53 (75.6 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640]

2023-11-16 12:04:53.758687: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-16 12:04:53.758743: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-16 12:04:53.760530: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered 2023-11-16 12:04:55.968449: E external/local_xla/xla/stream_executor/cuda/cuda_driver.cc:274] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected Reading graph file: /tmp/cora/cora.cites... Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds). Making all edges bi-directional... Done (0.01 seconds). Total graph nodes: 2708 Joining seed and neighbor tf.train.Examples with graph edges... Done creating and writing 2155 merged tf.train.Examples (1.44 seconds). Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)] Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr. Output test data written to TFRecord file: /tmp/cora/test_examples.tfr. Total running time: 0.05 minutes.

Global variables

The file paths to the train and test data are based on the command line flag values used to invoke the 'preprocess_cora_dataset.py' script above.

### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'

### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'

Hyperparameters

We will use an instance of HParams to include various hyperparameters and constants used for training and evaluation. We briefly describe each of them below:

class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 7
    self.max_seq_length = 1433
    ### neural graph learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    self.num_neighbors = 1
    ### model architecture
    self.num_fc_units = [50, 50]
    ### training parameters
    self.train_epochs = 100
    self.batch_size = 128
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()

Load train and test data

As described earlier in this notebook, the input training and test data have been created by the 'preprocess_cora_dataset.py'. We will load them into twotf.data.Dataset objects -- one for train and one for test.

In the input layer of our model, we will extract not just the 'words' and the 'label' features from each sample, but also corresponding neighbor features based on the hparams.num_neighbors value. Instances with fewer neighbors thanhparams.num_neighbors will be assigned dummy values for those non-existent neighbor features.

def make_dataset(file_path, training=False):
  """Creates a `tf.data.TFRecordDataset`.

  Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

  Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
  """

  def parse_example(example_proto):
    """Extracts relevant fields from the `example_proto`.

    Args:
      example_proto: An instance of `tf.train.Example`.

    Returns:
      A pair whose first value is a dictionary containing relevant features
      and whose second value contains the ground truth label.
    """
    # The 'words' feature is a multi-hot, bag-of-words representation of the
    # original raw text. A default value is required for examples that don't
    # have the feature.
    feature_spec = {
        'words':
            tf.io.FixedLenFeature([HPARAMS.max_seq_length],
                                  tf.int64,
                                  default_value=tf.constant(
                                      0,
                                      dtype=tf.int64,
                                      shape=[HPARAMS.max_seq_length])),
        'label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1),
    }
    # We also extract corresponding neighbor features in a similar manner to
    # the features above during training.
    if training:
      for i in range(HPARAMS.num_neighbors):
        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
                                         NBR_WEIGHT_SUFFIX)
        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
            [HPARAMS.max_seq_length],
            tf.int64,
            default_value=tf.constant(
                0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))

        # We assign a default value of 0.0 for the neighbor weight so that
        # graph regularization is done on samples based on their exact number
        # of neighbors. In other words, non-existent neighbors are discounted.
        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
            [1], tf.float32, default_value=tf.constant([0.0]))

    features = tf.io.parse_single_example(example_proto, feature_spec)

    label = features.pop('label')
    return features, label

  dataset = tf.data.TFRecordDataset([file_path])
  if training:
    dataset = dataset.shuffle(10000)
  dataset = dataset.map(parse_example)
  dataset = dataset.batch(HPARAMS.batch_size)
  return dataset


train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)

Let's peek into the train dataset to look at its contents.

for feature_batch, label_batch in train_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)

Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 1 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor weights: tf.Tensor( [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

              1. 1.], shape=(128,), dtype=float32) Batch of labels: tf.Tensor( [2 2 3 6 6 4 3 1 3 4 2 5 4 5 6 4 1 5 1 0 5 6 3 0 4 2 4 4 1 1 1 6 2 2 5 3 3 5 3 2 0 0 1 5 5 0 4 6 1 4 2 0 2 4 4 1 3 2 2 2 1 2 2 5 2 2 4 1 2 6 1 6 3 0 5 2 6 4 3 2 4 0 2 1 2 2 2 2 2 2 1 1 6 3 2 4 1 2 1 0 3 0 0 3 2 6 1 2 2 1 2 2 2 3 2 0 2 3 2 5 3 0 1 1 2 0 2 1], shape=(128,), dtype=int64)

Let's peek into the test dataset to look at its contents.

for feature_batch, label_batch in test_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  print('Batch of labels:', label_batch)

Feature list: ['words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of labels: tf.Tensor( [5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)

Model definition

In order to demonstrate the use of graph regularization, we build a base model for this problem first. We will use a simple feed-forward neural network with 2 hidden layers and dropout in between. We illustrate the creation of the base model using all model types supported by the tf.Keras framework -- sequential, functional, and subclass.

Sequential base model

def make_mlp_sequential_model(hparams):
  """Creates a sequential multi-layer perceptron model."""
  model = tf.keras.Sequential()
  model.add(
      tf.keras.layers.InputLayer(
          input_shape=(hparams.max_seq_length,), name='words'))
  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  model.add(
      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
  for num_units in hparams.num_fc_units:
    model.add(tf.keras.layers.Dense(num_units, activation='relu'))
    # For sequential models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
  model.add(tf.keras.layers.Dense(hparams.num_classes))
  return model

Functional base model

def make_mlp_functional_model(hparams):
  """Creates a functional API-based multi-layer perceptron model."""
  inputs = tf.keras.Input(
      shape=(hparams.max_seq_length,), dtype='int64', name='words')

  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  cur_layer = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.cast(x, tf.float32))(
          inputs)

  for num_units in hparams.num_fc_units:
    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
    # For functional models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)

  outputs = tf.keras.layers.Dense(hparams.num_classes)(cur_layer)

  model = tf.keras.Model(inputs, outputs=outputs)
  return model

Subclass base model

def make_mlp_subclass_model(hparams):
  """Creates a multi-layer perceptron subclass model in Keras."""

  class MLP(tf.keras.Model):
    """Subclass model defining a multi-layer perceptron."""

    def __init__(self):
      super(MLP, self).__init__()
      # Input is already one-hot encoded in the integer format. We create a
      # layer to cast it to floating point format here.
      self.cast_to_float_layer = tf.keras.layers.Lambda(
          lambda x: tf.keras.backend.cast(x, tf.float32))
      self.dense_layers = [
          tf.keras.layers.Dense(num_units, activation='relu')
          for num_units in hparams.num_fc_units
      ]
      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
      self.output_layer = tf.keras.layers.Dense(hparams.num_classes)

    def call(self, inputs, training=False):
      cur_layer = self.cast_to_float_layer(inputs['words'])
      for dense_layer in self.dense_layers:
        cur_layer = dense_layer(cur_layer)
        cur_layer = self.dropout_layer(cur_layer, training=training)

      outputs = self.output_layer(cur_layer)

      return outputs

  return MLP()

Create base model(s)

# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()

Model: "model" _________________________________________________________________ Layer (type) Output Shape Param #

words (InputLayer) [(None, 1433)] 0

lambda (Lambda) (None, 1433) 0

dense (Dense) (None, 50) 71700

dropout (Dropout) (None, 50) 0

dense_1 (Dense) (None, 50) 2550

dropout_1 (Dropout) (None, 50) 0

dense_2 (Dense) (None, 7) 357

Total params: 74607 (291.43 KB) Trainable params: 74607 (291.43 KB) Non-trainable params: 0 (0.00 Byte)


Train base MLP model

# Compile and train the base MLP model
base_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)

Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/keras/src/engine/functional.py:642: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model. inputs = self._flatten_to_reference_inputs(inputs) 17/17 [==============================] - 1s 6ms/step - loss: 1.9105 - accuracy: 0.2260 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8280 - accuracy: 0.3044 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7240 - accuracy: 0.3299 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5969 - accuracy: 0.3745 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4765 - accuracy: 0.4492 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3235 - accuracy: 0.5276 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1913 - accuracy: 0.5889 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0604 - accuracy: 0.6432 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9628 - accuracy: 0.6821 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8601 - accuracy: 0.7234 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7914 - accuracy: 0.7480 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 0.7230 - accuracy: 0.7633 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 0.6783 - accuracy: 0.7791 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 0.6019 - accuracy: 0.8070 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5587 - accuracy: 0.8367 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 0.5295 - accuracy: 0.8450 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4789 - accuracy: 0.8599 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4474 - accuracy: 0.8650 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 0.4148 - accuracy: 0.8701 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3812 - accuracy: 0.8896 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3656 - accuracy: 0.8863 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3544 - accuracy: 0.8923 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 0.3050 - accuracy: 0.9165 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2858 - accuracy: 0.9216 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2821 - accuracy: 0.9234 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2543 - accuracy: 0.9276 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2477 - accuracy: 0.9285 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2413 - accuracy: 0.9295 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2153 - accuracy: 0.9415 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2241 - accuracy: 0.9290 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2118 - accuracy: 0.9374 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 0.2041 - accuracy: 0.9471 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1951 - accuracy: 0.9392 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1841 - accuracy: 0.9443 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1783 - accuracy: 0.9522 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1742 - accuracy: 0.9485 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1705 - accuracy: 0.9541 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1507 - accuracy: 0.9592 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1513 - accuracy: 0.9555 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1378 - accuracy: 0.9652 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1471 - accuracy: 0.9587 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1309 - accuracy: 0.9661 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1288 - accuracy: 0.9596 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1327 - accuracy: 0.9629 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1170 - accuracy: 0.9675 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1198 - accuracy: 0.9666 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1183 - accuracy: 0.9680 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1025 - accuracy: 0.9740 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0981 - accuracy: 0.9754 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1076 - accuracy: 0.9708 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0874 - accuracy: 0.9796 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.1027 - accuracy: 0.9735 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0993 - accuracy: 0.9740 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0934 - accuracy: 0.9759 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0932 - accuracy: 0.9759 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0787 - accuracy: 0.9810 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0890 - accuracy: 0.9754 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0918 - accuracy: 0.9749 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0908 - accuracy: 0.9717 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0825 - accuracy: 0.9777 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0926 - accuracy: 0.9684 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0702 - accuracy: 0.9800 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0720 - accuracy: 0.9842 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0792 - accuracy: 0.9773 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0760 - accuracy: 0.9782 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0736 - accuracy: 0.9800 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0838 - accuracy: 0.9773 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0639 - accuracy: 0.9824 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0742 - accuracy: 0.9805 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0798 - accuracy: 0.9782 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0694 - accuracy: 0.9805 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0635 - accuracy: 0.9833 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0587 - accuracy: 0.9824 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0689 - accuracy: 0.9828 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0628 - accuracy: 0.9828 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0570 - accuracy: 0.9842 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0632 - accuracy: 0.9824 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0673 - accuracy: 0.9782 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0573 - accuracy: 0.9828 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0640 - accuracy: 0.9824 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0610 - accuracy: 0.9810 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0553 - accuracy: 0.9861 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0482 - accuracy: 0.9879 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0548 - accuracy: 0.9842 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0537 - accuracy: 0.9865 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0540 - accuracy: 0.9828 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0528 - accuracy: 0.9838 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0505 - accuracy: 0.9865 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0473 - accuracy: 0.9833 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0604 - accuracy: 0.9810 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0469 - accuracy: 0.9879 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0554 - accuracy: 0.9810 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0427 - accuracy: 0.9875 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0581 - accuracy: 0.9824 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0488 - accuracy: 0.9842 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0466 - accuracy: 0.9875 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0465 - accuracy: 0.9875 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0411 - accuracy: 0.9879 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0539 - accuracy: 0.9852 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.0451 - accuracy: 0.9870 <keras.src.callbacks.History at 0x7f459c2e9e50>

Evaluate base MLP model

# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
  """Prints evaluation metrics.

  Args:
    model_desc: A description of the model.
    eval_metrics: A dictionary mapping metric names to corresponding values. It
      must contain the loss and accuracy metrics.
  """
  print('\n')
  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
  if 'graph_loss' in eval_metrics:
    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
    zip(base_model.metrics_names,
        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)

5/5 [==============================] - 0s 5ms/step - loss: 1.4164 - accuracy: 0.7758

Eval accuracy for Base MLP model : 0.775768518447876 Eval loss for Base MLP model : 1.4164185523986816

Train MLP model with graph regularization

Incorporating graph regularization into the loss term of an existingtf.Keras.Model requires just a few lines of code. The base model is wrapped to create a new tf.Keras subclass model, whose loss includes graph regularization.

To assess the incremental benefit of graph regularization, we will create a new base model instance. This is because base_model has already been trained for a few iterations, and reusing this trained model to create a graph-regularized model will not be a fair comparison for base_model.

# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
    HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)

Epoch 1/100 17/17 [==============================] - 2s 7ms/step - loss: 1.9586 - accuracy: 0.2107 - scaled_graph_loss: 0.0319 Epoch 2/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8903 - accuracy: 0.2942 - scaled_graph_loss: 0.0282 Epoch 3/100 17/17 [==============================] - 0s 3ms/step - loss: 1.8290 - accuracy: 0.3262 - scaled_graph_loss: 0.0411 Epoch 4/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7762 - accuracy: 0.3248 - scaled_graph_loss: 0.0604 Epoch 5/100 17/17 [==============================] - 0s 3ms/step - loss: 1.7334 - accuracy: 0.3568 - scaled_graph_loss: 0.0792 Epoch 6/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6859 - accuracy: 0.3735 - scaled_graph_loss: 0.0920 Epoch 7/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6506 - accuracy: 0.3935 - scaled_graph_loss: 0.1086 Epoch 8/100 17/17 [==============================] - 0s 3ms/step - loss: 1.6028 - accuracy: 0.4520 - scaled_graph_loss: 0.1249 Epoch 9/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5690 - accuracy: 0.5012 - scaled_graph_loss: 0.1386 Epoch 10/100 17/17 [==============================] - 0s 3ms/step - loss: 1.5332 - accuracy: 0.5420 - scaled_graph_loss: 0.1577 Epoch 11/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4792 - accuracy: 0.5842 - scaled_graph_loss: 0.1642 Epoch 12/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4438 - accuracy: 0.6306 - scaled_graph_loss: 0.1909 Epoch 13/100 17/17 [==============================] - 0s 3ms/step - loss: 1.4155 - accuracy: 0.6617 - scaled_graph_loss: 0.2009 Epoch 14/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3596 - accuracy: 0.6896 - scaled_graph_loss: 0.1964 Epoch 15/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3462 - accuracy: 0.7077 - scaled_graph_loss: 0.2294 Epoch 16/100 17/17 [==============================] - 0s 3ms/step - loss: 1.3151 - accuracy: 0.7295 - scaled_graph_loss: 0.2312 Epoch 17/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2848 - accuracy: 0.7555 - scaled_graph_loss: 0.2319 Epoch 18/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2643 - accuracy: 0.7759 - scaled_graph_loss: 0.2469 Epoch 19/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2434 - accuracy: 0.7921 - scaled_graph_loss: 0.2544 Epoch 20/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2005 - accuracy: 0.8093 - scaled_graph_loss: 0.2473 Epoch 21/100 17/17 [==============================] - 0s 3ms/step - loss: 1.2007 - accuracy: 0.8070 - scaled_graph_loss: 0.2688 Epoch 22/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1876 - accuracy: 0.8135 - scaled_graph_loss: 0.2708 Epoch 23/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1729 - accuracy: 0.8274 - scaled_graph_loss: 0.2662 Epoch 24/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1543 - accuracy: 0.8376 - scaled_graph_loss: 0.2707 Epoch 25/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1228 - accuracy: 0.8538 - scaled_graph_loss: 0.2677 Epoch 26/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1166 - accuracy: 0.8603 - scaled_graph_loss: 0.2785 Epoch 27/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1176 - accuracy: 0.8473 - scaled_graph_loss: 0.2807 Epoch 28/100 17/17 [==============================] - 0s 3ms/step - loss: 1.1085 - accuracy: 0.8473 - scaled_graph_loss: 0.2649 Epoch 29/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0751 - accuracy: 0.8691 - scaled_graph_loss: 0.2858 Epoch 30/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0851 - accuracy: 0.8696 - scaled_graph_loss: 0.2996 Epoch 31/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0932 - accuracy: 0.8770 - scaled_graph_loss: 0.2892 Epoch 32/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0619 - accuracy: 0.8821 - scaled_graph_loss: 0.2880 Epoch 33/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0531 - accuracy: 0.8886 - scaled_graph_loss: 0.2847 Epoch 34/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0558 - accuracy: 0.8863 - scaled_graph_loss: 0.2962 Epoch 35/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0375 - accuracy: 0.8891 - scaled_graph_loss: 0.2780 Epoch 36/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0310 - accuracy: 0.8858 - scaled_graph_loss: 0.2932 Epoch 37/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0269 - accuracy: 0.8872 - scaled_graph_loss: 0.2916 Epoch 38/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0273 - accuracy: 0.8928 - scaled_graph_loss: 0.2948 Epoch 39/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9935 - accuracy: 0.9123 - scaled_graph_loss: 0.2910 Epoch 40/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0083 - accuracy: 0.9104 - scaled_graph_loss: 0.2951 Epoch 41/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0196 - accuracy: 0.8951 - scaled_graph_loss: 0.2982 Epoch 42/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9941 - accuracy: 0.9007 - scaled_graph_loss: 0.2898 Epoch 43/100 17/17 [==============================] - 0s 3ms/step - loss: 1.0069 - accuracy: 0.9012 - scaled_graph_loss: 0.3076 Epoch 44/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9816 - accuracy: 0.9049 - scaled_graph_loss: 0.2930 Epoch 45/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9910 - accuracy: 0.9104 - scaled_graph_loss: 0.2954 Epoch 46/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9949 - accuracy: 0.9026 - scaled_graph_loss: 0.3111 Epoch 47/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9715 - accuracy: 0.9114 - scaled_graph_loss: 0.2830 Epoch 48/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9796 - accuracy: 0.9067 - scaled_graph_loss: 0.2970 Epoch 49/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9570 - accuracy: 0.9114 - scaled_graph_loss: 0.2936 Epoch 50/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9691 - accuracy: 0.9049 - scaled_graph_loss: 0.2940 Epoch 51/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9803 - accuracy: 0.9114 - scaled_graph_loss: 0.3083 Epoch 52/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9612 - accuracy: 0.9128 - scaled_graph_loss: 0.2860 Epoch 53/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9627 - accuracy: 0.9216 - scaled_graph_loss: 0.3077 Epoch 54/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9516 - accuracy: 0.9151 - scaled_graph_loss: 0.2906 Epoch 55/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9431 - accuracy: 0.9197 - scaled_graph_loss: 0.2967 Epoch 56/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9622 - accuracy: 0.9132 - scaled_graph_loss: 0.3053 Epoch 57/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9410 - accuracy: 0.9188 - scaled_graph_loss: 0.2830 Epoch 58/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9531 - accuracy: 0.9230 - scaled_graph_loss: 0.3049 Epoch 59/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9309 - accuracy: 0.9193 - scaled_graph_loss: 0.3009 Epoch 60/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9300 - accuracy: 0.9248 - scaled_graph_loss: 0.2988 Epoch 61/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9173 - accuracy: 0.9244 - scaled_graph_loss: 0.2884 Epoch 62/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9228 - accuracy: 0.9248 - scaled_graph_loss: 0.2960 Epoch 63/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9394 - accuracy: 0.9174 - scaled_graph_loss: 0.3102 Epoch 64/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9182 - accuracy: 0.9174 - scaled_graph_loss: 0.2899 Epoch 65/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9276 - accuracy: 0.9253 - scaled_graph_loss: 0.2996 Epoch 66/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9229 - accuracy: 0.9244 - scaled_graph_loss: 0.2912 Epoch 67/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9325 - accuracy: 0.9142 - scaled_graph_loss: 0.3088 Epoch 68/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9091 - accuracy: 0.9216 - scaled_graph_loss: 0.2883 Epoch 69/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8987 - accuracy: 0.9267 - scaled_graph_loss: 0.2924 Epoch 70/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9188 - accuracy: 0.9216 - scaled_graph_loss: 0.2970 Epoch 71/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9003 - accuracy: 0.9299 - scaled_graph_loss: 0.2962 Epoch 72/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9086 - accuracy: 0.9206 - scaled_graph_loss: 0.2944 Epoch 73/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9047 - accuracy: 0.9304 - scaled_graph_loss: 0.3174 Epoch 74/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9214 - accuracy: 0.9202 - scaled_graph_loss: 0.2923 Epoch 75/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9081 - accuracy: 0.9276 - scaled_graph_loss: 0.3020 Epoch 76/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9043 - accuracy: 0.9220 - scaled_graph_loss: 0.2892 Epoch 77/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9022 - accuracy: 0.9253 - scaled_graph_loss: 0.2998 Epoch 78/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8871 - accuracy: 0.9332 - scaled_graph_loss: 0.2979 Epoch 79/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8863 - accuracy: 0.9295 - scaled_graph_loss: 0.3021 Epoch 80/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8893 - accuracy: 0.9225 - scaled_graph_loss: 0.2928 Epoch 81/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8850 - accuracy: 0.9258 - scaled_graph_loss: 0.2997 Epoch 82/100 17/17 [==============================] - 0s 3ms/step - loss: 0.9013 - accuracy: 0.9165 - scaled_graph_loss: 0.2961 Epoch 83/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8739 - accuracy: 0.9253 - scaled_graph_loss: 0.2886 Epoch 84/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8840 - accuracy: 0.9318 - scaled_graph_loss: 0.3040 Epoch 85/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8628 - accuracy: 0.9378 - scaled_graph_loss: 0.2886 Epoch 86/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8745 - accuracy: 0.9313 - scaled_graph_loss: 0.3013 Epoch 87/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8678 - accuracy: 0.9327 - scaled_graph_loss: 0.2980 Epoch 88/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8614 - accuracy: 0.9397 - scaled_graph_loss: 0.2947 Epoch 89/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8589 - accuracy: 0.9327 - scaled_graph_loss: 0.2957 Epoch 90/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8688 - accuracy: 0.9346 - scaled_graph_loss: 0.2996 Epoch 91/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8661 - accuracy: 0.9216 - scaled_graph_loss: 0.2881 Epoch 92/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8828 - accuracy: 0.9318 - scaled_graph_loss: 0.3019 Epoch 93/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8701 - accuracy: 0.9374 - scaled_graph_loss: 0.3051 Epoch 94/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8572 - accuracy: 0.9383 - scaled_graph_loss: 0.2998 Epoch 95/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8765 - accuracy: 0.9327 - scaled_graph_loss: 0.2999 Epoch 96/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8685 - accuracy: 0.9336 - scaled_graph_loss: 0.3013 Epoch 97/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8710 - accuracy: 0.9378 - scaled_graph_loss: 0.3023 Epoch 98/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8746 - accuracy: 0.9327 - scaled_graph_loss: 0.2956 Epoch 99/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8642 - accuracy: 0.9341 - scaled_graph_loss: 0.2984 Epoch 100/100 17/17 [==============================] - 0s 3ms/step - loss: 0.8638 - accuracy: 0.9318 - scaled_graph_loss: 0.2965 <keras.src.callbacks.History at 0x7f445862f130>

Evaluate MLP model with graph regularization

eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)

5/5 [==============================] - 0s 5ms/step - loss: 0.8791 - accuracy: 0.7993

Eval accuracy for MLP + graph regularization : 0.7992766499519348 Eval loss for MLP + graph regularization : 0.8790676593780518

The graph-regularized model's accuracy is about 2-3% higher than that of the base model (base_model).

Conclusion

We have demonstrated the use of graph regularization for document classification on a natural citation graph (Cora) using the Neural Structured Learning (NSL) framework. Our advanced tutorial involves synthesizing graphs based on sample embeddings before training a neural network with graph regularization. This approach is useful if the input does not contain an explicit graph.

We encourage users to experiment further by varying the amount of supervision as well as trying different neural architectures for graph regularization.