From 4b16c2210e410e714a5e0152db1ec5acab7d39d3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phan?= Date: Mon, 19 May 2025 08:05:07 +0200 Subject: [PATCH] Update rag.py --- rag.py | 191 ++++++++++++++++++++++++++++++--------------------------- 1 file changed, 101 insertions(+), 90 deletions(-) diff --git a/rag.py b/rag.py index 1ed322c..290c292 100644 --- a/rag.py +++ b/rag.py @@ -1,113 +1,124 @@ #!/usr/bin/env python3 """ -RAG interactif – version alignée sur l'index -------------------------------------------- -• Utilise corpus.idx + corpus.meta.json pour connaître l'ordre exact des passages. -• Recharge **uniquement** les textes correspondants en gardant cet ordre – ainsi, plus - d'erreur d'index out‑of‑range quelle que soit la découpe. -• Recherche FAISS (top‑k=4) + génération via mistral7b-fast (Ollama). +rag.py — recherche + génération (version robuste, chapitres) +============================================================ +• Charge **un ou plusieurs** couples index/meta (FAISS + JSON). Par défaut : + chap.idx / chap.meta.json +• Reconstitue les textes à partir des fichiers `path` indiqués dans la méta. + – Les passages sont déjà prêts (1 par fichier court, ou découpés par index.py). +• Recherche : embeddings BGE‑M3 (CPU) + FAISS (cosinus IP) sur tous les index. + – top‑k configurable (déf. 20 pour index détaillé, 5 pour index chapitres). + – trie ensuite les hits mettant en avant ceux contenant un mot‑clé fourni + (ex. « seuil » pour ICS). +• Génération : appelle Mistral‑7B (Ollama) avec temperature 0.1 et consigne : + « Réponds uniquement à partir du contexte. Si l’info manque : Je ne sais pas. » + +Usage : + python rag.py [--k 25] [--kw seuil] [--model mistral7b-fast] """ - -import json, readline, re +from __future__ import annotations +import argparse, json, re, sys from pathlib import Path -from collections import defaultdict - import faiss, numpy as np, requests from FlagEmbedding import BGEM3FlagModel from rich import print -ROOT = Path("Fiches") # dossier racine des fiches -K = 30 # nombre de passages remis au LLM +# ------------------------- CLI ------------------------------------------- +p = argparse.ArgumentParser() +p.add_argument("--index", nargs="*", default=["chap.idx"], + help="Liste des fichiers FAISS à charger (déf. chap.idx)") +p.add_argument("--meta", nargs="*", default=["chap.meta.json"], + help="Liste des méta JSON assortis (même ordre que --index)") +p.add_argument("--k", type=int, default=15, help="top‑k cumulés (déf. 15)") +p.add_argument("--kw", default="seuil", help="mot‑clé boosté (déf. seuil)") +p.add_argument("--model", default="mistral7b-fast", help="modèle Ollama") +args = p.parse_args() -# ------------------ utilitaires de découpe identiques à l'index ------------- -CHUNK, OVERLAP = 800, 100 # garder cohérent avec index.py +if len(args.index) != len(args.meta): + print("[red]Erreur : --index et --meta doivent avoir la même longueur.") + sys.exit(1) -def split(text: str): - sents = re.split(r"(?<=[.!?]) +", text) - buf, out = [], [] - for s in sents: - buf.append(s) - if len(" ".join(buf).split()) > CHUNK: - out.append(" ".join(buf)) - buf = buf[-OVERLAP:] - if buf: - out.append(" ".join(buf)) - return out +# ------------------------- charger indexes ------------------------------- +indexes, metas, start_offset = [], [], [] +offset = 0 +for idx_f, meta_f in zip(args.index, args.meta): + idx = faiss.read_index(str(idx_f)) + meta = json.load(open(meta_f)) + if idx.ntotal != len(meta): + print(f"[yellow]Avertissement : {idx_f} contient {idx.ntotal} vecteurs, meta {len(meta)} lignes.[/]") + indexes.append(idx) + metas.append(meta) + start_offset.append(offset) + offset += idx.ntotal -# ------------------- charger meta et reconstruire passages ------------------ -meta_path = Path("corpus.meta.json") -if not meta_path.exists(): - raise SystemExit("corpus.meta.json introuvable – lancez d'abord index.py") -meta = json.load(meta_path.open()) +total_passages = offset +print(f"Passages chargés : {total_passages} (agrégat de {len(indexes)} index)") -# mapping (file, part) -> chunk text -cache: dict[tuple[str, int], str] = {} -for fp in sorted(ROOT.rglob("*")): - if fp.suffix.lower() not in {".md", ".markdown", ".txt"}: - continue - chunks = split(fp.read_text(encoding="utf-8")) - for i, ch in enumerate(chunks): - cache[(fp.name, i)] = ch +# ------------------------- cache texte ----------------------------------- +DOCS: dict[int,str] = {} +for base_offset, meta in zip(start_offset, metas): + for i, m in enumerate(meta): + DOCS[base_offset + i] = Path(m["path"]).read_text(encoding="utf-8") +print("[dim]Cache texte préchargé.[/]") -# reconstruire docs dans le même ordre que l'index --------------------------- -docs = [] -for m in meta: - # compatibilité avec l’ancien et le nouveau format - path = m.get("file") or m.get("path") # nouvelle clé : "path" - part = m["part"] - key = (Path(path).name, part) # on garde le nom court pour le cache +# ------------------------- modèle embeddings ----------------------------- +embedder = BGEM3FlagModel("BAAI/bge-m3", device="cpu") - docs.append(cache.get(key, "[passage manquant]")) +# ------------------------- helpers --------------------------------------- +def encode_query(q: str): + emb = embedder.encode([q]) + if isinstance(emb, dict): + emb = next(v for v in emb.values() if isinstance(v, np.ndarray)) + v = emb[0] + return (v / np.linalg.norm(v)).astype("float32").reshape(1, -1) -print(f"[dim]Passages rechargés : {len(docs)} (ordre conforme à l'index).[/]") +def search_all(vec): + hits = [] + for idx, off in zip(indexes, start_offset): + D, I = idx.search(vec, min(args.k, idx.ntotal)) + hits.extend([off + int(i) for i in I[0]]) + return hits -# ---------------- FAISS + modèle embeddings -------------------------------- -idx = faiss.read_index("corpus.idx") -model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") +# ------------------------- boucle interactive ---------------------------- +print("RAG prêt ! (Ctrl‑D pour quitter)") +while True: + try: + q = input("❓ > ").strip() + except (EOFError, KeyboardInterrupt): + print("\nBye."); break + if not q: continue -# ---------------- boucle interactive --------------------------------------- -print("RAG prêt. Posez vos questions ! (Ctrl‑D pour sortir)") -try: - while True: - try: - q = input("❓ > ").strip() - if not q: - continue - except (EOFError, KeyboardInterrupt): - print("\nBye."); break + # correction rapide de typos courantes (substituabilité…) + q_norm = re.sub(r"susbtitu[a-z]+", "substituabilité", q, flags=re.I) - emb = model.encode([q]) - if isinstance(emb, dict): - # récupère le 1er ndarray trouvé - emb = next(v for v in emb.values() if isinstance(v, np.ndarray)) - q_emb = emb[0] / np.linalg.norm(emb[0]) + vec = encode_query(q_norm) + hits = search_all(vec) - D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K) - hits = I[0] - # réordonne pour mettre en tête les passages contenant “Seuil” - hits = sorted(hits, key=lambda i: "Seuil" not in docs[int(i)]) + # Boost lexical : passages contenant le mot‑clé args.kw d’abord + kw_lower = args.kw.lower() + hits.sort(key=lambda i: kw_lower not in DOCS[i].lower()) + hits = hits[:args.k] - context = "\n\n".join(docs[int(i)] for i in hits[:K]) - prompt = ( - "Réponds en français, de façon précise, et uniquement à partir du contexte fourni. Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.\n" - f"{context}\n{q}" - ) + context = "\n\n".join(DOCS[i] for i in hits) - def ask_llm(p): - r = requests.post("http://127.0.0.1:11434/api/generate", json={ - "model": "mistral7b-fast", - "prompt": p, - "stream": False, - "options": {"temperature": 0.0, "num_predict": 512} - }, timeout=300) - return r.json()["response"] + prompt = ( + "Réponds en français, de façon précise et uniquement à partir du contexte. " + "Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.\n" + f"{context}\n" + f"{q}" + ) - print("\n[bold]Réponse :[/]") - print(ask_llm(prompt)) + r = requests.post("http://127.0.0.1:11434/api/generate", json={ + "model": args.model, + "prompt": prompt, + "stream": False, + "options": {"temperature": 0.1, "num_predict": 512} + }, timeout=300) + answer = r.json().get("response", "(erreur API)") - print("\n[dim]--- contexte utilisé ---[/]") - for rank, idx_id in enumerate(hits, 1): - m = meta[int(idx_id)] - print(f"[{rank}] {Path(m.get('file') or m.get('path')).name} · part {m['part']} → …") -except Exception as e: - print("[red]Erreur :", e) + print("\n[bold]Réponse :[/]\n", answer) + print("\n[dim]--- contexte utilisé (top " + str(len(hits)) + ") ---[/]") + for rank, idx_id in enumerate(hits, 1): + m = metas[0] # non utilisé ici, on affiche juste le nom + path = DOCS[idx_id].splitlines()[0][:60] + print(f"[{rank}] … {path}…")