diff --git a/rag.py b/rag.py index a223bb6..f622a94 100644 --- a/rag.py +++ b/rag.py @@ -1,65 +1,67 @@ #!/usr/bin/env python3 """ -RAG interactif robuste. -• Recharge les passages à partir des fiches (même découpe que l'index) pour disposer du texte. -• Recherche FAISS top‑k=4 et génération via mistral7b-fast. +RAG interactif – version alignée sur l'index +------------------------------------------- +• Utilise corpus.idx + corpus.meta.json pour connaître l'ordre exact des passages. +• Recharge **uniquement** les textes correspondants en gardant cet ordre – ainsi, plus + d'erreur d'index out‑of‑range quelle que soit la découpe. +• Recherche FAISS (top‑k=4) + génération via mistral7b-fast (Ollama). """ -import os, json, readline, re + +import json, readline, re from pathlib import Path +from collections import defaultdict import faiss, numpy as np, requests from FlagEmbedding import BGEM3FlagModel from rich import print -# --------------------------------------------------------------------------- -ROOT = Path("Fiches") # dossier des fiches sur l'hôte -CHUNK, OVERLAP = 800, 100 # identiques à l'indexation -K = 4 # nb de passages remis au modèle +ROOT = Path("Fiches") # dossier racine des fiches +K = 4 # nombre de passages remis au LLM -# --- découpe --------------------------------------------------------------- +# ------------------ utilitaires de découpe identiques à l'index ------------- +CHUNK, OVERLAP = 800, 100 # garder cohérent avec index.py def split(text: str): sents = re.split(r"(?<=[.!?]) +", text) buf, out = [], [] for s in sents: buf.append(s) - if len(" ".join(buf).split()) > CHUNK: # approx 1 mot = 1 token + if len(" ".join(buf).split()) > CHUNK: out.append(" ".join(buf)) buf = buf[-OVERLAP:] if buf: out.append(" ".join(buf)) return out -# --- charger docs + meta dans le même ordre que l'index -------------------- +# ------------------- charger meta et reconstruire passages ------------------ +meta_path = Path("corpus.meta.json") +if not meta_path.exists(): + raise SystemExit("corpus.meta.json introuvable – lancez d'abord index.py") +meta = json.load(meta_path.open()) -docs, meta = [], [] -for fp in ROOT.rglob("*.md"): - for i, chunk in enumerate(split(fp.read_text(encoding="utf-8"))): - docs.append(chunk) - meta.append({"file": fp.name, "part": i}) +# mapping (file, part) -> chunk text +cache: dict[tuple[str, int], str] = {} +for fp in sorted(ROOT.rglob("*")): + if fp.suffix.lower() not in {".md", ".markdown", ".txt"}: + continue + chunks = split(fp.read_text(encoding="utf-8")) + for i, ch in enumerate(chunks): + cache[(fp.name, i)] = ch -print(f"[dim]Chargé {len(docs)} passages depuis {ROOT}.[/]") +# reconstruire docs dans le même ordre que l'index --------------------------- +docs = [] +for m in meta: + key = (m["file"], m["part"]) + docs.append(cache.get(key, "[passage manquant]")) -# --- FAISS index existant --------------------------------------------------- +print(f"[dim]Passages rechargés : {len(docs)} (ordre conforme à l'index).[/]") -idx = faiss.read_index("corpus.idx") +# ---------------- FAISS + modèle embeddings -------------------------------- +idx = faiss.read_index("corpus.idx") model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") -# --- boucle Q/A ------------------------------------------------------------- - -def fetch_passage(i: int): - m = meta[i] - return f"[{m['file']} · part {m['part']}] {docs[i][:200]}…" - -def ask_llm(prompt: str): - r = requests.post("http://127.0.0.1:11434/api/generate", json={ - "model": "mistral7b-fast", - "prompt": prompt, - "stream": False, - "options": {"temperature": 0.2, "num_predict": 512} - }, timeout=300) - return r.json()["response"] - +# ---------------- boucle interactive --------------------------------------- print("RAG prêt. Posez vos questions ! (Ctrl‑D pour sortir)") try: while True: @@ -72,20 +74,34 @@ try: emb = model.encode([q]) if isinstance(emb, dict): + # récupère le 1er ndarray trouvé emb = next(v for v in emb.values() if isinstance(v, np.ndarray)) q_emb = emb[0] / np.linalg.norm(emb[0]) D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K) + hits = I[0] - context = "\n\n".join(docs[int(idx_id)] for idx_id in I[0]) - prompt = f"""Réponds en français, précis et factuel.\n{context}\n{q}""" + context = "\n\n".join(docs[int(i)] for i in hits) + prompt = ( + "Réponds en français, précis et factuel.\n" + f"{context}\n{q}" + ) + + def ask_llm(p): + r = requests.post("http://127.0.0.1:11434/api/generate", json={ + "model": "mistral7b-fast", + "prompt": p, + "stream": False, + "options": {"temperature": 0.2, "num_predict": 512} + }, timeout=300) + return r.json()["response"] print("\n[bold]Réponse :[/]") print(ask_llm(prompt)) - # petite trace des sources print("\n[dim]--- contexte utilisé ---[/]") - for idx_id in I[0]: - print(fetch_passage(int(idx_id))) + for rank, idx_id in enumerate(hits, 1): + m = meta[int(idx_id)] + print(f"[{rank}] {m['file']} · part {m['part']} → {docs[int(idx_id)][:120]}…") except Exception as e: print("[red]Erreur :", e)