diff --git a/models.py b/models.py index 3fa8ffe..203e84f 100644 --- a/models.py +++ b/models.py @@ -177,9 +177,10 @@ class TimeAwareGPT2(nn.Module): A time-aware GPT-2 model with custom temporal features. """ - def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float): + def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float, ignore_tokens: list[int] = None): super().__init__() self.token_pdrop = token_pdrop + self.ignore_tokens = ignore_tokens if ignore_tokens is not None else [] self.wte = nn.Embedding(vocab_size, n_embd) self.age_encoder = AgeSinusoidalEncoding(n_embd) @@ -234,6 +235,58 @@ class TimeAwareGPT2(nn.Module): """ return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 + @torch.no_grad() + def generate(self, x, t, max_new_tokens=100, max_age=85*365.25, no_repeat=True, termination_tokens=None, top_k=None): + """ + Take a conditioning sequence of indices x (LongTensor of shape (b,t)) and complete + the sequence max_new_tokens times, feeding the predictions back into the model each time. + Most likely you'll want to make sure to be in model.eval() mode of operation for this. + """ + self.eval() + + if termination_tokens is None: + termination_tokens = [1269] + + termination_tokens = torch.tensor(termination_tokens, dtype=torch.int64, device=x.device) + mask_time = -10000 + + for _ in range(max_new_tokens): + logits = self(x, t) + logits = logits[:, -1, :] + + if self.ignore_tokens: + logits[:, self.ignore_tokens] = -torch.inf + + if no_repeat: + fill = x.clone() + fill[fill == 1] = 0 + logits = logits.scatter(1, fill, -torch.inf) + + t_next_dist = torch.clamp(-torch.exp(-logits) * torch.rand(logits.shape, device=x.device).log(), min=0, max=365*80) + t_next_val, idx_next = t_next_dist.min(1) + + idx_next = idx_next.unsqueeze(1) + age_next = t[:, -1].unsqueeze(1) + t_next_val.unsqueeze(1) + + x = torch.cat((x, idx_next), dim=1) + t = torch.cat((t, age_next), dim=1) + + if torch.logical_or(torch.isin(x, termination_tokens).any(-1), age_next.squeeze() > max_age).all(): + break + + pad = (torch.cumsum(torch.cumsum(torch.isin(x, termination_tokens), 1).bool().int(), 1) > 1) + (t > max_age) + + final_logits = self(x, t) + x[pad] = 0 + t[pad] = mask_time + + if no_repeat: + fill = x.clone() + fill[fill == 1] = 0 + final_logits = torch.stack([final_logits[:,j].scatter(1, fill[:,:j+1], -torch.inf) for j in range(fill.shape[1])]).transpose(0,1) + + return x, t, final_logits + class CovariateAwareGPT2(nn.Module): """ Extends TimeAwareGPT2 to incorporate static and time-varying covariates.