Enhance LandmarkEvaluator with model compilation and optimization options
This commit is contained in:
351
evaluate.py
351
evaluate.py
@@ -35,6 +35,19 @@ from losses import (
|
||||
warnings.filterwarnings('ignore')
|
||||
|
||||
|
||||
def _maybe_torch_compile(module: torch.nn.Module, enabled: bool = True) -> torch.nn.Module:
|
||||
"""Best-effort torch.compile() wrapper (PyTorch 2.x)."""
|
||||
if not enabled:
|
||||
return module
|
||||
try:
|
||||
torch_compile = getattr(torch, "compile", None)
|
||||
if torch_compile is None:
|
||||
return module
|
||||
return torch_compile(module, mode="reduce-overhead")
|
||||
except Exception:
|
||||
return module
|
||||
|
||||
|
||||
def _ensure_dir(path: str) -> str:
|
||||
os.makedirs(path, exist_ok=True)
|
||||
return path
|
||||
@@ -225,11 +238,10 @@ class LandmarkEvaluator:
|
||||
device: str = 'cuda',
|
||||
batch_size: int = 256,
|
||||
num_workers: int = 4,
|
||||
compile_model: bool = True,
|
||||
):
|
||||
self.model = model.to(device)
|
||||
self.model.eval()
|
||||
self.head = head.to(device)
|
||||
self.head.eval()
|
||||
self.model = model.to(device).eval()
|
||||
self.head = head.to(device).eval()
|
||||
self.loss_fn = loss_fn
|
||||
self.dataset = dataset
|
||||
self.eval_indices = eval_indices
|
||||
@@ -237,6 +249,22 @@ class LandmarkEvaluator:
|
||||
self.batch_size = batch_size
|
||||
self.num_workers = num_workers
|
||||
|
||||
use_cuda = str(self.device).startswith(
|
||||
"cuda") and torch.cuda.is_available()
|
||||
if use_cuda:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
torch.backends.cudnn.allow_tf32 = True
|
||||
try:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# JIT/compile optimization (best effort)
|
||||
if compile_model and use_cuda:
|
||||
self.model = _maybe_torch_compile(self.model, enabled=True)
|
||||
self.head = _maybe_torch_compile(self.head, enabled=True)
|
||||
|
||||
# Evaluation parameters from design doc
|
||||
self.age_cutoffs = [50, 60, 70]
|
||||
self.horizons = [0.25, 0.5, 1, 2, 5, 10]
|
||||
@@ -247,6 +275,150 @@ class LandmarkEvaluator:
|
||||
self.age_cutoffs_days = [age * 365.25 for age in self.age_cutoffs]
|
||||
self.horizons_days = [h * 365.25 for h in self.horizons]
|
||||
|
||||
@staticmethod
|
||||
def _last_time(time_batch: torch.Tensor, event_batch: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute last observed (non-padding) time per patient."""
|
||||
real_mask = event_batch >= 1
|
||||
masked = time_batch.masked_fill(~real_mask, float('-inf'))
|
||||
return masked.max(dim=1).values
|
||||
|
||||
@staticmethod
|
||||
def _anchor_indices(
|
||||
time_batch: torch.Tensor,
|
||||
event_batch: torch.Tensor,
|
||||
cutoff_days: float,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Find anchor index/time: last valid record before cutoff."""
|
||||
real_mask = event_batch >= 1
|
||||
before = time_batch < cutoff_days
|
||||
valid_before = real_mask & before
|
||||
has_anchor = valid_before.any(dim=1)
|
||||
|
||||
# argmax of position under mask gives last True position
|
||||
L = event_batch.size(1)
|
||||
pos = torch.arange(L, device=event_batch.device).view(1, L)
|
||||
anchor_idx = (valid_before.to(torch.long) *
|
||||
pos).max(dim=1).values.to(torch.long)
|
||||
t_anchor = time_batch.gather(1, anchor_idx.view(-1, 1)).squeeze(1)
|
||||
return has_anchor, anchor_idx, t_anchor
|
||||
|
||||
def _labels_and_validity_for_cutoff(
|
||||
self,
|
||||
time_batch: torch.Tensor,
|
||||
event_batch: torch.Tensor,
|
||||
cutoff_days: float,
|
||||
horizons_days: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Vectorized label + validity computation for all horizons at a cutoff.
|
||||
|
||||
Returns:
|
||||
labels: (B, H, K) float32 {0,1}
|
||||
valid_cc: (B, H, K) bool
|
||||
valid_clean: (B, H, K) bool
|
||||
"""
|
||||
|
||||
n_tech_tokens = 2
|
||||
K = int(self.dataset.n_disease)
|
||||
death_code = int(K - 1)
|
||||
|
||||
B, L = event_batch.shape
|
||||
H = int(horizons_days.numel())
|
||||
|
||||
# Disease token mask and indices
|
||||
is_disease = event_batch >= n_tech_tokens
|
||||
disease_idx = (event_batch - n_tech_tokens).clamp(min=0, max=K - 1)
|
||||
|
||||
# ever_has_disease: (B, K)
|
||||
ever = torch.zeros((B, K), dtype=torch.bool, device=event_batch.device)
|
||||
if is_disease.any():
|
||||
b_idx, t_idx = is_disease.nonzero(as_tuple=True)
|
||||
d_idx = disease_idx[b_idx, t_idx]
|
||||
ever[b_idx, d_idx] = True
|
||||
|
||||
# Events within horizon windows: (B, L, H)
|
||||
offset = time_batch - float(cutoff_days)
|
||||
within = is_disease.unsqueeze(-1) & (offset.unsqueeze(-1) >= 0) & (
|
||||
offset.unsqueeze(-1) <= horizons_days.view(1, 1, H)
|
||||
)
|
||||
|
||||
labels_bool = torch.zeros(
|
||||
(B, H, K), dtype=torch.bool, device=event_batch.device)
|
||||
if within.any():
|
||||
b2, t2, h2 = within.nonzero(as_tuple=True)
|
||||
d2 = disease_idx[b2, t2]
|
||||
labels_bool[b2, h2, d2] = True
|
||||
|
||||
labels = labels_bool.to(torch.float32)
|
||||
|
||||
last_time = self._last_time(time_batch, event_batch) # (B,)
|
||||
horizon_end = float(cutoff_days) + horizons_days.view(1, H) # (1, H)
|
||||
|
||||
death_in_horizon = labels_bool[:, :, death_code] # (B, H)
|
||||
observed_past_horizon = last_time.view(B, 1) > horizon_end
|
||||
lost_within_horizon = last_time.view(B, 1) <= horizon_end
|
||||
|
||||
# Track A (Complete-Case):
|
||||
# - if observed past horizon OR death in horizon => valid all diseases
|
||||
# - else (censored within horizon) => valid only for diseases that occurred within horizon
|
||||
valid_cc = labels_bool.clone()
|
||||
full_mask = (observed_past_horizon | death_in_horizon).unsqueeze(-1)
|
||||
if full_mask.any():
|
||||
valid_cc = torch.where(
|
||||
full_mask.expand(-1, -1, K), torch.ones_like(valid_cc), valid_cc)
|
||||
|
||||
# Track B (Clean-Control) per disease:
|
||||
# valid[k] = hit_in_window(k) OR (never_has_k AND not lost_within_window)
|
||||
never = ~ever # (B, K)
|
||||
valid_clean = (~death_in_horizon).unsqueeze(-1) & (
|
||||
labels_bool | (never.unsqueeze(1) & (
|
||||
~lost_within_horizon).unsqueeze(-1))
|
||||
)
|
||||
|
||||
return labels, valid_cc, valid_clean
|
||||
|
||||
def _compute_risk_scores_many_horizons(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
t_start_days: torch.Tensor,
|
||||
horizons_days: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Compute risk increments for all horizons in one vectorized call.
|
||||
|
||||
Args:
|
||||
logits: model head outputs for anchor points.
|
||||
t_start_days: (B,) time from anchor to cutoff (days).
|
||||
horizons_days: (H,) horizons in days.
|
||||
|
||||
Returns:
|
||||
risk: (B, H, K) float32
|
||||
"""
|
||||
t_start_days = torch.clamp(t_start_days, min=0)
|
||||
t_end_days = torch.clamp(t_start_days.unsqueeze(
|
||||
1) + horizons_days.view(1, -1), min=0)
|
||||
|
||||
t_query_years = torch.cat([t_start_days.unsqueeze(
|
||||
1), t_end_days], dim=1) / 365.25 # (B, H+1)
|
||||
|
||||
# calculate_cifs returns (B, K) if scalar/per-sample, else (B, K, T)
|
||||
if hasattr(self.loss_fn, "calculate_cifs"):
|
||||
cifs = self.loss_fn.calculate_cifs(
|
||||
logits, t_query_years, return_survival=False)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Loss function does not support calculate_cifs: {type(self.loss_fn)}")
|
||||
|
||||
if cifs.ndim == 2:
|
||||
# (B, K) -> (B, 1, K)
|
||||
cifs_bt_k = cifs.unsqueeze(1)
|
||||
else:
|
||||
# (B, K, T) -> (B, T, K)
|
||||
cifs_bt_k = cifs.permute(0, 2, 1)
|
||||
|
||||
cif_start = cifs_bt_k[:, :1, :] # (B, 1, K)
|
||||
cif_end = cifs_bt_k[:, 1:, :] # (B, H, K)
|
||||
risk = torch.clamp(cif_end - cif_start, min=0)
|
||||
return risk
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_risk_scores(
|
||||
self,
|
||||
@@ -926,22 +1098,169 @@ class LandmarkEvaluator:
|
||||
return results
|
||||
|
||||
def run_full_evaluation(self) -> Dict:
|
||||
"""
|
||||
Run complete landmark analysis across all cutoffs and horizons.
|
||||
"""Run the full evaluation using a single-pass DataLoader.
|
||||
|
||||
Returns:
|
||||
all_results: Nested dictionary with all evaluation results
|
||||
Key optimizations:
|
||||
- iterate DataLoader exactly once
|
||||
- run transformer backbone once per batch
|
||||
- reuse hidden states per cutoff (3x head only)
|
||||
- vectorize CIF/risk over all horizons in one call
|
||||
"""
|
||||
all_results = {
|
||||
|
||||
# Build evaluation subset loader
|
||||
indices = self.eval_indices if self.eval_indices is not None else list(
|
||||
range(len(self.dataset)))
|
||||
subset = Subset(self.dataset, indices)
|
||||
loader = DataLoader(
|
||||
subset,
|
||||
batch_size=self.batch_size,
|
||||
shuffle=False,
|
||||
collate_fn=health_collate_fn,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True if str(self.device).startswith('cuda') else False,
|
||||
)
|
||||
|
||||
cutoffs_days = torch.tensor(
|
||||
# (C,)
|
||||
self.age_cutoffs_days, dtype=torch.float32, device=self.device)
|
||||
horizons_days = torch.tensor(
|
||||
# (H,)
|
||||
self.horizons_days, dtype=torch.float32, device=self.device)
|
||||
C = int(cutoffs_days.numel())
|
||||
H = int(horizons_days.numel())
|
||||
K = int(self.dataset.n_disease)
|
||||
|
||||
# Buffers: store per landmark/track arrays in chunks to avoid repeated I/O.
|
||||
# Each key stores lists of numpy arrays to be concatenated at the end.
|
||||
buffers: Dict[Tuple[int, int, str], Dict[str, List[np.ndarray]]] = {}
|
||||
for ci in range(C):
|
||||
for hi in range(H):
|
||||
for track in ("complete_case", "clean_control"):
|
||||
buffers[(ci, hi, track)] = {
|
||||
"risk": [], "labels": [], "valid": []}
|
||||
|
||||
with torch.inference_mode():
|
||||
for batch in tqdm(loader, desc="Single-pass evaluation", ncols=100):
|
||||
event_batch, time_batch, cont_batch, cate_batch, sex_batch = batch
|
||||
event_batch = event_batch.to(self.device, non_blocking=True)
|
||||
time_batch = time_batch.to(self.device, non_blocking=True)
|
||||
cont_batch = cont_batch.to(self.device, non_blocking=True)
|
||||
cate_batch = cate_batch.to(self.device, non_blocking=True)
|
||||
sex_batch = sex_batch.to(self.device, non_blocking=True)
|
||||
|
||||
B, L = event_batch.shape
|
||||
batch_idx = torch.arange(B, device=self.device)
|
||||
|
||||
# Backbone once per batch
|
||||
hidden = self.model(
|
||||
# (B, L, D)
|
||||
event_batch, time_batch, sex_batch, cont_batch, cate_batch)
|
||||
|
||||
for ci in range(C):
|
||||
cutoff = float(cutoffs_days[ci].item())
|
||||
|
||||
has_anchor, anchor_idx, t_anchor = self._anchor_indices(
|
||||
time_batch, event_batch, cutoff)
|
||||
if not has_anchor.any():
|
||||
continue
|
||||
|
||||
# Hidden states at anchor positions
|
||||
hidden_anchor = hidden[batch_idx, anchor_idx] # (B, D)
|
||||
logits = self.head(hidden_anchor)
|
||||
|
||||
# Vectorized labels/validity for all horizons
|
||||
labels_bhk, valid_cc_bhk, valid_clean_bhk = self._labels_and_validity_for_cutoff(
|
||||
time_batch, event_batch, cutoff, horizons_days
|
||||
)
|
||||
|
||||
# Risk scores for all horizons (B, H, K)
|
||||
t_start = torch.clamp(torch.tensor(
|
||||
cutoff, device=self.device) - t_anchor, min=0)
|
||||
risk_bhk = self._compute_risk_scores_many_horizons(
|
||||
logits, t_start, horizons_days)
|
||||
|
||||
# Apply anchor constraint to validity
|
||||
anchor_mask = has_anchor.view(B, 1, 1)
|
||||
valid_cc_bhk = valid_cc_bhk & anchor_mask
|
||||
valid_clean_bhk = valid_clean_bhk & anchor_mask
|
||||
|
||||
# Push per-horizon chunks
|
||||
for hi in range(H):
|
||||
for track, valid_bk in (
|
||||
("complete_case", valid_cc_bhk[:, hi, :]),
|
||||
("clean_control", valid_clean_bhk[:, hi, :]),
|
||||
):
|
||||
row_mask = valid_bk.any(dim=1)
|
||||
if not row_mask.any():
|
||||
continue
|
||||
|
||||
r = risk_bhk[row_mask, hi, :].to(
|
||||
torch.float32).cpu().numpy()
|
||||
y = labels_bhk[row_mask, hi, :].to(
|
||||
torch.float32).cpu().numpy()
|
||||
m = valid_bk[row_mask, :].to(
|
||||
torch.bool).cpu().numpy()
|
||||
|
||||
buffers[(ci, hi, track)]["risk"].append(r)
|
||||
buffers[(ci, hi, track)]["labels"].append(y)
|
||||
buffers[(ci, hi, track)]["valid"].append(m)
|
||||
|
||||
# Assemble results in the original output schema
|
||||
all_results: Dict = {
|
||||
'age_cutoffs': self.age_cutoffs,
|
||||
'horizons': self.horizons,
|
||||
'landmarks': [],
|
||||
}
|
||||
|
||||
# Evaluate each landmark
|
||||
for age_cutoff in self.age_cutoffs:
|
||||
for horizon in self.horizons:
|
||||
landmark_results = self.evaluate_landmark(age_cutoff, horizon)
|
||||
for ci, age in enumerate(self.age_cutoffs):
|
||||
for hi, horizon in enumerate(self.horizons):
|
||||
landmark_results = {
|
||||
'age_cutoff': age,
|
||||
'horizon': horizon,
|
||||
'complete_case': {},
|
||||
'clean_control': {},
|
||||
}
|
||||
|
||||
for track in ("complete_case", "clean_control"):
|
||||
chunks = buffers[(ci, hi, track)]
|
||||
if len(chunks["risk"]) == 0:
|
||||
continue
|
||||
|
||||
risk_scores = np.concatenate(chunks["risk"], axis=0)
|
||||
labels = np.concatenate(chunks["labels"], axis=0)
|
||||
valid_mask = np.concatenate(chunks["valid"], axis=0)
|
||||
|
||||
auc_scores = self.compute_auc_per_disease(
|
||||
risk_scores, labels, valid_mask)
|
||||
mean_auc = np.nanmean(list(auc_scores.values()))
|
||||
|
||||
track_out = {
|
||||
'n_patients': int(valid_mask.shape[0]),
|
||||
'n_valid': int(valid_mask.sum()),
|
||||
'n_valid_patients': int((valid_mask.any(axis=1)).sum()),
|
||||
'auc_per_disease': auc_scores,
|
||||
'mean_auc': mean_auc,
|
||||
}
|
||||
|
||||
if track == "complete_case":
|
||||
brier_metrics = self.compute_brier_score(
|
||||
risk_scores, labels, valid_mask)
|
||||
capture_metrics = self.compute_disease_capture_at_k(
|
||||
risk_scores, labels, valid_mask)
|
||||
lift_yield_metrics = self.compute_lift_and_yield(
|
||||
risk_scores, labels, valid_mask)
|
||||
dca_metrics = self.compute_dca_net_benefit(
|
||||
risk_scores, labels, valid_mask)
|
||||
track_out.update({
|
||||
'brier_score': brier_metrics['brier_score'],
|
||||
'brier_skill_score': brier_metrics['brier_skill_score'],
|
||||
'disease_capture_at_k': capture_metrics,
|
||||
'lift_and_yield': lift_yield_metrics,
|
||||
'dca': dca_metrics,
|
||||
})
|
||||
|
||||
landmark_results[track] = track_out
|
||||
|
||||
all_results['landmarks'].append(landmark_results)
|
||||
|
||||
return all_results
|
||||
@@ -1202,6 +1521,11 @@ def main():
|
||||
default=4,
|
||||
help='Number of data loader workers'
|
||||
)
|
||||
parser.add_argument(
|
||||
'--no_compile',
|
||||
action='store_true',
|
||||
help='Disable torch.compile optimization (useful if your PyTorch build does not support it well)'
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
@@ -1219,6 +1543,7 @@ def main():
|
||||
device=args.device,
|
||||
batch_size=args.batch_size,
|
||||
num_workers=args.num_workers,
|
||||
compile_model=(not args.no_compile),
|
||||
)
|
||||
|
||||
# Run evaluation
|
||||
|
||||
Reference in New Issue
Block a user