# creator : wangdiedang
# time : 2022/6/7 10:13
# filename : Read_File.py

import numpy as np
import struct as st
import matplotlib.pyplot as plt

path = [
    'train-images.idx3-ubyte',
    'train-labels.idx1-ubyte',
    't10k-images.idx3-ubyte',
    't10k-labels.idx1-ubyte'
]


def normalize(data):  # 将图片像素二值化
    data[data != 0] = 1


# 读入图像
def read_idx3(path):
    offset = 0  # 定义偏移量
    fmt_header = ">4i"  # 定义类型 idx3前面有4个整型
    raw_bin_data = open(path, "rb").read()  # 读入字节数据
    magic_number, num_images, num_rows, num_cols = st.unpack_from(fmt_header, raw_bin_data, offset)
    image_size = num_rows * num_cols
    offset += st.calcsize(fmt_header)
    fmt_image = ">" + str(image_size) + "B"
    images = np.empty((num_images, num_rows, num_cols))  # 返回一个空矩阵
    # print(images.shape)
    for i in range(num_images):
        images[i] = np.array(st.unpack_from(fmt_image, raw_bin_data, offset)).reshape((num_rows, num_cols))
        # 二值化图像矩阵
        normalize(images[i])
        offset += st.calcsize(fmt_image)

    return images


# 读入对应标签
def read_idx1(path):
    offset = 0  # 定义偏移量
    fmt_header = ">2i"  # 定义类型 idx1前面有2个整型
    raw_bin_data = open(path, "rb").read()  # 读入字节数据
    magic_number, num_images = st.unpack_from(fmt_header, raw_bin_data, offset)
    offset += st.calcsize(fmt_header)
    fmt_image = ">B"
    labels = np.empty(num_images)  # 返回一个空矩阵
    # print(images.shape) 输出值为(10000,)
    for i in range(num_images):
        labels[i] = st.unpack_from(fmt_image, raw_bin_data, offset)[0]
        offset += st.calcsize(fmt_image)

    return labels


def read_train_and_test(train_imgs_path, train_labels_path, test_imgs_path, test_labels_path):
    # 读入 训练集6w 的 idx3文件
    train_imgs = read_idx3(train_imgs_path)
    # 读入 训练集6w 的 idx1文件
    train_labels = read_idx1(train_labels_path)
    # 读入 测试集1w 的 idx3文件
    test_imgs = read_idx3(test_imgs_path)
    # 读入 测试集1w 的 idx1文件
    test_labels = read_idx1(test_labels_path)
    normalize(train_imgs)
    normalize(test_imgs)
    return train_imgs, train_labels, test_imgs, test_labels


# def show_img(imgs, labels):
#     for i, item in enumerate(imgs):
#         if i >= 15:
#             break
#         plt.imshow(imgs[i], cmap='gray')
#         plt.pause(0.000001)
#         plt.show()
#         print(labels[i])
def show_img(imgs):
    m, r, c = imgs.shape
    t_img = np.zeros((1, 8 * c + 2))
    for i in range(8):
        t = np.zeros((r, 1))
        for j in range(8):
            t = np.hstack((t, imgs[i*8+j]))
        t = np.hstack((t, np.zeros((r, 1))))
        t_img = np.vstack((t_img, t))
    t_img = np.vstack((t_img, np.zeros((1, 8 * c + 2))))
    plt.imshow(t_img)
    plt.show()


def read_main():
    # 调用read_train_and_test读入完整数据矩阵
    return read_train_and_test(path[0], path[1], path[2], path[3])


if __name__ == '__main__':
    read_main()

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