GigaProjects

← Back to rag-assistant

chunking.py

"""Split documents into retrieval chunks."""

from __future__ import annotations

import re


MAX_WORDS_PER_CHUNK = 900


def chunk_documents(documents: list[dict]) -> list[dict]:
    chunks = []

    for document in documents:
        document_chunks = chunk_document(document)
        chunks.extend(document_chunks)

    return chunks


def chunk_document(document: dict) -> list[dict]:
    if document["extension"] == ".md":
        return chunk_markdown_document(document)

    return [make_chunk(document, section="Full document", text=document["text"], index=0)]


def chunk_markdown_document(document: dict) -> list[dict]:
    if word_count(document["text"]) <= MAX_WORDS_PER_CHUNK:
        return [make_chunk(document, section="Full document", text=document["text"], index=0)]

    sections = split_markdown_sections(document["text"])
    chunks = []

    for section_index, section in enumerate(sections):
        section_title = section["title"]
        section_text = section["text"]

        if word_count(section_text) <= MAX_WORDS_PER_CHUNK:
            chunks.append(make_chunk(document, section_title, section_text, len(chunks)))
            continue

        parts = split_long_text(section_text)
        for part_number, part_text in enumerate(parts, start=1):
            part_title = f"{section_title} part {part_number}"
            chunks.append(make_chunk(document, part_title, part_text, len(chunks)))

    return chunks


def split_markdown_sections(text: str) -> list[dict]:
    lines = text.splitlines()
    sections = []
    current_title = "Document start"
    current_lines = []

    for line in lines:
        heading = parse_heading(line)
        if heading:
            add_section(sections, current_title, current_lines)
            current_title = heading
            current_lines = [line]
            continue

        current_lines.append(line)

    add_section(sections, current_title, current_lines)

    if sections:
        return sections

    return [{"title": "Full document", "text": text}]


def parse_heading(line: str) -> str | None:
    match = re.match(r"^(#{1,3})\s+(.+)$", line.strip())
    if not match:
        return None

    return match.group(2).strip()


def add_section(sections: list[dict], title: str, lines: list[str]) -> None:
    text = "\n".join(lines).strip()
    if not text:
        return

    sections.append({"title": title, "text": text})


def split_long_text(text: str) -> list[str]:
    paragraphs = [paragraph.strip() for paragraph in text.split("\n\n") if paragraph.strip()]
    chunks = []
    current_paragraphs = []
    current_words = 0

    for paragraph in paragraphs:
        paragraph_words = word_count(paragraph)

        if current_paragraphs and current_words + paragraph_words > MAX_WORDS_PER_CHUNK:
            chunks.append("\n\n".join(current_paragraphs))
            current_paragraphs = []
            current_words = 0

        current_paragraphs.append(paragraph)
        current_words += paragraph_words

    if current_paragraphs:
        chunks.append("\n\n".join(current_paragraphs))

    if chunks:
        return chunks

    return [text]


def make_chunk(document: dict, section: str, text: str, index: int) -> dict:
    chunk_id = f"{document['source']}::chunk-{index:03d}"
    searchable_text = build_searchable_text(document, section, text)

    return {
        "id": chunk_id,
        "source": document["source"],
        "path": document["path"],
        "relative_path": document["relative_path"],
        "document_type": document["document_type"],
        "section": section,
        "text": text.strip(),
        "searchable_text": searchable_text,
    }


def build_searchable_text(document: dict, section: str, text: str) -> str:
    return "\n".join(
        [
            f"Source: {document['source']}",
            f"Document type: {document['document_type']}",
            f"Section: {section}",
            "",
            text.strip(),
        ]
    )


def word_count(text: str) -> int:
    return len(text.split())

Run this code