サイトアイコン 知的好奇心

Kerasで転移学習を行う方法

Kerasで転移学習を行う方法をご紹介します。

条件

転移学習とファインチューニング

「ゼロから作るDeep Learning」では以下のように説明されています。

どちらも似たような説明ですが、ネットで色々調べてみると以下のような感じかと思われます。
(一方は概念、もう一方はその中の一つのプロセスといった感じでしょうか)

メリット

学習済みのモデルを流用することで、少ないデータで学習を行うことが出来ます。

転移学習前の推論

KerasではImageNetで学習した重みをもつ画像分類のモデルを簡単に使用することが出来ます。

VGG16の学習済みモデルを用いて、転移学習前の推論を行ってみます。

対象は、ペルーの国旗にも描かれている有名な動物「ビクーニャ」と、日本で有名な「タヌキ」の画像です。

ソース

from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input, decode_predictions
import keras.preprocessing.image as Image
import numpy as np


model = VGG16(weights="imagenet", include_top=True)

image_path = "sample_images_pretrain/biku.jpg"
image = Image.load_img(image_path, target_size=(224, 224))  # imagenet size
x = Image.img_to_array(image)
x = np.expand_dims(x, axis=0)  # add batch size dim
x = preprocess_input(x)

result = model.predict(x)
result = decode_predictions(result, top=3)[0]
print(result[0][1])  # show description

for _, name, score in result:
    print('{}: {:.2%}'.format(name, score))

実行結果

「ビクーニャ」の結果は「ガゼル」と出ました。

2番目の可能性として「インパラ」、3番目には「リャマ」と出ています。
一番似ているであろう「リャマ」の可能性さえ3%程度という不本意な結果です。

gazelle
gazelle: 84.48%
impala: 9.11%
llama: 3.08%

「タヌキ」の結果は「マダガスカルキャット」と出ました。

2番目は「ワラビー」、3番目の結果は「インドリ」と出ています。
どうやら「ビクーニャ」および「タヌキ」は未学習のようです。

Madagascar_cat
Madagascar_cat: 38.88%
wallaby: 26.10%
indri: 7.21%

転移学習

「ビクーニャ」はともかく、「タヌキ」が正しく認識されないのは日本人として受け入れることが出来ないでしょう。

ここでは、「ビクーニャ」or「タヌキ」という分類機を転移学習によって作成したいと思います。

詳しい説明は以下のサイトをご参照ください。(当該サイトを参考にしています。)

決定木の2つの種類とランダムフォレストによる機械学習アルゴリズム入門

学習画像

以下のように、「ビクーニャ」と「タヌキ」の画像を配置します。
(画像ファイル名は任意)

data/
    image/
        train/
            raccoon/
                raccoon001.jpg
                raccoon002.jpg
                ...
            vicuna/
                vicuna001.jpg
                vicuna002.jpg
                ...
        validation/
            raccoon/
                raccoon031.jpg
                raccoon032.jpg
                ...
            vicuna/
                vicuna031.jpg
                vicuna032.jpg
                ...

ソース

ポイント

base_model = VGG16(weights='imagenet', include_top=False, input_tensor=input_tensor)

「include_top=False」とすることで、ネットワークの出力層側にある3つの全結合層を含まないようにします。

for layer in base_model.layers[:15]:
   layer.trainable = False

入力層から15番目の層までを学習させないようにします。(入力層から15番目層までの重みは変わらない)

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(N_CATEGORIES, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)

出力側の層に、全結合層を追加します。
2種類の分類を行うモデルとなります。
summaryで見ると以下のような感じです。

