{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# IESO Coincident Peak Prediction — Model Training & Selection\n",
    "\n",
    "This notebook trains and compares multiple approaches for predicting daily maximum\n",
    "Ontario demand:\n",
    "- **Approach A:** Daily max demand regression (Linear, Random Forest, XGBoost)\n",
    "- **Approach B:** Peak day classification (handles class imbalance)\n",
    "- **Approach C:** Threshold-based heuristic (baseline)\n",
    "\n",
    "The regression approach is expected to outperform classification because it avoids\n",
    "the extreme class imbalance problem (5 peaks out of 8,760 hours = 0.057%)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "import warnings\n",
    "import joblib\n",
    "from pathlib import Path\n",
    "from sklearn.linear_model import LinearRegression, LogisticRegression\n",
    "from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier\n",
    "from sklearn.metrics import (mean_squared_error, mean_absolute_error, r2_score,\n",
    "                             precision_score, recall_score, f1_score,\n",
    "                             confusion_matrix, precision_recall_curve,\n",
    "                             classification_report)\n",
    "from sklearn.model_selection import GridSearchCV, TimeSeriesSplit\n",
    "from sklearn.preprocessing import StandardScaler\n",
    "import xgboost as xgb\n",
    "import lightgbm as lgb\n",
    "import shap\n",
    "\n",
    "warnings.filterwarnings('ignore')\n",
    "sns.set_theme(style='whitegrid', font_scale=1.1)\n",
    "\n",
    "PROJECT_ROOT = Path(r'C:/wamp64/www/Spec_Driven_Dev_Website')\n",
    "DATA_DIR = PROJECT_ROOT / 'notebooks' / 'source' / 'data'\n",
    "MODEL_DIR = PROJECT_ROOT / 'notebooks' / 'source' / 'models'\n",
    "MODEL_DIR.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "print('Libraries loaded successfully')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load feature-engineered dataset\n",
    "features = pd.read_parquet(DATA_DIR / 'ieso_features_daily.parquet')\n",
    "features['Date'] = pd.to_datetime(features['Date'])\n",
    "\n",
    "print(f'Feature matrix: {len(features)} days x {len(features.columns)} columns')\n",
    "print(f'Date range: {features[\"Date\"].min()} to {features[\"Date\"].max()}')\n",
    "print(f'Base periods: {sorted(features[\"base_period\"].unique())}')\n",
    "print(f'Peak days: {features[\"is_peak_day\"].sum()}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train / Validation / Test Split\n",
    "\n",
    "Split by base period boundary to prevent temporal leakage:\n",
    "- **Train:** Base periods 2010–2021 (12 years)\n",
    "- **Validation:** Base periods 2022–2023 (2 years)\n",
    "- **Test:** Base period 2024 (held out, never seen during development)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define feature columns for modeling\n",
    "FEATURE_COLS = [\n",
    "    # Weather features (primary)\n",
    "    'daily_max_temp', 'daily_mean_temp', 'daily_max_humidex', 'daily_cdh',\n",
    "    'daily_mean_rh', 'daily_mean_dewpoint', 'temp_3day_avg', 'cdh_3day_avg',\n",
    "    # Temporal features\n",
    "    'month', 'day_of_week', 'is_business_day', 'day_of_year',\n",
    "    # Demand momentum (available by morning)\n",
    "    'prev_day_max_demand', 'rolling_7d_max_demand', 'rolling_7d_mean_demand',\n",
    "    # Peak context\n",
    "    'current_5th_peak', 'max_demand_so_far',\n",
    "]\n",
    "\n",
    "TARGET_COL = 'daily_max_demand'\n",
    "PEAK_COL = 'is_peak_day'\n",
    "\n",
    "# Split by base period\n",
    "train_mask = features['base_period'].isin(range(2010, 2022))\n",
    "val_mask = features['base_period'].isin([2022, 2023])\n",
    "test_mask = features['base_period'] == 2024\n",
    "\n",
    "# Drop rows with missing features\n",
    "complete_mask = features[FEATURE_COLS + [TARGET_COL]].notna().all(axis=1)\n",
    "\n",
    "train = features[train_mask & complete_mask].copy()\n",
    "val = features[val_mask & complete_mask].copy()\n",
    "test = features[test_mask & complete_mask].copy()\n",
    "\n",
    "X_train, y_train = train[FEATURE_COLS], train[TARGET_COL]\n",
    "X_val, y_val = val[FEATURE_COLS], val[TARGET_COL]\n",
    "X_test, y_test = test[FEATURE_COLS], test[TARGET_COL]\n",
    "\n",
    "# Classification targets\n",
    "y_train_cls = train[PEAK_COL]\n",
    "y_val_cls = val[PEAK_COL]\n",
    "y_test_cls = test[PEAK_COL]\n",
    "\n",
    "print(f'Train: {len(train)} days (base periods 2010–2021), '\n",
    "      f'{y_train_cls.sum()} peak days ({y_train_cls.mean()*100:.2f}%)')\n",
    "print(f'Val:   {len(val)} days (base periods 2022–2023), '\n",
    "      f'{y_val_cls.sum()} peak days ({y_val_cls.mean()*100:.2f}%)')\n",
    "print(f'Test:  {len(test)} days (base period 2024), '\n",
    "      f'{y_test_cls.sum()} peak days ({y_test_cls.mean()*100:.2f}%)')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Approach A: Daily Max Demand Regression\n",
    "\n",
    "Predict daily maximum Ontario Demand (MW) as a continuous quantity.\n",
    "Then derive peak day alerts by comparing predictions to the displacement threshold."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model A1: Linear Regression (baseline)\n",
    "lr = LinearRegression()\n",
    "lr.fit(X_train, y_train)\n",
    "\n",
    "y_pred_lr_val = lr.predict(X_val)\n",
    "y_pred_lr_test = lr.predict(X_test)\n",
    "\n",
    "print('=== Linear Regression ===')\n",
    "print(f'Train R²:  {lr.score(X_train, y_train):.4f}')\n",
    "print(f'Val RMSE:  {np.sqrt(mean_squared_error(y_val, y_pred_lr_val)):.1f} MW')\n",
    "print(f'Val MAE:   {mean_absolute_error(y_val, y_pred_lr_val):.1f} MW')\n",
    "print(f'Val R²:    {r2_score(y_val, y_pred_lr_val):.4f}')\n",
    "print(f'Test RMSE: {np.sqrt(mean_squared_error(y_test, y_pred_lr_test)):.1f} MW')\n",
    "print(f'Test R²:   {r2_score(y_test, y_pred_lr_test):.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model A2: Random Forest Regressor\n",
    "rf = RandomForestRegressor(\n",
    "    n_estimators=200,\n",
    "    max_depth=15,\n",
    "    min_samples_leaf=5,\n",
    "    random_state=42,\n",
    "    n_jobs=-1\n",
    ")\n",
    "rf.fit(X_train, y_train)\n",
    "\n",
    "y_pred_rf_val = rf.predict(X_val)\n",
    "y_pred_rf_test = rf.predict(X_test)\n",
    "\n",
    "print('=== Random Forest ===')\n",
    "print(f'Train R²:  {rf.score(X_train, y_train):.4f}')\n",
    "print(f'Val RMSE:  {np.sqrt(mean_squared_error(y_val, y_pred_rf_val)):.1f} MW')\n",
    "print(f'Val MAE:   {mean_absolute_error(y_val, y_pred_rf_val):.1f} MW')\n",
    "print(f'Val R²:    {r2_score(y_val, y_pred_rf_val):.4f}')\n",
    "print(f'Test RMSE: {np.sqrt(mean_squared_error(y_test, y_pred_rf_test)):.1f} MW')\n",
    "print(f'Test R²:   {r2_score(y_test, y_pred_rf_test):.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model A3: XGBoost Regressor\n",
    "xgb_model = xgb.XGBRegressor(\n",
    "    n_estimators=300,\n",
    "    max_depth=6,\n",
    "    learning_rate=0.05,\n",
    "    subsample=0.8,\n",
    "    colsample_bytree=0.8,\n",
    "    min_child_weight=5,\n",
    "    random_state=42,\n",
    "    verbosity=0\n",
    ")\n",
    "xgb_model.fit(\n",
    "    X_train, y_train,\n",
    "    eval_set=[(X_val, y_val)],\n",
    "    verbose=False\n",
    ")\n",
    "\n",
    "y_pred_xgb_val = xgb_model.predict(X_val)\n",
    "y_pred_xgb_test = xgb_model.predict(X_test)\n",
    "\n",
    "print('=== XGBoost ===')\n",
    "print(f'Train R²:  {xgb_model.score(X_train, y_train):.4f}')\n",
    "print(f'Val RMSE:  {np.sqrt(mean_squared_error(y_val, y_pred_xgb_val)):.1f} MW')\n",
    "print(f'Val MAE:   {mean_absolute_error(y_val, y_pred_xgb_val):.1f} MW')\n",
    "print(f'Val R²:    {r2_score(y_val, y_pred_xgb_val):.4f}')\n",
    "print(f'Test RMSE: {np.sqrt(mean_squared_error(y_test, y_pred_xgb_test)):.1f} MW')\n",
    "print(f'Test R²:   {r2_score(y_test, y_pred_xgb_test):.4f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model A4: LightGBM Regressor\n",
    "lgb_model = lgb.LGBMRegressor(\n",
    "    n_estimators=300,\n",
    "    max_depth=6,\n",
    "    learning_rate=0.05,\n",
    "    subsample=0.8,\n",
    "    colsample_bytree=0.8,\n",
    "    min_child_samples=10,\n",
    "    random_state=42,\n",
    "    verbose=-1\n",
    ")\n",
    "lgb_model.fit(\n",
    "    X_train, y_train,\n",
    "    eval_set=[(X_val, y_val)],\n",
    ")\n",
    "\n",
    "y_pred_lgb_val = lgb_model.predict(X_val)\n",
    "y_pred_lgb_test = lgb_model.predict(X_test)\n",
    "\n",
    "print('=== LightGBM ===')\n",
    "print(f'Train R²:  {lgb_model.score(X_train, y_train):.4f}')\n",
    "print(f'Val RMSE:  {np.sqrt(mean_squared_error(y_val, y_pred_lgb_val)):.1f} MW')\n",
    "print(f'Val MAE:   {mean_absolute_error(y_val, y_pred_lgb_val):.1f} MW')\n",
    "print(f'Val R²:    {r2_score(y_val, y_pred_lgb_val):.4f}')\n",
    "print(f'Test RMSE: {np.sqrt(mean_squared_error(y_test, y_pred_lgb_test)):.1f} MW')\n",
    "print(f'Test R²:   {r2_score(y_test, y_pred_lgb_test):.4f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Approach B: Peak Day Classification\n",
    "\n",
    "Binary classification with class imbalance handling. Included for comparison\n",
    "to demonstrate why regression is preferred."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model B1: Logistic Regression with class weights\n",
    "scaler = StandardScaler()\n",
    "X_train_scaled = scaler.fit_transform(X_train)\n",
    "X_val_scaled = scaler.transform(X_val)\n",
    "X_test_scaled = scaler.transform(X_test)\n",
    "\n",
    "log_clf = LogisticRegression(\n",
    "    class_weight='balanced',\n",
    "    max_iter=1000,\n",
    "    random_state=42\n",
    ")\n",
    "log_clf.fit(X_train_scaled, y_train_cls)\n",
    "\n",
    "y_pred_logcls_val = log_clf.predict(X_val_scaled)\n",
    "y_pred_logcls_test = log_clf.predict(X_test_scaled)\n",
    "\n",
    "print('=== Logistic Regression (Classification) ===')\n",
    "print(f'Val  — Precision: {precision_score(y_val_cls, y_pred_logcls_val, zero_division=0):.3f}, '\n",
    "      f'Recall: {recall_score(y_val_cls, y_pred_logcls_val, zero_division=0):.3f}, '\n",
    "      f'F1: {f1_score(y_val_cls, y_pred_logcls_val, zero_division=0):.3f}')\n",
    "print(f'Test — Precision: {precision_score(y_test_cls, y_pred_logcls_test, zero_division=0):.3f}, '\n",
    "      f'Recall: {recall_score(y_test_cls, y_pred_logcls_test, zero_division=0):.3f}, '\n",
    "      f'F1: {f1_score(y_test_cls, y_pred_logcls_test, zero_division=0):.3f}')\n",
    "print(f'Val false alarms: {((y_pred_logcls_val == 1) & (y_val_cls == 0)).sum()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model B2: XGBoost Classifier\n",
    "xgb_clf = xgb.XGBClassifier(\n",
    "    n_estimators=200,\n",
    "    max_depth=4,\n",
    "    learning_rate=0.05,\n",
    "    scale_pos_weight=len(y_train_cls[y_train_cls == 0]) / max(len(y_train_cls[y_train_cls == 1]), 1),\n",
    "    random_state=42,\n",
    "    verbosity=0\n",
    ")\n",
    "xgb_clf.fit(X_train, y_train_cls, eval_set=[(X_val, y_val_cls)], verbose=False)\n",
    "\n",
    "y_pred_xgbcls_val = xgb_clf.predict(X_val)\n",
    "y_pred_xgbcls_test = xgb_clf.predict(X_test)\n",
    "\n",
    "print('=== XGBoost Classifier ===')\n",
    "print(f'Val  — Precision: {precision_score(y_val_cls, y_pred_xgbcls_val, zero_division=0):.3f}, '\n",
    "      f'Recall: {recall_score(y_val_cls, y_pred_xgbcls_val, zero_division=0):.3f}, '\n",
    "      f'F1: {f1_score(y_val_cls, y_pred_xgbcls_val, zero_division=0):.3f}')\n",
    "print(f'Test — Precision: {precision_score(y_test_cls, y_pred_xgbcls_test, zero_division=0):.3f}, '\n",
    "      f'Recall: {recall_score(y_test_cls, y_pred_xgbcls_test, zero_division=0):.3f}, '\n",
    "      f'F1: {f1_score(y_test_cls, y_pred_xgbcls_test, zero_division=0):.3f}')\n",
    "print(f'Val false alarms: {((y_pred_xgbcls_val == 1) & (y_val_cls == 0)).sum()}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Approach C: Threshold-Based Heuristic (Baseline)\n",
    "\n",
    "Simple rule: alert if forecast max temp > 30°C AND is_weekday AND month in {6,7,8}.\n",
    "This represents what an experienced energy manager might do without a model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def heuristic_alert(df):\n",
    "    \"\"\"Simple temperature heuristic for peak day prediction.\"\"\"\n",
    "    return (\n",
    "        (df['daily_max_temp'] > 30) & \n",
    "        (df['is_business_day'] == 1) & \n",
    "        (df['month'].isin([6, 7, 8]))\n",
    "    ).astype(int)\n",
    "\n",
    "y_pred_heuristic_val = heuristic_alert(val)\n",
    "y_pred_heuristic_test = heuristic_alert(test)\n",
    "\n",
    "print('=== Temperature Heuristic (>30°C, weekday, Jun-Aug) ===')\n",
    "print(f'Val  — Precision: {precision_score(y_val_cls, y_pred_heuristic_val, zero_division=0):.3f}, '\n",
    "      f'Recall: {recall_score(y_val_cls, y_pred_heuristic_val, zero_division=0):.3f}, '\n",
    "      f'F1: {f1_score(y_val_cls, y_pred_heuristic_val, zero_division=0):.3f}')\n",
    "print(f'Test — Precision: {precision_score(y_test_cls, y_pred_heuristic_test, zero_division=0):.3f}, '\n",
    "      f'Recall: {recall_score(y_test_cls, y_pred_heuristic_test, zero_division=0):.3f}, '\n",
    "      f'F1: {f1_score(y_test_cls, y_pred_heuristic_test, zero_division=0):.3f}')\n",
    "print(f'Val alert days: {y_pred_heuristic_val.sum()}')\n",
    "print(f'Test alert days: {y_pred_heuristic_test.sum()}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Regression → Alert Conversion & Model Comparison\n",
    "\n",
    "Convert regression predictions into RED/YELLOW/GREEN alerts using the\n",
    "displacement threshold, then evaluate peak detection performance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def regression_to_alerts(y_pred, threshold_series, buffer_mw=500):\n",
    "    \"\"\"Convert regression predictions to RED/YELLOW/GREEN alerts.\n",
    "    \n",
    "    RED:    predicted > threshold + buffer\n",
    "    YELLOW: |predicted - threshold| <= buffer\n",
    "    GREEN:  predicted < threshold - buffer\n",
    "    \"\"\"\n",
    "    alerts = pd.Series('GREEN', index=range(len(y_pred)))\n",
    "    for i in range(len(y_pred)):\n",
    "        threshold = threshold_series.iloc[i] if threshold_series.iloc[i] > 0 else 20000\n",
    "        diff = y_pred[i] - threshold\n",
    "        if diff > buffer_mw:\n",
    "            alerts.iloc[i] = 'RED'\n",
    "        elif abs(diff) <= buffer_mw:\n",
    "            alerts.iloc[i] = 'YELLOW'\n",
    "    return alerts\n",
    "\n",
    "# Generate alerts from the best regression model (XGBoost)\n",
    "val_alerts = regression_to_alerts(y_pred_xgb_val, val['current_5th_peak'].reset_index(drop=True))\n",
    "test_alerts = regression_to_alerts(y_pred_xgb_test, test['current_5th_peak'].reset_index(drop=True))\n",
    "\n",
    "# Convert RED alerts to binary for comparison\n",
    "val_red = (val_alerts == 'RED').astype(int) | (val_alerts == 'YELLOW').astype(int)\n",
    "test_red = (test_alerts == 'RED').astype(int) | (test_alerts == 'YELLOW').astype(int)\n",
    "\n",
    "print('=== XGBoost Regression → Alert System ===')\n",
    "print(f'Val  — Precision: {precision_score(y_val_cls.values, val_red.values, zero_division=0):.3f}, '\n",
    "      f'Recall: {recall_score(y_val_cls.values, val_red.values, zero_division=0):.3f}')\n",
    "print(f'Test — Precision: {precision_score(y_test_cls.values, test_red.values, zero_division=0):.3f}, '\n",
    "      f'Recall: {recall_score(y_test_cls.values, test_red.values, zero_division=0):.3f}')\n",
    "print(f'\\nAlert distribution (test set):')\n",
    "print(test_alerts.value_counts().to_string())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Model Comparison Table\n",
    "def eval_regression(name, y_true, y_pred, y_true_cls, threshold_series):\n",
    "    alerts = regression_to_alerts(y_pred, threshold_series)\n",
    "    alert_binary = ((alerts == 'RED') | (alerts == 'YELLOW')).astype(int)\n",
    "    return {\n",
    "        'Model': name,\n",
    "        'RMSE (MW)': np.sqrt(mean_squared_error(y_true, y_pred)),\n",
    "        'MAE (MW)': mean_absolute_error(y_true, y_pred),\n",
    "        'R²': r2_score(y_true, y_pred),\n",
    "        'Peak Precision': precision_score(y_true_cls.values, alert_binary.values, zero_division=0),\n",
    "        'Peak Recall': recall_score(y_true_cls.values, alert_binary.values, zero_division=0),\n",
    "        'Peak F1': f1_score(y_true_cls.values, alert_binary.values, zero_division=0),\n",
    "        'Alert Days': alert_binary.sum(),\n",
    "    }\n",
    "\n",
    "def eval_classifier(name, y_true_cls, y_pred_cls):\n",
    "    return {\n",
    "        'Model': name,\n",
    "        'RMSE (MW)': np.nan,\n",
    "        'MAE (MW)': np.nan,\n",
    "        'R²': np.nan,\n",
    "        'Peak Precision': precision_score(y_true_cls, y_pred_cls, zero_division=0),\n",
    "        'Peak Recall': recall_score(y_true_cls, y_pred_cls, zero_division=0),\n",
    "        'Peak F1': f1_score(y_true_cls, y_pred_cls, zero_division=0),\n",
    "        'Alert Days': y_pred_cls.sum(),\n",
    "    }\n",
    "\n",
    "threshold_val = val['current_5th_peak'].reset_index(drop=True)\n",
    "threshold_test = test['current_5th_peak'].reset_index(drop=True)\n",
    "\n",
    "comparison = pd.DataFrame([\n",
    "    eval_regression('Linear Regression', y_test, y_pred_lr_test, y_test_cls, threshold_test),\n",
    "    eval_regression('Random Forest', y_test, y_pred_rf_test, y_test_cls, threshold_test),\n",
    "    eval_regression('XGBoost', y_test, y_pred_xgb_test, y_test_cls, threshold_test),\n",
    "    eval_regression('LightGBM', y_test, y_pred_lgb_test, y_test_cls, threshold_test),\n",
    "    eval_classifier('Logistic (balanced)', y_test_cls, y_pred_logcls_test),\n",
    "    eval_classifier('XGBoost Classifier', y_test_cls, y_pred_xgbcls_test),\n",
    "    eval_classifier('Heuristic (>30°C)', y_test_cls, y_pred_heuristic_test),\n",
    "])\n",
    "\n",
    "print('=== Model Comparison (Test Set: 2024 Base Period) ===')\n",
    "print(comparison.round(3).to_string(index=False))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## SHAP Feature Importance Analysis\n",
    "\n",
    "SHAP (SHapley Additive exPlanations) values quantify each feature's contribution\n",
    "to individual predictions. This reveals what drives the model's decisions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SHAP analysis on the best model (XGBoost)\n",
    "explainer = shap.TreeExplainer(xgb_model)\n",
    "shap_values = explainer.shap_values(X_val)\n",
    "\n",
    "# Summary plot (beeswarm)\n",
    "fig, ax = plt.subplots(figsize=(10, 8))\n",
    "shap.summary_plot(shap_values, X_val, feature_names=FEATURE_COLS, \n",
    "                  show=False, max_display=15)\n",
    "plt.title('SHAP Feature Importance — XGBoost Demand Regression')\n",
    "plt.tight_layout()\n",
    "plt.savefig(DATA_DIR / 'shap_summary.png', dpi=150, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# SHAP dependence plots for top 3 features\n",
    "fig, axes = plt.subplots(1, 3, figsize=(18, 5))\n",
    "\n",
    "top_features = pd.DataFrame({\n",
    "    'feature': FEATURE_COLS,\n",
    "    'importance': np.abs(shap_values).mean(axis=0)\n",
    "}).sort_values('importance', ascending=False).head(3)['feature'].values\n",
    "\n",
    "for i, feat in enumerate(top_features):\n",
    "    feat_idx = FEATURE_COLS.index(feat)\n",
    "    axes[i].scatter(X_val[feat].values, shap_values[:, feat_idx], \n",
    "                    alpha=0.3, s=5, color='#1565C0')\n",
    "    axes[i].set_xlabel(feat)\n",
    "    axes[i].set_ylabel('SHAP Value (MW)')\n",
    "    axes[i].set_title(f'SHAP Dependence: {feat}')\n",
    "    axes[i].axhline(y=0, color='gray', linestyle='--', alpha=0.5)\n",
    "\n",
    "plt.suptitle('SHAP Dependence Plots — Top 3 Features', fontsize=14, y=1.02)\n",
    "plt.tight_layout()\n",
    "plt.savefig(DATA_DIR / 'shap_dependence.png', dpi=150, bbox_inches='tight')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Hyperparameter Tuning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Light hyperparameter search on XGBoost\n",
    "param_grid = {\n",
    "    'max_depth': [4, 6, 8],\n",
    "    'learning_rate': [0.03, 0.05, 0.1],\n",
    "    'n_estimators': [200, 300, 500],\n",
    "    'min_child_weight': [3, 5, 10],\n",
    "}\n",
    "\n",
    "# Use time-series cross-validation on training data\n",
    "tscv = TimeSeriesSplit(n_splits=3)\n",
    "grid_search = GridSearchCV(\n",
    "    xgb.XGBRegressor(subsample=0.8, colsample_bytree=0.8, \n",
    "                     random_state=42, verbosity=0),\n",
    "    param_grid,\n",
    "    cv=tscv,\n",
    "    scoring='neg_root_mean_squared_error',\n",
    "    n_jobs=-1,\n",
    "    verbose=0\n",
    ")\n",
    "grid_search.fit(X_train, y_train)\n",
    "\n",
    "print(f'Best parameters: {grid_search.best_params_}')\n",
    "print(f'Best CV RMSE: {-grid_search.best_score_:.1f} MW')\n",
    "\n",
    "# Retrain with best parameters\n",
    "best_model = grid_search.best_estimator_\n",
    "y_pred_best_test = best_model.predict(X_test)\n",
    "print(f'\\nTuned model test RMSE: {np.sqrt(mean_squared_error(y_test, y_pred_best_test)):.1f} MW')\n",
    "print(f'Tuned model test R²: {r2_score(y_test, y_pred_best_test):.4f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save Trained Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Save the best model and metadata\n",
    "model_artifact = {\n",
    "    'model': best_model,\n",
    "    'feature_columns': FEATURE_COLS,\n",
    "    'target_column': TARGET_COL,\n",
    "    'scaler': scaler,  # For classification comparison\n",
    "    'training_base_periods': list(range(2010, 2022)),\n",
    "    'validation_base_periods': [2022, 2023],\n",
    "    'test_base_period': 2024,\n",
    "    'best_params': grid_search.best_params_,\n",
    "    'test_rmse': np.sqrt(mean_squared_error(y_test, y_pred_best_test)),\n",
    "    'test_r2': r2_score(y_test, y_pred_best_test),\n",
    "}\n",
    "\n",
    "joblib.dump(model_artifact, MODEL_DIR / 'ieso_peak_model.joblib')\n",
    "print(f'Model saved to: {MODEL_DIR / \"ieso_peak_model.joblib\"}')\n",
    "print(f'Best test RMSE: {model_artifact[\"test_rmse\"]:.1f} MW')\n",
    "print(f'Best test R²: {model_artifact[\"test_r2\"]:.4f}')\n",
    "\n",
    "# Also save the comparison table\n",
    "comparison.to_csv(DATA_DIR / 'model_comparison.csv', index=False)\n",
    "print(f'\\nModel comparison saved to: {DATA_DIR / \"model_comparison.csv\"}')\n",
    "\n",
    "print('\\n=== Notebook 3 complete ===')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.12.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
