栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

小模型的手部目标检测训练和部署

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

小模型的手部目标检测训练和部署

需求:自己搭建一个足够小的手部检测的模型。使用目标检测算法和手部的数据集进行实验。

经过自己探索,寻找的资源和实施方法整理如下。

1.直接调用MideaPipe

        可以直接调用MediaPipe的API直接实现手部检测及关键点检测,效果挺好,不需要训练,直接跑推理即可。代码来自:https://www.youtube.com/watch?v=x4eeX7WJIuA。如果有需要的可以采用此方案。可以参考:2022.3.3 Python-opencv-mediapipe - 简书

import cv2
import mediapipe as mp
import time

cap = cv2.VideoCapture(0)
mpHands = mp.solutions.hands
hands = mpHands.Hands(min_detection_confidence=0.5)
mpDraw = mp.solutions.drawing_utils
handLmsStyle = mpDraw.DrawingSpec(color=(0, 0, 255), thickness=5)
handConsStyle = mpDraw.DrawingSpec(color=(0, 255, 0), thickness=10)
pTime = 0
cTime = 0
#https://www.youtube.com/watch?v=x4eeX7WJIuA
while True:
    ret, img = cap.read()
    if ret:
        imgRGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        result = hands.process(imgRGB)
        # print(result.multi_hand_landmarks)
        imgHeight = img.shape[0]
        imgWidth = img.shape[1]

        if result.multi_hand_landmarks:
            for handLms in result.multi_hand_landmarks:
                mpDraw.draw_landmarks(img, handLms, mpHands.HAND_CONNECTIONS, handLmsStyle, handConsStyle)
                for i, lm in enumerate(handLms.landmark):
                    xPos = int(lm.x * imgWidth)
                    yPos = int(lm.y * imgHeight)
                    zPos = int(lm.z )
                    cv2.putText(img, str(i), (xPos - 25, yPos + 5), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (0, 0, 255), 2)
                    if i == 4:
                        cv2.circle(img, (xPos, yPos), 20, (0, 0, 255), cv2.FILLED)
                    print(i, xPos, yPos, zPos)
        cTime = time.time()
        fps = 1/(cTime - pTime)
        pTime = cTime
        cv2.putText(img, f"FPS : {int(fps)}", (30, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
        cv2.imshow('img', img)

    if cv2.waitKey(1) == ord('q'):
        break
 2.Yolo V3

参考代码链接:

Eric.Lee2021 / yolo_v3 · GitCodexx

此份资源中有多种类别的目标检测模型及数据集,很容易能够跑通,自己跑了predict.py,初步效果不是特别好,很多帧识别不出手,下面有一个自己跑的效果图。主要模型60多MB,不符合小模型的需求。遂放弃此方案。不过手部数据集留用了。需要相关数据集的可以去看看。

其中除了yolo v3还包含tiny模型。

  

3.Yolo_FastestV2

作者给了一个相关网络模型的比较,如下图,准确率和模型大小一目了然,可以按需去寻找想要的资源。

NetworkModel SizemAP(VOC 2007)FLOPS
Tiny YOLOv260.5MB57.1%6.97BFlops
Tiny YOLOv333.4MB58.4%5.52BFlops
YOLO Nano4.0MB69.1%4.51Bflops
MobileNetv2-SSD-Lite13.8MB68.6%&Bflops
MobileNetV2-YOLOv311.52MB70.20%2.02Bflos
Pelee-SSD21.68MB70.09%2.40Bflos
Yolo Fastest1.3MB61.02%0.23Bflops
Yolo Fastest-XL3.5MB69.43%0.70Bflops
MobileNetv2-Yolo-Lite8.0MB73.26%1.80Bflops

代码链接:GitHub - dog-qiuqiu/Yolo-FastestV2: Based on Yolo's low-power, ultra-lightweight universal target detection algorithm, the parameter is only 250k, and the speed of the smart phone mobile terminal can reach ~300fps+

此资源有比较完整的小模型训练及移动端部署方法,不过它是基于COCO数据集的,没有手部数据集,并需要重新组织数据。本人使用2中的手部数据集,按作者写的文档,重新组织数据,步骤如下:

1)重构train和val数据,按照4:1的数据量来划分整个数据。其实只需要分train.txt和val.txt将图片地址分别填进去,也需要填label地址。具体看作者的文档就可以弄好。

2)组织.data和.names文件

