跳到主要内容

模型蒸馏

模型蒸馏(Knowledge Distillation, KD)是一种模型压缩和能力迁移方法。它用一个能力更强的教师模型(Teacher Model)指导一个更小、更快、更便宜的学生模型(Student Model)训练,让学生模型在较低推理成本下尽量保留教师模型的能力。

在大语言模型场景中,蒸馏常用于:

  • 降低推理成本和显存占用。
  • 把大模型能力迁移到小模型、端侧模型或专用模型。
  • 为特定任务训练更快的专用模型,例如客服、代码、数学、检索问答。
  • 配合量化、剪枝、稀疏化等方法进一步压缩部署成本。

1. 核心思想

传统训练通常使用真实标签,也就是 hard target。
例如分类任务中,正确类别是 1,其他类别是 0。模型只知道“正确答案是什么”,但不知道其他错误答案之间的相似程度。

蒸馏训练会额外使用教师模型输出的 soft target。
soft target 是教师模型对所有候选类别或 token 的概率分布,它包含更多信息:

  • 哪些错误答案更接近正确答案。
  • 哪些 token 或类别之间更容易混淆。
  • 教师模型在不同候选输出上的偏好。
  • 教师模型对样本难度和不确定性的判断。

因此,学生模型不是只学习最终答案,而是在学习教师模型的输出分布、推理偏好和中间表示。


2. Hard Target 与 Soft Target

训练方式监督信号学到的信息典型目标
传统训练真实标签 hard target正确答案本身最大化 ground truth 概率
蒸馏训练教师模型提供的 soft target教师模型的输出偏好、类别关系和中间表示对齐教师模型的行为或表示

这里的 soft target 可以理解为“教师模型提供的更丰富监督信号”,不只限于最终概率分布。在 LLM 中,它可以来自:

  • 教师模型输出的 logits。
  • 教师模型经过 softmax 后的概率分布。
  • 教师模型生成的高质量回答。
  • 教师模型的中间层 hidden states 或 attention patterns。

3. Temperature 的作用

蒸馏通常会在 softmax 中加入温度参数 T:

  • qi=exp(zi/T)jexp(zj/T)q_i = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)}

其中:

  • ziz_i 是第 i 个类别或 token 的 logit。
  • TT 是 temperature。
  • qiq_i 是经过温度缩放后的输出概率。

Temperature 会改变概率分布的平滑程度:

温度效果含义
T=1T = 1原始 softmax不做额外平滑
T>1T > 1分布更平滑放大负标签中的相对信息
T<1T < 1分布更陡峭更强调最高概率类别
TT \to \infty接近均匀分布类别差异被大幅削弱
T0T \to 0接近 argmax退化为只关注最高概率

蒸馏里常用较高的 T,是因为教师模型对负标签的排序也有价值。
例如同样不是正确答案,概率 0.18 的类别通常比概率 0.001 的类别更接近正确语义。学生模型学习这种“相似性结构”,往往比只学习 hard target 更有效。


4. 常见训练目标

通用蒸馏目标通常由两部分组成:

损失作用
Distillation Loss让学生模型输出接近教师模型 soft target
Student Loss让学生模型仍然学习真实标签 hard target

常见写法是把两种损失加权:

total_loss = α * distillation_loss + β * student_loss

常用指标包括:

  • KL 散度:对齐教师和学生的概率分布。
  • MSE:对齐 logits、hidden states 或 attention matrices。
  • Cross Entropy:继续使用真实标签训练学生模型。

如果直接 match logits,而不经过 softmax,可以把目标理解为让学生模型直接拟合教师模型的原始输出。直接 match logits 可以看作在 TT \to \infty 情况下的一种特殊近似。


5. 典型方法

5.1 基于输出的蒸馏

基于输出的蒸馏只关注教师模型和学生模型的最终输出。

常见做法:

  • 用教师模型生成 soft labels。
  • 用 KL 散度对齐教师和学生的输出概率分布。
  • 对 LLM 来说,也可以用教师模型生成高质量回答,再用这些回答训练学生模型。

适合场景:

  • 学生模型结构和教师模型不同。
  • 只能访问教师模型 API,不能访问中间层。
  • 想快速蒸馏特定任务能力。

5.2 基于中间层特征的蒸馏

这类方法会对齐教师模型和学生模型的内部表示。

常见对齐对象:

  • Embedding。
  • Hidden states。
  • Attention scores。
  • Value relation。
  • 最后一层或多层中间表示。

