此笔记本演示如何使用 Ada 嵌入来实现语义代码搜索。在本演示中,我们使用自己的 openai-python 代码存储库。我们实现了一个简单的文件解析版本,并从 python 文件中提取函数,可以嵌入、索引和查询。
1. 帮助程序函数
我们首先设置了一些简单的解析函数,允许我们从代码库中提取重要信息。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import pandas as pd from pathlib import Path DEF_PREFIXES = ['def ', 'async def '] NEWLINE = '\n' def get_function_name(code): """ Extract function name from a line beginning with 'def' or 'async def'. """ for prefix in DEF_PREFIXES: if code.startswith(prefix): return code[len(prefix): code.index('(')] def get_until_no_space(all_lines, i): """ Get all lines until a line outside the function definition is found. """ ret = [all_lines[i]] for j in range(i + 1, len(all_lines)): if len(all_lines[j]) == 0 or all_lines[j][0] in [' ', '\t', ')']: ret.append(all_lines[j]) else: break return NEWLINE.join(ret) def get_functions(filepath): """ Get all functions in a Python file. """ with open(filepath, 'r') as file: all_lines = file.read().replace('\r', NEWLINE).split(NEWLINE) for i, l in enumerate(all_lines): for prefix in DEF_PREFIXES: if l.startswith(prefix): code = get_until_no_space(all_lines, i) function_name = get_function_name(code) yield { 'code': code, 'function_name': function_name, 'filepath': filepath, } break def extract_functions_from_repo(code_root): """ Extract all .py functions from the repository. """ code_files = list(code_root.glob('**/*.py')) num_files = len(code_files) print(f'Total number of .py files: {num_files}') if num_files == 0: print('Verify openai-python repo exists and code_root is set correctly.') return None all_funcs = [ func for code_file in code_files for func in get_functions(str(code_file)) ] num_funcs = len(all_funcs) print(f'Total number of functions extracted: {num_funcs}') return all_funcs |
2. 数据加载
需要克隆 openai-python
1 |
!git clone https://github.com/openai/openai-python |
我们将首先加载 openai-python 文件夹,并使用我们上面定义的函数提取所需的信息。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 |
# Set user root directory to the 'openai-python' repository import os # 获取当前工作目录 current_directory = os.getcwd() # Assumes the 'openai-python' repository exists in the user's root directory current_directory = current_directory + '\\openai-python' # 假设code_root是你想要搜索的目录的路径 code_root = Path(current_directory) #print(code_root) # Extract all functions from the repository all_funcs = extract_functions_from_repo(code_root) |
运行结果:
1 2 |
Total number of .py files: 289 Total number of functions extracted: 349 |
现在我们有了内容,我们可以将数据传递给 text-embedding-3-small
模型并取回我们的向量嵌入。
1 2 3 4 5 6 7 |
from embeddings_utils import get_embedding df = pd.DataFrame(all_funcs) df['code_embedding'] = df['code'].apply(lambda x: get_embedding(x, model='text-embedding-3-small')) df['filepath'] = df['filepath'].map(lambda x: Path(x).relative_to(code_root)) df.to_csv("data/code_search_openai-python.csv", index=False) df.head() |
运行结果:

3. 测试
让我们通过一些简单的查询来测试我们的终结点。如果您熟悉 openai-python
存储库,您会发现我们只需简单的英文描述即可轻松找到我们正在寻找的函数。
我们定义了一个 search_functions 方法,该方法获取包含嵌入、查询字符串和其他一些配置选项的数据。搜索数据库的过程是这样的:
- 1.我们首先将查询字符串 (code_query) 嵌入到
text-embedding-3-small
.这里的推理是,像“反转字符串的函数”这样的查询字符串和像“def reverse(string): return string[::-1]”这样的函数在嵌入时会非常相似。 - 2.然后,我们计算查询字符串嵌入与数据库中所有数据点之间的余弦相似度。这给出了每个点和我们的查询之间的距离。
- 3.最后,我们按所有数据点与查询字符串的距离对它们进行排序,并返回函数参数中请求的结果数。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
from embeddings_utils import cosine_similarity def search_functions(df, code_query, n=3, pprint=True, n_lines=7): embedding = get_embedding(code_query, model='text-embedding-3-small') df['similarities'] = df.code_embedding.apply(lambda x: cosine_similarity(x, embedding)) res = df.sort_values('similarities', ascending=False).head(n) if pprint: for r in res.iterrows(): print(f"{r[1].filepath}:{r[1].function_name} score={round(r[1].similarities, 3)}") print("\n".join(r[1].code.split("\n")[:n_lines])) print('-' * 70) return res res = search_functions(df, 'fine-tuning input data validation logic', n=3) |
运行结果:

1 |
res = search_functions(df, 'find common suffix', n=2, n_lines=10) |
运行结果:

1 |
res = search_functions(df, 'Command line interface for fine-tuning', n=1, n_lines=20) |
运行结果:

1 |
res = search_functions(df, 'find main', n=1, n_lines=20) |
运行结果:

1 |
res = search_functions(df, 'Completions API tests', n=3) |
运行结果:
