update evaluate
This commit is contained in:
47
utils.py
47
utils.py
@@ -42,17 +42,22 @@ class PatientEventDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
return len(self.patient_ids)
|
||||
|
||||
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
def __getitem__(self, idx):
|
||||
"""
|
||||
Retrieves, processes, and returns a single patient's event sequence.
|
||||
Retrieves, processes, and returns a single patient's event sequence,
|
||||
or a list of sequences if a slice is provided.
|
||||
|
||||
Args:
|
||||
idx (int): The index of the patient to retrieve.
|
||||
idx (int or slice): The index or slice of the patient(s) to retrieve.
|
||||
|
||||
Returns:
|
||||
A tuple of two torch.long tensors: (event_sequence, time_sequence),
|
||||
both of shape (block_length,).
|
||||
If idx is an int, a tuple of two torch.long tensors:
|
||||
(event_sequence, time_sequence), both of shape (block_length,).
|
||||
If idx is a slice, a list of such tuples.
|
||||
"""
|
||||
if isinstance(idx, slice):
|
||||
return [self[i] for i in range(*idx.indices(len(self)))]
|
||||
|
||||
# 1. Retrieve and Sort
|
||||
patient_id = self.patient_ids[idx]
|
||||
records = sorted(self.patient_events[patient_id], key=lambda x: x[0])
|
||||
@@ -150,3 +155,35 @@ def load_model(config_path, model_path, vocab_size, device='cpu'):
|
||||
|
||||
print(f"Model loaded from {model_path} with {model.get_num_params():.2f}M parameters.")
|
||||
return model
|
||||
|
||||
|
||||
def get_batch(dataset: PatientEventDataset, batch_slice: slice) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Retrieves a batch of data from a PatientEventDataset and prepares it for model training.
|
||||
|
||||
Args:
|
||||
dataset (PatientEventDataset): The dataset to retrieve data from.
|
||||
batch_slice (slice): The slice defining the batch of patients to retrieve.
|
||||
ignore_tokens (list, optional): A list of token IDs to be ignored in the target events.
|
||||
These tokens will be replaced with -100. Defaults to None.
|
||||
|
||||
Returns:
|
||||
A tuple containing four tensors:
|
||||
- input_events: (batch_size, sequence_length - 1)
|
||||
- input_tims: (batch_size, sequence_length - 1)
|
||||
- target_events: (batch_size, sequence_length - 1)
|
||||
- target_times: (batch_size, sequence_length - 1)
|
||||
"""
|
||||
batch_data = dataset[batch_slice]
|
||||
|
||||
input_events = [item[0][:-1] for item in batch_data]
|
||||
input_tims = [item[1][:-1] for item in batch_data]
|
||||
target_events = [item[0][1:] for item in batch_data]
|
||||
target_times = [item[1][1:] for item in batch_data]
|
||||
|
||||
input_events = torch.stack(input_events)
|
||||
input_tims = torch.stack(input_tims)
|
||||
target_events = torch.stack(target_events)
|
||||
target_times = torch.stack(target_times)
|
||||
|
||||
return input_events, input_tims, target_events, target_times
|
Reference in New Issue
Block a user