Commit 3006d2a4 authored by dmMaze's avatar dmMaze
Browse files

fix chatgpt prompt

parent 40b0fa11
Loading
Loading
Loading
Loading
+13 −13
Original line number Diff line number Diff line
@@ -59,7 +59,7 @@ class GPTTranslator(BaseTranslator):
        'max tokens': 4096,
        'temperature': 0.5,
        'top p': 1,
        'return prompt': True,
        # 'return prompt': False,
        'retry attempts': 5,
    }

@@ -141,7 +141,7 @@ class GPTTranslator(BaseTranslator):
    def _assemble_prompts(self, from_lang: str, to_lang: str, queries: List[str]) -> List[str]:
        prompt = ''
        max_tokens = self.max_tokens
        return_prompt = self.params['return prompt']
        # return_prompt = self.params['return prompt']
        prompt_template = self.params['prompt template']['content'].format(to_lang=to_lang)
        prompt += prompt_template

@@ -153,15 +153,15 @@ class GPTTranslator(BaseTranslator):
            # 1 token = ~4 characters according to https://platform.openai.com/tokenizer
            # TODO: potentially add summarizations from special requests as context information
            if max_tokens * 2 and len(''.join(queries[i+1:])) > max_tokens:
                if return_prompt:
                    prompt += '\n<|1|>'
                # if return_prompt:
                #     prompt += '\n<|1|>'
                yield prompt.lstrip()
                prompt = prompt_template
                # Restart counting at 1
                i_offset = i + 1

        if return_prompt:
            prompt += '\n<|1|>'
        # if return_prompt:
        #     prompt += '\n<|1|>'
        yield prompt.lstrip()

    def _format_prompt_log(self, to_lang: str, prompt: str) -> str:
@@ -191,11 +191,9 @@ class GPTTranslator(BaseTranslator):
        from_lang = self.lang_map[self.lang_source]
        to_lang = self.lang_map[self.lang_target]
        queries = text

        # return_prompt = self.params['return prompt']
        chat_sample = self.chat_sample
        for prompt in self._assemble_prompts(from_lang, to_lang, queries):
            # self.logger.info('-- GPT Prompt --\n' + self._format_prompt_log(to_lang, prompt))

            ratelimit_attempt = 0
            server_error_attempt = 0
            while True:
@@ -217,12 +215,14 @@ class GPTTranslator(BaseTranslator):
                    time.sleep(1)

            self.logger.debug('-- GPT Response --\n' + response)
            new_translations = re.split(r'<\|\d+\|>', response)
            if chat_sample is not None:
                new_translations = new_translations[1:]
            new_translations = re.split(r'<\|\d+\|>', response)[-len(queries):]
            # if return_prompt:
            #     new_translations = new_translations[:-1]

            # if chat_sample is not None:
            #     new_translations = new_translations[1:]
            translations.extend([t.strip() for t in new_translations])

        self.logger.debug(translations)
        if self.token_count_last:
            self.logger.info(f'Used {self.token_count_last} tokens (Total: {self.token_count})')