diff --git a/rag.py b/rag.py
index db6a9b8..2ecac8f 100644
--- a/rag.py
+++ b/rag.py
@@ -1,58 +1,107 @@
#!/usr/bin/env python3
-import faiss, json, requests, readline, numpy as np
-from rich import print
+"""
+Mini RAG interactif :
+• Recherche sémantique FAISS sur le corpus.idx / corpus.meta.json
+• Contexte (top‑k=4 passages) envoyé à Mistral‑7B via Ollama.
+
+Robuste aux différentes sorties de BGEM3FlagModel.encode.
+"""
+import json
+import readline
+from pathlib import Path
+
+import faiss
+import numpy as np
+import requests
from FlagEmbedding import BGEM3FlagModel
+from rich import print
-# --- chargements -------------------------------------------------------------
-idx = faiss.read_index("corpus.idx")
-meta = json.load(open("corpus.meta.json"))
-model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") # même qu’à l’indexation
+# ---------------------------------------------------------------------------
+# Chargements initiaux
+# ---------------------------------------------------------------------------
+IDX_FILE = Path("corpus.idx")
+META_FILE = Path("corpus.meta.json")
-# simple aide mémoire pour retrouver rapidement un passage
-def fetch_passage(i):
+if not IDX_FILE.exists() or not META_FILE.exists():
+ raise SystemExit("[bold red]Erreur :[/] index absent. Lancez d'abord index.py !")
+
+index = faiss.read_index(str(IDX_FILE))
+meta = json.loads(META_FILE.read_text())
+model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
+
+# ---------------------------------------------------------------------------
+# Utilitaires
+# ---------------------------------------------------------------------------
+
+def _normalize(x: np.ndarray) -> np.ndarray:
+ """L2‑normalize each row (tokens=float32)."""
+ return x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-12)
+
+
+def embed(texts):
+ """Encode list[str] → ndarray (n, dim), quelle que soit la sortie lib."""
+ out = model.encode(texts)
+ # Possible shapes :
+ # • ndarray
+ # • dict {"embedding": ndarray} ou {"embeddings": ndarray}
+ # • dict {"sentence_embeds": [...]} etc.
+ if isinstance(out, np.ndarray):
+ arr = out
+ elif isinstance(out, dict):
+ # pick the first ndarray-like value
+ for v in out.values():
+ if isinstance(v, (list, tuple)) or hasattr(v, "shape"):
+ arr = np.asarray(v)
+ break
+ else:
+ raise TypeError("encode() dict sans clé embedding !")
+ else: # list[list[float]] etc.
+ arr = np.asarray(out)
+ return _normalize(arr.astype("float32"))
+
+
+def fetch_passage(i: int) -> str:
m = meta[i]
return f"[{m['file']} · part {m['part']}] {m['text']}"
-def ask_llm(prompt):
- r = requests.post("http://127.0.0.1:11434/api/generate", json={
- "model": "mistral7b-fast",
- "prompt": prompt,
- "stream": False,
- "options": {"temperature":0.2, "num_predict":512}
- }, timeout=300)
+
+def ask_llm(prompt: str) -> str:
+ r = requests.post(
+ "http://127.0.0.1:11434/api/generate",
+ json={
+ "model": "mistral7b-fast",
+ "prompt": prompt,
+ "stream": False,
+ "options": {"temperature": 0.2, "num_predict": 512},
+ },
+ timeout=300,
+ )
return r.json()["response"]
-# --- boucle interactive ------------------------------------------------------
+# ---------------------------------------------------------------------------
+# Boucle interactive
+# ---------------------------------------------------------------------------
+print("[bold green]RAG prêt.[/] Posez vos questions ! (Ctrl‑D pour sortir)")
while True:
try:
q = input("❓ > ").strip()
- if not q: continue
- except (KeyboardInterrupt, EOFError):
- print("\nBye."); break
+ if not q:
+ continue
+ except (EOFError, KeyboardInterrupt):
+ print("\n[dim]Bye.[/]")
+ break
- # embeddings & recherche FAISS (top-k=4)
- # remplace ces deux lignes (32-34)
- # q_emb = model.encode([q], normalize_embeddings=True)
- # D, I = idx.search(q_emb.astype("float32"), 4)
+ q_emb = embed([q]) # (1, dim)
+ D, I = index.search(q_emb, 4)
+ ctx = "\n\n".join(fetch_passage(int(idx)) for idx in I[0])
- emb = model.encode([q]) # ndarray (1, 1024)
- if isinstance(emb, dict): # selon la version de FlagEmbedding
- emb = emb.get("embedding") or emb.get("embeddings")
- q_emb = emb[0] / np.linalg.norm(emb[0]) # L2 normalisation
+ prompt = (
+ "Réponds en français, précis et factuel.\n"
+ f"{ctx}\n"
+ f"{q}"
+ )
- D, I = idx.search(q_emb.astype("float32").reshape(1, -1), 4)
-
-
- ctx_blocks = []
- for rank, idx_id in enumerate(I[0]):
- ctx_blocks.append(fetch_passage(idx_id))
- context = "\n\n".join(ctx_blocks)
-
- prompt = f"""Réponds en français, précis et factuel.
-{context}
-{q}"""
-
- print("\n[bold]Réponse :[/]\n")
+ print("\n[bold]Réponse :[/]\n")
print(ask_llm(prompt))
print("\n[dim]--- contexte utilisé ---[/]")
- print(context)
+ print(ctx)