Enhance LandmarkEvaluator with model compilation and optimization options
This commit is contained in:
351
evaluate.py
351
evaluate.py
@@ -35,6 +35,19 @@ from losses import (
|
|||||||
warnings.filterwarnings('ignore')
|
warnings.filterwarnings('ignore')
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_torch_compile(module: torch.nn.Module, enabled: bool = True) -> torch.nn.Module:
|
||||||
|
"""Best-effort torch.compile() wrapper (PyTorch 2.x)."""
|
||||||
|
if not enabled:
|
||||||
|
return module
|
||||||
|
try:
|
||||||
|
torch_compile = getattr(torch, "compile", None)
|
||||||
|
if torch_compile is None:
|
||||||
|
return module
|
||||||
|
return torch_compile(module, mode="reduce-overhead")
|
||||||
|
except Exception:
|
||||||
|
return module
|
||||||
|
|
||||||
|
|
||||||
def _ensure_dir(path: str) -> str:
|
def _ensure_dir(path: str) -> str:
|
||||||
os.makedirs(path, exist_ok=True)
|
os.makedirs(path, exist_ok=True)
|
||||||
return path
|
return path
|
||||||
@@ -225,11 +238,10 @@ class LandmarkEvaluator:
|
|||||||
device: str = 'cuda',
|
device: str = 'cuda',
|
||||||
batch_size: int = 256,
|
batch_size: int = 256,
|
||||||
num_workers: int = 4,
|
num_workers: int = 4,
|
||||||
|
compile_model: bool = True,
|
||||||
):
|
):
|
||||||
self.model = model.to(device)
|
self.model = model.to(device).eval()
|
||||||
self.model.eval()
|
self.head = head.to(device).eval()
|
||||||
self.head = head.to(device)
|
|
||||||
self.head.eval()
|
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.dataset = dataset
|
self.dataset = dataset
|
||||||
self.eval_indices = eval_indices
|
self.eval_indices = eval_indices
|
||||||
@@ -237,6 +249,22 @@ class LandmarkEvaluator:
|
|||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.num_workers = num_workers
|
self.num_workers = num_workers
|
||||||
|
|
||||||
|
use_cuda = str(self.device).startswith(
|
||||||
|
"cuda") and torch.cuda.is_available()
|
||||||
|
if use_cuda:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
|
torch.backends.cudnn.allow_tf32 = True
|
||||||
|
try:
|
||||||
|
torch.set_float32_matmul_precision("high")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# JIT/compile optimization (best effort)
|
||||||
|
if compile_model and use_cuda:
|
||||||
|
self.model = _maybe_torch_compile(self.model, enabled=True)
|
||||||
|
self.head = _maybe_torch_compile(self.head, enabled=True)
|
||||||
|
|
||||||
# Evaluation parameters from design doc
|
# Evaluation parameters from design doc
|
||||||
self.age_cutoffs = [50, 60, 70]
|
self.age_cutoffs = [50, 60, 70]
|
||||||
self.horizons = [0.25, 0.5, 1, 2, 5, 10]
|
self.horizons = [0.25, 0.5, 1, 2, 5, 10]
|
||||||
@@ -247,6 +275,150 @@ class LandmarkEvaluator:
|
|||||||
self.age_cutoffs_days = [age * 365.25 for age in self.age_cutoffs]
|
self.age_cutoffs_days = [age * 365.25 for age in self.age_cutoffs]
|
||||||
self.horizons_days = [h * 365.25 for h in self.horizons]
|
self.horizons_days = [h * 365.25 for h in self.horizons]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _last_time(time_batch: torch.Tensor, event_batch: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Compute last observed (non-padding) time per patient."""
|
||||||
|
real_mask = event_batch >= 1
|
||||||
|
masked = time_batch.masked_fill(~real_mask, float('-inf'))
|
||||||
|
return masked.max(dim=1).values
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _anchor_indices(
|
||||||
|
time_batch: torch.Tensor,
|
||||||
|
event_batch: torch.Tensor,
|
||||||
|
cutoff_days: float,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Find anchor index/time: last valid record before cutoff."""
|
||||||
|
real_mask = event_batch >= 1
|
||||||
|
before = time_batch < cutoff_days
|
||||||
|
valid_before = real_mask & before
|
||||||
|
has_anchor = valid_before.any(dim=1)
|
||||||
|
|
||||||
|
# argmax of position under mask gives last True position
|
||||||
|
L = event_batch.size(1)
|
||||||
|
pos = torch.arange(L, device=event_batch.device).view(1, L)
|
||||||
|
anchor_idx = (valid_before.to(torch.long) *
|
||||||
|
pos).max(dim=1).values.to(torch.long)
|
||||||
|
t_anchor = time_batch.gather(1, anchor_idx.view(-1, 1)).squeeze(1)
|
||||||
|
return has_anchor, anchor_idx, t_anchor
|
||||||
|
|
||||||
|
def _labels_and_validity_for_cutoff(
|
||||||
|
self,
|
||||||
|
time_batch: torch.Tensor,
|
||||||
|
event_batch: torch.Tensor,
|
||||||
|
cutoff_days: float,
|
||||||
|
horizons_days: torch.Tensor,
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Vectorized label + validity computation for all horizons at a cutoff.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
labels: (B, H, K) float32 {0,1}
|
||||||
|
valid_cc: (B, H, K) bool
|
||||||
|
valid_clean: (B, H, K) bool
|
||||||
|
"""
|
||||||
|
|
||||||
|
n_tech_tokens = 2
|
||||||
|
K = int(self.dataset.n_disease)
|
||||||
|
death_code = int(K - 1)
|
||||||
|
|
||||||
|
B, L = event_batch.shape
|
||||||
|
H = int(horizons_days.numel())
|
||||||
|
|
||||||
|
# Disease token mask and indices
|
||||||
|
is_disease = event_batch >= n_tech_tokens
|
||||||
|
disease_idx = (event_batch - n_tech_tokens).clamp(min=0, max=K - 1)
|
||||||
|
|
||||||
|
# ever_has_disease: (B, K)
|
||||||
|
ever = torch.zeros((B, K), dtype=torch.bool, device=event_batch.device)
|
||||||
|
if is_disease.any():
|
||||||
|
b_idx, t_idx = is_disease.nonzero(as_tuple=True)
|
||||||
|
d_idx = disease_idx[b_idx, t_idx]
|
||||||
|
ever[b_idx, d_idx] = True
|
||||||
|
|
||||||
|
# Events within horizon windows: (B, L, H)
|
||||||
|
offset = time_batch - float(cutoff_days)
|
||||||
|
within = is_disease.unsqueeze(-1) & (offset.unsqueeze(-1) >= 0) & (
|
||||||
|
offset.unsqueeze(-1) <= horizons_days.view(1, 1, H)
|
||||||
|
)
|
||||||
|
|
||||||
|
labels_bool = torch.zeros(
|
||||||
|
(B, H, K), dtype=torch.bool, device=event_batch.device)
|
||||||
|
if within.any():
|
||||||
|
b2, t2, h2 = within.nonzero(as_tuple=True)
|
||||||
|
d2 = disease_idx[b2, t2]
|
||||||
|
labels_bool[b2, h2, d2] = True
|
||||||
|
|
||||||
|
labels = labels_bool.to(torch.float32)
|
||||||
|
|
||||||
|
last_time = self._last_time(time_batch, event_batch) # (B,)
|
||||||
|
horizon_end = float(cutoff_days) + horizons_days.view(1, H) # (1, H)
|
||||||
|
|
||||||
|
death_in_horizon = labels_bool[:, :, death_code] # (B, H)
|
||||||
|
observed_past_horizon = last_time.view(B, 1) > horizon_end
|
||||||
|
lost_within_horizon = last_time.view(B, 1) <= horizon_end
|
||||||
|
|
||||||
|
# Track A (Complete-Case):
|
||||||
|
# - if observed past horizon OR death in horizon => valid all diseases
|
||||||
|
# - else (censored within horizon) => valid only for diseases that occurred within horizon
|
||||||
|
valid_cc = labels_bool.clone()
|
||||||
|
full_mask = (observed_past_horizon | death_in_horizon).unsqueeze(-1)
|
||||||
|
if full_mask.any():
|
||||||
|
valid_cc = torch.where(
|
||||||
|
full_mask.expand(-1, -1, K), torch.ones_like(valid_cc), valid_cc)
|
||||||
|
|
||||||
|
# Track B (Clean-Control) per disease:
|
||||||
|
# valid[k] = hit_in_window(k) OR (never_has_k AND not lost_within_window)
|
||||||
|
never = ~ever # (B, K)
|
||||||
|
valid_clean = (~death_in_horizon).unsqueeze(-1) & (
|
||||||
|
labels_bool | (never.unsqueeze(1) & (
|
||||||
|
~lost_within_horizon).unsqueeze(-1))
|
||||||
|
)
|
||||||
|
|
||||||
|
return labels, valid_cc, valid_clean
|
||||||
|
|
||||||
|
def _compute_risk_scores_many_horizons(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
t_start_days: torch.Tensor,
|
||||||
|
horizons_days: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Compute risk increments for all horizons in one vectorized call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits: model head outputs for anchor points.
|
||||||
|
t_start_days: (B,) time from anchor to cutoff (days).
|
||||||
|
horizons_days: (H,) horizons in days.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
risk: (B, H, K) float32
|
||||||
|
"""
|
||||||
|
t_start_days = torch.clamp(t_start_days, min=0)
|
||||||
|
t_end_days = torch.clamp(t_start_days.unsqueeze(
|
||||||
|
1) + horizons_days.view(1, -1), min=0)
|
||||||
|
|
||||||
|
t_query_years = torch.cat([t_start_days.unsqueeze(
|
||||||
|
1), t_end_days], dim=1) / 365.25 # (B, H+1)
|
||||||
|
|
||||||
|
# calculate_cifs returns (B, K) if scalar/per-sample, else (B, K, T)
|
||||||
|
if hasattr(self.loss_fn, "calculate_cifs"):
|
||||||
|
cifs = self.loss_fn.calculate_cifs(
|
||||||
|
logits, t_query_years, return_survival=False)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Loss function does not support calculate_cifs: {type(self.loss_fn)}")
|
||||||
|
|
||||||
|
if cifs.ndim == 2:
|
||||||
|
# (B, K) -> (B, 1, K)
|
||||||
|
cifs_bt_k = cifs.unsqueeze(1)
|
||||||
|
else:
|
||||||
|
# (B, K, T) -> (B, T, K)
|
||||||
|
cifs_bt_k = cifs.permute(0, 2, 1)
|
||||||
|
|
||||||
|
cif_start = cifs_bt_k[:, :1, :] # (B, 1, K)
|
||||||
|
cif_end = cifs_bt_k[:, 1:, :] # (B, H, K)
|
||||||
|
risk = torch.clamp(cif_end - cif_start, min=0)
|
||||||
|
return risk
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def compute_risk_scores(
|
def compute_risk_scores(
|
||||||
self,
|
self,
|
||||||
@@ -926,22 +1098,169 @@ class LandmarkEvaluator:
|
|||||||
return results
|
return results
|
||||||
|
|
||||||
def run_full_evaluation(self) -> Dict:
|
def run_full_evaluation(self) -> Dict:
|
||||||
"""
|
"""Run the full evaluation using a single-pass DataLoader.
|
||||||
Run complete landmark analysis across all cutoffs and horizons.
|
|
||||||
|
|
||||||
Returns:
|
Key optimizations:
|
||||||
all_results: Nested dictionary with all evaluation results
|
- iterate DataLoader exactly once
|
||||||
|
- run transformer backbone once per batch
|
||||||
|
- reuse hidden states per cutoff (3x head only)
|
||||||
|
- vectorize CIF/risk over all horizons in one call
|
||||||
"""
|
"""
|
||||||
all_results = {
|
|
||||||
|
# Build evaluation subset loader
|
||||||
|
indices = self.eval_indices if self.eval_indices is not None else list(
|
||||||
|
range(len(self.dataset)))
|
||||||
|
subset = Subset(self.dataset, indices)
|
||||||
|
loader = DataLoader(
|
||||||
|
subset,
|
||||||
|
batch_size=self.batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
collate_fn=health_collate_fn,
|
||||||
|
num_workers=self.num_workers,
|
||||||
|
pin_memory=True if str(self.device).startswith('cuda') else False,
|
||||||
|
)
|
||||||
|
|
||||||
|
cutoffs_days = torch.tensor(
|
||||||
|
# (C,)
|
||||||
|
self.age_cutoffs_days, dtype=torch.float32, device=self.device)
|
||||||
|
horizons_days = torch.tensor(
|
||||||
|
# (H,)
|
||||||
|
self.horizons_days, dtype=torch.float32, device=self.device)
|
||||||
|
C = int(cutoffs_days.numel())
|
||||||
|
H = int(horizons_days.numel())
|
||||||
|
K = int(self.dataset.n_disease)
|
||||||
|
|
||||||
|
# Buffers: store per landmark/track arrays in chunks to avoid repeated I/O.
|
||||||
|
# Each key stores lists of numpy arrays to be concatenated at the end.
|
||||||
|
buffers: Dict[Tuple[int, int, str], Dict[str, List[np.ndarray]]] = {}
|
||||||
|
for ci in range(C):
|
||||||
|
for hi in range(H):
|
||||||
|
for track in ("complete_case", "clean_control"):
|
||||||
|
buffers[(ci, hi, track)] = {
|
||||||
|
"risk": [], "labels": [], "valid": []}
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
for batch in tqdm(loader, desc="Single-pass evaluation", ncols=100):
|
||||||
|
event_batch, time_batch, cont_batch, cate_batch, sex_batch = batch
|
||||||
|
event_batch = event_batch.to(self.device, non_blocking=True)
|
||||||
|
time_batch = time_batch.to(self.device, non_blocking=True)
|
||||||
|
cont_batch = cont_batch.to(self.device, non_blocking=True)
|
||||||
|
cate_batch = cate_batch.to(self.device, non_blocking=True)
|
||||||
|
sex_batch = sex_batch.to(self.device, non_blocking=True)
|
||||||
|
|
||||||
|
B, L = event_batch.shape
|
||||||
|
batch_idx = torch.arange(B, device=self.device)
|
||||||
|
|
||||||
|
# Backbone once per batch
|
||||||
|
hidden = self.model(
|
||||||
|
# (B, L, D)
|
||||||
|
event_batch, time_batch, sex_batch, cont_batch, cate_batch)
|
||||||
|
|
||||||
|
for ci in range(C):
|
||||||
|
cutoff = float(cutoffs_days[ci].item())
|
||||||
|
|
||||||
|
has_anchor, anchor_idx, t_anchor = self._anchor_indices(
|
||||||
|
time_batch, event_batch, cutoff)
|
||||||
|
if not has_anchor.any():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Hidden states at anchor positions
|
||||||
|
hidden_anchor = hidden[batch_idx, anchor_idx] # (B, D)
|
||||||
|
logits = self.head(hidden_anchor)
|
||||||
|
|
||||||
|
# Vectorized labels/validity for all horizons
|
||||||
|
labels_bhk, valid_cc_bhk, valid_clean_bhk = self._labels_and_validity_for_cutoff(
|
||||||
|
time_batch, event_batch, cutoff, horizons_days
|
||||||
|
)
|
||||||
|
|
||||||
|
# Risk scores for all horizons (B, H, K)
|
||||||
|
t_start = torch.clamp(torch.tensor(
|
||||||
|
cutoff, device=self.device) - t_anchor, min=0)
|
||||||
|
risk_bhk = self._compute_risk_scores_many_horizons(
|
||||||
|
logits, t_start, horizons_days)
|
||||||
|
|
||||||
|
# Apply anchor constraint to validity
|
||||||
|
anchor_mask = has_anchor.view(B, 1, 1)
|
||||||
|
valid_cc_bhk = valid_cc_bhk & anchor_mask
|
||||||
|
valid_clean_bhk = valid_clean_bhk & anchor_mask
|
||||||
|
|
||||||
|
# Push per-horizon chunks
|
||||||
|
for hi in range(H):
|
||||||
|
for track, valid_bk in (
|
||||||
|
("complete_case", valid_cc_bhk[:, hi, :]),
|
||||||
|
("clean_control", valid_clean_bhk[:, hi, :]),
|
||||||
|
):
|
||||||
|
row_mask = valid_bk.any(dim=1)
|
||||||
|
if not row_mask.any():
|
||||||
|
continue
|
||||||
|
|
||||||
|
r = risk_bhk[row_mask, hi, :].to(
|
||||||
|
torch.float32).cpu().numpy()
|
||||||
|
y = labels_bhk[row_mask, hi, :].to(
|
||||||
|
torch.float32).cpu().numpy()
|
||||||
|
m = valid_bk[row_mask, :].to(
|
||||||
|
torch.bool).cpu().numpy()
|
||||||
|
|
||||||
|
buffers[(ci, hi, track)]["risk"].append(r)
|
||||||
|
buffers[(ci, hi, track)]["labels"].append(y)
|
||||||
|
buffers[(ci, hi, track)]["valid"].append(m)
|
||||||
|
|
||||||
|
# Assemble results in the original output schema
|
||||||
|
all_results: Dict = {
|
||||||
'age_cutoffs': self.age_cutoffs,
|
'age_cutoffs': self.age_cutoffs,
|
||||||
'horizons': self.horizons,
|
'horizons': self.horizons,
|
||||||
'landmarks': [],
|
'landmarks': [],
|
||||||
}
|
}
|
||||||
|
|
||||||
# Evaluate each landmark
|
for ci, age in enumerate(self.age_cutoffs):
|
||||||
for age_cutoff in self.age_cutoffs:
|
for hi, horizon in enumerate(self.horizons):
|
||||||
for horizon in self.horizons:
|
landmark_results = {
|
||||||
landmark_results = self.evaluate_landmark(age_cutoff, horizon)
|
'age_cutoff': age,
|
||||||
|
'horizon': horizon,
|
||||||
|
'complete_case': {},
|
||||||
|
'clean_control': {},
|
||||||
|
}
|
||||||
|
|
||||||
|
for track in ("complete_case", "clean_control"):
|
||||||
|
chunks = buffers[(ci, hi, track)]
|
||||||
|
if len(chunks["risk"]) == 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
risk_scores = np.concatenate(chunks["risk"], axis=0)
|
||||||
|
labels = np.concatenate(chunks["labels"], axis=0)
|
||||||
|
valid_mask = np.concatenate(chunks["valid"], axis=0)
|
||||||
|
|
||||||
|
auc_scores = self.compute_auc_per_disease(
|
||||||
|
risk_scores, labels, valid_mask)
|
||||||
|
mean_auc = np.nanmean(list(auc_scores.values()))
|
||||||
|
|
||||||
|
track_out = {
|
||||||
|
'n_patients': int(valid_mask.shape[0]),
|
||||||
|
'n_valid': int(valid_mask.sum()),
|
||||||
|
'n_valid_patients': int((valid_mask.any(axis=1)).sum()),
|
||||||
|
'auc_per_disease': auc_scores,
|
||||||
|
'mean_auc': mean_auc,
|
||||||
|
}
|
||||||
|
|
||||||
|
if track == "complete_case":
|
||||||
|
brier_metrics = self.compute_brier_score(
|
||||||
|
risk_scores, labels, valid_mask)
|
||||||
|
capture_metrics = self.compute_disease_capture_at_k(
|
||||||
|
risk_scores, labels, valid_mask)
|
||||||
|
lift_yield_metrics = self.compute_lift_and_yield(
|
||||||
|
risk_scores, labels, valid_mask)
|
||||||
|
dca_metrics = self.compute_dca_net_benefit(
|
||||||
|
risk_scores, labels, valid_mask)
|
||||||
|
track_out.update({
|
||||||
|
'brier_score': brier_metrics['brier_score'],
|
||||||
|
'brier_skill_score': brier_metrics['brier_skill_score'],
|
||||||
|
'disease_capture_at_k': capture_metrics,
|
||||||
|
'lift_and_yield': lift_yield_metrics,
|
||||||
|
'dca': dca_metrics,
|
||||||
|
})
|
||||||
|
|
||||||
|
landmark_results[track] = track_out
|
||||||
|
|
||||||
all_results['landmarks'].append(landmark_results)
|
all_results['landmarks'].append(landmark_results)
|
||||||
|
|
||||||
return all_results
|
return all_results
|
||||||
@@ -1202,6 +1521,11 @@ def main():
|
|||||||
default=4,
|
default=4,
|
||||||
help='Number of data loader workers'
|
help='Number of data loader workers'
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
'--no_compile',
|
||||||
|
action='store_true',
|
||||||
|
help='Disable torch.compile optimization (useful if your PyTorch build does not support it well)'
|
||||||
|
)
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -1219,6 +1543,7 @@ def main():
|
|||||||
device=args.device,
|
device=args.device,
|
||||||
batch_size=args.batch_size,
|
batch_size=args.batch_size,
|
||||||
num_workers=args.num_workers,
|
num_workers=args.num_workers,
|
||||||
|
compile_model=(not args.no_compile),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Run evaluation
|
# Run evaluation
|
||||||
|
|||||||
Reference in New Issue
Block a user