This commit is contained in:
恍兮惚兮 2024-07-19 17:13:24 +08:00
parent 464c11b7ae
commit 8760a9fae6
2 changed files with 126 additions and 59 deletions

View File

@ -1,5 +1,5 @@
from translator.basetranslator import basetrans
import json
import json, requests
from traceback import print_exc
@ -28,10 +28,25 @@ class gptcommon(basetrans):
self.context = []
super().__init__(typename)
def createurl(self):
if self.config["API接口地址"].endswith("/chat/completions"):
return self.config["API接口地址"]
return self.checkv1(self.config["API接口地址"]) + "/chat/completions"
def createdata(self, message):
try:
temperature = float(self.config["Temperature"])
except:
temperature = 0.3
data = dict(
model=self.config["model"],
messages=message,
# optional
max_tokens=self.config["max_tokens"],
n=1,
# stop=None,
top_p=self.config["top_p"],
temperature=temperature,
frequency_penalty=self.config["frequency_penalty"],
stream=self.config["流式输出"],
)
return data
def createparam(self):
return None
@ -54,14 +69,88 @@ class gptcommon(basetrans):
api_url += "/v1"
return api_url
def alicreateheaders(self):
h = self.createheaders()
if self.config["流式输出"]:
h.update({"Accept": "text/event-stream"})
return h
def alicreatedata(self, message):
data = dict(model=self.config["model"], input=dict(messages=message))
if self.config["流式输出"]:
data.update(dict(parameters=dict(incremental_output=True)))
return data
def aliparseresponse(self, query, response: requests.ResponseBase, usingstream):
if usingstream:
message = ""
for chunk in response.iter_lines():
response_data = chunk.decode("utf-8").strip()
if not response_data:
continue
if response_data.startswith("data:") == False:
continue
try:
json_data = json.loads(response_data[5:])
msg = json_data["output"]["text"]
yield msg
message += msg
except:
print_exc()
raise Exception(response_data)
if json_data["output"]["finish_reason"] == "stop":
break
else:
try:
message = (
response.json()["output"]["text"].replace("\n\n", "\n").strip()
)
yield message
except:
raise Exception(response.text)
self.context.append({"role": "user", "content": query})
self.context.append({"role": "assistant", "content": message})
def commonparseresponse(self, query, response: requests.ResponseBase, usingstream):
if usingstream:
message = ""
for chunk in response.iter_lines():
response_data = chunk.decode("utf-8").strip()
if not response_data:
continue
try:
json_data = json.loads(response_data[6:])
rs = json_data["choices"][0].get("finish_reason")
if rs and rs != "null":
break
msg = json_data["choices"][0]["delta"]["content"]
yield msg
message += msg
except:
print_exc()
raise Exception(response_data)
else:
try:
message = (
response.json()["choices"][0]["message"]["content"]
.replace("\n\n", "\n")
.strip()
)
yield message
except:
raise Exception(response.text)
self.context.append({"role": "user", "content": query})
self.context.append({"role": "assistant", "content": message})
def translate(self, query):
self.contextnum = int(self.config["附带上下文个数"])
try:
temperature = float(self.config["Temperature"])
except:
temperature = 0.3
if self.config["使用自定义promt"]:
message = [{"role": "user", "content": self.config["自定义promt"]}]
else:
@ -85,52 +174,25 @@ class gptcommon(basetrans):
message.append({"role": "user", "content": query})
usingstream = self.config["流式输出"]
data = dict(
model=self.config["model"],
messages=message,
# optional
max_tokens=self.config["max_tokens"],
n=1,
# stop=None,
top_p=self.config["top_p"],
temperature=temperature,
frequency_penalty=self.config["frequency_penalty"],
stream=usingstream,
)
response = self.proxysession.post(
self.createurl(),
headers=self.createheaders(),
params=self.createparam(),
json=data,
stream=usingstream,
)
if usingstream:
message = ""
for chunk in response.iter_lines():
response_data = chunk.decode("utf-8").strip()
if not response_data:
continue
try:
json_data = json.loads(response_data[6:])
rs = json_data["choices"][0]["finish_reason"]
if rs and rs != "null":
break
msg = json_data["choices"][0]["delta"]["content"]
yield msg
message += msg
except:
print_exc()
raise Exception(response_data)
url = self.config["API接口地址"]
parseresponse = self.commonparseresponse
createheaders = self.createheaders
createdata = self.createdata
if url.endswith("/chat/completions"):
pass
elif url.endswith("/text-generation/generation") or 'dashscope.aliyuncs.com' in url:
# https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
# 阿里云百炼
parseresponse = self.aliparseresponse
createheaders = self.alicreateheaders
createdata = self.alicreatedata
else:
try:
message = (
response.json()["choices"][0]["message"]["content"]
.replace("\n\n", "\n")
.strip()
url = self.checkv1(url) + "/chat/completions"
response = self.proxysession.post(
url,
headers=createheaders(),
params=self.createparam(),
json=createdata(message),
stream=usingstream,
)
yield message
except:
raise Exception(response.text)
self.context.append({"role": "user", "content": query})
self.context.append({"role": "assistant", "content": message})
return parseresponse(query, response, usingstream)

View File

@ -1,6 +1,6 @@
## 国产大模型如何使用ChatGPT兼容接口
### 火山引擎(豆包大模型等)
### 字节跳动豆包大模型等
**API接口地址** `https://ark.cn-beijing.volces.com/api/v3`
@ -10,3 +10,8 @@
![img](https://image.lunatranslator.xyz/zh/damoxing/doubao.png)
### 阿里云百炼大模型
**API接口地址** `https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation`