Kerasで転移学習を行う方法をご紹介します。
目次
条件
- Python 3.7.0
- Keras 2.1.2
転移学習とファインチューニング
「ゼロから作るDeep Learning」では以下のように説明されています。
- 転移学習
- 学習済みの重み(の一部)を別のニューラルネットワークにコピーして再学習を行うこと。
- ファインチューニング
- 学習済みの重みを初期値とし、新しいデータセットを対象に行う再学習のこと。
どちらも似たような説明ですが、ネットで色々調べてみると以下のような感じかと思われます。
(一方は概念、もう一方はその中の一つのプロセスといった感じでしょうか)
- 転移学習
- 学習済みモデル(ニューラルネットワーク)の重みを流用して、別のモデルの学習に用いるという手法。
- ファインチューニング
- 転移学習において、未学習データセットを流用した重みに再学習させるプロセス。
メリット
学習済みのモデルを流用することで、少ないデータで学習を行うことが出来ます。
転移学習前の推論
KerasではImageNetで学習した重みをもつ画像分類のモデルを簡単に使用することが出来ます。
VGG16の学習済みモデルを用いて、転移学習前の推論を行ってみます。
対象は、ペルーの国旗にも描かれている有名な動物「ビクーニャ」と、日本で有名な「タヌキ」の画像です。
ソース
from keras.applications.vgg16 import VGG16 from keras.applications.vgg16 import preprocess_input, decode_predictions import keras.preprocessing.image as Image import numpy as np model = VGG16(weights="imagenet", include_top=True) image_path = "sample_images_pretrain/biku.jpg" image = Image.load_img(image_path, target_size=(224, 224)) # imagenet size x = Image.img_to_array(image) x = np.expand_dims(x, axis=0) # add batch size dim x = preprocess_input(x) result = model.predict(x) result = decode_predictions(result, top=3)[0] print(result[0][1]) # show description for _, name, score in result: print('{}: {:.2%}'.format(name, score))
実行結果
「ビクーニャ」の結果は「ガゼル」と出ました。
2番目の可能性として「インパラ」、3番目には「リャマ」と出ています。
一番似ているであろう「リャマ」の可能性さえ3%程度という不本意な結果です。
gazelle gazelle: 84.48% impala: 9.11% llama: 3.08%
「タヌキ」の結果は「マダガスカルキャット」と出ました。
2番目は「ワラビー」、3番目の結果は「インドリ」と出ています。
どうやら「ビクーニャ」および「タヌキ」は未学習のようです。
Madagascar_cat Madagascar_cat: 38.88% wallaby: 26.10% indri: 7.21%
転移学習
「ビクーニャ」はともかく、「タヌキ」が正しく認識されないのは日本人として受け入れることが出来ないでしょう。
ここでは、「ビクーニャ」or「タヌキ」という分類機を転移学習によって作成したいと思います。
詳しい説明は以下のサイトをご参照ください。(当該サイトを参考にしています。)
学習画像
以下のように、「ビクーニャ」と「タヌキ」の画像を配置します。
(画像ファイル名は任意)
data/ image/ train/ raccoon/ raccoon001.jpg raccoon002.jpg ... vicuna/ vicuna001.jpg vicuna002.jpg ... validation/ raccoon/ raccoon031.jpg raccoon032.jpg ... vicuna/ vicuna031.jpg vicuna032.jpg ...
ソース
ポイント
base_model = VGG16(weights='imagenet', include_top=False, input_tensor=input_tensor)
「include_top=False」とすることで、ネットワークの出力層側にある3つの全結合層を含まないようにします。
for layer in base_model.layers[:15]: layer.trainable = False
入力層から15番目の層までを学習させないようにします。(入力層から15番目層までの重みは変わらない)
x = base_model.output x = GlobalAveragePooling2D()(x) x = Dense(1024, activation='relu')(x) predictions = Dense(N_CATEGORIES, activation='softmax')(x) model = Model(inputs=base_model.input, outputs=predictions)
出力側の層に、全結合層を追加します。
2種類の分類を行うモデルとなります。
summaryで見ると以下のような感じです。
_________________________________________________________________ global_average_pooling2d_1 ( (None, 512) 0 _________________________________________________________________ dense_1 (Dense) (None, 1024) 525312 _________________________________________________________________ dense_2 (Dense) (None, 2) 2050 =================================================================
train_datagen = ImageDataGenerator( ・・・ ) test_datagen = ImageDataGenerator( ・・・ ) train_generator = train_datagen.flow_from_directory( ・・・ ) validation_generator = test_datagen.flow_from_directory( ・・・ )
画像を回転、水平/垂直方向へシフトなどを行って、学習/バリデーション画像の水増しを行います。
ソース全体
from keras.models import Model from keras.layers import Dense, GlobalAveragePooling2D,Input from keras.applications.vgg16 import VGG16 from keras.preprocessing.image import ImageDataGenerator import matplotlib.pyplot as plt N_CATEGORIES = 2 IMAGE_SIZE = 224 BATCH_SIZE = 5 NUM_EPOCHS = 50 train_data_dir = 'C:/data/image/train' validation_data_dir = 'C:/data/image/validation' NUM_TRAINING = 30 NUM_VALIDATION = 5 input_tensor = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3)) base_model = VGG16(weights='imagenet', include_top=False, input_tensor=input_tensor) x = base_model.output x = GlobalAveragePooling2D()(x) x = Dense(1024, activation='relu')(x) predictions = Dense(N_CATEGORIES, activation='softmax')(x) model = Model(inputs=base_model.input, outputs=predictions) # for layer in base_model.layers: for layer in base_model.layers[:15]: layer.trainable = False from keras.optimizers import SGD model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy',metrics=['accuracy']) model.summary() train_datagen = ImageDataGenerator( rescale=1.0 / 255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True, rotation_range=10) test_datagen = ImageDataGenerator( rescale=1.0 / 255, ) train_generator = train_datagen.flow_from_directory( train_data_dir, target_size=(IMAGE_SIZE, IMAGE_SIZE), batch_size=BATCH_SIZE, class_mode='categorical', shuffle=True ) validation_generator = test_datagen.flow_from_directory( validation_data_dir, target_size=(IMAGE_SIZE, IMAGE_SIZE), batch_size=BATCH_SIZE, class_mode='categorical', shuffle=True ) history = model.fit_generator(train_generator, steps_per_epoch=NUM_TRAINING//BATCH_SIZE, epochs=NUM_EPOCHS, verbose=1, validation_data=validation_generator, validation_steps=NUM_VALIDATION//BATCH_SIZE, ) model.save('transfer.h5') # モデルの保存 # グラフ描画 # Accuracy plt.plot(range(1, NUM_EPOCHS+1), history.history['acc'], "o-") plt.plot(range(1, NUM_EPOCHS+1), history.history['val_acc'], "o-") plt.title('model accuracy') plt.ylabel('accuracy') # Y軸ラベル plt.xlabel('epoch') # X軸ラベル plt.legend(['train', 'test'], loc='upper left') plt.show() # loss plt.plot(range(1, NUM_EPOCHS+1), history.history['loss'], "o-") plt.plot(range(1, NUM_EPOCHS+1), history.history['val_loss'], "o-") plt.title('model loss') plt.ylabel('loss') # Y軸ラベル plt.xlabel('epoch') # X軸ラベル plt.legend(['train', 'test'], loc='upper right') plt.show()
実行結果
モデルのサマリー
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) (None, 224, 224, 3) 0 _________________________________________________________________ block1_conv1 (Conv2D) (None, 224, 224, 64) 1792 _________________________________________________________________ block1_conv2 (Conv2D) (None, 224, 224, 64) 36928 _________________________________________________________________ block1_pool (MaxPooling2D) (None, 112, 112, 64) 0 _________________________________________________________________ block2_conv1 (Conv2D) (None, 112, 112, 128) 73856 _________________________________________________________________ block2_conv2 (Conv2D) (None, 112, 112, 128) 147584 _________________________________________________________________ block2_pool (MaxPooling2D) (None, 56, 56, 128) 0 _________________________________________________________________ block3_conv1 (Conv2D) (None, 56, 56, 256) 295168 _________________________________________________________________ block3_conv2 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_conv3 (Conv2D) (None, 56, 56, 256) 590080 _________________________________________________________________ block3_pool (MaxPooling2D) (None, 28, 28, 256) 0 _________________________________________________________________ block4_conv1 (Conv2D) (None, 28, 28, 512) 1180160 _________________________________________________________________ block4_conv2 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_conv3 (Conv2D) (None, 28, 28, 512) 2359808 _________________________________________________________________ block4_pool (MaxPooling2D) (None, 14, 14, 512) 0 _________________________________________________________________ block5_conv1 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv2 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_conv3 (Conv2D) (None, 14, 14, 512) 2359808 _________________________________________________________________ block5_pool (MaxPooling2D) (None, 7, 7, 512) 0 _________________________________________________________________ global_average_pooling2d_1 ( (None, 512) 0 _________________________________________________________________ dense_1 (Dense) (None, 1024) 525312 _________________________________________________________________ dense_2 (Dense) (None, 2) 2050 ================================================================= Total params: 15,242,050 Trainable params: 7,606,786 Non-trainable params: 7,635,264 _________________________________________________________________ Found 60 images belonging to 2 classes. Found 10 images belonging to 2 classes.
学習状況グラフ
転移学習後の推論
先ほど出力した「transfer.h5」を読み込んで推論を行います。
ソース
from keras.applications.vgg16 import preprocess_input import keras.preprocessing.image as Image import numpy as np from keras.models import load_model model = load_model('transfer.h5') image_path = "sample_images_pretrain/biku.jpg" # image_path = "sample_images_pretrain/tanu.jpg" image = Image.load_img(image_path, target_size=(224, 224)) # imagenet size x = Image.img_to_array(image) # numpy 配列に変換 x = np.expand_dims(x, axis=0) # add batch size dim x = preprocess_input(x) result = model.predict(x, 1)[0] for i, score in enumerate(result): if i==0: print('{}: {:.2%}'.format('タヌキ', score)) else : print('{}: {:.2%}'.format('ビクーニャ', score))
実行結果
いい感じに識別できたようです。
タヌキ: 0.03% ビクーニャ: 99.97%
タヌキ: 100.00% ビクーニャ: 0.00%
参考
少ない画像から画像分類を学習させる方法(kerasで転移学習:fine tuning)
Qiita:TensorFlow + Kerasでフレンズ識別する – その3: 分類編
https://qiita.com/croquette0212/items/1f52913cc861ae8cebad