Import de private_gpt et amléiorations de l'analyse IA

This commit is contained in:
Fabrication du Numérique 2025-05-27 17:21:49 +02:00
parent c4fffb829c
commit 95ede9c6f1
81 changed files with 6354 additions and 22 deletions

171
IA/get_regeneration_plan.py Normal file
View File

@ -0,0 +1,171 @@
from datetime import datetime
from collections import defaultdict, deque
import os
import sys
from datetime import timezone
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
# À adapter dans ton environnement
from config import GITEA_URL, ORGANISATION, DEPOT_FICHES, ENV
from utils.gitea import recuperer_date_dernier_commit
from IA.make_config import MAKE # MAKE doit être importé depuis un fichier de config
def get_mtime(path):
try:
return datetime.fromtimestamp(os.path.getmtime(path), tz=timezone.utc)
except FileNotFoundError:
return None
def get_commit_time(path_relative):
commits_url = f"{GITEA_URL}/repos/{ORGANISATION}/{DEPOT_FICHES}/commits?path={path_relative.replace("Fiches", "Documents")}&sha={ENV}"
return recuperer_date_dernier_commit(commits_url)
def resolve_path_from_where(where_str):
parts = where_str.split(".")
current = MAKE
path_stack = []
for part in parts:
if isinstance(current, dict) and part in current:
path_stack.append((part, current))
current = current[part]
else:
return None
if not isinstance(current, str):
return None
for i in range(len(path_stack) - 1, -1, -1):
key, context = path_stack[i]
if "directory" in context:
directory = context["directory"]
if "fiches" in where_str:
return os.path.join("Fiches", directory, current)
else:
return os.path.join(directory, current)
return None
def identifier_type_fiche(path):
for type_fiche, data in MAKE["fiches"].items():
if not isinstance(data, dict):
continue
directory = data.get("directory", "")
prefix = data.get("prefix", "")
base = os.path.join("Fiches", directory, prefix)
if path.startswith(base):
return type_fiche, data
raise ValueError("Type de fiche non reconnu")
def doit_regenerer(fichier, doc_deps, fiche_data=None):
mtime_fichier = get_mtime(fichier)
if fiche_data:
gitea_dep = fiche_data.get("depends_on", {}).get("gitea", {})
if gitea_dep.get("compare") == "file2commit":
commit_time = get_commit_time(fichier)
if commit_time and mtime_fichier and commit_time > mtime_fichier:
return True
for _, dep in doc_deps.items():
if isinstance(dep, dict) and "where" in dep:
source_path = resolve_path_from_where(dep["where"])
if source_path:
if dep["compare"] == "file2file":
mtime_source = get_mtime(source_path)
elif dep["compare"] == "file2commit":
mtime_source = get_commit_time(source_path)
else:
continue
if mtime_source and mtime_fichier and mtime_source > mtime_fichier:
return True
return False
def get_regeneration_plan(fiche_path):
def build_dependency_graph_complete(path, graph=None, visited=None):
if graph is None:
graph = defaultdict(set)
if visited is None:
visited = set()
if path in visited:
return graph
visited.add(path)
try:
_, fiche_data = identifier_type_fiche(path)
except ValueError:
return graph
depends = fiche_data.get("depends_on", {})
doc_deps = depends.get("document", {})
for _, dep_info in doc_deps.items():
if isinstance(dep_info, dict) and "where" in dep_info:
dep_path = resolve_path_from_where(dep_info["where"])
if dep_path:
graph[path].add(dep_path)
build_dependency_graph_complete(dep_path, graph, visited)
if "file" in fiche_data:
for fichier in fiche_data["file"].values():
dir_fiche = fiche_data.get("directory", "")
fichier_path = os.path.join("Fiches", dir_fiche, fichier)
if fichier_path not in graph:
graph[fichier_path] = set()
return graph
def topological_sort(graph):
in_degree = defaultdict(int)
for node in graph:
for dep in graph[node]:
in_degree[dep] += 1
queue = deque([node for node in graph if in_degree[node] == 0])
result = []
while queue:
node = queue.popleft()
result.append(node)
for dep in graph[node]:
in_degree[dep] -= 1
if in_degree[dep] == 0:
queue.append(dep)
all_nodes = set(graph.keys()).union(*graph.values())
for node in all_nodes:
if node not in result:
result.append(node)
return result[::-1]
graph = build_dependency_graph_complete(fiche_path)
sorted_fiches = topological_sort(graph)
if fiche_path not in sorted_fiches:
sorted_fiches.append(fiche_path)
to_regen = []
regen_flags = {}
for fiche in sorted_fiches:
print(f"=> {fiche}")
try:
_, fiche_data = identifier_type_fiche(fiche)
except ValueError:
fiche_data = None
depends = fiche_data.get("depends_on", {}) if fiche_data else {}
doc_deps = depends.get("document", {}) if depends else {}
doit = doit_regenerer(fiche, doc_deps, fiche_data)
if any(regen_flags.get(dep, False) for dep in graph.get(fiche, [])):
doit = True
regen_flags[fiche] = doit
if doit:
to_regen.append(fiche)
return to_regen
plan = get_regeneration_plan("Fiches/Minerai/Fiche minerai antimoine.md")
print(plan)

141
IA/make_config.py Normal file
View File

@ -0,0 +1,141 @@
from utils.gitea import recuperer_date_dernier_commit
#
# from config import GITEA_URL, GITEA_TOKEN, ORGANISATION, DEPOT_FICHES, DEPOT_CODE, ENV, ENV_CODE, DOT_FILE
#
#def recuperer_date_dernier_commit(url):
# headers = {"Authorization": f"token {GITEA_TOKEN}"}
# try:
# response = requests.get(url, headers=headers, timeout=10)
# response.raise_for_status()
# commits = response.json()
# if commits:
# return parser.isoparse(commits[0]["commit"]["author"]["date"])
# except Exception as e:
# logging.error(f"Erreur récupération commit schema : {e}")
# return None
#
# path_relative = f"Documents/{dossier_choisi}/{fiche_choisie}"
# commits_url = f"{GITEA_URL}/repos/{ORGANISATION}/{DEPOT_FICHES}/commits?path={path_relative}&sha={ENV}"
#
# local_mtime = datetime.fromtimestamp(os.path.getmtime(path_relative), tz=timezone.utc)
# remote_mtime = recuperer_date_dernier_commit(commit_url)
MAKE = {
"assets": {
"directory": "assets",
"seuils": {
"depends_on": "None",
},
"file": {
"seuils": "config.yaml"
}
},
"fiches": {
"directory": "Fiches",
"criticites": {
"directory": "Criticités",
"préfix": "Fiche technique ",
"depends_on": {
"gitea": {
"compare": "file2commit"
},
"document": {
"seuils": {
"where": "assets.file.seuils",
"compare": "file2file"
}
}
},
"file": {
"ihh": "Fiche technique IHH.md",
"isg": "Fiche technique ISG.md",
"ivc": "Fiche technique IVC.md",
"ics": "Fiche technique ICS.md"
}
},
"assemblage": {
"directory": "Assemblage",
"prefix": "Fiche assemblage ",
"depends_on": {
"gitea": {
"compare": "file2commit"
},
"document": {
"ihh": {
"where": "fiches.criticites.file.ihh",
"compare": "file2file"
},
"isg": {
"where": "fiches.criticites.file.isg",
"compare": "file2file"
}
}
}
},
"fabrication": {
"directory": "Fabrication",
"prefix": "Fiche fabrication ",
"depends_on": {
"gitea": {
"compare": "file2commit"
},
"document": {
"ihh": {
"where": "fiches.criticites.file.ihh",
"compare": "file2file"
},
"isg": {
"where": "fiches.criticites.file.isg",
"compare": "file2file"
}
}
}
},
"connexe": {
"directory": "Connexe",
"prefix": "Fiche assemblage ",
"depends_on": {
"gitea": {
"compare": "file2commit"
},
"document": {
"ihh": {
"where": "fiches.criticites.file.ihh",
"compare": "file2file"
},
"isg": {
"where": "fiches.criticites.file.isg",
"compare": "file2file"
}
}
}
},
"minerai": {
"directory": "Minerai",
"prefix": "Fiche minerai ",
"depends_on": {
"gitea": {
"compare": "file2commit"
},
"document": {
"ihh": {
"where": "fiches.criticites.file.ihh",
"compare": "file2file"
},
"isg": {
"where": "fiches.criticites.file.isg",
"compare": "file2file"
},
"ics": {
"where": "fiches.criticites.file.ics",
"compare": "file2file"
},
"ivc": {
"where": "fiches.criticites.file.ivc",
"compare": "file2file"
}
}
}
}
}
}

View File

@ -1,5 +1,5 @@
version: 1.1
date: 2025-05-06
date: 2025-05-27
seuils:
IVC: # Indice de vulnérabilité concurrentielle

View File

@ -1139,7 +1139,7 @@ def generate_operations_section(data, results, config):
if total_share > 0:
isg_combined = isg_weighted_sum / total_share
color, suffix = determine_threshold_color(isg_combined, "ISG", config.get('thresholds'))
template.append(f"\n**ISG combiné: {isg_combined:.0f} - {color} ({suffix})**\n")
template.append(f"\n**ISG combiné: {isg_combined:.0f} - {color} ({suffix})**\n\n")
# IHH
ihh_file = find_corpus_file("matrice-des-risques-liés-à-la-fabrication/indice-de-herfindahl-hirschmann", f"Fabrication/Fiche fabrication {component_slug}")
@ -1647,7 +1647,7 @@ def generate_report(data, results, config):
return "\n".join(report), file_names
def generate_text(input_file, full_prompt, system_message, temperature = "0.1"):
def generate_text(input_file, full_prompt, system_message, temperature = "0.1", use_context = True):
"""Génère du texte avec l'API PrivateGPT"""
try:
@ -1657,7 +1657,7 @@ def generate_text(input_file, full_prompt, system_message, temperature = "0.1"):
{"role": "system", "content": system_message},
{"role": "user", "content": full_prompt}
],
"use_context": True, # Active la recherche RAG dans les documents ingérés
"use_context": use_context, # Active la recherche RAG dans les documents ingérés
"temperature": temperature, # Température réduite pour plus de cohérence
"stream": False
}
@ -1733,6 +1733,8 @@ def ia_analyse(file_names):
Votre analyse doit être rigoureuse, accessible, pertinente pour la prise de décision stratégique, et conforme à la méthodologie définie ci-dessous :
{PROMPT_METHODOLOGIE}
/no_think
"""
reponse[produit_final] = f"\n**{produit_final}**\n\n" + generate_text(file, full_prompt, system_message).split("</think>")[-1].strip()
@ -1758,6 +1760,8 @@ def ia_analyse(file_names):
- Être fluide, agréable à lire, avec un ton sobre et professionnel.
Répondez uniquement avec l'introduction rédigée. Ne fournissez aucune autre explication complémentaire.
/no_think
"""
@ -1794,6 +1798,8 @@ def ia_analyse(file_names):
- Inviter de manière dynamique le COMEX à passer immédiatement à l'action.
Votre rédaction doit être fluide, professionnelle, claire et immédiatement exploitable par des dirigeants. Ne fournissez aucune explication supplémentaire. Ne répondez que par la conclusion demandée.
/no_think
"""
conclusion = generate_text("", full_prompt, system_message, "0.7").split("</think>")[-1].strip()
@ -1811,36 +1817,47 @@ def ia_analyse(file_names):
"\n\n## Méthodologie\n\n" + \
PROMPT_METHODOLOGIE
fichier_a_reviser = Path(TEMPLATE_PATH.name.replace(".md", " - analyse à relire.md"))
write_report(analyse, TEMP_SECTIONS / fichier_a_reviser)
ingest_document(TEMP_SECTIONS / fichier_a_reviser)
# fichier_a_reviser = Path(TEMPLATE_PATH.name.replace(".md", " - analyse à relire.md"))
# write_report(analyse, TEMP_SECTIONS / fichier_a_reviser)
# ingest_document(TEMP_SECTIONS / fichier_a_reviser)
full_prompt = f"""
Le fichier à réviser est {fichier_a_reviser}. Suivre scrupuleusement les consignes.
full_prompt = """
Suivre scrupuleusement les consignes.
"""
system_message = f"""
Vous êtes un réviseur professionnel expert en écriture stratégique, maîtrisant parfaitement la langue française et habitué à réviser des textes destinés à des dirigeants de haut niveau (COMEX).
Votre unique tâche est d'améliorer la qualité rédactionnelle du texte dans le fichier {fichier_a_reviser}, sans modifier ni sa structure, ni son sens initial, ni ajouter dinformations nouvelles. Cette révision doit :
Votre tâche unique est d'améliorer strictement la qualité rédactionnelle du texte suivant, sans modifier en aucune manière :
- la structure existante (sections, titres, sous-titres),
- l'ordre des paragraphes et des idées,
- le sens précis du contenu original,
- sans ajouter aucune information nouvelle.
Votre révision doit impérativement respecter les points suivants :
- Éliminer toutes répétitions ou redondances et varier systématiquement les tournures entre les paragraphes.
- Rendre chaque phrase claire, directe et concise. Si une phrase est trop longue, scindez-la en plusieurs phrases courtes.
- Scinder les paragraphes en 2 ou 3 parties cohérentes et bien enchaînées avec des termes de coordinations, d'implication, …
- Remplacer systématiquement les acronymes par les expressions suivantes :
- Rendre chaque phrase claire, directe et concise. Si une phrase est trop longue, scindez-la clairement en plusieurs phrases courtes.
- Structurer chaque paragraphe en 2 à 3 parties cohérentes, reliées entre elles par des termes logiques (coordination, implication, opposition, etc.) et séparées par des retours à la ligne.
- Remplacer systématiquement les acronymes par ces expressions précises :
- ICS « capacité à substituer un minerai »
- IHH « concentration géographique ou industrielle »
- ISG « stabilité géopolitique »
- IVC « concurrence intersectorielle pour les minerais »
Votre texte final doit être fluide, agréable à lire, parfaitement adapté à un COMEX, avec un ton professionnel et sobre.
Votre texte final doit être parfaitement fluide, agréable à lire, adapté à un COMEX, avec un ton professionnel et sobre.
Répondez uniquement avec le texte révisé, sans autre commentaire.
**Important : Ne répondez strictement que par le texte révisé ci-dessous, sans aucun commentaire ou explication supplémentaire.**
Voici le texte à réviser précisément :
{analyse}
/no_think
"""
corps = generate_text(fichier_a_reviser, full_prompt, system_message, "0.6").split("</think>")[-1].strip()
revision = generate_text("", full_prompt, system_message, "0.1", False).split("</think>")[-1].strip()
print("Relecture")
return analyse
return revision
def write_report(report, fichier):
"""Écrit le rapport généré dans le fichier spécifié."""

View File

@ -1 +1,7 @@
{}
{
"stephan": {
"status": "en cours",
"timestamp": 1748358233.0347543,
"position": 0
}
}

16
pgpt/.docker/router.yml Normal file
View File

@ -0,0 +1,16 @@
http:
services:
ollama:
loadBalancer:
healthCheck:
interval: 5s
path: /
servers:
- url: http://ollama-cpu:11434
- url: http://ollama-cuda:11434
- url: http://host.docker.internal:11434
routers:
ollama-router:
rule: "PathPrefix(`/`)"
service: ollama

51
pgpt/Dockerfile.ollama Normal file
View File

@ -0,0 +1,51 @@
FROM python:3.11.6-slim-bookworm AS base
# Install poetry
RUN pip install pipx
RUN python3 -m pipx ensurepath
RUN pipx install poetry==1.8.3
ENV PATH="/root/.local/bin:$PATH"
ENV PATH=".venv/bin/:$PATH"
# https://python-poetry.org/docs/configuration/#virtualenvsin-project
ENV POETRY_VIRTUALENVS_IN_PROJECT=true
FROM base AS dependencies
WORKDIR /home/worker/app
COPY pyproject.toml poetry.lock ./
ARG POETRY_EXTRAS="ui vector-stores-qdrant llms-ollama embeddings-ollama"
RUN poetry install --no-root --extras "${POETRY_EXTRAS}"
FROM base AS app
ENV PYTHONUNBUFFERED=1
ENV PORT=8080
ENV APP_ENV=prod
ENV PYTHONPATH="$PYTHONPATH:/home/worker/app/private_gpt/"
EXPOSE 8080
# Prepare a non-root user
# More info about how to configure UIDs and GIDs in Docker:
# https://github.com/systemd/systemd/blob/main/docs/UIDS-GIDS.md
# Define the User ID (UID) for the non-root user
# UID 100 is chosen to avoid conflicts with existing system users
ARG UID=100
# Define the Group ID (GID) for the non-root user
# GID 65534 is often used for the 'nogroup' or 'nobody' group
ARG GID=65534
RUN adduser --system --gid ${GID} --uid ${UID} --home /home/worker worker
WORKDIR /home/worker/app
RUN chown worker /home/worker/app
RUN mkdir local_data && chown worker local_data
RUN mkdir models && chown worker models
COPY --chown=worker --from=dependencies /home/worker/app/.venv/ .venv
COPY --chown=worker private_gpt/ private_gpt
COPY --chown=worker *.yaml .
COPY --chown=worker scripts/ scripts
USER worker
ENTRYPOINT python -m private_gpt

121
pgpt/docker-compose.yaml Normal file
View File

@ -0,0 +1,121 @@
services:
#-----------------------------------
#---- Private-GPT services ---------
#-----------------------------------
# Private-GPT service for the Ollama CPU and GPU modes
# This service builds from an external Dockerfile and runs the Ollama mode.
private-gpt-ollama:
image: ${PGPT_IMAGE:-zylonai/private-gpt}:${PGPT_TAG:-0.6.2}-ollama # x-release-please-version
user: root
build:
context: .
dockerfile: Dockerfile.ollama
volumes:
- /home/fabnum/fabnum-dev/Fiches:/home/worker/app/local_data/Fiches:Z
ports:
- "127.0.0.1:8001:8001"
environment:
PORT: 8001
PGPT_PROFILES: docker
PGPT_MODE: ollama
PGPT_EMBED_MODE: ollama
PGPT_OLLAMA_API_BASE: http://ollama:11434
HF_TOKEN: ${HF_TOKEN:-}
profiles:
- ""
- ollama-cpu
- ollama-cuda
- ollama-api
depends_on:
ollama:
condition: service_healthy
# Private-GPT service for the local mode
# This service builds from a local Dockerfile and runs the application in local mode.
private-gpt-llamacpp-cpu:
image: ${PGPT_IMAGE:-zylonai/private-gpt}:${PGPT_TAG:-0.6.2}-llamacpp-cpu # x-release-please-version
user: root
build:
context: .
dockerfile: Dockerfile.llamacpp-cpu
volumes:
- ./local_data/:/home/worker/app/local_data
- ./models/:/home/worker/app/models
entrypoint: sh -c ".venv/bin/python scripts/setup && .venv/bin/python -m private_gpt"
ports:
- "127.0.0.1:8001:8001"
environment:
PORT: 8001
PGPT_PROFILES: local
HF_TOKEN: ${HF_TOKEN:-}
profiles:
- llamacpp-cpu
#-----------------------------------
#---- Ollama services --------------
#-----------------------------------
# Traefik reverse proxy for the Ollama service
# This will route requests to the Ollama service based on the profile.
ollama:
image: traefik:v2.10
healthcheck:
test:
[
"CMD",
"sh",
"-c",
"wget -q --spider http://ollama:11434 || exit 1",
]
interval: 10s
retries: 3
start_period: 5s
timeout: 5s
ports:
- "127.0.0.1:8080:8080"
command:
- "--providers.file.filename=/etc/router.yml"
- "--log.level=ERROR"
- "--api.insecure=true"
- "--providers.docker=true"
- "--providers.docker.exposedbydefault=false"
- "--entrypoints.web.address=:11434"
volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro
- ./.docker/router.yml:/etc/router.yml:ro
extra_hosts:
- "host.docker.internal:host-gateway"
profiles:
- ""
- ollama-cpu
- ollama-cuda
- ollama-api
# Ollama service for the CPU mode
ollama-cpu:
image: ollama/ollama:latest
ports:
- "127.0.0.1:11434:11434"
volumes:
- ./models:/root/.ollama:Z
profiles:
- ""
- ollama-cpu
# Ollama service for the CUDA mode
ollama-cuda:
image: ollama/ollama:latest
ports:
- "11434:11434"
volumes:
- ./models:/root/.ollama
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: 1
capabilities: [gpu]
profiles:
- ollama-cuda

View File

@ -0,0 +1,27 @@
"""private-gpt."""
import logging
import os
# Set to 'DEBUG' to have extensive logging turned on, even for libraries
ROOT_LOG_LEVEL = "INFO"
PRETTY_LOG_FORMAT = (
"%(asctime)s.%(msecs)03d [%(levelname)-8s] %(name)+25s - %(message)s"
)
logging.basicConfig(level=ROOT_LOG_LEVEL, format=PRETTY_LOG_FORMAT, datefmt="%H:%M:%S")
logging.captureWarnings(True)
# Disable gradio analytics
# This is done this way because gradio does not solely rely on what values are
# passed to gr.Blocks(enable_analytics=...) but also on the environment
# variable GRADIO_ANALYTICS_ENABLED. `gradio.strings` actually reads this env
# directly, so to fully disable gradio analytics we need to set this env var.
os.environ["GRADIO_ANALYTICS_ENABLED"] = "False"
# Disable chromaDB telemetry
# It is already disabled, see PR#1144
# os.environ["ANONYMIZED_TELEMETRY"] = "False"
# adding tiktoken cache path within repo to be able to run in offline environment.
os.environ["TIKTOKEN_CACHE_DIR"] = "tiktoken_cache"

View File

@ -0,0 +1,11 @@
# start a fastapi server with uvicorn
import uvicorn
from private_gpt.main import app
from private_gpt.settings.settings import settings
# Set log_config=None to do not use the uvicorn logging configuration, and
# use ours instead. For reference, see below:
# https://github.com/tiangolo/fastapi/discussions/7457#discussioncomment-5141108
uvicorn.run(app, host="0.0.0.0", port=settings().server.port, log_config=None)

View File

View File

@ -0,0 +1,82 @@
# mypy: ignore-errors
import json
from typing import Any
import boto3
from llama_index.core.base.embeddings.base import BaseEmbedding
from pydantic import Field, PrivateAttr
class SagemakerEmbedding(BaseEmbedding):
"""Sagemaker Embedding Endpoint.
To use, you must supply the endpoint name from your deployed
Sagemaker embedding model & the region where it is deployed.
To authenticate, the AWS client uses the following methods to
automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If a specific credential profile should be used, you must pass
the name of the profile from the ~/.aws/credentials file that is to be used.
Make sure the credentials / roles used have the required policies to
access the Sagemaker endpoint.
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
"""
endpoint_name: str = Field(description="")
_boto_client: Any = boto3.client(
"sagemaker-runtime",
) # TODO make it an optional field
_async_not_implemented_warned: bool = PrivateAttr(default=False)
@classmethod
def class_name(cls) -> str:
return "SagemakerEmbedding"
def _async_not_implemented_warn_once(self) -> None:
if not self._async_not_implemented_warned:
print("Async embedding not available, falling back to sync method.")
self._async_not_implemented_warned = True
def _embed(self, sentences: list[str]) -> list[list[float]]:
request_params = {
"inputs": sentences,
}
resp = self._boto_client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=json.dumps(request_params),
ContentType="application/json",
)
response_body = resp["Body"]
response_str = response_body.read().decode("utf-8")
response_json = json.loads(response_str)
return response_json["vectors"]
def _get_query_embedding(self, query: str) -> list[float]:
"""Get query embedding."""
return self._embed([query])[0]
async def _aget_query_embedding(self, query: str) -> list[float]:
# Warn the user that sync is being used
self._async_not_implemented_warn_once()
return self._get_query_embedding(query)
async def _aget_text_embedding(self, text: str) -> list[float]:
# Warn the user that sync is being used
self._async_not_implemented_warn_once()
return self._get_text_embedding(text)
def _get_text_embedding(self, text: str) -> list[float]:
"""Get text embedding."""
return self._embed([text])[0]
def _get_text_embeddings(self, texts: list[str]) -> list[list[float]]:
"""Get text embeddings."""
return self._embed(texts)

View File

