本教程使用了surya-ocr库,实现本地RAG,
使用了嵌入模型 bert-base-multilingual-cased(支持多语言)
和 Ollama 环境下的推理模型 Qwen1.5-1.8B-Chat
Surya 是一个文档 OCR 工具包,可以处理pdf文件和图片等
- 90+ 种语言的 OCR,与云服务相比具有优势
- 任何语言的行级文本检测
- 布局分析(表格、图像、页眉等检测)
- 读取顺序检测
测试环境:Windows
1. 安装必要的库
1 2 3 4 5 |
!pip install surya-ocr !pip install python-magic !pip install -U sentence_transformers !pip install -U numpy |
2. 引入所有的库
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from sentence_transformers import SentenceTransformer import faiss import numpy as np from typing import List, Dict, Any import json import requests import io import pypdfium2 from surya.detection import batch_text_detection from surya.layout import batch_layout_detection from surya.model.detection.segformer import load_model, load_processor from surya.model.recognition.model import load_model as load_rec_model from surya.model.recognition.processor import load_processor as load_rec_processor from surya.model.ordering.processor import load_processor as load_order_processor from surya.model.ordering.model import load_model as load_order_model from surya.ordering import batch_ordering from surya.postprocessing.heatmap import draw_polys_on_image from surya.ocr import run_ocr from surya.postprocessing.text import draw_text_on_image from PIL import Image from surya.languages import CODE_TO_LANGUAGE from surya.input.langs import replace_lang_with_code from surya.schema import OCRResult, TextDetectionResult, LayoutResult, OrderResult from surya.settings import settings |
3. 处理 pdf 的函数定义
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
def open_pdf(pdf_file): # 打开文件并读取内容到内存 with open(pdf_file, 'rb') as file: pdf_data = file.read() stream = io.BytesIO(pdf_data) return pypdfium2.PdfDocument(stream) def page_count(pdf_file): doc = open_pdf(pdf_file) return len(doc) def get_page_image(pdf_file, page_num, dpi=96): doc = open_pdf(pdf_file) renderer = doc.render( pypdfium2.PdfBitmap.to_pil, page_indices=[page_num - 1], scale=dpi / 72, ) png = list(renderer)[0] png_image = png.convert("RGB") return png_image def ocr(img, langs: List[str]) -> (Image.Image, OCRResult): replace_lang_with_code(langs) img_pred = run_ocr([img], [langs], det_model, det_processor, rec_model, rec_processor)[0] bboxes = [l.bbox for l in img_pred.text_lines] text = [l.text for l in img_pred.text_lines] rec_img = draw_text_on_image(bboxes, text, img.size, langs, has_math="_math" in langs) return rec_img, img_pred def load_det_cached(): checkpoint = settings.DETECTOR_MODEL_CHECKPOINT return load_model(checkpoint=checkpoint), load_processor(checkpoint=checkpoint) def load_rec_cached(): return load_rec_model(), load_rec_processor() |
4. ollama 对话函数
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
def chat(messages: List[Dict[str, Any]], model: str = '', stream: bool = False) -> requests.Response: json_data = { 'model': model, 'messages': messages, 'stream': stream, } response = requests.post( 'http://127.0.0.1:11434/api/chat', json=json_data ) #print("Request Headers:", response.request.headers) #print("Request Body:", response.request.body) #print("Response Status Code:", response.status_code) #print("Response Body:", response.text) response_object = json.loads(response.text) return response_object['message']['content'] |
5. 嵌入模型
1 2 3 |
# 小模型用于创建嵌入 #embedder = SentenceTransformer('Qwen/Qwen1.5-0.5B-Chat') embedder = SentenceTransformer('bert-base-multilingual-cased') |
6. 参数初始化
1 2 3 4 5 6 7 8 |
languages=["English"] # Initialize an empty list to store the embeddings embeddings_list = [] documents = [] det_model, det_processor = load_det_cached() rec_model, rec_processor = load_rec_cached() |
languages=[“English”],支持多语言,可以自行加入其他语言,比如加入中文
1 |
languages=["English", "Chinese"] |
7. 处理 pdf 文件
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
in_file = "data/Learning to Model the World with Language.pdf" page_count = page_count(in_file) print(f"page_count=", page_count) # 循环遍历每一页 for page_number in range(page_count): pil_image = get_page_image(in_file, page_number + 1) rec_img, pred = ocr(pil_image, languages) document = "\n".join([p.text for p in pred.text_lines]) embeddings = embedder.encode(document) embeddings_list.append(embeddings) print(f"page {page_number + 1},{len(document)}:", document) # print(f"embeddings:{len(embeddings)},", embeddings) documents.append(document) |
8. 创建 FAISS 索引和推理
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
# 创建FAISS索引 if embeddings_list: embeddings_array = np.vstack(embeddings_list) index = faiss.IndexFlatL2(embeddings_array.shape[1]) index.add(embeddings_array.astype('float32')) # 用户问题处理与推理 #question = "What is the theme of the document? " question = "这份文档的主题是什么?" query_embedding = embedder.encode([question])[0].astype('float32') # 检索最相关的几个文档段落 combined_segments = "" k = 3 # 你希望检索的相关文档数量 D, I = index.search(np.array([query_embedding]), k=k) print("D:", D) print("I:", I) #print("Top", k, "most relevant document segments:") for idx, segment_index in enumerate(I[0]): most_relevant_segment = documents[segment_index] #print(f"{idx+1}: {most_relevant_segment}\n") combined_segments += " " + most_relevant_segment prompt = combined_segments + "\n\n###\n\n" + question + "\n\n用中文回答" messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": prompt} ] response = chat(messages=messages, model="qwen:1.8b-chat") print("Answer to the question:", response) else: print("No embeddings found. Please check your data.") |
显示部分结果内容:
1 2 3 4 5 6 7 |
D: [[143.52664 143.77243 145.9426 ]] I: [[1 0 2]] Answer to the question: 这篇文档的主题是多模态语言理解系统的实现。具体来说,该系统采用了多模态的输入处理方式,能够从视觉、听觉等多种途径获取人类语言信息。 在多模态语言理解系统中,文本作为输入信号被转化为视觉或听觉特征,如图像、音频等。这些特征通过多模态神经网络进行融合和映射,从而实现对复杂自然语言环境的理解和操作。 为了增强系统的泛化能力和适应性,该系统还引入了丰富的知识图谱和语义模型,以提升系统的跨模态理解与生成能力。此外,为了充分利用多模态语言理解系统在视觉、听觉等多种途径获取人类语言信息的基础上实现对复杂自然语言环境的理解和操作的优势,该系统还在其设计中充分考虑到了多模态语言理解系统的可扩展性和适应性,从而使得多模态语言理解系统的性能得以充分发挥。 |