import numpy as np
from tqdm import tqdm


class BayesianLogisticRegression:
    def __init__(self, lam=1):
        self.lam = lam

    def sample_weights(self, X, y, n_samples, rng=None):
        if rng is None:
            rng = np.random.default_rng()

        n, d = X.shape

        samples = np.zeros((n_samples, d))
        w = np.zeros(d)
        log_p = self.log_likelihood(X, y, w)

        n_accept = 0
        pbar = tqdm(range(n_samples))
        for i in pbar:
            w_hat = w + rng.normal(
                loc=0,
                scale=1,
                size=d,
            )
            log_phat = self.log_likelihood(X, y, w_hat)

            log_r = log_phat - log_p
            if np.log(rng.uniform()) < log_r:
                w = w_hat
                n_accept += 1
                log_p = log_phat
            samples[i, :] = w
            pbar.set_postfix({"accept_rate": f"{n_accept / (i + 1):.2%}"})

        print(
            f"did {n_samples:,} steps; accepted {n_accept:,} samples ({n_accept / n_samples:.2%})"
        )
        return samples

    def log_likelihood(self, X, y, w, lam=None):
        if lam is None:
            lam = self.lam
        # return -np.log1p(np.exp(-y * (X @ w))).sum() - lam / 2 * (w @ w)
        return -np.logaddexp(0, -y * (X @ w)).sum() - lam / 2 * (w @ w)

    def plot_weights(self, sample_weights, point_est=None, figname=None):
        import matplotlib.pyplot as plt

        indices = [
            ("positive variables", [119, 144, 152, 162, 184, 196]),
            ("negative variables", [30, 75, 106, 109, 123, 124]),
            ("neutral variables", [79, 167, 182, 213, 222, 255]),
        ]

        fig, axes = plt.subplots(
            6, 3, figsize=(8, 8), constrained_layout=True, sharex=True, sharey=True
        )

        for col, (name, ary) in enumerate(indices):
            for row, idx in enumerate(ary):
                ax = axes[row, col]
                ax.hist(sample_weights[:, idx], bins=50)
                if point_est is not None:
                    ax.axvline(point_est[idx], color="r")

                if row == 0:
                    ax.set_title(name)

        if figname is None:
            plt.show()
        else:
            fn = f"../figs/{figname}.png"
            fig.savefig(fn, bbox_inches="tight", pad_inches=0.1)
            print(f"Saved to {fn}")

    def map_estimate(self, X, y):
        import torch

        # there are lots of ways to solve this, but you'll have torch installed anyway
        # and this ensures it's exactly the same parameterization

        X = torch.tensor(X, dtype=torch.float32, requires_grad=False)
        y = torch.tensor(y, dtype=torch.float32, requires_grad=False)

        w = torch.zeros(X.shape[1], requires_grad=True)
        opt = torch.optim.LBFGS([w], max_iter=100_000, lr=0.1)

        def step():
            opt.zero_grad()
            loss = torch.log1p(torch.exp(-y * (X @ w))).sum() + self.lam / 2 * (w @ w)
            loss.backward()
            return loss

        opt.step(step)

        return w.detach().numpy()

    def predict_proba(self, X, w):
        "Probabilistic prediction for a given set of inputs."
        return 1 / (1 + np.exp(-X @ w))

    def different_log_likelihoods(self, X, y, samples):
        preds = X @ samples.T  # shape [n, n_samples]

        avg_decision = ((preds > 0).sum(axis=1) + 1) / (preds.shape[1] + 2)
        method_one = np.log(np.where(y > 0, avg_decision, 1 - avg_decision)).sum()

        probs = 1 / (1 + np.exp(-X @ samples.T))  # shape [n, n_samples]
        log_liks = np.log(
            np.where(y[:, np.newaxis] > 0, probs, 1 - probs)
        )  # [n, n_samples]

        method_two = log_liks.mean(axis=1).sum()
        method_three = log_liks.sum(axis=0).mean()

        mean_probs = probs.mean(axis=1)
        method_four = np.log(np.where(y > 0, mean_probs, 1 - mean_probs)).sum()

        return method_one, method_two, method_three, method_four
