mirror of
https://github.com/HIllya51/LunaTranslator.git
synced 2024-12-28 08:04:13 +08:00
Sakura: Compatible with new SakuraAPI. (#492)
* Feat(Sakura): Compatible with new SakuraAPI by using openai package; Update info. * Fix bugs and change degen detecting method. * Fix bug. * Sakura: add context num option, fix bugs. * Sakura: Fix bugs. * restore * Sakura: restore using requests to interact with API.
This commit is contained in:
parent
75713f8322
commit
c4e55060f4
@ -1,165 +1,151 @@
|
|||||||
from traceback import print_exc
|
from traceback import print_exc
|
||||||
from translator.basetranslator import basetrans
|
from translator.basetranslator import basetrans
|
||||||
import requests
|
import requests
|
||||||
|
# OpenAI
|
||||||
|
# from openai import OpenAI
|
||||||
|
|
||||||
class TS(basetrans):
|
class TS(basetrans):
|
||||||
def langmap(self):
|
def langmap(self):
|
||||||
return {"zh": "zh-CN"}
|
return {"zh": "zh-CN"}
|
||||||
def __init__(self, typename) :
|
def __init__(self, typename) :
|
||||||
self.api_type = ""
|
|
||||||
self.api_url = ""
|
self.api_url = ""
|
||||||
self.model_type = ""
|
self.history = {
|
||||||
self.history = []
|
"ja": [],
|
||||||
|
"zh": []
|
||||||
|
}
|
||||||
self.session = requests.Session()
|
self.session = requests.Session()
|
||||||
super( ).__init__(typename)
|
super( ).__init__(typename)
|
||||||
def sliding_window(self, query):
|
def sliding_window(self, text_ja, text_zh):
|
||||||
self.history.append(query)
|
if text_ja == "" or text_zh == "":
|
||||||
if len(self.history) > 4:
|
return
|
||||||
del self.history[0]
|
self.history['ja'].append(text_ja)
|
||||||
return self.history
|
self.history['zh'].append(text_zh)
|
||||||
def list_to_prompt(self, query_list):
|
if len(self.history['ja']) > int(self.config['附带上下文个数(必须打开利用上文翻译)']) + 1:
|
||||||
|
del self.history['ja'][0]
|
||||||
|
del self.history['zh'][0]
|
||||||
|
def get_history(self, key):
|
||||||
prompt = ""
|
prompt = ""
|
||||||
for q in query_list:
|
for q in self.history[key]:
|
||||||
prompt += q + "\n"
|
prompt += q + "\n"
|
||||||
prompt = prompt.strip()
|
prompt = prompt.strip()
|
||||||
return prompt
|
return prompt
|
||||||
def make_prompt(self, context):
|
def get_client(self, api_url):
|
||||||
if self.model_type == "baichuan":
|
if api_url[-4:] == "/v1/":
|
||||||
prompt = f"<reserved_106>将下面的日文文本翻译成中文:{context}<reserved_107>"
|
api_url = api_url[:-1]
|
||||||
elif self.model_type == "qwen":
|
elif api_url[-3:] == "/v1":
|
||||||
prompt = f"<|im_start|>system\n你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。<|im_end|>\n<|im_start|>user\n将下面的日文文本翻译成中文:{context}<|im_end|>\n<|im_start|>assistant\n"
|
pass
|
||||||
|
elif api_url[-1] == '/':
|
||||||
|
api_url += "v1"
|
||||||
else:
|
else:
|
||||||
prompt = f"<reserved_106>将下面的日文文本翻译成中文:{context}<reserved_107>"
|
api_url += "/v1"
|
||||||
|
self.api_url = api_url
|
||||||
return prompt
|
# OpenAI
|
||||||
|
# self.client = OpenAI(api_key="114514", base_url=api_url)
|
||||||
def make_request(self, prompt, is_test=False):
|
def make_messages(self, query, history_ja=None, history_zh=None, **kwargs):
|
||||||
if self.api_type == "llama.cpp":
|
messages = [
|
||||||
request = {
|
{
|
||||||
"prompt": prompt,
|
"role": "system",
|
||||||
"n_predict": 1 if is_test else int(self.config['max_new_token']),
|
"content": "你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。"
|
||||||
"temperature": float(self.config['temperature']),
|
|
||||||
"top_p": float(self.config['top_p']),
|
|
||||||
"repeat_penalty": float(self.config['repetition_penalty']),
|
|
||||||
"frequency_penalty": float(self.config['frequency_penalty']),
|
|
||||||
"top_k": 40,
|
|
||||||
"seed": -1
|
|
||||||
}
|
}
|
||||||
return request
|
]
|
||||||
elif self.api_type == "dev_server" or is_test:
|
if history_ja:
|
||||||
request = {
|
messages.append({
|
||||||
"prompt": prompt,
|
"role": "user",
|
||||||
"max_new_tokens": 1 if is_test else int(self.config['max_new_token']),
|
"content": f"将下面的日文文本翻译成中文:{history_ja}"
|
||||||
"do_sample": bool(self.config['do_sample']),
|
})
|
||||||
"temperature": float(self.config['temperature']),
|
if history_zh:
|
||||||
"top_p": float(self.config['top_p']),
|
messages.append({
|
||||||
"repetition_penalty": float(self.config['repetition_penalty']),
|
"role": "assistant",
|
||||||
"num_beams": int(self.config['num_beams']),
|
"content": history_zh
|
||||||
"frequency_penalty": float(self.config['frequency_penalty']),
|
})
|
||||||
"top_k": 40,
|
|
||||||
"seed": -1
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"将下面的日文文本翻译成中文:{query}"
|
||||||
}
|
}
|
||||||
return request
|
)
|
||||||
elif self.api_type == "openai_like":
|
return messages
|
||||||
raise NotImplementedError(f"1: {self.api_type}")
|
|
||||||
else:
|
|
||||||
raise NotImplementedError(f"2: {self.api_type}")
|
def send_request(self, query, is_test=False, **kwargs):
|
||||||
def parse_output(self, output: str, length):
|
extra_query = {
|
||||||
output = output.strip()
|
'do_sample': bool(self.config['do_sample']),
|
||||||
output_list = output.split("\n")
|
'num_beams': int(self.config['num_beams']),
|
||||||
if len(output_list) != length:
|
'repetition_penalty': float(self.config['repetition_penalty']),
|
||||||
# fallback to no history translation
|
}
|
||||||
return None
|
messages = self.make_messages(query, **kwargs)
|
||||||
else:
|
|
||||||
return output_list[-1]
|
|
||||||
def do_post(self, request):
|
|
||||||
try:
|
try:
|
||||||
response = self.session.post(self.api_url, json=request).json()
|
# OpenAI
|
||||||
if self.api_type == "dev_server":
|
# output = self.client.chat.completions.create(
|
||||||
output = response['results'][0]['text']
|
data = dict(
|
||||||
new_token = response['results'][0]['new_token']
|
model="sukinishiro",
|
||||||
elif self.api_type == "llama.cpp":
|
messages=messages,
|
||||||
output = response['content']
|
temperature=float(self.config['temperature']),
|
||||||
new_token = response['tokens_predicted']
|
top_p=float(self.config['top_p']),
|
||||||
else:
|
max_tokens= 1 if is_test else int(self.config['max_new_token']),
|
||||||
raise NotImplementedError("3")
|
frequency_penalty=float(kwargs['frequency_penalty']) if "frequency_penalty" in kwargs.keys() else float(self.config['frequency_penalty']),
|
||||||
|
seed=-1,
|
||||||
|
extra_query=extra_query,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
output = self.session.post(self.api_url + "/chat/completions", json=data).json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(str(e) + f"\napi_type: '{self.api_type}', api_url: '{self.api_url}', model_type: '{self.model_type}'\n与API接口通信失败,请检查设置的API服务器监听地址是否正确,或检查API服务器是否正常开启。")
|
raise ValueError(f"无法连接到Sakura API:{self.api_url},请检查你的API链接是否正确填写,以及API后端是否成功启动。")
|
||||||
return output, new_token, response
|
return output
|
||||||
def set_model_type(self):
|
|
||||||
#TODO: get model type from api
|
|
||||||
self.model_type = "NotImplemented"
|
|
||||||
request = self.make_request("test", is_test=True)
|
|
||||||
_, _, response = self.do_post(request)
|
|
||||||
if self.api_type == "llama.cpp":
|
|
||||||
model_name: str = response['model']
|
|
||||||
model_version = model_name.split("-")[-2]
|
|
||||||
if "0.8" in model_version:
|
|
||||||
self.model_type = "baichuan"
|
|
||||||
elif "0.9" in model_version:
|
|
||||||
self.model_type = "qwen"
|
|
||||||
return
|
|
||||||
def set_api_type(self):
|
|
||||||
endpoint = self.config['API接口地址']
|
|
||||||
if endpoint[-1] != "/":
|
|
||||||
endpoint += "/"
|
|
||||||
api_url = endpoint + "api/v1/generate"
|
|
||||||
test_json = self.make_request("test", is_test=True)
|
|
||||||
try:
|
|
||||||
response = self.session.post(api_url, json=test_json)
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(str(e) + f"\napi_type: '{self.api_type}', api_url: '{self.api_url}', model_type: '{self.model_type}'\n与API接口通信失败,请检查设置的API服务器监听地址是否正确,或检查API服务器是否正常开启。")
|
|
||||||
try:
|
|
||||||
response = response.json()
|
|
||||||
output = response['results'][0]['text']
|
|
||||||
new_token = response['results'][0]['new_token']
|
|
||||||
self.api_type = "dev_server"
|
|
||||||
self.api_url = api_url
|
|
||||||
except:
|
|
||||||
self.api_type = "llama.cpp"
|
|
||||||
self.api_url = endpoint + "completion"
|
|
||||||
|
|
||||||
return
|
|
||||||
def translate(self, query):
|
def translate(self, query):
|
||||||
self.checkempty(['API接口地址'])
|
self.checkempty(['API接口地址'])
|
||||||
if self.api_type == "":
|
if self.api_url == "":
|
||||||
self.set_api_type()
|
self.get_client(self.config['API接口地址'])
|
||||||
self.set_model_type()
|
frequency_penalty = float(self.config['frequency_penalty'])
|
||||||
|
if not bool(self.config['利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)']):
|
||||||
|
output = self.send_request(query)
|
||||||
|
completion_tokens = output["usage"]["completion_tokens"]
|
||||||
|
output_text = output["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
prompt = self.make_prompt(query)
|
|
||||||
request = self.make_request(prompt)
|
|
||||||
|
|
||||||
if not self.config['利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)']:
|
|
||||||
output, new_token, _ = self.do_post(request)
|
|
||||||
|
|
||||||
if bool(self.config['fix_degeneration']):
|
if bool(self.config['fix_degeneration']):
|
||||||
cnt = 0
|
cnt = 0
|
||||||
while new_token == self.config['max_new_token']:
|
while completion_tokens == int(self.config['max_new_token']):
|
||||||
# detect degeneration, fixing
|
# detect degeneration, fixing
|
||||||
request['frequency_penalty'] += 0.1
|
frequency_penalty += 0.1
|
||||||
output, new_token, _ = self.do_post(request)
|
output = self.send_request(query, frequency_penalty=frequency_penalty)
|
||||||
|
completion_tokens = output["usage"]["completion_tokens"]
|
||||||
|
output_text = output["choices"][0]["message"]["content"]
|
||||||
cnt += 1
|
cnt += 1
|
||||||
if cnt == 2:
|
if cnt == 2:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
query_list = self.sliding_window(query)
|
# 实验性功能,测试效果后决定是否加入。
|
||||||
request['prompt'] = self.make_prompt(query)
|
# fallback = False
|
||||||
output, new_token, _ = self.do_post(request)
|
# if self.config['启用日文上下文模式']:
|
||||||
|
# history_prompt = self.get_history('ja')
|
||||||
|
# output = self.send_request(history_prompt + "\n" + query)
|
||||||
|
# completion_tokens = output.usage.completion_tokens
|
||||||
|
# output_text = output.choices[0].message.content
|
||||||
|
|
||||||
|
# if len(output_text.split("\n")) == len(history_prompt.split("\n")) + 1:
|
||||||
|
# output_text = output_text.split("\n")[-1]
|
||||||
|
# else:
|
||||||
|
# fallback = True
|
||||||
|
# 如果日文上下文模式失败,则fallback到中文上下文模式。
|
||||||
|
# if fallback or not self.config['启用日文上下文模式']:
|
||||||
|
|
||||||
|
history_prompt = self.get_history('zh')
|
||||||
|
output = self.send_request(query, history_zh=history_prompt)
|
||||||
|
completion_tokens = output["usage"]["completion_tokens"]
|
||||||
|
output_text = output["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
if bool(self.config['fix_degeneration']):
|
if bool(self.config['fix_degeneration']):
|
||||||
cnt = 0
|
cnt = 0
|
||||||
while new_token == self.config['max_new_token']:
|
while completion_tokens == int(self.config['max_new_token']):
|
||||||
# detect degeneration, fixing
|
frequency_penalty += 0.1
|
||||||
request['frequency_penalty'] += 0.1
|
output = self.send_request(query, history_zh=history_prompt, frequency_penalty=frequency_penalty)
|
||||||
output, new_token, _ = self.do_post(request)
|
completion_tokens = output["usage"]["completion_tokens"]
|
||||||
|
output_text = output["choices"][0]["message"]["content"]
|
||||||
cnt += 1
|
cnt += 1
|
||||||
if cnt == 2:
|
if cnt == 3:
|
||||||
|
output_text = "Error:模型无法完整输出或退化无法解决,请调大设置中的max_new_token!!!原输出:" + output_text
|
||||||
break
|
break
|
||||||
|
self.sliding_window(query, output_text)
|
||||||
output = self.parse_output(output, len(query_list))
|
return output_text
|
||||||
if not output:
|
|
||||||
request['prompt'] = self.make_prompt(query)
|
|
||||||
output, new_token, _ = self.do_post(request)
|
|
||||||
return output
|
|
@ -463,10 +463,11 @@
|
|||||||
},
|
},
|
||||||
"sakura": {
|
"sakura": {
|
||||||
"args": {
|
"args": {
|
||||||
"Sakura部署教程": "https://sakura.srpr.moe",
|
"Sakura部署教程": "https://github.com/SakuraLLM/Sakura-13B-Galgame/wiki",
|
||||||
"Github仓库": "https://github.com/SakuraLLM/Sakura-13B-Galgame",
|
"Github仓库": "https://github.com/SakuraLLM/Sakura-13B-Galgame",
|
||||||
"API接口地址": "http://127.0.0.1:5000/",
|
"API接口地址": "http://127.0.0.1:8080/",
|
||||||
"利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)": false,
|
"利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)": false,
|
||||||
|
"附带上下文个数(必须打开利用上文翻译)": 3,
|
||||||
"temperature": 0.1,
|
"temperature": 0.1,
|
||||||
"top_p": 0.3,
|
"top_p": 0.3,
|
||||||
"num_beams": 1,
|
"num_beams": 1,
|
||||||
@ -492,6 +493,12 @@
|
|||||||
"利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)": {
|
"利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)": {
|
||||||
"type": "switch"
|
"type": "switch"
|
||||||
},
|
},
|
||||||
|
"附带上下文个数(必须打开利用上文翻译)":{
|
||||||
|
"type":"intspin",
|
||||||
|
"min":1,
|
||||||
|
"max":32,
|
||||||
|
"step":1
|
||||||
|
},
|
||||||
"temperature":{
|
"temperature":{
|
||||||
"type":"spin",
|
"type":"spin",
|
||||||
"min":0,
|
"min":0,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user