Loading zoo/monochrome/dataset.py +42 −21 Original line number Diff line number Diff line Loading @@ -5,6 +5,9 @@ import torch from PIL import Image from torch.utils.data import Dataset from torchvision.transforms import transforms from copy import deepcopy from tqdm.auto import tqdm import random from .encode import image_encode Loading @@ -18,41 +21,59 @@ TRANSFORM = transforms.Compose([ transforms.Resize(450), ]) TRANSFORM_val = transforms.Compose([ transforms.Resize(450), ]) class ImageDirectoryDataset(Dataset): def __init__(self, root_dir, label: int = 1, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM): class MonochromeDataset(Dataset): def __init__(self, root_dir: str, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM): self.root_dir = root_dir self.label = label self.bins = bins self.fc = fc self.transform = transform self.samples = [] for file_name in os.listdir(root_dir): file_path = os.path.join(root_dir, file_name) self.samples.append(file_path) self.pre_build = False mono_dir = os.path.join(root_dir, 'monochrome') for file_name in os.listdir(mono_dir): file_path = os.path.join(mono_dir, file_name) self.samples.append((file_path, 1)) normal_dir = os.path.join(root_dir, 'normal') for file_name in os.listdir(normal_dir): file_path = os.path.join(normal_dir, file_name) self.samples.append((file_path, 0)) def __len__(self): return len(self.samples) def __getitem__(self, idx): file_path = self.samples[idx] image = Image.open(file_path) def get_hist(self, sample): image = Image.open(sample) if self.transform: image = self.transform(image) image = image.convert('HSV') return image_encode(image, bins=self.bins, fc=self.fc, normalize=True), torch.tensor(self.label) return image_encode(image, bins=self.bins, fc=self.fc, normalize=True) def __getitem__(self, idx): sample, label = self.samples[idx] if self.pre_build: return sample, label else: return self.get_hist(sample), label class MonochromeDataset(Dataset): def __init__(self, root_dir: str, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM): self.monochrome = ImageDirectoryDataset(os.path.join(root_dir, 'monochrome'), 1, bins, fc, transform) self.normal = ImageDirectoryDataset(os.path.join(root_dir, 'normal'), 0, bins, fc, transform) def random_split_dataset(dataset:MonochromeDataset, train_size, test_size): train_data = deepcopy(dataset) random.shuffle(train_data.samples) all_samples = train_data.samples train_data.samples = train_data.samples[:train_size] def __len__(self): return len(self.monochrome) + len(self.normal) test_data = dataset test_data.transform = TRANSFORM_val samples_build = [] print('pre-build testset') for sample, label in tqdm(all_samples[train_size:train_size+test_size]): samples_build.append((test_data.get_hist(sample), label)) test_data.samples = samples_build test_data.pre_build=True def __getitem__(self, idx): if idx < len(self.monochrome): return self.monochrome[idx] else: return self.normal[idx - len(self.monochrome)] return train_data, test_data No newline at end of file zoo/monochrome/resnet.py +7 −2 Original line number Diff line number Diff line Loading @@ -141,6 +141,11 @@ class ResNet152(ResNet): if __name__ == '__main__': from thop import profile net = ResNet50(2) y = net(torch.randn(10, 3, 400)) print(y.shape) x = torch.randn(1, 3, 180) flops, params = profile(net, (x,)) print('FLOPs = ' + str(flops / 1000 ** 3) + 'G') print('Params = ' + str(params / 1000 ** 2) + 'M') zoo/monochrome/train_.py +63 −45 Original line number Diff line number Diff line Loading @@ -12,12 +12,14 @@ from torch.optim import lr_scheduler from tqdm.auto import tqdm from .alexnet import MonochromeAlexNet from .dataset import MonochromeDataset from .dataset import MonochromeDataset, random_split_dataset from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 from .transformer import SigTransformer from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR from ..utils import LRTyping, get_init_lr, get_dynamic_lr_scheduler from accelerate import Accelerator _TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome') _LOG_DIR = os.path.join(_TRAIN_DIR, 'logs') _CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts') Loading Loading @@ -72,9 +74,16 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]: def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None, train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 256, fc: Optional[int] = 100, max_epochs: int = 500, learning_rate: LRTyping = 0.001, num_workers=8, save_per_epoch: int = 10, model_name: str = 'alexnet'): save_per_epoch: int = 10, eval_epoch=5, model_name: str = 'alexnet'): accelerator = Accelerator( #mixed_precision=self.cfgs.mixed_precision, step_scheduler_with_optimizer=False, ) session_name = session_name or model_name _log_dir = os.path.join(_LOG_DIR, session_name) if accelerator.is_local_main_process: os.makedirs(_log_dir, exist_ok=True) os.makedirs(_CKPT_DIR, exist_ok=True) writer = SummaryWriter(_log_dir) Loading @@ -91,7 +100,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio test_size = dataset_size - train_size # 使用 random_split 函数拆分数据集 train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size]) train_dataset, test_dataset = random_split_dataset(full_dataset, train_size, test_size) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers) Loading @@ -108,8 +117,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio logging.info(f'No checkpoint found, new model will be used.') # Try use cude if torch.cuda.is_available(): model = model.cuda() #if torch.cuda.is_available(): # model = model.cuda() loss_fn = nn.CrossEntropyLoss() initial_lr = get_init_lr(learning_rate) Loading @@ -119,57 +128,66 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio steps_per_epoch=len(train_dataloader), epochs=max_epochs, pct_start=0.15, final_div_factor=20.) model, optimizer, train_dataloader, test_dataloader, scheduler=accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler) for epoch in range(previous_epoch + 1, max_epochs + 1): running_loss = 0.0 train_correct = 0 for i, (inputs, labels) in enumerate(tqdm(train_dataloader)): inputs = inputs.float() if torch.cuda.is_available(): inputs = inputs.cuda() labels = labels.cuda() inputs = inputs.to(accelerator.device) labels = labels.to(accelerator.device) optimizer.zero_grad() outputs = model(inputs) train_correct += (torch.argmax(outputs, dim=1) == labels).sum().detach().item() loss = loss_fn(outputs, labels) loss.backward() #loss.backward() accelerator.backward(loss) optimizer.step() running_loss += loss.item() * inputs.size(0) scheduler.step() epoch_loss = running_loss / len(train_dataset) epoch_loss = torch.tensor(epoch_loss).to(accelerator.device) epoch_loss = accelerator.reduce(epoch_loss, reduction="sum") if accelerator.is_local_main_process: epoch_loss = epoch_loss.item() logging.info(f'Epoch [{epoch}/{max_epochs+1}] loss: {epoch_loss:.4f}, with learning rate: {scheduler.get_last_lr()[0]:.6f}') #scheduler.step() writer.add_scalar('train/loss', epoch_loss, epoch) with torch.no_grad(): train_correct = 0 for i, (inputs, labels) in enumerate(tqdm(train_dataloader)): inputs = inputs.float() if torch.cuda.is_available(): inputs = inputs.cuda() labels = labels.cuda() outputs = model(inputs) train_correct += (torch.argmax(outputs, dim=1) == labels).sum().item() train_accuracy = train_correct / len(train_dataset) train_accuracy = torch.tensor(train_accuracy).to(accelerator.device) train_accuracy = accelerator.reduce(train_accuracy, reduction="sum") if accelerator.is_local_main_process: train_accuracy = train_accuracy.item() logging.info(f'Epoch {epoch} train accuracy: {train_accuracy:.4f}') writer.add_scalar('train/accuracy', train_accuracy, epoch) if epoch%eval_epoch == 0: with torch.no_grad(): test_correct = 0 for i, (inputs, labels) in enumerate(tqdm(test_dataloader)): inputs = inputs.float() if torch.cuda.is_available(): inputs = inputs.cuda() labels = labels.cuda() inputs = inputs.to(accelerator.device) labels = labels.to(accelerator.device) outputs = model(inputs) test_correct += (torch.argmax(outputs, dim=1) == labels).sum().item() test_accuracy = test_correct / len(test_dataset) test_accuracy = torch.tensor(test_accuracy).to(accelerator.device) test_accuracy = accelerator.reduce(test_accuracy, reduction="sum") if accelerator.is_local_main_process: test_accuracy = test_accuracy.item() logging.info(f'Epoch {epoch} test accuracy: {test_accuracy:.4f}') writer.add_scalar('test/accuracy', test_accuracy, epoch) if epoch % save_per_epoch == 0: if accelerator.is_local_main_process and epoch % save_per_epoch == 0: current_ckpt_file = os.path.join(_CKPT_DIR, f'monochrome-{session_name}-{epoch}.ckpt') torch.save(model.state_dict(), current_ckpt_file) logging.info(f'Saved to {current_ckpt_file!r}.') zoo/monochrome/transformer.py +9 −5 Original line number Diff line number Diff line Loading @@ -31,7 +31,7 @@ class CNNHead(nn.Module): #nn.BatchNorm1d(embed_dim // 2), #nn.SiLU(), #nn.Conv1d(embed_dim // 2, embed_dim, kernel_size=5, stride=2), nn.Conv1d(in_chans, embed_dim, kernel_size=2, stride=2), nn.Conv1d(in_chans, embed_dim, kernel_size=1, stride=1), Rearrange('b h n -> n b h'), nn.LayerNorm(embed_dim), ) Loading @@ -44,7 +44,7 @@ class CNNHead(nn.Module): class SigTransformer(nn.Module): __model_name__ = 'transformer' def __init__(self, in_ch=3, n_cls=2, hidden=512, nlayers=12, dropout=0.1, seq_len=128): def __init__(self, in_ch=3, n_cls=2, hidden=512, nlayers=16, dropout=0.1, seq_len=180): super(SigTransformer, self).__init__() nhead = hidden // 64 Loading Loading @@ -78,7 +78,11 @@ class SigTransformer(nn.Module): if __name__ == '__main__': from thop import profile transformer = SigTransformer() x = torch.randn(8, 3, 400) y = transformer(x) print(y.shape) x = torch.randn(1, 3, 180) flops, params = profile(transformer, (x,)) print('FLOPs = ' + str(flops / 1000 ** 3) + 'G') print('Params = ' + str(params / 1000 ** 2) + 'M') Loading
zoo/monochrome/dataset.py +42 −21 Original line number Diff line number Diff line Loading @@ -5,6 +5,9 @@ import torch from PIL import Image from torch.utils.data import Dataset from torchvision.transforms import transforms from copy import deepcopy from tqdm.auto import tqdm import random from .encode import image_encode Loading @@ -18,41 +21,59 @@ TRANSFORM = transforms.Compose([ transforms.Resize(450), ]) TRANSFORM_val = transforms.Compose([ transforms.Resize(450), ]) class ImageDirectoryDataset(Dataset): def __init__(self, root_dir, label: int = 1, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM): class MonochromeDataset(Dataset): def __init__(self, root_dir: str, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM): self.root_dir = root_dir self.label = label self.bins = bins self.fc = fc self.transform = transform self.samples = [] for file_name in os.listdir(root_dir): file_path = os.path.join(root_dir, file_name) self.samples.append(file_path) self.pre_build = False mono_dir = os.path.join(root_dir, 'monochrome') for file_name in os.listdir(mono_dir): file_path = os.path.join(mono_dir, file_name) self.samples.append((file_path, 1)) normal_dir = os.path.join(root_dir, 'normal') for file_name in os.listdir(normal_dir): file_path = os.path.join(normal_dir, file_name) self.samples.append((file_path, 0)) def __len__(self): return len(self.samples) def __getitem__(self, idx): file_path = self.samples[idx] image = Image.open(file_path) def get_hist(self, sample): image = Image.open(sample) if self.transform: image = self.transform(image) image = image.convert('HSV') return image_encode(image, bins=self.bins, fc=self.fc, normalize=True), torch.tensor(self.label) return image_encode(image, bins=self.bins, fc=self.fc, normalize=True) def __getitem__(self, idx): sample, label = self.samples[idx] if self.pre_build: return sample, label else: return self.get_hist(sample), label class MonochromeDataset(Dataset): def __init__(self, root_dir: str, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM): self.monochrome = ImageDirectoryDataset(os.path.join(root_dir, 'monochrome'), 1, bins, fc, transform) self.normal = ImageDirectoryDataset(os.path.join(root_dir, 'normal'), 0, bins, fc, transform) def random_split_dataset(dataset:MonochromeDataset, train_size, test_size): train_data = deepcopy(dataset) random.shuffle(train_data.samples) all_samples = train_data.samples train_data.samples = train_data.samples[:train_size] def __len__(self): return len(self.monochrome) + len(self.normal) test_data = dataset test_data.transform = TRANSFORM_val samples_build = [] print('pre-build testset') for sample, label in tqdm(all_samples[train_size:train_size+test_size]): samples_build.append((test_data.get_hist(sample), label)) test_data.samples = samples_build test_data.pre_build=True def __getitem__(self, idx): if idx < len(self.monochrome): return self.monochrome[idx] else: return self.normal[idx - len(self.monochrome)] return train_data, test_data No newline at end of file
zoo/monochrome/resnet.py +7 −2 Original line number Diff line number Diff line Loading @@ -141,6 +141,11 @@ class ResNet152(ResNet): if __name__ == '__main__': from thop import profile net = ResNet50(2) y = net(torch.randn(10, 3, 400)) print(y.shape) x = torch.randn(1, 3, 180) flops, params = profile(net, (x,)) print('FLOPs = ' + str(flops / 1000 ** 3) + 'G') print('Params = ' + str(params / 1000 ** 2) + 'M')
zoo/monochrome/train_.py +63 −45 Original line number Diff line number Diff line Loading @@ -12,12 +12,14 @@ from torch.optim import lr_scheduler from tqdm.auto import tqdm from .alexnet import MonochromeAlexNet from .dataset import MonochromeDataset from .dataset import MonochromeDataset, random_split_dataset from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152 from .transformer import SigTransformer from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR from ..utils import LRTyping, get_init_lr, get_dynamic_lr_scheduler from accelerate import Accelerator _TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome') _LOG_DIR = os.path.join(_TRAIN_DIR, 'logs') _CKPT_DIR = os.path.join(_TRAIN_DIR, 'ckpts') Loading Loading @@ -72,9 +74,16 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]: def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None, train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 256, fc: Optional[int] = 100, max_epochs: int = 500, learning_rate: LRTyping = 0.001, num_workers=8, save_per_epoch: int = 10, model_name: str = 'alexnet'): save_per_epoch: int = 10, eval_epoch=5, model_name: str = 'alexnet'): accelerator = Accelerator( #mixed_precision=self.cfgs.mixed_precision, step_scheduler_with_optimizer=False, ) session_name = session_name or model_name _log_dir = os.path.join(_LOG_DIR, session_name) if accelerator.is_local_main_process: os.makedirs(_log_dir, exist_ok=True) os.makedirs(_CKPT_DIR, exist_ok=True) writer = SummaryWriter(_log_dir) Loading @@ -91,7 +100,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio test_size = dataset_size - train_size # 使用 random_split 函数拆分数据集 train_dataset, test_dataset = torch.utils.data.random_split(full_dataset, [train_size, test_size]) train_dataset, test_dataset = random_split_dataset(full_dataset, train_size, test_size) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True) test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers) Loading @@ -108,8 +117,8 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio logging.info(f'No checkpoint found, new model will be used.') # Try use cude if torch.cuda.is_available(): model = model.cuda() #if torch.cuda.is_available(): # model = model.cuda() loss_fn = nn.CrossEntropyLoss() initial_lr = get_init_lr(learning_rate) Loading @@ -119,57 +128,66 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio steps_per_epoch=len(train_dataloader), epochs=max_epochs, pct_start=0.15, final_div_factor=20.) model, optimizer, train_dataloader, test_dataloader, scheduler=accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler) for epoch in range(previous_epoch + 1, max_epochs + 1): running_loss = 0.0 train_correct = 0 for i, (inputs, labels) in enumerate(tqdm(train_dataloader)): inputs = inputs.float() if torch.cuda.is_available(): inputs = inputs.cuda() labels = labels.cuda() inputs = inputs.to(accelerator.device) labels = labels.to(accelerator.device) optimizer.zero_grad() outputs = model(inputs) train_correct += (torch.argmax(outputs, dim=1) == labels).sum().detach().item() loss = loss_fn(outputs, labels) loss.backward() #loss.backward() accelerator.backward(loss) optimizer.step() running_loss += loss.item() * inputs.size(0) scheduler.step() epoch_loss = running_loss / len(train_dataset) epoch_loss = torch.tensor(epoch_loss).to(accelerator.device) epoch_loss = accelerator.reduce(epoch_loss, reduction="sum") if accelerator.is_local_main_process: epoch_loss = epoch_loss.item() logging.info(f'Epoch [{epoch}/{max_epochs+1}] loss: {epoch_loss:.4f}, with learning rate: {scheduler.get_last_lr()[0]:.6f}') #scheduler.step() writer.add_scalar('train/loss', epoch_loss, epoch) with torch.no_grad(): train_correct = 0 for i, (inputs, labels) in enumerate(tqdm(train_dataloader)): inputs = inputs.float() if torch.cuda.is_available(): inputs = inputs.cuda() labels = labels.cuda() outputs = model(inputs) train_correct += (torch.argmax(outputs, dim=1) == labels).sum().item() train_accuracy = train_correct / len(train_dataset) train_accuracy = torch.tensor(train_accuracy).to(accelerator.device) train_accuracy = accelerator.reduce(train_accuracy, reduction="sum") if accelerator.is_local_main_process: train_accuracy = train_accuracy.item() logging.info(f'Epoch {epoch} train accuracy: {train_accuracy:.4f}') writer.add_scalar('train/accuracy', train_accuracy, epoch) if epoch%eval_epoch == 0: with torch.no_grad(): test_correct = 0 for i, (inputs, labels) in enumerate(tqdm(test_dataloader)): inputs = inputs.float() if torch.cuda.is_available(): inputs = inputs.cuda() labels = labels.cuda() inputs = inputs.to(accelerator.device) labels = labels.to(accelerator.device) outputs = model(inputs) test_correct += (torch.argmax(outputs, dim=1) == labels).sum().item() test_accuracy = test_correct / len(test_dataset) test_accuracy = torch.tensor(test_accuracy).to(accelerator.device) test_accuracy = accelerator.reduce(test_accuracy, reduction="sum") if accelerator.is_local_main_process: test_accuracy = test_accuracy.item() logging.info(f'Epoch {epoch} test accuracy: {test_accuracy:.4f}') writer.add_scalar('test/accuracy', test_accuracy, epoch) if epoch % save_per_epoch == 0: if accelerator.is_local_main_process and epoch % save_per_epoch == 0: current_ckpt_file = os.path.join(_CKPT_DIR, f'monochrome-{session_name}-{epoch}.ckpt') torch.save(model.state_dict(), current_ckpt_file) logging.info(f'Saved to {current_ckpt_file!r}.')
zoo/monochrome/transformer.py +9 −5 Original line number Diff line number Diff line Loading @@ -31,7 +31,7 @@ class CNNHead(nn.Module): #nn.BatchNorm1d(embed_dim // 2), #nn.SiLU(), #nn.Conv1d(embed_dim // 2, embed_dim, kernel_size=5, stride=2), nn.Conv1d(in_chans, embed_dim, kernel_size=2, stride=2), nn.Conv1d(in_chans, embed_dim, kernel_size=1, stride=1), Rearrange('b h n -> n b h'), nn.LayerNorm(embed_dim), ) Loading @@ -44,7 +44,7 @@ class CNNHead(nn.Module): class SigTransformer(nn.Module): __model_name__ = 'transformer' def __init__(self, in_ch=3, n_cls=2, hidden=512, nlayers=12, dropout=0.1, seq_len=128): def __init__(self, in_ch=3, n_cls=2, hidden=512, nlayers=16, dropout=0.1, seq_len=180): super(SigTransformer, self).__init__() nhead = hidden // 64 Loading Loading @@ -78,7 +78,11 @@ class SigTransformer(nn.Module): if __name__ == '__main__': from thop import profile transformer = SigTransformer() x = torch.randn(8, 3, 400) y = transformer(x) print(y.shape) x = torch.randn(1, 3, 180) flops, params = profile(transformer, (x,)) print('FLOPs = ' + str(flops / 1000 ** 3) + 'G') print('Params = ' + str(params / 1000 ** 2) + 'M')