Update rag.py

This commit is contained in:
Stéphan Peccini 2025-05-19 06:24:29 +02:00
parent 852b81ba93
commit c56a46545f

156
rag.py
View File

@ -1,107 +1,91 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
Mini RAG interactif : RAG interactif robuste.
Recherche sémantique FAISS sur le corpus.idx / corpus.meta.json Recharge les passages à partir des fiches (même découpe que l'index) pour disposer du texte.
Contexte (topk=4 passages) envoyé à Mistral7B via Ollama. Recherche FAISS topk=4 et génération via mistral7b-fast.
Robuste aux différentes sorties de BGEM3FlagModel.encode.
""" """
import json import os, json, readline, re
import readline
from pathlib import Path from pathlib import Path
import faiss import faiss, numpy as np, requests
import numpy as np
import requests
from FlagEmbedding import BGEM3FlagModel from FlagEmbedding import BGEM3FlagModel
from rich import print from rich import print
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Chargements initiaux ROOT = Path("Fiches") # dossier des fiches sur l'hôte
# --------------------------------------------------------------------------- CHUNK, OVERLAP = 800, 100 # identiques à l'indexation
IDX_FILE = Path("corpus.idx") K = 4 # nb de passages remis au modèle
META_FILE = Path("corpus.meta.json")
if not IDX_FILE.exists() or not META_FILE.exists(): # --- découpe ---------------------------------------------------------------
raise SystemExit("[bold red]Erreur :[/] index absent. Lancez d'abord index.py !")
index = faiss.read_index(str(IDX_FILE)) def split(text: str):
meta = json.loads(META_FILE.read_text()) sents = re.split(r"(?<=[.!?]) +", text)
buf, out = [], []
for s in sents:
buf.append(s)
if len(" ".join(buf).split()) > CHUNK: # approx 1 mot = 1 token
out.append(" ".join(buf))
buf = buf[-OVERLAP:]
if buf:
out.append(" ".join(buf))
return out
# --- charger docs + meta dans le même ordre que l'index --------------------
docs, meta = [], []
for fp in ROOT.rglob("*.md"):
for i, chunk in enumerate(split(fp.read_text(encoding="utf-8"))):
docs.append(chunk)
meta.append({"file": fp.name, "part": i})
print(f"[dim]Chargé {len(docs)} passages depuis {ROOT}.[/]")
# --- FAISS index existant ---------------------------------------------------
idx = faiss.read_index("corpus.idx")
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
# --------------------------------------------------------------------------- # --- boucle Q/A -------------------------------------------------------------
# Utilitaires
# ---------------------------------------------------------------------------
def _normalize(x: np.ndarray) -> np.ndarray: def fetch_passage(i: int):
"""L2normalize each row (tokens=float32)."""
return x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-12)
def embed(texts):
"""Encode list[str] → ndarray (n, dim), quelle que soit la sortie lib."""
out = model.encode(texts)
# Possible shapes :
# • ndarray
# • dict {"embedding": ndarray} ou {"embeddings": ndarray}
# • dict {"sentence_embeds": [...]} etc.
if isinstance(out, np.ndarray):
arr = out
elif isinstance(out, dict):
# pick the first ndarray-like value
for v in out.values():
if isinstance(v, (list, tuple)) or hasattr(v, "shape"):
arr = np.asarray(v)
break
else:
raise TypeError("encode() dict sans clé embedding !")
else: # list[list[float]] etc.
arr = np.asarray(out)
return _normalize(arr.astype("float32"))
def fetch_passage(i: int) -> str:
m = meta[i] m = meta[i]
return f"[{m['file']} · part {m['part']}] {m['text']}" return f"[{m['file']} · part {m['part']}] {docs[i][:200]}"
def ask_llm(prompt: str):
def ask_llm(prompt: str) -> str: r = requests.post("http://127.0.0.1:11434/api/generate", json={
r = requests.post( "model": "mistral7b-fast",
"http://127.0.0.1:11434/api/generate", "prompt": prompt,
json={ "stream": False,
"model": "mistral7b-fast", "options": {"temperature": 0.2, "num_predict": 512}
"prompt": prompt, }, timeout=300)
"stream": False,
"options": {"temperature": 0.2, "num_predict": 512},
},
timeout=300,
)
return r.json()["response"] return r.json()["response"]
# --------------------------------------------------------------------------- print("RAG prêt. Posez vos questions ! (CtrlD pour sortir)")
# Boucle interactive try:
# --------------------------------------------------------------------------- while True:
print("[bold green]RAG prêt.[/] Posez vos questions ! (CtrlD pour sortir)") try:
while True: q = input("❓ > ").strip()
try: if not q:
q = input("❓ > ").strip() continue
if not q: except (EOFError, KeyboardInterrupt):
continue print("\nBye."); break
except (EOFError, KeyboardInterrupt):
print("\n[dim]Bye.[/]")
break
q_emb = embed([q]) # (1, dim) emb = model.encode([q])
D, I = index.search(q_emb, 4) if isinstance(emb, dict):
ctx = "\n\n".join(fetch_passage(int(idx)) for idx in I[0]) emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
q_emb = emb[0] / np.linalg.norm(emb[0])
prompt = ( D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K)
"<system>Réponds en français, précis et factuel.</system>\n"
f"<context>{ctx}</context>\n"
f"<user>{q}</user>"
)
print("\n[bold]Réponse :[/]\n") context = "\n\n".join(docs[int(idx_id)] for idx_id in I[0])
print(ask_llm(prompt)) prompt = f"""<system>Réponds en français, précis et factuel.</system>\n<context>{context}</context>\n<user>{q}</user>"""
print("\n[dim]--- contexte utilisé ---[/]")
print(ctx) print("\n[bold]Réponse :[/]")
print(ask_llm(prompt))
# petite trace des sources
print("\n[dim]--- contexte utilisé ---[/]")
for idx_id in I[0]:
print(fetch_passage(int(idx_id)))
except Exception as e:
print("[red]Erreur :", e)