Commit a7cb08d8 authored by Phil Wang's avatar Phil Wang
Browse files

cleanup

parent bc626a64
Loading
Loading
Loading
Loading
+9 −9
Original line number Diff line number Diff line
@@ -683,6 +683,8 @@ class FineTransformer(nn.Module):
        self.embed_text = partial(t5_encode_text, name = t5_name)
        self.cond_drop_prob = cond_drop_prob

        self.num_coarse_quantizers = num_coarse_quantizers

        self.coarse_start_token = nn.Parameter(torch.randn(dim))
        self.fine_start_token = nn.Parameter(torch.randn(dim))

@@ -824,10 +826,9 @@ class CoarseTransformerWrapper(nn.Module):
    def __init__(
        self,
        *,
        transformer: FineTransformer,
        transformer: CoarseTransformer,
        soundstream: Optional[SoundStream]  = None,
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]] = None,
        num_coarse_quantize = 3,
        pad_id = -1,
        unique_consecutive = True
    ):
@@ -839,8 +840,7 @@ class CoarseTransformerWrapper(nn.Module):
        self.unique_consecutive = unique_consecutive
        self.pad_id = pad_id

        assert num_coarse_quantize > 0
        self.num_coarse_quantize = num_coarse_quantize
        self.num_coarse_quantizers = transformer.num_coarse_quantizers

    def forward(
        self,
@@ -865,7 +865,7 @@ class CoarseTransformerWrapper(nn.Module):
            with torch.no_grad():
                self.soundstream.eval()
                _, indices, _ = self.soundstream(raw_wave, return_encoded = True)
                coarse_token_ids, _ = indices[..., :self.num_coarse_quantize], indices[..., self.num_coarse_quantize:]
                coarse_token_ids, _ = indices[..., :self.num_coarse_quantizers], indices[..., self.num_coarse_quantizers:]

        semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)')
        coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
@@ -924,15 +924,15 @@ class FineTransformerWrapper(nn.Module):
        *,
        transformer: FineTransformer,
        soundstream: Optional[SoundStream] = None,
        num_coarse_quantize = 3,
        pad_id = -1
    ):
        super().__init__()
        self.soundstream = soundstream
        self.transformer = transformer

        assert num_coarse_quantize > 0
        self.num_coarse_quantize = num_coarse_quantize
        self.num_coarse_quantizers = transformer.num_coarse_quantizers
        assert self.num_coarse_quantizers > 0

        self.pad_id = pad_id

    def forward(
@@ -952,7 +952,7 @@ class FineTransformerWrapper(nn.Module):
            with torch.no_grad():
                self.soundstream.eval()
                _, indices, _ = self.soundstream(raw_wave, return_encoded = True)
                coarse_token_ids, fine_token_ids = indices[..., :self.num_coarse_quantize], indices[..., self.num_coarse_quantize:]
                coarse_token_ids, fine_token_ids = indices[..., :self.num_coarse_quantizers], indices[..., self.num_coarse_quantizers:]

        coarse_token_ids = rearrange(coarse_token_ids, 'b ... -> b (...)')
        fine_token_ids = rearrange(fine_token_ids, 'b ... -> b (...)')
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.0.36',
  version = '0.0.38',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',