@ -0,0 +1,167 @@
import logging
from injector import inject, singleton
from llama_index.core.embeddings import BaseEmbedding, MockEmbedding
from private_gpt.paths import models_cache_path
from private_gpt.settings.settings import Settings
logger = logging.getLogger(__name__)
@singleton
class EmbeddingComponent:
embedding_model: BaseEmbedding
@inject
def __init__(self, settings: Settings) -> None:
embedding_mode = settings.embedding.mode
logger.info("Initializing the embedding model in mode=%s", embedding_mode)
match embedding_mode:
case "huggingface":
try:
from llama_index.embeddings.huggingface import ( # type: ignore
HuggingFaceEmbedding,
)
except ImportError as e:
raise ImportError(
"Local dependencies not found, install with `poetry install --extras embeddings-huggingface`"
) from e
self.embedding_model = HuggingFaceEmbedding(
model_name=settings.huggingface.embedding_hf_model_name,
cache_folder=str(models_cache_path),
trust_remote_code=settings.huggingface.trust_remote_code,
)
case "sagemaker":
try:
from private_gpt.components.embedding.custom.sagemaker import (
SagemakerEmbedding,
)
except ImportError as e:
raise ImportError(
"Sagemaker dependencies not found, install with `poetry install --extras embeddings-sagemaker`"
) from e
self.embedding_model = SagemakerEmbedding(
endpoint_name=settings.sagemaker.embedding_endpoint_name,
)
case "openai":
try:
from llama_index.embeddings.openai import ( # type: ignore
OpenAIEmbedding,
)
except ImportError as e:
raise ImportError(
"OpenAI dependencies not found, install with `poetry install --extras embeddings-openai`"
) from e
api_base = (
settings.openai.embedding_api_base or settings.openai.api_base
)
api_key = settings.openai.embedding_api_key or settings.openai.api_key
model = settings.openai.embedding_model
self.embedding_model = OpenAIEmbedding(
api_base=api_base,
api_key=api_key,
model=model,
)
case "ollama":
try:
from llama_index.embeddings.ollama import ( # type: ignore
OllamaEmbedding,
)
from ollama import Client # type: ignore
except ImportError as e:
raise ImportError(
"Local dependencies not found, install with `poetry install --extras embeddings-ollama`"
) from e
ollama_settings = settings.ollama
# Calculate embedding model. If not provided tag, it will be use latest
model_name = (
ollama_settings.embedding_model + ":latest"
if ":" not in ollama_settings.embedding_model
else ollama_settings.embedding_model
)
self.embedding_model = OllamaEmbedding(
model_name=model_name,
base_url=ollama_settings.embedding_api_base,
)
if ollama_settings.autopull_models:
if ollama_settings.autopull_models:
from private_gpt.utils.ollama import (
check_connection,
pull_model,
)
# TODO: Reuse llama-index client when llama-index is updated
client = Client(
host=ollama_settings.embedding_api_base,
timeout=ollama_settings.request_timeout,
)
if not check_connection(client):
raise ValueError(
f"Failed to connect to Ollama, "
f"check if Ollama server is running on {ollama_settings.api_base}"
)
pull_model(client, model_name)
case "azopenai":
try:
from llama_index.embeddings.azure_openai import ( # type: ignore
AzureOpenAIEmbedding,
)
except ImportError as e:
raise ImportError(
"Azure OpenAI dependencies not found, install with `poetry install --extras embeddings-azopenai`"
) from e
azopenai_settings = settings.azopenai
self.embedding_model = AzureOpenAIEmbedding(
model=azopenai_settings.embedding_model,
deployment_name=azopenai_settings.embedding_deployment_name,
api_key=azopenai_settings.api_key,
azure_endpoint=azopenai_settings.azure_endpoint,
api_version=azopenai_settings.api_version,
)
case "gemini":
try:
from llama_index.embeddings.gemini import ( # type: ignore
GeminiEmbedding,
)
except ImportError as e:
raise ImportError(
"Gemini dependencies not found, install with `poetry install --extras embeddings-gemini`"
) from e
self.embedding_model = GeminiEmbedding(
api_key=settings.gemini.api_key,
model_name=settings.gemini.embedding_model,
)
case "mistralai":
try:
from llama_index.embeddings.mistralai import ( # type: ignore
MistralAIEmbedding,
)
except ImportError as e:
raise ImportError(
"Mistral dependencies not found, install with `poetry install --extras embeddings-mistral`"
) from e
api_key = settings.openai.api_key
model = settings.openai.embedding_model
self.embedding_model = MistralAIEmbedding(
api_key=api_key,
model=model,
)
case "mock":
# Not a random number, is the dimensionality used by
# the default embedding model
self.embedding_model = MockEmbedding(384)

View File

