"""Run retrieval and answer evaluation."""
from __future__ import annotations
import argparse
import json
import sys
from pathlib import Path
PROJECT_ROOT = Path(__file__).resolve().parents[1]
sys.path.insert(0, str(PROJECT_ROOT))
from src.config import (
EVAL_FILE,
EVAL_RESULTS_DIR,
get_judge_model,
get_openrouter_api_key,
get_top_k,
load_env_file,
)
from src.evaluation import load_eval_questions
from src.generation import create_openrouter_client
from src.rag import answer_question
from src.retrieval import retrieve
def main() -> None:
load_env_file()
args = parse_args()
if args.retrieval_only:
run_retrieval_eval(args.top_k)
return
if args.answers:
run_answer_eval(args.top_k, args.limit)
return
print("Choose --retrieval-only or --answers.")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Run RAG evaluation.")
parser.add_argument("--retrieval-only", action="store_true")
parser.add_argument("--answers", action="store_true")
parser.add_argument("--top-k", type=int, default=get_top_k())
parser.add_argument("--limit", type=int, default=None)
return parser.parse_args()
def run_retrieval_eval(top_k: int) -> None:
questions = load_eval_questions(EVAL_FILE)
results = []
for question in questions:
retrieved = retrieve(question["question"], top_k=top_k)
result = score_retrieval(question, retrieved)
results.append(result)
print_result(result)
print_summary(results)
save_results("retrieval_eval_latest.json", build_retrieval_output(top_k, results))
def run_answer_eval(top_k: int, limit: int | None = None) -> None:
api_key = get_openrouter_api_key()
if not api_key:
print("OPENROUTER_API_KEY is required for answer evaluation.")
return
questions = load_eval_questions(EVAL_FILE)
if limit:
questions = questions[:limit]
judge_client = create_openrouter_client(api_key)
results = []
for question in questions:
print(f"Evaluating {question['id']}...", flush=True)
result = evaluate_answer(question, top_k, judge_client)
results.append(result)
print_answer_result(result)
print_answer_summary(results)
save_results("answer_eval_latest.json", build_answer_output(top_k, results))
def evaluate_answer(question: dict, top_k: int, judge_client) -> dict:
rag_result = answer_question(question["question"], top_k=top_k)
retrieval_result = score_retrieval(question, rag_result["retrieved_chunks"])
retrieved_sources = retrieval_result["retrieved_sources"]
judge_result = judge_answer(question, rag_result["answer"], retrieved_sources, judge_client)
return {
"id": question["id"],
"category": question["category"],
"retrieval_status": retrieval_result["status"],
"answer": rag_result["answer"],
"retrieved_sources": retrieved_sources,
"judge": judge_result,
}
def score_retrieval(question: dict, retrieved: list[dict]) -> dict:
expected_sources = question["expected_sources"]
retrieved_sources = [item["metadata"]["source"] for item in retrieved]
unique_retrieved_sources = list(dict.fromkeys(retrieved_sources))
if question["must_refuse"]:
return {
"id": question["id"],
"category": question["category"],
"status": "review",
"expected_sources": expected_sources,
"retrieved_sources": unique_retrieved_sources,
"missing_sources": [],
}
missing_sources = [
source for source in expected_sources if source not in unique_retrieved_sources
]
if not missing_sources:
status = "pass"
elif len(missing_sources) < len(expected_sources):
status = "partial"
else:
status = "fail"
return {
"id": question["id"],
"category": question["category"],
"status": status,
"expected_sources": expected_sources,
"retrieved_sources": unique_retrieved_sources,
"missing_sources": missing_sources,
}
def judge_answer(
question: dict, actual_answer: str, retrieved_sources: list[str], judge_client
) -> dict:
prompt = build_judge_prompt(question, actual_answer, retrieved_sources)
try:
response = judge_client.chat.completions.create(
model=get_judge_model(),
messages=[
{
"role": "system",
"content": "You are a strict evaluator for a RAG assistant.",
},
{"role": "user", "content": prompt},
],
temperature=0,
response_format={"type": "json_object"},
)
except Exception as exc:
return failed_judge_result(f"Judge API error: {type(exc).__name__}")
content = response.choices[0].message.content
if not content:
return failed_judge_result("Judge returned empty content.")
return parse_judge_json(content)
def build_judge_prompt(
question: dict, actual_answer: str, retrieved_sources: list[str]
) -> str:
payload = {
"question_id": question["id"],
"category": question["category"],
"question": question["question"],
"expected_answer": question["expected_answer"],
"key_facts": question["key_facts"],
"expected_sources": question["expected_sources"],
"must_refuse": question["must_refuse"],
"retrieved_sources": retrieved_sources,
"actual_answer": actual_answer,
}
return f"""
Evaluate the RAG assistant answer.
Return only valid JSON with this exact shape:
{{
"correctness": "pass|partial|fail",
"grounding": "pass|partial|fail",
"sources": "pass|partial|fail",
"refusal": "pass|fail|not_applicable",
"overall": "pass|partial|fail",
"reason": "one short sentence"
}}
Rules:
- Correctness checks whether the answer matches the expected answer and key facts.
- Grounding checks whether the answer is supported by retrieved sources.
- Sources checks whether citations are appropriate and not invented.
- For refusal answers, "Sources: none" is appropriate when no source directly supports the requested fact.
- Refusal is pass only when must_refuse is true and the answer refuses to invent.
- If must_refuse is false, refusal must be "not_applicable".
- Overall should be fail for serious hallucination, wrong final answer, or bad refusal.
- Overall should be partial when the answer is mostly right but incomplete.
Evaluation payload:
{json.dumps(payload, indent=2)}
""".strip()
def parse_judge_json(content: str) -> dict:
try:
return normalize_judge_result(json.loads(content))
except json.JSONDecodeError:
pass
start = content.find("{")
end = content.rfind("}")
if start == -1 or end == -1 or end <= start:
return failed_judge_result("Judge did not return JSON.")
try:
return normalize_judge_result(json.loads(content[start : end + 1]))
except json.JSONDecodeError:
return failed_judge_result("Judge returned invalid JSON.")
def normalize_judge_result(result: dict) -> dict:
return {
"correctness": clean_status(result.get("correctness"), {"pass", "partial", "fail"}),
"grounding": clean_status(result.get("grounding"), {"pass", "partial", "fail"}),
"sources": clean_status(result.get("sources"), {"pass", "partial", "fail"}),
"refusal": clean_status(
result.get("refusal"), {"pass", "fail", "not_applicable"}
),
"overall": clean_status(result.get("overall"), {"pass", "partial", "fail"}),
"reason": str(result.get("reason", "")).strip(),
}
def clean_status(value: object, allowed: set[str]) -> str:
if not isinstance(value, str):
return "fail"
normalized = value.strip().lower()
if normalized in allowed:
return normalized
return "fail"
def failed_judge_result(reason: str) -> dict:
return {
"correctness": "fail",
"grounding": "fail",
"sources": "fail",
"refusal": "fail",
"overall": "fail",
"reason": reason,
}
def print_result(result: dict) -> None:
print(f"{result['id']} [{result['category']}]: {result['status']}")
if result["expected_sources"]:
print(f" expected: {', '.join(result['expected_sources'])}")
else:
print(" expected: no specific source")
print(f" retrieved: {', '.join(result['retrieved_sources'])}")
if result["missing_sources"]:
print(f" missing: {', '.join(result['missing_sources'])}")
def print_answer_result(result: dict) -> None:
judge = result["judge"]
print(f"{result['id']} [{result['category']}]: {judge['overall']}")
print(f" retrieval: {result['retrieval_status']}")
print(f" correctness: {judge['correctness']}")
print(f" grounding: {judge['grounding']}")
print(f" sources: {judge['sources']}")
print(f" refusal: {judge['refusal']}")
print(f" reason: {judge['reason']}")
def print_summary(results: list[dict]) -> None:
counts = {"pass": 0, "partial": 0, "fail": 0, "review": 0}
for result in results:
counts[result["status"]] += 1
print()
print("Retrieval summary")
print(f" pass: {counts['pass']}")
print(f" partial: {counts['partial']}")
print(f" fail: {counts['fail']}")
print(f" review: {counts['review']} (unanswerable questions need manual review)")
def print_answer_summary(results: list[dict]) -> None:
counts = {"pass": 0, "partial": 0, "fail": 0}
for result in results:
overall = result["judge"]["overall"]
counts[overall] += 1
retrieval_counts = {"pass": 0, "partial": 0, "fail": 0, "review": 0}
for result in results:
retrieval_counts[result["retrieval_status"]] += 1
print()
print("Answer evaluation summary")
print(f" pass: {counts['pass']}")
print(f" partial: {counts['partial']}")
print(f" fail: {counts['fail']}")
print()
print("Retrieval coverage inside answer eval")
print(f" pass: {retrieval_counts['pass']}")
print(f" partial: {retrieval_counts['partial']}")
print(f" fail: {retrieval_counts['fail']}")
print(f" review: {retrieval_counts['review']}")
def build_retrieval_output(top_k: int, results: list[dict]) -> dict:
return {
"mode": "retrieval",
"top_k": top_k,
"summary": count_statuses(results, "status", ["pass", "partial", "fail", "review"]),
"results": results,
}
def build_answer_output(top_k: int, results: list[dict]) -> dict:
return {
"mode": "answers",
"top_k": top_k,
"judge_model": get_judge_model(),
"answer_summary": count_judge_overall(results),
"retrieval_summary": count_statuses(
results, "retrieval_status", ["pass", "partial", "fail", "review"]
),
"results": results,
}
def count_statuses(results: list[dict], key: str, statuses: list[str]) -> dict:
counts = {status: 0 for status in statuses}
for result in results:
status = result[key]
counts[status] += 1
return counts
def count_judge_overall(results: list[dict]) -> dict:
counts = {"pass": 0, "partial": 0, "fail": 0}
for result in results:
status = result["judge"]["overall"]
counts[status] += 1
return counts
def save_results(filename: str, payload: dict) -> None:
EVAL_RESULTS_DIR.mkdir(parents=True, exist_ok=True)
path = EVAL_RESULTS_DIR / filename
path.write_text(json.dumps(payload, indent=2), encoding="utf-8")
print()
print(f"Saved results to {path.relative_to(PROJECT_ROOT)}")
if __name__ == "__main__":
main()