Unverified Commit 1dc5ab1c authored by Erlend Aune's avatar Erlend Aune Committed by GitHub
Browse files

Update trainer.py

Missing argument in EMA.
parent 039ec88b
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -96,7 +96,7 @@ class SoundStreamTrainer(nn.Module):
        self.accelerator = Accelerator(**accelerate_kwargs)

        self.soundstream = soundstream
        self.ema_soundstream = EMA(soundstream, update_after_step = ema_update_after_step, update_every = ema_update_every)
        self.ema_soundstream = EMA(soundstream, beta=ema_beta, update_after_step = ema_update_after_step, update_every = ema_update_every)

        self.register_buffer('steps', torch.Tensor([0]))