Code/rag.py
Stéphan a3608353a2 Improve text chunking to preserve Markdown tables
Enhance split function to detect and preserve Markdown tables when
chunking text. Tables are now kept intact by forcing splits before
and after table content.

Also increase K value from 10 to 30 in rag.py to provide more
passages to the LLM.
2025-05-19 06:45:46 +02:00

110 lines
4.0 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
RAG interactif version alignée sur l'index
-------------------------------------------
• Utilise corpus.idx + corpus.meta.json pour connaître l'ordre exact des passages.
• Recharge **uniquement** les textes correspondants en gardant cet ordre ainsi, plus
d'erreur d'index outofrange quelle que soit la découpe.
• Recherche FAISS (topk=4) + génération via mistral7b-fast (Ollama).
"""
import json, readline, re
from pathlib import Path
from collections import defaultdict
import faiss, numpy as np, requests
from FlagEmbedding import BGEM3FlagModel
from rich import print
ROOT = Path("Fiches") # dossier racine des fiches
K = 30 # nombre de passages remis au LLM
# ------------------ utilitaires de découpe identiques à l'index -------------
CHUNK, OVERLAP = 800, 100 # garder cohérent avec index.py
def split(text: str):
sents = re.split(r"(?<=[.!?]) +", text)
buf, out = [], []
for s in sents:
buf.append(s)
if len(" ".join(buf).split()) > CHUNK:
out.append(" ".join(buf))
buf = buf[-OVERLAP:]
if buf:
out.append(" ".join(buf))
return out
# ------------------- charger meta et reconstruire passages ------------------
meta_path = Path("corpus.meta.json")
if not meta_path.exists():
raise SystemExit("corpus.meta.json introuvable lancez d'abord index.py")
meta = json.load(meta_path.open())
# mapping (file, part) -> chunk text
cache: dict[tuple[str, int], str] = {}
for fp in sorted(ROOT.rglob("*")):
if fp.suffix.lower() not in {".md", ".markdown", ".txt"}:
continue
chunks = split(fp.read_text(encoding="utf-8"))
for i, ch in enumerate(chunks):
cache[(fp.name, i)] = ch
# reconstruire docs dans le même ordre que l'index ---------------------------
docs = []
for m in meta:
key = (m["file"], m["part"])
docs.append(cache.get(key, "[passage manquant]"))
print(f"[dim]Passages rechargés : {len(docs)} (ordre conforme à l'index).[/]")
# ---------------- FAISS + modèle embeddings --------------------------------
idx = faiss.read_index("corpus.idx")
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
# ---------------- boucle interactive ---------------------------------------
print("RAG prêt. Posez vos questions ! (CtrlD pour sortir)")
try:
while True:
try:
q = input("❓ > ").strip()
if not q:
continue
except (EOFError, KeyboardInterrupt):
print("\nBye."); break
emb = model.encode([q])
if isinstance(emb, dict):
# récupère le 1er ndarray trouvé
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
q_emb = emb[0] / np.linalg.norm(emb[0])
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K)
hits = I[0]
# réordonne pour mettre en tête les passages contenant “Seuil”
hits = sorted(hits, key=lambda i: "Seuil" not in docs[int(i)])
context = "\n\n".join(docs[int(i)] for i in hits[:K])
prompt = (
"<system>Réponds en français, de façon précise, et uniquement à partir du contexte fourni. Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.</system>\n"
f"<context>{context}</context>\n<user>{q}</user>"
)
def ask_llm(p):
r = requests.post("http://127.0.0.1:11434/api/generate", json={
"model": "mistral7b-fast",
"prompt": p,
"stream": False,
"options": {"temperature": 0.0, "num_predict": 512}
}, timeout=300)
return r.json()["response"]
print("\n[bold]Réponse :[/]")
print(ask_llm(prompt))
print("\n[dim]--- contexte utilisé ---[/]")
for rank, idx_id in enumerate(hits, 1):
m = meta[int(idx_id)]
print(f"[{rank}] {m['file']} · part {m['part']}{docs[int(idx_id)][:120]}")
except Exception as e:
print("[red]Erreur :", e)