Enhance LandmarkEvaluator with model compilation and optimization options

This commit is contained in:
2026-01-18 17:56:59 +08:00
parent 014393a33f
commit a4b19b6e08

View File

@@ -35,6 +35,19 @@ from losses import (
warnings.filterwarnings('ignore') warnings.filterwarnings('ignore')
def _maybe_torch_compile(module: torch.nn.Module, enabled: bool = True) -> torch.nn.Module:
"""Best-effort torch.compile() wrapper (PyTorch 2.x)."""
if not enabled:
return module
try:
torch_compile = getattr(torch, "compile", None)
if torch_compile is None:
return module
return torch_compile(module, mode="reduce-overhead")
except Exception:
return module
def _ensure_dir(path: str) -> str: def _ensure_dir(path: str) -> str:
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
return path return path
@@ -225,11 +238,10 @@ class LandmarkEvaluator:
device: str = 'cuda', device: str = 'cuda',
batch_size: int = 256, batch_size: int = 256,
num_workers: int = 4, num_workers: int = 4,
compile_model: bool = True,
): ):
self.model = model.to(device) self.model = model.to(device).eval()
self.model.eval() self.head = head.to(device).eval()
self.head = head.to(device)
self.head.eval()
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.dataset = dataset self.dataset = dataset
self.eval_indices = eval_indices self.eval_indices = eval_indices
@@ -237,6 +249,22 @@ class LandmarkEvaluator:
self.batch_size = batch_size self.batch_size = batch_size
self.num_workers = num_workers self.num_workers = num_workers
use_cuda = str(self.device).startswith(
"cuda") and torch.cuda.is_available()
if use_cuda:
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
try:
torch.set_float32_matmul_precision("high")
except Exception:
pass
# JIT/compile optimization (best effort)
if compile_model and use_cuda:
self.model = _maybe_torch_compile(self.model, enabled=True)
self.head = _maybe_torch_compile(self.head, enabled=True)
# Evaluation parameters from design doc # Evaluation parameters from design doc
self.age_cutoffs = [50, 60, 70] self.age_cutoffs = [50, 60, 70]
self.horizons = [0.25, 0.5, 1, 2, 5, 10] self.horizons = [0.25, 0.5, 1, 2, 5, 10]
@@ -247,6 +275,150 @@ class LandmarkEvaluator:
self.age_cutoffs_days = [age * 365.25 for age in self.age_cutoffs] self.age_cutoffs_days = [age * 365.25 for age in self.age_cutoffs]
self.horizons_days = [h * 365.25 for h in self.horizons] self.horizons_days = [h * 365.25 for h in self.horizons]
@staticmethod
def _last_time(time_batch: torch.Tensor, event_batch: torch.Tensor) -> torch.Tensor:
"""Compute last observed (non-padding) time per patient."""
real_mask = event_batch >= 1
masked = time_batch.masked_fill(~real_mask, float('-inf'))
return masked.max(dim=1).values
@staticmethod
def _anchor_indices(
time_batch: torch.Tensor,
event_batch: torch.Tensor,
cutoff_days: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Find anchor index/time: last valid record before cutoff."""
real_mask = event_batch >= 1
before = time_batch < cutoff_days
valid_before = real_mask & before
has_anchor = valid_before.any(dim=1)
# argmax of position under mask gives last True position
L = event_batch.size(1)
pos = torch.arange(L, device=event_batch.device).view(1, L)
anchor_idx = (valid_before.to(torch.long) *
pos).max(dim=1).values.to(torch.long)
t_anchor = time_batch.gather(1, anchor_idx.view(-1, 1)).squeeze(1)
return has_anchor, anchor_idx, t_anchor
def _labels_and_validity_for_cutoff(
self,
time_batch: torch.Tensor,
event_batch: torch.Tensor,
cutoff_days: float,
horizons_days: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Vectorized label + validity computation for all horizons at a cutoff.
Returns:
labels: (B, H, K) float32 {0,1}
valid_cc: (B, H, K) bool
valid_clean: (B, H, K) bool
"""
n_tech_tokens = 2
K = int(self.dataset.n_disease)
death_code = int(K - 1)
B, L = event_batch.shape
H = int(horizons_days.numel())
# Disease token mask and indices
is_disease = event_batch >= n_tech_tokens
disease_idx = (event_batch - n_tech_tokens).clamp(min=0, max=K - 1)
# ever_has_disease: (B, K)
ever = torch.zeros((B, K), dtype=torch.bool, device=event_batch.device)
if is_disease.any():
b_idx, t_idx = is_disease.nonzero(as_tuple=True)
d_idx = disease_idx[b_idx, t_idx]
ever[b_idx, d_idx] = True
# Events within horizon windows: (B, L, H)
offset = time_batch - float(cutoff_days)
within = is_disease.unsqueeze(-1) & (offset.unsqueeze(-1) >= 0) & (
offset.unsqueeze(-1) <= horizons_days.view(1, 1, H)
)
labels_bool = torch.zeros(
(B, H, K), dtype=torch.bool, device=event_batch.device)
if within.any():
b2, t2, h2 = within.nonzero(as_tuple=True)
d2 = disease_idx[b2, t2]
labels_bool[b2, h2, d2] = True
labels = labels_bool.to(torch.float32)
last_time = self._last_time(time_batch, event_batch) # (B,)
horizon_end = float(cutoff_days) + horizons_days.view(1, H) # (1, H)
death_in_horizon = labels_bool[:, :, death_code] # (B, H)
observed_past_horizon = last_time.view(B, 1) > horizon_end
lost_within_horizon = last_time.view(B, 1) <= horizon_end
# Track A (Complete-Case):
# - if observed past horizon OR death in horizon => valid all diseases
# - else (censored within horizon) => valid only for diseases that occurred within horizon
valid_cc = labels_bool.clone()
full_mask = (observed_past_horizon | death_in_horizon).unsqueeze(-1)
if full_mask.any():
valid_cc = torch.where(
full_mask.expand(-1, -1, K), torch.ones_like(valid_cc), valid_cc)
# Track B (Clean-Control) per disease:
# valid[k] = hit_in_window(k) OR (never_has_k AND not lost_within_window)
never = ~ever # (B, K)
valid_clean = (~death_in_horizon).unsqueeze(-1) & (
labels_bool | (never.unsqueeze(1) & (
~lost_within_horizon).unsqueeze(-1))
)
return labels, valid_cc, valid_clean
def _compute_risk_scores_many_horizons(
self,
logits: torch.Tensor,
t_start_days: torch.Tensor,
horizons_days: torch.Tensor,
) -> torch.Tensor:
"""Compute risk increments for all horizons in one vectorized call.
Args:
logits: model head outputs for anchor points.
t_start_days: (B,) time from anchor to cutoff (days).
horizons_days: (H,) horizons in days.
Returns:
risk: (B, H, K) float32
"""
t_start_days = torch.clamp(t_start_days, min=0)
t_end_days = torch.clamp(t_start_days.unsqueeze(
1) + horizons_days.view(1, -1), min=0)
t_query_years = torch.cat([t_start_days.unsqueeze(
1), t_end_days], dim=1) / 365.25 # (B, H+1)
# calculate_cifs returns (B, K) if scalar/per-sample, else (B, K, T)
if hasattr(self.loss_fn, "calculate_cifs"):
cifs = self.loss_fn.calculate_cifs(
logits, t_query_years, return_survival=False)
else:
raise ValueError(
f"Loss function does not support calculate_cifs: {type(self.loss_fn)}")
if cifs.ndim == 2:
# (B, K) -> (B, 1, K)
cifs_bt_k = cifs.unsqueeze(1)
else:
# (B, K, T) -> (B, T, K)
cifs_bt_k = cifs.permute(0, 2, 1)
cif_start = cifs_bt_k[:, :1, :] # (B, 1, K)
cif_end = cifs_bt_k[:, 1:, :] # (B, H, K)
risk = torch.clamp(cif_end - cif_start, min=0)
return risk
@torch.no_grad() @torch.no_grad()
def compute_risk_scores( def compute_risk_scores(
self, self,
@@ -926,22 +1098,169 @@ class LandmarkEvaluator:
return results return results
def run_full_evaluation(self) -> Dict: def run_full_evaluation(self) -> Dict:
""" """Run the full evaluation using a single-pass DataLoader.
Run complete landmark analysis across all cutoffs and horizons.
Returns: Key optimizations:
all_results: Nested dictionary with all evaluation results - iterate DataLoader exactly once
- run transformer backbone once per batch
- reuse hidden states per cutoff (3x head only)
- vectorize CIF/risk over all horizons in one call
""" """
all_results = {
# Build evaluation subset loader
indices = self.eval_indices if self.eval_indices is not None else list(
range(len(self.dataset)))
subset = Subset(self.dataset, indices)
loader = DataLoader(
subset,
batch_size=self.batch_size,
shuffle=False,
collate_fn=health_collate_fn,
num_workers=self.num_workers,
pin_memory=True if str(self.device).startswith('cuda') else False,
)
cutoffs_days = torch.tensor(
# (C,)
self.age_cutoffs_days, dtype=torch.float32, device=self.device)
horizons_days = torch.tensor(
# (H,)
self.horizons_days, dtype=torch.float32, device=self.device)
C = int(cutoffs_days.numel())
H = int(horizons_days.numel())
K = int(self.dataset.n_disease)
# Buffers: store per landmark/track arrays in chunks to avoid repeated I/O.
# Each key stores lists of numpy arrays to be concatenated at the end.
buffers: Dict[Tuple[int, int, str], Dict[str, List[np.ndarray]]] = {}
for ci in range(C):
for hi in range(H):
for track in ("complete_case", "clean_control"):
buffers[(ci, hi, track)] = {
"risk": [], "labels": [], "valid": []}
with torch.inference_mode():
for batch in tqdm(loader, desc="Single-pass evaluation", ncols=100):
event_batch, time_batch, cont_batch, cate_batch, sex_batch = batch
event_batch = event_batch.to(self.device, non_blocking=True)
time_batch = time_batch.to(self.device, non_blocking=True)
cont_batch = cont_batch.to(self.device, non_blocking=True)
cate_batch = cate_batch.to(self.device, non_blocking=True)
sex_batch = sex_batch.to(self.device, non_blocking=True)
B, L = event_batch.shape
batch_idx = torch.arange(B, device=self.device)
# Backbone once per batch
hidden = self.model(
# (B, L, D)
event_batch, time_batch, sex_batch, cont_batch, cate_batch)
for ci in range(C):
cutoff = float(cutoffs_days[ci].item())
has_anchor, anchor_idx, t_anchor = self._anchor_indices(
time_batch, event_batch, cutoff)
if not has_anchor.any():
continue
# Hidden states at anchor positions
hidden_anchor = hidden[batch_idx, anchor_idx] # (B, D)
logits = self.head(hidden_anchor)
# Vectorized labels/validity for all horizons
labels_bhk, valid_cc_bhk, valid_clean_bhk = self._labels_and_validity_for_cutoff(
time_batch, event_batch, cutoff, horizons_days
)
# Risk scores for all horizons (B, H, K)
t_start = torch.clamp(torch.tensor(
cutoff, device=self.device) - t_anchor, min=0)
risk_bhk = self._compute_risk_scores_many_horizons(
logits, t_start, horizons_days)
# Apply anchor constraint to validity
anchor_mask = has_anchor.view(B, 1, 1)
valid_cc_bhk = valid_cc_bhk & anchor_mask
valid_clean_bhk = valid_clean_bhk & anchor_mask
# Push per-horizon chunks
for hi in range(H):
for track, valid_bk in (
("complete_case", valid_cc_bhk[:, hi, :]),
("clean_control", valid_clean_bhk[:, hi, :]),
):
row_mask = valid_bk.any(dim=1)
if not row_mask.any():
continue
r = risk_bhk[row_mask, hi, :].to(
torch.float32).cpu().numpy()
y = labels_bhk[row_mask, hi, :].to(
torch.float32).cpu().numpy()
m = valid_bk[row_mask, :].to(
torch.bool).cpu().numpy()
buffers[(ci, hi, track)]["risk"].append(r)
buffers[(ci, hi, track)]["labels"].append(y)
buffers[(ci, hi, track)]["valid"].append(m)
# Assemble results in the original output schema
all_results: Dict = {
'age_cutoffs': self.age_cutoffs, 'age_cutoffs': self.age_cutoffs,
'horizons': self.horizons, 'horizons': self.horizons,
'landmarks': [], 'landmarks': [],
} }
# Evaluate each landmark for ci, age in enumerate(self.age_cutoffs):
for age_cutoff in self.age_cutoffs: for hi, horizon in enumerate(self.horizons):
for horizon in self.horizons: landmark_results = {
landmark_results = self.evaluate_landmark(age_cutoff, horizon) 'age_cutoff': age,
'horizon': horizon,
'complete_case': {},
'clean_control': {},
}
for track in ("complete_case", "clean_control"):
chunks = buffers[(ci, hi, track)]
if len(chunks["risk"]) == 0:
continue
risk_scores = np.concatenate(chunks["risk"], axis=0)
labels = np.concatenate(chunks["labels"], axis=0)
valid_mask = np.concatenate(chunks["valid"], axis=0)
auc_scores = self.compute_auc_per_disease(
risk_scores, labels, valid_mask)
mean_auc = np.nanmean(list(auc_scores.values()))
track_out = {
'n_patients': int(valid_mask.shape[0]),
'n_valid': int(valid_mask.sum()),
'n_valid_patients': int((valid_mask.any(axis=1)).sum()),
'auc_per_disease': auc_scores,
'mean_auc': mean_auc,
}
if track == "complete_case":
brier_metrics = self.compute_brier_score(
risk_scores, labels, valid_mask)
capture_metrics = self.compute_disease_capture_at_k(
risk_scores, labels, valid_mask)
lift_yield_metrics = self.compute_lift_and_yield(
risk_scores, labels, valid_mask)
dca_metrics = self.compute_dca_net_benefit(
risk_scores, labels, valid_mask)
track_out.update({
'brier_score': brier_metrics['brier_score'],
'brier_skill_score': brier_metrics['brier_skill_score'],
'disease_capture_at_k': capture_metrics,
'lift_and_yield': lift_yield_metrics,
'dca': dca_metrics,
})
landmark_results[track] = track_out
all_results['landmarks'].append(landmark_results) all_results['landmarks'].append(landmark_results)
return all_results return all_results
@@ -1202,6 +1521,11 @@ def main():
default=4, default=4,
help='Number of data loader workers' help='Number of data loader workers'
) )
parser.add_argument(
'--no_compile',
action='store_true',
help='Disable torch.compile optimization (useful if your PyTorch build does not support it well)'
)
args = parser.parse_args() args = parser.parse_args()
@@ -1219,6 +1543,7 @@ def main():
device=args.device, device=args.device,
batch_size=args.batch_size, batch_size=args.batch_size,
num_workers=args.num_workers, num_workers=args.num_workers,
compile_model=(not args.no_compile),
) )
# Run evaluation # Run evaluation