Commit 91cee2e4 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): addd x

parent 8a973bd2
Loading
Loading
Loading
Loading
+13 −0
Original line number Diff line number Diff line
@@ -103,6 +103,19 @@ def extract(export_dir: str, model_name: str = "tagger_v_2_2_7", no_optimize: bo
    logging.info(f'{ratio * 100:.2f}% of the logits value are the same.')
    assert close_matrix.all(), 'Not all values can match.'

    matrix_data_file = os.path.join(export_dir, 'matrix.npz')
    bias = classifier.bias.detach().numpy()
    weight = classifier.weight.detach().numpy().T
    logging.info(f'Saving matrix data file to {matrix_data_file!r}, '
                 f'bias: {bias.dtype!r}{bias.shape!r}, weight: {weight.dtype!r}{weight.shape!r}.')
    np.savez(
        matrix_data_file,
        bias=bias,
        weight=weight,
    )
    expected_logits = conv_features.detach().numpy() @ weight + bias
    np.testing.assert_allclose(conv_output.detach().numpy(), expected_logits, rtol=1e-03, atol=1e-05)

    logging.info('Profiling model ...')
    macs, params = profile(model, inputs=(dummy_input,))
    s_macs, s_params = clever_format([macs, params], "%.1f")