こんにちは。技術研究所のYKです。
最近TensorFlowを触り始めて、使ってみたら楽しかったので記事を書くことにしました。

 

今回はTensorFlowを使って、画像・正解データから単一ファイルのデータセットを作って学習させるところまでやってみました。

なんで単一ファイルにまとめたかったの?

  • こちらに記載があるように、TensorFlowでは画像ファイルを直接読み込んで利用することも可能です。
    しかし、学習用のファイルを単一のファイルにまとめることでデータセットを管理し易くなり、かつ人的ミスが減らせるので、単一ファイルにまとめて扱えるようにしたいと思いました。
    TensorFlowではTFRecordsという形式でデータセットをまとめられるとのことなので、試してみました。

TFRecords形式について

  • ここStandard TensorFlow format にある通り、TFRecordsはTensorFlow向けの推奨ファイルフォーマットです。
    TFRecordsファイルには、Protocol Buffersフォーマットでシリアライズされたデータを詰め込むことができます。
    つまり、画像データなどのバイナリデータでも、シリアライズしてしまえば詰め込むことが可能です。
  • TFRecords形式に変換しておくメリットとしては、以下の2点が挙げられます。
    1. TFRecords形式のファイルを扱う機能が提供されている
    2. 高速
      • 直接画像データを読み込むよりも、TFRecords形式に変換したものを利用した方が高速にデータを処理することができます。
      • 参考: Performance Guide

