Commit 9ec98366 authored by Phil Wang's avatar Phil Wang
Browse files

fix coarse transformer wrapper loss if semantic loss weight set to 0

parent 504a7e17
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -1463,12 +1463,16 @@ class CoarseTransformerWrapper(nn.Module):
        coarse_logits, semantic_logits = map(lambda t: maybe(rearrange)(t, 'b n c -> b c n'), (coarse_logits, semantic_logits))

        if self.unique_consecutive:
            num_coarse_logits, num_semantic_logits = coarse_labels.numel(), (semantic_labels != self.pad_id).sum()
            num_coarse_logits, _num_semantic_logits = coarse_labels.numel(), (semantic_labels != self.pad_id).sum()
        else:
            num_coarse_logits, num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1]
            num_coarse_logits, _num_semantic_logits = coarse_logits.shape[-1], semantic_logits.shape[-1]

        semantic_loss = 0.
        num_semantic_logits = 0

        if self.semantic_cross_entropy_loss_weight > 0 and exists(semantic_logits):
            num_semantic_logits = _num_semantic_logits

            semantic_loss = F.cross_entropy(
                semantic_logits,
                semantic_labels,
+1 −1
Original line number Diff line number Diff line
__version__ = '0.30.2'
__version__ = '0.30.3'