Commit 64ffeed1 authored by dmMaze's avatar dmMaze
Browse files

try to fix chatgpt (#173)

parent 3006d2a4
Loading
Loading
Loading
Loading
+26 −12
Original line number Diff line number Diff line
@@ -8,7 +8,9 @@ from typing import List, Dict, Union
import yaml
from .base import BaseTranslator, register_translator

CONFIG = None

class InvalidNumTranslations(Exception):
    pass

@register_translator('chatgpt')
class GPTTranslator(BaseTranslator):
@@ -23,7 +25,7 @@ class GPTTranslator(BaseTranslator):
                'gpt35-turbo',
                'gpt4',
            ],
            'select': 'gpt3'
            'select': 'gpt35-turbo'
        },
        'prompt template': {
            'type': 'editor',
@@ -142,12 +144,14 @@ class GPTTranslator(BaseTranslator):
        prompt = ''
        max_tokens = self.max_tokens
        # return_prompt = self.params['return prompt']
        prompt_template = self.params['prompt template']['content'].format(to_lang=to_lang)
        prompt_template = self.params['prompt template']['content'].format(to_lang=to_lang).rstrip()
        prompt += prompt_template

        i_offset = 0
        num_src = 0
        for i, query in enumerate(queries):
            prompt += f'\n<|{i+1-i_offset}|>{query}'
            num_src += 1
            # If prompt is growing too large and theres still a lot of text left
            # split off the rest of the queries into new prompts.
            # 1 token = ~4 characters according to https://platform.openai.com/tokenizer
@@ -155,14 +159,15 @@ class GPTTranslator(BaseTranslator):
            if max_tokens * 2 and len(''.join(queries[i+1:])) > max_tokens:
                # if return_prompt:
                #     prompt += '\n<|1|>'
                yield prompt.lstrip()
                yield prompt.lstrip(), num_src
                prompt = prompt_template
                # Restart counting at 1
                i_offset = i + 1
                num_src = 0

        # if return_prompt:
        #     prompt += '\n<|1|>'
        yield prompt.lstrip()
        yield prompt.lstrip(), num_src

    def _format_prompt_log(self, to_lang: str, prompt: str) -> str:
        chat_sample = self.chat_sample
@@ -193,12 +198,15 @@ class GPTTranslator(BaseTranslator):
        queries = text
        # return_prompt = self.params['return prompt']
        chat_sample = self.chat_sample
        for prompt in self._assemble_prompts(from_lang, to_lang, queries):
        for prompt, num_src in self._assemble_prompts(from_lang, to_lang, queries):
            ratelimit_attempt = 0
            server_error_attempt = 0
            retry_attempt = 0
            while True:
                try:
                    response = self._request_translation(prompt, chat_sample)
                    new_translations = re.split(r'<\|\d+\|>', response)[-num_src:]
                    if len(new_translations) != num_src:
                        raise InvalidNumTranslations
                    break
                except openai.error.RateLimitError: # Server returned ratelimit response
                    ratelimit_attempt += 1
@@ -207,15 +215,21 @@ class GPTTranslator(BaseTranslator):
                    self.logger.warn(f'Restarting request due to ratelimiting by openai servers. Attempt: {ratelimit_attempt}')
                    time.sleep(2)
                except openai.error.APIError: # Server returned 500 error (probably server load)
                    server_error_attempt += 1
                    if server_error_attempt >= self.retry_attempts:
                    retry_attempt += 1
                    if retry_attempt >= self.retry_attempts:
                        self.logger.error('OpenAI encountered a server error, possibly due to high server load. Use a different translator or try again later.')
                        raise
                    self.logger.warn(f'Restarting request due to a server error. Attempt: {server_error_attempt}')
                    self.logger.warn(f'Restarting request due to a server error. Attempt: {retry_attempt}')
                    time.sleep(1)
                except InvalidNumTranslations:
                    retry_attempt += 1
                    message = f'number of translations does not match to source:\nprompt:\n    {prompt}\ntranslations:\n  {new_translations}\nopenai response:\n  {response}'
                    if retry_attempt >= self.retry_attempts:
                        self.logger.error(message)
                        new_translations = [''] * num_src
                        break
                    self.logger.warn(message + '\n' + f'Restarting request. Attempt: {retry_attempt}')

            self.logger.debug('-- GPT Response --\n' + response)
            new_translations = re.split(r'<\|\d+\|>', response)[-len(queries):]
            # if return_prompt:
            #     new_translations = new_translations[:-1]