From c4e55060f44a4787f9c0e27d2865e18543193138 Mon Sep 17 00:00:00 2001 From: SakuraUmi Date: Sat, 13 Jan 2024 19:44:54 +0800 Subject: [PATCH] Sakura: Compatible with new SakuraAPI. (#492) * Feat(Sakura): Compatible with new SakuraAPI by using openai package; Update info. * Fix bugs and change degen detecting method. * Fix bug. * Sakura: add context num option, fix bugs. * Sakura: Fix bugs. * restore * Sakura: restore using requests to interact with API. --- .../LunaTranslator/translator/sakura.py | 248 +++++++++--------- .../defaultconfig/translatorsetting.json | 11 +- 2 files changed, 126 insertions(+), 133 deletions(-) diff --git a/LunaTranslator/LunaTranslator/translator/sakura.py b/LunaTranslator/LunaTranslator/translator/sakura.py index 61ca8d4f..da955d73 100644 --- a/LunaTranslator/LunaTranslator/translator/sakura.py +++ b/LunaTranslator/LunaTranslator/translator/sakura.py @@ -1,165 +1,151 @@ from traceback import print_exc from translator.basetranslator import basetrans import requests +# OpenAI +# from openai import OpenAI class TS(basetrans): def langmap(self): return {"zh": "zh-CN"} def __init__(self, typename) : - self.api_type = "" self.api_url = "" - self.model_type = "" - self.history = [] + self.history = { + "ja": [], + "zh": [] + } self.session = requests.Session() super( ).__init__(typename) - def sliding_window(self, query): - self.history.append(query) - if len(self.history) > 4: - del self.history[0] - return self.history - def list_to_prompt(self, query_list): + def sliding_window(self, text_ja, text_zh): + if text_ja == "" or text_zh == "": + return + self.history['ja'].append(text_ja) + self.history['zh'].append(text_zh) + if len(self.history['ja']) > int(self.config['附带上下文个数(必须打开利用上文翻译)']) + 1: + del self.history['ja'][0] + del self.history['zh'][0] + def get_history(self, key): prompt = "" - for q in query_list: + for q in self.history[key]: prompt += q + "\n" prompt = prompt.strip() return prompt - def make_prompt(self, context): - if self.model_type == "baichuan": - prompt = f"将下面的日文文本翻译成中文:{context}" - elif self.model_type == "qwen": - prompt = f"<|im_start|>system\n你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。<|im_end|>\n<|im_start|>user\n将下面的日文文本翻译成中文:{context}<|im_end|>\n<|im_start|>assistant\n" + def get_client(self, api_url): + if api_url[-4:] == "/v1/": + api_url = api_url[:-1] + elif api_url[-3:] == "/v1": + pass + elif api_url[-1] == '/': + api_url += "v1" else: - prompt = f"将下面的日文文本翻译成中文:{context}" - - return prompt - - def make_request(self, prompt, is_test=False): - if self.api_type == "llama.cpp": - request = { - "prompt": prompt, - "n_predict": 1 if is_test else int(self.config['max_new_token']), - "temperature": float(self.config['temperature']), - "top_p": float(self.config['top_p']), - "repeat_penalty": float(self.config['repetition_penalty']), - "frequency_penalty": float(self.config['frequency_penalty']), - "top_k": 40, - "seed": -1 + api_url += "/v1" + self.api_url = api_url + # OpenAI + # self.client = OpenAI(api_key="114514", base_url=api_url) + def make_messages(self, query, history_ja=None, history_zh=None, **kwargs): + messages = [ + { + "role": "system", + "content": "你是一个轻小说翻译模型,可以流畅通顺地以日本轻小说的风格将日文翻译成简体中文,并联系上下文正确使用人称代词,不擅自添加原文中没有的代词。" } - return request - elif self.api_type == "dev_server" or is_test: - request = { - "prompt": prompt, - "max_new_tokens": 1 if is_test else int(self.config['max_new_token']), - "do_sample": bool(self.config['do_sample']), - "temperature": float(self.config['temperature']), - "top_p": float(self.config['top_p']), - "repetition_penalty": float(self.config['repetition_penalty']), - "num_beams": int(self.config['num_beams']), - "frequency_penalty": float(self.config['frequency_penalty']), - "top_k": 40, - "seed": -1 + ] + if history_ja: + messages.append({ + "role": "user", + "content": f"将下面的日文文本翻译成中文:{history_ja}" + }) + if history_zh: + messages.append({ + "role": "assistant", + "content": history_zh + }) + + messages.append( + { + "role": "user", + "content": f"将下面的日文文本翻译成中文:{query}" } - return request - elif self.api_type == "openai_like": - raise NotImplementedError(f"1: {self.api_type}") - else: - raise NotImplementedError(f"2: {self.api_type}") - def parse_output(self, output: str, length): - output = output.strip() - output_list = output.split("\n") - if len(output_list) != length: - # fallback to no history translation - return None - else: - return output_list[-1] - def do_post(self, request): + ) + return messages + + + def send_request(self, query, is_test=False, **kwargs): + extra_query = { + 'do_sample': bool(self.config['do_sample']), + 'num_beams': int(self.config['num_beams']), + 'repetition_penalty': float(self.config['repetition_penalty']), + } + messages = self.make_messages(query, **kwargs) try: - response = self.session.post(self.api_url, json=request).json() - if self.api_type == "dev_server": - output = response['results'][0]['text'] - new_token = response['results'][0]['new_token'] - elif self.api_type == "llama.cpp": - output = response['content'] - new_token = response['tokens_predicted'] - else: - raise NotImplementedError("3") + # OpenAI + # output = self.client.chat.completions.create( + data = dict( + model="sukinishiro", + messages=messages, + temperature=float(self.config['temperature']), + top_p=float(self.config['top_p']), + max_tokens= 1 if is_test else int(self.config['max_new_token']), + frequency_penalty=float(kwargs['frequency_penalty']) if "frequency_penalty" in kwargs.keys() else float(self.config['frequency_penalty']), + seed=-1, + extra_query=extra_query, + stream=False, + ) + output = self.session.post(self.api_url + "/chat/completions", json=data).json() except Exception as e: - raise Exception(str(e) + f"\napi_type: '{self.api_type}', api_url: '{self.api_url}', model_type: '{self.model_type}'\n与API接口通信失败,请检查设置的API服务器监听地址是否正确,或检查API服务器是否正常开启。") - return output, new_token, response - def set_model_type(self): - #TODO: get model type from api - self.model_type = "NotImplemented" - request = self.make_request("test", is_test=True) - _, _, response = self.do_post(request) - if self.api_type == "llama.cpp": - model_name: str = response['model'] - model_version = model_name.split("-")[-2] - if "0.8" in model_version: - self.model_type = "baichuan" - elif "0.9" in model_version: - self.model_type = "qwen" - return - def set_api_type(self): - endpoint = self.config['API接口地址'] - if endpoint[-1] != "/": - endpoint += "/" - api_url = endpoint + "api/v1/generate" - test_json = self.make_request("test", is_test=True) - try: - response = self.session.post(api_url, json=test_json) - except Exception as e: - raise Exception(str(e) + f"\napi_type: '{self.api_type}', api_url: '{self.api_url}', model_type: '{self.model_type}'\n与API接口通信失败,请检查设置的API服务器监听地址是否正确,或检查API服务器是否正常开启。") - try: - response = response.json() - output = response['results'][0]['text'] - new_token = response['results'][0]['new_token'] - self.api_type = "dev_server" - self.api_url = api_url - except: - self.api_type = "llama.cpp" - self.api_url = endpoint + "completion" - - return + raise ValueError(f"无法连接到Sakura API:{self.api_url},请检查你的API链接是否正确填写,以及API后端是否成功启动。") + return output + def translate(self, query): self.checkempty(['API接口地址']) - if self.api_type == "": - self.set_api_type() - self.set_model_type() + if self.api_url == "": + self.get_client(self.config['API接口地址']) + frequency_penalty = float(self.config['frequency_penalty']) + if not bool(self.config['利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)']): + output = self.send_request(query) + completion_tokens = output["usage"]["completion_tokens"] + output_text = output["choices"][0]["message"]["content"] - prompt = self.make_prompt(query) - request = self.make_request(prompt) - - if not self.config['利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)']: - output, new_token, _ = self.do_post(request) - if bool(self.config['fix_degeneration']): cnt = 0 - while new_token == self.config['max_new_token']: + while completion_tokens == int(self.config['max_new_token']): # detect degeneration, fixing - request['frequency_penalty'] += 0.1 - output, new_token, _ = self.do_post(request) - + frequency_penalty += 0.1 + output = self.send_request(query, frequency_penalty=frequency_penalty) + completion_tokens = output["usage"]["completion_tokens"] + output_text = output["choices"][0]["message"]["content"] cnt += 1 if cnt == 2: break else: - query_list = self.sliding_window(query) - request['prompt'] = self.make_prompt(query) - output, new_token, _ = self.do_post(request) - + # 实验性功能,测试效果后决定是否加入。 + # fallback = False + # if self.config['启用日文上下文模式']: + # history_prompt = self.get_history('ja') + # output = self.send_request(history_prompt + "\n" + query) + # completion_tokens = output.usage.completion_tokens + # output_text = output.choices[0].message.content + + # if len(output_text.split("\n")) == len(history_prompt.split("\n")) + 1: + # output_text = output_text.split("\n")[-1] + # else: + # fallback = True + # 如果日文上下文模式失败,则fallback到中文上下文模式。 + # if fallback or not self.config['启用日文上下文模式']: + + history_prompt = self.get_history('zh') + output = self.send_request(query, history_zh=history_prompt) + completion_tokens = output["usage"]["completion_tokens"] + output_text = output["choices"][0]["message"]["content"] + if bool(self.config['fix_degeneration']): cnt = 0 - while new_token == self.config['max_new_token']: - # detect degeneration, fixing - request['frequency_penalty'] += 0.1 - output, new_token, _ = self.do_post(request) - + while completion_tokens == int(self.config['max_new_token']): + frequency_penalty += 0.1 + output = self.send_request(query, history_zh=history_prompt, frequency_penalty=frequency_penalty) + completion_tokens = output["usage"]["completion_tokens"] + output_text = output["choices"][0]["message"]["content"] cnt += 1 - if cnt == 2: + if cnt == 3: + output_text = "Error:模型无法完整输出或退化无法解决,请调大设置中的max_new_token!!!原输出:" + output_text break - - output = self.parse_output(output, len(query_list)) - if not output: - request['prompt'] = self.make_prompt(query) - output, new_token, _ = self.do_post(request) - return output \ No newline at end of file + self.sliding_window(query, output_text) + return output_text \ No newline at end of file diff --git a/LunaTranslator/files/defaultconfig/translatorsetting.json b/LunaTranslator/files/defaultconfig/translatorsetting.json index 1c71fd98..1a297b84 100644 --- a/LunaTranslator/files/defaultconfig/translatorsetting.json +++ b/LunaTranslator/files/defaultconfig/translatorsetting.json @@ -463,10 +463,11 @@ }, "sakura": { "args": { - "Sakura部署教程": "https://sakura.srpr.moe", + "Sakura部署教程": "https://github.com/SakuraLLM/Sakura-13B-Galgame/wiki", "Github仓库": "https://github.com/SakuraLLM/Sakura-13B-Galgame", - "API接口地址": "http://127.0.0.1:5000/", + "API接口地址": "http://127.0.0.1:8080/", "利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)": false, + "附带上下文个数(必须打开利用上文翻译)": 3, "temperature": 0.1, "top_p": 0.3, "num_beams": 1, @@ -492,6 +493,12 @@ "利用上文信息翻译(通常会有一定的效果提升,但会导致变慢)": { "type": "switch" }, + "附带上下文个数(必须打开利用上文翻译)":{ + "type":"intspin", + "min":1, + "max":32, + "step":1 + }, "temperature":{ "type":"spin", "min":0,