_________________________________________________________________
global_average_pooling2d_1 ( (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 1024)              525312    
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 2050      
=================================================================
train_datagen = ImageDataGenerator(
 ・・・
)
test_datagen = ImageDataGenerator(
 ・・・
)
train_generator = train_datagen.flow_from_directory(
 ・・・
)
validation_generator = test_datagen.flow_from_directory(
 ・・・
)

画像を回転、水平/垂直方向へシフトなどを行って、学習/バリデーション画像の水増しを行います。

ソース全体

from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D,Input
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

N_CATEGORIES = 2
IMAGE_SIZE = 224
BATCH_SIZE = 5
NUM_EPOCHS = 50

train_data_dir = 'C:/data/image/train'
validation_data_dir = 'C:/data/image/validation'

NUM_TRAINING = 30
NUM_VALIDATION = 5

input_tensor = Input(shape=(IMAGE_SIZE, IMAGE_SIZE, 3))
base_model = VGG16(weights='imagenet', include_top=False, input_tensor=input_tensor)

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(N_CATEGORIES, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)

# for layer in base_model.layers:
for layer in base_model.layers[:15]:
   layer.trainable = False
from keras.optimizers import SGD
model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy',metrics=['accuracy'])

model.summary()

train_datagen = ImageDataGenerator(
   rescale=1.0 / 255,
   shear_range=0.2,
   zoom_range=0.2,
   horizontal_flip=True,
   rotation_range=10)

test_datagen = ImageDataGenerator(
   rescale=1.0 / 255,
)

train_generator = train_datagen.flow_from_directory(
   train_data_dir,
   target_size=(IMAGE_SIZE, IMAGE_SIZE),
   batch_size=BATCH_SIZE,
   class_mode='categorical',
   shuffle=True
)

validation_generator = test_datagen.flow_from_directory(
   validation_data_dir,
   target_size=(IMAGE_SIZE, IMAGE_SIZE),
   batch_size=BATCH_SIZE,
   class_mode='categorical',
   shuffle=True
)

history = model.fit_generator(train_generator,
   steps_per_epoch=NUM_TRAINING//BATCH_SIZE,
   epochs=NUM_EPOCHS,
   verbose=1,
   validation_data=validation_generator,
   validation_steps=NUM_VALIDATION//BATCH_SIZE,
   )

model.save('transfer.h5')  # モデルの保存

# グラフ描画
# Accuracy
plt.plot(range(1, NUM_EPOCHS+1), history.history['acc'], "o-")
plt.plot(range(1, NUM_EPOCHS+1), history.history['val_acc'], "o-")
plt.title('model accuracy')
plt.ylabel('accuracy')  # Y軸ラベル
plt.xlabel('epoch')  # X軸ラベル
plt.legend(['train', 'test'], loc='upper left')
plt.show()
# loss
plt.plot(range(1, NUM_EPOCHS+1), history.history['loss'], "o-")
plt.plot(range(1, NUM_EPOCHS+1), history.history['val_loss'], "o-")
plt.title('model loss')
plt.ylabel('loss')  # Y軸ラベル
plt.xlabel('epoch')  # X軸ラベル
plt.legend(['train', 'test'], loc='upper right')
plt.show()

実行結果

モデルのサマリー

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_1 (InputLayer)         (None, 224, 224, 3)       0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 224, 224, 64)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 224, 224, 64)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 112, 112, 64)      0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 112, 112, 128)     73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 112, 112, 128)     147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 56, 56, 128)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 56, 56, 256)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 56, 56, 256)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 28, 28, 256)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 28, 28, 512)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 28, 28, 512)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 14, 14, 512)       0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_conv3 (Conv2D)        (None, 14, 14, 512)       2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 7, 7, 512)         0         
_________________________________________________________________
global_average_pooling2d_1 ( (None, 512)               0         
_________________________________________________________________
dense_1 (Dense)              (None, 1024)              525312    
_________________________________________________________________
dense_2 (Dense)              (None, 2)                 2050      
=================================================================
Total params: 15,242,050
Trainable params: 7,606,786
Non-trainable params: 7,635,264
_________________________________________________________________
Found 60 images belonging to 2 classes.
Found 10 images belonging to 2 classes.

学習状況グラフ

転移学習後の推論

先ほど出力した「transfer.h5」を読み込んで推論を行います。

ソース

from keras.applications.vgg16 import preprocess_input
import keras.preprocessing.image as Image
import numpy as np
from keras.models import load_model


model = load_model('transfer.h5')

image_path = "sample_images_pretrain/biku.jpg"
# image_path = "sample_images_pretrain/tanu.jpg"

image = Image.load_img(image_path, target_size=(224, 224))  # imagenet size
x = Image.img_to_array(image)  # numpy 配列に変換
x = np.expand_dims(x, axis=0)  # add batch size dim
x = preprocess_input(x)


result = model.predict(x, 1)[0]
for i, score in enumerate(result):
    if i==0:
        print('{}: {:.2%}'.format('タヌキ', score))
    else :
        print('{}: {:.2%}'.format('ビクーニャ', score))

実行結果

いい感じに識別できたようです。

タヌキ: 0.03%
ビクーニャ: 99.97%

タヌキ: 100.00%
ビクーニャ: 0.00%

参考

少ない画像から画像分類を学習させる方法(kerasで転移学習:fine tuning)

決定木の2つの種類とランダムフォレストによる機械学習アルゴリズム入門

Qiita:TensorFlow + Kerasでフレンズ識別する – その3: 分類編

https://qiita.com/croquette0212/items/1f52913cc861ae8cebad

ゼロから作るDeep Learning

直感 Deep Learning

モバイルバージョンを終了