mirror of
https://github.com/HIllya51/LunaTranslator.git
synced 2025-01-01 10:04:12 +08:00
Feat(sakura): add stream output. (#557)
This commit is contained in:
parent
6c88fbce86
commit
164abaae4e
@ -1,6 +1,7 @@
|
|||||||
from traceback import print_exc
|
from traceback import print_exc
|
||||||
from translator.basetranslator import basetrans
|
from translator.basetranslator import basetrans
|
||||||
import requests
|
import requests
|
||||||
|
import json
|
||||||
# OpenAI
|
# OpenAI
|
||||||
# from openai import OpenAI
|
# from openai import OpenAI
|
||||||
|
|
||||||
@ -91,13 +92,49 @@ class TS(basetrans):
|
|||||||
stream=False,
|
stream=False,
|
||||||
)
|
)
|
||||||
output = self.session.post(self.api_url + "/chat/completions", timeout=self.timeout, json=data).json()
|
output = self.session.post(self.api_url + "/chat/completions", timeout=self.timeout, json=data).json()
|
||||||
|
yield output
|
||||||
except requests.Timeout as e:
|
except requests.Timeout as e:
|
||||||
raise ValueError(f"连接到Sakura API超时:{self.api_url},当前最大连接时间为: {self.timeout},请尝试修改参数。")
|
raise ValueError(f"连接到Sakura API超时:{self.api_url},当前最大连接时间为: {self.timeout},请尝试修改参数。")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
raise ValueError(f"无法连接到Sakura API:{self.api_url},请检查你的API链接是否正确填写,以及API后端是否成功启动。")
|
raise ValueError(f"无法连接到Sakura API:{self.api_url},请检查你的API链接是否正确填写,以及API后端是否成功启动。")
|
||||||
return output
|
|
||||||
|
def send_request_stream(self, query, is_test=False, **kwargs):
|
||||||
|
extra_query = {
|
||||||
|
'do_sample': bool(self.config['do_sample']),
|
||||||
|
'num_beams': int(self.config['num_beams']),
|
||||||
|
'repetition_penalty': float(self.config['repetition_penalty']),
|
||||||
|
}
|
||||||
|
messages = self.make_messages(query, **kwargs)
|
||||||
|
try:
|
||||||
|
# OpenAI
|
||||||
|
# output = self.client.chat.completions.create(
|
||||||
|
data = dict(
|
||||||
|
model="sukinishiro",
|
||||||
|
messages=messages,
|
||||||
|
temperature=float(self.config['temperature']),
|
||||||
|
top_p=float(self.config['top_p']),
|
||||||
|
max_tokens= 1 if is_test else int(self.config['max_new_token']),
|
||||||
|
frequency_penalty=float(kwargs['frequency_penalty']) if "frequency_penalty" in kwargs.keys() else float(self.config['frequency_penalty']),
|
||||||
|
seed=-1,
|
||||||
|
extra_query=extra_query,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
output = self.session.post(self.api_url + "/chat/completions", timeout=self.timeout, json=data, stream=True)
|
||||||
|
for o in output.iter_lines(delimiter="\n\n".encode()):
|
||||||
|
res = o.decode("utf-8").strip().replace("data: ", "")
|
||||||
|
print(res)
|
||||||
|
if res != "":
|
||||||
|
yield json.loads(res)
|
||||||
|
except requests.Timeout as e:
|
||||||
|
raise ValueError(f"连接到Sakura API超时:{self.api_url},当前最大连接时间为: {self.timeout},请尝试修改参数。")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
import traceback
|
||||||
|
print(e)
|
||||||
|
e1 = traceback.format_exc()
|
||||||
|
raise ValueError(f"Error: {str(e1)}. 无法连接到Sakura API:{self.api_url},请检查你的API链接是否正确填写,以及API后端是否成功启动。")
|
||||||
|
|
||||||
def translate(self, query):
|
def translate(self, query):
|
||||||
self.checkempty(['API接口地址'])
|
self.checkempty(['API接口地址'])
|
||||||
@ -106,18 +143,49 @@ class TS(basetrans):
|
|||||||
self.get_client(self.config['API接口地址'])
|
self.get_client(self.config['API接口地址'])
|
||||||
frequency_penalty = float(self.config['frequency_penalty'])
|
frequency_penalty = float(self.config['frequency_penalty'])
|
||||||
if not bool(self.config['利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)']):
|
if not bool(self.config['利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)']):
|
||||||
output = self.send_request(query)
|
if bool(self.config['流式输出']) == True:
|
||||||
completion_tokens = output["usage"]["completion_tokens"]
|
output = self.send_request_stream(query)
|
||||||
output_text = output["choices"][0]["message"]["content"]
|
completion_tokens = 0
|
||||||
|
for o in output:
|
||||||
|
if o['choices'][0]['finish_reason'] == None:
|
||||||
|
yield o['choices'][0]['delta']['content']
|
||||||
|
completion_tokens += 1
|
||||||
|
else:
|
||||||
|
finish_reason = o['choices'][0]['finish_reason']
|
||||||
|
else:
|
||||||
|
output = self.send_request(query)
|
||||||
|
for o in output:
|
||||||
|
completion_tokens = o["usage"]["completion_tokens"]
|
||||||
|
output_text = o["choices"][0]["message"]["content"]
|
||||||
|
yield output_text
|
||||||
|
|
||||||
if bool(self.config['fix_degeneration']):
|
if bool(self.config['fix_degeneration']):
|
||||||
cnt = 0
|
cnt = 0
|
||||||
|
print(completion_tokens)
|
||||||
while completion_tokens == int(self.config['max_new_token']):
|
while completion_tokens == int(self.config['max_new_token']):
|
||||||
# detect degeneration, fixing
|
# detect degeneration, fixing
|
||||||
frequency_penalty += 0.1
|
frequency_penalty += 0.1
|
||||||
output = self.send_request(query, frequency_penalty=frequency_penalty)
|
yield '\0'
|
||||||
completion_tokens = output["usage"]["completion_tokens"]
|
print("------------------清零------------------")
|
||||||
output_text = output["choices"][0]["message"]["content"]
|
if bool(self.config['流式输出']) == True:
|
||||||
|
output = self.send_request_stream(query, frequency_penalty=frequency_penalty)
|
||||||
|
completion_tokens = 0
|
||||||
|
for o in output:
|
||||||
|
if o['choices'][0]['finish_reason'] == None:
|
||||||
|
yield o['choices'][0]['delta']['content']
|
||||||
|
completion_tokens += 1
|
||||||
|
else:
|
||||||
|
finish_reason = o['choices'][0]['finish_reason']
|
||||||
|
else:
|
||||||
|
output = self.send_request(query, frequency_penalty=frequency_penalty)
|
||||||
|
for o in output:
|
||||||
|
completion_tokens = o["usage"]["completion_tokens"]
|
||||||
|
output_text = o["choices"][0]["message"]["content"]
|
||||||
|
yield output_text
|
||||||
|
|
||||||
|
# output = self.send_request(query, frequency_penalty=frequency_penalty)
|
||||||
|
# completion_tokens = output["usage"]["completion_tokens"]
|
||||||
|
# output_text = output["choices"][0]["message"]["content"]
|
||||||
cnt += 1
|
cnt += 1
|
||||||
if cnt == 2:
|
if cnt == 2:
|
||||||
break
|
break
|
||||||
@ -138,20 +206,54 @@ class TS(basetrans):
|
|||||||
# if fallback or not self.config['启用日文上下文模式']:
|
# if fallback or not self.config['启用日文上下文模式']:
|
||||||
|
|
||||||
history_prompt = self.get_history('zh')
|
history_prompt = self.get_history('zh')
|
||||||
output = self.send_request(query, history_zh=history_prompt)
|
# output = self.send_request(query, history_zh=history_prompt)
|
||||||
completion_tokens = output["usage"]["completion_tokens"]
|
# completion_tokens = output["usage"]["completion_tokens"]
|
||||||
output_text = output["choices"][0]["message"]["content"]
|
# output_text = output["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
if bool(self.config['流式输出']) == True:
|
||||||
|
output = self.send_request_stream(query, history_zh=history_prompt)
|
||||||
|
completion_tokens = 0
|
||||||
|
for o in output:
|
||||||
|
if o['choices'][0]['finish_reason'] == None:
|
||||||
|
yield o['choices'][0]['delta']['content']
|
||||||
|
completion_tokens += 1
|
||||||
|
else:
|
||||||
|
finish_reason = o['choices'][0]['finish_reason']
|
||||||
|
else:
|
||||||
|
output = self.send_request(query, history_zh=history_prompt)
|
||||||
|
for o in output:
|
||||||
|
completion_tokens = o["usage"]["completion_tokens"]
|
||||||
|
output_text = o["choices"][0]["message"]["content"]
|
||||||
|
yield output_text
|
||||||
|
|
||||||
if bool(self.config['fix_degeneration']):
|
if bool(self.config['fix_degeneration']):
|
||||||
cnt = 0
|
cnt = 0
|
||||||
|
print(completion_tokens)
|
||||||
while completion_tokens == int(self.config['max_new_token']):
|
while completion_tokens == int(self.config['max_new_token']):
|
||||||
frequency_penalty += 0.1
|
frequency_penalty += 0.1
|
||||||
output = self.send_request(query, history_zh=history_prompt, frequency_penalty=frequency_penalty)
|
yield '\0'
|
||||||
completion_tokens = output["usage"]["completion_tokens"]
|
print("------------------清零------------------")
|
||||||
output_text = output["choices"][0]["message"]["content"]
|
if bool(self.config['流式输出']) == True:
|
||||||
|
output = self.send_request_stream(query, history_zh=history_prompt, frequency_penalty=frequency_penalty)
|
||||||
|
completion_tokens = 0
|
||||||
|
for o in output:
|
||||||
|
if o['choices'][0]['finish_reason'] == None:
|
||||||
|
yield o['choices'][0]['delta']['content']
|
||||||
|
completion_tokens += 1
|
||||||
|
else:
|
||||||
|
finish_reason = o['choices'][0]['finish_reason']
|
||||||
|
else:
|
||||||
|
output = self.send_request(query, history_zh=history_prompt, frequency_penalty=frequency_penalty)
|
||||||
|
for o in output:
|
||||||
|
completion_tokens = o["usage"]["completion_tokens"]
|
||||||
|
output_text = o["choices"][0]["message"]["content"]
|
||||||
|
yield output_text
|
||||||
|
|
||||||
|
# output = self.send_request(query, history_zh=history_prompt, frequency_penalty=frequency_penalty)
|
||||||
|
# completion_tokens = output["usage"]["completion_tokens"]
|
||||||
|
# output_text = output["choices"][0]["message"]["content"]
|
||||||
cnt += 1
|
cnt += 1
|
||||||
if cnt == 3:
|
if cnt == 3:
|
||||||
output_text = "Error:模型无法完整输出或退化无法解决,请调大设置中的max_new_token!!!原输出:" + output_text
|
output_text = "Error:模型无法完整输出或退化无法解决,请调大设置中的max_new_token!!!原输出:" + output_text
|
||||||
break
|
break
|
||||||
self.sliding_window(query, output_text)
|
self.sliding_window(query, output_text)
|
||||||
return output_text
|
|
Loading…
x
Reference in New Issue
Block a user