CaptainChen

TensorFlow 学习笔记 -- Input Pipeline

2018-08-22 11149 words (about 50 min read) Views
TensorFlow Deep learning

本篇笔记主要总结了如何在 TensorFlow 如何构建高效的 Input Pipeline,目的是协调 CPU 文件预处理和 GPU 模型计算之间的调度,尽最大限度发挥 GPU 算力。其中涉及到 TFRecord 文件的读写,tf.image 模块对图像的处理,以及版本 1.4 前使用的生产者/消费者多线程文件读写流程,和 1.4 后官方主推的 Dataset 处理方式。后者已经开始逐步支持 eager 模式。

吐槽部分:这一部分学完后感触很深,tf 真的很难,它已经从一个 Python 库上升到了一个十分复杂的编程语言的高度,让人抓狂。大部分 API 文档是根据代码注释生成的,往往晦涩难懂,由于时间和精力原因去深入庞杂的 C++ 底层几乎不可能。由此带来的最大困扰是,you seldom know what is happening under the hood。比如有多少人知道 tf.image 模块的 resize 的插值方法中,align corners 和主流图像库处理的方式都不一样?在学习这部分的过程中,我翻阅了大量博客,Github Issue,Stack Overflow 回答,多次对着源码逐行分析并进行代码测试,最终总结成个人笔记可谓是“满纸荒唐言,一把辛酸泪”。不过,也正是学习这部分的钻研过程中,感慨 tf 真的是为了实际应用做了充分的考虑,正如 caffe 作者贾扬清在知乎说道: "TF的确难,但是它给你提供了真正可以产品化的可能性。"

先把这部分我个人觉得要注意的一些点(或者说是大坑)列举一下,方便日后查阅:

  • 读取 TFRecord 文件过程中,解析 Example Protobuf 文件时,decode_raw 得到的数据(如 image raw data) 要通过 reshape 操作恢复 shape,而 shape 参数也是从 TFRecord 文件中获取时,要加 tf.stack 操作: image = tf.reshape(image, tf.stack([height, width, channels]))
  • 读取图片时,PIL 的 Image.tobytes() 或 Numpy 的 np.array().tobytes() 得到的二进制文件搭配 tf.decode_raw() 使用;tf.gfile.GFile(img_name) 读取的图片文件要用 tf.image.decode_png() 等函数解析
  • tf.image.decode_image 得到的图像没有静态 shape,因此无法与 tf.image.resize_images 一起使用。应该尽量使用 tf.image.decode_png 函数
  • tf.image.resize_images 中有个 align_corners 参数,它的机制与主流图像库如 opencv 是不一样的,建议设置为 True。不然后果可参见 How Tensorflow’s tf.image.resize stole 60 days of my life
  • tf.image.sample_distorted_bounding_box 得到的 bbox 信息配合 tf.slice 裁剪得到的图像的最后一个维度是动态的,建议后面跟一个 set_shape 操作
  • tf.image.distorted_bounding_box_crop 函数的 min_object_covered 默认取 0.1,很多时候是不合适的
  • tf.train.match_filenames_once 和 tf.train.string_input_producer 都创建了局部变量,Session 中需要初始化。
  • tf.train.batch 和 tf.train.batch_join 的多线程不一样,后者才是一个文件分配一个线程
  • 用 tf.train.batch 等操作出队得到的多元 Tensor tuple 一定要一起 run,不然会错位交叉。参见 3.2.5 例子。
  • Dataset 模块中, shuffle 的 buffer_size 要好好设计,同时 shuffle 和 repeat 的顺序要注意
  • Dataset.from_tensor_slices 是把传进来的输入以 tf.constant() 的形式存在于 Graph 中,而一个 Graph Protobuf 文件的上限是 2G

推荐阅读: CS230 课件 “An overview of tf.data” 部分列举的所有链接。


1. TFRecord

1.1 TFRecord 格式

TFRecord 文件的数据是通过 tf.train.Example 这个 protobuf 的格式存储的。其定义在 tensorflow/core/example/example.prototensorflow/core/example/feature.proto :

message Example {
  // type: Features, name: features
  Features features = 1;
};

message Features {
  // Map from feature name to feature.
  // Features is a key-value store, where each key (string) 
  // maps to a Feature message
  map<string, Feature> feature = 1;
};

// Containers for non-sequential data.
message Feature {
  // Each feature can be exactly one kind.
  oneof kind {
    BytesList bytes_list = 1;
    FloatList float_list = 2;
    Int64List int64_list = 3;
  }
};

// Containers to hold repeated fundamental values.
message BytesList {
  repeated bytes value = 1;
}
message FloatList {
  repeated float value = 1 [packed = true];
}
message Int64List {
  repeated int64 value = 1 [packed = true];
}

tf.train.Example 中包含名为 features 的 Features 的信息,每个 Features 中包含一个从 feature 名称到 feature 属性 (Feature) 的字典映射(key-value对),其中每个 Feature 的可以从 BytesList (字符串),FloatList (浮点实数列表), Int64List (整数列表)中取。
一个 tf.train.Example 的例子:

features {
    feature {
        key: "age"
        value { float_list {
        value: 29.0
        }}
    }
    feature {
        key: "movie"
        value { bytes_list {
            value: "The Shawshank Redemption"
            value: "Fight Club"
        }}
    }
    feature {
        key: "movie_ratings"
        value { float_list {
            value: 9.0
            value: 9.7
        }}
    }
}

1.2 文件写入 TFRecord

样例程序:

import tensorflow as tf
import os
import cv2

# 输出的 TFRecord 文件带路径的名称
output = './output.tfrecords'
# 创建一个 writer 来写入 TFRecord 文件
writer = tf.python_io.TFRecordWriter(output)

# 辅助函数,生成整数和字符串型的 tf.train.Feature。
def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 读取图片,并保存到 TFRecord 文件中
img_dir = './img'
imgs = os.listdir(img_dir)
imgs.sort()

for index, img in enumerate(imgs):
    img_path = os.path.join(img_dir, img)
    img_data = cv2.imread(img_path)
    resized_img = cv2.resize(img_data, (128, 128, interpolation=cv2.INTER_AREA))
    # 这里必须把 ndarry 转换成字符串形式的原始二进制数据流
    img_raw = resized_img.tostring()

    # 生成 Example Protobuf 文件
    example = tf.train.Example(features=tf.train.Features(feature={
        'shape_': _int64_feature(resized_img.shape[0]),
        'label_': _int64_feature(index),
        'img_raw': _bytes_feature(img_raw)
    }))

    # 将序列化后的example 写入 TFRecord 文件
    writer.write(example.SerializeToString())
  
writer.close()

1.3 读取 TFRecord 文件

import tensorflow as tf

# 注意默认 shuffle = True
# 返回一个队列 Queue 对象
filename_queue = tf.train.string_input_producer(['./output.tfrecords'], shuffle=False)

# 创建一个 reader 读取 TFRecord 文件中的样例
reader = tf.TFRecordReader()
# 一次读取一个样例。也可以使用 read_up_to 函数一次读取多个样例
# Returns the next record (key, value) pair produced by a reader.
_, serialized_example = reader.read(filename_queue)

# 解析单个 features 文件; 解析多个用 parse_example 函数
features = tf.parse_single_example(
    serialized_example,
    features={
        'shape_': tf.FixedLenFeature([], tf.int64),
        'label_': tf.FixedLenFeature([], tf.int64),
        'img_raw': tf.FixedLenFeature([], tf.string)
    })
# 以上 tf.FixedLenFeature 方法解析的结果为一个 tensor。另一种方法是 tf.VarLenFeature, 解析的
# 结果是 SparseTensor

