腾讯开源分类项目码源阅读(NeuralNLP-NeuralClassifier优点与缺点)
NeuralNLP-NeuralClassifier-master1. 所有用超参数用json文件保存2. 训练结束后设置学习率lr=0, 这样就不用设置is_train这个参数了def update_lr(self, optimizer, epoch):if epoch > self.config.train.n...
·
NeuralNLP-NeuralClassifier-master
1. 所有用超参数用json文件保存
2. 训练结束后设置学习率lr=0, 这样就不用设置is_train这个参数了
def update_lr(self, optimizer, epoch):
if epoch > self.config.train.num_epochs_static_embedding:
for param_group in optimizer.param_groups[:2]:
param_group["lr"] = self.config.optimizer.learning_rate
else:
for param_group in optimizer.param_groups[:2]:
param_group["lr"] = 0 # 结束设置为0
3. dataloader中的sample选择有RandomSampler(随机选择), SequentialSampler(顺序选择)
# 1.sample
if batch_sampler is None:
if sampler is None:
if shuffle:
sampler = RandomSampler(dataset)
else:
sampler = SequentialSampler(dataset)
batch_sampler = BatchSampler(sampler, batch_size, drop_last)
# 2.yield
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
4. graph层是一个组件(相当于embedding), 不同于传统代码中的每个模型后边加loss层, loss是在ClassificationTrainer里边
total_loss = 0.
for batch in data_loader:
logits = model(batch)
# hierarchical classification
if self.conf.task_info.hierarchical:
linear_paras = model.linear.weight
is_hierar = True
# 层次化类别惩罚
used_argvs = (self.conf.task_info.hierar_penalty, linear_paras, self.hierar_relations)
loss = self.loss_fn(
logits,
batch[ClassificationDataset.DOC_LABEL].to(self.conf.device),
is_hierar,
is_multi,
*used_argvs)
5. 使用BCEWithLogitsLoss(该loss 层包括了 Sigmoid 层和BCELoss 层. 单类别任务.数值计算稳定性更好(log-sum-exp trik), 相比于Sigmoid +BCELoss.), 而不是BCELoss(计算target 和output 间的二值交叉熵(Binary Cross Entropy))
6. pytorch带的focal_loss(焦点损失, topk更新差异较大的loss而不是全部),tf和keras默认中没有, OHEM(Online Hard Example Mining)
7. 支持多标签分类(multi-class)和层次多标签分类(hierarchical-multi-class),multi-class只支持转化为多类分类的,hierarchical-multi-class只支持hierarchical text classification with BCELoss
8. embedding只有PositionEmbedding和RegionEmbedding(word-context, context-word),没有word2vec、bert等
9. 很常见的macro_average和micro_average的precision, recall, f_score(默认);即所有语料的平均,或者是每个类的平均,默认类平均micro_average
魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。
更多推荐



所有评论(0)