Commit 7f1f3b25 authored by Leon Wu's avatar Leon Wu
Browse files

Fix token/indices shape code

parent 9f2d1532
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -127,3 +127,6 @@ dmypy.json

# Pyre type checker
.pyre/

# Pycharm
.idea/
+8 −0
Original line number Diff line number Diff line
@@ -1409,6 +1409,10 @@ class CoarseTransformerWrapper(nn.Module):
            with torch.no_grad():
                self.codec.eval()
                _, indices, _ = self.codec(raw_wave_for_codec, return_encoded = True)
                batch = raw_wave.shape[0]
                timesteps = raw_wave.shape[1]
                assert indices.shape == torch.Size((batch, timesteps, self.num_coarse_quantizers + self.num_fine_quantizers)), \
                    f'Expected codec to have shape (batch, timesteps, num_coarse_quantizers + num_fine_quantizers), but got {indices.shape}'
                coarse_token_ids, _ = indices[..., :self.num_coarse_quantizers], indices[..., self.num_coarse_quantizers:]

        semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)')
@@ -1631,6 +1635,10 @@ class FineTransformerWrapper(nn.Module):
            with torch.no_grad():
                self.codec.eval()
                _, token_ids, _ = self.codec(raw_wave, return_encoded = True)
                batch = raw_wave.shape[0]
                timesteps = raw_wave.shape[1]
                assert token_ids.shape == torch.Size((batch, timesteps, self.num_coarse_quantizers + self.num_fine_quantizers)), \
                    f'Expected token ids to have shape (batch, timesteps, num_coarse_quantizers + num_fine_quantizers), but got {token_ids.shape}'

        if exists(token_ids):
            coarse_token_ids, fine_token_ids = token_ids[..., :self.num_coarse_quantizers], token_ids[..., self.num_coarse_quantizers:]
+1 −1
Original line number Diff line number Diff line
@@ -53,7 +53,7 @@ class EncodecWrapper(nn.Module):
        # Extract discrete codes from EnCodec
        with torch.no_grad():
            encoded_frames = self.model.encode(wav)
        codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1)  # [batch, num_quantizers, timesteps]
        codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=1)  # [batch, timesteps, num_quantizers]
        # in original soundstream, is x, indices, commit_loss. But we only use indices in eval mode, so just keep that.
        return None, codes, None

+1 −1
Original line number Diff line number Diff line
__version__ = '0.25.6'
__version__ = '0.26.0'