使用Parquet和Polars高效处理文本嵌入

本文探讨如何利用Parquet文件和Polars库高效存储和处理文本嵌入数据,包括嵌入生成、相似性计算、元数据过滤等技术细节,提供比传统向量数据库更轻量灵活的解决方案。

使用Parquet和Polars高效处理文本嵌入

文本嵌入,特别是由大型语言模型生成的现代嵌入,是生成式AI热潮中最实用的应用之一。嵌入是一组代表对象的数字列表:对于文本嵌入,它们可以代表单词、句子、完整段落和文档,并具有惊人的区分度。

最近,我为截至2025年2月Aetherdrift扩展发布的所有独特Magic: the Gathering卡牌创建了文本嵌入:总共32,254张。通过这些嵌入,我可以找到卡牌之间通过其卡牌设计编码表示的数学相似性,包括所有机械属性,如卡牌名称、卡牌成本、卡牌文本,甚至卡牌稀有度。

Wrath of God卡牌及其前四个最相似卡牌

此外,我可以为所有这些卡牌创建一个有趣的2D UMAP投影,这也识别出了有趣的模式:

UMAP降维过程

UMAP降维过程还隐式地将Magic卡牌聚类到逻辑簇中,例如按卡牌颜色和卡牌类型。

这些Magic卡牌嵌入不仅用于漂亮的数据可视化,还用于特殊目的。嵌入是使用新但被低估的gte-modernbert-base嵌入模型生成的,详细过程可在GitHub存储库中找到。嵌入本身(包括重现2D UMAP可视化的坐标值)可作为Hugging Face数据集使用。

大多数涉及嵌入生成的教程忽略了一个明显问题:生成文本嵌入后如何处理它们?常见解决方案是使用向量数据库,如faiss或qdrant,甚至是云托管服务如Pinecone。但这些并不容易使用:faiss有令人困惑的配置选项,qdrant需要使用Docker容器托管存储服务器,Pinecone可能很快变得非常昂贵,其免费Starter层级有限。

