0. 总结

数据导入及处理部分:在 PyTorch 中,我们通常先将 NumPy 数组转换为 torch.Tensor,再封装到 TensorDataset 或自定义的 Dataset 里,然后用 DataLoader 按批次加载。

模型构建部分:RNN

设置超参数:在这之前需要定义损失函数,学习率(动态学习率),以及根据学习率定义优化器(例如SGD随机梯度下降),用来在训练中更新参数,最小化损失函数。

定义训练函数:函数的传入的参数有四个,分别是设置好的DataLoader(),定义好的模型,损失函数,优化器。函数内部初始化损失准确率为0,接着开始循环,使用DataLoader()获取一个批次的数据,对这个批次的数据带入模型得到预测值,然后使用损失函数计算得到损失值。接下来就是进行反向传播以及使用优化器优化参数,梯度清零放在反向传播之前或者是使用优化器优化之后都是可以的,一般是默认放在反向传播之前。

定义测试函数:函数传入的参数相比训练函数少了优化器,只需传入设置好的DataLoader(),定义好的模型,损失函数。此外除了处理批次数据时无需再设置梯度清零、返向传播以及优化器优化参数,其余部分均和训练函数保持一致。

训练过程:定义训练次数,有几次就使用整个数据集进行几次训练,初始化四个空list分别存储每次训练及测试的准确率及损失。使用model.train()开启训练模式,调用训练函数得到准确率及损失。使用model.eval()将模型设置为评估模式,调用测试函数得到准确率及损失。接着就是将得到的训练及测试的准确率及损失存储到相应list中并合并打印出来,得到每一次整体训练后的准确率及损失。

结果可视化

模型的保存,调取及使用。在PyTorch中,通常使用 torch.save(model.state_dict(), ‘model.pth’) 保存模型的参数,使用 model.load_state_dict(torch.load(‘model.pth’)) 加载参数。

需要改进优化的地方:确保模型和数据的一致性,都存到GPU或者CPU;注意numclasses不要直接用默认的1000,需要根据实际数据集改进;实例化模型也要注意numclasses这个参数;此外注意测试模型需要用(3,224,224)3表示通道数,这和tensorflow定义的顺序是不用的(224,224,3),做代码转换时需要注意。

1. RNN介绍

下面是对 RNN(Recurrent Neural Network) 的一个循序渐进、相对通俗的介绍,帮助你从原理上理解 RNN 的本质与应用,希望对你有所帮助。


a. 什么是 RNN?

RNN,全称 Recurrent Neural Network,即“循环神经网络”。它是一类专门处理序列数据的神经网络模型,与传统的前馈网络(如全连接网络 MLP、卷积网络 CNN 等)最大的区别在于:

  • 序列性:RNN 可以在序列的时间步之间传递信息,具备“记忆”先前输入的能力。
  • 循环结构:在每一个时间步,网络都会基于当前输入上一时刻的隐藏状态来更新当前隐藏状态,然后输出结果。
RNN 的一般应用场景
  • 自然语言处理(NLP):如情感分析、文本分类、机器翻译、文本生成等。
  • 时间序列预测:如股票预测、温度预测、信号处理等。
  • 语音识别或合成:处理音频序列。

b. 传统 RNN 的基本结构

以下是一个最基础(经典版)的 RNN 结构示意:

  ┌───────┐      ┌───────┐      ┌───────┐ 
  │x(t-1) │      │x(t)   │      │x(t+1) │  ← 输入序列
  └──┬────┘      └──┬────┘      └──┬────┘
     │              │              │
   ┌─▼──────────────▼──────────────▼─────────────────────────┐
   │                    RNN 单元 (循环体)                     │
   │                                                        │
   │   h(t-1) ──┐   ┌─────────┐   ┌─────────┐               │
   │            │   │激活函数 f│   │激活函数 g│               │
   │ x(t), h(t-1) → │ 线性运算 → │ (如 tanh)  → h(t)          │
   │            │   └─────────┘   └─────────┘               │
   └────────────┴─────────────────────────────────────────────┘
                   ↑  
             通过时间传递
             (隐藏状态 h)
  • 输入序列:( x(1), x(2), …, x(T) )
  • 隐藏状态:( h(t) ) 表示网络在时间步 ( t ) 的内部记忆。
  • 更新公式(经典 RNN 的简单形式):
    [
    h(t) = \sigma(W_{hh} \cdot h(t-1) + W_{xh} \cdot x(t) + b_h)
    ]
    其中 (\sigma) 通常是一个非线性激活函数,如 (\tanh) 或 (\text{ReLU}) 等。
关键特征
  1. 循环(Recurrent)

    • RNN 通过将过去的隐藏状态 ( h(t-1) ) 反复输入到网络,与当前输入 ( x(t) ) 一起决策新的隐藏状态 ( h(t) )。因此它在时间序列上“循环”展开。
  2. 参数共享(Parameter Sharing)

    • 对于序列中每个时间步,RNN 使用相同的一组权重((W_{hh}, W_{xh}) 等),这与一般的多层感知器(MLP)不同,MLP 每一层都会有一组新的权重。
  3. 序列建模

    • 借助隐藏状态的更新,RNN 在一定程度上能够“记住”之前输入的信息,从而可以用来处理依赖于上下文或时间顺序的任务(如语言模型,每个单词与前面单词息息相关)。

