python如何读取tfrecord_tensorflow 的 tfrecord数据写入和读取流程
第一步 生成TFRecod Writer$1 import tensorflow as tf$2 writer = tf.python_io.TFRecordWriter(path, options=None)其中,path: 文件的存放路径,例如path='.../train.tfrecord'options:TFRecordOptions对象,定义TFRecord文件保存的压缩格式,如下:#
第一步 生成TFRecod Writer$1 import tensorflow as tf
$2 writer = tf.python_io.TFRecordWriter(path, options=None)
其中,path: 文件的存放路径,例如path='.../train.tfrecord'
options:TFRecordOptions对象,定义TFRecord文件保存的压缩格式,如下:# 三种文件压缩方式,默认方式3
# 方式1
writer = tf.python_io.TFRecordWriter(path,
options=tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB))
# 方式2
writer = tf.python_io.TFRecordWriter(path,
options=tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP))
# 方式3
writer = tf.python_io.TFRecordWriter(path,
options=tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE))
第二步 tf.train.Feature 生成协议信息$1 feature = {"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image.tobytes()])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}
其中,image.tobytes()表示将一张图片对象转成bytes格式;feature是字典对象,通过键"image"、"label"可以索引到image和label转变后的Feature特征$2 features = tf.train.Features(feature)
表示将多个feature封装成一个features
第三步 tf.train.Example将features封装成特定的PB协议格式$1 example = tf.train.Example(features)
第四步 将example序列化成字符串$1 example = example.SerializeToString()
第五步 将example写入协议缓冲区,并关闭writer$1 writer.write(example)
$2 writer.close()
一个生成CIFAR10数据集的test.tfrecord实例
代码:import tensorflow as tf
import glob
import os
import numpy as np
import cv2
# CIFAR10数据集的类别
CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
# ==========================================================
# 获取图片和对应的标签
# ==========================================================
def get_data_and_label(path):
img_data = []
img_label = []
for every_class in CLASSES:
# 每一类的所有图片路径,注意路径拼接是否正确
every_class_path = glob.glob(os.path.join(path, every_class) + '/*')
# print(every_class_path)
# 将当前类别的所有图片路径添加到im_data中, 使用list的拼接方法
img_data += every_class_path
# 将当前类别的所有图片对应的标签添加到im_label中,当前的所有图片只对应一个标签
every_class_label = [CLASSES.index(every_class) for _ in range(every_class_path.__len__())]
img_label += every_class_label
return img_data, img_label
# ==========================================================
# 创建tfrecord
# ==========================================================
def create_tfrecord(filename, img_data, img_label):
writer = tf.python_io.TFRecordWriter(filename)
# 打乱数据
index = [i for i in range(img_data.__len__())]
np.random.shuffle(index)
for index_i in index:
# 根据索引index_i在im_data和im_label中获取一张图片路径和一个对应的标签
im_d = img_data[index_i]
im_l = img_label[index_i]
data = cv2.imread(im_d)
features = tf.train.Features(
feature={
"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[data.tobytes()])),
"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[im_l])),
}
)
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())
writer.close()
if __name__ == '__main__':
dataset_path = '.../data/image/test'
tfrecord_file = ".../data/test.tfrecord"
# 获取图片的所有路径和对应的标签
im_data, im_label = get_data_and_label(dataset_path)
# 创建tfrecord
create_tfrecord(tfrecord_file, im_data, im_label)
文件夹目录:

image.png
https://www.jianshu.com/p/0bcc1a2cfc04
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐



所有评论(0)