Update rag.py
This commit is contained in:
parent
a569c71ad4
commit
c1a0a8e072
13
rag.py
13
rag.py
@ -31,8 +31,17 @@ while True:
|
|||||||
print("\nBye."); break
|
print("\nBye."); break
|
||||||
|
|
||||||
# embeddings & recherche FAISS (top-k=4)
|
# embeddings & recherche FAISS (top-k=4)
|
||||||
q_emb = model.encode([q], normalize_embeddings=True)
|
# remplace ces deux lignes (32-34)
|
||||||
D, I = idx.search(q_emb.astype("float32"), 4)
|
# q_emb = model.encode([q], normalize_embeddings=True)
|
||||||
|
# D, I = idx.search(q_emb.astype("float32"), 4)
|
||||||
|
|
||||||
|
emb = model.encode([q]) # ndarray (1, 1024)
|
||||||
|
if isinstance(emb, dict): # selon la version de FlagEmbedding
|
||||||
|
emb = emb.get("embedding") or emb.get("embeddings")
|
||||||
|
q_emb = emb[0] / np.linalg.norm(emb[0]) # L2 normalisation
|
||||||
|
|
||||||
|
D, I = idx.search(q_emb.astype("float32").reshape(1, -1), 4)
|
||||||
|
|
||||||
|
|
||||||
ctx_blocks = []
|
ctx_blocks = []
|
||||||
for rank, idx_id in enumerate(I[0]):
|
for rank, idx_id in enumerate(I[0]):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user