Commit d2535557 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): oh damn it

parent c9a8adfe
Loading
Loading
Loading
Loading
+8 −4
Original line number Diff line number Diff line
@@ -87,10 +87,14 @@ class CharacterDataset(Dataset):
    def __getitem__(self, item):
        idx = self._id_map[item]
        current_samples = self._x_to_y(self.group_size)
        if current_samples > len(self.groups[idx]) and self.force_prob:
        if current_samples > len(self.groups[idx]):
            if self.force_prob:
                total_samples = self._y_to_x(len(self.groups[idx]))
                current_samples = self._x_to_y(total_samples)
                ex_samples = total_samples - current_samples
            else:
                current_samples = len(self.groups[idx])
                ex_samples = self.group_size - current_samples
        else:
            ex_samples = self.group_size - current_samples

+12 −6
Original line number Diff line number Diff line
@@ -38,11 +38,17 @@ class NTXentLoss(nn.Module):
        self.register_buffer('tau', torch.as_tensor(tau, dtype=torch.float))
        self.register_buffer('eps', torch.as_tensor(eps, dtype=torch.float))

    def forward(self, sim_tensors, state_tensors):
    def forward(self, similarities, state):
        """
        :param sim_tensors: Similarities, float32[N]
        :param state_tensors: Positive sample or not, bool[N]
        :param similarities: Similarities, float32[N]
        :param state: Positive sample or not, bool[N]
        """
        log_items = -torch.log(torch.softmax(sim_tensors / self.tau, dim=-1))
        positive_items = log_items[state_tensors]
        return (positive_items.sum() + self.eps) / (positive_items.shape[0] + self.eps)
        negs = similarities[~state]
        pos_items = []
        for pos in similarities[state]:
            current_sims = torch.cat([pos.reshape(-1), negs])
            exp_sims = torch.exp(current_sims / self.tau)
            pos_items.append(-torch.log(exp_sims[0] / exp_sims.sum()))

        pos_tensor = torch.stack(pos_items)
        return (pos_tensor.sum() + self.eps) / (pos_tensor.shape[0] + self.eps)
+2 −2
Original line number Diff line number Diff line
@@ -80,7 +80,7 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000):


def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 100,
          train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 30,
          learning_rate: float = 0.001, weight_decay: float = 1e-3, tau: float = 0.15,
          save_per_epoch: int = 10, eval_epoch: int = 5,
          model_name: str = 'clip/ViT-B/32', seed: Optional[int] = 0):
@@ -117,7 +117,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    train_image_dataset.transform = Compose([*TRAIN_TRANSFORM.transforms, *model.preprocess.transforms])
    test_image_dataset.transform = Compose([*TEST_TRANSFORM.transforms, *model.preprocess.transforms])

    train_dataset = CharacterDataset(train_image_dataset, group_size)
    train_dataset = CharacterDataset(train_image_dataset, group_size, force_prob=False)
    test_dataset = CharacterDataset(test_image_dataset, group_size)

    if from_ckpt is None: