CaptainChen

TensorFlow 学习笔记 -- Input Pipeline

本篇笔记主要总结了如何在 TensorFlow 如何构建高效的 Input Pipeline,目的是协调 CPU 文件预处理和 GPU 模型计算之间的调度,尽最大限度发挥 GPU 算力。其中涉及到 TFRecord 文件的读写,tf.image 模块对图像的处理,以及版本 1.4 前使用的生产者/消费者多线程文件读写流程,和 1.4 后官方主推的 Dataset 处理方式。后者已经开始逐步支持 eager 模式。

吐槽部分:这一部分学完后感触很深,tf 真的很难,它已经从一个 Python 库上升到了一个十分复杂的编程语言的高度,让人抓狂。大部分 API 文档是根据代码注释生成的,往往晦涩难懂,由于时间和精力原因去深入庞杂的 C++ 底层几乎不可能。由此带来的最大困扰是,you seldom know what is happening under the hood。比如有多少人知道 tf.image 模块的 resize 的插值方法中,align corners 和主流图像库处理的方式都不一样?在学习这部分的过程中,我翻阅了大量博客,Github Issue,Stack Overflow 回答,多次对着源码逐行分析并进行代码测试,最终总结成个人笔记可谓是“满纸荒唐言,一把辛酸泪”。不过,也正是学习这部分的钻研过程中,感慨 tf 真的是为了实际应用做了充分的考虑,正如 caffe 作者贾扬清在知乎说道: “TF的确难,但是它给你提供了真正可以产品化的可能性。”

先把这部分我个人觉得要注意的一些点(或者说是大坑)列举一下,方便日后查阅:

  • 读取 TFRecord 文件过程中,解析 Example Protobuf 文件时,decode_raw 得到的数据(如 image raw data) 要通过 reshape 操作恢复 shape,而 shape 参数也是从 TFRecord 文件中获取时,要加 tf.stack 操作: image = tf.reshape(image, tf.stack([height, width, channels]))
  • 读取图片时,PIL 的 Image.tobytes() 或 Numpy 的 np.array().tobytes() 得到的二进制文件搭配 tf.decode_raw() 使用;tf.gfile.GFile(img_name) 读取的图片文件要用 tf.image.decode_png() 等函数解析
  • tf.image.decode_image 得到的图像没有静态 shape,因此无法与 tf.image.resize_images 一起使用。应该尽量使用 tf.image.decode_png 函数
  • tf.image.resize_images 中有个 align_corners 参数,它的机制与主流图像库如 opencv 是不一样的,建议设置为 True。不然后果可参见 How Tensorflow’s tf.image.resize stole 60 days of my life
  • tf.image.sample_distorted_bounding_box 得到的 bbox 信息配合 tf.slice 裁剪得到的图像的最后一个维度是动态的,建议后面跟一个 set_shape 操作
  • tf.image.distorted_bounding_box_crop 函数的 min_object_covered 默认取 0.1,很多时候是不合适的
  • tf.train.match_filenames_once 和 tf.train.string_input_producer 都创建了局部变量,Session 中需要初始化。
  • tf.train.batch 和 tf.train.batch_join 的多线程不一样,后者才是一个文件分配一个线程
  • 用 tf.train.batch 等操作出队得到的多元 Tensor tuple 一定要一起 run,不然会错位交叉。参见 3.2.5 例子。
  • Dataset 模块中, shuffle 的 buffer_size 要好好设计,同时 shuffle 和 repeat 的顺序要注意
  • Dataset.from_tensor_slices 是把传进来的输入以 tf.constant() 的形式存在于 Graph 中,而一个 Graph Protobuf 文件的上限是 2G

推荐阅读: CS230 课件 “An overview of tf.data” 部分列举的所有链接。



1. TFRecord

1.1 TFRecord 格式

TFRecord 文件的数据是通过 tf.train.Example 这个 protobuf 的格式存储的。其定义在 tensorflow/core/example/example.prototensorflow/core/example/feature.proto :

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
message Example {
// type: Features, name: features
Features features = 1;
};

message Features {
// Map from feature name to feature.
// Features is a key-value store, where each key (string)
// maps to a Feature message
map<string, Feature> feature = 1;
};

// Containers for non-sequential data.
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};

// Containers to hold repeated fundamental values.
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}

tf.train.Example 中包含名为 features 的 Features 的信息,每个 Features 中包含一个从 feature 名称到 feature 属性 (Feature) 的字典映射(key-value对),其中每个 Feature 的可以从 BytesList (字符串),FloatList (浮点实数列表), Int64List (整数列表)中取。
一个 tf.train.Example 的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
features {
feature {
key: "age"
value { float_list {
value: 29.0
}}
}
feature {
key: "movie"
value { bytes_list {
value: "The Shawshank Redemption"
value: "Fight Club"
}}
}
feature {
key: "movie_ratings"
value { float_list {
value: 9.0
value: 9.7
}}
}
}

1.2 文件写入 TFRecord

样例程序:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import tensorflow as tf
import os
import cv2

# 输出的 TFRecord 文件带路径的名称
output = './output.tfrecords'
# 创建一个 writer 来写入 TFRecord 文件
writer = tf.python_io.TFRecordWriter(output)

# 辅助函数,生成整数和字符串型的 tf.train.Feature。
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))

def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

# 读取图片,并保存到 TFRecord 文件中
img_dir = './img'
imgs = os.listdir(img_dir)
imgs.sort()

for index, img in enumerate(imgs):
img_path = os.path.join(img_dir, img)
img_data = cv2.imread(img_path)
resized_img = cv2.resize(img_data, (128, 128, interpolation=cv2.INTER_AREA))
# 这里必须把 ndarry 转换成字符串形式的原始二进制数据流
img_raw = resized_img.tostring()

# 生成 Example Protobuf 文件
example = tf.train.Example(features=tf.train.Features(feature={
'shape_': _int64_feature(resized_img.shape[0]),
'label_': _int64_feature(index),
'img_raw': _bytes_feature(img_raw)
}))

# 将序列化后的example 写入 TFRecord 文件
writer.write(example.SerializeToString())

writer.close()

1.3 读取 TFRecord 文件

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import tensorflow as tf

# 注意默认 shuffle = True
# 返回一个队列 Queue 对象
filename_queue = tf.train.string_input_producer(['./output.tfrecords'], shuffle=False)

# 创建一个 reader 读取 TFRecord 文件中的样例
reader = tf.TFRecordReader()
# 一次读取一个样例。也可以使用 read_up_to 函数一次读取多个样例
# Returns the next record (key, value) pair produced by a reader.
_, serialized_example = reader.read(filename_queue)

# 解析单个 features 文件; 解析多个用 parse_example 函数
features = tf.parse_single_example(
serialized_example,
features={
'shape_': tf.FixedLenFeature([], tf.int64),
'label_': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
})
# 以上 tf.FixedLenFeature 方法解析的结果为一个 tensor。另一种方法是 tf.VarLenFeature, 解析的
# 结果是 SparseTensor

# 将字符串 tensor 解析成数据 tensor,注意此时为一维数据,需要 reshape
image = tf.decode_raw(features['img_raw'], tf.uint8)
image_shape = tf.stack([shape_, shape_, 3]) # 这一行不能少
image = tf.reshape(image, image_shape)
# 默认是 int64 的 tensor, 转成 int32
shape = tf.cast(features['shape_'], tf.int32)
label = tf.cast(features['label_'], tf.int32)

