自然グラフを用いた文書分類のためのグラフ正則化 (original) (raw)

メイン コンテンツにスキップ

自然グラフを用いた文書分類のためのグラフ正則化

概要

グラフ正則化は、Neural Graph Learning(Bui et al.、2018)の広範なパラダイムに基づく固有の手法です。その中核となる考え方は、グラフ正則化された対象を持つニューラルネットワークモデルを、ラベル付けされたデータとラベル付けされていないデータの両方を使用してトレーニングすることです。

このチュートリアルでは、自然な(有機的な)グラフを形成する文書を分類するためにグラフ正則化を使用することについて見ていきます。

Neural Structured Learning(NSL)フレームワークを使用してグラフ正則化モデルを作成する、一般的な方策は以下の通りです。

  1. 入力グラフとサンプル特徴からトレーニングデータを生成します。グラフのノードはサンプルに対応し、グラフのエッジはサンプルのペア間の類似性に対応します。結果として得られるトレーニングデータには、元のノード特徴に加え、近傍特徴が含まれます。
  2. Keras Sequential API、Functional API、または Subclass API を使用して、基本モデルとしてニューラルネットワークを作成します。
  3. NSL フレームワークが提供する GraphRegularization ラッパークラスで基本モデルをラップし、新しいグラフ Keras モデルを作成します。この新しいモデルは、トレーニング目的の正則化項にグラフ正則化損失を含みます。
  4. グラフ Keras モデルをトレーニングして評価します。

セットアップ

Neural Structured Learning パッケージをインストールします。

pip install --quiet neural-structured-learning

依存関係とインポート

import neural_structured_learning as nsl

import tensorflow as tf

# Resets notebook state
tf.keras.backend.clear_session()

print("Version: ", tf.__version__)
print("Eager mode: ", tf.executing_eagerly())
print(
    "GPU is",
    "available" if tf.config.list_physical_devices("GPU") else "NOT AVAILABLE")

Version: 2.4.1 Eager mode: True GPU is available

Cora データセット

Cora データセットはノードが機械学習論文を表現し、エッジが論文のペア間の引用を表現する引用グラフです。このタスクには、各論文を 7 つのカテゴリのいずれかに分類することを目的とする文書分類が含まれています。言い換えると、これは 7 つのクラスを持つマルチクラス分類問題です。

グラフ

元のグラフは有向グラフです。しかし、この例の目的のためには無向グラフを考慮します。つまり、論文 A が論文 B を引用している場合、論文 B も論文 A を引用していると考えます。これは必ずしも正しいわけではありませんが、この例においては、引用を類似性のプロキシとみなしており、通常は可換性を持つとみなされます。

特徴

入力された各論文には、次の 2 つの特徴が効果的に含まれています。

  1. 単語:論文中のテキストを密でマルチホットな Bag of Words(BoW)表現にしたもの。Cora データセットの語彙には 1433 個のユニークな単語が含まれています。つまり、この特徴の長さは 1433 で、 'i' の位置の値は語彙中の単語 'i' が論文中に存在するかどうかを示す 0 か 1 です。
  2. ラベル: 論文のクラス ID(カテゴリ)を表現する単一の整数。

Cora データセットをダウンロードする

wget --quiet -P /tmp https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz tar -C /tmp -xvzf /tmp/cora.tgz

cora/ cora/README cora/cora.cites cora/cora.content

Cora データセットを NSL 形式に変換する

Cora データセットを前処理し、Neural Structured Learning に必要な形式に変換するために、NSL の GitHub リポジトリに含まれている 'preprocess_cora_dataset.py' スクリプトを実行します。このスクリプトは以下を行います。

  1. 元のノード特徴とグラフを使用して近傍特徴を生成します。
  2. tf.train.Example インスタンスを含むトレーニングデータとテストデータの分割を生成します。
  3. 結果として得られたトレーニングデータとテストデータを TFRecord 形式で永続化します。
!wget https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py

!python preprocess_cora_dataset.py \
--input_cora_content=/tmp/cora/cora.content \
--input_cora_graph=/tmp/cora/cora.cites \
--max_nbrs=5 \
--output_train_data=/tmp/cora/train_merged_examples.tfr \
--output_test_data=/tmp/cora/test_examples.tfr

--2021-02-12 22:29:54-- https://raw.githubusercontent.com/tensorflow/neural-structured-learning/master/neural_structured_learning/examples/preprocess/cora/preprocess_cora_dataset.py Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.110.133, 185.199.108.133, ... Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected. HTTP request sent, awaiting response... 200 OK Length: 11640 (11K) [text/plain] Saving to: ‘preprocess_cora_dataset.py’

preprocess_cora_dat 100%[===================>] 11.37K --.-KB/s in 0.001s

2021-02-12 22:29:54 (19.3 MB/s) - ‘preprocess_cora_dataset.py’ saved [11640/11640]

2021-02-12 22:29:55.197371: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0 Reading graph file: /tmp/cora/cora.cites... Done reading 5429 edges from: /tmp/cora/cora.cites (0.01 seconds). Making all edges bi-directional... Done (0.01 seconds). Total graph nodes: 2708 Joining seed and neighbor tf.train.Examples with graph edges... Done creating and writing 2155 merged tf.train.Examples (1.36 seconds). Out-degree histogram: [(1, 386), (2, 468), (3, 452), (4, 309), (5, 540)] Output training data written to TFRecord file: /tmp/cora/train_merged_examples.tfr. Output test data written to TFRecord file: /tmp/cora/test_examples.tfr. Total running time: 0.04 minutes.

グローバル変数

トレーニングデータとテストデータへのファイルパスは、上記 'preprocess_cora_dataset.py' スクリプトの呼び出しに使用したコマンドラインフラグの値に基づきます。

### Experiment dataset
TRAIN_DATA_PATH = '/tmp/cora/train_merged_examples.tfr'
TEST_DATA_PATH = '/tmp/cora/test_examples.tfr'

### Constants used to identify neighbor features in the input.
NBR_FEATURE_PREFIX = 'NL_nbr_'
NBR_WEIGHT_SUFFIX = '_weight'

ハイパーパラメータ

HParams のインスタンスを使用して、トレーニングと評価に使用する様々なハイパーパラメータと定数をインクルードします。それぞれについての簡単な説明を以下に示します。

class HParams(object):
  """Hyperparameters used for training."""
  def __init__(self):
    ### dataset parameters
    self.num_classes = 7
    self.max_seq_length = 1433
    ### neural graph learning parameters
    self.distance_type = nsl.configs.DistanceType.L2
    self.graph_regularization_multiplier = 0.1
    self.num_neighbors = 1
    ### model architecture
    self.num_fc_units = [50, 50]
    ### training parameters
    self.train_epochs = 100
    self.batch_size = 128
    self.dropout_rate = 0.5
    ### eval parameters
    self.eval_steps = None  # All instances in the test set are evaluated.

HPARAMS = HParams()

トレーニングデータとテストデータを読み込む

先にこのノートブックで説明したように、入力のトレーニングデータとテストデータは 'preprocess_cora_dataset.py' によって作成されています。これらのデータを 2 つの tf.data.Dataset オブジェクトに読み込みます。1 つはトレーニング用、もう1 つはテスト用です。

モデルの入力レイヤー内では、各サンプルから「単語」と「ラベル」の特徴だけでなく、hparams.num_neighbors の値に基づき対応する近傍特徴も抽出します。近傍が hparams.num_neighbors よりも少ないインスタンスでは、存在しない近傍特徴にダミー値を割り当てます。

def make_dataset(file_path, training=False):
  """Creates a `tf.data.TFRecordDataset`.

  Args:
    file_path: Name of the file in the `.tfrecord` format containing
      `tf.train.Example` objects.
    training: Boolean indicating if we are in training mode.

  Returns:
    An instance of `tf.data.TFRecordDataset` containing the `tf.train.Example`
    objects.
  """

  def parse_example(example_proto):
    """Extracts relevant fields from the `example_proto`.

    Args:
      example_proto: An instance of `tf.train.Example`.

    Returns:
      A pair whose first value is a dictionary containing relevant features
      and whose second value contains the ground truth label.
    """
    # The 'words' feature is a multi-hot, bag-of-words representation of the
    # original raw text. A default value is required for examples that don't
    # have the feature.
    feature_spec = {
        'words':
            tf.io.FixedLenFeature([HPARAMS.max_seq_length],
                                  tf.int64,
                                  default_value=tf.constant(
                                      0,
                                      dtype=tf.int64,
                                      shape=[HPARAMS.max_seq_length])),
        'label':
            tf.io.FixedLenFeature((), tf.int64, default_value=-1),
    }
    # We also extract corresponding neighbor features in a similar manner to
    # the features above during training.
    if training:
      for i in range(HPARAMS.num_neighbors):
        nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, i, 'words')
        nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, i,
                                         NBR_WEIGHT_SUFFIX)
        feature_spec[nbr_feature_key] = tf.io.FixedLenFeature(
            [HPARAMS.max_seq_length],
            tf.int64,
            default_value=tf.constant(
                0, dtype=tf.int64, shape=[HPARAMS.max_seq_length]))

        # We assign a default value of 0.0 for the neighbor weight so that
        # graph regularization is done on samples based on their exact number
        # of neighbors. In other words, non-existent neighbors are discounted.
        feature_spec[nbr_weight_key] = tf.io.FixedLenFeature(
            [1], tf.float32, default_value=tf.constant([0.0]))

    features = tf.io.parse_single_example(example_proto, feature_spec)

    label = features.pop('label')
    return features, label

  dataset = tf.data.TFRecordDataset([file_path])
  if training:
    dataset = dataset.shuffle(10000)
  dataset = dataset.map(parse_example)
  dataset = dataset.batch(HPARAMS.batch_size)
  return dataset


train_dataset = make_dataset(TRAIN_DATA_PATH, training=True)
test_dataset = make_dataset(TEST_DATA_PATH)

トレーニングデータセットの中身を覗いてみましょう。

