Commit df05e1e9 authored by Leon Wu's avatar Leon Wu
Browse files

update trainer

parent 2e64311c
Loading
Loading
Loading
Loading
+14 −14
Original line number Diff line number Diff line
@@ -22,7 +22,7 @@ from audiolm_pytorch.optimizer import get_optimizer

from ema_pytorch import EMA

from audiolm_pytorch.soundstream import SoundStream
from audiolm_pytorch.soundstream import SoundStream, EncodecWrapper

from audiolm_pytorch.audiolm_pytorch import (
    SemanticTransformer,
@@ -735,14 +735,14 @@ class CoarseTransformerTrainer(nn.Module):
    def __init__(
        self,
        transformer: CoarseTransformer,
        soundstream: SoundStream,
        codec: Union[SoundStream, EncodecWrapper],
        wav2vec: Optional[Union[FairseqVQWav2Vec, HubertWithKmeans]],
        *,
        num_train_steps,
        batch_size,
        audio_conditioner: Optional[AudioConditionerBase] = None,
        dataset: Optional[Dataset] = None,
        ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_soundstream', 'text'),
        ds_fields: Tuple[str, ...] = ('raw_wave', 'raw_wave_for_codec', 'text'),
        data_max_length = None,
        data_max_length_seconds = None,
        folder = None,
@@ -762,12 +762,12 @@ class CoarseTransformerTrainer(nn.Module):
        self.accelerator = Accelerator(**accelerate_kwargs)

        self.transformer = transformer
        self.soundstream = soundstream
        self.codec = codec
        self.wav2vec = wav2vec
        self.audio_conditioner = audio_conditioner

        self.train_wrapper = CoarseTransformerWrapper(
            soundstream = soundstream,
            codec = codec,
            wav2vec = wav2vec,
            transformer = transformer,
            audio_conditioner = audio_conditioner
@@ -797,16 +797,16 @@ class CoarseTransformerTrainer(nn.Module):
            assert not (exists(data_max_length) and exists(data_max_length_seconds))

            if exists(data_max_length_seconds):
                data_max_length = tuple(data_max_length_seconds * hz for hz in (wav2vec.target_sample_hz, soundstream.target_sample_hz))
                data_max_length = tuple(data_max_length_seconds * hz for hz in (wav2vec.target_sample_hz, codec.target_sample_hz))

            self.ds = SoundDataset(
                folder,
                max_length = data_max_length,
                target_sample_hz = (
                    wav2vec.target_sample_hz,
                    soundstream.target_sample_hz
                    codec.target_sample_hz
                ), # need 2 waves resampled differently here
                seq_len_multiple_of = soundstream.seq_len_multiple_of
                seq_len_multiple_of = codec.seq_len_multiple_of
            )

        self.ds_fields = ds_fields
@@ -985,7 +985,7 @@ class FineTransformerTrainer(nn.Module):
    def __init__(
        self,
        transformer: FineTransformer,
        soundstream: SoundStream,
        codec: Union[SoundStream, EncodecWrapper],
        *,
        num_train_steps,
        batch_size,
@@ -1011,11 +1011,11 @@ class FineTransformerTrainer(nn.Module):
        self.accelerator = Accelerator(**accelerate_kwargs)

        self.transformer = transformer
        self.soundstream = soundstream
        self.codec = codec
        self.audio_conditioner = audio_conditioner

        self.train_wrapper = FineTransformerWrapper(
            soundstream = soundstream,
            codec = codec,
            transformer = transformer,
            audio_conditioner = audio_conditioner
        )
@@ -1044,13 +1044,13 @@ class FineTransformerTrainer(nn.Module):
            assert not (exists(data_max_length) and exists(data_max_length_seconds))

            if exists(data_max_length_seconds):
                data_max_length = data_max_length_seconds * soundstream.target_sample_hz
                data_max_length = data_max_length_seconds * codec.target_sample_hz

            self.ds = SoundDataset(
                folder,
                max_length = data_max_length,
                target_sample_hz = soundstream.target_sample_hz,
                seq_len_multiple_of = soundstream.seq_len_multiple_of
                target_sample_hz = codec.target_sample_hz,
                seq_len_multiple_of = codec.seq_len_multiple_of
            )

        self.ds_fields = None