Commit 7d017fc3 authored by blackheaven's avatar blackheaven
Browse files

Add flask support

parent e50305de
Loading
Loading
Loading
Loading

Pipfile

0 → 100644
+15 −0
Original line number Diff line number Diff line
[[source]]
url = "https://pypi.org/simple"
verify_ssl = true
name = "pypi"

[packages]
flask = "*"
audiolm-pytorch = "*"
tensorboardx = "*"

[dev-packages]

[requires]
python_version = "3.8"
python_full_version = "3.8.16"

Pipfile.lock

0 → 100644
+1326 −0

File added.

Preview size limit exceeded, changes collapsed.

app.py

0 → 100644
+129 −0
Original line number Diff line number Diff line
# Dependencies
from flask import Flask, request, send_file

# ML
import torchaudio
import torch
from audiolm_pytorch import AudioLM, SoundStream, HubertWithKmeans, SemanticTransformer, SemanticTransformerTrainer, CoarseTransformer, CoarseTransformerTrainer, FineTransformer, FineTransformerTrainer

# define all dataset paths, checkpoints, etc
dataset_folder = 'placeholder_dataset'
ckpt_folder = 'placeholder_ckpt'

wav2vec = HubertWithKmeans(
    checkpoint_path = f'{ckpt_folder}/hubert/hubert_base_ls960.pt',
	kmeans_path = f'{ckpt_folder}/hubert/hubert_base_ls960_L9_km500.bin')

# SoundStream
soundstream = SoundStream.init_and_load_from(f'{ckpt_folder}/soundstream/soundstream.pt')

# Semantic
semantic_transformer = SemanticTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    dim = 1024,
    depth = 6
).cuda()

trainer_Semantic = SemanticTransformerTrainer(
    transformer = semantic_transformer,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    num_train_steps = 1, 
    # force_clear_prev_results: bool = None  # set to True | False to skip the prompt
    force_clear_prev_results = False
)

trainer_Semantic.load(f'{ckpt_folder}/semantic/semantic.pt')

# Coarse
coarse_transformer = CoarseTransformer(
    num_semantic_tokens = wav2vec.codebook_size,
    codebook_size = 1024,
    num_coarse_quantizers = 3,
    dim = 512,
    depth = 6
)

trainer_Coarse = CoarseTransformerTrainer(
    transformer = coarse_transformer,
    codec = soundstream,
    wav2vec = wav2vec,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    # num_train_steps = 1_000_000
    num_train_steps = 50_000,
    # force_clear_prev_results: bool = None  # set to True | False to skip the prompt
    force_clear_prev_results = False
)

trainer_Coarse.load(f'{ckpt_folder}/coarse/coarse.pt')

# Fine
fine_transformer = FineTransformer(
    num_coarse_quantizers = 3,
    num_fine_quantizers = 5,
    codebook_size = 1024,
    dim = 512,
    depth = 6
)

trainer_Fine = FineTransformerTrainer(
    transformer = fine_transformer,
    codec = soundstream,
    folder = dataset_folder,
    batch_size = 1,
    data_max_length = 320 * 32,
    # num_train_steps = 1_000_000
    num_train_steps = 50_000,
    # force_clear_prev_results: bool = None  # set to True | False to skip the prompt
    force_clear_prev_results = False
)

trainer_Fine.load(f'{ckpt_folder}/fine/fine.pt')

# Mix All together

audiolm = AudioLM(
    wav2vec = wav2vec,
    codec = soundstream,
    semantic_transformer = semantic_transformer,
    coarse_transformer = coarse_transformer,
    fine_transformer = fine_transformer
)

# Your API definition
app = Flask(__name__)

@app.route('/', methods=['GET'])
def index():
    return 'there\'s nothing here, yet'

@app.route('/health', methods=['GET'])
def healthcheck():
    return 'OK'

@app.route('/text2audio', methods=['GET', 'POST'])
# https://pythonprogramming.net/flask-send-file-tutorial/
# https://stackoverflow.com/a/53398209
# https://medium.com/analytics-vidhya/receive-or-return-files-flask-api-8389d42b0684
def text2audio():
    if request.method == 'GET':
        return 'return audio file something'
    elif request.method == 'POST':
        input_text = request.form['text']
        generated_wav_with_text_condition = audiolm(text = [input_text])
        output_path = "output.wav"
        sample_rate = 44100
        torchaudio.save(output_path, generated_wav_with_text_condition.cpu(), sample_rate)
        try:
            # return 'something file'
            return send_file('output.wav', as_attachment=True) #attachment_filename='sample.wav' )
        except:
            print(error)

if __name__ == '__main__':
    # https://www.datacamp.com/tutorial/machine-learning-models-api-python
    app.run(host='0.0.0.0', port='12345', debug=True)