使用PointRend实例分割训练自己的coco数据集_window_name = "coco detections-程序员宅基地

技术标签: detectron2  pytorch  PointRend  

       近期在研究将何恺明团队提出的Pointrend算法来实现自己图像目标分割,而Fackbook的开源检测框架Detectron2已开源了一段时间,但是自己也是刚刚接触到,通过实现PointRend,来慢慢的认识这个框架。这里主要写的是如何实现自己的coco数据集。

1、安装Detectron2

       安装Detectron2,现在网上也是一大推,我也写了一个我自己理解的安装和测试的博客,感兴趣的可以参考一下。安装成功了,就是要准备自己的数据集了。

2、准备自己的coco数据集

     ①我是使用labelme标注自己的图像,并转换成json文件,这个不是我这里说的重点,所以就放·······。

     ②将json文件分成train和val两个文件夹,然后分别转换成train.json和val.json,转换代码可以参考下面,或者去github上下载。

# -*- coding: utf-8 -*-
"""
Created on Fri May 15 14:27:14 2020

@author: Administrator
"""
import argparse
import json
import matplotlib.pyplot as plt
import skimage.io as io
import cv2
from labelme import utils
import numpy as np
import glob
import PIL.Image
 
class MyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        else:
            return super(MyEncoder, self).default(obj)
 
class labelme2coco(object):
    def __init__(self, labelme_json=[], save_json_path='./tran.json'):
        '''
        :param labelme_json: 所有labelme的json文件路径组成的列表
        :param save_json_path: json保存位置
        '''
        self.labelme_json = labelme_json
        self.save_json_path = save_json_path
        self.images = []
        self.categories = []
        self.annotations = []
        # self.data_coco = {}
        self.label = []
        self.annID = 1
        self.height = 0
        self.width = 0
 
        self.save_json()
 
    def data_transfer(self):
 
        for num, json_file in enumerate(self.labelme_json):
            with open(json_file, 'r') as fp:
                data = json.load(fp)  # 加载json文件
                self.images.append(self.image(data, num))
                for shapes in data['shapes']:
                    label = shapes['label']
                    if label not in self.label:
                        self.categories.append(self.categorie(label))
                        self.label.append(label)
                    points = shapes['points']#这里的point是用rectangle标注得到的,只有两个点,需要转成四个点
                    #points.append([points[0][0],points[1][1]])
                    #points.append([points[1][0],points[0][1]])
                    self.annotations.append(self.annotation(points, label, num))
                    self.annID += 1
 
    def image(self, data, num):
        image = {}
        img = utils.img_b64_to_arr(data['imageData'])  # 解析原图片数据
        # img=io.imread(data['imagePath']) # 通过图片路径打开图片
        # img = cv2.imread(data['imagePath'], 0)
        height, width = img.shape[:2]
        img = None
        image['height'] = height
        image['width'] = width
        image['id'] = num + 1
        #image['file_name'] = data['imagePath'].split('/')[-1]

        image['file_name'] = data['imagePath']  #github中,自己要根据自己的图像名修改
        self.height = height
        self.width = width
 
        return image
 
    def categorie(self, label):
        categorie = {}
        categorie['supercategory'] = 'Cancer'
        categorie['id'] = len(self.label) + 1  # 0 默认为背景
        categorie['name'] = label
        return categorie
 
    def annotation(self, points, label, num):
        annotation = {}
        annotation['segmentation'] = [list(np.asarray(points).flatten())]
        annotation['iscrowd'] = 0
        annotation['image_id'] = num + 1
        # annotation['bbox'] = str(self.getbbox(points)) # 使用list保存json文件时报错(不知道为什么)
        # list(map(int,a[1:-1].split(','))) a=annotation['bbox'] 使用该方式转成list
        annotation['bbox'] = list(map(float, self.getbbox(points)))
        annotation['area'] = annotation['bbox'][2] * annotation['bbox'][3]
        # annotation['category_id'] = self.getcatid(label)
        annotation['category_id'] = self.getcatid(label)#注意,源代码默认为1
        annotation['id'] = self.annID
        return annotation
 
    def getcatid(self, label):
        for categorie in self.categories:
            if label == categorie['name']:
                return categorie['id']
        return 1
 
    def getbbox(self, points):
        # img = np.zeros([self.height,self.width],np.uint8)
        # cv2.polylines(img, [np.asarray(points)], True, 1, lineType=cv2.LINE_AA)  # 画边界线
        # cv2.fillPoly(img, [np.asarray(points)], 1)  # 画多边形 内部像素值为1
        polygons = points
 
        mask = self.polygons_to_mask([self.height, self.width], polygons)
        return self.mask2box(mask)
 
    def mask2box(self, mask):
        '''从mask反算出其边框
        mask:[h,w]  0、1组成的图片
        1对应对象,只需计算1对应的行列号(左上角行列号,右下角行列号,就可以算出其边框)
        '''
        # np.where(mask==1)
        index = np.argwhere(mask == 1)
        rows = index[:, 0]
        clos = index[:, 1]
        # 解析左上角行列号
        left_top_r = np.min(rows)  # y
        left_top_c = np.min(clos)  # x
 
        # 解析右下角行列号
        right_bottom_r = np.max(rows)
        right_bottom_c = np.max(clos)
 
        # return [(left_top_r,left_top_c),(right_bottom_r,right_bottom_c)]
        # return [(left_top_c, left_top_r), (right_bottom_c, right_bottom_r)]
        # return [left_top_c, left_top_r, right_bottom_c, right_bottom_r]  # [x1,y1,x2,y2]
        return [left_top_c, left_top_r, right_bottom_c - left_top_c,
                right_bottom_r - left_top_r]  # [x1,y1,w,h] 对应COCO的bbox格式
 
    def polygons_to_mask(self, img_shape, polygons):
        mask = np.zeros(img_shape, dtype=np.uint8)
        mask = PIL.Image.fromarray(mask)
        xy = list(map(tuple, polygons))
        PIL.ImageDraw.Draw(mask).polygon(xy=xy, outline=1, fill=1)
        mask = np.array(mask, dtype=bool)
        return mask
 
    def data2coco(self):
        data_coco = {}
        data_coco['images'] = self.images
        data_coco['categories'] = self.categories
        data_coco['annotations'] = self.annotations
        return data_coco
 
    def save_json(self):
        self.data_transfer()
        self.data_coco = self.data2coco()
        # 保存json文件
        json.dump(self.data_coco, open(self.save_json_path, 'w'), indent=4, cls=MyEncoder)  # indent=4 更加美观显示
 
 