c. RNN 的优势与局限

优势
  1. 适合序列数据:相比于传统的全连接网络,RNN 能够更好地处理变长的序列输入,捕捉序列中的时序依赖关系。
  2. 参数共享:节省模型参数,防止过度膨胀。
局限与改进
  1. 长期依赖问题:经典 RNN 里,随着序列长度增大,早期输入的信息往往无法传播到后面时间步,会导致梯度消失或梯度爆炸
  2. 训练效率:由于存在序列展开 + 反向传播(BPTT: Back Propagation Through Time)的特殊性,训练速度通常慢于并行度高的卷积网络。
  3. 改进模型
    • LSTM(Long Short-Term Memory)
    • GRU(Gated Recurrent Unit)
      这两种模型通过门控机制(忘记门、输入门、输出门等)来缓解或部分解决长期依赖问题,在实际中广泛使用。

d. RNN 的常见变体:LSTM 和 GRU

由于传统 RNN 在对长序列进行建模时,容易遗忘早期信息,为了解决这个问题,人们提出了带有 “门控” 机制的循环神经网络结构。其中最典型的就是 LSTMGRU

LSTM (Long Short-Term Memory)
  • 记忆单元(Cell state)和 门控机制(input gate、forget gate、output gate)来控制信息的流动,保留长期的梯度信息,从而缓解梯度消失问题。
  • 在很多 NLP 任务中,LSTM 大多表现优于传统 RNN。
GRU (Gated Recurrent Unit)
  • 结构上比 LSTM 更简化,只有 更新门重置门,虽然结构更简单,但也能保留一定的长期依赖能力。
  • 在某些任务中,GRU 的性能与 LSTM 不相上下,而且训练速度更快。

e. RNN 的应用案例

  1. 语言模型

    • 给定前面的单词,预测下一个单词;或给定一段前文,生成下一段文本。
    • 例如早期的机器翻译系统,输入序列是原语言单词,输出序列是翻译后的目标语言单词。
    • 现在更多使用了 Transformer 这种基于自注意力机制的模型,但 RNN 依然是重要的基石概念。
  2. 序列分类

    • 对一段文本或语音做分类,如情感分析(正向/负向)、语音识别(识别说的是哪一句话)等。
  3. 时间序列预测

    • 比如股票预测、流量预测、天气预测,通过过去若干时刻的数据预测未来走向。

f. RNN 在 PyTorch 中的实现方式

在 PyTorch 里,最常见的循环网络层包括:

  • nn.RNN:经典单层 RNN,可选激活函数 tanhReLU
  • nn.LSTM:LSTM 结构
  • nn.GRU:GRU 结构

输入通常需要形状 (batch_size, seq_len, input_size)(当 batch_first=True 时)。
输出需要自己选择:

  • 如果只需要最后一个时间步的输出,往往取 output[:, -1, :]
  • 如果需要所有时间步的输出(比如生成序列时),则直接使用 output
  • 训练时要记得将 hidden state(以及 cell state)正确地传递或重置。

g. 如何更进一步学习 RNN?

  1. 从小例子入手
    • 用 RNN 来解决简单的序列学习任务(例如正弦波预测、小规模字符级语言模型),查看网络是如何随时间迭代的。
  2. 阅读论文与教程
    • LSTM 的原始论文 (Hochreiter & Schmidhuber, 1997)
    • GRU (Cho et al., 2014)
    • 深入理解门控机制,体会为什么能让 RNN 更好地记住/遗忘信息。
  3. 与 Transformer 对比
    • 在大多数 NLP 任务上,目前已被 Transformer 结构占据主流,但 RNN 思想仍是许多研究的基础。理解 RNN 有助于理解注意力机制为什么行之有效。
  4. 深入到框架实现
    • 看 PyTorch 中 nn.RNNnn.LSTMnn.GRU 的源代码或官方文档,了解参数含义及前向、后向的具体计算流程。

h. 总结

  • 核心思想:RNN 可以“循环”地将过去的信息传递到现在,从而在一定程度上捕捉序列数据的依赖关系。
  • 传统 RNN 的问题:容易出现梯度消失或爆炸,难以捕捉长程依赖。
  • 常见改进:LSTM、GRU 等门控结构缓解了长期依赖难题,也成为 RNN 家族的主力。
  • 现今趋势:NLP 等领域更多使用 Transformer,但 RNN 在许多对序列长度不太长的场合依旧可以使用,而且对初学者理解神经网络的“记忆”能力非常有帮助。

如果你刚开始学习,可以:

  1. 多动手调试:写一些小规模 RNN 代码,训练简单的序列数据,观察 loss 和隐藏状态如何变化。
  2. 多画图:用纸笔画 RNN 在时序上的展开图,有助于理解反向传播的流程。
  3. 分门别类:清楚哪些任务用 LSTM/GRU,哪些任务需要 CNN 或 Transformer,知道各种模型的优势与局限。