# 将字符串 tensor 解析成数据 tensor,注意此时为一维数据,需要 reshape
image = tf.decode_raw(features['img_raw'], tf.uint8)
image_shape = tf.stack([shape_, shape_, 3])  # 这一行不能少
image = tf.reshape(image, image_shape)
# 默认是 int64 的 tensor, 转成 int32
shape = tf.cast(features['shape_'], tf.int32)
label = tf.cast(features['label_'], tf.int32)

sess = tf.Session()
# 多线程部分参见第3部分
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

for i in range(4):
    print sess.run([image, shape, label])

2. tf.image 图像处理

2.1 图像读取,编码

读取示例:

tf.InteractiveSession()

image_raw_data = tf.gfile.FastGFile('./imgs/img1.png', 'rb').read()
# 一定要明确指定3通道,默认自动识别好像有bug,不生效
img_data = tf.image.decode_png(image_raw_data, channels=3)

# (732, 808, 3) uint8
# 格式: [H, W, C]
print img_data.eval().shape, img_data.eval().dtype

注意:

(1) tf.gfile 模块为 TensorFlow 的文件读写模块,C++ 实现,API 与 Python 自带文件模块高度相近。区别在于,tf.gfile 实现了多种文件系统的读写,比如本地文件,谷歌云存储文件(前缀 gs://),HDFS 文件(前缀 hdfs://) 等等。TensorFlow 中写入载入 checkpoints,TensorBoard 日志文件等都是用 tf.gfile 模块实现。

简言之,tf.gfile 实现了更多文件读写接口,在处理日常本地文件读写时,tf.gfile 并没有明显的速度优势,用 python 自带文件读写模块就可以了。

(2) tf.gfile.GFile, tf.gfile.FastGFile 二者在r1.8源码实现上并无区别,都是没有线程锁的文件 I/O,所以二者等同。

(3) 图像编码时的搭配:

  • 使用 PIL 的 Image.tobytes() 或 Numpy 的 np.array().tobytes() 得到的二进制文件,解码时应该使用tf.decode_raw()函数;
  • 使用 tf.gfile.GFile() 读取的图片文件,使用 tf.image 模块下的 tf.image.decode_png() 等解码。

如上面的例子,可以改成下面的写法:

image = cv2.imread('./imgs/img1.png')
print image.shape  # (732, 808, 3)
image_raw_data = tf.decode_raw(image.tobytes(), tf.uint8)
print image_raw_data.eval().shape  # (1774368,) // 1774368 = 732 * 808 * 3

注意 tf.decode_raw() 里面第二个形参 out_type 一定要正确,否则输出的数据的维度就不对。

(4) tf.image 中解码图像的 API:

  • tf.image.decode_png, tf.image.decode_jpeg, tf.image.decode_gif
  • tf.image.decode_image

上面的函数 tf.image.decode_png 等返回的 tensor 是有静态 shape的,而下面的 tf.image.decode_image 由于使用了 tf.cond 判断图片类型,因此返回的 tensor 没有静态 shape,造成它解码的图片无法与tf.image.resize_images() 一起使用;

现在 tf.image.decode_png,tf.image.decode_jpeg 已经能读取所有图片文件类型,和非动态 gif 文件了;

务必明确指定 tf.image.decode_png() 等函数中的 channels,想要 RGB 图务必指定为 3,默认的0在 1.8 版本并未生效。

--> 总结: 使用 tf.image.decode_png() 函数,不要使用 tf.image.decode_images() 函数,注意 channels 参数手动指定。返回的数据类型为 tf.uint8。

(5) 图像的存储 API:

  • tf.image.encode_png
  • tf.image.encode_jpeg
encoded_image = tf.image.encode_png(img_data)
with tf.gfile.FastGFile('test.png', 'wb') as f:
    f.write(encoded_image.eval())

函数第一个参数为图片数据,shape: [height, width, channels],数据类型必须为 uint8。同时, 两个函数中提供了图片品质压缩参数,encode_png 为 compression,0-9 之间取值,数值越大压缩越严重;encode_jpeg 为 quality,0-100之间取值,数值越大质量越好。

2.2 图像大小调整

说明:后续的图像操作,很多只接受浮点图像数据,有些先把图像转成浮点,处理完成后再转为原来的数据类型;如果有多个图像处理操作,来回在 uint8 和 float32 之间的转换会导致精度损失,因此建议在图像处理之前先统一转换成 float32 类型:

img_data = tf.image.convert_image_dtype(img_data, tf.float32)
  • tf.image.resize_images(images, size, method=ResizedMethod.BILINEAR, align_corners=False)

输入:

images: 4D 的 [N, H, W, C] 或 3D 的 [H, W, C] 数据。因此这个函数支持批处理。

size: [new_height, new_width]

method: 默认双线性插值。可选0-3。0:双线性插值;1:最近邻法;2:双三次插值;3:面积插值
	
align_corners: 角度是否对齐。`记得务必设置为 True`

说明: tf 中的图片 resize 和 opencv,PIL 等主流库的实现不一样。opencv 等主流库在插值计算时,对齐时是把每个像素看做一个“点区域”,因此用的是中心点对齐,有个 0.5 的偏移计算设置,可参考 CSDN博客;而 tf 在实现的时候,每个像素就是一个点,align_corners 设置为 False 就是上述博客中没加偏移的情况,显然不合理;align_corners 设置为 True 就是对齐四个角顶点,连续插值。个人觉得 tf 的设置 align_corners=True 更加合理。代码的区别可参加图: https://i.loli.net/2018/08/16/5b755dc364464.png。这一部分的相关讨论参见: https://github.com/tensorflow/tensorflow/issues/6720。

若要使用 opencv 的 resize 函数,需使用 tf.py_func 包装起来:

# img_data is a tensor
img = tf.py_func(lambda input: cv2.resize(input, (4, 4)), [img_data], tf.float32, stateful=False)
print img.eval()

2.3 图像裁剪

  • tf.image.resize_image_with_crop_or_pad(image, height, width)

输入: image: 4D 的 [N, H, W, C] 或 3D 的 [H, W, C] 数据。因此这个函数支持批处理。

生成一个大小为 [height, width] 的图,图标图像小于原图就裁剪,否则周围填0。注意只是裁剪或填充,没有插值。

  • tf.image.central_crop(image, central_fraction)

输入的 image 为 3D tensor。因此不能批处理。central_fraction为 (0, 1] 之间的浮点数。

做中心裁剪,central_fraction 为长和宽裁剪出的比例。比如 [100, 100] 的原图,central_fraction=0.5,那么输出[50, 50] 大小的图。

  • tf.image.crop_to_bounding_box(image, offset_h, offset_w, target_h, target_w)

输入:

image: 4D 的 [N, H, W, C] 或 3D 的 [H, W, C] 数据。因此这个函数支持批处理。

offset_h, offset_w: 裁剪区域的左上角坐标

target_h, target_w: 输出区域的大小
  • tf.image.pad_to_bounding_box(image, offset_h, offset_w, target_h, target_w)

输入:

image: 4D 的 [N, H, W, C] 或 3D 的 [H, W, C] 数据。因此这个函数支持批处理。

offset_h, offset_w: 原图上面要补充多少行0,原图左侧要补充多少列0

target_h, target_w: 输出区域的大小。多出的区域一律补0
  • tf.image.crop_and_resize(): crop 和 resize 一并做了,默认使用双线性插值,角落对齐了。略。
  • tf.image.decode_and_crop_jpeg(): 裁剪和解码一起做,但是效率更高,因为只解码要裁剪的区域。

2.4 图像翻转

  • tf.image.flip_up_down(image): 垂直翻转

    tf.image.random_flip_up_down(image): 随机垂直翻转

  • tf.image.flip_left_right(image): 左右翻转

    tf.image.random_flip_left_right(image): 随机左右翻转

  • tf.image.transpose_image(image): 沿对角线翻转。其实就是矩阵转置。

  • tf.image.rot90(image, k=1): 沿逆时针旋转 k 个 90 度。

以上函数均支持单个图像处理和批量处理。

2.5 图像色彩调整

  • tf.image.adjust_brightness(image, delta): 调整图像亮度。delta: [-1, 1] 的实数。原理就是先将图像转成 float,然后整个图像加上 delta,再转换回原图的 dtype。因此建议处理之前直接转成 float。支持批处理。
  • tf.image.adjust_contrast(image, contrast_factor): 调整对比度。contrast_factor: 对比度乘性系数,提高对比度 contrast_factor 倍。原理: (x - mean) * contrast_factor + mean。支持批处理。
  • tf.image.adjust_hue(image, delta): 调整色相。delta: [-1, 1]
  • tf.image.adjust_saturation(image, saturation_factor): 调整饱和度。原理: 先将 RGB 图像转换为 HSV 图,然后 S 通道乘以 saturation_factor,最后做有效裁剪后转换回 RGB。
  • tf.image.adjust_gamma(image, gamma=1, gain=1): 调整图像 gamma 值。
  • tf.image.per_image_standardization(image): 减均值除方差。

说明:

(1) 以上前四个均有 random 函数,如: tf.image.random_brightness, tf.image.random_contrast, 具体参见 API。

(2) 在做一连串的调整操作后,图像的数值分布可能已经越界了。因此在最后一步操作后记得做有效截断:

result = tf.clip_by_value(adjusted_image, 0, 1)

2.6 处理标注框

  • tf.image.draw_bounding_boxes(image, boxes): 在图上画标注框。

输入:

image: 4D 的 [N, H, W, C] 图像。注意数据类型必须为 float。

boxes: [batch, box_num, 4]。bounding box 的坐标,四个点的坐标格式是 [y_min, x_min, y_max, x_max],这四个参数都是 [0, 1] 的数,表示比例。

输出: 带框的图像。

# img: [h, w, c]
batch_img = tf.expand_dims(img, 0) # [1, 400, 800, 3]
result = tf.image.draw_bounding_boxes(batch_img, boxes=[[[0.1, 0.3, 0.5, 0.7]]])
plt.imshow(result.eval()[0])
plt.show()

原图上框的位置: 左上角[240, 40],右下角[560, 200]

注意,框越界并不会报错。

  • tf.image.sample_distorted_bounding_box(image_size, bounding_boxes, min_object_covered=0.1, aspect_ratio_range, max_attempts=None, use_image_if_no_bounding_boxes=None): 根据 Ground_truth 标记框的位置,根据一些约束随机生成图像。用于数据扩充。

输入:

image_size: [h, w, c] 原图的 shape。

bounding_boxes:  [batch, box_num, 4], ground_truth 的位置信息。坐标格式仍然是 [y_min, x_min, y_max, x_max],四个参数均为 [0, 1] 之间的浮点数,表示比例。

min_object_covered: 默认 0.1。提取的区域至少包含某个 gt 标注框的比例的百分比。

aspect_ratio_range: list, 默认 [0.75, 1.33]。提取区域的宽高比的范围 (ratio = width / height)。

max_attempts: 默认100。最大尝试次数。

use_image_if_no_bounding_boxes: 默认 False。不提供标记框时是否返回原图。

输出:元组 (begin, size, bboxes):

begin: [offset_height, offset_width, 0]。输出区域起点,即左上角的坐标。

size: [target_height, target_width, -1]。输出区域的大小。

bboxes: 输出框的坐标,shape: [1, 1, 4]。

其中输出的 begin 和 size 可作为 tf.slice 的输入,bboxes 可作为 tf.image.draw_bounding_boxes 的输入。

# img: [h, w, c]
begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(tf.shape(img),[[[0.1, 0.3, 0.5, 0.7]]], min_object_covered=0.4)
# distorted 为随机提取出来的图像区域
distorted = tf.slice(resized, begin, size)

**小坑注意:**上面的例子中,由于 tf.slice 收到的 size 中最后一个维度的数是 -1, 意思是多个通道全要提取,具体有几个通道是动态的(要根据输入来定),因此导致输出的 distorted 最后一个维度也是动态的。

print distorted.shape  # (?, ?, ?)
result = tf.image.resize_images(distorted, [200, 200])
print result.shape  # (200, 200, ?)

由于后面对 distorted 的操作一般不会涉及到图像通道数,为了图像的维度的 shape 能正常获取,最好在 tf.slice 后手动设定一下 shape:

# Restore the shape since the dynamic slice based upon the bbox_size loses
# the third dimension.
distorted.set_shape([None, None, 3])
print distorted.shape  # (?, ?, 3)
result = tf.image.resize_images(distorted, [200, 200])
print result.shape  # (200, 200, 3)

2.7 颜色空间转换

  • tf.image.rgb_to_grayscale
  • tf.image.grayscale_to_rgb
  • tf.image.hsv_to_rgb
  • tf.image.rgb_to_hsv

注意,以上操作均需要图片的数据类型为 tf.float32

2.8 Inception 图像预处理

这部分的官方代码在 slim/preprocessing/inception_preprocessing.py

需要注意的两点:

(1) distorted_bounding_box_crop 函数的 min_object_covered 默认取 0.1,根据具体数据集调整这个参数比较合适。

(2) resize 都没有 align_corners.

import tensorflow as tf

from tensorflow.python.ops import control_flow_ops

def apply_with_random_selector(x, func, num_cases):
    """
    从 func(x, 0), func(x, 1), ..., func(0, num_cases-1) 中随机选一个
    """
    sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
    # merge(inputs): 依次判断 inputs 中的元素是否存在, 返回第一个存在的元素的[data, data_index]
    # switch(data, pred): 返回 (output_false, output_true), 如果 pred 为 true, output_true = data, output_false
    # 不存在; 反过来也一样
    return control_flow_ops.merge([
        func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
        for case in range(num_cases)])[0]


def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
    """
    对图片进行随机色彩变换:调整亮度、饱和度、色相、对比度。
    
    Args:
        image (float): [0, 1] 之间的 3D tensor 图像
        color_ordering (int, optional): Defaults to 0. 可取 0-3,代表不同的随机变换模式。
        fast_mode (bool, optional): Defaults to True. 快速模式下不采用调整色相和对比度的变换。
        scope (optional): Defaults to None.
    
    Returns:
        颜色变换处理后的 float32 图像。
    """
    with tf.name_scope(scope, 'distort_color', [image]):
        if fast_mode:
            if color_ordering == 0:
                image = tf.image.random_brightness(image, max_delta=32. / 255.)
                image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
            else:
                image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
                image = tf.image.random_brightness(image, max_delta=32. / 255.)
        else:
            if color_ordering == 0:
                image = tf.image.random_brightness(image, max_delta=32. / 255.)
                image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
                image = tf.image.random_hue(image, max_delta=0.2)
                image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
            elif color_ordering == 1:
                image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
                image = tf.image.random_brightness(image, max_delta=32. / 255.)
                image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
                image = tf.image.random_hue(image, max_delta=0.2)
            elif color_ordering == 2:
                image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
                image = tf.image.random_hue(image, max_delta=0.2)
                image = tf.image.random_brightness(image, max_delta=32. / 255.)
                image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
            elif color_ordering == 3:
                image = tf.image.random_hue(image, max_delta=0.2)
                image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
                image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
                image = tf.image.random_brightness(image, max_delta=32. / 255.)
            else:
                raise ValueError('color_ordering must be in [0, 3]')

        # 有效截断
        return tf.clip_by_value(image, 0.0, 1.0)

# min_object_covered 默认为 0.1,可以根据实际任务稍微改大一点
def distorted_bounding_box_crop(image,
                                bbox,
                                min_object_covered=0.1,
                                aspect_ratio_range=(0.75, 1.33),
                                area_range=(0.05, 1.0),
                                max_attempts=100,
                                scope=None):
    with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
        sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
            tf.shape(image),
            bounding_boxes=bbox,
            min_object_covered=min_object_covered,
            aspect_ratio_range=aspect_ratio_range,
            area_range=area_range,
            max_attempts=max_attempts,
            use_image_if_no_bounding_boxes=True)
        bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box

        cropped_image = tf.slice(image, bbox_begin, bbox_size)
        return cropped_image, distort_bbox

