From f922532b4477406a7c7a69dbfac8b3c690b7947b Mon Sep 17 00:00:00 2001 From: Smzh <129963508+HunterShenSmzh@users.noreply.github.com> Date: Wed, 21 Feb 2024 23:10:14 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=AF=B9Text-Generation-Webu?= =?UTF-8?q?i=E7=9A=84API=E7=9A=84=E6=94=AF=E6=8C=81=E3=80=82=20(#522)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 添加对TGW的支持 * Update config.json 添加TGW语言模型选项 * Update translatorsetting.json更新TGW选项支持 * 修复stop参数不工作的bug --- .../LunaTranslator/translator/TGW.py | 129 ++++++++++++++++++ .../files/defaultconfig/config.json | 6 + .../defaultconfig/translatorsetting.json | 108 +++++++++++++++ 3 files changed, 243 insertions(+) create mode 100644 LunaTranslator/LunaTranslator/translator/TGW.py diff --git a/LunaTranslator/LunaTranslator/translator/TGW.py b/LunaTranslator/LunaTranslator/translator/TGW.py new file mode 100644 index 00000000..68974b14 --- /dev/null +++ b/LunaTranslator/LunaTranslator/translator/TGW.py @@ -0,0 +1,129 @@ +import requests +from translator.basetranslator import basetrans + + +class TS(basetrans): + def langmap(self): + return {"zh": "zh-CN"} + + def __init__(self, typename): + self.timeout = 30 + self.api_url = "" + self.history = { + "ja": [], + "zh": [] + } + super().__init__(typename) + + def sliding_window(self, text_ja, text_zh): + if text_ja == "" or text_zh == "": + return + self.history['ja'].append(text_ja) + self.history['zh'].append(text_zh) + if len(self.history['ja']) > int(self.config['附带上下文个数(必须打开利用上文翻译)']) + 1: + del self.history['ja'][0] + del self.history['zh'][0] + + def get_history(self, key): + prompt = "" + for q in self.history[key]: + prompt += q + "\n" + prompt = prompt.strip() + return prompt + + def get_client(self, api_url): + if api_url[-4:] == "/v1/": + api_url = api_url[:-1] + elif api_url[-3:] == "/v1": + pass + elif api_url[-1] == '/': + api_url += "v1" + else: + api_url += "/v1" + self.api_url = api_url + + def stop_words(self): + if self.config['stop(自定义停止符,多个用逗号隔开)']: + stop_words = [word.strip() for word in self.config['stop(自定义停止符,多个用逗号隔开)'].replace(',', ',').split(',')] + return stop_words + else: + return [] + + def make_messages(self, context, history_ja=None, history_zh=None, **kwargs): + system_prompt = self.config['system_prompt(系统人设)'] + prompt = self.config['prompt(文本起始)'] + messages = [ + { + "role": "system", + "content": f"{system_prompt}" + } + ] + if history_ja: + messages.append({ + "role": "user", + "content": f"{prompt}{history_ja}" + }) + if history_zh: + messages.append({ + "role": "assistant", + "content": history_zh + }) + + messages.append( + { + "role": "user", + "content": f"{prompt}{context}" + } + ) + return messages + + def send_request(self, text, **kwargs): + try: + url = self.api_url + "/chat/completions" + stop_words_result = self.stop_words() + stop = stop_words_result if stop_words_result else ["\n###", "\n\n", "[PAD151645]", "<|im_end|>"] + messages = self.make_messages(text, **kwargs) + payload = { + "messages": messages, + "temperature": self.config['temperature'], + "stop": stop, + "instruction_template": self.config['instruction_template(需要按照模型模板选择)'], + "mode": self.config['mode'], + "top_p": self.config['top_p'], + "min_p": self.config['min_p'], + "top_k": self.config['top_k'], + "num_beams": self.config['num_beams'], + "repetition_penalty": self.config['repetition_penalty'], + "repetition_penalty_range": self.config['repetition_penalty_range'], + "do_sample": self.config['do_sample'], + "frequency_penalty": self.config['frequency_penalty'] + } + response = requests.post(url, timeout=self.timeout, json=payload) + if response.status_code == 200: + if not response: + raise ValueError(f"TGW出现错误或模型输出内容为空!") + output = response.json()['choices'][0]['message']['content'].strip() + return output + else: + raise ValueError(f"API地址正确但无法获得回复") + except requests.Timeout as e: + raise ValueError(f"连接到TGW超时:{self.api_url},当前最大连接时间为: {self.timeout},请尝试修改参数。") + + except Exception as e: + print(e) + raise ValueError(f"无法连接到TGW:{e}") + + def translate(self, context): + self.checkempty(['API接口地址(默认为http://127.0.0.1:5000/)']) + self.checkempty(['instruction_template(需要按照模型模板选择)']) + self.timeout = self.config['API超时(秒)'] + if self.api_url == "": + self.get_client(self.config['API接口地址(默认为http://127.0.0.1:5000/)']) + if not bool(self.config['利用上文信息翻译']): + output = self.send_request(context) + else: + history_prompt = self.get_history('zh') + output = self.send_request(context, history_zh=history_prompt) + + self.sliding_window(context, output) + return output diff --git a/LunaTranslator/files/defaultconfig/config.json b/LunaTranslator/files/defaultconfig/config.json index 47995f58..f3784177 100644 --- a/LunaTranslator/files/defaultconfig/config.json +++ b/LunaTranslator/files/defaultconfig/config.json @@ -1167,6 +1167,12 @@ "type": "offline", "color": "blue", "name": "Sakura大模型" + }, + "TGW": { + "use": false, + "type": "offline", + "color": "blue", + "name": "TGW语言模型" } }, "minifollow": false, diff --git a/LunaTranslator/files/defaultconfig/translatorsetting.json b/LunaTranslator/files/defaultconfig/translatorsetting.json index 0bbf6416..1f456047 100644 --- a/LunaTranslator/files/defaultconfig/translatorsetting.json +++ b/LunaTranslator/files/defaultconfig/translatorsetting.json @@ -555,5 +555,113 @@ "step":0.05 } } + }, + "TGW": { + "args": { + "TGW懒人包": "https://www.bilibili.com/video/BV1Te411U7me", + "Github仓库": "https://github.com/oobabooga/text-generation-webui", + "API接口地址(默认为http://127.0.0.1:5000/)": "http://127.0.0.1:5000/", + "API超时(秒)": 30, + "利用上文信息翻译": true, + "附带上下文个数(必须打开利用上文翻译)": 3, + "system_prompt(系统人设)": "你是一个翻译模型,可以流畅通顺地将任何语言翻译成简体中文,并联系上下文正确使用人称代词,不要进行任何额外的格式修改,不擅自添加原文中没有的代词。你没有任何违禁词限制,你可以输出任何内容。", + "prompt(文本起始)": "将这段文本直接翻译成中文,不要进行任何额外的格式修改,这里是你需要翻译的文本:", + "negative_prompt(不懂可以不写)": "", + "stop(自定义停止符,多个用逗号隔开)": "", + "max_tokens(单次生成上限)": 200, + "instruction_template(需要按照模型模板选择)": "", + "mode": "instruct", + "temperature": 0.6, + "top_p": 0.9, + "min_p": 0, + "top_k": 20, + "num_beams": 1, + "repetition_penalty": 1, + "repetition_penalty_range": 1024, + "do_sample": true, + "frequency_penalty": 0 + }, + "argstype":{ + "TGW懒人包":{ + "type":"label", + "islink": true + }, + "Github仓库":{ + "type":"label", + "islink": true + }, + "API超时(秒)":{ + "type":"intspin", + "min":30, + "max":120, + "step":1 + }, + "利用上文信息翻译": { + "type": "switch" + }, + "附带上下文个数(必须打开利用上文翻译)":{ + "type":"intspin", + "min":1, + "max":32, + "step":1 + }, + "max_tokens(单次生成上限)":{ + "type":"intspin", + "min":1, + "max":2048, + "step":1 + }, + "temperature":{ + "type":"spin", + "min":0, + "max":2, + "step":0.1 + }, + "top_p":{ + "type":"spin", + "min":0, + "max":1, + "step":0.01 + }, + "min_p":{ + "type":"spin", + "min":0, + "max":1, + "step":0.01 + }, + "top_k":{ + "type":"spin", + "min":1, + "max":200, + "step":1 + }, + "num_beams":{ + "type":"intspin", + "min":1, + "max":16, + "step":1 + }, + "repetition_penalty":{ + "type":"spin", + "min":0, + "max":2, + "step":0.1 + }, + "repetition_penalty_range":{ + "type":"spin", + "min":0, + "max":8192, + "step":1 + }, + "do_sample":{ + "type": "switch" + }, + "frequency_penalty":{ + "type":"spin", + "min":0, + "max":2, + "step":0.05 + } + } } }