Commit 18a9bad5 authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): add more profiles, ci skip

parent 82cf1103
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -91,7 +91,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):
        with open(new_meta_file, 'w') as f:
            json.dump(meta_info, f, indent=4, sort_keys=True, ensure_ascii=False)

        wrapped_model, (conv_features, _) = get_model(handler.model, dummy_input)
        wrapped_model, (conv_features, _, _) = get_model(handler.model, dummy_input)
        conv_features = conv_features.detach().cpu()
        onnx_filename = os.path.join(upload_dir, 'model.onnx')
        with TemporaryDirectory() as td:
@@ -102,11 +102,12 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False):
                dummy_input,
                temp_model_onnx,
                input_names=['input'],
                output_names=['embedding', 'output'],
                output_names=['embedding', 'logits', 'prediction'],
                dynamic_axes={
                    'input': {0: 'batch_size'},
                    'embedding': {0: 'batch_size'},
                    'output': {0: 'batch_size'},
                    'logits': {0: 'batch_size'},
                    'prediction': {0: 'batch_size'},
                },
                opset_version=14,
                do_constant_folding=True,
+2 −1
Original line number Diff line number Diff line
@@ -22,10 +22,11 @@ class TaggingHead(torch.nn.Module):
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.head = torch.nn.Sequential(torch.nn.Linear(input_dim, num_classes))
        self.sigmoid = torch.nn.Sigmoid()

    def forward(self, x):
        logits = self.head(x)
        probs = torch.nn.functional.sigmoid(logits)
        probs = self.sigmoid(logits)
        return probs


+21 −8
Original line number Diff line number Diff line
@@ -5,44 +5,57 @@ from torch import nn


class ModuleWrapper(nn.Module):
    def __init__(self, base_module: nn.Module, classifier: nn.Module):
    def __init__(self, base_module: nn.Module, classifier: nn.Module, sigmoid: nn.Module):
        super().__init__()
        self.base_module = base_module
        self.classifier = classifier
        self.sigmoid = sigmoid

        self._output_features = None
        self._output_logits = None
        self._register_hook()

    def _register_hook(self):
        def hook_fn(module, input_tensor, output_tensor):
        def hook_fn_embeddings(module, input_tensor, output_tensor):
            assert isinstance(input_tensor, tuple) and len(input_tensor) == 1
            input_tensor = input_tensor[0]
            self._output_features = input_tensor

        self.classifier.register_forward_hook(hook_fn)
        self.classifier.register_forward_hook(hook_fn_embeddings)

        def hook_fn_logits(module, input_tensor, output_tensor):
            assert isinstance(input_tensor, tuple) and len(input_tensor) == 1
            input_tensor = input_tensor[0]
            self._output_logits = input_tensor

        self.sigmoid.register_forward_hook(hook_fn_logits)

    def forward(self, x: torch.Tensor):
        preds = self.base_module(x)

        if self._output_features is None:
            raise RuntimeError("Target module did not receive any input during forward pass")
            raise RuntimeError("Target module did not receive any input during forward pass (features)")
        if self._output_logits is None:
            raise RuntimeError("Target module did not receive any input during forward pass (logits)")
        features, self._output_features = self._output_features, None
        logits, self._output_logits = self._output_logits, None
        assert all([x == 1 for x in features.shape[2:]]), f'Invalid feature shape: {features.shape!r}'
        features = torch.flatten(features, start_dim=1)

        return features, preds
        return features, logits, preds


def get_model(model: nn.Module, dummy_input: torch.Tensor):
    assert isinstance(model, nn.Sequential)
    head = model[-1]
    wrapped_model = ModuleWrapper(model, head)
    wrapped_model = ModuleWrapper(model, head, head.sigmoid)

    logging.info(f'Input size: {dummy_input.shape!r}')
    with torch.no_grad():
        dummy_embedding, dummy_preds = wrapped_model(dummy_input)
        dummy_embedding, dummy_logits, dummy_preds = wrapped_model(dummy_input)
    logging.info(f'Embedding size: {dummy_embedding.shape!r}')
    logging.info(f'Logits size: {dummy_preds.shape!r}')
    logging.info(f'Preds size: {dummy_preds.shape!r}')

    return wrapped_model, (dummy_embedding, dummy_preds)
    return wrapped_model, (dummy_embedding, dummy_logits, dummy_preds)
    # print(model[-1])