diff --git a/LunaTranslator/LunaTranslator/translator/sakura.py b/LunaTranslator/LunaTranslator/translator/sakura.py index 0b4672b9..e195d7b5 100644 --- a/LunaTranslator/LunaTranslator/translator/sakura.py +++ b/LunaTranslator/LunaTranslator/translator/sakura.py @@ -122,34 +122,31 @@ class TS(basetrans): "repetition_penalty": float(self.config["repetition_penalty"]), } messages = self.make_messages(query, **kwargs) - try: - # OpenAI - # output = self.client.chat.completions.create( - data = dict( - model="sukinishiro", - messages=messages, - temperature=float(self.config["temperature"]), - top_p=float(self.config["top_p"]), - max_tokens=1 if is_test else int(self.config["max_new_token"]), - frequency_penalty=float( - kwargs.get("frequency_penalty", self.config["frequency_penalty"]) - ), - seed=-1, - extra_query=extra_query, - stream=False, - ) - output = self.session.post( - self.api_url + "/chat/completions", json=data - ).json() - yield output - except requests.Timeout as e: - raise ValueError(f"连接到Sakura API超时:{self.api_url},请尝试修改参数。") - except Exception as e: - print(e) - raise ValueError( - f"无法连接到Sakura API:{self.api_url},请检查你的API链接是否正确填写,以及API后端是否成功启动。" - ) + # OpenAI + # output = self.client.chat.completions.create( + data = dict( + model="sukinishiro", + messages=messages, + temperature=float(self.config["temperature"]), + top_p=float(self.config["top_p"]), + max_tokens=1 if is_test else int(self.config["max_new_token"]), + frequency_penalty=float( + kwargs.get("frequency_penalty", self.config["frequency_penalty"]) + ), + seed=-1, + extra_query=extra_query, + stream=False, + ) + try: + output = self.session.post(self.api_url + "/chat/completions", json=data) + + except requests.Timeout as e: + raise ValueError(f"连接到Sakura API超时:{self.api_url}") + try: + yield output.json() + except: + raise Exception(output.text) def send_request_stream(self, query, is_test=False, **kwargs): extra_query = { @@ -158,43 +155,39 @@ class TS(basetrans): "repetition_penalty": float(self.config["repetition_penalty"]), } messages = self.make_messages(query, **kwargs) + + # OpenAI + # output = self.client.chat.completions.create( + data = dict( + model="sukinishiro", + messages=messages, + temperature=float(self.config["temperature"]), + top_p=float(self.config["top_p"]), + max_tokens=1 if is_test else int(self.config["max_new_token"]), + frequency_penalty=float( + kwargs.get("frequency_penalty", self.config["frequency_penalty"]) + ), + seed=-1, + extra_query=extra_query, + stream=True, + ) try: - # OpenAI - # output = self.client.chat.completions.create( - data = dict( - model="sukinishiro", - messages=messages, - temperature=float(self.config["temperature"]), - top_p=float(self.config["top_p"]), - max_tokens=1 if is_test else int(self.config["max_new_token"]), - frequency_penalty=float( - kwargs.get("frequency_penalty", self.config["frequency_penalty"]) - ), - seed=-1, - extra_query=extra_query, - stream=True, - ) output = self.session.post( self.api_url + "/chat/completions", json=data, stream=True, ) - for o in output.iter_lines(delimiter="\n\n".encode()): + except requests.Timeout: + raise ValueError(f"连接到Sakura API超时:{self.api_url}") + + for o in output.iter_lines(): + try: res = o.decode("utf-8").strip()[6:] # .replace("data: ", "") print(res) if res != "": yield json.loads(res) - except requests.Timeout as e: - raise ValueError(f"连接到Sakura API超时:{self.api_url},请尝试修改参数。") - - except Exception as e: - import traceback - - print(e) - e1 = traceback.format_exc() - raise ValueError( - f"Error: {str(e1)}. 无法连接到Sakura API:{self.api_url},请检查你的API链接是否正确填写,以及API后端是否成功启动。" - ) + except: + raise Exception(o) def translate(self, query): query = json.loads(query) @@ -210,7 +203,7 @@ class TS(basetrans): output_text = "" for o in output: if o["choices"][0]["finish_reason"] == None: - text_partial = o["choices"][0]["delta"]["content"] + text_partial = o["choices"][0]["delta"].get("content", "") output_text += text_partial yield text_partial completion_tokens += 1