Unverified Commit a1fd0c77 authored by narugo1992's avatar narugo1992 Committed by GitHub
Browse files

Merge pull request #1 from deepghs/one_cycle

Add transformer for monochrome training
parents 0d49eaf4 25d3ed20
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -10,3 +10,4 @@ onnxruntime
click
di-toolkit
tensorboard
einops
+21 −19
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ from typing import Optional, Type
import torch
from ditk import logging
from torch import nn
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from tqdm.auto import tqdm
@@ -15,7 +16,6 @@ from .dataset import MonochromeDataset
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .transformer import SigTransformer
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR
from ..utils import LRTyping, get_init_lr, get_dynamic_lr_scheduler

_TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome')
_LOG_DIR = os.path.join(_TRAIN_DIR, 'logs')
@@ -69,10 +69,10 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:


def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 256, fc: Optional[int] = 100,
          max_epochs: int = 500, learning_rate: LRTyping = 0.001,
          weight_decay: float = 1e-2, num_workers: Optional[int] = None,
          device: Optional[str] = None, save_per_epoch: int = 10, model_name: str = 'alexnet'):
          train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 180, fc: Optional[int] = 75,
          max_epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 1e-3,
          num_workers: Optional[int] = None, device: Optional[str] = None,
          save_per_epoch: int = 10, model_name: str = 'alexnet'):
    session_name = session_name or model_name
    _log_dir = os.path.join(_LOG_DIR, session_name)
    os.makedirs(_log_dir, exist_ok=True)
@@ -93,12 +93,12 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    # 使用 random_split 函数拆分数据集
    num_workers = num_workers or os.cpu_count()
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers,
                                  drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)

    # Load previous epoch
    model = _KNOWN_MODELS[model_name]().float()
    # model = MonochromeAlexNet().float()
    if from_ckpt is None:
        from_ckpt = _find_latest_ckpt(session_name)
    previous_epoch = _ckpt_epoch(from_ckpt) or 0
@@ -108,20 +108,20 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    else:
        logging.info(f'No checkpoint found, new model will be used.')

    # Try use cude
    # Try use cuda
    if not device:
        device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = model.to(device)

    loss_fn = nn.CrossEntropyLoss()
    initial_lr = get_init_lr(learning_rate)
    optimizer = torch.optim.AdamW(
        [{'params': model.parameters(), 'initial_lr': initial_lr}],
        lr=initial_lr, weight_decay=weight_decay,
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate,
        steps_per_epoch=len(train_dataloader), epochs=max_epochs,
        pct_start=0.15, final_div_factor=20.
    )
    scheduler = get_dynamic_lr_scheduler(optimizer, lr=learning_rate, last_epoch=previous_epoch)

    for epoch in tqdm(range(previous_epoch + 1, max_epochs + 1)):
    for epoch in range(previous_epoch + 1, max_epochs + 1):
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(tqdm(train_dataloader)):
            inputs = inputs.float().to(device)
@@ -132,33 +132,35 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            scheduler.step()

        epoch_loss = running_loss / len(train_dataset)
        logging.info(f'Epoch [{epoch}/{max_epochs + 1}] loss: {epoch_loss:.4f}, '
                     f'with learning rate: {scheduler.get_last_lr()[0]:.6f}')
        scheduler.step()
        writer.add_scalar('train/loss', epoch_loss, epoch)

        with torch.no_grad():
            train_correct = 0
            train_correct, train_total = 0, 0
            for i, (inputs, labels) in enumerate(tqdm(train_dataloader)):
                inputs = inputs.float().to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                train_correct += (torch.argmax(outputs, dim=1) == labels).sum().item()
                train_total += labels.shape[0]

            train_accuracy = train_correct / len(train_dataset)
            train_accuracy = train_correct / train_total
            logging.info(f'Epoch {epoch} train accuracy: {train_accuracy:.4f}')
            writer.add_scalar('train/accuracy', train_accuracy, epoch)

            test_correct = 0
            test_correct, test_total = 0, 0
            for i, (inputs, labels) in enumerate(tqdm(test_dataloader)):
                inputs = inputs.float().to(device)
                labels = labels.to(device)
                outputs = model(inputs)
                test_correct += (torch.argmax(outputs, dim=1) == labels).sum().item()
                test_total += labels.shape[0]

            test_accuracy = test_correct / len(test_dataset)
            test_accuracy = test_correct / test_total
            logging.info(f'Epoch {epoch} test accuracy: {test_accuracy:.4f}')
            writer.add_scalar('test/accuracy', test_accuracy, epoch)

+13 −8
Original line number Diff line number Diff line
import math

import torch
from einops import repeat, rearrange
from einops.layers.torch import Rearrange
from torch import nn


@@ -31,23 +33,26 @@ class CNNHead(nn.Module):
            # nn.SiLU(),
            # nn.Conv1d(embed_dim // 2, embed_dim, kernel_size=5, stride=2),
            nn.Conv1d(in_chans, embed_dim, kernel_size=2, stride=2),
            Rearrange('b h n -> n b h'),
            nn.LayerNorm(embed_dim),
        )

    def forward(self, x):  # x:[B,ch,N_seq]
        x = self.proj(x).permute(2, 0, 1)
        x = self.proj(x)
        return x


class SigTransformer(nn.Module):
    __model_name__ = 'transformer'

    def __init__(self, in_ch=3, n_cls=2, hidden=512, nlayers=5, dropout=0.1, seq_len=128):
    def __init__(self, in_ch=3, n_cls=2, hidden=512, nlayers=12, dropout=0.1, seq_len=90):
        super(SigTransformer, self).__init__()
        nhead = hidden // 64

        self.head = CNNHead(in_ch, hidden)
        # self.pos_encoder = PositionalEncoding(hidden, dropout)
        self.pos_embedding = nn.Parameter(torch.randn(seq_len + 1, 1, hidden))
        self.pos_embedding = nn.Parameter(torch.randn(seq_len + 1, 1, hidden) * 0.02)
        self.pos_drop = nn.Dropout(p=dropout)

        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden) * 0.02)

@@ -60,13 +65,14 @@ class SigTransformer(nn.Module):
        )

    def forward(self, src):
        src = self.head(src)  # [N,B,emb]
        cls_tokens = self.cls_token.expand(-1, src.shape[1], -1)
        src = self.head(src)  # [N,B,h]
        cls_tokens = repeat(self.cls_token, '1 1 h -> 1 b h', b=src.shape[1])
        src = torch.cat((cls_tokens, src), dim=0)
        # src = self.pos_encoder(src)
        src += self.pos_embedding
        src = self.pos_drop(src)

        output = self.encoder(src).transpose(0, 1)  # [B,N,emb]
        output = rearrange(self.encoder(src), 'n b h -> b n h')
        output = self.mlp_head(output[:, 0, :])

        return output
@@ -74,7 +80,6 @@ class SigTransformer(nn.Module):

if __name__ == '__main__':
    transformer = SigTransformer()
    x = torch.randn(8, 3, 256)
    x = torch.randn(8, 3, 180)
    y = transformer(x)
    print(y)
    print(y.shape)