from typing import Optional, Union, Iterable
from qiskit import QuantumCircuit, transpile
from qiskit_aer import AerSimulator
from qiskit.quantum_info import Statevector, partial_trace, DensityMatrix, Pauli
from qiskit.visualization.bloch import Bloch
from qiskit.result import marginal_counts
from qiskit.visualization import (
    plot_bloch_multivector,
    plot_bloch_vector, 
    plot_histogram, 
    plot_state_city, 
    plot_state_qsphere
)
import numpy as np
import matplotlib.pyplot as plt

# ---------- Interni helper ----------


def _get_statevector(state: Union[QuantumCircuit, Statevector]) -> Statevector:
    """
    Konvertuje QuantumCircuit ili Statevector u Statevector objekat.
    
    Args:
        state: QuantumCircuit ili Statevector objekat
        
    Returns:
        Statevector objekat
        
    Raises:
        TypeError: Ako argument nije odgovarajućeg tipa
    """
    if isinstance(state, Statevector):
        return state
    elif isinstance(state, QuantumCircuit):
        sim = AerSimulator(method='statevector')
        qc_copy = state.copy()
        qc_copy.save_statevector()
        tqc = transpile(qc_copy, sim)
        result = sim.run(tqc, shots=1).result()
        return result.get_statevector()
    else:
        raise TypeError("Argument mora da bude QuantumCircuit ili Statevector")

# ---------- Utilities ----------


def get_state(state: Union[QuantumCircuit, Statevector], 
                   threshold: float = 1e-10) -> str:
    """
    Vraća string reprezentaciju state vektora sa samo znacima (bez amplituda).
    
    Args:
        state: QuantumCircuit ili Statevector
        threshold: Minimalna amplituda da bi se prikazala (default: 1e-10)
        
    Returns:
        String reprezentacija stanja
    """
    sv = _get_statevector(state)
    terms = []
    
    for i, amp in enumerate(sv.data):
        if np.abs(amp) > threshold:
            sign = '-' if np.real(amp) < 0 else '+'
            ket = f"|{format(i, f'0{sv.num_qubits}b')}>"
            terms.append(f"{sign} {ket}")
    
    if not terms:
        return "|0>"
    
    result = " ".join(terms)
    # Ukloni vodeći + znak
    if result.startswith('+ '):
        result = result[2:]
    
    return result.replace('+ ', ' + ').replace('- ', ' - ')


def get_full_state(state: Union[QuantumCircuit, Statevector], 
              threshold: float = 1e-10,
              precision: int = 3) -> str:
    """
    Vraća potpunu string reprezentaciju state vektora sa amplitudama.
    
    Args:
        state: QuantumCircuit ili Statevector
        threshold: Minimalna amplituda da bi se prikazala
        precision: Broj decimala za prikaz amplituda
        
    Returns:
        String reprezentacija stanja sa amplitudama
    """
    sv = _get_statevector(state)
    terms = []
    
    for i, amp in enumerate(sv.data):
        if np.abs(amp) > threshold:
            # Formatiranje kompleksnog broja
            real = np.real(amp)
            imag = np.imag(amp)
            
            if np.abs(imag) < threshold:
                # Samo realni deo
                amp_str = f"{real:.{precision}f}"
            elif np.abs(real) < threshold:
                # Samo imaginarni deo
                if np.abs(imag - 1) < threshold:
                    amp_str = "i"
                elif np.abs(imag + 1) < threshold:
                    amp_str = "-i"
                else:
                    amp_str = f"{imag:.{precision}f}i"
            else:
                # Oba dela
                sign = '+' if imag >= 0 else '-'
                amp_str = f"{real:.{precision}f}{sign}{abs(imag):.{precision}f}i"
            
            ket = f"|{format(i, f'0{sv.num_qubits}b')}>"
            
            # Dodaj zagradе ako je potrebno
            if '+' in amp_str or (amp_str.startswith('-') and 'i' in amp_str):
                terms.append(f"({amp_str}){ket}")
            else:
                terms.append(f"{amp_str}{ket}")
    
    if not terms:
        return "0"
    
    result = " + ".join(terms)
    return result.replace(" + -", " - ")


def print_state(state: Union[QuantumCircuit, Statevector], 
               label: str = "",
               show_amplitudes: bool = False,
               precision: int = 3) -> None:
    """
    Ispisuje kvantno stanje u konzoli.
    
    Args:
        state: QuantumCircuit ili Statevector
        label: Opcioni label za prikaz
        show_amplitudes: Da li prikazati amplitude (False = samo znaci)
        precision: Broj decimala za amplitude
    """
    sv = _get_statevector(state)
    
    if show_amplitudes:
        state_str = get_full_state(sv, precision=precision)
    else:
        state_str = get_state(sv)
    
    if label:
        print(f"{label} = {state_str}")
    else:
        print(state_str)


