Update rag.py
This commit is contained in:
parent
c1a0a8e072
commit
852b81ba93
131
rag.py
131
rag.py
@ -1,58 +1,107 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import faiss, json, requests, readline, numpy as np
|
"""
|
||||||
from rich import print
|
Mini RAG interactif :
|
||||||
|
• Recherche sémantique FAISS sur le corpus.idx / corpus.meta.json
|
||||||
|
• Contexte (top‑k=4 passages) envoyé à Mistral‑7B via Ollama.
|
||||||
|
|
||||||
|
Robuste aux différentes sorties de BGEM3FlagModel.encode.
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import readline
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import faiss
|
||||||
|
import numpy as np
|
||||||
|
import requests
|
||||||
from FlagEmbedding import BGEM3FlagModel
|
from FlagEmbedding import BGEM3FlagModel
|
||||||
|
from rich import print
|
||||||
|
|
||||||
# --- chargements -------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
idx = faiss.read_index("corpus.idx")
|
# Chargements initiaux
|
||||||
meta = json.load(open("corpus.meta.json"))
|
# ---------------------------------------------------------------------------
|
||||||
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu") # même qu’à l’indexation
|
IDX_FILE = Path("corpus.idx")
|
||||||
|
META_FILE = Path("corpus.meta.json")
|
||||||
|
|
||||||
# simple aide mémoire pour retrouver rapidement un passage
|
if not IDX_FILE.exists() or not META_FILE.exists():
|
||||||
def fetch_passage(i):
|
raise SystemExit("[bold red]Erreur :[/] index absent. Lancez d'abord index.py !")
|
||||||
|
|
||||||
|
index = faiss.read_index(str(IDX_FILE))
|
||||||
|
meta = json.loads(META_FILE.read_text())
|
||||||
|
model = BGEM3FlagModel("BAAI/bge-m3", device="cpu")
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Utilitaires
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _normalize(x: np.ndarray) -> np.ndarray:
|
||||||
|
"""L2‑normalize each row (tokens=float32)."""
|
||||||
|
return x / (np.linalg.norm(x, axis=1, keepdims=True) + 1e-12)
|
||||||
|
|
||||||
|
|
||||||
|
def embed(texts):
|
||||||
|
"""Encode list[str] → ndarray (n, dim), quelle que soit la sortie lib."""
|
||||||
|
out = model.encode(texts)
|
||||||
|
# Possible shapes :
|
||||||
|
# • ndarray
|
||||||
|
# • dict {"embedding": ndarray} ou {"embeddings": ndarray}
|
||||||
|
# • dict {"sentence_embeds": [...]} etc.
|
||||||
|
if isinstance(out, np.ndarray):
|
||||||
|
arr = out
|
||||||
|
elif isinstance(out, dict):
|
||||||
|
# pick the first ndarray-like value
|
||||||
|
for v in out.values():
|
||||||
|
if isinstance(v, (list, tuple)) or hasattr(v, "shape"):
|
||||||
|
arr = np.asarray(v)
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
raise TypeError("encode() dict sans clé embedding !")
|
||||||
|
else: # list[list[float]] etc.
|
||||||
|
arr = np.asarray(out)
|
||||||
|
return _normalize(arr.astype("float32"))
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_passage(i: int) -> str:
|
||||||
m = meta[i]
|
m = meta[i]
|
||||||
return f"[{m['file']} · part {m['part']}] {m['text']}"
|
return f"[{m['file']} · part {m['part']}] {m['text']}"
|
||||||
|
|
||||||
def ask_llm(prompt):
|
|
||||||
r = requests.post("http://127.0.0.1:11434/api/generate", json={
|
def ask_llm(prompt: str) -> str:
|
||||||
"model": "mistral7b-fast",
|
r = requests.post(
|
||||||
"prompt": prompt,
|
"http://127.0.0.1:11434/api/generate",
|
||||||
"stream": False,
|
json={
|
||||||
"options": {"temperature":0.2, "num_predict":512}
|
"model": "mistral7b-fast",
|
||||||
}, timeout=300)
|
"prompt": prompt,
|
||||||
|
"stream": False,
|
||||||
|
"options": {"temperature": 0.2, "num_predict": 512},
|
||||||
|
},
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
return r.json()["response"]
|
return r.json()["response"]
|
||||||
|
|
||||||
# --- boucle interactive ------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
# Boucle interactive
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
print("[bold green]RAG prêt.[/] Posez vos questions ! (Ctrl‑D pour sortir)")
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
q = input("❓ > ").strip()
|
q = input("❓ > ").strip()
|
||||||
if not q: continue
|
if not q:
|
||||||
except (KeyboardInterrupt, EOFError):
|
continue
|
||||||
print("\nBye."); break
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
print("\n[dim]Bye.[/]")
|
||||||
|
break
|
||||||
|
|
||||||
# embeddings & recherche FAISS (top-k=4)
|
q_emb = embed([q]) # (1, dim)
|
||||||
# remplace ces deux lignes (32-34)
|
D, I = index.search(q_emb, 4)
|
||||||
# q_emb = model.encode([q], normalize_embeddings=True)
|
ctx = "\n\n".join(fetch_passage(int(idx)) for idx in I[0])
|
||||||
# D, I = idx.search(q_emb.astype("float32"), 4)
|
|
||||||
|
|
||||||
emb = model.encode([q]) # ndarray (1, 1024)
|
prompt = (
|
||||||
if isinstance(emb, dict): # selon la version de FlagEmbedding
|
"<system>Réponds en français, précis et factuel.</system>\n"
|
||||||
emb = emb.get("embedding") or emb.get("embeddings")
|
f"<context>{ctx}</context>\n"
|
||||||
q_emb = emb[0] / np.linalg.norm(emb[0]) # L2 normalisation
|
f"<user>{q}</user>"
|
||||||
|
)
|
||||||
|
|
||||||
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), 4)
|
print("\n[bold]Réponse :[/]\n")
|
||||||
|
|
||||||
|
|
||||||
ctx_blocks = []
|
|
||||||
for rank, idx_id in enumerate(I[0]):
|
|
||||||
ctx_blocks.append(fetch_passage(idx_id))
|
|
||||||
context = "\n\n".join(ctx_blocks)
|
|
||||||
|
|
||||||
prompt = f"""<system>Réponds en français, précis et factuel.</system>
|
|
||||||
<context>{context}</context>
|
|
||||||
<user>{q}</user>"""
|
|
||||||
|
|
||||||
print("\n[bold]Réponse :[/]\n")
|
|
||||||
print(ask_llm(prompt))
|
print(ask_llm(prompt))
|
||||||
print("\n[dim]--- contexte utilisé ---[/]")
|
print("\n[dim]--- contexte utilisé ---[/]")
|
||||||
print(context)
|
print(ctx)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user