项目代码详见:https://github.com/xiaozhou-alt/text-classifier


一、项目介绍

项目主要基于 THUCNews 中文文本分类数据集,包含83万新闻数据,采用了 wobert 以词为单位的中文 BERT 模型进行文本分类任务。
话不多说,让我们进入项目。

二、模型介绍

BERT(Bidirectional Encoder Representations from Transformers)是由Google于2018年提出的基于Transformer架构的预训练语言模型,标志着自然语言处理领域进入预训练大模型时代。其核心创新在于通过双向Transformer编码器捕捉上下文语义,突破了传统单向语言模型的限制。BERT采用掩码语言建模(MLM)和下一句预测(NSP)两大预训练任务,在大规模无标注文本(如Wikipedia)上学习通用语言表示,通过随机遮蔽15%的词汇要求模型还原原始文本,同时判断句子间的逻辑关系。这种预训练范式使模型能够捕获词汇、句法和语义的多层次特征,通过微调即可适配文本分类、问答、实体识别等下游任务。
在这里插入图片描述
以下是项目中使用的 wobert 模型与 bert-base 模型的对比:

特性 BERT-base WoBERT
基础架构 Transformer Encoder Transformer Encoder + 词结构优化
​层数 12层 12层
​隐藏层维度 768 768
​注意力头数 12 12
​位置编码 绝对位置编码 相对位置编码(部分变体)

分词层(Tokenizer)设计

# BERT中文版分词示例
text = "自然语言处理"
tokenizer.tokenize(text)['自', '然', '语', '言', '处', '理']  # 字粒度

# WoBERT分词示例
tokenizer.tokenize(text)['自然', '语言', '处理']  # 词粒度
  • BERT:采用字级别的WordPiece分词
    • 优点:避免分词错误传播
    • 缺点:丢失词边界信息
  • ​WoBERT:基于词的分词(需预训练词表)
    • 优点:保留词汇语义完整性
    • 缺点:词表膨胀风险

三、项目实现

数据集介绍:

