Loading zoo/pixai_tagger/export.py +4 −3 Original line number Diff line number Diff line Loading @@ -91,7 +91,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): with open(new_meta_file, 'w') as f: json.dump(meta_info, f, indent=4, sort_keys=True, ensure_ascii=False) wrapped_model, (conv_features, _) = get_model(handler.model, dummy_input) wrapped_model, (conv_features, _, _) = get_model(handler.model, dummy_input) conv_features = conv_features.detach().cpu() onnx_filename = os.path.join(upload_dir, 'model.onnx') with TemporaryDirectory() as td: Loading @@ -102,11 +102,12 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): dummy_input, temp_model_onnx, input_names=['input'], output_names=['embedding', 'output'], output_names=['embedding', 'logits', 'prediction'], dynamic_axes={ 'input': {0: 'batch_size'}, 'embedding': {0: 'batch_size'}, 'output': {0: 'batch_size'}, 'logits': {0: 'batch_size'}, 'prediction': {0: 'batch_size'}, }, opset_version=14, do_constant_folding=True, Loading zoo/pixai_tagger/min_script.py +2 −1 Original line number Diff line number Diff line Loading @@ -22,10 +22,11 @@ class TaggingHead(torch.nn.Module): self.input_dim = input_dim self.num_classes = num_classes self.head = torch.nn.Sequential(torch.nn.Linear(input_dim, num_classes)) self.sigmoid = torch.nn.Sigmoid() def forward(self, x): logits = self.head(x) probs = torch.nn.functional.sigmoid(logits) probs = self.sigmoid(logits) return probs Loading zoo/pixai_tagger/onnx.py +21 −8 Original line number Diff line number Diff line Loading @@ -5,44 +5,57 @@ from torch import nn class ModuleWrapper(nn.Module): def __init__(self, base_module: nn.Module, classifier: nn.Module): def __init__(self, base_module: nn.Module, classifier: nn.Module, sigmoid: nn.Module): super().__init__() self.base_module = base_module self.classifier = classifier self.sigmoid = sigmoid self._output_features = None self._output_logits = None self._register_hook() def _register_hook(self): def hook_fn(module, input_tensor, output_tensor): def hook_fn_embeddings(module, input_tensor, output_tensor): assert isinstance(input_tensor, tuple) and len(input_tensor) == 1 input_tensor = input_tensor[0] self._output_features = input_tensor self.classifier.register_forward_hook(hook_fn) self.classifier.register_forward_hook(hook_fn_embeddings) def hook_fn_logits(module, input_tensor, output_tensor): assert isinstance(input_tensor, tuple) and len(input_tensor) == 1 input_tensor = input_tensor[0] self._output_logits = input_tensor self.sigmoid.register_forward_hook(hook_fn_logits) def forward(self, x: torch.Tensor): preds = self.base_module(x) if self._output_features is None: raise RuntimeError("Target module did not receive any input during forward pass") raise RuntimeError("Target module did not receive any input during forward pass (features)") if self._output_logits is None: raise RuntimeError("Target module did not receive any input during forward pass (logits)") features, self._output_features = self._output_features, None logits, self._output_logits = self._output_logits, None assert all([x == 1 for x in features.shape[2:]]), f'Invalid feature shape: {features.shape!r}' features = torch.flatten(features, start_dim=1) return features, preds return features, logits, preds def get_model(model: nn.Module, dummy_input: torch.Tensor): assert isinstance(model, nn.Sequential) head = model[-1] wrapped_model = ModuleWrapper(model, head) wrapped_model = ModuleWrapper(model, head, head.sigmoid) logging.info(f'Input size: {dummy_input.shape!r}') with torch.no_grad(): dummy_embedding, dummy_preds = wrapped_model(dummy_input) dummy_embedding, dummy_logits, dummy_preds = wrapped_model(dummy_input) logging.info(f'Embedding size: {dummy_embedding.shape!r}') logging.info(f'Logits size: {dummy_preds.shape!r}') logging.info(f'Preds size: {dummy_preds.shape!r}') return wrapped_model, (dummy_embedding, dummy_preds) return wrapped_model, (dummy_embedding, dummy_logits, dummy_preds) # print(model[-1]) Loading
zoo/pixai_tagger/export.py +4 −3 Original line number Diff line number Diff line Loading @@ -91,7 +91,7 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): with open(new_meta_file, 'w') as f: json.dump(meta_info, f, indent=4, sort_keys=True, ensure_ascii=False) wrapped_model, (conv_features, _) = get_model(handler.model, dummy_input) wrapped_model, (conv_features, _, _) = get_model(handler.model, dummy_input) conv_features = conv_features.detach().cpu() onnx_filename = os.path.join(upload_dir, 'model.onnx') with TemporaryDirectory() as td: Loading @@ -102,11 +102,12 @@ def sync(src_repo: str, dst_repo: str, no_optimize: bool = False): dummy_input, temp_model_onnx, input_names=['input'], output_names=['embedding', 'output'], output_names=['embedding', 'logits', 'prediction'], dynamic_axes={ 'input': {0: 'batch_size'}, 'embedding': {0: 'batch_size'}, 'output': {0: 'batch_size'}, 'logits': {0: 'batch_size'}, 'prediction': {0: 'batch_size'}, }, opset_version=14, do_constant_folding=True, Loading
zoo/pixai_tagger/min_script.py +2 −1 Original line number Diff line number Diff line Loading @@ -22,10 +22,11 @@ class TaggingHead(torch.nn.Module): self.input_dim = input_dim self.num_classes = num_classes self.head = torch.nn.Sequential(torch.nn.Linear(input_dim, num_classes)) self.sigmoid = torch.nn.Sigmoid() def forward(self, x): logits = self.head(x) probs = torch.nn.functional.sigmoid(logits) probs = self.sigmoid(logits) return probs Loading
zoo/pixai_tagger/onnx.py +21 −8 Original line number Diff line number Diff line Loading @@ -5,44 +5,57 @@ from torch import nn class ModuleWrapper(nn.Module): def __init__(self, base_module: nn.Module, classifier: nn.Module): def __init__(self, base_module: nn.Module, classifier: nn.Module, sigmoid: nn.Module): super().__init__() self.base_module = base_module self.classifier = classifier self.sigmoid = sigmoid self._output_features = None self._output_logits = None self._register_hook() def _register_hook(self): def hook_fn(module, input_tensor, output_tensor): def hook_fn_embeddings(module, input_tensor, output_tensor): assert isinstance(input_tensor, tuple) and len(input_tensor) == 1 input_tensor = input_tensor[0] self._output_features = input_tensor self.classifier.register_forward_hook(hook_fn) self.classifier.register_forward_hook(hook_fn_embeddings) def hook_fn_logits(module, input_tensor, output_tensor): assert isinstance(input_tensor, tuple) and len(input_tensor) == 1 input_tensor = input_tensor[0] self._output_logits = input_tensor self.sigmoid.register_forward_hook(hook_fn_logits) def forward(self, x: torch.Tensor): preds = self.base_module(x) if self._output_features is None: raise RuntimeError("Target module did not receive any input during forward pass") raise RuntimeError("Target module did not receive any input during forward pass (features)") if self._output_logits is None: raise RuntimeError("Target module did not receive any input during forward pass (logits)") features, self._output_features = self._output_features, None logits, self._output_logits = self._output_logits, None assert all([x == 1 for x in features.shape[2:]]), f'Invalid feature shape: {features.shape!r}' features = torch.flatten(features, start_dim=1) return features, preds return features, logits, preds def get_model(model: nn.Module, dummy_input: torch.Tensor): assert isinstance(model, nn.Sequential) head = model[-1] wrapped_model = ModuleWrapper(model, head) wrapped_model = ModuleWrapper(model, head, head.sigmoid) logging.info(f'Input size: {dummy_input.shape!r}') with torch.no_grad(): dummy_embedding, dummy_preds = wrapped_model(dummy_input) dummy_embedding, dummy_logits, dummy_preds = wrapped_model(dummy_input) logging.info(f'Embedding size: {dummy_embedding.shape!r}') logging.info(f'Logits size: {dummy_preds.shape!r}') logging.info(f'Preds size: {dummy_preds.shape!r}') return wrapped_model, (dummy_embedding, dummy_preds) return wrapped_model, (dummy_embedding, dummy_logits, dummy_preds) # print(model[-1])