代表方法:

  • TinyBERT:在预训练和任务微调阶段对齐 embedding、attention、hidden states 和 prediction layer。
  • MiniLM:重点蒸馏最后一层自注意力关系和值关系,减少对完整教师结构的依赖。

适合场景:

  • 可以访问教师模型内部结构。
  • 学生模型和教师模型结构相近,或可以设计层映射关系。
  • 希望学生模型不仅模仿输出,还模仿推理表示。

5.3 分阶段蒸馏

分阶段蒸馏把蒸馏拆成多个训练阶段。

常见阶段:

  • 预训练阶段蒸馏:先让学生模型学习通用语言表示。
  • 任务微调阶段蒸馏:再让学生模型学习特定任务表现。
  • 指令蒸馏:用教师模型生成指令数据,训练学生模型的对话和任务执行能力。

适合场景:

  • 需要兼顾通用能力和任务能力。
  • 数据量较大,训练周期较长。
  • 想训练一个可复用的小模型基座。

5.4 在线蒸馏与自蒸馏

在线蒸馏(Online Distillation)中,教师模型和学生模型可以同时训练,甚至多个模型互相学习。

自蒸馏(Self-Distillation)中,模型从自身的不同层、不同阶段或历史 checkpoint 中学习。

适合场景:

  • 没有固定教师模型。
  • 想提升模型训练稳定性。
  • 希望在不引入额外大模型的情况下提升小模型表现。

6. LLM 蒸馏中的关键问题

6.1 数据选择

蒸馏效果高度依赖数据质量。常见数据来源包括:

  • 原始训练数据。
  • 教师模型生成的合成数据。
  • 人工标注数据。
  • 高质量指令数据。
  • 领域任务数据,例如代码、数学、医疗、客服。

如果教师模型生成的数据有幻觉、偏见或格式问题,学生模型也会继承这些问题。

6.2 教师模型与学生模型差距

教师太强、学生太小,学生可能无法吸收复杂能力。常见处理方式:

  • 选择容量差距适中的教师和学生。
  • 使用中间教师模型逐级蒸馏。
  • 降低任务难度或只蒸馏部分能力。
  • 针对特定任务做专用蒸馏。

6.3 层映射

如果对齐中间层,需要设计教师层和学生层的对应关系。例如:

  • 每隔几层对齐一次。
  • 只对齐最后几层。
  • 按层数比例映射。
  • 用投影层处理 hidden size 不一致的问题。

层映射设计不当,可能会让学生模型学习到不合适的表示。

6.4 能力继承与风险继承

蒸馏不仅会迁移能力,也可能迁移问题:

  • 教师模型的幻觉。
  • 教师模型的偏见。
  • 教师模型的安全绕过模式。
  • 教师模型在特定任务上的错误习惯。

因此蒸馏后仍然需要评测、安全测试和任务回归测试。


7. 经典案例

模型 / 方法主要做法特点
DistilBERT使用 soft target 和中间层表示蒸馏 BERT参数量更小,推理更快
TinyBERT预训练和微调两阶段蒸馏对齐 embedding、attention、hidden states 和 prediction layer
MiniLM蒸馏 self-attention 关系和值关系对教师模型结构依赖较低
MobileBERT通过瓶颈结构和逐层蒸馏压缩 BERT面向移动端推理
指令蒸馏用强模型生成指令数据训练小模型常见于 LLM 对话能力迁移

8. 与其他压缩方法的关系

蒸馏可以单独使用,也可以和其他压缩方法组合。

方法作用和蒸馏的关系
量化降低权重和激活精度常在蒸馏后用于部署压缩
剪枝删除不重要的参数或结构可先剪枝再蒸馏恢复性能
稀疏化让模型权重或激活更稀疏可和蒸馏共同降低推理成本
LoRA / SFT低成本任务适配可用于蒸馏后的任务微调

9. 总结

模型蒸馏的本质是:用更强模型提供更丰富的监督信号,让更小模型学习到超过 hard target 的信息。它适合降低部署成本、构建专用小模型和迁移大模型能力。

实际使用时,需要重点关注三件事:

  • 教师模型是否可靠。
  • 蒸馏数据是否高质量。
  • 学生模型容量是否足够承接目标能力。

对 LLM 来说,蒸馏已经不只是分类概率对齐,还包括指令数据生成、中间表示对齐、推理轨迹模仿和任务行为迁移。它通常会和量化、微调、评测一起组成完整的小模型部署流程。


参考