推理模型蒸馏实战:用小模型复制大模型的思维链能力
2026年,推理模型(Reasoning Model)已成为AI领域的核心范式。从OpenAI o3到DeepSeek R1,这些模型通过思维链(Chain-of-Thought)在数学、编程和复杂推理任务上展现出惊人能力。但这些模型动辄数百B参数,部署成本高昂。模型蒸馏成为将推理能力迁移到小模型的关键技术路径。
推理蒸馏的核心原理
推理蒸馏与传统知识蒸馏的关键区别在于:我们不仅蒸馏最终输出的概率分布,更要蒸馏中间推理过程。学生模型需要学会"像老师一样思考"。
2026年主流蒸馏方案对比
DeepSeek R1 蒸馏方案
DeepSeek团队在2025年底开源了R1蒸馏模型系列,到2026年已更新至R1-0528版本。其核心思路是用R1-671B生成80万条高质量CoT数据,然后分别蒸馏到1.5B到70B的不同规模模型上。
关键发现:蒸馏后的Qwen3-8B在MATH-500上达到91.2%,超越了原始Qwen3-8B的78.5%,甚至接近Qwen3-72B的水平。
Qwen3 Reasoning 蒸馏
阿里Qwen团队在2026年Q2发布了Qwen3-235B-A22B-Thinking,采用MoE架构。其蒸馏策略分为两阶段:
- 阶段一:用Qwen3-235B生成思维链数据,对Qwen3-8B进行SFT
- 阶段二:使用强化学习(GRPO)进一步优化推理路径
OpenAI o3-mini 蒸馏
OpenAI在2026年初提供了o3的蒸馏API,允许开发者用o3的推理输出微调GPT-4o-mini。这是目前唯一的闭源到闭源蒸馏方案,通过API直接完成。
基准测试对比
实战代码:LoRA蒸馏完整流程
环境准备
# 安装Axolotl v0.9.0+ (2026年版本)
pip install axolotl[flash-attn]==0.9.2
pip install trl==0.17.0 transformers==4.52.0
# 或使用TRL直接训练
pip install trl>=0.17.0 peft>=0.15.0
方法一:使用Axolotl进行蒸馏
# distill_config.yaml
base_model: Qwen/Qwen3-8B
model_type: AutoModelForCausalLM
adapter: lora
lora_r: 64
lora_alpha: 128
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
load_in_4bit: true
bf16: auto
datasets:
- path: 51domino/deepseek-r1-distill-800k
type: sharegpt
conversation: chatml
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
micro_batch_size: 2
gradient_accumulation_steps: 8
num_epochs: 3
learning_rate: 2e-4
lr_scheduler: cosine
warmup_ratio: 0.05
optimizer: adamw_8bit
output_dir: ./qwen3-8b-r1-distill
logging_steps: 10
save_strategy: steps
save_steps: 200
eval_steps: 200
# 启动训练
accelerate launch -m axolotl.cli.train distill_config.yaml
方法二:使用TRL进行知识蒸馏
# distill_with_trl.py
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import LoraConfig, get_peft_model
from trl import SFTTrainer, SFTConfig
from datasets import load_dataset
# 加载学生模型
student_model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen3-8B",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
tokenizer.pad_token = tokenizer.eos_token
# LoRA配置
lora_config = LoraConfig(
r=64,
lora_alpha=128,
lora_dropout=0.05,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM",
)
student_model = get_peft_model(student_model, lora_config)
student_model.print_trainable_parameters()
# trainable params: 83,886,080 || all params: 8,113,897,472 || trainable%: 1.034%
# 加载蒸馏数据集(教师模型生成的CoT数据)
dataset = load_dataset("51domino/deepseek-r1-distill-800k", split="train")
def format_example(example):
"""将CoT数据格式化为对话格式"""
messages = [
{"role": "system", "content": "You are a helpful assistant that thinks step by step."},
{"role": "user", "content": example["question"]},
{"role": "assistant", "content": f"<think>\n{example['reasoning']}\n</think>\n\n{example['answer']}"},
]
return {"messages": messages}
formatted_dataset = dataset.map(format_example, remove_columns=dataset.column_names)
# 训练配置
training_config = SFTConfig(
output_dir="./qwen3-8b-r1-distill-trl",
num_train_epochs=3,
per_device_train_batch_size=2,
gradient_accumulation_steps=8,
learning_rate=2e-4,
lr_scheduler_type="cosine",
warmup_ratio=0.05,
bf16=True,
logging_steps=10,
save_strategy="steps",
save_steps=200,
max_seq_length=8192,
dataset_text_field=None,
packing=False,
report_to="wandb",
)
trainer = SFTTrainer(
model=student_model,
args=training_config,
train_dataset=formatted_dataset,
processing_class=tokenizer,
)
trainer.train()
trainer.save_model("./qwen3-8b-r1-distill-final")
进阶:添加KL散度蒸馏损失
# advanced_distill.py - 同时使用教师模型logits
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM
teacher_model = AutoModelForCausalLM.from_pretrained(
"deepseek-ai/DeepSeek-R1-0528",
torch_dtype=torch.bfloat16,
device_map="auto",
attn_implementation="flash_attention_2",
)
class DistillationLoss(torch.nn.Module):
def __init__(self, temperature=2.0, alpha=0.5):
super().__init__()
self.temperature = temperature
self.alpha = alpha # KL loss权重
def forward(self, student_logits, teacher_logits, labels):
# 标准交叉熵损失
ce_loss = F.cross_entropy(
student_logits.view(-1, student_logits.size(-1)),
labels.view(-1),
ignore_index=-100,
)
# KL散度损失 - 蒸馏关键
student_log_probs = F.log_softmax(
student_logits / self.temperature, dim=-1
)
teacher_probs = F.softmax(
teacher_logits / self.temperature, dim=-1
)
kl_loss = F.kl_div(
student_log_probs, teacher_probs,
reduction="batchmean"
) * (self.temperature ** 2)
# 组合损失
total_loss = (1 - self.alpha) * ce_loss + self.alpha * kl_loss
return total_loss, ce_loss, kl_loss
蒸馏最佳实践
部署与推理优化
蒸馏完成后,使用vLLM 0.8.x或SGLang进行高效推理部署:
# 推理部署示例
from vllm import LLM, SamplingParams
llm = LLM(
model="./qwen3-8b-r1-distill-final",
dtype="bfloat16",
max_model_len=8192,
gpu_memory_utilization=0.9,
)
sampling = SamplingParams(
temperature=0.6,
top_p=0.95,
max_tokens=4096,
)
prompts = ["证明√2是无理数"]
outputs = llm.generate(prompts, sampling)
print(outputs[0].outputs[0].text)
总结
推理模型蒸馏是2026年最实用的模型优化技术之一。通过合理的数据准备、LoRA配置和训练策略,8B规模的蒸馏模型可以在推理任务上达到接近70B原始模型的水平。这意味着更低的推理成本和更广泛的部署可能性。
本文由51domino.com团队撰写,代码已在GitHub开源。如有问题欢迎在评论区交流。