for feature_batch, label_batch in train_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  nbr_feature_key = '{}{}_{}'.format(NBR_FEATURE_PREFIX, 0, 'words')
  nbr_weight_key = '{}{}{}'.format(NBR_FEATURE_PREFIX, 0, NBR_WEIGHT_SUFFIX)
  print('Batch of neighbor inputs:', feature_batch[nbr_feature_key])
  print('Batch of neighbor weights:',
        tf.reshape(feature_batch[nbr_weight_key], [-1]))
  print('Batch of labels:', label_batch)

Feature list: ['NL_nbr_0_weight', 'NL_nbr_0_words', 'words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of neighbor weights: tf.Tensor( [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.

              1. 1.], shape=(128,), dtype=float32) Batch of labels: tf.Tensor( [1 3 3 6 3 3 2 0 3 2 6 3 2 1 2 3 2 3 1 1 2 0 5 2 2 2 1 2 2 2 0 2 0 1 1 6 6 5 6 5 6 5 1 0 1 4 1 5 1 3 3 0 6 1 1 2 6 5 0 3 6 4 2 6 2 3 2 3 2 0 3 1 2 2 0 2 3 4 1 2 0 4 6 2 4 3 3 4 0 1 3 3 3 6 2 6 1 1 2 0 3 3 2 5 4 4 1 3 1 3 5 3 3 5 2 6 2 3 5 3 0 3 1 6 1 1 3 3], shape=(128,), dtype=int64)

テストデータセットの中身を覗いてみましょう。

for feature_batch, label_batch in test_dataset.take(1):
  print('Feature list:', list(feature_batch.keys()))
  print('Batch of inputs:', feature_batch['words'])
  print('Batch of labels:', label_batch)

Feature list: ['words'] Batch of inputs: tf.Tensor( [[0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] ... [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0] [0 0 0 ... 0 0 0]], shape=(128, 1433), dtype=int64) Batch of labels: tf.Tensor( [5 2 2 2 1 2 6 3 2 3 6 1 3 6 4 4 2 3 3 0 2 0 5 2 1 0 6 3 6 4 2 2 3 0 4 2 2 2 2 3 2 2 2 0 2 2 2 2 4 2 3 4 0 2 6 2 1 4 2 0 0 1 4 2 6 0 5 2 2 3 2 5 2 5 2 3 2 2 2 2 2 6 6 3 2 4 2 6 3 2 2 6 2 4 2 2 1 3 4 6 0 0 2 4 2 1 3 6 6 2 6 6 6 1 4 6 4 3 6 6 0 0 2 6 2 4 0 0], shape=(128,), dtype=int64)

モデルの定義

グラフ正則化の使い方を実証するために、まずこの問題の基本モデルを構築します。2 つの隠れレイヤーとその間にドロップアウトを持つ単純なフィードフォワード ニューラルネットワークを使用します。ここでは tf.Keras フレームワークでサポートされているすべてのモデルタイプ(Sequential モデル、Functional モデル、Subclass モデル)を使用して基本モデルを作成します。

Sequential 基本モデル

def make_mlp_sequential_model(hparams):
  """Creates a sequential multi-layer perceptron model."""
  model = tf.keras.Sequential()
  model.add(
      tf.keras.layers.InputLayer(
          input_shape=(hparams.max_seq_length,), name='words'))
  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  model.add(
      tf.keras.layers.Lambda(lambda x: tf.keras.backend.cast(x, tf.float32)))
  for num_units in hparams.num_fc_units:
    model.add(tf.keras.layers.Dense(num_units, activation='relu'))
    # For sequential models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    model.add(tf.keras.layers.Dropout(hparams.dropout_rate))
  model.add(tf.keras.layers.Dense(hparams.num_classes, activation='softmax'))
  return model

Functional 基本モデル

def make_mlp_functional_model(hparams):
  """Creates a functional API-based multi-layer perceptron model."""
  inputs = tf.keras.Input(
      shape=(hparams.max_seq_length,), dtype='int64', name='words')

  # Input is already one-hot encoded in the integer format. We cast it to
  # floating point format here.
  cur_layer = tf.keras.layers.Lambda(
      lambda x: tf.keras.backend.cast(x, tf.float32))(
          inputs)

  for num_units in hparams.num_fc_units:
    cur_layer = tf.keras.layers.Dense(num_units, activation='relu')(cur_layer)
    # For functional models, by default, Keras ensures that the 'dropout' layer
    # is invoked only during training.
    cur_layer = tf.keras.layers.Dropout(hparams.dropout_rate)(cur_layer)

  outputs = tf.keras.layers.Dense(
      hparams.num_classes, activation='softmax')(
          cur_layer)

  model = tf.keras.Model(inputs, outputs=outputs)
  return model

Subclass 基本モデル

def make_mlp_subclass_model(hparams):
  """Creates a multi-layer perceptron subclass model in Keras."""

  class MLP(tf.keras.Model):
    """Subclass model defining a multi-layer perceptron."""

    def __init__(self):
      super(MLP, self).__init__()
      # Input is already one-hot encoded in the integer format. We create a
      # layer to cast it to floating point format here.
      self.cast_to_float_layer = tf.keras.layers.Lambda(
          lambda x: tf.keras.backend.cast(x, tf.float32))
      self.dense_layers = [
          tf.keras.layers.Dense(num_units, activation='relu')
          for num_units in hparams.num_fc_units
      ]
      self.dropout_layer = tf.keras.layers.Dropout(hparams.dropout_rate)
      self.output_layer = tf.keras.layers.Dense(
          hparams.num_classes, activation='softmax')

    def call(self, inputs, training=False):
      cur_layer = self.cast_to_float_layer(inputs['words'])
      for dense_layer in self.dense_layers:
        cur_layer = dense_layer(cur_layer)
        cur_layer = self.dropout_layer(cur_layer, training=training)

      outputs = self.output_layer(cur_layer)

      return outputs

  return MLP()

