diff --git a/GEMINI.md b/GEMINI.md new file mode 100644 index 0000000..d511b32 --- /dev/null +++ b/GEMINI.md @@ -0,0 +1,59 @@ +# DeepHealth Project + +This is a deep learning project based on PyTorch. This project adheres to specific code style and file structure conventions to ensure clarity, maintainability, and reproducibility. + +## 1. Project Structure + +To maintain a clean and modular project, we adopt the following file organization: + +DeepHealth/ + |-tain.py + |-models.py + |-utils.py + |-data/ + |-requirements.txt + |-README.md + + +### File Descriptions + +* **`train.py`**: + * **Core training script**. It contains the control flow for the entire training process. + * Responsible for initializing the model, optimizer, DataLoader, etc. + * Executes the training and validation loops. + * Handles saving and loading checkpoints, logging, and other related tasks. + +* **`models.py`**: + * **Model and Loss Function Definitions**. This file stores the architecture for all neural network models. + * All subclasses of `torch.nn.Module` should be defined in this file. + * Custom loss functions should also be implemented here. + +* **`utils.py`**: + * **Utility Functions Module**. It contains reusable helper functions for the project. + * Primarily responsible for data I/O operations, data preprocessing, performance metric calculations, logger configuration, or other logic that doesn't belong in the core model or training framework. + +* **`data/`**: + * **Data Storage Directory**. Used to store the datasets required for the project. + * `data/raw/` stores the original, unprocessed data. + * `data/processed/` stores data after it has been preprocessed. + +* **`requirements.txt`**: + * **Project Dependencies**. Lists all the Python packages and their versions required to run this project. + +* **`README.md`**: + * **Project Documentation**. Provides a high-level overview of the project, setup instructions, and usage guidelines. + +## 2. Core Framework + +* **Deep Learning Framework**: `PyTorch` + +## 3. Coding Style + +This project uniformly adopts the **Google Python Style Guide**. All submitted code should adhere to this standard to ensure consistency and readability. + +Key features include: +* Using `yapf` or `black` for automatic code formatting. +* Following detailed naming conventions (`module_name`, `package_name`, `ClassName`, `method_name`, `ExceptionName`, `function_name`, `GLOBAL_CONSTANT_NAME`). +* Using Google-style docstrings. + +Please refer to the official documentation: [Google Python Style Guide](http://google.github.io/styleguide/pyguide.html) \ No newline at end of file diff --git a/models.py b/models.py new file mode 100644 index 0000000..f03ebe7 --- /dev/null +++ b/models.py @@ -0,0 +1,284 @@ +import torch +import torch.nn as nn +from torch.nn import functional as F +from typing import Tuple +import math + +class CausalSelfAttention(nn.Module): + """ + A vanilla multi-head masked self-attention layer with a projection at the end. + """ + + def __init__(self, n_embd: int, n_head: int, pdrop: float): + super().__init__() + assert n_embd % n_head == 0 + # key, query, value projections for all heads + self.c_attn = nn.Linear(n_embd, 3 * n_embd) + # output projection + self.c_proj = nn.Linear(n_embd, n_embd) + # regularization + self.attn_dropout = nn.Dropout(pdrop) + self.resid_dropout = nn.Dropout(pdrop) + self.n_head = n_head + self.n_embd = n_embd + + def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor: + B, L, D = x.size() # batch size, sequence length, embedding dimensionality (n_embd) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q, k, v = self.c_attn(x).split(self.n_embd, dim=2) + k = k.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs) + q = q.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs) + v = v.view(B, L, self.n_head, D // self.n_head).transpose(1, 2) # (B, nh, L, hs) + + # causal self-attention; Self-attend: (B, nh, L, hs) x (B, nh, hs, L) -> (B, nh, L, L) + att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) + + # Apply the time-based causal mask + att = att.masked_fill(custom_mask.unsqueeze(1) == 0, float('-inf')) + + att = F.softmax(att, dim=-1) + att = self.attn_dropout(att) + y = att @ v # (B, nh, L, L) x (B, nh, L, hs) -> (B, nh, L, hs) + y = y.transpose(1, 2).contiguous().view(B, L, D) # re-assemble all head outputs side by side + + # output projection + y = self.resid_dropout(self.c_proj(y)) + return y + +class Block(nn.Module): + """ an unassuming Transformer block """ + + def __init__(self, n_embd: int, n_head: int, pdrop: float): + super().__init__() + self.ln_1 = nn.LayerNorm(n_embd) + self.attn = CausalSelfAttention(n_embd, n_head, pdrop) + self.ln_2 = nn.LayerNorm(n_embd) + self.mlp = nn.ModuleDict(dict( + c_fc = nn.Linear(n_embd, 4 * n_embd), + c_proj = nn.Linear(4 * n_embd, n_embd), + act = nn.GELU(), + dropout = nn.Dropout(pdrop), + )) + m = self.mlp + self.mlpf = lambda x: m.dropout(m.c_proj(m.act(m.c_fc(x)))) # MLP forward + + def forward(self, x: torch.Tensor, custom_mask: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.ln_1(x), custom_mask=custom_mask) + x = x + self.mlpf(self.ln_2(x)) + return x + +class AgeSinusoidalEncoding(nn.Module): + """ + Encodes age using sinusoidal functions, similar to positional encodings + in Transformers. This module creates a fixed-size embedding for an age + value given in days. + """ + + def __init__(self, embedding_dim: int): + """ + Initializes the AgeSinusoidalEncoding module. + + Args: + embedding_dim (int): The dimensionality of the output embedding. + Must be an even number. + + Raises: + ValueError: If embedding_dim is not an even number. + """ + super().__init__() + if embedding_dim % 2 != 0: + raise ValueError(f"Embedding dimension must be an even number, but got {embedding_dim}") + + self.embedding_dim = embedding_dim + + # Pre-calculate the divisor term for the sinusoidal formula. + # The formula for the divisor is 10000^(2i/D), where D is the + # embedding_dim and i is the index for each pair of dimensions. + # i ranges from 0 to D/2 - 1. + i = torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) + divisor = torch.pow(10000, i / self.embedding_dim) + + # Register the divisor as a non-trainable buffer. This ensures it is + # moved to the correct device (e.g., GPU) along with the model. + self.register_buffer('divisor', divisor) + + def forward(self, t: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the AgeSinusoidalEncoding. + + Args: + t (torch.Tensor): A tensor of shape (batch_size, sequence_length) + with dtype=torch.float32, representing age in days. + + Returns: + torch.Tensor: The encoded age tensor of shape + (batch_size, sequence_length, embedding_dim). + """ + # 1. Unit Conversion: Convert age from days to years. + # We use 365.25 to account for leap years. + t_years = t / 365.25 + + # 2. Argument Calculation: Calculate the arguments for the sin/cos functions. + # The shapes are broadcast to (B, L, D/2). + # Input t_years: (B, L) -> unsqueezed to (B, L, 1) + # Divisor: (D/2) -> viewed as (1, 1, D/2) + args = t_years.unsqueeze(-1) * self.divisor.view(1, 1, -1) + + # 3. Sinusoidal Application: Create the final output tensor. + # Initialize an empty tensor to store the embeddings. + output = torch.zeros(t.shape[0], t.shape[1], self.embedding_dim, device=t.device) + + # Assign cosine of the arguments to the even indices. + output[:, :, 0::2] = torch.cos(args) + + # Assign sine of the arguments to the odd indices. + output[:, :, 1::2] = torch.sin(args) + + return output + +class TimeAwareGPT2(nn.Module): + """ + A time-aware GPT-2 model with custom temporal features. + """ + + def __init__(self, vocab_size: int, n_embd: int, n_layer: int, n_head: int, pdrop: float, token_pdrop: float): + super().__init__() + self.token_pdrop = token_pdrop + + # Token and positional embeddings + self.wte = nn.Embedding(vocab_size, n_embd) + self.age_encoder = AgeSinusoidalEncoding(n_embd) + self.drop = nn.Dropout(pdrop) + + # Transformer blocks + self.blocks = nn.ModuleList([Block(n_embd, n_head, pdrop) for _ in range(n_layer)]) + + # Final layer norm and linear head + self.ln_f = nn.LayerNorm(n_embd) + self.head = nn.Linear(n_embd, vocab_size, bias=False) + + self.n_embd = n_embd + + def forward(self, event_seq: torch.Tensor, time_seq: torch.Tensor) -> torch.Tensor: + """ + Forward pass for the TimeAwareGPT2 model. + + Args: + event_seq (torch.Tensor): Token indices of shape (B, L). + time_seq (torch.Tensor): Timestamps for each event of shape (B, L). + + Returns: + torch.Tensor: Logits of shape (B, L, vocab_size). + """ + B, L = event_seq.size() + + # 1. Get token embeddings + token_embeddings = self.wte(event_seq) + + # 2. Apply token dropout (only during training) + if self.training and self.token_pdrop > 0: + # Create a mask to randomly zero out entire token embedding vectors + drop_mask = torch.rand(token_embeddings.shape[:2], device=token_embeddings.device) < self.token_pdrop + token_embeddings[drop_mask] = 0.0 + + # 3. Get positional embeddings from time sequence + pos_embeddings = self.age_encoder(time_seq.float()) + + # 4. Combine embeddings and apply dropout + x = self.drop(token_embeddings + pos_embeddings) + + # 5. Generate attention mask + # The attention mask combines two conditions: + # a) Time-based causality: A token i can attend to a token j only if time_seq[j] < time_seq[i]. + # b) Padding mask: Do not attend to positions where the event token is 0. + + # a) Time-based causal mask + t_i = time_seq.unsqueeze(-1) # (B, L, 1) + t_j = time_seq.unsqueeze(1) # (B, 1, L) + time_mask = (t_j < t_i) + + # b) Padding mask (prevents attending to key positions that are padding) + padding_mask = (event_seq != 0).unsqueeze(1) # Shape: (B, 1, L) + + # Combine the masks. A position (j) can be attended to by a query (i) only if + # it's in the past (time_mask) AND it's not a padding token (padding_mask). + combined_mask = time_mask & padding_mask + + # 6. Pass through transformer blocks + for block in self.blocks: + x = block(x, custom_mask=combined_mask) + + # 7. Final layer norm and projection to vocab size + x = self.ln_f(x) + logits = self.head(x) + + return logits + + def get_num_params(self) -> float: + """ + Returns the number of trainable parameters in the model in millions. + """ + return sum(p.numel() for p in self.parameters() if p.requires_grad) / 1e6 + +class CombinedLoss(nn.Module): + """ + Computes a two-part loss: a standard cross-entropy loss for event type + prediction and a survival analysis loss for event timing. + """ + + def __init__(self, ignored_token_ids: list[int]): + """ + Initializes the CombinedLoss module. + + Args: + ignored_token_ids (list[int]): A list of event type IDs to be + excluded from all loss calculations. + """ + super().__init__() + self.ignored_token_ids = ignored_token_ids + + def forward(self, logits: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Calculates the combined cross-entropy and survival loss. + + Args: + logits (torch.Tensor): Raw model outputs of shape (B, L, N). + x (torch.Tensor): Ground-truth event labels of shape (B, L). + t (torch.Tensor): True time duration for each event, shape (B, L). + + Returns: + A tuple containing the two scalar loss tensors: (loss_ce, loss_survival). + """ + # 1. Create a mask to filter out ignored token IDs from loss calculation. + # An element is True if the corresponding label in x is NOT in the ignored list. + mask = torch.ones_like(x, dtype=torch.bool) + for token_id in self.ignored_token_ids: + mask = mask & (x != token_id) + + # If the mask is all False (all tokens are ignored), return zero for both losses. + if not mask.any(): + return torch.tensor(0.0, device=logits.device), torch.tensor(0.0, device=logits.device) + + # 2. Part 1: Cross-Entropy Loss (loss_ce) + # Permute logits from (B, L, N) to (B, N, L) for F.cross_entropy. + logits_for_ce = logits.permute(0, 2, 1) + + # Calculate per-element loss without reduction. + per_element_ce = F.cross_entropy(logits_for_ce, x, reduction='none') + + # Apply the mask and compute the mean of valid elements. + loss_ce = per_element_ce[mask].mean() + + # 3. Part 2: Survival Loss (loss_survival) + # Calculate event intensity (lambda) as the sum of exponentiated logits. + intensity = torch.sum(torch.exp(logits), dim=2) + + # Calculate per-element survival loss (negative log-likelihood of exponential dist). + # We add a small epsilon for numerical stability with the log. + per_element_survival = -(torch.log(intensity + 1e-8) - intensity * t) + + # Apply the mask and compute the mean of valid elements. + loss_survival = per_element_survival[mask].mean() + + return loss_ce, loss_survival diff --git a/prepare_data.py b/prepare_data.py new file mode 100644 index 0000000..58bf03f --- /dev/null +++ b/prepare_data.py @@ -0,0 +1,133 @@ +import pandas as pd +import tqdm +import numpy as np + +label_files = 'labels.csv' +ukb_field_to_icd10_file = 'icd10_codes_mod.tsv' +ukb_basket_file = 'ukb_delphi.txt' +train_proportion = 0.8 +output_prefix = 'ukb_real' + +icdict = {} +icdcodes = [] +with open(ukb_field_to_icd10_file) as f: + for line in f: + parts = line.strip().split() + icdict[parts[0]] = parts[5] + icdcodes.append(parts[5]) + +# Using enumerate for cleaner, safer label assignment starting from 0 +label_dict = {} +with open(label_files) as f: + for i, line in enumerate(f): + label_dict[line.strip().split(' ')[0]] = i + +icdict['f.31.0.0'] = "sex" +icdict['f.34.0.0'] = "YEAR" +icdict['f.52.0.0'] = "MONTH" +icdict['f.40000.0.0'] = "Death" + +for j in range(17): + icdict[f'f.40005.{j}.0'] = f'cancer_date_{j}' + icdict[f'f.40006.{j}.0'] = f'cancer_type_{j}' + +icdict['f.53.0.0'] = "assessment_date" +icdict['f.21001.0.0'] = "BMI" +icdict['f.1239.0.0'] = "smoking" +icdict['f.1558.0.0'] = "alcohol" + +len_icd = len(icdcodes) + +# Corrected typo 'aseessment_date' to 'assessment_date' +icdcodes.extend(['Death', 'assessment_date'] + [f'cancer_date_{j}' for j in range(17)]) + +data_list = [] +ukb_iterator = pd.read_csv(ukb_basket_file, sep=',', chunksize=10000, index_col=0, low_memory=False) + +for _, dd in tqdm.tqdm(enumerate(ukb_iterator)): + dd = dd.rename(columns=icdict) + dd.dropna(subset=['sex'], inplace=True) + dd['sex'] += 1 + dd = dd[[col for col in dd.columns if not col.startswith('f.')]] + dd['dob'] = pd.to_datetime(dd[['YEAR', 'MONTH']].assign(DAY=1)) + + present_icdcodes = [c for c in icdcodes if c in dd.columns] + if present_icdcodes: + # Convert date columns to days from date of birth + date_cols = dd[present_icdcodes].apply(pd.to_datetime, format="%Y-%m-%d", errors='coerce') + date_cols_days = date_cols.sub(dd['dob'], axis=0) + dd[present_icdcodes] = date_cols_days.apply(lambda x: x.dt.days) + + # Process ICD codes efficiently using melt + cols_to_process = [col for col in icdcodes[:len_icd + 1] if col in dd.columns] + if cols_to_process: + melted_df = dd.reset_index().melt( + id_vars=['f.eid'], + value_vars=cols_to_process, + var_name='event_code', + value_name='days' + ) + melted_df.dropna(subset=['days'], inplace=True) + if not melted_df.empty: + melted_df['label'] = melted_df['event_code'].map(label_dict) + data_list.append(melted_df[['f.eid', 'days', 'label']].dropna().astype(int).to_numpy()) + + # Process sex + X = dd['sex'].reset_index().to_numpy().astype(int) + data_list.append(np.c_[X[:, 0], np.zeros(X.shape[0]), X[:, 1]].astype(int)) + + # Process cancer data efficiently using wide_to_long + df_res = dd.reset_index() + rename_dict = {f'cancer_date_{j}': f'cancerdate{j}' for j in range(17)} + rename_dict.update({f'cancer_type_{j}': f'cancertype{j}' for j in range(17)}) + df_renamed = df_res.rename(columns=rename_dict) + + stubs_to_use = [] + if any('cancerdate' in col for col in df_renamed.columns): stubs_to_use.append('cancerdate') + if any('cancertype' in col for col in df_renamed.columns): stubs_to_use.append('cancertype') + + if len(stubs_to_use) == 2: + long_cancer = pd.wide_to_long(df_renamed, + stubnames=stubs_to_use, + i=['f.eid'], + j='cancer_num' + ).dropna() + if not long_cancer.empty: + long_cancer['cancer'] = long_cancer['cancertype'].str.slice(0, 3) + long_cancer['cancer_label'] = long_cancer['cancer'].map(label_dict) + cancer_array = long_cancer.reset_index()[['f.eid', 'cancerdate', 'cancer_label']].dropna().astype(int).to_numpy() + if cancer_array.size > 0: + data_list.append(cancer_array) + + # Process BMI, smoking, and alcohol + dd_bmi = dd[['assessment_date', 'BMI']].dropna().reset_index() + if not dd_bmi.empty: + dd_bmi['bmi_status'] = np.select([dd_bmi['BMI'] > 28, dd_bmi['BMI'] > 22], [5, 4], default=3) + data_list.append(dd_bmi[['f.eid', 'assessment_date', 'bmi_status']].astype(int).to_numpy()) + + dd_sm = dd[['assessment_date', 'smoking']].dropna().reset_index() + dd_sm = dd_sm[dd_sm['smoking'] != -3] + if not dd_sm.empty: + dd_sm['smoking_status'] = np.select([dd_sm['smoking'] == 1, dd_sm['smoking'] == 2], [8, 7], default=6) + data_list.append(dd_sm[['f.eid', 'assessment_date', 'smoking_status']].astype(int).to_numpy()) + + dd_al = dd[['assessment_date', 'alcohol']].dropna().reset_index() + dd_al = dd_al[dd_al['alcohol'] != -3] + if not dd_al.empty: + dd_al['alcohol_status'] = np.select([dd_al['alcohol'] == 1, dd_al['alcohol'] < 4], [11, 10], default=9) + data_list.append(dd_al[['f.eid', 'assessment_date', 'alcohol_status']].astype(int).to_numpy()) + +data = np.vstack(data_list) +data = data[np.lexsort((data[:, 1], data[:, 2] == data[:, 2].max(), data[:, 0]))] +data = data[data[:, 1] >= 0] +data = pd.DataFrame(data).drop_duplicates([0, 2]).values +data = data.astype(np.uint32) +data.tofile(output_prefix + '.bin') + +# Correctly split train/validation sets +unique_ids = np.unique(data[:, 0]) +split_id = unique_ids[int(len(unique_ids) * train_proportion)] +train_val_split = data[:, 0] <= split_id + +data[train_val_split].tofile(output_prefix + '_train.bin') +data[~train_val_split].tofile(output_prefix + '_val.bin') diff --git a/train.py b/train.py new file mode 100644 index 0000000..fbcfb2a --- /dev/null +++ b/train.py @@ -0,0 +1,170 @@ +import torch +import torch.nn as nn +from torch.optim import Adam +from torch.utils.data import DataLoader +import numpy as np +import math +import tqdm + +from models import TimeAwareGPT2, CombinedLoss +from utils import PatientEventDataset + +# --- Configuration --- +class TrainConfig: + # Data parameters + train_data_path = 'ukb_real_train.bin' + val_data_path = 'ukb_real_val.bin' + block_length = 256 # Sequence length + + # Model parameters + n_embd = 256 + n_layer = 8 + n_head = 8 + pdrop = 0.1 + token_pdrop = 0.1 + + # Training parameters + max_epoch = 200 + batch_size = 128 + lr_initial = 6e-4 + lr_final = 6e-5 + warmup_epochs = 10 + early_stopping_patience = 5 + + # Loss parameters + # 0 = padding, 1 = "no event" + ignored_token_ids = [0, 1] + + # System parameters + device = 'cuda' if torch.cuda.is_available() else 'cpu' + +# --- Main Training Script --- +def main(): + config = TrainConfig() + + # --- 1. Data Loading --- + print(f"Loading data from {config.train_data_path} and {config.val_data_path}...") + train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3) + + # Infer vocab_size from the data (max label + 1) + vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1 + print(f"Inferred vocabulary size: {vocab_size}") + + train_dataset = PatientEventDataset(train_data_arr, config.block_length) + val_dataset = PatientEventDataset(val_data_arr, config.block_length) + + train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True) + val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True) + + # --- 2. Model, Optimizer, and Loss Initialization --- + print(f"Initializing model on {config.device}...") + model = TimeAwareGPT2( + vocab_size=vocab_size, + n_embd=config.n_embd, + n_layer=config.n_layer, + n_head=config.n_head, + pdrop=config.pdrop, + token_pdrop=config.token_pdrop + ).to(config.device) + + print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.") + + loss_fn = CombinedLoss(config.ignored_token_ids) + optimizer = Adam(model.parameters(), lr=config.lr_initial) + + # --- 3. Training Loop --- + best_val_loss = float('inf') + patience_counter = 0 + print("Starting training...") + for epoch in range(config.max_epoch): + # --- Learning Rate Scheduling --- + if epoch < config.warmup_epochs: + lr = config.lr_initial + else: + progress = (epoch - config.warmup_epochs) / (config.max_epoch - config.warmup_epochs) + lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress)) + + for param_group in optimizer.param_groups: + param_group['lr'] = lr + + # --- Training Phase --- + model.train() + train_loss_ce_acc, train_loss_surv_acc = 0.0, 0.0 + train_steps = 0 + + pbar = tqdm.tqdm(train_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Train]") + for event_seq, time_seq in pbar: + event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) + + # Prepare inputs and targets + input_events = event_seq[:, :-1] + input_times = time_seq[:, :-1] + target_events = event_seq[:, 1:] + target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() + + # Forward pass + logits = model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + loss = loss_ce + loss_survival + + # Backward pass and optimization + optimizer.zero_grad() + loss.backward() + optimizer.step() + + train_loss_ce_acc += loss_ce.item() + train_loss_surv_acc += loss_survival.item() + train_steps += 1 + pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'}) + + avg_train_loss_ce = train_loss_ce_acc / train_steps + avg_train_loss_surv = train_loss_surv_acc / train_steps + + # --- Validation Phase --- + model.eval() + val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0 + val_steps = 0 + + with torch.no_grad(): + pbar_val = tqdm.tqdm(val_loader, desc=f"Epoch {epoch+1}/{config.max_epoch} [Val]") + for event_seq, time_seq in pbar_val: + event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device) + + input_events = event_seq[:, :-1] + input_times = time_seq[:, :-1] + target_events = event_seq[:, 1:] + target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float() + + logits = model(input_events, input_times) + loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times) + + val_loss_ce_acc += loss_ce.item() + val_loss_surv_acc += loss_survival.item() + val_steps += 1 + pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'}) + + avg_val_loss_ce = val_loss_ce_acc / val_steps + avg_val_loss_surv = val_loss_surv_acc / val_steps + total_val_loss = avg_val_loss_ce + avg_val_loss_surv + + print(f"Epoch {epoch+1} Summary: \n" + f" Train Loss: {avg_train_loss_ce + avg_train_loss_surv:.4f} (CE: {avg_train_loss_ce:.4f}, Surv: {avg_train_loss_surv:.4f})\n" + f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})\n" + f" Learning Rate: {lr:.6f}") + + # --- Early Stopping Check --- + if total_val_loss < best_val_loss: + best_val_loss = total_val_loss + patience_counter = 0 + print(f"Validation loss improved to {best_val_loss:.4f}. Resetting patience.") + else: + patience_counter += 1 + print(f"Validation loss did not improve. Patience: {patience_counter}/{config.early_stopping_patience}") + + if patience_counter >= config.early_stopping_patience: + print("\nEarly stopping triggered due to no improvement in validation loss.") + break + +if __name__ == '__main__': + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..5128dc8 --- /dev/null +++ b/utils.py @@ -0,0 +1,104 @@ +import torch +import numpy as np +import random +from collections import defaultdict + +class PatientEventDataset(torch.utils.data.Dataset): + """ + A PyTorch Dataset for handling temporal sequences of patient events. + + This class processes a raw NumPy array of patient records, groups them by + patient ID, and prepares them for training by imputing gaps, padding, or + truncating sequences to a fixed length. + """ + + def __init__(self, data: np.ndarray, block_length: int): + """ + Initializes the dataset by pre-processing the patient event data. + + Args: + data (np.ndarray): A NumPy array of shape (N, 3) with dtype=np.uint32. + The columns represent (patient_id, time_in_days, event_code). + block_length (int): The fixed length for the output sequences. + """ + self.block_length = block_length + + # Group (time_in_days, event_code) pairs by patient_id. + # This pre-processing step allows for efficient lookups in __getitem__. + patient_events = defaultdict(list) + for patient_id, time, event in data: + patient_events[patient_id].append((time, event)) + + # Store a list of unique patient_ids to map indices to patients. + self.patient_ids = list(patient_events.keys()) + self.patient_events = dict(patient_events) + + def __len__(self) -> int: + """ + Returns the total number of unique patients in the dataset. + """ + return len(self.patient_ids) + + def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]: + """ + Retrieves, processes, and returns a single patient's event sequence. + + Args: + idx (int): The index of the patient to retrieve. + + Returns: + A tuple of two torch.long tensors: (event_sequence, time_sequence), + both of shape (block_length,). + """ + # 1. Retrieve and Sort + patient_id = self.patient_ids[idx] + records = sorted(self.patient_events[patient_id], key=lambda x: x[0]) + + # 2. Impute "No Event" Gaps + imputed_sequence = [] + if not records: + # Handle cases with no records for a patient if necessary, though + # the constructor logic would typically prevent this. + pass + else: + imputed_sequence.append(records[0]) + for i in range(len(records) - 1): + prev_time, _ = records[i] + next_time, _ = records[i+1] + time_gap = next_time - prev_time + + # If the gap is 5 years (1826 days) or more, insert "no event" records. + if time_gap >= 1826: + num_no_event_intervals = time_gap // 1826 + for j in range(1, num_no_event_intervals + 1): + no_event_time = prev_time + j * 1826 + imputed_sequence.append((no_event_time, 1)) # event_code=1 for "no event" + + imputed_sequence.append(records[i+1]) + + # 3. Adjust Sequence Length + seq_len = len(imputed_sequence) + + if seq_len > self.block_length: + # If longer, randomly select a contiguous sub-sequence. + start_index = random.randint(0, seq_len - self.block_length) + final_sequence = imputed_sequence[start_index : start_index + self.block_length] + elif seq_len < self.block_length: + # If shorter, pad the sequence at the end. + padding_needed = self.block_length - seq_len + # Use event_code=0 and time_in_days=36525 for padding. + padding = [(36525, 0)] * padding_needed + final_sequence = imputed_sequence + padding + else: + # If equal, use the sequence as is. + final_sequence = imputed_sequence + + # 4. Return Tensors + # Separate the sequence into event codes and time, then convert to tensors. + event_codes = [item[1] for item in final_sequence] + time_stamps = [item[0] for item in final_sequence] + + event_tensor = torch.tensor(event_codes, dtype=torch.long) + time_tensor = torch.tensor(time_stamps, dtype=torch.long) + + return event_tensor, time_tensor