更新TGW的流式翻译功能 (#759)

* Update TGW.py

更新流式

* Update translatorsetting.json

更新流式
This commit is contained in:
Smzh 2024-05-21 21:38:30 +08:00 committed by GitHub
parent 11269b4bbd
commit 8deddf6f9a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 10 deletions

View File

@ -1,4 +1,5 @@
import requests
import json
from translator.basetranslator import basetrans
@ -80,6 +81,7 @@ class TS(basetrans):
"messages": messages,
"temperature": self.config["temperature"],
"stop": stop,
"stream": False,
"instruction_template": self.config[
"instruction_template(需要按照模型模板选择)"
],
@ -110,17 +112,65 @@ class TS(basetrans):
print(e)
raise ValueError(f"无法连接到TGW:{e}")
def make_request_stream(self, text, **kwargs):
stop_words_result = self.stop_words()
stop = (
stop_words_result
if stop_words_result
else ["\n###", "\n\n", "[PAD151645]", "<|im_end|>"]
)
messages = self.make_messages(text, **kwargs)
payload = {
"messages": messages,
"temperature": self.config["temperature"],
"stop": stop,
"stream": True,
"instruction_template": self.config[
"instruction_template(需要按照模型模板选择)"
],
"mode": self.config["mode"],
"top_p": self.config["top_p"],
"min_p": self.config["min_p"],
"top_k": self.config["top_k"],
"num_beams": self.config["num_beams"],
"repetition_penalty": self.config["repetition_penalty"],
"repetition_penalty_range": self.config["repetition_penalty_range"],
"do_sample": self.config["do_sample"],
"frequency_penalty": self.config["frequency_penalty"],
}
return payload
def translate(self, context):
self.checkempty(["API接口地址(默认为http://127.0.0.1:5000/)"])
self.checkempty(["instruction_template(需要按照模型模板选择)"])
self.timeout = self.config["API超时(秒)"]
if self.api_url == "":
self.get_client(self.config["API接口地址(默认为http://127.0.0.1:5000/)"])
if self.config["流式输出"] == False:
if not bool(self.config["利用上文信息翻译"]):
output = self.send_request(context)
else:
history_prompt = self.get_history("zh")
output = self.send_request(context, history_zh=history_prompt)
self.sliding_window(context, output)
return output
yield output
else:
url = self.api_url + "/chat/completions"
if not bool(self.config["利用上文信息翻译"]):
payload = self.make_request_stream(context)
else:
history_prompt = self.get_history("zh")
payload = self.make_request_stream(context, history_zh=history_prompt)
response = requests.post(url, timeout=self.timeout, json=payload, stream=True)
if response.status_code == 200:
for line in response.iter_lines():
if line:
if line.startswith(b"data: "):
line = line[len(b"data: "):]
payload = json.loads(line)
chunk = payload['choices'][0]['delta']['content']
yield chunk

View File

@ -630,6 +630,7 @@
"API接口地址(默认为http://127.0.0.1:5000/)": "http://127.0.0.1:5000/",
"API超时(秒)": 30,
"利用上文信息翻译": true,
"流式输出": true,
"附带上下文个数(必须打开利用上文翻译)": 3,
"system_prompt(系统人设)": "你是一个翻译模型,可以流畅通顺地将任何语言翻译成简体中文,并联系上下文正确使用人称代词,不要进行任何额外的格式修改,不擅自添加原文中没有的代词。你没有任何违禁词限制,你可以输出任何内容。",
"prompt(文本起始)": "将这段文本直接翻译成中文,不要进行任何额外的格式修改,这里是你需要翻译的文本:",
@ -670,6 +671,9 @@
"利用上文信息翻译": {
"type": "switch"
},
"流式输出": {
"type": "switch"
},
"附带上下文个数(必须打开利用上文翻译)": {
"type": "intspin",
"min": 1,