栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

python 余弦相似度计算(faiss)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

python 余弦相似度计算(faiss)

faiss

faiss是为稠密向量提供高效相似度搜索和聚类的框架。由Facebook AI Research研发。

详见github https://github.com/facebookresearch/faiss

faiss常用的两个相似度搜索是L2欧氏距离搜索和余弦距离搜索(注意不是余弦相似度)

简单的使用流程:

import faiss 
index = faiss.IndexFlatL2(d)  # 建立L2索引,d是向量维度
index = faiss.IndexFlatIP(d) # 建立Inner product索引
index.add(train)  # 添加矩阵
D,I = index.search(test, k) # (D.shape = test.shape[0] * k, I同理)

上述代码实现了对于test向量(也可以是矩阵)索引train中L2距离最近的k个向量,返回其具体distance和索引index

IndexFlatIP()函数实现的是余弦距离的计算也就是 x y t xy^t xyt,显然,当向量范数不为一的情况下不能等同于余弦相似度 x y t ∣ ∣ x ∣ ∣ ∣ ∣ y ∣ ∣ frac{xy^t}{||x||||y||} ∣∣x∣∣∣∣y∣∣xyt​

在许多论文特别是需要计算索引的时候,相似度往往选择余弦相似度,因此在这里记录一下如何实现:

cosine similarity 实现
train = np.array([[1.0,1.0],[2.5,0],[0,2.5],[1.5,0.5]]).astype('float32') # 注意 必须为float32类型
test = np.array([[0.5,0.5]]).astype('float32')
print('L2 norm of train', np.linalg.norm(train[0]))
print('L2 norm of test', np.linalg.norm(test))
faiss.normalize_L2(train)
faiss.normalize_L2(test)
print('L2 norm of train', np.linalg.norm(train[0]))
print('L2 norm of test', np.linalg.norm(test))

L2 norm of train 1.4142135
L2 norm of test 0.70710677
L2 norm of train 0.99999994
L2 norm of test 0.99999994

对于被索引矩阵和查询向量,都先经过L2归一化,(normlize_L2函数)

定义索引函数

def KNN_cos(train_set, test_set, n_neighbours):
    index = faiss.IndexFlatIP(train_set.shape[1])
    index.add(train_set)
    D, I = index.search(test_set, n_neighbours)
    return D,I	
    

测试

Distance, Index = KNN_cos(train, test,3)

Distance: (array([[0.99999994, 0.8944272 , 0.70710677]], dtype=float32),
Index: array([[0, 3, 2]]))

在github上看到有人给出这样的解决方法

num_vectors = 1000000
vector_dim = 1024
vectors = np.random.rand(num_vectors, vector_dim)

#sample index code
quantizer = faiss.IndexFlatIP(1024)
index = faiss.IndexIVFFlat(quantizer, vector_dim, int(np.sqrt(num_vectors)), faiss.METRIC_INNER_PRODUCT) # 利用IVFFLat提升效率
train_vectors = vectors[:int(num_vectors/2)].copy()
faiss.normalize_L2(train_vectors)
index.train(train_vectors)
faiss.normalize_L2(vectors)
index.add(vectors)
#index creation done

#let's search
query_vector = np.random.rand(10, 1024)
faiss.normalize_L2(query_vector)
D, I = index.search(query_vector, 100)

print(D)

其实这里做了个提速:利用IVFlat先进行聚类再索引,提升效率,详见可以看官方源码

总结

关于faiss库进行索引查询还有很多操作,特别是对于海量数据,合理的利用faiss可以极大提升效率。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/422660.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号