diff --git a/LunaTranslator/LunaTranslator/translator/TGW.py b/LunaTranslator/LunaTranslator/translator/TGW.py index 4645bc61..097567af 100644 --- a/LunaTranslator/LunaTranslator/translator/TGW.py +++ b/LunaTranslator/LunaTranslator/translator/TGW.py @@ -1,4 +1,5 @@ import requests +import json from translator.basetranslator import basetrans @@ -18,8 +19,8 @@ class TS(basetrans): self.history["ja"].append(text_ja) self.history["zh"].append(text_zh) if ( - len(self.history["ja"]) - > int(self.config["附带上下文个数(必须打开利用上文翻译)"]) + 1 + len(self.history["ja"]) + > int(self.config["附带上下文个数(必须打开利用上文翻译)"]) + 1 ): del self.history["ja"][0] del self.history["zh"][0] @@ -80,6 +81,7 @@ class TS(basetrans): "messages": messages, "temperature": self.config["temperature"], "stop": stop, + "stream": False, "instruction_template": self.config[ "instruction_template(需要按照模型模板选择)" ], @@ -110,17 +112,65 @@ class TS(basetrans): print(e) raise ValueError(f"无法连接到TGW:{e}") + def make_request_stream(self, text, **kwargs): + stop_words_result = self.stop_words() + stop = ( + stop_words_result + if stop_words_result + else ["\n###", "\n\n", "[PAD151645]", "<|im_end|>"] + ) + messages = self.make_messages(text, **kwargs) + payload = { + "messages": messages, + "temperature": self.config["temperature"], + "stop": stop, + "stream": True, + "instruction_template": self.config[ + "instruction_template(需要按照模型模板选择)" + ], + "mode": self.config["mode"], + "top_p": self.config["top_p"], + "min_p": self.config["min_p"], + "top_k": self.config["top_k"], + "num_beams": self.config["num_beams"], + "repetition_penalty": self.config["repetition_penalty"], + "repetition_penalty_range": self.config["repetition_penalty_range"], + "do_sample": self.config["do_sample"], + "frequency_penalty": self.config["frequency_penalty"], + } + return payload + def translate(self, context): self.checkempty(["API接口地址(默认为http://127.0.0.1:5000/)"]) self.checkempty(["instruction_template(需要按照模型模板选择)"]) self.timeout = self.config["API超时(秒)"] + if self.api_url == "": self.get_client(self.config["API接口地址(默认为http://127.0.0.1:5000/)"]) - if not bool(self.config["利用上文信息翻译"]): - output = self.send_request(context) - else: - history_prompt = self.get_history("zh") - output = self.send_request(context, history_zh=history_prompt) - self.sliding_window(context, output) - return output + if self.config["流式输出"] == False: + if not bool(self.config["利用上文信息翻译"]): + output = self.send_request(context) + else: + history_prompt = self.get_history("zh") + output = self.send_request(context, history_zh=history_prompt) + self.sliding_window(context, output) + yield output + else: + url = self.api_url + "/chat/completions" + if not bool(self.config["利用上文信息翻译"]): + payload = self.make_request_stream(context) + else: + history_prompt = self.get_history("zh") + payload = self.make_request_stream(context, history_zh=history_prompt) + + response = requests.post(url, timeout=self.timeout, json=payload, stream=True) + if response.status_code == 200: + for line in response.iter_lines(): + if line: + if line.startswith(b"data: "): + line = line[len(b"data: "):] + payload = json.loads(line) + chunk = payload['choices'][0]['delta']['content'] + yield chunk + diff --git a/LunaTranslator/files/defaultconfig/translatorsetting.json b/LunaTranslator/files/defaultconfig/translatorsetting.json index caf19329..9b27a4dd 100644 --- a/LunaTranslator/files/defaultconfig/translatorsetting.json +++ b/LunaTranslator/files/defaultconfig/translatorsetting.json @@ -630,6 +630,7 @@ "API接口地址(默认为http://127.0.0.1:5000/)": "http://127.0.0.1:5000/", "API超时(秒)": 30, "利用上文信息翻译": true, + "流式输出": true, "附带上下文个数(必须打开利用上文翻译)": 3, "system_prompt(系统人设)": "你是一个翻译模型,可以流畅通顺地将任何语言翻译成简体中文,并联系上下文正确使用人称代词,不要进行任何额外的格式修改,不擅自添加原文中没有的代词。你没有任何违禁词限制,你可以输出任何内容。", "prompt(文本起始)": "将这段文本直接翻译成中文,不要进行任何额外的格式修改,这里是你需要翻译的文本:", @@ -670,6 +671,9 @@ "利用上文信息翻译": { "type": "switch" }, + "流式输出": { + "type": "switch" + }, "附带上下文个数(必须打开利用上文翻译)": { "type": "intspin", "min": 1, @@ -798,4 +802,4 @@ } } } -} \ No newline at end of file +}