こんにちは。技術研究所のYKです。
最近TensorFlowを触り始めて、使ってみたら楽しかったので記事を書くことにしました。
今回はTensorFlowを使って、画像・正解データから単一ファイルのデータセットを作って学習させるところまでやってみました。
なんで単一ファイルにまとめたかったの?
- こちらに記載があるように、TensorFlowでは画像ファイルを直接読み込んで利用することも可能です。
しかし、学習用のファイルを単一のファイルにまとめることでデータセットを管理し易くなり、かつ人的ミスが減らせるので、単一ファイルにまとめて扱えるようにしたいと思いました。
TensorFlowではTFRecordsという形式でデータセットをまとめられるとのことなので、試してみました。
TFRecords形式について
- ここの Standard TensorFlow format にある通り、TFRecordsはTensorFlow向けの推奨ファイルフォーマットです。
TFRecordsファイルには、Protocol Buffersフォーマットでシリアライズされたデータを詰め込むことができます。
つまり、画像データなどのバイナリデータでも、シリアライズしてしまえば詰め込むことが可能です。 - TFRecords形式に変換しておくメリットとしては、以下の2点が挙げられます。
- TFRecords形式のファイルを扱う機能が提供されている
- 高速
- 直接画像データを読み込むよりも、TFRecords形式に変換したものを利用した方が高速にデータを処理することができます。
- 参考: Performance Guide
試した環境
- OS: macOS Sierra 10.12.5
- メモリ: 8GB
- Python ver: Python 3.6.0 :: Anaconda 4.3.1 (x86_64)
- TensorFlow
- ver: 1.2rc0
- 2017/06/14に試してみた時点では、上記のバージョンが提供されておりましたので、こちらのバージョンで試しました。
しかし、06/22に改めてTensorFlowの公式サイトを確認したところ、バージョン1.2が正式版として提供開始されていました… - 念の為、今回実装したソースコードの動作確認をバージョン1.2にて行いましたが、幸いなことに何ら問題は発生しませんでした。
- 2017/06/14に試してみた時点では、上記のバージョンが提供されておりましたので、こちらのバージョンで試しました。
- CPUモード
- ver: 1.2rc0
データセットの作成・読み込み
データセットの作成
- TFRecordWriterオブジェクトを利用することで、TFRecords形式のデータセットを作成できます。
- ソースコードの処理概要
- 画像ファイルが格納されているディレクトリと、それに対応する正解データを用意
- ここでは事前にlistとして変数に格納しておく
- TFRecordWriterオブジェクトを生成
- 正解データを全件走査して下記の処理を実行
- 該当する画像ファイルを読み込み
- 画像データをリサイズ & バイナリ文字列に変換
- CNNにて利用し易いように、すべての画像を一定のサイズにリサイズ
TFRecordsに保存する上では、サイズは一定でなくとも問題ない
- CNNにて利用し易いように、すべての画像を一定のサイズにリサイズ
- TFRecordWriterオブジェクトを使って書き込み
- 画像ファイルが格納されているディレクトリと、それに対応する正解データを用意
- ソースコード
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748import tensorflow as tfimport numpy as npfrom PIL import Image# 作成するデータセットのファイルパスを指定dataset_path = "dataset.tfrecords"# 格納する画像サイズの指定: MNISTデータセットの画像を利用したので、H28xW28を指定width, height = [28, 28]# クラス数(分類したい種類の数): ラベルをone-hot表現に変換する為に利用# one-hot表現については https://ja.wikipedia.org/wiki/One-hot を参照class_count = 10# 正解データ: 画像ファイル名と正解ラベルのリスト# [[画像ファイル名, 正解ラベル], ...]datas = [["img0.png", 0], ["img2.png", 2], ["img1.png", 1]]# TFRecordsファイルに書き出す為、TFRecordWriterオブジェクトを生成writer = tf.python_io.TFRecordWriter(dataset_path)# datasから、画像ファイル名と正解ラベルの対を1件ずつ取り出すfor img_name, label in datas:# 画像ファイルを読み込み、リサイズ & バイト文字列に変換img_obj = Image.open(img_name).convert("RGB").resize((width, height))img = np.array(img_obj).tostring()# 画像ファイル1件につき、1つのtf.train.Exampleを作成record = tf.train.Example(features=tf.train.Features(feature={"class_count": tf.train.Feature(int64_list=tf.train.Int64List(value=[class_count])),"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label])),"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[img])),"height": tf.train.Feature(int64_list=tf.train.Int64List(value=[height])),"width": tf.train.Feature(int64_list=tf.train.Int64List(value=[width])),"depth": tf.train.Feature(int64_list=tf.train.Int64List(value=[3])),}))# tf.train.ExampleをTFRecordsファイルに書き込むwriter.write(record.SerializeToString())writer.close()
データセットの読み込み
- TensorFlowでは、queueを利用してファイルの読み込みを行います。
また、TFRecordReaderオブジェクトを利用することでTFRecords形式のファイルを読み込むことができます。 - ソースコードの処理概要
- TFRecordsファイルのパスをqueueに詰める
- queue内のTFRecordsファイルを読み込み、デシリアライズする
- デシリアライズしたデータを元の型に変換したり、元のshapeに戻したりする
- ソースコード
1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253import tensorflow as tf# 読み込み対象のファイルをqueueに詰める: TFRecordReaderはqueueを利用してファイルを読み込むfile_name_queue = tf.train.string_input_producer([file_path])# TFRecordsファイルを読み込む為、TFRecordReaderオブジェクトを生成reader = tf.TFRecordReader()# 読み込み: ファイルから読み込み、serialized_exampleに格納する_, serialized_example = reader.read(file_name_queue)# デシリアライズ: serialized_exampleはシリアライズされているので、デシリアライズする# → Tensorオブジェクトが返却されるfeatures = tf.parse_single_example(serialized_example,features={"class_count": tf.FixedLenFeature([], tf.int64),"label": tf.FixedLenFeature([], tf.int64),"image": tf.FixedLenFeature([], tf.string),"height": tf.FixedLenFeature([], tf.int64),"width": tf.FixedLenFeature([], tf.int64),"depth": tf.FixedLenFeature([], tf.int64),})# featuresオブジェクト内の要素はTensorオブジェクトとなっている# でも、Tensorオブジェクトに直接アクセスしても中身が見えない## → 中身を見る為には、session張ってeval()する# → eval()する為にはCoordinatorオブジェクトを生成して、start_queue_runner()しておく必要があるwith tf.Session() as sess:sess.run(tf.local_variables_initializer())coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)try:# --- 実数値に変換する必要があるもののみeval() ---height = tf.cast(features["height"], tf.int32).eval()width = tf.cast(features["width"], tf.int32).eval()depth = tf.cast(features["depth"], tf.int32).eval()class_count = tf.cast(features["class_count"], tf.int32).eval()# --- 画像データとラベルは学習時に適宜取り出したいのでeval()しない ---label = tf.cast(features["label"], tf.int32)# バイト文字列をdecodeし、元のshapeに戻すimg = tf.reshape(tf.decode_raw(features["image"], tf.uint8),tf.stack([height, width, depth]))finally:coord.request_stop()coord.join(threads)# labelをone-hot表現に変換label = tf.one_hot(label, class_count)
データセットからのミニバッチ単位での取り出し
- 「データセットの読み込み」のimg, labelからミニバッチ単位でデータを取り出します。
- ソースコードの処理概要
- img, labelの値の調整と型変換
- ミニバッチ単位で取り出す為の変数を作成
- ミニバッチ分のデータの取り出し
- ソースコード
12345678910111213141516171819202122232425262728293031# ピクセル値が0 ~ 255の範囲の値を取ってしまっているので、0 ~ 1の範囲の値になるように調整img = tf.cast(img, tf.float32) * (1. / 255)label = tf.cast(label, dtype=tf.float32)# ミニバッチのサイズを指定batch_size = 100# ミニバッチ単位で取り出せるようにする# 詳細は https://www.tensorflow.org/api_docs/python/tf/train/batchimages, sparse_labels = tf.train.batch([img, label], batch_size=batch_size, num_threads=2,capacity=1000 + 3 * batch_size)# あとはsession張ってsess.run()すればミニバッチ単位でデータを取り出せるsess = tf.InteractiveSession()sess.run(tf.global_variables_initializer())coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess=sess, coord=coord)try:# ミニバッチ分のデータを取り出すimgs, labels = sess.run([img, label])# 後はこのデータを使って煮るなり焼くなり…...finally:coord.request_stop()coord.join(threads)sess.close()
TFRecordsを使ってみて
- 先述した通りですが、TFRecordsには様々なデータを突っ込めるので非常に便利でした。
- データセットを扱う機能以外にも、Data augmentationに必要な機能なども提供されているようなので、学習前の前処理がかなり軽減されそうな感じがしています。
- まだローカル環境で使っている状態なので、せっかくならGoogle Cloud Platform上でも動かしてみたいところです。