@ -0,0 +1,517 @@
import abc
import itertools
import logging
import multiprocessing
import multiprocessing.pool
import os
import threading
from pathlib import Path
from queue import Queue
from typing import Any
from llama_index.core.data_structs import IndexDict
from llama_index.core.embeddings.utils import EmbedType
from llama_index.core.indices import VectorStoreIndex, load_index_from_storage
from llama_index.core.indices.base import BaseIndex
from llama_index.core.ingestion import run_transformations
from llama_index.core.schema import BaseNode, Document, TransformComponent
from llama_index.core.storage import StorageContext
from private_gpt.components.ingest.ingest_helper import IngestionHelper
from private_gpt.paths import local_data_path
from private_gpt.settings.settings import Settings
from private_gpt.utils.eta import eta
logger = logging.getLogger(__name__)
class BaseIngestComponent(abc.ABC):
def __init__(
self,
storage_context: StorageContext,
embed_model: EmbedType,
transformations: list[TransformComponent],
*args: Any,
**kwargs: Any,
) -> None:
logger.debug("Initializing base ingest component type=%s", type(self).__name__)
self.storage_context = storage_context
self.embed_model = embed_model
self.transformations = transformations
@abc.abstractmethod
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
pass
@abc.abstractmethod
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
pass
@abc.abstractmethod
def delete(self, doc_id: str) -> None:
pass
class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
def __init__(
self,
storage_context: StorageContext,
embed_model: EmbedType,
transformations: list[TransformComponent],
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
self.show_progress = True
self._index_thread_lock = (
threading.Lock()
) # Thread lock! Not Multiprocessing lock
self._index = self._initialize_index()
def _initialize_index(self) -> BaseIndex[IndexDict]:
"""Initialize the index from the storage context."""
try:
# Load the index with store_nodes_override=True to be able to delete them
index = load_index_from_storage(
storage_context=self.storage_context,
store_nodes_override=True, # Force store nodes in index and document stores
show_progress=self.show_progress,
embed_model=self.embed_model,
transformations=self.transformations,
)
except ValueError:
# There are no index in the storage context, creating a new one
logger.info("Creating a new vector store index")
index = VectorStoreIndex.from_documents(
[],
storage_context=self.storage_context,
store_nodes_override=True, # Force store nodes in index and document stores
show_progress=self.show_progress,
embed_model=self.embed_model,
transformations=self.transformations,
)
index.storage_context.persist(persist_dir=local_data_path)
return index
def _save_index(self) -> None:
self._index.storage_context.persist(persist_dir=local_data_path)
def delete(self, doc_id: str) -> None:
with self._index_thread_lock:
# Delete the document from the index
self._index.delete_ref_doc(doc_id, delete_from_docstore=True)
# Save the index
self._save_index()
class SimpleIngestComponent(BaseIngestComponentWithIndex):
def __init__(
self,
storage_context: StorageContext,
embed_model: EmbedType,
transformations: list[TransformComponent],
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
logger.info("Ingesting file_name=%s", file_name)
documents = IngestionHelper.transform_file_into_documents(file_name, file_data)
logger.info(
"Transformed file=%s into count=%s documents", file_name, len(documents)
)
logger.debug("Saving the documents in the index and doc store")
return self._save_docs(documents)
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
saved_documents = []
for file_name, file_data in files:
documents = IngestionHelper.transform_file_into_documents(
file_name, file_data
)
saved_documents.extend(self._save_docs(documents))
return saved_documents
def _save_docs(self, documents: list[Document]) -> list[Document]:
logger.debug("Transforming count=%s documents into nodes", len(documents))
with self._index_thread_lock:
for document in documents:
self._index.insert(document, show_progress=True)
logger.debug("Persisting the index and nodes")
# persist the index and nodes
self._save_index()
logger.debug("Persisted the index and nodes")
return documents
class BatchIngestComponent(BaseIngestComponentWithIndex):
"""Parallelize the file reading and parsing on multiple CPU core.
This also makes the embeddings to be computed in batches (on GPU or CPU).
"""
def __init__(
self,
storage_context: StorageContext,
embed_model: EmbedType,
transformations: list[TransformComponent],
count_workers: int,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
# Make an efficient use of the CPU and GPU, the embedding
# must be in the transformations
assert (
len(self.transformations) >= 2
), "Embeddings must be in the transformations"
assert count_workers > 0, "count_workers must be > 0"
self.count_workers = count_workers
self._file_to_documents_work_pool = multiprocessing.Pool(
processes=self.count_workers
)
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
logger.info("Ingesting file_name=%s", file_name)
documents = IngestionHelper.transform_file_into_documents(file_name, file_data)
logger.info(
"Transformed file=%s into count=%s documents", file_name, len(documents)
)
logger.debug("Saving the documents in the index and doc store")
return self._save_docs(documents)
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
documents = list(
itertools.chain.from_iterable(
self._file_to_documents_work_pool.starmap(
IngestionHelper.transform_file_into_documents, files
)
)
)
logger.info(
"Transformed count=%s files into count=%s documents",
len(files),
len(documents),
)
return self._save_docs(documents)
def _save_docs(self, documents: list[Document]) -> list[Document]:
logger.debug("Transforming count=%s documents into nodes", len(documents))
nodes = run_transformations(
documents, # type: ignore[arg-type]
self.transformations,
show_progress=self.show_progress,
)
# Locking the index to avoid concurrent writes
with self._index_thread_lock:
logger.info("Inserting count=%s nodes in the index", len(nodes))
self._index.insert_nodes(nodes, show_progress=True)
for document in documents:
self._index.docstore.set_document_hash(
document.get_doc_id(), document.hash
)
logger.debug("Persisting the index and nodes")
# persist the index and nodes
self._save_index()
logger.debug("Persisted the index and nodes")
return documents
class ParallelizedIngestComponent(BaseIngestComponentWithIndex):
"""Parallelize the file ingestion (file reading, embeddings, and index insertion).
This use the CPU and GPU in parallel (both running at the same time), and
reduce the memory pressure by not loading all the files in memory at the same time.
"""
def __init__(
self,
storage_context: StorageContext,
embed_model: EmbedType,
transformations: list[TransformComponent],
count_workers: int,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
# To make an efficient use of the CPU and GPU, the embeddings
# must be in the transformations (to be computed in batches)
assert (
len(self.transformations) >= 2
), "Embeddings must be in the transformations"
assert count_workers > 0, "count_workers must be > 0"
self.count_workers = count_workers
# We are doing our own multiprocessing
# To do not collide with the multiprocessing of huggingface, we disable it
os.environ["TOKENIZERS_PARALLELISM"] = "false"
self._ingest_work_pool = multiprocessing.pool.ThreadPool(
processes=self.count_workers
)
self._file_to_documents_work_pool = multiprocessing.Pool(
processes=self.count_workers
)
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
logger.info("Ingesting file_name=%s", file_name)
# Running in a single (1) process to release the current
# thread, and take a dedicated CPU core for computation
documents = self._file_to_documents_work_pool.apply(
IngestionHelper.transform_file_into_documents, (file_name, file_data)
)
logger.info(
"Transformed file=%s into count=%s documents", file_name, len(documents)
)
logger.debug("Saving the documents in the index and doc store")
return self._save_docs(documents)
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
# Lightweight threads, used for parallelize the
# underlying IO calls made in the ingestion
documents = list(
itertools.chain.from_iterable(
self._ingest_work_pool.starmap(self.ingest, files)
)
)
return documents
def _save_docs(self, documents: list[Document]) -> list[Document]:
logger.debug("Transforming count=%s documents into nodes", len(documents))
nodes = run_transformations(
documents, # type: ignore[arg-type]
self.transformations,
show_progress=self.show_progress,
)
# Locking the index to avoid concurrent writes
with self._index_thread_lock:
logger.info("Inserting count=%s nodes in the index", len(nodes))
self._index.insert_nodes(nodes, show_progress=True)
for document in documents:
self._index.docstore.set_document_hash(
document.get_doc_id(), document.hash
)
logger.debug("Persisting the index and nodes")
# persist the index and nodes
self._save_index()
logger.debug("Persisted the index and nodes")
return documents
def __del__(self) -> None:
# We need to do the appropriate cleanup of the multiprocessing pools
# when the object is deleted. Using root logger to avoid
# the logger to be deleted before the pool
logging.debug("Closing the ingest work pool")
self._ingest_work_pool.close()
self._ingest_work_pool.join()
self._ingest_work_pool.terminate()
logging.debug("Closing the file to documents work pool")
self._file_to_documents_work_pool.close()
self._file_to_documents_work_pool.join()
self._file_to_documents_work_pool.terminate()
class PipelineIngestComponent(BaseIngestComponentWithIndex):
"""Pipeline ingestion - keeping the embedding worker pool as busy as possible.
This class implements a threaded ingestion pipeline, which comprises two threads
and two queues. The primary thread is responsible for reading and parsing files
into documents. These documents are then placed into a queue, which is
distributed to a pool of worker processes for embedding computation. After
embedding, the documents are transferred to another queue where they are
accumulated until a threshold is reached. Upon reaching this threshold, the
accumulated documents are flushed to the document store, index, and vector
store.
Exception handling ensures robustness against erroneous files. However, in the
pipelined design, one error can lead to the discarding of multiple files. Any
discarded files will be reported.
"""
NODE_FLUSH_COUNT = 5000 # Save the index every # nodes.
def __init__(
self,
storage_context: StorageContext,
embed_model: EmbedType,
transformations: list[TransformComponent],
count_workers: int,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(storage_context, embed_model, transformations, *args, **kwargs)
self.count_workers = count_workers
assert (
len(self.transformations) >= 2
), "Embeddings must be in the transformations"
assert count_workers > 0, "count_workers must be > 0"
self.count_workers = count_workers
# We are doing our own multiprocessing
# To do not collide with the multiprocessing of huggingface, we disable it
os.environ["TOKENIZERS_PARALLELISM"] = "false"
# doc_q stores parsed files as Document chunks.
# Using a shallow queue causes the filesystem parser to block
# when it reaches capacity. This ensures it doesn't outpace the
# computationally intensive embeddings phase, avoiding unnecessary
# memory consumption. The semaphore is used to bound the async worker
# embedding computations to cause the doc Q to fill and block.
self.doc_semaphore = multiprocessing.Semaphore(
self.count_workers
) # limit the doc queue to # items.
self.doc_q: Queue[tuple[str, str | None, list[Document] | None]] = Queue(20)
# node_q stores documents parsed into nodes (embeddings).
# Larger queue size so we don't block the embedding workers during a slow
# index update.
self.node_q: Queue[
tuple[str, str | None, list[Document] | None, list[BaseNode] | None]
] = Queue(40)
threading.Thread(target=self._doc_to_node, daemon=True).start()
threading.Thread(target=self._write_nodes, daemon=True).start()
def _doc_to_node(self) -> None:
# Parse documents into nodes
with multiprocessing.pool.ThreadPool(processes=self.count_workers) as pool:
while True:
try:
cmd, file_name, documents = self.doc_q.get(
block=True
) # Documents for a file
if cmd == "process":
# Push CPU/GPU embedding work to the worker pool
# Acquire semaphore to control access to worker pool
self.doc_semaphore.acquire()
pool.apply_async(
self._doc_to_node_worker, (file_name, documents)
)
elif cmd == "quit":
break
finally:
if cmd != "process":
self.doc_q.task_done() # unblock Q joins
def _doc_to_node_worker(self, file_name: str, documents: list[Document]) -> None:
# CPU/GPU intensive work in its own process
try:
nodes = run_transformations(
documents, # type: ignore[arg-type]
self.transformations,
show_progress=self.show_progress,
)
self.node_q.put(("process", file_name, documents, list(nodes)))
finally:
self.doc_semaphore.release()
self.doc_q.task_done() # unblock Q joins
def _save_docs(
self, files: list[str], documents: list[Document], nodes: list[BaseNode]
) -> None:
try:
logger.info(
f"Saving {len(files)} files ({len(documents)} documents / {len(nodes)} nodes)"
)
self._index.insert_nodes(nodes)
for document in documents:
self._index.docstore.set_document_hash(
document.get_doc_id(), document.hash
)
self._save_index()
except Exception:
# Tell the user so they can investigate these files
logger.exception(f"Processing files {files}")
finally:
# Clearing work, even on exception, maintains a clean state.
nodes.clear()
documents.clear()
files.clear()
def _write_nodes(self) -> None:
# Save nodes to index. I/O intensive.
node_stack: list[BaseNode] = []
doc_stack: list[Document] = []
file_stack: list[str] = []
while True:
try:
cmd, file_name, documents, nodes = self.node_q.get(block=True)
if cmd in ("flush", "quit"):
if file_stack:
self._save_docs(file_stack, doc_stack, node_stack)
if cmd == "quit":
break
elif cmd == "process":
node_stack.extend(nodes) # type: ignore[arg-type]
doc_stack.extend(documents) # type: ignore[arg-type]
file_stack.append(file_name) # type: ignore[arg-type]
# Constant saving is heavy on I/O - accumulate to a threshold
if len(node_stack) >= self.NODE_FLUSH_COUNT:
self._save_docs(file_stack, doc_stack, node_stack)
finally:
self.node_q.task_done()
def _flush(self) -> None:
self.doc_q.put(("flush", None, None))
self.doc_q.join()
self.node_q.put(("flush", None, None, None))
self.node_q.join()
def ingest(self, file_name: str, file_data: Path) -> list[Document]:
documents = IngestionHelper.transform_file_into_documents(file_name, file_data)
self.doc_q.put(("process", file_name, documents))
self._flush()
return documents
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[Document]:
docs = []
for file_name, file_data in eta(files):
try:
documents = IngestionHelper.transform_file_into_documents(
file_name, file_data
)
self.doc_q.put(("process", file_name, documents))
docs.extend(documents)
except Exception:
logger.exception(f"Skipping {file_data.name}")
self._flush()
return docs
def get_ingestion_component(
storage_context: StorageContext,
embed_model: EmbedType,
transformations: list[TransformComponent],
settings: Settings,
) -> BaseIngestComponent:
"""Get the ingestion component for the given configuration."""
ingest_mode = settings.embedding.ingest_mode
if ingest_mode == "batch":
return BatchIngestComponent(
storage_context=storage_context,
embed_model=embed_model,
transformations=transformations,
count_workers=settings.embedding.count_workers,
)
elif ingest_mode == "parallel":
return ParallelizedIngestComponent(
storage_context=storage_context,
embed_model=embed_model,
transformations=transformations,
count_workers=settings.embedding.count_workers,
)
elif ingest_mode == "pipeline":
return PipelineIngestComponent(
storage_context=storage_context,
embed_model=embed_model,
transformations=transformations,
count_workers=settings.embedding.count_workers,
)
else:
return SimpleIngestComponent(
storage_context=storage_context,
embed_model=embed_model,
transformations=transformations,
)

View File

@ -0,0 +1,111 @@
import logging
from pathlib import Path
from llama_index.core.readers import StringIterableReader
from llama_index.core.readers.base import BaseReader
from llama_index.core.readers.json import JSONReader
from llama_index.core.schema import Document
logger = logging.getLogger(__name__)
# Inspired by the `llama_index.core.readers.file.base` module
def _try_loading_included_file_formats() -> dict[str, type[BaseReader]]:
try:
from llama_index.readers.file.docs import ( # type: ignore
DocxReader,
HWPReader,
PDFReader,
)
from llama_index.readers.file.epub import EpubReader # type: ignore
from llama_index.readers.file.image import ImageReader # type: ignore
from llama_index.readers.file.ipynb import IPYNBReader # type: ignore
from llama_index.readers.file.markdown import MarkdownReader # type: ignore
from llama_index.readers.file.mbox import MboxReader # type: ignore
from llama_index.readers.file.slides import PptxReader # type: ignore
from llama_index.readers.file.tabular import PandasCSVReader # type: ignore
from llama_index.readers.file.video_audio import ( # type: ignore
VideoAudioReader,
)
except ImportError as e:
raise ImportError("`llama-index-readers-file` package not found") from e
default_file_reader_cls: dict[str, type[BaseReader]] = {
".hwp": HWPReader,
".pdf": PDFReader,
".docx": DocxReader,
".pptx": PptxReader,
".ppt": PptxReader,
".pptm": PptxReader,
".jpg": ImageReader,
".png": ImageReader,
".jpeg": ImageReader,
".mp3": VideoAudioReader,
".mp4": VideoAudioReader,
".csv": PandasCSVReader,
".epub": EpubReader,
".md": MarkdownReader,
".mbox": MboxReader,
".ipynb": IPYNBReader,
}
return default_file_reader_cls
# Patching the default file reader to support other file types
FILE_READER_CLS = _try_loading_included_file_formats()
FILE_READER_CLS.update(
{
".json": JSONReader,
}
)
class IngestionHelper:
"""Helper class to transform a file into a list of documents.
This class should be used to transform a file into a list of documents.
These methods are thread-safe (and multiprocessing-safe).
"""
@staticmethod
def transform_file_into_documents(
file_name: str, file_data: Path
) -> list[Document]:
documents = IngestionHelper._load_file_to_documents(file_name, file_data)
for document in documents:
document.metadata["file_name"] = file_name
IngestionHelper._exclude_metadata(documents)
return documents
@staticmethod
def _load_file_to_documents(file_name: str, file_data: Path) -> list[Document]:
logger.debug("Transforming file_name=%s into documents", file_name)
extension = Path(file_name).suffix
reader_cls = FILE_READER_CLS.get(extension)
if reader_cls is None:
logger.debug(
"No reader found for extension=%s, using default string reader",
extension,
)
# Read as a plain text
string_reader = StringIterableReader()
return string_reader.load_data([file_data.read_text()])
logger.debug("Specific reader found for extension=%s", extension)
documents = reader_cls().load_data(file_data)
# Sanitize NUL bytes in text which can't be stored in Postgres
for i in range(len(documents)):
documents[i].text = documents[i].text.replace("\u0000", "")
return documents
@staticmethod
def _exclude_metadata(documents: list[Document]) -> None:
logger.debug("Excluding metadata from count=%s documents", len(documents))
for document in documents:
document.metadata["doc_id"] = document.doc_id
# We don't want the Embeddings search to receive this metadata
document.excluded_embed_metadata_keys = ["doc_id"]
# We don't want the LLM to receive these metadata in the context
document.excluded_llm_metadata_keys = ["file_name", "doc_id", "page_label"]

View File

@ -0,0 +1 @@
"""LLM implementations."""

View File

@ -0,0 +1,276 @@
# mypy: ignore-errors
from __future__ import annotations
import io
import json
import logging
from typing import TYPE_CHECKING, Any
import boto3 # type: ignore
from llama_index.core.base.llms.generic_utils import (
completion_response_to_chat_response,
stream_completion_response_to_chat_response,
)
from llama_index.core.bridge.pydantic import Field
from llama_index.core.llms import (
CompletionResponse,
CustomLLM,
LLMMetadata,
)
from llama_index.core.llms.callbacks import (
llm_chat_callback,
llm_completion_callback,
)
if TYPE_CHECKING:
from collections.abc import Sequence
from llama_index.callbacks import CallbackManager
from llama_index.llms import (
ChatMessage,
ChatResponse,
ChatResponseGen,
CompletionResponseGen,
)
logger = logging.getLogger(__name__)
class LineIterator:
r"""A helper class for parsing the byte stream input from TGI container.
The output of the model will be in the following format:
```
b'data:{"token": {"text": " a"}}\n\n'
b'data:{"token": {"text": " challenging"}}\n\n'
b'data:{"token": {"text": " problem"
b'}}'
...
```
While usually each PayloadPart event from the event stream will contain a byte array
with a full json, this is not guaranteed and some of the json objects may be split
across PayloadPart events. For example:
```
{'PayloadPart': {'Bytes': b'{"outputs": '}}
{'PayloadPart': {'Bytes': b'[" problem"]}\n'}}
```
This class accounts for this by concatenating bytes written via the 'write' function
and then exposing a method which will return lines (ending with a '\n' character)
within the buffer via the 'scan_lines' function. It maintains the position of the
last read position to ensure that previous bytes are not exposed again. It will
also save any pending lines that doe not end with a '\n' to make sure truncations
are concatinated
"""
def __init__(self, stream: Any) -> None:
"""Line iterator initializer."""
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self) -> Any:
"""Self iterator."""
return self
def __next__(self) -> Any:
"""Next element from iterator."""
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
logger.warning("Unknown event type=%s", chunk)
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
class SagemakerLLM(CustomLLM):
"""Sagemaker Inference Endpoint models.
To use, you must supply the endpoint name from your deployed
Sagemaker model & the region where it is deployed.
To authenticate, the AWS client uses the following methods to
automatically load credentials:
https://boto3.amazonaws.com/v1/documentation/api/latest/guide/credentials.html
If a specific credential profile should be used, you must pass
the name of the profile from the ~/.aws/credentials file that is to be used.
Make sure the credentials / roles used have the required policies to
access the Sagemaker endpoint.
See: https://docs.aws.amazon.com/IAM/latest/UserGuide/access_policies.html
"""
endpoint_name: str = Field(description="")
temperature: float = Field(description="The temperature to use for sampling.")
max_new_tokens: int = Field(description="The maximum number of tokens to generate.")
context_window: int = Field(
description="The maximum number of context tokens for the model."
)
messages_to_prompt: Any = Field(
description="The function to convert messages to a prompt.", exclude=True
)
completion_to_prompt: Any = Field(
description="The function to convert a completion to a prompt.", exclude=True
)
generate_kwargs: dict[str, Any] = Field(
default_factory=dict, description="Kwargs used for generation."
)
model_kwargs: dict[str, Any] = Field(
default_factory=dict, description="Kwargs used for model initialization."
)
verbose: bool = Field(description="Whether to print verbose output.")
_boto_client: Any = boto3.client(
"sagemaker-runtime",
) # TODO make it an optional field
def __init__(
self,
endpoint_name: str | None = "",
temperature: float = 0.1,
max_new_tokens: int = 512, # to review defaults
context_window: int = 2048, # to review defaults
messages_to_prompt: Any = None,
completion_to_prompt: Any = None,
callback_manager: CallbackManager | None = None,
generate_kwargs: dict[str, Any] | None = None,
model_kwargs: dict[str, Any] | None = None,
verbose: bool = True,
) -> None:
"""SagemakerLLM initializer."""
model_kwargs = model_kwargs or {}
model_kwargs.update({"n_ctx": context_window, "verbose": verbose})
messages_to_prompt = messages_to_prompt or {}
completion_to_prompt = completion_to_prompt or {}
generate_kwargs = generate_kwargs or {}
generate_kwargs.update(
{"temperature": temperature, "max_tokens": max_new_tokens}
)
super().__init__(
endpoint_name=endpoint_name,
temperature=temperature,
context_window=context_window,
max_new_tokens=max_new_tokens,
messages_to_prompt=messages_to_prompt,
completion_to_prompt=completion_to_prompt,
callback_manager=callback_manager,
generate_kwargs=generate_kwargs,
model_kwargs=model_kwargs,
verbose=verbose,
)
@property
def inference_params(self):
# TODO expose the rest of params
return {
"do_sample": True,
"top_p": 0.7,
"temperature": self.temperature,
"top_k": 50,
"max_new_tokens": self.max_new_tokens,
}
@property
def metadata(self) -> LLMMetadata:
"""Get LLM metadata."""
return LLMMetadata(
context_window=self.context_window,
num_output=self.max_new_tokens,
model_name="Sagemaker LLama 2",
)
@llm_completion_callback()
def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse:
self.generate_kwargs.update({"stream": False})
is_formatted = kwargs.pop("formatted", False)
if not is_formatted:
prompt = self.completion_to_prompt(prompt)
request_params = {
"inputs": prompt,
"stream": False,
"parameters": self.inference_params,
}
resp = self._boto_client.invoke_endpoint(
EndpointName=self.endpoint_name,
Body=json.dumps(request_params),
ContentType="application/json",
)
response_body = resp["Body"]
response_str = response_body.read().decode("utf-8")
response_dict = json.loads(response_str)
return CompletionResponse(
text=response_dict[0]["generated_text"][len(prompt) :], raw=resp
)
@llm_completion_callback()
def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen:
def get_stream():
text = ""
request_params = {
"inputs": prompt,
"stream": True,
"parameters": self.inference_params,
}
resp = self._boto_client.invoke_endpoint_with_response_stream(
EndpointName=self.endpoint_name,
Body=json.dumps(request_params),
ContentType="application/json",
)
event_stream = resp["Body"]
start_json = b"{"
stop_token = "<|endoftext|>"
first_token = True
for line in LineIterator(event_stream):
if line != b"" and start_json in line:
data = json.loads(line[line.find(start_json) :].decode("utf-8"))
special = data["token"]["special"]
stop = data["token"]["text"] == stop_token
if not special and not stop:
delta = data["token"]["text"]
# trim the leading space for the first token if present
if first_token:
delta = delta.lstrip()
first_token = False
text += delta
yield CompletionResponse(delta=delta, text=text, raw=data)
return get_stream()
@llm_chat_callback()
def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse:
prompt = self.messages_to_prompt(messages)
completion_response = self.complete(prompt, formatted=True, **kwargs)
return completion_response_to_chat_response(completion_response)
@llm_chat_callback()
def stream_chat(
self, messages: Sequence[ChatMessage], **kwargs: Any
) -> ChatResponseGen:
prompt = self.messages_to_prompt(messages)
completion_response = self.stream_complete(prompt, formatted=True, **kwargs)
return stream_completion_response_to_chat_response(completion_response)

View File

@ -0,0 +1,225 @@
import logging
from collections.abc import Callable
from typing import Any
from injector import inject, singleton
from llama_index.core.llms import LLM, MockLLM
from llama_index.core.settings import Settings as LlamaIndexSettings
from llama_index.core.utils import set_global_tokenizer
from transformers import AutoTokenizer # type: ignore
from private_gpt.components.llm.prompt_helper import get_prompt_style
from private_gpt.paths import models_cache_path, models_path
from private_gpt.settings.settings import Settings
logger = logging.getLogger(__name__)
@singleton
class LLMComponent:
llm: LLM
@inject
def __init__(self, settings: Settings) -> None:
llm_mode = settings.llm.mode
if settings.llm.tokenizer and settings.llm.mode != "mock":
# Try to download the tokenizer. If it fails, the LLM will still work
# using the default one, which is less accurate.
try:
set_global_tokenizer(
AutoTokenizer.from_pretrained(
pretrained_model_name_or_path=settings.llm.tokenizer,
cache_dir=str(models_cache_path),
token=settings.huggingface.access_token,
)
)
except Exception as e:
logger.warning(
f"Failed to download tokenizer {settings.llm.tokenizer}: {e!s}"
f"Please follow the instructions in the documentation to download it if needed: "
f"https://docs.privategpt.dev/installation/getting-started/troubleshooting#tokenizer-setup."
f"Falling back to default tokenizer."
)
logger.info("Initializing the LLM in mode=%s", llm_mode)
match settings.llm.mode:
case "llamacpp":
try:
from llama_index.llms.llama_cpp import LlamaCPP # type: ignore
except ImportError as e:
raise ImportError(
"Local dependencies not found, install with `poetry install --extras llms-llama-cpp`"
) from e
prompt_style = get_prompt_style(settings.llm.prompt_style)
settings_kwargs = {
"tfs_z": settings.llamacpp.tfs_z, # ollama and llama-cpp
"top_k": settings.llamacpp.top_k, # ollama and llama-cpp
"top_p": settings.llamacpp.top_p, # ollama and llama-cpp
"repeat_penalty": settings.llamacpp.repeat_penalty, # ollama llama-cpp
"n_gpu_layers": -1,
"offload_kqv": True,
}
self.llm = LlamaCPP(
model_path=str(models_path / settings.llamacpp.llm_hf_model_file),
temperature=settings.llm.temperature,
max_new_tokens=settings.llm.max_new_tokens,
context_window=settings.llm.context_window,
generate_kwargs={},
callback_manager=LlamaIndexSettings.callback_manager,
# All to GPU
model_kwargs=settings_kwargs,
# transform inputs into Llama2 format
messages_to_prompt=prompt_style.messages_to_prompt,
completion_to_prompt=prompt_style.completion_to_prompt,
verbose=True,
)
case "sagemaker":
try:
from private_gpt.components.llm.custom.sagemaker import SagemakerLLM
except ImportError as e:
raise ImportError(
"Sagemaker dependencies not found, install with `poetry install --extras llms-sagemaker`"
) from e
self.llm = SagemakerLLM(
endpoint_name=settings.sagemaker.llm_endpoint_name,
max_new_tokens=settings.llm.max_new_tokens,
context_window=settings.llm.context_window,
)
case "openai":
try:
from llama_index.llms.openai import OpenAI # type: ignore
except ImportError as e:
raise ImportError(
"OpenAI dependencies not found, install with `poetry install --extras llms-openai`"
) from e
openai_settings = settings.openai
self.llm = OpenAI(
api_base=openai_settings.api_base,
api_key=openai_settings.api_key,
model=openai_settings.model,
)
case "openailike":
try:
from llama_index.llms.openai_like import OpenAILike # type: ignore
except ImportError as e:
raise ImportError(
"OpenAILike dependencies not found, install with `poetry install --extras llms-openai-like`"
) from e
prompt_style = get_prompt_style(settings.llm.prompt_style)
openai_settings = settings.openai
self.llm = OpenAILike(
api_base=openai_settings.api_base,
api_key=openai_settings.api_key,
model=openai_settings.model,
is_chat_model=True,
max_tokens=settings.llm.max_new_tokens,
api_version="",
temperature=settings.llm.temperature,
context_window=settings.llm.context_window,
messages_to_prompt=prompt_style.messages_to_prompt,
completion_to_prompt=prompt_style.completion_to_prompt,
tokenizer=settings.llm.tokenizer,
timeout=openai_settings.request_timeout,
reuse_client=False,
)
case "ollama":
try:
from llama_index.llms.ollama import Ollama # type: ignore
except ImportError as e:
raise ImportError(
"Ollama dependencies not found, install with `poetry install --extras llms-ollama`"
) from e
ollama_settings = settings.ollama
settings_kwargs = {
"tfs_z": ollama_settings.tfs_z, # ollama and llama-cpp
"num_predict": ollama_settings.num_predict, # ollama only
"top_k": ollama_settings.top_k, # ollama and llama-cpp
"top_p": ollama_settings.top_p, # ollama and llama-cpp
"repeat_last_n": ollama_settings.repeat_last_n, # ollama
"repeat_penalty": ollama_settings.repeat_penalty, # ollama llama-cpp
}
# calculate llm model. If not provided tag, it will be use latest
model_name = (
ollama_settings.llm_model + ":latest"
if ":" not in ollama_settings.llm_model
else ollama_settings.llm_model
)
llm = Ollama(
model=model_name,
base_url=ollama_settings.api_base,
temperature=settings.llm.temperature,
context_window=settings.llm.context_window,
additional_kwargs=settings_kwargs,
request_timeout=ollama_settings.request_timeout,
)
if ollama_settings.autopull_models:
from private_gpt.utils.ollama import check_connection, pull_model
if not check_connection(llm.client):
raise ValueError(
f"Failed to connect to Ollama, "
f"check if Ollama server is running on {ollama_settings.api_base}"
)
pull_model(llm.client, model_name)
if (
ollama_settings.keep_alive
!= ollama_settings.model_fields["keep_alive"].default
):
# Modify Ollama methods to use the "keep_alive" field.
def add_keep_alive(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapper(*args: Any, **kwargs: Any) -> Any:
kwargs["keep_alive"] = ollama_settings.keep_alive
return func(*args, **kwargs)
return wrapper
Ollama.chat = add_keep_alive(Ollama.chat) # type: ignore
Ollama.stream_chat = add_keep_alive(Ollama.stream_chat) # type: ignore
Ollama.complete = add_keep_alive(Ollama.complete) # type: ignore
Ollama.stream_complete = add_keep_alive(Ollama.stream_complete) # type: ignore
self.llm = llm
case "azopenai":
try:
from llama_index.llms.azure_openai import ( # type: ignore
AzureOpenAI,
)
except ImportError as e:
raise ImportError(
"Azure OpenAI dependencies not found, install with `poetry install --extras llms-azopenai`"
) from e
azopenai_settings = settings.azopenai
self.llm = AzureOpenAI(
model=azopenai_settings.llm_model,
deployment_name=azopenai_settings.llm_deployment_name,
api_key=azopenai_settings.api_key,
azure_endpoint=azopenai_settings.azure_endpoint,
api_version=azopenai_settings.api_version,
)
case "gemini":
try:
from llama_index.llms.gemini import ( # type: ignore
Gemini,
)
except ImportError as e:
raise ImportError(
"Google Gemini dependencies not found, install with `poetry install --extras llms-gemini`"
) from e
gemini_settings = settings.gemini
self.llm = Gemini(
model_name=gemini_settings.model, api_key=gemini_settings.api_key
)
case "mock":
self.llm = MockLLM()

View File

@ -0,0 +1,310 @@
import abc
import logging
from collections.abc import Sequence
from typing import Any, Literal
from llama_index.core.llms import ChatMessage, MessageRole
logger = logging.getLogger(__name__)
class AbstractPromptStyle(abc.ABC):
"""Abstract class for prompt styles.
This class is used to format a series of messages into a prompt that can be
understood by the models. A series of messages represents the interaction(s)
between a user and an assistant. This series of messages can be considered as a
session between a user X and an assistant Y.This session holds, through the
messages, the state of the conversation. This session, to be understood by the
model, needs to be formatted into a prompt (i.e. a string that the models
can understand). Prompts can be formatted in different ways,
depending on the model.
The implementations of this class represent the different ways to format a
series of messages into a prompt.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
logger.debug("Initializing prompt_style=%s", self.__class__.__name__)
@abc.abstractmethod
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
pass
@abc.abstractmethod
def _completion_to_prompt(self, completion: str) -> str:
pass
def messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = self._messages_to_prompt(messages)
logger.debug("Got for messages='%s' the prompt='%s'", messages, prompt)
return prompt
def completion_to_prompt(self, prompt: str) -> str:
completion = prompt # Fix: Llama-index parameter has to be named as prompt
prompt = self._completion_to_prompt(completion)
logger.debug("Got for completion='%s' the prompt='%s'", completion, prompt)
return prompt
class DefaultPromptStyle(AbstractPromptStyle):
"""Default prompt style that uses the defaults from llama_utils.
It basically passes None to the LLM, indicating it should use
the default functions.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
# Hacky way to override the functions
# Override the functions to be None, and pass None to the LLM.
self.messages_to_prompt = None # type: ignore[method-assign, assignment]
self.completion_to_prompt = None # type: ignore[method-assign, assignment]
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
return ""
def _completion_to_prompt(self, completion: str) -> str:
return ""
class Llama2PromptStyle(AbstractPromptStyle):
"""Simple prompt style that uses llama 2 prompt style.
Inspired by llama_index/legacy/llms/llama_utils.py
It transforms the sequence of messages into a prompt that should look like:
```text
<s> [INST] <<SYS>> your system prompt here. <</SYS>>
user message here [/INST] assistant (model) response here </s>
```
"""
BOS, EOS = "<s>", "</s>"
B_INST, E_INST = "[INST]", "[/INST]"
B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. \
Always answer as helpfully as possible and follow ALL given instructions. \
Do not speculate or make up information. \
Do not reference any given instructions or context. \
"""
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
string_messages: list[str] = []
if messages[0].role == MessageRole.SYSTEM:
# pull out the system message (if it exists in messages)
system_message_str = messages[0].content or ""
messages = messages[1:]
else:
system_message_str = self.DEFAULT_SYSTEM_PROMPT
system_message_str = f"{self.B_SYS} {system_message_str.strip()} {self.E_SYS}"
for i in range(0, len(messages), 2):
# first message should always be a user
user_message = messages[i]
assert user_message.role == MessageRole.USER
if i == 0:
# make sure system prompt is included at the start
str_message = f"{self.BOS} {self.B_INST} {system_message_str} "
else:
# end previous user-assistant interaction
string_messages[-1] += f" {self.EOS}"
# no need to include system prompt
str_message = f"{self.BOS} {self.B_INST} "
# include user message content
str_message += f"{user_message.content} {self.E_INST}"
if len(messages) > (i + 1):
# if assistant message exists, add to str_message
assistant_message = messages[i + 1]
assert assistant_message.role == MessageRole.ASSISTANT
str_message += f" {assistant_message.content}"
string_messages.append(str_message)
return "".join(string_messages)
def _completion_to_prompt(self, completion: str) -> str:
system_prompt_str = self.DEFAULT_SYSTEM_PROMPT
return (
f"{self.BOS} {self.B_INST} {self.B_SYS} {system_prompt_str.strip()} {self.E_SYS} "
f"{completion.strip()} {self.E_INST}"
)
class Llama3PromptStyle(AbstractPromptStyle):
r"""Template for Meta's Llama 3.1.
The format follows this structure:
<|begin_of_text|>
<|start_header_id|>system<|end_header_id|>
[System message content]<|eot_id|>
<|start_header_id|>user<|end_header_id|>
[User message content]<|eot_id|>
<|start_header_id|>assistant<|end_header_id|>
[Assistant message content]<|eot_id|>
...
(Repeat for each message, including possible 'ipython' role)
"""
BOS, EOS = "<|begin_of_text|>", "<|end_of_text|>"
B_INST, E_INST = "<|start_header_id|>", "<|end_header_id|>"
EOT = "<|eot_id|>"
B_SYS, E_SYS = "<|start_header_id|>system<|end_header_id|>", "<|eot_id|>"
ASSISTANT_INST = "<|start_header_id|>assistant<|end_header_id|>"
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. \
Always answer as helpfully as possible and follow ALL given instructions. \
Do not speculate or make up information. \
Do not reference any given instructions or context. \
"""
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = ""
has_system_message = False
for i, message in enumerate(messages):
if not message or message.content is None:
continue
if message.role == MessageRole.SYSTEM:
prompt += f"{self.B_SYS}\n\n{message.content.strip()}{self.E_SYS}"
has_system_message = True
else:
role_header = f"{self.B_INST}{message.role.value}{self.E_INST}"
prompt += f"{role_header}\n\n{message.content.strip()}{self.EOT}"
# Add assistant header if the last message is not from the assistant
if i == len(messages) - 1 and message.role != MessageRole.ASSISTANT:
prompt += f"{self.ASSISTANT_INST}\n\n"
# Add default system prompt if no system message was provided
if not has_system_message:
prompt = (
f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}" + prompt
)
# TODO: Implement tool handling logic
return prompt
def _completion_to_prompt(self, completion: str) -> str:
return (
f"{self.B_SYS}\n\n{self.DEFAULT_SYSTEM_PROMPT}{self.E_SYS}"
f"{self.B_INST}user{self.E_INST}\n\n{completion.strip()}{self.EOT}"
f"{self.ASSISTANT_INST}\n\n"
)
class TagPromptStyle(AbstractPromptStyle):
"""Tag prompt style (used by Vigogne) that uses the prompt style `<|ROLE|>`.
It transforms the sequence of messages into a prompt that should look like:
```text
<|system|>: your system prompt here.
<|user|>: user message here
(possibly with context and question)
<|assistant|>: assistant (model) response here.
```
FIXME: should we add surrounding `<s>` and `</s>` tags, like in llama2?
"""
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
"""Format message to prompt with `<|ROLE|>: MSG` style."""
prompt = ""
for message in messages:
role = message.role
content = message.content or ""
message_from_user = f"<|{role.lower()}|>: {content.strip()}"
message_from_user += "\n"
prompt += message_from_user
# we are missing the last <|assistant|> tag that will trigger a completion
prompt += "<|assistant|>: "
return prompt
def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt(
[ChatMessage(content=completion, role=MessageRole.USER)]
)
class MistralPromptStyle(AbstractPromptStyle):
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
inst_buffer = []
text = ""
for message in messages:
if message.role == MessageRole.SYSTEM or message.role == MessageRole.USER:
inst_buffer.append(str(message.content).strip())
elif message.role == MessageRole.ASSISTANT:
text += "<s>[INST] " + "\n".join(inst_buffer) + " [/INST]"
text += " " + str(message.content).strip() + "</s>"
inst_buffer.clear()
else:
raise ValueError(f"Unknown message role {message.role}")
if len(inst_buffer) > 0:
text += "<s>[INST] " + "\n".join(inst_buffer) + " [/INST]"
return text
def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt(
[ChatMessage(content=completion, role=MessageRole.USER)]
)
class ChatMLPromptStyle(AbstractPromptStyle):
def _messages_to_prompt(self, messages: Sequence[ChatMessage]) -> str:
prompt = "<|im_start|>system\n"
for message in messages:
role = message.role
content = message.content or ""
if role.lower() == "system":
message_from_user = f"{content.strip()}"
prompt += message_from_user
elif role.lower() == "user":
prompt += "<|im_end|>\n<|im_start|>user\n"
message_from_user = f"{content.strip()}<|im_end|>\n"
prompt += message_from_user
prompt += "<|im_start|>assistant\n"
return prompt
def _completion_to_prompt(self, completion: str) -> str:
return self._messages_to_prompt(
[ChatMessage(content=completion, role=MessageRole.USER)]
)
def get_prompt_style(
prompt_style: (
Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"] | None
)
) -> AbstractPromptStyle:
"""Get the prompt style to use from the given string.
:param prompt_style: The prompt style to use.
:return: The prompt style to use.
"""
if prompt_style is None or prompt_style == "default":
return DefaultPromptStyle()
elif prompt_style == "llama2":
return Llama2PromptStyle()
elif prompt_style == "llama3":
return Llama3PromptStyle()
elif prompt_style == "tag":
return TagPromptStyle()
elif prompt_style == "mistral":
return MistralPromptStyle()
elif prompt_style == "chatml":
return ChatMLPromptStyle()
raise ValueError(f"Unknown prompt_style='{prompt_style}'")

View File

@ -0,0 +1,68 @@
import logging
from injector import inject, singleton
from llama_index.core.storage.docstore import BaseDocumentStore, SimpleDocumentStore
from llama_index.core.storage.index_store import SimpleIndexStore
from llama_index.core.storage.index_store.types import BaseIndexStore
from private_gpt.paths import local_data_path
from private_gpt.settings.settings import Settings
logger = logging.getLogger(__name__)
@singleton
class NodeStoreComponent:
index_store: BaseIndexStore
doc_store: BaseDocumentStore
@inject
def __init__(self, settings: Settings) -> None:
match settings.nodestore.database:
case "simple":
try:
self.index_store = SimpleIndexStore.from_persist_dir(
persist_dir=str(local_data_path)
)
except FileNotFoundError:
logger.debug("Local index store not found, creating a new one")
self.index_store = SimpleIndexStore()
try:
self.doc_store = SimpleDocumentStore.from_persist_dir(
persist_dir=str(local_data_path)
)
except FileNotFoundError:
logger.debug("Local document store not found, creating a new one")
self.doc_store = SimpleDocumentStore()
case "postgres":
try:
from llama_index.storage.docstore.postgres import ( # type: ignore
PostgresDocumentStore,
)
from llama_index.storage.index_store.postgres import ( # type: ignore
PostgresIndexStore,
)
except ImportError:
raise ImportError(
"Postgres dependencies not found, install with `poetry install --extras storage-nodestore-postgres`"
) from None
if settings.postgres is None:
raise ValueError("Postgres index/doc store settings not found.")
self.index_store = PostgresIndexStore.from_params(
**settings.postgres.model_dump(exclude_none=True)
)
self.doc_store = PostgresDocumentStore.from_params(
**settings.postgres.model_dump(exclude_none=True)
)
case _:
# Should be unreachable
# The settings validator should have caught this
raise ValueError(
f"Database {settings.nodestore.database} not supported"
)

View File

@ -0,0 +1,106 @@
from collections.abc import Generator, Sequence
from typing import TYPE_CHECKING, Any
from llama_index.core.schema import BaseNode, MetadataMode
from llama_index.core.vector_stores.utils import node_to_metadata_dict
from llama_index.vector_stores.chroma import ChromaVectorStore # type: ignore
if TYPE_CHECKING:
from collections.abc import Mapping
def chunk_list(
lst: Sequence[BaseNode], max_chunk_size: int
) -> Generator[Sequence[BaseNode], None, None]:
"""Yield successive max_chunk_size-sized chunks from lst.
Args:
lst (List[BaseNode]): list of nodes with embeddings
max_chunk_size (int): max chunk size
Yields:
Generator[List[BaseNode], None, None]: list of nodes with embeddings
"""
for i in range(0, len(lst), max_chunk_size):
yield lst[i : i + max_chunk_size]
class BatchedChromaVectorStore(ChromaVectorStore): # type: ignore
"""Chroma vector store, batching additions to avoid reaching the max batch limit.
In this vector store, embeddings are stored within a ChromaDB collection.
During query time, the index uses ChromaDB to query for the top
k most similar nodes.
Args:
chroma_client (from chromadb.api.API):
API instance
chroma_collection (chromadb.api.models.Collection.Collection):
ChromaDB collection instance
"""
chroma_client: Any | None
def __init__(
self,
chroma_client: Any,
chroma_collection: Any,
host: str | None = None,
port: str | None = None,
ssl: bool = False,
headers: dict[str, str] | None = None,
collection_kwargs: dict[Any, Any] | None = None,
) -> None:
super().__init__(
chroma_collection=chroma_collection,
host=host,
port=port,
ssl=ssl,
headers=headers,
collection_kwargs=collection_kwargs or {},
)
self.chroma_client = chroma_client
def add(self, nodes: Sequence[BaseNode], **add_kwargs: Any) -> list[str]:
"""Add nodes to index, batching the insertion to avoid issues.
Args:
nodes: List[BaseNode]: list of nodes with embeddings
add_kwargs: _
"""
if not self.chroma_client:
raise ValueError("Client not initialized")
if not self._collection:
raise ValueError("Collection not initialized")
max_chunk_size = self.chroma_client.max_batch_size
node_chunks = chunk_list(nodes, max_chunk_size)
all_ids = []
for node_chunk in node_chunks:
embeddings: list[Sequence[float]] = []
metadatas: list[Mapping[str, Any]] = []
ids = []
documents = []
for node in node_chunk:
embeddings.append(node.get_embedding())
metadatas.append(
node_to_metadata_dict(
node, remove_text=True, flat_metadata=self.flat_metadata
)
)
ids.append(node.node_id)
documents.append(node.get_content(metadata_mode=MetadataMode.NONE))
self._collection.add(
embeddings=embeddings,
ids=ids,
metadatas=metadatas,
documents=documents,
)
all_ids.extend(ids)
return all_ids

View File

@ -0,0 +1,217 @@
import logging
import typing
from injector import inject, singleton
from llama_index.core.indices.vector_store import VectorIndexRetriever, VectorStoreIndex
from llama_index.core.vector_stores.types import (
BasePydanticVectorStore,
FilterCondition,
MetadataFilter,
MetadataFilters,
)
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.paths import local_data_path
from private_gpt.settings.settings import Settings
logger = logging.getLogger(__name__)
def _doc_id_metadata_filter(
context_filter: ContextFilter | None,
) -> MetadataFilters:
filters = MetadataFilters(filters=[], condition=FilterCondition.OR)
if context_filter is not None and context_filter.docs_ids is not None:
for doc_id in context_filter.docs_ids:
filters.filters.append(MetadataFilter(key="doc_id", value=doc_id))
return filters
@singleton
class VectorStoreComponent:
settings: Settings
vector_store: BasePydanticVectorStore
@inject
def __init__(self, settings: Settings) -> None:
self.settings = settings
match settings.vectorstore.database:
case "postgres":
try:
from llama_index.vector_stores.postgres import ( # type: ignore
PGVectorStore,
)
except ImportError as e:
raise ImportError(
"Postgres dependencies not found, install with `poetry install --extras vector-stores-postgres`"
) from e
if settings.postgres is None:
raise ValueError(
"Postgres settings not found. Please provide settings."
)
self.vector_store = typing.cast(
BasePydanticVectorStore,
PGVectorStore.from_params(
**settings.postgres.model_dump(exclude_none=True),
table_name="embeddings",
embed_dim=settings.embedding.embed_dim,
),
)
case "chroma":
try:
import chromadb # type: ignore
from chromadb.config import ( # type: ignore
Settings as ChromaSettings,
)
from private_gpt.components.vector_store.batched_chroma import (
BatchedChromaVectorStore,
)
except ImportError as e:
raise ImportError(
"ChromaDB dependencies not found, install with `poetry install --extras vector-stores-chroma`"
) from e
chroma_settings = ChromaSettings(anonymized_telemetry=False)
chroma_client = chromadb.PersistentClient(
path=str((local_data_path / "chroma_db").absolute()),
settings=chroma_settings,
)
chroma_collection = chroma_client.get_or_create_collection(
"make_this_parameterizable_per_api_call"
) # TODO
self.vector_store = typing.cast(
BasePydanticVectorStore,
BatchedChromaVectorStore(
chroma_client=chroma_client, chroma_collection=chroma_collection
),
)
case "qdrant":
try:
from llama_index.vector_stores.qdrant import ( # type: ignore
QdrantVectorStore,
)
from qdrant_client import QdrantClient # type: ignore
except ImportError as e:
raise ImportError(
"Qdrant dependencies not found, install with `poetry install --extras vector-stores-qdrant`"
) from e
if settings.qdrant is None:
logger.info(
"Qdrant config not found. Using default settings."
"Trying to connect to Qdrant at localhost:6333."
)
client = QdrantClient()
else:
client = QdrantClient(
**settings.qdrant.model_dump(exclude_none=True)
)
self.vector_store = typing.cast(
BasePydanticVectorStore,
QdrantVectorStore(
client=client,
collection_name="make_this_parameterizable_per_api_call",
), # TODO
)
case "milvus":
try:
from llama_index.vector_stores.milvus import ( # type: ignore
MilvusVectorStore,
)
except ImportError as e:
raise ImportError(
"Milvus dependencies not found, install with `poetry install --extras vector-stores-milvus`"
) from e
if settings.milvus is None:
logger.info(
"Milvus config not found. Using default settings.\n"
"Trying to connect to Milvus at local_data/private_gpt/milvus/milvus_local.db "
"with collection 'make_this_parameterizable_per_api_call'."
)
self.vector_store = typing.cast(
BasePydanticVectorStore,
MilvusVectorStore(
dim=settings.embedding.embed_dim,
collection_name="make_this_parameterizable_per_api_call",
overwrite=True,
),
)
else:
self.vector_store = typing.cast(
BasePydanticVectorStore,
MilvusVectorStore(
dim=settings.embedding.embed_dim,
uri=settings.milvus.uri,
token=settings.milvus.token,
collection_name=settings.milvus.collection_name,
overwrite=settings.milvus.overwrite,
),
)
case "clickhouse":
try:
from clickhouse_connect import ( # type: ignore
get_client,
)
from llama_index.vector_stores.clickhouse import ( # type: ignore
ClickHouseVectorStore,
)
except ImportError as e:
raise ImportError(
"ClickHouse dependencies not found, install with `poetry install --extras vector-stores-clickhouse`"
) from e
if settings.clickhouse is None:
raise ValueError(
"ClickHouse settings not found. Please provide settings."
)
clickhouse_client = get_client(
host=settings.clickhouse.host,
port=settings.clickhouse.port,
username=settings.clickhouse.username,
password=settings.clickhouse.password,
)
self.vector_store = ClickHouseVectorStore(
clickhouse_client=clickhouse_client
)
case _:
# Should be unreachable
# The settings validator should have caught this
raise ValueError(
f"Vectorstore database {settings.vectorstore.database} not supported"
)
def get_retriever(
self,
index: VectorStoreIndex,
context_filter: ContextFilter | None = None,
similarity_top_k: int = 2,
) -> VectorIndexRetriever:
# This way we support qdrant (using doc_ids) and the rest (using filters)
return VectorIndexRetriever(
index=index,
similarity_top_k=similarity_top_k,
doc_ids=context_filter.docs_ids if context_filter else None,
filters=(
_doc_id_metadata_filter(context_filter)
if self.settings.vectorstore.database != "qdrant"
else None
),
)
def close(self) -> None:
if hasattr(self.vector_store.client, "close"):
self.vector_store.client.close()

View File

@ -0,0 +1,3 @@
from pathlib import Path
PROJECT_ROOT_PATH: Path = Path(__file__).parents[1]

19
pgpt/private_gpt/di.py Normal file
View File

@ -0,0 +1,19 @@
from injector import Injector
from private_gpt.settings.settings import Settings, unsafe_typed_settings
def create_application_injector() -> Injector:
_injector = Injector(auto_bind=True)
_injector.binder.bind(Settings, to=unsafe_typed_settings)
return _injector
"""
Global injector for the application.
Avoid using this reference, it will make your code harder to test.
Instead, use the `request.state.injector` reference, which is bound to every request
"""
global_injector: Injector = create_application_injector()

View File

@ -0,0 +1,69 @@
"""FastAPI app creation, logger configuration and main API routes."""
import logging
from fastapi import Depends, FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from injector import Injector
from llama_index.core.callbacks import CallbackManager
from llama_index.core.callbacks.global_handlers import create_global_handler
from llama_index.core.settings import Settings as LlamaIndexSettings
from private_gpt.server.chat.chat_router import chat_router
from private_gpt.server.chunks.chunks_router import chunks_router
from private_gpt.server.completions.completions_router import completions_router
from private_gpt.server.embeddings.embeddings_router import embeddings_router
from private_gpt.server.health.health_router import health_router
from private_gpt.server.ingest.ingest_router import ingest_router
from private_gpt.server.recipes.summarize.summarize_router import summarize_router
from private_gpt.settings.settings import Settings
logger = logging.getLogger(__name__)
def create_app(root_injector: Injector) -> FastAPI:
# Start the API
async def bind_injector_to_request(request: Request) -> None:
request.state.injector = root_injector
app = FastAPI(dependencies=[Depends(bind_injector_to_request)])
app.include_router(completions_router)
app.include_router(chat_router)
app.include_router(chunks_router)
app.include_router(ingest_router)
app.include_router(summarize_router)
app.include_router(embeddings_router)
app.include_router(health_router)
# Add LlamaIndex simple observability
global_handler = create_global_handler("simple")
if global_handler:
LlamaIndexSettings.callback_manager = CallbackManager([global_handler])
settings = root_injector.get(Settings)
if settings.server.cors.enabled:
logger.debug("Setting up CORS middleware")
app.add_middleware(
CORSMiddleware,
allow_credentials=settings.server.cors.allow_credentials,
allow_origins=settings.server.cors.allow_origins,
allow_origin_regex=settings.server.cors.allow_origin_regex,
allow_methods=settings.server.cors.allow_methods,
allow_headers=settings.server.cors.allow_headers,
)
if settings.ui.enabled:
logger.debug("Importing the UI module")
try:
from private_gpt.ui.ui import PrivateGptUi
except ImportError as e:
raise ImportError(
"UI dependencies not found, install with `poetry install --extras ui`"
) from e
ui = root_injector.get(PrivateGptUi)
ui.mount_in_app(app, settings.ui.path)
return app

6
pgpt/private_gpt/main.py Normal file
View File

@ -0,0 +1,6 @@
"""FastAPI app creation, logger configuration and main API routes."""
from private_gpt.di import global_injector
from private_gpt.launcher import create_app
app = create_app(global_injector)

View File

@ -0,0 +1 @@
"""OpenAI compatibility utilities."""

View File

@ -0,0 +1 @@
"""OpenAI API extensions."""

View File

@ -0,0 +1,7 @@
from pydantic import BaseModel, Field
class ContextFilter(BaseModel):
docs_ids: list[str] | None = Field(
examples=[["c202d5e6-7b69-4869-81cc-dd574ee8ee11"]]
)

View File

@ -0,0 +1,122 @@
import time
import uuid
from collections.abc import Iterator
from typing import Literal
from llama_index.core.llms import ChatResponse, CompletionResponse
from pydantic import BaseModel, Field
from private_gpt.server.chunks.chunks_service import Chunk
class OpenAIDelta(BaseModel):
"""A piece of completion that needs to be concatenated to get the full message."""
content: str | None
class OpenAIMessage(BaseModel):
"""Inference result, with the source of the message.
Role could be the assistant or system
(providing a default response, not AI generated).
"""
role: Literal["assistant", "system", "user"] = Field(default="user")
content: str | None
class OpenAIChoice(BaseModel):
"""Response from AI.
Either the delta or the message will be present, but never both.
Sources used will be returned in case context retrieval was enabled.
"""
finish_reason: str | None = Field(examples=["stop"])
delta: OpenAIDelta | None = None
message: OpenAIMessage | None = None
sources: list[Chunk] | None = None
index: int = 0
class OpenAICompletion(BaseModel):
"""Clone of OpenAI Completion model.
For more information see: https://platform.openai.com/docs/api-reference/chat/object
"""
id: str
object: Literal["completion", "completion.chunk"] = Field(default="completion")
created: int = Field(..., examples=[1623340000])
model: Literal["private-gpt"]
choices: list[OpenAIChoice]
@classmethod
def from_text(
cls,
text: str | None,
finish_reason: str | None = None,
sources: list[Chunk] | None = None,
) -> "OpenAICompletion":
return OpenAICompletion(
id=str(uuid.uuid4()),
object="completion",
created=int(time.time()),
model="private-gpt",
choices=[
OpenAIChoice(
message=OpenAIMessage(role="assistant", content=text),
finish_reason=finish_reason,
sources=sources,
)
],
)
@classmethod
def json_from_delta(
cls,
*,
text: str | None,
finish_reason: str | None = None,
sources: list[Chunk] | None = None,
) -> str:
chunk = OpenAICompletion(
id=str(uuid.uuid4()),
object="completion.chunk",
created=int(time.time()),
model="private-gpt",
choices=[
OpenAIChoice(
delta=OpenAIDelta(content=text),
finish_reason=finish_reason,
sources=sources,
)
],
)
return chunk.model_dump_json()
def to_openai_response(
response: str | ChatResponse, sources: list[Chunk] | None = None
) -> OpenAICompletion:
if isinstance(response, ChatResponse):
return OpenAICompletion.from_text(response.delta, finish_reason="stop")
else:
return OpenAICompletion.from_text(
response, finish_reason="stop", sources=sources
)
def to_openai_sse_stream(
response_generator: Iterator[str | CompletionResponse | ChatResponse],
sources: list[Chunk] | None = None,
) -> Iterator[str]:
for response in response_generator:
if isinstance(response, CompletionResponse | ChatResponse):
yield f"data: {OpenAICompletion.json_from_delta(text=response.delta)}\n\n"
else:
yield f"data: {OpenAICompletion.json_from_delta(text=response, sources=sources)}\n\n"
yield f"data: {OpenAICompletion.json_from_delta(text='', finish_reason='stop')}\n\n"
yield "data: [DONE]\n\n"

18
pgpt/private_gpt/paths.py Normal file
View File

@ -0,0 +1,18 @@
from pathlib import Path
from private_gpt.constants import PROJECT_ROOT_PATH
from private_gpt.settings.settings import settings
def _absolute_or_from_project_root(path: str) -> Path:
if path.startswith("/"):
return Path(path)
return PROJECT_ROOT_PATH / path
models_path: Path = PROJECT_ROOT_PATH / "models"
models_cache_path: Path = models_path / "cache"
docs_path: Path = PROJECT_ROOT_PATH / "docs"
local_data_path: Path = _absolute_or_from_project_root(
settings().data.local_data_folder
)

View File

@ -0,0 +1 @@
"""private-gpt server."""

View File

View File

@ -0,0 +1,115 @@
from fastapi import APIRouter, Depends, Request
from llama_index.core.llms import ChatMessage, MessageRole
from pydantic import BaseModel
from starlette.responses import StreamingResponse
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.open_ai.openai_models import (
OpenAICompletion,
OpenAIMessage,
to_openai_response,
to_openai_sse_stream,
)
from private_gpt.server.chat.chat_service import ChatService
from private_gpt.server.utils.auth import authenticated
chat_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
class ChatBody(BaseModel):
messages: list[OpenAIMessage]
use_context: bool = False
context_filter: ContextFilter | None = None
include_sources: bool = True
stream: bool = False
model_config = {
"json_schema_extra": {
"examples": [
{
"messages": [
{
"role": "system",
"content": "You are a rapper. Always answer with a rap.",
},
{
"role": "user",
"content": "How do you fry an egg?",
},
],
"stream": False,
"use_context": True,
"include_sources": True,
"context_filter": {
"docs_ids": ["c202d5e6-7b69-4869-81cc-dd574ee8ee11"]
},
}
]
}
}
@chat_router.post(
"/chat/completions",
response_model=None,
responses={200: {"model": OpenAICompletion}},
tags=["Contextual Completions"],
openapi_extra={
"x-fern-streaming": {
"stream-condition": "stream",
"response": {"$ref": "#/components/schemas/OpenAICompletion"},
"response-stream": {"$ref": "#/components/schemas/OpenAICompletion"},
}
},
)
def chat_completion(
request: Request, body: ChatBody
) -> OpenAICompletion | StreamingResponse:
"""Given a list of messages comprising a conversation, return a response.
Optionally include an initial `role: system` message to influence the way
the LLM answers.
If `use_context` is set to `true`, the model will use context coming
from the ingested documents to create the response. The documents being used can
be filtered using the `context_filter` and passing the document IDs to be used.
Ingested documents IDs can be found using `/ingest/list` endpoint. If you want
all ingested documents to be used, remove `context_filter` altogether.
When using `'include_sources': true`, the API will return the source Chunks used
to create the response, which come from the context provided.
When using `'stream': true`, the API will return data chunks following [OpenAI's
streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
```
{"id":"12345","object":"completion.chunk","created":1694268190,
"model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
"finish_reason":null}]}
```
"""
service = request.state.injector.get(ChatService)
all_messages = [
ChatMessage(content=m.content, role=MessageRole(m.role)) for m in body.messages
]
if body.stream:
completion_gen = service.stream_chat(
messages=all_messages,
use_context=body.use_context,
context_filter=body.context_filter,
)
return StreamingResponse(
to_openai_sse_stream(
completion_gen.response,
completion_gen.sources if body.include_sources else None,
),
media_type="text/event-stream",
)
else:
completion = service.chat(
messages=all_messages,
use_context=body.use_context,
context_filter=body.context_filter,
)
return to_openai_response(
completion.response, completion.sources if body.include_sources else None
)

View File

@ -0,0 +1,217 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING
from injector import inject, singleton
from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine
from llama_index.core.chat_engine.types import (
BaseChatEngine,
)
from llama_index.core.indices import VectorStoreIndex
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.postprocessor import (
SentenceTransformerRerank,
SimilarityPostprocessor,
)
from llama_index.core.storage import StorageContext
from llama_index.core.types import TokenGen
from pydantic import BaseModel
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
from private_gpt.components.llm.llm_component import LLMComponent
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
from private_gpt.components.vector_store.vector_store_component import (
VectorStoreComponent,
)
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chunks.chunks_service import Chunk
from private_gpt.settings.settings import Settings
if TYPE_CHECKING:
from llama_index.core.postprocessor.types import BaseNodePostprocessor
class Completion(BaseModel):
response: str
sources: list[Chunk] | None = None
class CompletionGen(BaseModel):
response: TokenGen
sources: list[Chunk] | None = None
@dataclass
class ChatEngineInput:
system_message: ChatMessage | None = None
last_message: ChatMessage | None = None
chat_history: list[ChatMessage] | None = None
@classmethod
def from_messages(cls, messages: list[ChatMessage]) -> "ChatEngineInput":
# Detect if there is a system message, extract the last message and chat history
system_message = (
messages[0]
if len(messages) > 0 and messages[0].role == MessageRole.SYSTEM
else None
)
last_message = (
messages[-1]
if len(messages) > 0 and messages[-1].role == MessageRole.USER
else None
)
# Remove from messages list the system message and last message,
# if they exist. The rest is the chat history.
if system_message:
messages.pop(0)
if last_message:
messages.pop(-1)
chat_history = messages if len(messages) > 0 else None
return cls(
system_message=system_message,
last_message=last_message,
chat_history=chat_history,
)
@singleton
class ChatService:
settings: Settings
@inject
def __init__(
self,
settings: Settings,
llm_component: LLMComponent,
vector_store_component: VectorStoreComponent,
embedding_component: EmbeddingComponent,
node_store_component: NodeStoreComponent,
) -> None:
self.settings = settings
self.llm_component = llm_component
self.embedding_component = embedding_component
self.vector_store_component = vector_store_component
self.storage_context = StorageContext.from_defaults(
vector_store=vector_store_component.vector_store,
docstore=node_store_component.doc_store,
index_store=node_store_component.index_store,
)
self.index = VectorStoreIndex.from_vector_store(
vector_store_component.vector_store,
storage_context=self.storage_context,
llm=llm_component.llm,
embed_model=embedding_component.embedding_model,
show_progress=True,
)
def _chat_engine(
self,
system_prompt: str | None = None,
use_context: bool = False,
context_filter: ContextFilter | None = None,
) -> BaseChatEngine:
settings = self.settings
if use_context:
vector_index_retriever = self.vector_store_component.get_retriever(
index=self.index,
context_filter=context_filter,
similarity_top_k=self.settings.rag.similarity_top_k,
)
node_postprocessors: list[BaseNodePostprocessor] = [
MetadataReplacementPostProcessor(target_metadata_key="window"),
]
if settings.rag.similarity_value:
node_postprocessors.append(
SimilarityPostprocessor(
similarity_cutoff=settings.rag.similarity_value
)
)
if settings.rag.rerank.enabled:
rerank_postprocessor = SentenceTransformerRerank(
model=settings.rag.rerank.model, top_n=settings.rag.rerank.top_n
)
node_postprocessors.append(rerank_postprocessor)
return ContextChatEngine.from_defaults(
system_prompt=system_prompt,
retriever=vector_index_retriever,
llm=self.llm_component.llm, # Takes no effect at the moment
node_postprocessors=node_postprocessors,
)
else:
return SimpleChatEngine.from_defaults(
system_prompt=system_prompt,
llm=self.llm_component.llm,
)
def stream_chat(
self,
messages: list[ChatMessage],
use_context: bool = False,
context_filter: ContextFilter | None = None,
) -> CompletionGen:
chat_engine_input = ChatEngineInput.from_messages(messages)
last_message = (
chat_engine_input.last_message.content
if chat_engine_input.last_message
else None
)
system_prompt = (
chat_engine_input.system_message.content
if chat_engine_input.system_message
else None
)
chat_history = (
chat_engine_input.chat_history if chat_engine_input.chat_history else None
)
chat_engine = self._chat_engine(
system_prompt=system_prompt,
use_context=use_context,
context_filter=context_filter,
)
streaming_response = chat_engine.stream_chat(
message=last_message if last_message is not None else "",
chat_history=chat_history,
)
sources = [Chunk.from_node(node) for node in streaming_response.source_nodes]
completion_gen = CompletionGen(
response=streaming_response.response_gen, sources=sources
)
return completion_gen
def chat(
self,
messages: list[ChatMessage],
use_context: bool = False,
context_filter: ContextFilter | None = None,
) -> Completion:
chat_engine_input = ChatEngineInput.from_messages(messages)
last_message = (
chat_engine_input.last_message.content
if chat_engine_input.last_message
else None
)
system_prompt = (
chat_engine_input.system_message.content
if chat_engine_input.system_message
else None
)
chat_history = (
chat_engine_input.chat_history if chat_engine_input.chat_history else None
)
chat_engine = self._chat_engine(
system_prompt=system_prompt,
use_context=use_context,
context_filter=context_filter,
)
wrapped_response = chat_engine.chat(
message=last_message if last_message is not None else "",
chat_history=chat_history,
)
sources = [Chunk.from_node(node) for node in wrapped_response.source_nodes]
completion = Completion(response=wrapped_response.response, sources=sources)
return completion

View File

@ -0,0 +1,55 @@
from typing import Literal
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel, Field
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.utils.auth import authenticated
chunks_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
class ChunksBody(BaseModel):
text: str = Field(examples=["Q3 2023 sales"])
context_filter: ContextFilter | None = None
limit: int = 10
prev_next_chunks: int = Field(default=0, examples=[2])
class ChunksResponse(BaseModel):
object: Literal["list"]
model: Literal["private-gpt"]
data: list[Chunk]
@chunks_router.post("/chunks", tags=["Context Chunks"])
def chunks_retrieval(request: Request, body: ChunksBody) -> ChunksResponse:
"""Given a `text`, returns the most relevant chunks from the ingested documents.
The returned information can be used to generate prompts that can be
passed to `/completions` or `/chat/completions` APIs. Note: it is usually a very
fast API, because only the Embeddings model is involved, not the LLM. The
returned information contains the relevant chunk `text` together with the source
`document` it is coming from. It also contains a score that can be used to
compare different results.
The max number of chunks to be returned is set using the `limit` param.
Previous and next chunks (pieces of text that appear right before or after in the
document) can be fetched by using the `prev_next_chunks` field.
The documents being used can be filtered using the `context_filter` and passing
the document IDs to be used. Ingested documents IDs can be found using
`/ingest/list` endpoint. If you want all ingested documents to be used,
remove `context_filter` altogether.
"""
service = request.state.injector.get(ChunksService)
results = service.retrieve_relevant(
body.text, body.context_filter, body.limit, body.prev_next_chunks
)
return ChunksResponse(
object="list",
model="private-gpt",
data=results,
)

View File

@ -0,0 +1,125 @@
from typing import TYPE_CHECKING, Literal
from injector import inject, singleton
from llama_index.core.indices import VectorStoreIndex
from llama_index.core.schema import NodeWithScore
from llama_index.core.storage import StorageContext
from pydantic import BaseModel, Field
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
from private_gpt.components.llm.llm_component import LLMComponent
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
from private_gpt.components.vector_store.vector_store_component import (
VectorStoreComponent,
)
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.ingest.model import IngestedDoc
if TYPE_CHECKING:
from llama_index.core.schema import RelatedNodeInfo
class Chunk(BaseModel):
object: Literal["context.chunk"]
score: float = Field(examples=[0.023])
document: IngestedDoc
text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."])
previous_texts: list[str] | None = Field(
default=None,
examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]],
)
next_texts: list[str] | None = Field(
default=None,
examples=[
[
"New leads came from Google Ads campaign.",
"The campaign was run by the Marketing Department",
]
],
)
@classmethod
def from_node(cls: type["Chunk"], node: NodeWithScore) -> "Chunk":
doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-"
return cls(
object="context.chunk",
score=node.score or 0.0,
document=IngestedDoc(
object="ingest.document",
doc_id=doc_id,
doc_metadata=node.metadata,
),
text=node.get_content(),
)
@singleton
class ChunksService:
@inject
def __init__(
self,
llm_component: LLMComponent,
vector_store_component: VectorStoreComponent,
embedding_component: EmbeddingComponent,
node_store_component: NodeStoreComponent,
) -> None:
self.vector_store_component = vector_store_component
self.llm_component = llm_component
self.embedding_component = embedding_component
self.storage_context = StorageContext.from_defaults(
vector_store=vector_store_component.vector_store,
docstore=node_store_component.doc_store,
index_store=node_store_component.index_store,
)
def _get_sibling_nodes_text(
self, node_with_score: NodeWithScore, related_number: int, forward: bool = True
) -> list[str]:
explored_nodes_texts = []
current_node = node_with_score.node
for _ in range(related_number):
explored_node_info: RelatedNodeInfo | None = (
current_node.next_node if forward else current_node.prev_node
)
if explored_node_info is None:
break
explored_node = self.storage_context.docstore.get_node(
explored_node_info.node_id
)
explored_nodes_texts.append(explored_node.get_content())
current_node = explored_node
return explored_nodes_texts
def retrieve_relevant(
self,
text: str,
context_filter: ContextFilter | None = None,
limit: int = 10,
prev_next_chunks: int = 0,
) -> list[Chunk]:
index = VectorStoreIndex.from_vector_store(
self.vector_store_component.vector_store,
storage_context=self.storage_context,
llm=self.llm_component.llm,
embed_model=self.embedding_component.embedding_model,
show_progress=True,
)
vector_index_retriever = self.vector_store_component.get_retriever(
index=index, context_filter=context_filter, similarity_top_k=limit
)
nodes = vector_index_retriever.retrieve(text)
nodes.sort(key=lambda n: n.score or 0.0, reverse=True)
retrieved_nodes = []
for node in nodes:
chunk = Chunk.from_node(node)
chunk.previous_texts = self._get_sibling_nodes_text(
node, prev_next_chunks, False
)
chunk.next_texts = self._get_sibling_nodes_text(node, prev_next_chunks)
retrieved_nodes.append(chunk)
return retrieved_nodes

View File

@ -0,0 +1 @@
"""Deprecated Openai compatibility endpoint."""

View File

@ -0,0 +1,92 @@
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from starlette.responses import StreamingResponse
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.open_ai.openai_models import (
OpenAICompletion,
OpenAIMessage,
)
from private_gpt.server.chat.chat_router import ChatBody, chat_completion
from private_gpt.server.utils.auth import authenticated
completions_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
class CompletionsBody(BaseModel):
prompt: str
system_prompt: str | None = None
use_context: bool = False
context_filter: ContextFilter | None = None
include_sources: bool = True
stream: bool = False
model_config = {
"json_schema_extra": {
"examples": [
{
"prompt": "How do you fry an egg?",
"system_prompt": "You are a rapper. Always answer with a rap.",
"stream": False,
"use_context": False,
"include_sources": False,
}
]
}
}
@completions_router.post(
"/completions",
response_model=None,
summary="Completion",
responses={200: {"model": OpenAICompletion}},
tags=["Contextual Completions"],
openapi_extra={
"x-fern-streaming": {
"stream-condition": "stream",
"response": {"$ref": "#/components/schemas/OpenAICompletion"},
"response-stream": {"$ref": "#/components/schemas/OpenAICompletion"},
}
},
)
def prompt_completion(
request: Request, body: CompletionsBody
) -> OpenAICompletion | StreamingResponse:
"""We recommend most users use our Chat completions API.
Given a prompt, the model will return one predicted completion.
Optionally include a `system_prompt` to influence the way the LLM answers.
If `use_context`
is set to `true`, the model will use context coming from the ingested documents
to create the response. The documents being used can be filtered using the
`context_filter` and passing the document IDs to be used. Ingested documents IDs
can be found using `/ingest/list` endpoint. If you want all ingested documents to
be used, remove `context_filter` altogether.
When using `'include_sources': true`, the API will return the source Chunks used
to create the response, which come from the context provided.
When using `'stream': true`, the API will return data chunks following [OpenAI's
streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
```
{"id":"12345","object":"completion.chunk","created":1694268190,
"model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
"finish_reason":null}]}
```
"""
messages = [OpenAIMessage(content=body.prompt, role="user")]
# If system prompt is passed, create a fake message with the system prompt.
if body.system_prompt:
messages.insert(0, OpenAIMessage(content=body.system_prompt, role="system"))
chat_body = ChatBody(
messages=messages,
use_context=body.use_context,
stream=body.stream,
include_sources=body.include_sources,
context_filter=body.context_filter,
)
return chat_completion(request, chat_body)

View File

@ -0,0 +1,35 @@
from typing import Literal
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from private_gpt.server.embeddings.embeddings_service import (
Embedding,
EmbeddingsService,
)
from private_gpt.server.utils.auth import authenticated
embeddings_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
class EmbeddingsBody(BaseModel):
input: str | list[str]
class EmbeddingsResponse(BaseModel):
object: Literal["list"]
model: Literal["private-gpt"]
data: list[Embedding]
@embeddings_router.post("/embeddings", tags=["Embeddings"])
def embeddings_generation(request: Request, body: EmbeddingsBody) -> EmbeddingsResponse:
"""Get a vector representation of a given input.
That vector representation can be easily consumed
by machine learning models and algorithms.
"""
service = request.state.injector.get(EmbeddingsService)
input_texts = body.input if isinstance(body.input, list) else [body.input]
embeddings = service.texts_embeddings(input_texts)
return EmbeddingsResponse(object="list", model="private-gpt", data=embeddings)

View File

@ -0,0 +1,30 @@
from typing import Literal
from injector import inject, singleton
from pydantic import BaseModel, Field
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
class Embedding(BaseModel):
index: int
object: Literal["embedding"]
embedding: list[float] = Field(examples=[[0.0023064255, -0.009327292]])
@singleton
class EmbeddingsService:
@inject
def __init__(self, embedding_component: EmbeddingComponent) -> None:
self.embedding_model = embedding_component.embedding_model
def texts_embeddings(self, texts: list[str]) -> list[Embedding]:
texts_embeddings = self.embedding_model.get_text_embedding_batch(texts)
return [
Embedding(
index=texts_embeddings.index(embedding),
object="embedding",
embedding=embedding,
)
for embedding in texts_embeddings
]

View File

@ -0,0 +1,17 @@
from typing import Literal
from fastapi import APIRouter
from pydantic import BaseModel, Field
# Not authentication or authorization required to get the health status.
health_router = APIRouter()
class HealthResponse(BaseModel):
status: Literal["ok"] = Field(default="ok")
@health_router.get("/health", tags=["Health"])
def health() -> HealthResponse:
"""Return ok if the system is up."""
return HealthResponse(status="ok")

View File

@ -0,0 +1,104 @@
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, Request, UploadFile
from pydantic import BaseModel, Field
from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.server.ingest.model import IngestedDoc
from private_gpt.server.utils.auth import authenticated
ingest_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
class IngestTextBody(BaseModel):
file_name: str = Field(examples=["Avatar: The Last Airbender"])
text: str = Field(
examples=[
"Avatar is set in an Asian and Arctic-inspired world in which some "
"people can telekinetically manipulate one of the four elements—water, "
"earth, fire or air—through practices known as 'bending', inspired by "
"Chinese martial arts."
]
)
class IngestResponse(BaseModel):
object: Literal["list"]
model: Literal["private-gpt"]
data: list[IngestedDoc]
@ingest_router.post("/ingest", tags=["Ingestion"], deprecated=True)
def ingest(request: Request, file: UploadFile) -> IngestResponse:
"""Ingests and processes a file.
Deprecated. Use ingest/file instead.
"""
return ingest_file(request, file)
@ingest_router.post("/ingest/file", tags=["Ingestion"])
def ingest_file(request: Request, file: UploadFile) -> IngestResponse:
"""Ingests and processes a file, storing its chunks to be used as context.
The context obtained from files is later used in
`/chat/completions`, `/completions`, and `/chunks` APIs.
Most common document
formats are supported, but you may be prompted to install an extra dependency to
manage a specific file type.
A file can generate different Documents (for example a PDF generates one Document
per page). All Documents IDs are returned in the response, together with the
extracted Metadata (which is later used to improve context retrieval). Those IDs
can be used to filter the context used to create responses in
`/chat/completions`, `/completions`, and `/chunks` APIs.
"""
service = request.state.injector.get(IngestService)
if file.filename is None:
raise HTTPException(400, "No file name provided")
ingested_documents = service.ingest_bin_data(file.filename, file.file)
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
@ingest_router.post("/ingest/text", tags=["Ingestion"])
def ingest_text(request: Request, body: IngestTextBody) -> IngestResponse:
"""Ingests and processes a text, storing its chunks to be used as context.
The context obtained from files is later used in
`/chat/completions`, `/completions`, and `/chunks` APIs.
A Document will be generated with the given text. The Document
ID is returned in the response, together with the
extracted Metadata (which is later used to improve context retrieval). That ID
can be used to filter the context used to create responses in
`/chat/completions`, `/completions`, and `/chunks` APIs.
"""
service = request.state.injector.get(IngestService)
if len(body.file_name) == 0:
raise HTTPException(400, "No file name provided")
ingested_documents = service.ingest_text(body.file_name, body.text)
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
@ingest_router.get("/ingest/list", tags=["Ingestion"])
def list_ingested(request: Request) -> IngestResponse:
"""Lists already ingested Documents including their Document ID and metadata.
Those IDs can be used to filter the context used to create responses
in `/chat/completions`, `/completions`, and `/chunks` APIs.
"""
service = request.state.injector.get(IngestService)
ingested_documents = service.list_ingested()
return IngestResponse(object="list", model="private-gpt", data=ingested_documents)
@ingest_router.delete("/ingest/{doc_id}", tags=["Ingestion"])
def delete_ingested(request: Request, doc_id: str) -> None:
"""Delete the specified ingested Document.
The `doc_id` can be obtained from the `GET /ingest/list` endpoint.
The document will be effectively deleted from your storage context.
"""
service = request.state.injector.get(IngestService)
service.delete(doc_id)

View File

@ -0,0 +1,125 @@
import logging
import tempfile
from pathlib import Path
from typing import TYPE_CHECKING, AnyStr, BinaryIO
from injector import inject, singleton
from llama_index.core.node_parser import SentenceWindowNodeParser
from llama_index.core.storage import StorageContext
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
from private_gpt.components.ingest.ingest_component import get_ingestion_component
from private_gpt.components.llm.llm_component import LLMComponent
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
from private_gpt.components.vector_store.vector_store_component import (
VectorStoreComponent,
)
from private_gpt.server.ingest.model import IngestedDoc
from private_gpt.settings.settings import settings
if TYPE_CHECKING:
from llama_index.core.storage.docstore.types import RefDocInfo
logger = logging.getLogger(__name__)
@singleton
class IngestService:
@inject
def __init__(
self,
llm_component: LLMComponent,
vector_store_component: VectorStoreComponent,
embedding_component: EmbeddingComponent,
node_store_component: NodeStoreComponent,
) -> None:
self.llm_service = llm_component
self.storage_context = StorageContext.from_defaults(
vector_store=vector_store_component.vector_store,
docstore=node_store_component.doc_store,
index_store=node_store_component.index_store,
)
node_parser = SentenceWindowNodeParser.from_defaults()
self.ingest_component = get_ingestion_component(
self.storage_context,
embed_model=embedding_component.embedding_model,
transformations=[node_parser, embedding_component.embedding_model],
settings=settings(),
)
def _ingest_data(self, file_name: str, file_data: AnyStr) -> list[IngestedDoc]:
logger.debug("Got file data of size=%s to ingest", len(file_data))
# llama-index mainly supports reading from files, so
# we have to create a tmp file to read for it to work
# delete=False to avoid a Windows 11 permission error.
with tempfile.NamedTemporaryFile(delete=False) as tmp:
try:
path_to_tmp = Path(tmp.name)
if isinstance(file_data, bytes):
path_to_tmp.write_bytes(file_data)
else:
path_to_tmp.write_text(str(file_data))
return self.ingest_file(file_name, path_to_tmp)
finally:
tmp.close()
path_to_tmp.unlink()
def ingest_file(self, file_name: str, file_data: Path) -> list[IngestedDoc]:
logger.info("Ingesting file_name=%s", file_name)
documents = self.ingest_component.ingest(file_name, file_data)
logger.info("Finished ingestion file_name=%s", file_name)
return [IngestedDoc.from_document(document) for document in documents]
def ingest_text(self, file_name: str, text: str) -> list[IngestedDoc]:
logger.debug("Ingesting text data with file_name=%s", file_name)
return self._ingest_data(file_name, text)
def ingest_bin_data(
self, file_name: str, raw_file_data: BinaryIO
) -> list[IngestedDoc]:
logger.debug("Ingesting binary data with file_name=%s", file_name)
file_data = raw_file_data.read()
return self._ingest_data(file_name, file_data)
def bulk_ingest(self, files: list[tuple[str, Path]]) -> list[IngestedDoc]:
logger.info("Ingesting file_names=%s", [f[0] for f in files])
documents = self.ingest_component.bulk_ingest(files)
logger.info("Finished ingestion file_name=%s", [f[0] for f in files])
return [IngestedDoc.from_document(document) for document in documents]
def list_ingested(self) -> list[IngestedDoc]:
ingested_docs: list[IngestedDoc] = []
try:
docstore = self.storage_context.docstore
ref_docs: dict[str, RefDocInfo] | None = docstore.get_all_ref_doc_info()
if not ref_docs:
return ingested_docs
for doc_id, ref_doc_info in ref_docs.items():
doc_metadata = None
if ref_doc_info is not None and ref_doc_info.metadata is not None:
doc_metadata = IngestedDoc.curate_metadata(ref_doc_info.metadata)
ingested_docs.append(
IngestedDoc(
object="ingest.document",
doc_id=doc_id,
doc_metadata=doc_metadata,
)
)
except ValueError:
logger.warning("Got an exception when getting list of docs", exc_info=True)
pass
logger.debug("Found count=%s ingested documents", len(ingested_docs))
return ingested_docs
def delete(self, doc_id: str) -> None:
"""Delete an ingested document.
:raises ValueError: if the document does not exist
"""
logger.info(
"Deleting the ingested document=%s in the doc and index store", doc_id
)
self.ingest_component.delete(doc_id)

View File

@ -0,0 +1,45 @@
from collections.abc import Callable
from pathlib import Path
from typing import Any
from watchdog.events import (
FileCreatedEvent,
FileModifiedEvent,
FileSystemEvent,
FileSystemEventHandler,
)
from watchdog.observers import Observer
class IngestWatcher:
def __init__(
self, watch_path: Path, on_file_changed: Callable[[Path], None]
) -> None:
self.watch_path = watch_path
self.on_file_changed = on_file_changed
class Handler(FileSystemEventHandler):
def on_modified(self, event: FileSystemEvent) -> None:
if isinstance(event, FileModifiedEvent):
on_file_changed(Path(event.src_path))
def on_created(self, event: FileSystemEvent) -> None:
if isinstance(event, FileCreatedEvent):
on_file_changed(Path(event.src_path))
event_handler = Handler()
observer: Any = Observer()
self._observer = observer
self._observer.schedule(event_handler, str(watch_path), recursive=True)
def start(self) -> None:
self._observer.start()
while self._observer.is_alive():
try:
self._observer.join(1)
except KeyboardInterrupt:
break
def stop(self) -> None:
self._observer.stop()
self._observer.join()

View File

@ -0,0 +1,32 @@
from typing import Any, Literal
from llama_index.core.schema import Document
from pydantic import BaseModel, Field
class IngestedDoc(BaseModel):
object: Literal["ingest.document"]
doc_id: str = Field(examples=["c202d5e6-7b69-4869-81cc-dd574ee8ee11"])
doc_metadata: dict[str, Any] | None = Field(
examples=[
{
"page_label": "2",
"file_name": "Sales Report Q3 2023.pdf",
}
]
)
@staticmethod
def curate_metadata(metadata: dict[str, Any]) -> dict[str, Any]:
"""Remove unwanted metadata keys."""
for key in ["doc_id", "window", "original_text"]:
metadata.pop(key, None)
return metadata
@staticmethod
def from_document(document: Document) -> "IngestedDoc":
return IngestedDoc(
object="ingest.document",
doc_id=document.doc_id,
doc_metadata=IngestedDoc.curate_metadata(document.metadata),
)

View File

@ -0,0 +1,86 @@
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from starlette.responses import StreamingResponse
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.open_ai.openai_models import (
to_openai_sse_stream,
)
from private_gpt.server.recipes.summarize.summarize_service import SummarizeService
from private_gpt.server.utils.auth import authenticated
summarize_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
class SummarizeBody(BaseModel):
text: str | None = None
use_context: bool = False
context_filter: ContextFilter | None = None
prompt: str | None = None
instructions: str | None = None
stream: bool = False
class SummarizeResponse(BaseModel):
summary: str
@summarize_router.post(
"/summarize",
response_model=None,
summary="Summarize",
responses={200: {"model": SummarizeResponse}},
tags=["Recipes"],
)
def summarize(
request: Request, body: SummarizeBody
) -> SummarizeResponse | StreamingResponse:
"""Given a text, the model will return a summary.
Optionally include `instructions` to influence the way the summary is generated.
If `use_context`
is set to `true`, the model will also use the content coming from the ingested
documents in the summary. The documents being used can
be filtered by their metadata using the `context_filter`.
Ingested documents metadata can be found using `/ingest/list` endpoint.
If you want all ingested documents to be used, remove `context_filter` altogether.
If `prompt` is set, it will be used as the prompt for the summarization,
otherwise the default prompt will be used.
When using `'stream': true`, the API will return data chunks following [OpenAI's
streaming model](https://platform.openai.com/docs/api-reference/chat/streaming):
```
{"id":"12345","object":"completion.chunk","created":1694268190,
"model":"private-gpt","choices":[{"index":0,"delta":{"content":"Hello"},
"finish_reason":null}]}
```
"""
service: SummarizeService = request.state.injector.get(SummarizeService)
if body.stream:
completion_gen = service.stream_summarize(
text=body.text,
instructions=body.instructions,
use_context=body.use_context,
context_filter=body.context_filter,
prompt=body.prompt,
)
return StreamingResponse(
to_openai_sse_stream(
response_generator=completion_gen,
),
media_type="text/event-stream",
)
else:
completion = service.summarize(
text=body.text,
instructions=body.instructions,
use_context=body.use_context,
context_filter=body.context_filter,
prompt=body.prompt,
)
return SummarizeResponse(
summary=completion,
)

View File

@ -0,0 +1,182 @@
from itertools import chain
from injector import inject, singleton
from llama_index.core import (
Document,
StorageContext,
SummaryIndex,
)
from llama_index.core.base.response.schema import Response, StreamingResponse
from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.response_synthesizers import ResponseMode
from llama_index.core.storage.docstore.types import RefDocInfo
from llama_index.core.types import TokenGen
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
from private_gpt.components.llm.llm_component import LLMComponent
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
from private_gpt.components.vector_store.vector_store_component import (
VectorStoreComponent,
)
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.settings.settings import Settings
#######
# Modification par SPC
#
# Forcer à utiliser le prompt de settings.yaml et non un prompt par défaut
#
# DEFAULT_SUMMARIZE_PROMPT = (
# "Provide a comprehensive summary of the provided context information. "
# "The summary should cover all the key points and main ideas presented in "
# "the original text, while also condensing the information into a concise "
# "and easy-to-understand format. Please ensure that the summary includes "
# "relevant details and examples that support the main ideas, while avoiding "
# "any unnecessary information or repetition."
# )
from private_gpt.settings.settings import settings
DEFAULT_SUMMARIZE_PROMPT = settings().ui.default_summarization_system_prompt
#
# Fin modification par SPC
#######
@singleton
class SummarizeService:
@inject
def __init__(
self,
settings: Settings,
llm_component: LLMComponent,
node_store_component: NodeStoreComponent,
vector_store_component: VectorStoreComponent,
embedding_component: EmbeddingComponent,
) -> None:
self.settings = settings
self.llm_component = llm_component
self.node_store_component = node_store_component
self.vector_store_component = vector_store_component
self.embedding_component = embedding_component
self.storage_context = StorageContext.from_defaults(
vector_store=vector_store_component.vector_store,
docstore=node_store_component.doc_store,
index_store=node_store_component.index_store,
)
@staticmethod
def _filter_ref_docs(
ref_docs: dict[str, RefDocInfo], context_filter: ContextFilter | None
) -> list[RefDocInfo]:
if context_filter is None or not context_filter.docs_ids:
return list(ref_docs.values())
return [
ref_doc
for doc_id, ref_doc in ref_docs.items()
if doc_id in context_filter.docs_ids
]
def _summarize(
self,
use_context: bool = False,
stream: bool = False,
text: str | None = None,
instructions: str | None = None,
context_filter: ContextFilter | None = None,
prompt: str | None = None,
) -> str | TokenGen:
nodes_to_summarize = []
# Add text to summarize
if text:
text_documents = [Document(text=text)]
nodes_to_summarize += (
SentenceSplitter.from_defaults().get_nodes_from_documents(
text_documents
)
)
# Add context documents to summarize
if use_context:
# 1. Recover all ref docs
ref_docs: dict[str, RefDocInfo] | None = (
self.storage_context.docstore.get_all_ref_doc_info()
)
if ref_docs is None:
raise ValueError("No documents have been ingested yet.")
# 2. Filter documents based on context_filter (if provided)
filtered_ref_docs = self._filter_ref_docs(ref_docs, context_filter)
# 3. Get all nodes from the filtered documents
filtered_node_ids = chain.from_iterable(
[ref_doc.node_ids for ref_doc in filtered_ref_docs]
)
filtered_nodes = self.storage_context.docstore.get_nodes(
node_ids=list(filtered_node_ids),
)
nodes_to_summarize += filtered_nodes
# Create a SummaryIndex to summarize the nodes
summary_index = SummaryIndex(
nodes=nodes_to_summarize,
storage_context=StorageContext.from_defaults(), # In memory SummaryIndex
show_progress=True,
)
# Make a tree summarization query
# above the set of all candidate nodes
query_engine = summary_index.as_query_engine(
llm=self.llm_component.llm,
response_mode=ResponseMode.TREE_SUMMARIZE,
streaming=stream,
use_async=self.settings.summarize.use_async,
)
prompt = prompt or DEFAULT_SUMMARIZE_PROMPT
summarize_query = prompt + "\n" + (instructions or "")
response = query_engine.query(summarize_query)
if isinstance(response, Response):
return response.response or ""
elif isinstance(response, StreamingResponse):
return response.response_gen
else:
raise TypeError(f"The result is not of a supported type: {type(response)}")
def summarize(
self,
use_context: bool = False,
text: str | None = None,
instructions: str | None = None,
context_filter: ContextFilter | None = None,
prompt: str | None = None,
) -> str:
return self._summarize(
use_context=use_context,
stream=False,
text=text,
instructions=instructions,
context_filter=context_filter,
prompt=prompt,
) # type: ignore
def stream_summarize(
self,
use_context: bool = False,
text: str | None = None,
instructions: str | None = None,
context_filter: ContextFilter | None = None,
prompt: str | None = None,
) -> TokenGen:
return self._summarize(
use_context=use_context,
stream=True,
text=text,
instructions=instructions,
context_filter=context_filter,
prompt=prompt,
) # type: ignore

View File

@ -0,0 +1,69 @@
"""Authentication mechanism for the API.
Define a simple mechanism to authenticate requests.
More complex authentication mechanisms can be defined here, and be placed in the
`authenticated` method (being a 'bean' injected in fastapi routers).
Authorization can also be made after the authentication, and depends on
the authentication. Authorization should not be implemented in this file.
Authorization can be done by following fastapi's guides:
* https://fastapi.tiangolo.com/advanced/security/oauth2-scopes/
* https://fastapi.tiangolo.com/tutorial/security/
* https://fastapi.tiangolo.com/tutorial/dependencies/dependencies-in-path-operation-decorators/
"""
# mypy: ignore-errors
# Disabled mypy error: All conditional function variants must have identical signatures
# We are changing the implementation of the authenticated method, based on
# the config. If the auth is not enabled, we are not defining the complex method
# with its dependencies.
import logging
import secrets
from typing import Annotated
from fastapi import Depends, Header, HTTPException
from private_gpt.settings.settings import settings
# 401 signify that the request requires authentication.
# 403 signify that the authenticated user is not authorized to perform the operation.
NOT_AUTHENTICATED = HTTPException(
status_code=401,
detail="Not authenticated",
headers={"WWW-Authenticate": 'Basic realm="All the API", charset="UTF-8"'},
)
logger = logging.getLogger(__name__)
def _simple_authentication(authorization: Annotated[str, Header()] = "") -> bool:
"""Check if the request is authenticated."""
if not secrets.compare_digest(authorization, settings().server.auth.secret):
# If the "Authorization" header is not the expected one, raise an exception.
raise NOT_AUTHENTICATED
return True
if not settings().server.auth.enabled:
logger.debug(
"Defining a dummy authentication mechanism for fastapi, always authenticating requests"
)
# Define a dummy authentication method that always returns True.
def authenticated() -> bool:
"""Check if the request is authenticated."""
return True
else:
logger.info("Defining the given authentication mechanism for the API")
# Method to be used as a dependency to check if the request is authenticated.
def authenticated(
_simple_authentication: Annotated[bool, Depends(_simple_authentication)]
) -> bool:
"""Check if the request is authenticated."""
assert settings().server.auth.enabled
if not _simple_authentication:
raise NOT_AUTHENTICATED
return True

View File

@ -0,0 +1 @@
"""Settings."""

View File

@ -0,0 +1,638 @@
from typing import Any, Literal
from pydantic import BaseModel, Field
from private_gpt.settings.settings_loader import load_active_settings
class CorsSettings(BaseModel):
"""CORS configuration.
For more details on the CORS configuration, see:
# * https://fastapi.tiangolo.com/tutorial/cors/
# * https://developer.mozilla.org/en-US/docs/Web/HTTP/CORS
"""
enabled: bool = Field(
description="Flag indicating if CORS headers are set or not."
"If set to True, the CORS headers will be set to allow all origins, methods and headers.",
default=False,
)
allow_credentials: bool = Field(
description="Indicate that cookies should be supported for cross-origin requests",
default=False,
)
allow_origins: list[str] = Field(
description="A list of origins that should be permitted to make cross-origin requests.",
default=[],
)
allow_origin_regex: list[str] = Field(
description="A regex string to match against origins that should be permitted to make cross-origin requests.",
default=None,
)
allow_methods: list[str] = Field(
description="A list of HTTP methods that should be allowed for cross-origin requests.",
default=[
"GET",
],
)
allow_headers: list[str] = Field(
description="A list of HTTP request headers that should be supported for cross-origin requests.",
default=[],
)
class AuthSettings(BaseModel):
"""Authentication configuration.
The implementation of the authentication strategy must
"""
enabled: bool = Field(
description="Flag indicating if authentication is enabled or not.",
default=False,
)
secret: str = Field(
description="The secret to be used for authentication. "
"It can be any non-blank string. For HTTP basic authentication, "
"this value should be the whole 'Authorization' header that is expected"
)
class IngestionSettings(BaseModel):
"""Ingestion configuration.
This configuration is used to control the ingestion of data into the system
using non-server methods. This is useful for local development and testing;
or to ingest in bulk from a folder.
Please note that this configuration is not secure and should be used in
a controlled environment only (setting right permissions, etc.).
"""
enabled: bool = Field(
description="Flag indicating if local ingestion is enabled or not.",
default=False,
)
allow_ingest_from: list[str] = Field(
description="A list of folders that should be permitted to make ingest requests.",
default=[],
)
class ServerSettings(BaseModel):
env_name: str = Field(
description="Name of the environment (prod, staging, local...)"
)
port: int = Field(description="Port of PrivateGPT FastAPI server, defaults to 8001")
cors: CorsSettings = Field(
description="CORS configuration", default=CorsSettings(enabled=False)
)
auth: AuthSettings = Field(
description="Authentication configuration",
default_factory=lambda: AuthSettings(enabled=False, secret="secret-key"),
)
class DataSettings(BaseModel):
local_ingestion: IngestionSettings = Field(
description="Ingestion configuration",
default_factory=lambda: IngestionSettings(allow_ingest_from=["*"]),
)
local_data_folder: str = Field(
description="Path to local storage."
"It will be treated as an absolute path if it starts with /"
)
class LLMSettings(BaseModel):
mode: Literal[
"llamacpp",
"openai",
"openailike",
"azopenai",
"sagemaker",
"mock",
"ollama",
"gemini",
]
max_new_tokens: int = Field(
256,
description="The maximum number of token that the LLM is authorized to generate in one completion.",
)
context_window: int = Field(
3900,
description="The maximum number of context tokens for the model.",
)
tokenizer: str = Field(
None,
description="The model id of a predefined tokenizer hosted inside a model repo on "
"huggingface.co. Valid model ids can be located at the root-level, like "
"`bert-base-uncased`, or namespaced under a user or organization name, "
"like `HuggingFaceH4/zephyr-7b-beta`. If not set, will load a tokenizer matching "
"gpt-3.5-turbo LLM.",
)
temperature: float = Field(
0.1,
description="The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual.",
)
prompt_style: Literal["default", "llama2", "llama3", "tag", "mistral", "chatml"] = (
Field(
"llama2",
description=(
"The prompt style to use for the chat engine. "
"If `default` - use the default prompt style from the llama_index. It should look like `role: message`.\n"
"If `llama2` - use the llama2 prompt style from the llama_index. Based on `<s>`, `[INST]` and `<<SYS>>`.\n"
"If `llama3` - use the llama3 prompt style from the llama_index."
"If `tag` - use the `tag` prompt style. It should look like `<|role|>: message`. \n"
"If `mistral` - use the `mistral prompt style. It shoudl look like <s>[INST] {System Prompt} [/INST]</s>[INST] { UserInstructions } [/INST]"
"`llama2` is the historic behaviour. `default` might work better with your custom models."
),
)
)
class VectorstoreSettings(BaseModel):
database: Literal["chroma", "qdrant", "postgres", "clickhouse", "milvus"]
class NodeStoreSettings(BaseModel):
database: Literal["simple", "postgres"]
class LlamaCPPSettings(BaseModel):
llm_hf_repo_id: str
llm_hf_model_file: str
tfs_z: float = Field(
1.0,
description="Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting.",
)
top_k: int = Field(
40,
description="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)",
)
top_p: float = Field(
0.9,
description="Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)",
)
repeat_penalty: float = Field(
1.1,
description="Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)",
)
class HuggingFaceSettings(BaseModel):
embedding_hf_model_name: str = Field(
description="Name of the HuggingFace model to use for embeddings"
)
access_token: str = Field(
None,
description="Huggingface access token, required to download some models",
)
trust_remote_code: bool = Field(
False,
description="If set to True, the code from the remote model will be trusted and executed.",
)
class EmbeddingSettings(BaseModel):
mode: Literal[
"huggingface",
"openai",
"azopenai",
"sagemaker",
"ollama",
"mock",
"gemini",
"mistralai",
]
ingest_mode: Literal["simple", "batch", "parallel", "pipeline"] = Field(
"simple",
description=(
"The ingest mode to use for the embedding engine:\n"
"If `simple` - ingest files sequentially and one by one. It is the historic behaviour.\n"
"If `batch` - if multiple files, parse all the files in parallel, "
"and send them in batch to the embedding model.\n"
"In `pipeline` - The Embedding engine is kept as busy as possible\n"
"If `parallel` - parse the files in parallel using multiple cores, and embedd them in parallel.\n"
"`parallel` is the fastest mode for local setup, as it parallelize IO RW in the index.\n"
"For modes that leverage parallelization, you can specify the number of "
"workers to use with `count_workers`.\n"
),
)
count_workers: int = Field(
2,
description=(
"The number of workers to use for file ingestion.\n"
"In `batch` mode, this is the number of workers used to parse the files.\n"
"In `parallel` mode, this is the number of workers used to parse the files and embed them.\n"
"In `pipeline` mode, this is the number of workers that can perform embeddings.\n"
"This is only used if `ingest_mode` is not `simple`.\n"
"Do not go too high with this number, as it might cause memory issues. (especially in `parallel` mode)\n"
"Do not set it higher than your number of threads of your CPU."
),
)
embed_dim: int = Field(
384,
description="The dimension of the embeddings stored in the Postgres database",
)
class SagemakerSettings(BaseModel):
llm_endpoint_name: str
embedding_endpoint_name: str
class OpenAISettings(BaseModel):
api_base: str = Field(
None,
description="Base URL of OpenAI API. Example: 'https://api.openai.com/v1'.",
)
api_key: str
model: str = Field(
"gpt-3.5-turbo",
description="OpenAI Model to use. Example: 'gpt-4'.",
)
request_timeout: float = Field(
120.0,
description="Time elapsed until openailike server times out the request. Default is 120s. Format is float. ",
)
embedding_api_base: str = Field(
None,
description="Base URL of OpenAI API. Example: 'https://api.openai.com/v1'.",
)
embedding_api_key: str
embedding_model: str = Field(
"text-embedding-ada-002",
description="OpenAI embedding Model to use. Example: 'text-embedding-3-large'.",
)
class GeminiSettings(BaseModel):
api_key: str
model: str = Field(
"models/gemini-pro",
description="Google Model to use. Example: 'models/gemini-pro'.",
)
embedding_model: str = Field(
"models/embedding-001",
description="Google Embedding Model to use. Example: 'models/embedding-001'.",
)
class OllamaSettings(BaseModel):
api_base: str = Field(
"http://localhost:11434",
description="Base URL of Ollama API. Example: 'https://localhost:11434'.",
)
embedding_api_base: str = Field(
"http://localhost:11434",
description="Base URL of Ollama embedding API. Example: 'https://localhost:11434'.",
)
llm_model: str = Field(
None,
description="Model to use. Example: 'llama2-uncensored'.",
)
embedding_model: str = Field(
None,
description="Model to use. Example: 'nomic-embed-text'.",
)
keep_alive: str = Field(
"5m",
description="Time the model will stay loaded in memory after a request. examples: 5m, 5h, '-1' ",
)
tfs_z: float = Field(
1.0,
description="Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting.",
)
num_predict: int = Field(
None,
description="Maximum number of tokens to predict when generating text. (Default: 128, -1 = infinite generation, -2 = fill context)",
)
top_k: int = Field(
40,
description="Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)",
)
top_p: float = Field(
0.9,
description="Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)",
)
repeat_last_n: int = Field(
64,
description="Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)",
)
repeat_penalty: float = Field(
1.1,
description="Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)",
)
request_timeout: float = Field(
120.0,
description="Time elapsed until ollama times out the request. Default is 120s. Format is float. ",
)
autopull_models: bool = Field(
False,
description="If set to True, the Ollama will automatically pull the models from the API base.",
)
class AzureOpenAISettings(BaseModel):
api_key: str
azure_endpoint: str
api_version: str = Field(
"2023_05_15",
description="The API version to use for this operation. This follows the YYYY-MM-DD format.",
)
embedding_deployment_name: str
embedding_model: str = Field(
"text-embedding-ada-002",
description="OpenAI Model to use. Example: 'text-embedding-ada-002'.",
)
llm_deployment_name: str
llm_model: str = Field(
"gpt-35-turbo",
description="OpenAI Model to use. Example: 'gpt-4'.",
)
class UISettings(BaseModel):
enabled: bool
path: str
default_mode: Literal["RAG", "Search", "Basic", "Summarize"] = Field(
"RAG",
description="The default mode.",
)
default_chat_system_prompt: str = Field(
None,
description="The default system prompt to use for the chat mode.",
)
default_query_system_prompt: str = Field(
None, description="The default system prompt to use for the query mode."
)
default_summarization_system_prompt: str = Field(
None,
description="The default system prompt to use for the summarization mode.",
)
delete_file_button_enabled: bool = Field(
True, description="If the button to delete a file is enabled or not."
)
delete_all_files_button_enabled: bool = Field(
False, description="If the button to delete all files is enabled or not."
)
class RerankSettings(BaseModel):
enabled: bool = Field(
False,
description="This value controls whether a reranker should be included in the RAG pipeline.",
)
model: str = Field(
"cross-encoder/ms-marco-MiniLM-L-2-v2",
description="Rerank model to use. Limited to SentenceTransformer cross-encoder models.",
)
top_n: int = Field(
2,
description="This value controls the number of documents returned by the RAG pipeline.",
)
class RagSettings(BaseModel):
similarity_top_k: int = Field(
2,
description="This value controls the number of documents returned by the RAG pipeline or considered for reranking if enabled.",
)
similarity_value: float = Field(
None,
description="If set, any documents retrieved from the RAG must meet a certain match score. Acceptable values are between 0 and 1.",
)
rerank: RerankSettings
class SummarizeSettings(BaseModel):
use_async: bool = Field(
True,
description="If set to True, the summarization will be done asynchronously.",
)
class ClickHouseSettings(BaseModel):
host: str = Field(
"localhost",
description="The server hosting the ClickHouse database",
)
port: int = Field(
8443,
description="The port on which the ClickHouse database is accessible",
)
username: str = Field(
"default",
description="The username to use to connect to the ClickHouse database",
)
password: str = Field(
"",
description="The password to use to connect to the ClickHouse database",
)
database: str = Field(
"__default__",
description="The default database to use for connections",
)
secure: bool | str = Field(
False,
description="Use https/TLS for secure connection to the server",
)
interface: str | None = Field(
None,
description="Must be either 'http' or 'https'. Determines the protocol to use for the connection",
)
settings: dict[str, Any] | None = Field(
None,
description="Specific ClickHouse server settings to be used with the session",
)
connect_timeout: int | None = Field(
None,
description="Timeout in seconds for establishing a connection",
)
send_receive_timeout: int | None = Field(
None,
description="Read timeout in seconds for http connection",
)
verify: bool | None = Field(
None,
description="Verify the server certificate in secure/https mode",
)
ca_cert: str | None = Field(
None,
description="Path to Certificate Authority root certificate (.pem format)",
)
client_cert: str | None = Field(
None,
description="Path to TLS Client certificate (.pem format)",
)
client_cert_key: str | None = Field(
None,
description="Path to the private key for the TLS Client certificate",
)
http_proxy: str | None = Field(
None,
description="HTTP proxy address",
)
https_proxy: str | None = Field(
None,
description="HTTPS proxy address",
)
server_host_name: str | None = Field(
None,
description="Server host name to be checked against the TLS certificate",
)
class PostgresSettings(BaseModel):
host: str = Field(
"localhost",
description="The server hosting the Postgres database",
)
port: int = Field(
5432,
description="The port on which the Postgres database is accessible",
)
user: str = Field(
"postgres",
description="The user to use to connect to the Postgres database",
)
password: str = Field(
"postgres",
description="The password to use to connect to the Postgres database",
)
database: str = Field(
"postgres",
description="The database to use to connect to the Postgres database",
)
schema_name: str = Field(
"public",
description="The name of the schema in the Postgres database to use",
)
class QdrantSettings(BaseModel):
location: str | None = Field(
None,
description=(
"If `:memory:` - use in-memory Qdrant instance.\n"
"If `str` - use it as a `url` parameter.\n"
),
)
url: str | None = Field(
None,
description=(
"Either host or str of 'Optional[scheme], host, Optional[port], Optional[prefix]'."
),
)
port: int | None = Field(6333, description="Port of the REST API interface.")
grpc_port: int | None = Field(6334, description="Port of the gRPC interface.")
prefer_grpc: bool | None = Field(
False,
description="If `true` - use gRPC interface whenever possible in custom methods.",
)
https: bool | None = Field(
None,
description="If `true` - use HTTPS(SSL) protocol.",
)
api_key: str | None = Field(
None,
description="API key for authentication in Qdrant Cloud.",
)
prefix: str | None = Field(
None,
description=(
"Prefix to add to the REST URL path."
"Example: `service/v1` will result in "
"'http://localhost:6333/service/v1/{qdrant-endpoint}' for REST API."
),
)
timeout: float | None = Field(
None,
description="Timeout for REST and gRPC API requests.",
)
host: str | None = Field(
None,
description="Host name of Qdrant service. If url and host are None, set to 'localhost'.",
)
path: str | None = Field(None, description="Persistence path for QdrantLocal.")
force_disable_check_same_thread: bool | None = Field(
True,
description=(
"For QdrantLocal, force disable check_same_thread. Default: `True`"
"Only use this if you can guarantee that you can resolve the thread safety outside QdrantClient."
),
)
class MilvusSettings(BaseModel):
uri: str = Field(
"local_data/private_gpt/milvus/milvus_local.db",
description="The URI of the Milvus instance. For example: 'local_data/private_gpt/milvus/milvus_local.db' for Milvus Lite.",
)
token: str = Field(
"",
description=(
"A valid access token to access the specified Milvus instance. "
"This can be used as a recommended alternative to setting user and password separately. "
),
)
collection_name: str = Field(
"make_this_parameterizable_per_api_call",
description="The name of the collection in Milvus. Default is 'make_this_parameterizable_per_api_call'.",
)
overwrite: bool = Field(
True, description="Overwrite the previous collection schema if it exists."
)
class Settings(BaseModel):
server: ServerSettings
data: DataSettings
ui: UISettings
llm: LLMSettings
embedding: EmbeddingSettings
llamacpp: LlamaCPPSettings
huggingface: HuggingFaceSettings
sagemaker: SagemakerSettings
openai: OpenAISettings
gemini: GeminiSettings
ollama: OllamaSettings
azopenai: AzureOpenAISettings
vectorstore: VectorstoreSettings
nodestore: NodeStoreSettings
rag: RagSettings
summarize: SummarizeSettings
qdrant: QdrantSettings | None = None
postgres: PostgresSettings | None = None
clickhouse: ClickHouseSettings | None = None
milvus: MilvusSettings | None = None
"""
This is visible just for DI or testing purposes.
Use dependency injection or `settings()` method instead.
"""
unsafe_settings = load_active_settings()
"""
This is visible just for DI or testing purposes.
Use dependency injection or `settings()` method instead.
"""
unsafe_typed_settings = Settings(**unsafe_settings)
def settings() -> Settings:
"""Get the current loaded settings from the DI container.
This method exists to keep compatibility with the existing code,
that require global access to the settings.
For regular components use dependency injection instead.
"""
from private_gpt.di import global_injector
return global_injector.get(Settings)

View File

@ -0,0 +1,57 @@
import functools
import logging
import os
import sys
from collections.abc import Iterable
from pathlib import Path
from typing import Any
from pydantic.v1.utils import deep_update, unique_list
from private_gpt.constants import PROJECT_ROOT_PATH
from private_gpt.settings.yaml import load_yaml_with_envvars
logger = logging.getLogger(__name__)
_settings_folder = os.environ.get("PGPT_SETTINGS_FOLDER", PROJECT_ROOT_PATH)
# if running in unittest, use the test profile
_test_profile = ["test"] if "tests.fixtures" in sys.modules else []
active_profiles: list[str] = unique_list(
["default"]
+ [
item.strip()
for item in os.environ.get("PGPT_PROFILES", "").split(",")
if item.strip()
]
+ _test_profile
)
def merge_settings(settings: Iterable[dict[str, Any]]) -> dict[str, Any]:
return functools.reduce(deep_update, settings, {})
def load_settings_from_profile(profile: str) -> dict[str, Any]:
if profile == "default":
profile_file_name = "settings.yaml"
else:
profile_file_name = f"settings-{profile}.yaml"
path = Path(_settings_folder) / profile_file_name
with Path(path).open("r") as f:
config = load_yaml_with_envvars(f)
if not isinstance(config, dict):
raise TypeError(f"Config file has no top-level mapping: {path}")
return config
def load_active_settings() -> dict[str, Any]:
"""Load active profiles and merge them."""
logger.info("Starting application with profiles=%s", active_profiles)
loaded_profiles = [
load_settings_from_profile(profile) for profile in active_profiles
]
merged: dict[str, Any] = merge_settings(loaded_profiles)
return merged

View File

@ -0,0 +1,41 @@
import os
import re
import typing
from typing import Any, TextIO
from yaml import SafeLoader
_env_replace_matcher = re.compile(r"\$\{(\w|_)+:?.*}")
@typing.no_type_check # pyaml does not have good hints, everything is Any
def load_yaml_with_envvars(
stream: TextIO, environ: dict[str, Any] = os.environ
) -> dict[str, Any]:
"""Load yaml file with environment variable expansion.
The pattern ${VAR} or ${VAR:default} will be replaced with
the value of the environment variable.
"""
loader = SafeLoader(stream)
def load_env_var(_, node) -> str:
"""Extract the matched value, expand env variable, and replace the match."""
value = str(node.value).removeprefix("${").removesuffix("}")
split = value.split(":", 1)
env_var = split[0]
value = environ.get(env_var)
default = None if len(split) == 1 else split[1]
if value is None and default is None:
raise ValueError(
f"Environment variable {env_var} is not set and not default was provided"
)
return value or default
loader.add_implicit_resolver("env_var_replacer", _env_replace_matcher, None)
loader.add_constructor("env_var_replacer", load_env_var)
try:
return loader.get_single_data()
finally:
loader.dispose()

View File

@ -0,0 +1 @@
"""Gradio based UI."""

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

View File

@ -0,0 +1 @@
logo_svg = "data:image/svg+xml;base64,PHN2ZyB3aWR0aD0iODYxIiBoZWlnaHQ9Ijk4IiB2aWV3Qm94PSIwIDAgODYxIDk4IiBmaWxsPSJub25lIiB4bWxucz0iaHR0cDovL3d3dy53My5vcmcvMjAwMC9zdmciPgo8cGF0aCBkPSJNNDguMTM0NSAwLjE1NzkxMUMzNi44Mjk5IDEuMDM2NTQgMjYuMTIwNSA1LjU1MzI4IDE3LjYyNTYgMTMuMDI1QzkuMTMwNDYgMjAuNDk2NyAzLjMxMTcgMzAuNTE2OSAxLjA0OTUyIDQxLjU3MDVDLTEuMjEyNzMgNTIuNjIzOCAwLjIwNDQxOSA2NC4xMDk0IDUuMDg2MiA3NC4yOTA1QzkuOTY4NjggODQuNDcxNiAxOC4wNTAzIDkyLjc5NDMgMjguMTA5OCA5OEwzMy43MDI2IDgyLjU5MDdMMzUuNDU0MiA3Ny43NjU2QzI5LjgzODcgNzQuMTY5MiAyNS41NDQ0IDY4Ljg2MDcgMjMuMjE0IDYyLjYzNDRDMjAuODgyMiA1Ni40MDg2IDIwLjYzOSA0OS41OTkxIDIyLjUyMDQgNDMuMjI0M0MyNC40MDI5IDM2Ljg0OTUgMjguMzA5NiAzMS4yNTI1IDMzLjY1NjEgMjcuMjcwNkMzOS4wMDIgMjMuMjg4MyA0NS41MDAzIDIxLjEzNSA1Mi4xNzg5IDIxLjEzM0M1OC44NTczIDIxLjEzMDMgNjUuMzU3MSAyMy4yNzgzIDcwLjcwNjUgMjcuMjU1OEM3Ni4wNTU0IDMxLjIzNCA3OS45NjY0IDM2LjgyNzcgODEuODU0MyA0My4yMDA2QzgzLjc0MjkgNDkuNTczNiA4My41MDYyIDU2LjM4MzYgODEuMTgwMSA2Mi42MTE3Qzc4Ljg1NDUgNjguODM5NiA3NC41NjUgNzQuMTUxNCA2OC45NTI5IDc3Ljc1MjhMNzAuNzA3NCA4Mi41OTA3TDc2LjMwMDIgOTcuOTk3MUM4Ni45Nzg4IDkyLjQ3MDUgOTUuNDA4OCA4My40NDE5IDEwMC4xNjMgNzIuNDQwNEMxMDQuOTE3IDYxLjQzOTQgMTA1LjcwNCA0OS4xNDE3IDEwMi4zODkgMzcuNjNDOTkuMDc0NiAyNi4xMTc5IDkxLjg2MjcgMTYuMDk5MyA4MS45NzQzIDkuMjcwNzlDNzIuMDg2MSAyLjQ0MTkxIDYwLjEyOTEgLTAuNzc3MDg2IDQ4LjEyODYgMC4xNTg5MzRMNDguMTM0NSAwLjE1NzkxMVoiIGZpbGw9IiMxRjFGMjkiLz4KPGcgY2xpcC1wYXRoPSJ1cmwoI2NsaXAwXzVfMTkpIj4KPHBhdGggZD0iTTIyMC43NzIgMTIuNzUyNEgyNTIuNjM5QzI2Ny4yNjMgMTIuNzUyNCAyNzcuNzM5IDIxLjk2NzUgMjc3LjczOSAzNS40MDUyQzI3Ny43MzkgNDYuNzg3IDI2OS44ODEgNTUuMzUwOCAyNTguMzE0IDU3LjQxMDdMMjc4LjgzIDg1LjM3OTRIMjYxLjM3TDI0Mi4wNTQgNTcuOTUzM0gyMzUuNTA2Vjg1LjM3OTRIMjIwLjc3NEwyMjAuNzcyIDEyLjc1MjRaTTIzNS41MDQgMjYuMzAyOFY0NC40MDdIMjUyLjYzMkMyNTguOTYyIDQ0LjQwNyAyNjIuOTk5IDQwLjgyOTggMjYyLjk5OSAzNS40MTAyQzI2Mi45OTkgMjkuODgwOSAyNTguOTYyIDI2LjMwMjggMjUyLjYzMiAyNi4zMDI4SDIzNS41MDRaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik0yOTUuMTc2IDg1LjM4NDRWMTIuNzUyNEgzMDkuOTA5Vjg1LjM4NDRIMjk1LjE3NloiIGZpbGw9IiMxRjFGMjkiLz4KPHBhdGggZD0iTTM2My43OTUgNjUuNzYzTDM4NS42MiAxMi43NTI0SDQwMS40NDRMMzcxLjIxNSA4NS4zODQ0SDM1Ni40ODNMMzI2LjI1NCAxMi43NTI0SDM0Mi4wNzhMMzYzLjc5NSA2NS43NjNaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik00NDguMzI3IDcyLjA1MDRINDE1LjY5OEw0MTAuMjQxIDg1LjM4NDRIMzk0LjQxOEw0MjQuNjQ3IDEyLjc1MjRINDM5LjM3OUw0NjkuNjA4IDg1LjM4NDRINDUzLjc4M0w0NDguMzI3IDcyLjA1MDRaTTQ0Mi43NjEgNTguNUw0MzIuMDY2IDMyLjM3NDhMNDIxLjI2MiA1OC41SDQ0Mi43NjFaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik00NjUuMjIxIDEyLjc1MjRINTMwLjU5MlYyNi4zMDI4SDUwNS4yNzVWODUuMzg0NEg0OTAuNTM5VjI2LjMwMjhINDY1LjIyMVYxMi43NTI0WiIgZmlsbD0iIzFGMUYyOSIvPgo8cGF0aCBkPSJNNTk1LjE5MyAxMi43NTI0VjI2LjMwMjhINTYyLjEyOFY0MS4xNTUxSDU5NS4xOTNWNTQuNzA2NUg1NjIuMTI4VjcxLjgzNEg1OTUuMTkzVjg1LjM4NDRINTQ3LjM5NVYxMi43NTI0SDU5NS4xOTNaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik0xNjcuMjAxIDU3LjQxNThIMTg2LjUzNkMxOTAuODg2IDU3LjQ2NjIgMTk1LjE2OCA1Ni4zMzQ4IDE5OC45MTggNTQuMTQzN0MyMDIuMTc5IDUyLjIxOTkgMjA0Ljg2OSA0OS40NzM2IDIwNi43MTYgNDYuMTgzNUMyMDguNTYyIDQyLjg5MzQgMjA5LjUgMzkuMTc2NiAyMDkuNDMzIDM1LjQxMDJDMjA5LjQzMyAyMS45Njc1IDE5OC45NTggMTIuNzU3NCAxODQuMzM0IDEyLjc1NzRIMTUyLjQ2OFY4NS4zODk0SDE2Ny4yMDFWNTcuNDIwN1Y1Ny40MTU4Wk0xNjcuMjAxIDI2LjMwNThIMTg0LjMyOUMxOTAuNjU4IDI2LjMwNTggMTk0LjY5NiAyOS44ODQgMTk0LjY5NiAzNS40MTMzQzE5NC42OTYgNDAuODMyOSAxOTAuNjU4IDQ0LjQwOTkgMTg0LjMyOSA0NC40MDk5SDE2Ny4yMDFWMjYuMzA1OFoiIGZpbGw9IiMxRjFGMjkiLz4KPHBhdGggZD0iTTc5NC44MzUgMTIuNzUyNEg4NjAuMjA2VjI2LjMwMjhIODM0Ljg4OVY4NS4zODQ0SDgyMC4xNTZWMjYuMzAyOEg3OTQuODM1VjEyLjc1MjRaIiBmaWxsPSIjMUYxRjI5Ii8+CjxwYXRoIGQ9Ik03NDEuOTA3IDU3LjQxNThINzYxLjI0MUM3NjUuNTkyIDU3LjQ2NjEgNzY5Ljg3NCA1Ni4zMzQ3IDc3My42MjQgNTQuMTQzN0M3NzYuODg0IDUyLjIxOTkgNzc5LjU3NSA0OS40NzM2IDc4MS40MjEgNDYuMTgzNUM3ODMuMjY4IDQyLjg5MzQgNzg0LjIwNiAzOS4xNzY2IDc4NC4xMzkgMzUuNDEwMkM3ODQuMTM5IDIxLjk2NzUgNzczLjY2NCAxMi43NTc0IDc1OS4wMzkgMTIuNzU3NEg3MjcuMTc1Vjg1LjM4OTRINzQxLjkwN1Y1Ny40MjA3VjU3LjQxNThaTTc0MS45MDcgMjYuMzA1OEg3NTkuMDM1Qzc2NS4zNjUgMjYuMzA1OCA3NjkuNDAzIDI5Ljg4NCA3NjkuNDAzIDM1LjQxMzNDNzY5LjQwMyA0MC44MzI5IDc2NS4zNjUgNDQuNDA5OSA3NTkuMDM1IDQ0LjQwOTlINzQxLjkwN1YyNi4zMDU4WiIgZmlsbD0iIzFGMUYyOSIvPgo8cGF0aCBkPSJNNjgxLjA2OSA0Ny4wMTE1VjU5LjAxMjVINjk1LjM3OVY3MS42NzE5QzY5Mi41MjYgNzMuNDM2OCA2ODguNTI0IDc0LjMzMTkgNjgzLjQ3NyA3NC4zMzE5QzY2Ni4wMDMgNzQuMzMxOSA2NTguMDQ1IDYxLjgxMjQgNjU4LjA0NSA1MC4xOEM2NTguMDQ1IDMzLjk2MDUgNjcxLjAwOCAyNS40NzMyIDY4My44MTIgMjUuNDczMkM2OTAuNDI1IDI1LjQ2MjggNjk2LjkwOSAyNy4yODA0IDcwMi41NDEgMzAuNzIyNkw3MDMuMTU3IDMxLjEyNTRMNzA1Ljk1OCAxOC4xODZMNzA1LjY2MyAxNy45OTc3QzcwMC4wNDYgMTQuNDAwNCA2OTEuMjkxIDEyLjI1OSA2ODIuMjUxIDEyLjI1OUM2NjMuMTk3IDEyLjI1OSA2NDIuOTQ5IDI1LjM5NjcgNjQyLjk0OSA0OS43NDVDNjQyLjk0OSA2MS4wODQ1IDY0Ny4yOTMgNzAuNzE3NCA2NTUuNTExIDc3LjYwMjlDNjYzLjIyNCA4My44MjQ1IDY3Mi44NzQgODcuMTg5IDY4Mi44MDkgODcuMTIwMUM2OTQuMzYzIDg3LjEyMDEgNzAzLjA2MSA4NC42NDk1IDcwOS40MDIgNzkuNTY5Mkw3MDkuNTg5IDc5LjQxODFWNDcuMDExNUg2ODEuMDY5WiIgZmlsbD0iIzFGMUYyOSIvPgo8L2c+CjxkZWZzPgo8Y2xpcFBhdGggaWQ9ImNsaXAwXzVfMTkiPgo8cmVjdCB3aWR0aD0iNzA3Ljc3OCIgaGVpZ2h0PSI3NC44NjExIiBmaWxsPSJ3aGl0ZSIgdHJhbnNmb3JtPSJ0cmFuc2xhdGUoMTUyLjQ0NCAxMi4yNSkiLz4KPC9jbGlwUGF0aD4KPC9kZWZzPgo8L3N2Zz4K"

596
pgpt/private_gpt/ui/ui.py Normal file
View File

@ -0,0 +1,596 @@
"""This file should be imported if and only if you want to run the UI locally."""
import base64
import logging
import time
from collections.abc import Iterable
from enum import Enum
from pathlib import Path
from typing import Any
import gradio as gr # type: ignore
from fastapi import FastAPI
from gradio.themes.utils.colors import slate # type: ignore
from injector import inject, singleton
from llama_index.core.llms import ChatMessage, ChatResponse, MessageRole
from llama_index.core.types import TokenGen
from pydantic import BaseModel
from private_gpt.constants import PROJECT_ROOT_PATH
from private_gpt.di import global_injector
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.server.chat.chat_service import ChatService, CompletionGen
from private_gpt.server.chunks.chunks_service import Chunk, ChunksService
from private_gpt.server.ingest.ingest_service import IngestService
from private_gpt.server.recipes.summarize.summarize_service import SummarizeService
from private_gpt.settings.settings import settings
from private_gpt.ui.images import logo_svg
logger = logging.getLogger(__name__)
THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH)
# Should be "private_gpt/ui/avatar-bot.ico"
AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "avatar-bot.ico"
UI_TAB_TITLE = "My Private GPT"
SOURCES_SEPARATOR = "<hr>Sources: \n"
class Modes(str, Enum):
RAG_MODE = "RAG"
SEARCH_MODE = "Search"
BASIC_CHAT_MODE = "Basic"
SUMMARIZE_MODE = "Summarize"
MODES: list[Modes] = [
Modes.RAG_MODE,
Modes.SEARCH_MODE,
Modes.BASIC_CHAT_MODE,
Modes.SUMMARIZE_MODE,
]
class Source(BaseModel):
file: str
page: str
text: str
class Config:
frozen = True
@staticmethod
def curate_sources(sources: list[Chunk]) -> list["Source"]:
curated_sources = []
for chunk in sources:
doc_metadata = chunk.document.doc_metadata
file_name = doc_metadata.get("file_name", "-") if doc_metadata else "-"
page_label = doc_metadata.get("page_label", "-") if doc_metadata else "-"
source = Source(file=file_name, page=page_label, text=chunk.text)
curated_sources.append(source)
curated_sources = list(
dict.fromkeys(curated_sources).keys()
) # Unique sources only
return curated_sources
@singleton
class PrivateGptUi:
@inject
def __init__(
self,
ingest_service: IngestService,
chat_service: ChatService,
chunks_service: ChunksService,
summarizeService: SummarizeService,
) -> None:
self._ingest_service = ingest_service
self._chat_service = chat_service
self._chunks_service = chunks_service
self._summarize_service = summarizeService
# Cache the UI blocks
self._ui_block = None
self._selected_filename = None
# Initialize system prompt based on default mode
default_mode_map = {mode.value: mode for mode in Modes}
self._default_mode = default_mode_map.get(
settings().ui.default_mode, Modes.RAG_MODE
)
self._system_prompt = self._get_default_system_prompt(self._default_mode)
def _chat(
self, message: str, history: list[list[str]], mode: Modes, *_: Any
) -> Any:
def yield_deltas(completion_gen: CompletionGen) -> Iterable[str]:
full_response: str = ""
stream = completion_gen.response
for delta in stream:
if isinstance(delta, str):
full_response += str(delta)
elif isinstance(delta, ChatResponse):
full_response += delta.delta or ""
yield full_response
time.sleep(0.02)
if completion_gen.sources:
full_response += SOURCES_SEPARATOR
cur_sources = Source.curate_sources(completion_gen.sources)
sources_text = "\n\n\n"
used_files = set()
for index, source in enumerate(cur_sources, start=1):
if f"{source.file}-{source.page}" not in used_files:
sources_text = (
sources_text
+ f"{index}. {source.file} (page {source.page}) \n\n"
)
used_files.add(f"{source.file}-{source.page}")
sources_text += "<hr>\n\n"
full_response += sources_text
yield full_response
def yield_tokens(token_gen: TokenGen) -> Iterable[str]:
full_response: str = ""
for token in token_gen:
full_response += str(token)
yield full_response
def build_history() -> list[ChatMessage]:
history_messages: list[ChatMessage] = []
for interaction in history:
history_messages.append(
ChatMessage(content=interaction[0], role=MessageRole.USER)
)
if len(interaction) > 1 and interaction[1] is not None:
history_messages.append(
ChatMessage(
# Remove from history content the Sources information
content=interaction[1].split(SOURCES_SEPARATOR)[0],
role=MessageRole.ASSISTANT,
)
)
# max 20 messages to try to avoid context overflow
return history_messages[:20]
new_message = ChatMessage(content=message, role=MessageRole.USER)
all_messages = [*build_history(), new_message]
# If a system prompt is set, add it as a system message
if self._system_prompt:
all_messages.insert(
0,
ChatMessage(
content=self._system_prompt,
role=MessageRole.SYSTEM,
),
)
match mode:
case Modes.RAG_MODE:
# Use only the selected file for the query
context_filter = None
if self._selected_filename is not None:
docs_ids = []
for ingested_document in self._ingest_service.list_ingested():
if (
ingested_document.doc_metadata["file_name"]
== self._selected_filename
):
docs_ids.append(ingested_document.doc_id)
context_filter = ContextFilter(docs_ids=docs_ids)
query_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=True,
context_filter=context_filter,
)
yield from yield_deltas(query_stream)
case Modes.BASIC_CHAT_MODE:
llm_stream = self._chat_service.stream_chat(
messages=all_messages,
use_context=False,
)
yield from yield_deltas(llm_stream)
case Modes.SEARCH_MODE:
response = self._chunks_service.retrieve_relevant(
text=message, limit=4, prev_next_chunks=0
)
sources = Source.curate_sources(response)
yield "\n\n\n".join(
f"{index}. **{source.file} "
f"(page {source.page})**\n "
f"{source.text}"
for index, source in enumerate(sources, start=1)
)
case Modes.SUMMARIZE_MODE:
# Summarize the given message, optionally using selected files
context_filter = None
if self._selected_filename:
docs_ids = []
for ingested_document in self._ingest_service.list_ingested():
if (
ingested_document.doc_metadata["file_name"]
== self._selected_filename
):
docs_ids.append(ingested_document.doc_id)
context_filter = ContextFilter(docs_ids=docs_ids)
summary_stream = self._summarize_service.stream_summarize(
use_context=True,
context_filter=context_filter,
instructions=message,
)
yield from yield_tokens(summary_stream)
# On initialization and on mode change, this function set the system prompt
# to the default prompt based on the mode (and user settings).
@staticmethod
def _get_default_system_prompt(mode: Modes) -> str:
p = ""
match mode:
# For query chat mode, obtain default system prompt from settings
case Modes.RAG_MODE:
p = settings().ui.default_query_system_prompt
# For chat mode, obtain default system prompt from settings
case Modes.BASIC_CHAT_MODE:
p = settings().ui.default_chat_system_prompt
# For summarization mode, obtain default system prompt from settings
case Modes.SUMMARIZE_MODE:
p = settings().ui.default_summarization_system_prompt
# For any other mode, clear the system prompt
case _:
p = ""
return p
@staticmethod
def _get_default_mode_explanation(mode: Modes) -> str:
match mode:
case Modes.RAG_MODE:
return "Get contextualized answers from selected files."
case Modes.SEARCH_MODE:
return "Find relevant chunks of text in selected files."
case Modes.BASIC_CHAT_MODE:
return "Chat with the LLM using its training data. Files are ignored."
case Modes.SUMMARIZE_MODE:
#######
# Modification par SPC
#
# return "Generate a summary of the selected files. Prompt to customize the result."
return "Générer l'analyse selon les conditions présentées dans le prompt système."
#
# Fin modification par SPC
#######
case _:
return ""
def _set_system_prompt(self, system_prompt_input: str) -> None:
logger.info(f"Setting system prompt to: {system_prompt_input}")
self._system_prompt = system_prompt_input
def _set_explanatation_mode(self, explanation_mode: str) -> None:
self._explanation_mode = explanation_mode
def _set_current_mode(self, mode: Modes) -> Any:
self.mode = mode
self._set_system_prompt(self._get_default_system_prompt(mode))
self._set_explanatation_mode(self._get_default_mode_explanation(mode))
interactive = self._system_prompt is not None
return [
gr.update(placeholder=self._system_prompt, interactive=interactive),
gr.update(value=self._explanation_mode),
]
def _list_ingested_files(self) -> list[list[str]]:
files = set()
for ingested_document in self._ingest_service.list_ingested():
if ingested_document.doc_metadata is None:
# Skipping documents without metadata
continue
file_name = ingested_document.doc_metadata.get(
"file_name", "[FILE NAME MISSING]"
)
files.add(file_name)
return [[row] for row in files]
def _upload_file(self, files: list[str]) -> None:
logger.debug("Loading count=%s files", len(files))
paths = [Path(file) for file in files]
# remove all existing Documents with name identical to a new file upload:
file_names = [path.name for path in paths]
doc_ids_to_delete = []
for ingested_document in self._ingest_service.list_ingested():
if (
ingested_document.doc_metadata
and ingested_document.doc_metadata["file_name"] in file_names
):
doc_ids_to_delete.append(ingested_document.doc_id)
if len(doc_ids_to_delete) > 0:
logger.info(
"Uploading file(s) which were already ingested: %s document(s) will be replaced.",
len(doc_ids_to_delete),
)
for doc_id in doc_ids_to_delete:
self._ingest_service.delete(doc_id)
self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths])
def _delete_all_files(self) -> Any:
ingested_files = self._ingest_service.list_ingested()
logger.debug("Deleting count=%s files", len(ingested_files))
for ingested_document in ingested_files:
self._ingest_service.delete(ingested_document.doc_id)
return [
gr.List(self._list_ingested_files()),
gr.components.Button(interactive=False),
gr.components.Button(interactive=False),
gr.components.Textbox("All files"),
]
def _delete_selected_file(self) -> Any:
logger.debug("Deleting selected %s", self._selected_filename)
# Note: keep looping for pdf's (each page became a Document)
for ingested_document in self._ingest_service.list_ingested():
if (
ingested_document.doc_metadata
and ingested_document.doc_metadata["file_name"]
== self._selected_filename
):
self._ingest_service.delete(ingested_document.doc_id)
return [
gr.List(self._list_ingested_files()),
gr.components.Button(interactive=False),
gr.components.Button(interactive=False),
gr.components.Textbox("All files"),
]
def _deselect_selected_file(self) -> Any:
self._selected_filename = None
return [
gr.components.Button(interactive=False),
gr.components.Button(interactive=False),
gr.components.Textbox("All files"),
]
def _selected_a_file(self, select_data: gr.SelectData) -> Any:
self._selected_filename = select_data.value
return [
gr.components.Button(interactive=True),
gr.components.Button(interactive=True),
gr.components.Textbox(self._selected_filename),
]
def _build_ui_blocks(self) -> gr.Blocks:
logger.debug("Creating the UI blocks")
with gr.Blocks(
title=UI_TAB_TITLE,
theme=gr.themes.Soft(primary_hue=slate),
css=".logo { "
"display:flex;"
"background-color: #C7BAFF;"
"height: 80px;"
"border-radius: 8px;"
"align-content: center;"
"justify-content: center;"
"align-items: center;"
"}"
".logo img { height: 25% }"
".contain { display: flex !important; flex-direction: column !important; }"
"#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }"
"#chatbot { flex-grow: 1 !important; overflow: auto !important;}"
"#col { height: calc(100vh - 112px - 16px) !important; }"
"hr { margin-top: 1em; margin-bottom: 1em; border: 0; border-top: 1px solid #FFF; }"
".avatar-image { background-color: antiquewhite; border-radius: 2px; }"
".footer { text-align: center; margin-top: 20px; font-size: 14px; display: flex; align-items: center; justify-content: center; }"
".footer-zylon-link { display:flex; margin-left: 5px; text-decoration: auto; color: var(--body-text-color); }"
".footer-zylon-link:hover { color: #C7BAFF; }"
".footer-zylon-ico { height: 20px; margin-left: 5px; background-color: antiquewhite; border-radius: 2px; }",
) as blocks:
with gr.Row():
gr.HTML(f"<div class='logo'/><img src={logo_svg} alt=PrivateGPT></div")
with gr.Row(equal_height=False):
with gr.Column(scale=3):
default_mode = self._default_mode
mode = gr.Radio(
[mode.value for mode in MODES],
label="Mode",
value=default_mode,
)
explanation_mode = gr.Textbox(
placeholder=self._get_default_mode_explanation(default_mode),
show_label=False,
max_lines=3,
interactive=False,
)
upload_button = gr.components.UploadButton(
"Upload File(s)",
type="filepath",
file_count="multiple",
size="sm",
)
ingested_dataset = gr.List(
self._list_ingested_files,
headers=["File name"],
label="Ingested Files",
height=235,
interactive=False,
render=False, # Rendered under the button
)
upload_button.upload(
self._upload_file,
inputs=upload_button,
outputs=ingested_dataset,
)
ingested_dataset.change(
self._list_ingested_files,
outputs=ingested_dataset,
)
ingested_dataset.render()
deselect_file_button = gr.components.Button(
"De-select selected file", size="sm", interactive=False
)
selected_text = gr.components.Textbox(
"All files", label="Selected for Query or Deletion", max_lines=1
)
delete_file_button = gr.components.Button(
"🗑️ Delete selected file",
size="sm",
visible=settings().ui.delete_file_button_enabled,
interactive=False,
)
delete_files_button = gr.components.Button(
"⚠️ Delete ALL files",
size="sm",
visible=settings().ui.delete_all_files_button_enabled,
)
deselect_file_button.click(
self._deselect_selected_file,
outputs=[
delete_file_button,
deselect_file_button,
selected_text,
],
)
ingested_dataset.select(
fn=self._selected_a_file,
outputs=[
delete_file_button,
deselect_file_button,
selected_text,
],
)
delete_file_button.click(
self._delete_selected_file,
outputs=[
ingested_dataset,
delete_file_button,
deselect_file_button,
selected_text,
],
)
delete_files_button.click(
self._delete_all_files,
outputs=[
ingested_dataset,
delete_file_button,
deselect_file_button,
selected_text,
],
)
system_prompt_input = gr.Textbox(
placeholder=self._system_prompt,
label="System Prompt",
lines=2,
interactive=True,
render=False,
)
# When mode changes, set default system prompt, and other stuffs
mode.change(
self._set_current_mode,
inputs=mode,
outputs=[system_prompt_input, explanation_mode],
)
# On blur, set system prompt to use in queries
system_prompt_input.blur(
self._set_system_prompt,
inputs=system_prompt_input,
)
def get_model_label() -> str | None:
"""Get model label from llm mode setting YAML.
Raises:
ValueError: If an invalid 'llm_mode' is encountered.
Returns:
str: The corresponding model label.
"""
# Get model label from llm mode setting YAML
# Labels: local, openai, openailike, sagemaker, mock, ollama
config_settings = settings()
if config_settings is None:
raise ValueError("Settings are not configured.")
# Get llm_mode from settings
llm_mode = config_settings.llm.mode
# Mapping of 'llm_mode' to corresponding model labels
model_mapping = {
"llamacpp": config_settings.llamacpp.llm_hf_model_file,
"openai": config_settings.openai.model,
"openailike": config_settings.openai.model,
"azopenai": config_settings.azopenai.llm_model,
"sagemaker": config_settings.sagemaker.llm_endpoint_name,
"mock": llm_mode,
"ollama": config_settings.ollama.llm_model,
"gemini": config_settings.gemini.model,
}
if llm_mode not in model_mapping:
print(f"Invalid 'llm mode': {llm_mode}")
return None
return model_mapping[llm_mode]
with gr.Column(scale=7, elem_id="col"):
# Determine the model label based on the value of PGPT_PROFILES
model_label = get_model_label()
if model_label is not None:
label_text = (
f"LLM: {settings().llm.mode} | Model: {model_label}"
)
else:
label_text = f"LLM: {settings().llm.mode}"
_ = gr.ChatInterface(
self._chat,
chatbot=gr.Chatbot(
label=label_text,
show_copy_button=True,
elem_id="chatbot",
render=False,
avatar_images=(
None,
AVATAR_BOT,
),
),
additional_inputs=[mode, upload_button, system_prompt_input],
)
with gr.Row():
avatar_byte = AVATAR_BOT.read_bytes()
f_base64 = f"data:image/png;base64,{base64.b64encode(avatar_byte).decode('utf-8')}"
gr.HTML(
f"<div class='footer'><a class='footer-zylon-link' href='https://zylon.ai/'>Maintained by Zylon <img class='footer-zylon-ico' src='{f_base64}' alt=Zylon></a></div>"
)
return blocks
def get_ui_blocks(self) -> gr.Blocks:
if self._ui_block is None:
self._ui_block = self._build_ui_blocks()
return self._ui_block
def mount_in_app(self, app: FastAPI, path: str) -> None:
blocks = self.get_ui_blocks()
blocks.queue()
logger.info("Mounting the gradio UI, at path=%s", path)
gr.mount_gradio_app(app, blocks, path=path, favicon_path=AVATAR_BOT)
if __name__ == "__main__":
ui = global_injector.get(PrivateGptUi)
_blocks = ui.get_ui_blocks()
_blocks.queue()
_blocks.launch(debug=False, show_api=False)

