今天尝试总结一下 tf.data 这个API的一些用法吧。之所以会用到这个API,是因为需要处理的数据量很大,而且数据均是分布式的存储在多台服务器上,所以没有办法采用传统的喂数据方式,而是运用了 tf.data 对数据进行了相应的预处理,并且最近正赶上总结需要,尝试写一下关于 tf.data 的一些用法,有错误的地方一定告诉我哈。
Tensorflow的数据读取
先来看一下Tensorflow的数据读取机制吧
这一篇文章对于 tensorflow的数据读取机制 讲解得很不错,大噶可以先看一下,有一个了解。
Dataset API是怎么用的呢
虽然上面的资料关于 tf.data 讲解得都很好,但是我没有找到一个很完整滴运用 tf.data.TextLineDataset() 和 tf.data.TFRecordDataset() 的例子,所以才想尝试写一写这篇总结。
MNIST的经典例子
本篇博客结合 mnist 的经典例子,针对不同的源数据:csv数据和tfrecord数据,分别运用 tf.data.TextLineDataset() 和 tf.data.TFRecordDataset() 创建不同的 Dataset 并运用四种不同的 Iterator ,分别是 单次,可初始化,可重新初始化,以及可馈送迭代器 的方式实现对源数据的预处理工作。
我将相关的资料放在了澜子的Github 上,欢迎互粉哇(星星眼)。其中包括了所需的 后缀名为csv和tfrecords的源数据 (data
的文件夹),以及在 jupyter notebook实现的具体代码 (tf_dataset_learn.ipynb
)。
如果有需要的同学可以直接git clone https://github.com/lanhongvp/tensorflow_dataset_learn.git
然后用 jupyter 跑一跑看看输出,这样可以有一个比较直观的认识。关于 Git和Github 的使用,大噶可以看我VSCODE_GIT这一篇博客啦。接下来,针对MNIST例子做一个简单的说明吧。
tf.data.TFRecordDataset() & make_one_shot_iterator()
tf.data.TFRecordDataset() 输入参数直接是后缀名为tfrecords
的文件路径,正因如此,即可解决数据量过大,导致无法单机训练的问题。本篇博客中,文件路径即为/Users/honglan/Desktop/train_output.tfrecords
,此处是我自己电脑上的路径,大家可以 根据自己的需要修改为对应的文件路径。
make_one_shot_iterator() 即为单次迭代器,是最简单的迭代器形式,仅支持对数据集进行一次迭代,不需要显式初始化。
配合 MNIST数据集以及tf.data.TFRecordDataset(),实现代码如下。
1 | # Validate tf.data.TFRecordDataset() using make_one_shot_iterator() |
tf.data.TFRecordDataset() & Initializable iterator
make_initializable_iterator()
为可初始化迭代器,运用此迭代器首先需要先运行显式 iterator.initializer
操作,然后才能使用。并且,可运用 可初始化迭代器实现训练集和验证集的切换。
配合 MNIST数据集 实现代码如下。
1 | # Validate tf.data.TFRecordDataset() using make_initializable_iterator() |
tf.data.TextLineDataset() & Reinitializable iterator
tf.data.TextLineDataset()
,输入参数可以是后缀名为csv
或者是txt
的源数据的文件路径。
此处用的迭代器是 Reinitializable iterator
,即为可重新初始化迭代器。官方定义如下。配合 MNIST数据集 实现代码见第二部分。
可重新初始化迭代器可以通过多个不同的 Dataset 对象进行初始化。例如,您可能有一个训练输入管道,它会对输入图片进行随机扰动来改善泛化;还有一个验证输入管道,它会评估对未修改数据的预测。这些管道通常会使用不同的 Dataset 对象,这些对象具有相同的结构(即每个组件具有相同类型和兼容形状)。
1 | # validate tf.data.TextLineDataset() using Reinitializable iterator |
tf.data.TextLineDataset() & Feedable iterator.
数据集读取方式同上一部分一样,运用tf.data.TextLineDataset()
此处运用的迭代器是 可馈送迭代器,其可以与 tf.placeholder
一起使用,通过熟悉的 feed_dict
机制选择每次调用 tf.Session.run
时所使用的 Iterator。并使用 tf.data.Iterator.from_string_handle
定义一个可让在两个数据集之间切换的可馈送迭代器,结合 MNIST数据集 的代码如下
1 | # validate tf.data.TextLineDataset() using two different iterator |
小结
- 运用
tfrecords
处理数据的速度明显加快 - 可以根据自身需要选择不同的
iterator
方式对源数据进行预处理 - 单机训练时也可以采用
tf.data
中API的相应处理方式