第四节:Pytorch数据处理与模型保存

本节将讲解数据操作与模型保存

其中数据处理包含数据处理与数据导入,数据处理能够帮助导入训练数据集,对数据进行正则化等功能

此外模型保存将会帮助我们保存已有的成果

这节讲解完毕我们就已经能够训练我们自己的网络,下一节我们将讲解网络结构可视化相关工具来帮助我们检测、表达网络的结构

数据处理

常用的类

Pytorch的torch.util.data模块中包含着一系列常用的数据预处理的函数或类,其中有数据的读取、切分、准备等内容

我们下面对于某一类任务的具体数据处理都将基于常用类

常用的类如下

功能
torch.utils.data.TensorDataset() 将数据处理为张量
torch.utils.data.ConcatDataset() 连接多个数据集
torch.utils.data.Subset() 根据索引获取数据集的子集
torch.utils.data.DataLoader() 数据加载器
torch.utils.data.random_split() 将数据集随机拆分为给定长度的费重叠新数据集

高维数组型数据

很多请款下,我们拿到手的raw data都是存储在文本文件中的高维数组型数据,我们需要做的从文本中读取多维数组数据然后将他们处理成可以直接用于训练的数据

通常以高维数组形式来保存的数据的特征就是具有多个预测变量和一个目标变量,即多个特征和一个目标标签。从深度学习的角度来说针对高维数组型数据,网络的目标就是对预测变量进行学习,来预测目标变量。更广泛意义上的机器学习的角度来说针对高维数组型数据,学习机 / 模型的目标是对预测变量进行学习,来预测目标变量

如果目标标签是连续的,则我们的任务就是回归问题;如果目标标签是离散的,则我们的任务就是分类问题

针对这两种不同的任务,我们使用Pytorch建立模型对数据进行学习时候,通常需要对数据进行预处理,将他们转化为网络需要的数据形式

下面我们将使用Scikit-Learn库中的波士顿房价数据和鸢尾花数据集来演示我们是如何将读取的高维数组转化为我们可用于训练的数据

我们所有的代码基于以下的导入

import torch
import torch.utils.data as Data
import numpy as np
from sklearn.datasets import load_boston, load_iris

其中Data是我们即将讲解的对象,我们从sklearn.datasets中导入load_boston和load_iris来帮助我们获取波士顿房价和鸢尾花数据集

准备工作

波士顿房价数据集中一共包含506个样本/观察,每个样本都有13个特征和1个标签

特征中包含房屋和房屋周围的信息,例如城镇人均犯罪率、房屋面积、房间数、税率等等信息,具体信息如下表,其中4不是数据集中原本就有的数据
波士顿数据集的标签就是房价

在这里插入图片描述

Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例。数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都有 4 项特征:花萼长度、花萼宽度、花瓣长度、花瓣宽度,可以通过这4个特征预测鸢尾花卉属于(iris-setosa, iris-versicolour, iris-virginica)中的哪一品种

波士顿数据集的房价是连续的,而鸢尾花数据集的鸢尾花种类有且仅有四种,因此分别是回归任务和分类任务

我们准备工作首先读取数据,即从txt,csv等格式数据中读取到当前脚本中来

#读取波士顿房价数据集
boston_X,boston_y=load_boston(return_X_y=True)
iris_X,iris_y=load_iris(return_X_y=True)

#了解数据集基础信息
print("Boston:")
print(type(boston_X))
print(boston_X.shape)
print(boston_X.dtype)
print(type(boston_y))
print(boston_y.shape)
print(boston_y.dtype)
print()
print("Iris:")
print(type(iris_X))
print(iris_X.shape)
print(iris_X.dtype)
print(type(iris_y))
print(iris_y.shape)
print(iris_y.dtype)
>>>
Boston:
<class 'numpy.ndarray'>
(506, 13)
float64
<class 'numpy.ndarray'>
(506,)
float64

Iris:
<class 'numpy.ndarray'>
(150, 4)
float64
<class 'numpy.ndarray'>
(150,)
int64

由于回归任务和分类任务的标签取值不同,因此我们的数据处理方法也不同,下面将一一介绍

回归数据准备

转化为Tensor