新建这两个文件,先按coco的复制粘贴:

需要按作者写的,生成当前数据集锚点,并填入配置文件。

python3 genanchors.py --traintxt ./train.txt

 同时在配置文件中填好数据集位置,这些都很清晰。

hand.data

[name]
model_name=hand

[train-configure]
epochs=300
steps=150,250
batch_size=128
subdivisions=1
learning_rate=0.001

[model-configure]
pre_weights=None
classes=1
width=352
height=352
anchor_num=3
anchors=5.07,9.50, 8.77,12.67, 18.48,23.87, 32.59,37.93, 54.71,59.43, 90.83,109.99
[data-configure]
train=./train/anno/train.txt  
val=./train/anno/val.txt
names=./data/hand.names

hand.names中写个hand即可。

3)开始训练

运行train.py即可。

4)测试

目前初步训练MAP在30%多,用一些图像测试效果还可以,模型1000多kb,基本满足要求,因需求有变动,后续再对模型进行优化。

5)移动端部署

作者写的很清楚,后续自己再补充。

4.基于TF的手部检测 

现在又给提出新的需求:使用量化模型,需要tf1.15搭建的网络模型,且模型需要小于1mb。上述都是pytorch的模型,暂不使用上面方案。

查找相关资料,直接使用TensorFlow object detection API即可满足需求。

参考:models/research/object_detection at master · tensorflow/models · GitHub

基本整个操作在此链接中已包括,但是实际中还是遇到了一些其他问题,具体记录如下:

1.安装TensorFlow及API

本次实验需要使用1.x的TensorFlow,注意环境对应。

2.准备pipeline.config文件

在~/models/research/object_detection/samples/configs路径寻找需要的模型的config,因为实验要求模型小于1mb,且使用量化,本次选用ssd_mobilenet_v2_coco.config文件。所有的训练、评估的参数都在此文件中进行配置,数据集路径也填入此文件。

3.准备数据集

数据集来自第二节。TensorFlow代码要求将数据集组织成tfrecord形式。

1)划分训练、测试集

随意写了个读取、copy的代码,将数据按照训练、评估、测试4:1:1来进行划分。

#划分数据集
import os
import shutil
filepath = '~/datasets_TVCOCO_hand_train/anno/images/'
filelist = os.listdir(filepath)
dis = '~/datasets_TVCOCO_hand_train/anno/val/images/'
# if not os.path.isdir(dis):
    # os.mkdir(dis)
# l = len(filelist)
# print(filelist)
# c=0
# for i in filelist:
    # from_path = filepath + i
    # shutil.move(from_path,dis)
    # c =c+1
    # if c == 5000:
        # break
        

#label要对应图像名称
labelpath = '~/datasets_TVCOCO_hand_train/anno/labels/'
labeldis = '~/datasets_TVCOCO_hand_train/anno/val/labels/'
if not os.path.isdir(labeldis):
    os.mkdir(labeldis)
for i in os.listdir(dis):
    from_path =labelpath + os.path.splitext(i)[0]  +'.txt'
    shutil.move(from_path,labeldis)
    

2)组织成tfrecord形式

参考TensorFlow网站models/using_your_own_dataset.md at master · tensorflow/models · GitHub

及~/models/research/object_detection/dataset_tools处代码,即可组织好代码:

import hashlib
import io
import logging
import os
import random
import re
import cv2
import contextlib2
from lxml import etree
import numpy as np
import PIL.Image
import tensorflow.compat.v1 as tf
from glob import glob
from object_detection.dataset_tools import tf_record_creation_util
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util


flags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw basketball dataset.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_integer('num_shards', 10, 'Number of TFRecord shards')
FLAGS = flags.FLAGS


def create_tf_example(example):
    # TODO(user): Populate the following variables from your example.
    height = example['height'] # Image height
    width = example['width'] # Image width
    #print(height,width)
    filename = example['filename'].encode('utf8') # Filename of the image. Empty if image is not from file
    encoded_image_data = example['image'] # Encoded image bytes
    image_format = 'jpeg'.encode('utf8') # b'jpeg' or b'png'
    key = example['key'].encode('utf8')
    xmins = example['xmins'] # List of normalized left x coordinates in bounding box (1 per box)
    xmaxs = example['xmaxs'] # List of normalized right x coordinates in bounding box
                 # (1 per box)
    ymins = example['ymins'] # List of normalized top y coordinates in bounding box (1 per box)
    ymaxs = example['ymaxs'] # List of normalized bottom y coordinates in bounding box
                 # (1 per box)
    classes_text = [class_name.encode('utf8') for class_name in example['class_names']] # List of string class name of bounding box (1 per box)
    classes = example['classes'] # List of integer class id of bounding box (1 per box)

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(filename),
        'image/source_id': dataset_util.bytes_feature(filename),
        'image/key/sha256': dataset_util.bytes_feature(key),
        'image/encoded': dataset_util.bytes_feature(encoded_image_data),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_example

