# mar_optimizer.py
from __future__ import annotations

import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Optional, Dict, List, Union
from scipy.optimize import minimize


@dataclass
class PortfolioResult:
    weights: np.ndarray
    cagr: float
    mean_return: float
    max_drawdown: float
    mar: float
    composite_returns: pd.Series
    composite_curve: pd.Series


# ---------- Helpers for dates / returns / drawdowns ----------

def _as_years_from_dates(start: pd.Timestamp, end: pd.Timestamp) -> float:
    """Year fraction based on first/last dates of the PRICE dataframe."""
    days = (end - start).days
    return max(days, 1) / 365.0


def _prepare_prices(df_with_dates_and_prices: pd.DataFrame) -> pd.DataFrame:
    """
    Ensure we have a datetime index from the FIRST column of the input,
    and the remaining columns are numeric prices, sorted by date.
    """
    if isinstance(df_with_dates_and_prices.index, pd.DatetimeIndex):
        df = df_with_dates_and_prices.sort_index().copy()
    else:
        first_col = df_with_dates_and_prices.columns[0]
        df = df_with_dates_and_prices.copy()
        df[first_col] = pd.to_datetime(df[first_col])
        df.set_index(first_col, inplace=True)
        df.sort_index(inplace=True)

    # Remaining columns are assumed to be prices
    prices = df.astype(float)
    return prices


def _returns_from_prices(prices: pd.DataFrame) -> pd.DataFrame:
    """
    Convert price history to period returns per column and align across assets.
    """
    rets = prices.pct_change().dropna(how="all")
    rets = rets.dropna(how="any")  # align dates across all strategies
    return rets


def _drawdown_stats(returns: pd.Series):
    """
    Compute equity curve, drawdown series, and max drawdown (as positive number).
    """
    curve = (1.0 + returns).cumprod()
    running_peak = curve.cummax()
    dd = curve / running_peak - 1.0       # negative in drawdown
    max_dd = float(dd.min())              # most negative
    return curve, dd, abs(max_dd)


def _metrics_from_returns(
    returns: pd.Series,
    price_start: pd.Timestamp,
    price_end: pd.Timestamp,
) -> Dict[str, float]:
    """
    Compute CAGR, mean return, MaxDD, MAR using the PRICE start/end dates.
    """
    if returns.empty:
        return dict(cagr=np.nan, mean_return=np.nan, max_drawdown=np.nan, mar=np.nan)

    curve, _, max_dd = _drawdown_stats(returns)

    years = _as_years_from_dates(price_start, price_end)
    ending = float(curve.iloc[-1])
    if ending > 0 and years > 0:
        cagr = ending ** (1.0 / years) - 1.0
    else:
        cagr = np.nan

    mean_ret = float(returns.mean())
    mar = (cagr / max_dd) if (max_dd > 0 and np.isfinite(cagr)) else -np.inf
    return dict(cagr=cagr, mean_return=mean_ret, max_drawdown=max_dd, mar=mar)


def _comp_returns(rets_mat: np.ndarray, idx: pd.DatetimeIndex, w: np.ndarray) -> pd.Series:
    """
    Weighted composite return series from matrix of returns and weights.
    """
    comp = rets_mat @ w
    return pd.Series(comp, index=idx)


def _percent_to_decimal(x: Optional[float]) -> Optional[float]:
    """
    Interpret MaxDD input as either decimal (0.1) or percent (10).
    """
    if x is None:
        return None
    if x <= 0:
        return 0.0
    # Accept either 0.10 or 10 (percent). If > 1, treat as percent.
    return float(x) / 100.0 if x > 1 else float(x)


# ---------- Main optimizer ----------

