Remove EMA model implementation from Trainer class and related parameters from TrainConfig

This commit is contained in:
2026-01-08 13:14:29 +08:00
parent 615e2fe748
commit 5382f9f159

View File

@@ -53,8 +53,6 @@ class TrainConfig:
grad_clip: float = 1.0
weight_decay: float = 1e-2
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
# EMA parameters
ema_decay: float = 0.999
def parse_args() -> TrainConfig:
@@ -103,8 +101,6 @@ def parse_args() -> TrainConfig:
default=1.0, help="Gradient clipping value.")
parser.add_argument("--weight_decay", type=float,
default=1e-2, help="Weight decay for optimizer.")
parser.add_argument("--ema_decay", type=float,
default=0.999, help="EMA decay rate.")
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
help="Device to use for training.")
args = parser.parse_args()
@@ -215,46 +211,6 @@ class Trainer:
raise ValueError(f"Unsupported model type: {cfg.model_type}")
print(f"Model initialized: {cfg.model_type}")
print(f"Number of trainable parameters: {get_num_params(self.model)}")
# Initialize EMA model
self.ema_model = None
if cfg.ema_decay < 1.0:
if cfg.model_type == "delphi_fork":
self.ema_model = DelphiFork(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=cfg.n_embd,
n_head=cfg.n_head,
n_layer=cfg.n_layer,
pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder,
n_dim=n_dim,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
).to(self.device)
elif cfg.model_type == "sap_delphi":
self.ema_model = SapDelphi(
n_disease=dataset.n_disease,
n_tech_tokens=2,
n_embd=cfg.n_embd,
n_head=cfg.n_head,
n_layer=cfg.n_layer,
pdrop=cfg.pdrop,
age_encoder_type=cfg.age_encoder,
n_dim=n_dim,
n_cont=dataset.n_cont,
n_cate=dataset.n_cate,
cate_dims=dataset.cate_dims,
pretrained_emd_path=cfg.pretrained_emd_path,
freeze_pretrained_emd=True,
).to(self.device)
else:
raise ValueError(f"Unsupported model type: {cfg.model_type}")
self.ema_model.load_state_dict(self.model.state_dict())
for param in self.ema_model.parameters():
param.requires_grad = False
print("EMA model initialized.")
self.optimizer = AdamW(
self.model.parameters(),
lr=cfg.max_lr,
@@ -285,17 +241,6 @@ class Trainer:
json.dump(asdict(self.cfg), f, indent=4)
print(f"Configuration saved to {cfg_path}")
def update_ema(self):
if self.ema_model is None:
return
decay = self.cfg.ema_decay
with torch.no_grad():
model_params = dict(self.model.named_parameters())
ema_params = dict(self.ema_model.named_parameters())
for name in model_params.keys():
ema_params[name].data.mul_(decay).add_(
model_params[name].data, alpha=1 - decay)
def compute_lr(self, current_step: int) -> float:
cfg = self.cfg
if current_step < cfg.warmup_epochs * len(self.train_loader):
@@ -373,7 +318,6 @@ class Trainer:
clip_grad_norm_(self.model.parameters(),
self.cfg.grad_clip)
self.optimizer.step()
self.update_ema()
self.global_step += 1
if batch_count == 0:
@@ -383,7 +327,7 @@ class Trainer:
train_nll = running_nll / batch_count
train_reg = running_reg / batch_count
self.ema_model.eval()
self.model.eval()
total_val_pairs = 0
total_val_nll = 0.0
total_val_reg = 0.0
@@ -407,7 +351,7 @@ class Trainer:
continue
dt, b_prev, t_prev, b_next, t_next = res
num_pairs = dt.size(0)
logits = self.ema_model(
logits = self.model(
event_seq,
time_seq,
sexes,
@@ -465,7 +409,7 @@ class Trainer:
torch.save({
"epoch": epoch,
"global_step": self.global_step,
"model_state_dict": self.ema_model.state_dict(),
"model_state_dict": self.model.state_dict(),
"criterion_state_dict": self.criterion.state_dict(),
"optimizer_state_dict": self.optimizer.state_dict(),
}, self.best_path)