# base code by Danica Sutherland, Hamed Shirzad, Alan Milligan  (spring 2024)

import math

import torch
from torch import nn
import torch.nn.functional as F
import torch_geometric
from torch_geometric.nn import GCNConv
from torch_geometric.utils import (
    add_self_loops,
    degree,
    to_dense_adj,
    to_dense_batch,
    to_nested_tensor,
    from_nested_tensor,
)


class GCN(nn.Module):
    def __init__(self, depth, width, input_dim, output_dim=1, conv_cls=GCNConv):
        super().__init__()

        layers = [
            (nn.Linear(input_dim, width), "x -> x"),
            nn.ReLU(inplace=True),
        ]

        for i in range(depth):
            # GCNConv does a convolution over the graph,
            # i.e. linear mapping of all node features, then averaging them over neighbours
            layers.append((conv_cls(width, width), "x, edge_index -> x"))
            layers.append(nn.ReLU(inplace=True))

        layers.append((torch_geometric.nn.global_add_pool, "x, batch -> x"))
        layers.append(nn.Linear(width, output_dim))

        self.layers = torch_geometric.nn.Sequential("x, edge_index, batch", layers)

    def forward(self, data):
        return self.layers(data.x.float(), data.edge_index, data.batch)


class MyGCNConv(nn.Module):
    def __init__(self, dim_in, dim_out):
        super().__init__()

        self.dim_in = dim_in
        self.dim_out = dim_out

        self.weight = torch.nn.Parameter(torch.empty((dim_out, dim_in)))
        self.bias = torch.nn.Parameter(torch.empty(dim_out))
        self.reset_parameters()

    def reset_parameters(self):
        # the scheme used by torch_geometric
        nn.init.xavier_uniform_(self.weight)
        nn.init.zeros_(self.bias)

    def forward(self, x, edge_index):
        # x is an array of node features, shape [nodes_in_batch, dim_in]
        # edge_index is an array of edges, shape [2, edges_in_batch]
        #   each edge goes from node edge_index[0, i] to edge_index[1, i]
        #   this dataset has undirected edges, which means each edge is in here twice

        # add "self-loops" (edges from v to v) for each node
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0])
        deg = degree(edge_index[1], x.shape[0], dtype=x.dtype)  # shape [nodes_in_batch]

        # you may or may not want to use this:
        #   adj is an adjacency matrix, a [num_nodes, num_nodes] matrix
        #   with adj[i, j] = 1(there is an edge from i to j)
        #   although we have multiple graphs in a batch, that just means there
        #   are no edges between nodes from different graphs
        adj = to_dense_adj(edge_index).squeeze(0)

        raise NotImplementedError()

        return x


################################################################################


class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim_in, n_heads, dim_v, dim_qk=None, dim_out=None, dropout=0.1):
        super().__init__()

        self.dim_in = dim_in
        self.n_heads = n_heads
        self.dim_v = dim_v
        self.dim_qk = dim_in if dim_qk is None else dim_qk
        self.dim_out = dim_v * n_heads if dim_out is None else dim_out

        self.map_Q = nn.Linear(self.dim_in, self.n_heads * self.dim_qk)
        self.map_K = nn.Linear(self.dim_in, self.n_heads * self.dim_qk)
        self.map_V = nn.Linear(self.dim_in, self.n_heads * self.dim_v)

        self.map_out = nn.Linear(self.dim_v * n_heads, self.dim_out)

        self.dropout = nn.Dropout(dropout)  # to apply to the attention scores

    def forward(self, x):
        # assume x is a torch.nested.nested_tensor
        #  x[graph_i, node_i, j] is the jth feature of the node_ith node from graph_i
        n_graphs = x.size(0)
        #  x.size(1) is irregular and would throw an error
        #  x.size(2) is the feature dimension == self.dim_in

        q = self.map_Q(x)  # [n_graphs, <irregular>, n_heads * dim_qk]
        k = self.map_K(x)  # [n_graphs, <irregular>, n_heads * dim_qk]
        v = self.map_V(x)  # [n_graphs, <irregular>, n_heads * dim_v]

        # split things up per head
        # for nested_tensors, -1 here means "stay the same"
        # so we end at [n_graphs, n_heads, <irregular>, dim_*]
        q = q.reshape(n_graphs, -1, self.n_heads, self.dim_qk).transpose(2, 1)
        k = k.reshape(n_graphs, -1, self.n_heads, self.dim_qk).transpose(2, 1)
        v = v.reshape(n_graphs, -1, self.n_heads, self.dim_v).transpose(2, 1)

        raise NotImplementedError()

        # at this point, you should have x of shape [n_graphs, <irreg>, n_heads * dim_v]
        x = self.map_out(x)  # [n_graphs, <irreg>, dim_out]
        return x


