Commit f8faea31 authored by dzy7e's avatar dzy7e
Browse files

optim dataset

parent 43dd6fb6
Loading
Loading
Loading
Loading
+42 −21
Original line number Diff line number Diff line
@@ -5,6 +5,9 @@ import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision.transforms import transforms
from copy import deepcopy
from tqdm.auto import tqdm
import random

from .encode import image_encode

@@ -18,41 +21,59 @@ TRANSFORM = transforms.Compose([
    transforms.Resize(450),
])

TRANSFORM_val = transforms.Compose([
    transforms.Resize(450),
])

class ImageDirectoryDataset(Dataset):
    def __init__(self, root_dir, label: int = 1, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM):
class MonochromeDataset(Dataset):
    def __init__(self, root_dir: str, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM):
        self.root_dir = root_dir
        self.label = label
        self.bins = bins
        self.fc = fc
        self.transform = transform
        self.samples = []
        for file_name in os.listdir(root_dir):
            file_path = os.path.join(root_dir, file_name)
            self.samples.append(file_path)
        self.pre_build = False

        mono_dir = os.path.join(root_dir, 'monochrome')
        for file_name in os.listdir(mono_dir):
            file_path = os.path.join(mono_dir, file_name)
            self.samples.append((file_path, 1))

        normal_dir = os.path.join(root_dir, 'normal')
        for file_name in os.listdir(normal_dir):
            file_path = os.path.join(normal_dir, file_name)
            self.samples.append((file_path, 0))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        file_path = self.samples[idx]
        image = Image.open(file_path)
    def get_hist(self, sample):
        image = Image.open(sample)
        if self.transform:
            image = self.transform(image)
        image = image.convert('HSV')
        return image_encode(image, bins=self.bins, fc=self.fc, normalize=True), torch.tensor(self.label)
        return image_encode(image, bins=self.bins, fc=self.fc, normalize=True)

    def __getitem__(self, idx):
        sample, label = self.samples[idx]
        if self.pre_build:
            return sample, label
        else:
            return self.get_hist(sample), label

class MonochromeDataset(Dataset):
    def __init__(self, root_dir: str, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM):
        self.monochrome = ImageDirectoryDataset(os.path.join(root_dir, 'monochrome'), 1, bins, fc, transform)
        self.normal = ImageDirectoryDataset(os.path.join(root_dir, 'normal'), 0, bins, fc, transform)
def random_split_dataset(dataset:MonochromeDataset, train_size, test_size):
    train_data = deepcopy(dataset)
    random.shuffle(train_data.samples)
    all_samples = train_data.samples
    train_data.samples = train_data.samples[:train_size]

    def __len__(self):
        return len(self.monochrome) + len(self.normal)
    test_data = dataset
    test_data.transform = TRANSFORM_val
    samples_build = []
    print('pre-build testset')
    for sample, label in tqdm(all_samples[train_size:train_size+test_size]):
        samples_build.append((test_data.get_hist(sample), label))
    test_data.samples = samples_build
    test_data.pre_build=True

    def __getitem__(self, idx):
        if idx < len(self.monochrome):
            return self.monochrome[idx]
        else:
            return self.normal[idx - len(self.monochrome)]
    return train_data, test_data
 No newline at end of file
+7 −2
Original line number Diff line number Diff line
@@ -141,6 +141,11 @@ class ResNet152(ResNet):


if __name__ == '__main__':
    from thop import profile

    net = ResNet50(2)
    y = net(torch.randn(10, 3, 400))
    print(y.shape)
    x = torch.randn(1, 3, 180)

    flops, params = profile(net, (x,))
    print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')
+63 −45
Original line number Diff line number Diff line
@@ -12,12 +12,14 @@ from torch.optim import lr_scheduler
from tqdm.auto import tqdm

from .alexnet import MonochromeAlexNet
from .dataset import MonochromeDataset
from .dataset import MonochromeDataset, random_split_dataset
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .transformer import SigTransformer
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR
from ..utils import LRTyping, get_init_lr, get_dynamic_lr_scheduler

from accelerate import Accelerator

_TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome')
_LOG_DIR = os.path.join(_TRAIN_DIR, 'logs')
_CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts')
@@ -72,9 +74,16 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:
def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 256, fc: Optional[int] = 100,
          max_epochs: int = 500, learning_rate: LRTyping = 0.001, num_workers=8,
          save_per_epoch: int = 10, model_name: str = 'alexnet'):
          save_per_epoch: int = 10, eval_epoch=5, model_name: str = 'alexnet'):
    accelerator = Accelerator(
        #mixed_precision=self.cfgs.mixed_precision,
        step_scheduler_with_optimizer=False,
    )

    session_name = session_name or model_name
    _log_dir = os.path.join(_LOG_DIR, session_name)

    if accelerator.is_local_main_process:
        os.makedirs(_log_dir, exist_ok=True)
        os.makedirs(_CKPT_DIR, exist_ok=True)
        writer = SummaryWriter(_log_dir)
@@ -91,7 +100,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    test_size = dataset_size - train_size

    # 使用 random_split 函数拆分数据集
    train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size])
    train_dataset, test_dataset = random_split_dataset(full_dataset, train_size, test_size)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)

