# mar_optimizer.py
from __future__ import annotations
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Tuple, Optional, Dict
from scipy.optimize import minimize

@dataclass
class PortfolioResult:
    weights: np.ndarray
    cagr: float
    mean_return: float
    max_drawdown: float
    mar: float
    composite_returns: pd.Series
    composite_curve: pd.Series

def _as_months(dt_index: pd.DatetimeIndex) -> float:
    days = (dt_index[-1] - dt_index[0]).days
    return days / 365.0

def _price_df_to_returns(df_prices: pd.DataFrame) -> pd.DataFrame:
    if not isinstance(df_prices.index, pd.DatetimeIndex):
        first_col = df_prices.columns[0]
        df = df_prices.copy()
        df[first_col] = pd.to_datetime(df[first_col])
        df.set_index(first_col, inplace=True)
        df.sort_index(inplace=True)
    else:
        df = df_prices.sort_index().copy()
    prices = df.astype(float)
    rets = prices.pct_change().dropna(how="all")
    return rets

def _drawdown_stats(returns: pd.Series):
    curve = (1 + returns).cumprod()
    running_peak = curve.cummax()
    dd = curve / running_peak - 1.0
    max_dd = float(dd.min())
    return curve, dd, abs(max_dd)

def _portfolio_metrics(returns: pd.Series) -> Dict[str, float]:
    if returns.empty:
        return dict(cagr=np.nan, mean_return=np.nan, max_drawdown=np.nan, mar=np.nan)
    curve, dd, max_dd = _drawdown_stats(returns)
    years = _as_months(returns.index)
    ending = float(curve.iloc[-1])
    cagr = ending ** (1.0 / years) - 1.0 if ending > 0 and years > 0 else np.nan
    mean_ret = float(returns.mean())
    mar = (cagr / max_dd) if (max_dd > 0 and np.isfinite(cagr)) else -np.inf
    return dict(cagr=cagr, mean_return=mean_ret, max_drawdown=max_dd, mar=mar)

def optimize_mar_from_prices(df_with_dates_and_prices: pd.DataFrame, wmax=1.0, n_restarts: int = 21, random_state: Optional[int] = 42) -> PortfolioResult:
    rets = _price_df_to_returns(df_with_dates_and_prices).dropna(how="any")
    n = rets.shape[1]
    if n == 0:
        raise ValueError("No valid price columns with returns found.")
    if np.isscalar(wmax):
        bounds = [(0.0, float(wmax)) for _ in range(n)]
    else:
        wmax = np.asarray(wmax, dtype=float)
        assert wmax.shape == (n,)
        bounds = [(0.0, float(u)) for u in wmax]
    cons = ({"type": "eq", "fun": lambda w: np.sum(w) - 1.0},)
    rng = np.random.default_rng(random_state)
    def mar_objective_neg(w: np.ndarray) -> float:
        comp = rets.values @ w
        comp_s = pd.Series(comp, index=rets.index)
        metrics = _portfolio_metrics(comp_s)
        val = -metrics["mar"]
        if not np.isfinite(val):
            val = 1e6
        return float(val)
    inits = [np.full(n, 1.0 / n)]
    for _ in range(max(0, n_restarts - 1)):
        x = rng.dirichlet(np.ones(n))
        inits.append(x)
    best, best_val = None, np.inf
    for x0 in inits:
        low, high = zip(*bounds)
        x0 = np.minimum(np.maximum(x0, low), high)
        s = x0.sum()
        x0 = (x0 / s) if s > 0 else np.full(n, 1.0 / n)
        res = minimize(mar_objective_neg, x0, method="SLSQP", bounds=bounds, constraints=cons, options={"maxiter": 200, "ftol": 1e-9, "disp": False})
        if res.success and res.fun < best_val:
            best, best_val = res, res.fun
    if best is None:
        raise RuntimeError("Optimization failed to converge for all initializations.")
    w_star = best.x
    comp_returns = pd.Series(rets.values @ w_star, index=rets.index, name="Composite_Return")
    curve, _, _ = _drawdown_stats(comp_returns)
    metrics = _portfolio_metrics(comp_returns)
    return PortfolioResult(weights=w_star, cagr=metrics["cagr"], mean_return=metrics["mean_return"], max_drawdown=metrics["max_drawdown"], mar=metrics["mar"], composite_returns=comp_returns, composite_curve=curve.rename("Composite_Equity_Curve"))
