GigaProjects

← Back to rag-assistant

run_eval.py

"""Run retrieval and answer evaluation."""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

PROJECT_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(PROJECT_ROOT))

from src.config import (
    EVAL_FILE,
    EVAL_RESULTS_DIR,
    get_judge_model,
    get_openrouter_api_key,
    get_top_k,
    load_env_file,
)
from src.evaluation import load_eval_questions
from src.generation import create_openrouter_client
from src.rag import answer_question
from src.retrieval import retrieve


def main() -> None:
    load_env_file()
    args = parse_args()

    if args.retrieval_only:
        run_retrieval_eval(args.top_k)
        return

    if args.answers:
        run_answer_eval(args.top_k, args.limit)
        return

    print("Choose --retrieval-only or --answers.")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Run RAG evaluation.")
    parser.add_argument("--retrieval-only", action="store_true")
    parser.add_argument("--answers", action="store_true")
    parser.add_argument("--top-k", type=int, default=get_top_k())
    parser.add_argument("--limit", type=int, default=None)
    return parser.parse_args()


def run_retrieval_eval(top_k: int) -> None:
    questions = load_eval_questions(EVAL_FILE)
    results = []

    for question in questions:
        retrieved = retrieve(question["question"], top_k=top_k)
        result = score_retrieval(question, retrieved)
        results.append(result)
        print_result(result)

    print_summary(results)
    save_results("retrieval_eval_latest.json", build_retrieval_output(top_k, results))


def run_answer_eval(top_k: int, limit: int | None = None) -> None:
    api_key = get_openrouter_api_key()
    if not api_key:
        print("OPENROUTER_API_KEY is required for answer evaluation.")
        return

    questions = load_eval_questions(EVAL_FILE)
    if limit:
        questions = questions[:limit]

    judge_client = create_openrouter_client(api_key)
    results = []

    for question in questions:
        print(f"Evaluating {question['id']}...", flush=True)
        result = evaluate_answer(question, top_k, judge_client)
        results.append(result)
        print_answer_result(result)

    print_answer_summary(results)
    save_results("answer_eval_latest.json", build_answer_output(top_k, results))


def evaluate_answer(question: dict, top_k: int, judge_client) -> dict:
    rag_result = answer_question(question["question"], top_k=top_k)
    retrieval_result = score_retrieval(question, rag_result["retrieved_chunks"])
    retrieved_sources = retrieval_result["retrieved_sources"]
    judge_result = judge_answer(question, rag_result["answer"], retrieved_sources, judge_client)

    return {
        "id": question["id"],
        "category": question["category"],
        "retrieval_status": retrieval_result["status"],
        "answer": rag_result["answer"],
        "retrieved_sources": retrieved_sources,
        "judge": judge_result,
    }


def score_retrieval(question: dict, retrieved: list[dict]) -> dict:
    expected_sources = question["expected_sources"]
    retrieved_sources = [item["metadata"]["source"] for item in retrieved]
    unique_retrieved_sources = list(dict.fromkeys(retrieved_sources))

    if question["must_refuse"]:
        return {
            "id": question["id"],
            "category": question["category"],
            "status": "review",
            "expected_sources": expected_sources,
            "retrieved_sources": unique_retrieved_sources,
            "missing_sources": [],
        }

    missing_sources = [
        source for source in expected_sources if source not in unique_retrieved_sources
    ]

    if not missing_sources:
        status = "pass"
    elif len(missing_sources) < len(expected_sources):
        status = "partial"
    else:
        status = "fail"

    return {
        "id": question["id"],
        "category": question["category"],
        "status": status,
        "expected_sources": expected_sources,
        "retrieved_sources": unique_retrieved_sources,
        "missing_sources": missing_sources,
    }


def judge_answer(
    question: dict, actual_answer: str, retrieved_sources: list[str], judge_client
) -> dict:
    prompt = build_judge_prompt(question, actual_answer, retrieved_sources)

    try:
        response = judge_client.chat.completions.create(
            model=get_judge_model(),
            messages=[
                {
                    "role": "system",
                    "content": "You are a strict evaluator for a RAG assistant.",
                },
                {"role": "user", "content": prompt},
            ],
            temperature=0,
            response_format={"type": "json_object"},
        )
    except Exception as exc:
        return failed_judge_result(f"Judge API error: {type(exc).__name__}")

    content = response.choices[0].message.content
    if not content:
        return failed_judge_result("Judge returned empty content.")

    return parse_judge_json(content)


def build_judge_prompt(
    question: dict, actual_answer: str, retrieved_sources: list[str]
) -> str:
    payload = {
        "question_id": question["id"],
        "category": question["category"],
        "question": question["question"],
        "expected_answer": question["expected_answer"],
        "key_facts": question["key_facts"],
        "expected_sources": question["expected_sources"],
        "must_refuse": question["must_refuse"],
        "retrieved_sources": retrieved_sources,
        "actual_answer": actual_answer,
    }

    return f"""
Evaluate the RAG assistant answer.

Return only valid JSON with this exact shape:
{{
  "correctness": "pass|partial|fail",
  "grounding": "pass|partial|fail",
  "sources": "pass|partial|fail",
  "refusal": "pass|fail|not_applicable",
  "overall": "pass|partial|fail",
  "reason": "one short sentence"
}}

Rules:
- Correctness checks whether the answer matches the expected answer and key facts.
- Grounding checks whether the answer is supported by retrieved sources.
- Sources checks whether citations are appropriate and not invented.
- For refusal answers, "Sources: none" is appropriate when no source directly supports the requested fact.
- Refusal is pass only when must_refuse is true and the answer refuses to invent.
- If must_refuse is false, refusal must be "not_applicable".
- Overall should be fail for serious hallucination, wrong final answer, or bad refusal.
- Overall should be partial when the answer is mostly right but incomplete.

Evaluation payload:
{json.dumps(payload, indent=2)}
""".strip()


