Loading audiolm_pytorch/audiolm_pytorch.py +8 −6 Original line number Diff line number Diff line Loading @@ -1410,9 +1410,10 @@ class CoarseTransformerWrapper(nn.Module): self.codec.eval() _, indices, _ = self.codec(raw_wave_for_codec, return_encoded = True) batch = raw_wave.shape[0] timesteps = raw_wave.shape[1] assert indices.shape == torch.Size((batch, timesteps, self.num_coarse_quantizers + self.num_fine_quantizers)), \ f'Expected codec to have shape (batch, timesteps, num_coarse_quantizers + num_fine_quantizers), but got {indices.shape}' num_timesteps = raw_wave.shape[1] num_frames = int(num_timesteps / self.codec.seq_len_multiple_of) assert indices.shape[0] == batch and indices.shape[1] == num_frames, \ f'Expected indices to have shape (batch, num_frames, num_coarse_quantizers + num_fine_quantizers), but got {indices.shape}' coarse_token_ids, _ = indices[..., :self.num_coarse_quantizers], indices[..., self.num_coarse_quantizers:] semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)') Loading Loading @@ -1636,9 +1637,10 @@ class FineTransformerWrapper(nn.Module): self.codec.eval() _, token_ids, _ = self.codec(raw_wave, return_encoded = True) batch = raw_wave.shape[0] timesteps = raw_wave.shape[1] assert token_ids.shape == torch.Size((batch, timesteps, self.num_coarse_quantizers + self.num_fine_quantizers)), \ f'Expected token ids to have shape (batch, timesteps, num_coarse_quantizers + num_fine_quantizers), but got {token_ids.shape}' num_timesteps = raw_wave.shape[1] num_frames = int(num_timesteps / self.codec.seq_len_multiple_of) assert token_ids.shape == torch.Size((batch, num_frames, self.num_coarse_quantizers + self.num_fine_quantizers)), \ f'Expected token ids to have shape (batch, num_frames, num_coarse_quantizers + num_fine_quantizers), but got {token_ids.shape}' if exists(token_ids): coarse_token_ids, fine_token_ids = token_ids[..., :self.num_coarse_quantizers], token_ids[..., self.num_coarse_quantizers:] Loading Loading
audiolm_pytorch/audiolm_pytorch.py +8 −6 Original line number Diff line number Diff line Loading @@ -1410,9 +1410,10 @@ class CoarseTransformerWrapper(nn.Module): self.codec.eval() _, indices, _ = self.codec(raw_wave_for_codec, return_encoded = True) batch = raw_wave.shape[0] timesteps = raw_wave.shape[1] assert indices.shape == torch.Size((batch, timesteps, self.num_coarse_quantizers + self.num_fine_quantizers)), \ f'Expected codec to have shape (batch, timesteps, num_coarse_quantizers + num_fine_quantizers), but got {indices.shape}' num_timesteps = raw_wave.shape[1] num_frames = int(num_timesteps / self.codec.seq_len_multiple_of) assert indices.shape[0] == batch and indices.shape[1] == num_frames, \ f'Expected indices to have shape (batch, num_frames, num_coarse_quantizers + num_fine_quantizers), but got {indices.shape}' coarse_token_ids, _ = indices[..., :self.num_coarse_quantizers], indices[..., self.num_coarse_quantizers:] semantic_token_ids = rearrange(semantic_token_ids, 'b ... -> b (...)') Loading Loading @@ -1636,9 +1637,10 @@ class FineTransformerWrapper(nn.Module): self.codec.eval() _, token_ids, _ = self.codec(raw_wave, return_encoded = True) batch = raw_wave.shape[0] timesteps = raw_wave.shape[1] assert token_ids.shape == torch.Size((batch, timesteps, self.num_coarse_quantizers + self.num_fine_quantizers)), \ f'Expected token ids to have shape (batch, timesteps, num_coarse_quantizers + num_fine_quantizers), but got {token_ids.shape}' num_timesteps = raw_wave.shape[1] num_frames = int(num_timesteps / self.codec.seq_len_multiple_of) assert token_ids.shape == torch.Size((batch, num_frames, self.num_coarse_quantizers + self.num_fine_quantizers)), \ f'Expected token ids to have shape (batch, num_frames, num_coarse_quantizers + num_fine_quantizers), but got {token_ids.shape}' if exists(token_ids): coarse_token_ids, fine_token_ids = token_ids[..., :self.num_coarse_quantizers], token_ids[..., self.num_coarse_quantizers:] Loading