Commit 1d83673a authored by dzy7e's avatar dzy7e
Browse files

levit 2d

parent 6bb9bbc6
Loading
Loading
Loading
Loading
+41 −9
Original line number Diff line number Diff line
@@ -25,6 +25,24 @@ TRANSFORM_val = transforms.Compose([
    transforms.Resize(450),
])

TRANSFORM2 = transforms.Compose([
    transforms.Resize(450),
    transforms.RandomCrop(400, padding=50, pad_if_needed=True, padding_mode='reflect'),
    transforms.RandomRotation((-180, 180)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(0.10, 0.10, 0.10, 0.10),
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

TRANSFORM2_val = transforms.Compose([
    transforms.Resize((384, 384)),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5])
])

class MonochromeDataset(Dataset):
    def __init__(self, root_dir: str, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM):
        self.root_dir = root_dir
@@ -47,33 +65,47 @@ class MonochromeDataset(Dataset):
    def __len__(self):
        return len(self.samples)

    def get_hist(self, sample):
    def pre_process(self, sample):
        image = Image.open(sample).convert('RGB')  # image must be rgb
        if self.transform:
            image = self.transform(image)
        image = image.convert('HSV')
        return image_encode(image, bins=self.bins, fc=self.fc, normalize=True)

    def cache_data(self, repeats=1):
        samples_build = []
        for sample, label in tqdm(self.samples*repeats):
            samples_build.append((self.pre_process(sample), label))
        self.samples = samples_build
        self.pre_build = True

    def __getitem__(self, idx):
        sample, label = self.samples[idx]
        if self.pre_build:
            return sample, label
        else:
            return self.get_hist(sample), label
            return self.pre_process(sample), label

class Monochrome2DDataset(MonochromeDataset):
    def __init__(self, root_dir: str, bins: int = 200, fc: Optional[int] = 50, transform=TRANSFORM2):
        super(Monochrome2DDataset, self).__init__(root_dir, bins, fc, transform)

    def pre_process(self, sample):
        image = Image.open(sample).convert('RGB')  # image must be rgb
        if self.transform:
            image = self.transform(image)
        return image

def random_split_dataset(dataset:MonochromeDataset, train_size, test_size):
def random_split_dataset(dataset:MonochromeDataset, train_size, test_size, trans_val=TRANSFORM_val):
    train_data = deepcopy(dataset)
    random.shuffle(train_data.samples)
    all_samples = train_data.samples
    train_data.samples = train_data.samples[:train_size]

    test_data = dataset
    test_data.transform = TRANSFORM_val
    samples_build = []
    test_data.transform = trans_val
    print('pre-build testset')
    for sample, label in tqdm(all_samples[train_size:train_size+test_size]):
        samples_build.append((test_data.get_hist(sample), label))
    test_data.samples = samples_build
    test_data.pre_build=True
    test_data.samples = all_samples[train_size:train_size+test_size]
    test_data.cache_data()

    return train_data, test_data
 No newline at end of file
+199 −0
Original line number Diff line number Diff line
from math import ceil

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def cast_tuple(val, l = 3):
    val = val if isinstance(val, tuple) else (val,)
    return (*val, *((val[-1],) * max(l - len(val), 0)))

def always(val):
    return lambda *args, **kwargs: val

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, mult, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv1d(dim, dim * mult, 1),
            nn.Hardswish(),
            nn.Dropout(dropout),
            nn.Conv1d(dim * mult, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False):
        super().__init__()
        inner_dim_key = dim_key *  heads
        inner_dim_value = dim_value *  heads
        dim_out = default(dim_out, dim)

        self.heads = heads
        self.scale = dim_key ** -0.5

        self.to_q = nn.Sequential(nn.Conv1d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm1d(inner_dim_key))
        self.to_k = nn.Sequential(nn.Conv1d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm1d(inner_dim_key))
        self.to_v = nn.Sequential(nn.Conv1d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm1d(inner_dim_value))

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        out_batch_norm = nn.BatchNorm1d(dim_out)
        nn.init.zeros_(out_batch_norm.weight)

        self.to_out = nn.Sequential(
            nn.GELU(),
            nn.Conv1d(inner_dim_value, dim_out, 1),
            out_batch_norm,
            nn.Dropout(dropout)
        )

        # positional bias

        self.pos_bias = nn.Embedding(fmap_size, heads)

        q_pos = torch.arange(0, fmap_size, step = (2 if downsample else 1))
        k_pos = torch.arange(fmap_size)

        rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()

        self.register_buffer('pos_indices', rel_pos)

    def apply_pos_bias(self, fmap):
        bias = self.pos_bias(self.pos_indices)
        bias = rearrange(bias, 'i j h -> () h i j')
        return fmap + (bias / self.scale)

    def forward(self, x):
        b, n, *_, h = *x.shape, self.heads

        q = self.to_q(x)
        l = q.shape[2]

        qkv = (q, self.to_k(x), self.to_v(x))
        q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        dots = self.apply_pos_bias(dots)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h l d -> b (h d) l', h = h, l = l)
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.layers = nn.ModuleList([])
        self.attn_residual = (not downsample) and dim == dim_out

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out),
                FeedForward(dim_out, mlp_mult, dropout = dropout)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            attn_res = (x if self.attn_residual else 0)
            x = attn(x) + attn_res
            x = ff(x) + x
        return x