def preprocess_for_train(image, height, width, bbox,
                         fast_mode=True,
                         scope=None,
                         add_image_summaries=True):
    """
    训练集数据预处理。
    
    Args:
        image: 输入图像, uint8 或者 float32(都会被转换为 float32). shape: [H, W, C]
        bbox: shape: [1, num_boxes, coords]. coords: [ymin, xmin, ymax, xmax]. 为 None 时取原图整图.
        fast_mode: resize 插值和颜色变换是否采用快速模式。
        add_image_summaries: 是否画出画出中间处理过程得到的图。
    
    Returns:
        3D float 图像。范围在 [-1, 1] 之间。
    """
    with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]):
        # bbox 为空取原图
        if bbox is None:
            bbox = tf.constant([0.0, 0.0, 1.0, 1.0],
                                dtype=tf.float32,
                                shape=[1, 1, 4])
        if image.dtype != tf.float32:
            image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
                                                    bbox)
        if add_image_summaries:
        tf.summary.image('image_with_bounding_boxes', image_with_box)

        distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
        # Restore the shape since the dynamic slice based upon the bbox_size loses
        # the third dimension.
        distorted_image.set_shape([None, None, 3])
        image_with_distorted_box = tf.image.draw_bounding_boxes(
            tf.expand_dims(image, 0), distorted_bbox)
        if add_image_summaries:
        tf.summary.image('images_with_distorted_bounding_box',
                        image_with_distorted_box)

        # 快速模式下: resize 采取双线性插值,否则是选择随机方法插值。
        num_resize_cases = 1 if fast_mode else 4
        distorted_image = apply_with_random_selector(
            distorted_image,
            lambda x, method: tf.image.resize_images(x, [height, width], method),
            num_cases=num_resize_cases)

        if add_image_summaries:
        tf.summary.image('cropped_resized_image',
                        tf.expand_dims(distorted_image, 0))

        # 随机左右翻转
        distorted_image = tf.image.random_flip_left_right(distorted_image)

        # 调用 distort_color 随机做颜色变换,选择是否采用快速模式。
        num_distort_cases = 1 if fast_mode else 4
        distorted_image = apply_with_random_selector(
            distorted_image,
            lambda x, ordering: distort_color(x, ordering, fast_mode),
            num_cases=num_distort_cases)

        if add_image_summaries:
        tf.summary.image('final_distorted_image',
                        tf.expand_dims(distorted_image, 0))
        # [0, 1] 之间的图,变为 [-1, 1] 之间。
        distorted_image = tf.subtract(distorted_image, 0.5)
        distorted_image = tf.multiply(distorted_image, 2.0)
        return distorted_image


