Loading audiolm_pytorch/soundstream.py +0 −20 Original line number Diff line number Diff line Loading @@ -246,26 +246,6 @@ class ComplexSTFTDiscriminator(nn.Module): return complex_logits_abs, intermediates # learned EMA blocks class MultiHeadEMABlock(nn.Module): def __init__( self, dim, **kwargs ): super().__init__() self.prenorm = nn.LayerNorm(dim) self.mhema = MultiHeadedEMA(dim = dim, **kwargs) def forward(self, x): residual = x.clone() x = rearrange(x, 'b c n -> b n c') x = self.prenorm(x) x = self.mhema(x) x = rearrange(x, 'b n c -> b c n') return x + residual # sound stream class Residual(nn.Module): Loading Loading
audiolm_pytorch/soundstream.py +0 −20 Original line number Diff line number Diff line Loading @@ -246,26 +246,6 @@ class ComplexSTFTDiscriminator(nn.Module): return complex_logits_abs, intermediates # learned EMA blocks class MultiHeadEMABlock(nn.Module): def __init__( self, dim, **kwargs ): super().__init__() self.prenorm = nn.LayerNorm(dim) self.mhema = MultiHeadedEMA(dim = dim, **kwargs) def forward(self, x): residual = x.clone() x = rearrange(x, 'b c n -> b n c') x = self.prenorm(x) x = self.mhema(x) x = rearrange(x, 'b n c -> b c n') return x + residual # sound stream class Residual(nn.Module): Loading