fix pull/501

This commit is contained in:
恍兮惚兮 2024-01-26 22:00:48 +08:00
parent c182cbbf99
commit 52515993aa
5 changed files with 38 additions and 30 deletions

View File

@ -140,6 +140,8 @@ class CURLoption(c_int):
CURLOPT_SHARE=CURLOPTTYPE_OBJECTPOINT+100 CURLOPT_SHARE=CURLOPTTYPE_OBJECTPOINT+100
CURLOPT_ACCEPT_ENCODING=CURLOPTTYPE_STRINGPOINT+102 CURLOPT_ACCEPT_ENCODING=CURLOPTTYPE_STRINGPOINT+102
CURLOPT_CONNECT_ONLY=CURLOPTTYPE_LONG+141 CURLOPT_CONNECT_ONLY=CURLOPTTYPE_LONG+141
CURLOPT_TIMEOUT_MS=CURLOPTTYPE_LONG+155
CURLOPT_CONNECTTIMEOUT_MS=CURLOPTTYPE_LONG+156
class CURLINFO(c_int): class CURLINFO(c_int):
CURLINFO_STRING =0x100000 CURLINFO_STRING =0x100000
CURLINFO_LONG =0x200000 CURLINFO_LONG =0x200000
@ -280,15 +282,13 @@ class AutoCURLHandle(CURL):
class CURLException(Exception): class CURLException(Exception):
def __init__(self,code) -> None: def __init__(self,code) -> None:
self.errorcode=code self.errorcode=code.value
if isinstance(code,CURLcode): if isinstance(code,CURLcode):
error=curl_easy_strerror(code).decode('utf8') error=curl_easy_strerror(code).decode('utf8')
for _ in dir(CURLcode): for _ in dir(CURLcode):
if _.startswith('CURLE_') and code.value==getattr(CURLcode,_): if _.startswith('CURLE_') and code.value==getattr(CURLcode,_):
error=str(code.value)+' '+_+' : '+error error=str(code.value)+' '+_+' : '+error
break break
elif isinstance(code,str):
error=code
else: else:
error='' raise Exception("not a valid CURLException")
super().__init__(error) super().__init__(error)

View File

@ -58,7 +58,17 @@ class Session(Sessionbase):
def _getmembyte(self,mem): def _getmembyte(self,mem):
return cast(mem.memory,POINTER(c_char))[:mem.size] return cast(mem.memory,POINTER(c_char))[:mem.size]
def request_impl(self, def request_impl(self,
method,scheme,server,port,param,url,headers,cookies,dataptr,datalen,proxy,stream,verify): method,scheme,server,port,param,url,headers,cookies,dataptr,datalen,proxy,stream,verify,timeout):
try:
_= self.request_impl_1(method,scheme,server,port,param,url,headers,cookies,dataptr,datalen,proxy,stream,verify,timeout)
return _
except CURLException as e:
if e.errorcode==CURLcode.CURLE_OPERATION_TIMEDOUT:
raise Timeout(e)
else:
raise e
def request_impl_1(self,
method,scheme,server,port,param,url,headers,cookies,dataptr,datalen,proxy,stream,verify,timeout):
if self._status==0: if self._status==0:
curl=self.curl curl=self.curl
@ -73,6 +83,9 @@ class Session(Sessionbase):
if cookies: if cookies:
cookie=self._parsecookie(cookies) cookie=self._parsecookie(cookies)
curl_easy_setopt(curl, CURLoption.CURLOPT_COOKIE, cookie.encode('utf8')); curl_easy_setopt(curl, CURLoption.CURLOPT_COOKIE, cookie.encode('utf8'));
if timeout:
curl_easy_setopt(curl, CURLoption.CURLOPT_TIMEOUT_MS, timeout);
curl_easy_setopt(curl, CURLoption.CURLOPT_CONNECTTIMEOUT_MS, timeout);
curl_easy_setopt(curl,CURLoption.CURLOPT_ACCEPT_ENCODING, headers['Accept-Encoding'].encode('utf8')) curl_easy_setopt(curl,CURLoption.CURLOPT_ACCEPT_ENCODING, headers['Accept-Encoding'].encode('utf8'))
curl_easy_setopt(curl,CURLoption.CURLOPT_CUSTOMREQUEST,method.upper().encode('utf8')) curl_easy_setopt(curl,CURLoption.CURLOPT_CUSTOMREQUEST,method.upper().encode('utf8'))

View File