sess = tf.Session()
# 多线程部分参见第3部分
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

for i in range(4):
print sess.run([image, shape, label])

2. tf.image 图像处理

2.1 图像读取,编码

读取示例:

1
2
3
4
5
6
7
8
9
tf.InteractiveSession()

image_raw_data = tf.gfile.FastGFile('./imgs/img1.png', 'rb').read()
# 一定要明确指定3通道,默认自动识别好像有bug,不生效
img_data = tf.image.decode_png(image_raw_data, channels=3)

# (732, 808, 3) uint8
# 格式: [H, W, C]
print img_data.eval().shape, img_data.eval().dtype

注意:

(1) tf.gfile 模块为 TensorFlow 的文件读写模块,C++ 实现,API 与 Python 自带文件模块高度相近。区别在于,tf.gfile 实现了多种文件系统的读写,比如本地文件,谷歌云存储文件(前缀 gs://),HDFS 文件(前缀 hdfs://) 等等。TensorFlow 中写入载入 checkpoints,TensorBoard 日志文件等都是用 tf.gfile 模块实现。

简言之,tf.gfile 实现了更多文件读写接口,在处理日常本地文件读写时,tf.gfile 并没有明显的速度优势,用 python 自带文件读写模块就可以了。

(2) tf.gfile.GFile, tf.gfile.FastGFile 二者在r1.8源码实现上并无区别,都是没有线程锁的文件 I/O,所以二者等同。

(3) 图像编码时的搭配:

  • 使用 PIL 的 Image.tobytes() 或 Numpy 的 np.array().tobytes() 得到的二进制文件,解码时应该使用tf.decode_raw()函数;
  • 使用 tf.gfile.GFile() 读取的图片文件,使用 tf.image 模块下的 tf.image.decode_png() 等解码。

如上面的例子,可以改成下面的写法:

1
2
3
4
image = cv2.imread('./imgs/img1.png')
print image.shape # (732, 808, 3)
image_raw_data = tf.decode_raw(image.tobytes(), tf.uint8)
print image_raw_data.eval().shape # (1774368,) // 1774368 = 732 * 808 * 3

注意 tf.decode_raw() 里面第二个形参 out_type 一定要正确,否则输出的数据的维度就不对。

(4) tf.image 中解码图像的 API:

  • tf.image.decode_png, tf.image.decode_jpeg, tf.image.decode_gif
  • tf.image.decode_image

上面的函数 tf.image.decode_png 等返回的 tensor 是有静态 shape的,而下面的 tf.image.decode_image 由于使用了 tf.cond 判断图片类型,因此返回的 tensor 没有静态 shape,造成它解码的图片无法与tf.image.resize_images() 一起使用;

现在 tf.image.decode_png,tf.image.decode_jpeg 已经能读取所有图片文件类型,和非动态 gif 文件了;

务必明确指定 tf.image.decode_png() 等函数中的 channels,想要 RGB 图务必指定为 3,默认的0在 1.8 版本并未生效。

–> 总结: 使用 tf.image.decode_png() 函数,不要使用 tf.image.decode_images() 函数,注意 channels 参数手动指定。返回的数据类型为 tf.uint8。

(5) 图像的存储 API:

  • tf.image.encode_png
  • tf.image.encode_jpeg
1
2
3
encoded_image = tf.image.encode_png(img_data)
with tf.gfile.FastGFile('test.png', 'wb') as f:
f.write(encoded_image.eval())

函数第一个参数为图片数据,shape: [height, width, channels],数据类型必须为 uint8。同时, 两个函数中提供了图片品质压缩参数,encode_png 为 compression,0-9 之间取值,数值越大压缩越严重;encode_jpeg 为 quality,0-100之间取值,数值越大质量越好。

2.2 图像大小调整

说明:后续的图像操作,很多只接受浮点图像数据,有些先把图像转成浮点,处理完成后再转为原来的数据类型;如果有多个图像处理操作,来回在 uint8 和 float32 之间的转换会导致精度损失,因此建议在图像处理之前先统一转换成 float32 类型:

1
2
> img_data = tf.image.convert_image_dtype(img_data, tf.float32)
>
  • tf.image.resize_images(images, size, method=ResizedMethod.BILINEAR, align_corners=False)

输入:

images: 4D 的 [N, H, W, C] 或 3D 的 [H, W, C] 数据。因此这个函数支持批处理。

size: [new_height, new_width]

method: 默认双线性插值。可选0-3。0:双线性插值;1:最近邻法;2:双三次插值;3:面积插值

  align_corners: 角度是否对齐。`记得务必设置为 True`

说明: tf 中的图片 resize 和 opencv,PIL 等主流库的实现不一样。opencv 等主流库在插值计算时,对齐时是把每个像素看做一个“点区域”,因此用的是中心点对齐,有个 0.5 的偏移计算设置,可参考 CSDN博客;而 tf 在实现的时候,每个像素就是一个点,align_corners 设置为 False 就是上述博客中没加偏移的情况,显然不合理;align_corners 设置为 True 就是对齐四个角顶点,连续插值。个人觉得 tf 的设置 align_corners=True 更加合理。代码的区别可参加图: https://i.loli.net/2018/08/16/5b755dc364464.png。这一部分的相关讨论参见: https://github.com/tensorflow/tensorflow/issues/6720。

若要使用 opencv 的 resize 函数,需使用 tf.py_func 包装起来:

1
2
3
# img_data is a tensor
img = tf.py_func(lambda input: cv2.resize(input, (4, 4)), [img_data], tf.float32, stateful=False)
print img.eval()

2.3 图像裁剪

  • tf.image.resize_image_with_crop_or_pad(image, height, width)

输入: image: 4D 的 [N, H, W, C] 或 3D 的 [H, W, C] 数据。因此这个函数支持批处理。

生成一个大小为 [height, width] 的图,图标图像小于原图就裁剪,否则周围填0。注意只是裁剪或填充,没有插值。

  • tf.image.central_crop(image, central_fraction)

输入的 image 为 3D tensor。因此不能批处理。central_fraction为 (0, 1] 之间的浮点数。

做中心裁剪,central_fraction 为长和宽裁剪出的比例。比如 [100, 100] 的原图,central_fraction=0.5,那么输出[50, 50] 大小的图。

  • tf.image.crop_to_bounding_box(image, offset_h, offset_w, target_h, target_w)

输入:

image: 4D 的 [N, H, W, C] 或 3D 的 [H, W, C] 数据。因此这个函数支持批处理。

offset_h, offset_w: 裁剪区域的左上角坐标

target_h, target_w: 输出区域的大小
  • tf.image.pad_to_bounding_box(image, offset_h, offset_w, target_h, target_w)

输入:

image: 4D 的 [N, H, W, C] 或 3D 的 [H, W, C] 数据。因此这个函数支持批处理。

offset_h, offset_w: 原图上面要补充多少行0,原图左侧要补充多少列0

target_h, target_w: 输出区域的大小。多出的区域一律补0
  • tf.image.crop_and_resize(): crop 和 resize 一并做了,默认使用双线性插值,角落对齐了。略。
  • tf.image.decode_and_crop_jpeg(): 裁剪和解码一起做,但是效率更高,因为只解码要裁剪的区域。

2.4 图像翻转

  • tf.image.flip_up_down(image): 垂直翻转

    tf.image.random_flip_up_down(image): 随机垂直翻转

  • tf.image.flip_left_right(image): 左右翻转

    tf.image.random_flip_left_right(image): 随机左右翻转

  • tf.image.transpose_image(image): 沿对角线翻转。其实就是矩阵转置。

  • tf.image.rot90(image, k=1): 沿逆时针旋转 k 个 90 度。

以上函数均支持单个图像处理和批量处理。

2.5 图像色彩调整

  • tf.image.adjust_brightness(image, delta): 调整图像亮度。delta: [-1, 1] 的实数。原理就是先将图像转成 float,然后整个图像加上 delta,再转换回原图的 dtype。因此建议处理之前直接转成 float。支持批处理。
  • tf.image.adjust_contrast(image, contrast_factor): 调整对比度。contrast_factor: 对比度乘性系数,提高对比度 contrast_factor 倍。原理: (x - mean) * contrast_factor + mean。支持批处理。
  • tf.image.adjust_hue(image, delta): 调整色相。delta: [-1, 1]
  • tf.image.adjust_saturation(image, saturation_factor): 调整饱和度。原理: 先将 RGB 图像转换为 HSV 图,然后 S 通道乘以 saturation_factor,最后做有效裁剪后转换回 RGB。
  • tf.image.adjust_gamma(image, gamma=1, gain=1): 调整图像 gamma 值。
  • tf.image.per_image_standardization(image): 减均值除方差。

说明:

(1) 以上前四个均有 random 函数,如: tf.image.random_brightness, tf.image.random_contrast, 具体参见 API。

(2) 在做一连串的调整操作后,图像的数值分布可能已经越界了。因此在最后一步操作后记得做有效截断:

1
result = tf.clip_by_value(adjusted_image, 0, 1)

2.6 处理标注框

  • tf.image.draw_bounding_boxes(image, boxes): 在图上画标注框。

输入:

image: 4D 的 [N, H, W, C] 图像。注意数据类型必须为 float。

boxes: [batch, box_num, 4]。bounding box 的坐标,四个点的坐标格式是 [y_min, x_min, y_max, x_max],这四个参数都是 [0, 1] 的数,表示比例。

输出: 带框的图像。

1
2
3
4
5
# img: [h, w, c]
batch_img = tf.expand_dims(img, 0) # [1, 400, 800, 3]
result = tf.image.draw_bounding_boxes(batch_img, boxes=[[[0.1, 0.3, 0.5, 0.7]]])
plt.imshow(result.eval()[0])
plt.show()

原图上框的位置: 左上角[240, 40],右下角[560, 200]

注意,框越界并不会报错。

  • tf.image.sample_distorted_bounding_box(image_size, bounding_boxes, min_object_covered=0.1, aspect_ratio_range, max_attempts=None, use_image_if_no_bounding_boxes=None): 根据 Ground_truth 标记框的位置,根据一些约束随机生成图像。用于数据扩充。

输入:

image_size: [h, w, c] 原图的 shape。

bounding_boxes:  [batch, box_num, 4], ground_truth 的位置信息。坐标格式仍然是 [y_min, x_min, y_max, x_max],四个参数均为 [0, 1] 之间的浮点数,表示比例。

min_object_covered: 默认 0.1。提取的区域至少包含某个 gt 标注框的比例的百分比。

aspect_ratio_range: list, 默认 [0.75, 1.33]。提取区域的宽高比的范围 (ratio = width / height)。

max_attempts: 默认100。最大尝试次数。

use_image_if_no_bounding_boxes: 默认 False。不提供标记框时是否返回原图。

输出:元组 (begin, size, bboxes):

begin: [offset_height, offset_width, 0]。输出区域起点,即左上角的坐标。

size: [target_height, target_width, -1]。输出区域的大小。

bboxes: 输出框的坐标,shape: [1, 1, 4]。

其中输出的 begin 和 size 可作为 tf.slice 的输入,bboxes 可作为 tf.image.draw_bounding_boxes 的输入。
1
2
3
4
# img: [h, w, c]
begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(tf.shape(img),[[[0.1, 0.3, 0.5, 0.7]]], min_object_covered=0.4)
# distorted 为随机提取出来的图像区域
distorted = tf.slice(resized, begin, size)

小坑注意:上面的例子中,由于 tf.slice 收到的 size 中最后一个维度的数是 -1, 意思是多个通道全要提取,具体有几个通道是动态的(要根据输入来定),因此导致输出的 distorted 最后一个维度也是动态的。

1
2
3
print distorted.shape  # (?, ?, ?)
result = tf.image.resize_images(distorted, [200, 200])
print result.shape # (200, 200, ?)

由于后面对 distorted 的操作一般不会涉及到图像通道数,为了图像的维度的 shape 能正常获取,最好在 tf.slice 后手动设定一下 shape:

1
2
3
4
5
6
# Restore the shape since the dynamic slice based upon the bbox_size loses
# the third dimension.
distorted.set_shape([None, None, 3])
print distorted.shape # (?, ?, 3)
result = tf.image.resize_images(distorted, [200, 200])
print result.shape # (200, 200, 3)

2.7 颜色空间转换

  • tf.image.rgb_to_grayscale
  • tf.image.grayscale_to_rgb
  • tf.image.hsv_to_rgb
  • tf.image.rgb_to_hsv

注意,以上操作均需要图片的数据类型为 tf.float32

2.8 Inception 图像预处理

这部分的官方代码在 slim/preprocessing/inception_preprocessing.py

需要注意的两点:

(1) distorted_bounding_box_crop 函数的 min_object_covered 默认取 0.1,根据具体数据集调整这个参数比较合适。

(2) resize 都没有 align_corners.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import tensorflow as tf

from tensorflow.python.ops import control_flow_ops

def apply_with_random_selector(x, func, num_cases):
"""
从 func(x, 0), func(x, 1), ..., func(0, num_cases-1) 中随机选一个
"""
sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
# merge(inputs): 依次判断 inputs 中的元素是否存在, 返回第一个存在的元素的[data, data_index]
# switch(data, pred): 返回 (output_false, output_true), 如果 pred 为 true, output_true = data, output_false
# 不存在; 反过来也一样
return control_flow_ops.merge([
func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
for case in range(num_cases)])[0]


def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
"""
对图片进行随机色彩变换:调整亮度、饱和度、色相、对比度。

Args:
image (float): [0, 1] 之间的 3D tensor 图像
color_ordering (int, optional): Defaults to 0. 可取 0-3,代表不同的随机变换模式。
fast_mode (bool, optional): Defaults to True. 快速模式下不采用调整色相和对比度的变换。
scope (optional): Defaults to None.

Returns:
颜色变换处理后的 float32 图像。
"""
with tf.name_scope(scope, 'distort_color', [image]):
if fast_mode:
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
else:
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
else:
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
elif color_ordering == 1:
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
elif color_ordering == 2:
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
elif color_ordering == 3:
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
else:
raise ValueError('color_ordering must be in [0, 3]')

# 有效截断
return tf.clip_by_value(image, 0.0, 1.0)

# min_object_covered 默认为 0.1,可以根据实际任务稍微改大一点
def distorted_bounding_box_crop(image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0),
max_attempts=100,
scope=None):
with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
tf.shape(image),
bounding_boxes=bbox,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
max_attempts=max_attempts,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box

