{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Copyright (c) Qluon Inc. All rights reserved.\n",
        "\n",
        "Provided for Learn-By-Wire Guard evaluation and customer testing under the applicable Qluon license terms.\n",
        "\n",
        "# LBW Guard Ablation Colab\n",
        "\n",
        "This notebook is a black-box ablation test for `lbw_guard` in a lighter Colab form:\n",
        "\n",
        "1. Build one or more ablation scenarios.\n",
        "2. Run the same model, data slice, and training loop for `adamw` and `lbw_guard`.\n",
        "3. Write common metrics and LBW-vs-AdamW gain tables.\n",
        "\n",
        "It does not import local source folders. The only LBW code used is the installed `LBW-Guard` package that provides `lbw.Guard`.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# @title 1. Install public dependencies and LBW Guard\n",
        "import subprocess\n",
        "import sys\n",
        "\n",
        "public_deps = [\n",
        "    \"transformers>=4.45\",\n",
        "    \"datasets>=2.20\",\n",
        "    \"peft>=0.12\",\n",
        "    \"accelerate>=0.33\",\n",
        "    \"sentencepiece\",\n",
        "    \"pandas\",\n",
        "]\n",
        "subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"--upgrade\", *public_deps])\n",
        "\n",
        "# Colab can include an old torchao build. Newer PEFT versions reject it,\n",
        "# and this notebook does not need torchao for LoRA, so remove it if present.\n",
        "subprocess.call([sys.executable, \"-m\", \"pip\", \"uninstall\", \"-y\", \"-q\", \"torchao\"])\n",
        "\n",
        "subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"LBW-Guard\"])\n",
        "print(\"Dependency install complete. If this cell changed packages, restart runtime and run all cells once.\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# @title 2. Configure ablation plan\n",
        "import importlib\n",
        "from copy import deepcopy\n",
        "\n",
        "import torch\n",
        "\n",
        "lbw = importlib.import_module(\"lbw\")\n",
        "print(\"lbw module:\", lbw.__file__)\n",
        "print(\"lbw.Guard:\", lbw.Guard)\n",
        "\n",
        "MODEL_NAME = \"Qwen/Qwen2.5-0.5B\"\n",
        "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "OPTIMIZERS = [\"adamw\", \"lbw_guard\"]\n",
        "\n",
        "# Keep the default close to the local ablation test objective, but small enough for Colab.\n",
        "# Add \"lr\", \"schedule\", \"steps\", \"data\", or \"lora\" for a wider matrix.\n",
        "ABLATIONS = [\"optimizer\"]\n",
        "\n",
        "BASE_CONFIG = {\n",
        "    \"seed\": 42,\n",
        "    \"max_steps\": 200,\n",
        "    \"eval_every\": 50,\n",
        "    \"eval_batches\": 8,\n",
        "    \"seq_len\": 64,\n",
        "    \"batch_size\": 1,\n",
        "    \"max_chars\": 20000,\n",
        "    \"eval_chars\": 8000,\n",
        "    \"full_wikitext_train\": False,\n",
        "    \"full_wikitext_eval\": False,\n",
        "    \"full_validation_ppl\": False,\n",
        "    \"lr\": 5e-4,\n",
        "    \"betas\": (0.9, 0.999),\n",
        "    \"weight_decay\": 0.01,\n",
        "    \"warmup_steps\": 10,\n",
        "    \"schedule_mode\": \"constant\",  # constant or cosine\n",
        "    \"lora_r\": 8,\n",
        "    \"lora_alpha\": 16,\n",
        "    \"lora_dropout\": 0.05,\n",
        "    \"lbw_stats_freq\": 10,\n",
        "    \"lbw_stress_th\": 1.1,\n",
        "    \"lbw_spike_th\": 1.5,\n",
        "    \"lbw_rec_fast\": 0.01,\n",
        "    \"lbw_ema_decay\": 0.95,\n",
        "}\n",
        "\n",
        "LR_SWEEP = [1e-3, 5e-4]\n",
        "SCHEDULE_SWEEP = [\"constant\", \"cosine\"]\n",
        "STEP_SWEEP = [100, 200]\n",
        "DATA_SWEEP = [\n",
        "    {\"max_chars\": 20000, \"eval_chars\": 8000, \"label\": \"small-data\"},\n",
        "    {\"max_chars\": 80000, \"eval_chars\": 20000, \"label\": \"larger-data\"},\n",
        "]\n",
        "LORA_R_SWEEP = [4, 8, 16]\n",
        "\n",
        "print(\"Device:\", DEVICE)\n",
        "if DEVICE == \"cuda\":\n",
        "    print(\"GPU:\", torch.cuda.get_device_name(0))\n",
        "print(\"Selected ablations:\", ABLATIONS)\n",
        "print(\"Default optimizer steps:\", BASE_CONFIG[\"max_steps\"])\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# @title 3. Define ablation scenarios\n",
        "import pandas as pd\n",
        "\n",
        "\n",
        "def scenario(slug, label, note, overrides=None):\n",
        "    cfg = deepcopy(BASE_CONFIG)\n",
        "    if overrides:\n",
        "        cfg.update(overrides)\n",
        "    return {\n",
        "        \"slug\": slug,\n",
        "        \"label\": label,\n",
        "        \"note\": note,\n",
        "        \"config\": cfg,\n",
        "        \"optimizers\": list(OPTIMIZERS),\n",
        "    }\n",
        "\n",
        "\n",
        "def build_scenarios():\n",
        "    selected = {str(item).strip().lower() for item in ABLATIONS}\n",
        "    scenarios = []\n",
        "\n",
        "    if \"optimizer\" in selected:\n",
        "        scenarios.append(scenario(\n",
        "            \"optimizer-adamw-vs-lbw-guard\",\n",
        "            \"Optimizer: AdamW vs lbw_guard\",\n",
        "            \"Direct optimizer comparison with the base config.\",\n",
        "        ))\n",
        "\n",
        "    if \"lr\" in selected:\n",
        "        for lr in LR_SWEEP:\n",
        "            scenarios.append(scenario(\n",
        "                f\"lr-{lr:g}\",\n",
        "                f\"Learning Rate: {lr:g}\",\n",
        "                \"Learning-rate sensitivity check.\",\n",
        "                {\"lr\": float(lr)},\n",
        "            ))\n",
        "\n",
        "    if \"schedule\" in selected:\n",
        "        for mode in SCHEDULE_SWEEP:\n",
        "            scenarios.append(scenario(\n",
        "                f\"schedule-{mode}\",\n",
        "                f\"Schedule: {mode}\",\n",
        "                \"Scheduler-shape sensitivity check.\",\n",
        "                {\"schedule_mode\": mode},\n",
        "            ))\n",
        "\n",
        "    if \"steps\" in selected:\n",
        "        for steps in STEP_SWEEP:\n",
        "            scenarios.append(scenario(\n",
        "                f\"steps-{steps}\",\n",
        "                f\"Steps: {steps}\",\n",
        "                \"Training-length sensitivity check.\",\n",
        "                {\"max_steps\": int(steps), \"eval_every\": max(1, int(steps) // 4)},\n",
        "            ))\n",
        "\n",
        "    if \"data\" in selected:\n",
        "        for item in DATA_SWEEP:\n",
        "            label = item.get(\"label\", f\"data-{item['max_chars']}\")\n",
        "            scenarios.append(scenario(\n",
        "                label,\n",
        "                f\"Data Slice: {label}\",\n",
        "                \"WikiText slice-size sensitivity check.\",\n",
        "                {\"max_chars\": int(item[\"max_chars\"]), \"eval_chars\": int(item[\"eval_chars\"])},\n",
        "            ))\n",
        "\n",
        "    if \"lora\" in selected:\n",
        "        for rank in LORA_R_SWEEP:\n",
        "            scenarios.append(scenario(\n",
        "                f\"lora-r{rank}\",\n",
        "                f\"LoRA Rank: {rank}\",\n",
        "                \"Adapter-capacity sensitivity check.\",\n",
        "                {\"lora_r\": int(rank), \"lora_alpha\": int(rank) * 2},\n",
        "            ))\n",
        "\n",
        "    if not scenarios:\n",
        "        raise ValueError(\"No scenarios selected. Set ABLATIONS to include optimizer, lr, schedule, steps, data, or lora.\")\n",
        "    return scenarios\n",
        "\n",
        "\n",
        "SCENARIOS = build_scenarios()\n",
        "plan_rows = []\n",
        "for item in SCENARIOS:\n",
        "    cfg = item[\"config\"]\n",
        "    plan_rows.append({\n",
        "        \"scenario\": item[\"label\"],\n",
        "        \"optimizers\": \",\".join(item[\"optimizers\"]),\n",
        "        \"steps\": cfg[\"max_steps\"],\n",
        "        \"lr\": cfg[\"lr\"],\n",
        "        \"schedule\": cfg[\"schedule_mode\"],\n",
        "        \"train_chars\": \"FULL\" if cfg[\"full_wikitext_train\"] else cfg[\"max_chars\"],\n",
        "        \"eval_chars\": \"FULL\" if cfg[\"full_wikitext_eval\"] else cfg[\"eval_chars\"],\n",
        "        \"lora_r\": cfg[\"lora_r\"],\n",
        "        \"note\": item[\"note\"],\n",
        "    })\n",
        "\n",
        "plan_df = pd.DataFrame(plan_rows)\n",
        "display(plan_df)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# @title 4. Define data, model, optimizer, and metric helpers\n",
        "import gc\n",
        "import math\n",
        "import random\n",
        "import time\n",
        "\n",
        "from datasets import load_dataset\n",
        "from peft import LoraConfig, TaskType, get_peft_model\n",
        "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
        "\n",
        "TOKENIZER = None\n",
        "DATA_CACHE = {}\n",
        "\n",
        "\n",
        "def set_seed(seed):\n",
        "    random.seed(int(seed))\n",
        "    torch.manual_seed(int(seed))\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed_all(int(seed))\n",
        "\n",
        "\n",
        "def get_tokenizer():\n",
        "    global TOKENIZER\n",
        "    if TOKENIZER is None:\n",
        "        TOKENIZER = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)\n",
        "        if TOKENIZER.pad_token is None:\n",
        "            TOKENIZER.pad_token = TOKENIZER.eos_token\n",
        "    return TOKENIZER\n",
        "\n",
        "\n",
        "def build_wikitext_chunks(tokenizer, split, seq_len, max_chars):\n",
        "    cap = None if max_chars is None else int(max_chars)\n",
        "    print(f\"Preparing WikiText split={split!r}\" + (f\" with char cap {cap:,}\" if cap is not None else \" with full split\"))\n",
        "    ds = load_dataset(\"wikitext\", \"wikitext-103-raw-v1\", split=split)\n",
        "    pieces = []\n",
        "    chars_used = 0\n",
        "    rows_used = 0\n",
        "    first_piece = True\n",
        "    for row in ds:\n",
        "        text = str(row.get(\"text\", \"\") or \"\")\n",
        "        if not text.strip():\n",
        "            continue\n",
        "        piece = text if first_piece else \" \" + text\n",
        "        if cap is not None:\n",
        "            remain = cap - chars_used\n",
        "            if remain <= 0:\n",
        "                break\n",
        "            if len(piece) > remain:\n",
        "                piece = piece[:remain]\n",
        "        pieces.append(piece)\n",
        "        chars_used += len(piece)\n",
        "        rows_used += 1\n",
        "        first_piece = False\n",
        "        if cap is not None and chars_used >= cap:\n",
        "            break\n",
        "    text = \"\".join(pieces)\n",
        "    token_ids = tokenizer(text, add_special_tokens=False)[\"input_ids\"]\n",
        "    ids = torch.tensor(token_ids, dtype=torch.long)\n",
        "    n = ids.numel() // int(seq_len)\n",
        "    if n <= 0:\n",
        "        raise RuntimeError(\"Not enough tokens. Increase max_chars or reduce seq_len.\")\n",
        "    ids = ids[: n * int(seq_len)].view(n, int(seq_len)).contiguous()\n",
        "    print(f\"Prepared split={split!r}: {chars_used:,} chars across {rows_used:,} rows -> {ids.size(0):,} sequences\")\n",
        "    return {\"input_ids\": ids, \"chars\": chars_used, \"rows\": rows_used, \"cap\": cap, \"seq_len\": int(seq_len)}\n",
        "\n",
        "\n",
        "def get_chunks(cfg):\n",
        "    tokenizer = get_tokenizer()\n",
        "    train_cap = None if cfg[\"full_wikitext_train\"] else int(cfg[\"max_chars\"])\n",
        "    eval_cap = None if cfg[\"full_wikitext_eval\"] else int(cfg[\"eval_chars\"])\n",
        "    key = (int(cfg[\"seq_len\"]), train_cap, eval_cap)\n",
        "    if key not in DATA_CACHE:\n",
        "        DATA_CACHE[key] = {\n",
        "            \"train\": build_wikitext_chunks(tokenizer, \"train\", cfg[\"seq_len\"], train_cap),\n",
        "            \"eval\": build_wikitext_chunks(tokenizer, \"validation\", cfg[\"seq_len\"], eval_cap),\n",
        "        }\n",
        "    return DATA_CACHE[key][\"train\"], DATA_CACHE[key][\"eval\"]\n",
        "\n",
        "\n",
        "def batch_iter(chunks, batch_size):\n",
        "    ids = chunks[\"input_ids\"]\n",
        "    i = 0\n",
        "    batch_size = int(batch_size)\n",
        "    while True:\n",
        "        if i + batch_size > ids.size(0):\n",
        "            i = 0\n",
        "        batch = ids[i : i + batch_size].to(DEVICE, non_blocking=True)\n",
        "        i += batch_size\n",
        "        yield batch\n",
        "\n",
        "\n",
        "def load_lora_model(cfg):\n",
        "    dtype = torch.float16 if DEVICE == \"cuda\" else torch.float32\n",
        "    model = AutoModelForCausalLM.from_pretrained(\n",
        "        MODEL_NAME,\n",
        "        dtype=dtype,\n",
        "        low_cpu_mem_usage=True,\n",
        "        use_safetensors=True,\n",
        "    )\n",
        "    if getattr(model.config, \"use_cache\", None) is not None:\n",
        "        model.config.use_cache = False\n",
        "    model.to(DEVICE)\n",
        "    lora_cfg = LoraConfig(\n",
        "        r=int(cfg[\"lora_r\"]),\n",
        "        lora_alpha=int(cfg[\"lora_alpha\"]),\n",
        "        lora_dropout=float(cfg[\"lora_dropout\"]),\n",
        "        target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
        "        task_type=TaskType.CAUSAL_LM,\n",
        "        bias=\"none\",\n",
        "    )\n",
        "    return get_peft_model(model, lora_cfg)\n",
        "\n",
        "\n",
        "def make_optimizer(name, model, cfg):\n",
        "    params = [p for p in model.parameters() if p.requires_grad]\n",
        "    if name == \"adamw\":\n",
        "        return torch.optim.AdamW(params, lr=float(cfg[\"lr\"]), betas=tuple(cfg[\"betas\"]), weight_decay=float(cfg[\"weight_decay\"]))\n",
        "    if name == \"lbw_guard\":\n",
        "        return lbw.Guard(\n",
        "            params,\n",
        "            lr=float(cfg[\"lr\"]),\n",
        "            betas=tuple(cfg[\"betas\"]),\n",
        "            weight_decay=float(cfg[\"weight_decay\"]),\n",
        "            mode=\"eval\",\n",
        "            auto_enabled=True,\n",
        "            stats_freq=int(cfg[\"lbw_stats_freq\"]),\n",
        "            stress_threshold=float(cfg[\"lbw_stress_th\"]),\n",
        "            spike_threshold=float(cfg[\"lbw_spike_th\"]),\n",
        "            recovery_fast=float(cfg[\"lbw_rec_fast\"]),\n",
        "            ema_decay=float(cfg[\"lbw_ema_decay\"]),\n",
        "            use_max_rms=True,\n",
        "        )\n",
        "    raise ValueError(f\"Unknown optimizer: {name}\")\n",
        "\n",
        "\n",
        "def scheduled_lr(cfg, step):\n",
        "    base_lr = float(cfg[\"lr\"])\n",
        "    warmup = max(int(cfg.get(\"warmup_steps\", 0)), 0)\n",
        "    max_steps = max(int(cfg[\"max_steps\"]), 1)\n",
        "    if warmup > 0 and step <= warmup:\n",
        "        return base_lr * float(step) / float(warmup)\n",
        "    mode = str(cfg.get(\"schedule_mode\", \"constant\")).lower()\n",
        "    if mode == \"cosine\":\n",
        "        progress = (step - warmup) / max(max_steps - warmup, 1)\n",
        "        progress = min(max(progress, 0.0), 1.0)\n",
        "        return base_lr * 0.5 * (1.0 + math.cos(math.pi * progress))\n",
        "    return base_lr\n",
        "\n",
        "\n",
        "def set_lr(opt, value):\n",
        "    for group in opt.param_groups:\n",
        "        group[\"lr\"] = float(value)\n",
        "\n",
        "\n",
        "@torch.no_grad()\n",
        "def evaluate_ppl(model, eval_chunks, cfg, full_pass=False):\n",
        "    model.eval()\n",
        "    ids = eval_chunks[\"input_ids\"]\n",
        "    batch_size = int(cfg[\"batch_size\"])\n",
        "    max_sequences = ids.size(0) if full_pass else min(ids.size(0), int(cfg[\"eval_batches\"]) * batch_size)\n",
        "    losses = []\n",
        "    for start in range(0, max_sequences, batch_size):\n",
        "        xb = ids[start : start + batch_size].to(DEVICE, non_blocking=True)\n",
        "        with torch.autocast(device_type=DEVICE, dtype=torch.float16, enabled=(DEVICE == \"cuda\")):\n",
        "            loss = model(input_ids=xb, labels=xb).loss\n",
        "        losses.append(float(loss.detach().cpu()))\n",
        "    avg_loss = sum(losses) / max(len(losses), 1)\n",
        "    return avg_loss, math.exp(min(avg_loss, 20.0))\n",
        "\n",
        "\n",
        "def optimizer_state(opt):\n",
        "    state = dict(getattr(opt, \"state\", {}).get(\"lbw\", {}) or {})\n",
        "    return {\n",
        "        \"scale\": float(state.get(\"scale\", state.get(\"lbw_scale\", 1.0))),\n",
        "        \"ratio\": float(state.get(\"ratio\", 1.0)),\n",
        "        \"stress_mode\": str(state.get(\"stress_mode\", \"none\")),\n",
        "    }\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "107c58b1",
      "metadata": {},
      "outputs": [],
      "source": [
        "# @title 5. Run ablation matrix\n",
        "\n",
        "def run_one_optimizer(scenario_item, optimizer_name):\n",
        "    cfg = scenario_item[\"config\"]\n",
        "    train_chunks, eval_chunks = get_chunks(cfg)\n",
        "    set_seed(cfg[\"seed\"])\n",
        "    model = load_lora_model(cfg)\n",
        "    model.train()\n",
        "    opt = make_optimizer(optimizer_name, model, cfg)\n",
        "    train_batches = batch_iter(train_chunks, cfg[\"batch_size\"])\n",
        "\n",
        "    start_time = time.time()\n",
        "    losses = []\n",
        "    eval_loss = None\n",
        "    eval_ppl = None\n",
        "    last_lr = float(cfg[\"lr\"])\n",
        "\n",
        "    for step in range(1, int(cfg[\"max_steps\"]) + 1):\n",
        "        last_lr = scheduled_lr(cfg, step)\n",
        "        set_lr(opt, last_lr)\n",
        "        xb = next(train_batches)\n",
        "        with torch.autocast(device_type=DEVICE, dtype=torch.float16, enabled=(DEVICE == \"cuda\")):\n",
        "            loss = model(input_ids=xb, labels=xb).loss\n",
        "        loss.backward()\n",
        "        torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], 1.0)\n",
        "        opt.step()\n",
        "        opt.zero_grad(set_to_none=True)\n",
        "        loss_value = float(loss.detach().cpu())\n",
        "        losses.append(loss_value)\n",
        "\n",
        "        if step == 1 or step == int(cfg[\"max_steps\"]) or step % int(cfg[\"eval_every\"]) == 0:\n",
        "            eval_loss, eval_ppl = evaluate_ppl(model, eval_chunks, cfg, full_pass=False)\n",
        "            state = optimizer_state(opt)\n",
        "            print(\n",
        "                f\"[{scenario_item['slug']}] {optimizer_name} step {step}/{cfg['max_steps']}: \"\n",
        "                f\"loss={loss_value:.4f}, sampled_eval_ppl={eval_ppl:.4f}, \"\n",
        "                f\"lr={last_lr:.2e}, scale={state['scale']:.4f}, ratio={state['ratio']:.4f}\"\n",
        "            )\n",
        "            model.train()\n",
        "\n",
        "    final_full_pass = bool(cfg[\"full_validation_ppl\"])\n",
        "    final_scope = \"full_wikitext\" if final_full_pass and eval_chunks[\"cap\"] is None else (\"full_loaded_subset\" if final_full_pass else \"sampled\")\n",
        "    final_loss, final_ppl = evaluate_ppl(model, eval_chunks, cfg, full_pass=final_full_pass)\n",
        "    state = optimizer_state(opt)\n",
        "    wall_time = max(time.time() - start_time, 1e-9)\n",
        "    trained_tokens = int(cfg[\"max_steps\"]) * int(cfg[\"batch_size\"]) * int(cfg[\"seq_len\"])\n",
        "\n",
        "    result = {\n",
        "        \"scenario_slug\": scenario_item[\"slug\"],\n",
        "        \"scenario\": scenario_item[\"label\"],\n",
        "        \"optimizer\": optimizer_name,\n",
        "        \"final_eval_ppl\": final_ppl,\n",
        "        \"final_eval_loss\": final_loss,\n",
        "        \"train_loss_last\": losses[-1] if losses else None,\n",
        "        \"final_eval_scope\": final_scope,\n",
        "        \"max_steps\": int(cfg[\"max_steps\"]),\n",
        "        \"lr\": float(cfg[\"lr\"]),\n",
        "        \"scheduled_lr_last\": float(last_lr),\n",
        "        \"schedule_mode\": cfg[\"schedule_mode\"],\n",
        "        \"batch_size\": int(cfg[\"batch_size\"]),\n",
        "        \"seq_len\": int(cfg[\"seq_len\"]),\n",
        "        \"lora_r\": int(cfg[\"lora_r\"]),\n",
        "        \"train_chars\": int(train_chunks[\"chars\"]),\n",
        "        \"eval_chars\": int(eval_chunks[\"chars\"]),\n",
        "        \"train_sequences\": int(train_chunks[\"input_ids\"].size(0)),\n",
        "        \"eval_sequences\": int(eval_chunks[\"input_ids\"].size(0)),\n",
        "        \"scale\": state[\"scale\"],\n",
        "        \"ratio\": state[\"ratio\"],\n",
        "        \"stress_mode\": state[\"stress_mode\"],\n",
        "        \"wall_time_sec\": wall_time,\n",
        "        \"tokens_per_sec_wall\": trained_tokens / wall_time,\n",
        "    }\n",
        "\n",
        "    del model, opt\n",
        "    gc.collect()\n",
        "    if DEVICE == \"cuda\":\n",
        "        torch.cuda.empty_cache()\n",
        "    return result\n",
        "\n",
        "\n",
        "results = []\n",
        "for scenario_item in SCENARIOS:\n",
        "    print(\"\\n=== Scenario:\", scenario_item[\"label\"], \"===\")\n",
        "    for optimizer_name in scenario_item[\"optimizers\"]:\n",
        "        print(\"\\n---\", optimizer_name, \"---\")\n",
        "        results.append(run_one_optimizer(scenario_item, optimizer_name))\n",
        "\n",
        "metrics_df = pd.DataFrame(results)\n",
        "display(metrics_df)\n",
        "metrics_path = \"/content/lbw_guard_ablation_metrics.csv\"\n",
        "metrics_df.to_csv(metrics_path, index=False)\n",
        "print(\"Wrote\", metrics_path)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# @title 6. Compute LBW-vs-AdamW gains\n",
        "\n",
        "def build_gain_rows(metrics):\n",
        "    gain_rows = []\n",
        "    for scenario_slug, group in metrics.groupby(\"scenario_slug\"):\n",
        "        baseline_rows = group[group[\"optimizer\"] == \"adamw\"]\n",
        "        if baseline_rows.empty:\n",
        "            continue\n",
        "        baseline = baseline_rows.iloc[0]\n",
        "        for _, row in group.iterrows():\n",
        "            if row[\"optimizer\"] == \"adamw\":\n",
        "                continue\n",
        "            ppl_gain_pct = (baseline[\"final_eval_ppl\"] - row[\"final_eval_ppl\"]) / baseline[\"final_eval_ppl\"] * 100.0\n",
        "            loss_gain_pct = (baseline[\"final_eval_loss\"] - row[\"final_eval_loss\"]) / baseline[\"final_eval_loss\"] * 100.0\n",
        "            speed_gain_pct = (row[\"tokens_per_sec_wall\"] - baseline[\"tokens_per_sec_wall\"]) / baseline[\"tokens_per_sec_wall\"] * 100.0\n",
        "            gain_rows.append({\n",
        "                \"scenario_slug\": scenario_slug,\n",
        "                \"scenario\": row[\"scenario\"],\n",
        "                \"optimizer\": row[\"optimizer\"],\n",
        "                \"adamw_final_eval_ppl\": baseline[\"final_eval_ppl\"],\n",
        "                \"optimizer_final_eval_ppl\": row[\"final_eval_ppl\"],\n",
        "                \"ppl_gain_pct_vs_adamw\": ppl_gain_pct,\n",
        "                \"loss_gain_pct_vs_adamw\": loss_gain_pct,\n",
        "                \"speed_gain_pct_vs_adamw\": speed_gain_pct,\n",
        "                \"adamw_tokens_per_sec_wall\": baseline[\"tokens_per_sec_wall\"],\n",
        "                \"optimizer_tokens_per_sec_wall\": row[\"tokens_per_sec_wall\"],\n",
        "                \"lbw_scale\": row[\"scale\"],\n",
        "                \"lbw_ratio\": row[\"ratio\"],\n",
        "                \"lbw_stress_mode\": row[\"stress_mode\"],\n",
        "            })\n",
        "    return gain_rows\n",
        "\n",
        "\n",
        "gains_df = pd.DataFrame(build_gain_rows(metrics_df))\n",
        "display(gains_df if not gains_df.empty else pd.DataFrame([{\"message\": \"No gain rows. Keep adamw and lbw_guard in OPTIMIZERS.\"}]))\n",
        "gains_path = \"/content/lbw_guard_ablation_gains.csv\"\n",
        "gains_df.to_csv(gains_path, index=False)\n",
        "print(\"Wrote\", gains_path)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## How to read this ablation\n",
        "\n",
        "- `scenario`: the ablation condition being tested.\n",
        "- `optimizer`: `adamw` is the baseline; `lbw_guard` is the TCG optimizer under test.\n",
        "- `final_eval_ppl`: lower is better for this WikiText smoke benchmark.\n",
        "- `ppl_gain_pct_vs_adamw`: positive means `lbw_guard` achieved lower perplexity than AdamW in that scenario.\n",
        "- `scale`: LBW Guard's control scale for the effective update.\n",
        "- `ratio`: LBW Guard's gradient stress ratio.\n",
        "- `stress_mode`: LBW Guard's current controller regime.\n",
        "- `final_eval_scope`: `sampled`, `full_loaded_subset`, or `full_wikitext`.\n",
        "\n",
        "For true full WikiText validation PPL, set both values in `BASE_CONFIG`:\n",
        "\n",
        "```python\n",
        "\"full_wikitext_eval\": True,\n",
        "\"full_validation_ppl\": True,\n",
        "```\n",
        "\n",
        "For a wider ablation matrix, change:\n",
        "\n",
        "```python\n",
        "ABLATIONS = [\"optimizer\", \"lr\", \"schedule\", \"steps\", \"data\", \"lora\"]\n",
        "```\n",
        "\n",
        "That will take longer because each scenario runs both AdamW and `lbw_guard`.\n"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "LBW_Guard_Ablation_Test_COLAB.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
