mirror of
https://github.com/HIllya51/LunaTranslator.git
synced 2024-12-28 08:04:13 +08:00
gpt
This commit is contained in:
parent
464c11b7ae
commit
8760a9fae6
@ -1,5 +1,5 @@
|
||||
from translator.basetranslator import basetrans
|
||||
import json
|
||||
import json, requests
|
||||
from traceback import print_exc
|
||||
|
||||
|
||||
@ -28,10 +28,25 @@ class gptcommon(basetrans):
|
||||
self.context = []
|
||||
super().__init__(typename)
|
||||
|
||||
def createurl(self):
|
||||
if self.config["API接口地址"].endswith("/chat/completions"):
|
||||
return self.config["API接口地址"]
|
||||
return self.checkv1(self.config["API接口地址"]) + "/chat/completions"
|
||||
def createdata(self, message):
|
||||
try:
|
||||
temperature = float(self.config["Temperature"])
|
||||
except:
|
||||
temperature = 0.3
|
||||
|
||||
data = dict(
|
||||
model=self.config["model"],
|
||||
messages=message,
|
||||
# optional
|
||||
max_tokens=self.config["max_tokens"],
|
||||
n=1,
|
||||
# stop=None,
|
||||
top_p=self.config["top_p"],
|
||||
temperature=temperature,
|
||||
frequency_penalty=self.config["frequency_penalty"],
|
||||
stream=self.config["流式输出"],
|
||||
)
|
||||
return data
|
||||
|
||||
def createparam(self):
|
||||
return None
|
||||
@ -54,14 +69,88 @@ class gptcommon(basetrans):
|
||||
api_url += "/v1"
|
||||
return api_url
|
||||
|
||||
def alicreateheaders(self):
|
||||
h = self.createheaders()
|
||||
if self.config["流式输出"]:
|
||||
h.update({"Accept": "text/event-stream"})
|
||||
return h
|
||||
|
||||
def alicreatedata(self, message):
|
||||
|
||||
data = dict(model=self.config["model"], input=dict(messages=message))
|
||||
if self.config["流式输出"]:
|
||||
data.update(dict(parameters=dict(incremental_output=True)))
|
||||
|
||||
return data
|
||||
|
||||
def aliparseresponse(self, query, response: requests.ResponseBase, usingstream):
|
||||
if usingstream:
|
||||
message = ""
|
||||
for chunk in response.iter_lines():
|
||||
response_data = chunk.decode("utf-8").strip()
|
||||
if not response_data:
|
||||
continue
|
||||
if response_data.startswith("data:") == False:
|
||||
continue
|
||||
try:
|
||||
json_data = json.loads(response_data[5:])
|
||||
msg = json_data["output"]["text"]
|
||||
yield msg
|
||||
message += msg
|
||||
except:
|
||||
print_exc()
|
||||
raise Exception(response_data)
|
||||
|
||||
if json_data["output"]["finish_reason"] == "stop":
|
||||
break
|
||||
|
||||
else:
|
||||
try:
|
||||
message = (
|
||||
response.json()["output"]["text"].replace("\n\n", "\n").strip()
|
||||
)
|
||||
yield message
|
||||
except:
|
||||
raise Exception(response.text)
|
||||
self.context.append({"role": "user", "content": query})
|
||||
self.context.append({"role": "assistant", "content": message})
|
||||
|
||||
def commonparseresponse(self, query, response: requests.ResponseBase, usingstream):
|
||||
if usingstream:
|
||||
message = ""
|
||||
for chunk in response.iter_lines():
|
||||
response_data = chunk.decode("utf-8").strip()
|
||||
if not response_data:
|
||||
continue
|
||||
try:
|
||||
json_data = json.loads(response_data[6:])
|
||||
rs = json_data["choices"][0].get("finish_reason")
|
||||
if rs and rs != "null":
|
||||
break
|
||||
msg = json_data["choices"][0]["delta"]["content"]
|
||||
yield msg
|
||||
message += msg
|
||||
|
||||
except:
|
||||
print_exc()
|
||||
raise Exception(response_data)
|
||||
else:
|
||||
try:
|
||||
|
||||
message = (
|
||||
response.json()["choices"][0]["message"]["content"]
|
||||
.replace("\n\n", "\n")
|
||||
.strip()
|
||||
)
|
||||
yield message
|
||||
except:
|
||||
raise Exception(response.text)
|
||||
self.context.append({"role": "user", "content": query})
|
||||
self.context.append({"role": "assistant", "content": message})
|
||||
|
||||
def translate(self, query):
|
||||
self.contextnum = int(self.config["附带上下文个数"])
|
||||
|
||||
try:
|
||||
temperature = float(self.config["Temperature"])
|
||||
except:
|
||||
temperature = 0.3
|
||||
|
||||
if self.config["使用自定义promt"]:
|
||||
message = [{"role": "user", "content": self.config["自定义promt"]}]
|
||||
else:
|
||||
@ -85,52 +174,25 @@ class gptcommon(basetrans):
|
||||
message.append({"role": "user", "content": query})
|
||||
|
||||
usingstream = self.config["流式输出"]
|
||||
data = dict(
|
||||
model=self.config["model"],
|
||||
messages=message,
|
||||
# optional
|
||||
max_tokens=self.config["max_tokens"],
|
||||
n=1,
|
||||
# stop=None,
|
||||
top_p=self.config["top_p"],
|
||||
temperature=temperature,
|
||||
frequency_penalty=self.config["frequency_penalty"],
|
||||
stream=usingstream,
|
||||
)
|
||||
response = self.proxysession.post(
|
||||
self.createurl(),
|
||||
headers=self.createheaders(),
|
||||
params=self.createparam(),
|
||||
json=data,
|
||||
stream=usingstream,
|
||||
)
|
||||
if usingstream:
|
||||
message = ""
|
||||
for chunk in response.iter_lines():
|
||||
response_data = chunk.decode("utf-8").strip()
|
||||
if not response_data:
|
||||
continue
|
||||
try:
|
||||
json_data = json.loads(response_data[6:])
|
||||
rs = json_data["choices"][0]["finish_reason"]
|
||||
if rs and rs != "null":
|
||||
break
|
||||
msg = json_data["choices"][0]["delta"]["content"]
|
||||
yield msg
|
||||
message += msg
|
||||
except:
|
||||
print_exc()
|
||||
raise Exception(response_data)
|
||||
|
||||
url = self.config["API接口地址"]
|
||||
parseresponse = self.commonparseresponse
|
||||
createheaders = self.createheaders
|
||||
createdata = self.createdata
|
||||
if url.endswith("/chat/completions"):
|
||||
pass
|
||||
elif url.endswith("/text-generation/generation") or 'dashscope.aliyuncs.com' in url:
|
||||
# https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation
|
||||
# 阿里云百炼
|
||||
parseresponse = self.aliparseresponse
|
||||
createheaders = self.alicreateheaders
|
||||
createdata = self.alicreatedata
|
||||
else:
|
||||
try:
|
||||
message = (
|
||||
response.json()["choices"][0]["message"]["content"]
|
||||
.replace("\n\n", "\n")
|
||||
.strip()
|
||||
)
|
||||
yield message
|
||||
except:
|
||||
raise Exception(response.text)
|
||||
self.context.append({"role": "user", "content": query})
|
||||
self.context.append({"role": "assistant", "content": message})
|
||||
url = self.checkv1(url) + "/chat/completions"
|
||||
response = self.proxysession.post(
|
||||
url,
|
||||
headers=createheaders(),
|
||||
params=self.createparam(),
|
||||
json=createdata(message),
|
||||
stream=usingstream,
|
||||
)
|
||||
return parseresponse(query, response, usingstream)
|
||||
|
@ -1,6 +1,6 @@
|
||||
## 国产大模型如何使用ChatGPT兼容接口
|
||||
|
||||
### 火山引擎(豆包大模型等)
|
||||
### 字节跳动豆包大模型等
|
||||
|
||||
**API接口地址** `https://ark.cn-beijing.volces.com/api/v3`
|
||||
|
||||
@ -10,3 +10,8 @@
|
||||
|
||||
![img](https://image.lunatranslator.xyz/zh/damoxing/doubao.png)
|
||||
|
||||
|
||||
### 阿里云百炼大模型
|
||||
|
||||
**API接口地址** `https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation`
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user