train是训练集,val是训练过程中的测试集,是为了让你在边训练边看到训练的结果,及时判断学习状态。test就是训练模型结束后,用于评价模型结果的测试集。只有train就可以训练,val不是必须的,比例也可以设置很小。

验证数据集可以理解为训练数据集的一块

制作图书馆数据集代码如下:

### Data Format for Semantic Segmentation

The raw data will be processed by generator shell scripts. There will be two subdirs('train' & 'val')

```

train or val dir {

image: contains the images for train or val.

label: contains the label png files(mode='P') for train or val.

mask: contains the mask png files(mode='P') for train or val.

}

```

"""
 -*- coding: utf-8 -*-
 author: Hao Hu
 @date   2022/1/20 11:02 AM
"""
import cv2
import numpy as np
from matplotlib import pyplot as plt
import os.path as osp
import os
from tqdm import tqdm
from PIL import Image
import PIL
from concurrent.futures import ThreadPoolExecutor
def grab_cut(img_path):
    """使用了grab_cut算法获得物体和背景轮廓"""
    img_ori = cv2.imread(img_path)
    # 将img二值化
    retVal, image = cv2.threshold(img_ori, 50, 100, cv2.THRESH_BINARY)
    mask = np.zeros(image.shape[:2], np.uint8)
    bgdModel = np.zeros((1, 65), np.float64)
    fgdModel = np.zeros((1, 65), np.float64)
    ix = int(img_ori.shape[0] / 22)
    iy = int(img_ori.shape[1] / 20)
    w = iy * 20
    h = ix * 22
    rect = (ix, iy, int(w), int(h))
    # cv2.rectangle(img, (ix*2, iy*3), (int(w*0.9), int(h*0.9)), (0, 255, 0), 2)
    # 默认几个点作为物体和背景像素点
    # (ix*15,iy*26),(ix*21,iy*15),(ix*21,iy*10)为背景像素点
    cv2.circle(mask, (ix*15, iy*26), 15, [0,0,0], -1)
    cv2.grabCut(image, mask, rect, bgdModel, fgdModel, 5, cv2.GC_INIT_WITH_RECT)
    mask2 = np.where((mask == 2) | (mask == 0), 0, 1).astype('uint8')
    mask2[ix * 21, iy * 19] = 1
    #plt.imshow(mask2), plt.colorbar(), plt.show()
    img = image * mask2[:, :, np.newaxis]

    return img,image,mask2,img_ori

def get_mask_box(mask):

    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    contours = list(contours)
    contours.sort(key=lambda x: cv2.contourArea(x), reverse=True)
    cnt = cv2.approxPolyDP(contours[0], epsilon=100, closed=True)
    cnt = cv2.minAreaRect(cnt)
    box = np.int0(cv2.boxPoints(cnt))
    return mask, box


def imwrite_the_label_img(ori_folder,end_folder_path,img_NAME):
    img_path = osp.join(ori_folder,img_NAME)
    img,image,mask,img_ori = grab_cut(img_path)
    _, box=get_mask_box(mask)
    re = cv2.drawContours(image.copy(), [box], 0, (0, 255, 0), -1)

    end_path = osp.join(end_folder_path, img_NAME[:-2]+'.png')
    cv2.imwrite((end_path), re)
    # 将图片转为model = P
    re = PIL.Image.open(end_path)
    re = re.convert('P')
    re.save(end_path)




if __name__ == '__main__':
    ori_folder = '/cloud_disk/users/huh/dataset/lib_dataset/train/image'
    img_list = os.listdir(ori_folder)
    end_folder_path = '/cloud_disk/users/huh/dataset/lib_dataset/train/label'
    executor = ThreadPoolExecutor(max_workers=100)  # 最大线程数量
    for img_NAME in tqdm(img_list):
        executor.map(imwrite_the_label_img(ori_folder,end_folder_path,img_NAME))

Logo

魔乐社区(Modelers.cn) 是一个中立、公益的人工智能社区,提供人工智能工具、模型、数据的托管、展示与应用协同服务,为人工智能开发及爱好者搭建开放的学习交流平台。社区通过理事会方式运作,由全产业链共同建设、共同运营、共同享有,推动国产AI生态繁荣发展。

更多推荐