Commit 3e1f569e authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): fix a bug

parent 5eb326d6
Loading
Loading
Loading
Loading
+4 −4
Original line number Diff line number Diff line
@@ -196,10 +196,10 @@ def train(dataset_dir: str, session_name: Optional[str] = None, from_ckpt: Optio

                    outputs = model.diff(dists.reshape(-1, 1))
                    preds = torch.argmax(outputs, dim=1)
                    train_correct += (preds == labels).sum().item()
                    train_fp += (preds[labels == 0] == 1).sum().item()
                    train_fn += (preds[labels == 1] == 0).sum().item()
                    train_total += labels.shape[0]
                    test_correct += (preds == labels).sum().item()
                    test_fp += (preds[labels == 0] == 1).sum().item()
                    test_fn += (preds[labels == 1] == 0).sum().item()
                    test_total += labels.shape[0]

                test_accuracy = test_correct / test_total
                test_fp_p = test_fp / test_total