Commit 65495ad5 authored by Phil Wang's avatar Phil Wang
Browse files

remove use of eos in fine transformer

parent b6eeb93f
Loading
Loading
Loading
Loading
+14 −19
Original line number Diff line number Diff line
@@ -773,10 +773,8 @@ class FineTransformer(nn.Module):
        self.coarse_start_token = nn.Parameter(torch.randn(dim))
        self.fine_start_token = nn.Parameter(torch.randn(dim))

        codebook_size_with_eos = codebook_size + 1

        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size_with_eos, dim)
        self.fine_embedding = nn.Embedding(num_fine_quantizers * codebook_size_with_eos, dim)
        self.coarse_embedding = nn.Embedding(num_coarse_quantizers * codebook_size, dim)
        self.fine_embedding = nn.Embedding(num_fine_quantizers * codebook_size, dim)

        self.eos_id = codebook_size

@@ -799,8 +797,8 @@ class FineTransformer(nn.Module):
        self.num_coarse_quantizers = num_coarse_quantizers
        self.num_fine_quantizers = num_fine_quantizers

        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size_with_eos, dim)) if project_coarse_logits else None
        self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size_with_eos, dim))
        self.coarse_logit_weights = nn.Parameter(torch.randn(num_coarse_quantizers, codebook_size, dim)) if project_coarse_logits else None
        self.fine_logit_weights = nn.Parameter(torch.randn(num_fine_quantizers, codebook_size, dim))

    @property
    def device(self):
@@ -1313,7 +1311,8 @@ class CoarseTransformerWrapper(nn.Module):

        coarse_loss = F.cross_entropy(
            coarse_logits,
            coarse_labels
            coarse_labels,
            ignore_index = self.pad_id
        )

        return (
@@ -1415,9 +1414,6 @@ class FineTransformerWrapper(nn.Module):

                last_fine_logits = fine_logits[:, -1]

                if not is_last_step:
                    last_fine_logits[:, -1] = float('-inf') # prevent from eos if not last quantizer step, but move this to masking logic within the transformer at some point, for both training and eval

                filtered_logits = top_k(last_fine_logits, thres = filter_thres)
                sampled = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)

@@ -1435,8 +1431,6 @@ class FineTransformerWrapper(nn.Module):

        if mask_out_generated_fine_tokens:
            pos_is_all_padding = (coarse_token_ids == self.pad_id).all(dim = -1, keepdim = True)
            seq_lengths = reduce(~pos_is_all_padding, 'b n 1 -> b', 'sum')

            sampled_fine_token_ids = sampled_fine_token_ids.masked_fill(pos_is_all_padding, self.pad_id)

        # if not reconstructing wave, return just the fine token ids
@@ -1485,17 +1479,16 @@ class FineTransformerWrapper(nn.Module):
        coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
        fine_token_ids = rearrange(fine_token_ids, 'b ... -> b (...)')

        if self.training:
            coarse_token_ids = append_eos_id(coarse_token_ids, self.transformer.eos_id)
            fine_token_ids = append_eos_id(fine_token_ids, self.transformer.eos_id)
        # if training, determine labels, should remove one from fine token ids

        if return_loss:
            coarse_labels, fine_labels = coarse_token_ids, fine_token_ids.clone()
            coarse_labels = coarse_token_ids
            fine_labels = fine_token_ids
            fine_token_ids = fine_token_ids[:, :-1]

        # do not attend to any of the coarse padding tokens or coarse end token either

        self_attn_mask = (coarse_token_ids != self.pad_id) & (coarse_token_ids != self.eos_id)
        self_attn_mask = coarse_token_ids != self.pad_id
        coarse_token_ids = coarse_token_ids.masked_fill(~self_attn_mask, 0)

        fine_token_seq_len = fine_token_ids.shape[-1]
@@ -1532,12 +1525,14 @@ class FineTransformerWrapper(nn.Module):

            coarse_loss = F.cross_entropy(
                coarse_logits,
                coarse_labels
                coarse_labels,
                ignore_index = self.pad_id
            )

        fine_loss = F.cross_entropy(
            fine_logits,
            fine_labels
            fine_labels,
            ignore_index = self.pad_id
        )

        return (
+1 −1
Original line number Diff line number Diff line
__version__ = '0.18.2'
__version__ = '0.19.0'