#!/usr/bin/env python

import os
from pathlib import Path

import numpy as np
from tqdm import tqdm  # pip install tqdm

# we're going to delay torch imports to inside the relevant function,
# for speed/simplicity for non-torch parts

# make sure we're working in the directory this file lives in,
# for simplicity with imports and relative paths
os.chdir(Path(__file__).parent.resolve())

from bayesian_logistic_regression import BayesianLogisticRegression
from markov_chain import MarkovChain
from utils import handle, main, run


################################################################################


def grad_chain():
    return MarkovChain(
        init_probs=np.array([0.1, 0.6, 0.3, 0.0, 0.0, 0.0, 0.0]),
        transition_probs=np.array(
            [
                [0.08, 0.90, 0.01, 0.00, 0.00, 0.00, 0.01],
                [0.03, 0.95, 0.01, 0.00, 0.00, 0.00, 0.01],
                [0.06, 0.06, 0.75, 0.05, 0.05, 0.02, 0.01],
                [0.00, 0.00, 0.00, 0.30, 0.60, 0.09, 0.01],
                [0.00, 0.00, 0.00, 0.02, 0.95, 0.02, 0.01],
                [0.00, 0.00, 0.00, 0.01, 0.01, 0.97, 0.01],
                [0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 1.00],
            ]
        ),
        state_names=[
            "VIDEO_GAMES",
            "INDUSTRY",
            "GRAD_SCHOOL",
            "VIDEO_GAMES_WITH_PHD",
            "INDUSTRY_WITH_PHD",
            "ACADEMIA",
            "DECEASED",
        ],
        # could use a python enum for the state names,
        # but it doesn't always play super nice with numpy, idk
    )


@handle("mc-sample")
def mc_sample():
    model = grad_chain()

    n_samples = 10000
    time = 30
    samples = model.sample(n_samples, time)
    est = np.bincount(samples[:, time - 1]) / n_samples
    print(f"Empirical dist, time {time}: {est.round(3)}")


@handle("mc-marginals")
def mc_marginals():
    model = grad_chain()

    raise NotImplementedError()


@handle("mc-mostlikely-marginals")
def mc_mostlikely():
    model = grad_chain()

    raise NotImplementedError()


@handle("mc-mostlikely-sequence")
def mc_mostlikely_sequence():
    model = grad_chain()

    for time in [50, 100]:
        print(f"Mode at time {time}: {model.mode(time)}")


@handle("mc-conditionals-forward")
def mc_conditionals_forward():
    model = grad_chain()

    raise NotImplementedError()


@handle("mc-conditional-mc")
def mc_conditional_mc():
    model = grad_chain()

    raise NotImplementedError()


@handle("mc-conditional-exact")
def mc_conditional_exact():
    model = grad_chain()

    raise NotImplementedError()


################################################################################


@handle("mcmc-blogreg")
def mcmc_blogreg():
    d = np.load("../data/two_threes.npz")
    X = d["X"]
    y = d["y"]

    rng = np.random.default_rng(seed=12)
    is_train = rng.random(size=X.shape[0]) < 0.8
    X_train = X[is_train]
    y_train = y[is_train]
    X_test = X[~is_train]
    y_test = y[~is_train]

    model = BayesianLogisticRegression()
    map_est = model.map_estimate(X_train, y_train)
    samples = model.sample_weights(X_train, y_train, n_samples=250_000)
    model.plot_weights(samples, map_est, figname="blogreg")

    # map_test_loglik = model.log_likelihood(X_test, y_test, map_est)
    map_test_loglik = model.different_log_likelihoods(
        X_test, y_test, map_est[np.newaxis]
    )
    mcmc_test_loglik = model.different_log_likelihoods(X_test, y_test, samples)
    print(
        f"Test log-likelihood: MAP {map_test_loglik}\n"
        f"                    MCMC {mcmc_test_loglik}"
    )


################################################################################

_ds = {}


