Loading .gitignore +3 −0 Original line number Diff line number Diff line Loading @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ # Pycharm .idea/ audiolm_pytorch/audiolm_pytorch.py +10 −0 Original line number Diff line number Diff line Loading @@ -1409,6 +1409,11 @@ class CoarseTransformerWrapper(nn.Module): with torch.no_grad(): self.codec.eval() _, indices, _ = self.codec(raw_wave_for_codec, return_encoded = True) batch = raw_wave.shape[0] num_timesteps = raw_wave.shape[1] num_frames = int(num_timesteps / self.codec.seq_len_multiple_of) assert indices.shape[0] == batch and indices.shape[1] == num_frames, \ f'Expected indices to have shape (batch, num_frames, num_coarse_quantizers + num_fine_quantizers), but got {indices.shape}' coarse_token_ids, _ = indices[..., :self.num_coarse_quantizers], indices[..., self.num_coarse_quantizers:] semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)') Loading Loading @@ -1631,6 +1636,11 @@ class FineTransformerWrapper(nn.Module): with torch.no_grad(): self.codec.eval() _, token_ids, _ = self.codec(raw_wave, return_encoded = True) batch = raw_wave.shape[0] num_timesteps = raw_wave.shape[1] num_frames = int(num_timesteps / self.codec.seq_len_multiple_of) assert token_ids.shape == torch.Size((batch, num_frames, self.num_coarse_quantizers + self.num_fine_quantizers)), \ f'Expected token ids to have shape (batch, num_frames, num_coarse_quantizers + num_fine_quantizers), but got {token_ids.shape}' if exists(token_ids): coarse_token_ids, fine_token_ids = token_ids[..., :self.num_coarse_quantizers], token_ids[..., self.num_coarse_quantizers:] Loading audiolm_pytorch/encodec.py +6 −1 Original line number Diff line number Diff line Loading @@ -53,7 +53,12 @@ class EncodecWrapper(nn.Module): # Extract discrete codes from EnCodec with torch.no_grad(): encoded_frames = self.model.encode(wav) codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [batch, num_quantizers, timesteps] # encoded_frames is a list of (frame, scale) tuples. Scale is a scalar but we don't use it. Frame is a tensor # of shape [batch, num_quantizers, num_samples_per_frame]. We want to concatenate the frames to get all the # timesteps concatenated. codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=1) # [batch, num_quantizers, timesteps] # transformer code that uses codec expects codes to be [batch, timesteps, num_quantizers] codes = rearrange(codes, 'b q n -> b n q') # result: [batch, timesteps, num_quantizers] # in original soundstream, is x, indices, commit_loss. But we only use indices in eval mode, so just keep that. return None, codes, None Loading audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.25.6' __version__ = '0.26.0' Loading
.gitignore +3 −0 Original line number Diff line number Diff line Loading @@ -127,3 +127,6 @@ dmypy.json # Pyre type checker .pyre/ # Pycharm .idea/
audiolm_pytorch/audiolm_pytorch.py +10 −0 Original line number Diff line number Diff line Loading @@ -1409,6 +1409,11 @@ class CoarseTransformerWrapper(nn.Module): with torch.no_grad(): self.codec.eval() _, indices, _ = self.codec(raw_wave_for_codec, return_encoded = True) batch = raw_wave.shape[0] num_timesteps = raw_wave.shape[1] num_frames = int(num_timesteps / self.codec.seq_len_multiple_of) assert indices.shape[0] == batch and indices.shape[1] == num_frames, \ f'Expected indices to have shape (batch, num_frames, num_coarse_quantizers + num_fine_quantizers), but got {indices.shape}' coarse_token_ids, _ = indices[..., :self.num_coarse_quantizers], indices[..., self.num_coarse_quantizers:] semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)') Loading Loading @@ -1631,6 +1636,11 @@ class FineTransformerWrapper(nn.Module): with torch.no_grad(): self.codec.eval() _, token_ids, _ = self.codec(raw_wave, return_encoded = True) batch = raw_wave.shape[0] num_timesteps = raw_wave.shape[1] num_frames = int(num_timesteps / self.codec.seq_len_multiple_of) assert token_ids.shape == torch.Size((batch, num_frames, self.num_coarse_quantizers + self.num_fine_quantizers)), \ f'Expected token ids to have shape (batch, num_frames, num_coarse_quantizers + num_fine_quantizers), but got {token_ids.shape}' if exists(token_ids): coarse_token_ids, fine_token_ids = token_ids[..., :self.num_coarse_quantizers], token_ids[..., self.num_coarse_quantizers:] Loading
audiolm_pytorch/encodec.py +6 −1 Original line number Diff line number Diff line Loading @@ -53,7 +53,12 @@ class EncodecWrapper(nn.Module): # Extract discrete codes from EnCodec with torch.no_grad(): encoded_frames = self.model.encode(wav) codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1) # [batch, num_quantizers, timesteps] # encoded_frames is a list of (frame, scale) tuples. Scale is a scalar but we don't use it. Frame is a tensor # of shape [batch, num_quantizers, num_samples_per_frame]. We want to concatenate the frames to get all the # timesteps concatenated. codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=1) # [batch, num_quantizers, timesteps] # transformer code that uses codec expects codes to be [batch, timesteps, num_quantizers] codes = rearrange(codes, 'b q n -> b n q') # result: [batch, timesteps, num_quantizers] # in original soundstream, is x, indices, commit_loss. But we only use indices in eval mode, so just keep that. return None, codes, None Loading
audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.25.6' __version__ = '0.26.0'