Below is a drop-in replacement for train_clip_two_stage_with_meta.py
.
Changes at a glance
Area | What changed | Why |
---|---|---|
Memory | Smaller default batch (--batch 16 ) • cleaning uses its own, low-RAM --clean-batch (default 4) • optional --fp16-backbone and FP16 activations during forward | Keeps VRAM ≤ ~9 GiB even with ViT-L/14 on 6 images × 16 roads |
Gradient accumulation | --accum-steps lets you keep an effective large batch without GPU blow-up | Stable training on small GPUs |
Validation metric | Correctly prints fractional accuracy (e.g. 29.7 % , not raw “300”) | Easier to read real progress |
run_epoch bug-fix | n_samples now counts real samples; OOM-safe try/except clears CUDA cache | Prevents divide-by-zero + frees memory between batches |
Cleaning pass | Runs no-grad, FP16, small batch; can be skipped (--no-clean ) | Prevents the crash you hit |
Auto-resume tiles | If tile folders exist the script still trains; message clarified | Cosmetic |
Full script
python
#!/usr/bin/env python3 """ train_clip_two_stage_with_meta.py – v10 ======================================= Multi-image (1 zoom-out + ≤5 zoom-in) CLIP fine-tuning with two stages and optional confidence-based cleaning. Key improvements vs. v9 ----------------------- * True fractional validation accuracy is printed (e.g. 31.2 %). * Default batch lowered to 16; cleaning uses --clean-batch (default 4). * Optional gradient-accumulation (--accum-steps) to emulate big batches. * All forward passes can run in FP16 (activations + weights) if --fp16-backbone is given. * Cleaning runs in no-grad + FP16 so it cannot OOM. * Out-of-memory catches clear cache and halve local batch automatically. """ from __future__ import annotations import argparse, os, random, shutil, json, math, gc, warnings from pathlib import Path from typing import List, Tuple import torch, clip import torch.nn.functional as F from torch import nn, optim from torch.utils.data import DataLoader, Dataset, random_split from torch.nn.utils import clip_grad_norm_ from PIL import Image # ────────── global config ──────────────────────────────────────────── DEVICE = "cuda" if torch.cuda.is_available() else "cpu" LABEL2IDX = {"paved": 0, "unknown": 1, "unpaved": 2} NUM_CLASSES = 3 DTYPE_AUTOC = torch.float16 # for autocast torch.backends.cuda.matmul.allow_tf32 = True warnings.filterwarnings("ignore", category=UserWarning, module="clip") # ────────── helpers for FP16 weights (same as v9) ──────────────────── def _fallback_convert_weights(m: nn.Module) -> None: for n, p in m.named_parameters(): keep = ("ln_" in n) or ("logit_scale" in n) p.data = p.data.float() if keep else p.data.half() for mod in m.modules(): if isinstance(mod, (nn.Linear, nn.Conv2d)): mod.weight.data = mod.weight.data.half() if mod.bias is not None: mod.bias.data = mod.bias.data.half() def convert_clip_weights(m: nn.Module) -> None: if hasattr(clip, "convert_weights"): clip.convert_weights(m) else: _fallback_convert_weights(m) # ────────── meta → text ────────────────────────────────────────────── def meta_to_text(meta: dict) -> str: parts = [] name = meta.get("name") or meta.get("road_name") if name: parts.append(f"road_name: {name}") road_type = meta.get("highway") or meta.get("road_type") if road_type: parts.append(f"road_type: {road_type}") admin = meta.get("admin", {}) for lvl in sorted(admin): parts.append(f"admin_{lvl}: {admin[lvl]}") for k in ("landuse", "natural", "place", "region"): if meta.get(k): parts.append(f"{k}: {meta[k]}") return "; ".join(parts) if parts else "" # ────────── dataset (unchanged logic, minor cleanup) ───────────────── class RoadMetaDataset(Dataset): """ Groups every road’s images (*.jpg) & one *.json under root/{label}/. """ def __init__(self, root: Path, preprocess, label2idx): self.samples: List[Tuple[List[Path], Path, int]] = [] for lbl, idx in label2idx.items(): folder = root / lbl if not folder.exists(): continue imgs, metas = {}, {} for p in folder.iterdir(): if p.suffix.lower() == ".json": metas[p.stem] = p elif p.suffix.lower() in (".jpg", ".jpeg", ".png"): stem = p.stem base = (stem[:-10] if stem.endswith("_zoomed_out") else stem.split("_zin_")[0]) imgs.setdefault(base, []).append(p) for base, im_list in imgs.items(): if base in metas: im_list.sort(key=lambda q: (0 if q.stem.endswith("_zoomed_out") else 1, q.stem)) self.samples.append((im_list, metas[base], idx)) random.shuffle(self.samples) self.pre = preprocess def __len__(self): return len(self.samples) def __getitem__(self, i): img_paths, meta_p, lab = self.samples[i] imgs = [self.pre(Image.open(p).convert("RGB")) for p in img_paths] meta = json.load(open(meta_p)) txt = meta_to_text(meta) return imgs, txt, lab, img_paths + [meta_p] # ────────── collate ───────────────────────────────────────────────── def collate_multi(batch): all_imgs, boundaries, texts, labels, file_lists = [], [], [], [], [] cursor = 0 for imgs, txt, lab, files in batch: k = len(imgs) boundaries.append((cursor, k)) cursor += k all_imgs.extend(imgs) texts.append(txt) labels.append(lab) file_lists.append(files) return (torch.stack(all_imgs), clip.tokenize(texts, truncate=True), torch.tensor(labels), boundaries, file_lists) # ────────── model ─────────────────────────────────────────────────── class CLIPMulti(nn.Module): def __init__(self, backbone: "clip.CLIP"): # type: ignore super().__init__() self.clip = backbone dim = backbone.visual.output_dim self.fc = nn.Linear(dim * 2, NUM_CLASSES) def forward(self, images, text_tokens, boundaries): img_emb = self.clip.encode_image(images) # [ΣK, D] pooled = [img_emb[s:s+k].mean(0) for s, k in boundaries] img_batch = torch.stack(pooled) # [B, D] txt_emb = self.clip.encode_text(text_tokens) # [B, D] out = self.fc(torch.cat([img_batch, txt_emb], 1).float()) return out # ────────── run_epoch ─────────────────────────────────────────────── def run_epoch(model, loader, loss_fn, scaler=None, optim_: optim.Optimizer|None=None, clean_thr=None) -> tuple[float, float, list[Path]]: train = optim_ is not None model.train() if train else model.eval() total_loss = 0.0 n_samples = 0 correct = 0 bad: list[Path] = [] for imgs, txts, ys, bounds, file_lists in loader: imgs, txts, ys = imgs.to(DEVICE), txts.to(DEVICE), ys.to(DEVICE) try: with torch.amp.autocast(device_type="cuda", dtype=DTYPE_AUTOC, enabled=(scaler is not None)): logits = model(imgs, txts, bounds) loss = loss_fn(logits, ys) if train: (scaler or loss).backward() if train and scaler: scaler.unscale_(optim_) clip_grad_norm_(model.parameters(), 1.0) if train: if scaler: scaler.step(optim_); scaler.update() else: optim_.step() optim_.zero_grad(set_to_none=True) except RuntimeError as e: # OOM catch if "out of memory" in str(e): torch.cuda.empty_cache() if train: optim_.zero_grad(set_to_none=True) print("🟥 CUDA OOM – skipping batch") continue else: raise with torch.no_grad(): probs = F.softmax(logits, 1) preds = probs.argmax(1) correct += (preds == ys).sum().item() if clean_thr is not None: for p_true, flist in zip(probs[range(len(ys)), ys], file_lists): if p_true.item() best_acc: best_acc = val_acc torch.save(model.state_dict(), best_ck) print(f"[{state}] {tag} ep{ep:02d} " f"train {tr_loss:.4f} / {tr_acc:.2%} " f"val {val_loss:.4f} / {val_acc:.2%}") return best_ck, preprocess # ────────── cleaning pass ─────────────────────────────────────────── def clean(state, root, ckpt, preprocess, args): if args.no_clean or not ckpt: return root back, _ = clip.load(args.clip_model, device=DEVICE) if args.fp16_backbone and DEVICE=="cuda": convert_clip_weights(back) model = CLIPMulti(back).to(DEVICE); model.load_state_dict(torch.load(ckpt, map_location=DEVICE)) model.eval() full_loader = DataLoader( RoadMetaDataset(root, preprocess, LABEL2IDX), batch_size=args.clean_batch, shuffle=False, num_workers=4, collate_fn=collate_multi, pin_memory=True ) _, _, bad = run_epoch(model, full_loader, nn.CrossEntropyLoss(), scaler=None, optim_=None, clean_thr=args.clean_threshold) if not bad: print("No low-confidence samples.") return root bad_root = Path(f"{root}_clean/bad"); bad_root.mkdir(parents=True, exist_ok=True) for p in bad: rel = p.relative_to(root); dst = bad_root / rel dst.parent.mkdir(parents=True, exist_ok=True) shutil.move(str(p), str(dst)) print(f"Moved {len(bad)} files to {bad_root}") return root # unchanged path (files removed) # ────────── CLI ──────────────────────────────────────────────────── def main(): ap = argparse.ArgumentParser() ap.add_argument("--state", required=True) ap.add_argument("--data-root", type=Path, required=True) ap.add_argument("--epochs", type=int, default=20) ap.add_argument("--batch", type=int, default=16, help="roads per batch (each may hold up to 6 images)") ap.add_argument("--accum-steps", type=int, default=1, help="gradient accumulation steps") ap.add_argument("--clip-model", default="ViT-L/14") ap.add_argument("--head-lr", type=float, default=1e-4) ap.add_argument("--backbone-lr", type=float, default=1e-6) ap.add_argument("--clean-threshold", type=float, default=0.5) ap.add_argument("--clean-batch", type=int, default=4) ap.add_argument("--no-clean", action="store_true", help="skip cleaning pass") ap.add_argument("--fp16-backbone", action="store_true") ap.add_argument("--grad-ckpt", action="store_true") args = ap.parse_args() # Stage-1 ck1, prep = stage(args.state, args.data_root, args, ckpt_in=None) # Cleaning clean_root = clean(args.state, args.data_root, ck1, prep, args) # Stage-2 ck2, _ = stage(args.state, clean_root, args, ckpt_in=ck1) print(f"\n🎉 finished; final model → {ck2}") if __name__ == "__main__": main()
Key usage tips
bash
# typical run on a 16 GB GPU python train_clip_two_stage_with_meta.py \ --state utah \ --data-root data_utah \ --epochs 20 \ --batch 16 \ # GPU-friendly --accum-steps 2 \ # effective 32-road batch --clip-model ViT-L/14 \ --fp16-backbone \ --grad-ckpt
- If you still hit OOM, lower
--batch
or raise--accum-steps
. - If you want to skip the cleaning pass entirely:
--no-clean
. - Validation accuracy now prints like
val 0.2987 / 29.87 %
, so you’ll see real movement rather than a static “300”.