def optimize_mar_from_prices(
    df_with_dates_and_prices: pd.DataFrame,
    wmax: Union[float, np.ndarray] = 1.0,
    n_restarts: int = 21,
    random_state: Optional[int] = 42,
    MaxDD: Optional[float] = None,
) -> PortfolioResult:
    """
    Optimize a composite portfolio given price history.

    Default behavior (MaxDD is None):
        Maximize MAR = CAGR / Max Drawdown,
        subject to 0 <= w_i <= wmax and sum(w_i) = 1.

    If MaxDD is provided (e.g., 10 for 10% or 0.10):
        Maximize CAGR subject to Max Drawdown <= MaxDD
        and the same weight constraints.

    Parameters
    ----------
    df_with_dates_and_prices : DataFrame
        Dates in the first column; remaining columns are price histories for strategies.
    wmax : float or array-like
        Per-asset cap(s). Scalar applies to all; array length must equal number of assets.
    n_restarts : int
        Number of starting points (Dirichlet samples + equal) for multi-start SLSQP.
    random_state : int or None
        RNG seed for reproducibility of restarts.
    MaxDD : float or None
        Maximum allowed drawdown. If > 1, interpreted as percent (e.g., 10 -> 10%).
        If between 0 and 1, interpreted as decimal (e.g., 0.10 -> 10%).

    Returns
    -------
    PortfolioResult
    """
    # 1) Prepare prices and remember ORIGINAL start/end dates
    prices = _prepare_prices(df_with_dates_and_prices)
    if prices.shape[1] == 0:
        raise ValueError("No valid price columns found.")

    price_idx = prices.index
    if len(price_idx) < 2:
        raise ValueError("Need at least two price observations to compute CAGR.")

    price_start = price_idx[0]
    price_end = price_idx[-1]

    # 2) Returns matrix (index may be shorter than price index, but CAGR uses price_start/price_end)
    rets = _returns_from_prices(prices)
    if rets.shape[1] == 0:
        raise ValueError("No valid return series after cleaning NaNs.")

    n = rets.shape[1]
    idx = rets.index
    R = rets.values

    # 3) Bounds 0 <= w_i <= wmax
    if np.isscalar(wmax):
        bounds = [(0.0, float(wmax)) for _ in range(n)]
    else:
        wmax_arr = np.asarray(wmax, dtype=float)
        if wmax_arr.shape != (n,):
            raise ValueError("wmax vector must have length equal to number of assets.")
        bounds = [(0.0, float(u)) for u in wmax_arr]

    # 4) Sum(w) = 1 constraint
    cons: List[Dict] = [{"type": "eq", "fun": lambda w: np.sum(w) - 1.0}]

    # 5) Drawdown constraint if requested
    md_cap = _percent_to_decimal(MaxDD)
    if md_cap is not None:
        def dd_constraint(w: np.ndarray) -> float:
            comp = _comp_returns(R, idx, w)
            _, _, mdd = _drawdown_stats(comp)
            # Must have mdd <= md_cap -> md_cap - mdd >= 0
            return md_cap - mdd

        cons.append({"type": "ineq", "fun": dd_constraint})

    rng = np.random.default_rng(random_state)

    # 6) Objective
    def objective(w: np.ndarray) -> float:
        comp = _comp_returns(R, idx, w)
        m = _metrics_from_returns(comp, price_start, price_end)
        if md_cap is None:
            # Maximize MAR
            val = -m["mar"]
        else:
            # Maximize CAGR under MaxDD constraint
            val = -m["cagr"] if np.isfinite(m["cagr"]) else 1e6
        if not np.isfinite(val):
            val = 1e6
        return float(val)

    # 7) Multi-start initialization
    inits = [np.full(n, 1.0 / n)]
    for _ in range(max(0, n_restarts - 1)):
        x = rng.dirichlet(np.ones(n))
        inits.append(x)

    best = None
    best_val = np.inf

    for x0 in inits:
        # Project initial guess to bounds, then renormalize to sum to 1
        lo = np.array([b[0] for b in bounds])
        hi = np.array([b[1] for b in bounds])
        x0 = np.minimum(np.maximum(x0, lo), hi)
        s = x0.sum()
        x0 = (x0 / s) if s > 0 else np.full(n, 1.0 / n)

        res = minimize(
            objective,
            x0,
            method="SLSQP",
            bounds=bounds,
            constraints=tuple(cons),
            options={"maxiter": 400, "ftol": 1e-9, "disp": False},
        )

        if res.success and res.fun < best_val:
            best, best_val = res, res.fun

    if best is None:
        raise RuntimeError("Optimization failed to converge for all initializations.")

    # 8) Final portfolio series & metrics
    w_star = best.x
    comp = _comp_returns(R, idx, w_star)
    curve, _, mdd = _drawdown_stats(comp)
    m = _metrics_from_returns(comp, price_start, price_end)

    return PortfolioResult(
        weights=w_star,
        cagr=m["cagr"],
        mean_return=m["mean_return"],
        max_drawdown=m["max_drawdown"],
        mar=m["mar"],
        composite_returns=comp.rename("Composite_Return"),
        composite_curve=curve.rename("Composite_Equity_Curve"),
    )
