# mar_optimizer.py
from __future__ import annotations
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Tuple, Optional, Dict, List
from scipy.optimize import minimize

@dataclass
class PortfolioResult:
    weights: np.ndarray
    cagr: float
    mean_return: float
    max_drawdown: float
    mar: float
    composite_returns: pd.Series
    composite_curve: pd.Series

def _as_years(dt_index: pd.DatetimeIndex) -> float:
    days = (dt_index[-1] - dt_index[0]).days
    return max(days, 1) / 365.0

def _price_df_to_returns(df_prices: pd.DataFrame) -> pd.DataFrame:
    # Expect first column to be dates; remaining columns are prices
    if not isinstance(df_prices.index, pd.DatetimeIndex):
        first_col = df_prices.columns[0]
        df = df_prices.copy()
        df[first_col] = pd.to_datetime(df[first_col])
        df.set_index(first_col, inplace=True)
        df.sort_index(inplace=True)
    else:
        df = df_prices.sort_index().copy()
    prices = df.astype(float)
    rets = prices.pct_change().dropna(how="all")
    rets = rets.dropna(how="any")  # align across names
    return rets

def _drawdown_stats(returns: pd.Series):
    curve = (1 + returns).cumprod()
    running_peak = curve.cummax()
    dd = curve / running_peak - 1.0          # negative during drawdown
    max_dd = float(dd.min())                  # most negative
    return curve, dd, abs(max_dd)

def _metrics_from_returns(returns: pd.Series) -> Dict[str, float]:
    if returns.empty:
        return dict(cagr=np.nan, mean_return=np.nan, max_drawdown=np.nan, mar=np.nan)
    curve, dd, max_dd = _drawdown_stats(returns)
    years = _as_years(returns.index)
    ending = float(curve.iloc[-1])
    cagr = ending ** (1.0 / years) - 1.0 if ending > 0 else np.nan
    mean_ret = float(returns.mean())
    mar = (cagr / max_dd) if (max_dd > 0 and np.isfinite(cagr)) else -np.inf
    return dict(cagr=cagr, mean_return=mean_ret, max_drawdown=max_dd, mar=mar)

def _comp_returns(rets_mat: np.ndarray, idx: pd.DatetimeIndex, w: np.ndarray) -> pd.Series:
    comp = rets_mat @ w
    return pd.Series(comp, index=idx)

def _percent_to_decimal(x: Optional[float]) -> Optional[float]:
    if x is None:
        return None
    if x <= 0:
        return 0.0
    # Accept either 0.10 or 10 (percent). If > 1, treat as percent.
    return float(x) / 100.0 if x > 1 else float(x)

def optimize_mar_from_prices(
    df_with_dates_and_prices: pd.DataFrame,
    wmax: float | np.ndarray = 1.0,
    n_restarts: int = 21,
    random_state: Optional[int] = 42,
    MaxDD: Optional[float] = None,
) -> PortfolioResult:
    """
    Optimize a composite portfolio given price history.

    Default behavior (MaxDD is None):
        Maximize MAR = CAGR / Max Drawdown, subject to 0 <= w_i <= wmax and sum w_i = 1.

    If MaxDD is provided (e.g., 10 for 10% or 0.10):
        Maximize CAGR subject to Max Drawdown <= MaxDD and the same weight constraints.

    Parameters
    ----------
    df_with_dates_and_prices : DataFrame
        Dates in the first column; remaining columns are price histories for strategies.
    wmax : float or array-like
        Per-asset cap(s). Scalar applies to all; array length must equal number of assets.
    n_restarts : int
        Number of starting points (Dirichlet samples + equal) for multi-start SLSQP.
    random_state : int or None
        RNG seed for reproducibility of restarts.
    MaxDD : float or None
        Maximum allowed drawdown. If > 1, interpreted as percent (e.g., 10 -> 10%).
        If between 0 and 1, interpreted as decimal (e.g., 0.10 -> 10%).

    Returns
    -------
    PortfolioResult
    """
    rets = _price_df_to_returns(df_with_dates_and_prices)
    if rets.shape[1] == 0:
        raise ValueError("No valid price columns with returns found.")
    n = rets.shape[1]
    idx = rets.index
    R = rets.values

    # Bounds
    if np.isscalar(wmax):
        bounds = [(0.0, float(wmax)) for _ in range(n)]
    else:
        wmax = np.asarray(wmax, dtype=float)
        if wmax.shape != (n,):
            raise ValueError("wmax vector must have length equal to number of assets.")
        bounds = [(0.0, float(u)) for u in wmax]

    # Sum to one constraint
    cons: List[Dict] = [ {"type": "eq", "fun": lambda w: np.sum(w) - 1.0} ]

    # Drawdown constraint if requested
    md_cap = _percent_to_decimal(MaxDD)
    if md_cap is not None:
        def dd_constraint(w: np.ndarray) -> float:
            comp = _comp_returns(R, idx, w)
            _, _, mdd = _drawdown_stats(comp)
            return md_cap - mdd  # >= 0 when constraint satisfied
        cons.append({"type": "ineq", "fun": dd_constraint})

    rng = np.random.default_rng(random_state)

    # Objective
    def objective(w: np.ndarray) -> float:
        comp = _comp_returns(R, idx, w)
        m = _metrics_from_returns(comp)
        if md_cap is None:
            # maximize MAR
            val = -m["mar"]
        else:
            # maximize CAGR subject to MaxDD via constraint
            val = -m["cagr"] if np.isfinite(m["cagr"]) else 1e6
        if not np.isfinite(val):
            val = 1e6
        return float(val)

    # Initializations
    inits = [np.full(n, 1.0 / n)]
    for _ in range(max(0, n_restarts - 1)):
        x = rng.dirichlet(np.ones(n))
        inits.append(x)

    best, best_val = None, np.inf
    for x0 in inits:
        # project to box then renormalize
        lo, hi = np.array([b[0] for b in bounds]), np.array([b[1] for b in bounds])
        x0 = np.minimum(np.maximum(x0, lo), hi)
        s = x0.sum()
        x0 = (x0 / s) if s > 0 else np.full(n, 1.0 / n)

        res = minimize(
            objective,
            x0,
            method="SLSQP",
            bounds=bounds,
            constraints=tuple(cons),
            options={"maxiter": 400, "ftol": 1e-9, "disp": False},
        )
        if res.success and res.fun < best_val:
            best, best_val = res, res.fun

    if best is None:
        raise RuntimeError("Optimization failed to converge for all initializations.")

    w_star = best.x
    comp = _comp_returns(R, idx, w_star)
    curve, _, mdd = _drawdown_stats(comp)
    m = _metrics_from_returns(comp)

    return PortfolioResult(
        weights=w_star,
        cagr=m["cagr"],
        mean_return=m["mean_return"],
        max_drawdown=m["max_drawdown"],
        mar=m["mar"],
        composite_returns=comp.rename("Composite_Return"),
        composite_curve=curve.rename("Composite_Equity_Curve"),
    )
