Loading audiolm_pytorch/soundstream.py +13 −5 Original line number Diff line number Diff line Loading @@ -675,13 +675,18 @@ class SoundStream(nn.Module): def seq_len_multiple_of(self): return functools.reduce(lambda x, y: x * y, self.strides) def process_input(self, x, input_sample_hz = None): def process_input( self, x, input_sample_hz = None, curtail_from_left = False ): x, ps = pack([x], '* n') if exists(input_sample_hz): x = resample(x, input_sample_hz, self.target_sample_hz) x = curtail_to_multiple(x, self.seq_len_multiple_of) x = curtail_to_multiple(x, self.seq_len_multiple_of, from_left = curtail_from_left) if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') Loading @@ -699,14 +704,17 @@ class SoundStream(nn.Module): return_loss_breakdown = False, return_recons_only = False, input_sample_hz = None, apply_grad_penalty = False apply_grad_penalty = False, curtail_from_left = False ): assert not (exists(is_denoising) and not exists(target)) x, ps = self.process_input(x, input_sample_hz = input_sample_hz) process_input = partial(self.process_input, input_sample_hz = input_sample_hz, curtail_from_left = curtail_from_left) x, ps = process_input(x) if exists(target): target, _ = self.process_input(target, input_sample_hz = input_sample_hz) target, _ = process_input(target) orig_x = x.clone() Loading audiolm_pytorch/utils.py +4 −2 Original line number Diff line number Diff line Loading @@ -5,9 +5,11 @@ from torch import nn def round_down_nearest_multiple(num, divisor): return num // divisor * divisor def curtail_to_multiple(t, mult): def curtail_to_multiple(t, mult, from_left = False): data_len = t.shape[-1] return t[..., :round_down_nearest_multiple(data_len, mult)] rounded_seq_len = round_down_nearest_multiple(data_len, mult) seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None) return t[..., seq_slice] # base class Loading audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.30.1' __version__ = '0.30.2' Loading
audiolm_pytorch/soundstream.py +13 −5 Original line number Diff line number Diff line Loading @@ -675,13 +675,18 @@ class SoundStream(nn.Module): def seq_len_multiple_of(self): return functools.reduce(lambda x, y: x * y, self.strides) def process_input(self, x, input_sample_hz = None): def process_input( self, x, input_sample_hz = None, curtail_from_left = False ): x, ps = pack([x], '* n') if exists(input_sample_hz): x = resample(x, input_sample_hz, self.target_sample_hz) x = curtail_to_multiple(x, self.seq_len_multiple_of) x = curtail_to_multiple(x, self.seq_len_multiple_of, from_left = curtail_from_left) if x.ndim == 2: x = rearrange(x, 'b n -> b 1 n') Loading @@ -699,14 +704,17 @@ class SoundStream(nn.Module): return_loss_breakdown = False, return_recons_only = False, input_sample_hz = None, apply_grad_penalty = False apply_grad_penalty = False, curtail_from_left = False ): assert not (exists(is_denoising) and not exists(target)) x, ps = self.process_input(x, input_sample_hz = input_sample_hz) process_input = partial(self.process_input, input_sample_hz = input_sample_hz, curtail_from_left = curtail_from_left) x, ps = process_input(x) if exists(target): target, _ = self.process_input(target, input_sample_hz = input_sample_hz) target, _ = process_input(target) orig_x = x.clone() Loading
audiolm_pytorch/utils.py +4 −2 Original line number Diff line number Diff line Loading @@ -5,9 +5,11 @@ from torch import nn def round_down_nearest_multiple(num, divisor): return num // divisor * divisor def curtail_to_multiple(t, mult): def curtail_to_multiple(t, mult, from_left = False): data_len = t.shape[-1] return t[..., :round_down_nearest_multiple(data_len, mult)] rounded_seq_len = round_down_nearest_multiple(data_len, mult) seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None) return t[..., seq_slice] # base class Loading
audiolm_pytorch/version.py +1 −1 Original line number Diff line number Diff line __version__ = '0.30.1' __version__ = '0.30.2'