Commit 23b962ab authored by dmMaze's avatar dmMaze
Browse files

fix deprecated parameter max_tokens for openai #704

parent 9c017173
Loading
Loading
Loading
Loading
+20 −7
Original line number Diff line number Diff line
@@ -5,13 +5,16 @@ import time
from typing import List, Dict, Union
import yaml
import traceback
import inspect

import openai

from .base import BaseTranslator, register_translator


OPENAPI_V1_API = int(openai.__version__.split('.')[0]) >= 1


class InvalidNumTranslations(Exception):
    pass

@@ -310,17 +313,27 @@ class GPTTranslator(BaseTranslator):
            messages.insert(1, {'role': 'user', 'content': chat_sample[0]})
            messages.insert(2, {'role': 'assistant', 'content': chat_sample[1]})

        func_args = {
            'model': model,
            'messages': messages,
            'temperature': self.temperature,
            'top_p': self.top_p,
            'frequency_penalty': self.params['frequency penalty'],
            'presence_penalty ': self.params['presence penalty']
        }
        max_tokens = self.max_tokens // 2 # Assuming that half of the tokens are used for the query
        func_parameters = inspect.signature(openai.chat.completions.create).parameters
        if 'max_completion_tokens' in func_parameters:
            func_args['max_completion_tokens'] = max_tokens
        else:
            func_args['max_tokens'] = max_tokens

        if OPENAPI_V1_API:
            openai_chatcompletions_create = openai.chat.completions.create
        else:
            openai_chatcompletions_create = openai.ChatCompletion.create
        response = openai_chatcompletions_create(
            model=model,
            messages=messages,
            max_tokens=self.max_tokens // 2,
            temperature=self.temperature,
            top_p=self.top_p,
        )

        response = openai_chatcompletions_create(**func_args)

        if OPENAPI_V1_API:
            if response.usage is not None:
+18 −10
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ import time
from typing import List, Dict, Union
import xml.etree.ElementTree as ET
import traceback
import logging
import inspect

import openai

@@ -252,19 +252,27 @@ Then stop, without any other explanations or notes.
            {'role': 'user', 'content': prompt},
        ]

        func_args = {
            'model': model,
            'messages': messages,
            'temperature': self.temperature,
            'top_p': self.top_p,
            'frequency_penalty': self.params['frequency penalty'],
            'presence_penalty ': self.params['presence penalty']
        }
        max_tokens = self.max_tokens // 2 # Assuming that half of the tokens are used for the query
        func_parameters = inspect.signature(openai.chat.completions.create).parameters
        if 'max_completion_tokens' in func_parameters:
            func_args['max_completion_tokens'] = max_tokens
        else:
            func_args['max_tokens'] = max_tokens

        if OPENAPI_V1_API:
            openai_chatcompletions_create = openai.chat.completions.create
        else:
            openai_chatcompletions_create = openai.ChatCompletion.create
        response = openai_chatcompletions_create(
            model=model,
            messages=messages,
            max_tokens=self.max_tokens // 2,
            temperature=self.temperature,
            top_p=self.top_p,
            frequency_penalty=self.params['frequency penalty'],
            presence_penalty=self.params['presence penalty']
        )

        response = openai_chatcompletions_create(**func_args)

        self.logger.debug(f'openai response: \n {response}')