Database overview

Contents

Database overview#

⚠️ Caution: Change environment

The following code snippet needs to be run in a seperate environment due to minium python version requirements of pymatviz package, please use the provided pymatviz_env.yml to create a new environment before executing the following code blocks

import pickle
import plotly.graph_objects as go
import pandas as pd
import plotly.express as px
import matplotlib
import matplotlib.pyplot as plt

from pathlib import Path
from collections import defaultdict
from multiprocessing import Pool, cpu_count
from tqdm.notebook import tqdm
from monty.serialization import loadfn, dumpfn
from monty.serialization import MontyDecoder
from pymatgen.core import Composition
from pymatgen.core.periodic_table import Element
from pymatviz.enums import ElemCountMode
from pymatviz.sunburst import spacegroup_sunburst
from pymatviz.ptable import ptable_heatmap_plotly
matplotlib.rcParams['pdf.fonttype'] = 42
structures_dir = "absolute/path/to/paper-ml-with-lobster-descriptors/data/structures/structures.json.gz"
target_parent_dir = "absolute/path/to/paper-ml-with-lobster-descriptors/data/targets/"
df_structs = pd.read_json(structures_dir)

df_structs["structure"] = df_structs["structure"].apply(
        MontyDecoder().process_decoded
    )
def get_structures(structures_dir:str, target_parent_dir:str, dataset_name:str= None):
    """
    Get pymatgen structure objects based on data_name arg

    Parameters
    ----------
    structures_dir : str
        Path to pymatgen structures dataframe json
    target_parent_dir : pd.DataFrame
        Path to targets dataframe jsons
    dataset_name : str, optional
        Name of target to filter structures
    
    Returns
    -------
    pd.DataFrame
        Dataframe with pymatgen structure objects
    """
    df_structs = pd.read_json(structures_dir)

    # restore pymatgen objects from dicts
    df_structs["structure"] = df_structs["structure"].apply(MontyDecoder().process_decoded)

    if dataset_name:
        target_dir = Path(target_parent_dir) / f"{dataset_name}.json"
        if target_dir.exists():
            target_df = pd.read_json(target_dir)
    
            return df_structs.loc[target_df.index]
        else:
            raise ValueError(f"No data for {dataset_name} found in {target_parent_dir}")

    return df_structs
    
def classify_composition_oxi(comp):
    """
    Classifiy given composition using guessed oxidation states into oxides, halogenides etc.

    Parameters
    -----------
    comp: pymatgen composition object

    Returns
    --------
    tuple
        A tuple of form main_type, sub_type, composition and most probable identified oxidation state
    """

    elements = {el.symbol for el in comp.elements}
    metals = {el for el in elements if Element(el).is_metal}
    metalloids = {el for el in elements if Element(el).is_metalloid}
    non_metals = elements - metals - metalloids
    
    guesses = comp.oxi_state_guesses()  # TODO max_sites to e.g. 200 if needed
    if not guesses:
        if len(elements) == 1:
            return "zero oxidation states", "element", comp, guesses
        if metals and not non_metals:
            return "zero oxidation states", "intermetallics", comp, guesses

        return "zero oxidation states", "other", comp, guesses

    oxi = guesses[0]

    if len(elements) == 1:
        assert set(oxi.values()) == {0}
        return "zero oxidation states", "element", comp, oxi

    if set(oxi.values()) == {0}:
        if metals and not non_metals:
            return "zero oxidation states", "intermetallics", comp, oxi

        return "zero oxidation states", "other", comp, oxi

    if metals and not non_metals and set(oxi.values()) != {0}:
        return "polar intermetallics", "polar intermetallics", comp, oxi
    
    anions = {el for el, val in oxi.items() if val < 0}

    if len(anions) == 1:

        anion = list(anions)[0]

        if anion == "H":
            return "hydride", "hydride", comp, oxi

        if anion in HALOGENS:
            return "halogenides", anion, comp, oxi

        if anion == "O":
            o_oxi = oxi[anion]

            oxygen_classes = {
                -2: "oxides",
                -1: "peroxides",
                -0.5: "superoxides"
            }
        
            return oxygen_classes.get(o_oxi, "other oxygen compounds"), "O", comp, oxi
            
        if anion in CHALCOGENS:
            return "other chalcogenides", anion, comp, oxi

        if anion in PNICTIDES:
            return "pnictides", anion, comp, oxi

        if anion in GROUP14ANIONS:
            return "group 14 anions", anion, comp, oxi

        if anion in OTHER_ANIONS:
            if "H" in oxi and len(oxi) > 2:
                return "boron anion", "borohydride", comp, oxi
            if "H" not in oxi and len(oxi) == 2:
                return "boron anion", "borides", comp, oxi

        return "misc", "misc", comp, oxi

    # last case: multiple anions
    return "mixed anions", "mixed anions", comp, oxi