labelme_json = glob.glob(r'I:/timage/val/json/*.json')
# labelme_json=['./Annotations/*.json']
 
labelme2coco(labelme_json, 'I:/image//val/footdeep_val.json')

我都自己运行了,并且也跑通了,转换结果都没有问题,这里要注意一个,自己在做别的项目的时候,把图像给重命名了,导致json文件的imagepath与图像名不一致,然后训练一直报错,哎,这个一定要注意,也是给自己提个醒,以后标注了数据一定要存一份最原始的数据集,还要注意一点,使用labelme标注图片,背景是没有标注的,因此在上面代码中一定要将 annotation['category_id'] = self.getcatid(label)修改成 annotation['category_id'] = self.getcatid(label)-1,这样跑出来的json文件里 category_id从0开始,避免因为背景没有标注在测试的时候报错,这个是个坑啊。

最后的数据结构如下:

footdata
    --foot
       train.json
       val.json
    train
       png图片
    val
       png图片
       

我的数据集就1个类别,所以毕竟简单

3、使用API训练

基本与Detectron2训练差不多,这里附上代码,我定义为train_foot.py:

# -*- coding: utf-8 -*-
"""
Created on Thu Jul  2 16:26:02 2020

@author: CTZN
"""
import os
import torch

import detectron2.utils.comm as comm
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
from detectron2.data import MetadataCatalog
from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, launch
from detectron2.evaluation import (
    CityscapesEvaluator,
    COCOEvaluator,
    DatasetEvaluators,
    LVISEvaluator,
    verify_results,
)

from point_rend import add_pointrend_config

from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.datasets.coco import load_coco_json
import pycocotools
#声明类别,尽量保持
CLASS_NAMES =["__background__","foot"]

DATASET_ROOT = '/root/detectron2/projects/PointRend/footdata'
ANN_ROOT = os.path.join(DATASET_ROOT, 'foot')

TRAIN_PATH = os.path.join(DATASET_ROOT, 'train')
VAL_PATH = os.path.join(DATASET_ROOT, 'val')

TRAIN_JSON = os.path.join(ANN_ROOT, 'footdeep_train.json')
#VAL_JSON = os.path.join(ANN_ROOT, 'val.json')
VAL_JSON = os.path.join(ANN_ROOT, 'footdeep_val.json')

# 声明数据集的子集
PREDEFINED_SPLITS_DATASET = {
    "coco_my_train": (TRAIN_PATH, TRAIN_JSON),
    "coco_my_val": (VAL_PATH, VAL_JSON),
}

#注册数据集(这一步就是将自定义数据集注册进Detectron2)
def register_dataset():
    """
    purpose: register all splits of dataset with PREDEFINED_SPLITS_DATASET
    """
    for key, (image_root, json_file) in PREDEFINED_SPLITS_DATASET.items():
        register_dataset_instances(name=key,
                                   json_file=json_file,
                                   image_root=image_root)


def register_dataset_instances(name, json_file, image_root):
    """
    purpose: register dataset to DatasetCatalog,
             register metadata to MetadataCatalog and set attribute
    """
    DatasetCatalog.register(name, lambda: load_coco_json(json_file, image_root, name))
    MetadataCatalog.get(name).set(json_file=json_file,
                                  image_root=image_root,
                                  evaluator_type="coco")

# 注册数据集和元数据
def plain_register_dataset():
    #训练集
    DatasetCatalog.register("coco_my_train", lambda: load_coco_json(TRAIN_JSON, TRAIN_PATH))
    MetadataCatalog.get("coco_my_train").set(thing_classes=CLASS_NAMES,  # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭
                                                    evaluator_type='coco', # 指定评估方式
                                                    json_file=TRAIN_JSON,
                                                    image_root=TRAIN_PATH)


    #验证/测试集
    DatasetCatalog.register("coco_my_val", lambda: load_coco_json(VAL_JSON, VAL_PATH))
    MetadataCatalog.get("coco_my_val").set(thing_classes=CLASS_NAMES, # 可以选择开启,但是不能显示中文,这里需要注意,中文的话最好关闭
                                                evaluator_type='coco', # 指定评估方式
                                                json_file=VAL_JSON,
                                                image_root=VAL_PATH)


