第一步 生成TFRecod Writer$1 import tensorflow as tf

$2 writer = tf.python_io.TFRecordWriter(path, options=None)

其中,path: 文件的存放路径,例如path='.../train.tfrecord'

options:TFRecordOptions对象,定义TFRecord文件保存的压缩格式,如下:# 三种文件压缩方式,默认方式3

# 方式1

writer = tf.python_io.TFRecordWriter(path,

options=tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.ZLIB))

# 方式2

writer = tf.python_io.TFRecordWriter(path,

options=tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.GZIP))

# 方式3

writer = tf.python_io.TFRecordWriter(path,

options=tf.python_io.TFRecordOptions(tf.python_io.TFRecordCompressionType.NONE))

第二步 tf.train.Feature 生成协议信息$1 feature = {"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image.tobytes()])),

"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[label]))}

其中,image.tobytes()表示将一张图片对象转成bytes格式;feature是字典对象,通过键"image"、"label"可以索引到image和label转变后的Feature特征$2 features = tf.train.Features(feature)

表示将多个feature封装成一个features

第三步 tf.train.Example将features封装成特定的PB协议格式$1 example = tf.train.Example(features)

第四步 将example序列化成字符串$1 example = example.SerializeToString()

第五步 将example写入协议缓冲区,并关闭writer$1 writer.write(example)

$2 writer.close()

一个生成CIFAR10数据集的test.tfrecord实例

代码:import tensorflow as tf

import glob

import os

import numpy as np

import cv2

# CIFAR10数据集的类别

CLASSES = ['airplane', 'automobile', 'bird', 'cat', 'deer',

'dog', 'frog', 'horse', 'ship', 'truck']

# ==========================================================

# 获取图片和对应的标签

# ==========================================================

def get_data_and_label(path):

img_data = []

img_label = []

for every_class in CLASSES:

# 每一类的所有图片路径,注意路径拼接是否正确

every_class_path = glob.glob(os.path.join(path, every_class) + '/*')

# print(every_class_path)

# 将当前类别的所有图片路径添加到im_data中, 使用list的拼接方法

img_data += every_class_path

# 将当前类别的所有图片对应的标签添加到im_label中,当前的所有图片只对应一个标签

every_class_label = [CLASSES.index(every_class) for _ in range(every_class_path.__len__())]

img_label += every_class_label

return img_data, img_label

# ==========================================================

# 创建tfrecord

# ==========================================================

def create_tfrecord(filename, img_data, img_label):

writer = tf.python_io.TFRecordWriter(filename)

# 打乱数据

index = [i for i in range(img_data.__len__())]

np.random.shuffle(index)

for index_i in index:

# 根据索引index_i在im_data和im_label中获取一张图片路径和一个对应的标签

im_d = img_data[index_i]

im_l = img_label[index_i]

data = cv2.imread(im_d)

features = tf.train.Features(

feature={

"image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[data.tobytes()])),

"label": tf.train.Feature(int64_list=tf.train.Int64List(value=[im_l])),

}

)

example = tf.train.Example(features=features)

writer.write(example.SerializeToString())

writer.close()

if __name__ == '__main__':

dataset_path = '.../data/image/test'

tfrecord_file = ".../data/test.tfrecord"

# 获取图片的所有路径和对应的标签

im_data, im_label = get_data_and_label(dataset_path)

# 创建tfrecord

create_tfrecord(tfrecord_file, im_data, im_label)

文件夹目录:

wAAACwAAAAAAQABAEACAkQBADs=

image.png

https://www.jianshu.com/p/0bcc1a2cfc04

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