Commit 6143d0ca authored by narugo1992's avatar narugo1992
Browse files

dev(narugo): use combined feature, ci skip

parent 5c45f8b7
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -161,13 +161,13 @@ class CamieTaggerRefined(nn.Module):

        # 7. Fuse features: average the cross-attended tag outputs, and combine with original features
        fused_feature = cross_attn.mean(dim=1)  # [B, embedding_dim]
        combined = torch.cat([feats, fused_feature], dim=1)  # [B, embedding_dim*2]
        combined_feature = torch.cat([feats, fused_feature], dim=1)  # [B, embedding_dim*2]

        # 8. Refined tag prediction
        refined_logits = self.refined_classifier(combined)  # [B, total_tags]
        refined_logits = self.refined_classifier(combined_feature)  # [B, total_tags]
        refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)

        return feats, initial_preds, fused_feature, refined_preds
        return feats, initial_preds, combined_feature, refined_preds


def create_refined_model():