def parse_judge_json(content: str) -> dict:
    try:
        return normalize_judge_result(json.loads(content))
    except json.JSONDecodeError:
        pass

    start = content.find("{")
    end = content.rfind("}")
    if start == -1 or end == -1 or end <= start:
        return failed_judge_result("Judge did not return JSON.")

    try:
        return normalize_judge_result(json.loads(content[start : end + 1]))
    except json.JSONDecodeError:
        return failed_judge_result("Judge returned invalid JSON.")


def normalize_judge_result(result: dict) -> dict:
    return {
        "correctness": clean_status(result.get("correctness"), {"pass", "partial", "fail"}),
        "grounding": clean_status(result.get("grounding"), {"pass", "partial", "fail"}),
        "sources": clean_status(result.get("sources"), {"pass", "partial", "fail"}),
        "refusal": clean_status(
            result.get("refusal"), {"pass", "fail", "not_applicable"}
        ),
        "overall": clean_status(result.get("overall"), {"pass", "partial", "fail"}),
        "reason": str(result.get("reason", "")).strip(),
    }


def clean_status(value: object, allowed: set[str]) -> str:
    if not isinstance(value, str):
        return "fail"

    normalized = value.strip().lower()
    if normalized in allowed:
        return normalized

    return "fail"


def failed_judge_result(reason: str) -> dict:
    return {
        "correctness": "fail",
        "grounding": "fail",
        "sources": "fail",
        "refusal": "fail",
        "overall": "fail",
        "reason": reason,
    }


def print_result(result: dict) -> None:
    print(f"{result['id']} [{result['category']}]: {result['status']}")

    if result["expected_sources"]:
        print(f"  expected: {', '.join(result['expected_sources'])}")
    else:
        print("  expected: no specific source")

    print(f"  retrieved: {', '.join(result['retrieved_sources'])}")

    if result["missing_sources"]:
        print(f"  missing: {', '.join(result['missing_sources'])}")


def print_answer_result(result: dict) -> None:
    judge = result["judge"]
    print(f"{result['id']} [{result['category']}]: {judge['overall']}")
    print(f"  retrieval: {result['retrieval_status']}")
    print(f"  correctness: {judge['correctness']}")
    print(f"  grounding: {judge['grounding']}")
    print(f"  sources: {judge['sources']}")
    print(f"  refusal: {judge['refusal']}")
    print(f"  reason: {judge['reason']}")


def print_summary(results: list[dict]) -> None:
    counts = {"pass": 0, "partial": 0, "fail": 0, "review": 0}
    for result in results:
        counts[result["status"]] += 1

    print()
    print("Retrieval summary")
    print(f"  pass: {counts['pass']}")
    print(f"  partial: {counts['partial']}")
    print(f"  fail: {counts['fail']}")
    print(f"  review: {counts['review']} (unanswerable questions need manual review)")


def print_answer_summary(results: list[dict]) -> None:
    counts = {"pass": 0, "partial": 0, "fail": 0}
    for result in results:
        overall = result["judge"]["overall"]
        counts[overall] += 1

    retrieval_counts = {"pass": 0, "partial": 0, "fail": 0, "review": 0}
    for result in results:
        retrieval_counts[result["retrieval_status"]] += 1

    print()
    print("Answer evaluation summary")
    print(f"  pass: {counts['pass']}")
    print(f"  partial: {counts['partial']}")
    print(f"  fail: {counts['fail']}")
    print()
    print("Retrieval coverage inside answer eval")
    print(f"  pass: {retrieval_counts['pass']}")
    print(f"  partial: {retrieval_counts['partial']}")
    print(f"  fail: {retrieval_counts['fail']}")
    print(f"  review: {retrieval_counts['review']}")


def build_retrieval_output(top_k: int, results: list[dict]) -> dict:
    return {
        "mode": "retrieval",
        "top_k": top_k,
        "summary": count_statuses(results, "status", ["pass", "partial", "fail", "review"]),
        "results": results,
    }


def build_answer_output(top_k: int, results: list[dict]) -> dict:
    return {
        "mode": "answers",
        "top_k": top_k,
        "judge_model": get_judge_model(),
        "answer_summary": count_judge_overall(results),
        "retrieval_summary": count_statuses(
            results, "retrieval_status", ["pass", "partial", "fail", "review"]
        ),
        "results": results,
    }


def count_statuses(results: list[dict], key: str, statuses: list[str]) -> dict:
    counts = {status: 0 for status in statuses}
    for result in results:
        status = result[key]
        counts[status] += 1

    return counts


def count_judge_overall(results: list[dict]) -> dict:
    counts = {"pass": 0, "partial": 0, "fail": 0}
    for result in results:
        status = result["judge"]["overall"]
        counts[status] += 1

    return counts


def save_results(filename: str, payload: dict) -> None:
    EVAL_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
    path = EVAL_RESULTS_DIR / filename
    path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
    print()
    print(f"Saved results to {path.relative_to(PROJECT_ROOT)}")


if __name__ == "__main__":
    main()

Run this code