View File

@ -0,0 +1 @@
"""general utils."""

View File

@ -0,0 +1,122 @@
import datetime
import logging
import math
import time
from collections import deque
from typing import Any
logger = logging.getLogger(__name__)
def human_time(*args: Any, **kwargs: Any) -> str:
def timedelta_total_seconds(timedelta: datetime.timedelta) -> float:
return (
timedelta.microseconds
+ 0.0
+ (timedelta.seconds + timedelta.days * 24 * 3600) * 10**6
) / 10**6
secs = float(timedelta_total_seconds(datetime.timedelta(*args, **kwargs)))
# We want (ms) precision below 2 seconds
if secs < 2:
return f"{secs * 1000}ms"
units = [("y", 86400 * 365), ("d", 86400), ("h", 3600), ("m", 60), ("s", 1)]
parts = []
for unit, mul in units:
if secs / mul >= 1 or mul == 1:
if mul > 1:
n = int(math.floor(secs / mul))
secs -= n * mul
else:
# >2s we drop the (ms) component.
n = int(secs)
if n:
parts.append(f"{n}{unit}")
return " ".join(parts)
def eta(iterator: list[Any]) -> Any:
"""Report an ETA after 30s and every 60s thereafter."""
total = len(iterator)
_eta = ETA(total)
_eta.needReport(30)
for processed, data in enumerate(iterator, start=1):
yield data
_eta.update(processed)
if _eta.needReport(60):
logger.info(f"{processed}/{total} - ETA {_eta.human_time()}")
class ETA:
"""Predict how long something will take to complete."""
def __init__(self, total: int):
self.total: int = total # Total expected records.
self.rate: float = 0.0 # per second
self._timing_data: deque[tuple[float, int]] = deque(maxlen=100)
self.secondsLeft: float = 0.0
self.nexttime: float = 0.0
def human_time(self) -> str:
if self._calc():
return f"{human_time(seconds=self.secondsLeft)} @ {int(self.rate * 60)}/min"
return "(computing)"
def update(self, count: int) -> None:
# count should be in the range 0 to self.total
assert count > 0
assert count <= self.total
self._timing_data.append((time.time(), count)) # (X,Y) for pearson
def needReport(self, whenSecs: int) -> bool:
now = time.time()
if now > self.nexttime:
self.nexttime = now + whenSecs
return True
return False
def _calc(self) -> bool:
# A sample before a prediction. Need two points to compute slope!
if len(self._timing_data) < 3:
return False
# http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient
# Calculate means and standard deviations.
samples = len(self._timing_data)
# column wise sum of the timing tuples to compute their mean.
mean_x, mean_y = (
sum(i) / samples for i in zip(*self._timing_data, strict=False)
)
std_x = math.sqrt(
sum(pow(i[0] - mean_x, 2) for i in self._timing_data) / (samples - 1)
)
std_y = math.sqrt(
sum(pow(i[1] - mean_y, 2) for i in self._timing_data) / (samples - 1)
)
# Calculate coefficient.
sum_xy, sum_sq_v_x, sum_sq_v_y = 0.0, 0.0, 0
for x, y in self._timing_data:
x -= mean_x
y -= mean_y
sum_xy += x * y
sum_sq_v_x += pow(x, 2)
sum_sq_v_y += pow(y, 2)
pearson_r = sum_xy / math.sqrt(sum_sq_v_x * sum_sq_v_y)
# Calculate regression line.
# y = mx + b where m is the slope and b is the y-intercept.
m = self.rate = pearson_r * (std_y / std_x)
y = self.total
b = mean_y - m * mean_x
x = (y - b) / m
# Calculate fitted line (transformed/shifted regression line horizontally).
fitted_b = self._timing_data[-1][1] - (m * self._timing_data[-1][0])
fitted_x = (y - fitted_b) / m
_, count = self._timing_data[-1] # adjust last data point progress count
adjusted_x = ((fitted_x - x) * (count / self.total)) + x
eta_epoch = adjusted_x
self.secondsLeft = max([eta_epoch - time.time(), 0])
return True