2. 数据导入

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import copy
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import classification_report,confusion_matrix
from sklearn.metrics import r2_score
from sklearn.metrics import mean_absolute_error , mean_absolute_percentage_error , mean_squared_error
data = pd.read_csv("./data/weatherAUS.csv")
df   = data.copy()
data.head()
Date Location MinTemp MaxTemp Rainfall Evaporation Sunshine WindGustDir WindGustSpeed WindDir9am ... Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am Temp3pm RainToday RainTomorrow
0 2008-12-01 Albury 13.4 22.9 0.6 NaN NaN W 44.0 W ... 71.0 22.0 1007.7 1007.1 8.0 NaN 16.9 21.8 No No
1 2008-12-02 Albury 7.4 25.1 0.0 NaN NaN WNW 44.0 NNW ... 44.0 25.0 1010.6 1007.8 NaN NaN 17.2 24.3 No No
2 2008-12-03 Albury 12.9 25.7 0.0 NaN NaN WSW 46.0 W ... 38.0 30.0 1007.6 1008.7 NaN 2.0 21.0 23.2 No No
3 2008-12-04 Albury 9.2 28.0 0.0 NaN NaN NE 24.0 SE ... 45.0 16.0 1017.6 1012.8 NaN NaN 18.1 26.5 No No
4 2008-12-05 Albury 17.5 32.3 1.0 NaN NaN W 41.0 ENE ... 82.0 33.0 1010.8 1006.0 7.0 8.0 17.8 29.7 No No

5 rows × 23 columns

data.describe()
MinTemp MaxTemp Rainfall Evaporation Sunshine WindGustSpeed WindSpeed9am WindSpeed3pm Humidity9am Humidity3pm Pressure9am Pressure3pm Cloud9am Cloud3pm Temp9am Temp3pm
count 143975.000000 144199.000000 142199.000000 82670.000000 75625.000000 135197.000000 143693.000000 142398.000000 142806.000000 140953.000000 130395.00000 130432.000000 89572.000000 86102.000000 143693.000000 141851.00000
mean 12.194034 23.221348 2.360918 5.468232 7.611178 40.035230 14.043426 18.662657 68.880831 51.539116 1017.64994 1015.255889 4.447461 4.509930 16.990631 21.68339
std 6.398495 7.119049 8.478060 4.193704 3.785483 13.607062 8.915375 8.809800 19.029164 20.795902 7.10653 7.037414 2.887159 2.720357 6.488753 6.93665
min -8.500000 -4.800000 0.000000 0.000000 0.000000 6.000000 0.000000 0.000000 0.000000 0.000000 980.50000 977.100000 0.000000 0.000000 -7.200000 -5.40000
25% 7.600000 17.900000 0.000000 2.600000 4.800000 31.000000 7.000000 13.000000 57.000000 37.000000 1012.90000 1010.400000 1.000000 2.000000 12.300000 16.60000
50% 12.000000 22.600000 0.000000 4.800000 8.400000 39.000000 13.000000 19.000000 70.000000 52.000000 1017.60000 1015.200000 5.000000 5.000000 16.700000 21.10000
75% 16.900000 28.200000 0.800000 7.400000 10.600000 48.000000 19.000000 24.000000 83.000000 66.000000 1022.40000 1020.000000 7.000000 7.000000 21.600000 26.40000
max 33.900000 48.100000 371.000000 145.000000 14.500000 135.000000 130.000000 87.000000 100.000000 100.000000 1041.00000 1039.600000 9.000000 9.000000 40.200000 46.70000
data.dtypes
Date              object
Location          object
MinTemp          float64
MaxTemp          float64
Rainfall         float64
Evaporation      float64
Sunshine         float64
WindGustDir       object
WindGustSpeed    float64
WindDir9am        object
WindDir3pm        object
WindSpeed9am     float64
WindSpeed3pm     float64
Humidity9am      float64
Humidity3pm      float64
Pressure9am      float64
Pressure3pm      float64
Cloud9am         float64
Cloud3pm         float64
Temp9am          float64
Temp3pm          float64
RainToday         object
RainTomorrow      object
dtype: object

3. 数据探索性分析

#将数据转换为日期时间格式
data['Date'] = pd.to_datetime(data['Date'])

data['year']  = data['Date'].dt.year
data['Month'] = data['Date'].dt.month
data['day']   = data['Date'].dt.day

data.head()
Date Location MinTemp MaxTemp Rainfall Evaporation Sunshine WindGustDir WindGustSpeed WindDir9am ... Pressure3pm Cloud9am Cloud3pm Temp9am Temp3pm RainToday RainTomorrow year Month day
0 2008-12-01 Albury 13.4 22.9 0.6 NaN NaN W 44.0 W ... 1007.1 8.0 NaN 16.9 21.8 No No 2008 12 1
1 2008-12-02 Albury 7.4 25.1 0.0 NaN NaN WNW 44.0 NNW ... 1007.8 NaN NaN 17.2 24.3 No No 2008 12 2
2 2008-12-03 Albury 12.9 25.7 0.0 NaN NaN WSW 46.0 W ... 1008.7 NaN 2.0 21.0 23.2 No No 2008 12 3
3 2008-12-04 Albury 9.2 28.0 0.0 NaN NaN NE 24.0 SE ... 1012.8 NaN NaN 18.1 26.5 No No 2008 12 4
4 2008-12-05 Albury 17.5 32.3 1.0 NaN NaN W 41.0 ENE ... 1006.0 7.0 8.0 17.8 29.7 No No 2008 12 5

5 rows × 26 columns

