{
  "cells": [
    {
      "cell_type": "markdown",
      "id": "7646fe20",
      "metadata": {},
      "source": [
        "Copyright (c) Qluon Inc. All rights reserved.\n",
        "\n",
        "Provided for Learn-By-Wire Guard evaluation and customer testing under the applicable Qluon license terms.\n",
        "\n",
        "# LBW Guard Easy Test Colab\n",
        "\n",
        "This is a black-box smoke test for `lbw_guard` commercial evaluation. It compares `torch.optim.AdamW` against `lbw.Guard` on a small WikiText-103 LoRA run.\n",
        "\n",
        "It does not import local source folders. The only LBW code used is the installed `LBW-Guard` package that provides `lbw.Guard`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "7ce8f05d",
      "metadata": {},
      "outputs": [],
      "source": [
        "# @title 1. Install public dependencies\n",
        "import subprocess\n",
        "import sys\n",
        "\n",
        "deps = [\n",
        "    \"transformers>=4.45\",\n",
        "    \"datasets>=2.20\",\n",
        "    \"peft>=0.12\",\n",
        "    \"accelerate>=0.33\",\n",
        "    \"sentencepiece\",\n",
        "    \"pandas\",\n",
        "]\n",
        "subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", *deps])\n",
        "print(\"Public dependency install complete.\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "id": "0d73bff3",
      "metadata": {},
      "outputs": [],
      "source": [
        "# @title 2. Install LBW Guard package\n",
        "import importlib\n",
        "import subprocess\n",
        "import sys\n",
        "\n",
        "subprocess.check_call([sys.executable, \"-m\", \"pip\", \"install\", \"-q\", \"LBW-Guard\"])\n",
        "lbw = importlib.import_module(\"lbw\")\n",
        "\n",
        "print(\"lbw module:\", lbw.__file__)\n",
        "print(\"lbw.Guard:\", lbw.Guard)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# @title 3. Configure the easy test\n",
        "import torch\n",
        "\n",
        "MODEL_NAME = \"TinyLlama/TinyLlama_v1.1\"\n",
        "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "\n",
        "OPTIMIZERS = [\"adamw\", \"lbw_guard\"]\n",
        "SEED = 42\n",
        "MAX_STEPS = 5\n",
        "EVAL_EVERY = 5\n",
        "EVAL_BATCHES = 8\n",
        "\n",
        "SEQ_LEN = 64\n",
        "BATCH_SIZE = 1\n",
        "MAX_CHARS = 20000\n",
        "EVAL_CHARS = 8000\n",
        "\n",
        "FULL_WIKITEXT_TRAIN = False\n",
        "FULL_WIKITEXT_EVAL = False\n",
        "FULL_VALIDATION_PPL = False\n",
        "\n",
        "LR = 5e-4\n",
        "BETAS = (0.9, 0.999)\n",
        "WEIGHT_DECAY = 0.01\n",
        "\n",
        "LBW_STATS_FREQ = 10\n",
        "LBW_STRESS_TH = 1.1\n",
        "LBW_SPIKE_TH = 1.5\n",
        "LBW_REC_FAST = 0.01\n",
        "LBW_EMA_DECAY = 0.95\n",
        "\n",
        "print(\"Device:\", DEVICE)\n",
        "if DEVICE == \"cuda\":\n",
        "    print(\"GPU:\", torch.cuda.get_device_name(0))\n",
        "print(\"For true full WikiText validation PPL, set FULL_WIKITEXT_EVAL=True and FULL_VALIDATION_PPL=True.\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# @title 4. Define data, model, training, and evaluation helpers\n",
        "import gc\n",
        "import math\n",
        "import random\n",
        "import time\n",
        "\n",
        "import pandas as pd\n",
        "from datasets import load_dataset\n",
        "from peft import LoraConfig, TaskType, get_peft_model\n",
        "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
        "\n",
        "def set_seed(seed):\n",
        "    random.seed(seed)\n",
        "    torch.manual_seed(seed)\n",
        "    if torch.cuda.is_available():\n",
        "        torch.cuda.manual_seed_all(seed)\n",
        "\n",
        "def build_wikitext_chunks(tokenizer, split, max_chars):\n",
        "    cap = None if max_chars is None else int(max_chars)\n",
        "    print(f\"Preparing WikiText split={split!r}\" + (f\" with char cap {cap:,}\" if cap is not None else \" with full split\"))\n",
        "    ds = load_dataset(\"wikitext\", \"wikitext-103-raw-v1\", split=split)\n",
        "    pieces = []\n",
        "    chars_used = 0\n",
        "    rows_used = 0\n",
        "    first_piece = True\n",
        "    for row in ds:\n",
        "        text = str(row.get(\"text\", \"\") or \"\")\n",
        "        if not text.strip():\n",
        "            continue\n",
        "        piece = text if first_piece else \" \" + text\n",
        "        if cap is not None:\n",
        "            remain = cap - chars_used\n",
        "            if remain <= 0:\n",
        "                break\n",
        "            if len(piece) > remain:\n",
        "                piece = piece[:remain]\n",
        "        pieces.append(piece)\n",
        "        chars_used += len(piece)\n",
        "        rows_used += 1\n",
        "        first_piece = False\n",
        "        if cap is not None and chars_used >= cap:\n",
        "            break\n",
        "    text = \"\".join(pieces)\n",
        "    token_ids = tokenizer(text, add_special_tokens=False)[\"input_ids\"]\n",
        "    ids = torch.tensor(token_ids, dtype=torch.long)\n",
        "    n = ids.numel() // SEQ_LEN\n",
        "    if n <= 0:\n",
        "        raise RuntimeError(\"Not enough tokens. Increase MAX_CHARS or reduce SEQ_LEN.\")\n",
        "    ids = ids[: n * SEQ_LEN].view(n, SEQ_LEN).contiguous()\n",
        "    print(f\"Prepared split={split!r}: {chars_used:,} chars across {rows_used:,} rows -> {ids.size(0):,} sequences\")\n",
        "    return {\"input_ids\": ids, \"chars\": chars_used, \"rows\": rows_used, \"cap\": cap}\n",
        "\n",
        "def batch_iter(chunks):\n",
        "    ids = chunks[\"input_ids\"]\n",
        "    i = 0\n",
        "    while True:\n",
        "        if i + BATCH_SIZE > ids.size(0):\n",
        "            i = 0\n",
        "        batch = ids[i : i + BATCH_SIZE].to(DEVICE, non_blocking=True)\n",
        "        i += BATCH_SIZE\n",
        "        yield batch\n",
        "\n",
        "def load_lora_model():\n",
        "    dtype = torch.float16 if DEVICE == \"cuda\" else torch.float32\n",
        "    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=dtype, low_cpu_mem_usage=True)\n",
        "    if getattr(model.config, \"use_cache\", None) is not None:\n",
        "        model.config.use_cache = False\n",
        "    model.to(DEVICE)\n",
        "    lora_cfg = LoraConfig(\n",
        "        r=8,\n",
        "        lora_alpha=16,\n",
        "        lora_dropout=0.05,\n",
        "        target_modules=[\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\", \"gate_proj\", \"up_proj\", \"down_proj\"],\n",
        "        task_type=TaskType.CAUSAL_LM,\n",
        "        bias=\"none\",\n",
        "    )\n",
        "    return get_peft_model(model, lora_cfg)\n",
        "\n",
        "def make_optimizer(name, model):\n",
        "    params = [p for p in model.parameters() if p.requires_grad]\n",
        "    if name == \"adamw\":\n",
        "        return torch.optim.AdamW(params, lr=LR, betas=BETAS, weight_decay=WEIGHT_DECAY)\n",
        "    if name == \"lbw_guard\":\n",
        "        return lbw.Guard(\n",
        "            params,\n",
        "            lr=LR,\n",
        "            betas=BETAS,\n",
        "            weight_decay=WEIGHT_DECAY,\n",
        "            mode=\"eval\",\n",
        "            auto_enabled=True,\n",
        "            stats_freq=LBW_STATS_FREQ,\n",
        "            stress_threshold=LBW_STRESS_TH,\n",
        "            spike_threshold=LBW_SPIKE_TH,\n",
        "            recovery_fast=LBW_REC_FAST,\n",
        "            ema_decay=LBW_EMA_DECAY,\n",
        "            use_max_rms=True,\n",
        "        )\n",
        "    raise ValueError(f\"Unknown optimizer: {name}\")\n",
        "\n",
        "@torch.no_grad()\n",
        "def evaluate_ppl(model, eval_chunks, full_pass=False):\n",
        "    model.eval()\n",
        "    ids = eval_chunks[\"input_ids\"]\n",
        "    max_sequences = ids.size(0) if full_pass else min(ids.size(0), EVAL_BATCHES * BATCH_SIZE)\n",
        "    losses = []\n",
        "    for start in range(0, max_sequences, BATCH_SIZE):\n",
        "        xb = ids[start : start + BATCH_SIZE].to(DEVICE, non_blocking=True)\n",
        "        with torch.autocast(device_type=\"cuda\", dtype=torch.float16, enabled=(DEVICE == \"cuda\")):\n",
        "            loss = model(input_ids=xb, labels=xb).loss\n",
        "        losses.append(float(loss.detach().cpu()))\n",
        "    avg_loss = sum(losses) / max(len(losses), 1)\n",
        "    return avg_loss, math.exp(min(avg_loss, 20.0))\n",
        "\n",
        "def optimizer_state(opt):\n",
        "    state = dict(getattr(opt, \"state\", {}).get(\"lbw\", {}) or {})\n",
        "    return {\n",
        "        \"scale\": float(state.get(\"scale\", state.get(\"lbw_scale\", 1.0))),\n",
        "        \"ratio\": float(state.get(\"ratio\", 1.0)),\n",
        "        \"stress_mode\": str(state.get(\"stress_mode\", \"none\")),\n",
        "    }\n",
        "\n",
        "def run_one_optimizer(name, train_chunks, eval_chunks):\n",
        "    set_seed(SEED)\n",
        "    model = load_lora_model()\n",
        "    model.train()\n",
        "    opt = make_optimizer(name, model)\n",
        "    train_batches = batch_iter(train_chunks)\n",
        "    start_time = time.time()\n",
        "    last_loss = None\n",
        "    last_eval_loss = None\n",
        "    last_eval_ppl = None\n",
        "    for step in range(1, MAX_STEPS + 1):\n",
        "        xb = next(train_batches)\n",
        "        with torch.autocast(device_type=\"cuda\", dtype=torch.float16, enabled=(DEVICE == \"cuda\")):\n",
        "            loss = model(input_ids=xb, labels=xb).loss\n",
        "        loss.backward()\n",
        "        torch.nn.utils.clip_grad_norm_([p for p in model.parameters() if p.requires_grad], 1.0)\n",
        "        opt.step()\n",
        "        opt.zero_grad(set_to_none=True)\n",
        "        last_loss = float(loss.detach().cpu())\n",
        "        state = optimizer_state(opt)\n",
        "        if step == 1 or step == MAX_STEPS or step % EVAL_EVERY == 0:\n",
        "            last_eval_loss, last_eval_ppl = evaluate_ppl(model, eval_chunks, full_pass=False)\n",
        "            print(f\"{name} step {step}/{MAX_STEPS}: loss={last_loss:.4f}, sampled_eval_ppl={last_eval_ppl:.4f}, scale={state['scale']:.4f}, ratio={state['ratio']:.4f}\")\n",
        "            model.train()\n",
        "    final_full_pass = bool(FULL_VALIDATION_PPL)\n",
        "    final_scope = \"full_wikitext\" if final_full_pass and eval_chunks[\"cap\"] is None else (\"full_loaded_subset\" if final_full_pass else \"sampled\")\n",
        "    final_loss, final_ppl = evaluate_ppl(model, eval_chunks, full_pass=final_full_pass)\n",
        "    state = optimizer_state(opt)\n",
        "    wall_time = time.time() - start_time\n",
        "    result = {\n",
        "        \"optimizer\": name,\n",
        "        \"final_eval_ppl\": final_ppl,\n",
        "        \"final_eval_loss\": final_loss,\n",
        "        \"final_eval_scope\": final_scope,\n",
        "        \"train_chars\": train_chunks[\"chars\"],\n",
        "        \"eval_chars\": eval_chunks[\"chars\"],\n",
        "        \"scale\": state[\"scale\"],\n",
        "        \"ratio\": state[\"ratio\"],\n",
        "        \"stress_mode\": state[\"stress_mode\"],\n",
        "        \"wall_time_sec\": wall_time,\n",
        "    }\n",
        "    del model, opt\n",
        "    gc.collect()\n",
        "    if DEVICE == \"cuda\":\n",
        "        torch.cuda.empty_cache()\n",
        "    return result\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": [
        "# @title 5. Run AdamW vs lbw_guard on WikiText-103\n",
        "set_seed(SEED)\n",
        "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)\n",
        "if tokenizer.pad_token is None:\n",
        "    tokenizer.pad_token = tokenizer.eos_token\n",
        "\n",
        "train_cap = None if FULL_WIKITEXT_TRAIN else MAX_CHARS\n",
        "eval_cap = None if FULL_WIKITEXT_EVAL else EVAL_CHARS\n",
        "train_chunks = build_wikitext_chunks(tokenizer, \"train\", train_cap)\n",
        "eval_chunks = build_wikitext_chunks(tokenizer, \"validation\", eval_cap)\n",
        "\n",
        "results = []\n",
        "for optimizer_name in OPTIMIZERS:\n",
        "    print(\"\\n===\", optimizer_name, \"===\")\n",
        "    results.append(run_one_optimizer(optimizer_name, train_chunks, eval_chunks))\n",
        "\n",
        "df = pd.DataFrame(results)\n",
        "display(df)\n",
        "df.to_csv(\"/content/lbw_guard_easy_test_results.csv\", index=False)\n",
        "print(\"Wrote /content/lbw_guard_easy_test_results.csv\")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Reading the result\n",
        "\n",
        "- `final_eval_ppl`: lower is better for this smoke test.\n",
        "- `final_eval_scope`: `sampled`, `full_loaded_subset`, or `full_wikitext`.\n",
        "- `scale`: the LBW Guard control scale applied to the effective update. AdamW stays at `1.0`.\n",
        "- `ratio`: the LBW Guard gradient stress ratio. AdamW stays at `1.0`.\n",
        "\n",
        "This default is intentionally tiny. Use it to check installation and behavior, not to claim final benchmark quality."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "LBW_Guard_Easy_Test_COLAB.ipynb",
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 5
}
