diff --git a/LunaTranslator/LunaTranslator/translator/gptcommon.py b/LunaTranslator/LunaTranslator/translator/gptcommon.py index 226a8ed8..2b808fea 100644 --- a/LunaTranslator/LunaTranslator/translator/gptcommon.py +++ b/LunaTranslator/LunaTranslator/translator/gptcommon.py @@ -1,5 +1,5 @@ from translator.basetranslator import basetrans -import json +import json, requests from traceback import print_exc @@ -28,10 +28,25 @@ class gptcommon(basetrans): self.context = [] super().__init__(typename) - def createurl(self): - if self.config["API接口地址"].endswith("/chat/completions"): - return self.config["API接口地址"] - return self.checkv1(self.config["API接口地址"]) + "/chat/completions" + def createdata(self, message): + try: + temperature = float(self.config["Temperature"]) + except: + temperature = 0.3 + + data = dict( + model=self.config["model"], + messages=message, + # optional + max_tokens=self.config["max_tokens"], + n=1, + # stop=None, + top_p=self.config["top_p"], + temperature=temperature, + frequency_penalty=self.config["frequency_penalty"], + stream=self.config["流式输出"], + ) + return data def createparam(self): return None @@ -54,14 +69,88 @@ class gptcommon(basetrans): api_url += "/v1" return api_url + def alicreateheaders(self): + h = self.createheaders() + if self.config["流式输出"]: + h.update({"Accept": "text/event-stream"}) + return h + + def alicreatedata(self, message): + + data = dict(model=self.config["model"], input=dict(messages=message)) + if self.config["流式输出"]: + data.update(dict(parameters=dict(incremental_output=True))) + + return data + + def aliparseresponse(self, query, response: requests.ResponseBase, usingstream): + if usingstream: + message = "" + for chunk in response.iter_lines(): + response_data = chunk.decode("utf-8").strip() + if not response_data: + continue + if response_data.startswith("data:") == False: + continue + try: + json_data = json.loads(response_data[5:]) + msg = json_data["output"]["text"] + yield msg + message += msg + except: + print_exc() + raise Exception(response_data) + + if json_data["output"]["finish_reason"] == "stop": + break + + else: + try: + message = ( + response.json()["output"]["text"].replace("\n\n", "\n").strip() + ) + yield message + except: + raise Exception(response.text) + self.context.append({"role": "user", "content": query}) + self.context.append({"role": "assistant", "content": message}) + + def commonparseresponse(self, query, response: requests.ResponseBase, usingstream): + if usingstream: + message = "" + for chunk in response.iter_lines(): + response_data = chunk.decode("utf-8").strip() + if not response_data: + continue + try: + json_data = json.loads(response_data[6:]) + rs = json_data["choices"][0].get("finish_reason") + if rs and rs != "null": + break + msg = json_data["choices"][0]["delta"]["content"] + yield msg + message += msg + + except: + print_exc() + raise Exception(response_data) + else: + try: + + message = ( + response.json()["choices"][0]["message"]["content"] + .replace("\n\n", "\n") + .strip() + ) + yield message + except: + raise Exception(response.text) + self.context.append({"role": "user", "content": query}) + self.context.append({"role": "assistant", "content": message}) + def translate(self, query): self.contextnum = int(self.config["附带上下文个数"]) - try: - temperature = float(self.config["Temperature"]) - except: - temperature = 0.3 - if self.config["使用自定义promt"]: message = [{"role": "user", "content": self.config["自定义promt"]}] else: @@ -85,52 +174,25 @@ class gptcommon(basetrans): message.append({"role": "user", "content": query}) usingstream = self.config["流式输出"] - data = dict( - model=self.config["model"], - messages=message, - # optional - max_tokens=self.config["max_tokens"], - n=1, - # stop=None, - top_p=self.config["top_p"], - temperature=temperature, - frequency_penalty=self.config["frequency_penalty"], - stream=usingstream, - ) - response = self.proxysession.post( - self.createurl(), - headers=self.createheaders(), - params=self.createparam(), - json=data, - stream=usingstream, - ) - if usingstream: - message = "" - for chunk in response.iter_lines(): - response_data = chunk.decode("utf-8").strip() - if not response_data: - continue - try: - json_data = json.loads(response_data[6:]) - rs = json_data["choices"][0]["finish_reason"] - if rs and rs != "null": - break - msg = json_data["choices"][0]["delta"]["content"] - yield msg - message += msg - except: - print_exc() - raise Exception(response_data) - + url = self.config["API接口地址"] + parseresponse = self.commonparseresponse + createheaders = self.createheaders + createdata = self.createdata + if url.endswith("/chat/completions"): + pass + elif url.endswith("/text-generation/generation") or 'dashscope.aliyuncs.com' in url: + # https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation + # 阿里云百炼 + parseresponse = self.aliparseresponse + createheaders = self.alicreateheaders + createdata = self.alicreatedata else: - try: - message = ( - response.json()["choices"][0]["message"]["content"] - .replace("\n\n", "\n") - .strip() - ) - yield message - except: - raise Exception(response.text) - self.context.append({"role": "user", "content": query}) - self.context.append({"role": "assistant", "content": message}) + url = self.checkv1(url) + "/chat/completions" + response = self.proxysession.post( + url, + headers=createheaders(), + params=self.createparam(), + json=createdata(message), + stream=usingstream, + ) + return parseresponse(query, response, usingstream) diff --git a/docs/zh/guochandamoxing.md b/docs/zh/guochandamoxing.md index 6e5ea2f8..81d4932a 100644 --- a/docs/zh/guochandamoxing.md +++ b/docs/zh/guochandamoxing.md @@ -1,6 +1,6 @@ ## 国产大模型如何使用ChatGPT兼容接口 -### 火山引擎(豆包大模型等) +### 字节跳动豆包大模型等 **API接口地址** `https://ark.cn-beijing.volces.com/api/v3` @@ -10,3 +10,8 @@ ![img](https://image.lunatranslator.xyz/zh/damoxing/doubao.png) + +### 阿里云百炼大模型 + +**API接口地址** `https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation` +