MNISTデータの内容

MNISTデータの内容についてご紹介します。

条件

  • Python 3.7.0
  • Windows 10 64bit
  • Keras 2.2.4

MNISTデータについて

MNISTは、0から9までの手書き数字の画像データセットです。
60,000の学習用画像と、10,000のテスト用画像で構成されています。
各画像データと正解ラベルのデータが対になっており、機械学習の実験用データとしてよく用いられます。

Kerasの場合、以下のコマンドようなコマンドを実行すると、MNISTデータがダウンロードされます。

from keras.datasets import mnist

(X_train, y_train), (X_test, y_test) = mnist.load_data()

Windowsの場合、下記のようなフォルダに保存されます。

C:\Users\user\.keras\datasets\mnist.npz

ちなみにnpzファイルは、NumPyの複数の配列を一つのファイルにまとめたバイナリファイルです。

画像とラベル確認

pyplotで表示

サンプルソース

学習データとテストデータを、1つずつ表示してみます。

28×28の配列に変換し、グレースケールで表示します。

from keras.datasets import mnist
import matplotlib.pyplot as plt


# mnistデータのダウンロード
(X_train, y_train), (X_test, y_test) = mnist.load_data()

print("学習データのラベル:", y_train[0])
plt.imshow(X_train[0].reshape(28, 28), cmap='Greys')
plt.show()

print("テストデータのラベル:", y_test[0])
plt.imshow(X_test[0].reshape(28, 28), cmap='Greys')
plt.show()

実行結果

学習データ、テストデータ共に、ラベルと画像が対になっていることがわかります。

学習データのラベル: 5
テストデータのラベル: 7

PILで表示

Pythonの画像処理ライブラリである、PIL(Python Image Library)でも同様の表示を行ってみます。

サンプルソース

28×28の配列に変換し、PIL用データオブジェクトに変換して表示します。

import numpy as np
from keras.datasets import mnist
from PIL import Image


def img_show(img):
    pil_img = Image.fromarray(np.uint8(img))  # PIL用データオブジェクトに変換
    pil_img.show()


# mnistデータのダウンロード
(X_train, y_train), (X_test, y_test) = mnist.load_data()

img = X_train[0].reshape(28, 28)
print("学習データのラベル:", y_train[0])
img_show(img)  # 学習データの画像

img =X_test[0].reshape(28, 28)
print("テストデータのラベル:", y_test[0])
img_show(img)  # テストデータの画像

実行結果

学習データのラベル: 5
テストデータのラベル: 7

npzファイルを指定して読み込み

以下のように、npzファイルを指定して読み込むことも出来ます。

import matplotlib.pyplot as plt
import numpy as np


with np.load('C:/Users/user/.keras/datasets/mnist.npz') as data:

    X_train, y_train = data['x_train'], data['y_train']
    X_test, y_test = data['x_test'], data['y_test']

print("学習データのラベル:", y_train[0])
plt.imshow(X_train[0].reshape(28, 28), cmap='Greys')
plt.show()

print("テストデータのラベル:", y_test[0])
plt.imshow(X_test[0].reshape(28, 28), cmap='Greys')
plt.show()

参考

THE MNIST DATABASE

http://yann.lecun.com/exdb/mnist/

NumPy:Input and output

https://www.numpy.org/devdocs/reference/routines.io.html?highlight=npz

Pillow

https://pillow.readthedocs.io/en/stable/

matplotlib:matplotlib.pyplot.imshow

https://matplotlib.org/3.1.0/api/_as_gen/matplotlib.pyplot.imshow.html

matplotlib:color example code

https://matplotlib.org/examples/color/colormaps_reference.html

GitHub:keras-team/keras

https://github.com/keras-team/keras/blob/master/keras/datasets/mnist.py

コメントを残す

メールアドレスが公開されることはありません。 が付いている欄は必須項目です