#!/usr/bin/env python3
"""
rag.py — recherche + génération (version robuste, chapitres)
============================================================
• Charge **un ou plusieurs** couples index/meta (FAISS + JSON). Par défaut :
rapport.idx / rapport.meta.json
• Reconstitue les textes à partir des fichiers `path` indiqués dans la méta.
– Les passages sont déjà prêts (1 par fichier court, ou découpés par index.py).
• Recherche : embeddings BGE‑M3 (CPU) + FAISS (cosinus IP) sur tous les index.
– top‑k configurable (déf. 20 pour index détaillé, 5 pour index chapitres).
– trie ensuite les hits mettant en avant ceux contenant un mot‑clé fourni
(ex. « seuil » pour ICS).
• Génération : appelle llama3-8b-fast (Ollama) avec temperature 0.1 et consigne :
« Réponds uniquement à partir du contexte. Si l’info manque : Je ne sais pas. »
Usage :
python rag.py [--k 25] [--kw seuil] [--model llama3-8b-fast]
"""
from __future__ import annotations
import argparse, json, re, sys
from pathlib import Path
import faiss, numpy as np, requests
from FlagEmbedding import BGEM3FlagModel
from rich import print
ROOT = Path("Rapport")
# ------------------------- CLI -------------------------------------------
p = argparse.ArgumentParser()
p.add_argument("--index", nargs="*", default=["rapport.idx"],
help="Liste des fichiers FAISS à charger (déf. rapport.idx)")
p.add_argument("--meta", nargs="*", default=["rapport.meta.json"],
help="Liste des méta JSON assortis (même ordre que --index)")
p.add_argument("--k", type=int, default=15, help="top‑k cumulés (déf. 15)")
p.add_argument("--kw", default="seuil", help="mot‑clé boosté (déf. seuil)")
p.add_argument("--model", default="llama3-8b-fast", help="modèle Ollama")
args = p.parse_args()
if len(args.index) != len(args.meta):
print("[red]Erreur : --index et --meta doivent avoir la même longueur.")
sys.exit(1)
# ------------------------- charger indexes -------------------------------
indexes, metas, start_offset = [], [], []
offset = 0
for idx_f, meta_f in zip(args.index, args.meta):
idx = faiss.read_index(str(idx_f))
meta = json.load(open(meta_f))
if idx.ntotal != len(meta):
print(f"[yellow]Avertissement : {idx_f} contient {idx.ntotal} vecteurs, meta {len(meta)} lignes.[/]")
indexes.append(idx)
metas.append(meta)
start_offset.append(offset)
offset += idx.ntotal
total_passages = offset
print(f"Passages chargés : {total_passages} (agrégat de {len(indexes)} index)")
# ------------------------- cache texte -----------------------------------
DOCS: dict[int,str] = {}
for base_offset, meta in zip(start_offset, metas):
for i, m in enumerate(meta):
rel_path = m.get("path") or m.get("file")
full_path = ROOT / rel_path
DOCS[base_offset + i] = full_path.read_text(encoding="utf-8")
print("[dim]Cache texte préchargé.[/]")
# ------------------------- modèle embeddings -----------------------------
embedder = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
# ------------------------- helpers ---------------------------------------
def encode_query(q: str):
emb = embedder.encode([q])
if isinstance(emb, dict):
emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
v = emb[0]
return (v / np.linalg.norm(v)).astype("float32").reshape(1, -1)
def search_all(vec):
hits = []
for idx, off in zip(indexes, start_offset):
D, I = idx.search(vec, min(args.k, idx.ntotal))
hits.extend([off + int(i) for i in I[0]])
return hits
# ------------------------- boucle interactive ----------------------------
print("RAG prêt ! (Ctrl‑D pour quitter)")
while True:
try:
q = input("❓ > ").strip()
except (EOFError, KeyboardInterrupt):
print("\nBye."); break
if not q: continue
# correction rapide de typos courantes (substituabilité…)
q_norm = re.sub(r"susbtitu[a-z]+", "substituabilité", q, flags=re.I)
vec = encode_query(q_norm)
hits = search_all(vec)
# Boost lexical : passages contenant le mot‑clé args.kw d’abord
kw_lower = args.kw.lower()
hits.sort(key=lambda i: kw_lower not in DOCS[i].lower())
hits = hits[:args.k]
context = "\n\n".join(DOCS[i] for i in hits)
prompt = (
"Réponds en français, de façon précise et uniquement à partir du contexte. "
"Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.\n"
f"{context}\n"
f"{q}"
)
r = requests.post("http://127.0.0.1:11434/api/generate", json={
"model": args.model,
"prompt": prompt,
"stream": False,
"options": {"temperature": 0.1, "num_predict": 512}
}, timeout=300)
answer = r.json().get("response", "(erreur API)")
print("\n[bold]Réponse :[/]\n", answer)
print("\n[dim]--- contexte utilisé (top " + str(len(hits)) + ") ---[/]")
for rank, idx_id in enumerate(hits, 1):
m = metas[0] # non utilisé ici, on affiche juste le nom
path = DOCS[idx_id].splitlines()[0][:250]
print(f"[{rank}] … {path}…")