def print_raw_state(qc, label=""):
    sv = Statevector(qc)
    if label:
        print(f"\n{label}:")
    for i, amp in enumerate(sv.data):
        if abs(amp) > 1e-10:
            print(f"|{i:03b}⟩: {amp:.6f}")


def get_probabilities(state: Union[QuantumCircuit, Statevector]) -> dict:
    """
    Vraća verovatnoće svih baza stanja.
    
    Args:
        state: QuantumCircuit ili Statevector
        
    Returns:
        Dictionary: {basis_state: probability}
    """
    sv = _get_statevector(state)
    probs = {}
    
    for i, amp in enumerate(sv.data):
        prob = np.abs(amp) ** 2
        if prob > 1e-10:
            basis = format(i, f'0{sv.num_qubits}b')
            probs[basis] = prob
    
    return probs


def print_probabilities(state: Union[QuantumCircuit, Statevector], 
                       label: str = "") -> None:
    """
    Ispisuje verovatnoće svih baza stanja.
    
    Args:
        state: QuantumCircuit ili Statevector
        label: Opcioni label
    """
    probs = get_probabilities(state)
    
    if label:
        print(f"\n{label}:")
    
    for basis, prob in sorted(probs.items()):
        print(f"|{basis}>: {prob:.4f} ({prob*100:.2f}%)")


def _to_density(state: Union[QuantumCircuit, Statevector, DensityMatrix]):
    """Pretvara stanje u gustinsku matricu."""
    if isinstance(state, DensityMatrix):
        return state
    if isinstance(state, Statevector):
        return DensityMatrix(state)
    if isinstance(state, QuantumCircuit):
        return DensityMatrix(Statevector.from_instruction(state))
    raise TypeError("Unsupported state type.")


def matrix_power(dm: DensityMatrix, p: float) -> DensityMatrix:
    """Compute ρ^p using eigen-decomposition."""
    M = dm.data
    vals, vecs = np.linalg.eigh(M)
    vals = np.maximum(vals, 0)
    vals_p = np.diag(vals ** p)
    M_p = vecs @ vals_p @ vecs.conj().T
    return DensityMatrix(M_p)


def get_fidelity(
    state1: Union[QuantumCircuit, Statevector, DensityMatrix],
    state2: Union[QuantumCircuit, Statevector, DensityMatrix],
    qubits: Iterable[int] = None
                ) -> float:
    """
    Računa fidelity između dva kvantna stanja.
    
    - Može porediti:
        * čisto–čisto
        * čisto–mešovito
        * mešovito–mešovito
    - Ako su stanja različite dimenzije:
        * automatski radi redukciju na podskup qubita (argument qubits)
    
    Args:
        state1, state2: kvantna stanja (QC, Statevector ili DensityMatrix)
        qubits: lista qubita koje poredi (npr. [2]).
                Ako None — upoređuje kompletne sisteme (moraju biti iste dimenzije)
    
    Returns:
        Fidelity vrednost iz [0, 1].
    """

    dm1 = _to_density(state1)
    dm2 = _to_density(state2)

    # Handle subsystem comparison
    if dm1.num_qubits != dm2.num_qubits:
        if qubits is None:
            raise ValueError(
                "State dimensions differ. Specify qubits=[...] for subsystem fidelity."
            )
        dm1 = dm1.reduce(qubits)
        dm2 = dm2.reduce(qubits)
    else:
        if qubits is not None:
            dm1 = dm1.reduce(qubits)
            dm2 = dm2.reduce(qubits)

    # sqrt(ρ)
    sqrt_dm1 = matrix_power(dm1, 0.5)

    # ρ^{1/2} σ ρ^{1/2}
    product = DensityMatrix(
        sqrt_dm1.data @ dm2.data @ sqrt_dm1.data
    )

    # eigenvalues
    evals = np.linalg.eigvalsh(product.data)
    evals = np.maximum(evals, 0)

    # (Tr sqrt(product))^2
    fidelity = (np.sum(np.sqrt(evals)))**2

    return float(np.real_if_close(fidelity))


