注意:本实验在百度平台的Al Studio运行

1.什么是文本分类

文本分类就是根据文本内容将文本划分到不同类别,例如新闻系统中,每篇新闻报道会划归到不同的类别。

2.文本分类的应用

  • 内容分类(新闻分类)

  • 邮件过滤(例如垃圾邮件过滤)

  • 用户分类(如商城消费级别、喜好)

  • 评论、文章、对话的情感分类(正面、负面、中性)

3.文本分类案例

  • 任务:建立文本分类模型,并对模型进行训练、评估,从而实现对中文新闻摘要类别正确划分

  • 数据集:从网站上爬取56821条数据中文新闻摘要,包含10种类别,国际、文化、娱乐、体育、财经、汽车、教育、科技、房产、证券,各类别样本数量如下表所示:

  • 模型选择:

  • 步骤:

  • 代码

    【预处理部分】

    ########################### 数据预处理 #########################
    import os
    from multiprocessing import cpu_count
    import numpy as np
    import paddle
    import paddle.fluid as fluid
    
    # 定义一组公共变量
    data_root = "data/" # 数据集所在目录
    data_file = "news_classify_data.txt" # 原始数据集
    train_file = "train.txt" # 训练集文件
    test_file = "test.txt" # 测试集文件
    dict_file = "dict_txt.txt" # 字典文件(存放字和编码映射关系)
    
    data_file_path = data_root + data_file # 数据集完整路径
    train_file_path = data_root + train_file # 训练集文件完整路径
    test_file_path = data_root + test_file # 测试集文件完整路径
    dict_file_path = data_root + dict_file # 字典文件完整路径
    
    # 取出样本中所有字,对每个字进行编码,将编码结果存入字典文件
    def create_dict():
        dict_set = set() # 集合,用作去重
        with open(data_file_path, "r", encoding="utf-8") as f:
            for line in f.readlines(): # 遍历每行
                line = line.replace("\n", "") # 去除换行符
                tmp_list = line.split("_!_") # 根据分隔符拆分
                title = tmp_list[-1] # 最后一个字段即为标题
                for word in title: # 取出每个字
                    dict_set.add(word)
    
        # 遍历集合,取出每个字进行编号
        dict_txt = {} # 定义字典
        i = 1 # 编码使用的计数器
        for word in dict_set:
            dict_txt[word] = i # 字-编码 键值对添加到字典
            i += 1
    
        dict_txt["<unk>"] = i # 未知字符(在样本中未出现过的字)
    
        # 将字典内容存入文件
        with open(dict_file_path, "w", encoding="utf-8") as f:
            f.write(str(dict_txt))
    
        print("生成字典结束.")
    
    # 传入一个句子,将每个字替换为编码值,和标签一起返回
    def line_encoding(title, dict_txt, label):
        new_line = "" # 编码结果
        for word in title:
            if word in dict_txt: # 在字典中
                code = str(dict_txt[word]) # 取出编码值
            else: # 不在字典中
                code = str(dict_txt["<unk>"]) # 取未知字符编码值
            new_line = new_line + code + "," # 追加到字符串后面
        new_line = new_line[:-1] # 去掉最后一个多余的逗号
        new_line = new_line + "\t" + label + "\n" # 追加标签值
        return new_line
    
    # 读取原始样本,取出标题部分进行编码,将编码后的划分测试集/训练集
    def create_train_test_file():
        # 清空训练集/测试集
        with open(train_file_path, "w") as f:
            pass
        with open(test_file_path, "w") as f:
            pass
    
        # 读取字典文件
        with open(dict_file_path, "r", encoding="utf-8") as f_dict:
            dict_txt = eval(f_dict.readlines()[0]) # 读取字典文件第一行,生成字典对象
    
        # 读取原始样本
        with open(data_file_path, "r", encoding="utf-8") as f_data:
            lines = f_data.readlines()
    
        i = 0
        for line in lines:
            tmp_list = line.replace("\n", "").split("_!_") # 拆分
            title = tmp_list[3] # 标题
            label = tmp_list[1] # 类别
            new_line = line_encoding(title, dict_txt, label) # 对标题编码
    
            if i % 10 == 0: # 写入测试集
                with open(test_file_path, "a", encoding="utf-8") as f:
                    f.write(new_line)
            else: # 写入训练集
                with open(train_file_path, "a", encoding="utf-8") as f:
                    f.write(new_line)
            i += 1
        print("生成训练集/测试集结束.")
    
    create_dict() # 根据样本生成字典
    create_train_test_file()

    输出:

    生成字典结束.
    生成训练集/测试集结束.

    【模型定义与训练】

    paddle.enable_static()
    # 读取字典文件,返回字典长度
    def get_dict_len(dict_path):
        with open(dict_path, "r", encoding="utf-8") as f:
          dict_txt = eval(f.readlines()[0])
        return len(dict_txt.keys())
    
    def data_mapper(sample):
        data, label = sample # 赋值到变量
        val = [int(w) for w in data.split(",")] # 将编码值转换位数字(从文件读取为字符串)
        return val, int(label)
    
    def train_reader(train_file_path): # 训练集读取器
        def reader():
            with open(train_file_path, "r") as f:
                lines = f.readlines()
                np.random.shuffle(lines) # 随机化处理
                for line in lines:
                    data, label = line.split("\t") # 拆分
                    yield data, label
        return paddle.reader.xmap_readers(data_mapper, reader, cpu_count(), 1024)
    
    def test_reader(test_file_path): # 训练集读取器
        def reader():
            with open(test_file_path, "r") as f:
                lines = f.readlines()
                
                for line in lines:
                    data, label = line.split("\t") # 拆分
                    yield data, label
        return paddle.reader.xmap_readers(data_mapper, reader, cpu_count(), 1024)
    
    # 定义网络
    def Text_CNN(data, dict_dim, class_dim=10, emb_dim=128,
                 hid_dim=128, hid_dim2=128):
        """
        定义TextCNN模型
        :param data: 输入
        :param dict_dim: 词典大小(词语总的数量)
        :param class_dim: 分类的数量
        :param emb_dim: 词嵌入长度
        :param hid_dim: 第一个卷基层卷积核数量
        :param hid_dim2: 第二个卷基层卷积核数量
        :return: 模型预测结果
        """
        # embedding层
        emb = fluid.layers.embedding(input=data, size=[dict_dim, emb_dim])
        # 并列两个卷积/池化层
        conv1 = fluid.nets.sequence_conv_pool(input=emb, # 输入(词嵌入层输出)
                                              num_filters=hid_dim,# 卷积核数量
                                              filter_size=3,#卷积核大小
                                              act="tanh",#激活函数
                                              pool_type="sqrt")#池化类型
        conv2 = fluid.nets.sequence_conv_pool(input=emb, # 输入(词嵌入层输出)
                                              num_filters=hid_dim2,# 卷积核数量
                                              filter_size=4,#卷积核大小
                                              act="tanh",#激活函数
                                              pool_type="sqrt")#池化类型
        # fc
        output = fluid.layers.fc(input=[conv1, conv2], # 输入
                                 size=class_dim,#输出值个数
                                 act="softmax")#激活函数
        return output
    
    # 定义占位符张量
    words = fluid.layers.data(name="words",
                              shape=[1],
                              dtype="int64",
                              lod_level=1) # LOD张量用来表示变长数据
    label = fluid.layers.data(name="label",
                              shape=[1],
                              dtype="int64")
    dict_dim = get_dict_len(dict_file_path) # 获取字典长度
    # 调用模型函数
    model = Text_CNN(words, dict_dim)
    # 损失函数
    cost = fluid.layers.cross_entropy(input=model, label=label)
    avg_cost = fluid.layers.mean(cost)
    # 优化器
    optimizer = fluid.optimizer.Adam(learning_rate=0.0001)
    optimizer.minimize(avg_cost)
    # 准确率
    accuracy = fluid.layers.accuracy(input=model, label=label)
    
    # 执行器
    place = fluid.CUDAPlace(0)
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
    
    # reader
    ## 训练集reader
    tr_reader = train_reader(train_file_path)
    batch_train_reader = paddle.batch(tr_reader, batch_size=128)
    ## 测试集reader
    ts_reader = test_reader(test_file_path)
    batch_test_reader = paddle.batch(ts_reader, batch_size=128)
    
    # feeder
    feeder = fluid.DataFeeder(place=place, feed_list=[words, label])
    
    # 开始训练
    for epoch in range(80): # 外层循环控制训练轮次
        for batch_id, data in enumerate(batch_train_reader()): # 内层循环控制批次
            train_cost, train_acc = exe.run(fluid.default_main_program(),#program
                                            feed=feeder.feed(data),#喂入的参数
                                            fetch_list=[avg_cost, accuracy])#返回值
            if batch_id % 100 == 0:
                print("epoch:%d, batch:%d, cost:%f, acc:%f" %
                      (epoch, batch_id, train_cost[0], train_acc[0]))
    
        # 每轮训练结束后进行模型评估
        test_costs_list = [] # 存放测试集损失值
        test_accs_list = [] # 存放测试集准确率
    
        for batch_id, data in enumerate(batch_test_reader()):
            test_cost, test_acc = exe.run(fluid.default_main_program(), 
                                          feed=feeder.feed(data),
                                          fetch_list=[avg_cost, accuracy])
            test_costs_list.append(test_cost[0])
            test_accs_list.append(test_acc[0])
        # 计算所有批次损失值/准确率均值
        avg_test_cost = sum(test_costs_list) / len(test_costs_list)
        avg_test_acc = sum(test_accs_list) / len(test_accs_list)
        print("epoch:%d, test_cost:%f, test_acc:%f" %
              (epoch, avg_test_cost, avg_test_acc))
    
    # 训练结束,保存模型
    model_save_dir = "model/"
    if not os.path.exists(model_save_dir):
        os.makedirs(model_save_dir)
    fluid.io.save_inference_model(model_save_dir, # 保存路径
                                  feeded_var_names=[words.name],# 使用时传入参数名称
                                  target_vars=[model],#预测结果
                                  executor=exe)#执行器
    print("模型保存成功.")

    输出

    W0301 17:44:27.392561   134 gpu\_resources.cc:61\] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 12.0, Runtime API Version: 11.2
    W0301 17:44:27.396858   134 gpu\_resources.cc:91\] device: 0, cuDNN Version: 8.2.
    
    epoch:0, batch:0, cost:2.311182, acc:0.062500
    epoch:0, batch:100, cost:2.063974, acc:0.398438
    epoch:0, batch:200, cost:1.564050, acc:0.531250
    epoch:0, batch:300, cost:1.113311, acc:0.726562
    epoch:0, test\_cost:0.984438, test\_acc:0.699619
    epoch:1, batch:0, cost:0.904973, acc:0.703125
    epoch:1, batch:100, cost:1.056342, acc:0.648438
    epoch:1, batch:200, cost:0.757222, acc:0.773438
    epoch:1, batch:300, cost:0.768483, acc:0.757812
    epoch:1, test\_cost:0.743878, test\_acc:0.760737
    epoch:2, batch:0, cost:0.799179, acc:0.773438
    epoch:2, batch:100, cost:0.684252, acc:0.796875
    epoch:2, batch:200, cost:0.700367, acc:0.773438
    epoch:2, batch:300, cost:0.658518, acc:0.820312
    epoch:2, test\_cost:0.662426, test\_acc:0.787224
    epoch:3, batch:0, cost:0.732371, acc:0.757812
    epoch:3, batch:100, cost:0.621147, acc:0.804688
    epoch:3, batch:200, cost:0.592039, acc:0.820312
    epoch:3, batch:300, cost:0.566666, acc:0.835938
    epoch:3, test\_cost:0.613746, test\_acc:0.799462
    epoch:4, batch:0, cost:0.477513, acc:0.835938
    epoch:4, batch:100, cost:0.510523, acc:0.820312
    epoch:4, batch:200, cost:0.635614, acc:0.765625
    epoch:4, batch:300, cost:0.487310, acc:0.820312
    epoch:4, test\_cost:0.579772, test\_acc:0.810488
    epoch:5, batch:0, cost:0.564993, acc:0.796875
    epoch:5, batch:100, cost:0.406841, acc:0.882812
    epoch:5, batch:200, cost:0.444392, acc:0.867188
    epoch:5, batch:300, cost:0.729928, acc:0.765625
    epoch:5, test\_cost:0.549828, test\_acc:0.820210
    epoch:6, batch:0, cost:0.676264, acc:0.757812
    epoch:6, batch:100, cost:0.596377, acc:0.859375
    epoch:6, batch:200, cost:0.532443, acc:0.773438
    epoch:6, batch:300, cost:0.534213, acc:0.820312
    epoch:6, test\_cost:0.529354, test\_acc:0.824554
    epoch:7, batch:0, cost:0.426611, acc:0.867188
    epoch:7, batch:100, cost:0.474698, acc:0.851562
    epoch:7, batch:200, cost:0.529180, acc:0.851562
    epoch:7, batch:300, cost:0.512545, acc:0.820312
    epoch:7, test\_cost:0.510936, test\_acc:0.830113
    epoch:8, batch:0, cost:0.559594, acc:0.859375
    epoch:8, batch:100, cost:0.670438, acc:0.812500
    epoch:8, batch:200, cost:0.519166, acc:0.796875
    epoch:8, batch:300, cost:0.487624, acc:0.835938
    epoch:8, test\_cost:0.495487, test\_acc:0.831584
    epoch:9, batch:0, cost:0.523822, acc:0.875000
    epoch:9, batch:100, cost:0.604066, acc:0.812500
    epoch:9, batch:200, cost:0.420712, acc:0.835938
    epoch:9, batch:300, cost:0.490420, acc:0.867188
    epoch:9, test\_cost:0.480619, test\_acc:0.838099
    epoch:10, batch:0, cost:0.337900, acc:0.875000
    epoch:10, batch:100, cost:0.513283, acc:0.820312
    epoch:10, batch:200, cost:0.471449, acc:0.835938
    epoch:10, batch:300, cost:0.387642, acc:0.859375
    epoch:10, test\_cost:0.470769, test\_acc:0.842698
    epoch:11, batch:0, cost:0.448336, acc:0.835938
    epoch:11, batch:100, cost:0.440521, acc:0.843750
    epoch:11, batch:200, cost:0.436573, acc:0.875000
    epoch:11, batch:300, cost:0.694043, acc:0.781250
    epoch:11, test\_cost:0.460047, test\_acc:0.846606
    epoch:12, batch:0, cost:0.481746, acc:0.882812
    epoch:12, batch:100, cost:0.552047, acc:0.804688
    epoch:12, batch:200, cost:0.386960, acc:0.882812
    epoch:12, batch:300, cost:0.341771, acc:0.882812
    epoch:12, test\_cost:0.450299, test\_acc:0.850769
    epoch:13, batch:0, cost:0.547187, acc:0.859375
    epoch:13, batch:100, cost:0.356574, acc:0.875000
    epoch:13, batch:200, cost:0.379301, acc:0.867188
    epoch:13, batch:300, cost:0.374310, acc:0.898438
    epoch:13, test\_cost:0.440953, test\_acc:0.856413
    epoch:14, batch:0, cost:0.689324, acc:0.843750
    epoch:14, batch:100, cost:0.493557, acc:0.851562
    epoch:14, batch:200, cost:0.332372, acc:0.898438
    epoch:14, batch:300, cost:0.462879, acc:0.867188
    epoch:14, test\_cost:0.434527, test\_acc:0.858323
    epoch:15, batch:0, cost:0.539430, acc:0.859375
    epoch:15, batch:100, cost:0.610829, acc:0.828125
    epoch:15, batch:200, cost:0.567205, acc:0.773438
    epoch:15, batch:300, cost:0.448473, acc:0.851562
    epoch:15, test\_cost:0.429657, test\_acc:0.858929
    epoch:16, batch:0, cost:0.275449, acc:0.914062
    epoch:16, batch:100, cost:0.322028, acc:0.890625
    epoch:16, batch:200, cost:0.433238, acc:0.843750
    epoch:16, batch:300, cost:0.435924, acc:0.875000
    epoch:16, test\_cost:0.420777, test\_acc:0.864314
    epoch:17, batch:0, cost:0.463425, acc:0.859375
    epoch:17, batch:100, cost:0.498210, acc:0.890625
    epoch:17, batch:200, cost:0.347210, acc:0.898438
    epoch:17, batch:300, cost:0.375353, acc:0.867188
    epoch:17, test\_cost:0.414920, test\_acc:0.866483
    epoch:18, batch:0, cost:0.371144, acc:0.906250
    epoch:18, batch:100, cost:0.511294, acc:0.828125
    epoch:18, batch:200, cost:0.431728, acc:0.828125
    epoch:18, batch:300, cost:0.505222, acc:0.843750
    epoch:18, test\_cost:0.412018, test\_acc:0.866136
    epoch:19, batch:0, cost:0.417319, acc:0.859375
    epoch:19, batch:100, cost:0.405875, acc:0.867188
    epoch:19, batch:200, cost:0.466319, acc:0.843750
    epoch:19, batch:300, cost:0.524598, acc:0.820312
    epoch:19, test\_cost:0.408254, test\_acc:0.870387
    epoch:20, batch:0, cost:0.278774, acc:0.921875
    epoch:20, batch:100, cost:0.375402, acc:0.875000
    epoch:20, batch:200, cost:0.512493, acc:0.851562
    epoch:20, batch:300, cost:0.352869, acc:0.867188
    epoch:20, test\_cost:0.402862, test\_acc:0.870646
    epoch:21, batch:0, cost:0.328388, acc:0.890625
    epoch:21, batch:100, cost:0.474930, acc:0.843750
    epoch:21, batch:200, cost:0.279459, acc:0.898438
    epoch:21, batch:300, cost:0.480916, acc:0.843750
    epoch:21, test\_cost:0.398193, test\_acc:0.870476
    epoch:22, batch:0, cost:0.360476, acc:0.914062
    epoch:22, batch:100, cost:0.399123, acc:0.867188
    epoch:22, batch:200, cost:0.330940, acc:0.898438
    epoch:22, batch:300, cost:0.449070, acc:0.851562
    epoch:22, test\_cost:0.396272, test\_acc:0.872644
    epoch:23, batch:0, cost:0.311765, acc:0.882812
    epoch:23, batch:100, cost:0.430598, acc:0.859375
    epoch:23, batch:200, cost:0.371466, acc:0.867188
    epoch:23, batch:300, cost:0.497460, acc:0.859375
    epoch:23, test\_cost:0.391935, test\_acc:0.874990
    epoch:24, batch:0, cost:0.278461, acc:0.921875
    epoch:24, batch:100, cost:0.384332, acc:0.867188
    epoch:24, batch:200, cost:0.687089, acc:0.804688
    epoch:24, batch:300, cost:0.465835, acc:0.835938
    epoch:24, test\_cost:0.386384, test\_acc:0.874905
    epoch:25, batch:0, cost:0.359800, acc:0.914062
    epoch:25, batch:100, cost:0.370942, acc:0.906250
    epoch:25, batch:200, cost:0.343612, acc:0.906250
    epoch:25, batch:300, cost:0.373149, acc:0.859375
    epoch:25, test\_cost:0.385754, test\_acc:0.875249
    epoch:26, batch:0, cost:0.359912, acc:0.859375
    epoch:26, batch:100, cost:0.299233, acc:0.906250
    epoch:26, batch:200, cost:0.321898, acc:0.882812
    epoch:26, batch:300, cost:0.506139, acc:0.820312
    epoch:26, test\_cost:0.382092, test\_acc:0.877597
    epoch:27, batch:0, cost:0.438806, acc:0.882812
    epoch:27, batch:100, cost:0.351698, acc:0.867188
    epoch:27, batch:200, cost:0.413263, acc:0.875000
    epoch:27, batch:300, cost:0.327677, acc:0.875000
    epoch:27, test\_cost:0.379122, test\_acc:0.880460
    epoch:28, batch:0, cost:0.329184, acc:0.921875
    epoch:28, batch:100, cost:0.489258, acc:0.882812
    epoch:28, batch:200, cost:0.375317, acc:0.890625
    epoch:28, batch:300, cost:0.355702, acc:0.859375
    epoch:28, test\_cost:0.377964, test\_acc:0.881066
    epoch:29, batch:0, cost:0.360147, acc:0.882812
    epoch:29, batch:100, cost:0.361545, acc:0.906250
    epoch:29, batch:200, cost:0.535644, acc:0.812500
    epoch:29, batch:300, cost:0.463827, acc:0.789062
    epoch:29, test\_cost:0.374992, test\_acc:0.879156
    epoch:30, batch:0, cost:0.386321, acc:0.843750
    epoch:30, batch:100, cost:0.450116, acc:0.851562
    epoch:30, batch:200, cost:0.380319, acc:0.867188
    epoch:30, batch:300, cost:0.357393, acc:0.914062
    epoch:30, test\_cost:0.372232, test\_acc:0.880198
    epoch:31, batch:0, cost:0.338851, acc:0.882812
    epoch:31, batch:100, cost:0.418707, acc:0.890625
    epoch:31, batch:200, cost:0.349568, acc:0.875000
    epoch:31, batch:300, cost:0.414638, acc:0.882812
    epoch:31, test\_cost:0.373127, test\_acc:0.879245
    epoch:32, batch:0, cost:0.278832, acc:0.906250
    epoch:32, batch:100, cost:0.538143, acc:0.851562
    epoch:32, batch:200, cost:0.418359, acc:0.890625
    epoch:32, batch:300, cost:0.510367, acc:0.875000
    epoch:32, test\_cost:0.370239, test\_acc:0.880896
    epoch:33, batch:0, cost:0.410598, acc:0.835938
    epoch:33, batch:100, cost:0.295002, acc:0.906250
    epoch:33, batch:200, cost:0.430560, acc:0.828125
    epoch:33, batch:300, cost:0.417476, acc:0.859375
    epoch:33, test\_cost:0.367410, test\_acc:0.881155
    epoch:34, batch:0, cost:0.337740, acc:0.937500
    epoch:34, batch:100, cost:0.304080, acc:0.906250
    epoch:34, batch:200, cost:0.359049, acc:0.890625
    epoch:34, batch:300, cost:0.373999, acc:0.890625
    epoch:34, test\_cost:0.367002, test\_acc:0.880113
    epoch:35, batch:0, cost:0.411581, acc:0.898438
    epoch:35, batch:100, cost:0.400797, acc:0.851562
    epoch:35, batch:200, cost:0.482271, acc:0.828125
    epoch:35, batch:300, cost:0.340450, acc:0.890625
    epoch:35, test\_cost:0.363663, test\_acc:0.883068
    epoch:36, batch:0, cost:0.338912, acc:0.875000
    epoch:36, batch:100, cost:0.416916, acc:0.867188
    epoch:36, batch:200, cost:0.313621, acc:0.882812
    epoch:36, batch:300, cost:0.677497, acc:0.796875
    epoch:36, test\_cost:0.361819, test\_acc:0.882983
    epoch:37, batch:0, cost:0.329249, acc:0.867188
    epoch:37, batch:100, cost:0.375915, acc:0.890625
    epoch:37, batch:200, cost:0.290267, acc:0.906250
    epoch:37, batch:300, cost:0.388264, acc:0.859375
    epoch:37, test\_cost:0.363713, test\_acc:0.880025
    epoch:38, batch:0, cost:0.452093, acc:0.875000
    epoch:38, batch:100, cost:0.237014, acc:0.898438
    epoch:38, batch:200, cost:0.334976, acc:0.898438
    epoch:38, batch:300, cost:0.386618, acc:0.875000
    epoch:38, test\_cost:0.357681, test\_acc:0.884889
    epoch:39, batch:0, cost:0.397014, acc:0.867188
    epoch:39, batch:100, cost:0.387132, acc:0.882812
    epoch:39, batch:200, cost:0.262646, acc:0.921875
    epoch:39, batch:300, cost:0.295718, acc:0.906250
    epoch:39, test\_cost:0.358814, test\_acc:0.884542
    epoch:40, batch:0, cost:0.336061, acc:0.875000
    epoch:40, batch:100, cost:0.393282, acc:0.867188
    epoch:40, batch:200, cost:0.453071, acc:0.867188
    epoch:40, batch:300, cost:0.276213, acc:0.921875
    epoch:40, test\_cost:0.355846, test\_acc:0.886278
    epoch:41, batch:0, cost:0.362588, acc:0.867188
    epoch:41, batch:100, cost:0.293396, acc:0.914062
    epoch:41, batch:200, cost:0.351766, acc:0.890625
    epoch:41, batch:300, cost:0.437711, acc:0.820312
    epoch:41, test\_cost:0.356017, test\_acc:0.886799
    epoch:42, batch:0, cost:0.431722, acc:0.843750
    epoch:42, batch:100, cost:0.296809, acc:0.914062
    epoch:42, batch:200, cost:0.300333, acc:0.898438
    epoch:42, batch:300, cost:0.392034, acc:0.859375
    epoch:42, test\_cost:0.354504, test\_acc:0.885580
    epoch:43, batch:0, cost:0.237395, acc:0.945312
    epoch:43, batch:100, cost:0.274653, acc:0.914062
    epoch:43, batch:200, cost:0.320165, acc:0.898438
    epoch:43, batch:300, cost:0.233366, acc:0.937500
    epoch:43, test\_cost:0.352862, test\_acc:0.885410
    epoch:44, batch:0, cost:0.309431, acc:0.953125
    epoch:44, batch:100, cost:0.371803, acc:0.843750
    epoch:44, batch:200, cost:0.309721, acc:0.898438
    epoch:44, batch:300, cost:0.330030, acc:0.898438
    epoch:44, test\_cost:0.348967, test\_acc:0.888017
    epoch:45, batch:0, cost:0.382172, acc:0.890625
    epoch:45, batch:100, cost:0.292855, acc:0.929688
    epoch:45, batch:200, cost:0.445127, acc:0.898438
    epoch:45, batch:300, cost:0.365554, acc:0.890625
    epoch:45, test\_cost:0.352218, test\_acc:0.883932
    epoch:46, batch:0, cost:0.424743, acc:0.898438
    epoch:46, batch:100, cost:0.382699, acc:0.859375
    epoch:46, batch:200, cost:0.319472, acc:0.914062
    epoch:46, batch:300, cost:0.414162, acc:0.859375
    epoch:46, test\_cost:0.349987, test\_acc:0.885498
    epoch:47, batch:0, cost:0.304131, acc:0.890625
    epoch:47, batch:100, cost:0.386861, acc:0.890625
    epoch:47, batch:200, cost:0.608894, acc:0.820312
    epoch:47, batch:300, cost:0.281832, acc:0.898438
    epoch:47, test\_cost:0.349286, test\_acc:0.888276
    epoch:48, batch:0, cost:0.406423, acc:0.882812
    epoch:48, batch:100, cost:0.398680, acc:0.898438
    epoch:48, batch:200, cost:0.291706, acc:0.914062
    epoch:48, batch:300, cost:0.358105, acc:0.875000
    epoch:48, test\_cost:0.348130, test\_acc:0.888361
    epoch:49, batch:0, cost:0.284720, acc:0.914062
    epoch:49, batch:100, cost:0.341173, acc:0.898438
    epoch:49, batch:200, cost:0.341595, acc:0.859375
    epoch:49, batch:300, cost:0.442754, acc:0.820312
    epoch:49, test\_cost:0.347218, test\_acc:0.886012
    epoch:50, batch:0, cost:0.311721, acc:0.906250
    epoch:50, batch:100, cost:0.326822, acc:0.875000
    epoch:50, batch:200, cost:0.331799, acc:0.898438
    epoch:50, batch:300, cost:0.426647, acc:0.851562
    epoch:50, test\_cost:0.347288, test\_acc:0.888535
    epoch:51, batch:0, cost:0.389481, acc:0.867188
    epoch:51, batch:100, cost:0.289127, acc:0.906250
    epoch:51, batch:200, cost:0.328051, acc:0.929688
    epoch:51, batch:300, cost:0.426396, acc:0.890625
    epoch:51, test\_cost:0.344246, test\_acc:0.889839
    epoch:52, batch:0, cost:0.288156, acc:0.906250
    epoch:52, batch:100, cost:0.298805, acc:0.906250
    epoch:52, batch:200, cost:0.371176, acc:0.921875
    epoch:52, batch:300, cost:0.389306, acc:0.875000
    epoch:52, test\_cost:0.345692, test\_acc:0.891224
    epoch:53, batch:0, cost:0.425932, acc:0.890625
    epoch:53, batch:100, cost:0.415528, acc:0.882812
    epoch:53, batch:200, cost:0.434767, acc:0.867188
    epoch:53, batch:300, cost:0.331441, acc:0.914062
    epoch:53, test\_cost:0.340924, test\_acc:0.890101
    epoch:54, batch:0, cost:0.260270, acc:0.906250
    epoch:54, batch:100, cost:0.305412, acc:0.898438
    epoch:54, batch:200, cost:0.330370, acc:0.906250
    epoch:54, batch:300, cost:0.334084, acc:0.898438
    epoch:54, test\_cost:0.341799, test\_acc:0.892010
    epoch:55, batch:0, cost:0.239946, acc:0.937500
    epoch:55, batch:100, cost:0.510334, acc:0.898438
    epoch:55, batch:200, cost:0.331789, acc:0.898438
    epoch:55, batch:300, cost:0.273344, acc:0.898438
    epoch:55, test\_cost:0.341348, test\_acc:0.889403
    epoch:56, batch:0, cost:0.288282, acc:0.914062
    epoch:56, batch:100, cost:0.384843, acc:0.898438
    epoch:56, batch:200, cost:0.391903, acc:0.867188
    epoch:56, batch:300, cost:0.352458, acc:0.882812
    epoch:56, test\_cost:0.338860, test\_acc:0.891054
    epoch:57, batch:0, cost:0.434810, acc:0.828125
    epoch:57, batch:100, cost:0.257800, acc:0.953125
    epoch:57, batch:200, cost:0.283473, acc:0.921875
    epoch:57, batch:300, cost:0.337173, acc:0.867188
    epoch:57, test\_cost:0.339060, test\_acc:0.891575
    epoch:58, batch:0, cost:0.240891, acc:0.898438
    epoch:58, batch:100, cost:0.390225, acc:0.875000
    epoch:58, batch:200, cost:0.393483, acc:0.843750
    epoch:58, batch:300, cost:0.289487, acc:0.890625
    epoch:58, test\_cost:0.337302, test\_acc:0.892269
    epoch:59, batch:0, cost:0.210337, acc:0.960938
    epoch:59, batch:100, cost:0.423231, acc:0.867188
    epoch:59, batch:200, cost:0.319490, acc:0.921875
    epoch:59, batch:300, cost:0.451494, acc:0.859375
    epoch:59, test\_cost:0.336483, test\_acc:0.893137
    epoch:60, batch:0, cost:0.231775, acc:0.937500
    epoch:60, batch:100, cost:0.295306, acc:0.906250
    epoch:60, batch:200, cost:0.378960, acc:0.859375
    epoch:60, batch:300, cost:0.350808, acc:0.843750
    epoch:60, test\_cost:0.335058, test\_acc:0.894267
    epoch:61, batch:0, cost:0.440865, acc:0.867188
    epoch:61, batch:100, cost:0.270725, acc:0.882812
    epoch:61, batch:200, cost:0.398181, acc:0.851562
    epoch:61, batch:300, cost:0.363882, acc:0.921875
    epoch:61, test\_cost:0.336761, test\_acc:0.892875
    epoch:62, batch:0, cost:0.321757, acc:0.898438
    epoch:62, batch:100, cost:0.330311, acc:0.890625
    epoch:62, batch:200, cost:0.406124, acc:0.851562
    epoch:62, batch:300, cost:0.275819, acc:0.898438
    epoch:62, test\_cost:0.342463, test\_acc:0.891565
    epoch:63, batch:0, cost:0.321822, acc:0.898438
    epoch:63, batch:100, cost:0.322195, acc:0.882812
    epoch:63, batch:200, cost:0.432605, acc:0.882812
    epoch:63, batch:300, cost:0.377368, acc:0.898438
    epoch:63, test\_cost:0.333785, test\_acc:0.895221
    epoch:64, batch:0, cost:0.247617, acc:0.882812
    epoch:64, batch:100, cost:0.231372, acc:0.921875
    epoch:64, batch:200, cost:0.336805, acc:0.867188
    epoch:64, batch:300, cost:0.274635, acc:0.898438
    epoch:64, test\_cost:0.332033, test\_acc:0.894179
    epoch:65, batch:0, cost:0.241076, acc:0.906250
    epoch:65, batch:100, cost:0.377462, acc:0.906250
    epoch:65, batch:200, cost:0.297226, acc:0.882812
    epoch:65, batch:300, cost:0.440397, acc:0.867188
    epoch:65, test\_cost:0.330794, test\_acc:0.897045
    epoch:66, batch:0, cost:0.266126, acc:0.898438
    epoch:66, batch:100, cost:0.390715, acc:0.859375
    epoch:66, batch:200, cost:0.292437, acc:0.914062
    epoch:66, batch:300, cost:0.395078, acc:0.867188
    epoch:66, test\_cost:0.330902, test\_acc:0.895221
    epoch:67, batch:0, cost:0.301438, acc:0.929688
    epoch:67, batch:100, cost:0.388324, acc:0.898438
    epoch:67, batch:200, cost:0.439915, acc:0.890625
    epoch:67, batch:300, cost:0.310547, acc:0.867188
    epoch:67, test\_cost:0.330386, test\_acc:0.896521
    epoch:68, batch:0, cost:0.243119, acc:0.929688
    epoch:68, batch:100, cost:0.447522, acc:0.875000
    epoch:68, batch:200, cost:0.470691, acc:0.882812
    epoch:68, batch:300, cost:0.296465, acc:0.882812
    epoch:68, test\_cost:0.326098, test\_acc:0.896266
    epoch:69, batch:0, cost:0.260604, acc:0.898438
    epoch:69, batch:100, cost:0.417193, acc:0.882812
    epoch:69, batch:200, cost:0.483119, acc:0.835938
    epoch:69, batch:300, cost:0.405713, acc:0.875000
    epoch:69, test\_cost:0.328661, test\_acc:0.896957
    epoch:70, batch:0, cost:0.300975, acc:0.882812
    epoch:70, batch:100, cost:0.199427, acc:0.945312
    epoch:70, batch:200, cost:0.207260, acc:0.937500
    epoch:70, batch:300, cost:0.199148, acc:0.914062
    epoch:70, test\_cost:0.327545, test\_acc:0.894958
    epoch:71, batch:0, cost:0.281955, acc:0.914062
    epoch:71, batch:100, cost:0.267508, acc:0.914062
    epoch:71, batch:200, cost:0.561389, acc:0.828125
    epoch:71, batch:300, cost:0.377676, acc:0.867188
    epoch:71, test\_cost:0.325637, test\_acc:0.897740
    epoch:72, batch:0, cost:0.348661, acc:0.890625
    epoch:72, batch:100, cost:0.346154, acc:0.898438
    epoch:72, batch:200, cost:0.447819, acc:0.867188
    epoch:72, batch:300, cost:0.342514, acc:0.929688
    epoch:72, test\_cost:0.325294, test\_acc:0.897304
    epoch:73, batch:0, cost:0.223638, acc:0.929688
    epoch:73, batch:100, cost:0.394560, acc:0.859375
    epoch:73, batch:200, cost:0.341260, acc:0.890625
    epoch:73, batch:300, cost:0.283185, acc:0.898438
    epoch:73, test\_cost:0.326340, test\_acc:0.895394
    epoch:74, batch:0, cost:0.371942, acc:0.921875
    epoch:74, batch:100, cost:0.333636, acc:0.882812
    epoch:74, batch:200, cost:0.397030, acc:0.875000
    epoch:74, batch:300, cost:0.392802, acc:0.875000
    epoch:74, test\_cost:0.322571, test\_acc:0.896089
    epoch:75, batch:0, cost:0.275930, acc:0.921875
    epoch:75, batch:100, cost:0.263152, acc:0.914062
    epoch:75, batch:200, cost:0.296550, acc:0.898438
    epoch:75, batch:300, cost:0.402121, acc:0.898438
    epoch:75, test\_cost:0.320611, test\_acc:0.897134
    epoch:76, batch:0, cost:0.279775, acc:0.921875
    epoch:76, batch:100, cost:0.439274, acc:0.843750
    epoch:76, batch:200, cost:0.330266, acc:0.898438
    epoch:76, batch:300, cost:0.418308, acc:0.851562
    epoch:76, test\_cost:0.320242, test\_acc:0.900429
    epoch:77, batch:0, cost:0.320668, acc:0.890625
    epoch:77, batch:100, cost:0.168939, acc:0.960938
    epoch:77, batch:200, cost:0.244379, acc:0.953125
    epoch:77, batch:300, cost:0.621534, acc:0.875000
    epoch:77, test\_cost:0.319756, test\_acc:0.900865
    epoch:78, batch:0, cost:0.284392, acc:0.914062
    epoch:78, batch:100, cost:0.309243, acc:0.890625
    epoch:78, batch:200, cost:0.273962, acc:0.945312
    epoch:78, batch:300, cost:0.311928, acc:0.906250
    epoch:78, test\_cost:0.318491, test\_acc:0.901818
    epoch:79, batch:0, cost:0.242170, acc:0.898438
    epoch:79, batch:100, cost:0.315753, acc:0.875000
    epoch:79, batch:200, cost:0.252874, acc:0.937500
    epoch:79, batch:300, cost:0.447730, acc:0.812500
    epoch:79, test\_cost:0.318828, test\_acc:0.900603
    模型保存成功.

    【推理预测】

    model_save_dir = "model/"
    
    def get_data(sentence): # 将传入的句子根据字典中的值进行编码
        with open(dict_file_path, "r", encoding="utf-8") as f:
            dict_txt = eval(f.readlines()[0])
    
        ret = [] # 编码结果
        keys = dict_txt.keys()
        for w in sentence: # 取出每个字
            if not w in keys: # 字不在字典中
                w = "<unk>"
            ret.append(int(dict_txt[w]))
        return ret
    
    # 执行器
    place = fluid.CPUPlace()
    exe = fluid.Executor(place)
    exe.run(fluid.default_startup_program())
    
    infer_program, feed_names, target_var = \
    fluid.io.load_inference_model(model_save_dir, exe)
    
    texts = [] # 存放待预测句子
    
    data1 = get_data("在获得诺贝尔文学奖7年之后,莫言15日晚间在山西汾阳贾家庄如是说")
    data2 = get_data("综合'今日美国'、《世界日报》等当地媒体报道,芝加哥河滨警察局表示")
    data3 = get_data("中国队2022年冬奥会表现优秀")
    data4 = get_data("中国人民银行今日发布通知,降低准备金率,预计释放4000亿流动性")
    data5 = get_data("10月20日,第六届世界互联网大会正式开幕")
    data6 = get_data("同一户型,为什么高层比低层要贵那么多?")
    data7 = get_data("揭秘A股周涨5%资金动向:追捧2类股,抛售600亿香饽饽")
    data8 = get_data("宋慧乔陷入感染危机,前夫宋仲基不戴口罩露面,身处国外神态轻松")
    data9 = get_data("此盆栽花很好养,花美似牡丹,三季开花,南北都能养,很值得栽培")  # 不属于任何一个类别
    
    texts.append(data1)
    texts.append(data2)
    texts.append(data3)
    texts.append(data4)
    texts.append(data5)
    texts.append(data6)
    texts.append(data7)
    texts.append(data8)
    texts.append(data9)
    
    base_shape = [[len(c) for c in texts]] # 计算每个句子长度
    tensor_words = fluid.create_lod_tensor(texts, base_shape, place)
    result = exe.run(infer_program,
                     feed={feed_names[0]: tensor_words},
                     fetch_list=target_var)
    names = ["文化", "娱乐", "体育", "财经", "房产","汽车", "教育", "科技", "国际", "证券"]
    for r in result[0]:
        idx = np.argmax(r) # 取出最大值的索引
        print("预测结果:", names[idx], " 概率:", r[idx])

    输出

    预测结果: 财经  概率: 0.81440145
    预测结果: 娱乐  概率: 1.0
    预测结果: 财经  概率: 1.0
    预测结果: 汽车  概率: 0.9996093
    预测结果: 文化  概率: 0.9404757
    预测结果: 娱乐  概率: 0.8715788
    预测结果: 房产  概率: 0.9625704
    预测结果: 科技  概率: 0.985617
    预测结果: 房产  概率: 1.0

    文章涉及到的数据资源链接如下:news_classify_data.zip - 蓝奏云文件大小:2.7 M|icon-default.png?t=N7T8https://wwt.lanzoum.com/iTF6N1q1rnti

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