こんにちは。技術研究所のYKです。
最近TensorFlowを触り始めて、使ってみたら楽しかったので記事を書くことにしました。
今回はTensorFlowを使って、画像・正解データから単一ファイルのデータセットを作って学習させるところまでやってみました。
- こちらに記載があるように、TensorFlowでは画像ファイルを直接読み込んで利用することも可能です。
しかし、学習用のファイルを単一のファイルにまとめることでデータセットを管理し易くなり、かつ人的ミスが減らせるので、単一ファイルにまとめて扱えるようにしたいと思いました。
TensorFlowではTFRecordsという形式でデータセットをまとめられるとのことなので、試してみました。
- ここの Standard TensorFlow format にある通り、TFRecordsはTensorFlow向けの推奨ファイルフォーマットです。
TFRecordsファイルには、Protocol Buffersフォーマットでシリアライズされたデータを詰め込むことができます。
つまり、画像データなどのバイナリデータでも、シリアライズしてしまえば詰め込むことが可能です。 - TFRecords形式に変換しておくメリットとしては、以下の2点が挙げられます。
- TFRecords形式のファイルを扱う機能が提供されている
- 高速
- 直接画像データを読み込むよりも、TFRecords形式に変換したものを利用した方が高速にデータを処理することができます。
- 参考: Performance Guide
- OS: macOS Sierra 10.12.5
- メモリ: 8GB
- Python ver: Python 3.6.0 :: Anaconda 4.3.1 (x86_64)
- TensorFlow
- ver: 1.2rc0
- 2017/06/14に試してみた時点では、上記のバージョンが提供されておりましたので、こちらのバージョンで試しました。
しかし、06/22に改めてTensorFlowの公式サイトを確認したところ、バージョン1.2が正式版として提供開始されていました… - 念の為、今回実装したソースコードの動作確認をバージョン1.2にて行いましたが、幸いなことに何ら問題は発生しませんでした。
- 2017/06/14に試してみた時点では、上記のバージョンが提供されておりましたので、こちらのバージョンで試しました。
- ver: 1.2rc0
- TFRecordWriterオブジェクトを利用することで、TFRecords形式のデータセットを作成できます。
- ソースコードの処理概要
- 画像ファイルが格納されているディレクトリと、それに対応する正解データを用意
- ここでは事前にlistとして変数に格納しておく
- 画像ファイルが格納されているディレクトリと、それに対応する正解データを用意
- 該当する画像ファイルを読み込み
- 画像データをリサイズ & バイナリ文字列に変換
- CNNにて利用し易いように、すべての画像を一定のサイズにリサイズ
TFRecordsに保存する上では、サイズは一定でなくとも問題ない
- CNNにて利用し易いように、すべての画像を一定のサイズにリサイズ
- TFRecordWriterオブジェクトを使って書き込み
ソースコード
| import tensorflow as tf |
| import numpy as np |
| from PIL import Image |
| # 作成するデータセットのファイルパスを指定 |
| dataset_path = "dataset.tfrecords" |
| # 格納する画像サイズの指定: MNISTデータセットの画像を利用したので、H28xW28を指定 |
| width, height = [28, 28] |
| # クラス数(分類したい種類の数): ラベルをone-hot表現に変換する為に利用 |
| # one-hot表現については https://ja.wikipedia.org/wiki/One-hot を参照 |
| class_count = 10 |
| # 正解データ: 画像ファイル名と正解ラベルのリスト |
| # [[画像ファイル名, 正解ラベル], ...] |
| datas = [["img0.png", 0], ["img2.png", 2], ["img1.png", 1]] |
| # TFRecordsファイルに書き出す為、TFRecordWriterオブジェクトを生成 |
| writer = tf.python_io.TFRecordWriter(dataset_path) |
| # datasから、画像ファイル名と正解ラベルの対を1件ずつ取り出す |
| for img_name, label in datas: |
| # 画像ファイルを読み込み、リサイズ & バイト文字列に変換 |
| img_obj = Image.open(img_name).convert("RGB").resize((width, height)) |
| img = np.array(img_obj).tostring() |
| # 画像ファイル1件につき、1つのtf.train.Exampleを作成 |
| record = tf.train.Example(features=tf.train.Features(feature={ |
| "class_count": tf.train.Feature( |
| int64_list=tf.train.Int64List(value=[class_count])), |
| "label": tf.train.Feature( |
| int64_list=tf.train.Int64List(value=[label])), |
| "image": tf.train.Feature( |
| bytes_list=tf.train.BytesList(value=[img])), |
| "height": tf.train.Feature( |
| int64_list=tf.train.Int64List(value=[height])), |
| "width": tf.train.Feature( |
| int64_list=tf.train.Int64List(value=[width])), |
| "depth": tf.train.Feature( |
| int64_list=tf.train.Int64List(value=[3])), |
| })) |
| # tf.train.ExampleをTFRecordsファイルに書き込む |
| writer.write(record.SerializeToString()) |
| writer.close() |
- TensorFlowでは、queueを利用してファイルの読み込みを行います。
また、TFRecordReaderオブジェクトを利用することでTFRecords形式のファイルを読み込むことができます。 - ソースコードの処理概要
- TFRecordsファイルのパスをqueueに詰める
- queue内のTFRecordsファイルを読み込み、デシリアライズする
- デシリアライズしたデータを元の型に変換したり、元のshapeに戻したりする
ソースコード
| import tensorflow as tf |
| # 読み込み対象のファイルをqueueに詰める: TFRecordReaderはqueueを利用してファイルを読み込む |
| file_name_queue = tf.train.string_input_producer([file_path]) |
| # TFRecordsファイルを読み込む為、TFRecordReaderオブジェクトを生成 |
| reader = tf.TFRecordReader() |
| # 読み込み: ファイルから読み込み、serialized_exampleに格納する |
| _, serialized_example = reader.read(file_name_queue) |
| # デシリアライズ: serialized_exampleはシリアライズされているので、デシリアライズする |
| # → Tensorオブジェクトが返却される |
| features = tf.parse_single_example( |
| serialized_example, |
| features={ |
| "class_count": tf.FixedLenFeature([], tf.int64), |
| "label": tf.FixedLenFeature([], tf.int64), |
| "image": tf.FixedLenFeature([], tf.string), |
| "height": tf.FixedLenFeature([], tf.int64), |
| "width": tf.FixedLenFeature([], tf.int64), |
| "depth": tf.FixedLenFeature([], tf.int64), |
| }) |
| # featuresオブジェクト内の要素はTensorオブジェクトとなっている |
| # でも、Tensorオブジェクトに直接アクセスしても中身が見えない |
| # |
| # → 中身を見る為には、session張ってeval()する |
| # → eval()する為にはCoordinatorオブジェクトを生成して、start_queue_runner()しておく必要がある |
| with tf.Session() as sess: |
| sess.run(tf.local_variables_initializer()) |
| coord = tf.train.Coordinator() |
| threads = tf.train.start_queue_runners(sess=sess, coord=coord) |
| try: |
| # --- 実数値に変換する必要があるもののみeval() --- |
| height = tf.cast(features["height"], tf.int32).eval() |
| width = tf.cast(features["width"], tf.int32).eval() |
| depth = tf.cast(features["depth"], tf.int32).eval() |
| class_count = tf.cast(features["class_count"], tf.int32).eval() |
| # --- 画像データとラベルは学習時に適宜取り出したいのでeval()しない --- |
| label = tf.cast(features["label"], tf.int32) |
| # バイト文字列をdecodeし、元のshapeに戻す |
| img = tf.reshape(tf.decode_raw(features["image"], tf.uint8), |
| tf.stack([height, width, depth])) |
| finally: |
| coord.request_stop() |
| coord.join(threads) |
| # labelをone-hot表現に変換 |
| label = tf.one_hot(label, class_count) |
- 「データセットの読み込み」のimg, labelからミニバッチ単位でデータを取り出します。
- ソースコードの処理概要
- img, labelの値の調整と型変換
- ミニバッチ単位で取り出す為の変数を作成
- ミニバッチ分のデータの取り出し
ソースコード
| # ピクセル値が0 ~ 255の範囲の値を取ってしまっているので、0 ~ 1の範囲の値になるように調整 |
| img = tf.cast(img, tf.float32) * (1. / 255) |
| label = tf.cast(label, dtype=tf.float32) |
| # ミニバッチのサイズを指定 |
| batch_size = 100 |
| # ミニバッチ単位で取り出せるようにする |
| # 詳細は https://www.tensorflow.org/api_docs/python/tf/train/batch |
| images, sparse_labels = tf.train.batch( |
| [img, label], batch_size=batch_size, num_threads=2, |
| capacity=1000 + 3 * batch_size) |
| # あとはsession張ってsess.run()すればミニバッチ単位でデータを取り出せる |
| sess = tf.InteractiveSession() |
| sess.run(tf.global_variables_initializer()) |
| coord = tf.train.Coordinator() |
| threads = tf.train.start_queue_runners(sess=sess, coord=coord) |
| try: |
| # ミニバッチ分のデータを取り出す |
| imgs, labels = sess.run([img, label]) |
| # 後はこのデータを使って煮るなり焼くなり… |
| ... |
| finally: |
| coord.request_stop() |
| coord.join(threads) |
| sess.close() |
- 先述した通りですが、TFRecordsには様々なデータを突っ込めるので非常に便利でした。
- データセットを扱う機能以外にも、Data augmentationに必要な機能なども提供されているようなので、学習前の前処理がかなり軽減されそうな感じがしています。
- まだローカル環境で使っている状態なので、せっかくならGoogle Cloud Platform上でも動かしてみたいところです。


