Commit 8ad550a5 authored by Phil Wang's avatar Phil Wang
Browse files

set things up for eventual text conditioned audio synthesis

parent a1cf75ca
Loading
Loading
Loading
Loading
+44 −4
Original line number Diff line number Diff line
@@ -5,7 +5,11 @@ from pathlib import Path
from shutil import rmtree

from typing import Union, List, Optional
from typeguard import typechecked
from typing_extensions import Annotated

from beartype import beartype
from beartype.door import is_bearable
from beartype.vale import Is

import torch
import torchaudio
@@ -39,6 +43,20 @@ from accelerate import Accelerator

DEFAULT_SAMPLE_RATE = 16000

# for automatically routing data emitted from a dataset to keywords of the transformer wrappers

DATASET_FIELD_TYPE_CONFIG = dict(
    raw_wave = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.float and t.ndim in {2, 3}]
    ],
    text = List[str],
    text_embeds = Annotated[
        torch.Tensor,
        Is[lambda t: t.dtype == torch.float and t.ndim == 3]
    ],
)

# helpers

def exists(val):
@@ -65,6 +83,28 @@ def accum_log(log, new_logs):
        log[key] = old_value + new_value
    return log

# auto data to module keyword argument routing functions

def has_duplicates(tup):
    counts = dict()
    for el in tup:
        if el not in counts:
            counts[el] = 0
        counts[el] += 1
    return any(filter(lambda count: count > 1, counts.values()))

def determine_types(data, config):
    output = []
    for el in data:
        for name, data_type in config.items():
            if is_bearable(el, data_type):
                output.append(name)
                break
        else:
            raise TypeError(f'unable to determine type of {data}')

    return tuple(output)

# main trainer class

class SoundStreamTrainer(nn.Module):
@@ -336,7 +376,7 @@ class SoundStreamTrainer(nn.Module):

# semantic transformer trainer

@typechecked
@beartype
class SemanticTransformerTrainer(nn.Module):
    def __init__(
        self,
@@ -525,7 +565,7 @@ class SemanticTransformerTrainer(nn.Module):

# fine transformer trainer

@typechecked
@beartype
class CoarseTransformerTrainer(nn.Module):
    def __init__(
        self,
@@ -731,7 +771,7 @@ class CoarseTransformerTrainer(nn.Module):

# fine transformer trainer

@typechecked
@beartype
class FineTransformerTrainer(nn.Module):
    def __init__(
        self,
+1 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from setuptools import setup, find_packages
setup(
  name = 'audiolm-pytorch',
  packages = find_packages(exclude=[]),
  version = '0.1.12',
  version = '0.1.14',
  license='MIT',
  description = 'AudioLM - Language Modeling Approach to Audio Generation from Google Research - Pytorch',
  author = 'Phil Wang',