2026-01-13 21:11:38 +08:00
|
|
|
from losses import ExponentialNLLLoss, DiscreteTimeCIFNLLLoss, LogNormalBasisBinnedHazardCIFNLLLoss, get_valid_pairs_and_dt
|
2026-01-09 18:31:38 +08:00
|
|
|
from model import DelphiFork, SapDelphi, SimpleHead
|
2026-01-09 10:16:03 +08:00
|
|
|
from dataset import HealthDataset, health_collate_fn
|
|
|
|
|
from tqdm import tqdm
|
|
|
|
|
from torch.nn.utils import clip_grad_norm_
|
|
|
|
|
from torch.utils.data import random_split
|
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
|
from torch.optim import AdamW
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
import torch
|
2026-01-07 23:57:29 +08:00
|
|
|
import json
|
|
|
|
|
import os
|
|
|
|
|
import time
|
|
|
|
|
import argparse
|
|
|
|
|
import math
|
2026-01-08 12:45:31 +08:00
|
|
|
from dataclasses import asdict, dataclass, field
|
2026-01-13 15:59:20 +08:00
|
|
|
from typing import Literal, Optional, Sequence
|
2026-01-07 23:57:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
class TrainConfig:
|
|
|
|
|
# Model Parameters
|
|
|
|
|
model_type: Literal['sap_delphi', 'delphi_fork'] = 'delphi_fork'
|
2026-01-13 15:59:20 +08:00
|
|
|
loss_type: Literal['exponential', 'discrete_time_cif',
|
2026-01-13 21:11:38 +08:00
|
|
|
'lognormal_basis_binned_hazard_cif'] = 'exponential'
|
2026-01-08 11:38:45 +08:00
|
|
|
age_encoder: Literal['sinusoidal', 'mlp'] = 'sinusoidal'
|
2026-01-07 23:57:29 +08:00
|
|
|
full_cov: bool = False
|
|
|
|
|
n_embd: int = 120
|
|
|
|
|
n_head: int = 12
|
|
|
|
|
n_layer: int = 12
|
|
|
|
|
pdrop: float = 0.1
|
|
|
|
|
lambda_reg: float = 1e-4
|
2026-01-08 12:45:31 +08:00
|
|
|
bin_edges: Sequence[float] = field(
|
2026-01-09 18:31:38 +08:00
|
|
|
default_factory=lambda: [0.0, 0.24, 0.72,
|
|
|
|
|
1.61, 3.84, 10.0, 31.0, float('inf')]
|
2026-01-08 12:45:31 +08:00
|
|
|
)
|
2026-01-13 21:11:38 +08:00
|
|
|
# LogNormal basis (shared by Route-3 binned hazard)
|
2026-01-13 15:59:20 +08:00
|
|
|
lognormal_centers: Optional[Sequence[float]] = field(
|
|
|
|
|
default_factory=list) # mu_r in log-time
|
|
|
|
|
loss_eps: float = 1e-8
|
|
|
|
|
bandwidth_init: float = 0.7
|
|
|
|
|
bandwidth_min: float = 1e-3
|
|
|
|
|
bandwidth_max: float = 10.0
|
|
|
|
|
lambda_sigma_reg: float = 1e-4
|
|
|
|
|
sigma_reg_target: Optional[float] = None
|
2026-01-09 13:18:09 +08:00
|
|
|
rank: int = 16
|
2026-01-07 23:57:29 +08:00
|
|
|
# SapDelphi specific
|
|
|
|
|
pretrained_emd_path: str = "icd10_sapbert_embeddings.npy"
|
|
|
|
|
# Data Parameters
|
|
|
|
|
data_prefix: str = "ukb"
|
|
|
|
|
train_ratio: float = 0.7
|
|
|
|
|
val_ratio: float = 0.15
|
|
|
|
|
random_seed: int = 42
|
|
|
|
|
# Training Parameters
|
|
|
|
|
batch_size: int = 128
|
|
|
|
|
max_epochs: int = 200
|
|
|
|
|
warmup_epochs: int = 10
|
|
|
|
|
patience: int = 10
|
|
|
|
|
min_lr: float = 1e-5
|
|
|
|
|
max_lr: float = 5e-4
|
|
|
|
|
grad_clip: float = 1.0
|
|
|
|
|
weight_decay: float = 1e-2
|
|
|
|
|
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
2026-01-08 13:20:32 +08:00
|
|
|
num_workers: int = 0
|
|
|
|
|
prefetch_factor: int = 2
|
|
|
|
|
persistent_workers: bool = False
|
2026-01-07 23:57:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args() -> TrainConfig:
|
|
|
|
|
parser = argparse.ArgumentParser(description="Train Delphi Model")
|
|
|
|
|
parser.add_argument("--model_type", type=str, choices=[
|
|
|
|
|
'sap_delphi', 'delphi_fork'], default='delphi_fork', help="Type of model to use.")
|
2026-01-09 18:31:38 +08:00
|
|
|
parser.add_argument(
|
|
|
|
|
"--loss_type",
|
|
|
|
|
type=str,
|
2026-01-13 21:11:38 +08:00
|
|
|
choices=['exponential', 'discrete_time_cif',
|
|
|
|
|
'lognormal_basis_binned_hazard_cif'],
|
2026-01-09 18:31:38 +08:00
|
|
|
default='exponential',
|
|
|
|
|
help="Type of loss function to use.")
|
2026-01-07 23:57:29 +08:00
|
|
|
parser.add_argument("--age_encoder", type=str, choices=[
|
2026-01-08 11:38:45 +08:00
|
|
|
'sinusoidal', 'mlp'], default='sinusoidal', help="Type of age encoder to use.")
|
2026-01-07 23:57:29 +08:00
|
|
|
parser.add_argument("--n_embd", type=int, default=120,
|
|
|
|
|
help="Embedding dimension.")
|
|
|
|
|
parser.add_argument("--n_head", type=int, default=12,
|
|
|
|
|
help="Number of attention heads.")
|
|
|
|
|
parser.add_argument("--n_layer", type=int, default=12,
|
|
|
|
|
help="Number of transformer layers.")
|
|
|
|
|
parser.add_argument("--pdrop", type=float, default=0.1,
|
|
|
|
|
help="Dropout probability.")
|
|
|
|
|
parser.add_argument("--lambda_reg", type=float,
|
|
|
|
|
default=1e-4, help="Regularization weight.")
|
2026-01-13 15:59:20 +08:00
|
|
|
parser.add_argument(
|
|
|
|
|
"--lognormal_centers",
|
|
|
|
|
type=float,
|
|
|
|
|
nargs='*',
|
|
|
|
|
default=None,
|
2026-01-13 21:11:38 +08:00
|
|
|
help="LogNormal basis centers (mu_r) in log-time; provide as space-separated floats. If omitted, centers are derived from bin_edges.")
|
2026-01-13 15:59:20 +08:00
|
|
|
parser.add_argument("--loss_eps", type=float, default=1e-8,
|
2026-01-13 21:11:38 +08:00
|
|
|
help="Epsilon for log clamps in lognormal-basis losses.")
|
2026-01-13 15:59:20 +08:00
|
|
|
parser.add_argument("--bandwidth_init", type=float, default=0.7,
|
2026-01-13 21:11:38 +08:00
|
|
|
help="Initial sigma for lognormal-basis.")
|
2026-01-13 15:59:20 +08:00
|
|
|
parser.add_argument("--bandwidth_min", type=float, default=1e-3,
|
2026-01-13 21:11:38 +08:00
|
|
|
help="Minimum sigma clamp for lognormal-basis.")
|
2026-01-13 15:59:20 +08:00
|
|
|
parser.add_argument("--bandwidth_max", type=float, default=10.0,
|
2026-01-13 21:11:38 +08:00
|
|
|
help="Maximum sigma clamp for lognormal-basis.")
|
2026-01-13 15:59:20 +08:00
|
|
|
parser.add_argument("--lambda_sigma_reg", type=float, default=1e-4,
|
2026-01-13 21:11:38 +08:00
|
|
|
help="Sigma regularization strength for lognormal-basis.")
|
2026-01-13 15:59:20 +08:00
|
|
|
parser.add_argument("--sigma_reg_target", type=float, default=None,
|
|
|
|
|
help="Optional sigma target for regularization (otherwise uses bandwidth_init).")
|
2026-01-09 13:18:09 +08:00
|
|
|
parser.add_argument("--rank", type=int, default=16,
|
|
|
|
|
help="Rank for low-rank parameterization (if applicable).")
|
2026-01-07 23:57:29 +08:00
|
|
|
parser.add_argument("--pretrained_emd_path", type=str, default="icd10_sapbert_embeddings.npy",
|
|
|
|
|
help="Path to pretrained embeddings for SapDelphi.")
|
|
|
|
|
parser.add_argument("--data_prefix", type=str,
|
|
|
|
|
default="ukb", help="Prefix for dataset files.")
|
|
|
|
|
parser.add_argument("--full_cov", action='store_true',
|
|
|
|
|
help="Whether to use full covariates.")
|
|
|
|
|
parser.add_argument("--train_ratio", type=float,
|
|
|
|
|
default=0.7, help="Training data ratio.")
|
|
|
|
|
parser.add_argument("--val_ratio", type=float,
|
|
|
|
|
default=0.15, help="Validation data ratio.")
|
|
|
|
|
parser.add_argument("--random_seed", type=int, default=42,
|
|
|
|
|
help="Random seed for data splitting.")
|
|
|
|
|
parser.add_argument("--batch_size", type=int,
|
|
|
|
|
default=128, help="Batch size.")
|
|
|
|
|
parser.add_argument("--max_epochs", type=int, default=200,
|
|
|
|
|
help="Maximum number of epochs.")
|
|
|
|
|
parser.add_argument("--warmup_epochs", type=int,
|
|
|
|
|
default=10, help="Number of warmup epochs.")
|
|
|
|
|
parser.add_argument("--patience", type=int, default=10,
|
|
|
|
|
help="Early stopping patience.")
|
|
|
|
|
parser.add_argument("--min_lr", type=float, default=1e-5,
|
|
|
|
|
help="Minimum learning rate.")
|
|
|
|
|
parser.add_argument("--max_lr", type=float, default=5e-4,
|
|
|
|
|
help="Maximum learning rate.")
|
|
|
|
|
parser.add_argument("--grad_clip", type=float,
|
|
|
|
|
default=1.0, help="Gradient clipping value.")
|
|
|
|
|
parser.add_argument("--weight_decay", type=float,
|
|
|
|
|
default=1e-2, help="Weight decay for optimizer.")
|
2026-01-08 13:20:32 +08:00
|
|
|
parser.add_argument("--num_workers", type=int, default=0,
|
|
|
|
|
help="DataLoader workers (0 is safest on Windows).")
|
|
|
|
|
parser.add_argument("--prefetch_factor", type=int, default=2,
|
|
|
|
|
help="DataLoader prefetch factor (only used when num_workers>0).")
|
|
|
|
|
parser.add_argument("--persistent_workers", action='store_true',
|
|
|
|
|
help="Keep DataLoader workers alive between epochs (only if num_workers>0).")
|
2026-01-07 23:57:29 +08:00
|
|
|
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
|
|
|
|
help="Device to use for training.")
|
|
|
|
|
args = parser.parse_args()
|
2026-01-13 15:59:20 +08:00
|
|
|
|
|
|
|
|
cfg = TrainConfig(**vars(args))
|
|
|
|
|
|
|
|
|
|
# If lognormal_centers are not provided, derive a reasonable default from bin_edges.
|
|
|
|
|
if cfg.lognormal_centers is None or len(cfg.lognormal_centers) == 0:
|
|
|
|
|
edges = [float(x) for x in cfg.bin_edges]
|
|
|
|
|
finite = [e for e in edges if math.isfinite(e)]
|
|
|
|
|
if len(finite) < 2:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"bin_edges must contain at least two finite edges to derive lognormal_centers")
|
|
|
|
|
e1 = finite[1]
|
|
|
|
|
t_min = (e1 * 1e-3) if e1 > 0 else 1e-12
|
|
|
|
|
# Build one center per bin (including the +inf last bin if present).
|
|
|
|
|
centers: list[float] = []
|
|
|
|
|
for i in range(1, len(edges)):
|
|
|
|
|
left = float(edges[i - 1])
|
|
|
|
|
right = float(edges[i])
|
|
|
|
|
if i == 1 and left <= 0.0:
|
|
|
|
|
left_pos = t_min
|
|
|
|
|
else:
|
|
|
|
|
left_pos = max(left, t_min)
|
|
|
|
|
if math.isinf(right):
|
|
|
|
|
mid = max(left_pos * 2.0, left_pos + 1e-6)
|
|
|
|
|
else:
|
|
|
|
|
right_pos = max(right, t_min)
|
|
|
|
|
mid = math.sqrt(left_pos * right_pos)
|
|
|
|
|
centers.append(math.log(max(mid, t_min)))
|
|
|
|
|
cfg.lognormal_centers = centers
|
|
|
|
|
|
|
|
|
|
return cfg
|
2026-01-07 23:57:29 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_num_params(model: nn.Module) -> int:
|
|
|
|
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Trainer:
|
|
|
|
|
def __init__(
|
|
|
|
|
self,
|
|
|
|
|
cfg: TrainConfig,
|
|
|
|
|
):
|
|
|
|
|
|
|
|
|
|
self.cfg = cfg
|
|
|
|
|
self.device = cfg.device
|
|
|
|
|
self.global_step = 0
|
|
|
|
|
|
2026-01-08 13:20:32 +08:00
|
|
|
use_cuda = str(self.device).startswith(
|
|
|
|
|
"cuda") and torch.cuda.is_available()
|
|
|
|
|
if use_cuda:
|
|
|
|
|
torch.backends.cudnn.benchmark = True
|
|
|
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
|
|
|
try:
|
|
|
|
|
torch.set_float32_matmul_precision("high")
|
|
|
|
|
except Exception:
|
|
|
|
|
pass
|
|
|
|
|
|
2026-01-07 23:57:29 +08:00
|
|
|
if cfg.full_cov:
|
|
|
|
|
cov_list = None
|
|
|
|
|
else:
|
|
|
|
|
cov_list = ["bmi", "smoking", "alcohol"]
|
|
|
|
|
dataset = HealthDataset(
|
|
|
|
|
data_prefix=cfg.data_prefix,
|
|
|
|
|
covariate_list=cov_list,
|
|
|
|
|
)
|
|
|
|
|
print("Dataset loaded.")
|
|
|
|
|
n_total = len(dataset)
|
|
|
|
|
print(f"Total samples in dataset: {n_total}")
|
|
|
|
|
print(f"Number of diseases: {dataset.n_disease}")
|
|
|
|
|
print(f"Number of continuous covariates: {dataset.n_cont}")
|
|
|
|
|
print(f"Number of categorical covariates: {dataset.n_cate}")
|
|
|
|
|
self.train_data, self.val_data, _ = random_split(
|
|
|
|
|
dataset,
|
|
|
|
|
[
|
|
|
|
|
int(n_total * cfg.train_ratio),
|
|
|
|
|
int(n_total * cfg.val_ratio),
|
|
|
|
|
n_total - int(n_total * cfg.train_ratio) -
|
|
|
|
|
int(n_total * cfg.val_ratio),
|
|
|
|
|
],
|
|
|
|
|
generator=torch.Generator().manual_seed(cfg.random_seed),
|
|
|
|
|
)
|
2026-01-08 13:20:32 +08:00
|
|
|
pin_memory = use_cuda
|
|
|
|
|
loader_kwargs = dict(
|
|
|
|
|
collate_fn=health_collate_fn,
|
|
|
|
|
pin_memory=pin_memory,
|
|
|
|
|
)
|
|
|
|
|
if cfg.num_workers > 0:
|
|
|
|
|
loader_kwargs["num_workers"] = cfg.num_workers
|
|
|
|
|
loader_kwargs["prefetch_factor"] = cfg.prefetch_factor
|
|
|
|
|
loader_kwargs["persistent_workers"] = cfg.persistent_workers
|
|
|
|
|
|
2026-01-07 23:57:29 +08:00
|
|
|
self.train_loader = DataLoader(
|
|
|
|
|
self.train_data,
|
|
|
|
|
batch_size=cfg.batch_size,
|
|
|
|
|
shuffle=True,
|
2026-01-08 13:20:32 +08:00
|
|
|
**loader_kwargs,
|
2026-01-07 23:57:29 +08:00
|
|
|
)
|
|
|
|
|
self.val_loader = DataLoader(
|
|
|
|
|
self.val_data,
|
|
|
|
|
batch_size=cfg.batch_size,
|
|
|
|
|
shuffle=False,
|
2026-01-08 13:20:32 +08:00
|
|
|
**loader_kwargs,
|
2026-01-07 23:57:29 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if cfg.loss_type == "exponential":
|
|
|
|
|
self.criterion = ExponentialNLLLoss(
|
|
|
|
|
lambda_reg=cfg.lambda_reg,
|
|
|
|
|
).to(self.device)
|
2026-01-09 18:31:38 +08:00
|
|
|
out_dims = [dataset.n_disease]
|
|
|
|
|
elif cfg.loss_type == "discrete_time_cif":
|
|
|
|
|
self.criterion = DiscreteTimeCIFNLLLoss(
|
2026-01-08 12:45:31 +08:00
|
|
|
bin_edges=cfg.bin_edges,
|
|
|
|
|
lambda_reg=cfg.lambda_reg,
|
|
|
|
|
).to(self.device)
|
2026-01-09 18:31:38 +08:00
|
|
|
# logits shape (M, K+1, n_bins+1)
|
|
|
|
|
out_dims = [dataset.n_disease + 1, len(cfg.bin_edges)]
|
2026-01-13 21:11:38 +08:00
|
|
|
elif cfg.loss_type == "lognormal_basis_binned_hazard_cif":
|
2026-01-13 15:59:20 +08:00
|
|
|
r = len(cfg.lognormal_centers)
|
|
|
|
|
if r <= 0:
|
|
|
|
|
raise ValueError(
|
2026-01-13 21:11:38 +08:00
|
|
|
"lognormal_centers must be non-empty for lognormal_basis_binned_hazard_cif")
|
|
|
|
|
self.criterion = LogNormalBasisBinnedHazardCIFNLLLoss(
|
2026-01-13 15:59:20 +08:00
|
|
|
bin_edges=cfg.bin_edges,
|
|
|
|
|
centers=cfg.lognormal_centers,
|
|
|
|
|
eps=cfg.loss_eps,
|
|
|
|
|
bandwidth_init=cfg.bandwidth_init,
|
|
|
|
|
bandwidth_min=cfg.bandwidth_min,
|
|
|
|
|
bandwidth_max=cfg.bandwidth_max,
|
|
|
|
|
lambda_sigma_reg=cfg.lambda_sigma_reg,
|
|
|
|
|
sigma_reg_target=cfg.sigma_reg_target,
|
2026-01-13 21:11:38 +08:00
|
|
|
lambda_reg=cfg.lambda_reg,
|
2026-01-13 15:59:20 +08:00
|
|
|
).to(self.device)
|
2026-01-13 21:11:38 +08:00
|
|
|
# Head emits (M, J, R) for Route-3.
|
|
|
|
|
out_dims = [dataset.n_disease, r]
|
2026-01-07 23:57:29 +08:00
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported loss type: {cfg.loss_type}")
|
|
|
|
|
|
|
|
|
|
if cfg.model_type == "delphi_fork":
|
|
|
|
|
self.model = DelphiFork(
|
|
|
|
|
n_disease=dataset.n_disease,
|
2026-01-08 11:36:23 +08:00
|
|
|
n_tech_tokens=2,
|
2026-01-07 23:57:29 +08:00
|
|
|
n_embd=cfg.n_embd,
|
|
|
|
|
n_head=cfg.n_head,
|
|
|
|
|
n_layer=cfg.n_layer,
|
|
|
|
|
pdrop=cfg.pdrop,
|
2026-01-08 11:34:44 +08:00
|
|
|
age_encoder_type=cfg.age_encoder,
|
2026-01-07 23:57:29 +08:00
|
|
|
n_cont=dataset.n_cont,
|
|
|
|
|
n_cate=dataset.n_cate,
|
2026-01-08 11:36:23 +08:00
|
|
|
cate_dims=dataset.cate_dims,
|
2026-01-07 23:57:29 +08:00
|
|
|
).to(self.device)
|
|
|
|
|
elif cfg.model_type == "sap_delphi":
|
|
|
|
|
self.model = SapDelphi(
|
|
|
|
|
n_disease=dataset.n_disease,
|
2026-01-08 11:36:23 +08:00
|
|
|
n_tech_tokens=2,
|
2026-01-07 23:57:29 +08:00
|
|
|
n_embd=cfg.n_embd,
|
|
|
|
|
n_head=cfg.n_head,
|
|
|
|
|
n_layer=cfg.n_layer,
|
|
|
|
|
pdrop=cfg.pdrop,
|
2026-01-08 11:34:44 +08:00
|
|
|
age_encoder_type=cfg.age_encoder,
|
2026-01-07 23:57:29 +08:00
|
|
|
n_cont=dataset.n_cont,
|
|
|
|
|
n_cate=dataset.n_cate,
|
2026-01-08 11:36:23 +08:00
|
|
|
cate_dims=dataset.cate_dims,
|
2026-01-08 17:11:19 +08:00
|
|
|
pretrained_weights_path=cfg.pretrained_emd_path,
|
|
|
|
|
freeze_embeddings=True,
|
2026-01-07 23:57:29 +08:00
|
|
|
).to(self.device)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"Unsupported model type: {cfg.model_type}")
|
2026-01-09 18:31:38 +08:00
|
|
|
|
|
|
|
|
# Prediction head maps context vectors -> logits with the shape required by the loss.
|
|
|
|
|
self.head = SimpleHead(
|
|
|
|
|
n_embd=cfg.n_embd,
|
|
|
|
|
out_dims=out_dims,
|
|
|
|
|
).to(self.device)
|
|
|
|
|
|
2026-01-07 23:57:29 +08:00
|
|
|
print(f"Model initialized: {cfg.model_type}")
|
2026-01-09 18:31:38 +08:00
|
|
|
print(
|
|
|
|
|
f"Number of trainable parameters (backbone): {get_num_params(self.model)}")
|
|
|
|
|
print(
|
|
|
|
|
f"Number of trainable parameters (head): {get_num_params(self.head)}")
|
|
|
|
|
|
|
|
|
|
self._optim_params = (
|
|
|
|
|
list(self.model.parameters())
|
|
|
|
|
+ list(self.head.parameters())
|
|
|
|
|
)
|
2026-01-07 23:57:29 +08:00
|
|
|
self.optimizer = AdamW(
|
2026-01-09 18:31:38 +08:00
|
|
|
self._optim_params,
|
2026-01-07 23:57:29 +08:00
|
|
|
lr=cfg.max_lr,
|
|
|
|
|
weight_decay=cfg.weight_decay,
|
|
|
|
|
betas=(0.9, 0.99),
|
|
|
|
|
)
|
|
|
|
|
self.total_steps = (len(self.train_loader) *
|
|
|
|
|
cfg.max_epochs)
|
|
|
|
|
print(f"Total optimization steps: {self.total_steps}")
|
|
|
|
|
while True:
|
|
|
|
|
cov_suffix = "fullcov" if cfg.full_cov else "partcov"
|
|
|
|
|
name = f"{cfg.model_type}_{cfg.loss_type}_{cfg.age_encoder}_{cov_suffix}"
|
|
|
|
|
timestamp = time.strftime("%Y%m%d-%H%M%S")
|
|
|
|
|
model_dir = os.path.join("runs", f"{name}_{timestamp}")
|
|
|
|
|
if not os.path.exists(model_dir):
|
|
|
|
|
self.out_dir = model_dir
|
|
|
|
|
os.makedirs(model_dir)
|
|
|
|
|
break
|
|
|
|
|
time.sleep(1)
|
|
|
|
|
print(f"Output directory: {self.out_dir}")
|
|
|
|
|
self.best_path = os.path.join(self.out_dir, "best_model.pt")
|
|
|
|
|
self.global_step = 0
|
|
|
|
|
self.save_config()
|
|
|
|
|
|
|
|
|
|
def save_config(self):
|
|
|
|
|
cfg_path = os.path.join(self.out_dir, "train_config.json")
|
|
|
|
|
with open(cfg_path, 'w') as f:
|
|
|
|
|
json.dump(asdict(self.cfg), f, indent=4)
|
|
|
|
|
print(f"Configuration saved to {cfg_path}")
|
|
|
|
|
|
|
|
|
|
def compute_lr(self, current_step: int) -> float:
|
|
|
|
|
cfg = self.cfg
|
|
|
|
|
if current_step < cfg.warmup_epochs * len(self.train_loader):
|
|
|
|
|
lr = cfg.max_lr * (current_step /
|
|
|
|
|
(cfg.warmup_epochs * len(self.train_loader)))
|
|
|
|
|
else:
|
|
|
|
|
denom = (cfg.max_epochs - cfg.warmup_epochs) * \
|
|
|
|
|
len(self.train_loader)
|
|
|
|
|
progress = (current_step - cfg.warmup_epochs *
|
|
|
|
|
len(self.train_loader)) / denom
|
|
|
|
|
lr = cfg.min_lr + 0.5 * \
|
|
|
|
|
(cfg.max_lr - cfg.min_lr) * (1 + math.cos(math.pi * progress))
|
|
|
|
|
return lr
|
|
|
|
|
|
|
|
|
|
def train(self) -> None:
|
|
|
|
|
history = []
|
2026-01-08 00:07:15 +08:00
|
|
|
best_val_score = float('inf')
|
2026-01-07 23:57:29 +08:00
|
|
|
patience_counter = 0
|
|
|
|
|
for epoch in range(1, self.cfg.max_epochs + 1):
|
|
|
|
|
self.model.train()
|
2026-01-09 18:31:38 +08:00
|
|
|
self.head.train()
|
2026-01-09 12:49:29 +08:00
|
|
|
total_train_pairs = 0
|
|
|
|
|
total_train_nll = 0.0
|
|
|
|
|
total_train_reg = 0.0
|
2026-01-07 23:57:29 +08:00
|
|
|
pbar = tqdm(self.train_loader,
|
|
|
|
|
desc=f"Epoch {epoch}/{self.cfg.max_epochs} - Training", ncols=100)
|
2026-01-08 00:07:15 +08:00
|
|
|
batch_count = 0
|
2026-01-07 23:57:29 +08:00
|
|
|
for batch in pbar:
|
|
|
|
|
(
|
|
|
|
|
event_seq,
|
|
|
|
|
time_seq,
|
|
|
|
|
cont_feats,
|
|
|
|
|
cate_feats,
|
|
|
|
|
sexes,
|
|
|
|
|
) = batch
|
2026-01-08 13:20:32 +08:00
|
|
|
event_seq = event_seq.to(self.device, non_blocking=True)
|
|
|
|
|
time_seq = time_seq.to(self.device, non_blocking=True)
|
|
|
|
|
cont_feats = cont_feats.to(self.device, non_blocking=True)
|
|
|
|
|
cate_feats = cate_feats.to(self.device, non_blocking=True)
|
|
|
|
|
sexes = sexes.to(self.device, non_blocking=True)
|
2026-01-07 23:57:29 +08:00
|
|
|
res = get_valid_pairs_and_dt(event_seq, time_seq, 2)
|
|
|
|
|
if res is None:
|
|
|
|
|
continue
|
|
|
|
|
dt, b_prev, t_prev, b_next, t_next = res
|
2026-01-09 12:49:29 +08:00
|
|
|
num_pairs = dt.size(0)
|
2026-01-07 23:57:29 +08:00
|
|
|
self.optimizer.zero_grad()
|
|
|
|
|
lr = self.compute_lr(self.global_step)
|
2026-01-08 13:20:32 +08:00
|
|
|
self.optimizer.param_groups[0]['lr'] = lr
|
2026-01-09 18:31:38 +08:00
|
|
|
h = self.model(
|
2026-01-07 23:57:29 +08:00
|
|
|
event_seq,
|
|
|
|
|
time_seq,
|
|
|
|
|
sexes,
|
|
|
|
|
cont_feats,
|
|
|
|
|
cate_feats,
|
|
|
|
|
)
|
2026-01-09 13:48:36 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
# Context vectors for selected previous events
|
|
|
|
|
c = h[b_prev, t_prev] # (M, D)
|
|
|
|
|
logits = self.head(c)
|
2026-01-09 13:48:36 +08:00
|
|
|
|
2026-01-07 23:57:29 +08:00
|
|
|
target_event = event_seq[b_next, t_next] - 2
|
|
|
|
|
nll_vec, reg = self.criterion(
|
|
|
|
|
logits,
|
|
|
|
|
target_event,
|
|
|
|
|
dt,
|
|
|
|
|
reduction="none",
|
|
|
|
|
)
|
|
|
|
|
nll = nll_vec.mean()
|
|
|
|
|
loss = nll + reg
|
|
|
|
|
batch_count += 1
|
2026-01-09 12:49:29 +08:00
|
|
|
total_train_pairs += num_pairs
|
|
|
|
|
total_train_nll += nll_vec.sum().item()
|
|
|
|
|
total_train_reg += reg.item() * num_pairs
|
|
|
|
|
avg_train_nll = total_train_nll / total_train_pairs
|
|
|
|
|
avg_train_reg = total_train_reg / total_train_pairs
|
2026-01-07 23:57:29 +08:00
|
|
|
pbar.set_postfix({
|
|
|
|
|
"lr": lr,
|
2026-01-09 12:49:29 +08:00
|
|
|
"NLL": avg_train_nll,
|
|
|
|
|
"Reg": avg_train_reg,
|
2026-01-07 23:57:29 +08:00
|
|
|
})
|
|
|
|
|
loss.backward()
|
|
|
|
|
if self.cfg.grad_clip > 0:
|
2026-01-09 18:31:38 +08:00
|
|
|
clip_grad_norm_(self._optim_params, self.cfg.grad_clip)
|
2026-01-07 23:57:29 +08:00
|
|
|
self.optimizer.step()
|
|
|
|
|
self.global_step += 1
|
|
|
|
|
|
|
|
|
|
if batch_count == 0:
|
|
|
|
|
print("No valid batches in this epoch, skipping validation.")
|
|
|
|
|
continue
|
|
|
|
|
|
2026-01-09 12:49:29 +08:00
|
|
|
train_nll = total_train_nll / total_train_pairs if total_train_pairs > 0 else 0.0
|
|
|
|
|
train_reg = total_train_reg / total_train_pairs if total_train_pairs > 0 else 0.0
|
2026-01-07 23:57:29 +08:00
|
|
|
|
2026-01-08 13:14:29 +08:00
|
|
|
self.model.eval()
|
2026-01-09 18:31:38 +08:00
|
|
|
self.head.eval()
|
2026-01-07 23:57:29 +08:00
|
|
|
total_val_pairs = 0
|
|
|
|
|
total_val_nll = 0.0
|
|
|
|
|
total_val_reg = 0.0
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
val_pbar = tqdm(self.val_loader, desc="Validation")
|
|
|
|
|
for batch in val_pbar:
|
|
|
|
|
(
|
|
|
|
|
event_seq,
|
|
|
|
|
time_seq,
|
|
|
|
|
cont_feats,
|
|
|
|
|
cate_feats,
|
|
|
|
|
sexes,
|
|
|
|
|
) = batch
|
2026-01-08 13:20:32 +08:00
|
|
|
event_seq = event_seq.to(self.device, non_blocking=True)
|
|
|
|
|
time_seq = time_seq.to(self.device, non_blocking=True)
|
|
|
|
|
cont_feats = cont_feats.to(self.device, non_blocking=True)
|
|
|
|
|
cate_feats = cate_feats.to(self.device, non_blocking=True)
|
|
|
|
|
sexes = sexes.to(self.device, non_blocking=True)
|
2026-01-07 23:57:29 +08:00
|
|
|
res = get_valid_pairs_and_dt(event_seq, time_seq, 2)
|
|
|
|
|
if res is None:
|
|
|
|
|
continue
|
|
|
|
|
dt, b_prev, t_prev, b_next, t_next = res
|
|
|
|
|
num_pairs = dt.size(0)
|
2026-01-09 18:31:38 +08:00
|
|
|
h = self.model(
|
2026-01-07 23:57:29 +08:00
|
|
|
event_seq,
|
|
|
|
|
time_seq,
|
|
|
|
|
sexes,
|
|
|
|
|
cont_feats,
|
|
|
|
|
cate_feats,
|
|
|
|
|
)
|
2026-01-09 13:48:36 +08:00
|
|
|
|
2026-01-09 18:31:38 +08:00
|
|
|
c = h[b_prev, t_prev]
|
|
|
|
|
logits = self.head(c)
|
2026-01-09 13:48:36 +08:00
|
|
|
|
2026-01-07 23:57:29 +08:00
|
|
|
target_events = event_seq[b_next, t_next] - 2
|
|
|
|
|
nll, reg = self.criterion(
|
|
|
|
|
logits,
|
|
|
|
|
target_events,
|
|
|
|
|
dt,
|
|
|
|
|
reduction="none",
|
|
|
|
|
)
|
|
|
|
|
batch_nll_sum = nll.sum().item()
|
|
|
|
|
total_val_nll += batch_nll_sum
|
|
|
|
|
total_val_reg += reg.item() * num_pairs
|
|
|
|
|
total_val_pairs += num_pairs
|
|
|
|
|
|
|
|
|
|
current_val_avg_nll = total_val_nll / \
|
|
|
|
|
total_val_pairs if total_val_pairs > 0 else 0.0
|
|
|
|
|
current_val_avg_reg = total_val_reg / \
|
|
|
|
|
total_val_pairs if total_val_pairs > 0 else 0.0
|
|
|
|
|
|
|
|
|
|
val_pbar.set_postfix({
|
|
|
|
|
"NLL": f"{current_val_avg_nll:.4f}",
|
|
|
|
|
"Reg": f"{current_val_avg_reg:.4f}",
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
val_nll = total_val_nll / total_val_pairs if total_val_pairs > 0 else 0.0
|
|
|
|
|
val_reg = total_val_reg / total_val_pairs if total_val_pairs > 0 else 0.0
|
|
|
|
|
|
|
|
|
|
history.append({
|
|
|
|
|
"epoch": epoch,
|
|
|
|
|
"train_nll": train_nll,
|
|
|
|
|
"train_reg": train_reg,
|
|
|
|
|
"val_nll": val_nll,
|
|
|
|
|
"val_reg": val_reg,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
tqdm.write(f"\nEpoch {epoch+1}/{self.cfg.max_epochs} Stats:")
|
|
|
|
|
tqdm.write(f" Train NLL: {train_nll:.4f}")
|
2026-01-09 12:49:29 +08:00
|
|
|
tqdm.write(f" Train Reg: {train_reg:.4f}")
|
2026-01-07 23:57:29 +08:00
|
|
|
tqdm.write(f" Val NLL: {val_nll:.4f} ← PRIMARY METRIC")
|
2026-01-09 12:49:29 +08:00
|
|
|
tqdm.write(f" Val Reg: {val_reg:.4f}")
|
2026-01-07 23:57:29 +08:00
|
|
|
|
|
|
|
|
with open(os.path.join(self.out_dir, "training_history.json"), "w") as f:
|
|
|
|
|
json.dump(history, f, indent=4)
|
|
|
|
|
|
|
|
|
|
# Check for improvement
|
|
|
|
|
if val_nll < best_val_score:
|
|
|
|
|
best_val_score = val_nll
|
2026-01-08 00:07:15 +08:00
|
|
|
patience_counter = 0
|
2026-01-07 23:57:29 +08:00
|
|
|
tqdm.write(" ✓ New best validation score. Saving checkpoint.")
|
|
|
|
|
|
|
|
|
|
torch.save({
|
|
|
|
|
"epoch": epoch,
|
|
|
|
|
"global_step": self.global_step,
|
2026-01-08 13:14:29 +08:00
|
|
|
"model_state_dict": self.model.state_dict(),
|
2026-01-09 18:31:38 +08:00
|
|
|
"head_state_dict": self.head.state_dict(),
|
2026-01-07 23:57:29 +08:00
|
|
|
"criterion_state_dict": self.criterion.state_dict(),
|
|
|
|
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
|
|
|
}, self.best_path)
|
|
|
|
|
else:
|
2026-01-08 00:07:15 +08:00
|
|
|
patience_counter += 1
|
|
|
|
|
if epoch >= self.cfg.warmup_epochs and patience_counter >= self.cfg.patience:
|
2026-01-07 23:57:29 +08:00
|
|
|
tqdm.write(
|
2026-01-08 00:07:15 +08:00
|
|
|
f"\n⚠ No improvement in validation score for {patience_counter} epochs. Early stopping.")
|
2026-01-07 23:57:29 +08:00
|
|
|
return
|
|
|
|
|
tqdm.write(
|
2026-01-08 00:07:15 +08:00
|
|
|
f" No improvement (patience: {patience_counter}/{self.cfg.patience})")
|
2026-01-07 23:57:29 +08:00
|
|
|
|
|
|
|
|
tqdm.write("\n🎉 Training complete!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
cfg = parse_args()
|
|
|
|
|
trainer = Trainer(cfg)
|
|
|
|
|
trainer.train()
|