def get_peptides_struct(split="train"):
    from gnn_utils import PeptidesStructWithEigs

    if split not in _ds:
        _ds[split] = PeptidesStructWithEigs(split).to(get_device())
    return _ds[split]


_device = []


def get_device(dev=None):
    import torch

    if dev is None:
        if _device:
            return _device[0]
        if torch.cuda.is_available():
            return torch.device("cuda")
        # cpu is 10x faster for me for GCN, but mps is 4x faster for transformer...
        # elif torch.backends.mps.is_available():
        #     return torch.device("mps")
        else:
            return torch.device("cpu")
    return torch.device(dev)


@handle("use-cpu")
def use_cpu():
    import torch

    _device[:] = [torch.device("cpu")]


@handle("use-mps")
def use_mps():
    import torch

    _device[:] = [torch.device("mps")]


@handle("use-cuda")
def use_cuda():
    import torch

    _device[:] = [torch.device("cuda")]


def train_model(model, epochs, batch_size=32):
    import torch
    import torch.nn.functional as F
    from torch_geometric.loader import DataLoader  # pip install torch_geometric

    device = next(model.parameters()).device
    model.train()

    train_ds = get_peptides_struct("train")
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)

    opt = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=5e-4)
    for epoch in range(1, epochs + 1):
        total_loss = 0
        for batch in tqdm(
            train_loader, leave=False, desc=f"epoch {epoch:>5,}", unit="batch"
        ):
            batch = batch.to(device)

            opt.zero_grad()
            out = model(batch)
            loss = F.mse_loss(out, batch.y)
            total_loss += loss.detach().cpu().item() * len(batch)
            loss.backward()
            opt.step()

        print(
            f"epoch {epoch:>5,}: average batch train mse {total_loss / len(train_ds) : >8.6f}"
        )

        if epoch % 10 == 0 or epoch == epochs:
            val_mse, val_mae = eval_model(model, "val")
            model.train()
            print(" " * 13 + f"val mse {val_mse:>8.6f}  val mae {val_mae:>8.6f}")

    return model


def eval_model(model, split="val", batch_size=64):
    import torch
    import torch.nn.functional as F
    from torch_geometric.loader import DataLoader

    device = next(model.parameters()).device
    model.eval()

    ds = get_peptides_struct(split)
    loader = DataLoader(ds, batch_size=batch_size)

    sum_abs_err = 0.0
    sum_sq_err = 0.0
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            out = model(batch)

            sum_abs_err += F.l1_loss(out, batch.y).cpu().item() * len(batch)
            sum_sq_err += F.mse_loss(out, batch.y).cpu().item() * len(batch)

    return sum_sq_err / len(ds), sum_abs_err / len(ds)


handle("my-gcn")(lambda: train_gcn(use_our_impl=True))


@handle("gcn")
def train_gcn(epochs=3, use_our_impl=False):
    from gnns import GCN, GCNConv, MyGCNConv

    ds = get_peptides_struct("train")

    model = GCN(
        depth=4,
        width=32,
        input_dim=ds[0].x.shape[1],
        output_dim=ds[0].y.shape[1],
        conv_cls=MyGCNConv if use_our_impl else GCNConv,
    ).to(get_device())

    train_model(model, epochs=epochs)

    mse, mae = eval_model(model, split="test")
    print(f"test mse: {mse:.5f}   test mae: {mae:.5f}")


handle("my-graph-transformer")(lambda: train_graph_transformer(use_our_impl=True))


@handle("graph-transformer")
def train_graph_transformer(epochs=3, use_our_impl=False):
    from gnns import GraphTransformer

    ds = get_peptides_struct("train")

    model = GraphTransformer(
        depth=4,
        width=32,
        input_dim=ds[0].x.shape[1],
        output_dim=ds[0].y.shape[1],
        use_our_impl=use_our_impl,
    ).to(get_device())

    train_model(model, epochs=epochs)

    mse, mae = eval_model(model, split="test")
    print(f"test mse: {mse:.5f}   test mae: {mae:.5f}")


################################################################################

if __name__ == "__main__":
    main()
