Sakura: Compatible with new SakuraAPI. (#492)

* Feat(Sakura): Compatible with new SakuraAPI by using openai package; Update info.

* Fix bugs and change degen detecting method.

* Fix bug.

* Sakura: add context num option, fix bugs.

* Sakura: Fix bugs.

* restore

* Sakura: restore using requests to interact with API.
This commit is contained in:
SakuraUmi 2024-01-13 19:44:54 +08:00 committed by GitHub
parent 75713f8322
commit c4e55060f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 126 additions and 133 deletions

View File

@ -1,165 +1,151 @@
from traceback import print_exc
from translator.basetranslator import basetrans
import requests
# OpenAI
# from openai import OpenAI
class TS(basetrans):
def langmap(self):
return {"zh": "zh-CN"}
def __init__(self, typename) :
self.api_type = ""
self.api_url = ""
self.model_type = ""
self.history = []
self.history = {
"ja": [],
"zh": []
}
self.session = requests.Session()
super( ).__init__(typename)
def sliding_window(self, query):
self.history.append(query)
if len(self.history) > 4:
del self.history[0]
return self.history
def list_to_prompt(self, query_list):
def sliding_window(self, text_ja, text_zh):
if text_ja == "" or text_zh == "":
return
self.history['ja'].append(text_ja)
self.history['zh'].append(text_zh)
if len(self.history['ja']) > int(self.config['附带上下文个数(必须打开利用上文翻译)']) + 1:
del self.history['ja'][0]
del self.history['zh'][0]
def get_history(self, key):
prompt = ""
for q in query_list:
for q in self.history[key]:
prompt += q + "\n"
prompt = prompt.strip()
return prompt
def make_prompt(self, context):
if self.model_type == "baichuan":
prompt = f"<reserved_106>将下面的日文文本翻译成中文:{context}<reserved_107>"
elif self.model_type == "qwen":
prompt = f"<|im_start|>system\n你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。<|im_end|>\n<|im_start|>user\n将下面的日文文本翻译成中文:{context}<|im_end|>\n<|im_start|>assistant\n"
def get_client(self, api_url):
if api_url[-4:] == "/v1/":
api_url = api_url[:-1]
elif api_url[-3:] == "/v1":
pass
elif api_url[-1] == '/':
api_url += "v1"
else:
prompt = f"<reserved_106>将下面的日文文本翻译成中文:{context}<reserved_107>"
return prompt
def make_request(self, prompt, is_test=False):
if self.api_type == "llama.cpp":
request = {
"prompt": prompt,
"n_predict": 1 if is_test else int(self.config['max_new_token']),
"temperature": float(self.config['temperature']),
"top_p": float(self.config['top_p']),
"repeat_penalty": float(self.config['repetition_penalty']),
"frequency_penalty": float(self.config['frequency_penalty']),
"top_k": 40,
"seed": -1
api_url += "/v1"
self.api_url = api_url
# OpenAI
# self.client = OpenAI(api_key="114514", base_url=api_url)
def make_messages(self, query, history_ja=None, history_zh=None, **kwargs):
messages = [
{
"role": "system",
"content": "你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。"
}
return request
elif self.api_type == "dev_server" or is_test:
request = {
"prompt": prompt,
"max_new_tokens": 1 if is_test else int(self.config['max_new_token']),
"do_sample": bool(self.config['do_sample']),
"temperature": float(self.config['temperature']),
"top_p": float(self.config['top_p']),
"repetition_penalty": float(self.config['repetition_penalty']),
"num_beams": int(self.config['num_beams']),
"frequency_penalty": float(self.config['frequency_penalty']),
"top_k": 40,
"seed": -1
]
if history_ja:
messages.append({
"role": "user",
"content": f"将下面的日文文本翻译成中文:{history_ja}"
})
if history_zh:
messages.append({
"role": "assistant",
"content": history_zh
})
messages.append(
{
"role": "user",
"content": f"将下面的日文文本翻译成中文:{query}"
}
return request
elif self.api_type == "openai_like":
raise NotImplementedError(f"1: {self.api_type}")
else:
raise NotImplementedError(f"2: {self.api_type}")
def parse_output(self, output: str, length):
output = output.strip()
output_list = output.split("\n")
if len(output_list) != length:
# fallback to no history translation
return None
else:
return output_list[-1]
def do_post(self, request):
)
return messages
def send_request(self, query, is_test=False, **kwargs):
extra_query = {
'do_sample': bool(self.config['do_sample']),
'num_beams': int(self.config['num_beams']),
'repetition_penalty': float(self.config['repetition_penalty']),
}
messages = self.make_messages(query, **kwargs)
try:
response = self.session.post(self.api_url, json=request).json()
if self.api_type == "dev_server":
output = response['results'][0]['text']
new_token = response['results'][0]['new_token']
elif self.api_type == "llama.cpp":
output = response['content']
new_token = response['tokens_predicted']
else:
raise NotImplementedError("3")
# OpenAI
# output = self.client.chat.completions.create(
data = dict(
model="sukinishiro",
messages=messages,
temperature=float(self.config['temperature']),
top_p=float(self.config['top_p']),
max_tokens= 1 if is_test else int(self.config['max_new_token']),
frequency_penalty=float(kwargs['frequency_penalty']) if "frequency_penalty" in kwargs.keys() else float(self.config['frequency_penalty']),
seed=-1,
extra_query=extra_query,
stream=False,
)
output = self.session.post(self.api_url + "/chat/completions", json=data).json()
except Exception as e:
raise Exception(str(e) + f"\napi_type: '{self.api_type}', api_url: '{self.api_url}', model_type: '{self.model_type}'\n与API接口通信失败请检查设置的API服务器监听地址是否正确或检查API服务器是否正常开启。")
return output, new_token, response
def set_model_type(self):
#TODO: get model type from api
self.model_type = "NotImplemented"
request = self.make_request("test", is_test=True)
_, _, response = self.do_post(request)
if self.api_type == "llama.cpp":
model_name: str = response['model']
model_version = model_name.split("-")[-2]
if "0.8" in model_version:
self.model_type = "baichuan"
elif "0.9" in model_version:
self.model_type = "qwen"
return
def set_api_type(self):
endpoint = self.config['API接口地址']
if endpoint[-1] != "/":
endpoint += "/"
api_url = endpoint + "api/v1/generate"
test_json = self.make_request("test", is_test=True)
try:
response = self.session.post(api_url, json=test_json)
except Exception as e:
raise Exception(str(e) + f"\napi_type: '{self.api_type}', api_url: '{self.api_url}', model_type: '{self.model_type}'\n与API接口通信失败请检查设置的API服务器监听地址是否正确或检查API服务器是否正常开启。")
try:
response = response.json()
output = response['results'][0]['text']
new_token = response['results'][0]['new_token']
self.api_type = "dev_server"
self.api_url = api_url
except:
self.api_type = "llama.cpp"
self.api_url = endpoint + "completion"
return
raise ValueError(f"无法连接到Sakura API{self.api_url}请检查你的API链接是否正确填写以及API后端是否成功启动。")
return output
def translate(self, query):
self.checkempty(['API接口地址'])
if self.api_type == "":
self.set_api_type()
self.set_model_type()
if self.api_url == "":
self.get_client(self.config['API接口地址'])
frequency_penalty = float(self.config['frequency_penalty'])
if not bool(self.config['利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)']):
output = self.send_request(query)
completion_tokens = output["usage"]["completion_tokens"]
output_text = output["choices"][0]["message"]["content"]
prompt = self.make_prompt(query)
request = self.make_request(prompt)
if not self.config['利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)']:
output, new_token, _ = self.do_post(request)
if bool(self.config['fix_degeneration']):
cnt = 0
while new_token == self.config['max_new_token']:
while completion_tokens == int(self.config['max_new_token']):
# detect degeneration, fixing
request['frequency_penalty'] += 0.1
output, new_token, _ = self.do_post(request)
frequency_penalty += 0.1
output = self.send_request(query, frequency_penalty=frequency_penalty)
completion_tokens = output["usage"]["completion_tokens"]
output_text = output["choices"][0]["message"]["content"]
cnt += 1
if cnt == 2:
break
else:
query_list = self.sliding_window(query)
request['prompt'] = self.make_prompt(query)
output, new_token, _ = self.do_post(request)
# 实验性功能,测试效果后决定是否加入。
# fallback = False
# if self.config['启用日文上下文模式']:
# history_prompt = self.get_history('ja')
# output = self.send_request(history_prompt + "\n" + query)
# completion_tokens = output.usage.completion_tokens
# output_text = output.choices[0].message.content
# if len(output_text.split("\n")) == len(history_prompt.split("\n")) + 1:
# output_text = output_text.split("\n")[-1]
# else:
# fallback = True
# 如果日文上下文模式失败则fallback到中文上下文模式。
# if fallback or not self.config['启用日文上下文模式']:
history_prompt = self.get_history('zh')
output = self.send_request(query, history_zh=history_prompt)
completion_tokens = output["usage"]["completion_tokens"]
output_text = output["choices"][0]["message"]["content"]
if bool(self.config['fix_degeneration']):
cnt = 0
while new_token == self.config['max_new_token']:
# detect degeneration, fixing
request['frequency_penalty'] += 0.1
output, new_token, _ = self.do_post(request)
while completion_tokens == int(self.config['max_new_token']):
frequency_penalty += 0.1
output = self.send_request(query, history_zh=history_prompt, frequency_penalty=frequency_penalty)
completion_tokens = output["usage"]["completion_tokens"]
output_text = output["choices"][0]["message"]["content"]
cnt += 1
if cnt == 2:
if cnt == 3:
output_text = "Error模型无法完整输出或退化无法解决请调大设置中的max_new_token原输出" + output_text
break
output = self.parse_output(output, len(query_list))
if not output:
request['prompt'] = self.make_prompt(query)
output, new_token, _ = self.do_post(request)
return output
self.sliding_window(query, output_text)
return output_text

View File

@ -463,10 +463,11 @@
},
"sakura": {
"args": {
"Sakura部署教程": "https://sakura.srpr.moe",
"Sakura部署教程": "https://github.com/SakuraLLM/Sakura-13B-Galgame/wiki",
"Github仓库": "https://github.com/SakuraLLM/Sakura-13B-Galgame",
"API接口地址": "http://127.0.0.1:5000/",
"API接口地址": "http://127.0.0.1:8080/",
"利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)": false,
"附带上下文个数(必须打开利用上文翻译)": 3,
"temperature": 0.1,
"top_p": 0.3,
"num_beams": 1,
@ -492,6 +493,12 @@
"利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)": {
"type": "switch"
},
"附带上下文个数(必须打开利用上文翻译)":{
"type":"intspin",
"min":1,
"max":32,
"step":1
},
"temperature":{
"type":"spin",
"min":0,