data.drop('Date',axis=1,inplace=True)
data.columns
Index(['Location', 'MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine',
       'WindGustDir', 'WindGustSpeed', 'WindDir9am', 'WindDir3pm',
       'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm',
       'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', 'Temp9am',
       'Temp3pm', 'RainToday', 'RainTomorrow', 'year', 'Month', 'day'],
      dtype='object')

a. 数据相关性探索

plt.figure(figsize=(15,13))
# data.corr()表示了data中的两个变量之间的相关性
ax = sns.heatmap(data.corr(), square=True, annot=True, fmt='.2f')
ax.set_xticklabels(ax.get_xticklabels(), rotation=90)          
plt.show()

在这里插入图片描述


b. 是否会下雨

# 设置样式和调色板
sns.set(style="whitegrid", palette="Set2")

# 创建一个 1 行 2 列的图像布局
fig, axes = plt.subplots(1, 2, figsize=(10, 4))  # 图形尺寸调大 (10, 4)

# 图表标题样式
title_font = {'fontsize': 14, 'fontweight': 'bold', 'color': 'darkblue'}

# 第一张图:RainTomorrow
sns.countplot(x='RainTomorrow', data=data, ax=axes[0], edgecolor='black')  # 添加边框
axes[0].set_title('Rain Tomorrow', fontdict=title_font)  # 设置标题
axes[0].set_xlabel('Will it Rain Tomorrow?', fontsize=12)  # X轴标签
axes[0].set_ylabel('Count', fontsize=12)  # Y轴标签
axes[0].tick_params(axis='x', labelsize=11)  # X轴刻度字体大小
axes[0].tick_params(axis='y', labelsize=11)  # Y轴刻度字体大小

# 第二张图:RainToday
sns.countplot(x='RainToday', data=data, ax=axes[1], edgecolor='black')  # 添加边框
axes[1].set_title('Rain Today', fontdict=title_font)  # 设置标题
axes[1].set_xlabel('Did it Rain Today?', fontsize=12)  # X轴标签
axes[1].set_ylabel('Count', fontsize=12)  # Y轴标签
axes[1].tick_params(axis='x', labelsize=11)  # X轴刻度字体大小
axes[1].tick_params(axis='y', labelsize=11)  # Y轴刻度字体大小

sns.despine()      # 去除图表顶部和右侧的边框
plt.tight_layout() # 调整布局,避免图形之间的重叠
plt.show()

在这里插入图片描述

x=pd.crosstab(data['RainTomorrow'],data['RainToday'])
x
RainToday No Yes
RainTomorrow
No 92728 16858
Yes 16604 14597
y=x/x.transpose().sum().values.reshape(2,1)*100
y
RainToday No Yes
RainTomorrow
No 84.616648 15.383352
Yes 53.216243 46.783757
  • 如果今天不下雨,那么明天下雨的机会 = 53.22%

  • 如果今天下雨明天下雨的机会 = 46.78%

y.plot(kind="bar",figsize=(4,3),color=['#006666','#d279a6']);

在这里插入图片描述

c. 地理位置与下雨的关系

x=pd.crosstab(data['Location'],data['RainToday']) 
# 获取每个城市下雨天数和非下雨天数的百分比
y=x/x.transpose().sum().values.reshape((-1, 1))*100
# 按每个城市的雨天百分比排序
y=y.sort_values(by='Yes',ascending=True )

color=['#cc6699','#006699','#006666','#862d86','#ff9966'  ]
y.Yes.plot(kind="barh",figsize=(15,20),color=color)
<Axes: ylabel='Location'>

在这里插入图片描述

位置影响下雨,对于 Portland 来说,有 36% 的时间在下雨,而对于 Woomers 来说,只有6%的时间在下雨

d. 湿度和压力对下雨的影响

data.columns
Index(['Location', 'MinTemp', 'MaxTemp', 'Rainfall', 'Evaporation', 'Sunshine',
       'WindGustDir', 'WindGustSpeed', 'WindDir9am', 'WindDir3pm',
       'WindSpeed9am', 'WindSpeed3pm', 'Humidity9am', 'Humidity3pm',
       'Pressure9am', 'Pressure3pm', 'Cloud9am', 'Cloud3pm', 'Temp9am',
       'Temp3pm', 'RainToday', 'RainTomorrow', 'year', 'Month', 'day'],
      dtype='object')
plt.figure(figsize=(8,6))
sns.scatterplot(data=data,x='Pressure9am',
                y='Pressure3pm',hue='RainTomorrow');

在这里插入图片描述

plt.figure(figsize=(8,6))
sns.scatterplot(data=data,x='Humidity9am',
                y='Humidity3pm',hue='RainTomorrow');

在这里插入图片描述


低压与高湿度会增加第二天下雨的概率,尤其是下午 3 点的空气湿度。

e. 气温对下雨的影响

plt.figure(figsize=(8,6))
sns.scatterplot(x='MaxTemp', y='MinTemp', 
                data=data, hue='RainTomorrow');

请添加图片描述

4. 数据预处理

处理缺损值