基本モデルを作成する

# Create a base MLP model using the functional API.
# Alternatively, you can also create a sequential or subclass base model using
# the make_mlp_sequential_model() or make_mlp_subclass_model() functions
# respectively, defined above. Note that if a subclass model is used, its
# summary cannot be generated until it is built.
base_model_tag, base_model = 'FUNCTIONAL', make_mlp_functional_model(HPARAMS)
base_model.summary()

Model: "model" _________________________________________________________________ Layer (type) Output Shape Param #

words (InputLayer) [(None, 1433)] 0
_________________________________________________________________ lambda (Lambda) (None, 1433) 0
_________________________________________________________________ dense (Dense) (None, 50) 71700
_________________________________________________________________ dropout (Dropout) (None, 50) 0
_________________________________________________________________ dense_1 (Dense) (None, 50) 2550
_________________________________________________________________ dropout_1 (Dropout) (None, 50) 0
_________________________________________________________________ dense_2 (Dense) (None, 7) 357

Total params: 74,607 Trainable params: 74,607 Non-trainable params: 0


基本 MLP モデルをトレーニングする

# Compile and train the base MLP model
base_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
base_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)

Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/keras/engine/functional.py:595: UserWarning: Input dict contained keys ['NL_nbr_0_weight', 'NL_nbr_0_words'] which did not match any model input. They will be ignored by the model. [n for n in tensors.keys() if n not in ref_input_names]) 17/17 [==============================] - 1s 11ms/step - loss: 1.9608 - accuracy: 0.1465 Epoch 2/100 17/17 [==============================] - 0s 11ms/step - loss: 1.8578 - accuracy: 0.2793 Epoch 3/100 17/17 [==============================] - 0s 11ms/step - loss: 1.7744 - accuracy: 0.3468 Epoch 4/100 17/17 [==============================] - 0s 10ms/step - loss: 1.6850 - accuracy: 0.3542 Epoch 5/100 17/17 [==============================] - 0s 11ms/step - loss: 1.5511 - accuracy: 0.4065 Epoch 6/100 17/17 [==============================] - 0s 11ms/step - loss: 1.3826 - accuracy: 0.5161 Epoch 7/100 17/17 [==============================] - 0s 11ms/step - loss: 1.2052 - accuracy: 0.5874 Epoch 8/100 17/17 [==============================] - 0s 11ms/step - loss: 1.0876 - accuracy: 0.6437 Epoch 9/100 17/17 [==============================] - 0s 10ms/step - loss: 0.9621 - accuracy: 0.6866 Epoch 10/100 17/17 [==============================] - 0s 11ms/step - loss: 0.8881 - accuracy: 0.7042 Epoch 11/100 17/17 [==============================] - 0s 11ms/step - loss: 0.8042 - accuracy: 0.7365 Epoch 12/100 17/17 [==============================] - 0s 11ms/step - loss: 0.7164 - accuracy: 0.7680 Epoch 13/100 17/17 [==============================] - 0s 11ms/step - loss: 0.6374 - accuracy: 0.8080 Epoch 14/100 17/17 [==============================] - 0s 10ms/step - loss: 0.5826 - accuracy: 0.8164 Epoch 15/100 17/17 [==============================] - 0s 12ms/step - loss: 0.5169 - accuracy: 0.8426 Epoch 16/100 17/17 [==============================] - 0s 11ms/step - loss: 0.5486 - accuracy: 0.8348 Epoch 17/100 17/17 [==============================] - 0s 11ms/step - loss: 0.4695 - accuracy: 0.8565 Epoch 18/100 17/17 [==============================] - 0s 10ms/step - loss: 0.4333 - accuracy: 0.8688 Epoch 19/100 17/17 [==============================] - 0s 10ms/step - loss: 0.4156 - accuracy: 0.8735 Epoch 20/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3798 - accuracy: 0.8881 Epoch 21/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3469 - accuracy: 0.9021 Epoch 22/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3103 - accuracy: 0.9090 Epoch 23/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3284 - accuracy: 0.8891 Epoch 24/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2758 - accuracy: 0.9196 Epoch 25/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2780 - accuracy: 0.9124 Epoch 26/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2244 - accuracy: 0.9427 Epoch 27/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2555 - accuracy: 0.9215 Epoch 28/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2261 - accuracy: 0.9410 Epoch 29/100 17/17 [==============================] - 0s 10ms/step - loss: 0.2545 - accuracy: 0.9228 Epoch 30/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2161 - accuracy: 0.9354 Epoch 31/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2065 - accuracy: 0.9445 Epoch 32/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2176 - accuracy: 0.9336 Epoch 33/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2013 - accuracy: 0.9421 Epoch 34/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1651 - accuracy: 0.9513 Epoch 35/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1848 - accuracy: 0.9514 Epoch 36/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1634 - accuracy: 0.9558 Epoch 37/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1605 - accuracy: 0.9598 Epoch 38/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1438 - accuracy: 0.9651 Epoch 39/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1602 - accuracy: 0.9569 Epoch 40/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1487 - accuracy: 0.9576 Epoch 41/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1324 - accuracy: 0.9742 Epoch 42/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1173 - accuracy: 0.9698 Epoch 43/100 17/17 [==============================] - 0s 12ms/step - loss: 0.1148 - accuracy: 0.9690 Epoch 44/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1214 - accuracy: 0.9672 Epoch 45/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1289 - accuracy: 0.9645 Epoch 46/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1255 - accuracy: 0.9628 Epoch 47/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1151 - accuracy: 0.9697 Epoch 48/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1153 - accuracy: 0.9672 Epoch 49/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1074 - accuracy: 0.9681 Epoch 50/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1201 - accuracy: 0.9616 Epoch 51/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1033 - accuracy: 0.9784 Epoch 52/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0972 - accuracy: 0.9701 Epoch 53/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1056 - accuracy: 0.9733 Epoch 54/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1073 - accuracy: 0.9707 Epoch 55/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0989 - accuracy: 0.9705 Epoch 56/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0868 - accuracy: 0.9787 Epoch 57/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0956 - accuracy: 0.9745 Epoch 58/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0833 - accuracy: 0.9805 Epoch 59/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0848 - accuracy: 0.9805 Epoch 60/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1015 - accuracy: 0.9743 Epoch 61/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0894 - accuracy: 0.9735 Epoch 62/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0755 - accuracy: 0.9780 Epoch 63/100 17/17 [==============================] - 0s 12ms/step - loss: 0.0736 - accuracy: 0.9793 Epoch 64/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0867 - accuracy: 0.9751 Epoch 65/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0708 - accuracy: 0.9783 Epoch 66/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0712 - accuracy: 0.9784 Epoch 67/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0678 - accuracy: 0.9816 Epoch 68/100 17/17 [==============================] - 0s 12ms/step - loss: 0.0697 - accuracy: 0.9771 Epoch 69/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0960 - accuracy: 0.9764 Epoch 70/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0707 - accuracy: 0.9809 Epoch 71/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0818 - accuracy: 0.9771 Epoch 72/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0739 - accuracy: 0.9775 Epoch 73/100 17/17 [==============================] - 0s 12ms/step - loss: 0.0710 - accuracy: 0.9796 Epoch 74/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0621 - accuracy: 0.9824 Epoch 75/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0567 - accuracy: 0.9881 Epoch 76/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0496 - accuracy: 0.9890 Epoch 77/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0609 - accuracy: 0.9837 Epoch 78/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0779 - accuracy: 0.9812 Epoch 79/100 17/17 [==============================] - 0s 12ms/step - loss: 0.0591 - accuracy: 0.9837 Epoch 80/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0673 - accuracy: 0.9791 Epoch 81/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0568 - accuracy: 0.9839 Epoch 82/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0568 - accuracy: 0.9830 Epoch 83/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0635 - accuracy: 0.9830 Epoch 84/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0570 - accuracy: 0.9846 Epoch 85/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0617 - accuracy: 0.9854 Epoch 86/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0624 - accuracy: 0.9831 Epoch 87/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0462 - accuracy: 0.9884 Epoch 88/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0520 - accuracy: 0.9884 Epoch 89/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0468 - accuracy: 0.9875 Epoch 90/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0600 - accuracy: 0.9806 Epoch 91/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0507 - accuracy: 0.9823 Epoch 92/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0528 - accuracy: 0.9841 Epoch 93/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0483 - accuracy: 0.9865 Epoch 94/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0615 - accuracy: 0.9832 Epoch 95/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0470 - accuracy: 0.9856 Epoch 96/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0419 - accuracy: 0.9900 Epoch 97/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0407 - accuracy: 0.9942 Epoch 98/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0573 - accuracy: 0.9826 Epoch 99/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0465 - accuracy: 0.9877 Epoch 100/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0442 - accuracy: 0.9880 <tensorflow.python.keras.callbacks.History at 0x7f7860e3d048>

