Reorganize import statements for consistency and clarity in model and training scripts
This commit is contained in:
18
model.py
18
model.py
@@ -1,10 +1,16 @@
|
|||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
|
|
||||||
from backbones import Block
|
|
||||||
from typing import Optional, List
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from typing import Optional, List
|
||||||
|
from backbones import Block
|
||||||
|
from age_encoder import AgeSinusoidalEncoder, AgeMLPEncoder
|
||||||
|
import torch.nn.functional as F
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
_PROJECT_ROOT = Path(__file__).resolve().parent
|
||||||
|
if str(_PROJECT_ROOT) not in sys.path:
|
||||||
|
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||||
|
|
||||||
|
|
||||||
class TabularEncoder(nn.Module):
|
class TabularEncoder(nn.Module):
|
||||||
|
|||||||
25
train.py
25
train.py
@@ -1,23 +1,26 @@
|
|||||||
|
from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt
|
||||||
|
from model import DelphiFork, SapDelphi
|
||||||
|
from dataset import HealthDataset, health_collate_fn
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
from torch.utils.data import random_split
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torch.optim import AdamW
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import argparse
|
import argparse
|
||||||
import math
|
import math
|
||||||
|
import sys
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Literal, Sequence
|
from typing import Literal, Sequence
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
_PROJECT_ROOT = Path(__file__).resolve().parent
|
||||||
import torch.nn as nn
|
if str(_PROJECT_ROOT) not in sys.path:
|
||||||
from torch.optim import AdamW
|
sys.path.insert(0, str(_PROJECT_ROOT))
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
from torch.utils.data import random_split
|
|
||||||
from torch.nn.utils import clip_grad_norm_
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
from dataset import HealthDataset, health_collate_fn
|
|
||||||
from model import DelphiFork, SapDelphi
|
|
||||||
from losses import ExponentialNLLLoss, PiecewiseExponentialLoss, WeibullNLLLoss, get_valid_pairs_and_dt
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
Reference in New Issue
Block a user