Kerasにおけるtrain、validation、testについて

Kerasにおけるtrain、validation、testについて簡単に説明します。

各データをざっくり言うと

  • train
    • 実際にニューラルネットワークの重みを更新する学習データ。
  • validation
    • ニューラルネットワークのハイパーパラメータの良し悪しを確かめるための検証データ。学習は行わない。
  • test
    • 学習後に汎化性能を確かめるテストデータ。学習は行わない。

テストデータは汎化性能をチェックするために、最後に(理想的には一度だけ)利用します。

出典:ゼロから作るDeep Learning

ハイパーパラメータとは?

  • 各層のニューロン数、バッチサイズ、パラメータ更新の際の学習係数など、人の手によって設定されるパラメータのこと。
  • ニューラルネットワークの重みは、学習アルゴリズムによって自動獲得されるものであり、ハイパーパラメータではない。
  • 一般的に、ハイパーパラメータをいろいろな値で試しながら、うまく学習できるケースを探すという作業が必要になる。

Kerasでの確認

サマリー

以下のように、model.fitのvalidation_split=に「0.1」を渡した場合、学習データ(train)のうち、10%が検証データ(validation)として使用されることになります。

model.fit(Xtrain, Ytrain, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS,
          validation_split=0.1)

model.summary()での出力は、以下のようになります。

  • 60000データのうち、90%である54000が学習データ(train)に使用される。
  • 60000データのうち、10%である6000が検証データ(validation)に使用される。
Train on 54000 samples, validate on 6000 samples

同様に、model.fitのvalidation_split=に「0.2」を渡した場合、学習データ(train)のうち、20%が検証データ(validation)として使用されることになります。

model.fit(Xtrain, Ytrain, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS,
          validation_split=0.2)

model.summary()での出力は、以下のようになります。

  • 60000データのうち、80%である48000が学習データ(train)に使用される。
  • 60000データのうち、20%である12000が検証データ(validation)に使用される。
Train on 48000 samples, validate on 12000 samples

実行ログ

model.fit()のログで出力される値では以下の通り。

  • acc:学習データ(train)の精度
  • val_acc:検証データ(validation)の精度
48000/48000 [==============================] - 1s 27us/step - loss: 0.0596 - acc: 0.9798 - val_loss: 0.0818 - val_acc: 0.9778

acc < val_acc の場合

学習データの精度が検証データの精度より低い場合、学習が不十分であると言えます。
エポック数を増やしてみましょう。

学習データの精度は評価データの精度を超えるべきです。そうでないときは、十分に学習していないことになります。その場合は学習回数を大幅に増やしてみましょう。

出典:直感 Deep Learing

過学習?

検証データ(validation)の精度推移は、過学習かどうかの確認にも用いられます。

以下のようなソースでグラフ表示を行います。

# 学習実行
history = model.fit(Xtrain, Ytrain, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS,
          validation_split=0.2)

# グラフ描画
# Accuracy
plt.plot(range(1, NUM_EPOCHS+1), history.history['acc'], "-o")
plt.plot(range(1, NUM_EPOCHS+1), history.history['val_acc'], "-o")
plt.title('model accuracy')
plt.ylabel('accuracy')  # Y軸ラベル
plt.xlabel('epoch')  # X軸ラベル
plt.grid()
plt.legend(['acc', 'val_acc'], loc='best')
plt.show()

下のグラフのような結果であるとします。

8エポック以降、検証データ(validation)の精度が下がってきており、過学習が疑われます。

参考

Keras Document:ModelクラスAPI

https://keras.io/ja/models/model/

ゼロから作るDeep Learning

直感 Deep Learning

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

次の記事

MNISTデータの内容