Update rag.py

This commit is contained in:
Stéphan Peccini 2025-05-19 06:22:52 +02:00
parent c1a0a8e072
commit 852b81ba93

131
rag.py
View File

@ -1,58 +1,107 @@
#!/usr/bin/env python3
import faiss, json, requests, readline, numpy as np
from rich import print
"""
Mini RAG interactif :
Recherche sémantique FAISS sur le corpus.idx / corpus.meta.json
Contexte (topk=4 passages) envoyé à Mistral7B via Ollama.
Robuste aux différentes sorties de BGEM3FlagModel.encode.
"""
import json
import readline
from pathlib import Path
import faiss
import numpy as np
import requests
from FlagEmbedding import BGEM3FlagModel
from rich import print
# --- chargements -------------------------------------------------------------
idx = faiss.read_index("corpus.idx")
meta = json.load(open("corpus.meta.json"))
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") # même quà lindexation
# ---------------------------------------------------------------------------
# Chargements initiaux
# ---------------------------------------------------------------------------
IDX_FILE = Path("corpus.idx")
META_FILE = Path("corpus.meta.json")
# simple aide mémoire pour retrouver rapidement un passage
def fetch_passage(i):
if not IDX_FILE.exists() or not META_FILE.exists():
raise SystemExit("[bold red]Erreur :[/] index absent. Lancez d'abord index.py !")
index = faiss.read_index(str(IDX_FILE))
meta = json.loads(META_FILE.read_text())
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
# ---------------------------------------------------------------------------
# Utilitaires
# ---------------------------------------------------------------------------
def _normalize(x: np.ndarray) -> np.ndarray:
"""L2normalize each row (tokens=float32)."""
return x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-12)
def embed(texts):
"""Encode list[str] → ndarray (n, dim), quelle que soit la sortie lib."""
out = model.encode(texts)
# Possible shapes :
# • ndarray
# • dict {"embedding": ndarray} ou {"embeddings": ndarray}
# • dict {"sentence_embeds": [...]} etc.
if isinstance(out, np.ndarray):
arr = out
elif isinstance(out, dict):
# pick the first ndarray-like value
for v in out.values():
if isinstance(v, (list, tuple)) or hasattr(v, "shape"):
arr = np.asarray(v)
break
else:
raise TypeError("encode() dict sans clé embedding !")
else: # list[list[float]] etc.
arr = np.asarray(out)
return _normalize(arr.astype("float32"))
def fetch_passage(i: int) -> str:
m = meta[i]
return f"[{m['file']} · part {m['part']}] {m['text']}"
def ask_llm(prompt):
r = requests.post("http://127.0.0.1:11434/api/generate", json={
"model": "mistral7b-fast",
"prompt": prompt,
"stream": False,
"options": {"temperature":0.2, "num_predict":512}
}, timeout=300)
def ask_llm(prompt: str) -> str:
r = requests.post(
"http://127.0.0.1:11434/api/generate",
json={
"model": "mistral7b-fast",
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.2, "num_predict": 512},
},
timeout=300,
)
return r.json()["response"]
# --- boucle interactive ------------------------------------------------------
# ---------------------------------------------------------------------------
# Boucle interactive
# ---------------------------------------------------------------------------
print("[bold green]RAG prêt.[/] Posez vos questions ! (CtrlD pour sortir)")
while True:
try:
q = input("❓ > ").strip()
if not q: continue
except (KeyboardInterrupt, EOFError):
print("\nBye."); break
if not q:
continue
except (EOFError, KeyboardInterrupt):
print("\n[dim]Bye.[/]")
break
# embeddings & recherche FAISS (top-k=4)
# remplace ces deux lignes (32-34)
# q_emb = model.encode([q], normalize_embeddings=True)
# D, I = idx.search(q_emb.astype("float32"), 4)
q_emb = embed([q]) # (1, dim)
D, I = index.search(q_emb, 4)
ctx = "\n\n".join(fetch_passage(int(idx)) for idx in I[0])
emb = model.encode([q]) # ndarray (1, 1024)
if isinstance(emb, dict): # selon la version de FlagEmbedding
emb = emb.get("embedding") or emb.get("embeddings")
q_emb = emb[0] / np.linalg.norm(emb[0]) # L2 normalisation
prompt = (
"<system>Réponds en français, précis et factuel.</system>\n"
f"<context>{ctx}</context>\n"
f"<user>{q}</user>"
)
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), 4)
ctx_blocks = []
for rank, idx_id in enumerate(I[0]):
ctx_blocks.append(fetch_passage(idx_id))
context = "\n\n".join(ctx_blocks)
prompt = f"""<system>Réponds en français, précis et factuel.</system>
<context>{context}</context>
<user>{q}</user>"""
print("\n[bold]Réponse :[/]\n")
print("\n[bold]Réponse :[/]\n")
print(ask_llm(prompt))
print("\n[dim]--- contexte utilisé ---[/]")
print(context)
print(ctx)