def build_dataframe(compositions:list[Composition], processes:int=None):
    """
    Classify compositions in parallel and return concatenated dataframe

    Parameters
    -----------
    compositions: list of pymatgen composition objects
    """
    processes = processes or cpu_count()
    rows = []

    with Pool(processes=processes) as pool:
        for main, sub, comp, oxi in tqdm(pool.imap_unordered(classify_composition_oxi, compositions),
                               total=len(compositions), desc="Processing compositions"):
            rows.append({"main_type": main, "sub_type": sub, "comp": comp, "oxi": oxi})

    return pd.DataFrame(rows)
def plot_sunburst(df, color_map):
    """
    Get composition classification sunburst plot

    Parameters
    -----------
    df: pd.Dataframe
        Pandas dataframe containing composition classification labels, output of `build_dataframe` method
    color_map: dict
        Dict with color mapping for various classification labels
    
    """
    
    fig = px.sunburst(
        df,
        path=["main_type", "sub_type"],
        #title="Compound Classification",
        width=800,
        height=800,
        color_discrete_map=color_map,
        color="main_type",
    )
    
    fig.update_traces(textinfo="label+percent root", textfont_size=24, marker_line_width=0.1)  # only percentages

    return fig
def extract_structure_data(df_structures:pd.DataFrame):
    """
    Extract number of atoms and composition type from pymatgen structures.

    Parameters
    -----------
    df_structures: pd.Dataframe
        Dataframe with pymatgen structure objects

    Returns:
        dict with keys:
            - n_atoms (list[int])
            - n_elements (list[int])
            - type (list[str])
    """
    label_map = {
        1: "Unary",
        2: "Binary",
        3: "Ternary",
        4: "Quaternary"
    }

    data = {
        "n_atoms": [],
        "n_elements": [],
        "type": []
    }

    for s in df_structures.structure:
        n_atoms = len(s)
        n_elements = len(s.composition.chemical_system_set)

        data["n_atoms"].append(n_atoms)
        data["n_elements"].append(n_elements)
        data["type"].append(label_map.get(n_elements, "Higher"))

    return data
