From 5382f9f15908976f5165a38c82fcfce0b2702af5 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Thu, 8 Jan 2026 13:14:29 +0800 Subject: [PATCH] Remove EMA model implementation from Trainer class and related parameters from TrainConfig --- train.py | 62 +++----------------------------------------------------- 1 file changed, 3 insertions(+), 59 deletions(-) diff --git a/train.py b/train.py index fedf03b..78ce0c0 100644 --- a/train.py +++ b/train.py @@ -53,8 +53,6 @@ class TrainConfig: grad_clip: float = 1.0 weight_decay: float = 1e-2 device: str = 'cuda' if torch.cuda.is_available() else 'cpu' - # EMA parameters - ema_decay: float = 0.999 def parse_args() -> TrainConfig: @@ -103,8 +101,6 @@ def parse_args() -> TrainConfig: default=1.0, help="Gradient clipping value.") parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay for optimizer.") - parser.add_argument("--ema_decay", type=float, - default=0.999, help="EMA decay rate.") parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help="Device to use for training.") args = parser.parse_args() @@ -215,46 +211,6 @@ class Trainer: raise ValueError(f"Unsupported model type: {cfg.model_type}") print(f"Model initialized: {cfg.model_type}") print(f"Number of trainable parameters: {get_num_params(self.model)}") - - # Initialize EMA model - self.ema_model = None - if cfg.ema_decay < 1.0: - if cfg.model_type == "delphi_fork": - self.ema_model = DelphiFork( - n_disease=dataset.n_disease, - n_tech_tokens=2, - n_embd=cfg.n_embd, - n_head=cfg.n_head, - n_layer=cfg.n_layer, - pdrop=cfg.pdrop, - age_encoder_type=cfg.age_encoder, - n_dim=n_dim, - n_cont=dataset.n_cont, - n_cate=dataset.n_cate, - cate_dims=dataset.cate_dims, - ).to(self.device) - elif cfg.model_type == "sap_delphi": - self.ema_model = SapDelphi( - n_disease=dataset.n_disease, - n_tech_tokens=2, - n_embd=cfg.n_embd, - n_head=cfg.n_head, - n_layer=cfg.n_layer, - pdrop=cfg.pdrop, - age_encoder_type=cfg.age_encoder, - n_dim=n_dim, - n_cont=dataset.n_cont, - n_cate=dataset.n_cate, - cate_dims=dataset.cate_dims, - pretrained_emd_path=cfg.pretrained_emd_path, - freeze_pretrained_emd=True, - ).to(self.device) - else: - raise ValueError(f"Unsupported model type: {cfg.model_type}") - self.ema_model.load_state_dict(self.model.state_dict()) - for param in self.ema_model.parameters(): - param.requires_grad = False - print("EMA model initialized.") self.optimizer = AdamW( self.model.parameters(), lr=cfg.max_lr, @@ -285,17 +241,6 @@ class Trainer: json.dump(asdict(self.cfg), f, indent=4) print(f"Configuration saved to {cfg_path}") - def update_ema(self): - if self.ema_model is None: - return - decay = self.cfg.ema_decay - with torch.no_grad(): - model_params = dict(self.model.named_parameters()) - ema_params = dict(self.ema_model.named_parameters()) - for name in model_params.keys(): - ema_params[name].data.mul_(decay).add_( - model_params[name].data, alpha=1 - decay) - def compute_lr(self, current_step: int) -> float: cfg = self.cfg if current_step < cfg.warmup_epochs * len(self.train_loader): @@ -373,7 +318,6 @@ class Trainer: clip_grad_norm_(self.model.parameters(), self.cfg.grad_clip) self.optimizer.step() - self.update_ema() self.global_step += 1 if batch_count == 0: @@ -383,7 +327,7 @@ class Trainer: train_nll = running_nll / batch_count train_reg = running_reg / batch_count - self.ema_model.eval() + self.model.eval() total_val_pairs = 0 total_val_nll = 0.0 total_val_reg = 0.0 @@ -407,7 +351,7 @@ class Trainer: continue dt, b_prev, t_prev, b_next, t_next = res num_pairs = dt.size(0) - logits = self.ema_model( + logits = self.model( event_seq, time_seq, sexes, @@ -465,7 +409,7 @@ class Trainer: torch.save({ "epoch": epoch, "global_step": self.global_step, - "model_state_dict": self.ema_model.state_dict(), + "model_state_dict": self.model.state_dict(), "criterion_state_dict": self.criterion.state_dict(), "optimizer_state_dict": self.optimizer.state_dict(), }, self.best_path)