Reorganize import statements for consistency and clarity in model and training scripts
This commit is contained in:
18
model.py
18
model.py
@@ -1,10 +1,16 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
|
||||
from backbones import Block
|
||||
from typing import Optional, List
|
||||
import numpy as np
|
||||
from typing import Optional, List
|
||||
from backbones import Block
|
||||
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent
|
||||
if str(_PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||
|
||||
|
||||
class TabularEncoder(nn.Module):
|
||||
|
||||
25
train.py
25
train.py
@@ -1,23 +1,26 @@
|
||||
from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt
|
||||
from model import DelphiFork, SapDelphi
|
||||
from dataset import HealthDataset, health_collate_fn
|
||||
from tqdm import tqdm
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from torch.utils.data import random_split
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.optim import AdamW
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import argparse
|
||||
import math
|
||||
import sys
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Literal, Sequence
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import random_split
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
from tqdm import tqdm
|
||||
|
||||
from dataset import HealthDataset, health_collate_fn
|
||||
from model import DelphiFork, SapDelphi
|
||||
from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt
|
||||
_PROJECT_ROOT = Path(__file__).resolve().parent
|
||||
if str(_PROJECT_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
Reference in New Issue
Block a user