def preprocess_for_eval(image, height, width,
                        central_fraction=0.875, scope=None):
    """
    验证集数据预处理。做中心裁剪,双线性插值 resize,数值范围变换到 [-1, 1] 之间。
    
    Returns:
        3D float 图像。范围在 [-1, 1] 之间。
    """

    with tf.name_scope(scope, 'eval_image', [image, height, width]):
        if image.dtype != tf.float32:
            image = tf.image.convert_image_dtype(image, dtype=tf.float32)
        # 取出中间 0.875 的区域
        if central_fraction:
            image = tf.image.central_crop(image, central_fraction=central_fraction)

        if height and width:
            # resize 图像,采用双线性插值
            image = tf.expand_dims(image, 0)
            image = tf.image.resize_bilinear(image, [height, width],
                                            align_corners=False)
            image = tf.squeeze(image, [0])
            # 取值范围变为 [-1, 1] 之间
            image = tf.subtract(image, 0.5)
            image = tf.multiply(image, 2.0)
        return image


def preprocess_image(image, height, width,
                     is_training=False,
                     bbox=None,
                     fast_mode=True,
                     add_image_summaries=True):
    """
    训练集和验证集图像预处理。
    
    Args:
        image: 输入图像, uint8 或者 float32(都会被转换为 float32). shape: [H, W, C]
        is_training: 训练集还是测试集
        bbox: 标记框,默认 None 表示取原图整图
        fast_mode: resize 和颜色变换是否采用快速模式
        add_image_summaries: 中间处理后的图是否画图。
    
    Returns:
        3D float 图像。范围在 [-1, 1] 之间。
    """

    if is_training:
        return preprocess_for_train(image, height, width, bbox, fast_mode,
                                    add_image_summaries=add_image_summaries)
    else:
        return preprocess_for_eval(image, height, width)

其他的如 VGG 预处理,可参见: vgg_preprocessing.py,VGG 裁剪图片时是确定短边长度后再等比例 resize.

3. 队列与多线程(旧 API)

3.1 队列 Queue

TensorFlow 中提供的队列有:

  • tf.FIFOQuese: 先进先出队列
  • tf.RandomShuffleQueue: 随机顺序出列的队列。注意:测试发现,在满足队列容量 > min_after_dequeue 条件下,每 dequeue 一次,整个 Queue 就要 Shuffle 一次。
  • tf.PaddingFIFOQueue: 以固定长度批量出列的队列
  • tf.PriorityQueue: 带优先级出列的队列

这些类型的 Queue 的 API 大致差不多,以 FIFOQueue 为例:

tf.FIFOQueue(capacity, dtypes, shapes=None): capacity 为队列容量,dtypes 为队列中元素的数据类型,shapes 为队列中元素的 shape。
-> 常用函数:
- dequeue(): 出队一个元素
- dequeue_many(n): 出队 n 个元素。使用此函数必须手动指定 shapes。
- enqueue(): 入队一个元素
- enqueue_many(val): 入队多个元素. 注意这里 val 要比基元多一维。
- size(): 返回队列中元素多少。返回数据类型为 tensor。

注意: 队列中存储了 capacity 个基元,以 list 形式存在。当基元中包含多个数据时,dtypes 是个 list,长度与基元长度要相同。若要使用 dequeue_many(n),shape 必须手动指定。

q = tf.FIFOQueue(5, tf.int32, shapes=[()])
op1 = q.enqueue_many([[1, 2]])
op1.run()
op2 = q.enqueue([3])
op2.run()
print q.size().eval()  # 3
print q.dequeue().eval()  # 1
print q.dequeue_many(2).eval()  # [2, 3]

TODO(20180817): 当基元有多个元素时,enqueue_many 表现很奇怪,尚未弄懂。因此暂时避免使用 enqueue_many 和 dequeue_many。

3.2 tf.train.Coordinator 与 tf.train.QueueRunner

  • tf.train.Coordinator: tf 中用来协调线程运行的工具。主要用于协调线程停止。
  • tf.train.QueueRunner: tf 中对操作队列多线程的封装,用于创造线程。

3.2.1 tf.train.Coordinator