cropped_image = tf.slice(image, bbox_begin, bbox_size)
return cropped_image, distort_bbox

def preprocess_for_train(image, height, width, bbox,
fast_mode=True,
scope=None,
add_image_summaries=True):
"""
训练集数据预处理。

Args:
image: 输入图像, uint8 或者 float32(都会被转换为 float32). shape: [H, W, C]
bbox: shape: [1, num_boxes, coords]. coords: [ymin, xmin, ymax, xmax]. 为 None 时取原图整图.
fast_mode: resize 插值和颜色变换是否采用快速模式。
add_image_summaries: 是否画出画出中间处理过程得到的图。

Returns:
3D float 图像。范围在 [-1, 1] 之间。
"""
with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]):
# bbox 为空取原图
if bbox is None:
bbox = tf.constant([0.0, 0.0, 1.0, 1.0],
dtype=tf.float32,
shape=[1, 1, 4])
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
bbox)
if add_image_summaries:
tf.summary.image('image_with_bounding_boxes', image_with_box)

distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
# Restore the shape since the dynamic slice based upon the bbox_size loses
# the third dimension.
distorted_image.set_shape([None, None, 3])
image_with_distorted_box = tf.image.draw_bounding_boxes(
tf.expand_dims(image, 0), distorted_bbox)
if add_image_summaries:
tf.summary.image('images_with_distorted_bounding_box',
image_with_distorted_box)

