Commit 733bf5b8 authored by dzy7e's avatar dzy7e
Browse files

WeakRandAugment

parent 7a3d6b3d
Loading
Loading
Loading
Loading
+22 −5
Original line number Diff line number Diff line
@@ -11,16 +11,33 @@ from torchvision import transforms
from imgutils.data import load_image
from .prob import get_reg_for_prob

class WeakRandAugment(transforms.RandAugment):
    def _augmentation_space(self, num_bins: int, image_size: Tuple[int, int]) -> Dict[str, Tuple[torch.Tensor, bool]]:
        return {
            # op_name: (magnitudes, signed)
            "Identity": (torch.tensor(0.0), False),
            "ShearX": (torch.linspace(0.0, 0.1, num_bins), True),
            "ShearY": (torch.linspace(0.0, 0.1, num_bins), True),
            "TranslateX": (torch.linspace(0.0, 0.1 * image_size[1], num_bins), True),
            "TranslateY": (torch.linspace(0.0, 0.1 * image_size[0], num_bins), True),
            "Rotate": (torch.linspace(0.0, 8.0, num_bins), True),
            "Brightness": (torch.linspace(0.0, 0.1, num_bins), True),
            "Contrast": (torch.linspace(0.0, 0.1, num_bins), True),
            "Sharpness": (torch.linspace(0.0, 0.2, num_bins), True),
            "Posterize": (8 - (torch.arange(num_bins) / ((num_bins - 1) / 4)).round().int(), False),
            "AutoContrast": (torch.tensor(0.0), False),
            "Equalize": (torch.tensor(0.0), False),
        }

TRAIN_TRANSFORM = [
    transforms.Resize((416, 416)),
    transforms.RandomRotation((-15, 15)),
    transforms.RandomCrop(384),
    transforms.Resize((272, 272)),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(0.10, 0.10, 0.10, 0.10),
    WeakRandAugment(),
    transforms.RandomCrop(256),
    transforms.ToTensor(),
]
TEST_TRANSFORM = [
    transforms.Resize((384, 384)),
    transforms.Resize((256, 256)),
    #transforms.c(384),
    transforms.ToTensor(),
]