一、描述:
- 将两个模型的知识提炼到一个新模型(学生模型)中,使其兼具两者的优点。
使用 AutoModelForCausalLM
来实现从两个教师模型(teacher1
和 teacher2
)到学生模型的知识蒸馏(Knowledge Distillation)。在因果语言建模(Causal Language Modeling)中,模型的任务是预测序列中的下一个词,因此训练和损失函数与序列分类任务有所不同。
下面,我将提供完整的代码实现,并对每个步骤进行详细解释。
二、步骤概述
- 环境准备:安装必要的库。
- 加载教师模型和数据集:从 Hugging Face Hub 加载教师模型和数据集。
- 初始化学生模型:基于相同的基础模型,创建一个较小的学生模型。
- 定义蒸馏损失函数:结合教师模型输出和真实标签,定义总损失函数。
- 数据预处理和加载:对数据进行预处理,并创建数据加载器。
- 设置训练参数:定义训练超参数,如学习率、批量大小等。
- 创建训练循环:实现训练过程,包括前向传播、计算损失、反向传播和参数更新。
- 保存学生模型:训练完成后,保存模型。
- 完整代码汇总:提供完整的代码。
1. 环境准备
首先,确保您已经安装了必要的库:
1 |
pip install transformers datasets torch |
2. 加载教师模型和数据集
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
from transformers import AutoModelForCausalLM, AutoTokenizer from datasets import load_dataset import torch from torch.utils.data import DataLoader from tqdm.auto import tqdm # 检查设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 定义模型名称和数据集名称 teacher1_name = 'your-teacher1-model-name' # 替换为实际的模型名称 teacher2_name = 'your-teacher2-model-name' student_name = 'your-student-model-name' # 可以使用相同的基础模型 dataset_name = 'your-dataset-name' # 替换为实际的数据集名称 # 加载分词器 tokenizer = AutoTokenizer.from_pretrained(teacher1_name) # 加载教师模型 teacher1 = AutoModelForCausalLM.from_pretrained(teacher1_name).to(device) teacher2 = AutoModelForCausalLM.from_pretrained(teacher2_name).to(device) # 加载数据集 dataset = load_dataset(dataset_name) |
注意:请将 'your-teacher1-model-name'
、'your-teacher2-model-name'
和 'your-dataset-name'
替换为实际的模型和数据集名称。
3. 初始化学生模型
1 2 3 4 5 6 7 8 |
from transformers import AutoConfig # 从教师模型的配置加载,并修改层数 student_config = AutoConfig.from_pretrained(teacher1_name) student_config.num_hidden_layers = 6 # 例如,将层数减半 # 从头初始化学生模型(不加载预训练权重) student = AutoModelForCausalLM(config=student_config).to(device) |
4. 定义蒸馏损失函数
对于因果语言模型的知识蒸馏,我们将使用以下损失函数:
- 交叉熵损失(Cross-Entropy Loss):用于学生模型的输出与真实标签之间的损失。
- KL 散度损失(Kullback-Leibler Divergence):用于学生模型与教师模型输出分布之间的损失。
5. 数据预处理和加载
1 2 3 4 5 6 7 8 9 10 11 12 13 |
# 定义数据预处理函数 def preprocess_function(examples): # 对文本进行拼接和分词 inputs = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=128) inputs['labels'] = inputs['input_ids'].copy() return inputs # 对数据集进行预处理 tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset['train'].column_names) # 创建数据加载器 train_loader = DataLoader(tokenized_dataset['train'], batch_size=8, shuffle=True) eval_loader = DataLoader(tokenized_dataset['validation'], batch_size=8) |
说明:
- 我们将
labels
设置为input_ids
的副本,以便模型学习预测下一个词。 - 根据您的资源,调整
max_length
和batch_size
。
6. 设置训练参数
1 2 3 4 |
from transformers import AdamW # 定义优化器 optimizer = AdamW(student.parameters(), lr=5e-5) |
7. 创建训练循环
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch.nn.functional as F # 超参数 num_epochs = 3 temperature = 2.0 alpha = 0.5 # 蒸馏损失与真实标签损失的权重平衡 for epoch in range(num_epochs): student.train() progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}") for batch in progress_bar: optimizer.zero_grad() # 将数据移动到设备 input_ids = torch.stack(batch['input_ids']).to(device) attention_mask = torch.stack(batch['attention_mask']).to(device) labels = torch.stack(batch['labels']).to(device) # 获取教师模型的输出 with torch.no_grad(): outputs_t1 = teacher1(input_ids=input_ids, attention_mask=attention_mask) outputs_t2 = teacher2(input_ids=input_ids, attention_mask=attention_mask) # 平均两个教师模型的 logits logits_teacher = (outputs_t1.logits + outputs_t2.logits) / 2 # 学生模型的输出 outputs_student = student(input_ids=input_ids, attention_mask=attention_mask) logits_student = outputs_student.logits # 计算蒸馏损失(KL 散度) loss_kd = F.kl_div( input=F.log_softmax(logits_student / temperature, dim=-1), target=F.softmax(logits_teacher / temperature, dim=-1), reduction='batchmean' ) * (temperature ** 2) # 计算真实标签的交叉熵损失 loss_ce = F.cross_entropy(logits_student.view(-1, logits_student.size(-1)), labels.view(-1), ignore_index=tokenizer.pad_token_id) # 总损失 loss = alpha * loss_ce + (1 - alpha) * loss_kd # 反向传播和优化 loss.backward() optimizer.step() # 更新进度条 progress_bar.set_postfix({'loss': loss.item()}) # 每个 epoch 结束后进行评估 student.eval() total_loss = 0 with torch.no_grad(): for batch in eval_loader: input_ids = torch.stack(batch['input_ids']).to(device) attention_mask = torch.stack(batch['attention_mask']).to(device) labels = torch.stack(batch['labels']).to(device) outputs = student(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss total_loss += loss.item() avg_loss = total_loss / len(eval_loader) print(f"Validation Loss after epoch {epoch+1}: {avg_loss:.4f}") |
解释:
- 教师模型输出:使用
torch.no_grad()
,避免计算梯度。 - 蒸馏损失(KL 散度):计算学生模型和教师模型输出分布之间的差异。
- 真实标签损失(交叉熵):计算学生模型输出与真实标签之间的损失。
- 总损失:蒸馏损失和真实标签损失的加权和。
- 忽略填充:在计算交叉熵损失时,使用
ignore_index=tokenizer.pad_token_id
,忽略填充标记。
8. 保存学生模型
1 2 3 |
# 保存训练好的学生模型 student.save_pretrained('path_to_save_student_model') tokenizer.save_pretrained('path_to_save_student_model') |
9. 完整代码汇总
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, AdamW from datasets import load_dataset import torch import torch.nn.functional as F from torch.utils.data import DataLoader from tqdm.auto import tqdm # 检查设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 定义模型名称和数据集名称 teacher1_name = 'your-teacher1-model-name' # 替换为实际的模型名称 teacher2_name = 'your-teacher2-model-name' student_name = 'your-student-model-name' # 使用教师模型的名称或其他 dataset_name = 'your-dataset-name' # 替换为实际的数据集名称 # 加载分词器 tokenizer = AutoTokenizer.from_pretrained(teacher1_name) # 加载教师模型 teacher1 = AutoModelForCausalLM.from_pretrained(teacher1_name).to(device) teacher2 = AutoModelForCausalLM.from_pretrained(teacher2_name).to(device) # 加载数据集 dataset = load_dataset(dataset_name) # 定义数据预处理函数 def preprocess_function(examples): inputs = tokenizer(examples['text'], truncation=True, padding='max_length', max_length=128) inputs['labels'] = inputs['input_ids'].copy() return inputs # 对数据集进行预处理 tokenized_dataset = dataset.map(preprocess_function, batched=True, remove_columns=dataset['train'].column_names) # 创建数据加载器 train_loader = DataLoader(tokenized_dataset['train'], batch_size=8, shuffle=True) eval_loader = DataLoader(tokenized_dataset['validation'], batch_size=8) # 从教师模型的配置加载,并修改层数 student_config = AutoConfig.from_pretrained(teacher1_name) student_config.num_hidden_layers = 6 # 例如,将层数减半 # 从头初始化学生模型(不加载预训练权重) student = AutoModelForCausalLM(config=student_config).to(device) # 定义优化器 optimizer = AdamW(student.parameters(), lr=5e-5) # 超参数 num_epochs = 3 temperature = 2.0 alpha = 0.5 # 蒸馏损失与真实标签损失的权重平衡 for epoch in range(num_epochs): student.train() progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}") for batch in progress_bar: optimizer.zero_grad() # 将数据移动到设备 input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) # 获取教师模型的输出 with torch.no_grad(): outputs_t1 = teacher1(input_ids=input_ids, attention_mask=attention_mask) outputs_t2 = teacher2(input_ids=input_ids, attention_mask=attention_mask) logits_teacher = (outputs_t1.logits + outputs_t2.logits) / 2 # 学生模型的输出 outputs_student = student(input_ids=input_ids, attention_mask=attention_mask) logits_student = outputs_student.logits # 计算蒸馏损失(KL 散度) loss_kd = F.kl_div( input=F.log_softmax(logits_student / temperature, dim=-1), target=F.softmax(logits_teacher / temperature, dim=-1), reduction='batchmean' ) * (temperature ** 2) # 计算真实标签的交叉熵损失 loss_ce = F.cross_entropy(logits_student.view(-1, logits_student.size(-1)), labels.view(-1), ignore_index=tokenizer.pad_token_id) # 总损失 loss = alpha * loss_ce + (1 - alpha) * loss_kd # 反向传播和优化 loss.backward() optimizer.step() # 更新进度条 progress_bar.set_postfix({'loss': loss.item()}) # 每个 epoch 结束后进行评估 student.eval() total_loss = 0 with torch.no_grad(): for batch in eval_loader: input_ids = batch['input_ids'].to(device) attention_mask = batch['attention_mask'].to(device) labels = batch['labels'].to(device) outputs = student(input_ids=input_ids, attention_mask=attention_mask, labels=labels) loss = outputs.loss total_loss += loss.item() avg_loss = total_loss / len(eval_loader) print(f"Validation Loss after epoch {epoch+1}: {avg_loss:.4f}") # 保存训练好的学生模型 student.save_pretrained('path_to_save_student_model') tokenizer.save_pretrained('path_to_save_student_model') |
三、注意事项和解释
1. 数据加载和预处理
- 拼接和分词:在预处理函数中,我们对文本进行分词,并设置
labels
等于input_ids
的副本。 - 删除原始列:使用
remove_columns
删除原始数据集的列,避免不必要的数据冗余。
2. 处理批次数据
torch.stack
:在训练循环中,我们使用torch.stack
将批次中的张量组合起来。- 张量形状:确保输入的张量形状正确,匹配模型的预期输入。
3. 损失函数
- KL 散度损失:用于度量学生模型输出分布与教师模型输出分布之间的差异。
- 交叉熵损失:用于度量学生模型输出与真实标签之间的差异。
- 温度参数:通过温度参数软化概率分布,以更好地学习教师模型的知识。
4. 超参数调整
num_epochs
:根据数据集大小和模型收敛情况调整训练轮数。learning_rate
:学习率对训练稳定性和收敛速度有重要影响,可根据需要调整。temperature
:常用值为 1 到 5,需根据实验效果调整。alpha
:用于平衡蒸馏损失和真实标签损失,取值范围在 0 到 1 之间。
5. 资源要求
- 显存占用:因果语言模型的输出维度为词汇表大小,可能导致显存占用较高。可通过减小
batch_size
或使用梯度累积来缓解。 - 计算时间:训练时间可能较长,建议使用 GPU 加速。
6. 法律和版权
- 模型和数据集许可:在使用预训练模型和数据集时,务必遵守其许可协议和使用条款。
四、总结
通过上述代码,我们实现了使用 AutoModelForCausalLM
的知识蒸馏过程,将两个教师模型的知识蒸馏到一个较小的学生模型中。该学生模型能够在保持性能的同时,减小模型大小,提高推理速度。
如果您有任何疑问或需要进一步的帮助,请随时告诉我!