7. 模型转换
把模型参数,转换为自己想要的模型参数,自己定义模型参数,参照 https://github.com/karpathy/llama2.c 项目下的model.py文件,命名为 model.py,文件保存到 newsrc 目录下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 |
import math import struct import inspect from dataclasses import dataclass from typing import Any, Optional, Tuple import numpy as np import torch import torch.nn.functional as F from torch import nn @dataclass class ModelArgs: # default hyperparameters for the Llama 7B model dim: int = 4096 n_layers: int = 32 n_heads: int = 32 n_kv_heads: Optional[int] = None vocab_size: int = 32000 hidden_dim: Optional[int] = None multiple_of: int = 256 # MLP hidden layer size will be multiple of norm_eps: float = 1e-5 max_seq_len: int = 2048 dropout: float = 0.0 class RMSNorm(torch.nn.Module): def __init__(self, dim: int, eps: float): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) return output * self.weight def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) # type: ignore freqs = torch.outer(t, freqs).float() # type: ignore freqs_cos = torch.cos(freqs) # real part freqs_sin = torch.sin(freqs) # imaginary part return freqs_cos, freqs_sin def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(shape) def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: # reshape xq and xk to match the complex representation xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1) xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1) # reshape freqs_cos and freqs_sin for broadcasting freqs_cos = reshape_for_broadcast(freqs_cos, xq_r) freqs_sin = reshape_for_broadcast(freqs_sin, xq_r) # apply rotation using real numbers xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos # flatten last two dimensions xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3) xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" bs, slen, n_kv_heads, head_dim = x.shape if n_rep == 1: return x return ( x[:, :, :, None, :] .expand(bs, slen, n_kv_heads, n_rep, head_dim) .reshape(bs, slen, n_kv_heads * n_rep, head_dim) ) class Attention(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads assert args.n_heads % self.n_kv_heads == 0 model_parallel_size = 1 self.n_local_heads = args.n_heads // model_parallel_size self.n_local_kv_heads = self.n_kv_heads // model_parallel_size self.n_rep = self.n_local_heads // self.n_local_kv_heads self.head_dim = args.dim // args.n_heads self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False) self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False) self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False) self.attn_dropout = nn.Dropout(args.dropout) self.resid_dropout = nn.Dropout(args.dropout) self.dropout = args.dropout # use flash attention or a manual implementation? self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') if not self.flash: print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0") mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf")) mask = torch.triu(mask, diagonal=1) self.register_buffer("mask", mask) def forward( self, x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor, ): bsz, seqlen, _ = x.shape # QKV xq, xk, xv = self.wq(x), self.wk(x), self.wv(x) xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim) xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim) # RoPE relative positional embeddings xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin) # grouped multiquery attention: expand out keys and values xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) # make heads into a batch dimension xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xk = xk.transpose(1, 2) xv = xv.transpose(1, 2) # flash implementation if self.flash: output = torch.nn.functional.scaled_dot_product_attention(xq, xk, xv, attn_mask=None, dropout_p=self.dropout if self.training else 0.0, is_causal=True) else: # manual implementation scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim) assert hasattr(self, 'mask') scores = scores + self.mask[:, :, :seqlen, :seqlen] # (bs, n_local_heads, seqlen, cache_len + seqlen) scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = self.attn_dropout(scores) output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim) # restore time as batch dimension and concat heads output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) # final projection into the residual stream output = self.wo(output) output = self.resid_dropout(output) return output class FeedForward(nn.Module): def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float): super().__init__() if hidden_dim is None: hidden_dim = 4 * dim hidden_dim = int(2 * hidden_dim / 3) hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of) self.w1 = nn.Linear(dim, hidden_dim, bias=False) self.w2 = nn.Linear(hidden_dim, dim, bias=False) self.w3 = nn.Linear(dim, hidden_dim, bias=False) self.dropout = nn.Dropout(dropout) def forward(self, x): return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x))) class TransformerBlock(nn.Module): def __init__(self, layer_id: int, args: ModelArgs): super().__init__() self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads self.attention = Attention(args) self.feed_forward = FeedForward( dim=args.dim, hidden_dim=args.hidden_dim, multiple_of=args.multiple_of, dropout=args.dropout, ) self.layer_id = layer_id self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) def forward(self, x, freqs_cos, freqs_sin): h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin) out = h + self.feed_forward.forward(self.ffn_norm(h)) return out class Transformer(nn.Module): last_loss: Optional[torch.Tensor] def __init__(self, params: ModelArgs): super().__init__() self.params = params self.vocab_size = params.vocab_size self.n_layers = params.n_layers self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim) self.dropout = nn.Dropout(params.dropout) self.layers = torch.nn.ModuleList() for layer_id in range(params.n_layers): self.layers.append(TransformerBlock(layer_id, params)) self.norm = RMSNorm(params.dim, eps=params.norm_eps) self.output = nn.Linear(params.dim, params.vocab_size, bias=False) # share the unembedding parameters with the embedding parameters self.tok_embeddings.weight = self.output.weight # https://paperswithcode.com/method/weight-tying # some useful precompute for the RoPE relative positional embeddings freqs_cos, freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len) self.register_buffer("freqs_cos", freqs_cos, persistent=False) self.register_buffer("freqs_sin", freqs_sin, persistent=False) # init all weights self.apply(self._init_weights) # apply special scaled init to the residual projections, per GPT-2 paper for pn, p in self.named_parameters(): if pn.endswith('w3.weight') or pn.endswith('wo.weight'): torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers)) # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor. self.last_loss = None def _init_weights(self, module): if isinstance(module, nn.Linear): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if module.bias is not None: torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) def forward(self, tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor: _bsz, seqlen = tokens.shape h = self.tok_embeddings(tokens) h = self.dropout(h) freqs_cos = self.freqs_cos[:seqlen] freqs_sin = self.freqs_sin[:seqlen] for layer in self.layers: h = layer(h, freqs_cos, freqs_sin) h = self.norm(h) if targets is not None: # if we are given some desired targets also calculate the loss logits = self.output(h) self.last_loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1) else: # inference-time mini-optimization: only forward the output on the very last position logits = self.output(h[:, [-1], :]) # note: using list [-1] to preserve the time dim self.last_loss = None return logits def configure_optimizers(self, weight_decay, learning_rate, betas, device_type): # start with all of the candidate parameters param_dict = {pn: p for pn, p in self.named_parameters()} # filter out those that do not require grad param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad} # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no. # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't. decay_params = [p for n, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2] optim_groups = [ {'params': decay_params, 'weight_decay': weight_decay}, {'params': nodecay_params, 'weight_decay': 0.0} ] num_decay_params = sum(p.numel() for p in decay_params) num_nodecay_params = sum(p.numel() for p in nodecay_params) print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters") print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters") # Create AdamW optimizer and use the fused version if it is available fused_available = 'fused' in inspect.signature(torch.optim.AdamW).parameters use_fused = fused_available and device_type == 'cuda' extra_args = dict(fused=True) if use_fused else dict() optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas, **extra_args) print(f"using fused AdamW: {use_fused}") return optimizer def estimate_mfu(self, fwdbwd_per_iter, dt): """ estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS """ # first estimate the number of flops we do per iteration. # see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311 N = sum(p.numel() for p in self.parameters()) cfg = self.params L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim//cfg.n_heads, cfg.max_seq_len flops_per_token = 6*N + 12*L*H*Q*T flops_per_fwdbwd = flops_per_token * T flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter # express our flops throughput as ratio of A100 bfloat16 peak flops flops_achieved = flops_per_iter * (1.0/dt) # per second flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS mfu = flops_achieved / flops_promised return mfu @torch.inference_mode() def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None): """ Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and complete the sequence max_new_tokens times, feeding the predictions back into the model each time. Most likely you'll want to make sure to be in model.eval() mode of operation for this. Also note this is a super inefficient version of sampling with no key/value cache. """ for _ in range(max_new_tokens): # if the sequence context is growing too long we must crop it at block_size idx_cond = idx if idx.size(1) <= self.params.max_seq_len else idx[:, -self.params.max_seq_len:] # forward the model to get the logits for the index in the sequence logits = self(idx_cond) logits = logits[:, -1, :] # crop to just the final time step if temperature == 0.0: # "sample" the single most likely index _, idx_next = torch.topk(logits, k=1, dim=-1) else: # pluck the logits at the final step and scale by desired temperature logits = logits / temperature # optionally crop the logits to only the top k options if top_k is not None: v, _ = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') # apply softmax to convert logits to (normalized) probabilities probs = F.softmax(logits, dim=-1) idx_next = torch.multinomial(probs, num_samples=1) # append sampled index to the running sequence and continue idx = torch.cat((idx, idx_next), dim=1) return idx |
下面是调用代码,参照 https://github.com/karpathy/llama2.c 项目下的 export.py 文件,命名为test07.py,文件保存到 newsrc 目录下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
from transformers import AutoModelForCausalLM from model import ModelArgs, Transformer import numpy as np import torch from torch import nn def load_hf_model(model_path): # load HF model hf_model = AutoModelForCausalLM.from_pretrained(model_path) hf_dict = hf_model.state_dict() # convert LlamaConfig to ModelArgs config = ModelArgs() config.dim = hf_model.config.hidden_size config.n_layers = hf_model.config.num_hidden_layers config.n_heads = hf_model.config.num_attention_heads config.n_kv_heads = hf_model.config.num_attention_heads config.vocab_size = hf_model.config.vocab_size config.hidden_dim = hf_model.config.intermediate_size config.norm_eps = hf_model.config.rms_norm_eps config.max_seq_len = hf_model.config.max_position_embeddings # create a new Transformer object and set weights model = Transformer(config) model.tok_embeddings.weight = nn.Parameter(hf_dict['model.embed_tokens.weight']) model.norm.weight = nn.Parameter(hf_dict['model.norm.weight']) # huggingface permutes WQ and WK, this function reverses it def permute_reverse(w, n_heads=config.n_heads, dim1=config.dim, dim2=config.dim): return w.view(n_heads, 2, dim1 // n_heads // 2, dim2).transpose(1, 2).reshape(dim1, dim2) for layer in model.layers: i = layer.layer_id layer.attention_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.input_layernorm.weight']) layer.attention.wq.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.q_proj.weight'])) layer.attention.wk.weight = nn.Parameter(permute_reverse(hf_dict[f'model.layers.{i}.self_attn.k_proj.weight'])) layer.attention.wv.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.v_proj.weight']) layer.attention.wo.weight = nn.Parameter(hf_dict[f'model.layers.{i}.self_attn.o_proj.weight']) layer.ffn_norm.weight = nn.Parameter(hf_dict[f'model.layers.{i}.post_attention_layernorm.weight']) layer.feed_forward.w1.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.gate_proj.weight']) layer.feed_forward.w2.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.down_proj.weight']) layer.feed_forward.w3.weight = nn.Parameter(hf_dict[f'model.layers.{i}.mlp.up_proj.weight']) # final classifier model.output.weight = nn.Parameter(hf_dict['lm_head.weight']) model.eval() return model # 指定模型路径 model_path = "meta-llama/Llama-2-7b-chat-hf" model = load_hf_model(model_path) print(model) for name, param in model.named_parameters(): print(f"{name}: {param.size()}") |
运行 test07.py
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
python newsrc/test07.py Loading checkpoint shards: 100%|███████| 2/2 [01:25<00:00, 42.94s/it] Transformer( (tok_embeddings): Embedding(32000, 4096) (dropout): Dropout(p=0.0, inplace=False) (layers): ModuleList( (0-31): 32 x TransformerBlock( (attention): Attention( (wq): Linear(in_features=4096, out_features=4096, bias=False) (wk): Linear(in_features=4096, out_features=4096, bias=False) (wv): Linear(in_features=4096, out_features=4096, bias=False) (wo): Linear(in_features=4096, out_features=4096, bias=False) (attn_dropout): Dropout(p=0.0, inplace=False) (resid_dropout): Dropout(p=0.0, inplace=False) ) (feed_forward): FeedForward( (w1): Linear(in_features=4096, out_features=11008, bias=False) (w2): Linear(in_features=11008, out_features=4096, bias=False) (w3): Linear(in_features=4096, out_features=11008, bias=False) (dropout): Dropout(p=0.0, inplace=False) ) (attention_norm): RMSNorm() (ffn_norm): RMSNorm() ) ) (norm): RMSNorm() (output): Linear(in_features=4096, out_features=32000, bias=False) ) tok_embeddings.weight: torch.Size([32000, 4096]) layers.0.attention.wq.weight: torch.Size([4096, 4096]) layers.0.attention.wk.weight: torch.Size([4096, 4096]) layers.0.attention.wv.weight: torch.Size([4096, 4096]) layers.0.attention.wo.weight: torch.Size([4096, 4096]) layers.0.feed_forward.w1.weight: torch.Size([11008, 4096]) layers.0.feed_forward.w2.weight: torch.Size([4096, 11008]) layers.0.feed_forward.w3.weight: torch.Size([11008, 4096]) layers.0.attention_norm.weight: torch.Size([4096]) layers.0.ffn_norm.weight: torch.Size([4096]) ... layers.31.attention.wq.weight: torch.Size([4096, 4096]) layers.31.attention.wk.weight: torch.Size([4096, 4096]) layers.31.attention.wv.weight: torch.Size([4096, 4096]) layers.31.attention.wo.weight: torch.Size([4096, 4096]) layers.31.feed_forward.w1.weight: torch.Size([11008, 4096]) layers.31.feed_forward.w2.weight: torch.Size([4096, 11008]) layers.31.feed_forward.w3.weight: torch.Size([11008, 4096]) layers.31.attention_norm.weight: torch.Size([4096]) layers.31.ffn_norm.weight: torch.Size([4096]) norm.weight: torch.Size([4096]) output.weight: torch.Size([32000, 4096]) |