サイトアイコン 知的好奇心

Kerasにおけるtrain、validation、testについて

Kerasにおけるtrain、validation、testについて簡単に説明します。

各データをざっくり言うと

テストデータは汎化性能をチェックするために、最後に(理想的には一度だけ)利用します。

出典:ゼロから作るDeep Learning

ハイパーパラメータとは?

Kerasでの確認

サマリー

以下のように、model.fitのvalidation_split=に「0.1」を渡した場合、学習データ(train)のうち、10%が検証データ(validation)として使用されることになります。

model.fit(Xtrain, Ytrain, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS,
          validation_split=0.1)

model.summary()での出力は、以下のようになります。

Train on 54000 samples, validate on 6000 samples

同様に、model.fitのvalidation_split=に「0.2」を渡した場合、学習データ(train)のうち、20%が検証データ(validation)として使用されることになります。

model.fit(Xtrain, Ytrain, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS,
          validation_split=0.2)

model.summary()での出力は、以下のようになります。

Train on 48000 samples, validate on 12000 samples

実行ログ

model.fit()のログで出力される値では以下の通り。

48000/48000 [==============================] - 1s 27us/step - loss: 0.0596 - acc: 0.9798 - val_loss: 0.0818 - val_acc: 0.9778

acc < val_acc の場合

学習データの精度が検証データの精度より低い場合、学習が不十分であると言えます。
エポック数を増やしてみましょう。

学習データの精度は評価データの精度を超えるべきです。そうでないときは、十分に学習していないことになります。その場合は学習回数を大幅に増やしてみましょう。

出典:直感 Deep Learing

過学習?

検証データ(validation)の精度推移は、過学習かどうかの確認にも用いられます。

以下のようなソースでグラフ表示を行います。

# 学習実行
history = model.fit(Xtrain, Ytrain, batch_size=BATCH_SIZE, epochs=NUM_EPOCHS,
          validation_split=0.2)

# グラフ描画
# Accuracy
plt.plot(range(1, NUM_EPOCHS+1), history.history['acc'], "-o")
plt.plot(range(1, NUM_EPOCHS+1), history.history['val_acc'], "-o")
plt.title('model accuracy')
plt.ylabel('accuracy')  # Y軸ラベル
plt.xlabel('epoch')  # X軸ラベル
plt.grid()
plt.legend(['acc', 'val_acc'], loc='best')
plt.show()

下のグラフのような結果であるとします。

8エポック以降、検証データ(validation)の精度が下がってきており、過学習が疑われます。

参考

Keras Document:ModelクラスAPI

https://keras.io/ja/models/model/

ゼロから作るDeep Learning

直感 Deep Learning

モバイルバージョンを終了