Database overview#
⚠️ Caution: Change environment
The following code snippet needs to be run in a seperate environment due to minium python version requirements of pymatviz package, please use the provided pymatviz_env.yml to create a new environment before executing the following code blocks
The following code snippet needs to be run in a seperate environment due to minium python version requirements of pymatviz package, please use the provided pymatviz_env.yml to create a new environment before executing the following code blocks
import pickle
import plotly.graph_objects as go
import pandas as pd
import plotly.express as px
import matplotlib
import matplotlib.pyplot as plt
from pathlib import Path
from collections import defaultdict
from multiprocessing import Pool, cpu_count
from tqdm.notebook import tqdm
from monty.serialization import loadfn, dumpfn
from monty.serialization import MontyDecoder
from pymatgen.core import Composition
from pymatgen.core.periodic_table import Element
from pymatviz.enums import ElemCountMode
from pymatviz.sunburst import spacegroup_sunburst
from pymatviz.ptable import ptable_heatmap_plotly
matplotlib.rcParams['pdf.fonttype'] = 42
structures_dir = "absolute/path/to/paper-ml-with-lobster-descriptors/data/structures/structures.json.gz"
target_parent_dir = "absolute/path/to/paper-ml-with-lobster-descriptors/data/targets/"
df_structs = pd.read_json(structures_dir)
df_structs["structure"] = df_structs["structure"].apply(
MontyDecoder().process_decoded
)
def get_structures(structures_dir:str, target_parent_dir:str, dataset_name:str= None):
"""
Get pymatgen structure objects based on data_name arg
Parameters
----------
structures_dir : str
Path to pymatgen structures dataframe json
target_parent_dir : pd.DataFrame
Path to targets dataframe jsons
dataset_name : str, optional
Name of target to filter structures
Returns
-------
pd.DataFrame
Dataframe with pymatgen structure objects
"""
df_structs = pd.read_json(structures_dir)
# restore pymatgen objects from dicts
df_structs["structure"] = df_structs["structure"].apply(MontyDecoder().process_decoded)
if dataset_name:
target_dir = Path(target_parent_dir) / f"{dataset_name}.json"
if target_dir.exists():
target_df = pd.read_json(target_dir)
return df_structs.loc[target_df.index]
else:
raise ValueError(f"No data for {dataset_name} found in {target_parent_dir}")
return df_structs
def classify_composition_oxi(comp):
"""
Classifiy given composition using guessed oxidation states into oxides, halogenides etc.
Parameters
-----------
comp: pymatgen composition object
Returns
--------
tuple
A tuple of form main_type, sub_type, composition and most probable identified oxidation state
"""
elements = {el.symbol for el in comp.elements}
metals = {el for el in elements if Element(el).is_metal}
metalloids = {el for el in elements if Element(el).is_metalloid}
non_metals = elements - metals - metalloids
guesses = comp.oxi_state_guesses() # TODO max_sites to e.g. 200 if needed
if not guesses:
if len(elements) == 1:
return "zero oxidation states", "element", comp, guesses
if metals and not non_metals:
return "zero oxidation states", "intermetallics", comp, guesses
return "zero oxidation states", "other", comp, guesses
oxi = guesses[0]
if len(elements) == 1:
assert set(oxi.values()) == {0}
return "zero oxidation states", "element", comp, oxi
if set(oxi.values()) == {0}:
if metals and not non_metals:
return "zero oxidation states", "intermetallics", comp, oxi
return "zero oxidation states", "other", comp, oxi
if metals and not non_metals and set(oxi.values()) != {0}:
return "polar intermetallics", "polar intermetallics", comp, oxi
anions = {el for el, val in oxi.items() if val < 0}
if len(anions) == 1:
anion = list(anions)[0]
if anion == "H":
return "hydride", "hydride", comp, oxi
if anion in HALOGENS:
return "halogenides", anion, comp, oxi
if anion == "O":
o_oxi = oxi[anion]
oxygen_classes = {
-2: "oxides",
-1: "peroxides",
-0.5: "superoxides"
}
return oxygen_classes.get(o_oxi, "other oxygen compounds"), "O", comp, oxi
if anion in CHALCOGENS:
return "other chalcogenides", anion, comp, oxi
if anion in PNICTIDES:
return "pnictides", anion, comp, oxi
if anion in GROUP14ANIONS:
return "group 14 anions", anion, comp, oxi
if anion in OTHER_ANIONS:
if "H" in oxi and len(oxi) > 2:
return "boron anion", "borohydride", comp, oxi
if "H" not in oxi and len(oxi) == 2:
return "boron anion", "borides", comp, oxi
return "misc", "misc", comp, oxi
# last case: multiple anions
return "mixed anions", "mixed anions", comp, oxi
def build_dataframe(compositions:list[Composition], processes:int=None):
"""
Classify compositions in parallel and return concatenated dataframe
Parameters
-----------
compositions: list of pymatgen composition objects
"""
processes = processes or cpu_count()
rows = []
with Pool(processes=processes) as pool:
for main, sub, comp, oxi in tqdm(pool.imap_unordered(classify_composition_oxi, compositions),
total=len(compositions), desc="Processing compositions"):
rows.append({"main_type": main, "sub_type": sub, "comp": comp, "oxi": oxi})
return pd.DataFrame(rows)
def plot_sunburst(df, color_map):
"""
Get composition classification sunburst plot
Parameters
-----------
df: pd.Dataframe
Pandas dataframe containing composition classification labels, output of `build_dataframe` method
color_map: dict
Dict with color mapping for various classification labels
"""
fig = px.sunburst(
df,
path=["main_type", "sub_type"],
#title="Compound Classification",
width=800,
height=800,
color_discrete_map=color_map,
color="main_type",
)
fig.update_traces(textinfo="label+percent root", textfont_size=24, marker_line_width=0.1) # only percentages
return fig
def extract_structure_data(df_structures:pd.DataFrame):
"""
Extract number of atoms and composition type from pymatgen structures.
Parameters
-----------
df_structures: pd.Dataframe
Dataframe with pymatgen structure objects
Returns:
dict with keys:
- n_atoms (list[int])
- n_elements (list[int])
- type (list[str])
"""
label_map = {
1: "Unary",
2: "Binary",
3: "Ternary",
4: "Quaternary"
}
data = {
"n_atoms": [],
"n_elements": [],
"type": []
}
for s in df_structures.structure:
n_atoms = len(s)
n_elements = len(s.composition.chemical_system_set)
data["n_atoms"].append(n_atoms)
data["n_elements"].append(n_elements)
data["type"].append(label_map.get(n_elements, "Higher"))
return data
def plot_num_atoms_stacked_by_type_histogram(data, figsize=(10, 10), fontsize=14):
"""
Plot stacked histogram of number of atoms,
stacked by composition type.
Parameters
-----------
data: dict
Output of `extract_structure_data` function
figsize: tuple(int, int)
Figure size of matplotlib figure instance
fontsize: int
Fontsize for axis, ticks and legend
"""
color_map = {
"Unary": "#3B82F6", # blue
"Binary": "#14B8A6", # teal
"Ternary": "#F59E0B", # orange
"Quaternary": "#EF4444", # red
"Higher": "#8B5CF6" # purple
}
# Group data by type
groups = defaultdict(list)
for n_atoms, t in zip(data["n_atoms"], data["type"]):
groups[t].append(n_atoms)
# Sort types for consistent legend order
type_order = ["Unary", "Binary", "Ternary", "Quaternary", "Higher"]
ordered_types = [t for t in type_order if t in groups]
colors = [color_map[t] for t in ordered_types]
# Plot
fig = plt.figure(figsize=figsize)
plt.hist(
[groups[t] for t in ordered_types],
bins=80,
color=colors,
stacked=True,
label=ordered_types,
)
plt.xlabel("Number of atoms/unit cell", fontsize=fontsize)
plt.ylabel("Count", fontsize=fontsize)
plt.legend(title="Composition type", fontsize=fontsize-6, title_fontsize=fontsize-6)
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)
plt.tight_layout()
return fig
COLOR_MAP = {
"oxides": "#e76f51", # soft red
"halogenides": "#2a9d8f", # teal-green
"hydride": "#e9ecef", # very light gray
"pnictides": "#9d4edd", # soft purple
"other chalcogenides": "#f4a261", # soft orange
"boron anion": "#ffcb53",
"mixed anions": "#48bfe3", # light blue
"group 14 anions": "#bc6c25", # muted brown
"polar intermetallics": "#577590", # desaturated blue-gray
"zero oxidation states": "#277da1", # calm blue
}
HALOGENS = {"F", "Cl", "Br", "I"}
CHALCOGENS = {"O", "S", "Se", "Te"}
PNICTIDES = {"N", "P", "As", "Sb", "Bi"}
GROUP14ANIONS = {"C", "Si", "Ge", "Sn"}
OTHER_ANIONS = {"B"}
Save plots#
Save periodic table heatmap, compostion classification/space group sunburst plots and number of atoms per compostion type distribution plots
dataset_names = ["last_phdos_peak",
"log_g_vrh",
"log_klat_300",
None]
file_name_prefixes = {
"last_phdos_peak": "vibrational",
"log_g_vrh": "elasticity",
"log_klat_300": "anharmonic",
None : "total"
}
for dataset_name in dataset_names:
df_structs = get_structures(structures_dir=structures_dir, target_parent_dir=target_parent_dir,dataset_name=dataset_name)
compositions=[i.composition for i in df_structs.structure]
file_name_prefix = file_name_prefixes.get(dataset_name)
df = build_dataframe(compositions=compositions)
df.to_pickle(f"{file_name_prefix}_comp_class.pkl") # save compound classification data
fig_sb = plot_sunburst(df, color_map=COLOR_MAP)
fig_sb.write_image(f"{file_name_prefix}_sunburst.pdf", scale=2) # save sunburst plot for compound classification
fig_sb.write_image(f"{file_name_prefix}_sunburst.png", scale=2) # save sunburst plot for compound classification png
pt_fig = ptable_heatmap_plotly(values=[i.composition for i in df_structs.structure],heat_mode="value", log=True,
count_mode=ElemCountMode.occurrence, fill_value=None)
pt_fig.update_traces(textfont_size=24)
pt_fig.update_layout(paper_bgcolor="white",plot_bgcolor="white")
pt_fig.write_image(f"{file_name_prefix}_ptable.pdf", scale=2) # save periodic table heatmap as per element occurrences
pt_fig.write_image(f"{file_name_prefix}_ptable.png", scale=2) # save periodic table heatmap as per element occurrences
spg_fig = spacegroup_sunburst([i.get_space_group_info()[0] for i in df_structs.structure], show_counts="percent")
spg_fig.update_traces(textfont_size=24, marker_line_width=0.1)
spg_fig.update_layout(width=800,height=800, paper_bgcolor="white",plot_bgcolor="white")
spg_fig.write_image(f"{file_name_prefix}_spg.pdf", scale=2) # save space group distribution figure
spg_fig.write_image(f"{file_name_prefix}_spg.png", scale=2) # save space group distribution figure png
hist_data = extract_structure_data(df_structures=df_structs)
fig_hist = plot_num_atoms_stacked_by_type_histogram(data=hist_data, fontsize=24)
fig_hist.savefig(f"{file_name_prefix}_num_atoms.pdf", dpi=300) # save num_atoms
fig_hist.savefig(f"{file_name_prefix}_num_atoms.png", dpi=300) # save num_atoms png
Save ionicity distrubtion plots
Ionicity is calculated using atomic charges obtained from LOBSTER calculations using the formulation
\[I_{\text {Charges }}=\frac{1}{N_{\text {Atoms }}} \sum_i^{N_{\text {Atoms }}}\left(\frac{q_i}{v_{\text {eff }, i}}\right)\]
introduced in R. Nelson, C. Ertural, P. C. Müller, R. Dronskowski, in Compr. Inorg. Chem. III, Elsevier, 2023, pp. 141–201, where, \(q_i\) and \(v_{\mathrm{eff, i}}\) denote the atomic charge and effective valence for atom \(i\) in a structure, respectively.
df_ionicity_ref = pd.read_json("ionicity_data.json") # read pre-computed ionicity data calculated from Lobster charges data
for dataset_name in dataset_names:
if dataset_name:
target_df_path = Path(target_parent_dir) / f"{dataset_name}.json"
if target_df_path.exists():
target_df = pd.read_json(target_df_path)
df_ionicity = df_ionicity_ref.loc[target_df.index]
else:
raise ValueError(f"No data for {dataset_name} found in {target_parent_dir}")
else:
df_ionicity = df_ionicity_ref
file_name_prefix = file_name_prefixes.get(dataset_name)
# Mulliken
fig = plt.figure(figsize=(10, 10))
plt.hist(df_ionicity.Ionicity_Mull.values, bins=100, color="#a6cee3");
plt.xlabel(
"Ionicity (Mulliken charges)",
fontsize=24,
)
plt.ylabel("Counts", fontsize=24)
plt.xticks(fontsize=24)
plt.yticks(fontsize=24)
plt.axvline(df_ionicity_ref.loc["mp-1784", "Ionicity_Mull"], color='red', linestyle='dashed', linewidth=2, label="Ionic (mp-1784: CsF)")
plt.axvline(df_ionicity_ref.loc["mp-1487", "Ionicity_Mull"], color='#f4a261', linestyle='dashed', linewidth=2, label="Intermetallic (mp-1487: AlNi)")
plt.axvline(df_ionicity_ref.loc["mp-10597", "Ionicity_Mull"], color="darkblue", linestyle='dashed', linewidth=2, label="Metal (mp-10597: Ag)")
plt.legend(loc="upper center", fontsize=18, title_fontsize=18, title="Compound type")
plt.tight_layout()
plt.savefig(f"{file_name_prefix}_ionicity_mull.pdf", dpi=300)
plt.savefig(f"{file_name_prefix}_ionicity_mull.png", dpi=300)
plt.close()
# Loewdin
fig = plt.figure(figsize=(10, 10))
plt.hist(df_ionicity.Ionicity_Loew.values, bins=100, color="#a6cee3");
plt.xlabel(
"Ionicity (Loewdin charges)",
fontsize=24,
)
plt.ylabel("Counts", fontsize=24)
plt.xticks(fontsize=24)
plt.yticks(fontsize=24)
plt.axvline(df_ionicity_ref.loc["mp-1784", "Ionicity_Loew"], color='red', linestyle='dashed', linewidth=2, label="Ionic (mp-1784: CsF)")
plt.axvline(df_ionicity_ref.loc["mp-1487", "Ionicity_Loew"], color='#f4a261', linestyle='dashed', linewidth=2, label="Intermetallic (mp-1487: AlNi)")
plt.axvline(df_ionicity_ref.loc["mp-10597", "Ionicity_Loew"], color="darkblue", linestyle='dashed', linewidth=2, label="Metal (mp-10597: Ag)")
plt.legend(loc="upper center", fontsize=18, title_fontsize=18, title="Compound type")
plt.tight_layout()
plt.savefig(f"{file_name_prefix}_ionicity_loew.pdf", dpi=300)
plt.savefig(f"{file_name_prefix}_ionicity_loew.png", dpi=300)
plt.close()