properly release tensors and free pointers to avoid potential crash (#494)

This commit is contained in:
chaihahaha 2024-01-17 00:14:46 +08:00 committed by GitHub
parent c4e55060f4
commit ca647fe151
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -159,6 +159,8 @@ class TS(basetrans):
input_ids_len = n_tokens.value
input_ids_py = [token_ids[i] for i in range(input_ids_len)]
input_ids_py += [1] # add EOS token to notify the end of sentence and prevent repetition
self.splib.free_ptr(token_ids)
return input_ids_py
def decode_from_ids(self, output_ids_py):
@ -171,7 +173,10 @@ class TS(basetrans):
output_len,
ctypes.byref(decoded_str)
)
return decoded_str.value.decode("utf8")
decoded_str_py = decoded_str.value.decode("utf8")
self.splib.free_ptr(decoded_str)
return decoded_str_py
def run_session(self, input_ids_py):
input_ids_len = len(input_ids_py)
@ -200,6 +205,9 @@ class TS(basetrans):
output_ids_py = []
for i in range(output_len.value):
output_ids_py.append(output_ids[i])
self.ortmtlib.release_ort_tensor(input_ids_tensor)
self.ortmtlib.free_ptr(output_ids)
return output_ids_py
def translate(self, content):
@ -207,3 +215,12 @@ class TS(basetrans):
output_ids_py = self.run_session(input_ids_py)
translated = self.decode_from_ids(output_ids_py)
return translated
def __del__(self):
self.ortmtlib.release_ort_tensor(self.max_length_tensor)
self.ortmtlib.release_ort_tensor(self.min_length_tensor)
self.ortmtlib.release_ort_tensor(self.num_beams_tensor)
self.ortmtlib.release_ort_tensor(self.num_return_sequences_tensor)
self.ortmtlib.release_ort_tensor(self.length_penalty_tensor)
self.ortmtlib.release_ort_tensor(self.repetition_penalty_tensor)
self.ortmtlib.release_all_globals()