Commit 9184c694 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): basic refactor

parent 6bb9bbc6
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -11,3 +11,4 @@ click
di-toolkit
tensorboard
einops
thop
+11 −0
Original line number Diff line number Diff line
@@ -39,3 +39,14 @@ class MonochromeAlexNet(nn.Module):
        x = torch.flatten(x, 1)
        x = self.classifier(x)
        return x


if __name__ == '__main__':
    from thop import profile

    net = MonochromeAlexNet()
    x = torch.randn(1, 3, 180)

    flops, params = profile(net, (x,))
    print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')
+12 −11
Original line number Diff line number Diff line
import os
import random
from copy import deepcopy
from typing import Optional

import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from copy import deepcopy
from tqdm.auto import tqdm
import random

from .encode import image_encode

@@ -21,12 +20,13 @@ TRANSFORM = transforms.Compose([
    transforms.Resize(450),
])

TRANSFORM_val = transforms.Compose([
TRANSFORM_VAL = transforms.Compose([
    transforms.Resize(450),
])


class MonochromeDataset(Dataset):
    def __init__(self, root_dir: str, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM):
    def __init__(self, root_dir: str, bins: int = 180, fc: Optional[int] = 75, transform=TRANSFORM):
        self.root_dir = root_dir
        self.bins = bins
        self.fc = fc
@@ -61,6 +61,7 @@ class MonochromeDataset(Dataset):
        else:
            return self.get_hist(sample), label


def random_split_dataset(dataset: MonochromeDataset, train_size, test_size):
    train_data = deepcopy(dataset)
    random.shuffle(train_data.samples)
@@ -68,9 +69,9 @@ def random_split_dataset(dataset:MonochromeDataset, train_size, test_size):
    train_data.samples = train_data.samples[:train_size]

    test_data = dataset
    test_data.transform = TRANSFORM_val
    test_data.transform = TRANSFORM_VAL
    samples_build = []
    print('pre-build testset')
    # print('pre-build testset')
    for sample, label in tqdm(all_samples[train_size:train_size + test_size]):
        samples_build.append((test_data.get_hist(sample), label))
    test_data.samples = samples_build
+10 −11
Original line number Diff line number Diff line
@@ -46,18 +46,15 @@ class Bottleneck(nn.Module):
        super(Bottleneck, self).__init__()
        self.conv1 = nn.Conv1d(in_planes, planes, kernel_size=1, bias=False)
        self.bn1 = nn.BatchNorm1d(planes)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3,
                               stride=stride, padding=1, bias=False)
        self.conv2 = nn.Conv1d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(planes)
        self.conv3 = nn.Conv1d(planes, self.expansion *
                               planes, kernel_size=1, bias=False)
        self.conv3 = nn.Conv1d(planes, self.expansion * planes, kernel_size=1, bias=False)
        self.bn3 = nn.BatchNorm1d(self.expansion * planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion * planes:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_planes, self.expansion * planes,
                          kernel_size=1, stride=stride, bias=False),
                nn.Conv1d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(self.expansion * planes)
            )

@@ -143,9 +140,11 @@ class ResNet152(ResNet):
if __name__ == '__main__':
    from thop import profile

    net = ResNet50(2)
    for resnet_class in [ResNet18, ResNet34, ResNet50, ResNet101, ResNet152]:
        net = resnet_class()
        x = torch.randn(1, 3, 180)

        flops, params = profile(net, (x,))
        print(f'{resnet_class.__name__}:')
        print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
        print('Params = ' + str(params / 1000 ** 2) + 'M')
+23 −15
Original line number Diff line number Diff line
@@ -4,12 +4,12 @@ from functools import partial
from typing import Optional, Type

import torch
from accelerate import Accelerator
from ditk import logging
from torch import nn
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.ops import sigmoid_focal_loss
from tqdm.auto import tqdm

from .alexnet import MonochromeAlexNet
@@ -18,9 +18,6 @@ from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .transformer import SigTransformer
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR


from accelerate import Accelerator

_TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome')
_LOG_DIR = os.path.join(_TRAIN_DIR, 'logs')
_CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts')
@@ -94,6 +91,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
                "accuracy": ["Multiline", ["train/accuracy", "test/accuracy"]],
            },
        })
    else:
        writer = None

    # Initialize dataset
    full_dataset = MonochromeDataset(dataset_dir, bins=feature_bins, fc=fc)
@@ -103,7 +102,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

    # 使用 random_split 函数拆分数据集
    train_dataset, test_dataset = random_split_dataset(full_dataset, train_size, test_size)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                  drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)

    # Load previous epoch
@@ -130,7 +130,9 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        pct_start=0.15, final_div_factor=20.
    )

    model, optimizer, train_dataloader, test_dataloader, scheduler=accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler)
    model, optimizer, train_dataloader, test_dataloader, scheduler = accelerator.prepare(model, optimizer,
                                                                                         train_dataloader,
                                                                                         test_dataloader, scheduler)

    for epoch in range(previous_epoch + 1, max_epochs + 1):
        running_loss = 0.0
@@ -157,7 +159,11 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

        if accelerator.is_local_main_process:
            epoch_loss = epoch_loss.item()
            logging.info(f'Epoch [{epoch}/{max_epochs+1}] loss: {epoch_loss:.4f}, with learning rate: {scheduler.get_last_lr()[0]:.6f}')
            logging.info(
                f'Epoch [{epoch}/{max_epochs + 1}] loss: {epoch_loss:.4f}, '
                f'with learning rate: {scheduler.get_last_lr()[0]:.6f}'
            )
            if writer:
                writer.add_scalar('train/loss', epoch_loss, epoch)

        train_accuracy = train_correct / len(train_dataset)
@@ -167,6 +173,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        if accelerator.is_local_main_process:
            train_accuracy = train_accuracy.item()
            logging.info(f'Epoch {epoch} train accuracy: {train_accuracy:.4f}')
            if writer:
                writer.add_scalar('train/accuracy', train_accuracy, epoch)

        if epoch % eval_epoch == 0:
@@ -187,6 +194,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
                if accelerator.is_local_main_process:
                    test_accuracy = test_accuracy.item()
                    logging.info(f'Epoch {epoch} test accuracy: {test_accuracy:.4f}')
                    if writer:
                        writer.add_scalar('test/accuracy', test_accuracy, epoch)

        if accelerator.is_local_main_process and epoch % save_per_epoch == 0:
+4 −4

File changed.

Contains only whitespace changes.

Loading