1. 各种蒸馏方案大全

在这里插入图片描述

2. 蒸馏工具

https://github.com/airaria/TextBrewer#quickstart

2.1 蒸馏步骤:

在这里插入图片描述

2.2. 方法:看起来比较简单
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig
# Show the statistics of model parameters
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print (result)
print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)
# Define an adaptor for interpreting the model inputs and outputs
def simple_adaptor(batch, model_outputs):
      # The second and third elements of model outputs are the logits and hidden states
    return {'logits': model_outputs[1],
            'hidden': model_outputs[2]}
# Training configuration 
train_config = TrainingConfig()
# Distillation configuration
# Matching different layers of the student and the teacher
distill_config = DistillationConfig(
    intermediate_matches=[    
     {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},
     {'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])
# Build distiller
distiller = GeneralDistiller(
    train_config=train_config, distill_config = distill_config,
    model_T = teacher_model, model_S = student_model, 
    adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)
# Start!
with distiller:
    distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)

3. 参考

BERT蒸馏完全指南|原理/技巧/代码:https://mp.weixin.qq.com/s/p0EZ4uFrLBLUuRRNMezGiQ

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