Update rag.py

This commit is contained in:
Stéphan Peccini 2025-05-19 08:05:07 +02:00
parent 86a902de9d
commit 4b16c2210e

173
rag.py
View File

@ -1,113 +1,124 @@
#!/usr/bin/env python3
"""
RAG interactif version alignée sur l'index
-------------------------------------------
Utilise corpus.idx + corpus.meta.json pour connaître l'ordre exact des passages.
Recharge **uniquement** les textes correspondants en gardant cet ordre ainsi, plus
d'erreur d'index outofrange quelle que soit la découpe.
Recherche FAISS (topk=4) + génération via mistral7b-fast (Ollama).
rag.py recherche + génération (version robuste, chapitres)
============================================================
Charge **un ou plusieurs** couples index/meta (FAISS + JSON). Par défaut :
chap.idx / chap.meta.json
Reconstitue les textes à partir des fichiers `path` indiqués dans la méta.
Les passages sont déjà prêts (1 par fichier court, ou découpés par index.py).
Recherche : embeddings BGEM3 (CPU) + FAISS (cosinus IP) sur tous les index.
topk configurable (déf. 20 pour index détaillé, 5 pour index chapitres).
trie ensuite les hits mettant en avant ceux contenant un motclé fourni
(ex. « seuil » pour ICS).
Génération : appelle Mistral7B (Ollama) avec temperature 0.1 et consigne :
« Réponds uniquement à partir du contexte. Si linfo manque : Je ne sais pas. »
Usage :
python rag.py [--k 25] [--kw seuil] [--model mistral7b-fast]
"""
import json, readline, re
from __future__ import annotations
import argparse, json, re, sys
from pathlib import Path
from collections import defaultdict
import faiss, numpy as np, requests
from FlagEmbedding import BGEM3FlagModel
from rich import print
ROOT = Path("Fiches") # dossier racine des fiches
K = 30 # nombre de passages remis au LLM
# ------------------------- CLI -------------------------------------------
p = argparse.ArgumentParser()
p.add_argument("--index", nargs="*", default=["chap.idx"],
help="Liste des fichiers FAISS à charger (déf. chap.idx)")
p.add_argument("--meta", nargs="*", default=["chap.meta.json"],
help="Liste des méta JSON assortis (même ordre que --index)")
p.add_argument("--k", type=int, default=15, help="topk cumulés (déf. 15)")
p.add_argument("--kw", default="seuil", help="motclé boosté (déf. seuil)")
p.add_argument("--model", default="mistral7b-fast", help="modèle Ollama")
args = p.parse_args()
# ------------------ utilitaires de découpe identiques à l'index -------------
CHUNK, OVERLAP = 800, 100 # garder cohérent avec index.py
if len(args.index) != len(args.meta):
print("[red]Erreur : --index et --meta doivent avoir la même longueur.")
sys.exit(1)
def split(text: str):
sents = re.split(r"(?<=[.!?]) +", text)
buf, out = [], []
for s in sents:
buf.append(s)
if len(" ".join(buf).split()) > CHUNK:
out.append(" ".join(buf))
buf = buf[-OVERLAP:]
if buf:
out.append(" ".join(buf))
return out
# ------------------------- charger indexes -------------------------------
indexes, metas, start_offset = [], [], []
offset = 0
for idx_f, meta_f in zip(args.index, args.meta):
idx = faiss.read_index(str(idx_f))
meta = json.load(open(meta_f))
if idx.ntotal != len(meta):
print(f"[yellow]Avertissement : {idx_f} contient {idx.ntotal} vecteurs, meta {len(meta)} lignes.[/]")
indexes.append(idx)
metas.append(meta)
start_offset.append(offset)
offset += idx.ntotal
# ------------------- charger meta et reconstruire passages ------------------
meta_path = Path("corpus.meta.json")
if not meta_path.exists():
raise SystemExit("corpus.meta.json introuvable lancez d'abord index.py")
meta = json.load(meta_path.open())
total_passages = offset
print(f"Passages chargés : {total_passages} (agrégat de {len(indexes)} index)")
# mapping (file, part) -> chunk text
cache: dict[tuple[str, int], str] = {}
for fp in sorted(ROOT.rglob("*")):
if fp.suffix.lower() not in {".md", ".markdown", ".txt"}:
continue
chunks = split(fp.read_text(encoding="utf-8"))
for i, ch in enumerate(chunks):
cache[(fp.name, i)] = ch
# ------------------------- cache texte -----------------------------------
DOCS: dict[int,str] = {}
for base_offset, meta in zip(start_offset, metas):
for i, m in enumerate(meta):
DOCS[base_offset + i] = Path(m["path"]).read_text(encoding="utf-8")
print("[dim]Cache texte préchargé.[/]")
# reconstruire docs dans le même ordre que l'index ---------------------------
docs = []
for m in meta:
# compatibilité avec lancien et le nouveau format
path = m.get("file") or m.get("path") # nouvelle clé : "path"
part = m["part"]
key = (Path(path).name, part) # on garde le nom court pour le cache
# ------------------------- modèle embeddings -----------------------------
embedder = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
docs.append(cache.get(key, "[passage manquant]"))
# ------------------------- helpers ---------------------------------------
def encode_query(q: str):
emb = embedder.encode([q])
if isinstance(emb, dict):
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
v = emb[0]
return (v / np.linalg.norm(v)).astype("float32").reshape(1, -1)
print(f"[dim]Passages rechargés : {len(docs)} (ordre conforme à l'index).[/]")
def search_all(vec):
hits = []
for idx, off in zip(indexes, start_offset):
D, I = idx.search(vec, min(args.k, idx.ntotal))
hits.extend([off + int(i) for i in I[0]])
return hits
# ---------------- FAISS + modèle embeddings --------------------------------
idx = faiss.read_index("corpus.idx")
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
# ---------------- boucle interactive ---------------------------------------
print("RAG prêt. Posez vos questions ! (CtrlD pour sortir)")
try:
# ------------------------- boucle interactive ----------------------------
print("RAG prêt ! (CtrlD pour quitter)")
while True:
try:
q = input("❓ > ").strip()
if not q:
continue
except (EOFError, KeyboardInterrupt):
print("\nBye."); break
if not q: continue
emb = model.encode([q])
if isinstance(emb, dict):
# récupère le 1er ndarray trouvé
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
q_emb = emb[0] / np.linalg.norm(emb[0])
# correction rapide de typos courantes (substituabilité…)
q_norm = re.sub(r"susbtitu[a-z]+", "substituabilité", q, flags=re.I)
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K)
hits = I[0]
# réordonne pour mettre en tête les passages contenant “Seuil”
hits = sorted(hits, key=lambda i: "Seuil" not in docs[int(i)])
vec = encode_query(q_norm)
hits = search_all(vec)
# Boost lexical : passages contenant le motclé args.kw dabord
kw_lower = args.kw.lower()
hits.sort(key=lambda i: kw_lower not in DOCS[i].lower())
hits = hits[:args.k]
context = "\n\n".join(DOCS[i] for i in hits)
context = "\n\n".join(docs[int(i)] for i in hits[:K])
prompt = (
"<system>Réponds en français, de façon précise, et uniquement à partir du contexte fourni. Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.</system>\n"
f"<context>{context}</context>\n<user>{q}</user>"
"<system>Réponds en français, de façon précise et uniquement à partir du contexte. "
"Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.</system>\n"
f"<context>{context}</context>\n"
f"<user>{q}</user>"
)
def ask_llm(p):
r = requests.post("http://127.0.0.1:11434/api/generate", json={
"model": "mistral7b-fast",
"prompt": p,
"model": args.model,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.0, "num_predict": 512}
"options": {"temperature": 0.1, "num_predict": 512}
}, timeout=300)
return r.json()["response"]
answer = r.json().get("response", "(erreur API)")
print("\n[bold]Réponse :[/]")
print(ask_llm(prompt))
print("\n[dim]--- contexte utilisé ---[/]")
print("\n[bold]Réponse :[/]\n", answer)
print("\n[dim]--- contexte utilisé (top " + str(len(hits)) + ") ---[/]")
for rank, idx_id in enumerate(hits, 1):
m = meta[int(idx_id)]
print(f"[{rank}] {Path(m.get('file') or m.get('path')).name} · part {m['part']} → …")
except Exception as e:
print("[red]Erreur :", e)
m = metas[0] # non utilisé ici, on affiche juste le nom
path = DOCS[idx_id].splitlines()[0][:60]
print(f"[{rank}] … {path}")