diff --git a/rag.py b/rag.py
index 2ecac8f..a223bb6 100644
--- a/rag.py
+++ b/rag.py
@@ -1,107 +1,91 @@
#!/usr/bin/env python3
"""
-Mini RAG interactif :
-• Recherche sémantique FAISS sur le corpus.idx / corpus.meta.json
-• Contexte (top‑k=4 passages) envoyé à Mistral‑7B via Ollama.
-
-Robuste aux différentes sorties de BGEM3FlagModel.encode.
+RAG interactif robuste.
+• Recharge les passages à partir des fiches (même découpe que l'index) pour disposer du texte.
+• Recherche FAISS top‑k=4 et génération via mistral7b-fast.
"""
-import json
-import readline
+import os, json, readline, re
from pathlib import Path
-import faiss
-import numpy as np
-import requests
+import faiss, numpy as np, requests
from FlagEmbedding import BGEM3FlagModel
from rich import print
# ---------------------------------------------------------------------------
-# Chargements initiaux
-# ---------------------------------------------------------------------------
-IDX_FILE = Path("corpus.idx")
-META_FILE = Path("corpus.meta.json")
+ROOT = Path("Fiches") # dossier des fiches sur l'hôte
+CHUNK, OVERLAP = 800, 100 # identiques à l'indexation
+K = 4 # nb de passages remis au modèle
-if not IDX_FILE.exists() or not META_FILE.exists():
- raise SystemExit("[bold red]Erreur :[/] index absent. Lancez d'abord index.py !")
+# --- découpe ---------------------------------------------------------------
-index = faiss.read_index(str(IDX_FILE))
-meta = json.loads(META_FILE.read_text())
+def split(text: str):
+ sents = re.split(r"(?<=[.!?]) +", text)
+ buf, out = [], []
+ for s in sents:
+ buf.append(s)
+ if len(" ".join(buf).split()) > CHUNK: # approx 1 mot = 1 token
+ out.append(" ".join(buf))
+ buf = buf[-OVERLAP:]
+ if buf:
+ out.append(" ".join(buf))
+ return out
+
+# --- charger docs + meta dans le même ordre que l'index --------------------
+
+docs, meta = [], []
+for fp in ROOT.rglob("*.md"):
+ for i, chunk in enumerate(split(fp.read_text(encoding="utf-8"))):
+ docs.append(chunk)
+ meta.append({"file": fp.name, "part": i})
+
+print(f"[dim]Chargé {len(docs)} passages depuis {ROOT}.[/]")
+
+# --- FAISS index existant ---------------------------------------------------
+
+idx = faiss.read_index("corpus.idx")
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
-# ---------------------------------------------------------------------------
-# Utilitaires
-# ---------------------------------------------------------------------------
+# --- boucle Q/A -------------------------------------------------------------
-def _normalize(x: np.ndarray) -> np.ndarray:
- """L2‑normalize each row (tokens=float32)."""
- return x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-12)
-
-
-def embed(texts):
- """Encode list[str] → ndarray (n, dim), quelle que soit la sortie lib."""
- out = model.encode(texts)
- # Possible shapes :
- # • ndarray
- # • dict {"embedding": ndarray} ou {"embeddings": ndarray}
- # • dict {"sentence_embeds": [...]} etc.
- if isinstance(out, np.ndarray):
- arr = out
- elif isinstance(out, dict):
- # pick the first ndarray-like value
- for v in out.values():
- if isinstance(v, (list, tuple)) or hasattr(v, "shape"):
- arr = np.asarray(v)
- break
- else:
- raise TypeError("encode() dict sans clé embedding !")
- else: # list[list[float]] etc.
- arr = np.asarray(out)
- return _normalize(arr.astype("float32"))
-
-
-def fetch_passage(i: int) -> str:
+def fetch_passage(i: int):
m = meta[i]
- return f"[{m['file']} · part {m['part']}] {m['text']}"
+ return f"[{m['file']} · part {m['part']}] {docs[i][:200]}…"
-
-def ask_llm(prompt: str) -> str:
- r = requests.post(
- "http://127.0.0.1:11434/api/generate",
- json={
- "model": "mistral7b-fast",
- "prompt": prompt,
- "stream": False,
- "options": {"temperature": 0.2, "num_predict": 512},
- },
- timeout=300,
- )
+def ask_llm(prompt: str):
+ r = requests.post("http://127.0.0.1:11434/api/generate", json={
+ "model": "mistral7b-fast",
+ "prompt": prompt,
+ "stream": False,
+ "options": {"temperature": 0.2, "num_predict": 512}
+ }, timeout=300)
return r.json()["response"]
-# ---------------------------------------------------------------------------
-# Boucle interactive
-# ---------------------------------------------------------------------------
-print("[bold green]RAG prêt.[/] Posez vos questions ! (Ctrl‑D pour sortir)")
-while True:
- try:
- q = input("❓ > ").strip()
- if not q:
- continue
- except (EOFError, KeyboardInterrupt):
- print("\n[dim]Bye.[/]")
- break
+print("RAG prêt. Posez vos questions ! (Ctrl‑D pour sortir)")
+try:
+ while True:
+ try:
+ q = input("❓ > ").strip()
+ if not q:
+ continue
+ except (EOFError, KeyboardInterrupt):
+ print("\nBye."); break
- q_emb = embed([q]) # (1, dim)
- D, I = index.search(q_emb, 4)
- ctx = "\n\n".join(fetch_passage(int(idx)) for idx in I[0])
+ emb = model.encode([q])
+ if isinstance(emb, dict):
+ emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
+ q_emb = emb[0] / np.linalg.norm(emb[0])
- prompt = (
- "Réponds en français, précis et factuel.\n"
- f"{ctx}\n"
- f"{q}"
- )
+ D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K)
- print("\n[bold]Réponse :[/]\n")
- print(ask_llm(prompt))
- print("\n[dim]--- contexte utilisé ---[/]")
- print(ctx)
+ context = "\n\n".join(docs[int(idx_id)] for idx_id in I[0])
+ prompt = f"""Réponds en français, précis et factuel.\n{context}\n{q}"""
+
+ print("\n[bold]Réponse :[/]")
+ print(ask_llm(prompt))
+
+ # petite trace des sources
+ print("\n[dim]--- contexte utilisé ---[/]")
+ for idx_id in I[0]:
+ print(fetch_passage(int(idx_id)))
+except Exception as e:
+ print("[red]Erreur :", e)