项目使用THUCNews中文文本分类数据集,该数据集包含 83 万篇新闻文档,总计 14 类;在数据集的基础上可以进行文本分类、词向量的训练等任务。数据集的下载地址:(http://thuctc.thunlp.org/)
数据组成展示图:在这里插入图片描述
下载后的数据文件格式如下:

  • 财经
    • 1.txt
      • [标题]
        [记者](如果有记录)
        [正文]
    • 2.txt
  • 体育
  • 时政
  • 教育

数据预处理

1.数据统计

def count_text(category_path):
    total = {}
    for dir in glob(category_path + '/*'):  # 遍历每个类别目录
        total[dir] = len(glob(dir + '/*.txt'))  # 统计txt文件数
    print(f"共{len(total)}类, 总文件数:{sum(total.values())}")
    return total

作用:统计原始数据的类别分布
输入THUCNews / 目录路径
输出:字典格式 {“/data/THUCNews/体育”: 1000, …}

2.数据集划分

def cut_corpus(path):
    label = re.findall(r"[\u4e00-\u9fa5]+", path)  # 从路径提取中文标签
    files = glob(path + '/*.txt')
    
    # 分层随机划分
    train, test = train_test_split(files, test_size=0.3, random_state=2020)
    valid, test = train_test_split(test, test_size=0.5, random_state=2021)
    return train, test, valid, label

首先将整个数据集按照 7 : 3 划分为训练集和测试集;然后将测试集进一步按照 1 : 1 划分为验证集和测试集

3.保存数据为csv文件

def process(path_dict, filename='train', frac=1):
    samples = []
    for label, data in path_dict.items():
        sample = read_data(data, label, debug=True, frac=frac)
        samples.append(sample)
    
    df = pd.concat(samples)
    save_path = f"{os.path.dirname(__file__)}/data/{filename}.csv"
    df.to_csv(save_path, sep='\t', index=False)

使用正则表达式提取出标签:

re.findall(r"[\u4e00-\u9fa5]+", path)  # 匹配路径中的中文字符

示例:/project/data/THUCNews/ 体育 → 提取出标签体育
生成文件示例​(train.csv):

title content label
欧冠决赛… 北京时间5月30日… 体育
电影票房… 据猫眼专业版数据… 娱乐

剩余的两个csv格式同 trian.csv 一致

4.标签映射

def process(path_dict, filename='train', frac=1):
    samples = []
    for label, data in path_dict.items():
        sample = read_data(data, label, debug=True, frac=frac)
        samples.append(sample)
    
    df = pd.concat(samples)
    save_path = f"{os.path.dirname(__file__)}/data/{filename}.csv"
    df.to_csv(save_path, sep='\t', index=False)

生成文件 ​label2id.json

{"体育": 0, "娱乐": 1, "时政": 2, ...}

运行效果展示:
各类别及其数量
数据清洗以及划分过程

模型训练

主要组成模块

1. NewsDataset : 数据加载与预处理(继承torch Dataset)
2. NewsClassifier : 核心训练类
   |- 模型初始化
   |- 训练流程(train)
   |- 评估函数(evaluate)
   |- 预测函数(predict)
3. 主程序 : 配置初始化与流程控制

模型初始化

def __init__(self, data_path, label2id, tokenizer, max_len=128):
    # 读取CSV文件并清洗数据
    self.data = pd.read_csv(data_path, sep='\t')
    self.data = self.data.dropna(subset=['title', 'label'])  # 删除关键字段缺失的样本
    self.data = self.data[self.data['title'].str.len() > 0]  # 过滤空标题
    self.data['title'] = self.data['title'].str.replace(r'\s+', ' ', regex=True)  # 标准化空格
  • 三阶段清洗:缺失值处理 → 空文本过滤 → 非常规空格替换
  • 正则表达式 + 匹配各种空白字符(包括全角空格),统一替换为标准半角空格
  • 保留原始数据路径信息,便于后续错误追踪

训练循环

def train():
    # 核心训练逻辑:
    for epoch in epochs:
        for batch in train_loader:
            # 梯度累积实现:
            if (step+1) % 4 == 0:  # accumulation_steps=4
                optimizer.step()
                optimizer.zero_grad()
            
            # 记录训练细节:
            if global_step % 150 == 0:
                record_step(...)
        
        # 验证与模型保存:
        val_acc = evaluate(valid_loader)
        if val_acc > best_acc:
            save_pretrained()  # 保存完整模型
  • 梯度累积:每4个batch更新一次参数,等效增大batch size至128(32 * 4),提升训练稳定性
  • ​异步数据传输:non_blocking=True实现CPU-GPU并行数据传输
  • 内存锁页:pin_memory=True加速数据加载到GPU的过程

评估与预测

def evaluate(self, data_loader):
    # 多指标计算
    all_preds = []
    all_labels = []
    ...
    return accuracy_score(all_labels, all_preds), total_loss/len(data_loader)

def _save_training_records(self):
    # 双层级监控
    pd.DataFrame(self.step_records).to_excel(...)  # 详细步骤记录
    pd.DataFrame(self.epoch_records).to_excel(...)  # 阶段汇总

def predict(self, text, confidence_threshold=0.6):
    probs = torch.nn.functional.softmax(...)
    if max_prob.item() < confidence_threshold:
        return "Unknown", max_prob.item()
  • 置信度阈值:过滤低置信度预测(<60%),提升线上服务可靠性
  • ​概率校准:通过softmax获取归一化概率,避免模型过度自信
  • ​Unknown类别:置信度较低(<60%)的预测结果,标签被设置为Unknown,为后续主动学习提供数据采集入口

优化器配置

optimizer = AdamW(
    self.model.parameters(), 
    lr=2e-5,  # BERT标准学习率
    weight_decay=0.01)  # L2正则防止过拟合
  • 学习率2e-5:遵循BERT论文微调建议,避免破坏预训练参数
  • 权重衰减0.01:平衡模型复杂度和拟合能力
  • 未使用学习率预热:适合小数据集快速收敛

代码训练过程建议使用云端平台的GPU服务器如:kaggle、modelscope(本人电脑仅为3050 i5较为垃圾,如果使用本地CPU,预测每一个epoch要跑57个小时。。。,故使用云端8核32GB的GPU服务器,每epoch时间缩短到1h10min,nice!ლ(´ڡ`ლ))

模型训练部分过程展示:
在这里插入图片描述
预测结果:
在这里插入图片描述
模型在各类新闻数据上的表现( support 表示数据量):
在这里插入图片描述
混淆矩阵输出:
在这里插入图片描述
如果你喜欢我的文章,不妨给小周一个免费的点赞和关注吧!

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