def create_tf_record(output_filename,
                     num_shards,
                     data_dir,
                     data_type='train'):
    """Creates a TFRecord file from examples.

    Args:
        output_filename: Path to where output file is saved.
        annotations_dir: Directory where annotation files are stored.
        num_shards: Number of shards for output file.
        image_dir: Directory where image files are stored.
    """
    with contextlib2.ExitStack() as tf_record_close_stack:
        output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
            tf_record_close_stack, os.path.join(output_filename, "%s.record"%data_type), num_shards)
        image_dir = os.path.join(data_dir, data_type, "images/",  "*.jpg")
        image_list = glob(image_dir)
        skipped = 0
        for idx, imagename in enumerate(image_list):
            if idx % 100 == 0:
                logging.info('On image %d of %d', idx, len(image_list))
            label_path = os.path.join(data_dir, data_type, 'labels/', '%s.txt'%os.path.basename(imagename)[:-4])
            image_dir = os.path.join(data_dir,  data_type,'images/',  '%s.jpg'%os.path.basename(imagename)[:-4])

            if not os.path.exists(label_path):
                logging.warning('Could not find %s, ignoring example.', label_path)
                continue
            img = cv2.imread(imagename, cv2.IMREAD_COLOR)
            with tf.gfile.GFile(imagename, 'rb') as fid:
                encoded_jpg = fid.read()
            encoded_jpg_io = io.BytesIO(encoded_jpg)
            image = PIL.Image.open(encoded_jpg_io)
            # if image.format != 'JPEG':
                # raise ValueError('Image format not JPEG')
            key = hashlib.sha256(encoded_jpg).hexdigest()
            with open(label_path) as f:
                lines = f.readlines()
                xmins = []
                xmaxs = []
                ymins = []
                ymaxs = []
                class_names = []
                classes = []
                height = img.shape[0]
                width = img.shape[1]
                #print(height,width)
                for line in lines:
                    print(line)
                    id, x_center, y_center, w, h = list(map(float, line.strip().split(' ')))
                    classes.append(int(id + 1))
                    class_names.append("hand")
                    xmins.append(float(x_center-w/2)/width)
                    xmaxs.append(float(x_center+w/2)/width)
                    ymins.append(float(y_center-h/2)/height)
                    ymaxs.append(float(y_center+h/2)/height)
                    if (x_center-w/2)/width <0:
                        error = True
                        print(f"[WARNING] Error with {line}, xmin {float(x_center-w/2)/width} < 0")
                        #print(f"t row.xmin = {row.xmin} ; width = {width}")
                        print(label_path,height,width)
                        skipped += 1
                        continue
                    if (x_center+w/2)/width>1:
                        error = True
                        print(f"[WARNING] Error with {line}, xmax {float(x_center+w/2)/width} > 1")
                        #print(f"t row.xmax = {row.xmax} ; width = {width}")
                        print(label_path,height,width)
                        skipped += 1
                        continue
                    if (y_center-h/2)/height<0:
                        error = True
                        print(f"[WARNING] Error with {line}, ymin {float(y_center-h/2)/height} < 0")
                        #print(f"t row.ymin = {row.ymin} ; width = {width}")
                        print(label_path,height,width)
                        skipped += 1
                        continue
                    if (y_center+h/2)/height>1:
                        error = True
                        print(f"[WARNING] Error with {line}, ymax {float(y_center+h/2)/height} > 1")
                        #print(f"t row.ymax = {row.ymax} ; width = {width}")
                        print(label_path,height,width)
                        skipped += 1
                        continue
                        
                print("skipped",skipped)    
                example = {'height': img.shape[0],
                           'width': img.shape[1],
                           'filename': os.path.basename(imagename),
                           'key': key,
                           'image': encoded_jpg,
                           'xmins': xmins,
                           'xmaxs': xmaxs,
                           'ymins': ymins,
                           'ymaxs': ymaxs,
                           'class_names': class_names,
                           'classes': classes}

                try:
                    tf_example = create_tf_example(example)
                    if tf_example:
                        shard_idx = idx % num_shards
                        output_tfrecords[shard_idx].write(tf_example.SerializeToString())
                except ValueError:
                    logging.warning('Invalid example: %s, ignoring.', label_path)



