Code/rag.py
2025-05-19 13:49:43 +02:00

129 lines
5.1 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
rag.py — recherche + génération (version robuste, chapitres)
============================================================
• Charge **un ou plusieurs** couples index/meta (FAISS + JSON). Par défaut :
rapport.idx / rapport.meta.json
• Reconstitue les textes à partir des fichiers `path` indiqués dans la méta.
Les passages sont déjà prêts (1 par fichier court, ou découpés par index.py).
• Recherche : embeddings BGEM3 (CPU) + FAISS (cosinus IP) sur tous les index.
topk configurable (déf. 20 pour index détaillé, 5 pour index chapitres).
trie ensuite les hits mettant en avant ceux contenant un motclé fourni
(ex. « seuil » pour ICS).
• Génération : appelle llama3-8b-fast (Ollama) avec temperature 0.1 et consigne :
« Réponds uniquement à partir du contexte. Si linfo manque : Je ne sais pas. »
Usage :
python rag.py [--k 25] [--kw seuil] [--model llama3-8b-fast]
"""
from __future__ import annotations
import argparse, json, re, sys
from pathlib import Path
import faiss, numpy as np, requests
from FlagEmbedding import BGEM3FlagModel
from rich import print
ROOT = Path("Rapport")
# ------------------------- CLI -------------------------------------------
p = argparse.ArgumentParser()
p.add_argument("--index", nargs="*", default=["rapport.idx"],
help="Liste des fichiers FAISS à charger (déf. rapport.idx)")
p.add_argument("--meta", nargs="*", default=["rapport.meta.json"],
help="Liste des méta JSON assortis (même ordre que --index)")
p.add_argument("--k", type=int, default=15, help="topk cumulés (déf. 15)")
p.add_argument("--kw", default="seuil", help="motclé boosté (déf. seuil)")
p.add_argument("--model", default="llama3-8b-fast", help="modèle Ollama")
args = p.parse_args()
if len(args.index) != len(args.meta):
print("[red]Erreur : --index et --meta doivent avoir la même longueur.")
sys.exit(1)
# ------------------------- charger indexes -------------------------------
indexes, metas, start_offset = [], [], []
offset = 0
for idx_f, meta_f in zip(args.index, args.meta):
idx = faiss.read_index(str(idx_f))
meta = json.load(open(meta_f))
if idx.ntotal != len(meta):
print(f"[yellow]Avertissement : {idx_f} contient {idx.ntotal} vecteurs, meta {len(meta)} lignes.[/]")
indexes.append(idx)
metas.append(meta)
start_offset.append(offset)
offset += idx.ntotal
total_passages = offset
print(f"Passages chargés : {total_passages} (agrégat de {len(indexes)} index)")
# ------------------------- cache texte -----------------------------------
DOCS: dict[int,str] = {}
for base_offset, meta in zip(start_offset, metas):
for i, m in enumerate(meta):
rel_path = m.get("path") or m.get("file")
full_path = ROOT / rel_path
DOCS[base_offset + i] = full_path.read_text(encoding="utf-8")
print("[dim]Cache texte préchargé.[/]")
# ------------------------- modèle embeddings -----------------------------
embedder = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
# ------------------------- helpers ---------------------------------------
def encode_query(q: str):
emb = embedder.encode([q])
if isinstance(emb, dict):
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
v = emb[0]
return (v / np.linalg.norm(v)).astype("float32").reshape(1, -1)
def search_all(vec):
hits = []
for idx, off in zip(indexes, start_offset):
D, I = idx.search(vec, min(args.k, idx.ntotal))
hits.extend([off + int(i) for i in I[0]])
return hits
# ------------------------- boucle interactive ----------------------------
print("RAG prêt ! (CtrlD pour quitter)")
while True:
try:
q = input("❓ > ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye."); break
if not q: continue
# correction rapide de typos courantes (substituabilité…)
q_norm = re.sub(r"susbtitu[a-z]+", "substituabilité", q, flags=re.I)
vec = encode_query(q_norm)
hits = search_all(vec)
# Boost lexical : passages contenant le motclé args.kw dabord
kw_lower = args.kw.lower()
hits.sort(key=lambda i: kw_lower not in DOCS[i].lower())
hits = hits[:args.k]
context = "\n\n".join(DOCS[i] for i in hits)
prompt = (
"<system>Réponds en français, de façon précise et uniquement à partir du contexte. "
"Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.</system>\n"
f"<context>{context}</context>\n"
f"<user>{q}</user>"
)
r = requests.post("http://127.0.0.1:11434/api/generate", json={
"model": args.model,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.1, "num_predict": 512}
}, timeout=300)
answer = r.json().get("response", "(erreur API)")
print("\n[bold]Réponse :[/]\n", answer)
print("\n[dim]--- contexte utilisé (top " + str(len(hits)) + ") ---[/]")
for rank, idx_id in enumerate(hits, 1):
m = metas[0] # non utilisé ici, on affiche juste le nom
path = DOCS[idx_id].splitlines()[0][:250]
print(f"[{rank}] … {path}")