has_condition=True# this will have to be set to True
).cuda()
# mock text video dataset (as an example)
# you will have to extend your own from `Dataset`, and return an audio tensor as well as a string (the audio description) in any order (the framework will autodetect and route it into the transformer)
fromtorch.utils.dataimportDataset
classMockTextAudioDataset(Dataset):
def__init__(self,length=100,audio_length=320*32):
super().__init__()
self.audio_length=audio_length
self.len=length
def__len__(self):
returnself.len
def__getitem__(self,idx):
mock_audio=torch.randn(self.audio_length)
mock_caption='audio caption'
returnmock_caption,mock_audio
dataset=MockTextAudioDataset()
# instantiate semantic transformer trainer and train
trainer=SemanticTransformerTrainer(
transformer=semantic_transformer,
wav2vec=wav2vec,
dataset=dataset,
batch_size=4,
grad_accum_every=8,
data_max_length=320*32,
num_train_steps=100000
)
trainer.train()
# after much training above
sample=trainer.generate(text=['sound of rain drops on the rooftops'],batch_size=1,max_length=2)# (1, < 128) - may terminate early if it detects [eos]
```
## Appreciation
-<ahref="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work and open source cutting edge artificial intelligence research