def main(_):
    data_dir = FLAGS.data_dir
    # TODO(user): Write code to read in your dataset to examples variable
    logging.info('Reading from Basketball train dataset.')
    create_tf_record(FLAGS.output_path, num_shards=8, data_dir=data_dir, data_type='train')
    logging.info('train dataset done')

    logging.info('Reading from Basketball val dataset.')
    create_tf_record(FLAGS.output_path, num_shards=4, data_dir=data_dir, data_type='val')
    logging.info('val dataset done')



if __name__ == '__main__':
  tf.app.run()
import hashlib
import io
import logging
import os
import random
import re
import cv2
import contextlib2
from lxml import etree
import numpy as np
import PIL.Image
import tensorflow.compat.v1 as tf
from glob import glob
from object_detection.dataset_tools import tf_record_creation_util
from object_detection.utils import dataset_util
from object_detection.utils import label_map_util


flags = tf.app.flags
flags.DEFINE_string('data_dir', '', 'Root directory to raw basketball dataset.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_integer('num_shards', 10, 'Number of TFRecord shards')
FLAGS = flags.FLAGS


def create_tf_example(example):
    # TODO(user): Populate the following variables from your example.
    height = example['height'] # Image height
    width = example['width'] # Image width
    filename = example['filename'].encode('utf8') # Filename of the image. Empty if image is not from file
    encoded_image_data = example['image'] # Encoded image bytes
    image_format = 'jpeg'.encode('utf8') # b'jpeg' or b'png'
    key = example['key'].encode('utf8')
    xmins = example['xmins'] # List of normalized left x coordinates in bounding box (1 per box)
    xmaxs = example['xmaxs'] # List of normalized right x coordinates in bounding box
                 # (1 per box)
    ymins = example['ymins'] # List of normalized top y coordinates in bounding box (1 per box)
    ymaxs = example['ymaxs'] # List of normalized bottom y coordinates in bounding box
                 # (1 per box)
    classes_text = [class_name.encode('utf8') for class_name in example['class_names']] # List of string class name of bounding box (1 per box)
    classes = example['classes'] # List of integer class id of bounding box (1 per box)

    tf_example = tf.train.Example(features=tf.train.Features(feature={
        'image/height': dataset_util.int64_feature(height),
        'image/width': dataset_util.int64_feature(width),
        'image/filename': dataset_util.bytes_feature(filename),
        'image/source_id': dataset_util.bytes_feature(filename),
        'image/key/sha256': dataset_util.bytes_feature(key),
        'image/encoded': dataset_util.bytes_feature(encoded_image_data),
        'image/format': dataset_util.bytes_feature(image_format),
        'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
        'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
        'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
        'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
        'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
        'image/object/class/label': dataset_util.int64_list_feature(classes),
    }))
    return tf_example

def create_tf_record(output_filename,
                     num_shards,
                     data_dir,
                     data_type='train'):
    """Creates a TFRecord file from examples.

    Args:
        output_filename: Path to where output file is saved.
        annotations_dir: Directory where annotation files are stored.
        num_shards: Number of shards for output file.
        image_dir: Directory where image files are stored.
    """
    with contextlib2.ExitStack() as tf_record_close_stack:
        output_tfrecords = tf_record_creation_util.open_sharded_output_tfrecords(
            tf_record_close_stack, os.path.join(output_filename, data_type), num_shards)
        image_dir = os.path.join(data_dir, "images", data_type, "*.jpg")
        image_list = glob(image_dir)
        for idx, imagename in enumerate(image_list):
            if idx % 100 == 0:
                logging.info('On image %d of %d', idx, len(image_list))
            label_path = os.path.join(data_dir, 'labels', data_type,  '%s.txt'%os.path.basename(imagename)[:-4])

            if not os.path.exists(label_path):
                logging.warning('Could not find %s, ignoring example.', label_path)
                continue
            img = cv2.imread(imagename, cv2.IMREAD_COLOR)
            with tf.gfile.GFile(imagename, 'rb') as fid:
                encoded_jpg = fid.read()
            encoded_jpg_io = io.BytesIO(encoded_jpg)
            image = PIL.Image.open(encoded_jpg_io)
            if image.format != 'JPEG':
                raise ValueError('Image format not JPEG')
            key = hashlib.sha256(encoded_jpg).hexdigest()
            with open(label_path) as f:
                lines = f.readlines()
                xmins = []
                xmaxs = []
                ymins = []
                ymaxs = []
                class_names = []
                classes = []
                for line in lines:
                    id, x_center, y_center, w, h = list(map(float, line.strip().split(' ')))
                    classes.append(int(id + 1))
                    class_names.append("basketball")
                    xmins.append(x_center-w/2)
                    xmaxs.append(x_center+w/2)
                    ymins.append(y_center-h/2)
                    ymaxs.append(y_center+h/2)

                example = {'height': img.shape[0],
                           'width': img.shape[1],
                           'filename': os.path.basename(imagename),
                           'key': key,
                           'image': encoded_jpg,
                           'xmins': xmins,
                           'xmaxs': xmaxs,
                           'ymins': ymins,
                           'ymaxs': ymaxs,
                           'class_names': class_names,
                           'classes': classes}

                try:
                    tf_example = create_tf_example(example)
                    if tf_example:
                        shard_idx = idx % num_shards
                        output_tfrecords[shard_idx].write(tf_example.SerializeToString())
                except ValueError:
                    logging.warning('Invalid example: %s, ignoring.', label_path)



def main(_):
    data_dir = FLAGS.data_dir
    # TODO(user): Write code to read in your dataset to examples variable
    logging.info('Reading from Basketball train dataset.')
    create_tf_record(FLAGS.output_path, num_shards=8, data_dir=data_dir, data_type='train')
    logging.info('train dataset done')

    logging.info('Reading from Basketball val dataset.')
    create_tf_record(FLAGS.output_path, num_shards=4, data_dir=data_dir, data_type='val')
    logging.info('val dataset done')



if __name__ == '__main__':
  tf.app.run()

3)去除异常数据

