"""
Functions for plotting dependency graphs and feature learnability
"""
import pandas as pd
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
[docs]
def plot_dependency_graph_from_df(
results_df: pd.DataFrame,
feature1_name: str = "Lobster Features",
feature2_name: str = "Matminer Features",
target_name: str = "Target",
metric: str = "MAE Mean",
figsize: tuple = (10, 10),
node_colors: dict | None = None,
title: str = "",
save_path: str | None = None,
) -> None:
"""
Plot a dependency graph between two feature sets and a target using a results DataFrame.
Parameters
----------
results_df : pd.DataFrame
DataFrame containing 'From', 'To', and metric columns.
feature1_name : str
Name of the first feature node.
feature2_name : str
Name of the second feature node.
target_name : str
Name of the target node.
metric : str
Metric to visualize on edges (e.g., 'MAE Mean', 'R2 Mean').
figsize : tuple
Size of the figure.
node_colors : dict, optional
Colors for the nodes, e.g., {feature1_name: 'blue', feature2_name: 'green', target_name: 'orange'}.
title : str
Title of the plot.
save_path : str, optional
If provided, saves the figure to this path.
"""
if metric not in results_df.columns:
raise ValueError(f"Metric '{metric}' not found in results_df.")
def get_metric_value(frm, to):
row = results_df[(results_df["From"] == frm) & (results_df["To"] == to)]
if row.empty:
return np.nan
return row[metric].values[0]
r_feature1_target = get_metric_value(feature1_name, target_name)
r_feature2_target = get_metric_value(feature2_name, target_name)
r_feature1_feature2 = get_metric_value(feature1_name, feature2_name)
r_feature2_feature1 = get_metric_value(feature2_name, feature1_name)
graph_feat_name_1 = feature1_name.split(" ")[0].upper()
graph_feat_name_2 = feature2_name.split(" ")[0].upper()
default_colors = {
graph_feat_name_1: "#7fc7ff", # blue
graph_feat_name_2: "#a5d6a7", # green
target_name: "#ffcc80", # orange
}
colors = {**default_colors, **(node_colors or {})}
G = nx.MultiDiGraph()
G.add_nodes_from([graph_feat_name_1, graph_feat_name_2, target_name])
G.add_edge(
graph_feat_name_1,
target_name,
key="f1_target",
value=r_feature1_target,
color=colors[graph_feat_name_1],
)
G.add_edge(
graph_feat_name_2,
target_name,
key="f2_target",
value=r_feature2_target,
color=colors[graph_feat_name_2],
)
G.add_edge(
graph_feat_name_1,
graph_feat_name_2,
key="f1_f2",
value=r_feature1_feature2,
connectionstyle="arc3,rad=0.2",
color=colors[graph_feat_name_1],
)
G.add_edge(
graph_feat_name_2,
graph_feat_name_1,
key="f2_f1",
value=r_feature2_feature1,
connectionstyle="arc3,rad=0.2",
color=colors[graph_feat_name_2],
)
pos = {
graph_feat_name_1: (0.3, 0.3),
graph_feat_name_2: (0.7, 0.3),
target_name: (0.5, 0.5),
}
fig, ax = plt.subplots(figsize=figsize)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_aspect("equal")
# Draw nodes
for node, color in colors.items():
nx.draw_networkx_nodes(
G, pos, nodelist=[node], node_color=color, node_size=2000, ax=ax
)
nx.draw_networkx_labels(G, pos, font_size=12, font_weight="bold", ax=ax)
# Draw edges
for u, v, key, data in G.edges(keys=True, data=True):
edge_color = data.get("color", "k")
if "connectionstyle" in data:
nx.draw_networkx_edges(
G,
pos,
edgelist=[(u, v)],
connectionstyle=data["connectionstyle"],
edge_color=[edge_color], # list for consistency
style="dashed",
arrows=True,
arrowsize=25,
width=1.5,
ax=ax,
min_source_margin=20,
min_target_margin=20,
)
else:
nx.draw_networkx_edges(
G,
pos,
edgelist=[(u, v)],
edge_color=[edge_color],
style="dashed",
arrows=True,
arrowsize=25,
width=1.5,
ax=ax,
min_source_margin=20,
min_target_margin=20,
)
# Annotate edges with matching color
for u, v, key, data in G.edges(keys=True, data=True):
x1, y1 = pos[u]
x2, y2 = pos[v]
xm, ym = (x1 + x2) / 2, (y1 + y2) / 2
offset = (
-0.03
if (u, v) == (graph_feat_name_1, graph_feat_name_2)
else (0.03 if (u, v) == (graph_feat_name_2, graph_feat_name_1) else 0)
)
ax.text(
xm,
ym + offset,
f"{metric}={data['value']:.2f}",
fontsize=12,
ha="center",
va="center",
fontweight="bold",
bbox=dict(facecolor="white", edgecolor="none"),
color=data.get("color", "k"),
)
ax.set_axis_off()
# Custom title placement
fig.text(0.5, 0.55, title, ha="center", va="center", fontsize=12, fontweight="bold")
if save_path:
plt.savefig(save_path, dpi=300, pad_inches=0, bbox_inches="tight")
plt.close()
else:
plt.show()
[docs]
def plot_feature_learnability(
results: pd.DataFrame,
title: str = "Feature Learnability",
n_feats: int = 20,
save_path: str | None = None,
) -> None:
"""
Create a horizontal bar chart visualization of R² (mean ± std).
Parameters
----------
results : pd.DataFrame
Must contain 'R2_mean' and 'R2_std'.
title : str
Plot title.
n_feats: int
Number of top learned features to plot
save_path : str, optional
If provided, saves the figure.
"""
required_cols = {"R2_mean", "R2_std"}
if not required_cols.issubset(results.columns):
raise ValueError(f"Input results must contain columns: {required_cols}")
# Take top 20 features by R²
df = results.copy()
df = df.sort_values("R2_mean", ascending=True).tail(n_feats)
# Create figure
fig, ax = plt.subplots(figsize=(14, 10))
ax.barh(
df.index.astype(str),
df["R2_mean"],
xerr=df["R2_std"],
capsize=4,
color="#a6cee3",
ecolor="black",
edgecolor="black",
linewidth=0.1,
)
ax.set_xlabel("R² (mean ± std)", fontsize=14)
ax.set_title(title, size=14, pad=12, fontsize=18)
ax.grid(axis="x", linestyle="--", alpha=0.6)
ax.tick_params(axis="both", labelsize=14)
plt.tight_layout()
if save_path:
plt.savefig(save_path, bbox_inches="tight", dpi=300)
plt.close()
else:
plt.show()