Commit 89cdca3d authored by 7eu7d7's avatar 7eu7d7
Browse files

zero grad

parent 4474bbc0
Loading
Loading
Loading
Loading
+2 −3
Original line number Diff line number Diff line
torch<2
torch
lpips
matplotlib
torchvision
@@ -16,4 +16,3 @@ accelerate
timm
ftfy
regex
 No newline at end of file
git+https://github.com/openai/CLIP.git
 No newline at end of file
+2 −2
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@ from imgutils.data import load_image
from .prob import get_reg_for_prob

TRAIN_TRANSFORM = [
    transforms.Resize(416),
    transforms.Resize((416, 416)),
    transforms.RandomRotation((-15, 15)),
    transforms.RandomCrop(384),
    transforms.RandomHorizontalFlip(),
@@ -21,7 +21,7 @@ TRAIN_TRANSFORM = [
    transforms.ToTensor(),
]
TEST_TRANSFORM = [
    transforms.Resize(416),
    transforms.Resize((416, 416)),
    transforms.CenterCrop(384),
    transforms.ToTensor(),
]
+1 −1
Original line number Diff line number Diff line
@@ -34,7 +34,7 @@ class CCIPFeature(nn.Module):

    def forward(self, x):
        x = self.backbone(x)
        x = x / x.norm(dim=-1, keepdim=True)
        #x = x / x.norm(dim=-1, keepdim=True)
        return x


+5 −2
Original line number Diff line number Diff line
@@ -136,6 +136,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
        steps_per_epoch=len(train_dataloader), epochs=max_epochs,
        pct_start=0.15, final_div_factor=20.
    )
    #model = torch.compile(model)

    model, optimizer, train_dataloader, test_dataloader, scheduler = \
        accelerator.prepare(model, optimizer, train_dataloader, test_dataloader, scheduler)
@@ -161,11 +162,12 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
            accelerator.backward(loss)
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()

            running_loss += loss.item()*len(char_ids)
            train_pos_total += len(char_ids)

            mask = torch.ones_like(outputs).bool()
            mask = torch.ones_like(outputs).bool().cpu()
            mask ^= torch.diag_embed(torch.diag(mask))
            outputs = outputs.detach().cpu()
            gt_same = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu()
@@ -177,6 +179,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
                    mean_loss = running_loss/train_pos_total
                    if writer:
                        writer.add_scalar('train/loss', mean_loss, (epoch-1)*num_iter + i)
                        writer.add_scalar('train/lr', scheduler.get_last_lr()[0], (epoch-1)*num_iter + i)
                    running_loss = 0.
                    train_pos_total = 0

@@ -206,7 +209,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

                    outputs = model(inputs)  # BxB

                    mask = torch.ones_like(outputs).bool()
                    mask = torch.ones_like(outputs).bool().cpu()
                    mask ^= torch.diag_embed(torch.diag(mask))
                    outputs = outputs.detach().cpu()
                    gt_same = (char_ids.view(-1, 1) == char_ids.view(1, -1)).detach().cpu()