我们上面通过对数据集的描述得知,回归任务的波士顿房价数据集的样本是以numpy的ndarray对象存在的,我们需要用pytorch来进行计算的话,则要求我们将numpy数组转化为pytorch的tensor对象

此外我们转化的时候使用from_numpy函数的时候需要传入的ndarray对象的元素都是32位浮点型,因此我们首先需要转换数据类型再转化为tensor对象

trainBostion_Xt=torch.from_numpy(boston_X.astype(np.float32))
trainBostion_yt=torch.from_numpy(boston_y.astype(np.float32))
print(type(trainBostion_Xt))
print(trainBostion_Xt.shape)
print(type(trainBostion_yt))
print(trainBostion_yt.shape)
>>>
<class 'torch.Tensor'>
torch.Size([506, 13])
<class 'torch.Tensor'>
torch.Size([506])
绑定数据

我们上面得到的数据集的特征和标签都是分开的两个数组,因此我们需要使用Pytorch的util.data模块中的TensorDataset来将X和Y整合到一起

我们只需要给TensorDataset函数传入作为特征和标签的Tensor即可,需要按照特征,标签的顺序传入

fullBoston_data=Data.TensorDataset(trainBostion_Xt,trainBostion_yt)
数据分割

我们每次训练网络的时候用的都是Mini-batch,而且伴随超参数的选择,我们还会进行validation,以及划分测试集,因此我们通常需要对拿到的全部数据进行分割

我们的方法就是使用数据加载器DataLoader将数据分割为不同的batch,选取其中的部分作为测试集,validation集,训练集

trainBostion_loader=Data.DataLoader(dataset=fullBoston_data,batch_size=64,shuffle=True,num_workers=1)

我们指定dataset参数来指定分割的对象,batch_size指定分割的数据集的大小,shuffle表示是否随机打乱数据,num_workers表明使用的进程数

使用数据

我们使用上面分割好的数据的时候,需要使用enumerate函数,DataLoader类本身会返回两个分割好的数据集(样本和标签)的列表,我们需要使用Python内置的enumerate函数来将两者配对

for step,(batch_x,batch_y) in enumerate(trainBostion_loader):
    print(step)
    print(batch_x.shape)
    print(batch_y.shape)
>>>
0
torch.Size([64, 13])
torch.Size([64])
1
torch.Size([64, 13])
torch.Size([64])
2
torch.Size([64, 13])
torch.Size([64])
3
torch.Size([64, 13])
torch.Size([64])
4
torch.Size([64, 13])
torch.Size([64])
5
torch.Size([64, 13])
torch.Size([64])
6
torch.Size([64, 13])
torch.Size([64])
7
torch.Size([58, 13])
torch.Size([58])

分类数据准备

分类数据准备和回归数据大体类似,只不过因为分的类有限,因此y的数据类型是整数,因此我们将numpy的adarray对象转化为tensor对象的时候需要转化为64位整数

trainIris_X=torch.from_numpy(iris_X.astype(np.float32))
trainIris_y=torch.from_numpy(iris_y.astype(np.int64))
fullIris_data=Data.TensorDataset(trainIris_X,trainIris_y)
trainIris_loader=Data.DataLoader(dataset=fullIris_data,batch_size=10,shuffle=True,num_workers=1)
for step,(batch_x,batch_y) in enumerate(trainIris_loader):
    print(step)
    print(batch_x.shape)
    print(batch_y.shape)
>>>
0
torch.Size([10, 4])
torch.Size([10])
1
torch.Size([10, 4])
torch.Size([10])
2
torch.Size([10, 4])
torch.Size([10])
3
torch.Size([10, 4])
torch.Size([10])
4
torch.Size([10, 4])
torch.Size([10])
5
torch.Size([10, 4])
torch.Size([10])
6
torch.Size([10, 4])
torch.Size([10])
7
torch.Size([10, 4])
torch.Size([10])
8
torch.Size([10, 4])
torch.Size([10])
9
torch.Size([10, 4])
torch.Size([10])
10
torch.Size([10, 4])
torch.Size([10])
11
torch.Size([10, 4])
torch.Size([10])
12
torch.Size([10, 4])
torch.Size([10])
13
torch.Size([10, 4])
torch.Size([10])
14
torch.Size([10, 4])
torch.Size([10])