View File

@ -0,0 +1,95 @@
import logging
from collections import deque
from collections.abc import Iterator, Mapping
from typing import Any
from httpx import ConnectError
from tqdm import tqdm # type: ignore
from private_gpt.utils.retry import retry
try:
from ollama import Client, ResponseError # type: ignore
except ImportError as e:
raise ImportError(
"Ollama dependencies not found, install with `poetry install --extras llms-ollama or embeddings-ollama`"
) from e
logger = logging.getLogger(__name__)
_MAX_RETRIES = 5
_JITTER = (3.0, 10.0)
@retry(
is_async=False,
exceptions=(ConnectError, ResponseError),
tries=_MAX_RETRIES,
jitter=_JITTER,
logger=logger,
)
def check_connection(client: Client) -> bool:
try:
client.list()
return True
except (ConnectError, ResponseError) as e:
raise e
except Exception as e:
logger.error(f"Failed to connect to Ollama: {type(e).__name__}: {e!s}")
return False
def process_streaming(generator: Iterator[Mapping[str, Any]]) -> None:
progress_bars = {}
queue = deque() # type: ignore
def create_progress_bar(dgt: str, total: int) -> Any:
return tqdm(
total=total, desc=f"Pulling model {dgt[7:17]}...", unit="B", unit_scale=True
)
current_digest = None
for chunk in generator:
digest = chunk.get("digest")
completed_size = chunk.get("completed", 0)
total_size = chunk.get("total")
if digest and total_size is not None:
if digest not in progress_bars and completed_size > 0:
progress_bars[digest] = create_progress_bar(digest, total=total_size)
if current_digest is None:
current_digest = digest
else:
queue.append(digest)
if digest in progress_bars:
progress_bar = progress_bars[digest]
progress = completed_size - progress_bar.n
if completed_size > 0 and total_size >= progress != progress_bar.n:
if digest == current_digest:
progress_bar.update(progress)
if progress_bar.n >= total_size:
progress_bar.close()
current_digest = queue.popleft() if queue else None
else:
# Store progress for later update
progress_bars[digest].total = total_size
progress_bars[digest].n = completed_size
# Close any remaining progress bars at the end
for progress_bar in progress_bars.values():
progress_bar.close()
def pull_model(client: Client, model_name: str, raise_error: bool = True) -> None:
try:
installed_models = [model["name"] for model in client.list().get("models", {})]
if model_name not in installed_models:
logger.info(f"Pulling model {model_name}. Please wait...")
process_streaming(client.pull(model_name, stream=True))
logger.info(f"Model {model_name} pulled successfully")
except Exception as e:
logger.error(f"Failed to pull model {model_name}: {e!s}")
if raise_error:
raise e

