diff --git a/rag.py b/rag.py
index a223bb6..f622a94 100644
--- a/rag.py
+++ b/rag.py
@@ -1,65 +1,67 @@
#!/usr/bin/env python3
"""
-RAG interactif robuste.
-• Recharge les passages à partir des fiches (même découpe que l'index) pour disposer du texte.
-• Recherche FAISS top‑k=4 et génération via mistral7b-fast.
+RAG interactif – version alignée sur l'index
+-------------------------------------------
+• Utilise corpus.idx + corpus.meta.json pour connaître l'ordre exact des passages.
+• Recharge **uniquement** les textes correspondants en gardant cet ordre – ainsi, plus
+ d'erreur d'index out‑of‑range quelle que soit la découpe.
+• Recherche FAISS (top‑k=4) + génération via mistral7b-fast (Ollama).
"""
-import os, json, readline, re
+
+import json, readline, re
from pathlib import Path
+from collections import defaultdict
import faiss, numpy as np, requests
from FlagEmbedding import BGEM3FlagModel
from rich import print
-# ---------------------------------------------------------------------------
-ROOT = Path("Fiches") # dossier des fiches sur l'hôte
-CHUNK, OVERLAP = 800, 100 # identiques à l'indexation
-K = 4 # nb de passages remis au modèle
+ROOT = Path("Fiches") # dossier racine des fiches
+K = 4 # nombre de passages remis au LLM
-# --- découpe ---------------------------------------------------------------
+# ------------------ utilitaires de découpe identiques à l'index -------------
+CHUNK, OVERLAP = 800, 100 # garder cohérent avec index.py
def split(text: str):
sents = re.split(r"(?<=[.!?]) +", text)
buf, out = [], []
for s in sents:
buf.append(s)
- if len(" ".join(buf).split()) > CHUNK: # approx 1 mot = 1 token
+ if len(" ".join(buf).split()) > CHUNK:
out.append(" ".join(buf))
buf = buf[-OVERLAP:]
if buf:
out.append(" ".join(buf))
return out
-# --- charger docs + meta dans le même ordre que l'index --------------------
+# ------------------- charger meta et reconstruire passages ------------------
+meta_path = Path("corpus.meta.json")
+if not meta_path.exists():
+ raise SystemExit("corpus.meta.json introuvable – lancez d'abord index.py")
+meta = json.load(meta_path.open())
-docs, meta = [], []
-for fp in ROOT.rglob("*.md"):
- for i, chunk in enumerate(split(fp.read_text(encoding="utf-8"))):
- docs.append(chunk)
- meta.append({"file": fp.name, "part": i})
+# mapping (file, part) -> chunk text
+cache: dict[tuple[str, int], str] = {}
+for fp in sorted(ROOT.rglob("*")):
+ if fp.suffix.lower() not in {".md", ".markdown", ".txt"}:
+ continue
+ chunks = split(fp.read_text(encoding="utf-8"))
+ for i, ch in enumerate(chunks):
+ cache[(fp.name, i)] = ch
-print(f"[dim]Chargé {len(docs)} passages depuis {ROOT}.[/]")
+# reconstruire docs dans le même ordre que l'index ---------------------------
+docs = []
+for m in meta:
+ key = (m["file"], m["part"])
+ docs.append(cache.get(key, "[passage manquant]"))
-# --- FAISS index existant ---------------------------------------------------
+print(f"[dim]Passages rechargés : {len(docs)} (ordre conforme à l'index).[/]")
-idx = faiss.read_index("corpus.idx")
+# ---------------- FAISS + modèle embeddings --------------------------------
+idx = faiss.read_index("corpus.idx")
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
-# --- boucle Q/A -------------------------------------------------------------
-
-def fetch_passage(i: int):
- m = meta[i]
- return f"[{m['file']} · part {m['part']}] {docs[i][:200]}…"
-
-def ask_llm(prompt: str):
- r = requests.post("http://127.0.0.1:11434/api/generate", json={
- "model": "mistral7b-fast",
- "prompt": prompt,
- "stream": False,
- "options": {"temperature": 0.2, "num_predict": 512}
- }, timeout=300)
- return r.json()["response"]
-
+# ---------------- boucle interactive ---------------------------------------
print("RAG prêt. Posez vos questions ! (Ctrl‑D pour sortir)")
try:
while True:
@@ -72,20 +74,34 @@ try:
emb = model.encode([q])
if isinstance(emb, dict):
+ # récupère le 1er ndarray trouvé
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
q_emb = emb[0] / np.linalg.norm(emb[0])
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K)
+ hits = I[0]
- context = "\n\n".join(docs[int(idx_id)] for idx_id in I[0])
- prompt = f"""Réponds en français, précis et factuel.\n{context}\n{q}"""
+ context = "\n\n".join(docs[int(i)] for i in hits)
+ prompt = (
+ "Réponds en français, précis et factuel.\n"
+ f"{context}\n{q}"
+ )
+
+ def ask_llm(p):
+ r = requests.post("http://127.0.0.1:11434/api/generate", json={
+ "model": "mistral7b-fast",
+ "prompt": p,
+ "stream": False,
+ "options": {"temperature": 0.2, "num_predict": 512}
+ }, timeout=300)
+ return r.json()["response"]
print("\n[bold]Réponse :[/]")
print(ask_llm(prompt))
- # petite trace des sources
print("\n[dim]--- contexte utilisé ---[/]")
- for idx_id in I[0]:
- print(fetch_passage(int(idx_id)))
+ for rank, idx_id in enumerate(hits, 1):
+ m = meta[int(idx_id)]
+ print(f"[{rank}] {m['file']} · part {m['part']} → {docs[int(idx_id)][:120]}…")
except Exception as e:
print("[red]Erreur :", e)