#-*- coding:utf-8 -*-

__author__ = 'Leo.Z'

#Tensorflow Version:1.14.0

importosimporttensorflow as tffrom PIL importImage

BATCH_SIZE= 128

defread_cifar10(filenames):

label_bytes= 1height= 32width= 32depth= 3image_bytes= height * width *depth

record_bytes= label_bytes +image_bytes#lamda函数体

#def load_transform(x):

## Convert these examples to dense labels and processed images.

#per_record = tf.reshape(tf.decode_raw(x, tf.uint8), [record_bytes])

#return per_record

#tf v1.14.0版本的FixedLengthRecordDataset(filename_list,bin_data_len)

datasets = tf.data.FixedLengthRecordDataset(filenames=filenames, record_bytes=record_bytes)#是否打乱数据

#datasets.shuffle()

#重复几轮epoches

datasets = datasets.shuffle(buffer_size=BATCH_SIZE).repeat(2).batch(BATCH_SIZE)#使用map,也可使用lamda(注意,后面使用迭代器的时候这里转换为uint8没用,后面还得转一次,否则会报错)

#datasets.map(load_transform)

#datasets.map(lamda x : tf.reshape(tf.decode_raw(x, tf.uint8), [record_bytes]))

#创建一起迭代器tf v1.14.0

iter =tf.compat.v1.data.make_one_shot_iterator(datasets)#获取下一条数据(label+image的二进制数据1+32*32*3长度的bytes)

rec =iter.get_next()#这里转uint8才生效,在map中转貌似有问题?

rec =tf.decode_raw(rec, tf.uint8)

label=tf.cast(tf.slice(rec, [0, 0], [BATCH_SIZE, label_bytes]), tf.int32)#从第二个字节开始获取图片二进制数据大小为32*32*3

depth_major =tf.reshape(

tf.slice(rec, [0, label_bytes], [BATCH_SIZE, image_bytes]),

[BATCH_SIZE, depth, height, width])#将维度变换顺序,变为[H,W,C]

image = tf.transpose(depth_major, [0, 2, 3, 1])#返回获取到的label和image组成的元组

return(label, image)defget_data_from_files(data_dir):#filenames一共5个,从data_batch_1.bin到data_batch_5.bin

#读入的都是训练图像

filenames = [os.path.join(data_dir, 'data_batch_%d.bin' %i)for i in range(1, 6)]#判断文件是否存在

for f infilenames:if nottf.io.gfile.exists(f):raise ValueError('Failed to find file:' +f)#获取一张图片数据的数据,格式为(label,image)

data_tuple =read_cifar10(filenames)returndata_tupleif __name__ == "__main__":#获取label和type的对应关系

label_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

name_list= ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

label_map=dict(zip(label_list, name_list))

with tf.compat.v1.Session() as sess:

batch_data= get_data_from_files('cifar10_dir/cifar-10-batches-bin')#在之前的旧版本中,因为使用了filename_queue,所以要使用start_queue_runners进行数据填充

#1.14.0由于没有使用filename_queue所以不需要

#threads = tf.train.start_queue_runners(sess=sess)

sess.run(tf.compat.v1.global_variables_initializer())#创建一个文件夹用于存放图片

if not os.path.exists('cifar10_dir/raw'):

os.mkdir('cifar10_dir/raw')#存放30张,以index-typename.jpg命名,例如1-frog.jpg

for i in range(30):#获取一个batch的数据,BATCH_SIZE

#batch_data中包含一个batch的image和label

batch_data_tuple =sess.run(batch_data)#打印(128, 1)

print(batch_data_tuple[0].shape)#打印(128, 32, 32, 3)

print(batch_data_tuple[1].shape)#每个batch存放第一张图片作为实验

Image.fromarray(batch_data_tuple[1][0]).save("cifar10_dir/raw/{index}-{type}.jpg".format(

index=i, type=label_map[batch_data_tuple[0][0][0]]))

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