图像数据

图像数据和高维数组型数据不同,torchvision中的datasets模块中包含多种常用的分类数据集下载以及导入函数,我们可以很方便的导入数据以及验证模型的效果

dataset模块提供的部分常用图像数据集如下

数据集对应的类 描述
datasets.MNIST 手写字体数据集
datasets.FashionMNIST 衣服、鞋子、包等10类数据集
datasets.KMNIST 一些文字的灰度数据
datasets.CocoCaptions 微软的用于图像标注的数据集
datasets.CocoDetection 用于图像检测的MS COCO数据集
datasets.LSUN 10个常见和20个目标的分类数据集
datasets.CIFAR10 CIFAR数据集截取了部分,只有10个类的自数据及
datasets.CIFAR100 同上,只不过是100个类
datasets.STL100 包含10类的分类数据集和大量的未标记数据
datasets.ImageFolder 定义一个数据加载器从文件夹中读取数据

而torchvision中的transforms模块可以针对对每张图像进行预处理操作,该模块中常用的操作的类如下

说明
transforms.Compose 将多个transform组合起来使用
transforms.Scale 按照指定的图像尺寸对图像进行调整
transforms.CenterCrop 对图像进行中心切割,得到给定的大小
transforms.RandomCrop 切割中心点的位置随机选取
transforms.RandomHorizontalFlip 图像随机水平翻转
transforms.RandomSizedCrop 将给定的图像随机切割,然后再变换为给定大小
transforms.Pad 将图像所有边用给定的pad value填充
transforms.ToTensor 讲一个取值范围为[0,255]的PIL图像或形状为[H,W,C]的数组转换为形状为[C,H,W]的、取值为[0,1.0]的tensor
transforms.Normalize 将给定的图像进行正则化操作
transforms.Lambda 使用自定义的lambda表达式作为转化器,可以自定义图像操作方法

下面将以FashionMNIST为例讲解如何使用上面的数据/正常情况下如何完成一次导入

我们下面的代码将基于如下的导入

import torch
import torch.utils.data as Data
from torchvision.datasets import FashionMNIST
import torchvision.transforms as Transforms
from torchvision.datasets import ImageFolder
import numpy as np

导入数据

根据我们导入的图像数据的来源不同,可以是从网站直接导入,也可以从本地导入,因此具体有两种不同的方式

网络导入数据

我们以FashionMNIST为例,该数据集包含60000张28×28的灰度图像作为训练集,以及10000张28×28的灰度图片作为测试集,数据一共分10类,分别是鞋子、T恤、连衣裙等饰品

下载/导入数据集

我们直接实例化一个FashionMNIST对象即可

train_data=FashionMNIST(root='./data/FashionMNIST',train=True,transform=Transforms.ToTensor(),download=True)

我们指定root为读取数据集的路径,如果该路径下以及有数据集,那么指定download参数为True时,就会自动下载到指定路径;如果该路径下已存在数据集,则会直接导入

我们这里指定train参数为True,表明我们读取训练集数据,我们指定transform参数表明我们我们将导入的数据的转换为形状为[C,H,W]、值介于0~1之间的Tensor

下载成功后如下

在这里插入图片描述

我们去当前文件夹下就能看到下载的数据集

在这里插入图片描述

在这里插入图片描述

我们成功下载数据集后,我们再使用本地导入来导入训练集数据看看效果

test_data=FashionMNIST(root='./data/FashionMNIST',train=False,transform=Transforms.ToTensor,download=False)

此外,我们实例化的这些数据类都具有targets属性作为标签,data属性作为特征

分割训练集

接下来我们使用上面讲解过得DataLoader来分割、加载数据

train_loader=Data.DataLoader(dataset=train_data,batch_size=64,shuffle=True,num_workers=2)
print(len(train_loader))
>>>
938

前面讲解过,DataLoader是一个可以遍历的对象,因此我们使用len即可查询所有的batch数量

调整测试集

我们前面讲过,我们通过nn模块中现有的层或者我们自己写的层或者模块都需要接受形如[batch_size,channel,Height,Width]的输入,因此在我们处理成可用的数据前我们需要对数据进行最后的修改