class LeSigTransformer(nn.Module):
    __model_name__ = 'le_transformer'

    def __init__(
        self,
        seq_len = 180,
        num_classes = 2,
        dim = (256, 384, 512),
        depth = 4,
        heads = (4, 6, 8),
        mlp_mult = 4,
        stages = 3,
        dim_key = 32,
        dim_value = 64,
        dropout = 0.1,
        num_distill_classes = None
    ):
        super().__init__()

        dims = cast_tuple(dim, stages)
        depths = cast_tuple(depth, stages)
        layer_heads = cast_tuple(heads, stages)

        assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'

        self.conv_embedding = nn.Sequential(
            nn.Conv1d(3, 32, 3, stride = 1, padding = 1),
            nn.Conv1d(32, 128, 3, stride = 1, padding = 1),
            #nn.Conv1d(64, 128, 3, stride = 1, padding = 1),
            nn.Conv1d(128, dims[0], 3, stride = 1, padding = 1)
        )

        fmap_size = seq_len // (1)
        layers = []

        for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
            is_last = ind == (stages - 1)
            layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))

            if not is_last:
                next_dim = dims[ind + 1]
                layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
                fmap_size = ceil(fmap_size / 2)

        self.backbone = nn.Sequential(*layers)

        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            Rearrange('... () -> ...')
        )

        self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        x = self.conv_embedding(img)

        x = self.backbone(x)

        x = self.pool(x)

        out = self.mlp_head(x)
        distill = self.distill_head(x)

        if exists(distill):
            return out, distill

        return out

if __name__ == '__main__':
    from thop import profile

    transformer = LeSigTransformer()
    x = torch.randn(1, 3, 180)

    flops, params = profile(transformer, (x,))
    print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')
+206 −0
Original line number Diff line number Diff line
from math import ceil

import torch
from torch import nn, einsum
import torch.nn.functional as F

from einops import rearrange, repeat
from einops.layers.torch import Rearrange

# helpers

def exists(val):
    return val is not None

def default(val, d):
    return val if exists(val) else d

def cast_tuple(val, l = 3):
    val = val if isinstance(val, tuple) else (val,)
    return (*val, *((val[-1],) * max(l - len(val), 0)))

def always(val):
    return lambda *args, **kwargs: val

# classes

class FeedForward(nn.Module):
    def __init__(self, dim, mult, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(dim, dim * mult, 1),
            nn.Hardswish(),
            nn.Dropout(dropout),
            nn.Conv2d(dim * mult, dim, 1),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, dropout = 0., dim_out = None, downsample = False):
        super().__init__()
        inner_dim_key = dim_key *  heads
        inner_dim_value = dim_value *  heads
        dim_out = default(dim_out, dim)

        self.heads = heads
        self.scale = dim_key ** -0.5

        self.to_q = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, stride = (2 if downsample else 1), bias = False), nn.BatchNorm2d(inner_dim_key))
        self.to_k = nn.Sequential(nn.Conv2d(dim, inner_dim_key, 1, bias = False), nn.BatchNorm2d(inner_dim_key))
        self.to_v = nn.Sequential(nn.Conv2d(dim, inner_dim_value, 1, bias = False), nn.BatchNorm2d(inner_dim_value))

        self.attend = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

        out_batch_norm = nn.BatchNorm2d(dim_out)
        nn.init.zeros_(out_batch_norm.weight)

        self.to_out = nn.Sequential(
            nn.GELU(),
            nn.Conv2d(inner_dim_value, dim_out, 1),
            out_batch_norm,
            nn.Dropout(dropout)
        )

        # positional bias

        self.pos_bias = nn.Embedding(fmap_size * fmap_size, heads)

        q_range = torch.arange(0, fmap_size, step = (2 if downsample else 1))
        k_range = torch.arange(fmap_size)

        q_pos = torch.stack(torch.meshgrid(q_range, q_range, indexing = 'ij'), dim = -1)
        k_pos = torch.stack(torch.meshgrid(k_range, k_range, indexing = 'ij'), dim = -1)

        q_pos, k_pos = map(lambda t: rearrange(t, 'i j c -> (i j) c'), (q_pos, k_pos))
        rel_pos = (q_pos[:, None, ...] - k_pos[None, :, ...]).abs()

        x_rel, y_rel = rel_pos.unbind(dim = -1)
        pos_indices = (x_rel * fmap_size) + y_rel

        self.register_buffer('pos_indices', pos_indices)

    def apply_pos_bias(self, fmap):
        bias = self.pos_bias(self.pos_indices)
        bias = rearrange(bias, 'i j h -> () h i j')
        return fmap + (bias / self.scale)

    def forward(self, x):
        b, n, *_, h = *x.shape, self.heads

        q = self.to_q(x)
        y = q.shape[2]

        qkv = (q, self.to_k(x), self.to_v(x))
        q, k, v = map(lambda t: rearrange(t, 'b (h d) ... -> b h (...) d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        dots = self.apply_pos_bias(dots)

        attn = self.attend(dots)
        attn = self.dropout(attn)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h (x y) d -> b (h d) x y', h = h, y = y)
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult = 2, dropout = 0., dim_out = None, downsample = False):
        super().__init__()
        dim_out = default(dim_out, dim)
        self.layers = nn.ModuleList([])
        self.attn_residual = (not downsample) and dim == dim_out

        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, fmap_size = fmap_size, heads = heads, dim_key = dim_key, dim_value = dim_value, dropout = dropout, downsample = downsample, dim_out = dim_out),
                FeedForward(dim_out, mlp_mult, dropout = dropout)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            attn_res = (x if self.attn_residual else 0)
            x = attn(x) + attn_res
            x = ff(x) + x
        return x

class LeViT(nn.Module):
    __model_name__ = 'levit'

    def __init__(
        self,
        image_size=384,
        num_classes=2,
        dim = (256, 384, 512),
        depth = 4,
        heads = (4, 6, 8),
        mlp_mult = 2,
        stages = 3,
        dim_key = 32,
        dim_value = 64,
        dropout = 0.,
        num_distill_classes = None
    ):
        super().__init__()

        dims = cast_tuple(dim, stages)
        depths = cast_tuple(depth, stages)
        layer_heads = cast_tuple(heads, stages)

        assert all(map(lambda t: len(t) == stages, (dims, depths, layer_heads))), 'dimensions, depths, and heads must be a tuple that is less than the designated number of stages'

        self.conv_embedding = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride = 2, padding = 1),
            nn.Conv2d(32, 64, 3, stride = 2, padding = 1),
            nn.Conv2d(64, 128, 3, stride = 2, padding = 1),
            nn.Conv2d(128, dims[0], 3, stride = 2, padding = 1)
        )

        fmap_size = image_size // (2 ** 4)
        layers = []

        for ind, dim, depth, heads in zip(range(stages), dims, depths, layer_heads):
            is_last = ind == (stages - 1)
            layers.append(Transformer(dim, fmap_size, depth, heads, dim_key, dim_value, mlp_mult, dropout))

            if not is_last:
                next_dim = dims[ind + 1]
                layers.append(Transformer(dim, fmap_size, 1, heads * 2, dim_key, dim_value, dim_out = next_dim, downsample = True))
                fmap_size = ceil(fmap_size / 2)

        self.backbone = nn.Sequential(*layers)

        self.pool = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),
            Rearrange('... () () -> ...')
        )

        self.distill_head = nn.Linear(dim, num_distill_classes) if exists(num_distill_classes) else always(None)
        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        x = self.conv_embedding(img)

        x = self.backbone(x)

        x = self.pool(x)

        out = self.mlp_head(x)
        distill = self.distill_head(x)

        if exists(distill):
            return out, distill

        return out

if __name__ == '__main__':
    from thop import profile

    transformer = LeViT()
    x = torch.randn(1, 3, 384, 384)

    flops, params = profile(transformer, (x,))
    print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')
 No newline at end of file
+9 −5
Original line number Diff line number Diff line
@@ -13,9 +13,11 @@ from torchvision.ops import sigmoid_focal_loss
from tqdm.auto import tqdm

from .alexnet import MonochromeAlexNet
from .dataset import MonochromeDataset, random_split_dataset
from .dataset import MonochromeDataset, Monochrome2DDataset, random_split_dataset, TRANSFORM_val, TRANSFORM2_val
from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .transformer import SigTransformer
from .levit1d import LeSigTransformer
from .levit2d import LeViT
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR


@@ -42,6 +44,8 @@ _register_model(ResNet50)
_register_model(ResNet101)
_register_model(ResNet152)
_register_model(SigTransformer)
_register_model(LeSigTransformer)
_register_model(LeViT)


def _find_latest_ckpt(name: str) -> Optional[str]:
@@ -75,8 +79,8 @@ def _ckpt_epoch(filename: Optional[str]) -> Optional[int]:
def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, batch_size: int = 4, feature_bins: int = 180, fc: Optional[int] = 75,
          max_epochs: int = 500, learning_rate: float = 0.001, weight_decay: float = 1e-3,
          num_workers: Optional[int] = None, device: Optional[str] = None,
          save_per_epoch: int = 10, eval_epoch: int=5, model_name: str = 'alexnet'):
          num_workers: Optional[int] = 8, device: Optional[str] = None,
          save_per_epoch: int = 10, eval_epoch: int=5, model_name: str = 'alexnet', data_2d=False):
    accelerator = Accelerator(
        # mixed_precision=self.cfgs.mixed_precision,
        step_scheduler_with_optimizer=False,
@@ -96,13 +100,13 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        })

    # Initialize dataset
    full_dataset = MonochromeDataset(dataset_dir, bins=feature_bins, fc=fc)
    full_dataset = (Monochrome2DDataset if data_2d else MonochromeDataset)(dataset_dir, bins=feature_bins, fc=fc)
    dataset_size = len(full_dataset)
    train_size = int(train_ratio * dataset_size)
    test_size = dataset_size - train_size

    # 使用 random_split 函数拆分数据集
    train_dataset, test_dataset = random_split_dataset(full_dataset, train_size, test_size)
    train_dataset, test_dataset = random_split_dataset(full_dataset, train_size, test_size, trans_val=TRANSFORM2_val if data_2d else TRANSFORM_val)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, num_workers=num_workers)

+90 −52
Original line number Diff line number Diff line
@@ -5,84 +5,122 @@ from einops import repeat, rearrange
from einops.layers.torch import Rearrange
from torch import nn

# helpers

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=200):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
def posemb_sincos_1d(patches, temperature = 10000, dtype = torch.float32):
    _, n, dim, device, dtype = *patches.shape, patches.device, patches.dtype

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)
    n = torch.arange(n, device = device)
    assert (dim % 2) == 0, 'feature dimension must be multiple of 2 for sincos emb'
    omega = torch.arange(dim // 2, device = device) / (dim // 2 - 1)
    omega = 1. / (temperature ** omega)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)
    n = n.flatten()[:, None] * omega[None, :]
    pe = torch.cat((n.sin(), n.cos()), dim = 1)
    return pe.type(dtype)