class Trainer(DefaultTrainer):
    """
    We use the "DefaultTrainer" which contains a number pre-defined logic for
    standard training workflow. They may not work for you, especially if you
    are working on a new research project. In that case you can use the cleaner
    "SimpleTrainer", or write your own training loop.
    """

    @classmethod
    def build_evaluator(cls, cfg, dataset_name, output_folder=None):
        """
        Create evaluator(s) for a given dataset.
        This uses the special metadata "evaluator_type" associated with each builtin dataset.
        For your own dataset, you can simply create an evaluator manually in your
        script and do not have to worry about the hacky if-else logic here.
        """
        if output_folder is None:
            output_folder = os.path.join(cfg.OUTPUT_DIR, "inference")
        evaluator_list = []
        evaluator_type = MetadataCatalog.get(dataset_name).evaluator_type
        if evaluator_type == "lvis":
            return LVISEvaluator(dataset_name, cfg, True, output_folder)
        if evaluator_type == "coco":
            return COCOEvaluator(dataset_name, cfg, True, output_folder)
        if evaluator_type == "cityscapes":
            assert (
                torch.cuda.device_count() >= comm.get_rank()
            ), "CityscapesEvaluator currently do not work with multiple machines."
            return CityscapesEvaluator(dataset_name)
        if len(evaluator_list) == 0:
            raise NotImplementedError(
                "no Evaluator for the dataset {} with the type {}".format(
                    dataset_name, evaluator_type
                )
            )
        if len(evaluator_list) == 1:
            return evaluator_list[0]
        return DatasetEvaluators(evaluator_list)
    
def setup(args):
    """
    Create configs and perform basic setups.
    """
    cfg = get_cfg()
    add_pointrend_config(cfg)
    args.config_file = "/root/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml"
    cfg.merge_from_file(args.config_file)   # 从config file 覆盖配置
    cfg.merge_from_list(args.opts)          # 从CLI参数 覆盖配置

    # 更改配置参数
    cfg.DATASETS.TRAIN = ("coco_my_train",) # 训练数据集名称
    cfg.DATASETS.TEST = ("coco_my_val",)
    cfg.DATALOADER.NUM_WORKERS = 4  # 单线程

#    cfg.INPUT.CROP.ENABLED = True
    cfg.INPUT.MAX_SIZE_TRAIN = 640 # 训练图片输入的最大尺寸
    cfg.INPUT.MAX_SIZE_TEST = 640 # 测试数据输入的最大尺寸
    cfg.INPUT.MIN_SIZE_TRAIN = (512, 768) # 训练图片输入的最小尺寸,可以设定为多尺度训练
    cfg.INPUT.MIN_SIZE_TEST = 640
    #cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,其存在两种配置,分别为 choice 与 range :
    # range 让图像的短边从 512-768随机选择
    #choice : 把输入图像转化为指定的,有限的几种图片大小进行训练,即短边只能为 512或者768
    cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING = 'range'

    cfg.MODEL.RETINANET.NUM_CLASSES = 2  # 类别数+1(因为有background)
    #cfg.MODEL.WEIGHTS="/home/yourstorePath/.pth"
    cfg.MODEL.WEIGHTS = "/root/detectron2/projects/PointRend/models/model_final_3c3198.pkl"    # 预训练模型权重
    cfg.SOLVER.IMS_PER_BATCH = 4  # batch_size=2; iters_in_one_epoch = dataset_imgs/batch_size

    # 根据训练数据总数目以及batch_size,计算出每个epoch需要的迭代次数
    #9000为你的训练数据的总数目,可自定义
    ITERS_IN_ONE_EPOCH = int(9000 / cfg.SOLVER.IMS_PER_BATCH)

    # 指定最大迭代次数
    cfg.SOLVER.MAX_ITER = (ITERS_IN_ONE_EPOCH * 12) - 1 # 12 epochs,
    # 初始学习率
    cfg.SOLVER.BASE_LR = 0.002
    # 优化器动能
    cfg.SOLVER.MOMENTUM = 0.9
    #权重衰减
    cfg.SOLVER.WEIGHT_DECAY = 0.0001
    cfg.SOLVER.WEIGHT_DECAY_NORM = 0.0
    # 学习率衰减倍数
    cfg.SOLVER.GAMMA = 0.1
    # 迭代到指定次数,学习率进行衰减
    cfg.SOLVER.STEPS = (7000,)
    # 在训练之前,会做一个热身运动,学习率慢慢增加初始学习率
    cfg.SOLVER.WARMUP_FACTOR = 1.0 / 1000
    # 热身迭代次数
    cfg.SOLVER.WARMUP_ITERS = 1000

    cfg.SOLVER.WARMUP_METHOD = "linear"
    # 保存模型文件的命名数据减1
    cfg.SOLVER.CHECKPOINT_PERIOD = ITERS_IN_ONE_EPOCH - 1

    # 迭代到指定次数,进行一次评估
    cfg.TEST.EVAL_PERIOD = ITERS_IN_ONE_EPOCH
    #cfg.TEST.EVAL_PERIOD = 100

    #cfg.merge_from_file(args.config_file)
    #cfg.merge_from_list(args.opts)
    cfg.freeze()
    default_setup(cfg, args)
    return cfg
    
