Unverified Commit 6ea8cf01 authored by Phil Wang's avatar Phil Wang Committed by GitHub
Browse files

Merge pull request #46 from djqualia/patch-9

Fix logging of recon_loss and small typo
parents 8b70431f b181beac
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -305,7 +305,7 @@ class SoundStreamTrainer(nn.Module):

            accum_log(logs, dict(
                loss = loss.item() / self.grad_accum_every,
                recon_loss = recon_loss / self.grad_accum_every
                recon_loss = recon_loss.item() / self.grad_accum_every
            ))

        if exists(self.max_grad_norm):
@@ -521,7 +521,7 @@ class SemanticTransformerTrainer(nn.Module):
        self.results_folder.mkdir(parents = True, exist_ok = True)
        
        hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
        self.accelerator.init_trackers("semenatic", config=hps)
        self.accelerator.init_trackers("semantic", config=hps)

    def save(self, path):
        pkg = dict(