Commit c9a8adfe authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): fix threshold

parent 736333a1
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -65,7 +65,7 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000):

    s_poss, s_negs = s_poss.cpu(), s_negs.cpu()
    features = torch.cat([s_poss, s_negs]).detach().numpy()
    labels = torch.cat([torch.ones_like(s_poss), torch.zeros_like(s_negs)]).detach().numpy()
    labels = torch.cat([torch.ones_like(s_poss), -torch.ones_like(s_negs)]).detach().numpy()

    model = svm.SVC(kernel='linear')  # 线性核
    model.fit(features.reshape(-1, 1), labels)
@@ -73,7 +73,7 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000):

    coef = model.coef_.reshape(-1)[0].tolist()
    inter = model.intercept_.reshape(-1)[0].tolist()
    threshold = -coef / inter
    threshold = -inter / coef

    return poss.mean().item(), poss.std().item(), negs.mean().item(), negs.std().item(), \
           threshold, accuracy_score(labels, predictions)
@@ -81,7 +81,7 @@ def _sample_analysis(poss, negs, svm_samples: int = 10000):

def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optional[str] = None,
          train_ratio: float = 0.8, max_epochs: int = 500, group_size: int = 100,
          learning_rate: float = 0.001, weight_decay: float = 1e-3,
          learning_rate: float = 0.001, weight_decay: float = 1e-3, tau: float = 0.15,
          save_per_epoch: int = 10, eval_epoch: int = 5,
          model_name: str = 'clip/ViT-B/32', seed: Optional[int] = 0):
    if seed is not None:
@@ -129,7 +129,7 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio
    else:
        logging.info(f'No checkpoint found, new model will be used.')

    loss_fn = NTXentLoss().to(accelerator.device)
    loss_fn = NTXentLoss(tau=tau).to(accelerator.device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    scheduler = lr_scheduler.OneCycleLR(
        optimizer, max_lr=learning_rate,