Commit aa5c1611 authored by dzy7e's avatar dzy7e
Browse files

caformer

parent c6e04ad2
Loading
Loading
Loading
Loading
+6 −6
Original line number Diff line number Diff line
@@ -133,14 +133,14 @@ class LeViT(nn.Module):
        self,
        image_size=384,
        num_classes=2,
        dim = (256, 384, 512),
        depth = 4,
        dim = (384, 512, 768),
        depth = 8,
        heads = (4, 6, 8),
        mlp_mult = 2,
        mlp_mult = 3,
        stages = 3,
        dim_key = 32,
        dim_value = 64,
        dropout = 0.,
        dim_key = 64,
        dim_value = 128,
        dropout = 0.2,
        num_distill_classes = None
    ):
        super().__init__()
+37 −0
Original line number Diff line number Diff line
import torch
from torch import nn
from timm.models import create_model
import zoo.monochrome.metaformer_timm # register models

class CAFormerBuilder:
    __model_name__ = 'caformer'
    def __init__(self, arch='caformer_m36_384_in21ft1k', num_classes=2, drop_path_rate=0.4):
        self.create_model_args = dict(
            model_name=arch,
            pretrained=True,
            num_classes=num_classes,
            drop_rate=0.0,
            drop_connect_rate=None,  # DEPRECATED, use drop_path
            drop_path_rate=drop_path_rate,
            drop_block_rate=None,
            global_pool=None,
            bn_momentum=None,
            bn_eps=None,
            scriptable=False,
            checkpoint_path=None
        )
        self.num_classes=num_classes

    def __call__(self, *args, **kwargs):
        model = create_model(**self.create_model_args)
        return model

if __name__ == '__main__':
    from thop import profile

    transformer = CAFormerBuilder()()
    x = torch.randn(1, 3, 384, 384)

    flops, params = profile(transformer, (x,))
    print('FLOPs = ' + str(flops / 1000 ** 3) + 'G')
    print('Params = ' + str(params / 1000 ** 2) + 'M')
 No newline at end of file
+1643 −0

File added.

Preview size limit exceeded, changes collapsed.

+4 −0
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ from .resnet import ResNet18, ResNet34, ResNet50, ResNet101, ResNet152
from .transformer import SigTransformer
from .levit1d import LeSigTransformer
from .levit2d import LeViT
from .metaformer import CAFormerBuilder
from ..base import _TRAIN_DIR as _GLOBAL_TRAIN_DIR

_TRAIN_DIR = os.path.join(_GLOBAL_TRAIN_DIR, 'monochrome')
@@ -44,6 +45,7 @@ _register_model(ResNet152)
_register_model(SigTransformer)
_register_model(LeSigTransformer)
_register_model(LeViT)
_register_model(CAFormerBuilder)


def _find_latest_ckpt(name: str) -> Optional[str]:
@@ -140,6 +142,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        running_loss = 0.0
        train_correct, train_total = 0, 0
        train_fp, train_fn = 0, 0
        model.train()
        for i, (inputs, labels) in enumerate(tqdm(train_dataloader)):
            inputs = inputs.float()
            inputs = inputs.to(accelerator.device)
@@ -178,6 +181,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
                writer.add_scalar('train/fp', train_fp_p, epoch)
                writer.add_scalar('train/fn', train_fn_p, epoch)

        model.eval()
        if epoch % eval_epoch == 0:
            with torch.no_grad():
                test_correct, test_total = 0, 0