This commit is contained in:
恍兮惚兮 2024-07-31 22:45:41 +08:00
parent 5f47a1aaa4
commit 967041386c

View File

@ -122,34 +122,31 @@ class TS(basetrans):
"repetition_penalty": float(self.config["repetition_penalty"]), "repetition_penalty": float(self.config["repetition_penalty"]),
} }
messages = self.make_messages(query, **kwargs) messages = self.make_messages(query, **kwargs)
try:
# OpenAI
# output = self.client.chat.completions.create(
data = dict(
model="sukinishiro",
messages=messages,
temperature=float(self.config["temperature"]),
top_p=float(self.config["top_p"]),
max_tokens=1 if is_test else int(self.config["max_new_token"]),
frequency_penalty=float(
kwargs.get("frequency_penalty", self.config["frequency_penalty"])
),
seed=-1,
extra_query=extra_query,
stream=False,
)
output = self.session.post(
self.api_url + "/chat/completions", json=data
).json()
yield output
except requests.Timeout as e:
raise ValueError(f"连接到Sakura API超时{self.api_url},请尝试修改参数。")
except Exception as e: # OpenAI
print(e) # output = self.client.chat.completions.create(
raise ValueError( data = dict(
f"无法连接到Sakura API{self.api_url}请检查你的API链接是否正确填写以及API后端是否成功启动。" model="sukinishiro",
) messages=messages,
temperature=float(self.config["temperature"]),
top_p=float(self.config["top_p"]),
max_tokens=1 if is_test else int(self.config["max_new_token"]),
frequency_penalty=float(
kwargs.get("frequency_penalty", self.config["frequency_penalty"])
),
seed=-1,
extra_query=extra_query,
stream=False,
)
try:
output = self.session.post(self.api_url + "/chat/completions", json=data)
except requests.Timeout as e:
raise ValueError(f"连接到Sakura API超时{self.api_url}")
try:
yield output.json()
except:
raise Exception(output.text)
def send_request_stream(self, query, is_test=False, **kwargs): def send_request_stream(self, query, is_test=False, **kwargs):
extra_query = { extra_query = {
@ -158,43 +155,39 @@ class TS(basetrans):
"repetition_penalty": float(self.config["repetition_penalty"]), "repetition_penalty": float(self.config["repetition_penalty"]),
} }
messages = self.make_messages(query, **kwargs) messages = self.make_messages(query, **kwargs)
# OpenAI
# output = self.client.chat.completions.create(
data = dict(
model="sukinishiro",
messages=messages,
temperature=float(self.config["temperature"]),
top_p=float(self.config["top_p"]),
max_tokens=1 if is_test else int(self.config["max_new_token"]),
frequency_penalty=float(
kwargs.get("frequency_penalty", self.config["frequency_penalty"])
),
seed=-1,
extra_query=extra_query,
stream=True,
)
try: try:
# OpenAI
# output = self.client.chat.completions.create(
data = dict(
model="sukinishiro",
messages=messages,
temperature=float(self.config["temperature"]),
top_p=float(self.config["top_p"]),
max_tokens=1 if is_test else int(self.config["max_new_token"]),
frequency_penalty=float(
kwargs.get("frequency_penalty", self.config["frequency_penalty"])
),
seed=-1,
extra_query=extra_query,
stream=True,
)
output = self.session.post( output = self.session.post(
self.api_url + "/chat/completions", self.api_url + "/chat/completions",
json=data, json=data,
stream=True, stream=True,
) )
for o in output.iter_lines(delimiter="\n\n".encode()): except requests.Timeout:
raise ValueError(f"连接到Sakura API超时{self.api_url}")
for o in output.iter_lines():
try:
res = o.decode("utf-8").strip()[6:] # .replace("data: ", "") res = o.decode("utf-8").strip()[6:] # .replace("data: ", "")
print(res) print(res)
if res != "": if res != "":
yield json.loads(res) yield json.loads(res)
except requests.Timeout as e: except:
raise ValueError(f"连接到Sakura API超时{self.api_url},请尝试修改参数。") raise Exception(o)
except Exception as e:
import traceback
print(e)
e1 = traceback.format_exc()
raise ValueError(
f"Error: {str(e1)}. 无法连接到Sakura API{self.api_url}请检查你的API链接是否正确填写以及API后端是否成功启动。"
)
def translate(self, query): def translate(self, query):
query = json.loads(query) query = json.loads(query)
@ -210,7 +203,7 @@ class TS(basetrans):
output_text = "" output_text = ""
for o in output: for o in output:
if o["choices"][0]["finish_reason"] == None: if o["choices"][0]["finish_reason"] == None:
text_partial = o["choices"][0]["delta"]["content"] text_partial = o["choices"][0]["delta"].get("content", "")
output_text += text_partial output_text += text_partial
yield text_partial yield text_partial
completion_tokens += 1 completion_tokens += 1