Loading audiolm_pytorch/encodec.py +6 −1 Original line number Diff line number Diff line Loading @@ -53,7 +53,12 @@ class EncodecWrapper(nn.Module): # Extract discrete codes from EnCodec with torch.no_grad(): encoded_frames = self.model.encode(wav) codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=1) # [batch, timesteps, num_quantizers] # encoded_frames is a list of (frame, scale) tuples. Scale is a scalar but we don't use it. Frame is a tensor # of shape [batch, num_quantizers, num_samples_per_frame]. We want to concatenate the frames to get all the # timesteps concatenated. codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=1) # [batch, num_quantizers, timesteps] # transformer code that uses codec expects codes to be [batch, timesteps, num_quantizers] codes = rearrange(codes, 'b q n -> b n q') # result: [batch, timesteps, num_quantizers] # in original soundstream, is x, indices, commit_loss. But we only use indices in eval mode, so just keep that. return None, codes, None Loading Loading
audiolm_pytorch/encodec.py +6 −1 Original line number Diff line number Diff line Loading @@ -53,7 +53,12 @@ class EncodecWrapper(nn.Module): # Extract discrete codes from EnCodec with torch.no_grad(): encoded_frames = self.model.encode(wav) codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=1) # [batch, timesteps, num_quantizers] # encoded_frames is a list of (frame, scale) tuples. Scale is a scalar but we don't use it. Frame is a tensor # of shape [batch, num_quantizers, num_samples_per_frame]. We want to concatenate the frames to get all the # timesteps concatenated. codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=1) # [batch, num_quantizers, timesteps] # transformer code that uses codec expects codes to be [batch, timesteps, num_quantizers] codes = rearrange(codes, 'b q n -> b n q') # result: [batch, timesteps, num_quantizers] # in original soundstream, is x, indices, commit_loss. But we only use indices in eval mode, so just keep that. return None, codes, None Loading