本篇笔记主要总结了如何在 TensorFlow 如何构建高效的 Input Pipeline,目的是协调 CPU 文件预处理和 GPU 模型计算之间的调度,尽最大限度发挥 GPU 算力。其中涉及到 TFRecord 文件的读写,tf.image 模块对图像的处理,以及版本 1.4 前使用的生产者/消费者多线程文件读写流程,和 1.4 后官方主推的 Dataset 处理方式。后者已经开始逐步支持 eager 模式。
吐槽部分:这一部分学完后感触很深,tf 真的很难,它已经从一个 Python 库上升到了一个十分复杂的编程语言的高度,让人抓狂。大部分 API 文档是根据代码注释生成的,往往晦涩难懂,由于时间和精力原因去深入庞杂的 C++ 底层几乎不可能。由此带来的最大困扰是,you seldom know what is happening under the hood。比如有多少人知道 tf.image 模块的 resize 的插值方法中,align corners 和主流图像库处理的方式都不一样?在学习这部分的过程中,我翻阅了大量博客,Github Issue,Stack Overflow 回答,多次对着源码逐行分析并进行代码测试,最终总结成个人笔记可谓是“满纸荒唐言,一把辛酸泪”。不过,也正是学习这部分的钻研过程中,感慨 tf 真的是为了实际应用做了充分的考虑,正如 caffe 作者贾扬清在知乎说道: "TF的确难,但是它给你提供了真正可以产品化的可能性。"
先把这部分我个人觉得要注意的一些点(或者说是大坑)列举一下,方便日后查阅:
推荐阅读: CS230 课件 “An overview of tf.data” 部分列举的所有链接。
TFRecord 文件的数据是通过 tf.train.Example 这个 protobuf 的格式存储的。其定义在 tensorflow/core/example/example.proto 和 tensorflow/core/example/feature.proto :
message Example {
// type: Features, name: features
Features features = 1;
};
message Features {
// Map from feature name to feature.
// Features is a key-value store, where each key (string)
// maps to a Feature message
map<string, Feature> feature = 1;
};
// Containers for non-sequential data.
message Feature {
// Each feature can be exactly one kind.
oneof kind {
BytesList bytes_list = 1;
FloatList float_list = 2;
Int64List int64_list = 3;
}
};
// Containers to hold repeated fundamental values.
message BytesList {
repeated bytes value = 1;
}
message FloatList {
repeated float value = 1 [packed = true];
}
message Int64List {
repeated int64 value = 1 [packed = true];
}
tf.train.Example 中包含名为 features 的 Features 的信息,每个 Features 中包含一个从 feature 名称到 feature 属性 (Feature) 的字典映射(key-value对),其中每个 Feature 的可以从 BytesList (字符串),FloatList (浮点实数列表), Int64List (整数列表)中取。
一个 tf.train.Example 的例子:
features {
feature {
key: "age"
value { float_list {
value: 29.0
}}
}
feature {
key: "movie"
value { bytes_list {
value: "The Shawshank Redemption"
value: "Fight Club"
}}
}
feature {
key: "movie_ratings"
value { float_list {
value: 9.0
value: 9.7
}}
}
}
样例程序:
import tensorflow as tf
import os
import cv2
# 输出的 TFRecord 文件带路径的名称
output = './output.tfrecords'
# 创建一个 writer 来写入 TFRecord 文件
writer = tf.python_io.TFRecordWriter(output)
# 辅助函数,生成整数和字符串型的 tf.train.Feature。
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# 读取图片,并保存到 TFRecord 文件中
img_dir = './img'
imgs = os.listdir(img_dir)
imgs.sort()
for index, img in enumerate(imgs):
img_path = os.path.join(img_dir, img)
img_data = cv2.imread(img_path)
resized_img = cv2.resize(img_data, (128, 128, interpolation=cv2.INTER_AREA))
# 这里必须把 ndarry 转换成字符串形式的原始二进制数据流
img_raw = resized_img.tostring()
# 生成 Example Protobuf 文件
example = tf.train.Example(features=tf.train.Features(feature={
'shape_': _int64_feature(resized_img.shape[0]),
'label_': _int64_feature(index),
'img_raw': _bytes_feature(img_raw)
}))
# 将序列化后的example 写入 TFRecord 文件
writer.write(example.SerializeToString())
writer.close()
import tensorflow as tf
# 注意默认 shuffle = True
# 返回一个队列 Queue 对象
filename_queue = tf.train.string_input_producer(['./output.tfrecords'], shuffle=False)
# 创建一个 reader 读取 TFRecord 文件中的样例
reader = tf.TFRecordReader()
# 一次读取一个样例。也可以使用 read_up_to 函数一次读取多个样例
# Returns the next record (key, value) pair produced by a reader.
_, serialized_example = reader.read(filename_queue)
# 解析单个 features 文件; 解析多个用 parse_example 函数
features = tf.parse_single_example(
serialized_example,
features={
'shape_': tf.FixedLenFeature([], tf.int64),
'label_': tf.FixedLenFeature([], tf.int64),
'img_raw': tf.FixedLenFeature([], tf.string)
})
# 以上 tf.FixedLenFeature 方法解析的结果为一个 tensor。另一种方法是 tf.VarLenFeature, 解析的
# 结果是 SparseTensor
# 将字符串 tensor 解析成数据 tensor,注意此时为一维数据,需要 reshape
image = tf.decode_raw(features['img_raw'], tf.uint8)
image_shape = tf.stack([shape_, shape_, 3]) # 这一行不能少
image = tf.reshape(image, image_shape)
# 默认是 int64 的 tensor, 转成 int32
shape = tf.cast(features['shape_'], tf.int32)
label = tf.cast(features['label_'], tf.int32)
sess = tf.Session()
# 多线程部分参见第3部分
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(4):
print sess.run([image, shape, label])
读取示例:
tf.InteractiveSession()
image_raw_data = tf.gfile.FastGFile('./imgs/img1.png', 'rb').read()
# 一定要明确指定3通道,默认自动识别好像有bug,不生效
img_data = tf.image.decode_png(image_raw_data, channels=3)
# (732, 808, 3) uint8
# 格式: [H, W, C]
print img_data.eval().shape, img_data.eval().dtype
注意:
(1) tf.gfile 模块为 TensorFlow 的文件读写模块,C++ 实现,API 与 Python 自带文件模块高度相近。区别在于,tf.gfile 实现了多种文件系统的读写,比如本地文件,谷歌云存储文件(前缀 gs://),HDFS 文件(前缀 hdfs://) 等等。TensorFlow 中写入载入 checkpoints,TensorBoard 日志文件等都是用 tf.gfile 模块实现。
简言之,tf.gfile 实现了更多文件读写接口,在处理日常本地文件读写时,tf.gfile 并没有明显的速度优势,用 python 自带文件读写模块就可以了。
(2) tf.gfile.GFile, tf.gfile.FastGFile 二者在r1.8源码实现上并无区别,都是没有线程锁的文件 I/O,所以二者等同。
(3) 图像编码时的搭配:
tf.decode_raw()
函数;如上面的例子,可以改成下面的写法:
image = cv2.imread('./imgs/img1.png')
print image.shape # (732, 808, 3)
image_raw_data = tf.decode_raw(image.tobytes(), tf.uint8)
print image_raw_data.eval().shape # (1774368,) // 1774368 = 732 * 808 * 3
注意 tf.decode_raw() 里面第二个形参 out_type 一定要正确,否则输出的数据的维度就不对。
(4) tf.image 中解码图像的 API:
上面的函数 tf.image.decode_png 等返回的 tensor 是有静态 shape的,而下面的 tf.image.decode_image 由于使用了 tf.cond 判断图片类型,因此返回的 tensor 没有静态 shape,造成它解码的图片无法与tf.image.resize_images() 一起使用;
现在 tf.image.decode_png,tf.image.decode_jpeg 已经能读取所有图片文件类型,和非动态 gif 文件了;
务必明确指定 tf.image.decode_png() 等函数中的 channels,想要 RGB 图务必指定为 3,默认的0在 1.8 版本并未生效。
--> 总结: 使用 tf.image.decode_png() 函数,不要使用 tf.image.decode_images() 函数,注意 channels 参数手动指定。返回的数据类型为 tf.uint8。
(5) 图像的存储 API:
encoded_image = tf.image.encode_png(img_data)
with tf.gfile.FastGFile('test.png', 'wb') as f:
f.write(encoded_image.eval())
函数第一个参数为图片数据,shape: [height, width, channels],数据类型必须为 uint8。同时, 两个函数中提供了图片品质压缩参数,encode_png 为 compression,0-9 之间取值,数值越大压缩越严重;encode_jpeg 为 quality,0-100之间取值,数值越大质量越好。
说明:后续的图像操作,很多只接受浮点图像数据,有些先把图像转成浮点,处理完成后再转为原来的数据类型;如果有多个图像处理操作,来回在 uint8 和 float32 之间的转换会导致精度损失,因此建议在图像处理之前先统一转换成 float32 类型:
img_data = tf.image.convert_image_dtype(img_data, tf.float32)
输入:
images: 4D 的 [N, H, W, C] 或 3D 的 [H, W, C] 数据。因此这个函数支持批处理。
size: [new_height, new_width]
method: 默认双线性插值。可选0-3。0:双线性插值;1:最近邻法;2:双三次插值;3:面积插值
align_corners: 角度是否对齐。`记得务必设置为 True`
说明: tf 中的图片 resize 和 opencv,PIL 等主流库的实现不一样。opencv 等主流库在插值计算时,对齐时是把每个像素看做一个“点区域”,因此用的是中心点对齐,有个 0.5 的偏移计算设置,可参考 CSDN博客;而 tf 在实现的时候,每个像素就是一个点,align_corners 设置为 False 就是上述博客中没加偏移的情况,显然不合理;align_corners 设置为 True 就是对齐四个角顶点,连续插值。个人觉得 tf 的设置 align_corners=True 更加合理。代码的区别可参加图: https://i.loli.net/2018/08/16/5b755dc364464.png。这一部分的相关讨论参见: https://github.com/tensorflow/tensorflow/issues/6720。
若要使用 opencv 的 resize 函数,需使用 tf.py_func 包装起来:
# img_data is a tensor
img = tf.py_func(lambda input: cv2.resize(input, (4, 4)), [img_data], tf.float32, stateful=False)
print img.eval()
输入: image: 4D 的 [N, H, W, C] 或 3D 的 [H, W, C] 数据。因此这个函数支持批处理。
生成一个大小为 [height, width] 的图,图标图像小于原图就裁剪,否则周围填0。注意只是裁剪或填充,没有插值。
输入的 image 为 3D tensor。因此不能批处理。central_fraction为 (0, 1] 之间的浮点数。
做中心裁剪,central_fraction 为长和宽裁剪出的比例。比如 [100, 100] 的原图,central_fraction=0.5,那么输出[50, 50] 大小的图。
输入:
image: 4D 的 [N, H, W, C] 或 3D 的 [H, W, C] 数据。因此这个函数支持批处理。
offset_h, offset_w: 裁剪区域的左上角坐标
target_h, target_w: 输出区域的大小
输入:
image: 4D 的 [N, H, W, C] 或 3D 的 [H, W, C] 数据。因此这个函数支持批处理。
offset_h, offset_w: 原图上面要补充多少行0,原图左侧要补充多少列0
target_h, target_w: 输出区域的大小。多出的区域一律补0
tf.image.flip_up_down(image): 垂直翻转
tf.image.random_flip_up_down(image): 随机垂直翻转
tf.image.flip_left_right(image): 左右翻转
tf.image.random_flip_left_right(image): 随机左右翻转
tf.image.transpose_image(image): 沿对角线翻转。其实就是矩阵转置。
tf.image.rot90(image, k=1): 沿逆时针旋转 k 个 90 度。
以上函数均支持单个图像处理和批量处理。
说明:
(1) 以上前四个均有 random 函数,如: tf.image.random_brightness, tf.image.random_contrast, 具体参见 API。
(2) 在做一连串的调整操作后,图像的数值分布可能已经越界了。因此在最后一步操作后记得做有效截断:
result = tf.clip_by_value(adjusted_image, 0, 1)
输入:
image: 4D 的 [N, H, W, C] 图像。注意数据类型必须为 float。
boxes: [batch, box_num, 4]。bounding box 的坐标,四个点的坐标格式是 [y_min, x_min, y_max, x_max],这四个参数都是 [0, 1] 的数,表示比例。
输出: 带框的图像。
# img: [h, w, c]
batch_img = tf.expand_dims(img, 0) # [1, 400, 800, 3]
result = tf.image.draw_bounding_boxes(batch_img, boxes=[[[0.1, 0.3, 0.5, 0.7]]])
plt.imshow(result.eval()[0])
plt.show()
原图上框的位置: 左上角[240, 40],右下角[560, 200]
注意,框越界并不会报错。
输入:
image_size: [h, w, c] 原图的 shape。
bounding_boxes: [batch, box_num, 4], ground_truth 的位置信息。坐标格式仍然是 [y_min, x_min, y_max, x_max],四个参数均为 [0, 1] 之间的浮点数,表示比例。
min_object_covered: 默认 0.1。提取的区域至少包含某个 gt 标注框的比例的百分比。
aspect_ratio_range: list, 默认 [0.75, 1.33]。提取区域的宽高比的范围 (ratio = width / height)。
max_attempts: 默认100。最大尝试次数。
use_image_if_no_bounding_boxes: 默认 False。不提供标记框时是否返回原图。
输出:元组 (begin, size, bboxes):
begin: [offset_height, offset_width, 0]。输出区域起点,即左上角的坐标。
size: [target_height, target_width, -1]。输出区域的大小。
bboxes: 输出框的坐标,shape: [1, 1, 4]。
其中输出的 begin 和 size 可作为 tf.slice 的输入,bboxes 可作为 tf.image.draw_bounding_boxes 的输入。
# img: [h, w, c]
begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(tf.shape(img),[[[0.1, 0.3, 0.5, 0.7]]], min_object_covered=0.4)
# distorted 为随机提取出来的图像区域
distorted = tf.slice(resized, begin, size)
**小坑注意:**上面的例子中,由于 tf.slice 收到的 size 中最后一个维度的数是 -1, 意思是多个通道全要提取,具体有几个通道是动态的(要根据输入来定),因此导致输出的 distorted 最后一个维度也是动态的。
print distorted.shape # (?, ?, ?)
result = tf.image.resize_images(distorted, [200, 200])
print result.shape # (200, 200, ?)
由于后面对 distorted 的操作一般不会涉及到图像通道数,为了图像的维度的 shape 能正常获取,最好在 tf.slice 后手动设定一下 shape:
# Restore the shape since the dynamic slice based upon the bbox_size loses
# the third dimension.
distorted.set_shape([None, None, 3])
print distorted.shape # (?, ?, 3)
result = tf.image.resize_images(distorted, [200, 200])
print result.shape # (200, 200, 3)
注意,以上操作均需要图片的数据类型为 tf.float32
这部分的官方代码在 slim/preprocessing/inception_preprocessing.py。
需要注意的两点:
(1) distorted_bounding_box_crop 函数的 min_object_covered 默认取 0.1,根据具体数据集调整这个参数比较合适。
(2) resize 都没有 align_corners.
import tensorflow as tf
from tensorflow.python.ops import control_flow_ops
def apply_with_random_selector(x, func, num_cases):
"""
从 func(x, 0), func(x, 1), ..., func(0, num_cases-1) 中随机选一个
"""
sel = tf.random_uniform([], maxval=num_cases, dtype=tf.int32)
# merge(inputs): 依次判断 inputs 中的元素是否存在, 返回第一个存在的元素的[data, data_index]
# switch(data, pred): 返回 (output_false, output_true), 如果 pred 为 true, output_true = data, output_false
# 不存在; 反过来也一样
return control_flow_ops.merge([
func(control_flow_ops.switch(x, tf.equal(sel, case))[1], case)
for case in range(num_cases)])[0]
def distort_color(image, color_ordering=0, fast_mode=True, scope=None):
"""
对图片进行随机色彩变换:调整亮度、饱和度、色相、对比度。
Args:
image (float): [0, 1] 之间的 3D tensor 图像
color_ordering (int, optional): Defaults to 0. 可取 0-3,代表不同的随机变换模式。
fast_mode (bool, optional): Defaults to True. 快速模式下不采用调整色相和对比度的变换。
scope (optional): Defaults to None.
Returns:
颜色变换处理后的 float32 图像。
"""
with tf.name_scope(scope, 'distort_color', [image]):
if fast_mode:
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
else:
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
else:
if color_ordering == 0:
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
elif color_ordering == 1:
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
elif color_ordering == 2:
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
elif color_ordering == 3:
image = tf.image.random_hue(image, max_delta=0.2)
image = tf.image.random_saturation(image, lower=0.5, upper=1.5)
image = tf.image.random_contrast(image, lower=0.5, upper=1.5)
image = tf.image.random_brightness(image, max_delta=32. / 255.)
else:
raise ValueError('color_ordering must be in [0, 3]')
# 有效截断
return tf.clip_by_value(image, 0.0, 1.0)
# min_object_covered 默认为 0.1,可以根据实际任务稍微改大一点
def distorted_bounding_box_crop(image,
bbox,
min_object_covered=0.1,
aspect_ratio_range=(0.75, 1.33),
area_range=(0.05, 1.0),
max_attempts=100,
scope=None):
with tf.name_scope(scope, 'distorted_bounding_box_crop', [image, bbox]):
sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
tf.shape(image),
bounding_boxes=bbox,
min_object_covered=min_object_covered,
aspect_ratio_range=aspect_ratio_range,
area_range=area_range,
max_attempts=max_attempts,
use_image_if_no_bounding_boxes=True)
bbox_begin, bbox_size, distort_bbox = sample_distorted_bounding_box
cropped_image = tf.slice(image, bbox_begin, bbox_size)
return cropped_image, distort_bbox
def preprocess_for_train(image, height, width, bbox,
fast_mode=True,
scope=None,
add_image_summaries=True):
"""
训练集数据预处理。
Args:
image: 输入图像, uint8 或者 float32(都会被转换为 float32). shape: [H, W, C]
bbox: shape: [1, num_boxes, coords]. coords: [ymin, xmin, ymax, xmax]. 为 None 时取原图整图.
fast_mode: resize 插值和颜色变换是否采用快速模式。
add_image_summaries: 是否画出画出中间处理过程得到的图。
Returns:
3D float 图像。范围在 [-1, 1] 之间。
"""
with tf.name_scope(scope, 'distort_image', [image, height, width, bbox]):
# bbox 为空取原图
if bbox is None:
bbox = tf.constant([0.0, 0.0, 1.0, 1.0],
dtype=tf.float32,
shape=[1, 1, 4])
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
bbox)
if add_image_summaries:
tf.summary.image('image_with_bounding_boxes', image_with_box)
distorted_image, distorted_bbox = distorted_bounding_box_crop(image, bbox)
# Restore the shape since the dynamic slice based upon the bbox_size loses
# the third dimension.
distorted_image.set_shape([None, None, 3])
image_with_distorted_box = tf.image.draw_bounding_boxes(
tf.expand_dims(image, 0), distorted_bbox)
if add_image_summaries:
tf.summary.image('images_with_distorted_bounding_box',
image_with_distorted_box)
# 快速模式下: resize 采取双线性插值,否则是选择随机方法插值。
num_resize_cases = 1 if fast_mode else 4
distorted_image = apply_with_random_selector(
distorted_image,
lambda x, method: tf.image.resize_images(x, [height, width], method),
num_cases=num_resize_cases)
if add_image_summaries:
tf.summary.image('cropped_resized_image',
tf.expand_dims(distorted_image, 0))
# 随机左右翻转
distorted_image = tf.image.random_flip_left_right(distorted_image)
# 调用 distort_color 随机做颜色变换,选择是否采用快速模式。
num_distort_cases = 1 if fast_mode else 4
distorted_image = apply_with_random_selector(
distorted_image,
lambda x, ordering: distort_color(x, ordering, fast_mode),
num_cases=num_distort_cases)
if add_image_summaries:
tf.summary.image('final_distorted_image',
tf.expand_dims(distorted_image, 0))
# [0, 1] 之间的图,变为 [-1, 1] 之间。
distorted_image = tf.subtract(distorted_image, 0.5)
distorted_image = tf.multiply(distorted_image, 2.0)
return distorted_image
def preprocess_for_eval(image, height, width,
central_fraction=0.875, scope=None):
"""
验证集数据预处理。做中心裁剪,双线性插值 resize,数值范围变换到 [-1, 1] 之间。
Returns:
3D float 图像。范围在 [-1, 1] 之间。
"""
with tf.name_scope(scope, 'eval_image', [image, height, width]):
if image.dtype != tf.float32:
image = tf.image.convert_image_dtype(image, dtype=tf.float32)
# 取出中间 0.875 的区域
if central_fraction:
image = tf.image.central_crop(image, central_fraction=central_fraction)
if height and width:
# resize 图像,采用双线性插值
image = tf.expand_dims(image, 0)
image = tf.image.resize_bilinear(image, [height, width],
align_corners=False)
image = tf.squeeze(image, [0])
# 取值范围变为 [-1, 1] 之间
image = tf.subtract(image, 0.5)
image = tf.multiply(image, 2.0)
return image
def preprocess_image(image, height, width,
is_training=False,
bbox=None,
fast_mode=True,
add_image_summaries=True):
"""
训练集和验证集图像预处理。
Args:
image: 输入图像, uint8 或者 float32(都会被转换为 float32). shape: [H, W, C]
is_training: 训练集还是测试集
bbox: 标记框,默认 None 表示取原图整图
fast_mode: resize 和颜色变换是否采用快速模式
add_image_summaries: 中间处理后的图是否画图。
Returns:
3D float 图像。范围在 [-1, 1] 之间。
"""
if is_training:
return preprocess_for_train(image, height, width, bbox, fast_mode,
add_image_summaries=add_image_summaries)
else:
return preprocess_for_eval(image, height, width)
其他的如 VGG 预处理,可参见: vgg_preprocessing.py,VGG 裁剪图片时是确定短边长度后再等比例 resize.
TensorFlow 中提供的队列有:
这些类型的 Queue 的 API 大致差不多,以 FIFOQueue 为例:
tf.FIFOQueue(capacity, dtypes, shapes=None): capacity 为队列容量,dtypes 为队列中元素的数据类型,shapes 为队列中元素的 shape。
-> 常用函数:
- dequeue(): 出队一个元素
- dequeue_many(n): 出队 n 个元素。使用此函数必须手动指定 shapes。
- enqueue(): 入队一个元素
- enqueue_many(val): 入队多个元素. 注意这里 val 要比基元多一维。
- size(): 返回队列中元素多少。返回数据类型为 tensor。
注意: 队列中存储了 capacity 个基元,以 list 形式存在。当基元中包含多个数据时,dtypes 是个 list,长度与基元长度要相同。若要使用 dequeue_many(n),shape 必须手动指定。
q = tf.FIFOQueue(5, tf.int32, shapes=[()])
op1 = q.enqueue_many([[1, 2]])
op1.run()
op2 = q.enqueue([3])
op2.run()
print q.size().eval() # 3
print q.dequeue().eval() # 1
print q.dequeue_many(2).eval() # [2, 3]
TODO(20180817): 当基元有多个元素时,enqueue_many 表现很奇怪,尚未弄懂。因此暂时避免使用 enqueue_many 和 dequeue_many。
该类的常用方法:
使用方法:
try:
coord = tf.train.Coordinator()
... codes of creating threads ...
coord.join(threads)
except Exception as e:
... some codes ...
上面创建线程的代码为:
try:
while not coord.should_stop():
... some work ...
except Exception as e:
coord.request_stop(e)
coord.join(threads)
可以用 stop_on_exception () 简化上面创建线程的代码:
with coord.stop_on_exception():
while not coord.should_stop():
... some work ...
coord.join(threads)
tf 中的多线程使用的队列启动方案。与 tf.train.Coordinator 一起使用。
比如一个经典的文件输入流程: 第一批线程通过往第一个队列里面不断填充要处理的文件名;第二批线程从前面的队列取出文件名,然后进行读取处理等操作,得到的张量放在第二个队列;第三批线程从第二个队列中取出张量,组成 batch,输入网络进行训练。
初始化方法:
__init__(queue=None, enqueue_ops=None)
其中 queue 为要操作的队列,enqueue_ops 为要对该队列执行的多线程操作。
tf.train.QueueRunner 常与以下两个类一起使用:
一个完整的例子:
# 创建队列,入队操作
queue = tf.FIFOQueue(100, tf.float32)
enqueue_op = queue.enqueue([tf.random_normal([])])
# 开启 5 个线程
qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)
tf.train.add_queue_runner(qr)
out_tensor = queue.dequeue()
with tf.Session() as sess:
coord = tf.train.Coordinator()
enqueue_threads = tf.train.start_queue_runners(sess=sess, coord=coord)
for i in range(5):
print sess.run(out_tensor)
# 终止所有线程
coord.request_stop()
coord.join(enqueue_threads)
当然,上面也可以不使用 tf.train.add_queue_runner 和 tf.train.start_queue_runner 来启动线程。直接用 QueueRunner 的 create_threads(sess, coord=None, daemon=False, start=False) 来启动多个线程:
queue = tf.FIFOQueue(100, tf.float32)
enqueue_op = queue.enqueue([tf.random_normal([])])
# 开启 5 个线程
qr = tf.train.QueueRunner(queue, [enqueue_op] * 5)
out_tensor = queue.dequeue()
with tf.Session() as sess:
coord = tf.train.Coordinator()
# 手动启动 qr 负责的入队线程
enqueue_threads = qr.create_threads(sess=sess, coord=coord, start=True)
for i in range(5):
print sess.run(out_tensor)
coord.request_stop()
coord.join(enqueue_threads)
tf.train.match_filenames_once(pattern): 根据正则表达式获取符合要求的文件名列表。返回一个局部变量,为文件名列表。因此使用这和函数务必初始化局部变量。
tf.train.string_input_producer(string_tensor, num_epochs=None, shuffle=True, capacity=32): 根据文件名返回一个文件名队列,供多线程使用。string_tensor: 装有文件名的 tensor,可由 match_filenames_onces 函数返回,也可用 Python glob 生成等;num_epochs: 加载文件列表的最大轮数,默认无限循环;shuffle: 文件名加入队列前是否打乱;capacity: 队列长度。
注意: 从源码中发现,该函数创建了局部变量(对 epoch 进行计数),因此使用时要初始化局部变量(其实不初始化也可以,就不使用 epoch 这个变量)。该函数返回一个 FIFOQueue 用于存放文件名,并生成一个 QueueRunner 进行单线程入队操作,该 QueueRunner 放在默认的 tf.GraphKeys.QUEUE_RUNNERS 这个 collection 中。
在指定 num_epoches (如测试时指定为 1 ) 时,队列为空后继续出队,抛出 OutOfRange 异常。
sess = tf.InteractiveSession()
# 获取所有 png 图片文件列表
files = tf.train.match_filenames_once('some_path/*.png')
filename_queue = tf.train.string_input_producer(files, shuffle=False)
coord = tf.train.Coordinator()
thread = tf.train.start_queue_runners(sess, coord)
tf.local_variables_initializer().run()
print filename_queue.dequeue().eval() # 获取一个文件名
前面 tf.train.string_input_producer 生成了文件名队列,tf 通过各种 Reader 从这个文件名队列中取文件名,进行文件读取解析。常用的 Reader 有:
它们的 API 大致相同。常用的方法有:
用 Reader 实现 2.1 节的图片读取:
files = tf.train.match_filenames_once('some_path/*.png')
filename_queue = tf.train.string_input_producer(files, shuffle=False)
key, value = tf.WholeFileReader().read(filename_queue)
image = tf.image.decode_png(value, channels=3)
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
print sess.run(image)
单线程,效率较低。
一个高效的数据 pipeline 应该是一个生产者——消费者模型,即生产者是一个文件名队列,里面存储要处理的文件的名字,这个队列用单线程处理即可;消费者为多线程从生产者队列里面取出文件名,进行文件读取和预处理,然后放置在另一个队列,最终的数据从这个队列中取出。
一个完整的流程如图:
TensorFlow 中提供的对应的 batch 数据的函数为:
注意: 以上 dynamic_pad 为 False 时,传入的 tensors 必须显式确定,否则抛出异常。
一般我们的文件不止一个,比如 TensorFlow Performance Guide 建议,把大数据文件分割成多个约为 100 MB 的 TFRecord 文件,I/O 性能比较好。这种情况下,多文件,多线程进行读取和预处理操作应该用上面两个函数。
其中,输入的 tensors_list 为 a list of tuples of tensors,创建 len(tensors_list) 个线程,每个线程读取一个文件,然后压入队列:
# features 为解析的 TFRecord 文件
image, label = features['image'], features['label']
# 1. 使用 tf.train.batch:
# 返回的 iamge 是个 [N, H, W, C] 的 tensor
# 记住 image 和 label 要一起 run,不然就错位交叉了
image, label = tf.train.batch([image, label], batch_size=10, num_threads=1, capacity=100)
# 2. 使用 tf.train.batch_join:
image, label = tf.train.batch_join([[image, label] for _ in range(4)], batch_size=10, capacity=100)
比较: tf.train.batch 是多线程读取一个文件,tf.train.batch_join 是多线程读取多个文件,每个线程负责一个文件。如果同一个文件中样本相似,用 tf.train.batch 显然不合适;使用 tf.train.batch_join() 时,如果线程数大于文件数,那么也存在多个线程读取同一个文件的情况,而且多线程读多个文件的硬盘寻址也是有时间开销的,可能会让效率变低。
# coding: utf-8
import tensorflow as tf
files = tf.train.match_filenames_ones('some_path/data.tfrecords-*')
# 训练集,文件名打乱
filename_queue = tf.train.string_input_producer(files, shuffle=True, )
_, serialized_example = tf.TFRecordReader().read(filename_queue)
features = tf.parse_single_example(serialized_example, features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channels': tf.FixedLenFeature([], tf.int64)
})
image, label, height, width, channels = features['image'], features['label'], features['height'], features['width'], features['channels']
decoded_image = tf.decode_raw(image, tf.uint8)
decoded_image = tf.reshape(decoded_image, tf.stack([height, width, channels]))
image_size = 299
# 前面的 inception 预处理代码
distorted_image = preprocessing_for_train(decoded_image, image_size, image_size, None)
# 这一部分可以具体调整
num_threads = 10
batch_size = 50
min_after_dequeue = 1000
# as suggested by https://www.tensorflow.org/api_guides/python/reading_data#Batching
capacity = min_after_dequeue + (num_threads + 10) * batch_size
image_batch, label_batch = tf.train.shuffle_batch([distorted_image, label], batch_size=batch_size,
capacity=capacity, min_after_dequeue=min_after_dequeue, num_threads=num_threads)
# tf.train.shuffle_batch_join 的方案
# image_batch, label_batch = tf.train.shuffle_batch([[distorted_image, label] for i in range(num_threads)], batch_size=batch_size,
# capacity=capacity, min_after_dequeue=min_after_dequeue)
logit = inference(image_batch)
loss = calc_loss(logit, label_batch)
train_step = ...
with tf.Session() as sess:
sess.run(tf.global_variables_initializer(), tf.local_variables_initializer())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
with coord.stop_on_exception():
while not coord.should_stop():
for i in range(training_epoches):
sess.run(train_step)
coord.join(threads)
tf 1.4 后的数据输入框架,抛弃队列处理的旧 API,使用 Dataset 数据集提供数据的输入。这一部分比较简单,官方文档很详细。
(1) 利用数据集的基本步骤:
(2) 常用的 Dataset:
(3) Dataset 常用的属性:
(4) Dataset 常用的方法:
batch(batch_size): 按 batch_size 取一个 batch。最后不够一个 batch_size 也会被取出来,如果不要最后的零头,使用 tf.contrib.data.batch_and_drop_remainder 方法。
padded_batch(batch_size, padded_shapes, padding_values=None): 生成 padded batch。
shuffle(buffer_size, seed=None, reshuffle_each_iteration=None): 打乱数据集。buffer_size: 缓冲区大小。默认 reshuffle_each_iteration 为 True。因此,每迭代一个元素出来,缓冲区中填充一个新元素,shuffle 一次。buffer_size 越大,打乱性能越好,但是第一次启动的时间就较长。可设置 buffer_size 为数据集大小,这样打乱充分。参考链接。
repeat(count=None): 把 Dataset 重复 count 次。None 或 -1 表示无限循环。
skip(count): 跳过前 count 个元素。
take(count): 读取前 count 个元素。
shard(num_shards, index): 生成一个只包含原 Dataset 的 1/num_shards 的新 Dataset。
from_tensors(tensors): 根据单个元素构建 Dataset。
from_tensor_slices(tensors): 根据元素切片构建 Dataset。
from_generator(generator, output_types, out_shapes=None): 根据生成器构建 Dataset。这一部分使用 tf.py_func 实现的。
map(map_func, num_parallel_calls=None): map 运算,返回一个 map 运算后的 Dataset。
flat_map(map_func): map_func 后再 flat 展平,返回一个 Dataset。
filter(predicate): filter 运算,返回一个 Dataset。
interleave(map_func, cycle_length, block_length=1): 适用于分布式文件系统。当有多个文件时,一次对 cycle_length 个文件同时读取,block_length 是每个线程输出元素的个数。因此,map 和 flat_map 相当于 tf.train.batch,interleave 相当于 tf.train.batch_join。
apply(transformation_func): 对一个 Dataset 进行某个操作,类似于 map。
zip(datasets): 和 Python 中的 zip 函数一样,把多个 Dataset zip起来。
concatenate(dataset): 串接一个 Dataset, 返回一个新的合并的 Dataset。
prefetch(buffer_size): 预加载 buffer_size 的数据。
range(*args): 生成一个 RangeDataset。
list_files(file_pattern, shuffle=None): 根据文件名 file_pattern 正则表达式,获取一个包含这些文件的 Dataset。注意,这个顺序是不定的,即使 shuffle 为 False。看源码发现这个函数其实是先用 tf.matching_files 得到文件列表,在用 from_tensor_slices 得到一个 dataset,最后 shuffle。注意,最后一个 shuffle 默认的缓冲区为整个文件名列表。因此,由于 from_tensor_slices 和 shuffle 缓冲区长度的设定,当文件名列表过于巨大时,这一步耗时就会很大,可参考 Github。
(5)迭代器:
以一个简单的文本处理为例,假设为分布式文件系统,有 5 个 txt 文件,每个文件里面存放某一类样本的文件名和类别标号,形如:
- file1.txt:
cat1.jpg 0
cat2.jpg 0
...
cat4.jpg 0
- file2.txt
dog1.jpg 1
dog2.jpg 1
...
dog4.jpg 1
...
- file_n.txt
(1) 基本流程,无各种优化考虑:
# 获取文件列表,并按照 file_{n} 中的数字 n 从小到大排序
txt_files = glob.glob('some_path/*.txt')
txt_files.sort(key=lambda x: int(x.split('.')[-2][-1]))
# 使用 TextLineDataset 读取文本文件
dataset = tf.data.TextLineDataset(txt_files)
dataset = dataset.shuffle(buffer_size=20)
dataset = dataset.repeat(2)
# 分割字符串,按空格分割
dataset = dataset.map(lambda x: tf.string_split([x], delimiter=' ').values)
dataset = dataset.batch(2)
# 创建迭代器
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
输出的结果为:
[['bird1.jpg' '3']
['dog4.jpg' '1']]
[['dog2.jpg' '1']
['bird3.jpg' '3']]
...
要注意的几点:
-- 参考 4.1 中设置 shuffle 的 buffer_size
-- shuffle 和 repeat 的顺序,建议先 shuffle 再 repeat,如果反过来会造成打乱效果变差。参考: Dataset Performance Guide。
(2) 基础性能优化: 多线程读取多文件,多线程 map,预加载 prefetch
上面代码的性能缺点:单线程读取,单线程 map,加上 Dataset 默认的 lazy 属性,性能地下。
基础改进代码:
txt_files = glob.glob('some_path/*.txt')
txt_files.sort(key=lambda x: int(x.split('.')[-2][-1]))
# 改进1:多线程读取文件,创建5个线程,每个线程负责一个文件读取。
# 因此改进后可以同时读 5 个文件
dataset = tf.data.Dataset.from_tensor_slices(txt_files)
dataset = dataset.interleave(lambda x: tf.data.TextLineDataset(x),
cycle_length=5, block_length=1)
dataset = dataset.shuffle(buffer_size=20)
dataset = dataset.repeat(2)
# 改进2: 多线程 map 函数,也就是可以多线程预处理了。
dataset = dataset.map(lambda x: tf.string_split([x], delimiter=' ').values,
num_parallel_calls=4)
dataset = dataset.batch(2)
# 改进3:预加载,输出有缓冲区。类似于消费者队列中填充一定 batch 数,等待消费。
# 这里是基于前一步操作后的元素,因此是预加载 5 个 batch。
dataset = dataset.prefetch(5)
...
TODO: 上面的 interleave 同时读取 5 个文件,但是真的是并行吗?按照 tf.data API slides 应该是并行 I/O,但是按照 Dataset Performance Guide 却建议使用 tf.contrib.data.parallel_interleave() 实现真正的并行 I/O,这里有待进一步明确,是否需要改成下面的版本:
dataset = dataset.apply(tf.contrib.data.parallel_interleave(tf.data.TextLineDataset, cycle_length=4))
(3) 进一步优化:
并行化 batch: 当 batch_size 比较大时,取 batch 也是耗时的,因此可以把 map 和 batch 放在一起做多线程,最终改成这样的版本(使用 tf.contrib.data.map_and_batch 函数):
txt_files = glob.glob('some_path/*.txt')
txt_files.sort(key=lambda x: int(x.split('.')[-2][-1]))
# 确保运行在 CPU 上
with tf.device('/cpu:0'):
dataset = tf.data.Dataset.from_tensor_slices(txt_files)
dataset = dataset.apply(tf.contrib.data.parallel_interleave(tf.data.TextLineDataset, cycle_length=4))
dataset = dataset.shuffle(buffer_size=20)
dataset = dataset.repeat(2)
# 改动这里
dataset = dataset.apply(tf.contrib.data.map_and_batch(lambda x: tf.string_split([x], delimiter=' ').values, batch_size=2, num_parallel_batches=4))
dataset = dataset.batch(2)
dataset = dataset.prefetch(5)
其他的优化还有内存 cache 等的考虑等,参考 Dataset Performance Guide 。
注意:一般而言, batch 耗时相对较少,当电脑核心不太够的时候,并行化的 batch 占据线程也不一定是好事,可能速度也会变慢。
(1) initializable iterator: 动态指定 iterator 参数。配合 placeholder 使用。
# 指定文件名的 placeholder
txt_files = tf.placeholder(tf.string, [])
dataset = tf.data.TextLineDataset([txt_files])
dataset = dataset.map(lambda x: tf.string_split([x], delimiter=' ').values)
iterator = dataset.make_initializable_iterator()
next_element = iterator.get_next()
iterator.initializer.run(feed_dict={txt_files: './txt_files/file1.txt'})
for i in range(5):
print next_element.eval()
# 切换到另一个 txt 文件
iterator.initializer.run(feed_dict={txt_files: './txt_files/file2.txt'})
for i in range(2):
print next_element.eval()
# 中间重新初始化,迭代器从头开始
iterator.initializer.run(feed_dict={txt_files: './txt_files/file2.txt'})
for i in range(5):
print next_element.eval()
(2) reinitializable iterator: 一个可重复初始化的迭代器,绑定到不同的数据集上去。
# 两个数据集
cat_dataset = tf.data.TextLineDataset(['./txt_files/file1.txt']).map(lambda x: tf.string_split([x], delimiter=' ').values)
dog_dataset = tf.data.TextLineDataset(['./txt_files/file2.txt']).map(lambda x: tf.string_split([x], delimiter=' ').values)
# 同一个 reinitializable iterator
iterator = tf.data.Iterator.from_structure(cat_dataset.output_types, dog_dataset._output_shapes)
next_element = iterator.get_next()
# 迭代器初始化
cat_init_op = iterator.make_initializer(cat_dataset)
dog_init_op = iterator.make_initializer(dog_dataset)
# 迭代器绑定到 cat dataset
cat_init_op.run()
for _ in range(5):
print next_element.eval()
# 迭代器绑定到 dog dataset
dog_init_op.run()
for _ in range(5):
print next_element.eval()
通常可以用同一个迭代器绑定到 training set 和 test_set,不过,用前面的 initializer_iterator 也可实现相同功能。
(3) feedable iterator: 迭代器是可变的,目的是通过选择不同数据集的 Dataset 的迭代器来迭代不同的数据。
# 创建两个 Dataset
cat_dataset = tf.data.TextLineDataset(['./txt_files/file1.txt']).map(lambda x: tf.string_split([x], delimiter=' ').values)
dog_dataset = tf.data.TextLineDataset(['./txt_files/file2.txt']).map(lambda x: tf.string_split([x], delimiter=' ').values)
# 创建两个 Dataset 对应的迭代器
cat_iterator = cat_dataset.make_one_shot_iterator()
dog_iterator = dog_dataset.make_initializable_iterator()
# 创建代表两个 Dataset 的迭代器的 handle tensor
cat_handle = cat_iterator.string_handle().eval()
dog_handle = dog_iterator.string_handle().eval()
# 根据传入的 handle 选择具体的迭代器
handle = tf.placeholder(tf.string, [])
iterator = tf.data.Iterator.from_string_handle(handle, cat_dataset.output_types, cat_dataset.output_shapes)
next_element = iterator.get_next()
while True:
# 这里切换数据集,之后再切回 cat_dataset 是继续之前的数据迭代。
for _ in range(2):
print next_element.eval(feed_dict={handle: cat_handle})
dog_iterator.initializer.run()
for _ in range(2):
print next_element.eval(feed_dict={handle: dog_handle})
reinitializable iterator 和 feedable iterator 的区别:
二者都是实现切换数据集的功能,但是 reinitializable iterator 是同一个迭代器绑定到不同数据集,切换数据集的时候要重新绑定,然后初始化这个迭代器,特点是切换数据集后从头开始迭代切换到的新数据集;
feedable iterator 是针对每个数据集先建立对应的迭代器,然后选择用哪一个迭代器来选择对应的数据集,如果这些迭代器在创建的时候先初始化好,那么在迭代器之间来回切换的时候,**各自的数据集是接着前一次的时间点继续输出的。**比如训练集特别大,并不想训练集完全迭代一轮再验证,迭代一定数量的较少样本后就想测试一次验证集,就可以用这种迭代器。
(1) 读取 npy 文件:
可以直接用 from_tensor_slices 一次载入:
data = np.load('data.npy')
features = data["features"]
labels = data["labels"]
dataset = tf.data.Dataset.from_tensor_slices((features, labels))
...
缺点是:整个 npy 解析出来的数组以 tf.constant() 的形式存在于 Graph 中,而一个 Graph protobuf 文件的上限大小为 2G。
可以使用 initializable_iterator 配合 placeholder 载入:
data = np.load('data.npy')
features = data["features"]
labels = data["labels"]
# 创建对应的 placeholder
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
... 一些其他预处理操作等 ...
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={feature_placeholder: features, labels_placeholder: labels})
上面方案其实也不好。上面两个是官方提供的方案。
Stack Overflow 上提供了其他方案:
比如这个方案计算出 npy 的 header_bytes 长度,然后用 tf.data.FixedLengthRecordDataset 来解析 npy 二进制文件。如果要并行处理,那么多个 npy 的 header_bytes 长度要一样。
另外,可以用 tf.py_func 来调用 np.load() 函数,然后放到 Dataset.map() 函数里面解决:
file_list = ['a.npy', 'b.npy', ...]
def read_npy_file(file):
return np.load(file).as_type(np.float32)
dataset = tf.data.Dataset.from_tensor_slices(file_list)
dataset = dataset.map(lambda file: tuple(tf.py_func(
read_npy_file, [file], [tf.float32])))
...
其他的一些应用可以参考 Tensorflow Dataset Guide -- Importing Data,里面有从 Python 迭代器得到 Dataset,读取 csv 等文件,进行 padding 操作,利用 tf.py_func 调用 opencv 等。
用 tf.data 重写 3.2.6 部分:
import tensorflow as tf
IMAGE_SIZE = 299 # 输入图片大小
BATCH_SIZE = 50
SHUFFLE_BUFFER = 10000
NUM_EPOCHES = 100
# 列举 tfrecord 文件名
training_files = tf.train.match_filenames_ones('training_path/data.tfrecords-*')
val_files = tf.train.match_filenames_ones('val_path/data.tfrecords-*')
def tfrecord_parser(record):
features = tf.parse_single_example(
record, features={
'image': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
'height': tf.FixedLenFeature([], tf.int64),
'width': tf.FixedLenFeature([], tf.int64),
'channels': tf.FixedLenFeature([], tf.int64)
}
)
label, height, width, channels = features['label'], features['height'], features['width'], features['channels']
decoded_image = tf.decode_raw(features['image'], tf.uint8)
decoded_image = tf.reshape(decoded_image, tf.stack([height, width, channels]))
return decoded_image, label
# 定义 traing dataset 和对应的迭代器
# dataset = tf.data.TFRecordDataset(training_files) 这个效率低
training_dataset = tf.data.Dataset.from_tensor_slices(training_files)
training_dataset = training_dataset.interleave(lambda x: tf.data.TFRecordDataset(x), cycle_length=4)
training_dataset = training_dataset.shuffle(SHUFFLE_BUFFER)
training_dataset = training_dataset.repeat(NUM_EPOCHES)
training_dataset = training_dataset.map(tfrecord_parser, num_parallel_calls=2)
training_dataset = training_dataset.map(lambda image, label: (preprocess_for_train(image, IMAGE_SIZE, IMAGE_SIZE), label), num_parallel_calls=4)
training_dataset = training_dataset.batch(BATCH_SIZE)
training_dataset = training_dataset.prefetch(5)
training_iterator = training_datasetmake_initializable_iterator()
training_image_batch, training_label_batch = training_iterator.get_next()
# 网络训练部分
logit = inference(training_image_batch)
loss = calc_loss(logit, training_label_batch)
train_step = ...
# 创建 validation dataset 和对应的迭代器
val_dataset = tf.data.TFRecordDataset(val_files)
val_dataset = val_dataset.map(tfrecord_parser)
val_dataset = val_dataset.map(lambda image, label: (preprocess_for_val(image, IMAGE_SIZE, IMAGE_SIZE), label), num_parallel_calls=2)
val_dataset = val_dataset.batch(BATCH_SIZE)
val_iterator = val_dataset.make_initializable_iterator()
val_image_batch, val_label_batch = val_iterator.get_next()
# 网络测试部分
val_logit = inference(val_image_batch)
predictions = tf.argmax(val_logit, axis=-1, output_type=tf.int32)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(training_iterator.initializer)
while True:
try:
sess.run(train_step)
except tf.errors.OutOfRangeError:
break
sess.run(val_iterator.initializer)
val_results = []
val_labels = []
while True:
try:
pred, label = sess.run([predictions, val_label_batch])
val_results.extend([pred])
val_labels.extend([label])
except tf.errors.OutOfRangeError:
break
# 计算准确率等指标
correct = [float(y == y_) for (y, y_) in zip(val_results, val_labels)]
acc = sum(correct) / len(correct)