Commit 70f02c5b authored by Phil Wang's avatar Phil Wang
Browse files

semantic token ids will have variable lengths because of unique consecutive,...

semantic token ids will have variable lengths because of unique consecutive, so eos token must be manually selected and then used to predict the first coarse token, in the coarse transformer
parent c6bcd118
Loading
Loading
Loading
Loading
+87 −73
Original line number Diff line number Diff line
@@ -361,6 +361,7 @@ class SemanticTransformer(nn.Module):
        cond_scale = 3,
        filter_thres = 0.9,
        temperature = 1.,
        include_eos_in_output = True,  # if doing hierarchical sampling, eos must be kept for an easy time
        **kwargs
    ):
        device = self.device
@@ -422,7 +423,7 @@ class SemanticTransformer(nn.Module):

            last_logit_indices += 1

        output = mask_out_after_eos_id(output, self.pad_id, include_eos = False)
        output = mask_out_after_eos_id(output, self.pad_id, include_eos = include_eos_in_output)
        return output

    def forward_with_cond_scale(
@@ -612,7 +613,17 @@ class CoarseTransformer(nn.Module):

        tokens = self.transformer(tokens, context = text_embeds, self_attn_mask = self_attn_mask, context_mask = text_mask)

        pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, semantic_seq_len:]
        pred_semantic_tokens, pred_coarse_tokens = tokens[:, :semantic_seq_len], tokens[:, (semantic_seq_len + 1):]

        # get the eos token from predicted semantic tokens, and use that to predict the first coarse token

        semantic_eos = semantic_token_ids == self.semantic_eos_id
        pred_semantic_eos_tokens = tokens[:, 1:(semantic_seq_len + 1)][semantic_eos]

        pred_coarse_tokens = torch.cat((
            rearrange(pred_semantic_eos_tokens, 'b d -> b 1 d'),
            pred_coarse_tokens),
        dim = 1)

        # semantic logits

@@ -789,74 +800,6 @@ class FineTransformer(nn.Module):

# training wrappers

class FineTransformerWrapper(nn.Module):
    def __init__(
        self,
        *,
        transformer: FineTransformer,
        soundstream: Optional[SoundStream] = None,
        num_coarse_quantize = 3
    ):
        super().__init__()
        self.soundstream = soundstream
        self.transformer = transformer

        assert num_coarse_quantize > 0
        self.num_coarse_quantize = num_coarse_quantize

    def forward(
        self,
        *,
        raw_wave = None,
        coarse_token_ids = None,
        fine_token_ids = None,
        return_loss = False,
        **kwargs
    ):
        assert exists(raw_wave) ^ (exists(coarse_token_ids) and exists(fine_token_ids)), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'

        if exists(raw_wave):
            assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training'

            with torch.no_grad():
                self.soundstream.eval()
                _, indices, _ = self.soundstream(raw_wave, return_encoded = True)
                coarse_token_ids, fine_token_ids = indices[..., :self.num_coarse_quantize], indices[..., self.num_coarse_quantize:]

        coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
        fine_token_ids = rearrange(fine_token_ids, 'b ... -> b (...)')

        coarse_token_ids = append_eos_id(coarse_token_ids, self.transformer.eos_id)
        fine_token_ids = append_eos_id(fine_token_ids, self.transformer.eos_id)

        if return_loss:
            coarse_labels, fine_labels = coarse_token_ids, fine_token_ids.clone()
            fine_token_ids = fine_token_ids[:, :-1]

        coarse_logits, fine_logits = self.transformer(
            coarse_token_ids = coarse_token_ids,
            fine_token_ids = fine_token_ids,
            **kwargs
        )

        if not return_loss:
            return coarse_logits, fine_logits

        coarse_logits, fine_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, fine_logits))

        num_coarse_logits, num_fine_logits = coarse_logits.shape[-1], fine_logits.shape[-1]

        coarse_loss = F.cross_entropy(
            coarse_logits,
            coarse_labels
        )

        fine_loss = F.cross_entropy(
            fine_logits,
            fine_labels
        )

        return (coarse_loss * num_coarse_logits + fine_loss * num_fine_logits) / (num_coarse_logits + num_fine_logits)

