mirror of
https://github.com/HIllya51/LunaTranslator.git
synced 2024-12-27 15:44:12 +08:00
Sakura: Compatible with new SakuraAPI. (#492)
* Feat(Sakura): Compatible with new SakuraAPI by using openai package; Update info. * Fix bugs and change degen detecting method. * Fix bug. * Sakura: add context num option, fix bugs. * Sakura: Fix bugs. * restore * Sakura: restore using requests to interact with API.
This commit is contained in:
parent
75713f8322
commit
c4e55060f4
@ -1,165 +1,151 @@
|
||||
from traceback import print_exc
|
||||
from translator.basetranslator import basetrans
|
||||
import requests
|
||||
# OpenAI
|
||||
# from openai import OpenAI
|
||||
|
||||
class TS(basetrans):
|
||||
def langmap(self):
|
||||
return {"zh": "zh-CN"}
|
||||
def __init__(self, typename) :
|
||||
self.api_type = ""
|
||||
self.api_url = ""
|
||||
self.model_type = ""
|
||||
self.history = []
|
||||
self.history = {
|
||||
"ja": [],
|
||||
"zh": []
|
||||
}
|
||||
self.session = requests.Session()
|
||||
super( ).__init__(typename)
|
||||
def sliding_window(self, query):
|
||||
self.history.append(query)
|
||||
if len(self.history) > 4:
|
||||
del self.history[0]
|
||||
return self.history
|
||||
def list_to_prompt(self, query_list):
|
||||
def sliding_window(self, text_ja, text_zh):
|
||||
if text_ja == "" or text_zh == "":
|
||||
return
|
||||
self.history['ja'].append(text_ja)
|
||||
self.history['zh'].append(text_zh)
|
||||
if len(self.history['ja']) > int(self.config['附带上下文个数(必须打开利用上文翻译)']) + 1:
|
||||
del self.history['ja'][0]
|
||||
del self.history['zh'][0]
|
||||
def get_history(self, key):
|
||||
prompt = ""
|
||||
for q in query_list:
|
||||
for q in self.history[key]:
|
||||
prompt += q + "\n"
|
||||
prompt = prompt.strip()
|
||||
return prompt
|
||||
def make_prompt(self, context):
|
||||
if self.model_type == "baichuan":
|
||||
prompt = f"<reserved_106>将下面的日文文本翻译成中文:{context}<reserved_107>"
|
||||
elif self.model_type == "qwen":
|
||||
prompt = f"<|im_start|>system\n你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。<|im_end|>\n<|im_start|>user\n将下面的日文文本翻译成中文:{context}<|im_end|>\n<|im_start|>assistant\n"
|
||||
def get_client(self, api_url):
|
||||
if api_url[-4:] == "/v1/":
|
||||
api_url = api_url[:-1]
|
||||
elif api_url[-3:] == "/v1":
|
||||
pass
|
||||
elif api_url[-1] == '/':
|
||||
api_url += "v1"
|
||||
else:
|
||||
prompt = f"<reserved_106>将下面的日文文本翻译成中文:{context}<reserved_107>"
|
||||
|
||||
return prompt
|
||||
|
||||
def make_request(self, prompt, is_test=False):
|
||||
if self.api_type == "llama.cpp":
|
||||
request = {
|
||||
"prompt": prompt,
|
||||
"n_predict": 1 if is_test else int(self.config['max_new_token']),
|
||||
"temperature": float(self.config['temperature']),
|
||||
"top_p": float(self.config['top_p']),
|
||||
"repeat_penalty": float(self.config['repetition_penalty']),
|
||||
"frequency_penalty": float(self.config['frequency_penalty']),
|
||||
"top_k": 40,
|
||||
"seed": -1
|
||||
api_url += "/v1"
|
||||
self.api_url = api_url
|
||||
# OpenAI
|
||||
# self.client = OpenAI(api_key="114514", base_url=api_url)
|
||||
def make_messages(self, query, history_ja=None, history_zh=None, **kwargs):
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。"
|
||||
}
|
||||
return request
|
||||
elif self.api_type == "dev_server" or is_test:
|
||||
request = {
|
||||
"prompt": prompt,
|
||||
"max_new_tokens": 1 if is_test else int(self.config['max_new_token']),
|
||||
"do_sample": bool(self.config['do_sample']),
|
||||
"temperature": float(self.config['temperature']),
|
||||
"top_p": float(self.config['top_p']),
|
||||
"repetition_penalty": float(self.config['repetition_penalty']),
|
||||
"num_beams": int(self.config['num_beams']),
|
||||
"frequency_penalty": float(self.config['frequency_penalty']),
|
||||
"top_k": 40,
|
||||
"seed": -1
|
||||
]
|
||||
if history_ja:
|
||||
messages.append({
|
||||
"role": "user",
|
||||
"content": f"将下面的日文文本翻译成中文:{history_ja}"
|
||||
})
|
||||
if history_zh:
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": history_zh
|
||||
})
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"将下面的日文文本翻译成中文:{query}"
|
||||
}
|
||||
return request
|
||||
elif self.api_type == "openai_like":
|
||||
raise NotImplementedError(f"1: {self.api_type}")
|
||||
else:
|
||||
raise NotImplementedError(f"2: {self.api_type}")
|
||||
def parse_output(self, output: str, length):
|
||||
output = output.strip()
|
||||
output_list = output.split("\n")
|
||||
if len(output_list) != length:
|
||||
# fallback to no history translation
|
||||
return None
|
||||
else:
|
||||
return output_list[-1]
|
||||
def do_post(self, request):
|
||||
)
|
||||
return messages
|
||||
|
||||
|
||||
def send_request(self, query, is_test=False, **kwargs):
|
||||
extra_query = {
|
||||
'do_sample': bool(self.config['do_sample']),
|
||||
'num_beams': int(self.config['num_beams']),
|
||||
'repetition_penalty': float(self.config['repetition_penalty']),
|
||||
}
|
||||
messages = self.make_messages(query, **kwargs)
|
||||
try:
|
||||
response = self.session.post(self.api_url, json=request).json()
|
||||
if self.api_type == "dev_server":
|
||||
output = response['results'][0]['text']
|
||||
new_token = response['results'][0]['new_token']
|
||||
elif self.api_type == "llama.cpp":
|
||||
output = response['content']
|
||||
new_token = response['tokens_predicted']
|
||||
else:
|
||||
raise NotImplementedError("3")
|
||||
# OpenAI
|
||||
# output = self.client.chat.completions.create(
|
||||
data = dict(
|
||||
model="sukinishiro",
|
||||
messages=messages,
|
||||
temperature=float(self.config['temperature']),
|
||||
top_p=float(self.config['top_p']),
|
||||
max_tokens= 1 if is_test else int(self.config['max_new_token']),
|
||||
frequency_penalty=float(kwargs['frequency_penalty']) if "frequency_penalty" in kwargs.keys() else float(self.config['frequency_penalty']),
|
||||
seed=-1,
|
||||
extra_query=extra_query,
|
||||
stream=False,
|
||||
)
|
||||
output = self.session.post(self.api_url + "/chat/completions", json=data).json()
|
||||
except Exception as e:
|
||||
raise Exception(str(e) + f"\napi_type: '{self.api_type}', api_url: '{self.api_url}', model_type: '{self.model_type}'\n与API接口通信失败,请检查设置的API服务器监听地址是否正确,或检查API服务器是否正常开启。")
|
||||
return output, new_token, response
|
||||
def set_model_type(self):
|
||||
#TODO: get model type from api
|
||||
self.model_type = "NotImplemented"
|
||||
request = self.make_request("test", is_test=True)
|
||||
_, _, response = self.do_post(request)
|
||||
if self.api_type == "llama.cpp":
|
||||
model_name: str = response['model']
|
||||
model_version = model_name.split("-")[-2]
|
||||
if "0.8" in model_version:
|
||||
self.model_type = "baichuan"
|
||||
elif "0.9" in model_version:
|
||||
self.model_type = "qwen"
|
||||
return
|
||||
def set_api_type(self):
|
||||
endpoint = self.config['API接口地址']
|
||||
if endpoint[-1] != "/":
|
||||
endpoint += "/"
|
||||
api_url = endpoint + "api/v1/generate"
|
||||
test_json = self.make_request("test", is_test=True)
|
||||
try:
|
||||
response = self.session.post(api_url, json=test_json)
|
||||
except Exception as e:
|
||||
raise Exception(str(e) + f"\napi_type: '{self.api_type}', api_url: '{self.api_url}', model_type: '{self.model_type}'\n与API接口通信失败,请检查设置的API服务器监听地址是否正确,或检查API服务器是否正常开启。")
|
||||
try:
|
||||
response = response.json()
|
||||
output = response['results'][0]['text']
|
||||
new_token = response['results'][0]['new_token']
|
||||
self.api_type = "dev_server"
|
||||
self.api_url = api_url
|
||||
except:
|
||||
self.api_type = "llama.cpp"
|
||||
self.api_url = endpoint + "completion"
|
||||
|
||||
return
|
||||
raise ValueError(f"无法连接到Sakura API:{self.api_url},请检查你的API链接是否正确填写,以及API后端是否成功启动。")
|
||||
return output
|
||||
|
||||
def translate(self, query):
|
||||
self.checkempty(['API接口地址'])
|
||||
if self.api_type == "":
|
||||
self.set_api_type()
|
||||
self.set_model_type()
|
||||
if self.api_url == "":
|
||||
self.get_client(self.config['API接口地址'])
|
||||
frequency_penalty = float(self.config['frequency_penalty'])
|
||||
if not bool(self.config['利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)']):
|
||||
output = self.send_request(query)
|
||||
completion_tokens = output["usage"]["completion_tokens"]
|
||||
output_text = output["choices"][0]["message"]["content"]
|
||||
|
||||
prompt = self.make_prompt(query)
|
||||
request = self.make_request(prompt)
|
||||
|
||||
if not self.config['利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)']:
|
||||
output, new_token, _ = self.do_post(request)
|
||||
|
||||
if bool(self.config['fix_degeneration']):
|
||||
cnt = 0
|
||||
while new_token == self.config['max_new_token']:
|
||||
while completion_tokens == int(self.config['max_new_token']):
|
||||
# detect degeneration, fixing
|
||||
request['frequency_penalty'] += 0.1
|
||||
output, new_token, _ = self.do_post(request)
|
||||
|
||||
frequency_penalty += 0.1
|
||||
output = self.send_request(query, frequency_penalty=frequency_penalty)
|
||||
completion_tokens = output["usage"]["completion_tokens"]
|
||||
output_text = output["choices"][0]["message"]["content"]
|
||||
cnt += 1
|
||||
if cnt == 2:
|
||||
break
|
||||
else:
|
||||
query_list = self.sliding_window(query)
|
||||
request['prompt'] = self.make_prompt(query)
|
||||
output, new_token, _ = self.do_post(request)
|
||||
|
||||
# 实验性功能,测试效果后决定是否加入。
|
||||
# fallback = False
|
||||
# if self.config['启用日文上下文模式']:
|
||||
# history_prompt = self.get_history('ja')
|
||||
# output = self.send_request(history_prompt + "\n" + query)
|
||||
# completion_tokens = output.usage.completion_tokens
|
||||
# output_text = output.choices[0].message.content
|
||||
|
||||
# if len(output_text.split("\n")) == len(history_prompt.split("\n")) + 1:
|
||||
# output_text = output_text.split("\n")[-1]
|
||||
# else:
|
||||
# fallback = True
|
||||
# 如果日文上下文模式失败,则fallback到中文上下文模式。
|
||||
# if fallback or not self.config['启用日文上下文模式']:
|
||||
|
||||
history_prompt = self.get_history('zh')
|
||||
output = self.send_request(query, history_zh=history_prompt)
|
||||
completion_tokens = output["usage"]["completion_tokens"]
|
||||
output_text = output["choices"][0]["message"]["content"]
|
||||
|
||||
if bool(self.config['fix_degeneration']):
|
||||
cnt = 0
|
||||
while new_token == self.config['max_new_token']:
|
||||
# detect degeneration, fixing
|
||||
request['frequency_penalty'] += 0.1
|
||||
output, new_token, _ = self.do_post(request)
|
||||
|
||||
while completion_tokens == int(self.config['max_new_token']):
|
||||
frequency_penalty += 0.1
|
||||
output = self.send_request(query, history_zh=history_prompt, frequency_penalty=frequency_penalty)
|
||||
completion_tokens = output["usage"]["completion_tokens"]
|
||||
output_text = output["choices"][0]["message"]["content"]
|
||||
cnt += 1
|
||||
if cnt == 2:
|
||||
if cnt == 3:
|
||||
output_text = "Error:模型无法完整输出或退化无法解决,请调大设置中的max_new_token!!!原输出:" + output_text
|
||||
break
|
||||
|
||||
output = self.parse_output(output, len(query_list))
|
||||
if not output:
|
||||
request['prompt'] = self.make_prompt(query)
|
||||
output, new_token, _ = self.do_post(request)
|
||||
return output
|
||||
self.sliding_window(query, output_text)
|
||||
return output_text
|
@ -463,10 +463,11 @@
|
||||
},
|
||||
"sakura": {
|
||||
"args": {
|
||||
"Sakura部署教程": "https://sakura.srpr.moe",
|
||||
"Sakura部署教程": "https://github.com/SakuraLLM/Sakura-13B-Galgame/wiki",
|
||||
"Github仓库": "https://github.com/SakuraLLM/Sakura-13B-Galgame",
|
||||
"API接口地址": "http://127.0.0.1:5000/",
|
||||
"API接口地址": "http://127.0.0.1:8080/",
|
||||
"利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)": false,
|
||||
"附带上下文个数(必须打开利用上文翻译)": 3,
|
||||
"temperature": 0.1,
|
||||
"top_p": 0.3,
|
||||
"num_beams": 1,
|
||||
@ -492,6 +493,12 @@
|
||||
"利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)": {
|
||||
"type": "switch"
|
||||
},
|
||||
"附带上下文个数(必须打开利用上文翻译)":{
|
||||
"type":"intspin",
|
||||
"min":1,
|
||||
"max":32,
|
||||
"step":1
|
||||
},
|
||||
"temperature":{
|
||||
"type":"spin",
|
||||
"min":0,
|
||||
|
Loading…
x
Reference in New Issue
Block a user