def show_bloch_sphere(qc: Union[QuantumCircuit, Statevector],
                      from_instruction: bool = True,
                      qubit_index: Optional[int] = None,
                      title: str = None, 
                      figsize: tuple = (10, 5)) -> None:
    """
    Prikazuje Bloch sferu za sve qubite ili za jedan odabrani qubit.
    Radi za čista i mešovita stanja, i za više qubita koristi partial trace.

    Args:
        qc: QuantumCircuit ili Statevector
        from_instruction: Ako je True, koristi Statevector.from_instruction
                          (matematički pristup, brže).
        qubit_index: Ako je None -> prikazuje sve qubite.
                     Ako je broj -> prikazuje samo taj qubit.
        title: Naslov figure (opciono)
        figsize: Veličina figure
    """

    # ---- Paulijeve matrice ----
    X = np.array([[0, 1],
                  [1, 0]], dtype=complex)
    Y = np.array([[0, -1j],
                  [1j, 0]], dtype=complex)
    Z = np.array([[1, 0],
                  [0, -1]], dtype=complex)

    # ---- Izračunavanje statevectora ----
    if isinstance(qc, Statevector):
        sv = qc
    else:
        sv = Statevector.from_instruction(qc) if from_instruction else _get_statevector(qc)

    num_qubits = sv.num_qubits

    # ---- Ako je zadat samo jedan qubit ----
    if qubit_index is not None:
        if qubit_index < 0 or qubit_index >= num_qubits:
            raise ValueError(
                f"Nevažeći qubit_index {qubit_index}. Mora biti između 0 i {num_qubits-1}."
            )

        # Partial trace nad svim ostalim qubitima
        if num_qubits > 1:
            qubits_to_trace_out = [i for i in range(num_qubits) if i != qubit_index]
            rho = DensityMatrix(partial_trace(sv, qubits_to_trace_out))
        else:
            rho = DensityMatrix(sv)

        # Izračunavanje Bloch vektora
        bloch_vector = [
            np.real(np.trace(rho.data @ pauli))
            for pauli in (X, Y, Z)
        ]

        # Jedna jedina Bloch sfera
        fig = plt.figure(figsize=figsize)
        ax = fig.add_subplot(111, projection='3d')
        plot_bloch_vector(bloch_vector, title=f"q{qubit_index}", ax=ax)

        if title:
            plt.suptitle(title, fontsize=16)

        plt.tight_layout()
        plt.show()
        return

    # ---- Prikaz svih qubita ----
    if num_qubits > 5:
        print(f"Upozorenje: Previše qubita ({num_qubits}) — prikazujem samo prvih 5.")
        num_qubits = 5
        sv = Statevector(sv.data[:2**5])

    # Kreiranje figure
    fig, axes = plt.subplots(1, num_qubits, figsize=(4 * num_qubits, 4),
                             subplot_kw={'projection': '3d'})
    if num_qubits == 1:
        axes = [axes]

    # ---- Iteriranje kroz qubite ----
    for qubit_idx in range(num_qubits):
        if num_qubits == 1:
            rho = DensityMatrix(sv)
        else:
            qubits_to_trace_out = [i for i in range(sv.num_qubits) if i != qubit_idx]
            rho = DensityMatrix(partial_trace(sv, qubits_to_trace_out))

        # Bloch vektor
        bloch_vector = [
            np.real(np.trace(rho.data @ pauli))
            for pauli in (X, Y, Z)
        ]

        plot_bloch_vector(bloch_vector, title=f"q{qubit_idx}", ax=axes[qubit_idx])

    if title:
        plt.suptitle(title, fontsize=16)

    plt.tight_layout()
    plt.show()


def show_qsphere(state: Union[QuantumCircuit, Statevector], 
                        style: str = "qsphere") -> None:
    """
    Prikazuje stanje 2-qubit sistema vizualno.
    
    Args:
        state: QuantumCircuit ili Statevector (mora imati tačno 2 qubita)
        style: Stil prikaza - 'city' ili 'qsphere'
        
    Raises:
        ValueError: Ako nema tačno 2 qubita ili stil nije validan
    """
    sv = _get_statevector(state)
    
    if sv.num_qubits != 2:
        raise ValueError(f"Funkcija radi samo sa 2 qubita, a prosleđeno je {sv.num_qubits}")
    
    if style == "city":
        fig = plot_state_city(sv)
    elif style == "qsphere":
        fig = plot_state_qsphere(sv)
    else:
        raise ValueError("Style mora da bude 'city' ili 'qsphere'")
    
    plt.tight_layout()
    plt.show()
    plt.close(fig)


