文本分类:基于BERT模型处理新闻分类
使用BERT模型完成基础的NLP新闻分类任务
项目代码详见:https://github.com/xiaozhou-alt/text-classifier
一、项目介绍
项目主要基于 THUCNews 中文文本分类数据集,包含83万新闻数据,采用了 wobert 以词为单位的中文 BERT 模型进行文本分类任务。
话不多说,让我们进入项目。
二、模型介绍
BERT(Bidirectional Encoder Representations from Transformers)是由Google于2018年提出的基于Transformer架构的预训练语言模型,标志着自然语言处理领域进入预训练大模型时代。其核心创新在于通过双向Transformer编码器捕捉上下文语义,突破了传统单向语言模型的限制。BERT采用掩码语言建模(MLM)和下一句预测(NSP)两大预训练任务,在大规模无标注文本(如Wikipedia)上学习通用语言表示,通过随机遮蔽15%的词汇要求模型还原原始文本,同时判断句子间的逻辑关系。这种预训练范式使模型能够捕获词汇、句法和语义的多层次特征,通过微调即可适配文本分类、问答、实体识别等下游任务。
以下是项目中使用的 wobert 模型与 bert-base 模型的对比:
| 特性 | BERT-base | WoBERT |
|---|---|---|
| 基础架构 | Transformer Encoder | Transformer Encoder + 词结构优化 |
| 层数 | 12层 | 12层 |
| 隐藏层维度 | 768 | 768 |
| 注意力头数 | 12 | 12 |
| 位置编码 | 绝对位置编码 | 相对位置编码(部分变体) |
分词层(Tokenizer)设计:
# BERT中文版分词示例
text = "自然语言处理"
tokenizer.tokenize(text) → ['自', '然', '语', '言', '处', '理'] # 字粒度
# WoBERT分词示例
tokenizer.tokenize(text) → ['自然', '语言', '处理'] # 词粒度
- BERT:采用字级别的WordPiece分词
- 优点:避免分词错误传播
- 缺点:丢失词边界信息
- WoBERT:基于词的分词(需预训练词表)
- 优点:保留词汇语义完整性
- 缺点:词表膨胀风险
三、项目实现
数据集介绍:
项目使用THUCNews中文文本分类数据集,该数据集包含 83 万篇新闻文档,总计 14 类;在数据集的基础上可以进行文本分类、词向量的训练等任务。数据集的下载地址:(http://thuctc.thunlp.org/)
数据组成展示图:
下载后的数据文件格式如下:
- 财经
- 1.txt
- [标题]
[记者](如果有记录)
[正文]
- [标题]
- 2.txt
…
- 1.txt
- 体育
- 时政
… - 教育
数据预处理
1.数据统计
def count_text(category_path):
total = {}
for dir in glob(category_path + '/*'): # 遍历每个类别目录
total[dir] = len(glob(dir + '/*.txt')) # 统计txt文件数
print(f"共{len(total)}类, 总文件数:{sum(total.values())}")
return total
作用:统计原始数据的类别分布
输入:THUCNews / 目录路径
输出:字典格式 {“/data/THUCNews/体育”: 1000, …}
2.数据集划分
def cut_corpus(path):
label = re.findall(r"[\u4e00-\u9fa5]+", path) # 从路径提取中文标签
files = glob(path + '/*.txt')
# 分层随机划分
train, test = train_test_split(files, test_size=0.3, random_state=2020)
valid, test = train_test_split(test, test_size=0.5, random_state=2021)
return train, test, valid, label
首先将整个数据集按照 7 : 3 划分为训练集和测试集;然后将测试集进一步按照 1 : 1 划分为验证集和测试集
3.保存数据为csv文件
def process(path_dict, filename='train', frac=1):
samples = []
for label, data in path_dict.items():
sample = read_data(data, label, debug=True, frac=frac)
samples.append(sample)
df = pd.concat(samples)
save_path = f"{os.path.dirname(__file__)}/data/{filename}.csv"
df.to_csv(save_path, sep='\t', index=False)
使用正则表达式提取出标签:
re.findall(r"[\u4e00-\u9fa5]+", path) # 匹配路径中的中文字符
示例:/project/data/THUCNews/ 体育 → 提取出标签体育
生成文件示例(train.csv):
| title | content | label |
|---|---|---|
| 欧冠决赛… | 北京时间5月30日… | 体育 |
| … | … | … |
| 电影票房… | 据猫眼专业版数据… | 娱乐 |
剩余的两个csv格式同 trian.csv 一致
4.标签映射
def process(path_dict, filename='train', frac=1):
samples = []
for label, data in path_dict.items():
sample = read_data(data, label, debug=True, frac=frac)
samples.append(sample)
df = pd.concat(samples)
save_path = f"{os.path.dirname(__file__)}/data/{filename}.csv"
df.to_csv(save_path, sep='\t', index=False)
生成文件 label2id.json:
{"体育": 0, "娱乐": 1, "时政": 2, ...}
运行效果展示:

模型训练
主要组成模块:
1. NewsDataset : 数据加载与预处理(继承torch Dataset)
2. NewsClassifier : 核心训练类
|- 模型初始化
|- 训练流程(train)
|- 评估函数(evaluate)
|- 预测函数(predict)
3. 主程序 : 配置初始化与流程控制
模型初始化:
def __init__(self, data_path, label2id, tokenizer, max_len=128):
# 读取CSV文件并清洗数据
self.data = pd.read_csv(data_path, sep='\t')
self.data = self.data.dropna(subset=['title', 'label']) # 删除关键字段缺失的样本
self.data = self.data[self.data['title'].str.len() > 0] # 过滤空标题
self.data['title'] = self.data['title'].str.replace(r'\s+', ' ', regex=True) # 标准化空格
- 三阶段清洗:缺失值处理 → 空文本过滤 → 非常规空格替换
- 正则表达式 + 匹配各种空白字符(包括全角空格),统一替换为标准半角空格
- 保留原始数据路径信息,便于后续错误追踪
训练循环:
def train():
# 核心训练逻辑:
for epoch in epochs:
for batch in train_loader:
# 梯度累积实现:
if (step+1) % 4 == 0: # accumulation_steps=4
optimizer.step()
optimizer.zero_grad()
# 记录训练细节:
if global_step % 150 == 0:
record_step(...)
# 验证与模型保存:
val_acc = evaluate(valid_loader)
if val_acc > best_acc:
save_pretrained() # 保存完整模型
- 梯度累积:每4个batch更新一次参数,等效增大batch size至128(32 * 4),提升训练稳定性
- 异步数据传输:non_blocking=True实现CPU-GPU并行数据传输
- 内存锁页:pin_memory=True加速数据加载到GPU的过程
评估与预测:
def evaluate(self, data_loader):
# 多指标计算
all_preds = []
all_labels = []
...
return accuracy_score(all_labels, all_preds), total_loss/len(data_loader)
def _save_training_records(self):
# 双层级监控
pd.DataFrame(self.step_records).to_excel(...) # 详细步骤记录
pd.DataFrame(self.epoch_records).to_excel(...) # 阶段汇总
def predict(self, text, confidence_threshold=0.6):
probs = torch.nn.functional.softmax(...)
if max_prob.item() < confidence_threshold:
return "Unknown", max_prob.item()
- 置信度阈值:过滤低置信度预测(<60%),提升线上服务可靠性
- 概率校准:通过softmax获取归一化概率,避免模型过度自信
- Unknown类别:置信度较低(<60%)的预测结果,标签被设置为Unknown,为后续主动学习提供数据采集入口
优化器配置:
optimizer = AdamW(
self.model.parameters(),
lr=2e-5, # BERT标准学习率
weight_decay=0.01) # L2正则防止过拟合
- 学习率2e-5:遵循BERT论文微调建议,避免破坏预训练参数
- 权重衰减0.01:平衡模型复杂度和拟合能力
- 未使用学习率预热:适合小数据集快速收敛
代码训练过程建议使用云端平台的GPU服务器如:kaggle、modelscope(本人电脑仅为3050 i5较为垃圾,如果使用本地CPU,预测每一个epoch要跑57个小时。。。,故使用云端8核32GB的GPU服务器,每epoch时间缩短到1h10min,nice!ლ(´ڡ`ლ))
模型训练部分过程展示:
预测结果:
模型在各类新闻数据上的表现( support 表示数据量):
混淆矩阵输出:
如果你喜欢我的文章,不妨给小周一个免费的点赞和关注吧!
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐


所有评论(0)