Update rag.py

This commit is contained in:
Stéphan Peccini 2025-05-19 06:25:57 +02:00
parent c56a46545f
commit be9c3709db

94
rag.py
View File

@ -1,65 +1,67 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
RAG interactif robuste. RAG interactif version alignée sur l'index
Recharge les passages à partir des fiches (même découpe que l'index) pour disposer du texte. -------------------------------------------
Recherche FAISS topk=4 et génération via mistral7b-fast. Utilise corpus.idx + corpus.meta.json pour connaître l'ordre exact des passages.
Recharge **uniquement** les textes correspondants en gardant cet ordre ainsi, plus
d'erreur d'index outofrange quelle que soit la découpe.
Recherche FAISS (topk=4) + génération via mistral7b-fast (Ollama).
""" """
import os, json, readline, re
import json, readline, re
from pathlib import Path from pathlib import Path
from collections import defaultdict
import faiss, numpy as np, requests import faiss, numpy as np, requests
from FlagEmbedding import BGEM3FlagModel from FlagEmbedding import BGEM3FlagModel
from rich import print from rich import print
# --------------------------------------------------------------------------- ROOT = Path("Fiches") # dossier racine des fiches
ROOT = Path("Fiches") # dossier des fiches sur l'hôte K = 4 # nombre de passages remis au LLM
CHUNK, OVERLAP = 800, 100 # identiques à l'indexation
K = 4 # nb de passages remis au modèle
# --- découpe --------------------------------------------------------------- # ------------------ utilitaires de découpe identiques à l'index -------------
CHUNK, OVERLAP = 800, 100 # garder cohérent avec index.py
def split(text: str): def split(text: str):
sents = re.split(r"(?<=[.!?]) +", text) sents = re.split(r"(?<=[.!?]) +", text)
buf, out = [], [] buf, out = [], []
for s in sents: for s in sents:
buf.append(s) buf.append(s)
if len(" ".join(buf).split()) > CHUNK: # approx 1 mot = 1 token if len(" ".join(buf).split()) > CHUNK:
out.append(" ".join(buf)) out.append(" ".join(buf))
buf = buf[-OVERLAP:] buf = buf[-OVERLAP:]
if buf: if buf:
out.append(" ".join(buf)) out.append(" ".join(buf))
return out return out
# --- charger docs + meta dans le même ordre que l'index -------------------- # ------------------- charger meta et reconstruire passages ------------------
meta_path = Path("corpus.meta.json")
if not meta_path.exists():
raise SystemExit("corpus.meta.json introuvable lancez d'abord index.py")
meta = json.load(meta_path.open())
docs, meta = [], [] # mapping (file, part) -> chunk text
for fp in ROOT.rglob("*.md"): cache: dict[tuple[str, int], str] = {}
for i, chunk in enumerate(split(fp.read_text(encoding="utf-8"))): for fp in sorted(ROOT.rglob("*")):
docs.append(chunk) if fp.suffix.lower() not in {".md", ".markdown", ".txt"}:
meta.append({"file": fp.name, "part": i}) continue
chunks = split(fp.read_text(encoding="utf-8"))
for i, ch in enumerate(chunks):
cache[(fp.name, i)] = ch
print(f"[dim]Chargé {len(docs)} passages depuis {ROOT}.[/]") # reconstruire docs dans le même ordre que l'index ---------------------------
docs = []
for m in meta:
key = (m["file"], m["part"])
docs.append(cache.get(key, "[passage manquant]"))
# --- FAISS index existant --------------------------------------------------- print(f"[dim]Passages rechargés : {len(docs)} (ordre conforme à l'index).[/]")
idx = faiss.read_index("corpus.idx") # ---------------- FAISS + modèle embeddings --------------------------------
idx = faiss.read_index("corpus.idx")
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
# --- boucle Q/A ------------------------------------------------------------- # ---------------- boucle interactive ---------------------------------------
def fetch_passage(i: int):
m = meta[i]
return f"[{m['file']} · part {m['part']}] {docs[i][:200]}"
def ask_llm(prompt: str):
r = requests.post("http://127.0.0.1:11434/api/generate", json={
"model": "mistral7b-fast",
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.2, "num_predict": 512}
}, timeout=300)
return r.json()["response"]
print("RAG prêt. Posez vos questions ! (CtrlD pour sortir)") print("RAG prêt. Posez vos questions ! (CtrlD pour sortir)")
try: try:
while True: while True:
@ -72,20 +74,34 @@ try:
emb = model.encode([q]) emb = model.encode([q])
if isinstance(emb, dict): if isinstance(emb, dict):
# récupère le 1er ndarray trouvé
emb = next(v for v in emb.values() if isinstance(v, np.ndarray)) emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
q_emb = emb[0] / np.linalg.norm(emb[0]) q_emb = emb[0] / np.linalg.norm(emb[0])
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K) D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K)
hits = I[0]
context = "\n\n".join(docs[int(idx_id)] for idx_id in I[0]) context = "\n\n".join(docs[int(i)] for i in hits)
prompt = f"""<system>Réponds en français, précis et factuel.</system>\n<context>{context}</context>\n<user>{q}</user>""" prompt = (
"<system>Réponds en français, précis et factuel.</system>\n"
f"<context>{context}</context>\n<user>{q}</user>"
)
def ask_llm(p):
r = requests.post("http://127.0.0.1:11434/api/generate", json={
"model": "mistral7b-fast",
"prompt": p,
"stream": False,
"options": {"temperature": 0.2, "num_predict": 512}
}, timeout=300)
return r.json()["response"]
print("\n[bold]Réponse :[/]") print("\n[bold]Réponse :[/]")
print(ask_llm(prompt)) print(ask_llm(prompt))
# petite trace des sources
print("\n[dim]--- contexte utilisé ---[/]") print("\n[dim]--- contexte utilisé ---[/]")
for idx_id in I[0]: for rank, idx_id in enumerate(hits, 1):
print(fetch_passage(int(idx_id))) m = meta[int(idx_id)]
print(f"[{rank}] {m['file']} · part {m['part']}{docs[int(idx_id)][:120]}")
except Exception as e: except Exception as e:
print("[red]Erreur :", e) print("[red]Erreur :", e)