# 快速模式下: resize 采取双线性插值,否则是选择随机方法插值。
num_resize_cases = 1 if fast_mode else 4
distorted_image = apply_with_random_selector(
distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method),
num_cases=num_resize_cases)

if add_image_summaries:
tf.summary.image('cropped_resized_image',
tf.expand_dims(distorted_image, 0))

# 随机左右翻转
distorted_image = tf.image.random_flip_left_right(distorted_image)

# 调用 distort_color 随机做颜色变换,选择是否采用快速模式。
num_distort_cases = 1 if fast_mode else 4
distorted_image = apply_with_random_selector(
distorted_image,
lambda x, ordering: distort_color(x, ordering, fast_mode),
num_cases=num_distort_cases)

if add_image_summaries:
tf.summary.image('final_distorted_image',
tf.expand_dims(distorted_image, 0))
# [0, 1] 之间的图,变为 [-1, 1] 之间。
distorted_image = tf.subtract(distorted_image, 0.5)
distorted_image = tf.multiply(distorted_image, 2.0)
return distorted_image


def preprocess_for_eval(image, height, width,
central_fraction=0.875, scope=None):
"""
验证集数据预处理。做中心裁剪,双线性插值 resize,数值范围变换到 [-1, 1] 之间。

Returns:
3D float 图像。范围在 [-1, 1] 之间。
"""

with tf.name_scope(scope, 'eval_image', [image, height, width]):
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# 取出中间 0.875 的区域
if central_fraction:
image = tf.image.central_crop(image, central_fraction=central_fraction)

if height and width:
# resize 图像,采用双线性插值
image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(image, [height, width],
align_corners=False)
image = tf.squeeze(image, [0])
# 取值范围变为 [-1, 1] 之间
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.0)
return image


def preprocess_image(image, height, width,
is_training=False,
bbox=None,
fast_mode=True,
add_image_summaries=True):
"""
训练集和验证集图像预处理。

Args:
image: 输入图像, uint8 或者 float32(都会被转换为 float32). shape: [H, W, C]
is_training: 训练集还是测试集
bbox: 标记框,默认 None 表示取原图整图
fast_mode: resize 和颜色变换是否采用快速模式
add_image_summaries: 中间处理后的图是否画图。

Returns:
3D float 图像。范围在 [-1, 1] 之间。
"""

if is_training:
return preprocess_for_train(image, height, width, bbox, fast_mode,
add_image_summaries=add_image_summaries)
else:
return preprocess_for_eval(image, height, width)

其他的如 VGG 预处理,可参见: vgg_preprocessing.py,VGG 裁剪图片时是确定短边长度后再等比例 resize.

3. 队列与多线程(旧 API)

3.1 队列 Queue

TensorFlow 中提供的队列有:

  • tf.FIFOQuese: 先进先出队列
  • tf.RandomShuffleQueue: 随机顺序出列的队列。注意:测试发现,在满足队列容量 > min_after_dequeue 条件下,每 dequeue 一次,整个 Queue 就要 Shuffle 一次。
  • tf.PaddingFIFOQueue: 以固定长度批量出列的队列
  • tf.PriorityQueue: 带优先级出列的队列

这些类型的 Queue 的 API 大致差不多,以 FIFOQueue 为例:

1
2
3
4
5
6
7
tf.FIFOQueue(capacity, dtypes, shapes=None): capacity 为队列容量,dtypes 为队列中元素的数据类型,shapes 为队列中元素的 shape。
-> 常用函数:
- dequeue(): 出队一个元素
- dequeue_many(n): 出队 n 个元素。使用此函数必须手动指定 shapes。
- enqueue(): 入队一个元素
- enqueue_many(val): 入队多个元素. 注意这里 val 要比基元多一维。
- size(): 返回队列中元素多少。返回数据类型为 tensor。

注意: 队列中存储了 capacity 个基元,以 list 形式存在。当基元中包含多个数据时,dtypes 是个 list,长度与基元长度要相同。若要使用 dequeue_many(n),shape 必须手动指定。

1
2
3
4
5
6
7
8
q = tf.FIFOQueue(5, tf.int32, shapes=[()])
op1 = q.enqueue_many([[1, 2]])
op1.run()
op2 = q.enqueue([3])
op2.run()
print q.size().eval() # 3
print q.dequeue().eval() # 1
print q.dequeue_many(2).eval() # [2, 3]

TODO(20180817): 当基元有多个元素时,enqueue_many 表现很奇怪,尚未弄懂。因此暂时避免使用 enqueue_many 和 dequeue_many。

3.2 tf.train.Coordinator 与 tf.train.QueueRunner

  • tf.train.Coordinator: tf 中用来协调线程运行的工具。主要用于协调线程停止。
  • tf.train.QueueRunner: tf 中对操作队列多线程的封装,用于创造线程。

3.2.1 tf.train.Coordinator

该类的常用方法:

  • should_stop(): 查询线程是否要终止。返回布尔值。
  • request_stop(): 请求终止线程。每一个线程都可以调用 request_stop() 来请求终止其他线程,这样其他线程在检查 should_stop() 时得到 True,因此就会终止线程。
  • clear_stop(): 清除线程终止信号。
  • join(threads=None, stop_grace_period_secs=120): 阻塞线程。发出线程终止请求后,其他线程必须在 stop_grace_period_secs 时间内完成终止,否则抛出异常。
  • stop_on_exception(): 上下文管理。当发生异常是,请求终止线程。

使用方法:

1
2
3
4
5
6
try:
coord = tf.train.Coordinator()
... codes of creating threads ...
coord.join(threads)
except Exception as e:
... some codes ...

上面创建线程的代码为:

1
2
3
4
5
6
try:
while not coord.should_stop():
... some work ...
except Exception as e:
coord.request_stop(e)
coord.join(threads)

可以用 stop_on_exception () 简化上面创建线程的代码:

1
2
3
4
with coord.stop_on_exception():
while not coord.should_stop():
... some work ...
coord.join(threads)