# classes

class CNNHead(nn.Module):
    def __init__(self, in_chans=1, embed_dim=768):
class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim):
        super().__init__()
        self.proj = nn.Sequential(
            #nn.Conv1d(in_chans, embed_dim // 2, kernel_size=7, stride=2),
            #nn.BatchNorm1d(embed_dim // 2),
            #nn.SiLU(),
            #nn.Conv1d(embed_dim // 2, embed_dim, kernel_size=5, stride=2),
            nn.Conv1d(in_chans, embed_dim, kernel_size=1, stride=1),
            Rearrange('b h n -> n b h'),
            nn.LayerNorm(embed_dim),
        self.net = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Linear(hidden_dim, dim),
        )
    def forward(self, x):
        return self.net(x)

    def forward(self, x):  # x:[B,ch,N_seq]
        x = self.proj(x)
        return x
class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim_head ** -0.5
        self.norm = nn.LayerNorm(dim)

        self.attend = nn.Softmax(dim = -1)

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Linear(inner_dim, dim, bias = False)

    def forward(self, x):
        x = self.norm(x)

        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)

        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale

        attn = self.attend(dots)

        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Attention(dim, heads = heads, dim_head = dim_head),
                FeedForward(dim, mlp_dim)
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class SigTransformer(nn.Module):
    __model_name__ = 'transformer'

    def __init__(self, in_ch=3, n_cls=2, hidden=512, nlayers=16, dropout=0.1, seq_len=180):
        super(SigTransformer, self).__init__()
        nhead = hidden // 64
    def __init__(self, seq_len=180, patch_size=1, num_classes=2, dim=1024, depth=6, heads=8, mlp_dim=2048, channels = 3, dim_head = 64):
        super().__init__()

        assert seq_len % patch_size == 0

        self.head = CNNHead(in_ch, hidden)
        # self.pos_encoder = PositionalEncoding(hidden, dropout)
        self.pos_embedding = nn.Parameter(torch.randn(seq_len + 1, 1, hidden) * 0.02)
        self.pos_drop = nn.Dropout(p=dropout)
        num_patches = seq_len // patch_size
        patch_dim = channels * patch_size

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (n p) -> b n (p c)', p = patch_size),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.cls_token = nn.Parameter(torch.randn(1, 1, hidden) * 0.02)
        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim)

        encoder_layer = nn.TransformerEncoderLayer(hidden, nhead, dim_feedforward=2048, dropout=dropout)
        encoder_norm = nn.LayerNorm(hidden)
        self.encoder = nn.TransformerEncoder(encoder_layer, nlayers, encoder_norm)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(hidden),
            nn.Linear(hidden, n_cls)
        self.to_latent = nn.Identity()
        self.linear_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, src):
        src = self.head(src)  # [N,B,h]
        cls_tokens = repeat(self.cls_token, '1 1 h -> 1 b h', b=src.shape[1])
        src = torch.cat((cls_tokens, src), dim=0)
        # src = self.pos_encoder(src)
        src += self.pos_embedding
        src = self.pos_drop(src)
    def forward(self, series):
        *_, n, dtype = *series.shape, series.dtype

        output = rearrange(self.encoder(src), 'n b h -> b n h')
        output = self.mlp_head(output[:, 0, :])
        x = self.to_patch_embedding(series)
        pe = posemb_sincos_1d(x)
        x = rearrange(x, 'b ... d -> b (...) d') + pe

        return output
        x = self.transformer(x)
        x = x.mean(dim = 1)

        x = self.to_latent(x)
        return self.linear_head(x)

if __name__ == '__main__':
    from thop import profile

    transformer = SigTransformer()
    transformer = SigTransformer(180, 1, 2)
    x = torch.randn(1, 3, 180)
    y = transformer(x)
    print(y.shape)

    flops, params = profile(transformer, (x,))
    print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')