#!/usr/bin/env python3 """ rag.py — recherche + génération (version robuste, chapitres) ============================================================ • Charge **un ou plusieurs** couples index/meta (FAISS + JSON). Par défaut : rapport.idx / rapport.meta.json • Reconstitue les textes à partir des fichiers `path` indiqués dans la méta. – Les passages sont déjà prêts (1 par fichier court, ou découpés par index.py). • Recherche : embeddings BGE‑M3 (CPU) + FAISS (cosinus IP) sur tous les index. – top‑k configurable (déf. 20 pour index détaillé, 5 pour index chapitres). – trie ensuite les hits mettant en avant ceux contenant un mot‑clé fourni (ex. « seuil » pour ICS). • Génération : appelle llama3-8b-fast (Ollama) avec temperature 0.1 et consigne : « Réponds uniquement à partir du contexte. Si l’info manque : Je ne sais pas. » Usage : python rag.py [--k 25] [--kw seuil] [--model llama3-8b-fast] """ from __future__ import annotations import argparse, json, re, sys from pathlib import Path import faiss, numpy as np, requests from FlagEmbedding import BGEM3FlagModel from rich import print ROOT = Path("Rapport") # ------------------------- CLI ------------------------------------------- p = argparse.ArgumentParser() p.add_argument("--index", nargs="*", default=["rapport.idx"], help="Liste des fichiers FAISS à charger (déf. rapport.idx)") p.add_argument("--meta", nargs="*", default=["rapport.meta.json"], help="Liste des méta JSON assortis (même ordre que --index)") p.add_argument("--k", type=int, default=15, help="top‑k cumulés (déf. 15)") p.add_argument("--kw", default="seuil", help="mot‑clé boosté (déf. seuil)") p.add_argument("--model", default="llama3-8b-fast", help="modèle Ollama") args = p.parse_args() if len(args.index) != len(args.meta): print("[red]Erreur : --index et --meta doivent avoir la même longueur.") sys.exit(1) # ------------------------- charger indexes ------------------------------- indexes, metas, start_offset = [], [], [] offset = 0 for idx_f, meta_f in zip(args.index, args.meta): idx = faiss.read_index(str(idx_f)) meta = json.load(open(meta_f)) if idx.ntotal != len(meta): print(f"[yellow]Avertissement : {idx_f} contient {idx.ntotal} vecteurs, meta {len(meta)} lignes.[/]") indexes.append(idx) metas.append(meta) start_offset.append(offset) offset += idx.ntotal total_passages = offset print(f"Passages chargés : {total_passages} (agrégat de {len(indexes)} index)") # ------------------------- cache texte ----------------------------------- DOCS: dict[int,str] = {} for base_offset, meta in zip(start_offset, metas): for i, m in enumerate(meta): rel_path = m.get("path") or m.get("file") full_path = ROOT / rel_path DOCS[base_offset + i] = full_path.read_text(encoding="utf-8") print("[dim]Cache texte préchargé.[/]") # ------------------------- modèle embeddings ----------------------------- embedder = BGEM3FlagModel("BAAI/bge-m3", device="cpu") # ------------------------- helpers --------------------------------------- def encode_query(q: str): emb = embedder.encode([q]) if isinstance(emb, dict): emb = next(v for v in emb.values() if isinstance(v, np.ndarray)) v = emb[0] return (v / np.linalg.norm(v)).astype("float32").reshape(1, -1) def search_all(vec): hits = [] for idx, off in zip(indexes, start_offset): D, I = idx.search(vec, min(args.k, idx.ntotal)) hits.extend([off + int(i) for i in I[0]]) return hits # ------------------------- boucle interactive ---------------------------- print("RAG prêt ! (Ctrl‑D pour quitter)") while True: try: q = input("❓ > ").strip() except (EOFError, KeyboardInterrupt): print("\nBye."); break if not q: continue # correction rapide de typos courantes (substituabilité…) q_norm = re.sub(r"susbtitu[a-z]+", "substituabilité", q, flags=re.I) vec = encode_query(q_norm) hits = search_all(vec) # Boost lexical : passages contenant le mot‑clé args.kw d’abord kw_lower = args.kw.lower() hits.sort(key=lambda i: kw_lower not in DOCS[i].lower()) hits = hits[:args.k] context = "\n\n".join(DOCS[i] for i in hits) prompt = ( "Réponds en français, de façon précise et uniquement à partir du contexte. " "Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.\n" f"{context}\n" f"{q}" ) r = requests.post("http://127.0.0.1:11434/api/generate", json={ "model": args.model, "prompt": prompt, "stream": False, "options": {"temperature": 0.1, "num_predict": 512} }, timeout=300) answer = r.json().get("response", "(erreur API)") print("\n[bold]Réponse :[/]\n", answer) print("\n[dim]--- contexte utilisé (top " + str(len(hits)) + ") ---[/]") for rank, idx_id in enumerate(hits, 1): m = metas[0] # non utilisé ici, on affiche juste le nom path = DOCS[idx_id].splitlines()[0][:250] print(f"[{rank}] … {path}…")