3.2.2 tf.train.QueueRunner

tf 中的多线程使用的队列启动方案。与 tf.train.Coordinator 一起使用。

比如一个经典的文件输入流程: 第一批线程通过往第一个队列里面不断填充要处理的文件名;第二批线程从前面的队列取出文件名,然后进行读取处理等操作,得到的张量放在第二个队列;第三批线程从第二个队列中取出张量,组成 batch,输入网络进行训练。

初始化方法:

1
__init__(queue=None, enqueue_ops=None)

其中 queue 为要操作的队列,enqueue_ops 为要对该队列执行的多线程操作。

tf.train.QueueRunner 常与以下两个类一起使用:

  • tf.train.add_queue_runner(qr, collection=tf.GraphKeys.QUEUE_RUNNERS): 将一个 QueueRunner 添加到图的 collection 中。默认放在 tf.GraphKeys.QUEUE_RUNNERS 这个 collection 中。
  • tf.train.start_queue_runners(sess=None, coord=None, daemon=True, start=True, collection=tf.GraphKeys.QUEUE_RUNNERS ): 启动图的 collection 中的所有 QueueRunner。默认的 collection 为 tf.GraphKeys.QUEUE_RUNNERS。daemon: 是否为守护进程; start: 是否启动线程,不启动就只是创建线程。

一个完整的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
# 创建队列,入队操作
queue = tf.FIFOQueue(100, tf.float32)
enqueue_op = queue.enqueue([tf.random_normal([])])

# 开启 5 个线程
qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)
tf.train.add_queue_runner(qr)

out_tensor = queue.dequeue()

with tf.Session() as sess:
coord = tf.train.Coordinator()
enqueue_threads = tf.train.start_queue_runners(sess=sess, coord=coord)

for i in range(5):
print sess.run(out_tensor)

# 终止所有线程
coord.request_stop()
coord.join(enqueue_threads)

当然,上面也可以不使用 tf.train.add_queue_runner 和 tf.train.start_queue_runner 来启动线程。直接用 QueueRunner 的 create_threads(sess, coord=None, daemon=False, start=False) 来启动多个线程:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
queue = tf.FIFOQueue(100, tf.float32)
enqueue_op = queue.enqueue([tf.random_normal([])])

# 开启 5 个线程
qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)

out_tensor = queue.dequeue()

with tf.Session() as sess:
coord = tf.train.Coordinator()
# 手动启动 qr 负责的入队线程
enqueue_threads = qr.create_threads(sess=sess, coord=coord, start=True)

for i in range(5):
print sess.run(out_tensor)

coord.request_stop()
coord.join(enqueue_threads)

3.2.3 输入文件队列

  • tf.train.match_filenames_once(pattern): 根据正则表达式获取符合要求的文件名列表。返回一个局部变量,为文件名列表。因此使用这和函数务必初始化局部变量

  • tf.train.string_input_producer(string_tensor, num_epochs=None, shuffle=True, capacity=32): 根据文件名返回一个文件名队列,供多线程使用。string_tensor: 装有文件名的 tensor,可由 match_filenames_onces 函数返回,也可用 Python glob 生成等;num_epochs: 加载文件列表的最大轮数,默认无限循环;shuffle: 文件名加入队列前是否打乱;capacity: 队列长度。

    注意: 从源码中发现,该函数创建了局部变量(对 epoch 进行计数),因此使用时要初始化局部变量(其实不初始化也可以,就不使用 epoch 这个变量)。该函数返回一个 FIFOQueue 用于存放文件名,并生成一个 QueueRunner 进行单线程入队操作,该 QueueRunner 放在默认的 tf.GraphKeys.QUEUE_RUNNERS 这个 collection 中。

    在指定 num_epoches (如测试时指定为 1 ) 时,队列为空后继续出队,抛出 OutOfRange 异常。

1
2
3
4
5
6
7
8
9
10
sess = tf.InteractiveSession()
# 获取所有 png 图片文件列表
files = tf.train.match_filenames_once('some_path/*.png')

filename_queue = tf.train.string_input_producer(files, shuffle=False)
coord = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess, coord)
tf.local_variables_initializer().run()

print filename_queue.dequeue().eval() # 获取一个文件名

3.2.4 Readers

前面 tf.train.string_input_producer 生成了文件名队列,tf 通过各种 Reader 从这个文件名队列中取文件名,进行文件读取解析。常用的 Reader 有:

  • tf.TFRecordReader: 读取 TFRecord 文件
  • tf.WholeFileReader: 读取一个文件的全部
  • tf.TextLineReader: 读取文本文件
  • tf.FixedLengthRecordReader: 读取固定长度的文件

它们的 API 大致相同。常用的方法有:

  • read(queue): 输入为一个 Queue,以 (key, value) 的形式输出单个文件。读取后 Queue 出队一个。
  • read_up_to(queue, num_records): 以 (keys, values) 的形式输出多个文件。

用 Reader 实现 2.1 节的图片读取:

1
2
3
4
5
6
7
8
9
10
files = tf.train.match_filenames_once('some_path/*.png')
filename_queue = tf.train.string_input_producer(files, shuffle=False)
key, value = tf.WholeFileReader().read(filename_queue)
image = tf.image.decode_png(value, channels=3)

with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
print sess.run(image)

单线程,效率较低。

3.2.5 组合训练数据 (batching)

一个高效的数据 pipeline 应该是一个生产者——消费者模型,即生产者是一个文件名队列,里面存储要处理的文件的名字,这个队列用单线程处理即可;消费者为多线程从生产者队列里面取出文件名,进行文件读取和预处理,然后放置在另一个队列,最终的数据从这个队列中取出。

一个完整的流程如图:

TensorFlow 中提供的对应的 batch 数据的函数为:

  • tf.train.batch(tensors, batch_size, num_threads=1, capacity=32, enqueue_many=False, dynamic_pad=False, allow_smaller_final_batch=False): 组合输入的 tensors, 形成一个 batch,放置在新建的一个 Queue 中,并同时新建一个 QueueRunner (添加到默认的 QueueRunner collection)。allow_smaller_final_batch 为 True 时,最后的不够一个 batch 的数据也会留下。返回一个 batch 的数据。
  • tf.train.shuffle_batch(tensors, batch_size, capacity, min_after_dequeue, num_threads=1, seed=None, dynamic_pad=False, allow_smaller_final_batch=False): 产生打乱顺序的 batch,。与 tf.train.batch 的唯一区别是多了 min_after_dequeue 参数,限制了出队操作后队列中最少元素的数量。因为当队列中元素太少时,打乱的意义就不大了。创建的 Queue 类型是 RandomShuffleQueue。

注意: 以上 dynamic_pad 为 False 时,传入的 tensors 必须显式确定,否则抛出异常。

  • tf.train.batch_join(tensors_list, batch_size, capacity=32, allow_smaller_final_batch=False)
  • tf.train.shuffle_batch_join(tensors_list, batch_size, capacity, min_after_dequeue, allow_smaller_final_batch=False)

一般我们的文件不止一个,比如 TensorFlow Performance Guide 建议,把大数据文件分割成多个约为 100 MB 的 TFRecord 文件,I/O 性能比较好。这种情况下,多文件,多线程进行读取和预处理操作应该用上面两个函数。

