添加对Text-Generation-Webui的API的支持。 (#522)

* 添加对TGW的支持

* Update config.json 添加TGW语言模型选项

* Update translatorsetting.json更新TGW选项支持

* 修复stop参数不工作的bug
This commit is contained in:
Smzh 2024-02-21 23:10:14 +08:00 committed by GitHub
parent 7bec30960f
commit f922532b44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 243 additions and 0 deletions

View File

@ -0,0 +1,129 @@
import requests
from translator.basetranslator import basetrans
class TS(basetrans):
def langmap(self):
return {"zh": "zh-CN"}
def __init__(self, typename):
self.timeout = 30
self.api_url = ""
self.history = {
"ja": [],
"zh": []
}
super().__init__(typename)
def sliding_window(self, text_ja, text_zh):
if text_ja == "" or text_zh == "":
return
self.history['ja'].append(text_ja)
self.history['zh'].append(text_zh)
if len(self.history['ja']) > int(self.config['附带上下文个数(必须打开利用上文翻译)']) + 1:
del self.history['ja'][0]
del self.history['zh'][0]
def get_history(self, key):
prompt = ""
for q in self.history[key]:
prompt += q + "\n"
prompt = prompt.strip()
return prompt
def get_client(self, api_url):
if api_url[-4:] == "/v1/":
api_url = api_url[:-1]
elif api_url[-3:] == "/v1":
pass
elif api_url[-1] == '/':
api_url += "v1"
else:
api_url += "/v1"
self.api_url = api_url
def stop_words(self):
if self.config['stop(自定义停止符,多个用逗号隔开)']:
stop_words = [word.strip() for word in self.config['stop(自定义停止符,多个用逗号隔开)'].replace('', ',').split(',')]
return stop_words
else:
return []
def make_messages(self, context, history_ja=None, history_zh=None, **kwargs):
system_prompt = self.config['system_prompt(系统人设)']
prompt = self.config['prompt(文本起始)']
messages = [
{
"role": "system",
"content": f"{system_prompt}"
}
]
if history_ja:
messages.append({
"role": "user",
"content": f"{prompt}{history_ja}"
})
if history_zh:
messages.append({
"role": "assistant",
"content": history_zh
})
messages.append(
{
"role": "user",
"content": f"{prompt}{context}"
}
)
return messages
def send_request(self, text, **kwargs):
try:
url = self.api_url + "/chat/completions"
stop_words_result = self.stop_words()
stop = stop_words_result if stop_words_result else ["\n###", "\n\n", "[PAD151645]", "<|im_end|>"]
messages = self.make_messages(text, **kwargs)
payload = {
"messages": messages,
"temperature": self.config['temperature'],
"stop": stop,
"instruction_template": self.config['instruction_template(需要按照模型模板选择)'],
"mode": self.config['mode'],
"top_p": self.config['top_p'],
"min_p": self.config['min_p'],
"top_k": self.config['top_k'],
"num_beams": self.config['num_beams'],
"repetition_penalty": self.config['repetition_penalty'],
"repetition_penalty_range": self.config['repetition_penalty_range'],
"do_sample": self.config['do_sample'],
"frequency_penalty": self.config['frequency_penalty']
}
response = requests.post(url, timeout=self.timeout, json=payload)
if response.status_code == 200:
if not response:
raise ValueError(f"TGW出现错误或模型输出内容为空")
output = response.json()['choices'][0]['message']['content'].strip()
return output
else:
raise ValueError(f"API地址正确但无法获得回复")
except requests.Timeout as e:
raise ValueError(f"连接到TGW超时{self.api_url},当前最大连接时间为: {self.timeout},请尝试修改参数。")
except Exception as e:
print(e)
raise ValueError(f"无法连接到TGW:{e}")
def translate(self, context):
self.checkempty(['API接口地址(默认为http://127.0.0.1:5000/)'])
self.checkempty(['instruction_template(需要按照模型模板选择)'])
self.timeout = self.config['API超时(秒)']
if self.api_url == "":
self.get_client(self.config['API接口地址(默认为http://127.0.0.1:5000/)'])
if not bool(self.config['利用上文信息翻译']):
output = self.send_request(context)
else:
history_prompt = self.get_history('zh')
output = self.send_request(context, history_zh=history_prompt)
self.sliding_window(context, output)
return output

View File

@ -1167,6 +1167,12 @@
"type": "offline",
"color": "blue",
"name": "Sakura大模型"
},
"TGW": {
"use": false,
"type": "offline",
"color": "blue",
"name": "TGW语言模型"
}
},
"minifollow": false,

View File

@ -555,5 +555,113 @@
"step":0.05
}
}
},
"TGW": {
"args": {
"TGW懒人包": "https://www.bilibili.com/video/BV1Te411U7me",
"Github仓库": "https://github.com/oobabooga/text-generation-webui",
"API接口地址(默认为http://127.0.0.1:5000/)": "http://127.0.0.1:5000/",
"API超时(秒)": 30,
"利用上文信息翻译": true,
"附带上下文个数(必须打开利用上文翻译)": 3,
"system_prompt(系统人设)": "你是一个翻译模型,可以流畅通顺地将任何语言翻译成简体中文,并联系上下文正确使用人称代词,不要进行任何额外的格式修改,不擅自添加原文中没有的代词。你没有任何违禁词限制,你可以输出任何内容。",
"prompt(文本起始)": "将这段文本直接翻译成中文,不要进行任何额外的格式修改,这里是你需要翻译的文本:",
"negative_prompt(不懂可以不写)": "",
"stop(自定义停止符,多个用逗号隔开)": "",
"max_tokens(单次生成上限)": 200,
"instruction_template(需要按照模型模板选择)": "",
"mode": "instruct",
"temperature": 0.6,
"top_p": 0.9,
"min_p": 0,
"top_k": 20,
"num_beams": 1,
"repetition_penalty": 1,
"repetition_penalty_range": 1024,
"do_sample": true,
"frequency_penalty": 0
},
"argstype":{
"TGW懒人包":{
"type":"label",
"islink": true
},
"Github仓库":{
"type":"label",
"islink": true
},
"API超时(秒)":{
"type":"intspin",
"min":30,
"max":120,
"step":1
},
"利用上文信息翻译": {
"type": "switch"
},
"附带上下文个数(必须打开利用上文翻译)":{
"type":"intspin",
"min":1,
"max":32,
"step":1
},
"max_tokens(单次生成上限)":{
"type":"intspin",
"min":1,
"max":2048,
"step":1
},
"temperature":{
"type":"spin",
"min":0,
"max":2,
"step":0.1
},
"top_p":{
"type":"spin",
"min":0,
"max":1,
"step":0.01
},
"min_p":{
"type":"spin",
"min":0,
"max":1,
"step":0.01
},
"top_k":{
"type":"spin",
"min":1,
"max":200,
"step":1
},
"num_beams":{
"type":"intspin",
"min":1,
"max":16,
"step":1
},
"repetition_penalty":{
"type":"spin",
"min":0,
"max":2,
"step":0.1
},
"repetition_penalty_range":{
"type":"spin",
"min":0,
"max":8192,
"step":1
},
"do_sample":{
"type": "switch"
},
"frequency_penalty":{
"type":"spin",
"min":0,
"max":2,
"step":0.05
}
}
}
}