Commit 6b779a39 authored by Phil Wang's avatar Phil Wang
Browse files

allow for disabling the prompt asking whether to clear previous results

parent c12d7a82
Loading
Loading
Loading
Loading
+12 −8
Original line number Diff line number Diff line
@@ -132,7 +132,8 @@ class SoundStreamTrainer(nn.Module):
        ema_update_every = 10,
        apply_grad_penalty_every = 4,
        dl_num_workers = 0,
        accelerate_kwargs: dict = dict()
        accelerate_kwargs: dict = dict(),
        force_clear_prev_results = None  # set to True | False to skip the prompt
    ):
        super().__init__()
        self.accelerator = Accelerator(**accelerate_kwargs)
@@ -222,7 +223,7 @@ class SoundStreamTrainer(nn.Module):

        self.results_folder = Path(results_folder)

        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
        if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        self.results_folder.mkdir(parents = True, exist_ok = True)
@@ -445,7 +446,8 @@ class SemanticTransformerTrainer(nn.Module):
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict()
        accelerate_kwargs: dict = dict(),
        force_clear_prev_results = None
    ):
        super().__init__()
        self.accelerator = Accelerator(**accelerate_kwargs)
@@ -530,7 +532,7 @@ class SemanticTransformerTrainer(nn.Module):

        self.results_folder = Path(results_folder)

        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
        if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        self.results_folder.mkdir(parents = True, exist_ok = True)
@@ -674,7 +676,8 @@ class CoarseTransformerTrainer(nn.Module):
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict()
        accelerate_kwargs: dict = dict(),
        force_clear_prev_results = None
    ):
        super().__init__()
        self.accelerator = Accelerator(**accelerate_kwargs)
@@ -765,7 +768,7 @@ class CoarseTransformerTrainer(nn.Module):

        self.results_folder = Path(results_folder)

        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
        if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        self.results_folder.mkdir(parents = True, exist_ok = True)
@@ -911,7 +914,8 @@ class FineTransformerTrainer(nn.Module):
        save_results_every = 100,
        save_model_every = 1000,
        results_folder = './results',
        accelerate_kwargs: dict = dict()
        accelerate_kwargs: dict = dict(),
        force_clear_prev_results = None
    ):
        super().__init__()
        self.accelerator = Accelerator(**accelerate_kwargs)
@@ -997,7 +1001,7 @@ class FineTransformerTrainer(nn.Module):

        self.results_folder = Path(results_folder)

        if len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?'):
        if force_clear_prev_results is True or (not exists(force_clear_prev_results) and len([*self.results_folder.glob('**/*')]) > 0 and yes_or_no('do you want to clear previous experiment checkpoints and results?')):
            rmtree(str(self.results_folder))

        self.results_folder.mkdir(parents = True, exist_ok = True)