,
✏️ 编者按
每年暑期,Milvus 社区都会携手中科院软件所,在「开源之夏」活动中为高校学生们准备丰富的工程项目,并安排导师答疑解惑。张煜旻同学在「开源之夏」活动中表现优秀,他相信进一寸有进一寸的欢喜,尝试在贡献开源的过程中超越自我。
他的项目为 Milvus 数据库的向量查询操作提供精度控制,能让开发者自定义返回精度,在减少内存消耗的同时,提高了返回结果的可读性。
想要了解更多优质开源项目和项目经验分享?请戳:有哪些值得参与的开源项目?
项目简介
项目名称:支持指定搜索时返回的距离精度
学生简介:张煜旻,中国科学院大学电子信息软件工程专业硕士在读
项目导师:Zilliz 软件工程师张财
导师评语:张煜旻同学优化了 Milvus 数据库的查询功能,使其在搜索时可以用指定精度去进行查询,使搜索过程更灵活,用户可以根据自己的需求用不同的精度进行查询,给用户带来了便利。
支持指定搜索时返回的距离精度
任务简介
在进行向量查询时,搜索请求返回 id 和 distance 字段,其中的 distance 字段类型是浮点数。Milvus 数据库所计算的距离是一个 32 位浮点数,但是 Python SDK 返回并以 64 位浮点显示它,导致某些精度无效。本项目的贡献是,支持指定搜索时返回的距离精度,解决了在 Python 端显示时部分精度无效的情况,并减少部分内存开销。
项目目标
解决计算结果和显示精度不匹配的问题
支持搜索时返回指定的距离精度
补充相关文档
项目步骤
前期调研,理解 Milvus 整体框架
明确各模块之间的调用关系
设计解决方案和确认结果
项目综述
什么是 Milvus 数据库?
Milvus 是一款开源向量数据库,赋能 AI 应用和向量相似度搜索。在系统设计上, Milvus 数据库的前端有方便用户使用的 Python SDK(Client);在 Milvus 数据库的后端,整个系统分为了接入层(Access Layer)、协调服务(Coordinator Server)、执行节点(Worker Node)和存储服务(Storge)四个层面:
(1)接入层(Access Layer):系统的门面,包含了一组对等的 Proxy 节点。接入层是暴露给用户的统一 endpoint,负责转发请求并收集执行结果。
(2)协调服务(Coordinator Service):系统的大脑,负责分配任务给执行节点。共有四类协调者角色:root 协调者、data 协调者、query 协调者和 index 协调者。
(3)执行节点(Worker Node):系统的四肢,执行节点只负责被动执行协调服务发起的读写请求。目前有三类执行节点:data 节点、query 节点和 index 节点。
(4)存储服务(Storage):系统的骨骼,是所有其他功能实现的基础。Milvus 数据库依赖三类存储:元数据存储、消息存储(log broker)和对象存储。从语言角度来看,则可以看作三个语言层,分别是 Python 构成的 SDK 层、Go 构成的中间层和 C++ 构成的核心计算层。
Milvus 数据库的架构图
向量查询 Search 时,到底发生了什么?
在 Python SDK 端,当用户发起一个 Search API 调用时,这个调用会被封装成 gRPC 请求并发送给 Milvus 后端,同时 SDK 开始等待。而在后端,Proxy 节点首先接受了从 Python SDK 发送过来的请求,然后会对接受的请求进行处理,最后将其封装成 message,经由 Producer 发送到消费队列中。当消息被发送到消费队列后,Coordinator 将会对其进行协调,将信息发送到合适的 query node 中进行消费。而当 query node 接收到消息后,则会对消息进行进一步的处理,最后将信息传递给由 C++ 构成的计算层。在计算层,则会根据不同的情形,调用不同的计算函数对向量间的距离进行计算。当计算完成后,结果则会依次向上传递,直到到达 SDK 端。
解决方案设计
通过前文简单介绍,我们对向量查询的过程有了一个大致的概念。同时,我们也可以清楚地认识到,为了完成查询目标,我们需要对 Python 构成的 SDK 层、Go 构成的中间层和 C++ 构成的计算层都进行修改,修改方案如下:
1. 在 Python 层中的修改步骤:
为向量查询 Search 请求添加一个 round_decimal 参数,从而确定返回的精度信息。同时,需要对参数进行一些合法性检查和异常处理,从而构建 gRPC 的请求:
round_decimal = param_copy("round_decimal", 3)
if not isinstance(round_decimal, (int, str))
raise ParamError("round_decimal must be int or str")
try:
round_decimal = int(round_decimal)
except Exception:
raise ParamError("round_decimal is not illegal")
if round_decimal < 0 or round_decimal > 6:
raise ParamError("round_decimal must be greater than zero and less than seven")
if not instance(params, dict):
raise ParamError("Search params must be a dict")
search_params = {"anns_field": anns_field, "topk": limit, "metric_type": metric_type, "params": params, "round_decimal": round_decimal}
2. 在 Go 层中的修改步骤:
在 task.go 文件中添加 RoundDecimalKey 这个常量,保持风格统一并方便后续调取:
const ( InsertTaskName = "InsertTask" CreateCollectionTaskName = "CreateCollectionTask" DropCollectionTaskName = "DropCollectionTask" SearchTaskName = "SearchTask" RetrieveTaskName = "RetrieveTask" QueryTaskName = "QueryTask" AnnsFieldKey = "anns_field" TopKKey = "topk" MetricTypeKey = "metric_type" SearchParamsKey = "params" RoundDecimalKey = "round_decimal" HasCollectionTaskName = "HasCollectionTask" DescribeCollectionTaskName = "DescribeCollectionTask"
接着,修改 PreExecute 函数,获取 round_decimal 的值,构建 queryInfo 变量,并添加异常处理:
searchParams, err := funcutil.GetAttrByKeyFromRepeatedKV(SearchParamsKey, st.query.SearchParams)
if err != nil {
return errors.New(SearchParamsKey + " not found in search_params")
}
roundDecimalStr, err := funcutil.GetAttrByKeyFromRepeatedKV(RoundDecimalKey, st.query.SearchParams)
if err != nil {
return errors.New(RoundDecimalKey + "not found in search_params")
}
roundDeciaml, err := strconv.Atoi(roundDecimalStr)
if err != nil {
return errors.New(RoundDecimalKey + " " + roundDecimalStr + " is not invalid")
}
queryInfo := &planpb.QueryInfo{
Topk: int64(topK),
MetricType: metricType,
SearchParams: searchParams,
RoundDecimal: int64(roundDeciaml),
}
同时,修改 query 的 proto 文件,为 QueryInfo 添加 round_decimal 变量:
message QueryInfo {
int64 topk = 1;
string metric_type = 3;
string search_params = 4;
int64 round_decimal = 5;
}
3. 在 C++ 层中的修改步骤:
在 SearchInfo 结构体中添加新的变量 round_decimal_ ,从而接受 Go 层传来的 round_decimal 值:
struct SearchInfo {
int64_t topk_;
int64_t round_decimal_;
FieldOffset field_offset_;
MetricType metric_type_;
nlohmann::json search_params_;
};
在 ParseVecNode 和 PlanNodeFromProto 函数中,SearchInfo 结构体需要接受 Go 层中 round_decimal 值:
std::unique_ptrParser::ParseVecNode(const Json& out_body) { Assert(out_body.is_object()); Assert(out_body.size() == 1); auto iter = out_body.begin(); auto field_name = FieldName(iter.key()); auto& vec_info = iter.value(); Assert(vec_info.is_object()); auto topk = vec_info["topk"]; AssertInfo(topk > 0, "topk must greater than 0"); AssertInfo(topk < 16384, "topk is too large"); auto field_offset = schema.get_offset(field_name); auto vec_node = [&]() -> std::unique_ptr { auto& field_meta = schema.operator[](field_name); auto data_type = field_meta.get_data_type(); if (data_type == DataType::VECTOR_FLOAT) { return std::make_unique (); } else { return std::make_unique (); } }(); vec_node->search_info_.topk_ = topk; vec_node->search_info_.metric_type_ = GetMetricType(vec_info.at("metric_type")); vec_node->search_info_.search_params_ = vec_info.at("params"); vec_node->search_info_.field_offset_ = field_offset; vec_node->search_info_.round_decimal_ = vec_info.at("round_decimal"); vec_node->placeholder_tag_ = vec_info.at("query"); auto tag = vec_node->placeholder_tag_; AssertInfo(!tag2field_.count(tag), "duplicated placeholder tag"); tag2field_.emplace(tag, field_offset); return vec_node; }
std::unique_ptrProtoParser::PlanNodeFromProto(const planpb::PlanNode& plan_node_proto) { // TODO: add more buffs Assert(plan_node_proto.has_vector_anns()); auto& anns_proto = plan_node_proto.vector_anns(); auto expr_opt = [&]() -> std::optional { if (!anns_proto.has_predicates()) { return std::nullopt; } else { return ParseExpr(anns_proto.predicates()); } }(); auto& query_info_proto = anns_proto.query_info(); SearchInfo search_info; auto field_id = FieldId(anns_proto.field_id()); auto field_offset = schema.get_offset(field_id); search_info.field_offset_ = field_offset; search_info.metric_type_ = GetMetricType(query_info_proto.metric_type()); search_info.topk_ = query_info_proto.topk(); search_info.round_decimal_ = query_info_proto.round_decimal(); search_info.search_params_ = json::parse(query_info_proto.search_params()); auto plan_node = [&]() -> std::unique_ptr { if (anns_proto.is_binary()) { return std::make_unique (); } else { return std::make_unique (); } }(); plan_node->placeholder_tag_ = anns_proto.placeholder_tag(); plan_node->predicate_ = std::move(expr_opt); plan_node->search_info_ = std::move(search_info); return plan_node; }
在 SubSearchResult 类添加新的成员变量 round_decimal,同时修改每一处的 SubSearchResult 变量声明:
class SubSearchResult {
public:
SubSearchResult(int64_t num_queries, int64_t topk, MetricType metric_type)
: metric_type_(metric_type),
num_queries_(num_queries),
topk_(topk),
labels_(num_queries * topk, -1),
values_(num_queries * topk, init_value(metric_type)) {
}
在 SubSearchResult 类添加一个新的成员函数,以便最后对每一个结果进行四舍五入精度控制:
void
SubSearchResult::round_values() {
if (round_decimal_ == -1)
return;
const float multiplier = pow(10.0, round_decimal_);
for (auto it = this->values_.begin(); it != this->values_.end(); it++) {
*it = round(*it * multiplier) / multiplier;
}
}
为 SearchDataset 结构体添加新的变量 round_decimal,同时修改每一处的 SearchDataset 变量声明:
struct SearchDataset {
MetricType metric_type;
int64_t num_queries;
int64_t topk;
int64_t round_decimal;
int64_t dim;
const void* query_data;
};
修改 C++ 层中各个距离计算函数(FloatSearch、BinarySearchBruteForceFast 等等),使其接受 round_decomal 值:
StatusFloatSearch(const segcore::SegmentGrowingImpl& segment, const query::SearchInfo& info, const float* query_data, int64_t num_queries, int64_t ins_barrier, const BitsetView& bitset, SearchResult& results) { auto& schema = segment.get_schema(); auto& indexing_record = segment.get_indexing_record(); auto& record = segment.get_insert_record(); // step 1: binary search to find the barrier of the snapshot // auto del_barrier = get_barrier(deleted_record_, timestamp);#if 0 auto bitmap_holder = get_deleted_bitmap(del_barrier, timestamp, ins_barrier); Assert(bitmap_holder); auto bitmap = bitmap_holder->bitmap_ptr;#endif // step 2.1: get meta // step 2.2: get which vector field to search auto vecfield_offset = info.field_offset_; auto& field = schema[vecfield_offset]; AssertInfo(field.get_data_type() == DataType::VECTOR_FLOAT, "[FloatSearch]Field data type isn't VECTOR_FLOAT"); auto dim = field.get_dim(); auto topk = info.topk_; auto total_count = topk * num_queries; auto metric_type = info.metric_type_; auto round_decimal = info.round_decimal_; // step 3: small indexing search // std::vector final_uids(total_count, -1); // std::vector final_dis(total_count, std::numeric_limits::max()); SubSearchResult final_qr(num_queries, topk, metric_type, round_decimal); dataset::SearchDataset search_dataset{metric_type, num_queries, topk, round_decimal, dim, query_data}; auto vec_ptr = record.get_field_data(vecfield_offset); int current_chunk_id = 0;
SubSearchResult
BinarySearchBruteForceFast(MetricType metric_type,
int64_t dim,
const uint8_t* binary_chunk,
int64_t size_per_chunk,
int64_t topk,
int64_t num_queries,
int64_t round_decimal,
const uint8_t* query_data,
const faiss::BitsetView& bitset) {
SubSearchResult sub_result(num_queries, topk, metric_type, round_decimal);
float* result_distances = sub_result.get_values();
idx_t* result_labels = sub_result.get_labels();
int64_t code_size = dim / 8;
const idx_t block_size = size_per_chunk;
raw_search(metric_type, binary_chunk, size_per_chunk, code_size, num_queries, query_data, topk, result_distances,
result_labels, bitset);
sub_result.round_values();
return sub_result;
}
结果确认
1. 对 Milvus 数据库进行重新编译:
2. 启动环境容器:
3. 启动 Milvus 数据库:
4.构建向量查询请求:
5. 确认结果,默认保留 3 位小数,0 舍去:
总结和感想
参加这次的夏季开源活动,对我来说是非常宝贵的经历。在这次活动中,我第一次尝试阅读开源项目代码,第一次尝试接触多语言构成的项目,第一次接触到 Make、gRPc、pytest 等等。在编写代码和测试代码阶段,我也遇到来许多意想不到的问题,例如,「奇奇怪怪」的依赖问题、由于 Conda 环境导致的编译失败问题、测试无法通过等等。面对这些问题,我渐渐学会耐心细心地查看报错日志,积极思考、检查代码并进行测试,一步一步缩小错误范围,定位错误代码并尝试各种解决方案。
通过这次的活动,我吸取了许多经验和教训,同时也十分感谢张财导师,感谢他在我开发过程中耐心地帮我答疑解惑、指导方向!同时,希望大家能多多关注 Milvus 社区,相信一定能够有所收获!
最后,欢迎大家多多与我交流( deepmin@mail.deepexplore.top ),我主要的研究方向是自然语言处理,平时喜欢看科幻小说、动画和折腾服务器个人网站,每日闲逛 Stack Overflow 和GitHub。我相信进一寸有进一寸的欢喜,希望能和你一起共同进步。
Zilliz 以重新定义数据科学为愿景,致力于打造一家全球领先的开源技术创新公司,并通过开源和云原生解决方案为企业解锁非结构化数据的隐藏价值。
Zilliz 构建了 Milvus 向量数据库,以加快下一代数据平台的发展。Milvus 数据库是 LF AI & Data 基金会的毕业项目,能够管理大量非结构化数据集,在新药发现、推荐系统、聊天机器人等方面具有广泛的应用。
解锁更多应用场景



