diff --git a/rag.py b/rag.py
index 1ed322c..290c292 100644
--- a/rag.py
+++ b/rag.py
@@ -1,113 +1,124 @@
#!/usr/bin/env python3
"""
-RAG interactif – version alignée sur l'index
--------------------------------------------
-• Utilise corpus.idx + corpus.meta.json pour connaître l'ordre exact des passages.
-• Recharge **uniquement** les textes correspondants en gardant cet ordre – ainsi, plus
- d'erreur d'index out‑of‑range quelle que soit la découpe.
-• Recherche FAISS (top‑k=4) + génération via mistral7b-fast (Ollama).
+rag.py — recherche + génération (version robuste, chapitres)
+============================================================
+• Charge **un ou plusieurs** couples index/meta (FAISS + JSON). Par défaut :
+ chap.idx / chap.meta.json
+• Reconstitue les textes à partir des fichiers `path` indiqués dans la méta.
+ – Les passages sont déjà prêts (1 par fichier court, ou découpés par index.py).
+• Recherche : embeddings BGE‑M3 (CPU) + FAISS (cosinus IP) sur tous les index.
+ – top‑k configurable (déf. 20 pour index détaillé, 5 pour index chapitres).
+ – trie ensuite les hits mettant en avant ceux contenant un mot‑clé fourni
+ (ex. « seuil » pour ICS).
+• Génération : appelle Mistral‑7B (Ollama) avec temperature 0.1 et consigne :
+ « Réponds uniquement à partir du contexte. Si l’info manque : Je ne sais pas. »
+
+Usage :
+ python rag.py [--k 25] [--kw seuil] [--model mistral7b-fast]
"""
-
-import json, readline, re
+from __future__ import annotations
+import argparse, json, re, sys
from pathlib import Path
-from collections import defaultdict
-
import faiss, numpy as np, requests
from FlagEmbedding import BGEM3FlagModel
from rich import print
-ROOT = Path("Fiches") # dossier racine des fiches
-K = 30 # nombre de passages remis au LLM
+# ------------------------- CLI -------------------------------------------
+p = argparse.ArgumentParser()
+p.add_argument("--index", nargs="*", default=["chap.idx"],
+ help="Liste des fichiers FAISS à charger (déf. chap.idx)")
+p.add_argument("--meta", nargs="*", default=["chap.meta.json"],
+ help="Liste des méta JSON assortis (même ordre que --index)")
+p.add_argument("--k", type=int, default=15, help="top‑k cumulés (déf. 15)")
+p.add_argument("--kw", default="seuil", help="mot‑clé boosté (déf. seuil)")
+p.add_argument("--model", default="mistral7b-fast", help="modèle Ollama")
+args = p.parse_args()
-# ------------------ utilitaires de découpe identiques à l'index -------------
-CHUNK, OVERLAP = 800, 100 # garder cohérent avec index.py
+if len(args.index) != len(args.meta):
+ print("[red]Erreur : --index et --meta doivent avoir la même longueur.")
+ sys.exit(1)
-def split(text: str):
- sents = re.split(r"(?<=[.!?]) +", text)
- buf, out = [], []
- for s in sents:
- buf.append(s)
- if len(" ".join(buf).split()) > CHUNK:
- out.append(" ".join(buf))
- buf = buf[-OVERLAP:]
- if buf:
- out.append(" ".join(buf))
- return out
+# ------------------------- charger indexes -------------------------------
+indexes, metas, start_offset = [], [], []
+offset = 0
+for idx_f, meta_f in zip(args.index, args.meta):
+ idx = faiss.read_index(str(idx_f))
+ meta = json.load(open(meta_f))
+ if idx.ntotal != len(meta):
+ print(f"[yellow]Avertissement : {idx_f} contient {idx.ntotal} vecteurs, meta {len(meta)} lignes.[/]")
+ indexes.append(idx)
+ metas.append(meta)
+ start_offset.append(offset)
+ offset += idx.ntotal
-# ------------------- charger meta et reconstruire passages ------------------
-meta_path = Path("corpus.meta.json")
-if not meta_path.exists():
- raise SystemExit("corpus.meta.json introuvable – lancez d'abord index.py")
-meta = json.load(meta_path.open())
+total_passages = offset
+print(f"Passages chargés : {total_passages} (agrégat de {len(indexes)} index)")
-# mapping (file, part) -> chunk text
-cache: dict[tuple[str, int], str] = {}
-for fp in sorted(ROOT.rglob("*")):
- if fp.suffix.lower() not in {".md", ".markdown", ".txt"}:
- continue
- chunks = split(fp.read_text(encoding="utf-8"))
- for i, ch in enumerate(chunks):
- cache[(fp.name, i)] = ch
+# ------------------------- cache texte -----------------------------------
+DOCS: dict[int,str] = {}
+for base_offset, meta in zip(start_offset, metas):
+ for i, m in enumerate(meta):
+ DOCS[base_offset + i] = Path(m["path"]).read_text(encoding="utf-8")
+print("[dim]Cache texte préchargé.[/]")
-# reconstruire docs dans le même ordre que l'index ---------------------------
-docs = []
-for m in meta:
- # compatibilité avec l’ancien et le nouveau format
- path = m.get("file") or m.get("path") # nouvelle clé : "path"
- part = m["part"]
- key = (Path(path).name, part) # on garde le nom court pour le cache
+# ------------------------- modèle embeddings -----------------------------
+embedder = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
- docs.append(cache.get(key, "[passage manquant]"))
+# ------------------------- helpers ---------------------------------------
+def encode_query(q: str):
+ emb = embedder.encode([q])
+ if isinstance(emb, dict):
+ emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
+ v = emb[0]
+ return (v / np.linalg.norm(v)).astype("float32").reshape(1, -1)
-print(f"[dim]Passages rechargés : {len(docs)} (ordre conforme à l'index).[/]")
+def search_all(vec):
+ hits = []
+ for idx, off in zip(indexes, start_offset):
+ D, I = idx.search(vec, min(args.k, idx.ntotal))
+ hits.extend([off + int(i) for i in I[0]])
+ return hits
-# ---------------- FAISS + modèle embeddings --------------------------------
-idx = faiss.read_index("corpus.idx")
-model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
+# ------------------------- boucle interactive ----------------------------
+print("RAG prêt ! (Ctrl‑D pour quitter)")
+while True:
+ try:
+ q = input("❓ > ").strip()
+ except (EOFError, KeyboardInterrupt):
+ print("\nBye."); break
+ if not q: continue
-# ---------------- boucle interactive ---------------------------------------
-print("RAG prêt. Posez vos questions ! (Ctrl‑D pour sortir)")
-try:
- while True:
- try:
- q = input("❓ > ").strip()
- if not q:
- continue
- except (EOFError, KeyboardInterrupt):
- print("\nBye."); break
+ # correction rapide de typos courantes (substituabilité…)
+ q_norm = re.sub(r"susbtitu[a-z]+", "substituabilité", q, flags=re.I)
- emb = model.encode([q])
- if isinstance(emb, dict):
- # récupère le 1er ndarray trouvé
- emb = next(v for v in emb.values() if isinstance(v, np.ndarray))
- q_emb = emb[0] / np.linalg.norm(emb[0])
+ vec = encode_query(q_norm)
+ hits = search_all(vec)
- D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K)
- hits = I[0]
- # réordonne pour mettre en tête les passages contenant “Seuil”
- hits = sorted(hits, key=lambda i: "Seuil" not in docs[int(i)])
+ # Boost lexical : passages contenant le mot‑clé args.kw d’abord
+ kw_lower = args.kw.lower()
+ hits.sort(key=lambda i: kw_lower not in DOCS[i].lower())
+ hits = hits[:args.k]
- context = "\n\n".join(docs[int(i)] for i in hits[:K])
- prompt = (
- "Réponds en français, de façon précise, et uniquement à partir du contexte fourni. Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.\n"
- f"{context}\n{q}"
- )
+ context = "\n\n".join(DOCS[i] for i in hits)
- def ask_llm(p):
- r = requests.post("http://127.0.0.1:11434/api/generate", json={
- "model": "mistral7b-fast",
- "prompt": p,
- "stream": False,
- "options": {"temperature": 0.0, "num_predict": 512}
- }, timeout=300)
- return r.json()["response"]
+ prompt = (
+ "Réponds en français, de façon précise et uniquement à partir du contexte. "
+ "Si l'information n'est pas dans le contexte, réponds : 'Je ne sais pas'.\n"
+ f"{context}\n"
+ f"{q}"
+ )
- print("\n[bold]Réponse :[/]")
- print(ask_llm(prompt))
+ r = requests.post("http://127.0.0.1:11434/api/generate", json={
+ "model": args.model,
+ "prompt": prompt,
+ "stream": False,
+ "options": {"temperature": 0.1, "num_predict": 512}
+ }, timeout=300)
+ answer = r.json().get("response", "(erreur API)")
- print("\n[dim]--- contexte utilisé ---[/]")
- for rank, idx_id in enumerate(hits, 1):
- m = meta[int(idx_id)]
- print(f"[{rank}] {Path(m.get('file') or m.get('path')).name} · part {m['part']} → …")
-except Exception as e:
- print("[red]Erreur :", e)
+ print("\n[bold]Réponse :[/]\n", answer)
+ print("\n[dim]--- contexte utilisé (top " + str(len(hits)) + ") ---[/]")
+ for rank, idx_id in enumerate(hits, 1):
+ m = metas[0] # non utilisé ici, on affiche juste le nom
+ path = DOCS[idx_id].splitlines()[0][:60]
+ print(f"[{rank}] … {path}…")