def main(args):
    plain_register_dataset()
    cfg = setup(args)
    if args.eval_only:
        model = Trainer.build_model(cfg)
        DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
            cfg.MODEL.WEIGHTS, resume=args.resume
        )
        res = Trainer.test(cfg, model)
        if comm.is_main_process():
            verify_results(cfg, res)
        return res

    trainer = Trainer(cfg)
    trainer.resume_or_load(resume=args.resume)
    return trainer.train()

if __name__ == "__main__":
    args = default_argument_parser().parse_args()
    print("Command Line Args:", args)
    launch(
        main,
        args.num_gpus,
        num_machines=args.num_machines,
        machine_rank=args.machine_rank,
        dist_url=args.dist_url,
        args=(args,),
    )

这里我借鉴了一个博主实现Detectron2的API修改的,非常感谢一位博主,帮我解决不少问题,有问题的可以参考一下,如果使用labelme标注的数据集,那就将CLASS_NAMES =["__background__","foot"]改成CLASS_NAMES =["foot"],这样在测试的时候不会报错。
 

注册数据集,上面编写了3种方式,分别为	
    def register_dataset(), 
	def register_dataset_instances(),
	def plain_register_dataset().
实验中好用最后一个,我用最后一个成功了,可能是detectron2的版本问题吧,前面第一种我一直报错,就是自己数据集的类别没有注册近去,所以建议用最后一个。

然后直接运行python train_foot.py。

运行过程:

[07/03 17:20:32 d2.data.build]: Using training sampler TrainingSampler
[07/03 17:20:32 fvcore.common.checkpoint]: Loading checkpoint from /root/detectron2/projects/PointRend/models/model_final_3c3198.pkl
[07/03 17:20:32 fvcore.common.checkpoint]: Reading a file from 'Detectron2 Model Zoo'
[07/03 17:20:33 d2.engine.train_loop]: Starting training from iteration 0
====> <torch.utils.data.dataloader.DataLoader object at 0x7f7983d24b00>
[07/03 17:21:04 d2.utils.events]:  eta: 11:56:26  iter: 19  total_loss: 0.538  loss_cls: 0.529  loss_box_reg: 0.000  loss_mask: 0.000  loss_mask_point: 0.000  loss_rpn_cls: 0.007  loss_rpn_loc: 0.003  time: 1.5912  data_time: 0.0224  lr: 0.000040  max_mem: 2850M
[07/03 17:21:37 d2.utils.events]:  eta: 12:05:00  iter: 39  total_loss: 0.156  loss_cls: 0.148  loss_box_reg: 0.000  loss_mask: 0.000  loss_mask_point: 0.000  loss_rpn_cls: 0.006  loss_rpn_loc: 0.003  time: 1.6116  data_time: 0.0156  lr: 0.000080  max_mem: 2866M
[07/03 17:22:09 d2.utils.events]:  eta: 12:05:53  iter: 59  total_loss: 0.054  loss_cls: 0.044  loss_box_reg: 0.000  loss_mask: 0.000  loss_mask_point: 0.000  loss_rpn_cls: 0.007  loss_rpn_loc: 0.002  time: 1.6162  data_time: 0.0148  lr: 0.000120  max_mem: 2866M
[07/03 17:22:42 d2.utils.events]:  eta: 12:07:30  iter: 79  total_loss: 0.025  loss_cls: 0.016  loss_box_reg: 0.000  loss_mask: 0.000  loss_mask_point: 0.000  loss_rpn_cls: 0.006  loss_rpn_loc: 0.003  time: 1.6198  data_time: 0.0152  lr: 0.000160  max_mem: 2866M
[07/03 17:23:15 d2.utils.events]:  eta: 12:08:44  iter: 99  total_loss: 0.016  loss_cls: 0.009  loss_box_reg: 0.000  loss_mask: 0.000  loss_mask_point: 0.000  loss_rpn_cls: 0.005  loss_rpn_loc: 0.003  time: 1.6235  data_time: 0.0148  lr: 0.000200  max_mem: 2866M
[07/03 17:23:47 d2.utils.events]:  eta: 12:07:41  iter: 119  total_loss: 0.016  loss_cls: 0.007  loss_box_reg: 0.000  loss_mask: 0.000  loss_mask_point: 0.000  loss_rpn_cls: 0.005  loss_rpn_loc: 0.003  time: 1.6227  data_time: 0.0152  lr: 0.000240  max_mem: 2866M
[07/03 17:24:19 d2.utils.events]:  eta: 12:07:03  iter: 139  total_loss: 0.011  loss_cls: 0.004  loss_box_reg: 0.000  loss_mask: 0.000  loss_mask_point: 0.000  loss_rpn_cls: 0.004  loss_rpn_loc: 0.002  time: 1.6233  data_time: 0.0149  lr: 0.000280  max_mem: 2866M
[07/03 17:24:52 d2.utils.events]:  eta: 12:06:25  iter: 159  total_loss: 0.011  loss_cls: 0.003  loss_box_reg: 0.000  loss_mask: 0.000

4、注意项

报错1:CUDA error: device-side assert triggered

应该是实际类别数和网络输出的类别数不一致,注意修改。

