feat(models): Refactor generate function in TimeAwareGPT2 with competing risks sampling

This commit is contained in:
2025-10-18 12:42:14 +08:00
parent a631ac6d59
commit 082c719975

View File

@@ -177,9 +177,10 @@ class TimeAwareGPT2(nn.Module):
A time-aware GPT-2 model with custom temporal features. A time-aware GPT-2 model with custom temporal features.
""" """
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float): def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: list[int] = None):
super().__init__() super().__init__()
self.token_pdrop = token_pdrop self.token_pdrop = token_pdrop
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
self.wte = nn.Embedding(vocab_size, n_embd) self.wte = nn.Embedding(vocab_size, n_embd)
self.age_encoder = AgeSinusoidalEncoding(n_embd) self.age_encoder = AgeSinusoidalEncoding(n_embd)
@@ -234,6 +235,58 @@ class TimeAwareGPT2(nn.Module):
""" """
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
@torch.no_grad()
def generate(self, x, t, max_new_tokens=100, max_age=85*365.25, no_repeat=True, termination_tokens=None, top_k=None):
"""
Take a conditioning sequence of indices x (LongTensor of shape (b,t)) and complete
the sequence max_new_tokens times, feeding the predictions back into the model each time.
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
"""
self.eval()
if termination_tokens is None:
termination_tokens = [1269]
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
mask_time = -10000
for _ in range(max_new_tokens):
logits = self(x, t)
logits = logits[:, -1, :]
if self.ignore_tokens:
logits[:, self.ignore_tokens] = -torch.inf
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
logits = logits.scatter(1, fill, -torch.inf)
t_next_dist = torch.clamp(-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365*80)
t_next_val, idx_next = t_next_dist.min(1)
idx_next = idx_next.unsqueeze(1)
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
x = torch.cat((x, idx_next), dim=1)
t = torch.cat((t, age_next), dim=1)
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
break
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
final_logits = self(x, t)
x[pad] = 0
t[pad] = mask_time
if no_repeat:
fill = x.clone()
fill[fill == 1] = 0
final_logits = torch.stack([final_logits[:,j].scatter(1, fill[:,:j+1], -torch.inf) for j in range(fill.shape[1])]).transpose(0,1)
return x, t, final_logits
class CovariateAwareGPT2(nn.Module): class CovariateAwareGPT2(nn.Module):
""" """
Extends TimeAwareGPT2 to incorporate static and time-varying covariates. Extends TimeAwareGPT2 to incorporate static and time-varying covariates.