diff --git a/rag.py b/rag.py index 2ecac8f..a223bb6 100644 --- a/rag.py +++ b/rag.py @@ -1,107 +1,91 @@ #!/usr/bin/env python3 """ -Mini RAG interactif : -• Recherche sémantique FAISS sur le corpus.idx / corpus.meta.json -• Contexte (top‑k=4 passages) envoyé à Mistral‑7B via Ollama. - -Robuste aux différentes sorties de BGEM3FlagModel.encode. +RAG interactif robuste. +• Recharge les passages à partir des fiches (même découpe que l'index) pour disposer du texte. +• Recherche FAISS top‑k=4 et génération via mistral7b-fast. """ -import json -import readline +import os, json, readline, re from pathlib import Path -import faiss -import numpy as np -import requests +import faiss, numpy as np, requests from FlagEmbedding import BGEM3FlagModel from rich import print # --------------------------------------------------------------------------- -# Chargements initiaux -# --------------------------------------------------------------------------- -IDX_FILE = Path("corpus.idx") -META_FILE = Path("corpus.meta.json") +ROOT = Path("Fiches") # dossier des fiches sur l'hôte +CHUNK, OVERLAP = 800, 100 # identiques à l'indexation +K = 4 # nb de passages remis au modèle -if not IDX_FILE.exists() or not META_FILE.exists(): - raise SystemExit("[bold red]Erreur :[/] index absent. Lancez d'abord index.py !") +# --- découpe --------------------------------------------------------------- -index = faiss.read_index(str(IDX_FILE)) -meta = json.loads(META_FILE.read_text()) +def split(text: str): + sents = re.split(r"(?<=[.!?]) +", text) + buf, out = [], [] + for s in sents: + buf.append(s) + if len(" ".join(buf).split()) > CHUNK: # approx 1 mot = 1 token + out.append(" ".join(buf)) + buf = buf[-OVERLAP:] + if buf: + out.append(" ".join(buf)) + return out + +# --- charger docs + meta dans le même ordre que l'index -------------------- + +docs, meta = [], [] +for fp in ROOT.rglob("*.md"): + for i, chunk in enumerate(split(fp.read_text(encoding="utf-8"))): + docs.append(chunk) + meta.append({"file": fp.name, "part": i}) + +print(f"[dim]Chargé {len(docs)} passages depuis {ROOT}.[/]") + +# --- FAISS index existant --------------------------------------------------- + +idx = faiss.read_index("corpus.idx") model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") -# --------------------------------------------------------------------------- -# Utilitaires -# --------------------------------------------------------------------------- +# --- boucle Q/A ------------------------------------------------------------- -def _normalize(x: np.ndarray) -> np.ndarray: - """L2‑normalize each row (tokens=float32).""" - return x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-12) - - -def embed(texts): - """Encode list[str] → ndarray (n, dim), quelle que soit la sortie lib.""" - out = model.encode(texts) - # Possible shapes : - # • ndarray - # • dict {"embedding": ndarray} ou {"embeddings": ndarray} - # • dict {"sentence_embeds": [...]} etc. - if isinstance(out, np.ndarray): - arr = out - elif isinstance(out, dict): - # pick the first ndarray-like value - for v in out.values(): - if isinstance(v, (list, tuple)) or hasattr(v, "shape"): - arr = np.asarray(v) - break - else: - raise TypeError("encode() dict sans clé embedding !") - else: # list[list[float]] etc. - arr = np.asarray(out) - return _normalize(arr.astype("float32")) - - -def fetch_passage(i: int) -> str: +def fetch_passage(i: int): m = meta[i] - return f"[{m['file']} · part {m['part']}] {m['text']}" + return f"[{m['file']} · part {m['part']}] {docs[i][:200]}…" - -def ask_llm(prompt: str) -> str: - r = requests.post( - "http://127.0.0.1:11434/api/generate", - json={ - "model": "mistral7b-fast", - "prompt": prompt, - "stream": False, - "options": {"temperature": 0.2, "num_predict": 512}, - }, - timeout=300, - ) +def ask_llm(prompt: str): + r = requests.post("http://127.0.0.1:11434/api/generate", json={ + "model": "mistral7b-fast", + "prompt": prompt, + "stream": False, + "options": {"temperature": 0.2, "num_predict": 512} + }, timeout=300) return r.json()["response"] -# --------------------------------------------------------------------------- -# Boucle interactive -# --------------------------------------------------------------------------- -print("[bold green]RAG prêt.[/] Posez vos questions ! (Ctrl‑D pour sortir)") -while True: - try: - q = input("❓ > ").strip() - if not q: - continue - except (EOFError, KeyboardInterrupt): - print("\n[dim]Bye.[/]") - break +print("RAG prêt. Posez vos questions ! (Ctrl‑D pour sortir)") +try: + while True: + try: + q = input("❓ > ").strip() + if not q: + continue + except (EOFError, KeyboardInterrupt): + print("\nBye."); break - q_emb = embed([q]) # (1, dim) - D, I = index.search(q_emb, 4) - ctx = "\n\n".join(fetch_passage(int(idx)) for idx in I[0]) + emb = model.encode([q]) + if isinstance(emb, dict): + emb = next(v for v in emb.values() if isinstance(v, np.ndarray)) + q_emb = emb[0] / np.linalg.norm(emb[0]) - prompt = ( - "Réponds en français, précis et factuel.\n" - f"{ctx}\n" - f"{q}" - ) + D, I = idx.search(q_emb.astype("float32").reshape(1, -1), K) - print("\n[bold]Réponse :[/]\n") - print(ask_llm(prompt)) - print("\n[dim]--- contexte utilisé ---[/]") - print(ctx) + context = "\n\n".join(docs[int(idx_id)] for idx_id in I[0]) + prompt = f"""Réponds en français, précis et factuel.\n{context}\n{q}""" + + print("\n[bold]Réponse :[/]") + print(ask_llm(prompt)) + + # petite trace des sources + print("\n[dim]--- contexte utilisé ---[/]") + for idx_id in I[0]: + print(fetch_passage(int(idx_id))) +except Exception as e: + print("[red]Erreur :", e)