报错2:KeyError: 'Non-existent config key: MODEL.ROI_MASK_HEAD.FC_DIM'

这个是detectron2/config/config.py里没有ROI_MASK_HEAD.FC_DIM,去注册一下,如果注册了还报其他相同的错误,那就是pointrend的add_pointrend_config没有导入,将其导入,然后再set中加载,我把我add_pointrend_config的内容粘贴出来吧,看有朋友没有弄懂,这里再写点,有些朋友问我他们在pointrend的add_pointrend_config有,但是还是跑不通,那就去detectron2/detectron2/config/defaults.py相应的位置上也注册一下,这样就可以了。

from detectron2.config import CfgNode as CN


def add_pointrend_config(cfg):
    """
    Add config for PointRend.
    """
    # We retry random cropping until no single category in semantic segmentation GT occupies more
    # than `SINGLE_CATEGORY_MAX_AREA` part of the crop.
    cfg.INPUT.CROP.SINGLE_CATEGORY_MAX_AREA = 1.0
    # Color augmentatition from SSD paper for semantic segmentation model during training.
    cfg.INPUT.COLOR_AUG_SSD = False

    # Names of the input feature maps to be used by a coarse mask head.
    cfg.MODEL.ROI_MASK_HEAD.IN_FEATURES = ("p2",)
    cfg.MODEL.ROI_MASK_HEAD.FC_DIM = 1024      #一般报错是这里,那就在这里注册一下,哪个没有就把它在这里注册一下,就OK了
    cfg.MODEL.ROI_MASK_HEAD.NUM_FC = 2
    # The side size of a coarse mask head prediction.
    cfg.MODEL.ROI_MASK_HEAD.OUTPUT_SIDE_RESOLUTION = 7
    # True if point head is used.
    cfg.MODEL.ROI_MASK_HEAD.POINT_HEAD_ON = False

    cfg.MODEL.POINT_HEAD = CN()
    cfg.MODEL.POINT_HEAD.NAME = "StandardPointHead"
    cfg.MODEL.POINT_HEAD.NUM_CLASSES = 3
    # Names of the input feature maps to be used by a mask point head.
    cfg.MODEL.POINT_HEAD.IN_FEATURES = ("p2",)
    # Number of points sampled during training for a mask point head.
    cfg.MODEL.POINT_HEAD.TRAIN_NUM_POINTS = 14 * 14
    # Oversampling parameter for PointRend point sampling during training. Parameter `k` in the
    # original paper.
    cfg.MODEL.POINT_HEAD.OVERSAMPLE_RATIO = 3
    # Importance sampling parameter for PointRend point sampling during training. Parametr `beta` in
    # the original paper.
    cfg.MODEL.POINT_HEAD.IMPORTANCE_SAMPLE_RATIO = 0.75
    # Number of subdivision steps during inference.
    cfg.MODEL.POINT_HEAD.SUBDIVISION_STEPS = 5
    # Maximum number of points selected at each subdivision step (N).
    cfg.MODEL.POINT_HEAD.SUBDIVISION_NUM_POINTS = 28 * 28
    cfg.MODEL.POINT_HEAD.FC_DIM = 256
    cfg.MODEL.POINT_HEAD.NUM_FC = 3
    cfg.MODEL.POINT_HEAD.CLS_AGNOSTIC_MASK = False
    # If True, then coarse prediction features are used as inout for each layer in PointRend's MLP.
    cfg.MODEL.POINT_HEAD.COARSE_PRED_EACH_LAYER = True
    cfg.MODEL.POINT_HEAD.COARSE_SEM_SEG_HEAD_NAME = "SemSegFPNHead"

报错3:NotImplementedError: Caught NotImplementedError in DataLoader worker process 0.

Traceback (most recent call last):
  File "train_net.py", line 182, in <module>
    args=(args,),
  File "/root/detectron2/detectron2/engine/launch.py", line 54, in launch
    daemon=False,
  File "/root/anaconda3/envs/PointRend/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 171, in spawn
    while not spawn_context.join():
  File "/root/anaconda3/envs/PointRend/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 118, in join
    raise Exception(msg)
Exception:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/root/anaconda3/envs/PointRend/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/root/detectron2/detectron2/engine/launch.py", line 89, in _distributed_worker
    main_func(*args)
  File "/root/detectron2/projects/PointRend/train_net.py", line 170, in main
    return trainer.train()
  File "/root/detectron2/detectron2/engine/defaults.py", line 401, in train
    super().train(self.start_iter, self.max_iter)
  File "/root/detectron2/detectron2/engine/train_loop.py", line 132, in train
    self.run_step()
  File "/root/detectron2/detectron2/engine/train_loop.py", line 209, in run_step
    data = next(self._data_loader_iter)
  File "/root/detectron2/detectron2/data/common.py", line 141, in __iter__
    for d in self.dataset:
  File "/root/anaconda3/envs/PointRend/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 345, in __next__
    data = self._next_data()
  File "/root/anaconda3/envs/PointRend/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 856, in _next_data
    return self._process_data(data)
  File "/root/anaconda3/envs/PointRend/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 881, in _process_data
    data.reraise()
  File "/root/anaconda3/envs/PointRend/lib/python3.6/site-packages/torch/_utils.py", line 394, in reraise
    raise self.exc_type(msg)
NotImplementedError: Caught NotImplementedError in DataLoader worker process 0.