試した環境

  • OS: macOS Sierra 10.12.5
  • メモリ: 8GB
  • Python ver: Python 3.6.0 :: Anaconda 4.3.1 (x86_64)
  • TensorFlow
    • ver: 1.2rc0
      • 2017/06/14に試してみた時点では、上記のバージョンが提供されておりましたので、こちらのバージョンで試しました。
        しかし、06/22に改めてTensorFlowの公式サイトを確認したところ、バージョン1.2が正式版として提供開始されていました…
      • 念の為、今回実装したソースコードの動作確認をバージョン1.2にて行いましたが、幸いなことに何ら問題は発生しませんでした。
  • CPUモード
  • データセットの作成・読み込み

    データセットの作成

    • TFRecordWriterオブジェクトを利用することで、TFRecords形式のデータセットを作成できます。
    • ソースコードの処理概要
      1. 画像ファイルが格納されているディレクトリと、それに対応する正解データを用意
        • ここでは事前にlistとして変数に格納しておく
  • TFRecordWriterオブジェクトを生成
  • 正解データを全件走査して下記の処理を実行
    1. 該当する画像ファイルを読み込み
    2. 画像データをリサイズ & バイナリ文字列に変換
      • CNNにて利用し易いように、すべての画像を一定のサイズにリサイズ
        TFRecordsに保存する上では、サイズは一定でなくとも問題ない
    • TFRecordWriterオブジェクトを使って書き込み

    ソースコード

    import tensorflow as tf
    import numpy as np
    from PIL import Image
     
    # 作成するデータセットのファイルパスを指定
    dataset_path = "dataset.tfrecords"
     
    # 格納する画像サイズの指定: MNISTデータセットの画像を利用したので、H28xW28を指定
    width, height = [28, 28]
     
    # クラス数(分類したい種類の数): ラベルをone-hot表現に変換する為に利用
    # one-hot表現については https://ja.wikipedia.org/wiki/One-hot を参照
    class_count = 10
     
    # 正解データ: 画像ファイル名と正解ラベルのリスト
    # [[画像ファイル名, 正解ラベル], ...]
    datas = [["img0.png", 0], ["img2.png", 2], ["img1.png", 1]]
     
    # TFRecordsファイルに書き出す為、TFRecordWriterオブジェクトを生成
    writer = tf.python_io.TFRecordWriter(dataset_path)
     
    # datasから、画像ファイル名と正解ラベルの対を1件ずつ取り出す
    for img_name, label in datas:
    # 画像ファイルを読み込み、リサイズ & バイト文字列に変換
    img_obj = Image.open(img_name).convert("RGB").resize((width, height))
    img = np.array(img_obj).tostring()
     
    # 画像ファイル1件につき、1つのtf.train.Exampleを作成
    record = tf.train.Example(features=tf.train.Features(feature={
    "class_count": tf.train.Feature(
    int64_list=tf.train.Int64List(value=[class_count])),
    "label": tf.train.Feature(
    int64_list=tf.train.Int64List(value=[label])),
    "image": tf.train.Feature(
    bytes_list=tf.train.BytesList(value=[img])),
    "height": tf.train.Feature(
    int64_list=tf.train.Int64List(value=[height])),
    "width": tf.train.Feature(
    int64_list=tf.train.Int64List(value=[width])),
    "depth": tf.train.Feature(
    int64_list=tf.train.Int64List(value=[3])),
    }))
     
    # tf.train.ExampleをTFRecordsファイルに書き込む
    writer.write(record.SerializeToString())
     
    writer.close()

    データセットの読み込み

    • TensorFlowでは、queueを利用してファイルの読み込みを行います。
      また、TFRecordReaderオブジェクトを利用することでTFRecords形式のファイルを読み込むことができます。
    • ソースコードの処理概要
      1. TFRecordsファイルのパスをqueueに詰める
      2. queue内のTFRecordsファイルを読み込み、デシリアライズする
      3. デシリアライズしたデータを元の型に変換したり、元のshapeに戻したりする

    ソースコード

    import tensorflow as tf
    # 読み込み対象のファイルをqueueに詰める: TFRecordReaderはqueueを利用してファイルを読み込む
    file_name_queue = tf.train.string_input_producer([file_path])
    # TFRecordsファイルを読み込む為、TFRecordReaderオブジェクトを生成
    reader = tf.TFRecordReader()
    # 読み込み: ファイルから読み込み、serialized_exampleに格納する
    _, serialized_example = reader.read(file_name_queue)
    # デシリアライズ: serialized_exampleはシリアライズされているので、デシリアライズする
    # → Tensorオブジェクトが返却される
    features = tf.parse_single_example(
    serialized_example,
    features={
    "class_count": tf.FixedLenFeature([], tf.int64),
    "label": tf.FixedLenFeature([], tf.int64),
    "image": tf.FixedLenFeature([], tf.string),
    "height": tf.FixedLenFeature([], tf.int64),
    "width": tf.FixedLenFeature([], tf.int64),
    "depth": tf.FixedLenFeature([], tf.int64),
    })
    # featuresオブジェクト内の要素はTensorオブジェクトとなっている
    # でも、Tensorオブジェクトに直接アクセスしても中身が見えない
    #
    # → 中身を見る為には、session張ってeval()する
    # → eval()する為にはCoordinatorオブジェクトを生成して、start_queue_runner()しておく必要がある
    with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    try:
    # --- 実数値に変換する必要があるもののみeval() ---
    height = tf.cast(features["height"], tf.int32).eval()
    width = tf.cast(features["width"], tf.int32).eval()
    depth = tf.cast(features["depth"], tf.int32).eval()
    class_count = tf.cast(features["class_count"], tf.int32).eval()
    # --- 画像データとラベルは学習時に適宜取り出したいのでeval()しない ---
    label = tf.cast(features["label"], tf.int32)
    # バイト文字列をdecodeし、元のshapeに戻す
    img = tf.reshape(tf.decode_raw(features["image"], tf.uint8),
    tf.stack([height, width, depth]))
    finally:
    coord.request_stop()
    coord.join(threads)
    # labelをone-hot表現に変換
    label = tf.one_hot(label, class_count)

    データセットからのミニバッチ単位での取り出し

    • 「データセットの読み込み」のimg, labelからミニバッチ単位でデータを取り出します。
    • ソースコードの処理概要
      1. img, labelの値の調整と型変換
      2. ミニバッチ単位で取り出す為の変数を作成
      3. ミニバッチ分のデータの取り出し

    ソースコード

    # ピクセル値が0 ~ 255の範囲の値を取ってしまっているので、0 ~ 1の範囲の値になるように調整
    img = tf.cast(img, tf.float32) * (1. / 255)
    label = tf.cast(label, dtype=tf.float32)
    # ミニバッチのサイズを指定
    batch_size = 100
    # ミニバッチ単位で取り出せるようにする
    # 詳細は https://www.tensorflow.org/api_docs/python/tf/train/batch
    images, sparse_labels = tf.train.batch(
    [img, label], batch_size=batch_size, num_threads=2,
    capacity=1000 + 3 * batch_size)
    # あとはsession張ってsess.run()すればミニバッチ単位でデータを取り出せる
    sess = tf.InteractiveSession()
    sess.run(tf.global_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    try:
    # ミニバッチ分のデータを取り出す
    imgs, labels = sess.run([img, label])
    # 後はこのデータを使って煮るなり焼くなり…
    ...
    finally:
    coord.request_stop()
    coord.join(threads)
    sess.close()

    TFRecordsを使ってみて

    • 先述した通りですが、TFRecordsには様々なデータを突っ込めるので非常に便利でした。
    • データセットを扱う機能以外にも、Data augmentationに必要な機能なども提供されているようなので、学習前の前処理がかなり軽減されそうな感じがしています。
    • まだローカル環境で使っている状態なので、せっかくならGoogle Cloud Platform上でも動かしてみたいところです。