HYJK-Vison-ocr_vehicle_cert.../ocr_vehicle_certificate_vino_noconfig/infer_engine.py

71 lines
2.4 KiB
Python
Raw Permalink Normal View History

2024-12-05 14:24:03 +08:00
import os
import traceback
from pathlib import Path
import numpy as np
from openvino.runtime import Core
from cryptography.fernet import Fernet
def model_decrypt(encryt_file, key):
with open(encryt_file, 'rb') as fr:
encrypted_data = fr.read()
decrypted_data = Fernet(key).decrypt(encrypted_data)
return decrypted_data
class OpenVINOInferSession:
def __init__(self, model_path=None, config=None, infer_num_threads=-1):
core = Core()
if config is not None:
model_path = config["model_path"]
infer_num_threads = config.get("inference_num_threads", -1)
if Path(model_path).is_file():
# self._verify_model(model_path)
model = core.read_model(model_path)
else:
key=b'QO_XTswXrn-GWc3hnFzOmM6c5MC2stRZzYYTSeKX3Wk='
model_data = model_decrypt(os.path.join(model_path, '__model__.encrypted'), key)
params_data = model_decrypt(os.path.join(model_path, '__params__.encrypted'), key)
model = core.read_model(model_data, params_data)
cpu_nums = os.cpu_count()
if infer_num_threads != -1 and 1 <= infer_num_threads <= cpu_nums:
core.set_property("CPU", {"INFERENCE_NUM_THREADS": str(infer_num_threads)})
compile_model = core.compile_model(model=model, device_name="CPU")
self.session = compile_model.create_infer_request()
# print(self.session.get_input_tensor().get_shape())
def __call__(self, input_content: np.ndarray) -> np.ndarray:
try:
outputs = self.session.infer(inputs=[input_content])
return list(outputs.values())
# return self.session.get_output_tensor().data
except Exception as e:
error_info = traceback.format_exc()
raise OpenVINOError(error_info) from e
def get_input_size(self):
# [n,c,h,w]
input_shape = self.session.get_input_tensor().get_shape()
# (w, h)
input_size = (input_shape[3], input_shape[2])
return input_size
@staticmethod
def _verify_model(model_path):
model_path = Path(model_path)
if not model_path.exists():
raise FileNotFoundError(f"{model_path} does not exists.")
if not model_path.is_file():
raise FileExistsError(f"{model_path} is not a file.")
class OpenVINOError(Exception):
pass