diff --git a/index.py b/index.py index 7e78524..195826e 100644 --- a/index.py +++ b/index.py @@ -98,7 +98,9 @@ def main(): ) model = BGEM3FlagModel(MODEL_NAME, device="cpu") - emb = model.encode(new_docs, batch_size=BATCH, return_dict=False) + emb_out = model.encode(new_docs, batch_size=BATCH) + # FlagEmbedding renvoie soit un ndarray, soit un dict {"embedding": …} + emb = emb_out["embedding"] if isinstance(emb_out, dict) else emb_out emb = emb.astype("float32") emb /= np.linalg.norm(emb, axis=1, keepdims=True) + 1e-12