# 每列中缺失数据的百分比
data.isnull().sum()/data.shape[0]*100
Location          0.000000
MinTemp           1.020899
MaxTemp           0.866905
Rainfall          2.241853
Evaporation      43.166506
Sunshine         48.009762
WindGustDir       7.098859
WindGustSpeed     7.055548
WindDir9am        7.263853
WindDir3pm        2.906641
WindSpeed9am      1.214767
WindSpeed3pm      2.105046
Humidity9am       1.824557
Humidity3pm       3.098446
Pressure9am      10.356799
Pressure3pm      10.331363
Cloud9am         38.421559
Cloud3pm         40.807095
Temp9am           1.214767
Temp3pm           2.481094
RainToday         2.241853
RainTomorrow      2.245978
year              0.000000
Month             0.000000
day               0.000000
dtype: float64
# 在该列中随机选择数进行填充
lst=['Evaporation','Sunshine','Cloud9am','Cloud3pm']
for col in lst:
    fill_list = data[col].dropna()
    data[col] = data[col].fillna(pd.Series(np.random.choice(fill_list, size=len(data.index))))
s = (data.dtypes == "object")
object_cols = list(s[s].index)
object_cols
['Location',
 'WindGustDir',
 'WindDir9am',
 'WindDir3pm',
 'RainToday',
 'RainTomorrow']
# inplace=True:直接修改原对象,不创建副本
# data[i].mode()[0] 返回频率出现最高的选项,众数

for i in object_cols:
    data[i].fillna(data[i].mode()[0], inplace=True)
t = (data.dtypes == "float64")
num_cols = list(t[t].index)
num_cols
['MinTemp',
 'MaxTemp',
 'Rainfall',
 'Evaporation',
 'Sunshine',
 'WindGustSpeed',
 'WindSpeed9am',
 'WindSpeed3pm',
 'Humidity9am',
 'Humidity3pm',
 'Pressure9am',
 'Pressure3pm',
 'Cloud9am',
 'Cloud3pm',
 'Temp9am',
 'Temp3pm']
# .median(), 中位数
for i in num_cols:
    data[i].fillna(data[i].median(), inplace=True)
data.isnull().sum()
Location         0
MinTemp          0
MaxTemp          0
Rainfall         0
Evaporation      0
Sunshine         0
WindGustDir      0
WindGustSpeed    0
WindDir9am       0
WindDir3pm       0
WindSpeed9am     0
WindSpeed3pm     0
Humidity9am      0
Humidity3pm      0
Pressure9am      0
Pressure3pm      0
Cloud9am         0
Cloud3pm         0
Temp9am          0
Temp3pm          0
RainToday        0
RainTomorrow     0
year             0
Month            0
day              0
dtype: int64

5. 构建数据集

from sklearn.preprocessing import LabelEncoder

label_encoder = LabelEncoder()
for i in object_cols:
    data[i] = label_encoder.fit_transform(data[i])
X = data.drop(['RainTomorrow','day'],axis=1).values
y = data['RainTomorrow'].values
X_train, X_test, y_train, y_test = train_test_split(X,y,test_size=0.25,random_state=101)
scaler = MinMaxScaler()
scaler.fit(X_train)
X_train = scaler.transform(X_train)
X_test  = scaler.transform(X_test)
# 创建pytorch dataset 和 dataloader
"""
在 PyTorch 中,我们通常先将 NumPy 数组转换为 torch.Tensor,
再封装到 TensorDataset 或自定义的 Dataset 里,然后用 DataLoader 按批次加载。
"""
import torch
from torch.utils.data import Dataset, DataLoader, TensorDataset

X_train = X_train.reshape(X_train.shape[0],X_train.shape[1],1)
X_test = X_test.reshape(X_test.shape[0],X_test.shape[1],1)

# 如果要做二分类 + Sigmoid + nn.BCELoss,那么标签可以用 float32
# 如果要做多分类(例如 softmax + CrossEntropy),则需把标签转为 long
y_train = y_train.astype(np.float32)  # 二分类: float32
y_test = y_test.astype(np.float32)    # 二分类: float32

# 转换为张量
X_train_tensor = torch.from_numpy(X_train).float()  # shape:[samples, 13, 1]
y_train_tensor = torch.from_numpy(y_train)          # shape:[samples]

X_test_tensor = torch.from_numpy(X_test).float()
y_test_tensor = torch.from_numpy(y_test)

# 如果后续需要在训练中对标签执行 pred>0.5 判定,可以保持 y 的 shape=[samples] 即可
# 也可 reshape([-1,1]) 保持和网络输出尺寸一致,不过这并非必须。
# y_train_tensor = y_train_tensor.view(-1,1)
# y_test_tensor = y_test_tensor.view(-1,1)

# 用 TensorDataset 直接封装
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

# 创建 DataLoader
batch_size = 32

train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dl = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

6. 定义模型