将setup()里的cfg.INPUT.CROP.ENABLED = True注销,或者改成False,是因为数据集不需要进行分割处理来增强图片,可以参考detectron2/config/defaults.py里的内容。

更新一下

       看了评论里有朋友问我pointrend的测试demo,我是从detection2的demo.py里跑出来的,这样吧我这里贴两种代码,方便大家交流,也方便大家参考一下。

第一种:

一个是直接从demo修改过来的,在PointRend里新建一个demo.py,将下面的代码复制过去。

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import argparse
import glob
import multiprocessing as mp
import os
import time
import cv2
import tqdm

from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from detectron2.utils.logger import setup_logger

from predictor import VisualizationDemo
from point_rend import add_pointrend_config
# constants
WINDOW_NAME = "COCO detections"


def setup_cfg(args):
    # load config from file and command-line arguments
    cfg = get_cfg()
    add_pointrend_config(cfg)
    cfg.merge_from_file(args.config_file)
    cfg.merge_from_list(args.opts)
    # Set score_threshold for builtin models
    cfg.MODEL.RETINANET.SCORE_THRESH_TEST = args.confidence_threshold
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = args.confidence_threshold
    cfg.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = args.confidence_threshold
    cfg.freeze()
    return cfg


def get_parser():
    parser = argparse.ArgumentParser(description="Detectron2 demo for builtin models")
    parser.add_argument(
        "--config-file",
        default="configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml",
        metavar="FILE",
        help="path to config file",
    )
    parser.add_argument("--webcam", action="store_true", help="Take inputs from webcam.")
    parser.add_argument("--video-input", help="Path to video file.")
    parser.add_argument(
        "--input",
        nargs="+",
        help="A list of space separated input images; "
        "or a single glob pattern such as 'directory/*.jpg'",
    )
    parser.add_argument(
        "--output",
        help="A file or directory to save output visualizations. "
        "If not given, will show output in an OpenCV window.",
    )

    parser.add_argument(
        "--confidence-threshold",
        type=float,
        default=0.5,
        help="Minimum score for instance predictions to be shown",
    )
    parser.add_argument(
        "--opts",
        help="Modify config options using the command-line 'KEY VALUE' pairs",
        default=[],
        nargs=argparse.REMAINDER,
    )
    return parser


if __name__ == "__main__":
    mp.set_start_method("spawn", force=True)
    args = get_parser().parse_args()
    setup_logger(name="fvcore")
    logger = setup_logger()
    logger.info("Arguments: " + str(args))

    cfg = setup_cfg(args)
    
    demo = VisualizationDemo(cfg)

    if args.input:
        if len(args.input) == 1:
            args.input = glob.glob(os.path.expanduser(args.input[0]))
            assert args.input, "The input path(s) was not found"
        for path in tqdm.tqdm(args.input, disable=not args.output):
            # use PIL, to be consistent with evaluation
            img = read_image(path, format="BGR")
            start_time = time.time()
            predictions, visualized_output = demo.run_on_image(img)
            logger.info(
                "{}: {} in {:.2f}s".format(
                    path,
                    "detected {} instances".format(len(predictions["instances"]))
                    if "instances" in predictions
                    else "finished",
                    time.time() - start_time,
                )
            )

            if args.output:
                if os.path.isdir(args.output):
                    assert os.path.isdir(args.output), args.output
                    out_filename = os.path.join(args.output, os.path.basename(path))
                else:
                    assert len(args.input) == 1, "Please specify a directory with args.output"
                    out_filename = args.output
                visualized_output.save(out_filename)
            else:
                cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
                cv2.imshow(WINDOW_NAME, visualized_output.get_image()[:, :, ::-1])
                if cv2.waitKey(0) == 27:
                    break  # esc to quit
    elif args.webcam:
        assert args.input is None, "Cannot have both --input and --webcam!"
        cam = cv2.VideoCapture(0)
        for vis in tqdm.tqdm(demo.run_on_video(cam)):
            cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
            cv2.imshow(WINDOW_NAME, vis)
            if cv2.waitKey(1) == 27:
                break  # esc to quit
        cv2.destroyAllWindows()
    elif args.video_input:
        video = cv2.VideoCapture(args.video_input)
        width = int(video.get(cv2.CAP_PROP_FRAME_WIDTH))
        height = int(video.get(cv2.CAP_PROP_FRAME_HEIGHT))
        frames_per_second = video.get(cv2.CAP_PROP_FPS)
        num_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
        basename = os.path.basename(args.video_input)

        if args.output:
            if os.path.isdir(args.output):
                output_fname = os.path.join(args.output, basename)
                output_fname = os.path.splitext(output_fname)[0] + ".mkv"
            else:
                output_fname = args.output
            assert not os.path.isfile(output_fname), output_fname
            output_file = cv2.VideoWriter(
                filename=output_fname,
                # some installation of opencv may not support x264 (due to its license),
                # you can try other format (e.g. MPEG)
                fourcc=cv2.VideoWriter_fourcc(*"x264"),
                fps=float(frames_per_second),
                frameSize=(width, height),
                isColor=True,
            )
        assert os.path.isfile(args.video_input)
        for vis_frame in tqdm.tqdm(demo.run_on_video(video), total=num_frames):
            if args.output:
                output_file.write(vis_frame)
            else:
                cv2.namedWindow(basename, cv2.WINDOW_NORMAL)
                cv2.imshow(basename, vis_frame)
                if cv2.waitKey(1) == 27:
                    break  # esc to quit
        video.release()
        if args.output:
            output_file.release()
        else:
            cv2.destroyAllWindows()

