# -*- coding: utf-8 -*-
"""
Name: two_orbitals.py
Last edit: 27.7.2024
Author: Roman Michelko
--------------------------------
Description:
    Numerical solution to two orbitals per PC 1D diatomic chain problem

"""

import numpy as np
import numpy.typing as npt
import matplotlib.pyplot as plt
from matplotlib.ticker import MultipleLocator, FormatStrFormatter

#%% VARIBALES
N: int = 20                 # Number of primitive cells for calculation
eps_A: float = 1.0          # Eigen-energy of first atomic orbital
eps_B: float = 1.2          # Eigen-energy of second atomic orbital
t: float = 0.1              # Hopping factor

#%% FUNCTIONS AND LAMBDAS

def gen_matrix(k: float) -> np.ndarray:
    mat = np.zeros((2, 2), dtype=np.complex128())
    mat[0,0] = eps_A
    mat[1,1] = eps_B
    mat[0,1] = mat[1,0] = -2*t*np.cos(0.5*np.pi*k)
    
    return mat

def calc_E(k: float) -> float:
    tmp_1 = 0.5*(eps_A + eps_B)
    tmp_2 = np.sqrt(0.25*(eps_A - eps_B)**2 + (2*t*np.cos(0.5*np.pi*k))**2)
    return tmp_1 + tmp_2, tmp_1 - tmp_2

def plot_energy(k_a: np.ndarray, E_a: np.ndarray, k_n: np.ndarray, E_n: np.ndarray, name: str) -> None:
    plt.rcParams['text.usetex'] = True
    plt.rc('text.latex', preamble=r'\usepackage[slovak]{babel}')
    plt.title(name)
    plt.plot(k_a, E_a, label="Analytical", linestyle='-')
    plt.plot(k_n, E_n, label="Numerical", marker='.', linestyle="none")
    
    plt.xlabel(r"k")
    plt.ylabel(r"E(k)")
    
    plt.legend()
    plt.grid(True)
    
    plt.show()

def plot_bands(k_a, E_ap, E_am, k_n, E_np, E_nm) -> None:
    plt.rcParams['text.usetex'] = True
    plt.rc('text.latex', preamble=r'\usepackage[slovak]{babel}')
    plt.title(r"Diatomic chain; two orbitals per PC -- bandstructure")
    plt.plot(k_a, E_ap, label=r"$E_{2;\mathrm{anlt.}}$", linestyle='-')
    plt.plot(k_n, E_np, label=r"$E_{2;\mathrm{num.}}$", marker='.', linestyle="none")
    plt.plot(k_a, E_am, label=r"$E_{1;\mathrm{anlt.}}$", linestyle='-')
    plt.plot(k_n, E_nm, label=r"$E_{1;\mathrm{num.}}$", marker='.', linestyle="none")
    
    plt.xlabel(r"$k \ (\pi/a)$")
    plt.ylabel(r"$E(k)$ (a.u.)")
    
    plt.gca().xaxis.set_major_locator(MultipleLocator(0.5))
    plt.gca().xaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    
    plt.legend()
    plt.grid(True)
    
    plt.savefig("../diatomic_two_orbitals_bands.png", dpi=300)
    plt.show()
    
def save_data(fname: str, k_arr: list, E_lst: list) -> None:
    fw = open(fname, "w")
    fw.write("#k (\pi/a)")
    for i in range(len(E_lst)): fw.write("\tE_{:d} (a.u.)".format(i+1))
    
    for i in range(len(k_arr)):
        fw.write("\n{:.12f}".format(k_arr[i]))
        for band in E_lst: fw.write(" {:.12f}".format(band[i]))
    fw.close()
    
#%% MAIN FUNCTION

def main() -> None:
    # Define arrays of k-values
    k_anlt = np.linspace(-1, 1, num=100)
    k_num  = np.linspace(-1, 1, num=N)
    
    # Analytical calculation
    E_2_anlt, E_1_anlt = calc_E(k_anlt)
    
    # Numerical calculation
    E_2_num = []
    E_1_num = []

    for k in k_num:
        mat = gen_matrix(k)
        E_arr, Psi = np.linalg.eig(mat)
        E_2_num.append(np.max(np.real(E_arr)))
        E_1_num.append(np.min(np.real(E_arr)))
    
    # Save data
    save_data(r"two_orbitals_analytical.dat", k_anlt, [E_1_anlt, E_2_anlt])
    save_data(r"two_orbitals_numerical.dat", k_num, [E_1_num, E_2_num])
    
    # Create plots
    # plot_energy(k_anlt, E_2_anlt, k_num, E_2_num, r"$E_2$")
    # plot_energy(k_anlt, E_1_anlt, k_num, E_1_num, r"$E_1$")
    plot_bands(k_anlt, E_2_anlt, E_1_anlt, k_num, E_2_num, E_1_num)
    
#%% RUN
if __name__ == '__main__': main()