# -*- coding: utf-8 -*-
"""
Pytheia v0.9.9: Created on Tue Mar 18 16:41:38 2025
Pytheia v1.0.0: Released on August 06, 2025
@author: Nouta Albert Einstein
@supplier: Orionis Group Company, Princeton NJ USA
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import base64
import sys
from termcolor import colored
from sklearn.linear_model import BayesianRidge
from time import time
import traceback

def _ensure_main_injection():
    mod = sys.modules.get(__name__)
    required = ("org_name", "version_string", "report_timestamp")
    missing = [k for k in required if not hasattr(mod, k)]
    if missing:
        raise RuntimeError(
            "Pytheia core is not intended to be executed directly. "
            "Missing injected context: %s. Run via Pytheia100_main.py." % ", ".join(missing)
        )

input_dat = "Ecam_astrometry_pm-02-04-2025.dat"

selected_terms = ["AN", "AW", "IA", "IE", "CA", "TX", "ECES", "HZSZ2"]

output_dir = "My_Pointing_model_results"

#mode = "hadec"
mode = "azel"

lat = -29.259  

phi = np.radians(lat)  

aux1 = None  
aux2 = None  

CORRELATION_THRESHOLD = 0.8

PLOT_FIGSIZE = (18, 18)       
PLOT_DPI = 300                
PLOT_STYLE = 'default'        

def load_data(file_path):
    try:
        with open(file_path, 'r') as f:
            for i, line in enumerate(f, 1):
                tokens = line.strip().split()
                if not tokens:
                    continue  
                if len(tokens) != 4:
                    raise ValueError(f"Line {i} has {len(tokens)} values (expected 4): {line.strip()}")

        data = pd.read_csv(
            file_path,
            sep=r'\s+',
            header=None,
            names=["coord1_real", "coord2_real", "coord1_actual", "coord2_actual"]
        )

        if len(data) == 0:
            raise ValueError("Empty data file")

        data = data.apply(pd.to_numeric, errors="raise")

    except Exception as e:
        raise ValueError(f"Failed to load or validate data file: {str(e)}")

    if mode == 'azel':
        for col in ["coord1_real", "coord1_actual"]:
            if (data[col] < -360).any() or (data[col] > 360).any():
                raise ValueError(f"{col} contains invalid azimuth values - outside [-360, 360] range")
            data[col] = data[col] % 360

        for col in ["coord2_real", "coord2_actual"]:
            if (data[col] < 0).any() or (data[col] > 90).any():
                raise ValueError(f"{col} contains invalid elevation values - outside [0, 90] range")

    elif mode == 'hadec':
        for col in ["coord1_real", "coord1_actual"]:
            if (data[col] < -360).any() or (data[col] > 360).any():
                raise ValueError(f"Hour angle ({col}) must be between -360° and 360°")
            data[col] = data[col] % 360  
        for col in ["coord2_real", "coord2_actual"]:
            if (data[col] < -90).any() or (data[col] > 90).any():
                raise ValueError(f"Declination ({col}) must be between -90° and 90°")

    else:
        raise ValueError(f"Invalid mode '{mode}'. Must be either 'azel' or 'hadec'")

    for col in data.columns:
        if data[col].isnull().any():
            raise ValueError(f"Column {col} contains missing values")
        if np.isinf(data[col]).any():
            raise ValueError(f"Column {col} contains infinite values")

    return np.radians(data)

def angular_difference(a, b):
    diff = (a - b + np.pi) % (2 * np.pi) - np.pi
    return diff

def pointing_model(coord1, coord2, selected_terms, aux1, aux2):

    coord1 = np.asarray(coord1)
    coord2 = np.asarray(coord2)
    n = len(coord1)

    az = ha = coord1
    el = dec = coord2
    Z = 0.5 * np.pi - el if el is not None else None  

    if aux1 is None:
        aux1 = np.zeros(n)
    elif np.isscalar(aux1):
        aux1 = np.full(n, aux1)
    else:
        aux1 = np.asarray(aux1)

    if aux2 is None:
        aux2 = np.zeros(n)
    elif np.isscalar(aux2):
        aux2 = np.full(n, aux2)
    else:
        aux2 = np.asarray(aux2)


    sin, cos, tan, sec = np.sin, np.cos, np.tan, lambda x: 1 / np.cos(x)

    sin_az, cos_az = sin(az), cos(az)
    sin_el, cos_el = sin(el), cos(el)
    tan_el, sec_el = tan(el), sec(el)
    sin_ha, cos_ha = sin(ha), cos(ha)
    sin_dec, cos_dec = sin(dec), cos(dec)
    tan_dec, sec_dec = tan(dec), sec(dec)

    cos_z = cos(Z)
    cos_2z = cos(2 * Z)
    cos_4z = cos(4 * Z)
    sin_z = sin(Z)
    sin_2z = sin(2 * Z)
    sin_3z = sin(3 * Z)
    sin_4z = sin(4 * Z)
    pow_3z = Z**3
    pow_5z = Z**5

    az_terms = {
        "IA": np.ones(n),
        "NPAE": tan_el,
        "CA": sec_el,
        "AN": sin_az * tan_el,
        "AW": cos_az * tan_el,
        "HSCA": -cos_az / cos_el,
        "HVCA": -cos_az * tan_el,
        "ACEC": -cos_az,
        "ACES": sin_az,
        "HACA1": cos_az,
        "HSCZ1": cos_z / cos_el,
        "HSCZ2": cos_2z / cos_el,
        "HSCZ4": cos_4z / cos_el,
        "HSSA1": sin_az / cos_el,
        "HSSZ1": sin_z / cos_el,
        "HSSZ2": sin_2z / cos_el,
        "HSSZ3": sin_3z / cos_el,
        "NRX": np.ones(n),
        "NRY": tan_el,
        "AUX1A": aux1,
        "AUX2A": aux2,
        "AUX1S": aux1 * sec_el,
        "AUX2S": aux2 * sec_el
    }

    el_terms = {
        "IE": -np.ones(n),
        "TF": cos_el,
        "TX": 1 / tan_el,
        "AN": cos_az,
        "AW": -sin_az,
        "ECEC": cos_el,
        "ECES": sin_el,
        "HZCZ1": cos_z,
        "HZCZ2": cos_2z,
        "HZCZ4": cos_4z,
        "HZSA": sin_az,
        "HZSA2": sin(2*az),
        "HZSZ1": sin_z,
        "HZSZ2": sin_2z,
        "HZSZ4": sin_4z,
        "PZZ3": pow_3z,
        "PZZ5": pow_5z,
        "NRX": -sin_el,
        "NRY": cos_el,
        "AUX1E": aux1,
        "AUX2E": aux2
    }

    h_terms = {
        "IH": np.ones(n),
        "CH": sec_dec,
        "MA": -cos_ha * tan_dec,
        "ME": sin_ha * tan_dec,
        "NP": tan_dec,
        "TF": np.cos(phi) * sin_ha * sec_dec,
        "TX": (np.cos(phi) * sin_ha) / ((sin_dec * np.sin(phi) + cos_dec * cos_ha * cos(phi)) * cos_dec),
        "AUX1H": aux1,
        "AUX1X": aux1 * sec_dec,
        "AUX2H": aux2,
        "AUX2X": aux2 * sec_dec,
        "FO": cos_ha,
        "HCEC": cos_ha,
        "HCES": sin_ha,
        "X2HC": np.cos(2 * ha) * sec_dec,
        "HXCH": np.cos(ha)*sec_dec,
        "HXSH2": np.sin(2*ha)*sec_dec
    }

    dec_terms = {
        "ID": np.ones(n),
        "MA": sin_ha,
        "ME": cos_ha,
        "TF": (np.cos(phi)*cos_ha*sin_dec - np.sin(phi)*cos_dec),
        "TX": (np.cos(phi)*cos_ha*sin_dec - np.sin(phi)*cos_dec) / ((sin_dec*np.sin(phi) + cos_dec*cos_ha*cos(phi)) * cos_dec),
        "DCEC": cos_dec,
        "DCES": sin_dec,
        "AUX1D": aux1,
        "AUX2D": aux2,
        "HDSH2": np.sin(2 * ha)
    }

    azel_terms_set = set(az_terms) | set(el_terms)
    hadec_terms_set = set(h_terms) | set(dec_terms)
    selected_set = set(selected_terms)

    if mode == "azel" and not selected_set <= azel_terms_set:
        print(colored(f"\nERROR: Invalid terms in Az/El mode: {selected_set - azel_terms_set}", "red"))
        print(colored("Please check your terms and try again.", "red"))
        sys.exit(1)  
    elif mode == "hadec" and not selected_set <= hadec_terms_set:
        print(colored(f"\nERROR: Invalid terms in HA/Dec mode: {selected_set - hadec_terms_set}", "red"))
        print(colored("Please check your terms and try again.", "red"))
        sys.exit(1)  

    X_rows = []
    term_list = []

    for term in selected_terms:
        if mode == "azel":
            az_col = az_terms.get(term, np.zeros(n))
            el_col = el_terms.get(term, np.zeros(n))
        else:
            az_col = h_terms.get(term, np.zeros(n))
            el_col = dec_terms.get(term, np.zeros(n))

        full_col = np.concatenate([az_col, el_col])
        X_rows.append(full_col)
        term_list.append(term)

    X = np.column_stack(X_rows) if X_rows else np.zeros((2 * n, 0))

    return X, term_list


def compute_residuals(model, X, coord1_real, coord2_real, coord1_actual, coord2_actual):

    n = len(coord1_real)

    delta = X @ model.coef_  
    delta1 = delta[:n]
    delta2 = delta[n:]

    res1 = (angular_difference(coord1_real + delta1, coord1_actual) * np.cos(coord2_real)) * (180 / np.pi) * 3600
    res2 = ((coord2_real + delta2) - coord2_actual) * (180 / np.pi) * 3600

    return np.concatenate([res1, res2])


def calculate_rms(data, selected_terms, model):

    coord1_real = data.iloc[:, 0].to_numpy()
    coord2_real = data.iloc[:, 1].to_numpy()
    coord1_actual = data.iloc[:, 2].to_numpy()
    coord2_actual = data.iloc[:, 3].to_numpy()

    n = len(coord1_real)
    X, _ = pointing_model(coord1_real, coord2_real, selected_terms, aux1, aux2)

    delta = X @ model.coef_
    delta1 = delta[:n]
    delta2 = delta[n:]

    res1 = angular_difference(coord1_real + delta1, coord1_actual) * np.cos(coord2_real)
    res2 = (coord2_real + delta2) - coord2_actual

    rms1 = np.sqrt(np.mean(res1 ** 2)) * (180 / np.pi) * 3600
    rms2 = np.sqrt(np.mean(res2 ** 2)) * (180 / np.pi) * 3600

    return rms1, rms2

def tune_bayesian_ridge(X, Y, coord1_real, coord2_real, coord1_actual, coord2_actual):

    lambda_2_options = np.logspace(-1, 2, 5)
    alpha_2_options = np.logspace(-15, 3, 6)

    best_score = float('inf')
    best_model = None
    start_time = time()

    for l2 in lambda_2_options:
        for a2 in alpha_2_options:

            model = BayesianRidge(
                fit_intercept=False,
                lambda_2=l2,
                alpha_2=a2,

                compute_score=False,  
                tol=1e-4  
            )

            model.fit(X, Y)

            res = compute_residuals(model, X, coord1_real, coord2_real, coord1_actual, coord2_actual)
            rms = np.sqrt(np.mean(res ** 2))

            uncertainties = np.sqrt(np.diag(model.sigma_)) * (180 / np.pi) * 3600
            avg_uncertainty = np.mean(uncertainties)

            current_score = 0.5 * rms + 0.5 * avg_uncertainty

            if current_score < best_score:
                best_score = current_score
                best_model = model

            if time() - start_time > 10:
                break

    print(f"\nTuning completed in {time() - start_time:.2f}s | Best λ₂={best_model.lambda_2:.1e}, α₂={best_model.alpha_2:.1e}")
    return best_model

def fit_pointing_model(data, selected_terms):

    coord1_real = data.iloc[:, 0].to_numpy()
    coord2_real = data.iloc[:, 1].to_numpy()
    coord1_actual = data.iloc[:, 2].to_numpy()
    coord2_actual = data.iloc[:, 3].to_numpy()

    delta1 = angular_difference(coord1_actual, coord1_real) 
    delta2 = (coord2_actual - coord2_real)

    Y = np.concatenate([delta1, delta2]) 

    X, term_list = pointing_model(coord1_real, coord2_real, selected_terms, aux1, aux2)

    model = tune_bayesian_ridge(X, Y, coord1_real, coord2_real, coord1_actual, coord2_actual)

    final_model = BayesianRidge(
        fit_intercept=False, 
        lambda_2=model.lambda_2, 

        alpha_2=model.alpha_2, 
        alpha_1=1e-6, 
        lambda_1=1e-6, 

        compute_score=True,  
        tol=1e-6  
    )
    final_model.fit(X, Y)

    return final_model, term_list

def get_uncertainties_and_correlations(model, selected_terms):

    coef = model.coef_
    covariance = model.sigma_  


    std_uncertainties = np.degrees(np.sqrt(np.diag(covariance))) * 3600 

    std_outer = np.outer(np.sqrt(np.diag(covariance)), np.sqrt(np.diag(covariance)))
    corr_matrix = covariance / std_outer

    strong_corrs = []
    for i in range(len(selected_terms)):
        for j in range(i + 1, len(selected_terms)):
            if abs(corr_matrix[i, j]) > CORRELATION_THRESHOLD:
                strong_corrs.append(
                    (selected_terms[i], selected_terms[j], corr_matrix[i, j])
                )

    term_info = {
        term: {
            "coefficient": coef[i],
            "uncertainty": std_uncertainties[i]
        }
        for i, term in enumerate(selected_terms)
    }

    return term_info, corr_matrix, strong_corrs


def analyze_and_plot_all(data, selected_terms, output_dir, uncertainties, corr, model):

    coord1_real = data.iloc[:, 0]  
    coord2_real = data.iloc[:, 1]  
    coord1_actual = data.iloc[:, 2]
    coord2_actual = data.iloc[:, 3]
    n = len(data)

    X, _ = pointing_model(coord1_real, coord2_real, selected_terms, aux1, aux2)
    delta = X @ model.coef_
    delta1 = delta[:n]
    delta2 = delta[n:]

    res1 = angular_difference(coord1_real + delta1, coord1_actual) * np.cos(coord2_real)
    res2 = (coord2_real + delta2) - coord2_actual

    res1_arcsec = res1 * (180 / np.pi) * 3600
    res2_arcsec = res2 * (180 / np.pi) * 3600

    raw_res1 = angular_difference(coord1_real, coord1_actual) * np.cos(coord2_real)
    raw_res2 = coord2_real - coord2_actual
    raw_res1_arcsec = raw_res1 * (180 / np.pi) * 3600
    raw_res2_arcsec = raw_res2 * (180 / np.pi) * 3600

    rms1 = np.sqrt(np.mean(res1_arcsec ** 2))
    rms2 = np.sqrt(np.mean(res2_arcsec ** 2))
    rms_tot = np.sqrt(rms1 ** 2 + rms2 ** 2)

    raw_rms1 = np.sqrt(np.mean(raw_res1_arcsec ** 2))
    raw_rms2 = np.sqrt(np.mean(raw_res2_arcsec ** 2))
    raw_rms_tot = np.sqrt(raw_rms1 ** 2 + raw_rms2 ** 2)

    if mode == 'azel':
        coord1_label, coord2_label = "Azimuth", "Elevation"
        res1_label, res2_label = "ΔAz", "ΔEl"
        polar_r_label = "Elevation [deg]"
    else:  # hadec
        coord1_label, coord2_label = "Hour angle", "Declination"
        res1_label, res2_label = "ΔHA", "ΔDec"
        polar_r_label = "Declination [deg]"

    plt.style.use(PLOT_STYLE)
    fig = plt.figure(figsize=PLOT_FIGSIZE)  
    fig.suptitle(
        f'Pointing model analysis ({mode.upper()}): Raw RMS {res1_label}: {raw_rms1:.2f}", {res2_label}: {raw_rms2:.2f}", Δtot: {raw_rms_tot:.2f}" → Fitted RMS {res1_label}: {rms1:.2f}", {res2_label}: {rms2:.2f}", Δtot: {rms_tot:.2f}"',
        fontsize=15, y=0.98)

    gs = fig.add_gridspec(4, 3, hspace=0.4, wspace=0.3)

    ax1 = fig.add_subplot(gs[0, 0], projection='polar')
    if mode == 'azel':
        r = np.degrees(coord2_real)  
        theta = coord1_real  
        sc = ax1.scatter(theta, r, s=10, c=r, cmap='viridis', alpha=0.7)
        ax1.set_ylim(90, 0)
        ax1.set_yticks([90, 60, 30, 0])
        ax1.set_theta_direction(-1)
    else:  # hadec

        r = np.abs(np.degrees(coord2_real))  
        theta = coord1_real  
        sc = ax1.scatter(theta, r, s=10, c=np.degrees(coord2_real), cmap='coolwarm', vmin=-90, vmax=90, alpha=0.9)
        ax1.set_ylim(90, 0)
        ax1.set_yticks([90, 60, 30, 0])

        ax1.plot(np.linspace(0, 2 * np.pi, 100), [0] * 100, 'k--', alpha=0.3)

        ax1.set_xticks(np.linspace(0, 2 * np.pi, 12, endpoint=False))
        ax1.set_xticklabels(['%dh' % i for i in range(0, 24, 2)])

    ax1.set_theta_zero_location('N')

    plt.colorbar(sc, ax=ax1, label=polar_r_label, pad=0.1)
    ax1.set_title("Points distribution", pad=15)
    ax1.grid(alpha=0.3)

    ax2 = fig.add_subplot(gs[0, 1])
    raw_residual_mag = np.sqrt(raw_res1_arcsec ** 2 + raw_res2_arcsec ** 2)
    raw_circle_rms = raw_rms_tot
    raw_circle_100 = np.percentile(raw_residual_mag, 100)
    ax2.scatter(raw_res1_arcsec, raw_res2_arcsec, alpha=0.5, color='orange')
    ax2.axhline(y=0, color='r', linestyle='--', linewidth=1)
    ax2.axvline(x=0, color='r', linestyle='--', linewidth=1)

    def add_directional_labels(ax, mode):

        xlim = ax.get_xlim()
        ylim = ax.get_ylim()

        if mode == 'azel':

            ax.text(xlim[1] * 0.9, 0, 'R', ha='center', va='bottom', fontweight='bold', fontsize=12)
            ax.text(xlim[0] * 0.9, 0, 'L', ha='center', va='bottom', fontweight='bold', fontsize=12)
            ax.text(0, ylim[1] * 0.9, 'U', ha='left', va='center', fontweight='bold', fontsize=12)
            ax.text(0, ylim[0] * 0.9, 'D', ha='left', va='center', fontweight='bold', fontsize=12)

        else:  # hadec

            ax.text(xlim[1] * 0.9, 0, 'W', ha='center', va='bottom', fontweight='bold', fontsize=12)
            ax.text(xlim[0] * 0.9, 0, 'E', ha='center', va='bottom', fontweight='bold', fontsize=12)
            ax.text(0, ylim[1] * 0.9, 'N', ha='left', va='center', fontweight='bold', fontsize=12)
            ax.text(0, ylim[0] * 0.9, 'S', ha='left', va='center', fontweight='bold', fontsize=12)

    ax2.add_artist(plt.Circle((0, 0), raw_circle_rms, fill=False, color='green', linestyle='--', alpha=0.7))
    ax2.add_artist(plt.Circle((0, 0), raw_circle_100, fill=False, color='red', linestyle='--', alpha=0.7))

    diagonal_offset = raw_circle_rms / np.sqrt(2)
    ax2.annotate(f'{raw_circle_rms:.1f}"', xy=(diagonal_offset, diagonal_offset), xytext=(5, 5),
                 textcoords='offset points', color='green')
    diagonal_offset_100 = raw_circle_100 / np.sqrt(2)
    ax2.annotate(f'{raw_circle_100:.1f}"', xy=(diagonal_offset_100, diagonal_offset_100), xytext=(5, 5),
                 textcoords='offset points', color='red')
    ax2.set_xlabel(f'{res1_label} [arcsec]')
    ax2.set_ylabel(f'{res2_label} [arcsec]')
    ax2.set_title(f'Raw Residuals (before fit): {res1_label} vs {res2_label}')
    ax2.grid(True, linestyle=':', alpha=0.7)
    raw_max_limit = max(abs(raw_res1_arcsec).max(), abs(raw_res2_arcsec).max()) * 1.1
    ax2.set_xlim(-raw_max_limit, raw_max_limit)
    ax2.set_ylim(-raw_max_limit, raw_max_limit)
    ax2.set_aspect('equal')
    add_directional_labels(ax2, mode)

    ax3 = fig.add_subplot(gs[0, 2])
    residual_mag = np.sqrt(res1_arcsec ** 2 + res2_arcsec ** 2)
    circle_rms = rms_tot
    circle_100 = np.percentile(residual_mag, 100)
    ax3.scatter(res1_arcsec, res2_arcsec, alpha=0.5)
    ax3.axhline(y=0, color='r', linestyle='--', linewidth=1)
    ax3.axvline(x=0, color='r', linestyle='--', linewidth=1)
    ax3.add_artist(plt.Circle((0, 0), circle_rms, fill=False, color='green', linestyle='--', alpha=0.7))
    ax3.add_artist(plt.Circle((0, 0), circle_100, fill=False, color='red', linestyle='--', alpha=0.7))

    diagonal_offset = circle_rms / np.sqrt(2)
    ax3.annotate(f'{circle_rms:.1f}"', xy=(diagonal_offset, diagonal_offset), xytext=(5, 5), textcoords='offset points',
                 color='green')
    diagonal_offset_100 = circle_100 / np.sqrt(2)
    ax3.annotate(f'{circle_100:.1f}"', xy=(diagonal_offset_100, diagonal_offset_100), xytext=(5, 5),
                 textcoords='offset points', color='red')
    ax3.set_xlabel(f'{res1_label} [arcsec]')
    ax3.set_ylabel(f'{res2_label} [arcsec]')
    ax3.set_title(f'Fitted residuals: {res1_label} vs {res2_label}')
    ax3.grid(True, linestyle=':', alpha=0.7)
    max_limit = max(abs(res1_arcsec).max(), abs(res2_arcsec).max()) * 1.1
    ax3.set_xlim(-max_limit, max_limit)
    ax3.set_ylim(-max_limit, max_limit)
    ax3.set_aspect('equal')
    add_directional_labels(ax3, mode)

    ax4 = fig.add_subplot(gs[1, 0])
    ax4.scatter(np.degrees(coord1_real), res1_arcsec, s=10, color='blue', alpha=0.7)
    ax4.set_xlabel(f"{coord1_label} [deg]")
    ax4.set_ylabel(f"{res1_label} [arcsec]")
    ax4.set_title(f"{res1_label} vs {coord1_label}")
    ax4.grid(alpha=0.3)

    ax5 = fig.add_subplot(gs[1, 1])
    ax5.scatter(np.degrees(coord2_real), res1_arcsec, s=10, color='green', alpha=0.7)
    ax5.set_xlabel(f"{coord2_label} [deg]")
    ax5.set_ylabel(f"{res1_label} [arcsec]")
    ax5.set_title(f"{res1_label} vs {coord2_label}")
    ax5.grid(alpha=0.3)

    ax6 = fig.add_subplot(gs[1, 2])
    ax6.plot(model.scores_, marker='o', linestyle='-', color='darkorange')
    ax6.set_title("Log-Marginal Likelihood")
    ax6.set_xlabel("Iteration")
    ax6.set_ylabel("Log-Likelihood")
    ax6.grid(True, alpha=0.4)

    ax7 = fig.add_subplot(gs[2, 0])
    ax7.scatter(np.degrees(coord1_real), res2_arcsec, s=10, color='red', alpha=0.7)
    ax7.set_xlabel(f"{coord1_label} [deg]")
    ax7.set_ylabel(f"{res2_label} [arcsec]")
    ax7.set_title(f"{res2_label} vs {coord1_label}")
    ax7.grid(alpha=0.3)

    ax8 = fig.add_subplot(gs[2, 1])
    ax8.scatter(np.degrees(coord2_real), res2_arcsec, s=10, color='purple', alpha=0.7)
    ax8.set_xlabel(f"{coord2_label} [deg]")
    ax8.set_ylabel(f"{res2_label} [arcsec]")
    ax8.set_title(f"{res2_label} vs {coord2_label}")
    ax8.grid(alpha=0.3)

    ax9 = fig.add_subplot(gs[2, 2])
    coef_arcsec = model.coef_ * (180 / np.pi) * 3600
    ax9.scatter(np.abs(coef_arcsec), uncertainties, color='teal', alpha=0.7)
    ax9.set_xlabel("|Fitted value| [arcsec]")
    ax9.set_ylabel("Uncertainty σ [arcsec]")
    ax9.set_title("Parameter magnitude vs uncertainty")
    ax9.grid(alpha=0.4)

    ax10 = fig.add_subplot(gs[3, 0])
    ax10.hist(res1_arcsec, bins=30, color='blue', alpha=0.7, density=True)
    ax10.axvline(np.mean(res1_arcsec), color='k', linestyle='dashed', linewidth=1)
    ax10.set_xlabel(f"{res1_label} [arcsec]")
    ax10.set_ylabel("Density")
    ax10.set_title(f"{coord1_label} Residual distribution")
    ax10.grid(alpha=0.3)

    ax11 = fig.add_subplot(gs[3, 1])
    ax11.hist(res2_arcsec, bins=30, color='red', alpha=0.7, density=True)
    ax11.axvline(np.mean(res2_arcsec), color='k', linestyle='dashed', linewidth=1)
    ax11.set_xlabel(f"{res2_label} [arcsec]")
    ax11.set_ylabel("Density")
    ax11.set_title(f"{coord2_label} Residual distribution")
    ax11.grid(alpha=0.3)

    ax12 = fig.add_subplot(gs[3, 2])
    im = ax12.imshow(corr, cmap='coolwarm', vmin=-1, vmax=1)
    plt.colorbar(im, ax=ax12, label='Correlation')
    ax12.set_xticks(range(len(selected_terms)))
    ax12.set_yticks(range(len(selected_terms)))
    ax12.set_xticklabels(selected_terms, rotation=90)
    ax12.set_yticklabels(selected_terms)
    ax12.set_title("Parameter correlations", pad=15)

    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(os.path.join(output_dir, f"pointing_analysis_{mode}.png"), dpi=PLOT_DPI, bbox_inches='tight')
    plt.close()

def save_results(output_dir, selected_terms, model, std_errors, residuals, overall_rms_x, overall_rms_y, corr_matrix, strong_corrs):

    os.makedirs(output_dir, exist_ok=True)
    result_file = os.path.join(output_dir, "pointing_results.txt")

    if mode == 'azel':
        x_label, y_label = "Az", "El"
    else:
        x_label, y_label = "HA", "Dec"

    with open(result_file, "w", encoding='utf-8') as f:
        f.write(f"Fitted pointing terms with 1-sigma uncertainties (arcsec) [{mode.upper()} mode]:\n")
        for term, value, error in zip(selected_terms, model.coef_, std_errors):
            val = np.degrees(value) * 3600
            f.write(f"{term}: {val:.4f} +/- {error:.4f} arcsec\n")

        f.write(f"\nFit RMS {x_label}: {overall_rms_x:.2f} arcsec\n")
        f.write(f"Fit RMS {y_label}: {overall_rms_y:.2f} arcsec\n")
        total_rms = np.sqrt(overall_rms_x**2 + overall_rms_y**2)
        f.write(f"Fit total RMS: {total_rms:.2f} arcsec\n")

        pop_sd = np.sqrt(np.sum(residuals**2) / (0.5*len(residuals) - len(model.coef_)))
        f.write(f"\nPopulation standard deviation of residuals: {pop_sd:.2f} arcsec\n")

        f.write("\nBayesian Ridge diagnostics:\n")
        f.write(f"Final noise precision (alpha): {model.alpha_:.4e}\n")
        f.write(f"Final weight precision (lambda): {model.lambda_:.4e}\n")
        f.write(f"Prior alpha_2 (noise spread): {model.alpha_2:.4e}\n")
        f.write(f"Prior lambda_2 (weight spread): {model.lambda_2:.4e}\n")

        f.write("\nNoise and regularization analysis:\n")

        noise_std = np.sqrt(1 / model.alpha_)  
        f.write(f"- Estimated noise std: {noise_std:.4e} rad ({np.degrees(noise_std) * 3600:.2f} arcsec)\n")

        noise_arcsec = np.degrees(noise_std) * 3600

        if noise_arcsec < 1.0:  
            f.write("  * Excellent: Sub-arcsecond noise level\n")
        elif noise_arcsec < 5.0:  
            f.write("  * Good: Acceptable noise level for pointing\n")
        elif noise_arcsec < 10.0:  
            f.write("  * Moderate: Higher than ideal noise level\n")
        else:  # > 10 arcsec
            f.write("  * WARNING: High noise level - check data quality or model\n")

        lambda_val = model.lambda_
        if lambda_val > 1000:
            f.write("  - WARNING: High lambda_! Indicates strong regularization. Possible causes:\n")
            f.write("           * Correlated parameters in your model\n")
            f.write("           * Insufficient data for model complexity\n")
            f.write("           * Need to review term selection\n")
        elif lambda_val < 1e-3:
            f.write("  - WARNING: Low lambda_! Indicates weak regularization. Possible issues:\n")
            f.write("           * Potential overfitting to noise\n")
            f.write("           * Unusually clean data (verify measurements)\n")
            f.write("           * Consider adding more regularization\n")
        else:
            f.write("  - Regularization strength appears appropriate\n")

        if strong_corrs:
            f.write("\nStrongly correlated parameters (|r| > 0.8):\n")
            for term1, term2, corr in strong_corrs:
                f.write(f"{term1} vs {term2}: {corr:.2f}\n")

    with open(os.path.join(output_dir, "correlation_matrix.csv"), "w") as f:
        header = "Parameter".ljust(12) + "," + ",".join([f"{term:^10}" for term in selected_terms])
        f.write(header + "\n")
        f.write("-" * (12 + 11 * len(selected_terms)) + "\n")
        for i, term in enumerate(selected_terms):
            row_str = f"{term.ljust(12)},"
            corr_values = [f"{corr_matrix[i, j]:^10.3f}" for j in range(len(selected_terms))]
            row_str += ",".join(corr_values)
            f.write(row_str + "\n")

        f.write("\nNotes:\n")
        f.write("Diagonal values (self-correlations) are always 1.000\n")
        f.write("Values > |0.8| may indicate significant parameter coupling\n")

    bundle_to_html(output_dir)

def bundle_to_html(output_dir):

    title = f"Pointing model results ({mode.upper()} mode)"

    org_name = getattr(sys.modules[__name__], 'org_name', 'Unknown Organization')
    report_timestamp = getattr(sys.modules[__name__], 'report_timestamp', '')
    version_string = getattr(sys.modules[__name__], 'version_string', 'Pytheia v1.0.0')

    img_path = os.path.join(output_dir, f"pointing_analysis_{mode}.png")
    with open(img_path, "rb") as img_file:
        encoded_img = base64.b64encode(img_file.read()).decode("utf-8")
    img_data_uri = f"data:image/png;base64,{encoded_img}"

    with open(os.path.join(output_dir, "pointing_results.txt"), 'r') as f:
        pointing_results = f.read()

    with open(os.path.join(output_dir, "correlation_matrix.csv"), 'r') as f:
        correlation_matrix = f.read()

    html_content = f"""
    <!DOCTYPE html>
    <html>
    <head>
    <title>{title}</title>
    <style>
        body {{ font-family: Arial; margin: 20px }}
        .section {{ margin-bottom: 30px; border-bottom: 1px solid #eee; padding-bottom: 15px }}
        img {{ max-width: 100% }}
        .footer {{ text-align: center; padding: 20px 0; margin-top: 30px; border-top: 1px solid #eee; color: #666; font-size: 0.9em }}
    </style>
</head>
<body>
    <h1>{title}</h1>
    <div class="section">
        <h2>Diagnostic plots</h2>
        <img src="{img_data_uri}">
    </div>

    <div class="section">
    <h2>Pointing results</h2>
    <pre style="font-size: 1.2em;">{pointing_results}</pre>
    </div>
    
    <div class="section">
        <h2>Correlation Matrix</h2>
        <pre style="font-size: 1.1em;">{correlation_matrix}</pre>
    </div>
    <div class="footer">
        {version_string} | Report generated: {report_timestamp} | Licensed to: {org_name} | &copy; Orionis Group Company, Princeton NJ USA
    </div>
</body>
</html>
    """

    with open(os.path.join(output_dir, "combined_results_bayesian.html"), 'w') as f:
        f.write(html_content)

    for file in [f"pointing_analysis_{mode}.png", "pointing_results.txt", "correlation_matrix.csv"]:
        os.remove(os.path.join(output_dir, file))


def main(input_file, output_dir, selected_terms):
    _ensure_main_injection()

    os.makedirs(output_dir, exist_ok=True)
    data = load_data(input_file)
    print(f"\nFitting the following terms in {mode.upper()} mode:", selected_terms)

    if mode not in {"azel", "hadec"}:
        print(colored(f"Invalid mode: {mode}", "red"))
        sys.exit()

    coord1_real = data.iloc[:, 0].to_numpy()
    coord2_real = data.iloc[:, 1].to_numpy()
    coord1_actual = data.iloc[:, 2].to_numpy()
    coord2_actual = data.iloc[:, 3].to_numpy()

    X, term_list = pointing_model(coord1_real, coord2_real, selected_terms, aux1, aux2)

    model, term_list = fit_pointing_model(data, selected_terms)

    residuals = compute_residuals(model, X, coord1_real, coord2_real, coord1_actual, coord2_actual)

    term_info, corr_matrix, strong_corrs = get_uncertainties_and_correlations(model, term_list)

    overall_rms_x, overall_rms_y = calculate_rms(data, selected_terms, model)

    x_label = "Az" if mode == 'azel' else "HA"
    y_label = "El" if mode == 'azel' else "Dec"

    print(f"\nFit RMS Δ{x_label}: {overall_rms_x:.2f} arcsec")
    print(f"Fit RMS Δ{y_label}: {overall_rms_y:.2f} arcsec")
    print(f"Fit total RMS: {np.sqrt(overall_rms_x ** 2 + overall_rms_y ** 2):.2f} arcsec")

    print(f"\nEstimated noise level (alpha_): {model.alpha_:.4e}")
    print(f"Regularization strength (lambda_): {model.lambda_:.4e}")

    if model.lambda_ > 1e3:
        print(colored("Strong regularization: lambda_ is quite large — terms may be highly constrained.", "yellow"))
    elif model.lambda_ < 1e-3:
        print(colored("Weak regularization: lambda_ is very small — potential overfitting.", "yellow"))

    print("\nFitted parameters (arcsec):")
    std_errors = []
    for term, info in term_info.items():
        coef = info["coefficient"]
        std = info["uncertainty"]
        print(colored(f"{term}: {np.degrees(coef) * 3600:.4f} ± {std:.4f}", "magenta"))
        std_errors.append(std)

    if strong_corrs:
        print("\nStrongly correlated parameters (|corrs| > 0.8):")
        for term1, term2, corr in strong_corrs:
            print(f"{term1} vs {term2}: {corr:.2f}")
    else:
        print("\nNo strongly correlated parameter pairs found (all |corrs| ≤ 0.8)")

    analyze_and_plot_all(data, term_list, output_dir, std_errors, corr_matrix, model=model)
    save_results(output_dir, term_list, model, std_errors, residuals, overall_rms_x, overall_rms_y, corr_matrix, strong_corrs)

    print("\nResults saved in:", output_dir)

def run_pointing_model(input_file, output_dir, selected_terms):

    try:
        main(input_file, output_dir, selected_terms)
        print(colored(
            f"\nPointing model ({mode.upper()} mode) successfully computed and saved. Check output directory for results!",
            "green"))
    except Exception as e:
        exc_type, exc_value, exc_traceback = sys.exc_info()

        tb_frames = traceback.extract_tb(exc_traceback)

        error_frame = tb_frames[0]

        error_msg = (
            f"Failed to compute pointing model:\n"
            f" * Error: {str(e)}\n"
            f" * File: {os.path.basename(error_frame.filename)}\n"  
            f" * Line: {error_frame.lineno}\n"
            f" * Code: {error_frame.line.strip() if error_frame.line else 'Unknown'}"
        )
        print(colored(error_msg, "red"))

if __name__ == "__main__":
    print("ERROR: Do not run Pytheia100_core.py directly.")
    print("Please execute: python Pytheia100_main.py")
    import sys
    sys.exit(1)