0. 简介
DBRX 是一个基于 Transformer 的仅解码器大型语言模型 (LLM),它使用下一个令牌预测进行训练。它使用细粒度的专家混合 (MoE) 架构,总参数为 132B,其中 36B 参数在任何输入上都处于活动状态。它是在 12T 文本和代码数据标记上预先训练的。与 Mixtral-8x7B 和 Grok-1 等其他开放式 MoE 模型相比,DBRX 是细粒度的,这意味着它使用更多的小型专家。DBRX 有 16 名专家并选择 4 名,而 Mixtral-8x7B 和 Grok-1 有 8 名专家并选择 2 名。这提供了 65 倍以上的专家组合,我们发现这提高了模型质量。DBRX 使用旋转位置编码 (RoPE)、门控线性单元 (GLU) 和分组查询注意力 (GQA)。它使用 tiktoken 存储库中提供的 GPT-4 分词器。我们根据详尽的评估和规模实验做出了这些选择。
DBRX 在精心策划的 12T 令牌上进行了预训练,最大上下文长度为 32K 令牌。我们估计,这些数据比我们用于预训练 MPT 系列模型的数据至少要好 2 倍。这个新数据集是使用全套 Databricks 工具开发的,包括用于数据处理的 Apache Spark™ 和 Databricks 笔记本,以及用于数据管理和治理的 Unity Catalog。我们使用课程学习进行预训练,在训练过程中以我们发现可以显着提高模型质量的方式改变数据组合。
输入:DBRX 仅接受基于文本的输入,并接受最多 32768 个令牌的上下文长度。
输出:DBRX 仅生成基于文本的输出。
模型地址:https://huggingface.co/databricks/dbrx-instruct
如果您有大约550G的内存或交换区,可以在CPU情况下加载模型,打印模型类型
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
from transformers import AutoTokenizer from modeling_dbrx import DbrxForCausalLM tokenizer = AutoTokenizer.from_pretrained("databricks/dbrx", trust_remote_code=True) model = DbrxForCausalLM.from_pretrained("databricks/dbrx", trust_remote_code=True) print(model) prompt = "Hey, are you conscious? Can you talk to me?" inputs = tokenizer(prompt, return_tensors="pt") # Generate generate_ids = model.generate(inputs.input_ids, max_length=30) tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
运行结果
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
DbrxForCausalLM( (transformer): DbrxModel( (wte): Embedding(100352, 6144) (blocks): ModuleList( (0-39): 40 x DbrxBlock( (norm_attn_norm): DbrxNormAttentionNorm( (norm_1): LayerNorm((6144,), eps=1e-05, elementwise_affine=True) (attn): DbrxAttention( (Wqkv): Linear(in_features=6144, out_features=8192, bias=False) (out_proj): Linear(in_features=6144, out_features=6144, bias=False) (rotary_emb): DbrxRotaryEmbedding() ) (norm_2): LayerNorm((6144,), eps=1e-05, elementwise_affine=True) ) (ffn): DbrxFFN( (router): DbrxRouter( (layer): Linear(in_features=6144, out_features=16, bias=False) ) (experts): DbrxExperts( (mlp): DbrxExpertGLU() ) ) ) ) (norm_f): LayerNorm((6144,), eps=1e-05, elementwise_affine=True) ) (lm_head): Linear(in_features=6144, out_features=100352, bias=False) ) |
1. 分析 model.safetensors.index.json 文件
此模型很大,从 model.safetensors.index.json 可以得知为 245GB:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
{ "metadata": { "total_size": 263193047040 }, "weight_map": { "lm_head.weight": "model-00061-of-00061.safetensors", "transformer.blocks.0.ffn.experts.mlp.v1": "model-00002-of-00061.safetensors", "transformer.blocks.0.ffn.experts.mlp.w1": "model-00001-of-00061.safetensors", "transformer.blocks.0.ffn.experts.mlp.w2": "model-00002-of-00061.safetensors", "transformer.blocks.0.ffn.router.layer.weight": "model-00001-of-00061.safetensors", "transformer.blocks.0.norm_attn_norm.attn.Wqkv.weight": "model-00001-of-00061.safetensors", "transformer.blocks.0.norm_attn_norm.attn.out_proj.weight": "model-00001-of-00061.safetensors", "transformer.blocks.0.norm_attn_norm.norm_1.weight": "model-00001-of-00061.safetensors", "transformer.blocks.0.norm_attn_norm.norm_2.weight": "model-00001-of-00061.safetensors", "transformer.blocks.1.ffn.experts.mlp.v1": "model-00003-of-00061.safetensors", "transformer.blocks.1.ffn.experts.mlp.w1": "model-00003-of-00061.safetensors", "transformer.blocks.1.ffn.experts.mlp.w2": "model-00004-of-00061.safetensors", "transformer.blocks.1.ffn.router.layer.weight": "model-00002-of-00061.safetensors", "transformer.blocks.1.norm_attn_norm.attn.Wqkv.weight": "model-00002-of-00061.safetensors", "transformer.blocks.1.norm_attn_norm.attn.out_proj.weight": "model-00002-of-00061.safetensors", "transformer.blocks.1.norm_attn_norm.norm_1.weight": "model-00002-of-00061.safetensors", "transformer.blocks.1.norm_attn_norm.norm_2.weight": "model-00002-of-00061.safetensors", ... "transformer.blocks.39.ffn.experts.mlp.v1": "model-00060-of-00061.safetensors", "transformer.blocks.39.ffn.experts.mlp.w1": "model-00060-of-00061.safetensors", "transformer.blocks.39.ffn.experts.mlp.w2": "model-00061-of-00061.safetensors", "transformer.blocks.39.ffn.router.layer.weight": "model-00059-of-00061.safetensors", "transformer.blocks.39.norm_attn_norm.attn.Wqkv.weight": "model-00059-of-00061.safetensors", "transformer.blocks.39.norm_attn_norm.attn.out_proj.weight": "model-00059-of-00061.safetensors", "transformer.blocks.39.norm_attn_norm.norm_1.weight": "model-00059-of-00061.safetensors", "transformer.blocks.39.norm_attn_norm.norm_2.weight": "model-00059-of-00061.safetensors", "transformer.norm_f.weight": "model-00061-of-00061.safetensors", "transformer.wte.weight": "model-00001-of-00061.safetensors" } } |
模型的权重主要有4个参数,和大多数的模型参数差不多,只是名字稍微有区别
- m_head.weight,
- transformer.blocks, 40个(0-39)
- transformer.norm_f.weight,
- transformer.wte.weight
1.1 权重的大小
如果想要了解上面的四个权重的大小,参照 model.safetensors.index.json 文件,最少需要三个 .safetensors 文件,代码如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import os from safetensors import safe_open safetensors_path = "databricks/dbrx/" safetensors_files = [ "model-00001-of-00061.safetensors", "model-00002-of-00061.safetensors", "model-00061-of-00061.safetensors", ] for file in safetensors_files: file_path = os.path.join(safetensors_path, file) with safe_open(file_path, 'pt') as f: for k in f.keys(): tensor = f.get_tensor(k) total_bytes = tensor.numel() * tensor.element_size() formatted_bytes_size = "{:,}".format(total_bytes) print(f"{k}, {tensor.size()}, {tensor.dtype}, {formatted_bytes_size} bytes") |
运行结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
python test05.py transformer.blocks.0.ffn.experts.mlp.w1, torch.Size([172032, 6144]), torch.bfloat16, 2,113,929,216 bytes transformer.blocks.0.ffn.router.layer.weight, torch.Size([16, 6144]), torch.bfloat16, 196,608 bytes transformer.blocks.0.norm_attn_norm.attn.Wqkv.weight, torch.Size([8192, 6144]), torch.bfloat16, 100,663,296 bytes transformer.blocks.0.norm_attn_norm.attn.out_proj.weight, torch.Size([6144, 6144]), torch.bfloat16, 75,497,472 bytes transformer.blocks.0.norm_attn_norm.norm_1.weight, torch.Size([6144]), torch.bfloat16, 12,288 bytes transformer.blocks.0.norm_attn_norm.norm_2.weight, torch.Size([6144]), torch.bfloat16, 12,288 bytes transformer.wte.weight, torch.Size([100352, 6144]), torch.bfloat16, 1,233,125,376 bytes transformer.blocks.0.ffn.experts.mlp.v1, torch.Size([172032, 6144]), torch.bfloat16, 2,113,929,216 bytes transformer.blocks.0.ffn.experts.mlp.w2, torch.Size([172032, 6144]), torch.bfloat16, 2,113,929,216 bytes transformer.blocks.1.ffn.router.layer.weight, torch.Size([16, 6144]), torch.bfloat16, 196,608 bytes transformer.blocks.1.norm_attn_norm.attn.Wqkv.weight, torch.Size([8192, 6144]), torch.bfloat16, 100,663,296 bytes transformer.blocks.1.norm_attn_norm.attn.out_proj.weight, torch.Size([6144, 6144]), torch.bfloat16, 75,497,472 bytes transformer.blocks.1.norm_attn_norm.norm_1.weight, torch.Size([6144]), torch.bfloat16, 12,288 bytes transformer.blocks.1.norm_attn_norm.norm_2.weight, torch.Size([6144]), torch.bfloat16, 12,288 bytes lm_head.weight, torch.Size([100352, 6144]), torch.bfloat16, 1,233,125,376 bytes transformer.blocks.39.ffn.experts.mlp.w2, torch.Size([172032, 6144]), torch.bfloat16, 2,113,929,216 bytes transformer.norm_f.weight, torch.Size([6144]), torch.bfloat16, 12,288 bytes |
可以看到所有数据类型均是 bfloat16, 占两个字节
我们可以计算 lm_head.weigh 的实际大小:100352 * 6144 * 2 = 1233125376
1.2 保存每个权重为一个文件
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import os from safetensors import safe_open import torch safetensors_path = "databricks/dbrx/" safetensors_files = [ "model-00001-of-00061.safetensors", "model-00002-of-00061.safetensors", "model-00061-of-00061.safetensors", ] tensors = {} for file in safetensors_files: file_path = os.path.join(safetensors_path, file) with safe_open(file_path, 'pt') as f: for k in f.keys(): tensor = f.get_tensor(k) tensors[k] = tensor directory = "pt" if not os.path.exists(directory): os.mkdir(directory) for k, tensor in tensors.items(): if any([x in k for x in ['wte', 'lm_head', 'norm_f', 'blocks.0']]): total_bytes = tensor.numel() * tensor.element_size() formatted_bytes_size = "{:,}".format(total_bytes) file_path = os.path.join("pt", k) print(f"save {file_path}.pt, {formatted_bytes_size} bytes") torch.save(tensor, f"{file_path}.pt") |
运行结果:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
python test06.py save pt/transformer.blocks.0.ffn.experts.mlp.w1.pt, 2,113,929,216 bytes save pt/transformer.blocks.0.ffn.router.layer.weight.pt, 196,608 bytes save pt/transformer.blocks.0.norm_attn_norm.attn.Wqkv.weight.pt, 100,663,296 bytes save pt/transformer.blocks.0.norm_attn_norm.attn.out_proj.weight.pt, 75,497,472 bytes save pt/transformer.blocks.0.norm_attn_norm.norm_1.weight.pt, 12,288 bytes save pt/transformer.blocks.0.norm_attn_norm.norm_2.weight.pt, 12,288 bytes save pt/transformer.wte.weight.pt, 1,233,125,376 bytes save pt/transformer.blocks.0.ffn.experts.mlp.v1.pt, 2,113,929,216 bytes save pt/transformer.blocks.0.ffn.experts.mlp.w2.pt, 2,113,929,216 bytes save pt/lm_head.weight.pt, 1,233,125,376 bytes save pt/transformer.norm_f.weight.pt, 12,288 bytes ls -l pt total 8773904 -rwxrwxrwx 1 tony tony 1233126591 Apr 3 20:29 lm_head.weight.pt -rwxrwxrwx 1 tony tony 2113930684 Apr 3 20:27 transformer.blocks.0.ffn.experts.mlp.v1.pt -rwxrwxrwx 1 tony tony 2113930684 Apr 3 20:25 transformer.blocks.0.ffn.experts.mlp.w1.pt -rwxrwxrwx 1 tony tony 2113930684 Apr 3 20:28 transformer.blocks.0.ffn.experts.mlp.w2.pt -rwxrwxrwx 1 tony tony 198101 Apr 3 20:25 transformer.blocks.0.ffn.router.layer.weight.pt -rwxrwxrwx 1 tony tony 100664829 Apr 3 20:26 transformer.blocks.0.norm_attn_norm.attn.Wqkv.weight.pt -rwxrwxrwx 1 tony tony 75499089 Apr 3 20:26 transformer.blocks.0.norm_attn_norm.attn.out_proj.weight.pt -rwxrwxrwx 1 tony tony 13806 Apr 3 20:26 transformer.blocks.0.norm_attn_norm.norm_1.weight.pt -rwxrwxrwx 1 tony tony 13806 Apr 3 20:26 transformer.blocks.0.norm_attn_norm.norm_2.weight.pt -rwxrwxrwx 1 tony tony 13622 Apr 3 20:29 transformer.norm_f.weight.pt -rwxrwxrwx 1 tony tony 1233126695 Apr 3 20:26 transformer.wte.weight.pt |
1.3 分类存储张量
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import os import torch from safetensors import safe_open # 定义保存路径和safetensors文件列表 safetensors_path = "databricks/dbrx/" safetensors_files = [ "model-00001-of-00061.safetensors", "model-00002-of-00061.safetensors", "model-00061-of-00061.safetensors", ] # 分类存储张量 group1_tensors = {} # 'wte', 'lm_head', 'norm_f' group2_tensors = {} # 'blocks.0' # 加载张量 for file in safetensors_files: file_path = os.path.join(safetensors_path, file) with safe_open(file_path, 'pt') as f: for k in f.keys(): tensor = f.get_tensor(k) if any(x in k for x in ['wte', 'lm_head', 'norm_f']): group1_tensors[k] = tensor elif 'blocks.0' in k: group2_tensors[k] = tensor # 确保保存目录存在 directory = "pt" os.makedirs(directory, exist_ok=True) # 保存两组张量到不同的文件 torch.save(group1_tensors, os.path.join(directory, "group1_tensors.pt")) torch.save(group2_tensors, os.path.join(directory, "group2_tensors.pt")) print(f"保存完成: group1_tensors.pt 包含 'wte', 'lm_head', 'norm_f'") print(f"保存完成: group2_tensors.pt 包含 'blocks.0'") |
运行结果:
1 2 3 4 5 6 7 8 |
python test07.py 保存完成: group1_tensors.pt 包含 'wte', 'lm_head', 'norm_f' 保存完成: group2_tensors.pt 包含 'blocks.0' ls -l pt total 17547772 -rwxrwxrwx 1 tony tony 2466264837 Apr 3 20:42 group1_tensors.pt -rwxrwxrwx 1 tony tony 6518173056 Apr 3 20:45 group2_tensors.pt |
1.4 加载张量
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import os import torch # 定义保存路径和pt文件列表 pt_path = "pt/" pt_files = [ "group1_tensors.pt", "group2_tensors.pt" ] # 加载张量 for file in pt_files: file_path = os.path.join(pt_path, file) tensors = torch.load(file_path) for k, tensor in tensors.items(): total_bytes = tensor.numel() * tensor.element_size() formatted_bytes_size = "{:,}".format(total_bytes) print(f"{file_path}, {k}, {tensor.size()}, {tensor.dtype}, {formatted_bytes_size} bytes") |
运行结果:
1 2 3 4 5 6 7 8 9 10 11 12 |
python test08.py pt/group1_tensors.pt, transformer.wte.weight, torch.Size([100352, 6144]), torch.bfloat16, 1,233,125,376 bytes pt/group1_tensors.pt, lm_head.weight, torch.Size([100352, 6144]), torch.bfloat16, 1,233,125,376 bytes pt/group1_tensors.pt, transformer.norm_f.weight, torch.Size([6144]), torch.bfloat16, 12,288 bytes pt/group2_tensors.pt, transformer.blocks.0.ffn.experts.mlp.w1, torch.Size([172032, 6144]), torch.bfloat16, 2,113,929,216 bytes pt/group2_tensors.pt, transformer.blocks.0.ffn.router.layer.weight, torch.Size([16, 6144]), torch.bfloat16, 196,608 bytes pt/group2_tensors.pt, transformer.blocks.0.norm_attn_norm.attn.Wqkv.weight, torch.Size([8192, 6144]), torch.bfloat16, 100,663,296 bytes pt/group2_tensors.pt, transformer.blocks.0.norm_attn_norm.attn.out_proj.weight, torch.Size([6144, 6144]), torch.bfloat16, 75,497,472 bytes pt/group2_tensors.pt, transformer.blocks.0.norm_attn_norm.norm_1.weight, torch.Size([6144]), torch.bfloat16, 12,288 bytes pt/group2_tensors.pt, transformer.blocks.0.norm_attn_norm.norm_2.weight, torch.Size([6144]), torch.bfloat16, 12,288 bytes pt/group2_tensors.pt, transformer.blocks.0.ffn.experts.mlp.v1, torch.Size([172032, 6144]), torch.bfloat16, 2,113,929,216 bytes pt/group2_tensors.pt, transformer.blocks.0.ffn.experts.mlp.w2, torch.Size([172032, 6144]), torch.bfloat16, 2,113,929,216 bytes |
2. 了解 transformer.blocks.0
从前面的分析来看,每一个 block 含有有 ffn 和 norm_attn_norm