NeuralNLP-NeuralClassifier-master
  1. 所有用超参数用json文件保存
      

  2. 训练结束后设置学习率lr=0, 这样就不用设置is_train这个参数了
      
    def update_lr(self, optimizer, epoch):
        if epoch > self.config.train.num_epochs_static_embedding:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = self.config.optimizer.learning_rate
        else:
            for param_group in optimizer.param_groups[:2]:
                param_group["lr"] = 0 # 结束设置为0

  3. dataloader中的sample选择有RandomSampler(随机选择), SequentialSampler(顺序选择)
     
         # 1.sample
         if batch_sampler is None:
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(dataset)
                else:
                    sampler = SequentialSampler(dataset)
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
 
        # 2.yield
        def __iter__(self):
            batch = []
            for idx in self.sampler:
                batch.append(idx)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
            if len(batch) > 0 and not self.drop_last:
                yield batch

  4. graph层是一个组件(相当于embedding), 不同于传统代码中的每个模型后边加loss层, loss是在ClassificationTrainer里边

  
        total_loss = 0.
        for batch in data_loader:
            logits = model(batch)
            # hierarchical classification
            if self.conf.task_info.hierarchical:
                linear_paras = model.linear.weight
                is_hierar = True
                # 层次化类别惩罚
                used_argvs = (self.conf.task_info.hierar_penalty, linear_paras, self.hierar_relations)
                loss = self.loss_fn(
                    logits,
                    batch[ClassificationDataset.DOC_LABEL].to(self.conf.device),
                    is_hierar,
                    is_multi,
                    *used_argvs)
  5. 使用BCEWithLogitsLoss(该loss 层包括了 Sigmoid 层和BCELoss 层. 单类别任务.数值计算稳定性更好(log-sum-exp trik), 相比于Sigmoid +BCELoss.), 而不是BCELoss(计算target 和output 间的二值交叉熵(Binary Cross Entropy))

  6. pytorch带的focal_loss(焦点损失, topk更新差异较大的loss而不是全部),tf和keras默认中没有, OHEM(Online Hard Example Mining)

  7. 支持多标签分类(multi-class)和层次多标签分类(hierarchical-multi-class),multi-class只支持转化为多类分类的,hierarchical-multi-class只支持hierarchical text classification with BCELoss

  8. embedding只有PositionEmbedding和RegionEmbedding(word-context, context-word),没有word2vec、bert等

  9. 很常见的macro_average和micro_average的precision, recall, f_score(默认);即所有语料的平均,或者是每个类的平均,默认类平均micro_average
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