import streamlit as st
from networkx.drawing.nx_agraph import write_dot
import pandas as pd
import plotly.graph_objects as go
import networkx as nx
import logging
import tempfile
from utils.graph_utils import (
extraire_chemins_depuis,
extraire_chemins_vers,
couleur_noeud
)
niveau_labels = {
0: "Produit final",
1: "Composant",
2: "Minerai",
10: "Opération",
11: "Pays d'opération",
12: "Acteur d'opération",
99: "Pays géographique"
}
inverse_niveau_labels = {v: k for k, v in niveau_labels.items()}
def extraire_niveaux(G):
"""Extrait les niveaux des nœuds du graphe"""
niveaux = {}
for node, attrs in G.nodes(data=True):
niveau_str = attrs.get("niveau")
try:
if niveau_str:
niveaux[node] = int(str(niveau_str).strip('"'))
except ValueError:
logging.warning(f"Niveau non entier pour le noeud {node}: {niveau_str}")
return niveaux
def extraire_criticite(G, u, v):
"""Extrait la criticité d'un lien entre deux nœuds"""
data = G.get_edge_data(u, v)
if not data:
return 0
if isinstance(data, dict) and all(isinstance(k, int) for k in data):
return float(data[0].get("criticite", 0))
return float(data.get("criticite", 0))
def extraire_chemins_selon_criteres(G, niveaux, niveau_depart, noeuds_depart, noeuds_arrivee, minerais):
"""Extrait les chemins selon les critères spécifiés"""
chemins = []
if noeuds_depart and noeuds_arrivee:
for nd in noeuds_depart:
for na in noeuds_arrivee:
tous_chemins = extraire_chemins_depuis(G, nd)
chemins.extend([chemin for chemin in tous_chemins if na in chemin])
elif noeuds_depart:
for nd in noeuds_depart:
chemins.extend(extraire_chemins_depuis(G, nd))
elif noeuds_arrivee:
for na in noeuds_arrivee:
chemins.extend(extraire_chemins_vers(G, na, niveau_depart))
else:
sources_depart = [n for n in G.nodes() if niveaux.get(n) == niveau_depart]
for nd in sources_depart:
chemins.extend(extraire_chemins_depuis(G, nd))
if minerais:
chemins = [chemin for chemin in chemins if any(n in minerais for n in chemin)]
return chemins
def verifier_critere_ihh(G, chemin, niveaux, ihh_type):
"""Vérifie si un chemin respecte le critère IHH (concentration géographique ou industrielle)"""
ihh_field = "ihh_pays" if ihh_type == "Pays" else "ihh_acteurs"
for i in range(len(chemin) - 1):
u, v = chemin[i], chemin[i + 1]
niveau_u = niveaux.get(u)
niveau_v = niveaux.get(v)
if niveau_u in (10, 1010) and int(G.nodes[u].get(ihh_field, 0)) > 25:
return True
if niveau_v in (10, 1010) and int(G.nodes[v].get(ihh_field, 0)) > 25:
return True
return False
def verifier_critere_ivc(G, chemin, niveaux):
"""Vérifie si un chemin respecte le critère IVC (criticité par rapport à la concurrence sectorielle)"""
for i in range(len(chemin) - 1):
u = chemin[i]
niveau_u = niveaux.get(u)
if niveau_u in (2, 1002) and int(G.nodes[u].get("ivc", 0)) > 30:
return True
return False
def verifier_critere_ics(G, chemin, niveaux):
"""Vérifie si un chemin respecte le critère ICS (criticité d'un minerai pour un composant)"""
for i in range(len(chemin) - 1):
u, v = chemin[i], chemin[i + 1]
niveau_u = niveaux.get(u)
niveau_v = niveaux.get(v)
if ((niveau_u == 1 and niveau_v == 2) or
(niveau_u == 1001 and niveau_v == 1002) or
(niveau_u == 10 and niveau_v in (1000, 1001))) and extraire_criticite(G, u, v) > 0.66:
return True
return False
def verifier_critere_isg(G, chemin, niveaux):
"""Vérifie si un chemin contient un pays instable (ISG ≥ 60)"""
for i in range(len(chemin) - 1):
u, v = chemin[i], chemin[i + 1]
for n in (u, v):
if niveaux.get(n) == 99 and int(G.nodes[n].get("isg", 0)) >= 60:
return True
elif niveaux.get(n) in (11, 12, 1011, 1012):
for succ in G.successors(n):
if niveaux.get(succ) == 99 and int(G.nodes[succ].get("isg", 0)) >= 60:
return True
return False
def extraire_liens_filtres(chemins, niveaux, niveau_depart, niveau_arrivee, niveaux_speciaux):
"""Extrait les liens des chemins en respectant les niveaux"""
liens = set()
for chemin in chemins:
for i in range(len(chemin) - 1):
u, v = chemin[i], chemin[i + 1]
niveau_u = niveaux.get(u, 999)
niveau_v = niveaux.get(v, 999)
if (
(niveau_depart <= niveau_u <= niveau_arrivee or niveau_u in niveaux_speciaux)
and (niveau_depart <= niveau_v <= niveau_arrivee or niveau_v in niveaux_speciaux)
):
liens.add((u, v))
return liens
def filtrer_chemins_par_criteres(G, chemins, niveaux, niveau_depart, niveau_arrivee,
filtrer_ics, filtrer_ivc, filtrer_ihh, ihh_type, filtrer_isg, logique_filtrage):
"""Filtre les chemins selon les critères de vulnérabilité"""
niveaux_speciaux = [1000, 1001, 1002, 1010, 1011, 1012]
# Extraction des liens sans filtrage
liens_chemins = extraire_liens_filtres(chemins, niveaux, niveau_depart, niveau_arrivee, niveaux_speciaux)
# Si aucun filtre n'est appliqué, retourner tous les chemins
if not any([filtrer_ics, filtrer_ivc, filtrer_ihh, filtrer_isg]):
return liens_chemins, set()
# Application des filtres sur les chemins
chemins_filtres = set()
for chemin in chemins:
# Vérification des critères pour ce chemin
has_ihh = filtrer_ihh and verifier_critere_ihh(G, chemin, niveaux, ihh_type)
has_ivc = filtrer_ivc and verifier_critere_ivc(G, chemin, niveaux)
has_criticite = filtrer_ics and verifier_critere_ics(G, chemin, niveaux)
has_isg_critique = filtrer_isg and verifier_critere_isg(G, chemin, niveaux)
# Appliquer la logique de filtrage
if logique_filtrage == "ET":
keep = True
if filtrer_ihh: keep = keep and has_ihh
if filtrer_ivc: keep = keep and has_ivc
if filtrer_ics: keep = keep and has_criticite
if filtrer_isg: keep = keep and has_isg_critique
if keep:
chemins_filtres.add(tuple(chemin))
elif logique_filtrage == "OU":
if has_ihh or has_ivc or has_criticite or has_isg_critique:
chemins_filtres.add(tuple(chemin))
# Extraction des liens après filtrage
liens_filtres = extraire_liens_filtres(
chemins_filtres, niveaux, niveau_depart, niveau_arrivee, niveaux_speciaux
)
return liens_filtres, chemins_filtres
def couleur_criticite(p):
"""Retourne la couleur en fonction du niveau de criticité"""
if p <= 0.33:
return "darkgreen"
elif p <= 0.66:
return "orange"
else:
return "darkred"
def edge_info(G, u, v):
"""Génère l'info-bulle pour un lien"""
data = G.get_edge_data(u, v)
if not data:
return f"Relation : {u} → {v}"
if isinstance(data, dict) and all(isinstance(k, int) for k in data):
data = data[0]
base = [f"{k}: {v}" for k, v in data.items()]
return f"Relation : {u} → {v}
" + "
".join(base)
def preparer_donnees_sankey(G, liens_chemins, niveaux, chemins):
"""Prépare les données pour le graphique Sankey"""
df_liens = pd.DataFrame(list(liens_chemins), columns=["source", "target"])
df_liens = df_liens.groupby(["source", "target"]).size().reset_index(name="value")
df_liens["criticite"] = df_liens.apply(
lambda row: extraire_criticite(G, row["source"], row["target"]), axis=1)
df_liens["value"] = 0.1
# Ne garder que les nœuds effectivement connectés
niveaux_speciaux = [1000, 1001, 1002, 1010, 1011, 1012]
# Inclure les nœuds connectés + tous les nœuds 10xx traversés dans les chemins
noeuds_utilises = set(df_liens["source"]) | set(df_liens["target"])
for chemin in chemins:
for n in chemin:
if niveaux.get(n) in niveaux_speciaux:
noeuds_utilises.add(n)
df_liens["color"] = df_liens.apply(
lambda row: couleur_criticite(row["criticite"]) if row["criticite"] > 0 else "gray",
axis=1
)
all_nodes = pd.unique(df_liens[["source", "target"]].values.ravel())
sorted_nodes = sorted(
all_nodes, key=lambda x: niveaux.get(x, 99), reverse=True)
node_indices = {name: i for i, name in enumerate(sorted_nodes)}
customdata = []
for n in sorted_nodes:
info = [f"{k}: {v}" for k, v in G.nodes[n].items()]
niveau = niveaux.get(n, 99)
# Ajout d'un ISG hérité si applicable
if niveau in (11, 12, 1011, 1012):
for succ in G.successors(n):
if niveaux.get(succ) == 99 and "isg" in G.nodes[succ]:
isg_val = G.nodes[succ]["isg"]
info.append(f"isg (géographique): {isg_val}")
break
customdata.append("
".join(info))
link_customdata = [
edge_info(G, row["source"], row["target"]) for _, row in df_liens.iterrows()
]
return df_liens, sorted_nodes, customdata, link_customdata, node_indices
def creer_graphique_sankey(G, niveaux, df_liens, sorted_nodes, customdata, link_customdata, node_indices):
"""Crée et retourne le graphique Sankey"""
sources = df_liens["source"].map(node_indices).tolist()
targets = df_liens["target"].map(node_indices).tolist()
values = df_liens["value"].tolist()
fig = go.Figure(go.Sankey(
arrangement="snap",
node=dict(
pad=10,
thickness=8,
label=sorted_nodes,
x=[niveaux.get(n, 99) / 100 for n in sorted_nodes],
color=[couleur_noeud(n, niveaux, G) for n in sorted_nodes],
customdata=customdata,
hovertemplate="%{customdata}"
),
link=dict(
source=sources,
target=targets,
value=values,
color=df_liens["color"].tolist(),
customdata=link_customdata,
hovertemplate="%{customdata}"
)
))
fig.update_layout(
title_text="Hiérarchie filtrée par niveaux et noeuds",
paper_bgcolor="white",
plot_bgcolor="white"
)
return fig
def exporter_graphe_filtre(G, liens_chemins):
"""Gère l'export du graphe filtré au format DOT"""
if not st.session_state.get("logged_in", False) or not liens_chemins:
return
G_export = nx.DiGraph()
for u, v in liens_chemins:
G_export.add_node(u, **G.nodes[u])
G_export.add_node(v, **G.nodes[v])
data = G.get_edge_data(u, v)
if isinstance(data, dict) and all(isinstance(k, int) for k in data):
G_export.add_edge(u, v, **data[0])
elif isinstance(data, dict):
G_export.add_edge(u, v, **data)
else:
G_export.add_edge(u, v)
with tempfile.NamedTemporaryFile(delete=False, suffix=".dot", mode="w", encoding="utf-8") as f:
write_dot(G_export, f.name)
dot_path = f.name
with open(dot_path, encoding="utf-8") as f:
st.download_button(
label="Télécharger le fichier DOT filtré",
data=f.read(),
file_name="graphe_filtré.dot",
mime="text/plain"
)
def afficher_sankey(
G,
niveau_depart, niveau_arrivee,
noeuds_depart=None, noeuds_arrivee=None,
minerais=None,
filtrer_ics=False, filtrer_ivc=False,
filtrer_ihh=False, ihh_type="Pays", filtrer_isg=False,
logique_filtrage="OU"):
# Étape 1 : Extraction des niveaux des nœuds
niveaux = extraire_niveaux(G)
# Étape 2 : Extraction des chemins selon les critères
chemins = extraire_chemins_selon_criteres(G, niveaux, niveau_depart, noeuds_depart, noeuds_arrivee, minerais)
if not chemins:
st.warning("Aucun chemin trouvé pour les critères spécifiés.")
return
# Étape 3 : Filtrage des chemins selon les critères de vulnérabilité
liens_chemins, chemins_filtres = filtrer_chemins_par_criteres(
G, chemins, niveaux, niveau_depart, niveau_arrivee,
filtrer_ics, filtrer_ivc, filtrer_ihh, ihh_type, filtrer_isg, logique_filtrage
)
if not liens_chemins:
st.warning("Aucun chemin ne correspond aux critères.")
return
# Étape 4 : Préparation des données pour le graphique Sankey
df_liens, sorted_nodes, customdata, link_customdata, node_indices = preparer_donnees_sankey(
G, liens_chemins, niveaux, chemins_filtres if any([filtrer_ics, filtrer_ivc, filtrer_ihh, filtrer_isg]) else chemins
)
# Étape 5 : Création et affichage du graphique Sankey
fig = creer_graphique_sankey(G, niveaux, df_liens, sorted_nodes, customdata, link_customdata, node_indices)
st.plotly_chart(fig)
# Étape 6 : Export optionnel du graphe filtré
exporter_graphe_filtre(G, liens_chemins)