From c70c3cd71ed7078fc55f619341762ca4fe53cb53 Mon Sep 17 00:00:00 2001 From: Jiarui Li Date: Fri, 9 Jan 2026 10:16:03 +0800 Subject: [PATCH] Reorganize import statements for consistency and clarity in model and training scripts --- model.py | 18 ++++++++++++------ train.py | 25 ++++++++++++++----------- 2 files changed, 26 insertions(+), 17 deletions(-) diff --git a/model.py b/model.py index d9e4523..df395b2 100644 --- a/model.py +++ b/model.py @@ -1,10 +1,16 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder -from backbones import Block -from typing import Optional, List import numpy as np +from typing import Optional, List +from backbones import Block +from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder +import torch.nn.functional as F +import torch.nn as nn +import torch +import sys +from pathlib import Path + +_PROJECT_ROOT = Path(__file__).resolve().parent +if str(_PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(_PROJECT_ROOT)) class TabularEncoder(nn.Module): diff --git a/train.py b/train.py index eff1727..702c0e4 100644 --- a/train.py +++ b/train.py @@ -1,23 +1,26 @@ +from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt +from model import DelphiFork, SapDelphi +from dataset import HealthDataset, health_collate_fn +from tqdm import tqdm +from torch.nn.utils import clip_grad_norm_ +from torch.utils.data import random_split +from torch.utils.data import DataLoader +from torch.optim import AdamW +import torch.nn as nn +import torch import json import os import time import argparse import math +import sys from dataclasses import asdict, dataclass, field from typing import Literal, Sequence from pathlib import Path -import torch -import torch.nn as nn -from torch.optim import AdamW -from torch.utils.data import DataLoader -from torch.utils.data import random_split -from torch.nn.utils import clip_grad_norm_ -from tqdm import tqdm - -from dataset import HealthDataset, health_collate_fn -from model import DelphiFork, SapDelphi -from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt +_PROJECT_ROOT = Path(__file__).resolve().parent +if str(_PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(_PROJECT_ROOT)) @dataclass