View File

@ -0,0 +1,31 @@
import logging
from collections.abc import Callable
from typing import Any
from retry_async import retry as retry_untyped # type: ignore
retry_logger = logging.getLogger(__name__)
def retry(
exceptions: Any = Exception,
*,
is_async: bool = False,
tries: int = -1,
delay: float = 0,
max_delay: float | None = None,
backoff: float = 1,
jitter: float | tuple[float, float] = 0,
logger: logging.Logger = retry_logger,
) -> Callable[..., Any]:
wrapped = retry_untyped(
exceptions=exceptions,
is_async=is_async,
tries=tries,
delay=delay,
max_delay=max_delay,
backoff=backoff,
jitter=jitter,
logger=logger,
)
return wrapped # type: ignore

View File

@ -0,0 +1,5 @@
from typing import TypeVar
T = TypeVar("T")
K = TypeVar("K")
V = TypeVar("V")

197
pgpt/pyproject.toml Normal file
View File

@ -0,0 +1,197 @@
[tool.poetry]
name = "private-gpt"
version = "0.6.2"
description = "Private GPT"
authors = ["Zylon <hi@zylon.ai>"]
[tool.poetry.dependencies]
python = ">=3.11,<3.12"
# PrivateGPT
fastapi = { extras = ["all"], version = "^0.115.0" }
python-multipart = "^0.0.10"
injector = "^0.22.0"
pyyaml = "^6.0.2"
watchdog = "^4.0.1"
transformers = "^4.44.2"
docx2txt = "^0.8"
cryptography = "^3.1"
# LlamaIndex core libs
llama-index-core = ">=0.11.2,<0.12.0"
llama-index-readers-file = "*"
# Optional LlamaIndex integration libs
llama-index-llms-llama-cpp = {version = "*", optional = true}
llama-index-llms-openai = {version ="*", optional = true}
llama-index-llms-openai-like = {version ="*", optional = true}
llama-index-llms-ollama = {version ="*", optional = true}
llama-index-llms-azure-openai = {version ="*", optional = true}
llama-index-llms-gemini = {version ="*", optional = true}
llama-index-embeddings-ollama = {version ="*", optional = true}
llama-index-embeddings-huggingface = {version ="*", optional = true}
llama-index-embeddings-openai = {version ="*", optional = true}
llama-index-embeddings-azure-openai = {version ="*", optional = true}
llama-index-embeddings-gemini = {version ="*", optional = true}
llama-index-embeddings-mistralai = {version ="*", optional = true}
llama-index-vector-stores-qdrant = {version ="*", optional = true}
llama-index-vector-stores-milvus = {version ="*", optional = true}
llama-index-vector-stores-chroma = {version ="*", optional = true}
llama-index-vector-stores-postgres = {version ="*", optional = true}
llama-index-vector-stores-clickhouse = {version ="*", optional = true}
llama-index-storage-docstore-postgres = {version ="*", optional = true}
llama-index-storage-index-store-postgres = {version ="*", optional = true}
# Postgres
psycopg2-binary = {version ="^2.9.9", optional = true}
asyncpg = {version="^0.29.0", optional = true}
# ClickHouse
clickhouse-connect = {version = "^0.7.19", optional = true}
# Optional Sagemaker dependency
boto3 = {version ="^1.35.26", optional = true}
# Optional Reranker dependencies
torch = {version ="^2.4.1", optional = true}
sentence-transformers = {version ="^3.1.1", optional = true}
# Optional UI
gradio = {version ="^4.44.0", optional = true}
ffmpy = {version ="^0.4.0", optional = true}
# Optional HF Transformers
einops = {version = "^0.8.0", optional = true}
retry-async = "^0.1.4"
[tool.poetry.extras]
ui = ["gradio", "ffmpy"]
llms-llama-cpp = ["llama-index-llms-llama-cpp"]
llms-openai = ["llama-index-llms-openai"]
llms-openai-like = ["llama-index-llms-openai-like"]
llms-ollama = ["llama-index-llms-ollama"]
llms-sagemaker = ["boto3"]
llms-azopenai = ["llama-index-llms-azure-openai"]
llms-gemini = ["llama-index-llms-gemini"]
embeddings-ollama = ["llama-index-embeddings-ollama"]
embeddings-huggingface = ["llama-index-embeddings-huggingface", "einops"]
embeddings-openai = ["llama-index-embeddings-openai"]
embeddings-sagemaker = ["boto3"]
embeddings-azopenai = ["llama-index-embeddings-azure-openai"]
embeddings-gemini = ["llama-index-embeddings-gemini"]
embeddings-mistral = ["llama-index-embeddings-mistralai"]
vector-stores-qdrant = ["llama-index-vector-stores-qdrant"]
vector-stores-clickhouse = ["llama-index-vector-stores-clickhouse", "clickhouse_connect"]
vector-stores-chroma = ["llama-index-vector-stores-chroma"]
vector-stores-postgres = ["llama-index-vector-stores-postgres"]
vector-stores-milvus = ["llama-index-vector-stores-milvus"]
storage-nodestore-postgres = ["llama-index-storage-docstore-postgres","llama-index-storage-index-store-postgres","psycopg2-binary","asyncpg"]
rerank-sentence-transformers = ["torch", "sentence-transformers"]
[tool.poetry.group.dev.dependencies]
black = "^24"
mypy = "^1.11"
pre-commit = "^3"
pytest = "^8"
pytest-cov = "^5"
ruff = "^0"
pytest-asyncio = "^0.24.0"
types-pyyaml = "^6.0.12.20240917"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"
# Packages configs
## coverage
[tool.coverage.run]
branch = true
[tool.coverage.report]
skip_empty = true
precision = 2
## black
[tool.black]
target-version = ['py311']
## ruff
# Recommended ruff config for now, to be updated as we go along.
[tool.ruff]
target-version = 'py311'
# See all rules at https://beta.ruff.rs/docs/rules/
lint.select = [
"E", # pycodestyle
"W", # pycodestyle
"F", # Pyflakes
"B", # flake8-bugbear
"C4", # flake8-comprehensions
"D", # pydocstyle
"I", # isort
"SIM", # flake8-simplify
"TCH", # flake8-type-checking
"TID", # flake8-tidy-imports
"Q", # flake8-quotes
"UP", # pyupgrade
"PT", # flake8-pytest-style
"RUF", # Ruff-specific rules
]
lint.ignore = [
"E501", # "Line too long"
# -> line length already regulated by black
"PT011", # "pytest.raises() should specify expected exception"
# -> would imply to update tests every time you update exception message
"SIM102", # "Use a single `if` statement instead of nested `if` statements"
# -> too restrictive,
"D100",
"D101",
"D102",
"D103",
"D104",
"D105",
"D106",
"D107"
# -> "Missing docstring in public function too restrictive"
]
[tool.ruff.lint.pydocstyle]
# Automatically disable rules that are incompatible with Google docstring convention
convention = "google"
[tool.ruff.lint.pycodestyle]
max-doc-length = 88
[tool.ruff.lint.flake8-tidy-imports]
ban-relative-imports = "all"
[tool.ruff.lint.flake8-type-checking]
strict = true
runtime-evaluated-base-classes = ["pydantic.BaseModel"]
# Pydantic needs to be able to evaluate types at runtime
# see https://pypi.org/project/flake8-type-checking/ for flake8-type-checking documentation
# see https://beta.ruff.rs/docs/settings/#flake8-type-checking-runtime-evaluated-base-classes for ruff documentation
[tool.ruff.lint.per-file-ignores]
# Allow missing docstrings for tests
"tests/**/*.py" = ["D1"]
## mypy
[tool.mypy]
python_version = "3.11"
strict = true
check_untyped_defs = false
explicit_package_bases = true
warn_unused_ignores = false
exclude = ["tests"]
[tool.mypy-llama-index]
ignore_missing_imports = true
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
addopts = [
"--import-mode=importlib",
]

