Commit c1de7f22 authored by zhvng's avatar zhvng
Browse files

remove unused MultiHeadedEMABlock

parent 7e957b90
Loading
Loading
Loading
Loading
+0 −20
Original line number Diff line number Diff line
@@ -246,26 +246,6 @@ class ComplexSTFTDiscriminator(nn.Module):

        return complex_logits_abs, intermediates

# learned EMA blocks

class MultiHeadEMABlock(nn.Module):
    def __init__(
        self,
        dim,
        **kwargs
    ):
        super().__init__()
        self.prenorm = nn.LayerNorm(dim)
        self.mhema = MultiHeadedEMA(dim = dim, **kwargs)

    def forward(self, x):
        residual = x.clone()
        x = rearrange(x, 'b c n -> b n c')
        x = self.prenorm(x)
        x = self.mhema(x)
        x = rearrange(x, 'b n c -> b c n')
        return x + residual

# sound stream

class Residual(nn.Module):