其中,输入的 tensors_list 为 a list of tuples of tensors,创建 len(tensors_list) 个线程,每个线程读取一个文件,然后压入队列:

1
2
3
4
5
6
7
8
9
10
# features 为解析的 TFRecord 文件
image, label = features['image'], features['label']

# 1. 使用 tf.train.batch:
# 返回的 iamge 是个 [N, H, W, C] 的 tensor
# 记住 image 和 label 要一起 run,不然就错位交叉了
image, label = tf.train.batch([image, label], batch_size=10, num_threads=1, capacity=100)

# 2. 使用 tf.train.batch_join:
image, label = tf.train.batch_join([[image, label] for _ in range(4)], batch_size=10, capacity=100)

比较: tf.train.batch 是多线程读取一个文件,tf.train.batch_join 是多线程读取多个文件,每个线程负责一个文件。如果同一个文件中样本相似,用 tf.train.batch 显然不合适;使用 tf.train.batch_join() 时,如果线程数大于文件数,那么也存在多个线程读取同一个文件的情况,而且多线程读多个文件的硬盘寻址也是有时间开销的,可能会让效率变低。

3.2.6 Inception 数据输入框架 (旧 API)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# coding: utf-8

import tensorflow as tf

files = tf.train.match_filenames_ones('some_path/data.tfrecords-*')
# 训练集,文件名打乱
filename_queue = tf.train.string_input_producer(files, shuffle=True, )

_, serialized_example = tf.TFRecordReader().read(filename_queue)
features = tf.parse_single_example(serialized_example, features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channels': tf.FixedLenFeature([], tf.int64)
})

image, label, height, width, channels = features['image'], features['label'], features['height'], features['width'], features['channels']

decoded_image = tf.decode_raw(image, tf.uint8)
decoded_image = tf.reshape(decoded_image, tf.stack([height, width, channels]))

image_size = 299

# 前面的 inception 预处理代码
distorted_image = preprocessing_for_train(decoded_image, image_size, image_size, None)

# 这一部分可以具体调整
num_threads = 10
batch_size = 50
min_after_dequeue = 1000
# as suggested by https://www.tensorflow.org/api_guides/python/reading_data#Batching
capacity = min_after_dequeue + (num_threads + 10) * batch_size

image_batch, label_batch = tf.train.shuffle_batch([distorted_image, label], batch_size=batch_size,
capacity=capacity, min_after_dequeue=min_after_dequeue, num_threads=num_threads)
# tf.train.shuffle_batch_join 的方案
# image_batch, label_batch = tf.train.shuffle_batch([[distorted_image, label] for i in range(num_threads)], batch_size=batch_size,
# capacity=capacity, min_after_dequeue=min_after_dequeue)

logit = inference(image_batch)
loss = calc_loss(logit, label_batch)
train_step = ...

with tf.Session() as sess:
sess.run(tf.global_variables_initializer(), tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)

with coord.stop_on_exception():
while not coord.should_stop():
for i in range(training_epoches):
sess.run(train_step)

coord.join(threads)

4. tf.data 模块 (新 API)

tf 1.4 后的数据输入框架,抛弃队列处理的旧 API,使用 Dataset 数据集提供数据的输入。这一部分比较简单,官方文档很详细。

4.1 API 概览

(1) 利用数据集的基本步骤:

  • 根据数据类型创建对应的 Dataset
  • 定义迭代器 Iterator,并进行相应的初始化
  • 预处理、shuffle、batch
  • 使用 get_next() 从迭代器中取出数据张量

(2) 常用的 Dataset:

  • tf.data.Dataset(): 所有 Dataset 的基类
  • tf.data.TextLineDataset(filenames, compression_type=None, buffer_size=None): 从一个或多个文本文件中读取内容
  • tf.data.TFRecordDataset(filenames, compression_type=None, buffer_size=None, num_parallel_reads=None): 从一个或多个 TFRecord 文件中读取内容
  • tf.data.FixedLengthRecordDataset(filenames, record_bytes, header_bytes=None, footer_bytes=None, buffer_size=None): 从一个或多个二进制文件中读取固定长度内容。

(3) Dataset 常用的属性:

  • output_shapes: Dataset 中每个元素的 shape
  • output_types: Dataset 中每个元素的 type

(4) Dataset 常用的方法:

  • batch(batch_size): 按 batch_size 取一个 batch。最后不够一个 batch_size 也会被取出来,如果不要最后的零头,使用 tf.contrib.data.batch_and_drop_remainder 方法。
  • padded_batch(batch_size, padded_shapes, padding_values=None): 生成 padded batch。
  • shuffle(buffer_size, seed=None, reshuffle_each_iteration=None): 打乱数据集。buffer_size: 缓冲区大小。默认 reshuffle_each_iteration 为 True。因此,每迭代一个元素出来,缓冲区中填充一个新元素,shuffle 一次。buffer_size 越大,打乱性能越好,但是第一次启动的时间就较长。可设置 buffer_size 为数据集大小,这样打乱充分。参考链接
  • repeat(count=None): 把 Dataset 重复 count 次。None 或 -1 表示无限循环。
  • skip(count): 跳过前 count 个元素。
  • take(count): 读取前 count 个元素。
  • shard(num_shards, index): 生成一个只包含原 Dataset 的 1/num_shards 的新 Dataset。
  • from_tensors(tensors): 根据单个元素构建 Dataset。
  • from_tensor_slices(tensors): 根据元素切片构建 Dataset。
  • from_generator(generator, output_types, out_shapes=None): 根据生成器构建 Dataset。这一部分使用 tf.py_func 实现的。
  • map(map_func, num_parallel_calls=None): map 运算,返回一个 map 运算后的 Dataset。
  • flat_map(map_func): map_func 后再 flat 展平,返回一个 Dataset。
  • filter(predicate): filter 运算,返回一个 Dataset。
  • interleave(map_func, cycle_length, block_length=1): 适用于分布式文件系统。当有多个文件时,一次对 cycle_length 个文件同时读取,block_length 是每个线程输出元素的个数。因此,map 和 flat_map 相当于 tf.train.batch,interleave 相当于 tf.train.batch_join。
  • apply(transformation_func): 对一个 Dataset 进行某个操作,类似于 map。
  • zip(datasets): 和 Python 中的 zip 函数一样,把多个 Dataset zip起来。
  • concatenate(dataset): 串接一个 Dataset, 返回一个新的合并的 Dataset。
  • prefetch(buffer_size): 预加载 buffer_size 的数据。
  • range(*args): 生成一个 RangeDataset。
  • list_files(file_pattern, shuffle=None): 根据文件名 file_pattern 正则表达式,获取一个包含这些文件的 Dataset。注意,这个顺序是不定的,即使 shuffle 为 False。看源码发现这个函数其实是先用 tf.matching_files 得到文件列表,在用 from_tensor_slices 得到一个 dataset,最后 shuffle。注意,最后一个 shuffle 默认的缓冲区为整个文件名列表。因此,由于 from_tensor_slices 和 shuffle 缓冲区长度的设定,当文件名列表过于巨大时,这一步耗时就会很大,可参考 Github

