feat(models): Refactor generate function in TimeAwareGPT2 with competing risks sampling
This commit is contained in:
55
models.py
55
models.py
@@ -177,9 +177,10 @@ class TimeAwareGPT2(nn.Module):
|
||||
A time-aware GPT-2 model with custom temporal features.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float):
|
||||
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: list[int] = None):
|
||||
super().__init__()
|
||||
self.token_pdrop = token_pdrop
|
||||
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
|
||||
|
||||
self.wte = nn.Embedding(vocab_size, n_embd)
|
||||
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
||||
@@ -234,6 +235,58 @@ class TimeAwareGPT2(nn.Module):
|
||||
"""
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(self, x, t, max_new_tokens=100, max_age=85*365.25, no_repeat=True, termination_tokens=None, top_k=None):
|
||||
"""
|
||||
Take a conditioning sequence of indices x (LongTensor of shape (b,t)) and complete
|
||||
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
||||
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
if termination_tokens is None:
|
||||
termination_tokens = [1269]
|
||||
|
||||
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
|
||||
mask_time = -10000
|
||||
|
||||
for _ in range(max_new_tokens):
|
||||
logits = self(x, t)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if self.ignore_tokens:
|
||||
logits[:, self.ignore_tokens] = -torch.inf
|
||||
|
||||
if no_repeat:
|
||||
fill = x.clone()
|
||||
fill[fill == 1] = 0
|
||||
logits = logits.scatter(1, fill, -torch.inf)
|
||||
|
||||
t_next_dist = torch.clamp(-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365*80)
|
||||
t_next_val, idx_next = t_next_dist.min(1)
|
||||
|
||||
idx_next = idx_next.unsqueeze(1)
|
||||
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
|
||||
|
||||
x = torch.cat((x, idx_next), dim=1)
|
||||
t = torch.cat((t, age_next), dim=1)
|
||||
|
||||
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
|
||||
break
|
||||
|
||||
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
|
||||
|
||||
final_logits = self(x, t)
|
||||
x[pad] = 0
|
||||
t[pad] = mask_time
|
||||
|
||||
if no_repeat:
|
||||
fill = x.clone()
|
||||
fill[fill == 1] = 0
|
||||
final_logits = torch.stack([final_logits[:,j].scatter(1, fill[:,:j+1], -torch.inf) for j in range(fill.shape[1])]).transpose(0,1)
|
||||
|
||||
return x, t, final_logits
|
||||
|
||||
class CovariateAwareGPT2(nn.Module):
|
||||
"""
|
||||
Extends TimeAwareGPT2 to incorporate static and time-varying covariates.
|
||||
|
Reference in New Issue
Block a user