这里由于我们直接加载的数据集,而数据集中的训练集已经是被处理成形状为[channel,Height,Width]的Tensor的图像,而我们使用DataLoader的时候已经划分了Batch,因此我们只需要将测试集的数据进行处理为符合要求的数据即可

test_data_X=test_data.data.type(torch.FloatTensor) / 255.0
test_data_X=torch.unsqueeze(test_data_X,dim=1)
test_data_y=test_data.targets
print(test_data_X.shape)
print(test_data_y.shape)
>>>
torch.Size([10000, 1, 28, 28])
torch.Size([10000])
文件夹导入数据

我们除了可以用已经准备好的数据,我们也可以从文件夹导入自己的数据,具体方式就是使用ImageFolder函数,该函数能够读取特定格式存储的数据集,具体形式如下:

  • data
    • dog
      • dog1.png
      • dog2.png
    • cat
      • cat1.png
      • cat2.png

  1. 所有的图像数据需要放在同一个文件夹下
  2. 具有相同标签的数据放在同一个子文件夹下
  3. 不同标签的数据放在不同的子文件夹下

我们现在制作一个自己的图像数据集MyImageData,其中具有lena一个类/文件夹,lena文件夹中有三张lena的图片

在这里插入图片描述

接下来我们就将读取MyImageData数据集

设置变换器

上面我们导入的数据自动就具有了良好的、可以直接被处理的格式:Tensor形式存在,Tensor形状为[batch_size,channel,Heigh,Width]

而如果需要读取我们自己的数据的话,就需要我们自己定义一个变换器来对数据进行操作

变换器是指我们设定的对raw data进行的操作,例如我们设置一个用于将原始输入图像变换成28*28大小的图像,那么就能称之为一个变换器

我们通过多种变换器最终将输入的原图像raw data转换为可以被处理的数据,接下来进一步处理为可以直接用于训练的数据

train_data_transformers=Transforms.Compose([Transforms.RandomResizedCrop(size=224),\
                                           Transforms.RandomHorizontalFlip(p=0.5),\
                                           Transforms.ToTensor(),\
                                           Transforms.Normalize([0.485,0.456,0.406],[0.299,0.224,0.255])])

这里我们使用了Compose来组合多个变换器,首先我们将输入的数据变换为224×224大小的图片,接下来随机选择50%的图片水平翻转,接下来将图片转化为Tensor,最后我们对已经成为Tensor的图片进行正则化,最为可用于处理的数据

读取数据

接下来我们只需要给ImageFolder指定数据集的位置和需要进行的变换即可

train_data_dir='./data/MyImageData'
train_data=ImageFolder(root=train_data_dir,transform=train_data_transformers)
train_data_loader=Data.DataLoader(dataset=train_data,batch_size=1,shuffle=True,num_workers=1)
for step,(batch_X,batch_y) in enumerate(train_data_loader):
    print(step)
    print(batch_X.shape)
    print(batch_y.shape)
>>>
0
torch.Size([1, 3, 224, 224])
torch.Size([1])
1
torch.Size([1, 3, 224, 224])
torch.Size([1])
2
torch.Size([1, 3, 224, 224])
torch.Size([1])

模型保存

对于已经训练好的模型,我们通常有保存整个模型和保存模型的参数两种方法

我们首先搭建如下的网络作为测试

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

class TestConvNet(nn.Module):
    def __init__(self):
        super(TestConvNet,self).__init__()
        self.conv1=nn.Sequential(\
                                nn.Conv2d(in_channels=1,out_channels=1,kernel_size=(3,3),stride=1,padding=1),\
                                nn.ReLU(),\
                                nn.MaxPool2d(kernel_size=(2,2),stride=2))
        
        self.conv2=nn.Sequential(\
                                nn.Conv2d(in_channels=1,out_channels=1,kernel_size=(3,3),stride=1,padding=1),\
                                nn.ReLU(),\
                                nn.MaxPool2d(kernel_size=(2,2),stride=2))
        
        self.fc1=nn.Sequential(\
                                nn.Linear(in_features=100,out_features=50),\
                                nn.ReLU())
        
        self.fc2=nn.Sequential(\
                                nn.Linear(in_features=50,out_features=10),
                                nn.ReLU())
        self.predict=nn.Linear(in_features=10,out_features=1)
    
    def forward(self,x):
        x=self.conv1(x)
        x=self.conv2(x)
        x=self.fc1(x)
        x=self.fc2(x)
        output=self.predict(x)
        return output

