from pathlib import Path

from torch_geometric.data import InMemoryDataset, download_url
from torch_geometric.io import fs


# normally you should use torch_geometric.data.LRGBDataset
# here, we already computed the eigendecompositions of each graph for you
# and are hard-coding a couple things to make stuff easier


class PeptidesStructWithEigs(InMemoryDataset):
    def __init__(self, split: str = "train", transform=None):
        root = Path(__file__).absolute().parent.parent / "data-dl"
        self.data_dir = root / "peptides-struct-eigs"

        assert split in ["train", "val", "test"]
        self.split = split

        super().__init__(str(root), transform)
        self.load(str(self.data_dir / f"{split}.pt"))

    # we pre-processed for you to save time, no need to (or support for) more processing now
    @property
    def raw_dir(self):
        return str(self.data_dir)

    raw_file_names = ["train.pt", "val.pt", "test.pt"]

    def download(self):
        fs.rm(self.raw_dir)
        for fn in self.raw_file_names:
            download_url(
                f"https://cs.ubc.ca/~dsuth/440/23w2/peptides-struct-eigs/{fn}",
                self.raw_dir,
            )


# this dataset was generated as below, for the record

if False:
    from functools import partial
    import time

    import numpy as np
    import torch
    import torch.nn.functional as F
    from torch_geometric.datasets import LRGBDataset
    from torch_geometric.utils import (
        get_laplacian,
        to_scipy_sparse_matrix,
        to_undirected,
    )
    from tqdm import tqdm

    def pre_transform_in_memory(dataset, transform_func, show_progress=False):
        data_list = [
            transform_func(dataset.get(i))
            for i in tqdm(
                range(len(dataset)),
                disable=not show_progress,
                mininterval=10,
                miniters=len(dataset) // 20,
            )
        ]
        data_list = list(filter(None, data_list))

        dataset._indices = None
        dataset._data_list = data_list
        dataset.data, dataset.slices = dataset.collate(data_list)

    def add_LapPE(data, is_undirected, max_freqs, eigvec_norm="L2"):

        if hasattr(data, "num_nodes"):
            N = data.num_nodes  # Explicitly given number of nodes
        else:
            N = data.x.shape[0]  # Number of nodes, including disconnected nodes.

        if is_undirected:
            undir_edge_index = data.edge_index
        else:
            undir_edge_index = to_undirected(data.edge_index)

        # Eigen-decomposition with numpy
        L = to_scipy_sparse_matrix(
            *get_laplacian(undir_edge_index, normalization=None, num_nodes=N)
        )
        evals, evects = np.linalg.eigh(L.toarray())

        data.EigVals, data.EigVecs = get_lap_decomp_stats(
            evals=evals, evects=evects, max_freqs=max_freqs, eigvec_norm=eigvec_norm
        )

        return data

    def get_lap_decomp_stats(evals, evects, max_freqs, eigvec_norm="L2"):
        """Compute Laplacian eigen-decomposition-based PE stats of the given graph.

        Args:
            evals, evects: Precomputed eigen-decomposition
            max_freqs: Maximum number of top smallest frequencies / eigenvecs to use
            eigvec_norm: Normalization for the eigen vectors of the Laplacian
        Returns:
            Tensor (num_nodes, max_freqs, 1) eigenvalues repeated for each node
            Tensor (num_nodes, max_freqs) of eigenvector values per node
        """
        N = len(evals)  # Number of nodes, including disconnected nodes.

        # Keep up to the maximum desired number of frequencies.
        idx = evals.argsort()[:max_freqs]
        evals, evects = evals[idx], np.real(evects[:, idx])
        evals = torch.from_numpy(np.real(evals)).clamp_min(0)

        # Normalize and pad eigen vectors.
        evects = torch.from_numpy(evects).float()
        denom = evects.norm(p=2, dim=0, keepdim=True)
        denom = denom.clamp_min(1e-12).expand_as(evects)
        evects = evects / denom

        if N < max_freqs:
            EigVecs = F.pad(evects, (0, max_freqs - N), value=float("nan"))
        else:
            EigVecs = evects

        # Pad and save eigenvalues.
        if N < max_freqs:
            EigVals = F.pad(evals, (0, max_freqs - N), value=float("nan")).unsqueeze(0)
        else:
            EigVals = evals.unsqueeze(0)
        EigVals = EigVals.repeat(N, 1).unsqueeze(2)

        return EigVals, EigVecs

    train_dataset = LRGBDataset("data", "Peptides-struct", "train")
    val_dataset = LRGBDataset("data", "Peptides-struct", "val")
    test_dataset = LRGBDataset("data", "Peptides-struct", "test")

    max_freqs = 4
    start = time.perf_counter()
    print("Precomputing Positional Encoding statistics: ")
    # Estimate directedness based on 10 graphs to save time.
    is_undirected = all(d.is_undirected() for d in train_dataset[:10])
    print(f"  ...estimated to be undirected: {is_undirected}")
    pre_transform_in_memory(
        train_dataset,
        partial(add_LapPE, is_undirected=is_undirected, max_freqs=max_freqs),
        show_progress=True,
    )
    pre_transform_in_memory(
        val_dataset,
        partial(add_LapPE, is_undirected=is_undirected, max_freqs=max_freqs),
        show_progress=True,
    )
    pre_transform_in_memory(
        test_dataset,
        partial(add_LapPE, is_undirected=is_undirected, max_freqs=max_freqs),
        show_progress=True,
    )
    elapsed = time.perf_counter() - start
    timestr = time.strftime("%H:%M:%S", time.gmtime(elapsed)) + f"{elapsed:.2f}"[-3:]
    print(f"Done! Took {timestr}")
