Commit 504a7e17 authored by Phil Wang's avatar Phil Wang
Browse files

allow for curtailing a left part of the raw audio being passed into...

allow for curtailing a left part of the raw audio being passed into soundstream, for making the prompting training in naturalspeech a bit less error-prone
parent 657790d6
Loading
Loading
Loading
Loading
+13 −5
Original line number Diff line number Diff line
@@ -675,13 +675,18 @@ class SoundStream(nn.Module):
    def seq_len_multiple_of(self):
        return functools.reduce(lambda x, y: x * y, self.strides)

    def process_input(self, x, input_sample_hz = None):
    def process_input(
        self,
        x,
        input_sample_hz = None,
        curtail_from_left = False
    ):
        x, ps = pack([x], '* n')

        if exists(input_sample_hz):
            x = resample(x, input_sample_hz, self.target_sample_hz)

        x = curtail_to_multiple(x, self.seq_len_multiple_of)
        x = curtail_to_multiple(x, self.seq_len_multiple_of, from_left = curtail_from_left)

        if x.ndim == 2:
            x = rearrange(x, 'b n -> b 1 n')
@@ -699,14 +704,17 @@ class SoundStream(nn.Module):
        return_loss_breakdown = False,
        return_recons_only = False,
        input_sample_hz = None,
        apply_grad_penalty = False
        apply_grad_penalty = False,
        curtail_from_left = False
    ):
        assert not (exists(is_denoising) and not exists(target))

        x, ps = self.process_input(x, input_sample_hz = input_sample_hz)
        process_input = partial(self.process_input, input_sample_hz = input_sample_hz, curtail_from_left = curtail_from_left)

        x, ps = process_input(x)

        if exists(target):
            target, _ = self.process_input(target, input_sample_hz = input_sample_hz)
            target, _ = process_input(target)

        orig_x = x.clone()

+4 −2
Original line number Diff line number Diff line
@@ -5,9 +5,11 @@ from torch import nn
def round_down_nearest_multiple(num, divisor):
    return num // divisor * divisor

def curtail_to_multiple(t, mult):
def curtail_to_multiple(t, mult, from_left = False):
    data_len = t.shape[-1]
    return t[..., :round_down_nearest_multiple(data_len, mult)]
    rounded_seq_len = round_down_nearest_multiple(data_len, mult)
    seq_slice = slice(None, rounded_seq_len) if not from_left else slice(-rounded_seq_len, None)
    return t[..., seq_slice]

# base class

+1 −1
Original line number Diff line number Diff line
__version__ = '0.30.1'
__version__ = '0.30.2'