使用SWIFT微调大模型完成回归任务

传统大语言模型通过预测下一个token的概率来生成文本,其损失函数通常采用softmax分类器,这使得它们难以直接处理回归任务(需要输出连续数值)。然而,大模型具备强大的高维特征编码能力,我们可以利用预训练模型作为特征提取器,通过微调适配回归任务,从而获得优异的性能表现。

SWIFT框架对大模型任务提供了全面支持,最新版本已专门增加了回归任务功能。本文将详细介绍SWIFT框架实现回归任务的技术方案。

SWIFT执行回归命令

执行回归任务的命令示例如下:

CUDA_VISIBLE_DEVICES=0 \
swift sft \
    --model Qwen/Qwen2.5-0.5B \
    --train_type lora \
    --dataset 'sentence-transformers/stsb:reg#20000' \
    --torch_dtype bfloat16 \
    --num_train_epochs 1 \
    --per_device_train_batch_size 16 \
    --per_device_eval_batch_size 16 \
    --learning_rate 1e-4 \
    --lora_rank 8 \
    --lora_alpha 32 \
    --target_modules all-linear \
    --gradient_accumulation_steps 1 \
    --eval_steps 100 \
    --save_steps 100 \
    --save_total_limit 2 \
    --logging_steps 5 \
    --max_length 2048 \
    --output_dir output \
    --warmup_ratio 0.05 \
    --dataloader_num_workers 4 \
    --dataset_num_proc 4 \
    --num_labels 1 \
    --task_type seq_cls \
    --use_chat_template false \
    --problem_type regression
关键参数说明

必须指定的核心参数:

--num_labels 1 \
--task_type seq_cls 

可选参数(框架会自动推断):

--problem_type regression

本示例使用大模型计算句子相似度,输出为0-1之间的连续相似度分数。

SWIFT框架回归任务的实现机制

损失函数设计

框架通过swift.trainers.Trainer类调用以下损失函数:
损失函数示意图

def ForSequenceClassificationLoss(labels: torch.Tensor, pooled_logits: torch.Tensor, config, **kwargs) -> torch.Tensor:
    num_labels = config.num_labels
    if config.problem_type is None:
        if num_labels == 1:
            config.problem_type = "regression"
        elif num_labels > 1 and (labels.dtype in (torch.long, torch.int)):
            config.problem_type = "single_label_classification"
        else:
            config.problem_type = "multi_label_classification"

    labels = labels.to(pooled_logits.device)
    if config.problem_type == "regression":
        loss_fct = MSELoss()
        if num_labels == 1:
            return loss_fct(pooled_logits.squeeze(), labels.squeeze())
        else:
            return loss_fct(pooled_logits, labels)
    if config.problem_type == "single_label_classification":
        return fixed_cross_entropy(pooled_logits.view(-1, num_labels), labels.view(-1), **kwargs)

    if config.problem_type == "multi_label_classification":
        loss_fct = BCEWithLogitsLoss()
        return loss_fct(pooled_logits, labels)

    raise RuntimeError(f"Invalid problem type: {config.problem_type}")

关键实现细节:
• 当num_labels=1时自动设置为回归任务

• 使用MSE损失函数(均方误差)

• 通过squeeze()操作确保张量维度匹配

模型架构解析
运行日志显示模型结构如下:

[INFO:swift] model: PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): Qwen2ForSequenceClassification(
      (model): Qwen2Model(
        (embed_tokens): Embedding(151936, 896)
        (layers): ModuleList(
          (0-23): 24 x Qwen2DecoderLayer(
            (self_attn): Qwen2Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=896, out_features=896, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=896, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=896, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (k_proj): lora.Linear(
                (base_layer): Linear(in_features=896, out_features=128, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=896, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=128, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (v_proj): lora.Linear(
                (base_layer): Linear(in_features=896, out_features=128, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=896, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=128, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (o_proj): lora.Linear(
                (base_layer): Linear(in_features=896, out_features=896, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=896, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=896, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
            )
            (mlp): Qwen2MLP(
              (gate_proj): lora.Linear(
                (base_layer): Linear(in_features=896, out_features=4864, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=896, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4864, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (up_proj): lora.Linear(
                (base_layer): Linear(in_features=896, out_features=4864, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=896, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=4864, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (down_proj): lora.Linear(
                (base_layer): Linear(in_features=4864, out_features=896, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4864, out_features=8, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=8, out_features=896, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
              (act_fn): SiLU()
            )
            (input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
            (post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)
          )
        )
        (norm): Qwen2RMSNorm((896,), eps=1e-06)
        (rotary_emb): Qwen2RotaryEmbedding()
      )
      (score): ModulesToSaveWrapper(
        (original_module): Linear(in_features=896, out_features=1, bias=False)
        (modules_to_save): ModuleDict(
          (default): Linear(in_features=896, out_features=1, bias=False)
        )
      )
    )
  )
)

整个模型的架构类似如下:
Qwen2架构图
我觉得可能跟Llama2架构更像,但是还是有不一样的(有时间进行修改更新):
Llama2架构图

模型架构深度解析

核心组件

1. 基座模型Qwen2
• 类型:Decoder-only架构(24层堆叠)

• 关键参数:

 ◦ 词嵌入:151,936词汇表×896维度

 ◦ 注意力机制:7个注意力头,键/值维度128

 ◦ MLP扩展比:5.44倍(896→4864)

 ◦ 归一化:Qwen2RMSNorm(ε=1e-6)

 ◦ 位置编码:RoPE旋转位置编码

2. LoRA适配器
• 适配位置:

 ◦ 注意力层:q/k/v/o_proj

 ◦ MLP层:gate/up/down_proj

• 参数配置:

 ◦ 秩r=8

 ◦ 低秩分解公式:W' = W + B·A

 ◦ Dropout率:0.05

*3. 回归输出层
• 结构:Linear(896→1)

• 训练方式:取最后一个token的隐藏状态(1,896),使用RMS损失函数计算回归结果。

举一个例子(以下AI生成,大致看了一下修改了一下)

假设输入三个TOKEN

输入序列:[token₁, token₂, token₃]
目标输出:数值预测值 y ∈ ℝ
  1. 输入嵌入层(Embedding)
    • 输入矩阵:X ∈ ℕ³ (三个token的ID序列)

• 操作:Embedding(151936, 896)

• 输出矩阵:E ∈ ℝ³×896

• 计算过程:

E = embed_layer(X)  
# 将每个token映射为896维向量
  1. Decoder层处理(24层迭代)
    每层包含以下关键操作:

    2.1 自注意力机制
    • 输入矩阵:H_in ∈ ℝ³×896

    • 线性变换(带LoRA):

    # 原始权重 W_q ∈ ℝ^{896×896} 
    # LoRA分解:ΔW = B_q @ A_q,其中 A_q ∈ ℝ^{896×8}, B_q ∈ ℝ^{8×896}
    Q = (H_in @ W_q) + (H_in @ A_q.T @ B_q.T)  # ∈ ℝ³×896
    
    # 类似处理K ∈ ℝ³×128, V ∈ ℝ³×128
    K = (H_in @ W_k) + (H_in @ A_k.T @ B_k.T)  
    V = (H_in @ W_v) + (H_in @ A_v.T @ B_v.T)
    

    • 注意力计算:

    attn_scores = Q @ K.T / √d_k  # d_k=128 → ∈ ℝ³×³
    attn_weights = softmax(attn_scores)  # ∈ ℝ³×³
    attn_output = attn_weights @ V  # ∈ ℝ³×128
    

    2.2 注意力输出投影(带LoRA)

    # 原始权重 W_o ∈ ℝ^{128×896}
    # LoRA分解:ΔW = B_o @ A_o,A_o ∈ ℝ^{128×8}, B_o ∈ ℝ^{8×896}
    attn_proj = (attn_output @ W_o) + (attn_output @ A_o.T @ B_o.T)  # ∈ ℝ³×896
    

    2.3 残差连接

    H_mid = H_in + attn_proj  # ∈ ℝ³×896
    

    2.4 MLP层(带LoRA)

    # 门控投影(gate_proj):
    gate = (H_mid @ W_gate) + (H_mid @ A_gate.T @ B_gate.T)  # ∈ ℝ³×4864
    
    # 上投影(up_proj): 
    up = (H_mid @ W_up) + (H_mid @ A_up.T @ B_up.T)  # ∈ ℝ³×4864
    
    # 激活与门控:
    activated = SiLU(gate) * up  # ∈ ℝ³×4864
    
    # 下投影(down_proj):
    down = (activated @ W_down) + (activated @ A_down.T @ B_down.T)  # ∈ ℝ³×896
    

    2.5 最终输出

    H_out = H_mid + down  # ∈ ℝ³×896
    
  2. 回归头计算
    • 取最后一个token的表示:

    last_hidden = H_out[-1]  # ∈ ℝ^896
    

    • 回归投影:

    y_pred = last_hidden @ W_reg  # W_reg ∈ ℝ^{896×1} → y_pred ∈ ℝ
    

关键矩阵维度表

操作阶段 核心矩阵 维度说明
输入嵌入 E (3, 896)
自注意力Q/K/V Q (3, 896)
K (3, 128)
V (3, 128)
注意力分数矩阵 attn_scores (3, 3)
注意力输出 attn_output (3, 128)
MLP门控投影 gate (3, 4864)
最终隐藏状态 H_out (每层输出) (3, 896)
回归权重 W_reg (896, 1)

计算过程可视化

输入序列: [tok₁, tok₂, tok₃]
    │
    ▼
嵌入层: (3,896)
    │
    ▼
[Decoder×24]: (3,896) → ... → (3,896)
    │
    ▼
取最后token: (1,896)
    │
    ▼
回归投影: (896,1) → y ∈ ℝ
Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