预置Data Search实现
更新时间:2026-05-13
主流程代码
PYTHON
1# 定义一个 main 函数,传入 params 参数。params 中包含配置的输入变量。
2# 需要定义一个字典作为输出变量,字典详情见输出参数
3# 引用输入变量方式:params['变量名'] or params.get('变量名')
4# 运行环境 Python3
5# 预置 Package:databuilder_logic_sdk,使用方法见产品文档
6
7def main(params):
8 # step 0
9 context = SearchContext.from_params(params)
10
11 # step 1 --- retrieve
12 retrieve_result = dict()
13 try:
14 retriever = Retriever()
15 retrieve_result = retriever.retrieve(context=context)
16 except Exception as exc:
17 error_message = str(exc) or "unknown error"
18 return {
19 "code": -1,
20 "message": f" {error_message}",
21 "chunk": [],
22 }
23
24 chunks = retrieve_result.get("chunk") or []
25 if chunks == []:
26 return retrieve_result
27
28 # step 2 --- rerank
29 if context.rerank_enabled:
30 reranker = Reranker()
31 rerank_inputs: list[str] = []
32 for chunk in chunks:
33 content = chunk.get("content")
34 rerank_inputs.append(content if isinstance(content, str) else "")
35 rerank_response = reranker.rerank(
36 query=context.query,
37 chunks=rerank_inputs,
38 top_k=context.top_k,
39 )
40
41 retrieve_result["chunk"] = reranker.sort(chunks, rerank_response.results)
42
43 if len(retrieve_result["chunk"]) > context.top_k:
44 retrieve_result["chunk"] = retrieve_result["chunk"][0 : context.top_k]
45
46 # step 3 --- content expansion
47 chunk_expansion = params.get("chunk_expansion") or {}
48 if chunk_expansion:
49 window_size = chunk_expansion.get("window_size") or 0
50 retrieve_result["chunk"] = retriever.expand_chunks_after_rerank(
51 context,
52 retrieve_result["chunk"],
53 window_size=window_size,
54 )
55
56 return retrieve_result
方法实现
PYTHON
1from __future__ import annotations
2from collections import defaultdict
3from collections.abc import Iterable
4import json
5import builtins
6
7from typing import Any, Sequence
8from databuilder_logic_sdk.core import (
9 create_embedding_service,
10 create_ontology_service,
11 create_rerank_service,
12)
13from databuilder_logic_sdk.utils import logger
14
15# DEFAULT_TOP_K = 6
16DEFAULT_SCORE_THRESHOLD = 0.1
17DEFAULT_RECALL_TYPE = "fulltext"
18DEFAULT_EMBEDDING_MODEL = "text-embed_7b_bf16"
19
20# Ontology `contains` (ann_distance) may return 0 rows for very small limits
21# (e.g. limit=1), even when larger limits return results. Keep a small floor
22# to make semantic/hybrid recall stable.
23MIN_SEMANTIC_RECALL_TOP_N = 5
24
25BM25_NORMALIZATION_CAP = 50.0
26HYBRID_AUTO_PASS_THRESHOLD = 0.9
27
28CHUNK_TYPE_CHUNK = "chunk"
29CHUNK_TYPE_SENTENCE = "sentence"
30CHUNK_TYPE_CUSTOM_SENTENCE = "custom_sentence"
31
32
33class SearchContext:
34 def __init__(
35 self,
36 query: str,
37 ontology_name: str,
38 object_type: str,
39 fulltext_column: str,
40 semantic_column: str,
41 chunk_id_column: str,
42 parent_id_column: str,
43 type_column: str,
44 recall_type: str,
45 top_k: int,
46 score_threshold: float,
47 recall_top_n: int,
48 vec_weight: float,
49 rerank_enabled: bool,
50 metadata_filters: dict,
51 ) -> None:
52 """
53 初始化搜索上下文
54
55 Args:
56 query: 搜索查询
57 ontology_name: 本体名称
58 object_type: 对象类型
59 fulltext_column: 全文检索列名
60 semantic_column: 语义检索列名
61 chunk_id_column: 分块ID列名
62 parent_id_column: 父ID列名
63 type_column: 类型列名
64 recall_type: 召回类型
65 top_k: 返回结果数
66 score_threshold: 分数阈值
67 recall_top_n: 召回数量
68 vec_weight: 向量权重
69 rerank_enabled: 是否开启 rerank
70 """
71 self.query = query
72 self.ontology_name = ontology_name
73 self.object_type = object_type
74 self.fulltext_column = fulltext_column
75 self.semantic_column = semantic_column
76 self.chunk_id_column = chunk_id_column
77 self.parent_id_column = parent_id_column
78 self.type_column = type_column
79 self.recall_type = recall_type
80 self.top_k = top_k
81 self.score_threshold = score_threshold
82 self.recall_top_n = recall_top_n
83 self.vec_weight = vec_weight
84 self.rerank_enabled = rerank_enabled
85 self.metadata_filters = metadata_filters
86
87 @classmethod
88 def from_params(cls, params: dict[str, Any]) -> SearchContext:
89 query = (params.get("query") or "").strip()
90 if not query:
91 raise ValueError("query cannot be empty")
92
93 ontology = params.get("ontology") or {}
94 ontology_name = ontology.get("ontology_name") or ""
95 object_type = ontology.get("object_type") or ""
96 if not ontology_name or not object_type:
97 raise ValueError("ontology_name and object_type are required")
98
99 fulltext_column = ontology.get("fulltext_column")
100 semantic_column = ontology.get("semantic_column")
101
102 recall = params.get("recall") or {}
103 recall_type = recall.get("type", DEFAULT_RECALL_TYPE)
104
105 if recall_type in {"fulltext", "hybrid"} and not fulltext_column:
106 raise ValueError("fulltext_column is required for fulltext/hybrid recall")
107 if recall_type in {"semantic", "hybrid"} and not semantic_column:
108 raise ValueError("semantic_column is required for semantic/hybrid recall")
109
110 top_k = params.get("top_k") or ""
111 if not top_k:
112 raise ValueError("top_k is required")
113 try:
114 top_k = int(top_k)
115 except (TypeError, ValueError):
116 raise ValueError("top_k must be an integer")
117 if top_k <= 0:
118 raise ValueError("top_k must be positive")
119
120 score_threshold = float(params.get("score_threshold", DEFAULT_SCORE_THRESHOLD))
121 if score_threshold < 0 or score_threshold > 1:
122 raise ValueError("score_threshold must be between 0 and 1")
123
124 recall_top_n = recall.get("top_n")
125 if recall_top_n is None:
126 recall_top_n = top_k * 2
127 else:
128 try:
129 recall_top_n = int(recall_top_n)
130 except (TypeError, ValueError):
131 raise ValueError("recall.top_n mustbe an integer")
132 if recall_top_n <= 0:
133 raise ValueError("recall.top_n must be positive")
134
135 vec_weight = recall.get("vec_weight")
136 if vec_weight is not None:
137 try:
138 vec_weight = float(vec_weight)
139 except (TypeError, ValueError):
140 raise ValueError("recall.vec_weight must be a number")
141 if vec_weight < 0 or vec_weight > 1:
142 raise ValueError("recall.vec_weight must be between 0 and 1")
143 else:
144 vec_weight = 0.5
145
146 rerank_enabled = params.get("rerank")
147 if isinstance(rerank_enabled, str):
148 rerank_enabled = rerank_enabled.strip().lower() in {"true", "1", "yes", "y"}
149 else:
150 rerank_enabled = False
151
152 if rerank_enabled and recall_top_n < top_k:
153 logger.warning(
154 "recall top_n is smaller than top_k; rerank will receive fewer candidates",
155 extra={
156 "top_k": top_k,
157 "recall_top_n": recall_top_n,
158 },
159 )
160 if recall_type in {"semantic", "hybrid"} and recall_top_n < MIN_SEMANTIC_RECALL_TOP_N:
161 logger.warning(
162 "semantic recall_top_n too small for contains; bumping",
163 extra={
164 "original": recall_top_n,
165 "bumped": MIN_SEMANTIC_RECALL_TOP_N,
166 "recall_type": recall_type,
167 },
168 )
169 recall_top_n = MIN_SEMANTIC_RECALL_TOP_N
170
171 metadata_filters = params.get("metadata_filters") or {}
172 if not isinstance(metadata_filters, dict):
173 raise ValueError("metadata_filters must be a dict")
174 if len(metadata_filters) != 0:
175 operator = metadata_filters.get("operator") or ""
176 if operator != "":
177 if operator not in {"in", "notin", "not in"}:
178 raise ValueError('operator must be one of "in", "notin", "not in"')
179 field = metadata_filters.get("field")
180 if field != "docId":
181 raise ValueError("field currently can only be docId")
182 value = metadata_filters.get("value")
183 if not isinstance(value, list):
184 raise ValueError("value must be a list for 'in'/'notin'")
185 else:
186 metadata_filters = {}
187 chunk_expansion = params.get("chunk_expansion") or {}
188 if not isinstance(chunk_expansion, dict):
189 raise ValueError("chunk_expansion must be a dict")
190 if chunk_expansion != {}:
191 expansion_type = chunk_expansion.get("type") or ""
192 if expansion_type != "":
193 if expansion_type != "window_expansion":
194 raise ValueError("currently chunk_expansion type only support window_expansion")
195 window_size = chunk_expansion.get("window_size") or 0
196 if not isinstance(window_size, int):
197 raise ValueError("window_size must be int or not empty")
198 if window_size < 0 or window_size >= 100:
199 raise ValueError("invalid window_size")
200
201 return cls(
202 query=query,
203 ontology_name=ontology_name,
204 object_type=object_type,
205 fulltext_column=str(fulltext_column),
206 semantic_column=str(semantic_column),
207 chunk_id_column=str(ontology.get("chunk_id_column", "chunkId")),
208 parent_id_column=str(ontology.get("parent_id_column", "parentId")),
209 type_column=str(ontology.get("type_column", "type")),
210 recall_type=str(recall_type),
211 top_k=top_k,
212 score_threshold=score_threshold,
213 recall_top_n=recall_top_n,
214 vec_weight=vec_weight,
215 rerank_enabled=rerank_enabled,
216 metadata_filters=metadata_filters,
217 )
218
219
220class BasePipeline:
221 def __init__(self, ontology_service: Any) -> None:
222 self.ontology_service = ontology_service
223
224 @staticmethod
225 def _format_output(chunks: list[ScoredChunk]) -> dict[str, Any]:
226 return {
227 "code": 0,
228 "message": "",
229 "chunk": [chunk.to_dict() for chunk in chunks],
230 }
231
232 @staticmethod
233 def _format_error(exc: Exception) -> dict[str, Any]:
234 error_message = str(exc) or "unknown error"
235 return {
236 "code": -1,
237 "message": f"{error_message}",
238 "chunk": [],
239 }
240
241 @staticmethod
242 def _response_rows(response: Any) -> list[dict[str, Any]]:
243 code = builtins.getattr(response, "code", None)
244 if code not in (None, 0, "0", "SUCCESS"):
245 result = builtins.getattr(response, "result", None) or {}
246 sql = result.get("sql") if isinstance(result, dict) else None
247 raise RuntimeError(
248 f"ontology search failed: code={code} message={builtins.getattr(response, 'message', '')} sql={sql}"
249 )
250 result = builtins.getattr(response, "result", None) or {}
251 data = result.get("data")
252 if not data:
253 return []
254 if not isinstance(data, list):
255 raise TypeError(f"ontology response data is not a list")
256 return data
257
258 def _ontology_search(self, *, dsl: dict[str, Any]) -> list[dict[str, Any]]:
259 response = self.ontology_service.search(dsl=dsl)
260 try:
261 return self._response_rows(response)
262 except Exception as exc:
263 raise RuntimeError(f"ontology search failed: {exc}. dsl={dsl}") from exc
264
265 @staticmethod
266 def _require_row_field(row: dict[str, Any], field: str, *, hint: str) -> Any:
267 if field not in row:
268 raise ValueError(f"missing required field '{field}' in {hint}: {row}")
269 return row.get(field)
270
271 @staticmethod
272 def _filter_chunks(chunks: list[ScoredChunk], *, score_threshold: float) -> list[ScoredChunk]:
273 return [chunk for chunk in chunks if chunk.score >= score_threshold]
274
275 @staticmethod
276 def _sort_chunks(chunks: list[ScoredChunk]) -> None:
277 chunks.sort(
278 key=lambda chunk: chunk.score,
279 reverse=True,
280 )
281
282 @staticmethod
283 def _normalize_score(score: Any) -> float:
284 try:
285 return float(score)
286 except (TypeError, ValueError):
287 return 0.0
288
289 @staticmethod
290 def _build_chunk(
291 *,
292 chunk_id: str,
293 content: str,
294 object_type: str,
295 ontology: str,
296 score: float,
297 doc_id: str | None = None,
298 chunk_idx: int | None = None,
299 ) -> ScoredChunk:
300 return ScoredChunk(
301 chunk_id=chunk_id,
302 content=content,
303 object_type=object_type,
304 ontology=ontology,
305 score=score,
306 doc_id=doc_id,
307 chunk_idx=chunk_idx,
308 )
309
310 @staticmethod
311 def _extract_doc_id(row: dict[str, Any]) -> str | None:
312 doc_id = row.get("docId")
313 if not doc_id:
314 return None
315 return str(doc_id)
316
317 @staticmethod
318 def _extract_chunk_idx(row: dict[str, Any]) -> int | None:
319 chunk_idx = row.get("chunkIdx")
320 if isinstance(chunk_idx, bool) or not isinstance(chunk_idx, int):
321 return None
322 return chunk_idx
323
324 @staticmethod
325 def _apply_metadata_filters(where: dict[str, Any], ctx: SearchContext) -> dict[str, Any]:
326 metadata_filters = ctx.metadata_filters or {}
327 if not metadata_filters:
328 return where
329
330 operator = metadata_filters.get("operator") or ""
331 if not operator:
332 return where
333
334 operator_map = {
335 "in": "in",
336 "notin": "not in",
337 "not in": "not in",
338 }
339 mapped_operator = operator_map.get(operator)
340 if not mapped_operator:
341 return where
342 metadata_node = {
343 "type": mapped_operator,
344 "field": metadata_filters.get("field"),
345 "value": metadata_filters.get("value"),
346 }
347
348 updated_where = dict(where)
349 existing_filters = list(updated_where.get("filter") or [])
350 existing_filters.append(metadata_node)
351 updated_where["filter"] = existing_filters
352 return updated_where
353
354
355class Chunk:
356 chunk_id: str
357 content: str
358 object_type: str
359 ontology: str
360
361 def to_dict(self) -> dict[str, Any]:
362 return {
363 "chunk_id": self.chunk_id,
364 "content": self.content,
365 "object_type": self.object_type,
366 "ontology": self.ontology,
367 }
368
369
370class ScoredChunk:
371 def __init__(
372 self,
373 chunk_id: str,
374 content: str,
375 object_type: str,
376 ontology: str,
377 score: float,
378 doc_id: str | None = None,
379 chunk_idx: int | None = None,
380 ):
381 """
382 初始化带分数的文本块
383
384 Args:
385 chunk_id: 块ID
386 content: 内容
387 object_type: 对象类型
388 ontology: 本体
389 score: 分数 (0-1)
390 """
391 self.chunk_id = chunk_id
392 self.content = content
393 self.object_type = object_type
394 self.ontology = ontology
395 self.score = score
396 self.doc_id = doc_id
397 self.chunk_idx = chunk_idx
398
399 def to_dict(self) -> dict[str, Any]:
400 return {
401 "chunk_id": self.chunk_id,
402 "content": self.content,
403 "object_type": self.object_type,
404 "ontology": self.ontology,
405 }
406
407
408def _normalize_bm25(score: float) -> float:
409 if score <= 0:
410 return 0.0
411 if score > BM25_NORMALIZATION_CAP:
412 score = BM25_NORMALIZATION_CAP
413 return score / BM25_NORMALIZATION_CAP
414
415
416def _parse_embedding(value: Any) -> list[float]:
417 if value is None:
418 return []
419
420 if isinstance(value, list):
421 return value
422
423 if isinstance(value, str):
424 try:
425 parsed = json.loads(value)
426 except json.JSONDecodeError:
427 return []
428 if not isinstance(parsed, list):
429 return []
430 out = []
431 for item in parsed:
432 try:
433 out.append(float(item))
434 except (TypeError, ValueError):
435 return []
436 return out
437 return []
438
439
440def _cosine_similarity(left: Sequence[float], right: Sequence[float]) -> float:
441 if len(left) != len(right) or not left:
442 return 0.0
443
444 dot = sum(x * y for x, y in zip(left, right))
445 norm_l = sum(x * x for x in left) ** 0.5
446 norm_r = sum(x * x for x in right) ** 0.5
447 return dot / (norm_l * norm_r) if norm_l and norm_r else 0.0
448
449
450class FulltextPipeline(BasePipeline):
451 def __init__(self, ontology_service: Any) -> None:
452 super().__init__(ontology_service)
453
454 def run(self, ctx: SearchContext) -> dict[str, Any]:
455 try:
456 sentence_rows = self._search_sentences(ctx, chunk_type=CHUNK_TYPE_SENTENCE, limit=ctx.recall_top_n)
457 custom_rows = self._search_sentences(ctx, chunk_type=CHUNK_TYPE_CUSTOM_SENTENCE, limit=ctx.recall_top_n)
458
459 all_rows = sentence_rows + custom_rows
460 if not all_rows:
461 return self._format_output([])
462
463 parent_scores = self._aggregate_parent_scores(all_rows, ctx)
464 if not parent_scores:
465 return self._format_output([])
466
467 parent_ids = list(parent_scores.keys())
468 paragraphs = self._fetch_paragraphs_by_ids(ctx, parent_ids)
469
470 chunks: list[ScoredChunk] = []
471 for para in paragraphs:
472 chunk_id = para.get(ctx.chunk_id_column)
473 content = para.get("content")
474 if not chunk_id or content is None:
475 continue
476 if chunk_id not in parent_scores:
477 raise ValueError(f"missing bm25 for chunk {chunk_id}")
478 bm25_score = parent_scores[chunk_id]
479 doc_id = self._extract_doc_id(para)
480 chunk_idx = self._extract_chunk_idx(para)
481 chunks.append(
482 self._build_chunk(
483 chunk_id=chunk_id,
484 content=content,
485 object_type=ctx.object_type,
486 ontology=ctx.ontology_name,
487 score=bm25_score,
488 doc_id=doc_id,
489 chunk_idx=chunk_idx,
490 )
491 )
492
493 filtered = self._filter_chunks(chunks, score_threshold=ctx.score_threshold)
494 self._sort_chunks(filtered)
495 return self._format_output(filtered[: ctx.recall_top_n])
496 except Exception as exc:
497 logger.error("fulltext pipeline failed: %s", exc)
498 return self._format_error(exc)
499
500 def _search_sentences(self, ctx: SearchContext, *, chunk_type: str, limit: int) -> list[dict[str, Any]]:
501 dsl = {
502 "ontology": ctx.ontology_name,
503 "apiName": ctx.object_type,
504 "limit": limit,
505 "select": [
506 {"field": "content"},
507 {"field": ctx.chunk_id_column},
508 {"field": ctx.parent_id_column},
509 {"field": ctx.type_column},
510 {"field": "__SCORE", "orderBy": "desc"},
511 ],
512 "where": {
513 "type": "and",
514 "filter": [
515 {"type": "match_any", "field": ctx.fulltext_column, "value": ctx.query},
516 {"type": "eq", "field": ctx.type_column, "value": chunk_type},
517 ],
518 },
519 }
520 dsl["where"] = self._apply_metadata_filters(dsl["where"], ctx)
521 return self._ontology_search(dsl=dsl)
522
523 def _aggregate_parent_scores(self, rows: Iterable[dict[str, Any]], ctx: SearchContext) -> dict[str, float]:
524 scores: dict[str, float] = defaultdict(float)
525 for row in rows:
526 parent_id = row.get(ctx.parent_id_column)
527 if "__SCORE" not in row:
528 raise ValueError(f"missing __SCORE in fulltext row: {row}")
529 score = self._normalize_score(row.get("__SCORE"))
530 if not parent_id:
531 continue
532 if score > scores[parent_id]:
533 scores[parent_id] = score
534 return scores
535
536 def _fetch_paragraphs_by_ids(self, ctx: SearchContext, parent_ids: list[str]) -> list[dict[str, Any]]:
537 if not parent_ids:
538 return []
539
540 dsl = {
541 "ontology": ctx.ontology_name,
542 "apiName": ctx.object_type,
543 "limit": len(parent_ids),
544 "select": [
545 {"field": "content"},
546 {"field": ctx.chunk_id_column},
547 {"field": ctx.parent_id_column},
548 {"field": ctx.type_column},
549 {"field": "docId"},
550 {"field": "chunkIdx"},
551 ],
552 "where": {
553 "type": "and",
554 "filter": [
555 {"type": "eq", "field": ctx.type_column, "value": CHUNK_TYPE_CHUNK},
556 {"type": "in", "field": ctx.chunk_id_column, "value": parent_ids},
557 ],
558 },
559 }
560 dsl["where"] = self._apply_metadata_filters(dsl["where"], ctx)
561 return self._ontology_search(dsl=dsl)
562
563
564class SemanticPipeline(BasePipeline):
565 def __init__(self, ontology_service: Any, embedding_service: Any) -> None:
566 super().__init__(ontology_service)
567 self.embedding_service = embedding_service
568
569 def run(self, ctx: SearchContext) -> dict[str, Any]:
570 try:
571 embeddings = self.embedding_service.embed([ctx.query])
572 if not embeddings.results:
573 return self._format_output([])
574
575 vector = embeddings.results[0].embedding
576 dsl = {
577 "ontology": ctx.ontology_name,
578 "apiName": ctx.object_type,
579 "limit": ctx.recall_top_n,
580 "select": [
581 {"field": "content"},
582 {"field": ctx.chunk_id_column},
583 {"field": ctx.semantic_column},
584 {"field": ctx.type_column},
585 {"field": "docId"},
586 {"field": "chunkIdx"},
587 ],
588 "where": {
589 "type": "and",
590 "filter": [
591 {"type": "contains", "field": ctx.semantic_column, "value": vector},
592 {"type": "eq", "field": ctx.type_column, "value": CHUNK_TYPE_CHUNK},
593 ],
594 },
595 }
596 dsl["where"] = self._apply_metadata_filters(dsl["where"], ctx)
597 rows = self._ontology_search(dsl=dsl)
598
599 chunks: list[ScoredChunk] = []
600 skipped_missing_embedding = 0
601 for row in rows or []:
602 chunk_id = row.get(ctx.chunk_id_column)
603 content = row.get("content")
604 if not chunk_id or content is None:
605 continue
606 embedding = _parse_embedding(row.get(ctx.semantic_column))
607 if not embedding:
608 skipped_missing_embedding += 1
609 continue
610 similarity = _cosine_similarity(vector, embedding)
611 doc_id = self._extract_doc_id(row)
612 chunk_idx = self._extract_chunk_idx(row)
613 chunks.append(
614 self._build_chunk(
615 chunk_id=chunk_id,
616 content=content,
617 object_type=ctx.object_type,
618 ontology=ctx.ontology_name,
619 score=similarity,
620 doc_id=doc_id,
621 chunk_idx=chunk_idx,
622 )
623 )
624
625 if skipped_missing_embedding:
626 logger.warning(
627 "semantic rows missing embedding",
628 extra={
629 "skipped": skipped_missing_embedding,
630 "ontology": ctx.ontology_name,
631 "object_type": ctx.object_type,
632 "embedding_field": ctx.semantic_column,
633 },
634 )
635
636 filtered = self._filter_chunks(chunks, score_threshold=ctx.score_threshold)
637 self._sort_chunks(filtered)
638 return self._format_output(filtered[: ctx.recall_top_n])
639 except Exception as exc:
640 logger.error("semantic pipeline failed: %s", exc)
641 return self._format_error(exc)
642
643
644class HybridPipeline(BasePipeline):
645 def __init__(self, ontology_service: Any, embedding_service: Any) -> None:
646 super().__init__(ontology_service)
647 self.embedding_service = embedding_service
648
649 def run(self, ctx: SearchContext) -> dict[str, Any]:
650 try:
651 vec_weight = float(ctx.vec_weight)
652 if not 0.0 <= vec_weight <= 1.0:
653 raise ValueError("vec_weight must be between 0 and 1")
654
655 embedding_field = ctx.semantic_column
656
657 embeddings = self.embedding_service.embed([ctx.query])
658 if not embeddings.results:
659 return self._format_output([])
660 query_vector = embeddings.results[0].embedding
661
662 recall_top_n = int(ctx.recall_top_n)
663
664 # Fulltext branch: sentence/custom_sentence -> paragraph mix.
665 fulltext_rows: list[dict[str, Any]] = []
666 fulltext_rows.extend(
667 list(
668 self._search_fulltext_sentences(
669 ctx,
670 query=ctx.query,
671 chunk_type=CHUNK_TYPE_SENTENCE,
672 limit=recall_top_n,
673 embedding_field=embedding_field,
674 )
675 )
676 )
677 fulltext_rows.extend(
678 list(
679 self._search_fulltext_sentences(
680 ctx,
681 query=ctx.query,
682 chunk_type=CHUNK_TYPE_CUSTOM_SENTENCE,
683 limit=recall_top_n,
684 embedding_field=embedding_field,
685 )
686 )
687 )
688
689 fulltext_mix, known_parent_bm25 = self._hybrid_fulltext_paragraph_mix(
690 ctx,
691 query_vector=query_vector,
692 vec_weight=vec_weight,
693 rows=fulltext_rows,
694 embedding_field=embedding_field,
695 )
696
697 # Semantic branch: paragraph vector recall.
698 paragraph_vec = self._semantic_paragraph_vectors(
699 ctx,
700 query_vector=query_vector,
701 limit=recall_top_n,
702 embedding_field=embedding_field,
703 )
704
705 semantic_mix = self._hybrid_semantic_paragraph_mix(
706 ctx,
707 query=ctx.query,
708 query_vector=query_vector,
709 vec_weight=vec_weight,
710 paragraph_ids=list(paragraph_vec.keys()),
711 known_parent_bm25=known_parent_bm25,
712 bm25_recall_limit=recall_top_n,
713 embedding_field=embedding_field,
714 )
715
716 merged_ids = set(fulltext_mix.keys()) | set(semantic_mix.keys())
717 if not merged_ids:
718 return self._format_output([])
719
720 self._supplement_paragraph_vectors(
721 ctx,
722 paragraph_vec=paragraph_vec,
723 paragraph_ids=list(merged_ids),
724 query_vector=query_vector,
725 embedding_field=embedding_field,
726 )
727
728 paragraph_mix: dict[str, float] = {}
729 for pid in merged_ids:
730 fulltext_score = fulltext_mix.get(pid)
731 semantic_score = semantic_mix.get(pid)
732 if fulltext_score is None and semantic_score is None:
733 raise ValueError(f"missing mix score for paragraph {pid}")
734 if fulltext_score is None:
735 if semantic_score is None:
736 raise ValueError(f"missing semantic mix score for paragraph {pid}")
737 paragraph_mix[pid] = semantic_score
738 elif semantic_score is None:
739 paragraph_mix[pid] = fulltext_score
740 else:
741 paragraph_mix[pid] = max(fulltext_score, semantic_score)
742
743 scored = self._merge_and_prerank(
744 ctx,
745 paragraph_mix=paragraph_mix,
746 paragraph_vec=paragraph_vec,
747 vec_weight=vec_weight,
748 )
749 if not scored:
750 return self._format_output([])
751
752 paragraph_ids = [item[0] for item in scored]
753 paragraphs = self._fetch_chunks_by_ids(ctx, paragraph_ids)
754 content_by_id: dict[str, str] = {}
755 row_by_id: dict[str, dict[str, Any]] = {}
756 for row in paragraphs:
757 cid = row.get(ctx.chunk_id_column)
758 content = row.get("content")
759 if cid and isinstance(content, str):
760 content_by_id[cid] = content
761 row_by_id[cid] = row
762
763 chunks: list[ScoredChunk] = []
764 for cid, pre_rank_score in scored:
765 content = content_by_id.get(cid)
766 if content is None:
767 continue
768 row = row_by_id.get(cid, {})
769 doc_id = self._extract_doc_id(row)
770 chunk_idx = self._extract_chunk_idx(row)
771 chunks.append(
772 self._build_chunk(
773 chunk_id=cid,
774 content=content,
775 object_type=ctx.object_type,
776 ontology=ctx.ontology_name,
777 score=pre_rank_score,
778 doc_id=doc_id,
779 chunk_idx=chunk_idx,
780 )
781 )
782
783 filtered = self._filter_chunks(chunks, score_threshold=ctx.score_threshold)
784 self._sort_chunks(filtered)
785 return self._format_output(filtered[: ctx.recall_top_n])
786 except Exception as exc:
787 logger.error("hybrid pipeline failed: %s", exc)
788 return self._format_error(exc)
789
790 def _search_fulltext_sentences(
791 self,
792 ctx: SearchContext,
793 *,
794 query: str,
795 chunk_type: str,
796 limit: int,
797 embedding_field: str,
798 ) -> list[dict[str, Any]]:
799 dsl = {
800 "ontology": ctx.ontology_name,
801 "apiName": ctx.object_type,
802 "limit": limit,
803 "select": [
804 {"field": ctx.parent_id_column},
805 {"field": embedding_field},
806 {"field": "__SCORE", "orderBy": "desc"},
807 {"field": ctx.type_column},
808 ],
809 "where": {
810 "type": "and",
811 "filter": [
812 {"type": "match_any", "field": ctx.fulltext_column, "value": query},
813 {"type": "eq", "field": ctx.type_column, "value": chunk_type},
814 ],
815 },
816 }
817 dsl["where"] = self._apply_metadata_filters(dsl["where"], ctx)
818 return self._ontology_search(dsl=dsl)
819
820 def _semantic_paragraph_vectors(
821 self,
822 ctx: SearchContext,
823 *,
824 query_vector: list[float],
825 limit: int,
826 embedding_field: str,
827 ) -> dict[str, float]:
828 # Ontology API orders by ann_distance(), but does not return the distance column.
829 # We only use this call to get candidate chunkIds, then fetch embeddings by ids
830 # and compute cosine similarity client-side.
831 dsl = {
832 "ontology": ctx.ontology_name,
833 "apiName": ctx.object_type,
834 "limit": limit,
835 "select": [
836 {"field": ctx.chunk_id_column},
837 {"field": ctx.type_column},
838 ],
839 "where": {
840 "type": "and",
841 "filter": [
842 {"type": "contains", "field": ctx.semantic_column, "value": query_vector},
843 {"type": "eq", "field": ctx.type_column, "value": CHUNK_TYPE_CHUNK},
844 ],
845 },
846 }
847 dsl["where"] = self._apply_metadata_filters(dsl["where"], ctx)
848 rows = self._ontology_search(dsl=dsl)
849
850 chunk_ids: list[str] = []
851 for row in rows or []:
852 cid = row.get(ctx.chunk_id_column)
853 if cid:
854 chunk_ids.append(cid)
855
856 if not chunk_ids:
857 return {}
858
859 embeddings = self._fetch_chunk_embeddings_by_ids(ctx, chunk_ids, embedding_field=embedding_field)
860 paragraph_vec: dict[str, float] = {}
861 for cid, embedding in embeddings.items():
862 paragraph_vec[cid] = _cosine_similarity(query_vector, embedding)
863 return paragraph_vec
864
865 def _supplement_paragraph_vectors(
866 self,
867 ctx: SearchContext,
868 *,
869 paragraph_vec: dict[str, float],
870 paragraph_ids: list[str],
871 query_vector: list[float],
872 embedding_field: str,
873 ) -> None:
874 missing_ids = [pid for pid in paragraph_ids if pid not in paragraph_vec]
875 if not missing_ids:
876 return
877
878 embeddings = self._fetch_chunk_embeddings_by_ids(ctx, missing_ids, embedding_field=embedding_field)
879 for pid, embedding in embeddings.items():
880 paragraph_vec[pid] = _cosine_similarity(query_vector, embedding)
881
882 missing_after = [pid for pid in missing_ids if pid not in paragraph_vec]
883 if missing_after:
884 raise ValueError(f"missing embeddings for paragraphs: {missing_after}")
885
886 def _hybrid_fulltext_paragraph_mix(
887 self,
888 ctx: SearchContext,
889 *,
890 query_vector: list[float],
891 vec_weight: float,
892 rows: list[dict[str, Any]],
893 embedding_field: str,
894 ) -> tuple[dict[str, float], dict[str, float]]:
895 paragraph_mix: dict[str, float] = {}
896 # paragraph_id -> max sentence BM25 (used for semantic branch supplementation)
897 known_parent_bm25: dict[str, float] = {}
898
899 for row in rows:
900 paragraph_id = row.get(ctx.parent_id_column)
901 if not paragraph_id:
902 continue
903 if "__SCORE" not in row:
904 raise ValueError(f"missing __SCORE in fulltext row: {row}")
905 bm25 = self._normalize_score(row.get("__SCORE"))
906
907 existing_bm25 = known_parent_bm25.get(paragraph_id)
908 if existing_bm25 is None or bm25 > existing_bm25:
909 known_parent_bm25[paragraph_id] = bm25
910
911 embedding = _parse_embedding(row.get(embedding_field))
912 if not embedding:
913 continue
914
915 norm_bm25 = _normalize_bm25(bm25)
916 sent_vec = _cosine_similarity(query_vector, embedding)
917 mix = vec_weight * sent_vec + (1.0 - vec_weight) * norm_bm25
918 existing_mix = paragraph_mix.get(paragraph_id)
919 if existing_mix is None or mix > existing_mix:
920 paragraph_mix[paragraph_id] = mix
921
922 return paragraph_mix, known_parent_bm25
923
924 def _fetch_children_by_parent_ids(
925 self,
926 ctx: SearchContext,
927 parent_ids: list[str],
928 *,
929 embedding_field: str,
930 ) -> list[dict[str, Any]]:
931 if not parent_ids:
932 return []
933 dsl = {
934 "ontology": ctx.ontology_name,
935 "apiName": ctx.object_type,
936 "select": [
937 {"field": ctx.parent_id_column},
938 {"field": embedding_field},
939 {"field": ctx.type_column},
940 ],
941 "where": {
942 "type": "and",
943 "filter": [
944 {"type": "in", "field": ctx.parent_id_column, "value": parent_ids},
945 {
946 "type": "in",
947 "field": ctx.type_column,
948 "value": [CHUNK_TYPE_SENTENCE, CHUNK_TYPE_CUSTOM_SENTENCE],
949 },
950 ],
951 },
952 }
953 dsl["where"] = self._apply_metadata_filters(dsl["where"], ctx)
954 return self._ontology_search(dsl=dsl)
955
956 def _fetch_chunk_embeddings_by_ids(
957 self,
958 ctx: SearchContext,
959 chunk_ids: list[str],
960 *,
961 embedding_field: str,
962 ) -> dict[str, list[float]]:
963 if not chunk_ids:
964 return {}
965
966 dsl = {
967 "ontology": ctx.ontology_name,
968 "apiName": ctx.object_type,
969 "limit": len(chunk_ids),
970 "select": [
971 {"field": ctx.chunk_id_column},
972 {"field": embedding_field},
973 {"field": ctx.type_column},
974 ],
975 "where": {
976 "type": "and",
977 "filter": [
978 {"type": "eq", "field": ctx.type_column, "value": CHUNK_TYPE_CHUNK},
979 {"type": "in", "field": ctx.chunk_id_column, "value": chunk_ids},
980 ],
981 },
982 }
983 dsl["where"] = self._apply_metadata_filters(dsl["where"], ctx)
984 rows = self._ontology_search(dsl=dsl)
985
986 embeddings: dict[str, list[float]] = {}
987 for row in rows or []:
988 cid = row.get(ctx.chunk_id_column)
989 if not cid:
990 continue
991 embedding = _parse_embedding(row.get(embedding_field))
992 if not embedding:
993 raise ValueError(f"missing embedding for chunk: {row}")
994 embeddings[cid] = embedding
995
996 return embeddings
997
998 def _calculate_sentence_bm25_batch(
999 self,
1000 ctx: SearchContext,
1001 *,
1002 query: str,
1003 parent_ids: set[str],
1004 limit: int,
1005 ) -> dict[str, float]:
1006 if not parent_ids:
1007 return {}
1008
1009 # Query sentence hits by BM25, then filter client-side by parentId.
1010 dsl = {
1011 "ontology": ctx.ontology_name,
1012 "apiName": ctx.object_type,
1013 "limit": limit,
1014 "select": [
1015 {"field": ctx.parent_id_column},
1016 {"field": "__SCORE", "orderBy": "desc"},
1017 {"field": ctx.type_column},
1018 ],
1019 "where": {
1020 "type": "and",
1021 "filter": [
1022 {"type": "match_any", "field": ctx.fulltext_column, "value": query},
1023 {
1024 "type": "in",
1025 "field": ctx.type_column,
1026 "value": [CHUNK_TYPE_SENTENCE, CHUNK_TYPE_CUSTOM_SENTENCE],
1027 },
1028 ],
1029 },
1030 }
1031 dsl["where"] = self._apply_metadata_filters(dsl["where"], ctx)
1032 rows = self._ontology_search(dsl=dsl)
1033
1034 # parent_id -> max bm25 among matched sentences
1035 result: dict[str, float] = {}
1036 for row in rows or []:
1037 pid = row.get(ctx.parent_id_column)
1038 if not pid or pid not in parent_ids:
1039 continue
1040 if "__SCORE" not in row:
1041 raise ValueError(f"missing __SCORE in sentence bm25 row: {row}")
1042 bm25 = self._normalize_score(row.get("__SCORE"))
1043 existing_bm25 = result.get(pid)
1044 if existing_bm25 is None or bm25 > existing_bm25:
1045 result[pid] = bm25
1046 return result
1047
1048 def _hybrid_semantic_paragraph_mix(
1049 self,
1050 ctx: SearchContext,
1051 *,
1052 query: str,
1053 query_vector: list[float],
1054 vec_weight: float,
1055 paragraph_ids: list[str],
1056 known_parent_bm25: dict[str, float],
1057 bm25_recall_limit: int,
1058 embedding_field: str,
1059 ) -> dict[str, float]:
1060 if not paragraph_ids:
1061 return {}
1062
1063 children = self._fetch_children_by_parent_ids(ctx, paragraph_ids, embedding_field=embedding_field)
1064 if not children:
1065 return {}
1066
1067 parent_ids_need_bm25: set[str] = set()
1068 for child in children:
1069 pid = child.get(ctx.parent_id_column)
1070 if not pid:
1071 continue
1072 if pid not in known_parent_bm25:
1073 parent_ids_need_bm25.add(pid)
1074
1075 if parent_ids_need_bm25:
1076 supplement = self._calculate_sentence_bm25_batch(
1077 ctx,
1078 query=query,
1079 parent_ids=parent_ids_need_bm25,
1080 limit=bm25_recall_limit,
1081 )
1082 known_parent_bm25.update(supplement)
1083
1084 paragraph_mix: dict[str, float] = {}
1085 for child in children:
1086 paragraph_id = child.get(ctx.parent_id_column)
1087 if not paragraph_id:
1088 continue
1089 embedding = _parse_embedding(child.get(embedding_field))
1090 if not embedding:
1091 continue
1092
1093 bm25 = known_parent_bm25.get(paragraph_id)
1094 if bm25 is None:
1095 bm25 = 0.0
1096 logger.warning(
1097 "missing bm25 for paragraph, fallback to 0 (no fulltext match for sentence/custom_sentence)",
1098 extra={
1099 "paragraph_id": paragraph_id,
1100 "ontology": ctx.ontology_name,
1101 "object_type": ctx.object_type,
1102 },
1103 )
1104 norm_bm25 = _normalize_bm25(bm25)
1105 mix = vec_weight * _cosine_similarity(query_vector, embedding) + (1.0 - vec_weight) * norm_bm25
1106 existing_mix = paragraph_mix.get(paragraph_id)
1107 if existing_mix is None or mix > existing_mix:
1108 paragraph_mix[paragraph_id] = mix
1109
1110 return paragraph_mix
1111
1112 def _merge_and_prerank(
1113 self,
1114 ctx: SearchContext,
1115 *,
1116 paragraph_mix: dict[str, float],
1117 paragraph_vec: dict[str, float],
1118 vec_weight: float,
1119 ) -> list[tuple[str, float]]:
1120 top_k = int(ctx.recall_top_n)
1121 score_threshold = float(ctx.score_threshold)
1122
1123 auto_passed: list[tuple[str, float]] = []
1124 others: list[tuple[str, float]] = []
1125 for pid, mix_score in paragraph_mix.items():
1126 if pid not in paragraph_vec:
1127 raise ValueError(f"missing vector score for paragraph {pid}")
1128 vec_score = paragraph_vec[pid]
1129 pre_rank = vec_weight * vec_score + (1.0 - vec_weight) * mix_score
1130 if pre_rank < score_threshold:
1131 continue
1132 if mix_score >= HYBRID_AUTO_PASS_THRESHOLD:
1133 auto_passed.append((pid, pre_rank))
1134 else:
1135 others.append((pid, pre_rank))
1136
1137 auto_passed.sort(key=lambda item: item[1], reverse=True)
1138 others.sort(key=lambda item: item[1], reverse=True)
1139
1140 merged = auto_passed + others
1141 return merged[:top_k]
1142
1143 def _fetch_chunks_by_ids(self, ctx: SearchContext, chunk_ids: list[str]) -> list[dict[str, Any]]:
1144 if not chunk_ids:
1145 return []
1146 dsl = {
1147 "ontology": ctx.ontology_name,
1148 "apiName": ctx.object_type,
1149 "limit": len(chunk_ids),
1150 "select": [
1151 {"field": "content"},
1152 {"field": ctx.chunk_id_column},
1153 {"field": ctx.type_column},
1154 {"field": "docId"},
1155 {"field": "chunkIdx"},
1156 ],
1157 "where": {
1158 "type": "and",
1159 "filter": [
1160 {"type": "eq", "field": ctx.type_column, "value": CHUNK_TYPE_CHUNK},
1161 {"type": "in", "field": ctx.chunk_id_column, "value": chunk_ids},
1162 ],
1163 },
1164 }
1165 dsl["where"] = self._apply_metadata_filters(dsl["where"], ctx)
1166 return self._ontology_search(dsl=dsl)
1167
1168
1169class Reranker:
1170 def __init__(self) -> None:
1171 self.rerank_service = create_rerank_service(
1172 "bce-reranker-base"
1173 )
1174
1175 def rerank(self, query: str, chunks: list[str], top_k: int):
1176 response = self.rerank_service.rerank(query, chunks, top_k)
1177 results = builtins.getattr(response, "results", None)
1178 if isinstance(results, list):
1179 response.results = results[:top_k]
1180 return response
1181
1182 def sort(self, chunks: list[dict[str, Any]], results: list[Any]) -> list[dict[str, Any]]:
1183 return self.rerank_service.sort(chunks, results)
1184
1185
1186class Retriever:
1187 def __init__(self) -> None:
1188 self.ontology_service = create_ontology_service()
1189 self.embedding_service = create_embedding_service(
1190 DEFAULT_EMBEDDING_MODEL
1191 )
1192 self.fulltext_pipeline = FulltextPipeline(self.ontology_service)
1193 self.semantic_pipeline = SemanticPipeline(self.ontology_service, self.embedding_service)
1194 self.hybrid_pipeline = HybridPipeline(self.ontology_service, self.embedding_service)
1195
1196 def _ontology_search(self, *, dsl: dict[str, Any]) -> list[dict[str, Any]]:
1197 response = self.ontology_service.search(dsl=dsl)
1198 try:
1199 return BasePipeline._response_rows(response)
1200 except Exception as exc:
1201 raise RuntimeError(f"ontology search failed: {exc}. dsl={dsl}") from exc
1202
1203 def _fetch_chunk_metadata_by_ids(
1204 self,
1205 ctx: SearchContext,
1206 chunk_ids: list[str],
1207 ) -> dict[str, tuple[str | None, int | None]]:
1208 if not chunk_ids:
1209 return {}
1210 dsl = {
1211 "ontology": ctx.ontology_name,
1212 "apiName": ctx.object_type,
1213 "limit": len(chunk_ids),
1214 "select": [
1215 {"field": ctx.chunk_id_column},
1216 {"field": "docId"},
1217 {"field": "chunkIdx"},
1218 {"field": ctx.type_column},
1219 ],
1220 "where": {
1221 "type": "and",
1222 "filter": [
1223 {"type": "eq", "field": ctx.type_column, "value": CHUNK_TYPE_CHUNK},
1224 {"type": "in", "field": ctx.chunk_id_column, "value": chunk_ids},
1225 ],
1226 },
1227 }
1228 dsl["where"] = BasePipeline._apply_metadata_filters(dsl["where"], ctx)
1229 rows = self._ontology_search(dsl=dsl)
1230
1231 metadata: dict[str, tuple[str | None, int | None]] = {}
1232 for row in rows or []:
1233 chunk_id = row.get(ctx.chunk_id_column)
1234 if not chunk_id:
1235 continue
1236 metadata[str(chunk_id)] = (
1237 BasePipeline._extract_doc_id(row),
1238 BasePipeline._extract_chunk_idx(row),
1239 )
1240 return metadata
1241
1242 def _fetch_chunks_by_doc_and_idx(
1243 self,
1244 ctx: SearchContext,
1245 *,
1246 doc_id: str,
1247 chunk_indexes: list[int],
1248 ) -> list[dict[str, Any]]:
1249 if not chunk_indexes:
1250 return []
1251 dsl = {
1252 "ontology": ctx.ontology_name,
1253 "apiName": ctx.object_type,
1254 "limit": len(chunk_indexes),
1255 "select": [
1256 {"field": "content"},
1257 {"field": ctx.chunk_id_column},
1258 {"field": "docId"},
1259 {"field": "chunkIdx"},
1260 {"field": ctx.type_column},
1261 ],
1262 "where": {
1263 "type": "and",
1264 "filter": [
1265 {"type": "eq", "field": ctx.type_column, "value": CHUNK_TYPE_CHUNK},
1266 {"type": "eq", "field": "docId", "value": doc_id},
1267 {"type": "in", "field": "chunkIdx", "value": chunk_indexes},
1268 ],
1269 },
1270 }
1271 dsl["where"] = BasePipeline._apply_metadata_filters(dsl["where"], ctx)
1272 return self._ontology_search(dsl=dsl)
1273
1274 def retrieve(self, *, context: SearchContext) -> dict[str, Any]:
1275 recall_type = context.recall_type
1276
1277 if recall_type == "fulltext":
1278 return self.fulltext_pipeline.run(context)
1279 if recall_type == "semantic":
1280 return self.semantic_pipeline.run(context)
1281 if recall_type == "hybrid":
1282 return self.hybrid_pipeline.run(context)
1283
1284 raise ValueError(f"unsupported recall type: {recall_type}")
1285
1286 def expand_chunks_after_rerank(
1287 self,
1288 ctx: SearchContext,
1289 chunks: list[dict[str, Any]],
1290 *,
1291 window_size: int,
1292 ) -> list[dict[str, Any]]:
1293 if not chunks or window_size <= 0:
1294 return chunks
1295
1296 chunk_ids = [str(chunk_id) for chunk in chunks if (chunk_id := chunk.get("chunk_id"))]
1297 metadata = self._fetch_chunk_metadata_by_ids(ctx, chunk_ids)
1298
1299 doc_to_indexes: dict[str, set[int]] = defaultdict(set)
1300 for chunk in chunks:
1301 chunk_id = chunk.get("chunk_id")
1302 if not chunk_id:
1303 logger.warning("chunk missing chunk_id for expansion")
1304 continue
1305 doc_id, chunk_idx = metadata.get(str(chunk_id), (None, None))
1306 if not doc_id or chunk_idx is None or chunk_idx < 0:
1307 logger.warning(
1308 "chunk missing docId/chunkIdx for expansion",
1309 extra={"chunk_id": chunk_id},
1310 )
1311 continue
1312 start = max(0, chunk_idx - window_size)
1313 end = chunk_idx + window_size
1314 doc_to_indexes[str(doc_id)].update(range(start, end + 1))
1315
1316 expanded = list(chunks)
1317 for doc_id, idxs in doc_to_indexes.items():
1318 rows = self._fetch_chunks_by_doc_and_idx(ctx, doc_id=doc_id, chunk_indexes=sorted(idxs))
1319 for row in rows or []:
1320 chunk_id = row.get(ctx.chunk_id_column)
1321 content = row.get("content")
1322 if not chunk_id or content is None:
1323 continue
1324 expanded.append(
1325 {
1326 "chunk_id": chunk_id,
1327 "content": content,
1328 "object_type": ctx.object_type,
1329 "ontology": ctx.ontology_name,
1330 }
1331 )
1332 metadata[str(chunk_id)] = (
1333 BasePipeline._extract_doc_id(row),
1334 BasePipeline._extract_chunk_idx(row),
1335 )
1336
1337 seen: set[str] = set()
1338 deduped: list[dict[str, Any]] = []
1339 for chunk in expanded:
1340 chunk_id = chunk.get("chunk_id")
1341 if not chunk_id:
1342 logger.warning("chunk missing chunk_id for expansion", extra={"chunk": chunk})
1343 deduped.append(chunk)
1344 continue
1345 chunk_id_str = str(chunk_id)
1346 if chunk_id_str in seen:
1347 continue
1348 seen.add(chunk_id_str)
1349 deduped.append(chunk)
1350
1351 def sort_key(item: dict[str, Any]) -> tuple[int, str, int]:
1352 chunk_id = item.get("chunk_id")
1353 doc_id, chunk_idx = metadata.get(str(chunk_id), (None, None))
1354 if not doc_id or chunk_idx is None or chunk_idx < 0:
1355 logger.warning(
1356 "chunk missing docId/chunkIdx for ordering",
1357 extra={"chunk_id": chunk_id},
1358 )
1359 return (1, "", 0)
1360 return (0, str(doc_id), int(chunk_idx))
1361
1362 deduped.sort(key=sort_key)
1363 return deduped
评价此篇文章
