Files
DeepHealth/utils.py

256 lines
10 KiB
Python
Raw Normal View History

import os
import torch
import numpy as np
import random
from collections import defaultdict
import json
from models import TimeAwareGPT2, TimeAwareGPT2Learnable
class PatientEventDataset(torch.utils.data.Dataset):
"""
A PyTorch Dataset for handling temporal sequences of patient events.
This class processes a raw NumPy array of patient records, groups them by
patient ID, and prepares them for training by imputing gaps, padding, or
truncating sequences to a fixed length.
"""
def __init__(self, data: np.ndarray, block_length: int):
"""
Initializes the dataset by pre-processing the patient event data.
Args:
data (np.ndarray): A NumPy array of shape (N, 3) with dtype=np.uint32.
The columns represent (patient_id, time_in_days, event_code).
block_length (int): The fixed length for the output sequences.
"""
self.block_length = block_length
# Group (time_in_days, event_code) pairs by patient_id.
# This pre-processing step allows for efficient lookups in __getitem__.
patient_events = defaultdict(list)
for patient_id, time, event in data:
patient_events[patient_id].append((time, event))
# Store a list of unique patient_ids to map indices to patients.
self.patient_ids = list(patient_events.keys())
self.patient_events = dict(patient_events)
def __len__(self) -> int:
"""
Returns the total number of unique patients in the dataset.
"""
return len(self.patient_ids)
2025-10-20 13:47:50 +08:00
def __getitem__(self, idx):
"""
2025-10-20 13:47:50 +08:00
Retrieves, processes, and returns a single patient's event sequence,
or a list of sequences if a slice is provided.
Args:
2025-10-20 13:47:50 +08:00
idx (int or slice): The index or slice of the patient(s) to retrieve.
Returns:
2025-10-20 13:47:50 +08:00
If idx is an int, a tuple of two torch.long tensors:
(event_sequence, time_sequence), both of shape (block_length,).
If idx is a slice, a list of such tuples.
"""
2025-10-20 13:47:50 +08:00
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self)))]
# 1. Retrieve and Sort
patient_id = self.patient_ids[idx]
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
# 2. Impute "No Event" Gaps
imputed_sequence = []
if not records:
# Handle cases with no records for a patient if necessary, though
# the constructor logic would typically prevent this.
pass
else:
imputed_sequence.append(records[0])
for i in range(len(records) - 1):
prev_time, _ = records[i]
next_time, _ = records[i+1]
time_gap = next_time - prev_time
# If the gap is 5 years (1826 days) or more, insert "no event" records.
if time_gap >= 1826:
num_no_event_intervals = time_gap // 1826
for j in range(1, num_no_event_intervals + 1):
no_event_time = prev_time + j * 1826
imputed_sequence.append((no_event_time, 1)) # event_code=1 for "no event"
imputed_sequence.append(records[i+1])
# 3. Adjust Sequence Length
seq_len = len(imputed_sequence)
if seq_len > self.block_length:
# If longer, randomly select a contiguous sub-sequence.
start_index = random.randint(0, seq_len - self.block_length)
final_sequence = imputed_sequence[start_index : start_index + self.block_length]
elif seq_len < self.block_length:
# If shorter, pad the sequence at the end.
padding_needed = self.block_length - seq_len
# Use event_code=0 and time_in_days=36525 for padding.
padding = [(36525, 0)] * padding_needed
final_sequence = imputed_sequence + padding
else:
# If equal, use the sequence as is.
final_sequence = imputed_sequence
# 4. Return Tensors
# Separate the sequence into event codes and time, then convert to tensors.
event_codes = [item[1] for item in final_sequence]
time_stamps = [item[0] for item in final_sequence]
event_tensor = torch.tensor(event_codes, dtype=torch.long)
time_tensor = torch.tensor(time_stamps, dtype=torch.long)
return event_tensor, time_tensor
def load_model(config_path: str, device: str = 'cpu'):
"""
Load a trained model based on the training configuration, inferring the
checkpoint filename from the configuration.
According to train.py, models may be either 'TimeAwareGPT2' or
'TimeAwareGPT2Learnable'. This function:
- Reads the config JSON to get architecture hyperparameters
- Selects the model class using config.model_name (defaults to TimeAwareGPT2 if absent)
- Infers the checkpoint path from the config values
- Infers vocab_size from the checkpoint
- Loads weights and returns the model in eval mode on the requested device
Args:
config_path: Path to the JSON configuration file saved during training.
device: 'cpu' or 'cuda'.
Returns:
torch.nn.Module: Loaded model ready for inference.
"""
# 1) Read config
with open(config_path, 'r') as f:
config_dict = json.load(f)
# Access config entries with attribute-style access while staying tolerant to missing keys
class AttrDict(dict):
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(item)
config = AttrDict(config_dict)
# 2) Decide model class (train.py supports two variants)
model_name = getattr(config, 'model_name', 'TimeAwareGPT2')
model_cls = {
'TimeAwareGPT2': TimeAwareGPT2,
'TimeAwareGPT2Learnable': TimeAwareGPT2Learnable,
}.get(model_name, TimeAwareGPT2)
# 3) Infer checkpoint filename from config
n_embd = getattr(config, 'n_embd')
n_layer = getattr(config, 'n_layer')
n_head = getattr(config, 'n_head')
# Newer naming (includes model_name) used by train.py when model_name is present
suffix_with_model = f"{model_name}_n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
ckpt_with_model = f"best_model_{suffix_with_model}.pt"
# Older naming (without model_name) matches existing repo files
suffix_legacy = f"n_embd_{n_embd}_n_layer_{n_layer}_n_head_{n_head}"
ckpt_legacy = f"best_model_{suffix_legacy}.pt"
# Prefer file that exists on disk
if os.path.exists(ckpt_with_model):
model_path = ckpt_with_model
elif os.path.exists(ckpt_legacy):
model_path = ckpt_legacy
else:
# Fall back to including model_name; if not present in config earlier, user may still have saved this way
model_path = ckpt_with_model
print(f"Warning: Could not find checkpoint on disk. Expected one of: {ckpt_with_model}, {ckpt_legacy}")
# 4) Infer vocab_size from checkpoint
state_preview = torch.load(model_path, map_location='cpu')
if 'wte.weight' in state_preview:
vocab_size = state_preview['wte.weight'].shape[0]
elif 'head.weight' in state_preview:
vocab_size = state_preview['head.weight'].shape[0]
else:
candidate = None
for k, v in state_preview.items():
if isinstance(v, torch.Tensor) and v.ndim == 2:
V = max(v.shape)
if candidate is None or V > candidate:
candidate = V
if candidate is None:
raise ValueError("Unable to infer vocab_size from checkpoint. Unknown tensor shapes.")
vocab_size = candidate
# 5) Build model from config (tolerant to missing fields)
pdrop = getattr(config, 'pdrop', 0.1)
token_pdrop = getattr(config, 'token_pdrop', 0.1)
model = model_cls(
vocab_size=vocab_size,
n_embd=n_embd,
n_layer=n_layer,
n_head=n_head,
pdrop=pdrop,
token_pdrop=token_pdrop,
).to(device)
# 6) Load weights
state_dict = torch.load(model_path, map_location=device)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing:
print(f"Warning: Missing keys when loading state_dict: {missing}")
if unexpected:
print(f"Warning: Unexpected keys when loading state_dict: {unexpected}")
model.eval()
try:
num_params_m = model.get_num_params()
print(f"Model loaded from {model_path} with {num_params_m:.2f}M parameters.")
except Exception:
pass
return model
2025-10-20 13:47:50 +08:00
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
Args:
dataset (PatientEventDataset): The dataset to retrieve data from.
batch_slice (slice): The slice defining the batch of patients to retrieve.
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
These tokens will be replaced with -100. Defaults to None.
Returns:
A tuple containing four tensors:
- input_events: (batch_size, sequence_length - 1)
- input_tims: (batch_size, sequence_length - 1)
- target_events: (batch_size, sequence_length - 1)
- target_times: (batch_size, sequence_length - 1)
"""
batch_data = dataset[batch_slice]
input_events = [item[0][:-1] for item in batch_data]
input_tims = [item[1][:-1] for item in batch_data]
target_events = [item[0][1:] for item in batch_data]
target_times = [item[1][1:] for item in batch_data]
input_events = torch.stack(input_events)
input_tims = torch.stack(input_tims)
target_events = torch.stack(target_events)
target_times = torch.stack(target_times)
return input_events, input_tims, target_events, target_times