因为存在异常数据,即标签的bbox超出了图像尺寸,程序会报错,需要筛掉这部分数据。

问题参考:https://github.com/tensorflow/models/issues/5474

此部分代码已包含在2)中,再单独拿出来贴在此处:

for line in lines:
                    print(line)
                    id, x_center, y_center, w, h = list(map(float, line.strip().split(' ')))
                    classes.append(int(id + 1))
                    class_names.append("hand")
                    xmins.append(float(x_center-w/2)/width)
                    xmaxs.append(float(x_center+w/2)/width)
                    ymins.append(float(y_center-h/2)/height)
                    ymaxs.append(float(y_center+h/2)/height)
                    if (x_center-w/2)/width <0:
                        error = True
                        print(f"[WARNING] Error with {line}, xmin {float(x_center-w/2)/width} < 0")
                        #print(f"t row.xmin = {row.xmin} ; width = {width}")
                        print(label_path,height,width)
                        skipped += 1
                        continue
                    if (x_center+w/2)/width>1:
                        error = True
                        print(f"[WARNING] Error with {line}, xmax {float(x_center+w/2)/width} > 1")
                        #print(f"t row.xmax = {row.xmax} ; width = {width}")
                        print(label_path,height,width)
                        skipped += 1
                        continue
                    if (y_center-h/2)/height<0:
                        error = True
                        print(f"[WARNING] Error with {line}, ymin {float(y_center-h/2)/height} < 0")
                        #print(f"t row.ymin = {row.ymin} ; width = {width}")
                        print(label_path,height,width)
                        skipped += 1
                        continue
                    if (y_center+h/2)/height>1:
                        error = True
                        print(f"[WARNING] Error with {line}, ymax {float(y_center+h/2)/height} > 1")
                        #print(f"t row.ymax = {row.ymax} ; width = {width}")
                        print(label_path,height,width)
                        skipped += 1
                        continue
                        
                print("skipped",skipped)    

通过筛查,一共统计出200多异常数据。

4.训练及评估

这部分后续再补充

python model_main.py --pipeline_config_path=/home/mi/cjl/models/research/object_detection/hand.config --model_dir=/home/mi/cjl/models/ --num_train_steps=20000 --sample_1_of_n_eval_examples=1 --alsologtostderr

