Mean squared displacement data extraction and convergence tests

Mean squared displacement data extraction and convergence tests#

import pickle
import os
import pandas as pd
import numpy as np
import phonopy
import logging
from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt
from phonopy.structure.grid_points import length2mesh
# === Set up logger ===
logging.basicConfig(
    filename='mesh_convergence.log',
    filemode='w',
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger()
# Read the supercell matrix information saved when `convert_ddb_to_phonopy.ipynb` script was run 
sc_df = pd.read_json("/path/to/sc_matrix.json")
with open("msd_stable_mpids.txt", "r", encoding="utf-8") as f:
    stable_mpids = [line.rstrip("\n") for line in f]
os.makedirs("msd_convergence_results/yamls")
BASE_RESULTS_DIR = "/home/anaik/Work/kappa_togo/tdisp_convergence_results/yamls/"
BASE_RESULTS_DIR = "msd_convergence_results/yamls/"
# adjust this path accordingly to location where outputs from `convert_ddb_to_phonopy.ipynb` scripts are stored
# Directory should consist of POSCAR, FORCE_CONSTANTS and BORN file for each mpid. See files in path below for reference
BASE_PHONOPY_FC_DIR = "example_phonon_db_files/phonopy_fc" 
stable_mpids = ["mp-66"]
⚠️ Caution: Long-Runtime

The total duration to execute this script can be very long. Probably more than 24 hours on a 8 core system

The code snipped below will run phonopy msd calculation for each material with mesh sizes ranging from 40 - 240 in increments of 20 for 300 and 600K. Different cutoff frequencies are also tested i.e 0.0, 0.1 and 0.13 THz.

End result is a python dictionary. Below is an example format of the result dictionary for material id mp-66

{'mp-66': {0.0: {'mesh': [40, 60],
   'mesh_array': [array([19, 19, 19]), array([29, 29, 29])],
   'zpe': [34.957745672815214, 34.9578078521624],
   'tdisp': [array([[[0.00180217, 0.00180217, 0.00180217],
            [0.00180217, 0.00180217, 0.00180217]],
    
           [[0.00241203, 0.00241203, 0.00241203],
            [0.00241203, 0.00241203, 0.00241203]]]),
    array([[[0.00181413, 0.00181413, 0.00181413],
            [0.00181413, 0.00181413, 0.00181413]],
    
           [[0.00243585, 0.00243585, 0.00243585],
            [0.00243585, 0.00243585, 0.00243585]]])],
   'tdisp_sites_mean': {300.0: {'C1': [0.0018021732363467302,
      0.0018141306382915814],
     'C2': [0.0018021732363467334, 0.0018141306382915866]},
    600.0: {'C1': [0.002412028148042243, 0.0024358547783380423],
     'C2': [0.0024120281480422513, 0.0024358547783380467]}},
   'tdisp_temperatures': [array([300., 600.]), array([300., 600.])],
   'is_converged': False,
   'converged_at': None},
  0.1: {'mesh': [40, 60],
   'mesh_array': [array([19, 19, 19]), array([29, 29, 29])],
   'zpe': [34.957745672815214, 34.9578078521624],
   'tdisp': [array([[[0.00180217, 0.00180217, 0.00180217],
            [0.00180217, 0.00180217, 0.00180217]],
    
           [[0.00241203, 0.00241203, 0.00241203],
            [0.00241203, 0.00241203, 0.00241203]]]),
    array([[[0.00181413, 0.00181413, 0.00181413],
            [0.00181413, 0.00181413, 0.00181413]],
    
           [[0.00243585, 0.00243585, 0.00243585],
            [0.00243585, 0.00243585, 0.00243585]]])],
   'tdisp_sites_mean': {300.0: {'C1': [0.0018021732363467302,
      0.0018141306382915814],
     'C2': [0.0018021732363467334, 0.0018141306382915866]},
    600.0: {'C1': [0.002412028148042243, 0.0024358547783380423],
     'C2': [0.0024120281480422513, 0.0024358547783380467]}},
   'tdisp_temperatures': [array([300., 600.]), array([300., 600.])],
   'is_converged': False,
   'converged_at': None},
  0.13: {'mesh': [40, 60],
   'mesh_array': [array([19, 19, 19]), array([29, 29, 29])],
   'zpe': [34.957745672815214, 34.9578078521624],
   'tdisp': [array([[[0.00180217, 0.00180217, 0.00180217],
            [0.00180217, 0.00180217, 0.00180217]],
    
           [[0.00241203, 0.00241203, 0.00241203],
            [0.00241203, 0.00241203, 0.00241203]]]),
    array([[[0.00181413, 0.00181413, 0.00181413],
            [0.00181413, 0.00181413, 0.00181413]],
    
           [[0.00243585, 0.00243585, 0.00243585],
            [0.00243585, 0.00243585, 0.00243585]]])],
   'tdisp_sites_mean': {300.0: {'C1': [0.0018021732363467302,
      0.0018141306382915814],
     'C2': [0.0018021732363467334, 0.0018141306382915866]},
    600.0: {'C1': [0.002412028148042243, 0.0024358547783380423],
     'C2': [0.0024120281480422513, 0.0024358547783380467]}},
   'tdisp_temperatures': [array([300., 600.]), array([300., 600.])],
   'is_converged': False,
   'converged_at': None}}}
logger.info("Started convergence checks")

convergence_data = {}
cutoff_freqs = [0.0, 0.1, 0.13]
temperatures = [300.0, 600.0]

for mpid in tqdm(stable_mpids):
    mesh_sizes = list(range(40, 240, 20))
    convergence_data[mpid] = {}

    # Track convergence status per cutoff
    convergence_flags = {cutoff: False for cutoff in cutoff_freqs}

    for ix, mesh in enumerate(mesh_sizes):
        # If all cutoffs converged, stop looping
        if all(convergence_flags.values()):
            break

        sc_mat = np.eye(3) * sc_df.loc[mpid, "sc_matrix"]
        pm_mat = np.eye(3)
        name_pcell = f"{BASE_PHONOPY_FC_DIR}/{mpid}/POSCAR"
        name_ifc2nd = f"{BASE_PHONOPY_FC_DIR}/{mpid}/FORCE_CONSTANTS"
        born_filename = f"{BASE_PHONOPY_FC_DIR}/{mpid}/BORN"

        phonon = phonopy.load(
            supercell_matrix=sc_mat,
            primitive_matrix=pm_mat,
            unitcell_filename=name_pcell,
            is_symmetry=False,
            force_constants_filename=name_ifc2nd,
            is_nac=True,
            born_filename=born_filename,
        )

        if phonon._primitive_symmetry is not None:
            rots = phonon._primitive_symmetry.pointgroup_operations
            mesh_nums = length2mesh(mesh, phonon._primitive.cell, rotations=rots)
        else:
            mesh_nums = length2mesh(mesh, phonon._primitive.cell)

        
        phonon.run_mesh(
            mesh_nums,
            with_eigenvectors=True,
            is_gamma_center=False,
            with_group_velocities=False,
            is_time_reversal=True,
            is_mesh_symmetry=False
        )

        #mesh_nums = phonon.mesh.mesh_numbers

        # Now loop over cutoff frequencies
        for cutoff in cutoff_freqs:
            if convergence_flags[cutoff]:
                continue  # Skip if already converged

            if cutoff not in convergence_data[mpid]:
                convergence_data[mpid][cutoff] = {
                    "mesh": [],
                    "mesh_array": [],
                    "zpe": [],
                    "tdisp": [],
                    "tdisp_sites_mean": {t: {} for t in temperatures},
                    "tdisp_temperatures": [],
                    "is_converged": False,
                    "converged_at": None,
                }

            out_dir = os.path.join(BASE_RESULTS_DIR, mpid)
            os.makedirs(out_dir, exist_ok=True)

            phonon.run_thermal_properties(temperatures=temperatures, cutoff_frequency=cutoff)
            phonon.thermal_properties.write_yaml(filename=f"{out_dir}/thermal_prop_{mesh}_{cutoff}.yaml")
            zpe = phonon.thermal_properties.zero_point_energy

            phonon.run_thermal_displacements(temperatures=temperatures, freq_min=cutoff)
            phonon.thermal_displacements.write_yaml(filename=f"{out_dir}/thermal_disp_{mesh}_{cutoff}.yaml")
            td = phonon.thermal_displacements.thermal_displacements.reshape(len(temperatures),len(phonon._primitive.symbols),  3)
            tdisp_temps = phonon.thermal_displacements.temperatures

            # Append data
            data = convergence_data[mpid][cutoff]
            data["mesh"].append(mesh)
            data["mesh_array"].append(mesh_nums)
            data["zpe"].append(zpe)
            data["tdisp"].append(td)
            data["tdisp_temperatures"].append(tdisp_temps)

            for ele_i, ele in enumerate(phonon._primitive.symbols):
                site_name = f"{ele}{ele_i+1}"
                for temp_ix in range(len(temperatures)):
                    site_mean = td[temp_ix][ele_i].mean()  # mean over x,y,z for each temperature
                    if site_name not in data["tdisp_sites_mean"][temperatures[temp_ix]]:
                        data["tdisp_sites_mean"][temperatures[temp_ix]][site_name] = [site_mean]
                    else:
                        data["tdisp_sites_mean"][temperatures[temp_ix]][site_name].append(site_mean)

            # Check convergence if at least 3 points collected
            if len(data["zpe"]) > 3:
                percent_change_zpe = np.absolute(np.diff(data["zpe"]) / np.array(data["zpe"][:-1])) * 100

                diff_td_temp = {}
                for k, v in data["tdisp_sites_mean"].items():
                    diff_td_temp[k] = {}
                    for k2, v2 in v.items():
                        diff_td_temp[k][k2] = np.absolute(np.diff(v2))

                site_change_td = []
                for i in diff_td_temp.values():
                    site_change_td.append(np.array(list(i.values()))[:, -1])
                #site_change_td = np.absolute(np.diff(site_td, axis=0))[:, -1, :]

                last_three_changes_zpe = percent_change_zpe[-3:]

                if np.all(last_three_changes_zpe < 0.01) and np.all(np.array(site_change_td) < 0.001):
                    msg = f"{mpid} | Cutoff {cutoff} : Converged at mesh: {mesh} i.e., {mesh_nums}"
                    logger.info(msg)
                    data["is_converged"] = True
                    data["converged_at"] = mesh
                    convergence_flags[cutoff] = True

logger.info("All convergence checks completed.")
with open('msd_convergence_results/convergence_data.pkl', 'wb') as f:
    pickle.dump(convergence_data, f)
os.makedirs("msd_convergence_results/msd_bar", exist_ok=True)
BASE_RESULTS_DIR_BAR = "msd_convergence_results/msd_bar/"
def plot_msd_per_site_temperatures(convergence_data, cutoff_freqs, temperatures, BASE_RESULTS_DIR=BASE_RESULTS_DIR_BAR):
    """
    Visualize MSD per site and temperature using the convergence data for different mesh sizes and cutoff freq
    """
    colors = ['tab:blue', 'tab:orange', 'tab:green']
    bar_width = 0.2

    for mpid in convergence_data.keys():
        # Collect all unique site names for this material
        sites = set()
        for cutoff in cutoff_freqs:
            data = convergence_data[mpid].get(cutoff)
            if data:
                for temp in temperatures:
                    sites.update(data["tdisp_sites_mean"][temp].keys())
        sites = sorted(sites)

        n_sites = len(sites)
        n_temps = len(temperatures)

        fig, axes = plt.subplots(n_sites, n_temps, figsize=(5 * n_temps, 4 * n_sites), squeeze=False)
        fig.suptitle(f"Mean Squared Displacement for {mpid}", fontsize=16)

        # Collect all mesh points used overall (union)
        all_meshes = set()
        for cutoff in cutoff_freqs:
            data = convergence_data[mpid].get(cutoff)
            if data:
                all_meshes.update(data["mesh"])
        all_meshes = sorted(all_meshes)
        n_meshes = len(all_meshes)
        index = np.arange(n_meshes)

        for i, site in enumerate(sites):
            for j, temp in enumerate(temperatures):
                ax = axes[i, j]

                for k, cutoff in enumerate(cutoff_freqs):
                    data = convergence_data[mpid].get(cutoff)
                    if data is None:
                        continue

                    meshes_sampled = data["mesh"]
                    disp_vals = []

                    for mesh in all_meshes:
                        if mesh in meshes_sampled:
                            m = meshes_sampled.index(mesh)
                            if site in data["tdisp_sites_mean"][temp]:
                                try:
                                    disp_vals.append(data["tdisp_sites_mean"][temp][site][m])
                                except IndexError:
                                    disp_vals.append(np.nan)
                            else:
                                disp_vals.append(np.nan)
                        else:
                            disp_vals.append(np.nan)

                    xpos = index + k * bar_width
                    ax.bar(xpos, disp_vals, width=bar_width, color=colors[k],
                           label=f'Cutoff {cutoff} THz')

                ax.set_xticks(index + (bar_width * (len(cutoff_freqs) - 1)) / 2)
                ax.set_xticklabels(all_meshes, rotation=45)
                ax.set_xlabel('Mesh size')
                if j == 0:
                    ax.set_ylabel(f'{site}\nMean Squared Displacement (Å^2)')
                else:
                    ax.set_ylabel('')

                ax.set_title(f'Temp = {temp} K')
                ax.grid(True, linestyle='--', alpha=0.4)

        fig.tight_layout(rect=[0, 0, 0.85, 0.96])  # leave space for suptitle
        handles, labels = axes[0, 0].get_legend_handles_labels()
        fig.legend(handles, labels, loc='center right', bbox_to_anchor=(1.0, 0.5), fontsize=9)
        plt.savefig(f"{BASE_RESULTS_DIR}/msd_{mpid}.png")
        plt.close()
os.makedirs("msd_convergence_results/msd_diff", exist_ok=True)
BASE_RESULTS_DIR_DIFF = "msd_convergence_results/msd_diff/"
def plot_msd_differences_per_site_temperatures(convergence_data, cutoff_freqs, temperatures, BASE_RESULTS_DIR=BASE_RESULTS_DIR_DIFF):
    """
    Visualize change in MSD per site and temperature using the convergence data for different mesh sizes and cutoff freq
    """
    colors = ['tab:blue', 'tab:orange', 'tab:green']

    for mpid in convergence_data.keys():
        # Collect all unique site names for this material
        sites = set()
        for cutoff in cutoff_freqs:
            data = convergence_data[mpid].get(cutoff)
            if data:
                for temp in temperatures:
                    sites.update(data["tdisp_sites_mean"][temp].keys())
        sites = sorted(sites)

        n_sites = len(sites)
        n_temps = len(temperatures)

        fig, axes = plt.subplots(n_sites, n_temps, figsize=(5 * n_temps, 4 * n_sites), squeeze=False)
        fig.suptitle(f"MSD Differences for {mpid}", fontsize=16)

        # Collect all mesh points used overall (union)
        all_meshes = set()
        for cutoff in cutoff_freqs:
            data = convergence_data[mpid].get(cutoff)
            if data:
                all_meshes.update(data["mesh"])
        all_meshes = sorted(all_meshes)
        n_meshes = len(all_meshes)

        for i, site in enumerate(sites):
            for j, temp in enumerate(temperatures):
                ax = axes[i, j]

                for k, cutoff in enumerate(cutoff_freqs):
                    data = convergence_data[mpid].get(cutoff)
                    if data is None:
                        continue

                    meshes_sampled = data["mesh"]
                    disp_vals = []

                    for mesh in all_meshes:
                        if mesh in meshes_sampled:
                            m = meshes_sampled.index(mesh)
                            if site in data["tdisp_sites_mean"][temp]:
                                try:
                                    disp_vals.append(data["tdisp_sites_mean"][temp][site][m])
                                except IndexError:
                                    disp_vals.append(np.nan)
                            else:
                                disp_vals.append(np.nan)
                        else:
                            disp_vals.append(np.nan)

                    disp_vals = np.array(disp_vals)

                    # Compute differences between consecutive mesh sizes
                    diffs = np.diff(disp_vals)
                    
                    # Plot differences as line
                    ax.plot(all_meshes[1:], diffs, marker='o', color=colors[k],
                            label=f'Cutoff {cutoff} THz', alpha=0.75)

                ax.set_xticks(all_meshes)
                ax.set_xticklabels(all_meshes, rotation=45)
                ax.set_xlabel('Mesh size')
                if j == 0:
                    ax.set_ylabel(f'{site}\nΔTDisp (Ų)')
                else:
                    ax.set_ylabel('')

                ax.set_title(f'Temp = {temp} K')
                ax.axhline(0, color='grey', linestyle='--', linewidth=1)
                ax.tick_params(axis='both', direction='in')
                #ax.grid(True, linestyle='--', alpha=0.4)
                #ax.grid(False)


        fig.tight_layout(rect=[0, 0, 0.85, 0.96])  # leave space for suptitle
        handles, labels = axes[0, 0].get_legend_handles_labels()
        fig.legend(handles, labels, loc='center right', bbox_to_anchor=(1.0, 0.5), fontsize=9)
        plt.savefig(f"{BASE_RESULTS_DIR}/msd_differences_{mpid}.png")
        plt.close()
plot_msd_per_site_temperatures(convergence_data, cutoff_freqs=[0.0, 0.1, 0.13], temperatures=[300.0, 600.0])
plot_msd_differences_per_site_temperatures(convergence_data, cutoff_freqs=[0.0, 0.1, 0.13], temperatures=[300.0, 600.0])