模型蒸馏
模型蒸馏(Knowledge Distillation, KD)是一种模型压缩和能力迁移方法。它用一个能力更强的教师模型(Teacher Model)指导一个更小、更快、更便宜的学生模型(Student Model)训练,让学生模型在较低推理成本下尽量保留教师模型的能力。
在大语言模型场景中,蒸馏常用于:
- 降低推理成本和显存占用。
- 把大模型能力迁移到小模型、端侧模型或专用模型。
- 为特定任务训练更快的专用模型,例如客服、代码、数学、检索问答。
- 配合量化、剪枝、稀疏化等方法进一步压缩部署成本。
1. 核心思想
传统训练通常使用真实标签,也就是 hard target。
例如分类任务中,正确类别是 1,其他类别是 0。模型只知道“正确答案是什么”,但不知道其他错误答案之间的相似程度。
蒸馏训练会额外使用教师模型输出的 soft target。
soft target 是教师模型对所有候选类别或 token 的概率分布,它包含更多信息:
- 哪些错误答案更接近正确答案。
- 哪些 token 或类别之间更容易混淆。
- 教师模型在不同候选输出上的偏好。
- 教师模型对样本难度和不确定性的判断。
因此,学生模型不是只学习最终答案,而是在学习教师模型的输出分布、推理偏好和中间表示。
2. Hard Target 与 Soft Target
| 训练方式 | 监督信号 | 学到的信息 | 典型目标 |
|---|---|---|---|
| 传统训练 | 真实标签 hard target | 正确答案本身 | 最大化 ground truth 概率 |
| 蒸馏训练 | 教师模型提供的 soft target | 教师模型的输出偏好、类别关系和中间表示 | 对齐教师模型的行为或表示 |
这里的 soft target 可以理解为“教师模型提供的更丰富监督信号”,不只限于最终概率分布。在 LLM 中,它可以来自:
- 教师模型输出的 logits。
- 教师模型经过 softmax 后的概率分布。
- 教师模型生成的高质量回答。
- 教师模型的中间层 hidden states 或 attention patterns。
3. Temperature 的作用
蒸馏通常会在 softmax 中加入温度参数 T:
其中:
- 是第 i 个类别或 token 的 logit。
- 是 temperature。
- 是经过温度缩放后的输出概率。
Temperature 会改变概率分布的平滑程度:
| 温度 | 效果 | 含义 |
|---|---|---|
| 原始 softmax | 不做额外平滑 | |
| 分布更平滑 | 放大负标签中的相对信息 | |
| 分布更陡峭 | 更强调最高概率类别 | |
| 接近均匀分布 | 类别差异被大幅削弱 | |
| 接近 argmax | 退化为只关注最高概率 |
蒸馏里常用较高的 T,是因为教师模型对负标签的排序也有价值。
例如同样不是正确答案,概率 0.18 的类别通常比概率 0.001 的类别更接近正确语义。学生模型学习这种“相似性结构”,往往比只学习 hard target 更有效。
4. 常见训练目标
通用蒸馏目标通常由两部分组成:
| 损失 | 作用 |
|---|---|
| Distillation Loss | 让学生模型输出接近教师模型 soft target |
| Student Loss | 让学生模型仍然学习真实标签 hard target |
常见写法是把两种损失加权:
total_loss = α * distillation_loss + β * student_loss
常用指标包括:
- KL 散度:对齐教师和学生的概率分布。
- MSE:对齐 logits、hidden states 或 attention matrices。
- Cross Entropy:继续使用真实标签训练学生模型。
如果直接 match logits,而不经过 softmax,可以把目标理解为让学生模型直接拟合教师模型的原始输出。直接 match logits 可以看作在 情况下的一种特殊近似。
5. 典型方法
5.1 基于输出的蒸馏
基于输出的蒸馏只关注教师模型和学生模型的最终输出。
常见做法:
- 用教师模型生成 soft labels。
- 用 KL 散度对齐教师和学生的输出概率分布。
- 对 LLM 来说,也可以用教师模型生成高质量回答,再用这些回答训练学生模型。
适合场景:
- 学生模型结构和教师模型不同。
- 只能访问教师模型 API,不能访问中间层。
- 想快速蒸馏特定任务能力。
5.2 基于中间层特征的蒸馏
这类方法会对齐教师模型和学生模型的内部表示。
常见对齐对象:
- Embedding。
- Hidden states。
- Attention scores。
- Value relation。
- 最后一层或多层中间表示。
代表方法:
- TinyBERT:在预训练和任务微调阶段对齐 embedding、attention、hidden states 和 prediction layer。
- MiniLM:重点蒸馏最后一层自注意力关系和值关系,减少对完整教师结构的依赖。
适合场景:
- 可以访问教师模型内部结构。
- 学生模型和教师模型结构相近,或可以设计层映射关系。
- 希望学生模型不仅模仿输出,还模仿推理表示。
5.3 分阶段蒸馏
分阶段蒸馏把蒸馏拆成多个训练阶段。
常见阶段:
- 预训练阶段蒸馏:先让学生模型学习通用语言表示。
- 任务微调阶段蒸馏:再让学生模型学习特定任务表现。
- 指令蒸馏:用教师模型生成指令数据,训练学生模型的对话和任务执行能力。
适合场景:
- 需要兼顾通用能力和任务能力。
- 数据量较大,训练周期较长。
- 想训练一个可复用的小模型基座。
5.4 在线蒸馏与自蒸馏
在线蒸馏(Online Distillation)中,教师模型和学生模型可以同时训练,甚至多个模型互相学习。
自蒸馏(Self-Distillation)中,模型从自身的不同层、不同阶段或历史 checkpoint 中学习。
适合场景:
- 没有固定教师模型。
- 想提升模型训练稳定性。
- 希望在不引入额外大模型的情况下提升小模型表现。
6. LLM 蒸馏中的关键问题
6.1 数据选择
蒸馏效果高度依赖数据质量。常见数据来源包括:
- 原始训练数据。
- 教师模型生成的合成数据。
- 人工标注数据。
- 高质量指令数据。
- 领域任务数据,例如代码、数学、医疗、客服。
如果教师模型生成的数据有幻觉、偏见或格式问题,学生模型也会继承这些问题。
6.2 教师模型与学生模型差距
教师太强、学生太小,学生可能无法吸收复杂能力。常见处理方式:
- 选择容量差距适中的教师和学生。
- 使用中间教师模型逐级蒸馏。
- 降低任务难度或只蒸馏部分能力。
- 针对特定任务做专用蒸馏。
6.3 层映射
如果对齐中间层,需要设计教师层和学生层的对应关系。例如:
- 每隔几层对齐一次。
- 只对齐最后几层。
- 按层数比例映射。
- 用投影层处理 hidden size 不一致的问题。
层映射设计不当,可能会让学生模型学习到不合适的表示。
6.4 能力继承与风险继承
蒸馏不仅会迁移能力,也可能迁移问题:
- 教师模型的幻觉。
- 教师模型的偏见。
- 教师模型的安全绕过模式。
- 教师模型在特定任务上的错误习惯。
因此蒸馏后仍然需要评测、安全测试和任务回归测试。
7. 经典案例
| 模型 / 方法 | 主要做法 | 特点 |
|---|---|---|
| DistilBERT | 使用 soft target 和中间层表示蒸馏 BERT | 参数量更小,推理更快 |
| TinyBERT | 预训练和微调两阶段蒸馏 | 对齐 embedding、attention、hidden states 和 prediction layer |
| MiniLM | 蒸馏 self-attention 关系和值关系 | 对教师模型结构依赖较低 |
| MobileBERT | 通过瓶颈结构和逐层蒸馏压缩 BERT | 面向移动端推理 |
| 指令蒸馏 | 用强模型生成指令数据训练小模型 | 常见于 LLM 对话能力迁移 |
8. 与其他压缩方法的关系
蒸馏可以单独使用,也可以和其他压缩方法组合。
| 方法 | 作用 | 和蒸馏的关系 |
|---|---|---|
| 量化 | 降低权重和激活精度 | 常在蒸馏后用于部署压缩 |
| 剪枝 | 删除不重要的参数或结构 | 可先剪枝再蒸馏恢复性能 |
| 稀疏化 | 让模型权重或激活更稀疏 | 可和蒸馏共同降低推理成本 |
| LoRA / SFT | 低成本任务适配 | 可用于蒸馏后的任务微调 |
9. 总结
模型蒸馏的本质是:用更强模型提供更丰富的监督信号,让更小模型学习到超过 hard target 的信息。它适合降低部署成本、构建专用小模型和迁移大模型能力。
实际使用时,需要重点关注三件事:
- 教师模型是否可靠。
- 蒸馏数据是否高质量。
- 学生模型容量是否足够承接目标能力。
对 LLM 来说,蒸馏已经不只是分类概率对齐,还包括指令数据生成、中间表示对齐、推理轨迹模仿和任务行为迁移。它通常会和量化、微调、评测一起组成完整的小模型部署流程。