Commit c3fabc36 authored by Phil Wang's avatar Phil Wang
Browse files
parent 7cfe6095
Loading
Loading
Loading
Loading
+10 −5
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ from pathlib import Path

from functools import partial, wraps
from itertools import zip_longest
from typing import Optional

import torch
from torch import nn, einsum
@@ -438,7 +439,8 @@ class SoundStream(nn.Module):
        attn_xpos_scale_base = None,
        attn_dynamic_pos_bias = False,
        squeeze_excite = False,
        complex_stft_discr_logits_abs = True
        complex_stft_discr_logits_abs = True,
        stft_discriminator: Optional[nn.Module] = None  # can pass in own stft discriminator
    ):
        super().__init__()

@@ -522,6 +524,9 @@ class SoundStream(nn.Module):
        discr_rel_factors = [int(s1 / s2) for s1, s2 in zip(discr_multi_scales[:-1], discr_multi_scales[1:])]
        self.downsamples = nn.ModuleList([nn.Identity()] + [nn.AvgPool1d(2 * factor, stride = factor, padding = factor) for factor in discr_rel_factors])

        self.stft_discriminator = stft_discriminator

        if not exists(self.stft_discriminator):
            self.stft_discriminator = ComplexSTFTDiscriminator(
                stft_normalized = stft_normalized,
                logits_abs = complex_stft_discr_logits_abs  # whether to output as abs() or use view_as_real
+1 −1
Original line number Diff line number Diff line
__version__ = '0.28.0'
__version__ = '0.28.1'