diff --git a/rag.py b/rag.py index db6a9b8..2ecac8f 100644 --- a/rag.py +++ b/rag.py @@ -1,58 +1,107 @@ #!/usr/bin/env python3 -import faiss, json, requests, readline, numpy as np -from rich import print +""" +Mini RAG interactif : +• Recherche sémantique FAISS sur le corpus.idx / corpus.meta.json +• Contexte (top‑k=4 passages) envoyé à Mistral‑7B via Ollama. + +Robuste aux différentes sorties de BGEM3FlagModel.encode. +""" +import json +import readline +from pathlib import Path + +import faiss +import numpy as np +import requests from FlagEmbedding import BGEM3FlagModel +from rich import print -# --- chargements ------------------------------------------------------------- -idx = faiss.read_index("corpus.idx") -meta = json.load(open("corpus.meta.json")) -model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") # même qu’à l’indexation +# --------------------------------------------------------------------------- +# Chargements initiaux +# --------------------------------------------------------------------------- +IDX_FILE = Path("corpus.idx") +META_FILE = Path("corpus.meta.json") -# simple aide mémoire pour retrouver rapidement un passage -def fetch_passage(i): +if not IDX_FILE.exists() or not META_FILE.exists(): + raise SystemExit("[bold red]Erreur :[/] index absent. Lancez d'abord index.py !") + +index = faiss.read_index(str(IDX_FILE)) +meta = json.loads(META_FILE.read_text()) +model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") + +# --------------------------------------------------------------------------- +# Utilitaires +# --------------------------------------------------------------------------- + +def _normalize(x: np.ndarray) -> np.ndarray: + """L2‑normalize each row (tokens=float32).""" + return x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-12) + + +def embed(texts): + """Encode list[str] → ndarray (n, dim), quelle que soit la sortie lib.""" + out = model.encode(texts) + # Possible shapes : + # • ndarray + # • dict {"embedding": ndarray} ou {"embeddings": ndarray} + # • dict {"sentence_embeds": [...]} etc. + if isinstance(out, np.ndarray): + arr = out + elif isinstance(out, dict): + # pick the first ndarray-like value + for v in out.values(): + if isinstance(v, (list, tuple)) or hasattr(v, "shape"): + arr = np.asarray(v) + break + else: + raise TypeError("encode() dict sans clé embedding !") + else: # list[list[float]] etc. + arr = np.asarray(out) + return _normalize(arr.astype("float32")) + + +def fetch_passage(i: int) -> str: m = meta[i] return f"[{m['file']} · part {m['part']}] {m['text']}" -def ask_llm(prompt): - r = requests.post("http://127.0.0.1:11434/api/generate", json={ - "model": "mistral7b-fast", - "prompt": prompt, - "stream": False, - "options": {"temperature":0.2, "num_predict":512} - }, timeout=300) + +def ask_llm(prompt: str) -> str: + r = requests.post( + "http://127.0.0.1:11434/api/generate", + json={ + "model": "mistral7b-fast", + "prompt": prompt, + "stream": False, + "options": {"temperature": 0.2, "num_predict": 512}, + }, + timeout=300, + ) return r.json()["response"] -# --- boucle interactive ------------------------------------------------------ +# --------------------------------------------------------------------------- +# Boucle interactive +# --------------------------------------------------------------------------- +print("[bold green]RAG prêt.[/] Posez vos questions ! (Ctrl‑D pour sortir)") while True: try: q = input("❓ > ").strip() - if not q: continue - except (KeyboardInterrupt, EOFError): - print("\nBye."); break + if not q: + continue + except (EOFError, KeyboardInterrupt): + print("\n[dim]Bye.[/]") + break - # embeddings & recherche FAISS (top-k=4) - # remplace ces deux lignes (32-34) - # q_emb = model.encode([q], normalize_embeddings=True) - # D, I = idx.search(q_emb.astype("float32"), 4) + q_emb = embed([q]) # (1, dim) + D, I = index.search(q_emb, 4) + ctx = "\n\n".join(fetch_passage(int(idx)) for idx in I[0]) - emb = model.encode([q]) # ndarray (1, 1024) - if isinstance(emb, dict): # selon la version de FlagEmbedding - emb = emb.get("embedding") or emb.get("embeddings") - q_emb = emb[0] / np.linalg.norm(emb[0]) # L2 normalisation + prompt = ( + "Réponds en français, précis et factuel.\n" + f"{ctx}\n" + f"{q}" + ) - D, I = idx.search(q_emb.astype("float32").reshape(1, -1), 4) - - - ctx_blocks = [] - for rank, idx_id in enumerate(I[0]): - ctx_blocks.append(fetch_passage(idx_id)) - context = "\n\n".join(ctx_blocks) - - prompt = f"""Réponds en français, précis et factuel. -{context} -{q}""" - - print("\n[bold]Réponse :[/]\n") + print("\n[bold]Réponse :[/]\n") print(ask_llm(prompt)) print("\n[dim]--- contexte utilisé ---[/]") - print(context) + print(ctx)