class CoarseTransformerWrapper(nn.Module):
    def __init__(
@@ -905,11 +848,12 @@ class CoarseTransformerWrapper(nn.Module):
                _, indices, _ = self.soundstream(raw_wave, return_encoded = True)
                coarse_token_ids, _ = indices[..., :self.num_coarse_quantize], indices[..., self.num_coarse_quantize:]

        coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
        semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)')
        coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')

        coarse_token_ids = append_eos_id(coarse_token_ids, self.transformer.coarse_eos_id)
        if self.training:
            semantic_token_ids = append_eos_id(semantic_token_ids, self.transformer.semantic_eos_id)
            coarse_token_ids = append_eos_id(coarse_token_ids, self.transformer.coarse_eos_id)

        if self.unique_consecutive:
            semantic_token_ids = batch_unique_consecutive(semantic_token_ids, pad_value = self.pad_id)
@@ -954,6 +898,76 @@ class CoarseTransformerWrapper(nn.Module):

        return (semantic_loss * num_semantic_logits + coarse_loss * num_coarse_logits) / (num_semantic_logits + num_coarse_logits)

class FineTransformerWrapper(nn.Module):
    def __init__(
        self,
        *,
        transformer: FineTransformer,
        soundstream: Optional[SoundStream] = None,
        num_coarse_quantize = 3
    ):
        super().__init__()
        self.soundstream = soundstream
        self.transformer = transformer

        assert num_coarse_quantize > 0
        self.num_coarse_quantize = num_coarse_quantize

    def forward(
        self,
        *,
        raw_wave = None,
        coarse_token_ids = None,
        fine_token_ids = None,
        return_loss = False,
        **kwargs
    ):
        assert exists(raw_wave) ^ (exists(coarse_token_ids) and exists(fine_token_ids)), 'either raw waveform (raw_wav) is given, or coarse and fine token ids (coarse_token_ids, fine_token_ids)'

        if exists(raw_wave):
            assert exists(self.soundstream), 'SoundStream must be provided if given raw wave for training'

            with torch.no_grad():
                self.soundstream.eval()
                _, indices, _ = self.soundstream(raw_wave, return_encoded = True)
                coarse_token_ids, fine_token_ids = indices[..., :self.num_coarse_quantize], indices[..., self.num_coarse_quantize:]

        coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
        fine_token_ids = rearrange(fine_token_ids, 'b ... -> b (...)')

        if self.training:
            coarse_token_ids = append_eos_id(coarse_token_ids, self.transformer.eos_id)
            fine_token_ids = append_eos_id(fine_token_ids, self.transformer.eos_id)

        if return_loss:
            coarse_labels, fine_labels = coarse_token_ids, fine_token_ids.clone()
            fine_token_ids = fine_token_ids[:, :-1]

        coarse_logits, fine_logits = self.transformer(
            coarse_token_ids = coarse_token_ids,
            fine_token_ids = fine_token_ids,
            **kwargs
        )

        if not return_loss:
            return coarse_logits, fine_logits

        coarse_logits, fine_logits = map(lambda t: rearrange(t, 'b n c -> b c n'), (coarse_logits, fine_logits))

        num_coarse_logits, num_fine_logits = coarse_logits.shape[-1], fine_logits.shape[-1]

        coarse_loss = F.cross_entropy(
            coarse_logits,
            coarse_labels
        )

        fine_loss = F.cross_entropy(
            fine_logits,
            fine_labels
        )

        return (coarse_loss * num_coarse_logits + fine_loss * num_fine_logits) / (num_coarse_logits + num_fine_logits)

# audio LM

class AudioLM(nn.Module):
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.31',
  version = '0.0.32',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',