Remove EMA model implementation from Trainer class and related parameters from TrainConfig
This commit is contained in:
62
train.py
62
train.py
@@ -53,8 +53,6 @@ class TrainConfig:
|
||||
grad_clip: float = 1.0
|
||||
weight_decay: float = 1e-2
|
||||
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
# EMA parameters
|
||||
ema_decay: float = 0.999
|
||||
|
||||
|
||||
def parse_args() -> TrainConfig:
|
||||
@@ -103,8 +101,6 @@ def parse_args() -> TrainConfig:
|
||||
default=1.0, help="Gradient clipping value.")
|
||||
parser.add_argument("--weight_decay", type=float,
|
||||
default=1e-2, help="Weight decay for optimizer.")
|
||||
parser.add_argument("--ema_decay", type=float,
|
||||
default=0.999, help="EMA decay rate.")
|
||||
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
||||
help="Device to use for training.")
|
||||
args = parser.parse_args()
|
||||
@@ -215,46 +211,6 @@ class Trainer:
|
||||
raise ValueError(f"Unsupported model type: {cfg.model_type}")
|
||||
print(f"Model initialized: {cfg.model_type}")
|
||||
print(f"Number of trainable parameters: {get_num_params(self.model)}")
|
||||
|
||||
# Initialize EMA model
|
||||
self.ema_model = None
|
||||
if cfg.ema_decay < 1.0:
|
||||
if cfg.model_type == "delphi_fork":
|
||||
self.ema_model = DelphiFork(
|
||||
n_disease=dataset.n_disease,
|
||||
n_tech_tokens=2,
|
||||
n_embd=cfg.n_embd,
|
||||
n_head=cfg.n_head,
|
||||
n_layer=cfg.n_layer,
|
||||
pdrop=cfg.pdrop,
|
||||
age_encoder_type=cfg.age_encoder,
|
||||
n_dim=n_dim,
|
||||
n_cont=dataset.n_cont,
|
||||
n_cate=dataset.n_cate,
|
||||
cate_dims=dataset.cate_dims,
|
||||
).to(self.device)
|
||||
elif cfg.model_type == "sap_delphi":
|
||||
self.ema_model = SapDelphi(
|
||||
n_disease=dataset.n_disease,
|
||||
n_tech_tokens=2,
|
||||
n_embd=cfg.n_embd,
|
||||
n_head=cfg.n_head,
|
||||
n_layer=cfg.n_layer,
|
||||
pdrop=cfg.pdrop,
|
||||
age_encoder_type=cfg.age_encoder,
|
||||
n_dim=n_dim,
|
||||
n_cont=dataset.n_cont,
|
||||
n_cate=dataset.n_cate,
|
||||
cate_dims=dataset.cate_dims,
|
||||
pretrained_emd_path=cfg.pretrained_emd_path,
|
||||
freeze_pretrained_emd=True,
|
||||
).to(self.device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported model type: {cfg.model_type}")
|
||||
self.ema_model.load_state_dict(self.model.state_dict())
|
||||
for param in self.ema_model.parameters():
|
||||
param.requires_grad = False
|
||||
print("EMA model initialized.")
|
||||
self.optimizer = AdamW(
|
||||
self.model.parameters(),
|
||||
lr=cfg.max_lr,
|
||||
@@ -285,17 +241,6 @@ class Trainer:
|
||||
json.dump(asdict(self.cfg), f, indent=4)
|
||||
print(f"Configuration saved to {cfg_path}")
|
||||
|
||||
def update_ema(self):
|
||||
if self.ema_model is None:
|
||||
return
|
||||
decay = self.cfg.ema_decay
|
||||
with torch.no_grad():
|
||||
model_params = dict(self.model.named_parameters())
|
||||
ema_params = dict(self.ema_model.named_parameters())
|
||||
for name in model_params.keys():
|
||||
ema_params[name].data.mul_(decay).add_(
|
||||
model_params[name].data, alpha=1 - decay)
|
||||
|
||||
def compute_lr(self, current_step: int) -> float:
|
||||
cfg = self.cfg
|
||||
if current_step < cfg.warmup_epochs * len(self.train_loader):
|
||||
@@ -373,7 +318,6 @@ class Trainer:
|
||||
clip_grad_norm_(self.model.parameters(),
|
||||
self.cfg.grad_clip)
|
||||
self.optimizer.step()
|
||||
self.update_ema()
|
||||
self.global_step += 1
|
||||
|
||||
if batch_count == 0:
|
||||
@@ -383,7 +327,7 @@ class Trainer:
|
||||
train_nll = running_nll / batch_count
|
||||
train_reg = running_reg / batch_count
|
||||
|
||||
self.ema_model.eval()
|
||||
self.model.eval()
|
||||
total_val_pairs = 0
|
||||
total_val_nll = 0.0
|
||||
total_val_reg = 0.0
|
||||
@@ -407,7 +351,7 @@ class Trainer:
|
||||
continue
|
||||
dt, b_prev, t_prev, b_next, t_next = res
|
||||
num_pairs = dt.size(0)
|
||||
logits = self.ema_model(
|
||||
logits = self.model(
|
||||
event_seq,
|
||||
time_seq,
|
||||
sexes,
|
||||
@@ -465,7 +409,7 @@ class Trainer:
|
||||
torch.save({
|
||||
"epoch": epoch,
|
||||
"global_step": self.global_step,
|
||||
"model_state_dict": self.ema_model.state_dict(),
|
||||
"model_state_dict": self.model.state_dict(),
|
||||
"criterion_state_dict": self.criterion.state_dict(),
|
||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||
}, self.best_path)
|
||||
|
||||
Reference in New Issue
Block a user