测试的时候需要通过命令行来测试,运行下面的代码:

python demo.py --config-file models/model_final_3c3198.pkl --input test/1.jpg --output test/out/ 

第二种:

上面的就是简单的测试,要是测试自己的数据集并且不想用命令行来测试,也是先创建一个demo.py,那就用下面的代码:

import os
import cv2
from detectron2.config import get_cfg
from detectron2.data.detection_utils import read_image
from predictor import VisualizationDemo
from point_rend import add_pointrend_config
# constants
WINDOW_NAME = "COCO detections"


def setup_cfg():
    # load config from file and command-line arguments
    cfg = get_cfg()
    add_pointrend_config(cfg)
    cfg.merge_from_file("/root/detectron2/projects/PointRend/configs/InstanceSegmentation/pointrend_rcnn_R_50_FPN_3x_coco.yaml")
    # Set score_threshold for builtin models
    cfg.MODEL.RETINANET.SCORE_THRESH_TEST = 0.5
    cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.5

    cfg.MODEL.WEIGHTS = os.path.join(cfg.OUTPUT_DIR, "model_final.pth")#你下载的模型地址和名称
    cfg.freeze()
    return cfg

if __name__=='__main__':
    cfg = setup_cfg()    
    detectron_out=VisualizationDemo(cfg)
    path=r'/root/detectron2/testimage/1.jpg'#对应测试图片的地址
    outpath=r'/root/detectron2/testimage/out/'#对应保存测试结果的地址
    img = read_image(path, format="BGR")
    predictions, visualized_output = detectron_out.run_on_image(img)
    out_img=visualized_output.get_image()[:, :, ::-1]
    out_path=os.path.join(outpath,os.path.basename(path))
    visualized_output.save(out_path)

然后直接运行python demo.py就可以了。其实也就是从上面的代码上修改过来的,下面的代码有一个问题就是,要是测试自己训练的模型,这里的结果上是没有显示标签名,这个需要修改一下了,我这里就不说了啊,大家自己去研究一下吧,也很简单的,其实就是要把自己的数据标签注册进去。

 

版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/qq_33047753/article/details/107022058

智能推荐

【新手科研指南5】深度学习代码怎么读-小白阶段性思路(以手写数字识别应用为例)_深度学习程序怎么读-程序员宅基地

文章浏览阅读6.2k次,点赞6次,收藏26次。我是一个深度学习代码小白,请你用中文写上注释,能让我能轻松理解下面这段代码。注意包含所有函数、调用和参数的注释。以同样的python代码块样式返回你写的代码给我。代码看累了,就看《动手学深度学习》文档:基于PyTorch框架,从底层函数实现基础功能,再到框架的高级功能。努力上路的小白一枚,麻烦路过的大佬指导一二,同时希望能和大家交流学习~争取更新学习这个文档的专栏,记录学习过程。量身定做了一套话术hhh,亲身测试还不错。这个感觉更浅一点儿,之后复习看吧。20天吃掉那只Pytorch。_深度学习程序怎么读

Java学习路线图,看这一篇就够了!-程序员宅基地

文章浏览阅读2.7w次,点赞126次,收藏1.2k次。耗废1024根秀发,Java学习路线图来了,整合了自己所学的所有技术整理出来的2022最新版Java学习路线图,适合于初、中级别的Java程序员。_java学习路线

PCL_Tutorial2-1.7-点云保存PNG_pcl::io:savepng-程序员宅基地

文章浏览阅读4.4k次。1.7-savingPNG介绍代码详情函数详解savePNGFile()源码savePNGFile()源码提示savePNGFile()推荐用法处理结果代码链接介绍PCL提供了将点云的值保存到PNG图像文件的可能性。这只能用有有序的云来完成,因为结果图像的行和列将与云中的行和列完全对应。例如,如果您从类似Kinect或Xtion的传感器中获取了点云,则可以使用它来检索与该云匹配的640x480 RGB图像。代码详情#include <pcl / io / pcd_io.h>#incl_pcl::io:savepng

知乎问答:程序员在咖啡店编程,喝什么咖啡容易吸引妹纸?-程序员宅基地

文章浏览阅读936次。吸引妹子的关键点不在于喝什么咖啡,主要在于竖立哪种男性人设。能把人设在几分钟内快速固定下来,也就不愁吸引对口的妹子了。我有几个备选方案,仅供参考。1. 运动型男生左手单手俯卧撑,右手在键盘上敲代码。你雄壮的腰腹肌肉群活灵活现,简直就是移动的春药。2.幽默男生花 20 块找一个托(最好是老同学 or 同事)坐你对面。每当你侃侃而谈,他便满面涨红、放声大笑、不能自已。他笑的越弱_咖啡厅写代码

【笔试面试】腾讯WXG 面委会面复盘总结 --一次深刻的教训_腾讯面委会面试是什么-程序员宅基地