关于文本嵌入,许多人不知道的是,如果数据不太大,你不需要向量数据库来计算最近邻相似性。使用numpy和我的Magic卡牌嵌入,一个32,254个float32嵌入的2D矩阵,维度为768D(“较小"LLM嵌入模型的常见维度),占用94.49 MB系统内存,这对于现代个人计算机来说相对较低,并且可以适应云VM的免费使用层级。如果查询向量和嵌入本身都是单位归一化的(许多嵌入生成器默认归一化),那么查询和嵌入之间的矩阵点积会产生[-1, 1]之间的余弦相似性,其中分数越高越好/越相似。由于点积是线性代数的一个基本方面,numpy的实现非常快:借助额外的numpy排序技巧,在我的M3 Pro MacBook Pro上,平均只需1.08毫秒即可计算所有32,254个点积,找到前3个最相似的嵌入,并返回它们的矩阵idx和余弦相似性分数。

1
2
3
4
5
6
7
8
9
def fast_dot_product(query, matrix, k=3):
    dot_products = query @ matrix.T

    idx = np.argpartition(dot_products, -k)[-k:]
    idx = idx[np.argsort(dot_products[idx])[::-1]]

    score = dot_products[idx]

    return idx, score

在向量数据库的大多数实现中,一旦插入嵌入,它们就会以专有序列化格式卡在那里,并且你被锁定在该库和服务中。如果你只是构建个人宠物项目或健全检查嵌入以确保结果良好,那将带来巨大的摩擦。例如,当我想试验嵌入时,我在带有GPU的云服务器上生成它们,因为基于LLM的嵌入模型通常在没有GPU的情况下生成速度很慢,然后将它们下载到我的个人计算机上。处理嵌入的最佳方式是什么,以便它们可以轻松在机器之间移动,并且采用非专有格式?

经过大量个人试错,答案是Parquet文件,这仍然有惊人的细微差别。但在讨论为什么Parquet文件好之前,让我们先讨论如何不存储嵌入。

最差的存储嵌入方式

错误但不幸常见的方式是以文本格式(如CSV文件)存储嵌入。文本数据比float32数据大得多:例如,具有全精度的十进制数(例如2.145829051733016968e-02)作为float32是32位/4字节,而作为文本表示(在这种情况下是24个ASCII字符)是24字节,大了6倍。当CSV保存和加载时,数据必须在numpy和数组的字符串表示之间序列化,这增加了显著开销。尽管如此,在某中心官方教程中,他们使用pandas将嵌入保存为CSV,并承认警告"因为此示例仅使用几千个字符串,我们将它们存储在CSV文件中。(对于更大的数据集,使用向量数据库,这将更高效。)"。在Magic卡牌嵌入的情况下,pandas-to-CSV在任何编码选项中表现最差:稍后详细说明原因。

Numpy有本地方法将嵌入保存和加载为.txt,这很简单:

1
2
3
np.savetxt("embeddings_txt.txt", embeddings)

embeddings_r = np.loadtxt("embeddings_txt.txt", dtype=np.float32, delimiter=" ")

生成的文件不仅需要几秒钟保存和加载,而且巨大:631.5 MB!

顺便说一句,HTTP API(如某中心的嵌入API)确实通过文本传输嵌入,这增加了不必要的延迟和带宽开销。我希望更多嵌入提供商提供gRPC API,允许传输二进制float32数据以获得性能提升:例如,某机构的Python SDK就是这样做的。

第二种错误方法是将嵌入矩阵保存到磁盘作为Python pickle对象,它使用原生pickle库的几行代码将其内存表示存储在磁盘上。不幸的是,pickling在机器学习行业中很常见,因为许多ML框架(如scikit-learn)没有简单的方法来序列化编码器和模型。但它有两个主要警告:pickled文件是巨大的安全风险,因为它们可以执行任意代码,并且pickled文件可能无法保证在其他机器或Python版本上打开。现在是2025年,如果可以,请停止pickling。

在Magic卡牌嵌入的情况下,它确实可以即时保存/加载,磁盘上的文件大小为94.49 MB:与其内存消耗相同,并且大约是文本大小的1/6:

1
2
3
4
5
with open("embeddings_matrix.pkl", "wb") as f:
    pickle.dump(embeddings, f)

with open("embeddings_matrix.pkl", "rb") as f:
    embeddings_r = pickle.load(f)

但仍然有更好和更简单的方法。

预期但不太好的存储嵌入方式

Numpy本身有一种规范的方法来保存和加载矩阵——令人恼火的是,出于兼容性原因,默认保存为pickle,但幸运的是可以通过设置allow_pickle=False来禁用:

1
2
3
np.save("embeddings_matrix.npy", embeddings, allow_pickle=False)

embeddings_r = np.load("embeddings_matrix.npy", allow_pickle=False)

文件大小和I/O速度与pickle方法相同。

这有效——并且是我使用了一段时间的东西——但在此过程中暴露了另一个问题:我们如何将元数据(在这种情况下是Magic卡牌)映射到嵌入?目前,我们使用最相似匹配的idx来执行对源数据的高效批量查找。在这种情况下,行数完全匹配卡牌数量,但如果需要更改嵌入矩阵,例如添加或删除卡牌及其嵌入,会发生什么?如果你想添加数据集过滤器会发生什么?它变成了一团糟,不可避免地导致技术债务。

解决方案是将元数据(如卡牌名称、卡牌文本和属性)与其嵌入共置:这样,如果它们后来被添加、删除或排序,结果将保持不变。现代向量数据库(如qdrant和某机构)就是这样做的,能够在查询最相似向量的同时过滤和排序元数据。在numpy本身中这样做是一个坏主意,因为它更优化数字而不是其他数据类型(如字符串),这些数据类型可用的操作有限。

解决方案是寻找另一种可以同时存储元数据和嵌入的文件格式,答案是Parquet文件。但关于什么是最好的交互方式,有一个兔子洞。

什么是Parquet文件?

Parquet,由开源Apache Parquet项目开发,是一种用于处理列式数据的文件格式,但尽管于2013年首次发布,直到最近才在数据科学社区中流行起来。¹ Parquet最相关的特性是生成的文件为每列类型化,并且这种类型化包括嵌套列表,例如嵌入只是一个float32值列表。作为奖励,列式格式允许下游库选择性地非常快速地保存/加载它们,比CSV快得多,并且很少出现解析错误。文件格式还允许高效的压缩和解压缩,但对于嵌入来说效果较差,因为冗余数据很少。

对于Parquet文件I/O,标准方法是使用Apache Arrow协议,该协议在内存中是列式的,这与磁盘上的Parquet存储介质互补。但你如何使用Arrow?

如何在Python中使用Parquet文件进行嵌入?

理想情况下,我们需要一个可以轻松处理嵌套数据、可以与numpy互操作以序列化为矩阵,并且可以运行快速点积的库。

在Python中本地与Parquet交互的官方Arrow库是pyarrow。这里,我有一个示例Parquet文件,使用[剧透]生成,包含卡牌元数据和嵌入列,每行的嵌入对应于该卡牌。

1
df = pa.parquet.read_table("mtg-embeddings.parquet")

但pyarrow不是DataFrame库,尽管数据在Table中,但很难切片和访问:文档建议如果你需要更高级的操作,可以导出到pandas。

其他更传统的数据科学库可以直接利用pyarrow。最流行的是pandas本身,它可以读取/写入Parquet。有很多很多关于使用pandas的资源,因此它通常是数据科学从业者的首选。

1
2
df = pd.read_parquet("mtg-embeddings.parquet", columns=["name", "embedding"])
df

对于嵌入用例,有一个主要弱点:pandas非常不擅长嵌套数据。从上图你会看到嵌入列看起来是一个数字列表,但它实际上是一个numpy对象列表,这是一种非常低效的数据类型,这也是为什么我怀疑将其写入CSV非常慢。简单地用df[“embedding”].to_numpy()转换为numpy会得到一个1D数组,这肯定是错误的,并且尝试将其转换为float32不起作用。我发现从pandas嵌入列提取嵌入矩阵的最佳方法是np.vstack()嵌入,例如np.vstack(df[“embedding”].to_numpy()),这确实会产生一个(32254, 768) float32矩阵,如预期。这增加了大量计算和内存开销,以及不必要的numpy数组副本。最后,在计算候选查询和嵌入矩阵之间的点积之后,可以使用df.loc[idx]检索具有最相似值的行元数据。²

然而,还有另一个更新的表格数据库,不仅比pandas更快,而且对嵌套数据有适当支持。该库是polars。

Polars的力量

Polars是一个相对较新的Python库,主要用Rust编写并支持Arrow,这使其比pandas和许多其他DataFrame库具有巨大的性能提升。在Magic卡牌的情况下,32k行远非"大数据”,使用高性能库的收益较小,但有一些意外特性恰好完美适用于嵌入用例。

与pandas一样,你使用read_parquet()读取parquet文件:

1
2
df = pl.read_parquet("mtg-embeddings.parquet", columns=["name", "embedding"])
df

与pandas相比,表格输出有显著差异:它还报告其列的数据类型,更重要的是,它显示嵌入列由数组组成,全部为float32,全部长度为768。这是一个很好的开始!

polars也有一个to_numpy()函数。与pandas不同,如果你在列上调用to_numpy()作为Series,例如df[’embedding’].to_numpy(),返回的对象是一个numpy 2D矩阵:不需要np.vstack()。如果你查看该函数的文档,有一个好奇的特性:

此操作仅在必要时复制数据。当以下所有条件成立时,转换是零复制:[…]

零复制!对于列式存储的嵌入,条件将始终成立,但你可以设置allow_copy=False以防万一抛出错误。

相反,如果你想将2D嵌入矩阵添加到现有DataFrame并共置每个嵌入的相应元数据,例如在批量生成数千个嵌入后想要保存和下载生成的Parquet,就像向DataFrame添加列一样简单。

1
2
3
df = pl.with_columns(embedding=embeddings)

df.write_parquet("mtg-embeddings.parquet")

现在,让我们使用所有Magic卡牌元数据测试速度。如果我们对Magic卡牌执行嵌入相似性,但事先根据用户参数动态过滤数据集(因此同时过滤候选嵌入,因为它们共置),并像往常一样快速执行相似性计算,会怎样?让我们尝试Lightning Helix,一张即使不玩Magic的人也能自我解释其效果的卡牌。

Lightning Helix的最相似卡牌确实有类似效果,尽管"Lightning"卡牌造成伤害是Magic中的常见比喻。Warleader’s Helix是Lightning Helix的直接参考。

现在我们还可以找到Lightning Helix的类似卡牌,但带有过滤器。在这种情况下,让我们寻找一张Sorcery(类似于Instants但往往更强,因为它们有游戏限制)并且黑色作为其颜色之一。这将候选限制到原始数据集的约3%。给定query_embed,结果代码将如下所示:

1
2
3
4
5
6
7
8
df_filter = df.filter(
    pl.col("type").str.contains("Sorcery"),
    pl.col("manaCost").str.contains("B"),
)

embeddings_filter = df_filter["embedding"].to_numpy(allow_copy=False)
idx, _ = fast_dot_product(query_embed, embeddings_filter, k=4)
related_cards = df_filter[idx]

顺便说一句,在polars中,你可以用df[idx]调用DataFrame的行子集,这比pandas及其df.iloc[idx]无限好。

结果类似卡牌:

在这种情况下,相似性侧重于卡牌文本相似性,这些卡牌有近乎相同的文本。Smiting Helix也是Lightning Helix的直接参考。

速度方面,代码平均运行约1.48毫秒,或比计算所有点积慢约37%,因此过滤仍然有一些开销,这并不奇怪,因为过滤的dataframe确实复制了嵌入。总的来说,对于业余项目来说,它仍然足够快。

我创建了一个交互式Colab Notebook,你可以在其中生成任何Magic卡牌的相似性,并应用任何你想要的过滤器!

扩展到向量数据库

再次强调,所有这些都假设你将嵌入用于较小/非商业项目。如果你扩展到数十万个嵌入,用于查找相似性的parquet和点积方法应该仍然可以,但如果它是业务关键应用,查询向量数据库的边际成本可能低于快速相似性查找的边际收入。决定如何做出这些权衡是MLOps的有趣部分!

如果向量数量太大无法放入内存,但你不想全力投入向量数据库,另一个可能值得考虑的选项是使用可以现在支持向量嵌入的老式数据库。值得注意的是,SQLite数据库只是一个单一的可移植文件,然而与它们交互比polars的read_parquet()和write_parquet()有更多的技术开销和考虑。SQLite中向量数据库的一个显著实现是sqlite-vec扩展,它也允许同时过滤和相似性计算。

下次你处理嵌入时,考虑你是否真的需要向量数据库。对于许多应用,Parquet文件和polars的组合提供了你需要的一切:高效存储、快速相似性搜索和容易的元数据过滤。有时最简单的解决方案是最好的。

用于处理Magic卡牌数据、创建嵌入和绘制UMAP 2D投影的代码都可在此GitHub存储库中找到。

¹ 我怀疑广泛Parquet支持的主要瓶颈是某中心Excel和其他电子表格软件缺乏对该格式的本地支持。如果/当它们这样做时,每个数据科学家都会非常、非常高兴!

² 某中心使用pandas查找共置相似性的方法是手动迭代整个dataframe,计算候选和查询之间每行的每个余弦相似性,然后按分数排序。该实现肯定不可扩展。

comments powered by Disqus
使用 Hugo 构建
主题 StackJimmy 设计