class GraphTransformerLayer(nn.Module):
    def __init__(self, dim, n_heads, dropout_attn, dropout, use_our_impl=False):
        super().__init__()

        self.use_our_impl = use_our_impl
        if use_our_impl:
            self.self_attn = MultiHeadSelfAttention(
                dim, n_heads, dim_v=dim, dim_qk=dim, dim_out=dim, dropout=dropout_attn
            )
        else:
            self.self_attn = nn.MultiheadAttention(
                dim, n_heads, dropout=dropout_attn, batch_first=True
            )
        self.bn1 = nn.BatchNorm1d(dim)
        self.dropout = nn.Dropout(dropout)

        self.feedforward_block = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(dim * 2, dim),
            nn.Dropout(dropout),
        )
        self.bn2 = nn.BatchNorm1d(dim)

    def forward(self, x, batch):
        orig_x = x  # save for residual connection later

        if self.use_our_impl:
            x = to_nested_tensor(x, batch)
            x = self.self_attn(x)
            x = from_nested_tensor(x)
        else:
            # do some gross stuff with masks,
            # because pytorch's attenton doesn't support nested tensor when training
            # this means we do a lot of multiplying by 0 then dropping the result
            x_dense, mask = to_dense_batch(x, batch)
            x = self.self_attn(
                x_dense, x_dense, x_dense, key_padding_mask=~mask, need_weights=False
            )[0][mask]
            del x_dense, mask

        x = self.dropout(x)
        x = x + orig_x  # residual connection
        x = self.bn1(x)

        x = x + self.feedforward_block(x)
        x = self.bn2(x)
        return x


class GraphTransformer(nn.Module):
    def __init__(
        self,
        depth,
        width,
        input_dim,
        output_dim=1,
        num_head=4,
        dropout_attn=0.5,
        dropout=0.2,
        use_our_impl=False,
    ):
        super().__init__()
        if input_dim is None:
            input_dim = dim
        self.pos_enc = LapPENodeEncoder(width, input_dim)

        layers = [
            (
                GraphTransformerLayer(
                    width, num_head, dropout_attn, dropout, use_our_impl=use_our_impl
                ),
                "x, batch -> x",
            )
            for _ in range(depth)
        ] + [
            (torch_geometric.nn.global_add_pool, "x, batch -> x"),
            (nn.Linear(width, output_dim), "x -> x"),
        ]
        self.layers = torch_geometric.nn.Sequential("x, batch", layers)

    def forward(self, data):
        data = self.pos_enc(data)
        return self.layers(data.x.float(), data.batch)


################################################################################


class LapPENodeEncoder(nn.Module):
    """Laplace Positional Embedding node encoder.

    LapPE of size `dim_pe` will get appended to each node feature vector.
    If `project_x` is True, original node features will be first linearly
    projected to (dim_emb - dim_pe) size and then concatenated with LapPE.
    Often we use `dim_emb` > `dim_in`.

    No need to worry about any details here.

    Args:
        dim_emb: Size of final node embedding
        project_x: Project node features `x` from dim_in to (dim_emb - dim_pe)
    """

    def __init__(
        self, dim_emb, dim_in, dim_pe=8, n_layers=2, max_freqs=4, project_x=True
    ):
        super().__init__()

        if dim_emb - dim_pe < 1:
            raise ValueError(
                f"LapPE size {dim_pe} is too large for "
                f"desired embedding size of {dim_emb}."
            )

        if project_x:
            self.linear_x = nn.Linear(dim_in, dim_emb - dim_pe)
        self.project_x = project_x

        # Initial projection of eigenvalue and the node's eigenvector value
        self.linear_A = nn.Linear(2, dim_pe)

        # Sometimes people use batch norm here, but we're not

        # DeepSet model for LapPE
        layers = []
        if n_layers == 1:
            layers.append(nn.ReLU())
        else:
            self.linear_A = nn.Linear(2, 2 * dim_pe)
            layers.append(nn.ReLU())

            for _ in range(n_layers - 2):
                layers.append(nn.Linear(2 * dim_pe, 2 * dim_pe))
                layers.append(nn.ReLU())

            layers.append(nn.Linear(2 * dim_pe, dim_pe))
            layers.append(nn.ReLU())

        self.pe_encoder = nn.Sequential(*layers)

    def forward(self, batch):
        if not (hasattr(batch, "EigVals") and hasattr(batch, "EigVecs")):
            raise ValueError("Precomputed eigen values and vectors are required")

        eigvals = batch.EigVals
        eigvecs = batch.EigVecs

        if self.training:
            # randomly flip some eigenvectors
            sign_flip = (
                torch.randint(
                    0, 2, (eigvecs.size(1),), device=eigvecs.device, dtype=eigvecs.dtype
                )
                * 2
                - 1
            )
            eigvecs = eigvecs * sign_flip.unsqueeze(0)

        pos_enc = torch.cat(
            (eigvecs.unsqueeze(2), eigvals), dim=2
        )  # (num nodes) x (num eigenvectors) x 2
        empty_mask = torch.isnan(pos_enc)  # (num nodes) x (num eigenvectors) x 2

        pos_enc[empty_mask] = 0  # (num nodes) x (num eigenvectors) x 2
        pos_enc = self.linear_A(pos_enc)  # (num nodes) x (num eigenvectors) x dim_pe
        pos_enc = self.pe_encoder(pos_enc)

        # remove masked sequences; must clone before overwriting masked elements
        pos_enc = pos_enc.clone().masked_fill_(empty_mask[:, :, 0].unsqueeze(2), 0.0)

        # sum pooling
        pos_enc = torch.sum(pos_enc, 1, keepdim=False)  # (num nodes) x dim_pe

        # project node features if needed
        h = batch.x.float()
        if self.project_x:
            h = self.linear_x(h)
        # Concatenate final PEs to input embedding
        batch.x = torch.cat((h, pos_enc), 1)

        return batch
