GigaProjects

← Back to rag-assistant

app.py

"""CLI for asking questions against the reinsurance corpus."""

from __future__ import annotations

import argparse

from src.config import load_env_file
from src.rag import answer_question


def main() -> None:
    load_env_file()
    args = parse_args()

    if args.command == "ask":
        run_ask(args)
        return

    if args.command == "chat":
        run_chat(args)
        return

    print("Choose a command: ask or chat")


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser(description="Reinsurance RAG assistant.")
    subparsers = parser.add_subparsers(dest="command")

    ask_parser = subparsers.add_parser("ask", help="Ask one question.")
    ask_parser.add_argument("question", help="Question to answer.")
    ask_parser.add_argument("--top-k", type=int, default=None)
    ask_parser.add_argument("--show-retrieved", action="store_true")

    chat_parser = subparsers.add_parser("chat", help="Ask questions interactively.")
    chat_parser.add_argument("--top-k", type=int, default=None)
    chat_parser.add_argument("--show-retrieved", action="store_true")

    return parser.parse_args()


def run_ask(args: argparse.Namespace) -> None:
    result = answer_question(args.question, top_k=args.top_k)
    print_answer(result, show_retrieved=args.show_retrieved)


def run_chat(args: argparse.Namespace) -> None:
    print("Reinsurance RAG assistant")
    print("Type a question, or type 'exit' to stop.")

    while True:
        question = input("\nQuestion: ").strip()
        if should_exit(question):
            print("Goodbye.")
            return
        if not question:
            continue

        result = answer_question(question, top_k=args.top_k)
        print()
        print_answer(result, show_retrieved=args.show_retrieved)


def should_exit(text: str) -> bool:
    return text.lower() in {"exit", "quit"}


def print_answer(result: dict, show_retrieved: bool) -> None:
    print(result["answer"])

    if not show_retrieved:
        return

    print()
    print("Retrieved chunks:")
    for index, chunk in enumerate(result["retrieved_chunks"], start=1):
        metadata = chunk["metadata"]
        print(f"{index}. {metadata['source']} | {metadata['section']}")


if __name__ == "__main__":
    main()

Run this code