GigaProjects

← Back to rag-assistant

retrieval.py

"""Retrieve relevant chunks from ChromaDB."""

from __future__ import annotations

import re

from src.config import INDEX_DIR, get_top_k
from src.embeddings import embed_query
from src.index_store import get_collection, search


STOP_WORDS = {
    "a",
    "and",
    "are",
    "as",
    "at",
    "for",
    "from",
    "had",
    "in",
    "of",
    "on",
    "or",
    "should",
    "the",
    "to",
    "was",
    "what",
    "which",
    "who",
    "why",
    "with",
}

MAX_KEYWORD_BOOST = 0.4


def retrieve(question: str, top_k: int | None = None) -> list[dict]:
    if not question.strip():
        return []

    limit = top_k or get_top_k()
    query_embedding = embed_query(question)
    collection = get_collection(INDEX_DIR)
    candidates = search(collection, query_embedding, candidate_count(limit))
    candidates = rerank_with_keyword_overlap(question, candidates)
    return keep_best_chunks_per_source(candidates, limit, max_per_source=1)


def retrieve_for_generation(question: str, top_k: int | None = None) -> list[dict]:
    if not question.strip():
        return []

    limit = top_k or get_top_k()
    query_embedding = embed_query(question)
    collection = get_collection(INDEX_DIR)
    candidates = search(collection, query_embedding, candidate_count(limit * 2))
    candidates = rerank_with_keyword_overlap(question, candidates)
    return keep_best_chunks_per_source(candidates, limit * 2, max_per_source=3)


def keep_best_chunks_per_source(
    results: list[dict], limit: int, max_per_source: int
) -> list[dict]:
    selected = []
    source_counts = {}

    for result in results:
        source = result["metadata"]["source"]
        current_count = source_counts.get(source, 0)
        if current_count >= max_per_source:
            continue

        selected.append(result)
        source_counts[source] = current_count + 1

        if len(selected) >= limit:
            return selected

    return selected


def candidate_count(limit: int) -> int:
    return max(limit * 4, 40)


def rerank_with_keyword_overlap(question: str, results: list[dict]) -> list[dict]:
    question_terms = extract_terms(question)
    if not question_terms:
        return results

    return sorted(
        results,
        key=lambda result: rerank_score(result, question_terms),
    )


def rerank_score(result: dict, question_terms: set[str]) -> float:
    document_terms = extract_terms(result["text"])
    overlap_count = len(question_terms.intersection(document_terms))

    overlap_ratio = overlap_count / len(question_terms)
    keyword_boost = overlap_ratio * MAX_KEYWORD_BOOST
    return result["distance"] - keyword_boost


def extract_terms(text: str) -> set[str]:
    words = re.findall(r"[a-zA-Z0-9][a-zA-Z0-9-]*", text.lower())
    return {word for word in words if word not in STOP_WORDS and len(word) > 2}

Run this code