From c1a0a8e072f77195d54e7f6f66af430c63c764a5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?St=C3=A9phan?= Date: Mon, 19 May 2025 06:21:30 +0200 Subject: [PATCH] Update rag.py --- rag.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/rag.py b/rag.py index eeba59e..db6a9b8 100644 --- a/rag.py +++ b/rag.py @@ -31,8 +31,17 @@ while True: print("\nBye."); break # embeddings & recherche FAISS (top-k=4) - q_emb = model.encode([q], normalize_embeddings=True) - D, I = idx.search(q_emb.astype("float32"), 4) + # remplace ces deux lignes (32-34) + # q_emb = model.encode([q], normalize_embeddings=True) + # D, I = idx.search(q_emb.astype("float32"), 4) + + emb = model.encode([q]) # ndarray (1, 1024) + if isinstance(emb, dict): # selon la version de FlagEmbedding + emb = emb.get("embedding") or emb.get("embeddings") + q_emb = emb[0] / np.linalg.norm(emb[0]) # L2 normalisation + + D, I = idx.search(q_emb.astype("float32").reshape(1, -1), 4) + ctx_blocks = [] for rank, idx_id in enumerate(I[0]):