文章浏览阅读1.2w次,点赞5次,收藏5次。今天 (应该是昨天了,昨晚太晚了没发出去)下午参加了腾讯WXG的面委会面试。前面在牛客上搜索了面委会相关的面经普遍反映面委会较难,因为都是微信的核心大佬,问的问题也会比较深。昨晚还蛮紧张的,晚上都没睡好。面试使用的是腾讯会议,时间到了面试官准时进入会议。照例是简单的自我介绍,然后是几个常见的基础问题:例如数据库索引,什么时候索引会失效、设计模式等。这部分比较普通,问的也不是很多,不再赘述。现在回想下,大部分还是简历上写的技能点。接下来面试官让打开项目的代码,对着代码讲解思路。我笔记本上没有这部分代码,所_腾讯面委会面试是什么

AI绘画自动生成器:艺术创作的新浪潮-程序员宅基地

文章浏览阅读382次,点赞3次,收藏4次。AI绘画自动生成器是一种利用人工智能技术,特别是深度学习算法,来自动创建视觉艺术作品的软件工具。这些工具通常基于神经网络模型,如生成对抗网络(GANs),通过学习大量的图像数据来生成新的图像。AI绘画自动生成器作为艺术与科技结合的产物,正在开启艺术创作的新篇章。它们不仅为艺术家和设计师提供了新的工具,也为普通用户提供了探索艺术的机会。随着技术的不断进步,我们可以预见,AI绘画自动生成器将在未来的创意产业中发挥越来越重要的作用。

随便推点

Flutter ListView ListView.build ListView.separated_flutter listview.separated和listview.builder-程序员宅基地

文章浏览阅读1.7k次。理解为ListView 的三种形式吧ListView 默认构造但是这种方式创建的列表存在一个问题:对于那些长列表或者需要较昂贵渲染开销的子组件,即使还没有出现在屏幕中但仍然会被ListView所创建,这将是一项较大的开销,使用不当可能引起性能问题甚至卡顿直接返回的是每一行的Widget,相当于ios的row。行高按Widget(cell)高设置ListView.build 就和io..._flutter listview.separated和listview.builder

2021 最新前端面试题及答案-程序员宅基地

文章浏览阅读1.4k次,点赞4次,收藏14次。废话不多说直接上干货1.js运行机制JavaScript单线程,任务需要排队执行同步任务进入主线程排队,异步任务进入事件队列排队等待被推入主线程执行定时器的延迟时间为0并不是立刻执行,只是代表相比于其他定时器更早的被执行以宏任务和微任务进一步理解js执行机制整段代码作为宏任务开始执行,执行过程中宏任务和微任务进入相应的队列中整段代码执行结束,看微任务队列中是否有任务等待执行,如果有则执行所有的微任务,直到微任务队列中的任务执行完毕,如果没有则继续执行新的宏任务执行新的宏任务,凡是在..._前端面试

linux基本概述-程序员宅基地

文章浏览阅读1k次。(3)若没有查到,则将请求发给根域DNS服务器,并依序从根域查找顶级域,由顶级查找二级域,二级域查找三级,直至找到要解析的地址或名字,即向客户机所在网络的DNS服务器发出应答信息,DNS服务器收到应答后现在缓存中存储,然后,将解析结果发给客户机。(3)若没有查到,则将请求发给根域DNS服务器,并依序从根域查找顶级域,由顶级查找二级域,二级域查找三级,直至找到要解析的地址或名字,即向客户机所在网络的DNS服务器发出应答信息,DNS服务器收到应答后现在缓存中存储,然后,将解析结果发给客户机。_linux

JavaScript学习手册十三:HTML DOM——文档元素的操作(一)_javascript学习手册十三:html dom——文档元素的操作(一)-程序员宅基地

文章浏览阅读7.9k次,点赞26次,收藏66次。HTML DOM——文档元素的操作1、通过id获取文档元素任务描述相关知识什么是DOM文档元素节点树通过id获取文档元素代码文件2、通过类名获取文档元素任务描述相关知识通过类名获取文档元素代码文件3、通过标签名获取文档元素任务描述相关知识通过标签名获取文档元素获取标签内部的子元素代码文件4、html5中获取元素的方法一任务描述相关知识css选择器querySelector的用法代码文件5、html5中获取元素的方法二任务描述相关知识querySelectorAll的用法代码文件6、节点树上的操作任务描述相关_javascript学习手册十三:html dom——文档元素的操作(一)

《LeetCode刷题》172. 阶乘后的零(java篇)_java 给定一个整数n,返回n!结果尾数中零的数量-程序员宅基地

文章浏览阅读132次。《LeetCode学习》172. 阶乘后的零(java篇)_java 给定一个整数n,返回n!结果尾数中零的数量

php 公众号消息提醒,如何开启公众号消息提醒功能-程序员宅基地

文章浏览阅读426次。请注意,本文将要给大家分享的并不是开启公众号的安全操作风险提醒,而是当公众号粉丝给公众号发消息的时候,公众号的管理员和运营者如何能在手机上立即收到消息通知,以及在手机上回复粉丝消息。第一步:授权1、在微信中点击右上角+,然后选择“添加朋友”,然后选择“公众号”,然后输入“微小助”并关注该公众号。2、进入微小助公众号,然后点击底部菜单【新增授权】,如下图所示:3、然后会打开一个温馨提示页面。请一定要..._php微信公众号服务提示

推荐文章

热门文章

相关标签