Files
DeepHealth/train_iter.py
Jiarui Li a631ac6d59 feat: Add load_model function and update training script
Added a `load_model` function to `utils.py` to allow loading of trained models from configuration and state dictionary files.

The `train_iter.py` script was also modified, likely to incorporate or test this new functionality.
2025-10-18 11:07:59 +08:00

219 lines
7.6 KiB
Python

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.utils.data import DataLoader
import numpy as np
import math
import tqdm
import matplotlib.pyplot as plt
import json
import itertools
from models import TimeAwareGPT2, CombinedLoss
from utils import PatientEventDataset
# --- Configuration ---
class TrainConfig:
# Data parameters
train_data_path = 'ukb_real_train.bin'
val_data_path = 'ukb_real_val.bin'
block_length = 48 # Sequence length
# Model parameters
n_embd = 120
n_layer = 12
n_head = 12
pdrop = 0.0
token_pdrop = 0.0
# Training parameters
max_iter = 200000
batch_size = 128
lr_initial = 6e-4
lr_final = 6e-5
weight_decay = 2e-1
warmup_iter = 1000
# Loss parameters
# 0 = padding, 1 = "no event"
ignored_token_ids = [0, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12] # Example ignored token IDs
# System parameters
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- Main Training Script ---
def main():
config = TrainConfig()
model_filename = f"best_model_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.pt"
# --- 0. Save Configuration ---
config_filename = f"config_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.json"
config_dict = {k: v for k, v in vars(config).items() if not k.startswith('__')}
with open(config_filename, 'w') as f:
json.dump(config_dict, f, indent=4)
print(f"Configuration saved to {config_filename}")
# --- 1. Data Loading ---
print(f"Loading data from {config.train_data_path} and {config.val_data_path}...")
train_data_arr = np.memmap(config.train_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
val_data_arr = np.memmap(config.val_data_path, dtype=np.uint32, mode='r').reshape(-1, 3)
# Infer vocab_size from the data (max label + 1)
vocab_size = int(max(train_data_arr[:, 2].max(), val_data_arr[:, 2].max())) + 1
print(f"Inferred vocabulary size: {vocab_size}")
train_dataset = PatientEventDataset(train_data_arr, config.block_length)
val_dataset = PatientEventDataset(val_data_arr, config.block_length)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=4, pin_memory=True)
train_iter_loader = iter(itertools.cycle(train_loader))
# --- 2. Model, Optimizer, and Loss Initialization ---
print(f"Initializing model on {config.device}...")
model = TimeAwareGPT2(
vocab_size=vocab_size,
n_embd=config.n_embd,
n_layer=config.n_layer,
n_head=config.n_head,
pdrop=config.pdrop,
token_pdrop=config.token_pdrop
).to(config.device)
print(f"Model initialized with {model.get_num_params():.2f}M trainable parameters.")
loss_fn = CombinedLoss(config.ignored_token_ids)
optimizer = AdamW(model.parameters(), lr=config.lr_initial, weight_decay=config.weight_decay, betas=(0.9, 0.99))
# --- 3. Training Loop ---
# Lists to store losses
train_losses_ce, train_losses_surv, train_losses_total = [], [], []
print("Starting training...")
pbar = tqdm.tqdm(range(1, config.max_iter + 1), desc="Training")
for iter_num in pbar:
# --- Learning Rate Scheduling ---
if iter_num < config.warmup_iter:
lr = config.lr_initial
else:
progress = (iter_num - config.warmup_iter) / (config.max_iter - config.warmup_iter)
lr = config.lr_final + 0.5 * (config.lr_initial - config.lr_final) * (1 + math.cos(math.pi * progress))
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# --- Training Step ---
model.train()
event_seq, time_seq = next(train_iter_loader)
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
# Prepare inputs and targets
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
# Forward pass
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
loss = loss_ce + loss_survival
# Backward pass and optimization
optimizer.zero_grad()
loss.backward()
optimizer.step()
train_losses_ce.append(loss_ce.item())
train_losses_surv.append(loss_survival.item())
train_losses_total.append(loss.item())
pbar.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}', 'lr': f'{lr:.2e}'})
print("\nTraining finished.")
# --- 4. Final Validation ---
print("Running final validation...")
model.eval()
val_loss_ce_acc, val_loss_surv_acc = 0.0, 0.0
val_steps = 0
with torch.no_grad():
pbar_val = tqdm.tqdm(val_loader, desc="Final Validation")
for event_seq, time_seq in pbar_val:
event_seq, time_seq = event_seq.to(config.device), time_seq.to(config.device)
input_events = event_seq[:, :-1]
input_times = time_seq[:, :-1]
target_events = event_seq[:, 1:]
target_wait_times = (time_seq[:, 1:] - time_seq[:, :-1]).float()
logits = model(input_events, input_times)
loss_ce, loss_survival = loss_fn(logits, target_events, target_wait_times)
val_loss_ce_acc += loss_ce.item()
val_loss_surv_acc += loss_survival.item()
val_steps += 1
pbar_val.set_postfix({'loss_ce': f'{loss_ce.item():.4f}', 'loss_surv': f'{loss_survival.item():.4f}'})
avg_val_loss_ce = val_loss_ce_acc / val_steps
avg_val_loss_surv = val_loss_surv_acc / val_steps
total_val_loss = avg_val_loss_ce + avg_val_loss_surv
print(f"Final Validation Summary: \n"
f" Val Loss: {total_val_loss:.4f} (CE: {avg_val_loss_ce:.4f}, Surv: {avg_val_loss_surv:.4f})")
# --- 5. Save Model ---
print(f"Saving final model to {model_filename}")
torch.save(model.state_dict(), model_filename)
# --- 6. Save and Plot Losses ---
losses_filename = f"losses_n_embd_{config.n_embd}_n_layer_{config.n_layer}_n_head_{config.n_head}_iter.txt"
with open(losses_filename, 'w') as f:
f.write("iteration,train_loss_ce,train_loss_surv,train_loss_total\n")
for i in range(len(train_losses_total)):
f.write(f"{i+1},{train_losses_ce[i]},{train_losses_surv[i]},{train_losses_total[i]}\n")
print(f"\nLosses saved to {losses_filename}")
# Plot and Save Loss Curves
iterations = range(1, len(train_losses_total) + 1)
plt.figure(figsize=(18, 5))
# Plot CE Loss
plt.subplot(1, 3, 1)
plt.plot(iterations, train_losses_ce, label='Train CE')
plt.title('Cross-Entropy Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Survival Loss
plt.subplot(1, 3, 2)
plt.plot(iterations, train_losses_surv, label='Train Survival')
plt.title('Survival Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
# Plot Total Loss
plt.subplot(1, 3, 3)
plt.plot(iterations, train_losses_total, label='Train Total')
plt.title('Total Loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig('loss_curves_iter.png')
print("\nLoss curves saved to loss_curves_iter.png")
if __name__ == '__main__':
main()