def plot_num_atoms_stacked_by_type_histogram(data, figsize=(10, 10), fontsize=14):
    """
    Plot stacked histogram of number of atoms,
    stacked by composition type.

    Parameters
    -----------
    data: dict
        Output of `extract_structure_data` function
    figsize: tuple(int, int)
        Figure size of matplotlib figure instance
    fontsize: int
        Fontsize for axis, ticks and legend
    """

    color_map = {
        "Unary": "#3B82F6",       # blue
        "Binary": "#14B8A6",      # teal
        "Ternary": "#F59E0B",     # orange
        "Quaternary": "#EF4444",  # red
        "Higher": "#8B5CF6"       # purple
    }
    
    # Group data by type
    groups = defaultdict(list)
    for n_atoms, t in zip(data["n_atoms"], data["type"]):
        groups[t].append(n_atoms)

    # Sort types for consistent legend order
    type_order = ["Unary", "Binary", "Ternary", "Quaternary", "Higher"]
    ordered_types = [t for t in type_order if t in groups]
    colors = [color_map[t] for t in ordered_types]
    
    # Plot
    fig = plt.figure(figsize=figsize)

    plt.hist(
        [groups[t] for t in ordered_types],
        bins=80,
        color=colors,
        stacked=True,
        label=ordered_types,
    )

    plt.xlabel("Number of atoms/unit cell", fontsize=fontsize)
    plt.ylabel("Count", fontsize=fontsize)
    plt.legend(title="Composition type", fontsize=fontsize-6, title_fontsize=fontsize-6)
    plt.xticks(fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    plt.tight_layout()
    return fig
COLOR_MAP = {
    "oxides": "#e76f51",                 # soft red 
    "halogenides": "#2a9d8f",            # teal-green 
    "hydride": "#e9ecef",                # very light gray 
    "pnictides": "#9d4edd",              # soft purple 
    "other chalcogenides": "#f4a261",    # soft orange
    "boron anion": "#ffcb53",
    "mixed anions": "#48bfe3",           # light blue 
    "group 14 anions": "#bc6c25",        # muted brown 
    "polar intermetallics": "#577590",   # desaturated blue-gray 
    "zero oxidation states": "#277da1",  # calm blue
}


HALOGENS = {"F", "Cl", "Br", "I"}
CHALCOGENS = {"O", "S", "Se", "Te"}
PNICTIDES = {"N", "P", "As", "Sb", "Bi"}
GROUP14ANIONS = {"C", "Si", "Ge", "Sn"}
OTHER_ANIONS = {"B"}

Save plots#

Save periodic table heatmap, compostion classification/space group sunburst plots and number of atoms per compostion type distribution plots

dataset_names = ["last_phdos_peak", 
                "log_g_vrh", 
                "log_klat_300",
               None]
file_name_prefixes = {
    "last_phdos_peak": "vibrational",
    "log_g_vrh": "elasticity",
    "log_klat_300": "anharmonic",
    None : "total"
    }
for dataset_name in dataset_names:
    df_structs = get_structures(structures_dir=structures_dir, target_parent_dir=target_parent_dir,dataset_name=dataset_name)
    compositions=[i.composition for i in df_structs.structure]

    file_name_prefix = file_name_prefixes.get(dataset_name)
    
    df = build_dataframe(compositions=compositions)
    df.to_pickle(f"{file_name_prefix}_comp_class.pkl") # save compound classification data
    
    fig_sb = plot_sunburst(df, color_map=COLOR_MAP)
    fig_sb.write_image(f"{file_name_prefix}_sunburst.pdf", scale=2) # save sunburst plot for compound classification
    fig_sb.write_image(f"{file_name_prefix}_sunburst.png", scale=2) # save sunburst plot for compound classification png

    pt_fig = ptable_heatmap_plotly(values=[i.composition for i in df_structs.structure],heat_mode="value", log=True,
                      count_mode=ElemCountMode.occurrence, fill_value=None)
    pt_fig.update_traces(textfont_size=24)
    pt_fig.update_layout(paper_bgcolor="white",plot_bgcolor="white")

    pt_fig.write_image(f"{file_name_prefix}_ptable.pdf", scale=2) # save periodic table heatmap as per element occurrences
    pt_fig.write_image(f"{file_name_prefix}_ptable.png", scale=2) # save periodic table heatmap as per element occurrences

    spg_fig = spacegroup_sunburst([i.get_space_group_info()[0] for i in df_structs.structure], show_counts="percent")
    spg_fig.update_traces(textfont_size=24, marker_line_width=0.1)
    spg_fig.update_layout(width=800,height=800, paper_bgcolor="white",plot_bgcolor="white")

    spg_fig.write_image(f"{file_name_prefix}_spg.pdf", scale=2) # save space group distribution figure
    spg_fig.write_image(f"{file_name_prefix}_spg.png", scale=2) # save space group distribution figure png

    hist_data = extract_structure_data(df_structures=df_structs)

    fig_hist = plot_num_atoms_stacked_by_type_histogram(data=hist_data, fontsize=24)
    fig_hist.savefig(f"{file_name_prefix}_num_atoms.pdf", dpi=300) # save num_atoms
    fig_hist.savefig(f"{file_name_prefix}_num_atoms.png", dpi=300) # save num_atoms png

Save ionicity distrubtion plots

Ionicity is calculated using atomic charges obtained from LOBSTER calculations using the formulation

\[I_{\text {Charges }}=\frac{1}{N_{\text {Atoms }}} \sum_i^{N_{\text {Atoms }}}\left(\frac{q_i}{v_{\text {eff }, i}}\right)\]

introduced in R. Nelson, C. Ertural, P. C. Müller, R. Dronskowski, in Compr. Inorg. Chem. III, Elsevier, 2023, pp. 141–201, where, \(q_i\) and \(v_{\mathrm{eff, i}}\) denote the atomic charge and effective valence for atom \(i\) in a structure, respectively.

df_ionicity_ref = pd.read_json("ionicity_data.json") # read pre-computed ionicity data calculated from Lobster charges data
for dataset_name in dataset_names:
    if dataset_name:
        target_df_path = Path(target_parent_dir) / f"{dataset_name}.json"

        if target_df_path.exists():
            target_df = pd.read_json(target_df_path)
    
            df_ionicity = df_ionicity_ref.loc[target_df.index]
        else:
            raise ValueError(f"No data for {dataset_name} found in {target_parent_dir}")
    else:
        df_ionicity = df_ionicity_ref

    file_name_prefix = file_name_prefixes.get(dataset_name)


    # Mulliken
    fig = plt.figure(figsize=(10, 10))
    plt.hist(df_ionicity.Ionicity_Mull.values, bins=100, color="#a6cee3");
    plt.xlabel(
        "Ionicity (Mulliken charges)",
        fontsize=24,
    )
    plt.ylabel("Counts", fontsize=24)
    plt.xticks(fontsize=24)
    plt.yticks(fontsize=24)
    plt.axvline(df_ionicity_ref.loc["mp-1784", "Ionicity_Mull"], color='red', linestyle='dashed', linewidth=2, label="Ionic (mp-1784: CsF)")
    plt.axvline(df_ionicity_ref.loc["mp-1487", "Ionicity_Mull"], color='#f4a261', linestyle='dashed', linewidth=2, label="Intermetallic (mp-1487: AlNi)")
    plt.axvline(df_ionicity_ref.loc["mp-10597", "Ionicity_Mull"], color="darkblue", linestyle='dashed', linewidth=2, label="Metal (mp-10597: Ag)")
    
    plt.legend(loc="upper center", fontsize=18, title_fontsize=18, title="Compound type")
    plt.tight_layout()
    plt.savefig(f"{file_name_prefix}_ionicity_mull.pdf", dpi=300)
    plt.savefig(f"{file_name_prefix}_ionicity_mull.png", dpi=300)
    plt.close()
    
    # Loewdin
    fig = plt.figure(figsize=(10, 10))
    plt.hist(df_ionicity.Ionicity_Loew.values, bins=100, color="#a6cee3");
    plt.xlabel(
        "Ionicity (Loewdin charges)",
        fontsize=24,
    )
    plt.ylabel("Counts", fontsize=24)
    plt.xticks(fontsize=24)
    plt.yticks(fontsize=24)
    plt.axvline(df_ionicity_ref.loc["mp-1784", "Ionicity_Loew"], color='red', linestyle='dashed', linewidth=2, label="Ionic (mp-1784: CsF)")
    plt.axvline(df_ionicity_ref.loc["mp-1487", "Ionicity_Loew"], color='#f4a261', linestyle='dashed', linewidth=2, label="Intermetallic (mp-1487: AlNi)")
    plt.axvline(df_ionicity_ref.loc["mp-10597", "Ionicity_Loew"], color="darkblue", linestyle='dashed', linewidth=2, label="Metal (mp-10597: Ag)")
    plt.legend(loc="upper center", fontsize=18, title_fontsize=18, title="Compound type")
    plt.tight_layout()
    plt.savefig(f"{file_name_prefix}_ionicity_loew.pdf", dpi=300)
    plt.savefig(f"{file_name_prefix}_ionicity_loew.png", dpi=300)
    plt.close()