Loading audiolm_pytorch/trainer.py +44 −4 Original line number Diff line number Diff line Loading @@ -5,7 +5,11 @@ from pathlib import Path from shutil import rmtree from typing import Union, List, Optional from typeguard import typechecked from typing_extensions import Annotated from beartype import beartype from beartype.door import is_bearable from beartype.vale import Is import torch import torchaudio Loading Loading @@ -39,6 +43,20 @@ from accelerate import Accelerator DEFAULT_SAMPLE_RATE = 16000 # for automatically routing data emitted from a dataset to keywords of the transformer wrappers DATASET_FIELD_TYPE_CONFIG = dict( raw_wave = Annotated[ torch.Tensor, Is[lambda t: t.dtype == torch.float and t.ndim in {2, 3}] ], text = List[str], text_embeds = Annotated[ torch.Tensor, Is[lambda t: t.dtype == torch.float and t.ndim == 3] ], ) # helpers def exists(val): Loading @@ -65,6 +83,28 @@ def accum_log(log, new_logs): log[key] = old_value + new_value return log # auto data to module keyword argument routing functions def has_duplicates(tup): counts = dict() for el in tup: if el not in counts: counts[el] = 0 counts[el] += 1 return any(filter(lambda count: count > 1, counts.values())) def determine_types(data, config): output = [] for el in data: for name, data_type in config.items(): if is_bearable(el, data_type): output.append(name) break else: raise TypeError(f'unable to determine type of {data}') return tuple(output) # main trainer class class SoundStreamTrainer(nn.Module): Loading Loading @@ -336,7 +376,7 @@ class SoundStreamTrainer(nn.Module): # semantic transformer trainer @typechecked @beartype class SemanticTransformerTrainer(nn.Module): def __init__( self, Loading Loading @@ -525,7 +565,7 @@ class SemanticTransformerTrainer(nn.Module): # fine transformer trainer @typechecked @beartype class CoarseTransformerTrainer(nn.Module): def __init__( self, Loading Loading @@ -731,7 +771,7 @@ class CoarseTransformerTrainer(nn.Module): # fine transformer trainer @typechecked @beartype class FineTransformerTrainer(nn.Module): def __init__( self, Loading setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.1.12', version = '0.1.14', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading Loading
audiolm_pytorch/trainer.py +44 −4 Original line number Diff line number Diff line Loading @@ -5,7 +5,11 @@ from pathlib import Path from shutil import rmtree from typing import Union, List, Optional from typeguard import typechecked from typing_extensions import Annotated from beartype import beartype from beartype.door import is_bearable from beartype.vale import Is import torch import torchaudio Loading Loading @@ -39,6 +43,20 @@ from accelerate import Accelerator DEFAULT_SAMPLE_RATE = 16000 # for automatically routing data emitted from a dataset to keywords of the transformer wrappers DATASET_FIELD_TYPE_CONFIG = dict( raw_wave = Annotated[ torch.Tensor, Is[lambda t: t.dtype == torch.float and t.ndim in {2, 3}] ], text = List[str], text_embeds = Annotated[ torch.Tensor, Is[lambda t: t.dtype == torch.float and t.ndim == 3] ], ) # helpers def exists(val): Loading @@ -65,6 +83,28 @@ def accum_log(log, new_logs): log[key] = old_value + new_value return log # auto data to module keyword argument routing functions def has_duplicates(tup): counts = dict() for el in tup: if el not in counts: counts[el] = 0 counts[el] += 1 return any(filter(lambda count: count > 1, counts.values())) def determine_types(data, config): output = [] for el in data: for name, data_type in config.items(): if is_bearable(el, data_type): output.append(name) break else: raise TypeError(f'unable to determine type of {data}') return tuple(output) # main trainer class class SoundStreamTrainer(nn.Module): Loading Loading @@ -336,7 +376,7 @@ class SoundStreamTrainer(nn.Module): # semantic transformer trainer @typechecked @beartype class SemanticTransformerTrainer(nn.Module): def __init__( self, Loading Loading @@ -525,7 +565,7 @@ class SemanticTransformerTrainer(nn.Module): # fine transformer trainer @typechecked @beartype class CoarseTransformerTrainer(nn.Module): def __init__( self, Loading Loading @@ -731,7 +771,7 @@ class CoarseTransformerTrainer(nn.Module): # fine transformer trainer @typechecked @beartype class FineTransformerTrainer(nn.Module): def __init__( self, Loading
setup.py +1 −1 Original line number Diff line number Diff line Loading @@ -3,7 +3,7 @@ from setuptools import setup, find_packages setup( name = 'audiolm-pytorch', packages = find_packages(exclude=[]), version = '0.1.12', version = '0.1.14', license='MIT', description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch', author = 'Phil Wang', Loading