Unverified Commit c1d008c6 authored by Phil Wang's avatar Phil Wang Committed by GitHub
Browse files

Merge pull request #44 from djqualia/patch-6

Add Accelerate init_tracker and log losses
parents 920a8a46 c63b1362
Loading
Loading
Loading
Loading
+20 −0
Original line number Diff line number Diff line
@@ -225,6 +225,9 @@ class SoundStreamTrainer(nn.Module):

        self.results_folder.mkdir(parents = True, exist_ok = True)

        hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
        self.accelerator.init_trackers("soundstream", config=hps)        

    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.soundstream),
@@ -344,6 +347,7 @@ class SoundStreamTrainer(nn.Module):
        # build pretty printed losses

        losses_str = f"{steps}: soundstream total loss: {logs['loss']:.3f}, soundstream recon loss: {logs['recon_loss']:.3f}"
        self.accelerator.log({"total_loss": logs['loss'], "recon_loss": logs['recon_loss']}, step=steps)

        for key, loss in logs.items():
            if not key.startswith('scale:'):
@@ -351,6 +355,7 @@ class SoundStreamTrainer(nn.Module):
            _, scale_factor = key.split(':')

            losses_str += f" | discr (scale {scale_factor}) loss: {loss:.3f}"
            self.accelerator.log({f"discr_loss (scale {scale_factor})": loss}, step=steps)

        # log

@@ -515,6 +520,9 @@ class SemanticTransformerTrainer(nn.Module):

        self.results_folder.mkdir(parents = True, exist_ok = True)
        
        hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
        self.accelerator.init_trackers("semenatic", config=hps)

    def save(self, path):
        pkg = dict(
            model = self.accelerator.get_state_dict(self.transformer),
@@ -591,6 +599,7 @@ class SemanticTransformerTrainer(nn.Module):
        # log

        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # sample results every so often

@@ -602,6 +611,7 @@ class SemanticTransformerTrainer(nn.Module):
                valid_loss = self.train_wrapper(**data_kwargs, return_loss = True)

            self.print(f'{steps}: valid loss {valid_loss}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # save model every so often

@@ -742,6 +752,9 @@ class CoarseTransformerTrainer(nn.Module):

        self.results_folder.mkdir(parents = True, exist_ok = True)

        hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
        self.accelerator.init_trackers("coarse", config=hps)        

        self.train_wrapper.to(self.device)

    def save(self, path):
@@ -817,6 +830,7 @@ class CoarseTransformerTrainer(nn.Module):
        # log

        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # sample results every so often

@@ -832,6 +846,7 @@ class CoarseTransformerTrainer(nn.Module):
                )

            self.print(f'{steps}: valid loss {valid_loss}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # save model every so often

@@ -965,6 +980,9 @@ class FineTransformerTrainer(nn.Module):

        self.results_folder.mkdir(parents = True, exist_ok = True)

        hps = {"num_train_steps": num_train_steps, "data_max_length": data_max_length, "learning_rate": lr}
        self.accelerator.init_trackers("fine", config=hps)        

        self.train_wrapper.to(self.device)

    def save(self, path):
@@ -1043,6 +1061,7 @@ class FineTransformerTrainer(nn.Module):
        # log

        self.print(f"{steps}: loss: {logs['loss']}")
        self.accelerator.log({"train_loss": logs['loss']}, step=steps)

        # sample results every so often

@@ -1054,6 +1073,7 @@ class FineTransformerTrainer(nn.Module):
                valid_loss = self.train_wrapper(**data_kwargs, return_loss = True)

            self.print(f'{steps}: valid loss {valid_loss}')
            self.accelerator.log({"valid_loss": valid_loss}, step=steps)

        # save model every so often