### 构建RNN模型
# -----------------------------
# 1. 定义模型结构
# -----------------------------
class SimpleRNNModel(nn.Module):
    def __init__(self):
        super(SimpleRNNModel, self).__init__()
        # TensorFlow 中 input_shape=(13,1),即序列长度 seq_len = 13,特征维度 input_dim = 1
        # PyTorch RNN 层若设置 batch_first=True:
        #   输入张量形状: (batch_size, seq_len, input_dim)
        #   输出张量形状: (batch_size, seq_len, hidden_size)
        self.rnn = nn.RNN(
            input_size=1,         # 对应 TF 的 input_dim=1
            hidden_size=200,      # 对应 TF 的 RNN(200)
            batch_first=True,
            nonlinearity='relu'   # 对应 TF 的 activation='relu' 
        )
        
        self.fc1 = nn.Linear(200, 100)  # 对应 Dense(100, activation='relu')
        self.fc2 = nn.Linear(100, 1)    # 对应 Dense(1, activation='sigmoid')
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        # x: [batch_size, 13, 1]
        # RNN 输出: output, hidden
        #   output shape = [batch_size, seq_len, hidden_size]
        #   hidden shape = [num_layers, batch_size, hidden_size]
        out, hidden = self.rnn(x)

        # 取最后一个 time_step 的输出, 与 TensorFlow 里 SimpleRNN 的默认行为一致
        out = out[:, -1, :]  # shape: [batch_size, hidden_size]
        
        # 与 Dense(100, relu)
        out = F.relu(self.fc1(out)) # [batch_size, 100]
        
        # 与 Dense(1, sigmoid)
        out = self.sigmoid(self.fc2(out)) # [batch_size, 1]
        return out

7. 初始化模型与优化器

# -----------------------------
# 2. 初始化模型与优化器
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SimpleRNNModel().to(device)
print(model)

# 与 TF 中 loss='binary_crossentropy' 对应,PyTorch 用 BCE:nn.BCELoss
loss_fn = nn.BCELoss()

# 多分类问题使用nn.CrossEntropyLoss()
# criterion = nn.CrossEntropyLoss()