42
pgpt/settings-docker.yaml Normal file
View File

@ -0,0 +1,42 @@
server:
env_name: ${APP_ENV:prod}
port: ${PORT:8080}
llm:
mode: ${PGPT_MODE:mock}
embedding:
mode: ${PGPT_EMBED_MODE:mock}
llamacpp:
llm_hf_repo_id: ${PGPT_HF_REPO_ID:lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF}
llm_hf_model_file: ${PGPT_HF_MODEL_FILE:Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf}
huggingface:
embedding_hf_model_name: ${PGPT_EMBEDDING_HF_MODEL_NAME:nomic-ai/nomic-embed-text-v1.5}
sagemaker:
llm_endpoint_name: ${PGPT_SAGEMAKER_LLM_ENDPOINT_NAME:}
embedding_endpoint_name: ${PGPT_SAGEMAKER_EMBEDDING_ENDPOINT_NAME:}
ollama:
#llm_model: ${PGPT_OLLAMA_LLM_MODEL:llama3:8b-instruct-q4_K_M}
#llm_model: llama3:8b-instruct-q4_K_M
llm_model: qwen3:14b
context_window: 5000
# llm_model: gemma3:12b
# llm_model: deepseek-r1:14b
embedding_model: ${PGPT_OLLAMA_EMBEDDING_MODEL:mxbai-embed-large}
api_base: ${PGPT_OLLAMA_API_BASE:http://ollama:11434}
embedding_api_base: ${PGPT_OLLAMA_EMBEDDING_API_BASE:http://ollama:11434}
tfs_z: ${PGPT_OLLAMA_TFS_Z:1.0}
top_k: ${PGPT_OLLAMA_TOP_K:40}
top_p: ${PGPT_OLLAMA_TOP_P:0.9}
repeat_last_n: ${PGPT_OLLAMA_REPEAT_LAST_N:64}
repeat_penalty: ${PGPT_OLLAMA_REPEAT_PENALTY:1.2}
request_timeout: ${PGPT_OLLAMA_REQUEST_TIMEOUT:6000.0}
autopull_models: ${PGPT_OLLAMA_AUTOPULL_MODELS:true}
ui:
enabled: true
path: /

31
pgpt/settings-ollama.yaml Normal file
View File

@ -0,0 +1,31 @@
server:
env_name: ${APP_ENV:ollama}
llm:
mode: ollama
max_new_tokens: 512
context_window: 3900
temperature: 0.1 #The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual. (Default: 0.1)
embedding:
mode: ollama
ollama:
llm_model: qwen3:14b
context_window: 5000
embedding_model: mxbai-embed-large
api_base: http://ollama:11434
embedding_api_base: http://ollama:11434 # change if your embedding model runs on another ollama
keep_alive: 5m
tfs_z: 2.0 # Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting.
top_k: 40 # Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)
top_p: 0.9 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
repeat_last_n: -1 # Sets how far back for the model to look back to prevent repetition. (Default: 64, 0 = disabled, -1 = num_ctx)
repeat_penalty: 1.5 # Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)
request_timeout: 120.0 # Time elapsed until ollama times out the request. Default is 120s. Format is float.
vectorstore:
database: qdrant
qdrant:
path: local_data/private_gpt/qdrant

158
pgpt/settings.yaml Normal file
View File

@ -0,0 +1,158 @@
# The default configuration file.
# More information about configuration can be found in the documentation: https://docs.privategpt.dev/
# Syntax in `private_pgt/settings/settings.py`
server:
env_name: ${APP_ENV:prod}
port: ${PORT:8001}
cors:
enabled: true
allow_origins: ["*"]
allow_methods: ["*"]
allow_headers: ["*"]
auth:
enabled: false
# python -c 'import base64; print("Basic " + base64.b64encode("secret:key".encode()).decode())'
# 'secret' is the username and 'key' is the password for basic auth by default
# If the auth is enabled, this value must be set in the "Authorization" header of the request.
secret: "Basic c2VjcmV0OmtleQ=="
#data:
# local_ingestion:
# enabled: ${LOCAL_INGESTION_ENABLED:false}
# allow_ingest_from: ["*"]
# local_data_folder: local_data/Corpus/private_gpt
data:
local_ingestion:
enabled: true
allow_ingest_from: ["*"]
local_data_folder: local_data/private_gpt
ui:
enabled: true
path: /
# "RAG", "Search", "Basic", or "Summarize"
default_mode: "RAG"
default_chat_system_prompt: >
Vous ne devez répondre aux questions qu'à partir des données du contexte.
Si vous connaissez la réponse, mais qu'elle n'est pas basée sur le contexte,
faites en suggestion et non une réponse. Annoncez le clairement.
default_query_system_prompt: >
Vous ne devez répondre aux questions qu'à partir des données du contexte.
Si vous connaissez la réponse, mais qu'elle n'est pas basée sur le contexte,
faites en suggestion et non une réponse. Annoncez le clairement.
default_summarization_system_prompt: >
Vous ne devez répondre aux questions qu'à partir des données du contexte.
Si vous connaissez la réponse, mais qu'elle n'est pas basée sur le contexte,
faites en suggestion et non une réponse. Annoncez le clairement.
delete_file_button_enabled: true
delete_all_files_button_enabled: true
llm:
mode: llamacpp
prompt_style: "llama3"
# Should be matching the selected model
max_new_tokens: 512
context_window: 10000
# Select your tokenizer. Llama-index tokenizer is the default.
# tokenizer: meta-llama/Meta-Llama-3.1-8B-Instruct
temperature: 0.1 # The temperature of the model. Increasing the temperature will make the model answer more creatively. A value of 0.1 would be more factual. (Default: 0.1)
rag:
similarity_top_k: 15
#This value controls how many "top" documents the RAG returns to use in the context.
# similarity_value: 0.9
#This value is disabled by default. If you enable this settings, the RAG will only use articles that meet a certain percentage score.
rerank:
enabled: false
model: cross-encoder/ms-marco-MiniLM-L-2-v2
top_n: 1
summarize:
use_async: false
clickhouse:
host: localhost
port: 8443
username: admin
password: clickhouse
database: embeddings
llamacpp:
llm_hf_repo_id: lmstudio-community/Meta-Llama-3.1-8B-Instruct-GGUF
llm_hf_model_file: Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
tfs_z: 2.0 # Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting
top_k: 10 # Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)
top_p: 0.3 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
repeat_penalty: 1.1 # Sets how strongly to penalize repetitions. A higher value (e.g., 1.5) will penalize repetitions more strongly, while a lower value (e.g., 0.9) will be more lenient. (Default: 1.1)
#embedding:
# # Should be matching the value above in most cases
# mode: huggingface
# ingest_mode: simple
# embed_dim: 768 # 768 is for nomic-ai/nomic-embed-text-v1.5
embedding:
mode: ollama
model: mxbai-embed-large
ingest_mode: simple
embed_dim: 1536
huggingface:
embedding_hf_model_name: nomic-ai/nomic-embed-text-v1.5
access_token: ${HF_TOKEN:}
# Warning: Enabling this option will allow the model to download and execute code from the internet.
# Nomic AI requires this option to be enabled to use the model, be aware if you are using a different model.
trust_remote_code: true
nodestore:
database: simple
milvus:
uri: local_data/private_gpt/milvus/milvus_local.db
collection_name: milvus_db
overwrite: false
vectorstore:
database: qdrant
qdrant:
path: local_data/private_gpt/qdrant
postgres:
host: localhost
port: 5432
database: postgres
user: postgres
password: postgres
schema_name: private_gpt
sagemaker:
llm_endpoint_name: huggingface-pytorch-tgi-inference-2023-09-25-19-53-32-140
embedding_endpoint_name: huggingface-pytorch-inference-2023-11-03-07-41-36-479
openai:
api_key: ${OPENAI_API_KEY:}
model: gpt-3.5-turbo
embedding_api_key: ${OPENAI_API_KEY:}
ollama:
llm_model: qwen3:14b
embedding_model: mxbai-embed-large
api_base: http://localhost:11434
embedding_api_base: http://localhost:11434 # change if your embedding model runs on another ollama
keep_alive: 5m
request_timeout: 1200.0
autopull_models: true
azopenai:
api_key: ${AZ_OPENAI_API_KEY:}
azure_endpoint: ${AZ_OPENAI_ENDPOINT:}
embedding_deployment_name: ${AZ_OPENAI_EMBEDDING_DEPLOYMENT_NAME:}
llm_deployment_name: ${AZ_OPENAI_LLM_DEPLOYMENT_NAME:}
api_version: "2023-05-15"
embedding_model: text-embedding-ada-002
llm_model: gpt-35-turbo
gemini:
api_key: ${GOOGLE_API_KEY:}
model: models/gemini-pro
embedding_model: models/embedding-001

View File

@ -0,0 +1,96 @@
#!/bin/bash
# Script de démarrage pour Private-GPT avec Ollama personnalisé
# Permet de lancer facilement private-gpt avec des modèles d'IA personnalisés
set -e
# Répertoire courant où se trouve le script
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
cd "$SCRIPT_DIR"
# Couleurs pour les messages
GREEN='\033[0;32m'
YELLOW='\033[1;33m'
RED='\033[0;31m'
BLUE='\033[0;34m'
NC='\033[0m' # No Color
# Afficher les modèles disponibles
list_models() {
echo -e "${BLUE}Modèles Ollama disponibles:${NC}"
docker exec -it pgpt-ollama-cpu-1 ollama list
}
# Télécharger un nouveau modèle
download_model() {
if [ -z "$1" ]; then
echo -e "${RED}Erreur: Veuillez spécifier un nom de modèle${NC}"
echo "Exemple: $0 download llama3"
exit 1
fi
echo -e "${YELLOW}Téléchargement du modèle $1...${NC}"
docker exec -it pgpt-ollama-cpu-1 ollama pull "$1"
echo -e "${GREEN}Modèle $1 téléchargé avec succès${NC}"
}
# Arrêter les services
stop_services() {
echo -e "${YELLOW}Arrêt des services Private-GPT...${NC}"
docker compose -f docker-compose.yaml down
echo -e "${GREEN}Services arrêtés${NC}"
}
# Démarrer les services
start_services() {
local profile="$1"
if [ -z "$profile" ]; then
profile="ollama-cpu"
fi
echo -e "${YELLOW}Démarrage des services Private-GPT avec le profil $profile...${NC}"
docker compose -f docker-compose.yaml --profile "$profile" up -d
echo -e "${GREEN}Services démarrés avec succès${NC}"
echo -e "${BLUE}Interface web disponible sur: ${GREEN}http://localhost:8001${NC}"
}
# Afficher l'aide
show_help() {
echo -e "${BLUE}Script de démarrage pour Private-GPT avec Ollama${NC}"
echo ""
echo -e "Usage: $0 [commande]"
echo ""
echo "Commandes:"
echo " start [profile] Démarrer les services (profile: ollama-cpu (défaut), ollama-cuda)"
echo " stop Arrêter les services"
echo " list Lister les modèles Ollama disponibles"
echo " download <modèle> Télécharger un nouveau modèle Ollama"
echo " help Afficher cette aide"
echo ""
echo "Exemples:"
echo " $0 start Démarrer avec CPU (par défaut)"
echo " $0 start ollama-cuda Démarrer avec GPU (CUDA)"
echo " $0 download mistral Télécharger le modèle Mistral"
echo ""
}
# Traitement des commandes
case "$1" in
start)
start_services "$2"
;;
stop)
stop_services
;;
list)
list_models
;;
download)
download_model "$2"
;;
*)
show_help
;;
esac

