Commit e7f48bdb authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add embedding check

parent 24d04e52
Loading
Loading
Loading
Loading
+10 −0
Original line number Diff line number Diff line
import numpy as np
import pytest
from PIL import Image

@@ -166,6 +167,7 @@ class TestGenericClassify:
                'scores-top10-descriptions': 'scores-top10-descriptions',
                'scores-top5-definitions': 'scores-top5-definitions',
                'scores-top5-descriptions': 'scores-top5-descriptions',
                'embedding': 'embedding',
            }
        )
        assert results['scores-top10'] == pytest.approx({
@@ -206,3 +208,11 @@ class TestGenericClassify:
            'screwdriver': 0.029927952215075493,
            'chain saw, chainsaw': 0.02070867270231247
        }, abs=1e-3)
        # np.save(get_testfile('png_640_emb.npy'), results['embedding'])
        assert results['embedding'].shape == (1280,)
        expected_embedding = np.load(get_testfile('png_640_emb.npy'))
        emb_1 = results['embedding'] / np.linalg.norm(results['embedding'], axis=-1, keepdims=True)
        emb_2 = expected_embedding / np.linalg.norm(expected_embedding, axis=-1, keepdims=True)
        emb_sims = (emb_1 * emb_2).sum()
        assert emb_sims >= 0.99, 'Direction not match with expected embedding.'
        assert np.linalg.norm(results['embedding']) == pytest.approx(np.linalg.norm(expected_embedding))
+5.13 KiB

File added.

No diff preview for this file type.