mirror of
https://github.com/HIllya51/LunaTranslator.git
synced 2025-01-15 08:53:53 +08:00
fix
This commit is contained in:
parent
5f47a1aaa4
commit
967041386c
@ -122,34 +122,31 @@ class TS(basetrans):
|
|||||||
"repetition_penalty": float(self.config["repetition_penalty"]),
|
"repetition_penalty": float(self.config["repetition_penalty"]),
|
||||||
}
|
}
|
||||||
messages = self.make_messages(query, **kwargs)
|
messages = self.make_messages(query, **kwargs)
|
||||||
try:
|
|
||||||
# OpenAI
|
|
||||||
# output = self.client.chat.completions.create(
|
|
||||||
data = dict(
|
|
||||||
model="sukinishiro",
|
|
||||||
messages=messages,
|
|
||||||
temperature=float(self.config["temperature"]),
|
|
||||||
top_p=float(self.config["top_p"]),
|
|
||||||
max_tokens=1 if is_test else int(self.config["max_new_token"]),
|
|
||||||
frequency_penalty=float(
|
|
||||||
kwargs.get("frequency_penalty", self.config["frequency_penalty"])
|
|
||||||
),
|
|
||||||
seed=-1,
|
|
||||||
extra_query=extra_query,
|
|
||||||
stream=False,
|
|
||||||
)
|
|
||||||
output = self.session.post(
|
|
||||||
self.api_url + "/chat/completions", json=data
|
|
||||||
).json()
|
|
||||||
yield output
|
|
||||||
except requests.Timeout as e:
|
|
||||||
raise ValueError(f"连接到Sakura API超时:{self.api_url},请尝试修改参数。")
|
|
||||||
|
|
||||||
except Exception as e:
|
# OpenAI
|
||||||
print(e)
|
# output = self.client.chat.completions.create(
|
||||||
raise ValueError(
|
data = dict(
|
||||||
f"无法连接到Sakura API:{self.api_url},请检查你的API链接是否正确填写,以及API后端是否成功启动。"
|
model="sukinishiro",
|
||||||
)
|
messages=messages,
|
||||||
|
temperature=float(self.config["temperature"]),
|
||||||
|
top_p=float(self.config["top_p"]),
|
||||||
|
max_tokens=1 if is_test else int(self.config["max_new_token"]),
|
||||||
|
frequency_penalty=float(
|
||||||
|
kwargs.get("frequency_penalty", self.config["frequency_penalty"])
|
||||||
|
),
|
||||||
|
seed=-1,
|
||||||
|
extra_query=extra_query,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
output = self.session.post(self.api_url + "/chat/completions", json=data)
|
||||||
|
|
||||||
|
except requests.Timeout as e:
|
||||||
|
raise ValueError(f"连接到Sakura API超时:{self.api_url}")
|
||||||
|
try:
|
||||||
|
yield output.json()
|
||||||
|
except:
|
||||||
|
raise Exception(output.text)
|
||||||
|
|
||||||
def send_request_stream(self, query, is_test=False, **kwargs):
|
def send_request_stream(self, query, is_test=False, **kwargs):
|
||||||
extra_query = {
|
extra_query = {
|
||||||
@ -158,43 +155,39 @@ class TS(basetrans):
|
|||||||
"repetition_penalty": float(self.config["repetition_penalty"]),
|
"repetition_penalty": float(self.config["repetition_penalty"]),
|
||||||
}
|
}
|
||||||
messages = self.make_messages(query, **kwargs)
|
messages = self.make_messages(query, **kwargs)
|
||||||
|
|
||||||
|
# OpenAI
|
||||||
|
# output = self.client.chat.completions.create(
|
||||||
|
data = dict(
|
||||||
|
model="sukinishiro",
|
||||||
|
messages=messages,
|
||||||
|
temperature=float(self.config["temperature"]),
|
||||||
|
top_p=float(self.config["top_p"]),
|
||||||
|
max_tokens=1 if is_test else int(self.config["max_new_token"]),
|
||||||
|
frequency_penalty=float(
|
||||||
|
kwargs.get("frequency_penalty", self.config["frequency_penalty"])
|
||||||
|
),
|
||||||
|
seed=-1,
|
||||||
|
extra_query=extra_query,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
try:
|
try:
|
||||||
# OpenAI
|
|
||||||
# output = self.client.chat.completions.create(
|
|
||||||
data = dict(
|
|
||||||
model="sukinishiro",
|
|
||||||
messages=messages,
|
|
||||||
temperature=float(self.config["temperature"]),
|
|
||||||
top_p=float(self.config["top_p"]),
|
|
||||||
max_tokens=1 if is_test else int(self.config["max_new_token"]),
|
|
||||||
frequency_penalty=float(
|
|
||||||
kwargs.get("frequency_penalty", self.config["frequency_penalty"])
|
|
||||||
),
|
|
||||||
seed=-1,
|
|
||||||
extra_query=extra_query,
|
|
||||||
stream=True,
|
|
||||||
)
|
|
||||||
output = self.session.post(
|
output = self.session.post(
|
||||||
self.api_url + "/chat/completions",
|
self.api_url + "/chat/completions",
|
||||||
json=data,
|
json=data,
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
for o in output.iter_lines(delimiter="\n\n".encode()):
|
except requests.Timeout:
|
||||||
|
raise ValueError(f"连接到Sakura API超时:{self.api_url}")
|
||||||
|
|
||||||
|
for o in output.iter_lines():
|
||||||
|
try:
|
||||||
res = o.decode("utf-8").strip()[6:] # .replace("data: ", "")
|
res = o.decode("utf-8").strip()[6:] # .replace("data: ", "")
|
||||||
print(res)
|
print(res)
|
||||||
if res != "":
|
if res != "":
|
||||||
yield json.loads(res)
|
yield json.loads(res)
|
||||||
except requests.Timeout as e:
|
except:
|
||||||
raise ValueError(f"连接到Sakura API超时:{self.api_url},请尝试修改参数。")
|
raise Exception(o)
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
import traceback
|
|
||||||
|
|
||||||
print(e)
|
|
||||||
e1 = traceback.format_exc()
|
|
||||||
raise ValueError(
|
|
||||||
f"Error: {str(e1)}. 无法连接到Sakura API:{self.api_url},请检查你的API链接是否正确填写,以及API后端是否成功启动。"
|
|
||||||
)
|
|
||||||
|
|
||||||
def translate(self, query):
|
def translate(self, query):
|
||||||
query = json.loads(query)
|
query = json.loads(query)
|
||||||
@ -210,7 +203,7 @@ class TS(basetrans):
|
|||||||
output_text = ""
|
output_text = ""
|
||||||
for o in output:
|
for o in output:
|
||||||
if o["choices"][0]["finish_reason"] == None:
|
if o["choices"][0]["finish_reason"] == None:
|
||||||
text_partial = o["choices"][0]["delta"]["content"]
|
text_partial = o["choices"][0]["delta"].get("content", "")
|
||||||
output_text += text_partial
|
output_text += text_partial
|
||||||
yield text_partial
|
yield text_partial
|
||||||
completion_tokens += 1
|
completion_tokens += 1
|
||||||
|
Loading…
x
Reference in New Issue
Block a user