该类的常用方法:

  • should_stop(): 查询线程是否要终止。返回布尔值。
  • request_stop(): 请求终止线程。每一个线程都可以调用 request_stop() 来请求终止其他线程,这样其他线程在检查 should_stop() 时得到 True,因此就会终止线程。
  • clear_stop(): 清除线程终止信号。
  • join(threads=None, stop_grace_period_secs=120): 阻塞线程。发出线程终止请求后,其他线程必须在 stop_grace_period_secs 时间内完成终止,否则抛出异常。
  • stop_on_exception(): 上下文管理。当发生异常是,请求终止线程。

使用方法:

try:
    coord = tf.train.Coordinator()
    ... codes of creating threads ...
    coord.join(threads)
except Exception as e:
    ... some codes ...

上面创建线程的代码为:

try:
    while not coord.should_stop():
        ... some work ...
except Exception as e:
    coord.request_stop(e)
coord.join(threads)

可以用 stop_on_exception () 简化上面创建线程的代码:

with coord.stop_on_exception():
    while not coord.should_stop():
        ... some work ...
coord.join(threads)

3.2.2 tf.train.QueueRunner

tf 中的多线程使用的队列启动方案。与 tf.train.Coordinator 一起使用。

比如一个经典的文件输入流程: 第一批线程通过往第一个队列里面不断填充要处理的文件名;第二批线程从前面的队列取出文件名,然后进行读取处理等操作,得到的张量放在第二个队列;第三批线程从第二个队列中取出张量,组成 batch,输入网络进行训练。

初始化方法:

__init__(queue=None, enqueue_ops=None)

其中 queue 为要操作的队列,enqueue_ops 为要对该队列执行的多线程操作。

tf.train.QueueRunner 常与以下两个类一起使用:

  • tf.train.add_queue_runner(qr, collection=tf.GraphKeys.QUEUE_RUNNERS): 将一个 QueueRunner 添加到图的 collection 中。默认放在 tf.GraphKeys.QUEUE_RUNNERS 这个 collection 中。
  • tf.train.start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection=tf.GraphKeys.QUEUE_RUNNERS ): 启动图的 collection 中的所有 QueueRunner。默认的 collection 为 tf.GraphKeys.QUEUE_RUNNERS。daemon: 是否为守护进程; start: 是否启动线程,不启动就只是创建线程。

一个完整的例子:

# 创建队列,入队操作
queue = tf.FIFOQueue(100, tf.float32)
enqueue_op = queue.enqueue([tf.random_normal([])])

# 开启 5 个线程
qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)
tf.train.add_queue_runner(qr)

out_tensor = queue.dequeue()

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    enqueue_threads = tf.train.start_queue_runners(sess=sess, coord=coord)

    for i in range(5):
        print sess.run(out_tensor)
    
    # 终止所有线程
    coord.request_stop()
    coord.join(enqueue_threads)

当然,上面也可以不使用 tf.train.add_queue_runner 和 tf.train.start_queue_runner 来启动线程。直接用 QueueRunner 的 create_threads(sess, coord=None, daemon=False, start=False) 来启动多个线程:

queue = tf.FIFOQueue(100, tf.float32)
enqueue_op = queue.enqueue([tf.random_normal([])])

# 开启 5 个线程
qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)

out_tensor = queue.dequeue()

with tf.Session() as sess:
    coord = tf.train.Coordinator()
    # 手动启动 qr 负责的入队线程
    enqueue_threads = qr.create_threads(sess=sess, coord=coord, start=True)

    for i in range(5):
        print sess.run(out_tensor)
    
    coord.request_stop()
    coord.join(enqueue_threads)

3.2.3 输入文件队列

  • tf.train.match_filenames_once(pattern): 根据正则表达式获取符合要求的文件名列表。返回一个局部变量,为文件名列表。因此使用这和函数务必初始化局部变量

  • tf.train.string_input_producer(string_tensor, num_epochs=None, shuffle=True, capacity=32): 根据文件名返回一个文件名队列,供多线程使用。string_tensor: 装有文件名的 tensor,可由 match_filenames_onces 函数返回,也可用 Python glob 生成等;num_epochs: 加载文件列表的最大轮数,默认无限循环;shuffle: 文件名加入队列前是否打乱;capacity: 队列长度。

    注意: 从源码中发现,该函数创建了局部变量(对 epoch 进行计数),因此使用时要初始化局部变量(其实不初始化也可以,就不使用 epoch 这个变量)。该函数返回一个 FIFOQueue 用于存放文件名,并生成一个 QueueRunner 进行单线程入队操作,该 QueueRunner 放在默认的 tf.GraphKeys.QUEUE_RUNNERS 这个 collection 中。

    在指定 num_epoches (如测试时指定为 1 ) 时,队列为空后继续出队,抛出 OutOfRange 异常。

sess = tf.InteractiveSession()
# 获取所有 png 图片文件列表
files = tf.train.match_filenames_once('some_path/*.png')

filename_queue = tf.train.string_input_producer(files, shuffle=False)
coord = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess, coord)
tf.local_variables_initializer().run()

print filename_queue.dequeue().eval()  # 获取一个文件名

3.2.4 Readers

前面 tf.train.string_input_producer 生成了文件名队列,tf 通过各种 Reader 从这个文件名队列中取文件名,进行文件读取解析。常用的 Reader 有:

  • tf.TFRecordReader: 读取 TFRecord 文件
  • tf.WholeFileReader: 读取一个文件的全部
  • tf.TextLineReader: 读取文本文件
  • tf.FixedLengthRecordReader: 读取固定长度的文件

它们的 API 大致相同。常用的方法有:

  • read(queue): 输入为一个 Queue,以 (key, value) 的形式输出单个文件。读取后 Queue 出队一个。
  • read_up_to(queue, num_records): 以 (keys, values) 的形式输出多个文件。

用 Reader 实现 2.1 节的图片读取:

files = tf.train.match_filenames_once('some_path/*.png')
filename_queue = tf.train.string_input_producer(files, shuffle=False)
key, value = tf.WholeFileReader().read(filename_queue)
image = tf.image.decode_png(value, channels=3)

with tf.Session() as sess:
    sess.run(tf.local_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    print sess.run(image)

单线程,效率较低。

3.2.5 组合训练数据 (batching)

一个高效的数据 pipeline 应该是一个生产者——消费者模型,即生产者是一个文件名队列,里面存储要处理的文件的名字,这个队列用单线程处理即可;消费者为多线程从生产者队列里面取出文件名,进行文件读取和预处理,然后放置在另一个队列,最终的数据从这个队列中取出。

一个完整的流程如图:

TensorFlow 中提供的对应的 batch 数据的函数为:

  • tf.train.batch(tensors, batch_size, num_threads=1, capacity=32, enqueue_many=False, dynamic_pad=False, allow_smaller_final_batch=False): 组合输入的 tensors, 形成一个 batch,放置在新建的一个 Queue 中,并同时新建一个 QueueRunner (添加到默认的 QueueRunner collection)。allow_smaller_final_batch 为 True 时,最后的不够一个 batch 的数据也会留下。返回一个 batch 的数据。
  • tf.train.shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, dynamic_pad=False, allow_smaller_final_batch=False): 产生打乱顺序的 batch,。与 tf.train.batch 的唯一区别是多了 min_after_dequeue 参数,限制了出队操作后队列中最少元素的数量。因为当队列中元素太少时,打乱的意义就不大了。创建的 Queue 类型是 RandomShuffleQueue。

注意: 以上 dynamic_pad 为 False 时,传入的 tensors 必须显式确定,否则抛出异常。

  • tf.train.batch_join(tensors_list, batch_size, capacity=32, allow_smaller_final_batch=False)
  • tf.train.shuffle_batch_join(tensors_list, batch_size, capacity, min_after_dequeue, allow_smaller_final_batch=False)

