feat(models): Refactor generate function in TimeAwareGPT2 with competing risks sampling
This commit is contained in:
55
models.py
55
models.py
@@ -177,9 +177,10 @@ class TimeAwareGPT2(nn.Module):
|
|||||||
A time-aware GPT-2 model with custom temporal features.
|
A time-aware GPT-2 model with custom temporal features.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float):
|
def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: list[int] = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.token_pdrop = token_pdrop
|
self.token_pdrop = token_pdrop
|
||||||
|
self.ignore_tokens = ignore_tokens if ignore_tokens is not None else []
|
||||||
|
|
||||||
self.wte = nn.Embedding(vocab_size, n_embd)
|
self.wte = nn.Embedding(vocab_size, n_embd)
|
||||||
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
self.age_encoder = AgeSinusoidalEncoding(n_embd)
|
||||||
@@ -234,6 +235,58 @@ class TimeAwareGPT2(nn.Module):
|
|||||||
"""
|
"""
|
||||||
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def generate(self, x, t, max_new_tokens=100, max_age=85*365.25, no_repeat=True, termination_tokens=None, top_k=None):
|
||||||
|
"""
|
||||||
|
Take a conditioning sequence of indices x (LongTensor of shape (b,t)) and complete
|
||||||
|
the sequence max_new_tokens times, feeding the predictions back into the model each time.
|
||||||
|
Most likely you'll want to make sure to be in model.eval() mode of operation for this.
|
||||||
|
"""
|
||||||
|
self.eval()
|
||||||
|
|
||||||
|
if termination_tokens is None:
|
||||||
|
termination_tokens = [1269]
|
||||||
|
|
||||||
|
termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device)
|
||||||
|
mask_time = -10000
|
||||||
|
|
||||||
|
for _ in range(max_new_tokens):
|
||||||
|
logits = self(x, t)
|
||||||
|
logits = logits[:, -1, :]
|
||||||
|
|
||||||
|
if self.ignore_tokens:
|
||||||
|
logits[:, self.ignore_tokens] = -torch.inf
|
||||||
|
|
||||||
|
if no_repeat:
|
||||||
|
fill = x.clone()
|
||||||
|
fill[fill == 1] = 0
|
||||||
|
logits = logits.scatter(1, fill, -torch.inf)
|
||||||
|
|
||||||
|
t_next_dist = torch.clamp(-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365*80)
|
||||||
|
t_next_val, idx_next = t_next_dist.min(1)
|
||||||
|
|
||||||
|
idx_next = idx_next.unsqueeze(1)
|
||||||
|
age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1)
|
||||||
|
|
||||||
|
x = torch.cat((x, idx_next), dim=1)
|
||||||
|
t = torch.cat((t, age_next), dim=1)
|
||||||
|
|
||||||
|
if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all():
|
||||||
|
break
|
||||||
|
|
||||||
|
pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age)
|
||||||
|
|
||||||
|
final_logits = self(x, t)
|
||||||
|
x[pad] = 0
|
||||||
|
t[pad] = mask_time
|
||||||
|
|
||||||
|
if no_repeat:
|
||||||
|
fill = x.clone()
|
||||||
|
fill[fill == 1] = 0
|
||||||
|
final_logits = torch.stack([final_logits[:,j].scatter(1, fill[:,:j+1], -torch.inf) for j in range(fill.shape[1])]).transpose(0,1)
|
||||||
|
|
||||||
|
return x, t, final_logits
|
||||||
|
|
||||||
class CovariateAwareGPT2(nn.Module):
|
class CovariateAwareGPT2(nn.Module):
|
||||||
"""
|
"""
|
||||||
Extends TimeAwareGPT2 to incorporate static and time-varying covariates.
|
Extends TimeAwareGPT2 to incorporate static and time-varying covariates.
|
||||||
|
Reference in New Issue
Block a user