Loading audiolm_pytorch/trainer.py +14 −14 Original line number Diff line number Diff line Loading @@ -22,7 +22,7 @@ from audiolm_pytorch.optimizer import get_optimizer from ema_pytorch import EMA from audiolm_pytorch.soundstream import SoundStream from audiolm_pytorch.soundstream import SoundStream, EncodecWrapper from audiolm_pytorch.audiolm_pytorch import ( SemanticTransformer, Loading Loading @@ -735,14 +735,14 @@ class CoarseTransformerTrainer(nn.Module): def __init__( self, transformer: CoarseTransformer, soundstream: SoundStream, codec: Union[SoundStream, EncodecWrapper], wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]], *, num_train_steps, batch_size, audio_conditioner: Optional[AudioConditionerBase] = None, dataset: Optional[Dataset] = None, ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_soundstream', 'text'), ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_codec', 'text'), data_max_length = None, data_max_length_seconds = None, folder = None, Loading @@ -762,12 +762,12 @@ class CoarseTransformerTrainer(nn.Module): self.accelerator = Accelerator(**accelerate_kwargs) self.transformer = transformer self.soundstream = soundstream self.codec = codec self.wav2vec = wav2vec self.audio_conditioner = audio_conditioner self.train_wrapper = CoarseTransformerWrapper( soundstream = soundstream, codec = codec, wav2vec = wav2vec, transformer = transformer, audio_conditioner = audio_conditioner Loading Loading @@ -797,16 +797,16 @@ class CoarseTransformerTrainer(nn.Module): assert not (exists(data_max_length) and exists(data_max_length_seconds)) if exists(data_max_length_seconds): data_max_length = tuple(data_max_length_seconds * hz for hz in (wav2vec.target_sample_hz, soundstream.target_sample_hz)) data_max_length = tuple(data_max_length_seconds * hz for hz in (wav2vec.target_sample_hz, codec.target_sample_hz)) self.ds = SoundDataset( folder, max_length = data_max_length, target_sample_hz = ( wav2vec.target_sample_hz, soundstream.target_sample_hz codec.target_sample_hz ), # need 2 waves resampled differently here seq_len_multiple_of = soundstream.seq_len_multiple_of seq_len_multiple_of = codec.seq_len_multiple_of ) self.ds_fields = ds_fields Loading Loading @@ -985,7 +985,7 @@ class FineTransformerTrainer(nn.Module): def __init__( self, transformer: FineTransformer, soundstream: SoundStream, codec: Union[SoundStream, EncodecWrapper], *, num_train_steps, batch_size, Loading @@ -1011,11 +1011,11 @@ class FineTransformerTrainer(nn.Module): self.accelerator = Accelerator(**accelerate_kwargs) self.transformer = transformer self.soundstream = soundstream self.codec = codec self.audio_conditioner = audio_conditioner self.train_wrapper = FineTransformerWrapper( soundstream = soundstream, codec = codec, transformer = transformer, audio_conditioner = audio_conditioner ) Loading Loading @@ -1044,13 +1044,13 @@ class FineTransformerTrainer(nn.Module): assert not (exists(data_max_length) and exists(data_max_length_seconds)) if exists(data_max_length_seconds): data_max_length = data_max_length_seconds * soundstream.target_sample_hz data_max_length = data_max_length_seconds * codec.target_sample_hz self.ds = SoundDataset( folder, max_length = data_max_length, target_sample_hz = soundstream.target_sample_hz, seq_len_multiple_of = soundstream.seq_len_multiple_of target_sample_hz = codec.target_sample_hz, seq_len_multiple_of = codec.seq_len_multiple_of ) self.ds_fields = None Loading Loading
audiolm_pytorch/trainer.py +14 −14 Original line number Diff line number Diff line Loading @@ -22,7 +22,7 @@ from audiolm_pytorch.optimizer import get_optimizer from ema_pytorch import EMA from audiolm_pytorch.soundstream import SoundStream from audiolm_pytorch.soundstream import SoundStream, EncodecWrapper from audiolm_pytorch.audiolm_pytorch import ( SemanticTransformer, Loading Loading @@ -735,14 +735,14 @@ class CoarseTransformerTrainer(nn.Module): def __init__( self, transformer: CoarseTransformer, soundstream: SoundStream, codec: Union[SoundStream, EncodecWrapper], wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]], *, num_train_steps, batch_size, audio_conditioner: Optional[AudioConditionerBase] = None, dataset: Optional[Dataset] = None, ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_soundstream', 'text'), ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_codec', 'text'), data_max_length = None, data_max_length_seconds = None, folder = None, Loading @@ -762,12 +762,12 @@ class CoarseTransformerTrainer(nn.Module): self.accelerator = Accelerator(**accelerate_kwargs) self.transformer = transformer self.soundstream = soundstream self.codec = codec self.wav2vec = wav2vec self.audio_conditioner = audio_conditioner self.train_wrapper = CoarseTransformerWrapper( soundstream = soundstream, codec = codec, wav2vec = wav2vec, transformer = transformer, audio_conditioner = audio_conditioner Loading Loading @@ -797,16 +797,16 @@ class CoarseTransformerTrainer(nn.Module): assert not (exists(data_max_length) and exists(data_max_length_seconds)) if exists(data_max_length_seconds): data_max_length = tuple(data_max_length_seconds * hz for hz in (wav2vec.target_sample_hz, soundstream.target_sample_hz)) data_max_length = tuple(data_max_length_seconds * hz for hz in (wav2vec.target_sample_hz, codec.target_sample_hz)) self.ds = SoundDataset( folder, max_length = data_max_length, target_sample_hz = ( wav2vec.target_sample_hz, soundstream.target_sample_hz codec.target_sample_hz ), # need 2 waves resampled differently here seq_len_multiple_of = soundstream.seq_len_multiple_of seq_len_multiple_of = codec.seq_len_multiple_of ) self.ds_fields = ds_fields Loading Loading @@ -985,7 +985,7 @@ class FineTransformerTrainer(nn.Module): def __init__( self, transformer: FineTransformer, soundstream: SoundStream, codec: Union[SoundStream, EncodecWrapper], *, num_train_steps, batch_size, Loading @@ -1011,11 +1011,11 @@ class FineTransformerTrainer(nn.Module): self.accelerator = Accelerator(**accelerate_kwargs) self.transformer = transformer self.soundstream = soundstream self.codec = codec self.audio_conditioner = audio_conditioner self.train_wrapper = FineTransformerWrapper( soundstream = soundstream, codec = codec, transformer = transformer, audio_conditioner = audio_conditioner ) Loading Loading @@ -1044,13 +1044,13 @@ class FineTransformerTrainer(nn.Module): assert not (exists(data_max_length) and exists(data_max_length_seconds)) if exists(data_max_length_seconds): data_max_length = data_max_length_seconds * soundstream.target_sample_hz data_max_length = data_max_length_seconds * codec.target_sample_hz self.ds = SoundDataset( folder, max_length = data_max_length, target_sample_hz = soundstream.target_sample_hz, seq_len_multiple_of = soundstream.seq_len_multiple_of target_sample_hz = codec.target_sample_hz, seq_len_multiple_of = codec.seq_len_multiple_of ) self.ds_fields = None Loading