5.模型转化

首先将ckpt模型转化为pb模型,运行api自带代码即可:

python object_detection/export_tflite_ssd_graph.py --pipeline_config_path=~/research/object_detection/hand.config --trained_checkpoint_prefix=~/models/model.ckpt-10000 --output_directory=~/models/save/ --add_postprocessing_op=true

此处注意不需要修改生成的模型后缀, model.ckpt-10000代表会用到生成的index,meta及data-00000-of-00001。

运行完毕生成pb和pbtxt,还需要转化为tflite以便我们的轻量化移动端使用,代码如下:

import tensorflow as tf
import numpy as np


if __name__ == "__main__":
    # convert uint8 model
    path_to_frozen_graphdef_pb = 'tflite_graph.pb'
    converter = tf.lite.TFLiteConverter.from_frozen_graph(path_to_frozen_graphdef_pb,
                                                          ["normalized_input_image_tensor"],
                                                          [
                                                            'BoxPredictor_0/BoxEncodingPredictor/act_quant/FakeQuantWithMinMaxVars',
                                                            'BoxPredictor_1/BoxEncodingPredictor/act_quant/FakeQuantWithMinMaxVars',
                                                            'BoxPredictor_2/BoxEncodingPredictor/act_quant/FakeQuantWithMinMaxVars',
                                                            'BoxPredictor_3/BoxEncodingPredictor/act_quant/FakeQuantWithMinMaxVars',
                                                            'BoxPredictor_4/BoxEncodingPredictor/act_quant/FakeQuantWithMinMaxVars',
                                                            'BoxPredictor_5/BoxEncodingPredictor/act_quant/FakeQuantWithMinMaxVars',
                                                            'BoxPredictor_0/ClassPredictor/act_quant/FakeQuantWithMinMaxVars',
                                                            'BoxPredictor_1/ClassPredictor/act_quant/FakeQuantWithMinMaxVars',
                                                            'BoxPredictor_2/ClassPredictor/act_quant/FakeQuantWithMinMaxVars',
                                                            'BoxPredictor_3/ClassPredictor/act_quant/FakeQuantWithMinMaxVars',
                                                            'BoxPredictor_4/ClassPredictor/act_quant/FakeQuantWithMinMaxVars',
                                                            'BoxPredictor_5/ClassPredictor/act_quant/FakeQuantWithMinMaxVars'
                                                            
                                                            # 'hand_landmark/handness_identity',
                                                          ],
                                                          input_shapes={"normalized_input_image_tensor":[1, 300, 300, 3]})
    converter.inference_type = tf.lite.constants.QUANTIZED_UINT8
    converter.optimizations = ['DEFAULT']
    converter.quantized_input_stats = {"normalized_input_image_tensor": (128, 128)}
    converter.allow_custom_ops = True
    converter.default_ranges_stats = (0, 255)
    converter.change_concat_input_ranges = False 
    # # converter.post_training_quantize = True
    # # converter.representative_dataset = representative_dataset
    #
    tflite_model = converter.convert()
    open("ssd.tflite", 'wb').write(tflite_model)

运行完毕,生成tflite模型,模型大小仅为600余kb,整个满足需求。

后续继续提升效果。 

5.相关学习资料

评价指标:
目标检测评价指标 - 知乎

 算法介绍:

你一定从未看过如此通俗易懂的YOLO系列(从v1到v5)模型解读 (上) - 知乎

GitHub - luanshiyinyang/YOLO: YOLO目标检测算法的介绍。
极市开发者平台-计算机视觉算法开发落地平台
ShowMeAI知识社区

【深度学习】目标检测算法 YOLO 最耐心细致的讲解_frank909的博客-CSDN博客_yolo检测

其他:

GitHub - dog-qiuqiu/YOLOv5_NCNN:  移动端部署,支持YOLOv5s、YOLOv4-tiny、MobileNetV2-YOLOv3-nano、Simple-Pose与Yolact模型,支持iOS、Android,使用NCNN框架。

GitHub - AlexeyAB/darknet: YOLOv4 / Scaled-YOLOv4 / YOLO - Neural Networks for Object Detection (Windows and Linux version of Darknet )

YOLO: Real-Time Object Detection

10 | 手部实时检测器 · GitBook

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/1018037.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号