Remove EMA model implementation from Trainer class and related parameters from TrainConfig
This commit is contained in:
62
train.py
62
train.py
@@ -53,8 +53,6 @@ class TrainConfig:
|
|||||||
grad_clip: float = 1.0
|
grad_clip: float = 1.0
|
||||||
weight_decay: float = 1e-2
|
weight_decay: float = 1e-2
|
||||||
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||||
# EMA parameters
|
|
||||||
ema_decay: float = 0.999
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> TrainConfig:
|
def parse_args() -> TrainConfig:
|
||||||
@@ -103,8 +101,6 @@ def parse_args() -> TrainConfig:
|
|||||||
default=1.0, help="Gradient clipping value.")
|
default=1.0, help="Gradient clipping value.")
|
||||||
parser.add_argument("--weight_decay", type=float,
|
parser.add_argument("--weight_decay", type=float,
|
||||||
default=1e-2, help="Weight decay for optimizer.")
|
default=1e-2, help="Weight decay for optimizer.")
|
||||||
parser.add_argument("--ema_decay", type=float,
|
|
||||||
default=0.999, help="EMA decay rate.")
|
|
||||||
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
parser.add_argument("--device", type=str, default='cuda' if torch.cuda.is_available() else 'cpu',
|
||||||
help="Device to use for training.")
|
help="Device to use for training.")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
@@ -215,46 +211,6 @@ class Trainer:
|
|||||||
raise ValueError(f"Unsupported model type: {cfg.model_type}")
|
raise ValueError(f"Unsupported model type: {cfg.model_type}")
|
||||||
print(f"Model initialized: {cfg.model_type}")
|
print(f"Model initialized: {cfg.model_type}")
|
||||||
print(f"Number of trainable parameters: {get_num_params(self.model)}")
|
print(f"Number of trainable parameters: {get_num_params(self.model)}")
|
||||||
|
|
||||||
# Initialize EMA model
|
|
||||||
self.ema_model = None
|
|
||||||
if cfg.ema_decay < 1.0:
|
|
||||||
if cfg.model_type == "delphi_fork":
|
|
||||||
self.ema_model = DelphiFork(
|
|
||||||
n_disease=dataset.n_disease,
|
|
||||||
n_tech_tokens=2,
|
|
||||||
n_embd=cfg.n_embd,
|
|
||||||
n_head=cfg.n_head,
|
|
||||||
n_layer=cfg.n_layer,
|
|
||||||
pdrop=cfg.pdrop,
|
|
||||||
age_encoder_type=cfg.age_encoder,
|
|
||||||
n_dim=n_dim,
|
|
||||||
n_cont=dataset.n_cont,
|
|
||||||
n_cate=dataset.n_cate,
|
|
||||||
cate_dims=dataset.cate_dims,
|
|
||||||
).to(self.device)
|
|
||||||
elif cfg.model_type == "sap_delphi":
|
|
||||||
self.ema_model = SapDelphi(
|
|
||||||
n_disease=dataset.n_disease,
|
|
||||||
n_tech_tokens=2,
|
|
||||||
n_embd=cfg.n_embd,
|
|
||||||
n_head=cfg.n_head,
|
|
||||||
n_layer=cfg.n_layer,
|
|
||||||
pdrop=cfg.pdrop,
|
|
||||||
age_encoder_type=cfg.age_encoder,
|
|
||||||
n_dim=n_dim,
|
|
||||||
n_cont=dataset.n_cont,
|
|
||||||
n_cate=dataset.n_cate,
|
|
||||||
cate_dims=dataset.cate_dims,
|
|
||||||
pretrained_emd_path=cfg.pretrained_emd_path,
|
|
||||||
freeze_pretrained_emd=True,
|
|
||||||
).to(self.device)
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unsupported model type: {cfg.model_type}")
|
|
||||||
self.ema_model.load_state_dict(self.model.state_dict())
|
|
||||||
for param in self.ema_model.parameters():
|
|
||||||
param.requires_grad = False
|
|
||||||
print("EMA model initialized.")
|
|
||||||
self.optimizer = AdamW(
|
self.optimizer = AdamW(
|
||||||
self.model.parameters(),
|
self.model.parameters(),
|
||||||
lr=cfg.max_lr,
|
lr=cfg.max_lr,
|
||||||
@@ -285,17 +241,6 @@ class Trainer:
|
|||||||
json.dump(asdict(self.cfg), f, indent=4)
|
json.dump(asdict(self.cfg), f, indent=4)
|
||||||
print(f"Configuration saved to {cfg_path}")
|
print(f"Configuration saved to {cfg_path}")
|
||||||
|
|
||||||
def update_ema(self):
|
|
||||||
if self.ema_model is None:
|
|
||||||
return
|
|
||||||
decay = self.cfg.ema_decay
|
|
||||||
with torch.no_grad():
|
|
||||||
model_params = dict(self.model.named_parameters())
|
|
||||||
ema_params = dict(self.ema_model.named_parameters())
|
|
||||||
for name in model_params.keys():
|
|
||||||
ema_params[name].data.mul_(decay).add_(
|
|
||||||
model_params[name].data, alpha=1 - decay)
|
|
||||||
|
|
||||||
def compute_lr(self, current_step: int) -> float:
|
def compute_lr(self, current_step: int) -> float:
|
||||||
cfg = self.cfg
|
cfg = self.cfg
|
||||||
if current_step < cfg.warmup_epochs * len(self.train_loader):
|
if current_step < cfg.warmup_epochs * len(self.train_loader):
|
||||||
@@ -373,7 +318,6 @@ class Trainer:
|
|||||||
clip_grad_norm_(self.model.parameters(),
|
clip_grad_norm_(self.model.parameters(),
|
||||||
self.cfg.grad_clip)
|
self.cfg.grad_clip)
|
||||||
self.optimizer.step()
|
self.optimizer.step()
|
||||||
self.update_ema()
|
|
||||||
self.global_step += 1
|
self.global_step += 1
|
||||||
|
|
||||||
if batch_count == 0:
|
if batch_count == 0:
|
||||||
@@ -383,7 +327,7 @@ class Trainer:
|
|||||||
train_nll = running_nll / batch_count
|
train_nll = running_nll / batch_count
|
||||||
train_reg = running_reg / batch_count
|
train_reg = running_reg / batch_count
|
||||||
|
|
||||||
self.ema_model.eval()
|
self.model.eval()
|
||||||
total_val_pairs = 0
|
total_val_pairs = 0
|
||||||
total_val_nll = 0.0
|
total_val_nll = 0.0
|
||||||
total_val_reg = 0.0
|
total_val_reg = 0.0
|
||||||
@@ -407,7 +351,7 @@ class Trainer:
|
|||||||
continue
|
continue
|
||||||
dt, b_prev, t_prev, b_next, t_next = res
|
dt, b_prev, t_prev, b_next, t_next = res
|
||||||
num_pairs = dt.size(0)
|
num_pairs = dt.size(0)
|
||||||
logits = self.ema_model(
|
logits = self.model(
|
||||||
event_seq,
|
event_seq,
|
||||||
time_seq,
|
time_seq,
|
||||||
sexes,
|
sexes,
|
||||||
@@ -465,7 +409,7 @@ class Trainer:
|
|||||||
torch.save({
|
torch.save({
|
||||||
"epoch": epoch,
|
"epoch": epoch,
|
||||||
"global_step": self.global_step,
|
"global_step": self.global_step,
|
||||||
"model_state_dict": self.ema_model.state_dict(),
|
"model_state_dict": self.model.state_dict(),
|
||||||
"criterion_state_dict": self.criterion.state_dict(),
|
"criterion_state_dict": self.criterion.state_dict(),
|
||||||
"optimizer_state_dict": self.optimizer.state_dict(),
|
"optimizer_state_dict": self.optimizer.state_dict(),
|
||||||
}, self.best_path)
|
}, self.best_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user