Loading zoo/camie/model/refined.py +3 −3 Original line number Diff line number Diff line Loading @@ -161,13 +161,13 @@ class CamieTaggerRefined(nn.Module): # 7. Fuse features: average the cross-attended tag outputs, and combine with original features fused_feature = cross_attn.mean(dim=1) # [B, embedding_dim] combined = torch.cat([feats, fused_feature], dim=1) # [B, embedding_dim*2] combined_feature = torch.cat([feats, fused_feature], dim=1) # [B, embedding_dim*2] # 8. Refined tag prediction refined_logits = self.refined_classifier(combined) # [B, total_tags] refined_logits = self.refined_classifier(combined_feature) # [B, total_tags] refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0) return feats, initial_preds, fused_feature, refined_preds return feats, initial_preds, combined_feature, refined_preds def create_refined_model(): Loading Loading
zoo/camie/model/refined.py +3 −3 Original line number Diff line number Diff line Loading @@ -161,13 +161,13 @@ class CamieTaggerRefined(nn.Module): # 7. Fuse features: average the cross-attended tag outputs, and combine with original features fused_feature = cross_attn.mean(dim=1) # [B, embedding_dim] combined = torch.cat([feats, fused_feature], dim=1) # [B, embedding_dim*2] combined_feature = torch.cat([feats, fused_feature], dim=1) # [B, embedding_dim*2] # 8. Refined tag prediction refined_logits = self.refined_classifier(combined) # [B, total_tags] refined_logits = self.refined_classifier(combined_feature) # [B, total_tags] refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0) return feats, initial_preds, fused_feature, refined_preds return feats, initial_preds, combined_feature, refined_preds def create_refined_model(): Loading