learn_rate = 1e-4  
# learn_rate = 3e-4
lambda1 = lambda epoch:(0.92**(epoch//2))

optimizer = torch.optim.Adam(model.parameters(),lr = learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda=lambda1) # 选定调整方法

SimpleRNNModel(
  (rnn): RNN(1, 200, batch_first=True)
  (fc1): Linear(in_features=200, out_features=100, bias=True)
  (fc2): Linear(in_features=100, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

8. 训练函数

def train(dataloader, model, loss_fn, optimizer):
    size = len(dataloader.dataset)  # 训练集大小
    num_batches = len(dataloader)   # 批次数目
    
    train_loss, train_acc = 0, 0
    
    for X, y in dataloader:
        X, y = X.to(device), y.to(device)

        # 计算预测
        pred = model(X).view(-1) # [batch_size]
        loss = loss_fn(pred, y)
        
        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 记录acc与loss
        # 情况1: 如果是多分类(N>1), pred.shape=[batch_size, N],可以用argmax(1).
        # train_acc += (pred.argmax(1) == y).type(torch.float).sum().item()
        
        # 情况2: 如果是二分类且只有1个输出(使用 Sigmoid),则 pred.shape=[batch_size,1],
        # 那么可用 (pred>0.5) 转为0/1来比较:
        pred_label = (pred > 0.5).long() # [batch_size]
        train_acc += (pred_label == y.long()).sum().item()
        
        train_loss += loss.item()
    
    train_acc /= size
    train_loss /= num_batches
    
    return train_acc, train_loss

9. 测试函数

def test(dataloader, model, loss_fn):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    
    test_acc, test_loss = 0, 0
    
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            
            # 计算预测
            pred = model(X).view(-1) # [batch_size]
            loss = loss_fn(pred, y)
            
            # 情况1: 多分类(N>1):
            # test_acc += (pred.argmax(1) == y).type(torch.float).sum().item()

            # 情况2: 二分类单输出:
            pred_label = (pred > 0.5).long() # [batch_size]
            # test_acc += (pred_label.view(-1) == y).type(torch.float).sum().item()
            test_acc += (pred_label == y.long()).sum().item()
            
            test_loss += loss.item()
    
    test_acc /= size
    test_loss /= num_batches
    
    return test_acc, test_loss

10. 执行训练

# -----------------------------
# 打印可用 GPU 信息
# -----------------------------
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"Initial Memory Allocated: {torch.cuda.memory_allocated(i)/1024**2:.2f} MB")
        print(f"Initial Memory Reserved: {torch.cuda.memory_reserved(i)/1024**2:.2f} MB")
else:
    print("No GPU available. Using CPU.")

# -----------------------------
# 训练主循环
# -----------------------------
epochs = 60

train_acc_list = []
train_loss_list = []
test_acc_list = []
test_loss_list = []

best_acc = 0.0
best_model = None

for epoch in range(epochs):
    # 更新学习率——使用自定义学习率时使用
    # adjust_learning_rate(optimizer,epoch,learn_rate)
    
    # 切换为训练模式
    model.train()
    epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)
    
    # 更新学习率
    scheduler.step() # 更新学习率——调用官方动态学习率时使用
    
    # 切换为评估模式
    model.eval()
    epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)
    
    # 保存最佳模型
    if epoch_test_acc > best_acc:
        best_acc = epoch_test_acc
        best_model = copy.deepcopy(model)
    
    train_acc_list.append(epoch_train_acc)
    train_loss_list.append(epoch_train_loss)
    test_acc_list.append(epoch_test_acc)
    test_loss_list.append(epoch_test_loss)
    
    # 当前学习率
    lr = optimizer.state_dict()['param_groups'][0]['lr']
    
    template = (
        'Epoch:{:2d}, '
        'Train_acc:{:.1f}%, Train_loss:{:.3f}, '
        'Test_acc:{:.1f}%, Test_loss:{:.3f}, '
        'Lr:{:.2E}'
    )
    print(template.format(
        epoch+1,
        epoch_train_acc*100, epoch_train_loss,
        epoch_test_acc*100, epoch_test_loss,
        lr
    ))
    
    # 实时监控 GPU 状态
    if torch.cuda.is_available():
        for i in range(torch.cuda.device_count()):
            print(f"GPU {i} Usage:")
            print(f"  Memory Allocated: {torch.cuda.memory_allocated(i)/1024**2:.2f} MB")
            print(f"  Memory Reserved: {torch.cuda.memory_reserved(i)/1024**2:.2f} MB")
            print(f"  Max Memory Allocated: {torch.cuda.max_memory_allocated(i)/1024**2:.2f} MB")
            print(f"  Max Memory Reserved: {torch.cuda.max_memory_reserved(i)/1024**2:.2f} MB")

print('Done. Best test acc: ', best_acc)
No GPU available. Using CPU.
Epoch: 1, Train_acc:80.1%, Train_loss:0.460, Test_acc:82.7%, Test_loss:0.397, Lr:1.00E-04
Epoch: 2, Train_acc:83.4%, Train_loss:0.387, Test_acc:83.8%, Test_loss:0.374, Lr:9.20E-05
Epoch: 3, Train_acc:83.9%, Train_loss:0.375, Test_acc:84.1%, Test_loss:0.367, Lr:9.20E-05
Epoch: 4, Train_acc:83.9%, Train_loss:0.370, Test_acc:84.2%, Test_loss:0.365, Lr:8.46E-05
Epoch: 5, Train_acc:84.1%, Train_loss:0.368, Test_acc:83.9%, Test_loss:0.375, Lr:8.46E-05
Epoch: 6, Train_acc:84.3%, Train_loss:0.366, Test_acc:84.3%, Test_loss:0.364, Lr:7.79E-05
Epoch: 7, Train_acc:84.3%, Train_loss:0.365, Test_acc:84.3%, Test_loss:0.363, Lr:7.79E-05
Epoch: 8, Train_acc:84.3%, Train_loss:0.364, Test_acc:84.3%, Test_loss:0.362, Lr:7.16E-05
Epoch: 9, Train_acc:84.3%, Train_loss:0.364, Test_acc:84.4%, Test_loss:0.362, Lr:7.16E-05
Epoch:10, Train_acc:84.4%, Train_loss:0.362, Test_acc:84.3%, Test_loss:0.363, Lr:6.59E-05
Epoch:11, Train_acc:84.3%, Train_loss:0.361, Test_acc:84.4%, Test_loss:0.363, Lr:6.59E-05
Epoch:12, Train_acc:84.4%, Train_loss:0.361, Test_acc:84.4%, Test_loss:0.359, Lr:6.06E-05
Epoch:13, Train_acc:84.5%, Train_loss:0.360, Test_acc:84.4%, Test_loss:0.362, Lr:6.06E-05
Epoch:14, Train_acc:84.4%, Train_loss:0.360, Test_acc:84.5%, Test_loss:0.359, Lr:5.58E-05
Epoch:15, Train_acc:84.5%, Train_loss:0.358, Test_acc:84.4%, Test_loss:0.358, Lr:5.58E-05
Epoch:16, Train_acc:84.5%, Train_loss:0.358, Test_acc:84.5%, Test_loss:0.361, Lr:5.13E-05
Epoch:17, Train_acc:84.6%, Train_loss:0.357, Test_acc:84.5%, Test_loss:0.358, Lr:5.13E-05
Epoch:18, Train_acc:84.6%, Train_loss:0.357, Test_acc:84.6%, Test_loss:0.357, Lr:4.72E-05
Epoch:19, Train_acc:84.6%, Train_loss:0.356, Test_acc:84.6%, Test_loss:0.357, Lr:4.72E-05
Epoch:20, Train_acc:84.7%, Train_loss:0.356, Test_acc:84.6%, Test_loss:0.356, Lr:4.34E-05
Epoch:21, Train_acc:84.6%, Train_loss:0.355, Test_acc:84.6%, Test_loss:0.356, Lr:4.34E-05
Epoch:22, Train_acc:84.6%, Train_loss:0.355, Test_acc:84.6%, Test_loss:0.356, Lr:4.00E-05
Epoch:23, Train_acc:84.7%, Train_loss:0.354, Test_acc:84.6%, Test_loss:0.356, Lr:4.00E-05
Epoch:24, Train_acc:84.7%, Train_loss:0.354, Test_acc:84.5%, Test_loss:0.358, Lr:3.68E-05
Epoch:25, Train_acc:84.7%, Train_loss:0.353, Test_acc:84.6%, Test_loss:0.357, Lr:3.68E-05
Epoch:26, Train_acc:84.8%, Train_loss:0.353, Test_acc:84.7%, Test_loss:0.354, Lr:3.38E-05
Epoch:27, Train_acc:84.7%, Train_loss:0.352, Test_acc:84.7%, Test_loss:0.353, Lr:3.38E-05
Epoch:28, Train_acc:84.8%, Train_loss:0.352, Test_acc:84.7%, Test_loss:0.354, Lr:3.11E-05
Epoch:29, Train_acc:84.8%, Train_loss:0.352, Test_acc:84.8%, Test_loss:0.354, Lr:3.11E-05
Epoch:30, Train_acc:84.9%, Train_loss:0.352, Test_acc:84.8%, Test_loss:0.353, Lr:2.86E-05
Epoch:31, Train_acc:84.9%, Train_loss:0.351, Test_acc:84.7%, Test_loss:0.356, Lr:2.86E-05
Epoch:32, Train_acc:84.9%, Train_loss:0.351, Test_acc:84.6%, Test_loss:0.354, Lr:2.63E-05
Epoch:33, Train_acc:84.8%, Train_loss:0.350, Test_acc:84.8%, Test_loss:0.352, Lr:2.63E-05
Epoch:34, Train_acc:84.9%, Train_loss:0.350, Test_acc:84.7%, Test_loss:0.354, Lr:2.42E-05
Epoch:35, Train_acc:84.9%, Train_loss:0.350, Test_acc:84.8%, Test_loss:0.352, Lr:2.42E-05
Epoch:36, Train_acc:84.9%, Train_loss:0.350, Test_acc:84.6%, Test_loss:0.354, Lr:2.23E-05
Epoch:37, Train_acc:84.9%, Train_loss:0.349, Test_acc:84.8%, Test_loss:0.353, Lr:2.23E-05
Epoch:38, Train_acc:84.9%, Train_loss:0.349, Test_acc:84.6%, Test_loss:0.356, Lr:2.05E-05
Epoch:39, Train_acc:85.0%, Train_loss:0.348, Test_acc:85.0%, Test_loss:0.351, Lr:2.05E-05
Epoch:40, Train_acc:85.0%, Train_loss:0.349, Test_acc:84.8%, Test_loss:0.351, Lr:1.89E-05
Epoch:41, Train_acc:85.0%, Train_loss:0.348, Test_acc:84.8%, Test_loss:0.351, Lr:1.89E-05
Epoch:42, Train_acc:85.0%, Train_loss:0.348, Test_acc:84.9%, Test_loss:0.351, Lr:1.74E-05
Epoch:43, Train_acc:85.0%, Train_loss:0.347, Test_acc:85.0%, Test_loss:0.350, Lr:1.74E-05
Epoch:44, Train_acc:85.0%, Train_loss:0.347, Test_acc:85.0%, Test_loss:0.351, Lr:1.60E-05
Epoch:45, Train_acc:85.1%, Train_loss:0.347, Test_acc:84.9%, Test_loss:0.350, Lr:1.60E-05
Epoch:46, Train_acc:85.1%, Train_loss:0.347, Test_acc:84.9%, Test_loss:0.350, Lr:1.47E-05
Epoch:47, Train_acc:85.0%, Train_loss:0.347, Test_acc:84.9%, Test_loss:0.351, Lr:1.47E-05
Epoch:48, Train_acc:85.0%, Train_loss:0.346, Test_acc:84.9%, Test_loss:0.350, Lr:1.35E-05
Epoch:49, Train_acc:85.1%, Train_loss:0.346, Test_acc:85.0%, Test_loss:0.349, Lr:1.35E-05
Epoch:50, Train_acc:85.1%, Train_loss:0.346, Test_acc:84.9%, Test_loss:0.350, Lr:1.24E-05
Epoch:51, Train_acc:85.1%, Train_loss:0.346, Test_acc:84.9%, Test_loss:0.349, Lr:1.24E-05
Epoch:52, Train_acc:85.1%, Train_loss:0.346, Test_acc:84.9%, Test_loss:0.350, Lr:1.14E-05
Epoch:53, Train_acc:85.1%, Train_loss:0.345, Test_acc:84.9%, Test_loss:0.349, Lr:1.14E-05
Epoch:54, Train_acc:85.1%, Train_loss:0.345, Test_acc:85.0%, Test_loss:0.349, Lr:1.05E-05
Epoch:55, Train_acc:85.1%, Train_loss:0.345, Test_acc:85.0%, Test_loss:0.349, Lr:1.05E-05
Epoch:56, Train_acc:85.1%, Train_loss:0.345, Test_acc:84.8%, Test_loss:0.350, Lr:9.68E-06
Epoch:57, Train_acc:85.1%, Train_loss:0.345, Test_acc:85.0%, Test_loss:0.348, Lr:9.68E-06
Epoch:58, Train_acc:85.1%, Train_loss:0.344, Test_acc:84.8%, Test_loss:0.350, Lr:8.91E-06
Epoch:59, Train_acc:85.2%, Train_loss:0.344, Test_acc:85.0%, Test_loss:0.348, Lr:8.91E-06
Epoch:60, Train_acc:85.2%, Train_loss:0.344, Test_acc:84.9%, Test_loss:0.349, Lr:8.20E-06
Done. Best test acc:  0.8500206242265915

11. 过程可视化

epochs_range = range(epochs)

plt.figure(figsize=(12, 5))

# 准确率曲线
plt.subplot(1, 2, 1)
plt.plot(epochs_range, train_acc_list, label='Training Accuracy')
plt.plot(epochs_range, test_acc_list, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

# 损失曲线
plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss_list, label='Training Loss')
plt.plot(epochs_range, test_loss_list, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')

plt.show()

请添加图片描述


Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