1
pgpt/version.txt Normal file
View File

@ -0,0 +1 @@
0.6.2

View File

@ -1421,7 +1421,7 @@ digraph Hierarchie_Composants_Electroniques_Simplifiee {
subgraph cluster_ProcedeDUV {
label="ProcedeDUV";
fillcolor="#ffd699";
ProcedeDUV [fillcolor="#ffd699", label="Photolitographie DUV", niveau="1000"];
ProcedeDUV [fillcolor="#ffd699", label="Procédé DUV", niveau="1000"];
// Relations sortantes
ProcedeDUV -> Assemblage_ProcedeDUV [];
@ -1476,7 +1476,7 @@ digraph Hierarchie_Composants_Electroniques_Simplifiee {
subgraph cluster_ProcedeEUV {
label="ProcedeEUV";
fillcolor="#ffd699";
ProcedeEUV [fillcolor="#ffd699", label="Photolitographie EUV", niveau="1000"];
ProcedeEUV [fillcolor="#ffd699", label="Procédé EUV", niveau="1000"];
// Relations sortantes
ProcedeEUV -> Assemblage_ProcedeEUV [];
@ -4250,7 +4250,7 @@ digraph Hierarchie_Composants_Electroniques_Simplifiee {
subgraph cluster_CreusetGraphite {
label="CreusetGraphite";
fillcolor="#ffd699";
CreusetGraphite [fillcolor="#ffd699", label="Creuset en graphite - Pour fusion de métaux", niveau="1001"];
CreusetGraphite [fillcolor="#ffd699", label="Creuset graphite", niveau="1001"];
// Relations sortantes
CreusetGraphite -> Graphite [cout="0.5", delai="0.3", ics="0.44", technique="0.5"];
@ -4259,7 +4259,7 @@ digraph Hierarchie_Composants_Electroniques_Simplifiee {
subgraph cluster_CreusetQuartz {
label="CreusetQuartz";
fillcolor="#ffd699";
CreusetQuartz [fillcolor="#ffd699", label="Creuset en quartz - Pour silicium monocristallin", niveau="1001"];
CreusetQuartz [fillcolor="#ffd699", label="Creuset quartz", niveau="1001"];
// Relations sortantes
CreusetQuartz -> Verre [];