update evaluate

This commit is contained in:
2025-10-20 13:47:50 +08:00
parent 1c9e2a2fb3
commit 8f44018bae
3 changed files with 182 additions and 76 deletions

View File

@@ -42,17 +42,22 @@ class PatientEventDataset(torch.utils.data.Dataset):
"""
return len(self.patient_ids)
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
def __getitem__(self, idx):
"""
Retrieves, processes, and returns a single patient's event sequence.
Retrieves, processes, and returns a single patient's event sequence,
or a list of sequences if a slice is provided.
Args:
idx (int): The index of the patient to retrieve.
idx (int or slice): The index or slice of the patient(s) to retrieve.
Returns:
A tuple of two torch.long tensors: (event_sequence, time_sequence),
both of shape (block_length,).
If idx is an int, a tuple of two torch.long tensors:
(event_sequence, time_sequence), both of shape (block_length,).
If idx is a slice, a list of such tuples.
"""
if isinstance(idx, slice):
return [self[i] for i in range(*idx.indices(len(self)))]
# 1. Retrieve and Sort
patient_id = self.patient_ids[idx]
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
@@ -150,3 +155,35 @@ def load_model(config_path, model_path, vocab_size, device='cpu'):
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.")
return model
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
Args:
dataset (PatientEventDataset): The dataset to retrieve data from.
batch_slice (slice): The slice defining the batch of patients to retrieve.
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
These tokens will be replaced with -100. Defaults to None.
Returns:
A tuple containing four tensors:
- input_events: (batch_size, sequence_length - 1)
- input_tims: (batch_size, sequence_length - 1)
- target_events: (batch_size, sequence_length - 1)
- target_times: (batch_size, sequence_length - 1)
"""
batch_data = dataset[batch_slice]
input_events = [item[0][:-1] for item in batch_data]
input_tims = [item[1][:-1] for item in batch_data]
target_events = [item[0][1:] for item in batch_data]
target_times = [item[1][1:] for item in batch_data]
input_events = torch.stack(input_events)
input_tims = torch.stack(input_tims)
target_events = torch.stack(target_events)
target_times = torch.stack(target_times)
return input_events, input_tims, target_events, target_times