接下里我们实例化一个网络

ConvNet1=TestConvNet()
print(list(ConvNet1.conv1.parameters()))
>>>
[Parameter containing:
tensor([[[[-0.0856,  0.2359, -0.2708],
          [ 0.0364,  0.1312,  0.0799],
          [ 0.0720,  0.1978, -0.2602]]]], requires_grad=True), Parameter containing:
tensor([0.2487], requires_grad=True)]

假设该网络已经训练好,我们接下里需要保存这个模型

保存整个模型

我们使用torch.save函数就能保存整个模型,只需要传入需要保存的模型和路径即可

torch.save(ConvNet1,f='./MyConvNet.pkl')

需要注意的是,Pytorch保存模型的用的是Python自带的pickle库,因此文件的后缀名要是pkl

对应的我们载入模型就使用Pytorch的torch.load函数,传入模型的位置即可

MyConvNet1Load=torch.load(f='./MyConvNet.pkl')
print(list(MyConvNet1Load.conv1.parameters()))
>>>
[Parameter containing:
tensor([[[[-0.0856,  0.2359, -0.2708],
          [ 0.0364,  0.1312,  0.0799],
          [ 0.0720,  0.1978, -0.2602]]]], requires_grad=True), Parameter containing:
tensor([0.2487], requires_grad=True)]

保存模型参数

我们也可以选择只保存模型参数,而后自己搭建一个和原网络形状相同的网络,并将参数初值赋为保存的参数就能够实现复原

torch.save(ConvNet1.state_dict(),'./MyConvNetState.pkl')

同样,我们使用load函数载入即可