def show_qc(qc: QuantumCircuit, figsize: tuple = (12, 6)) -> None:
    """
    Prikazuje kvantno kolo.
    
    Args:
        qc: QuantumCircuit koji treba prikazati
        figsize: Veličina figure
    """
    fig = qc.draw(output='mpl', style='iqp', fold=-1)
    if hasattr(fig, 'set_size_inches'):
        fig.set_size_inches(figsize)
    plt.tight_layout()
    plt.show()
    plt.close(fig)


def show_measurement(qc: QuantumCircuit, 
                    shots: int = 1024,
                    figsize: tuple = (10, 6),
                    label_size: int = 10,
                    marginal_counts_flag: bool = False,
                    selected_qubits: list[int] = None) -> dict:
    """
    Simulira merenje kvantnog kola i prikazuje histogram rezultata.
    
    Args:
        qc: QuantumCircuit za merenje (mora imati merenje)
        shots: Broj simulacija
        figsize: Veličina figure
        label_size: Veličina labele na histogramu
        marginal_counts_flag: Ako je True, prikazuje marginalne raspodele po kubitima
        selected_qubits: Lista qubita koje želimo da prikažemo u histogramu.
                         Primeri:
                         None -> prikazuje sve
                         [2]  -> prikazuje samo q2
                         [0,2] -> prikazuje samo q0 i q2
    Returns:
        Dictionary sa rezultatima merenja
    """
    simulator = AerSimulator()
    qc_copy = qc.copy()
    
    # Ako nema classical bitova, dodaj merenje nad svim qubitima
    if qc_copy.num_clbits == 0:
        qc_copy.measure_all()
    
    compiled_circuit = transpile(qc_copy, simulator)
    result = simulator.run(compiled_circuit, shots=shots).result()
    counts = result.get_counts()

    # Ako korisnik želi samo neke qubite — filter 
    if selected_qubits is not None and marginal_counts_flag is False:
        filtered_counts = {}

        # 1. Sort descending to get MSB first (q_max...q_0)
        sorted_selected_qubits = sorted(selected_qubits, reverse=True) 

        for bitstring, count in counts.items():
            selected_bits = []
            num_qubits = len(bitstring.replace(" ", "")) # Handle spaces if present

            # 2. Iterate in MSB order (e.g., q=1, then q=0)
            for q in sorted_selected_qubits:
                bit_pos = num_qubits - 1 - q  # Correct index calculation
                selected_bits.append(bitstring.replace(" ", "")[bit_pos]) # Access bit

            # The new_key is correctly ordered (q_1 q_0)
            new_key = "".join(selected_bits)
            filtered_counts[new_key] = filtered_counts.get(new_key, 0) + count
        
        # Prikaži histogram samo selektovanih qubita
        fig = plot_histogram(filtered_counts, figsize=figsize)

    # Ako je tražen "marginal_counts" prikaz
    elif marginal_counts_flag:
        num_qubits = qc_copy.num_qubits
        
        # Kreiramo listu svih marginalnih merenja po kubitima
        # qubit_counts = [counts(q0), counts(q1), counts(q2), ...]
        from qiskit.result import marginal_counts
        all_marginal_counts = [marginal_counts(counts, [qubit]) for qubit in range(num_qubits)]

        if selected_qubits is not None:
            # IZVLAČIMO SAMO IZABRANE MARGINALNE REZULTATE:
            # Selektujemo samo one rečnike iz liste koji odgovaraju indeksima u selected_qubits
            selected_marginal_counts = [
                all_marginal_counts[q] 
                for q in selected_qubits 
                if q < num_qubits
            ]

            # plot_histogram prihvata listu rečnika i prikazuje ih jedan pored drugog
            fig = plot_histogram(selected_marginal_counts, figsize=figsize)            
        else:
            # Ako selected_qubits nije None, prikaži sve
            fig = plot_histogram(all_marginal_counts, figsize=figsize)

    # Inače prikaži sve rezultate
    else:
        fig = plot_histogram(counts, figsize=figsize)

    # Stilizacija
    ax = fig.axes[0]
    ax.tick_params(axis='x', labelsize=label_size)

    plt.tight_layout()
    plt.show()
    plt.close(fig)

    return counts


def compute_orbits(a: int, N: int):
    """
    Računa sve orbite permutacije y -> (a*y) mod N.
    
    a, N : parametri modularne funkcije
    povratna vrednost : lista orbita (svaka orbita je lista celih brojeva)
    """
    
    visited = set()
    orbits = []

    for y in range(N):
        if y in visited:
            continue  # već smo ga stavili u neku orbitu

        # Krećemo da pratimo orbitu od y
        orbit = []
        x = y

        while x not in visited:
            visited.add(x)
            orbit.append(x)
            x = (a * x) % N

        orbits.append(orbit)

    return orbits