一般我们的文件不止一个,比如 TensorFlow Performance Guide 建议,把大数据文件分割成多个约为 100 MB 的 TFRecord 文件,I/O 性能比较好。这种情况下,多文件,多线程进行读取和预处理操作应该用上面两个函数。

其中,输入的 tensors_list 为 a list of tuples of tensors,创建 len(tensors_list) 个线程,每个线程读取一个文件,然后压入队列:

# features 为解析的 TFRecord 文件
image, label = features['image'], features['label']

# 1. 使用 tf.train.batch:
# 返回的 iamge 是个 [N, H, W, C] 的 tensor
# 记住 image 和 label 要一起 run,不然就错位交叉了
image, label = tf.train.batch([image, label], batch_size=10, num_threads=1, capacity=100)

# 2. 使用 tf.train.batch_join:
image, label = tf.train.batch_join([[image, label] for _ in range(4)], batch_size=10, capacity=100)

比较: tf.train.batch 是多线程读取一个文件,tf.train.batch_join 是多线程读取多个文件,每个线程负责一个文件。如果同一个文件中样本相似,用 tf.train.batch 显然不合适;使用 tf.train.batch_join() 时,如果线程数大于文件数,那么也存在多个线程读取同一个文件的情况,而且多线程读多个文件的硬盘寻址也是有时间开销的,可能会让效率变低。

3.2.6 Inception 数据输入框架 (旧 API)

# coding: utf-8

import tensorflow as tf

files = tf.train.match_filenames_ones('some_path/data.tfrecords-*')
# 训练集,文件名打乱
filename_queue = tf.train.string_input_producer(files, shuffle=True, )

_, serialized_example = tf.TFRecordReader().read(filename_queue)
features = tf.parse_single_example(serialized_example, features={
    'image': tf.FixedLenFeature([], tf.string),
    'label': tf.FixedLenFeature([], tf.int64),
    'height': tf.FixedLenFeature([], tf.int64),
    'width': tf.FixedLenFeature([], tf.int64),
    'channels': tf.FixedLenFeature([], tf.int64)
})

image, label, height, width, channels = features['image'], features['label'], features['height'], features['width'], features['channels']

decoded_image = tf.decode_raw(image, tf.uint8)
decoded_image = tf.reshape(decoded_image, tf.stack([height, width, channels]))

image_size = 299

# 前面的 inception 预处理代码
distorted_image = preprocessing_for_train(decoded_image, image_size, image_size, None)

# 这一部分可以具体调整
num_threads = 10
batch_size = 50
min_after_dequeue = 1000
# as suggested by https://www.tensorflow.org/api_guides/python/reading_data#Batching
capacity = min_after_dequeue + (num_threads + 10) * batch_size

image_batch, label_batch = tf.train.shuffle_batch([distorted_image, label], batch_size=batch_size,
    capacity=capacity, min_after_dequeue=min_after_dequeue, num_threads=num_threads)
# tf.train.shuffle_batch_join 的方案
# image_batch, label_batch = tf.train.shuffle_batch([[distorted_image, label] for i in range(num_threads)], batch_size=batch_size,
#     capacity=capacity, min_after_dequeue=min_after_dequeue)

logit = inference(image_batch)
loss = calc_loss(logit, label_batch)
train_step = ...

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer(), tf.local_variables_initializer())
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(sess=sess, coord=coord)
	
    with coord.stop_on_exception():
        while not coord.should_stop():
            for i in range(training_epoches):
                sess.run(train_step)
 
    coord.join(threads)

4. tf.data 模块 (新 API)

tf 1.4 后的数据输入框架,抛弃队列处理的旧 API,使用 Dataset 数据集提供数据的输入。这一部分比较简单,官方文档很详细。

4.1 API 概览

(1) 利用数据集的基本步骤:

  • 根据数据类型创建对应的 Dataset
  • 定义迭代器 Iterator,并进行相应的初始化
  • 预处理、shuffle、batch
  • 使用 get_next() 从迭代器中取出数据张量

(2) 常用的 Dataset:

  • tf.data.Dataset(): 所有 Dataset 的基类
  • tf.data.TextLineDataset(filenames, compression_type=None, buffer_size=None): 从一个或多个文本文件中读取内容
  • tf.data.TFRecordDataset(filenames, compression_type=None, buffer_size=None, num_parallel_reads=None): 从一个或多个 TFRecord 文件中读取内容
  • tf.data.FixedLengthRecordDataset(filenames, record_bytes, header_bytes=None, footer_bytes=None, buffer_size=None): 从一个或多个二进制文件中读取固定长度内容。

(3) Dataset 常用的属性:

  • output_shapes: Dataset 中每个元素的 shape
  • output_types: Dataset 中每个元素的 type

(4) Dataset 常用的方法:

  • batch(batch_size): 按 batch_size 取一个 batch。最后不够一个 batch_size 也会被取出来,如果不要最后的零头,使用 tf.contrib.data.batch_and_drop_remainder 方法。

  • padded_batch(batch_size, padded_shapes, padding_values=None): 生成 padded batch。

  • shuffle(buffer_size, seed=None, reshuffle_each_iteration=None): 打乱数据集。buffer_size: 缓冲区大小。默认 reshuffle_each_iteration 为 True。因此,每迭代一个元素出来,缓冲区中填充一个新元素,shuffle 一次。buffer_size 越大,打乱性能越好,但是第一次启动的时间就较长。可设置 buffer_size 为数据集大小,这样打乱充分。参考链接

  • repeat(count=None): 把 Dataset 重复 count 次。None 或 -1 表示无限循环。

  • skip(count): 跳过前 count 个元素。

  • take(count): 读取前 count 个元素。

  • shard(num_shards, index): 生成一个只包含原 Dataset 的 1/num_shards 的新 Dataset。

  • from_tensors(tensors): 根据单个元素构建 Dataset。

  • from_tensor_slices(tensors): 根据元素切片构建 Dataset。

  • from_generator(generator, output_types, out_shapes=None): 根据生成器构建 Dataset。这一部分使用 tf.py_func 实现的。

  • map(map_func, num_parallel_calls=None): map 运算,返回一个 map 运算后的 Dataset。

  • flat_map(map_func): map_func 后再 flat 展平,返回一个 Dataset。

  • filter(predicate): filter 运算,返回一个 Dataset。

  • interleave(map_func, cycle_length, block_length=1): 适用于分布式文件系统。当有多个文件时,一次对 cycle_length 个文件同时读取,block_length 是每个线程输出元素的个数。因此,map 和 flat_map 相当于 tf.train.batch,interleave 相当于 tf.train.batch_join。

  • apply(transformation_func): 对一个 Dataset 进行某个操作,类似于 map。

  • zip(datasets): 和 Python 中的 zip 函数一样,把多个 Dataset zip起来。

  • concatenate(dataset): 串接一个 Dataset, 返回一个新的合并的 Dataset。

  • prefetch(buffer_size): 预加载 buffer_size 的数据。

  • range(*args): 生成一个 RangeDataset。

  • list_files(file_pattern, shuffle=None): 根据文件名 file_pattern 正则表达式,获取一个包含这些文件的 Dataset。注意,这个顺序是不定的,即使 shuffle 为 False。看源码发现这个函数其实是先用 tf.matching_files 得到文件列表,在用 from_tensor_slices 得到一个 dataset,最后 shuffle。注意,最后一个 shuffle 默认的缓冲区为整个文件名列表。因此,由于 from_tensor_slices 和 shuffle 缓冲区长度的设定,当文件名列表过于巨大时,这一步耗时就会很大,可参考 Github

