Update rag.py

This commit is contained in:
Stéphan Peccini 2025-05-19 08:05:07 +02:00
parent 86a902de9d
commit 4b16c2210e

191
rag.py
View File

@ -1,113 +1,124 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
""" """
RAG interactif version alignée sur l'index rag.py recherche + génération (version robuste, chapitres)
------------------------------------------- ============================================================
Utilise corpus.idx + corpus.meta.json pour connaître l'ordre exact des passages. Charge **un ou plusieurs** couples index/meta (FAISS + JSON). Par défaut :
Recharge **uniquement** les textes correspondants en gardant cet ordre ainsi, plus chap.idx / chap.meta.json
d'erreur d'index outofrange quelle que soit la découpe. Reconstitue les textes à partir des fichiers `path` indiqués dans la méta.
Recherche FAISS (topk=4) + génération via mistral7b-fast (Ollama). Les passages sont déjà prêts (1 par fichier court, ou découpés par index.py).
Recherche : embeddings BGEM3 (CPU) + FAISS (cosinus IP) sur tous les index.
topk configurable (déf. 20 pour index détaillé, 5 pour index chapitres).
trie ensuite les hits mettant en avant ceux contenant un motclé fourni
(ex. « seuil » pour ICS).
Génération : appelle Mistral7B (Ollama) avec temperature 0.1 et consigne :
« Réponds uniquement à partir du contexte. Si linfo manque : Je ne sais pas. »
Usage :
python rag.py [--k 25] [--kw seuil] [--model mistral7b-fast]
""" """
from __future__ import annotations
import json, readline, re import argparse, json, re, sys
from pathlib import Path from pathlib import Path
from collections import defaultdict
import faiss, numpy as np, requests import faiss, numpy as np, requests
from FlagEmbedding import BGEM3FlagModel from FlagEmbedding import BGEM3FlagModel
from rich import print from rich import print
ROOT = Path("Fiches") # dossier racine des fiches # ------------------------- CLI -------------------------------------------
K = 30 # nombre de passages remis au LLM p = argparse.ArgumentParser()
p.add_argument("--index", nargs="*", default=["chap.idx"],
help="Liste des fichiers FAISS à charger (déf. chap.idx)")
p.add_argument("--meta", nargs="*", default=["chap.meta.json"],
help="Liste des méta JSON assortis (même ordre que --index)")
p.add_argument("--k", type=int, default=15, help="topk cumulés (déf. 15)")
p.add_argument("--kw", default="seuil", help="motclé boosté (déf. seuil)")
p.add_argument("--model", default="mistral7b-fast", help="modèle Ollama")
args = p.parse_args()
# ------------------ utilitaires de découpe identiques à l'index ------------- if len(args.index) != len(args.meta):
CHUNK, OVERLAP = 800, 100 # garder cohérent avec index.py print("[red]Erreur : --index et --meta doivent avoir la même longueur.")
sys.exit(1)
def split(text: str): # ------------------------- charger indexes -------------------------------
sents = re.split(r"(?<=[.!?]) +", text) indexes, metas, start_offset = [], [], []
buf, out = [], [] offset = 0
for s in sents: for idx_f, meta_f in zip(args.index, args.meta):
buf.append(s) idx = faiss.read_index(str(idx_f))
if len(" ".join(buf).split()) > CHUNK: meta = json.load(open(meta_f))
out.append(" ".join(buf)) if idx.ntotal != len(meta):
buf = buf[-OVERLAP:] print(f"[yellow]Avertissement : {idx_f} contient {idx.ntotal} vecteurs, meta {len(meta)} lignes.[/]")
if buf: indexes.append(idx)
out.append(" ".join(buf)) metas.append(meta)
return out start_offset.append(offset)
offset += idx.ntotal
# ------------------- charger meta et reconstruire passages ------------------ total_passages = offset
meta_path = Path("corpus.meta.json") print(f"Passages chargés : {total_passages} (agrégat de {len(indexes)} index)")
if not meta_path.exists():
raise SystemExit("corpus.meta.json introuvable lancez d'abord index.py")
meta = json.load(meta_path.open())
# mapping (file, part) -> chunk text # ------------------------- cache texte -----------------------------------
cache: dict[tuple[str, int], str] = {} DOCS: dict[int,str] = {}
for fp in sorted(ROOT.rglob("*")): for base_offset, meta in zip(start_offset, metas):
if fp.suffix.lower() not in {".md", ".markdown", ".txt"}: for i, m in enumerate(meta):
continue DOCS[base_offset + i] = Path(m["path"]).read_text(encoding="utf-8")
chunks = split(fp.read_text(encoding="utf-8")) print("[dim]Cache texte préchargé.[/]")
for i, ch in enumerate(chunks):
cache[(fp.name, i)] = ch
# reconstruire docs dans le même ordre que l'index --------------------------- # ------------------------- modèle embeddings -----------------------------
docs = [] embedder = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
for m in meta:
# compatibilité avec lancien et le nouveau format
path = m.get("file") or m.get("path") # nouvelle clé : "path"
part = m["part"]
key = (Path(path).name, part) # on garde le nom court pour le cache
docs.append(cache.get(key, "[passage manquant]")) # ------------------------- helpers ---------------------------------------
def encode_query(q: str):
emb = embedder.encode([q])
if isinstance(emb, dict):
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
v = emb[0]
return (v / np.linalg.norm(v)).astype("float32").reshape(1, -1)
print(f"[dim]Passages rechargés : {len(docs)} (ordre conforme à l'index).[/]") def search_all(vec):
hits = []
for idx, off in zip(indexes, start_offset):
D, I = idx.search(vec, min(args.k, idx.ntotal))
hits.extend([off + int(i) for i in I[0]])
return hits
# ---------------- FAISS + modèle embeddings -------------------------------- # ------------------------- boucle interactive ----------------------------
idx = faiss.read_index("corpus.idx") print("RAG prêt ! (CtrlD pour quitter)")
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") while True:
try:
q = input("❓ > ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye."); break
if not q: continue
# ---------------- boucle interactive --------------------------------------- # correction rapide de typos courantes (substituabilité…)
print("RAG prêt. Posez vos questions ! (CtrlD pour sortir)") q_norm = re.sub(r"susbtitu[a-z]+", "substituabilité", q, flags=re.I)
try:
while True:
try:
q = input("❓ > ").strip()
if not q:
continue
except (EOFError, KeyboardInterrupt):
print("\nBye."); break
emb = model.encode([q]) vec = encode_query(q_norm)
if isinstance(emb, dict): hits = search_all(vec)
# récupère le 1er ndarray trouvé
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
q_emb = emb[0] / np.linalg.norm(emb[0])
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K) # Boost lexical : passages contenant le motclé args.kw dabord
hits = I[0] kw_lower = args.kw.lower()
# réordonne pour mettre en tête les passages contenant “Seuil” hits.sort(key=lambda i: kw_lower not in DOCS[i].lower())
hits = sorted(hits, key=lambda i: "Seuil" not in docs[int(i)]) hits = hits[:args.k]
context = "\n\n".join(docs[int(i)] for i in hits[:K]) context = "\n\n".join(DOCS[i] for i in hits)
prompt = (
"<system>Réponds en français, de façon précise, et uniquement à partir du contexte fourni. Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.</system>\n"
f"<context>{context}</context>\n<user>{q}</user>"
)
def ask_llm(p): prompt = (
r = requests.post("http://127.0.0.1:11434/api/generate", json={ "<system>Réponds en français, de façon précise et uniquement à partir du contexte. "
"model": "mistral7b-fast", "Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.</system>\n"
"prompt": p, f"<context>{context}</context>\n"
"stream": False, f"<user>{q}</user>"
"options": {"temperature": 0.0, "num_predict": 512} )
}, timeout=300)
return r.json()["response"]
print("\n[bold]Réponse :[/]") r = requests.post("http://127.0.0.1:11434/api/generate", json={
print(ask_llm(prompt)) "model": args.model,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.1, "num_predict": 512}
}, timeout=300)
answer = r.json().get("response", "(erreur API)")
print("\n[dim]--- contexte utilisé ---[/]") print("\n[bold]Réponse :[/]\n", answer)
for rank, idx_id in enumerate(hits, 1): print("\n[dim]--- contexte utilisé (top " + str(len(hits)) + ") ---[/]")
m = meta[int(idx_id)] for rank, idx_id in enumerate(hits, 1):
print(f"[{rank}] {Path(m.get('file') or m.get('path')).name} · part {m['part']} → …") m = metas[0] # non utilisé ici, on affiche juste le nom
except Exception as e: path = DOCS[idx_id].splitlines()[0][:60]
print("[red]Erreur :", e) print(f"[{rank}] … {path}")