基本 MLP モデルを評価する

# Helper function to print evaluation metrics.
def print_metrics(model_desc, eval_metrics):
  """Prints evaluation metrics.

  Args:
    model_desc: A description of the model.
    eval_metrics: A dictionary mapping metric names to corresponding values. It
      must contain the loss and accuracy metrics.
  """
  print('\n')
  print('Eval accuracy for ', model_desc, ': ', eval_metrics['accuracy'])
  print('Eval loss for ', model_desc, ': ', eval_metrics['loss'])
  if 'graph_loss' in eval_metrics:
    print('Eval graph loss for ', model_desc, ': ', eval_metrics['graph_loss'])
eval_results = dict(
    zip(base_model.metrics_names,
        base_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('Base MLP model', eval_results)

5/5 [==============================] - 0s 9ms/step - loss: 1.2943 - accuracy: 0.7939

Eval accuracy for Base MLP model : 0.7938517332077026 Eval loss for Base MLP model : 1.2943289279937744

MLP モデルをグラフ正則化でトレーニングする

既存の tf.Keras.Model の損失項にグラフ正則化を組み込む場合に必要なのは、数行のコードのみです。基本モデルをラップして、損失にグラフ正則化を含んだ新しい tf.Keras サブクラスモデルを作成します。

グラフ正則化の増分効果を評価するために、基本モデルの新しいインスタンスを作成します。これは、base_model は既に数回のイテレーションでトレーニングされているため、このトレーニング済みモデルを再利用してグラフ正則化モデルを作成しても base_model の公平な比較をすることができないからです。

# Build a new base MLP model.
base_reg_model_tag, base_reg_model = 'FUNCTIONAL', make_mlp_functional_model(
    HPARAMS)
# Wrap the base MLP model with graph regularization.
graph_reg_config = nsl.configs.make_graph_reg_config(
    max_neighbors=HPARAMS.num_neighbors,
    multiplier=HPARAMS.graph_regularization_multiplier,
    distance_type=HPARAMS.distance_type,
    sum_over_axis=-1)
graph_reg_model = nsl.keras.GraphRegularization(base_reg_model,
                                                graph_reg_config)
graph_reg_model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy'])
graph_reg_model.fit(train_dataset, epochs=HPARAMS.train_epochs, verbose=1)

Epoch 1/100 /tmpfs/src/tf_docs_env/lib/python3.6/site-packages/tensorflow/python/framework/indexed_slices.py:437: UserWarning: Converting sparse IndexedSlices(IndexedSlices(indices=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape_1:0", shape=(None,), dtype=int32), values=Tensor("gradient_tape/GraphRegularization/graph_loss/Reshape:0", shape=(None, 7), dtype=float32), dense_shape=Tensor("gradient_tape/GraphRegularization/graph_loss/Cast:0", shape=(2,), dtype=int32))) to a dense Tensor of unknown shape. This may consume a large amount of memory. "shape. This may consume a large amount of memory." % value) 17/17 [==============================] - 2s 11ms/step - loss: 1.9542 - accuracy: 0.1519 - scaled_graph_loss: 7.4008e-04 Epoch 2/100 17/17 [==============================] - 0s 11ms/step - loss: 1.8780 - accuracy: 0.2452 - scaled_graph_loss: 8.6243e-04 Epoch 3/100 17/17 [==============================] - 0s 11ms/step - loss: 1.7961 - accuracy: 0.3197 - scaled_graph_loss: 0.0016 Epoch 4/100 17/17 [==============================] - 0s 10ms/step - loss: 1.6863 - accuracy: 0.3774 - scaled_graph_loss: 0.0032 Epoch 5/100 17/17 [==============================] - 0s 11ms/step - loss: 1.5712 - accuracy: 0.3973 - scaled_graph_loss: 0.0054 Epoch 6/100 17/17 [==============================] - 0s 11ms/step - loss: 1.4242 - accuracy: 0.4789 - scaled_graph_loss: 0.0087 Epoch 7/100 17/17 [==============================] - 0s 11ms/step - loss: 1.3093 - accuracy: 0.5452 - scaled_graph_loss: 0.0125 Epoch 8/100 17/17 [==============================] - 0s 11ms/step - loss: 1.1419 - accuracy: 0.6088 - scaled_graph_loss: 0.0169 Epoch 9/100 17/17 [==============================] - 0s 11ms/step - loss: 1.0283 - accuracy: 0.6588 - scaled_graph_loss: 0.0207 Epoch 10/100 17/17 [==============================] - 0s 11ms/step - loss: 0.9211 - accuracy: 0.7076 - scaled_graph_loss: 0.0243 Epoch 11/100 17/17 [==============================] - 0s 11ms/step - loss: 0.8022 - accuracy: 0.7699 - scaled_graph_loss: 0.0262 Epoch 12/100 17/17 [==============================] - 0s 11ms/step - loss: 0.7787 - accuracy: 0.7628 - scaled_graph_loss: 0.0284 Epoch 13/100 17/17 [==============================] - 0s 10ms/step - loss: 0.6991 - accuracy: 0.7949 - scaled_graph_loss: 0.0298 Epoch 14/100 17/17 [==============================] - 0s 11ms/step - loss: 0.6366 - accuracy: 0.8353 - scaled_graph_loss: 0.0298 Epoch 15/100 17/17 [==============================] - 0s 11ms/step - loss: 0.5447 - accuracy: 0.8312 - scaled_graph_loss: 0.0316 Epoch 16/100 17/17 [==============================] - 0s 11ms/step - loss: 0.5165 - accuracy: 0.8604 - scaled_graph_loss: 0.0295 Epoch 17/100 17/17 [==============================] - 0s 11ms/step - loss: 0.4780 - accuracy: 0.8717 - scaled_graph_loss: 0.0307 Epoch 18/100 17/17 [==============================] - 0s 11ms/step - loss: 0.4786 - accuracy: 0.8763 - scaled_graph_loss: 0.0304 Epoch 19/100 17/17 [==============================] - 0s 10ms/step - loss: 0.4446 - accuracy: 0.8762 - scaled_graph_loss: 0.0328 Epoch 20/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3954 - accuracy: 0.8953 - scaled_graph_loss: 0.0322 Epoch 21/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3739 - accuracy: 0.8967 - scaled_graph_loss: 0.0320 Epoch 22/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3835 - accuracy: 0.9009 - scaled_graph_loss: 0.0329 Epoch 23/100 17/17 [==============================] - 0s 10ms/step - loss: 0.3242 - accuracy: 0.9201 - scaled_graph_loss: 0.0330 Epoch 24/100 17/17 [==============================] - 0s 11ms/step - loss: 0.3034 - accuracy: 0.9214 - scaled_graph_loss: 0.0310 Epoch 25/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2909 - accuracy: 0.9281 - scaled_graph_loss: 0.0345 Epoch 26/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2921 - accuracy: 0.9249 - scaled_graph_loss: 0.0347 Epoch 27/100 17/17 [==============================] - 0s 12ms/step - loss: 0.2439 - accuracy: 0.9483 - scaled_graph_loss: 0.0335 Epoch 28/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2524 - accuracy: 0.9445 - scaled_graph_loss: 0.0330 Epoch 29/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2310 - accuracy: 0.9424 - scaled_graph_loss: 0.0319 Epoch 30/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2389 - accuracy: 0.9388 - scaled_graph_loss: 0.0334 Epoch 31/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2204 - accuracy: 0.9523 - scaled_graph_loss: 0.0355 Epoch 32/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2159 - accuracy: 0.9525 - scaled_graph_loss: 0.0334 Epoch 33/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2022 - accuracy: 0.9561 - scaled_graph_loss: 0.0345 Epoch 34/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1926 - accuracy: 0.9601 - scaled_graph_loss: 0.0345 Epoch 35/100 17/17 [==============================] - 0s 11ms/step - loss: 0.2049 - accuracy: 0.9493 - scaled_graph_loss: 0.0343 Epoch 36/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1732 - accuracy: 0.9627 - scaled_graph_loss: 0.0335 Epoch 37/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1914 - accuracy: 0.9573 - scaled_graph_loss: 0.0327 Epoch 38/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1781 - accuracy: 0.9578 - scaled_graph_loss: 0.0332 Epoch 39/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1650 - accuracy: 0.9730 - scaled_graph_loss: 0.0324 Epoch 40/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1650 - accuracy: 0.9621 - scaled_graph_loss: 0.0328 Epoch 41/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1721 - accuracy: 0.9644 - scaled_graph_loss: 0.0339 Epoch 42/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1672 - accuracy: 0.9687 - scaled_graph_loss: 0.0356 Epoch 43/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1642 - accuracy: 0.9600 - scaled_graph_loss: 0.0343 Epoch 44/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1469 - accuracy: 0.9735 - scaled_graph_loss: 0.0334 Epoch 45/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1478 - accuracy: 0.9708 - scaled_graph_loss: 0.0340 Epoch 46/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1537 - accuracy: 0.9640 - scaled_graph_loss: 0.0367 Epoch 47/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1513 - accuracy: 0.9691 - scaled_graph_loss: 0.0355 Epoch 48/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1252 - accuracy: 0.9768 - scaled_graph_loss: 0.0327 Epoch 49/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1443 - accuracy: 0.9722 - scaled_graph_loss: 0.0352 Epoch 50/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1339 - accuracy: 0.9731 - scaled_graph_loss: 0.0333 Epoch 51/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1385 - accuracy: 0.9741 - scaled_graph_loss: 0.0362 Epoch 52/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1347 - accuracy: 0.9732 - scaled_graph_loss: 0.0333 Epoch 53/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1222 - accuracy: 0.9785 - scaled_graph_loss: 0.0353 Epoch 54/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1258 - accuracy: 0.9738 - scaled_graph_loss: 0.0354 Epoch 55/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1209 - accuracy: 0.9771 - scaled_graph_loss: 0.0352 Epoch 56/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1279 - accuracy: 0.9787 - scaled_graph_loss: 0.0352 Epoch 57/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1273 - accuracy: 0.9719 - scaled_graph_loss: 0.0312 Epoch 58/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1206 - accuracy: 0.9747 - scaled_graph_loss: 0.0332 Epoch 59/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1109 - accuracy: 0.9814 - scaled_graph_loss: 0.0342 Epoch 60/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1168 - accuracy: 0.9778 - scaled_graph_loss: 0.0338 Epoch 61/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1125 - accuracy: 0.9820 - scaled_graph_loss: 0.0341 Epoch 62/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1055 - accuracy: 0.9824 - scaled_graph_loss: 0.0359 Epoch 63/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1183 - accuracy: 0.9771 - scaled_graph_loss: 0.0361 Epoch 64/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1063 - accuracy: 0.9835 - scaled_graph_loss: 0.0343 Epoch 65/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1117 - accuracy: 0.9786 - scaled_graph_loss: 0.0306 Epoch 66/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1091 - accuracy: 0.9783 - scaled_graph_loss: 0.0343 Epoch 67/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0958 - accuracy: 0.9882 - scaled_graph_loss: 0.0340 Epoch 68/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1077 - accuracy: 0.9842 - scaled_graph_loss: 0.0366 Epoch 69/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1092 - accuracy: 0.9767 - scaled_graph_loss: 0.0353 Epoch 70/100 17/17 [==============================] - 0s 10ms/step - loss: 0.1159 - accuracy: 0.9777 - scaled_graph_loss: 0.0338 Epoch 71/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0875 - accuracy: 0.9881 - scaled_graph_loss: 0.0325 Epoch 72/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0905 - accuracy: 0.9864 - scaled_graph_loss: 0.0337 Epoch 73/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1021 - accuracy: 0.9767 - scaled_graph_loss: 0.0321 Epoch 74/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1047 - accuracy: 0.9773 - scaled_graph_loss: 0.0328 Epoch 75/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0958 - accuracy: 0.9812 - scaled_graph_loss: 0.0338 Epoch 76/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0997 - accuracy: 0.9802 - scaled_graph_loss: 0.0335 Epoch 77/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0853 - accuracy: 0.9877 - scaled_graph_loss: 0.0314 Epoch 78/100 17/17 [==============================] - 0s 11ms/step - loss: 0.1016 - accuracy: 0.9810 - scaled_graph_loss: 0.0346 Epoch 79/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0978 - accuracy: 0.9809 - scaled_graph_loss: 0.0317 Epoch 80/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0908 - accuracy: 0.9864 - scaled_graph_loss: 0.0329 Epoch 81/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0742 - accuracy: 0.9902 - scaled_graph_loss: 0.0332 Epoch 82/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0910 - accuracy: 0.9875 - scaled_graph_loss: 0.0345 Epoch 83/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0908 - accuracy: 0.9848 - scaled_graph_loss: 0.0345 Epoch 84/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0848 - accuracy: 0.9831 - scaled_graph_loss: 0.0328 Epoch 85/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0997 - accuracy: 0.9804 - scaled_graph_loss: 0.0345 Epoch 86/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0901 - accuracy: 0.9859 - scaled_graph_loss: 0.0326 Epoch 87/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0825 - accuracy: 0.9873 - scaled_graph_loss: 0.0334 Epoch 88/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0805 - accuracy: 0.9885 - scaled_graph_loss: 0.0332 Epoch 89/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0776 - accuracy: 0.9885 - scaled_graph_loss: 0.0330 Epoch 90/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0960 - accuracy: 0.9799 - scaled_graph_loss: 0.0341 Epoch 91/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0787 - accuracy: 0.9888 - scaled_graph_loss: 0.0337 Epoch 92/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0736 - accuracy: 0.9914 - scaled_graph_loss: 0.0348 Epoch 93/100 17/17 [==============================] - 0s 10ms/step - loss: 0.0806 - accuracy: 0.9892 - scaled_graph_loss: 0.0347 Epoch 94/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0723 - accuracy: 0.9912 - scaled_graph_loss: 0.0314 Epoch 95/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0671 - accuracy: 0.9887 - scaled_graph_loss: 0.0295 Epoch 96/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0780 - accuracy: 0.9887 - scaled_graph_loss: 0.0327 Epoch 97/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0843 - accuracy: 0.9871 - scaled_graph_loss: 0.0331 Epoch 98/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0796 - accuracy: 0.9901 - scaled_graph_loss: 0.0333 Epoch 99/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0871 - accuracy: 0.9847 - scaled_graph_loss: 0.0329 Epoch 100/100 17/17 [==============================] - 0s 11ms/step - loss: 0.0787 - accuracy: 0.9859 - scaled_graph_loss: 0.0335 <tensorflow.python.keras.callbacks.History at 0x7f786083d518>

MLP モデルをグラフ正則化で評価する

eval_results = dict(
    zip(graph_reg_model.metrics_names,
        graph_reg_model.evaluate(test_dataset, steps=HPARAMS.eval_steps)))
print_metrics('MLP + graph regularization', eval_results)

5/5 [==============================] - 0s 8ms/step - loss: 1.2281 - accuracy: 0.8103

Eval accuracy for MLP + graph regularization : 0.8155515193939209 Eval loss for MLP + graph regularization : 1.2275753021240234

グラフ正則化モデルの精度は、基本モデル(base_model)の精度に比べて 2-3% 程度高くなります。

結論

Neural Structured Learning(NSL)フレームワークを用いて、自然な引用グラフ上の文書(Cora)分類のためのグラフ正則化の使用について実証しました。上級者向けチュートリアルでは、グラフ正則化を使用してニューラルネットワークをトレーニングする前に、サンプル埋め込みに基づいたグラフを合成します。このアプローチは、入力に明示的なグラフが含まれない場合に有用です。

ユーザーの方々には、グラフ正則化のさまざまなニューラルアーキテクチャを試してみると共に、監視の量を加減してさらに深く実験を行うことを推奨しています。

特に記載のない限り、このページのコンテンツはクリエイティブ・コモンズの表示 4.0 ライセンスにより使用許諾されます。コードサンプルは Apache 2.0 ライセンスにより使用許諾されます。詳しくは、Google Developers サイトのポリシーをご覧ください。Java は Oracle および関連会社の登録商標です。

最終更新日 2022-01-24 UTC。