2025-10-22 11:39:10 +08:00
|
|
|
import os
|
2025-10-16 14:21:36 +08:00
|
|
|
import torch
|
|
|
|
import numpy as np
|
|
|
|
import random
|
|
|
|
from collections import defaultdict
|
2025-10-18 11:07:59 +08:00
|
|
|
import json
|
2025-10-22 11:30:59 +08:00
|
|
|
from models import TimeAwareGPT2, TimeAwareGPT2Learnable
|
2025-10-18 11:07:59 +08:00
|
|
|
|
2025-10-16 14:21:36 +08:00
|
|
|
|
|
|
|
class PatientEventDataset(torch.utils.data.Dataset):
|
|
|
|
"""
|
|
|
|
A PyTorch Dataset for handling temporal sequences of patient events.
|
|
|
|
|
|
|
|
This class processes a raw NumPy array of patient records, groups them by
|
|
|
|
patient ID, and prepares them for training by imputing gaps, padding, or
|
|
|
|
truncating sequences to a fixed length.
|
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(self, data: np.ndarray, block_length: int):
|
|
|
|
"""
|
|
|
|
Initializes the dataset by pre-processing the patient event data.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
data (np.ndarray): A NumPy array of shape (N, 3) with dtype=np.uint32.
|
|
|
|
The columns represent (patient_id, time_in_days, event_code).
|
|
|
|
block_length (int): The fixed length for the output sequences.
|
|
|
|
"""
|
|
|
|
self.block_length = block_length
|
|
|
|
|
|
|
|
# Group (time_in_days, event_code) pairs by patient_id.
|
|
|
|
# This pre-processing step allows for efficient lookups in __getitem__.
|
|
|
|
patient_events = defaultdict(list)
|
|
|
|
for patient_id, time, event in data:
|
|
|
|
patient_events[patient_id].append((time, event))
|
|
|
|
|
|
|
|
# Store a list of unique patient_ids to map indices to patients.
|
|
|
|
self.patient_ids = list(patient_events.keys())
|
|
|
|
self.patient_events = dict(patient_events)
|
|
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
|
"""
|
|
|
|
Returns the total number of unique patients in the dataset.
|
|
|
|
"""
|
|
|
|
return len(self.patient_ids)
|
|
|
|
|
2025-10-20 13:47:50 +08:00
|
|
|
def __getitem__(self, idx):
|
2025-10-16 14:21:36 +08:00
|
|
|
"""
|
2025-10-20 13:47:50 +08:00
|
|
|
Retrieves, processes, and returns a single patient's event sequence,
|
|
|
|
or a list of sequences if a slice is provided.
|
2025-10-16 14:21:36 +08:00
|
|
|
|
|
|
|
Args:
|
2025-10-20 13:47:50 +08:00
|
|
|
idx (int or slice): The index or slice of the patient(s) to retrieve.
|
2025-10-16 14:21:36 +08:00
|
|
|
|
|
|
|
Returns:
|
2025-10-20 13:47:50 +08:00
|
|
|
If idx is an int, a tuple of two torch.long tensors:
|
|
|
|
(event_sequence, time_sequence), both of shape (block_length,).
|
|
|
|
If idx is a slice, a list of such tuples.
|
2025-10-16 14:21:36 +08:00
|
|
|
"""
|
2025-10-20 13:47:50 +08:00
|
|
|
if isinstance(idx, slice):
|
|
|
|
return [self[i] for i in range(*idx.indices(len(self)))]
|
|
|
|
|
2025-10-16 14:21:36 +08:00
|
|
|
# 1. Retrieve and Sort
|
|
|
|
patient_id = self.patient_ids[idx]
|
|
|
|
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
|
|
|
|
|
|
|
|
# 2. Impute "No Event" Gaps
|
|
|
|
imputed_sequence = []
|
|
|
|
if not records:
|
|
|
|
# Handle cases with no records for a patient if necessary, though
|
|
|
|
# the constructor logic would typically prevent this.
|
|
|
|
pass
|
|
|
|
else:
|
|
|
|
imputed_sequence.append(records[0])
|
|
|
|
for i in range(len(records) - 1):
|
|
|
|
prev_time, _ = records[i]
|
|
|
|
next_time, _ = records[i+1]
|
|
|
|
time_gap = next_time - prev_time
|
|
|
|
|
|
|
|
# If the gap is 5 years (1826 days) or more, insert "no event" records.
|
|
|
|
if time_gap >= 1826:
|
|
|
|
num_no_event_intervals = time_gap // 1826
|
|
|
|
for j in range(1, num_no_event_intervals + 1):
|
|
|
|
no_event_time = prev_time + j * 1826
|
|
|
|
imputed_sequence.append((no_event_time, 1)) # event_code=1 for "no event"
|
|
|
|
|
|
|
|
imputed_sequence.append(records[i+1])
|
|
|
|
|
|
|
|
# 3. Adjust Sequence Length
|
|
|
|
seq_len = len(imputed_sequence)
|
|
|
|
|
|
|
|
if seq_len > self.block_length:
|
|
|
|
# If longer, randomly select a contiguous sub-sequence.
|
|
|
|
start_index = random.randint(0, seq_len - self.block_length)
|
|
|
|
final_sequence = imputed_sequence[start_index : start_index + self.block_length]
|
|
|
|
elif seq_len < self.block_length:
|
|
|
|
# If shorter, pad the sequence at the end.
|
|
|
|
padding_needed = self.block_length - seq_len
|
|
|
|
# Use event_code=0 and time_in_days=36525 for padding.
|
|
|
|
padding = [(36525, 0)] * padding_needed
|
|
|
|
final_sequence = imputed_sequence + padding
|
|
|
|
else:
|
|
|
|
# If equal, use the sequence as is.
|
|
|
|
final_sequence = imputed_sequence
|
|
|
|
|
|
|
|
# 4. Return Tensors
|
|
|
|
# Separate the sequence into event codes and time, then convert to tensors.
|
|
|
|
event_codes = [item[1] for item in final_sequence]
|
|
|
|
time_stamps = [item[0] for item in final_sequence]
|
|
|
|
|
|
|
|
event_tensor = torch.tensor(event_codes, dtype=torch.long)
|
|
|
|
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
|
|
|
|
|
|
|
|
return event_tensor, time_tensor
|
2025-10-18 11:07:59 +08:00
|
|
|
|
2025-10-22 11:39:10 +08:00
|
|
|
def load_model(config_path: str, device: str = 'cpu'):
|
2025-10-18 11:07:59 +08:00
|
|
|
"""
|
2025-10-22 11:39:10 +08:00
|
|
|
Load a trained model based on the training configuration, inferring the
|
|
|
|
checkpoint filename from the configuration.
|
2025-10-22 11:30:59 +08:00
|
|
|
|
|
|
|
According to train.py, models may be either 'TimeAwareGPT2' or
|
|
|
|
'TimeAwareGPT2Learnable'. This function:
|
|
|
|
- Reads the config JSON to get architecture hyperparameters
|
|
|
|
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
|
2025-10-22 11:39:10 +08:00
|
|
|
- Infers the checkpoint path from the config values
|
|
|
|
- Infers vocab_size from the checkpoint
|
2025-10-22 11:30:59 +08:00
|
|
|
- Loads weights and returns the model in eval mode on the requested device
|
2025-10-18 11:07:59 +08:00
|
|
|
|
|
|
|
Args:
|
2025-10-22 11:30:59 +08:00
|
|
|
config_path: Path to the JSON configuration file saved during training.
|
|
|
|
device: 'cpu' or 'cuda'.
|
2025-10-18 11:07:59 +08:00
|
|
|
|
|
|
|
Returns:
|
2025-10-22 11:30:59 +08:00
|
|
|
torch.nn.Module: Loaded model ready for inference.
|
2025-10-18 11:07:59 +08:00
|
|
|
"""
|
2025-10-22 11:30:59 +08:00
|
|
|
# 1) Read config
|
2025-10-18 11:07:59 +08:00
|
|
|
with open(config_path, 'r') as f:
|
|
|
|
config_dict = json.load(f)
|
|
|
|
|
2025-10-22 11:30:59 +08:00
|
|
|
# Access config entries with attribute-style access while staying tolerant to missing keys
|
2025-10-18 11:07:59 +08:00
|
|
|
class AttrDict(dict):
|
2025-10-22 11:30:59 +08:00
|
|
|
def __getattr__(self, item):
|
|
|
|
try:
|
|
|
|
return self[item]
|
|
|
|
except KeyError:
|
|
|
|
raise AttributeError(item)
|
2025-10-18 11:07:59 +08:00
|
|
|
|
|
|
|
config = AttrDict(config_dict)
|
|
|
|
|
2025-10-22 11:30:59 +08:00
|
|
|
# 2) Decide model class (train.py supports two variants)
|
|
|
|
model_name = getattr(config, 'model_name', 'TimeAwareGPT2')
|
|
|
|
model_cls = {
|
|
|
|
'TimeAwareGPT2': TimeAwareGPT2,
|
|
|
|
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
|
|
|
|
}.get(model_name, TimeAwareGPT2)
|
|
|
|
|
2025-10-22 11:39:10 +08:00
|
|
|
# 3) Infer checkpoint filename from config
|
2025-10-22 11:30:59 +08:00
|
|
|
n_embd = getattr(config, 'n_embd')
|
|
|
|
n_layer = getattr(config, 'n_layer')
|
|
|
|
n_head = getattr(config, 'n_head')
|
2025-10-22 11:39:10 +08:00
|
|
|
|
|
|
|
# Newer naming (includes model_name) used by train.py when model_name is present
|
|
|
|
suffix_with_model = f"{model_name}_n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
|
|
|
|
ckpt_with_model = f"best_model_{suffix_with_model}.pt"
|
|
|
|
|
|
|
|
# Older naming (without model_name) matches existing repo files
|
|
|
|
suffix_legacy = f"n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
|
|
|
|
ckpt_legacy = f"best_model_{suffix_legacy}.pt"
|
|
|
|
|
|
|
|
# Prefer file that exists on disk
|
|
|
|
if os.path.exists(ckpt_with_model):
|
|
|
|
model_path = ckpt_with_model
|
|
|
|
elif os.path.exists(ckpt_legacy):
|
|
|
|
model_path = ckpt_legacy
|
|
|
|
else:
|
|
|
|
# Fall back to including model_name; if not present in config earlier, user may still have saved this way
|
|
|
|
model_path = ckpt_with_model
|
|
|
|
print(f"Warning: Could not find checkpoint on disk. Expected one of: {ckpt_with_model}, {ckpt_legacy}")
|
|
|
|
|
|
|
|
# 4) Infer vocab_size from checkpoint
|
|
|
|
state_preview = torch.load(model_path, map_location='cpu')
|
|
|
|
if 'wte.weight' in state_preview:
|
|
|
|
vocab_size = state_preview['wte.weight'].shape[0]
|
|
|
|
elif 'head.weight' in state_preview:
|
|
|
|
vocab_size = state_preview['head.weight'].shape[0]
|
|
|
|
else:
|
|
|
|
candidate = None
|
|
|
|
for k, v in state_preview.items():
|
|
|
|
if isinstance(v, torch.Tensor) and v.ndim == 2:
|
|
|
|
V = max(v.shape)
|
|
|
|
if candidate is None or V > candidate:
|
|
|
|
candidate = V
|
|
|
|
if candidate is None:
|
|
|
|
raise ValueError("Unable to infer vocab_size from checkpoint. Unknown tensor shapes.")
|
|
|
|
vocab_size = candidate
|
|
|
|
|
|
|
|
# 5) Build model from config (tolerant to missing fields)
|
2025-10-22 11:30:59 +08:00
|
|
|
pdrop = getattr(config, 'pdrop', 0.1)
|
|
|
|
token_pdrop = getattr(config, 'token_pdrop', 0.1)
|
|
|
|
|
|
|
|
model = model_cls(
|
2025-10-18 11:07:59 +08:00
|
|
|
vocab_size=vocab_size,
|
2025-10-22 11:30:59 +08:00
|
|
|
n_embd=n_embd,
|
|
|
|
n_layer=n_layer,
|
|
|
|
n_head=n_head,
|
|
|
|
pdrop=pdrop,
|
|
|
|
token_pdrop=token_pdrop,
|
2025-10-18 11:07:59 +08:00
|
|
|
).to(device)
|
|
|
|
|
2025-10-22 11:39:10 +08:00
|
|
|
# 6) Load weights
|
2025-10-22 11:30:59 +08:00
|
|
|
state_dict = torch.load(model_path, map_location=device)
|
|
|
|
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
2025-10-18 11:07:59 +08:00
|
|
|
|
2025-10-22 11:30:59 +08:00
|
|
|
if missing:
|
|
|
|
print(f"Warning: Missing keys when loading state_dict: {missing}")
|
|
|
|
if unexpected:
|
|
|
|
print(f"Warning: Unexpected keys when loading state_dict: {unexpected}")
|
2025-10-18 11:07:59 +08:00
|
|
|
|
2025-10-22 11:30:59 +08:00
|
|
|
model.eval()
|
|
|
|
try:
|
|
|
|
num_params_m = model.get_num_params()
|
|
|
|
print(f"Model loaded from {model_path} with {num_params_m:.2f}M parameters.")
|
|
|
|
except Exception:
|
|
|
|
pass
|
2025-10-18 11:07:59 +08:00
|
|
|
return model
|
2025-10-20 13:47:50 +08:00
|
|
|
|
|
|
|
|
|
|
|
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
|
|
"""
|
|
|
|
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
dataset (PatientEventDataset): The dataset to retrieve data from.
|
|
|
|
batch_slice (slice): The slice defining the batch of patients to retrieve.
|
|
|
|
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
|
|
|
|
These tokens will be replaced with -100. Defaults to None.
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
A tuple containing four tensors:
|
|
|
|
- input_events: (batch_size, sequence_length - 1)
|
|
|
|
- input_tims: (batch_size, sequence_length - 1)
|
|
|
|
- target_events: (batch_size, sequence_length - 1)
|
|
|
|
- target_times: (batch_size, sequence_length - 1)
|
|
|
|
"""
|
|
|
|
batch_data = dataset[batch_slice]
|
|
|
|
|
|
|
|
input_events = [item[0][:-1] for item in batch_data]
|
|
|
|
input_tims = [item[1][:-1] for item in batch_data]
|
|
|
|
target_events = [item[0][1:] for item in batch_data]
|
|
|
|
target_times = [item[1][1:] for item in batch_data]
|
|
|
|
|
|
|
|
input_events = torch.stack(input_events)
|
|
|
|
input_tims = torch.stack(input_tims)
|
|
|
|
target_events = torch.stack(target_events)
|
|
|
|
target_times = torch.stack(target_times)
|
|
|
|
|
|
|
|
return input_events, input_tims, target_events, target_times
|