Commit 46664c9b authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): refactor initial and refined model

parent f40f6991
Loading
Loading
Loading
Loading
+10 −921

File changed.

Preview size limit exceeded, changes collapsed.

+7 −5
Original line number Diff line number Diff line
@@ -58,7 +58,7 @@ class MultiheadAttentionNoFlash(nn.Module):
        return output


class ImageTaggerRefinedONNX(nn.Module):
class CamieTaggerRefined(nn.Module):
    """
    Refined CAMIE Image Tagger model without FlashAttention.
    - EfficientNetV2 backbone
@@ -68,7 +68,7 @@ class ImageTaggerRefinedONNX(nn.Module):
    - Refined classifier for final tag logits
    """

    def __init__(self, total_tags, tag_context_size=256, num_heads=16, dropout=0.1):
    def __init__(self, total_tags: int, tag_context_size: int = 256, num_heads: int = 16, dropout: float = 0.1):
        super().__init__()
        self.tag_context_size = tag_context_size
        self.embedding_dim = 1280  # EfficientNetV2-L feature dimension
@@ -165,7 +165,7 @@ class ImageTaggerRefinedONNX(nn.Module):
        refined_logits = self.refined_classifier(combined)  # [B, total_tags]
        refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)

        return initial_preds, refined_preds
        return feats, initial_preds, fused_feature, refined_preds


if __name__ == '__main__':
@@ -180,7 +180,7 @@ if __name__ == '__main__':
    # state_dict = torch.load("model_refined.pt", map_location="cpu")  # Load the saved weights (should be an OrderedDict)

    # Initialize our model and load weights
    model = ImageTaggerRefinedONNX(total_tags=total_tags)
    model = CamieTaggerRefined(total_tags=total_tags)
    model.load_state_dict(state_dict)
    model.eval()  # set to evaluation mode (disable dropout)
    print(model)
@@ -191,8 +191,10 @@ if __name__ == '__main__':
    # --- Export to ONNX ---
    dummy_input = torch.randn(1, 3, 512, 512, dtype=torch.float32)
    with torch.no_grad():
        dummy_init_logits, dummy_refined_logits = model(dummy_input)
        dummy_init_embeddings, dummy_init_logits, dummy_refined_embeddings, dummy_refined_logits = model(dummy_input)
    print(dummy_init_embeddings.shape, dummy_init_embeddings.dtype)
    print(dummy_init_logits.shape, dummy_init_logits.dtype)
    print(dummy_refined_embeddings.shape, dummy_refined_embeddings.dtype)
    print(dummy_refined_logits.shape, dummy_refined_logits.dtype)

    # output_onnx_file = "camie_refined_no_flash_v15.onnx"
+26 −0
Original line number Diff line number Diff line
import torch
from torch import nn


class InitialOnlyWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        init_embeddings, init_logits, _, _ = self.model(x)
        init_prediction = torch.sigmoid(init_logits)
        return init_embeddings, init_logits, init_prediction


class FullWrapper(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def forward(self, x):
        init_embeddings, init_logits, refined_embeddings, refined_logits = self.model(x)
        init_prediction = torch.sigmoid(init_logits)
        refined_prediction = torch.sigmoid(refined_logits)
        return init_embeddings, init_logits, init_prediction, \
            refined_embeddings, refined_logits, refined_prediction