@@ -108,8 +117,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        logging.info(f'No checkpoint found, new model will be used.')

    # Try use cude
    if torch.cuda.is_available():
        model = model.cuda()
    #if torch.cuda.is_available():
    #    model = model.cuda()

    loss_fn = nn.CrossEntropyLoss()
    initial_lr = get_init_lr(learning_rate)
@@ -119,57 +128,66 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
                            steps_per_epoch=len(train_dataloader), epochs=max_epochs,
                            pct_start=0.15, final_div_factor=20.)

    model, optimizer, train_dataloader, test_dataloader, scheduler=accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler)

    for epoch in range(previous_epoch + 1, max_epochs + 1):
        running_loss = 0.0
        train_correct = 0
        for i, (inputs, labels) in enumerate(tqdm(train_dataloader)):
            inputs = inputs.float()
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()
            inputs = inputs.to(accelerator.device)
            labels = labels.to(accelerator.device)

            optimizer.zero_grad()
            outputs = model(inputs)
            train_correct += (torch.argmax(outputs, dim=1) == labels).sum().detach().item()

            loss = loss_fn(outputs, labels)
            loss.backward()
            #loss.backward()
            accelerator.backward(loss)
            optimizer.step()
            running_loss += loss.item() * inputs.size(0)
            scheduler.step()

        epoch_loss = running_loss / len(train_dataset)
        epoch_loss = torch.tensor(epoch_loss).to(accelerator.device)
        epoch_loss = accelerator.reduce(epoch_loss, reduction="sum")

        if accelerator.is_local_main_process:
            epoch_loss = epoch_loss.item()
            logging.info(f'Epoch [{epoch}/{max_epochs+1}] loss: {epoch_loss:.4f}, with learning rate: {scheduler.get_last_lr()[0]:.6f}')
        #scheduler.step()
            writer.add_scalar('train/loss', epoch_loss, epoch)

        with torch.no_grad():
            train_correct = 0
            for i, (inputs, labels) in enumerate(tqdm(train_dataloader)):
                inputs = inputs.float()
                if torch.cuda.is_available():
                    inputs = inputs.cuda()
                    labels = labels.cuda()

                outputs = model(inputs)
                train_correct += (torch.argmax(outputs, dim=1) == labels).sum().item()

        train_accuracy = train_correct / len(train_dataset)
        train_accuracy = torch.tensor(train_accuracy).to(accelerator.device)
        train_accuracy = accelerator.reduce(train_accuracy, reduction="sum")

        if accelerator.is_local_main_process:
            train_accuracy = train_accuracy.item()
            logging.info(f'Epoch {epoch} train accuracy: {train_accuracy:.4f}')
            writer.add_scalar('train/accuracy', train_accuracy, epoch)

        if epoch%eval_epoch == 0:
            with torch.no_grad():
                test_correct = 0
                for i, (inputs, labels) in enumerate(tqdm(test_dataloader)):
                    inputs = inputs.float()
                if torch.cuda.is_available():
                    inputs = inputs.cuda()
                    labels = labels.cuda()
                    inputs = inputs.to(accelerator.device)
                    labels = labels.to(accelerator.device)

                    outputs = model(inputs)
                    test_correct += (torch.argmax(outputs, dim=1) == labels).sum().item()

                test_accuracy = test_correct / len(test_dataset)
                test_accuracy = torch.tensor(test_accuracy).to(accelerator.device)
                test_accuracy = accelerator.reduce(test_accuracy, reduction="sum")

                if accelerator.is_local_main_process:
                    test_accuracy = test_accuracy.item()
                    logging.info(f'Epoch {epoch} test accuracy: {test_accuracy:.4f}')
                    writer.add_scalar('test/accuracy', test_accuracy, epoch)

        if epoch % save_per_epoch == 0:
        if accelerator.is_local_main_process and epoch % save_per_epoch == 0:
            current_ckpt_file = os.path.join(_CKPT_DIR, f'monochrome-{session_name}-{epoch}.ckpt')
            torch.save(model.state_dict(), current_ckpt_file)
            logging.info(f'Saved to {current_ckpt_file!r}.')
+9 −5
Original line number Diff line number Diff line
@@ -31,7 +31,7 @@ class CNNHead(nn.Module):
            #nn.BatchNorm1d(embed_dim // 2),
            #nn.SiLU(),
            #nn.Conv1d(embed_dim // 2, embed_dim, kernel_size=5, stride=2),
            nn.Conv1d(in_chans, embed_dim, kernel_size=2, stride=2),
            nn.Conv1d(in_chans, embed_dim, kernel_size=1, stride=1),
            Rearrange('b h n -> n b h'),
            nn.LayerNorm(embed_dim),
        )
@@ -44,7 +44,7 @@ class CNNHead(nn.Module):
class SigTransformer(nn.Module):
    __model_name__ = 'transformer'

    def __init__(self, in_ch=3, n_cls=2, hidden=512, nlayers=12, dropout=0.1, seq_len=128):
    def __init__(self, in_ch=3, n_cls=2, hidden=512, nlayers=16, dropout=0.1, seq_len=180):
        super(SigTransformer, self).__init__()
        nhead = hidden // 64

@@ -78,7 +78,11 @@ class SigTransformer(nn.Module):


if __name__ == '__main__':
    from thop import profile

    transformer = SigTransformer()
    x = torch.randn(8, 3, 400)
    y = transformer(x)
    print(y.shape)
    x = torch.randn(1, 3, 180)

    flops, params = profile(transformer, (x,))
    print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')