"""Retrieve relevant chunks from ChromaDB."""
from __future__ import annotations
import re
from src.config import INDEX_DIR, get_top_k
from src.embeddings import embed_query
from src.index_store import get_collection, search
STOP_WORDS = {
"a",
"and",
"are",
"as",
"at",
"for",
"from",
"had",
"in",
"of",
"on",
"or",
"should",
"the",
"to",
"was",
"what",
"which",
"who",
"why",
"with",
}
MAX_KEYWORD_BOOST = 0.4
def retrieve(question: str, top_k: int | None = None) -> list[dict]:
if not question.strip():
return []
limit = top_k or get_top_k()
query_embedding = embed_query(question)
collection = get_collection(INDEX_DIR)
candidates = search(collection, query_embedding, candidate_count(limit))
candidates = rerank_with_keyword_overlap(question, candidates)
return keep_best_chunks_per_source(candidates, limit, max_per_source=1)
def retrieve_for_generation(question: str, top_k: int | None = None) -> list[dict]:
if not question.strip():
return []
limit = top_k or get_top_k()
query_embedding = embed_query(question)
collection = get_collection(INDEX_DIR)
candidates = search(collection, query_embedding, candidate_count(limit * 2))
candidates = rerank_with_keyword_overlap(question, candidates)
return keep_best_chunks_per_source(candidates, limit * 2, max_per_source=3)
def keep_best_chunks_per_source(
results: list[dict], limit: int, max_per_source: int
) -> list[dict]:
selected = []
source_counts = {}
for result in results:
source = result["metadata"]["source"]
current_count = source_counts.get(source, 0)
if current_count >= max_per_source:
continue
selected.append(result)
source_counts[source] = current_count + 1
if len(selected) >= limit:
return selected
return selected
def candidate_count(limit: int) -> int:
return max(limit * 4, 40)
def rerank_with_keyword_overlap(question: str, results: list[dict]) -> list[dict]:
question_terms = extract_terms(question)
if not question_terms:
return results
return sorted(
results,
key=lambda result: rerank_score(result, question_terms),
)
def rerank_score(result: dict, question_terms: set[str]) -> float:
document_terms = extract_terms(result["text"])
overlap_count = len(question_terms.intersection(document_terms))
overlap_ratio = overlap_count / len(question_terms)
keyword_boost = overlap_ratio * MAX_KEYWORD_BOOST
return result["distance"] - keyword_boost
def extract_terms(text: str) -> set[str]:
words = re.findall(r"[a-zA-Z0-9][a-zA-Z0-9-]*", text.lower())
return {word for word in words if word not in STOP_WORDS and len(word) > 2}