import numpy as np


class MarkovChain:
    def __init__(self, transition_probs, init_probs, state_names=None):
        # p(x_1 = c) = init_probs[c]
        # p(x_j = c | x_{j-1} = c') = transition_probs[c', c]
        self.init_probs = init_probs
        self.transition_probs = transition_probs
        self.num_states = len(self.init_probs)
        assert self.init_probs.shape == (self.num_states,)
        assert self.transition_probs.shape == (self.num_states, self.num_states)

        if state_names is not None:
            self.state_names = np.asarray(state_names)
            assert len(state_names) == self.num_states

    def rename(self, ary):
        if self.state_names is None:
            return ary
        return self.state_names[ary]

    def sample(self, n_samples, length, rng=None):
        if rng is None:
            rng = np.random.default_rng()

        samples = np.zeros((n_samples, length), dtype=int)
        raise NotImplementedError()

        return samples

    def marginals(self, length):
        margs = np.zeros((self.num_states, length))
        raise NotImplementedError()

        return margs

    def mode(self, length):
        raise NotImplementedError()

    # TODO: method here for mc-conditional-exact
