Loading zoo/ccip/train_.py +4 −4 Original line number Diff line number Diff line Loading @@ -196,10 +196,10 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio outputs = model.diff(dists.reshape(-1, 1)) preds = torch.argmax(outputs, dim=1) train_correct += (preds == labels).sum().item() train_fp += (preds[labels == 0] == 1).sum().item() train_fn += (preds[labels == 1] == 0).sum().item() train_total += labels.shape[0] test_correct += (preds == labels).sum().item() test_fp += (preds[labels == 0] == 1).sum().item() test_fn += (preds[labels == 1] == 0).sum().item() test_total += labels.shape[0] test_accuracy = test_correct / test_total test_fp_p = test_fp / test_total Loading Loading
zoo/ccip/train_.py +4 −4 Original line number Diff line number Diff line Loading @@ -196,10 +196,10 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio outputs = model.diff(dists.reshape(-1, 1)) preds = torch.argmax(outputs, dim=1) train_correct += (preds == labels).sum().item() train_fp += (preds[labels == 0] == 1).sum().item() train_fn += (preds[labels == 1] == 0).sum().item() train_total += labels.shape[0] test_correct += (preds == labels).sum().item() test_fp += (preds[labels == 0] == 1).sum().item() test_fn += (preds[labels == 1] == 0).sum().item() test_total += labels.shape[0] test_accuracy = test_correct / test_total test_fp_p = test_fp / test_total Loading