(5)迭代器:

  • one-shot iterator: Dataset.make_one_shot_iterator()。绑定一个 Dataset 的单次迭代器,只对 Dataset 进行一次迭代。只有这个不需要显式初始化。所有参数都已经确定,中间不能修改了。
  • initializable iterator: Dataset.make_initializable_iterator(): 绑定一个 Dataset 的可初始化迭代器。需要手动初始化迭代器。中间可以修改迭代器参数,可以配合 placeholder 使用。
  • reintializable iterator: tf.data.Iterator.from_structure(output_types, output_shapes=None, shard_names=None, output_classes=None): 这个迭代器没有绑定 Dataset,因此是用 Iterator 这个类创建的。中间可以重复初始化。把同一个迭代器应用到不同的数据集从而实现切换数据集的功能。需要手动初始化,需要使用 iterator.make_initializer(dataset) 来针对给定的 Dataset 初始化。
  • feedable iterator: tf.data.Iterator.from_string_handle(string_handle, output_types, output_shaps=None, output_classes=None): 这个迭代器没有绑定 Dataset。目的是切换到不同的 iterator 。string_handle 是代表一个 Iterator 的 Tensor,可先用一个 placeholder 占据,选择数据集的时候 feed 为具体 Dataset 的 Iterator (用这个 Dataset 的 Iterator.string_handle() 得到)。

4.2 案例分析

4.2.1 性能优化

以一个简单的文本处理为例,假设为分布式文件系统,有 5 个 txt 文件,每个文件里面存放某一类样本的文件名和类别标号,形如:

1
2
3
4
5
6
7
8
9
10
11
12
- file1.txt:
cat1.jpg 0
cat2.jpg 0
...
cat4.jpg 0
- file2.txt
dog1.jpg 1
dog2.jpg 1
...
dog4.jpg 1
...
- file_n.txt

(1) 基本流程,无各种优化考虑:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 获取文件列表,并按照 file_{n} 中的数字 n 从小到大排序
txt_files = glob.glob('some_path/*.txt')
txt_files.sort(key=lambda x: int(x.split('.')[-2][-1]))

# 使用 TextLineDataset 读取文本文件
dataset = tf.data.TextLineDataset(txt_files)
dataset = dataset.shuffle(buffer_size=20)
dataset = dataset.repeat(2)
# 分割字符串,按空格分割
dataset = dataset.map(lambda x: tf.string_split([x], delimiter=' ').values)
dataset = dataset.batch(2)

# 创建迭代器
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()

输出的结果为:

1
2
3
4
5
[['bird1.jpg' '3']
['dog4.jpg' '1']]
[['dog2.jpg' '1']
['bird3.jpg' '3']]
...

要注意的几点:

– 参考 4.1 中设置 shuffle 的 buffer_size

– shuffle 和 repeat 的顺序,建议先 shuffle 再 repeat,如果反过来会造成打乱效果变差。参考: Dataset Performance Guide

(2) 基础性能优化: 多线程读取多文件,多线程 map,预加载 prefetch

上面代码的性能缺点:单线程读取,单线程 map,加上 Dataset 默认的 lazy 属性,性能地下。

基础改进代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
txt_files = glob.glob('some_path/*.txt')
txt_files.sort(key=lambda x: int(x.split('.')[-2][-1]))

# 改进1:多线程读取文件,创建5个线程,每个线程负责一个文件读取。
# 因此改进后可以同时读 5 个文件
dataset = tf.data.Dataset.from_tensor_slices(txt_files)
dataset = dataset.interleave(lambda x: tf.data.TextLineDataset(x),
cycle_length=5, block_length=1)

dataset = dataset.shuffle(buffer_size=20)
dataset = dataset.repeat(2)
# 改进2: 多线程 map 函数,也就是可以多线程预处理了。
dataset = dataset.map(lambda x: tf.string_split([x], delimiter=' ').values,
num_parallel_calls=4)
dataset = dataset.batch(2)

# 改进3:预加载,输出有缓冲区。类似于消费者队列中填充一定 batch 数,等待消费。
# 这里是基于前一步操作后的元素,因此是预加载 5 个 batch。
dataset = dataset.prefetch(5)
...

TODO: 上面的 interleave 同时读取 5 个文件,但是真的是并行吗?按照 tf.data API slides 应该是并行 I/O,但是按照 Dataset Performance Guide 却建议使用 tf.contrib.data.parallel_interleave() 实现真正的并行 I/O,这里有待进一步明确,是否需要改成下面的版本:

1
dataset = dataset.apply(tf.contrib.data.parallel_interleave(tf.data.TextLineDataset, cycle_length=4))

(3) 进一步优化:

并行化 batch: 当 batch_size 比较大时,取 batch 也是耗时的,因此可以把 map 和 batch 放在一起做多线程,最终改成这样的版本(使用 tf.contrib.data.map_and_batch 函数):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
txt_files = glob.glob('some_path/*.txt')
txt_files.sort(key=lambda x: int(x.split('.')[-2][-1]))

# 确保运行在 CPU 上
with tf.device('/cpu:0'):
dataset = tf.data.Dataset.from_tensor_slices(txt_files)
dataset = dataset.apply(tf.contrib.data.parallel_interleave(tf.data.TextLineDataset, cycle_length=4))

dataset = dataset.shuffle(buffer_size=20)
dataset = dataset.repeat(2)
# 改动这里
dataset = dataset.apply(tf.contrib.data.map_and_batch(lambda x: tf.string_split([x], delimiter=' ').values, batch_size=2, num_parallel_batches=4))
dataset = dataset.batch(2)

dataset = dataset.prefetch(5)

其他的优化还有内存 cache 等的考虑等,参考 Dataset Performance Guide

注意:一般而言, batch 耗时相对较少,当电脑核心不太够的时候,并行化的 batch 占据线程也不一定是好事,可能速度也会变慢。

4.2.2 各种迭代器的使用及比较

(1) initializable iterator: 动态指定 iterator 参数。配合 placeholder 使用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 指定文件名的 placeholder
txt_files = tf.placeholder(tf.string, [])

dataset = tf.data.TextLineDataset([txt_files])
dataset = dataset.map(lambda x: tf.string_split([x], delimiter=' ').values)

iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()

iterator.initializer.run(feed_dict={txt_files: './txt_files/file1.txt'})
for i in range(5):
print next_element.eval()

# 切换到另一个 txt 文件
iterator.initializer.run(feed_dict={txt_files: './txt_files/file2.txt'})
for i in range(2):
print next_element.eval()

# 中间重新初始化,迭代器从头开始
iterator.initializer.run(feed_dict={txt_files: './txt_files/file2.txt'})
for i in range(5):
print next_element.eval()

(2) reinitializable iterator: 一个可重复初始化的迭代器,绑定到不同的数据集上去。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 两个数据集
cat_dataset = tf.data.TextLineDataset(['./txt_files/file1.txt']).map(lambda x: tf.string_split([x], delimiter=' ').values)
dog_dataset = tf.data.TextLineDataset(['./txt_files/file2.txt']).map(lambda x: tf.string_split([x], delimiter=' ').values)

# 同一个 reinitializable iterator
iterator = tf.data.Iterator.from_structure(cat_dataset.output_types, dog_dataset._output_shapes)
next_element = iterator.get_next()

# 迭代器初始化
cat_init_op = iterator.make_initializer(cat_dataset)
dog_init_op = iterator.make_initializer(dog_dataset)

# 迭代器绑定到 cat dataset
cat_init_op.run()
for _ in range(5):
print next_element.eval()

# 迭代器绑定到 dog dataset
dog_init_op.run()
for _ in range(5):
print next_element.eval()

通常可以用同一个迭代器绑定到 training set 和 test_set,不过,用前面的 initializer_iterator 也可实现相同功能。

(3) feedable iterator: 迭代器是可变的,目的是通过选择不同数据集的 Dataset 的迭代器来迭代不同的数据。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# 创建两个 Dataset
cat_dataset = tf.data.TextLineDataset(['./txt_files/file1.txt']).map(lambda x: tf.string_split([x], delimiter=' ').values)
dog_dataset = tf.data.TextLineDataset(['./txt_files/file2.txt']).map(lambda x: tf.string_split([x], delimiter=' ').values)

# 创建两个 Dataset 对应的迭代器
cat_iterator = cat_dataset.make_one_shot_iterator()
dog_iterator = dog_dataset.make_initializable_iterator()

# 创建代表两个 Dataset 的迭代器的 handle tensor
cat_handle = cat_iterator.string_handle().eval()
dog_handle = dog_iterator.string_handle().eval()

# 根据传入的 handle 选择具体的迭代器
handle = tf.placeholder(tf.string, [])
iterator = tf.data.Iterator.from_string_handle(handle, cat_dataset.output_types, cat_dataset.output_shapes)
next_element = iterator.get_next()

while True:
# 这里切换数据集,之后再切回 cat_dataset 是继续之前的数据迭代。
for _ in range(2):
print next_element.eval(feed_dict={handle: cat_handle})

dog_iterator.initializer.run()
for _ in range(2):
print next_element.eval(feed_dict={handle: dog_handle})

reinitializable iterator 和 feedable iterator 的区别:

二者都是实现切换数据集的功能,但是 reinitializable iterator 是同一个迭代器绑定到不同数据集,切换数据集的时候要重新绑定,然后初始化这个迭代器,特点是切换数据集后从头开始迭代切换到的新数据集

feedable iterator 是针对每个数据集先建立对应的迭代器,然后选择用哪一个迭代器来选择对应的数据集,如果这些迭代器在创建的时候先初始化好,那么在迭代器之间来回切换的时候,各自的数据集是接着前一次的时间点继续输出的。比如训练集特别大,并不想训练集完全迭代一轮再验证,迭代一定数量的较少样本后就想测试一次验证集,就可以用这种迭代器。

4.2.3 一些应用

(1) 读取 npy 文件:

可以直接用 from_tensor_slices 一次载入:

1
2
3
4
5
6
data = np.load('data.npy')
features = data["features"]
labels = data["labels"]

dataset = tf.data.Dataset.from_tensor_slices((features, labels))
...

缺点是:整个 npy 解析出来的数组以 tf.constant() 的形式存在于 Graph 中,而一个 Graph protobuf 文件的上限大小为 2G。

可以使用 initializable_iterator 配合 placeholder 载入:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
data = np.load('data.npy')
features = data["features"]
labels = data["labels"]

# 创建对应的 placeholder
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
... 一些其他预处理操作等 ...

iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, feed_dict={feature_placeholder: features, labels_placeholder: labels})

上面方案其实也不好。上面两个是官方提供的方案。

Stack Overflow 上提供了其他方案:

比如这个方案计算出 npy 的 header_bytes 长度,然后用 tf.data.FixedLengthRecordDataset 来解析 npy 二进制文件。如果要并行处理,那么多个 npy 的 header_bytes 长度要一样。

另外,可以用 tf.py_func 来调用 np.load() 函数,然后放到 Dataset.map() 函数里面解决:

1
2
3
4
5
6
7
8
9
file_list = ['a.npy', 'b.npy', ...]

def read_npy_file(file):
return np.load(file).as_type(np.float32)

dataset = tf.data.Dataset.from_tensor_slices(file_list)
dataset = dataset.map(lambda file: tuple(tf.py_func(
read_npy_file, [file], [tf.float32])))
...

其他的一些应用可以参考 Tensorflow Dataset Guide – Importing Data,里面有从 Python 迭代器得到 Dataset,读取 csv 等文件,进行 padding 操作,利用 tf.py_func 调用 opencv 等。

4.2.4 Inception 数据输入框架 (新 API)

用 tf.data 重写 3.2.6 部分:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import tensorflow as tf

IMAGE_SIZE = 299 # 输入图片大小
BATCH_SIZE = 50
SHUFFLE_BUFFER = 10000
NUM_EPOCHES = 100

# 列举 tfrecord 文件名
training_files = tf.train.match_filenames_ones('training_path/data.tfrecords-*')
val_files = tf.train.match_filenames_ones('val_path/data.tfrecords-*')

def tfrecord_parser(record):
features = tf.parse_single_example(
record, features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channels': tf.FixedLenFeature([], tf.int64)
}
)
label, height, width, channels = features['label'], features['height'], features['width'], features['channels']
decoded_image = tf.decode_raw(features['image'], tf.uint8)
decoded_image = tf.reshape(decoded_image, tf.stack([height, width, channels]))
return decoded_image, label

# 定义 traing dataset 和对应的迭代器
# dataset = tf.data.TFRecordDataset(training_files) 这个效率低
training_dataset = tf.data.Dataset.from_tensor_slices(training_files)
training_dataset = training_dataset.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=4)
training_dataset = training_dataset.shuffle(SHUFFLE_BUFFER)
training_dataset = training_dataset.repeat(NUM_EPOCHES)
training_dataset = training_dataset.map(tfrecord_parser, num_parallel_calls=2)
training_dataset = training_dataset.map(lambda image, label: (preprocess_for_train(image, IMAGE_SIZE, IMAGE_SIZE), label), num_parallel_calls=4)
training_dataset = training_dataset.batch(BATCH_SIZE)
training_dataset = training_dataset.prefetch(5)

training_iterator = training_datasetmake_initializable_iterator()
training_image_batch, training_label_batch = training_iterator.get_next()

# 网络训练部分
logit = inference(training_image_batch)
loss = calc_loss(logit, training_label_batch)
train_step = ...

# 创建 validation dataset 和对应的迭代器
val_dataset = tf.data.TFRecordDataset(val_files)
val_dataset = val_dataset.map(tfrecord_parser)
val_dataset = val_dataset.map(lambda image, label: (preprocess_for_val(image, IMAGE_SIZE, IMAGE_SIZE), label), num_parallel_calls=2)
val_dataset = val_dataset.batch(BATCH_SIZE)

val_iterator = val_dataset.make_initializable_iterator()
val_image_batch, val_label_batch = val_iterator.get_next()

# 网络测试部分
val_logit = inference(val_image_batch)
predictions = tf.argmax(val_logit, axis=-1, output_type=tf.int32)

with tf.Session() as sess:
sess.run(tf.global_variables_initializer(), tf.local_variables_initializer())

sess.run(training_iterator.initializer)
while True:
try:
sess.run(train_step)
except tf.errors.OutOfRangeError:
break

sess.run(val_iterator.initializer)
val_results = []
val_labels = []
while True:
try:
pred, label = sess.run([predictions, val_label_batch])
val_results.extend([pred])
val_labels.extend([label])
except tf.errors.OutOfRangeError:
break

# 计算准确率等指标
correct = [float(y == y_) for (y, y_) in zip(val_results, val_labels)]
acc = sum(correct) / len(correct)
觉得文章不错,就赏我一杯咖啡钱吧~