MyConvNet1Parameters=torch.load(f='./MyConvNetState.pkl')
print(MyConvNet1Parameters)
>>>
OrderedDict([('conv1.0.weight', tensor([[[[-0.0856,  0.2359, -0.2708],
          [ 0.0364,  0.1312,  0.0799],
          [ 0.0720,  0.1978, -0.2602]]]])), ('conv1.0.bias', tensor([0.2487])), ('conv2.0.weight', tensor([[[[-0.1575,  0.0613, -0.0051],
          [ 0.0738,  0.2174,  0.2874],
          [-0.2416, -0.1645,  0.2076]]]])), ('conv2.0.bias', tensor([0.1067])), ('fc1.0.weight', tensor([[ 0.0976,  0.0563,  0.0211,  ..., -0.0565,  0.0145, -0.0107],
        [-0.0905,  0.0853, -0.0122,  ...,  0.0567, -0.0395,  0.0664],
        [ 0.0521,  0.0430,  0.0784,  ...,  0.0092,  0.0441, -0.0798],
        ...,
        [ 0.0297, -0.0779,  0.0400,  ...,  0.0307, -0.0687,  0.0837],
        [-0.0983,  0.0853, -0.0534,  ...,  0.0284, -0.0344, -0.0952],
        [-0.0995,  0.0719, -0.0223,  ..., -0.0745, -0.0418,  0.0595]])), ('fc1.0.bias', tensor([ 0.0585,  0.0032,  0.0757, -0.0183, -0.0525,  0.0412, -0.0416, -0.0200,
         0.0472, -0.0664, -0.0231, -0.0627,  0.0624, -0.0136,  0.0417,  0.0625,
         0.0028,  0.0678,  0.0889,  0.0719, -0.0982,  0.0819,  0.0050,  0.0559,
         0.0913,  0.0547,  0.0246, -0.0670, -0.0460, -0.0739, -0.0586,  0.0408,
         0.0062,  0.0647, -0.0831, -0.0406,  0.0444,  0.0406, -0.0603,  0.0219,
         0.0163, -0.0107,  0.0826, -0.0433,  0.0281,  0.0939, -0.0771,  0.0230,
        -0.0128,  0.0669])), ('fc2.0.weight', tensor([[-0.0041, -0.1265, -0.0821, -0.0979, -0.0859, -0.0389,  0.0774,  0.0543,
          0.0882, -0.0571, -0.0540,  0.0421,  0.1091,  0.0830, -0.0361, -0.0995,
         -0.1083,  0.1345,  0.1227, -0.0645, -0.0377,  0.0654, -0.0356,  0.0307,
          0.1212,  0.0621,  0.1238, -0.0979,  0.0104, -0.0337, -0.0989,  0.0602,
         -0.1288,  0.0287, -0.0684, -0.0842, -0.0779, -0.1151,  0.0272, -0.0568,
         -0.0909, -0.0183,  0.0593, -0.1340, -0.1225,  0.0076,  0.0554,  0.0807,
          0.1183, -0.0006],
        [-0.0017, -0.0632, -0.0569, -0.1268,  0.0709, -0.0547,  0.1328, -0.0228,
          0.0289, -0.1408,  0.1317, -0.1037, -0.0597, -0.1013,  0.0869,  0.0121,
          0.1344,  0.1262,  0.0884, -0.0596, -0.0309,  0.0947, -0.0680,  0.1288,
         -0.0453,  0.1366, -0.0758,  0.0842, -0.0853, -0.0450, -0.0190,  0.0257,
         -0.0927,  0.0383,  0.0047, -0.1287,  0.0092, -0.1043, -0.0148,  0.0983,
          0.0724, -0.0984,  0.0390,  0.0412,  0.0136,  0.1016,  0.0545, -0.0859,
          0.0677, -0.0713],
        [ 0.1218,  0.1186, -0.0373, -0.0321,  0.0587, -0.1179,  0.0429, -0.1284,
          0.0826,  0.0642, -0.1060, -0.0636,  0.0158, -0.0235,  0.0982,  0.0107,
         -0.1261, -0.0989,  0.1098,  0.1283, -0.0760, -0.0784,  0.0030, -0.1134,
          0.0700,  0.0789, -0.1367,  0.1329,  0.0882,  0.0277, -0.1269, -0.1211,
          0.0391,  0.0615, -0.0883, -0.0684, -0.0546,  0.1028, -0.1302, -0.0347,
         -0.0224, -0.1307,  0.0262, -0.1356, -0.0699,  0.0872,  0.0555, -0.0853,
          0.1076, -0.0792],
        [ 0.0979,  0.0927,  0.0741,  0.1034,  0.1033,  0.0821,  0.0530, -0.0245,
         -0.1121, -0.0098, -0.0187, -0.1185,  0.1080,  0.0160,  0.0924,  0.1155,
         -0.1316,  0.0941, -0.1245, -0.1294,  0.0251, -0.0161,  0.0632,  0.0913,
          0.0332, -0.1007, -0.0653, -0.0392, -0.1358, -0.0420, -0.0159,  0.0591,
         -0.0506,  0.0130, -0.0303,  0.1049,  0.1081, -0.1075,  0.1175, -0.0513,
          0.0064, -0.0749,  0.0678, -0.0104,  0.0809, -0.0982,  0.1126, -0.0913,
          0.0294,  0.1278],
        [-0.1295,  0.0892,  0.0597,  0.0144, -0.0793, -0.1250, -0.1391, -0.0386,
         -0.0892, -0.0293, -0.0048,  0.0628, -0.0697,  0.1043,  0.1364, -0.0712,
          0.0223,  0.0039, -0.1159,  0.1125,  0.0631, -0.1022, -0.0413,  0.0123,
         -0.0833,  0.0813,  0.0281, -0.0845,  0.1071, -0.0380, -0.1205,  0.0211,
          0.1133,  0.0522,  0.1262,  0.0342,  0.0329, -0.0822, -0.0088, -0.0079,
          0.0179, -0.0962,  0.0922,  0.0356,  0.0481, -0.0966, -0.0458,  0.0683,
          0.0120,  0.0903],
        [ 0.0455,  0.0389,  0.1308, -0.1083, -0.1311,  0.0211, -0.0723,  0.0201,
         -0.0054, -0.0841, -0.0604,  0.1124, -0.0828, -0.1317, -0.0528,  0.1092,
          0.1210,  0.1230,  0.1376,  0.0057, -0.0401, -0.0087, -0.0491, -0.0452,
         -0.1107, -0.1236, -0.0445, -0.1326, -0.0315, -0.0241,  0.1021, -0.1297,
         -0.0015, -0.1104,  0.1264,  0.0868,  0.0946,  0.1347, -0.1298,  0.0923,
          0.0187, -0.1099, -0.0827,  0.0924, -0.0463,  0.1208, -0.0577, -0.1128,
          0.0643,  0.0368],
        [-0.0083,  0.0303,  0.0089,  0.1150, -0.1332, -0.0611,  0.0357, -0.0983,
          0.0697, -0.0126, -0.1038, -0.1133,  0.1003, -0.1320, -0.0619,  0.1011,
          0.1222,  0.0568, -0.0081, -0.1107, -0.0698,  0.0300,  0.1356,  0.1370,
         -0.0567, -0.1068,  0.0695, -0.0740,  0.0017,  0.1310, -0.0652,  0.1143,
         -0.0379,  0.0571, -0.1051,  0.1261, -0.0621,  0.0704,  0.0191, -0.0858,
          0.0348,  0.0265,  0.0646,  0.0549,  0.1137,  0.0346,  0.0920,  0.0148,
          0.1376, -0.0228],
        [ 0.0013,  0.1306,  0.0852, -0.0806, -0.0144, -0.1149,  0.1043,  0.1032,
         -0.1022, -0.0720, -0.1110, -0.0403, -0.0834, -0.0519, -0.0598,  0.1251,
         -0.0636, -0.0718, -0.0539,  0.0572, -0.0885, -0.0419, -0.0004,  0.1392,
          0.1239, -0.1079,  0.0158, -0.0758,  0.0798,  0.1208, -0.0157,  0.0160,
         -0.0788,  0.0016,  0.0875,  0.0846,  0.0058,  0.1167, -0.0177,  0.1300,
          0.0300, -0.0965, -0.1005, -0.0323,  0.0437,  0.1360,  0.0127,  0.0788,
         -0.1342,  0.0289],
        [-0.0774, -0.0756, -0.0468, -0.0677,  0.0740, -0.0960,  0.1312, -0.0947,
         -0.0965,  0.1177,  0.0743,  0.0998, -0.1248, -0.0651,  0.0665,  0.1403,
         -0.0466, -0.0069,  0.0854, -0.0917,  0.1362, -0.0870,  0.0098,  0.0599,
          0.0673,  0.0755,  0.0607, -0.0756, -0.0722,  0.0414, -0.1024, -0.0641,
          0.0300,  0.1411, -0.1340, -0.0479,  0.0334,  0.0630, -0.0012, -0.1238,
          0.0035, -0.1068, -0.0647,  0.1236, -0.0542,  0.1002, -0.0414,  0.0750,
         -0.1071,  0.0616],
        [ 0.1034, -0.0619, -0.0179,  0.1262, -0.1172, -0.0343, -0.0466,  0.0049,
         -0.0196, -0.0204, -0.1245,  0.0900,  0.1212,  0.0941,  0.0417,  0.0144,
          0.0559,  0.0489,  0.0791,  0.0873,  0.0875, -0.0498, -0.1014,  0.0250,
          0.1244,  0.1289, -0.0023,  0.1065, -0.0038,  0.1163, -0.1376,  0.0116,
          0.0893,  0.0468,  0.0808, -0.0381, -0.0909, -0.0313, -0.1032,  0.0615,
         -0.0451,  0.0327, -0.0684, -0.0873,  0.0570, -0.0411,  0.0799,  0.0615,
         -0.1232, -0.1309]])), ('fc2.0.bias', tensor([ 0.0226,  0.1385,  0.0289, -0.0177,  0.0703,  0.0572,  0.0211, -0.1167,
        -0.0980, -0.0262])), ('predict.weight', tensor([[ 0.0122,  0.1662,  0.2079, -0.0570,  0.0929, -0.0008, -0.1581,  0.2524,
          0.1307,  0.0755]])), ('predict.bias', tensor([0.2225]))])
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