@ -3,6 +3,8 @@ from collections.abc import Callable, Mapping, MutableMapping
from collections import OrderedDict from collections import OrderedDict
from urllib.parse import urlencode,urlsplit from urllib.parse import urlencode,urlsplit
from functools import partial,partialmethod from functools import partial,partialmethod
class Timeout(Exception):
pass
class CaseInsensitiveDict(MutableMapping): class CaseInsensitiveDict(MutableMapping):
def __init__(self, data=None, **kwargs): def __init__(self, data=None, **kwargs):
@ -190,8 +192,6 @@ class Sessionbase:
def request(self, def request(self,
method, url, params=None, data=None, headers=None,proxies=None, json=None,cookies=None, files=None, method, url, params=None, data=None, headers=None,proxies=None, json=None,cookies=None, files=None,
auth=None, timeout=None, allow_redirects=True, hooks=None, stream=None, verify=False, cert=None, ): auth=None, timeout=None, allow_redirects=True, hooks=None, stream=None, verify=False, cert=None, ):
# 0 means infinity, cite: WinHttpSetTimeouts
timeout = timeout or 0
_h=self.headers.copy() _h=self.headers.copy()
if headers: if headers:
_h.update(headers) _h.update(headers)
@ -204,7 +204,8 @@ class Sessionbase:
scheme,server,port,param,url=self._parseurl(url,params) scheme,server,port,param,url=self._parseurl(url,params)
headers,dataptr,datalen=self._parsedata(data,headers,json) headers,dataptr,datalen=self._parsedata(data,headers,json)
proxy= proxies.get(scheme,None) if proxies else None proxy= proxies.get(scheme,None) if proxies else None
if timeout:
timeout = int(timeout * 1000) # convert to milliseconds
_= self.request_impl(method,scheme,server,port,param,url,headers,cookies,dataptr,datalen,proxy,stream,verify,timeout) _= self.request_impl(method,scheme,server,port,param,url,headers,cookies,dataptr,datalen,proxy,stream,verify,timeout)
if _.status_code==301: if _.status_code==301:

View File

@ -60,10 +60,18 @@ class Session(Sessionbase):
if verify==False: if verify==False:
dwFlags=DWORD(SECURITY_FLAG_IGNORE_ALL_CERT_ERRORS) dwFlags=DWORD(SECURITY_FLAG_IGNORE_ALL_CERT_ERRORS)
WinHttpSetOption(curl,WINHTTP_OPTION_SECURITY_FLAGS, pointer(dwFlags),sizeof(dwFlags)) WinHttpSetOption(curl,WINHTTP_OPTION_SECURITY_FLAGS, pointer(dwFlags),sizeof(dwFlags))
# 30s is the default timeout on Windows
def request_impl(self, def request_impl(self,
method,scheme,server,port,param,url,headers,cookies,dataptr,datalen,proxy,stream,verify,timeout=30): method,scheme,server,port,param,url,headers,cookies,dataptr,datalen,proxy,stream,verify,timeout):
try:
_= self.request_impl_1(method,scheme,server,port,param,url,headers,cookies,dataptr,datalen,proxy,stream,verify,timeout)
return _
except WinhttpException as e:
if e.errorcode==WinhttpException.ERROR_WINHTTP_TIMEOUT:
raise Timeout(e)
else:
raise e
def request_impl_1(self,
method,scheme,server,port,param,url,headers,cookies,dataptr,datalen,proxy,stream,verify,timeout):
headers=self._parseheader(headers,cookies) headers=self._parseheader(headers,cookies)
flag=WINHTTP_FLAG_SECURE if scheme=='https' else 0 flag=WINHTTP_FLAG_SECURE if scheme=='https' else 0
#print(server,port,param,dataptr) #print(server,port,param,dataptr)
@ -73,18 +81,8 @@ class Session(Sessionbase):
if hConnect==0: if hConnect==0:
raise WinhttpException(GetLastError()) raise WinhttpException(GetLastError())
hRequest=AutoWinHttpHandle(WinHttpOpenRequest( hConnect ,method,param,None,WINHTTP_NO_REFERER,WINHTTP_DEFAULT_ACCEPT_TYPES,flag) ) hRequest=AutoWinHttpHandle(WinHttpOpenRequest( hConnect ,method,param,None,WINHTTP_NO_REFERER,WINHTTP_DEFAULT_ACCEPT_TYPES,flag) )
timeout = timeout * 1000 # convert to milliseconds if timeout:
'''
WINHTTPAPI BOOL WinHttpSetTimeouts(
[in] HINTERNET hInternet,
[in] int nResolveTimeout,
[in] int nConnectTimeout,
[in] int nSendTimeout,
[in] int nReceiveTimeout
);
'''
WinHttpSetTimeouts(hRequest, timeout, timeout, timeout, timeout) WinHttpSetTimeouts(hRequest, timeout, timeout, timeout, timeout)
if hRequest==0: if hRequest==0:
raise WinhttpException(GetLastError()) raise WinhttpException(GetLastError())
self._set_verify(hRequest,verify) self._set_verify(hRequest,verify)

View File

@ -1,4 +1,3 @@
import winhttp
from traceback import print_exc from traceback import print_exc
from translator.basetranslator import basetrans from translator.basetranslator import basetrans
import requests import requests
@ -92,12 +91,9 @@ class TS(basetrans):
stream=False, stream=False,
) )
output = self.session.post(self.api_url + "/chat/completions", timeout=self.timeout, json=data).json() output = self.session.post(self.api_url + "/chat/completions", timeout=self.timeout, json=data).json()
except winhttp.WinhttpException as e: except requests.Timeout as e:
code = e.errorcode
if code == winhttp.WinhttpException.ERROR_WINHTTP_TIMEOUT:
raise ValueError(f"连接到Sakura API超时{self.api_url},当前最大连接时间为: {self.timeout},请尝试修改参数。") raise ValueError(f"连接到Sakura API超时{self.api_url},当前最大连接时间为: {self.timeout},请尝试修改参数。")
else:
raise ValueError(f"连接到Sakura API网络错误{self.api_url},错误代码: {code}")
except Exception as e: except Exception as e:
print(e) print(e)
raise ValueError(f"无法连接到Sakura API{self.api_url}请检查你的API链接是否正确填写以及API后端是否成功启动。") raise ValueError(f"无法连接到Sakura API{self.api_url}请检查你的API链接是否正确填写以及API后端是否成功启动。")