(5)迭代器:

  • one-shot iterator: Dataset.make_one_shot_iterator()。绑定一个 Dataset 的单次迭代器,只对 Dataset 进行一次迭代。只有这个不需要显式初始化。所有参数都已经确定,中间不能修改了。
  • initializable iterator: Dataset.make_initializable_iterator(): 绑定一个 Dataset 的可初始化迭代器。需要手动初始化迭代器。中间可以修改迭代器参数,可以配合 placeholder 使用。
  • reintializable iterator: tf.data.Iterator.from_structure(output_types, output_shapes=None, shard_names=None, output_classes=None): 这个迭代器没有绑定 Dataset,因此是用 Iterator 这个类创建的。中间可以重复初始化。把同一个迭代器应用到不同的数据集从而实现切换数据集的功能。需要手动初始化,需要使用 iterator.make_initializer(dataset) 来针对给定的 Dataset 初始化。
  • feedable iterator: tf.data.Iterator.from_string_handle(string_handle, output_types, output_shaps=None, output_classes=None): 这个迭代器没有绑定 Dataset。目的是切换到不同的 iterator 。string_handle 是代表一个 Iterator 的 Tensor,可先用一个 placeholder 占据,选择数据集的时候 feed 为具体 Dataset 的 Iterator (用这个 Dataset 的 Iterator.string_handle() 得到)。

4.2 案例分析

4.2.1 性能优化

以一个简单的文本处理为例,假设为分布式文件系统,有 5 个 txt 文件,每个文件里面存放某一类样本的文件名和类别标号,形如:

- file1.txt:
	cat1.jpg 0
 	cat2.jpg 0
 	...
 	cat4.jpg 0
- file2.txt
	dog1.jpg 1
	dog2.jpg 1
	...
	dog4.jpg 1
...
- file_n.txt

(1) 基本流程,无各种优化考虑:

# 获取文件列表,并按照 file_{n} 中的数字 n 从小到大排序
txt_files = glob.glob('some_path/*.txt')
txt_files.sort(key=lambda x: int(x.split('.')[-2][-1]))

# 使用 TextLineDataset 读取文本文件
dataset = tf.data.TextLineDataset(txt_files)
dataset = dataset.shuffle(buffer_size=20)
dataset = dataset.repeat(2)
# 分割字符串,按空格分割
dataset = dataset.map(lambda x: tf.string_split([x], delimiter=' ').values)
dataset = dataset.batch(2)

# 创建迭代器
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

输出的结果为:

[['bird1.jpg' '3']
 ['dog4.jpg' '1']]
[['dog2.jpg' '1']
 ['bird3.jpg' '3']]
 ...

要注意的几点:

-- 参考 4.1 中设置 shuffle 的 buffer_size

-- shuffle 和 repeat 的顺序,建议先 shuffle 再 repeat,如果反过来会造成打乱效果变差。参考: Dataset Performance Guide

(2) 基础性能优化: 多线程读取多文件,多线程 map,预加载 prefetch

上面代码的性能缺点:单线程读取,单线程 map,加上 Dataset 默认的 lazy 属性,性能地下。

基础改进代码:

txt_files = glob.glob('some_path/*.txt')
txt_files.sort(key=lambda x: int(x.split('.')[-2][-1]))

# 改进1:多线程读取文件,创建5个线程,每个线程负责一个文件读取。
# 因此改进后可以同时读 5 个文件
dataset = tf.data.Dataset.from_tensor_slices(txt_files)
dataset = dataset.interleave(lambda x: tf.data.TextLineDataset(x),
    cycle_length=5, block_length=1)

dataset = dataset.shuffle(buffer_size=20)
dataset = dataset.repeat(2)
# 改进2: 多线程 map 函数,也就是可以多线程预处理了。
dataset = dataset.map(lambda x: tf.string_split([x], delimiter=' ').values,
    num_parallel_calls=4)
dataset = dataset.batch(2)

# 改进3:预加载,输出有缓冲区。类似于消费者队列中填充一定 batch 数,等待消费。
# 这里是基于前一步操作后的元素,因此是预加载 5 个 batch。
dataset = dataset.prefetch(5)
...

TODO: 上面的 interleave 同时读取 5 个文件,但是真的是并行吗?按照 tf.data API slides 应该是并行 I/O,但是按照 Dataset Performance Guide 却建议使用 tf.contrib.data.parallel_interleave() 实现真正的并行 I/O,这里有待进一步明确,是否需要改成下面的版本:

dataset = dataset.apply(tf.contrib.data.parallel_interleave(tf.data.TextLineDataset, cycle_length=4))

(3) 进一步优化:

并行化 batch: 当 batch_size 比较大时,取 batch 也是耗时的,因此可以把 map 和 batch 放在一起做多线程,最终改成这样的版本(使用 tf.contrib.data.map_and_batch 函数):

txt_files = glob.glob('some_path/*.txt')
txt_files.sort(key=lambda x: int(x.split('.')[-2][-1]))

# 确保运行在 CPU 上
with tf.device('/cpu:0'):
    dataset = tf.data.Dataset.from_tensor_slices(txt_files)
    dataset = dataset.apply(tf.contrib.data.parallel_interleave(tf.data.TextLineDataset, cycle_length=4))

    dataset = dataset.shuffle(buffer_size=20)
    dataset = dataset.repeat(2)
    # 改动这里
    dataset = dataset.apply(tf.contrib.data.map_and_batch(lambda x: tf.string_split([x], delimiter=' ').values, batch_size=2, num_parallel_batches=4))
    dataset = dataset.batch(2)

    dataset = dataset.prefetch(5)

其他的优化还有内存 cache 等的考虑等,参考 Dataset Performance Guide

注意:一般而言, batch 耗时相对较少,当电脑核心不太够的时候,并行化的 batch 占据线程也不一定是好事,可能速度也会变慢。

4.2.2 各种迭代器的使用及比较

(1) initializable iterator: 动态指定 iterator 参数。配合 placeholder 使用。

# 指定文件名的 placeholder
txt_files = tf.placeholder(tf.string, [])

dataset = tf.data.TextLineDataset([txt_files])
dataset = dataset.map(lambda x: tf.string_split([x], delimiter=' ').values)

iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

iterator.initializer.run(feed_dict={txt_files: './txt_files/file1.txt'})
for i in range(5):
    print next_element.eval()

# 切换到另一个 txt 文件
iterator.initializer.run(feed_dict={txt_files: './txt_files/file2.txt'})
for i in range(2):
    print next_element.eval()

# 中间重新初始化,迭代器从头开始
iterator.initializer.run(feed_dict={txt_files: './txt_files/file2.txt'})
for i in range(5):
    print next_element.eval()
    

(2) reinitializable iterator: 一个可重复初始化的迭代器,绑定到不同的数据集上去。

# 两个数据集
cat_dataset = tf.data.TextLineDataset(['./txt_files/file1.txt']).map(lambda x: tf.string_split([x], delimiter=' ').values)
dog_dataset = tf.data.TextLineDataset(['./txt_files/file2.txt']).map(lambda x: tf.string_split([x], delimiter=' ').values)

# 同一个 reinitializable iterator
iterator = tf.data.Iterator.from_structure(cat_dataset.output_types, dog_dataset._output_shapes)
next_element = iterator.get_next()

# 迭代器初始化
cat_init_op = iterator.make_initializer(cat_dataset)
dog_init_op = iterator.make_initializer(dog_dataset)

# 迭代器绑定到 cat dataset
cat_init_op.run()
for _ in range(5):
    print next_element.eval()

# 迭代器绑定到 dog dataset
dog_init_op.run()
for _ in range(5):
    print next_element.eval()

通常可以用同一个迭代器绑定到 training set 和 test_set,不过,用前面的 initializer_iterator 也可实现相同功能。

(3) feedable iterator: 迭代器是可变的,目的是通过选择不同数据集的 Dataset 的迭代器来迭代不同的数据。

# 创建两个 Dataset
cat_dataset = tf.data.TextLineDataset(['./txt_files/file1.txt']).map(lambda x: tf.string_split([x], delimiter=' ').values)
dog_dataset = tf.data.TextLineDataset(['./txt_files/file2.txt']).map(lambda x: tf.string_split([x], delimiter=' ').values)

# 创建两个 Dataset 对应的迭代器
cat_iterator = cat_dataset.make_one_shot_iterator()
dog_iterator = dog_dataset.make_initializable_iterator()

# 创建代表两个 Dataset 的迭代器的 handle tensor
cat_handle = cat_iterator.string_handle().eval()
dog_handle = dog_iterator.string_handle().eval()

# 根据传入的 handle 选择具体的迭代器
handle = tf.placeholder(tf.string, [])
iterator = tf.data.Iterator.from_string_handle(handle, cat_dataset.output_types, cat_dataset.output_shapes)
next_element = iterator.get_next()

while True:
    # 这里切换数据集,之后再切回 cat_dataset 是继续之前的数据迭代。
    for _ in range(2):
        print next_element.eval(feed_dict={handle: cat_handle})

    dog_iterator.initializer.run()
    for _ in range(2):
        print next_element.eval(feed_dict={handle: dog_handle})

reinitializable iterator 和 feedable iterator 的区别:

二者都是实现切换数据集的功能,但是 reinitializable iterator 是同一个迭代器绑定到不同数据集,切换数据集的时候要重新绑定,然后初始化这个迭代器,特点是切换数据集后从头开始迭代切换到的新数据集

feedable iterator 是针对每个数据集先建立对应的迭代器,然后选择用哪一个迭代器来选择对应的数据集,如果这些迭代器在创建的时候先初始化好,那么在迭代器之间来回切换的时候,**各自的数据集是接着前一次的时间点继续输出的。**比如训练集特别大,并不想训练集完全迭代一轮再验证,迭代一定数量的较少样本后就想测试一次验证集,就可以用这种迭代器。

4.2.3 一些应用

(1) 读取 npy 文件:

可以直接用 from_tensor_slices 一次载入:

data = np.load('data.npy')
features = data["features"]
labels = data["labels"]

dataset = tf.data.Dataset.from_tensor_slices((features, labels))
...

缺点是:整个 npy 解析出来的数组以 tf.constant() 的形式存在于 Graph 中,而一个 Graph protobuf 文件的上限大小为 2G。

可以使用 initializable_iterator 配合 placeholder 载入:

data = np.load('data.npy')
features = data["features"]
labels = data["labels"]

# 创建对应的 placeholder
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
... 一些其他预处理操作等 ...

iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, feed_dict={feature_placeholder: features, labels_placeholder: labels})

上面方案其实也不好。上面两个是官方提供的方案。

Stack Overflow 上提供了其他方案:

比如这个方案计算出 npy 的 header_bytes 长度,然后用 tf.data.FixedLengthRecordDataset 来解析 npy 二进制文件。如果要并行处理,那么多个 npy 的 header_bytes 长度要一样。

另外,可以用 tf.py_func 来调用 np.load() 函数,然后放到 Dataset.map() 函数里面解决:

file_list = ['a.npy', 'b.npy', ...]

def read_npy_file(file):
    return np.load(file).as_type(np.float32)

dataset = tf.data.Dataset.from_tensor_slices(file_list)
dataset = dataset.map(lambda file: tuple(tf.py_func(
	read_npy_file, [file], [tf.float32])))
...

其他的一些应用可以参考 Tensorflow Dataset Guide -- Importing Data,里面有从 Python 迭代器得到 Dataset,读取 csv 等文件,进行 padding 操作,利用 tf.py_func 调用 opencv 等。

4.2.4 Inception 数据输入框架 (新 API)

用 tf.data 重写 3.2.6 部分:

import tensorflow as tf

IMAGE_SIZE = 299  # 输入图片大小
BATCH_SIZE = 50
SHUFFLE_BUFFER = 10000
NUM_EPOCHES = 100

# 列举 tfrecord 文件名
training_files = tf.train.match_filenames_ones('training_path/data.tfrecords-*')
val_files = tf.train.match_filenames_ones('val_path/data.tfrecords-*')

def tfrecord_parser(record):
    features = tf.parse_single_example(
        record, features={
            'image': tf.FixedLenFeature([], tf.string),
            'label': tf.FixedLenFeature([], tf.int64),
            'height': tf.FixedLenFeature([], tf.int64),
            'width': tf.FixedLenFeature([], tf.int64),
            'channels': tf.FixedLenFeature([], tf.int64)
        }
    )
    label, height, width, channels = features['label'], features['height'], features['width'], features['channels']
    decoded_image = tf.decode_raw(features['image'], tf.uint8)
    decoded_image = tf.reshape(decoded_image, tf.stack([height, width, channels]))
    return decoded_image, label

# 定义 traing dataset 和对应的迭代器
# dataset = tf.data.TFRecordDataset(training_files) 这个效率低
training_dataset = tf.data.Dataset.from_tensor_slices(training_files)
training_dataset = training_dataset.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=4)
training_dataset = training_dataset.shuffle(SHUFFLE_BUFFER)
training_dataset = training_dataset.repeat(NUM_EPOCHES)
training_dataset = training_dataset.map(tfrecord_parser, num_parallel_calls=2)
training_dataset = training_dataset.map(lambda image, label: (preprocess_for_train(image, IMAGE_SIZE, IMAGE_SIZE), label), num_parallel_calls=4)
training_dataset = training_dataset.batch(BATCH_SIZE)
training_dataset = training_dataset.prefetch(5)

training_iterator = training_datasetmake_initializable_iterator()
training_image_batch, training_label_batch = training_iterator.get_next()

# 网络训练部分
logit = inference(training_image_batch)
loss = calc_loss(logit, training_label_batch)
train_step = ...

# 创建 validation dataset 和对应的迭代器
val_dataset = tf.data.TFRecordDataset(val_files)
val_dataset = val_dataset.map(tfrecord_parser)
val_dataset = val_dataset.map(lambda image, label: (preprocess_for_val(image, IMAGE_SIZE, IMAGE_SIZE), label), num_parallel_calls=2)
val_dataset = val_dataset.batch(BATCH_SIZE)

val_iterator = val_dataset.make_initializable_iterator()
val_image_batch, val_label_batch = val_iterator.get_next()

# 网络测试部分
val_logit = inference(val_image_batch)
predictions = tf.argmax(val_logit, axis=-1, output_type=tf.int32)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer(), tf.local_variables_initializer())

    sess.run(training_iterator.initializer)
    while True:
        try:
            sess.run(train_step)
        except tf.errors.OutOfRangeError:
            break
    
    sess.run(val_iterator.initializer)
    val_results = []
    val_labels = []
    while True:
        try:
            pred, label = sess.run([predictions, val_label_batch])
            val_results.extend([pred])
            val_labels.extend([label])
        except tf.errors.OutOfRangeError:
            break
    
    # 计算准确率等指标
    correct = [float(y == y_) for (y, y_) in zip(val_results, val_labels)]
    acc = sum(correct) / len(correct)
如果你这里看到有广告,不妨点击一下,就是对本站最大的支持~



本文由 CaptainChen 创作
该文章采用 知识共享署名-非商业性使用 4.0 国际许可协议进行许可。转载请注明出处!
CopyRight © 2017 - 2020
本站已稳定运行 天 总访问量