Loading zoo/camie/model/initial.py +10 −921 File changed.Preview size limit exceeded, changes collapsed. Show changes zoo/camie/model/refined.py +7 −5 Original line number Diff line number Diff line Loading @@ -58,7 +58,7 @@ class MultiheadAttentionNoFlash(nn.Module): return output class ImageTaggerRefinedONNX(nn.Module): class CamieTaggerRefined(nn.Module): """ Refined CAMIE Image Tagger model without FlashAttention. - EfficientNetV2 backbone Loading @@ -68,7 +68,7 @@ class ImageTaggerRefinedONNX(nn.Module): - Refined classifier for final tag logits """ def __init__(self, total_tags, tag_context_size=256, num_heads=16, dropout=0.1): def __init__(self, total_tags: int, tag_context_size: int = 256, num_heads: int = 16, dropout: float = 0.1): super().__init__() self.tag_context_size = tag_context_size self.embedding_dim = 1280 # EfficientNetV2-L feature dimension Loading Loading @@ -165,7 +165,7 @@ class ImageTaggerRefinedONNX(nn.Module): refined_logits = self.refined_classifier(combined) # [B, total_tags] refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0) return initial_preds, refined_preds return feats, initial_preds, fused_feature, refined_preds if __name__ == '__main__': Loading @@ -180,7 +180,7 @@ if __name__ == '__main__': # state_dict = torch.load("model_refined.pt", map_location="cpu") # Load the saved weights (should be an OrderedDict) # Initialize our model and load weights model = ImageTaggerRefinedONNX(total_tags=total_tags) model = CamieTaggerRefined(total_tags=total_tags) model.load_state_dict(state_dict) model.eval() # set to evaluation mode (disable dropout) print(model) Loading @@ -191,8 +191,10 @@ if __name__ == '__main__': # --- Export to ONNX --- dummy_input = torch.randn(1, 3, 512, 512, dtype=torch.float32) with torch.no_grad(): dummy_init_logits, dummy_refined_logits = model(dummy_input) dummy_init_embeddings, dummy_init_logits, dummy_refined_embeddings, dummy_refined_logits = model(dummy_input) print(dummy_init_embeddings.shape, dummy_init_embeddings.dtype) print(dummy_init_logits.shape, dummy_init_logits.dtype) print(dummy_refined_embeddings.shape, dummy_refined_embeddings.dtype) print(dummy_refined_logits.shape, dummy_refined_logits.dtype) # output_onnx_file = "camie_refined_no_flash_v15.onnx" Loading zoo/camie/model/wrapper.py 0 → 100644 +26 −0 Original line number Diff line number Diff line import torch from torch import nn class InitialOnlyWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, x): init_embeddings, init_logits, _, _ = self.model(x) init_prediction = torch.sigmoid(init_logits) return init_embeddings, init_logits, init_prediction class FullWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, x): init_embeddings, init_logits, refined_embeddings, refined_logits = self.model(x) init_prediction = torch.sigmoid(init_logits) refined_prediction = torch.sigmoid(refined_logits) return init_embeddings, init_logits, init_prediction, \ refined_embeddings, refined_logits, refined_prediction Loading
zoo/camie/model/initial.py +10 −921 File changed.Preview size limit exceeded, changes collapsed. Show changes
zoo/camie/model/refined.py +7 −5 Original line number Diff line number Diff line Loading @@ -58,7 +58,7 @@ class MultiheadAttentionNoFlash(nn.Module): return output class ImageTaggerRefinedONNX(nn.Module): class CamieTaggerRefined(nn.Module): """ Refined CAMIE Image Tagger model without FlashAttention. - EfficientNetV2 backbone Loading @@ -68,7 +68,7 @@ class ImageTaggerRefinedONNX(nn.Module): - Refined classifier for final tag logits """ def __init__(self, total_tags, tag_context_size=256, num_heads=16, dropout=0.1): def __init__(self, total_tags: int, tag_context_size: int = 256, num_heads: int = 16, dropout: float = 0.1): super().__init__() self.tag_context_size = tag_context_size self.embedding_dim = 1280 # EfficientNetV2-L feature dimension Loading Loading @@ -165,7 +165,7 @@ class ImageTaggerRefinedONNX(nn.Module): refined_logits = self.refined_classifier(combined) # [B, total_tags] refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0) return initial_preds, refined_preds return feats, initial_preds, fused_feature, refined_preds if __name__ == '__main__': Loading @@ -180,7 +180,7 @@ if __name__ == '__main__': # state_dict = torch.load("model_refined.pt", map_location="cpu") # Load the saved weights (should be an OrderedDict) # Initialize our model and load weights model = ImageTaggerRefinedONNX(total_tags=total_tags) model = CamieTaggerRefined(total_tags=total_tags) model.load_state_dict(state_dict) model.eval() # set to evaluation mode (disable dropout) print(model) Loading @@ -191,8 +191,10 @@ if __name__ == '__main__': # --- Export to ONNX --- dummy_input = torch.randn(1, 3, 512, 512, dtype=torch.float32) with torch.no_grad(): dummy_init_logits, dummy_refined_logits = model(dummy_input) dummy_init_embeddings, dummy_init_logits, dummy_refined_embeddings, dummy_refined_logits = model(dummy_input) print(dummy_init_embeddings.shape, dummy_init_embeddings.dtype) print(dummy_init_logits.shape, dummy_init_logits.dtype) print(dummy_refined_embeddings.shape, dummy_refined_embeddings.dtype) print(dummy_refined_logits.shape, dummy_refined_logits.dtype) # output_onnx_file = "camie_refined_no_flash_v15.onnx" Loading
zoo/camie/model/wrapper.py 0 → 100644 +26 −0 Original line number Diff line number Diff line import torch from torch import nn class InitialOnlyWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, x): init_embeddings, init_logits, _, _ = self.model(x) init_prediction = torch.sigmoid(init_logits) return init_embeddings, init_logits, init_prediction class FullWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model def forward(self, x): init_embeddings, init_logits, refined_embeddings, refined_logits = self.model(x) init_prediction = torch.sigmoid(init_logits) refined_prediction = torch.sigmoid(refined_logits) return init_embeddings, init_logits, init_prediction, \ refined_embeddings, refined_logits, refined_prediction