AI21 Jamba 1.5 系列模型是最先进的混合 SSM-Transformer 指令,遵循基础模型。Jamba 模型是市场上最强大、最高效的长上下文模型,其推理速度比同等大小的领先模型快 2.5 倍。
这些模型展示了卓越的长上下文处理、速度和质量。它们标志着非变压器模型首次成功地扩展到市场领先模型的质量和强度。
Jamba 1.5 Mini(12B 活动/52B 总计)和 Jamba 1.5 Large(94B 活动/398B 总计)还针对业务用例和功能进行了优化,例如函数调用、结构化输出 (JSON) 和接地生成。
这些模型是在 Jamba 开放模型许可证下发布的,该许可证是一种宽松的许可证,允许根据许可条款进行全面研究使用和商业用途。如果您需要根据自己的需求许可该模型,请与我们联系。
Model Details 型号详细信息
- Developed by: AI21
开发者: AI21 - Model type: Joint Attention and Mamba (Jamba)
型号类型:联合关注和曼巴 (Jamba) - License: Jamba Open Model License
许可证:Jamba Open Model License - Context length: 256K
上下文长度:256K - Knowledge cutoff date: March 5, 2024
知识截止日期:3月 5, 2024 - Supported languages: English, Spanish, French, Portuguese, Italian, Dutch, German, Arabic and Hebrew
支持的语言:英语、西班牙语、法语、葡萄牙语、意大利语、荷兰语、德语、阿拉伯语和希伯来语
用法
先决条件
为了运行优化的 Mamba 实现,您首先需要安装并:mamba-ssm
causal-conv1d
您还必须在 CUDA 设备上拥有模型。
1 2 |
pip install -U mamba-ssm pip install -U causal-conv1d |
使用 vLLM 运行模型
使用 Jamba 1.5 Mini 执行高效推理的推荐方法是使用 vLLM。首先,确保安装 vLLM(需要 0.5.4 或更高版本)
1 |
pip install -U vllm |
这里建议从源码构建 vLLM
1 2 3 |
git clone https://github.com/vllm-project/vllm.git cd vllm pip install -e . # This may take 5-10 minutes. |
在下面的示例中,number_gpus
应与要部署 Jamba 1.5 Mini 的 GPU 数量匹配。至少需要 2 个 80GB GPU,或是8个24G GPU.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from vllm import LLM, SamplingParams from transformers import AutoTokenizer model = "ai21labs/AI21-Jamba-1.5-Mini" number_gpus = 2 llm = LLM(model=model, max_model_len=200*1024, tensor_parallel_size=number_gpus) tokenizer = AutoTokenizer.from_pretrained(model) messages = [ {"role": "system", "content": "You are an ancient oracle who speaks in cryptic but wise phrases, always hinting at deeper meanings."}, {"role": "user", "content": "Hello!"}, ] prompts = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) sampling_params = SamplingParams(temperature=0.4, top_p=0.95, max_tokens=100) outputs = llm.generate(prompts, sampling_params) generated_text = outputs[0].outputs[0].text print(generated_text) #Output: Seek and you shall find. The path is winding, but the journey is enlightening. What wisdom do you seek from the ancient echoes? |
在8个 RTX 4090(24G)的环境下,max_model_len 需要减少,max_model_len 越大,时间越长。模型在58k的情况下加载时间大概半个小时左右,
在24G GPU 内存下,max_model_len 过大则无法执行,下面是一个简单的对话代码:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from vllm import LLM, SamplingParams from transformers import AutoTokenizer model = "ai21labs/AI21-Jamba-1.5-Mini" number_gpus = 8 llm = LLM(model=model, max_model_len=58*1024, # 200*1024, tensor_parallel_size=number_gpus) tokenizer = AutoTokenizer.from_pretrained(model) # 初始化系统消息 system_message = {"role": "system", "content": "You are an ancient oracle who speaks in cryptic but wise phrases, always hinting at deeper meanings."} while True: # 获取用户输入 user_input = input("You: ") # 如果输入为空或者用户输入 'exit' 则退出循环 if not user_input.strip() or user_input.lower() == "exit": print("Exiting...") break # 构建消息列表 messages = [ system_message, {"role": "user", "content": user_input}, ] # 生成 prompts prompts = tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False) # 设置生成参数 sampling_params = SamplingParams(temperature=0.4, top_p=0.95, max_tokens=100) # 生成回答 outputs = llm.generate(prompts, sampling_params) # 获取并输出生成的文本 generated_text = outputs[0].outputs[0].text print(f"Oracle: {generated_text}") |
使用 2 个 80GB A100 GPU 上的默认 BF16 精度和默认的 vLLM 配置,您将能够对长达 200K 令牌的提示执行推理。在超过 2 个 80GB GPU 上,您可以轻松适应完整的 256K 环境。
注意:vLLM 的主
分支有一些特定于 Jamba 架构的内存利用率改进,允许在 2 个 80 GPU 上使用完整的 256K 上下文长度。如果您希望使用它们,您可以从源代码构建 vLLM。
ExpertsInt8 量化
我们开发了一种创新且高效的量化技术 ExpertsInt8,专为 vLLM 中部署的 MoE 模型(包括 Jamba 模型)而设计。使用它,您将能够在单个80 GB GPU 上部署 Jamba 1.5 Mini 。
ExpertsInt8 在最新的 vLLM 版本上尚不可用,但它已合并到 main
分支。要使用它,请从源代码构建 vLLM。
使用默认的 vLLM 配置,您可以在单个 80GB A100 GPU 上安装高达 100K 的提示:
1 2 3 4 5 6 7 |
import os os.environ['VLLM_FUSED_MOE_CHUNK_SIZE']='32768' # This is a workaround a bug in vLLM's fused_moe kernel from vllm import LLM llm = LLM(model="ai21labs/AI21-Jamba-1.5-Mini", max_model_len=100*1024, quantization="experts_int8") |
使用transformers
运行模型
以下示例以 BF16 精度将 Jamba 1.5 Mini 加载到 GPU,使用优化的 FlashAttention2 和 Mamba 内核,并使用 Accelerate 在多个 GPU 上并行化模型。请注意,在半精度 (FP16/BF16) 下,Jamba 1.5 Mini 太大,无法安装在单个 80GB GPU 上,因此您至少需要 2 个这样的 GPU。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch from transformers import AutoModelForCausalLM, AutoTokenizer model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-1.5-Mini", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", device_map="auto") tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-1.5-Mini") messages = [ {"role": "system", "content": "You are an ancient oracle who speaks in cryptic but wise phrases, always hinting at deeper meanings."}, {"role": "user", "content": "Hello!"}, ] input_ids = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors='pt').to(model.device) outputs = model.generate(input_ids, max_new_tokens=216) # Decode the output conversation = tokenizer.decode(outputs[0], skip_special_tokens=True) # Split the conversation to get only the assistant's response assistant_response = conversation.split(messages[-1]['content'])[1].strip() print(assistant_response) # Output: Seek and you shall find. The path is winding, but the journey is enlightening. What wisdom do you seek from the ancient echoes? |
注意:transformers
版本 4.44.0 和 4.44.1 存在一个 bug,该 bug 限制了运行 Jamba 架构的能力。请确保您未使用这些版本。
注意:如果您在为优化的 Mamba 内核安装 mamba-ssm
和 causal-conv1d
时遇到问题,您可以在没有它们的情况下运行 Jamba 1.5 Mini,但代价是额外的延迟。为此,请在通过 AutoModelForCausalLM.from_pretained()
加载模型时添加 kwarg use_mamba_kernels=False
。
以 8 位加载模型
使用 8 位精度,可以在单个 80GB GPU 上适应高达 140K 的序列长度。您可以使用 bitsandbytes 轻松地将模型量化为 8 位。为了不降低模型质量,我们建议从量化中排除 Mamba 块:
使用 8 位精度,可以在单个 80GB GPU 上适应高达 140K 的序列长度。您可以使用 bitsandbytes 轻松地将模型量化为 8 位。为了不降低模型质量,我们建议从量化中排除 Mamba 块:
1 2 3 4 5 6 7 |
from transformers import AutoModelForCausalLM, BitsAndBytesConfig quantization_config = BitsAndBytesConfig(load_in_8bit=True, llm_int8_skip_modules=["mamba"]) model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-1.5-Mini", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", quantization_config=quantization_config) |
在 CPU 上加载模型
如果您无法访问 GPU,您还可以在 CPU 上加载和运行 Jamba 1.5 Mini。请注意,这将导致推理性能不佳。
1 2 3 |
from transformers import AutoModelForCausalLM model = AutoModelForCausalLM.from_pretrained("ai21labs/AI21-Jamba-1.5-Mini", use_mamba_kernels=False) |
模型特性
与 Jamba 一起使用的工具
Jamba 1.5 根据 Huggingface 的工具使用 API,支持工具使用能力。用户定义的工具入到聊天模板的专用部分中,模型经过训练可以识别该部分。
给定包含工具的对话,模型可以输出内容和/或工具调用。工具调用在助手消息中被格式化为 json 格式的词典列表,包装在专用的特殊令牌中,如以下示例所示。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-1.5-Mini") messages = [ { "role": "user", "content": "What's the weather like right now in Jerusalem and in London?" } ] tools = [ { 'type': 'function', 'function': { 'name': 'get_current_weather', 'description': 'Get the current weather', 'parameters': { 'type': 'object', 'properties': { 'location': {'type': 'string', 'description': 'The city and state, e.g. San Francisco, CA'}, 'format': {'type': 'string', 'enum': ['celsius', 'fahrenheit'], 'description': 'The temperature unit to use. Infer this from the users location.'} }, 'required': ['location', 'format'] } } } ] prompt = tokenizer.apply_chat_template( messages, tools=tools, tokenize=False, ) |
输出:
1 2 3 4 |
<tool_calls>[ {"name": "get_current_weather", "arguments": {"location": "Jerusalem", "format": "celsius"}}, {"name": "get_current_weather", "arguments": {"location": "celsius", "format": "celsius"}} ]</tool_calls> |
将工具响应反馈到模型中
现在模型调用了工具,我们需要将工具响应反馈给模型。您可以从模型的响应中解析工具调用,并在 assistant 消息中传播 tool_calls
字段,或者只是将格式化的响应保留在 content
字段中。请注意,响应的顺序应与相应 Assistant 消息中工具调用的顺序一致。这部分介绍如何操作。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-1.5-Mini") # Note that you must send the tool responses in the same order as the model called the tools: messages = [ { "role": "user", "content": "What's the weather like right now in Jerusalem and in London?" }, { "role": "assistant", "content": null, "tool_calls": [ { "name": "get_current_weather", "arguments": "{\"location\": \"Jerusalem\", \"format\": \"celsius\"}" }, { "name": "get_current_weather", "arguments": "{\"location\": \"London\", \"format\": \"celsium\"}" } ] }, { "role": "tool", "content": "The weather in Jerusalem is 18 degrees celsius." }, { "role": "tool", "content": "The weather in London is 8 degrees celsius." } ] tool_use_prompt = tokenizer.apply_chat_template( messages, tools=tools, tokenize=False, ) |
输出示例:
1 |
The weather in Jerusalem is currently 18 degrees Celsius. In London, it is 8 degrees Celsius. |
与Jamba接地的一代
一个常见的用例LLMs是接地生成和 RAG,其中模型需要根据给定的文档集或文档片段回答问题或遵循说明。为了标准化这一过程,Jamba 在其聊天模板中接受了特定的 “documents” 部分的训练。该模型经过训练来处理此部分,当以这种方式格式化任务时,接地生成任务显示出更好的性能。
与工具类似,除了对话之外,工具还作为模型的外部参数提供,文档也以类似的方式提供。为了支持文档级元数据,文档被定义为具有您选择的键值的字典。这些在 chat 模板中进行了格式化。获得特殊处理的两个键是 “title” 和 “text” ,前者的格式显示在文档顶部,后者是必填字段,用于定义文档的实际文本。
将文档附加到 Jamba 1.5 提示符
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-1.5-Mini") messages = [ { "role": "user", "content": "Who wrote Harry Potter?" } ] documents = [ { "text": "Harry Potter is a series of seven fantasy novels written by British author J. K. Rowling.", "title": "Harry Potter" }, { "text": "The Great Gatsby is a novel by American writer F. Scott Fitzgerald.", "title": "The Great Gatsby", "country": "United States", "genre": "Novel" } ] prompt = tokenizer.apply_chat_template( messages, documents=documents, tokenize=False, ) # Output: J. K. Rowling |
JSON 模式
Jamba 1.5 使用特定的 “旋钮” 进行训练,这有助于引导模型实现常见的请求行为。通过在系统消息中包含特定的预定义文本来启用每种行为。为了便于使用,我们已将它们作为标志包含在 Jamba 1.5 的聊天模板中,因此可以通过向聊天模板传递适当的参数来切换它们。
Jamba 1.5 经过训练,可以在请求时生成有效的 JSON。它自然而然地这样做,但是当 JSON 模式旋钮被激活时,有效 json 的可能性会大大增加。在 JSON 模式下,Jamba 1.5 将尝试输出有效的 JSON,而不管用户请求如何。但是,强烈建议在用户请求或系统消息中指定有关预期 json 架构的信息,以获得最佳结果,如以下示例所示。
Jamba 1.5 中 JSON 旋钮的使用
1 2 3 4 5 6 7 8 9 10 11 12 13 |
from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-1.5-Mini") messages = [ {'role':'user', 'content':'Describe the first American president. Include year of birth (number) and name (string).'} ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False, knobs={"response_format": "json_object", "is_set": True}) #Output: "{ "year of birth": 1732, "name": "George Washington." }" |
微调示例
以下示例使用 huggingface/trl 中的 SFTTrainer
,因此请确保已安装它:
1 |
pip install trl |
以下是在 bfloat16 中使用 LoRA PEFT 进行微调的示例(需要 ~130GB GPU RAM,例如 2xA100 80GB):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
import torch from transformers import AutoTokenizer, AutoModelForCausalLM from datasets import load_dataset from trl import SFTTrainer, SFTConfig from peft import LoraConfig tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-1.5-Mini") model = AutoModelForCausalLM.from_pretrained( "ai21labs/AI21-Jamba-1.5-Mini", device_map="auto", torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", ) lora_config = LoraConfig( r=8, target_modules=[ "embed_tokens", "x_proj", "in_proj", "out_proj", # mamba "gate_proj", "up_proj", "down_proj", # mlp "q_proj", "k_proj", "v_proj", "o_proj", # attention ], task_type="CAUSAL_LM", bias="none", ) dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train") training_args = SFTConfig( output_dir="/dev/shm/results", logging_dir="./logs", num_train_epochs=2, per_device_train_batch_size=4, learning_rate=1e-5, logging_steps=10, gradient_checkpointing=True, max_seq_length=4096, save_steps=100, ) trainer = SFTTrainer( model=model, tokenizer=tokenizer, args=training_args, peft_config=lora_config, train_dataset=dataset, ) trainer.train() |
请注意,示例中的数据集使用对话格式(带有消息
列),因此 SFTTrainer
会自动应用 Jamba 的聊天模板,如 TRL 文档中所述。
QLoRA 示例
要在单个 80GB GPU 上进行微调,您可以读取 QLoRA,它将 LoRA 与量化为 4 位的冻结模型相结合:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from datasets import load_dataset from trl import SFTTrainer, SFTConfig from peft import LoraConfig tokenizer = AutoTokenizer.from_pretrained("ai21labs/AI21-Jamba-1.5-Mini") quantization_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( "ai21labs/AI21-Jamba-1.5-Mini", device_map="auto", quantization_config=quantization_config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2", ) lora_config = LoraConfig( r=8, target_modules=[ "embed_tokens", "x_proj", "in_proj", "out_proj", # mamba "gate_proj", "up_proj", "down_proj", # mlp "q_proj", "k_proj", "v_proj", "o_proj", # attention ], task_type="CAUSAL_LM", bias="none", ) dataset = load_dataset("philschmid/dolly-15k-oai-style", split="train") training_args = SFTConfig( output_dir="./results", logging_dir="./logs", num_train_epochs=2, per_device_train_batch_size=8, learning_rate=1e-5, logging_steps=1, gradient_checkpointing=True, gradient_checkpointing_kwargs={"use_reentrant": False}, save_steps=100, max_seq_length=4096, ) trainer = SFTTrainer( model=model, tokenizer=tokenizer, args=training_args, peft_config=lora_config, train_dataset=dataset, ) trainer.train() |
注意:上面的示例需要 4 位量化的 bitsandbytes
包:
1 |
pip install bitsandbytes |