Unverified Commit c63b1362 authored by djqualia's avatar djqualia Committed by GitHub
Browse files

Add Accelerate init_tracker and log losses

This supports visualization tools e.g. tensorboard.  It required initializing the accelerator trackers.  If that's done before the "do you want to clear' step it can lead to weird edge cases.
parent 920a8a46
Loading
Loading
Loading
Loading
+20 −0
Original line number Diff line number Diff line
@@ -225,6 +225,9 @@ class SoundStreamTrainer(nn.Module):

        self.results_folder.mkdir(parents = True, exist_ok = True)

        hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
        self.accelerator.init_trackers("soundstream", config=hps)        

    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.soundstream),
@@ -344,6 +347,7 @@ class SoundStreamTrainer(nn.Module):
        # build pretty printed losses

        losses_str = f"{steps}: soundstream total loss: {logs['loss']:.3f}, soundstream recon loss: {logs['recon_loss']:.3f}"
        self.accelerator.log({"total_loss": logs['loss'], "recon_loss": logs['recon_loss']}, step=steps)

        for key, loss in logs.items():
            if not key.startswith('scale:'):
@@ -351,6 +355,7 @@ class SoundStreamTrainer(nn.Module):
            _, scale_factor = key.split(':')

            losses_str += f" | discr (scale {scale_factor}) loss: {loss:.3f}"
            self.accelerator.log({f"discr_loss (scale {scale_factor})": loss}, step=steps)

        # log

@@ -515,6 +520,9 @@ class SemanticTransformerTrainer(nn.Module):

        self.results_folder.mkdir(parents = True, exist_ok = True)
        
        hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
        self.accelerator.init_trackers("semenatic", config=hps)

    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.transformer),
@@ -591,6 +599,7 @@ class SemanticTransformerTrainer(nn.Module):
        # log

        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # sample results every so often

@@ -602,6 +611,7 @@ class SemanticTransformerTrainer(nn.Module):
                valid_loss = self.train_wrapper(**data_kwargs, return_loss = True)

            self.print(f'{steps}: valid loss {valid_loss}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # save model every so often

@@ -742,6 +752,9 @@ class CoarseTransformerTrainer(nn.Module):

        self.results_folder.mkdir(parents = True, exist_ok = True)

        hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
        self.accelerator.init_trackers("coarse", config=hps)        

        self.train_wrapper.to(self.device)

    def save(self, path):
@@ -817,6 +830,7 @@ class CoarseTransformerTrainer(nn.Module):
        # log

        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # sample results every so often

@@ -832,6 +846,7 @@ class CoarseTransformerTrainer(nn.Module):
                )

            self.print(f'{steps}: valid loss {valid_loss}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # save model every so often

@@ -965,6 +980,9 @@ class FineTransformerTrainer(nn.Module):

        self.results_folder.mkdir(parents = True, exist_ok = True)

        hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
        self.accelerator.init_trackers("fine", config=hps)        

        self.train_wrapper.to(self.device)

    def save(self, path):
@@ -1043,6 +1061,7 @@ class FineTransformerTrainer(nn.Module):
        # log

        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # sample results every so often

@@ -1054,6 +1073,7 @@ class FineTransformerTrainer(nn.Module):
                valid_loss = self.train_wrapper(**data_kwargs, return_loss = True)

            self.print(f'{steps}: valid loss {valid_loss}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # save model every so often