๐Ÿ˜Ž ๊ณต๋ถ€ํ•˜๋Š” ์ง•์ง•์•ŒํŒŒ์นด๋Š” ์ฒ˜์Œ์ด์ง€?

[Kaggle] ์˜๋ฃŒ์šฉ ๋งˆ์Šคํฌ ํƒ์ง€ ๋ชจ๋ธ Object Detection ๋ณธ๋ฌธ

๐Ÿ‘ฉ‍๐Ÿ’ป IoT (Embedded)/Image Processing

[Kaggle] ์˜๋ฃŒ์šฉ ๋งˆ์Šคํฌ ํƒ์ง€ ๋ชจ๋ธ Object Detection

์ง•์ง•์•ŒํŒŒ์นด 2022. 11. 9. 16:08
728x90
๋ฐ˜์‘ํ˜•

<๋ณธ ๋ธ”๋กœ๊ทธ๋Š” pseudo-lab ๋‹˜์˜ Tutorial-Book ๋ธ”๋กœ๊ทธ๋ฅผ ์ฐธ๊ณ ํ•ด์„œ ๊ณต๋ถ€ํ•˜๋ฉฐ ์ž‘์„ฑํ•˜์˜€์Šต๋‹ˆ๋‹ค>

https://pseudo-lab.github.io/Tutorial-Book/chapters/object-detection/Ch1-Object-Detection.html

 

1. ๊ฐ์ฒด ํƒ์ง€ ์†Œ๊ฐœ — PseudoLab Tutorial Book

๊ฐ์ฒด ํƒ์ง€(Object Detection)๋Š” ์ปดํ“จํ„ฐ ๋น„์ „ ๊ธฐ์ˆ ์˜ ์„ธ๋ถ€ ๋ถ„์•ผ์ค‘ ํ•˜๋‚˜๋กœ์จ ์ฃผ์–ด์ง„ ์ด๋ฏธ์ง€๋‚ด ์‚ฌ์šฉ์ž๊ฐ€ ๊ด€์‹ฌ ์žˆ๋Š” ๊ฐ์ฒด๋ฅผ ํƒ์ง€ํ•˜๋Š” ๊ธฐ์ˆ ์ž…๋‹ˆ๋‹ค. ์ธ๊ณต์ง€๋Šฅ ๋ชจ๋ธ์ด ๊ทธ๋ฆผ 1-1 ์ขŒ์ธก์— ์žˆ๋Š” ๊ฐ•์•„์ง€ ์‚ฌ์ง„์„ ๊ฐ•

pseudo-lab.github.io

 

 

โ›„ ์˜๋ฃŒ์šฉ ๋งˆ์Šคํฌ ํƒ์ง€ ๋ชจ๋ธ

๋งˆ์Šคํฌ๋ฅผ ์ฐฉ์šฉํ•œ ์‚ฌ๋žŒ, ์ฐฉ์šฉํ•˜์ง€ ์•Š์€ ์‚ฌ๋žŒ ๋˜๋Š” ๋งˆ์Šคํฌ๋ฅผ ๋ถ€์ ์ ˆํ•˜๊ฒŒ ์ฐฉ์šฉํ•œ ์‚ฌ๋žŒ์„ ๊ฐ์ง€ํ•˜๋Š” ๋ชจ๋ธ์„ ๋งŒ๋“ค๊ธฐ

 

๐Ÿ‘‍๐Ÿ—จ ๋ฐ์ดํ„ฐ ์„ธํŠธ

3 ๊ฐ€์ง€ ํด๋ž˜์Šค ์— ์†ํ•˜๋Š” 853 ๊ฐœ์˜ ์ด๋ฏธ์ง€ ์™€ PASCAL VOC ํ˜•์‹์˜ ๊ฒฝ๊ณ„ ์ƒ์ž๊ฐ€ ํฌํ•จ

  • ๋งˆ์Šคํฌ ํฌํ•จ
  • ๋งˆ์Šคํฌ ์—†์ด
  • ๋งˆ์Šคํฌ๋ฅผ ์ž˜๋ชป ์ฐฉ์šฉ

๐Ÿ‘‍๐Ÿ—จ images์™€ annotations ํด๋”

  • images ํด๋”์—๋Š” ์ด๋ฏธ์ง€ ํŒŒ์ผ์ด 0๋ถ€ํ„ฐ 852
  • annotations ํด๋”์—๋Š” xml ํŒŒ์ผ์ด 0๋ถ€ํ„ฐ 852
    • annotations ํด๋” ์•ˆ์— ์žˆ๋Š” xml ํŒŒ์ผ๋“ค์€ ๊ฐ๊ฐ์˜ ์ด๋ฏธ์ง€ ํŒŒ์ผ์˜ ์ •๋ณด
      • ํด๋”๋ช…๊ณผ ํŒŒ์ผ๋ช…์ด ๋‚˜์˜ค๋ฉฐ, ์ด๋ฏธ์ง€ ํฌ๊ธฐ ์ •๋ณด๊ฐ€ ํฌํ•จ๋˜์–ด ์žˆ๋Š” ๊ฑธ ํ™•์ธ
      • mask_weared_incorrect์˜ ๊ฒฝ์šฐ ๋งˆ์Šคํฌ๋ฅผ ์ œ๋Œ€๋กœ ์“ฐ์ง€ ์•Š์€ ๊ฐ์ฒด์˜ ์ •๋ณด
      • with_mask๋Š” ๋งˆ์Šคํฌ๋ฅผ ์ฐฉ์šฉํ•˜๊ณ  ์žˆ๋Š” ๊ฐ์ฒด ์œ„์น˜ ์ •๋ณด
      • without_mask์€ ๋งˆ์Šคํฌ๋ฅผ ์“ฐ์ง€ ์•Š์€ ๊ฐ์ฒด์˜ ์ •๋ณด
      • bndbox์•ˆ์—๋Š” xmin, ymin, xmax, ymax๊ฐ€ ์ˆœ์„œ๋Œ€๋กœ ์žˆ์Œ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค ์˜์—ญ์„ ์ง€์ •ํ•˜๋Š” ์ •๋ณด

annotations

 

๐ŸŒ 1. ๋ฐ์ดํ„ฐ ๊ฒ€์ฆ

import os
import glob                             # ํŒŒ์ผ์„ ๋‹ค๋ฃจ๋Š”๋ฐ ๋„๋ฆฌ ์“ฐ์ž„
import matplotlib.pyplot as plt         # ์‹œ๊ฐํ™”
import matplotlib.image as mpimg        
import matplotlib.patches as patches    
# HTML๊ณผ XML ๋ฌธ์„œ ํŒŒ์ผ์„ ํŒŒ์‹ฑ(Parsing), ์›น์Šคํฌ๋ž˜ํ•‘(Web Scraping)์— ์œ ์šฉ
from bs4 import BeautifulSoup           

from PIL import Image
import cv2
import numpy as np
import time
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations
import albumentations.pytorch
import os
import random
  • glob ํŒจํ‚ค์ง€๋ฅผ ์ด์šฉํ•ด ๋ฐ์ดํ„ฐ์…‹์„ ๋ถˆ๋Ÿฌ์˜ด
  • sorted ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•ด img_list์— ์žˆ๋Š” ํŒŒ์ผ์˜ id ์ˆœ์„œ์™€ annot_list์— ์žˆ๋Š” ํŒŒ์ผ์˜ id ์ˆœ์„œ๊ฐ€ ๊ฐ™๋„๋ก ํ•จ
img_list = sorted(glob.glob('archive/images/*'))
annot_list = sorted(glob.glob('archive/annotations/*'))

 

๐ŸŒ 2. ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค ์‹œ๊ฐํ™”๋ฅผ ์œ„ํ•œ ํ•จ์ˆ˜๋ฅผ ์ •์˜

def generate_box(obj):
    # xmin, ymin, xmax, ymax ๊ฐ’์„ ๋ฐ˜ํ™˜

    xmin = float(obj.find('xmin').text)
    ymin = float(obj.find('ymin').text)
    xmax = float(obj.find('xmax').text)
    ymax = float(obj.find('ymax').text)
    
    return [xmin, ymin, xmax, ymax]

def generate_label(obj):
    # ๋งˆ์Šคํฌ ์ฐฉ์šฉ ์—ฌ๋ถ€๋ฅผ ์„ธ๋‹จ๊ณ„๋กœ ๋‚˜๋ˆ ์„œ 0, 1, 2 ๊ฐ’์„ ๋ฐ˜ํ™˜
    # with_mask์˜ ๊ฒฝ์šฐ 1
    # mask_weared_incorrect์˜ ๊ฒฝ์šฐ 2
    # without_mask๋Š” 0

    if obj.find('name').text == "with_mask":
        return 1
    elif obj.find('name').text == "mask_weared_incorrect":
        return 2
    return 0

def generate_target(file): 
    # generate_box์™€ generate_label๋ฅผ ๊ฐ๊ฐ ํ˜ธ์ถœํ•˜์—ฌ ๋ฐ˜ํ™˜๋œ ๊ฐ’์„ ๋”•์…”๋„ˆ๋ฆฌ์— ์ €์žฅํ•ด ๋ฐ˜ํ™˜

    with open(file) as f:
        data = f.read()

        # annotations ํŒŒ์ผ์— ์žˆ๋Š” ๋‚ด์šฉ๋“ค์„ ๋ถˆ๋Ÿฌ์™€ ํƒ€๊ฒŸ์˜ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค์™€ ๋ผ๋ฒจ์„ ์ถ”์ถœ
        soup = BeautifulSoup(data, "html.parser")
        objects = soup.find_all("object")

        num_objs = len(objects)

        boxes = []
        labels = []
        for i in objects:
            boxes.append(generate_box(i))
            labels.append(generate_label(i))
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        
        return target

def plot_image(img_path, annotation):
    # ์ด๋ฏธ์ง€์™€ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋ฅผ ํ•จ๊ป˜ ์‹œ๊ฐํ™”

    img = mpimg.imread(img_path)
    
    fig, ax = plt.subplots(1)
    ax.imshow(img)

    
    for idx in range(len(annotation["boxes"])):
        xmin, ymin, xmax, ymax = annotation["boxes"][idx]

        # ๋งˆ์Šคํฌ ์ฐฉ์šฉ์‹œ ์ดˆ๋ก์ƒ‰
        # ๋งˆ์Šคํฌ๋ฅผ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์ฐฉ์šฉ ์•ˆํ–ˆ์„ ์‹œ ์ฃผํ™ฉ์ƒ‰
        # ๋งˆ์Šคํฌ๋ฅผ ์ฐฉ์šฉ ์•ˆํ–ˆ์„ ์‹œ ๋นจ๊ฐ„์ƒ‰ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค
        if annotation['labels'][idx] == 0 :
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='r',facecolor='none')
        
        elif annotation['labels'][idx] == 1 :
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='g',facecolor='none')
            
        else :
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='orange',facecolor='none')

        ax.add_patch(rect)

    plt.show()

 

  • ์ธ๋ฑ์Šค ๊ฐ’ ํ™•์ธ
img_list.index('archive/images/maksssksksss307.png')

 

  • generate_target ํ•จ์ˆ˜๋ฅผ ํ™œ์šฉํ•ด maksssksksss307.png ํŒŒ์ผ์— ํ•ด๋‹นํ•˜๋Š” ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค ์ •๋ณด๋ฅผ bbox์— ์ €์žฅ
  • plot_image ํ•จ์ˆ˜์— ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค ์ •๋ณด์™€ ๋”๋ถˆ์–ด ํ•ด๋‹น ์ด๋ฏธ์ง€ ํŒŒ์ผ ์ •๋ณด๋„ ๋„˜๊ฒจ์ฃผ์–ด ์ด๋ฏธ์ง€ ์œ„์— ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋ฅผ ์‹œ๊ฐํ™”
  • img_list[]์™€ annot_list[]์•ˆ์˜ ์ˆซ์ž๋Š” maksssksksss307.png ํŒŒ์ผ์˜ ์œ„์น˜๋ฅผ ๋œปํ•˜๋ฏ€๋กœ ๊ฐ™์€ ์ˆซ์ž๊ฐ€ ๋“ค์–ด๊ฐ€ ์žˆ์Œ
bbox = generate_target(annot_list[232])
plot_image(img_list[232], bbox)

 

๐ŸŒ 3. ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ

  • augmentation : ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•๋ฒ•
    • torchvision.transforms๋Š” ํŒŒ์ดํ† ์น˜์—์„œ ๊ณต์‹์ ์œผ๋กœ ์ œ๊ณต
      • ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค ๋ณ€ํ˜• ๊ธฐ๋Šฅ ์—†์Œ
    • albumentations๋Š” OpenCV์™€ ๊ฐ™์€ ์˜คํ”ˆ ์†Œ์Šค ์ปดํ“จํ„ฐ ๋น„์ ผ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ฅผ ์ตœ์ ํ™” ํ•˜์˜€๊ธฐ์— ๋‹ค๋ฅธ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋ณด๋‹ค ๋” ๋น ๋ฅธ ์ฒ˜๋ฆฌ ์†๋„ ๋ฐ ๊ธฐํƒ€ ๊ธฐ๋Šฅ์„ ์ œ๊ณต
      • ๊ฐ์ฒด ํƒ์ง€์šฉ ์ด๋ฏธ์ง€ augmentation์€ ์ด๋ฏธ์ง€ ๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๊นŒ์ง€ ๋ณ€ํ˜•์„ ์ฃผ์–ด์•ผ ํ•จ
!git clone https://github.com/Pseudo-Lab/Tutorial-Book-Utils
!python Tutorial-Book-Utils/PL_data_loader.py --data FaceMaskDetection
!unzip -q Face\ Mask\ Detection.zip
!pip install --upgrade albumentations

 

  • augmentation ๊ฒฐ๊ณผ๋ฌผ์„ ์‹œ๊ฐํ™”
import os
import glob
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from bs4 import BeautifulSoup

def generate_box(obj):
    # xmin, ymin, xmax, ymax ๊ฐ’์„ ๋ฐ˜ํ™˜

    xmin = float(obj.find('xmin').text)
    ymin = float(obj.find('ymin').text)
    xmax = float(obj.find('xmax').text)
    ymax = float(obj.find('ymax').text)
    
    return [xmin, ymin, xmax, ymax]

def generate_label(obj):
    # ๋งˆ์Šคํฌ ์ฐฉ์šฉ ์—ฌ๋ถ€๋ฅผ ์„ธ๋‹จ๊ณ„๋กœ ๋‚˜๋ˆ ์„œ 0, 1, 2 ๊ฐ’์„ ๋ฐ˜ํ™˜
    # with_mask์˜ ๊ฒฝ์šฐ 1
    # mask_weared_incorrect์˜ ๊ฒฝ์šฐ 2
    # without_mask๋Š” 0

    if obj.find('name').text == "with_mask":
        return 1
    elif obj.find('name').text == "mask_weared_incorrect":
        return 2
    return 0

def generate_target(file): 
    with open(file) as f:
        data = f.read()
        soup = BeautifulSoup(data, "html.parser")
        objects = soup.find_all("object")

        num_objs = len(objects)

        boxes = []
        labels = []
        for i in objects:
            boxes.append(generate_box(i))
            labels.append(generate_label(i))

        # ๋”ฅ๋Ÿฌ๋‹ ๋ชจ๋ธ ํ•™์Šต์„ ์œ„ํ•œ tensor๊ฐ„์˜ ์—ฐ์‚ฐ์„ ์ค€๋น„ํ•˜๊ธฐ ์œ„ํ•ด  torch.as_tensorํ•จ์ˆ˜๊ฐ€ ์ถ”๊ฐ€
        boxes = torch.as_tensor(boxes, dtype=torch.float32) 
        labels = torch.as_tensor(labels, dtype=torch.int64) 
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        
        return target



def plot_image_from_output(img, annotation):
    # ์ด๋ฏธ์ง€์™€ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋ฅผ ํ•จ๊ป˜ ์‹œ๊ฐํ™”
    # torch.Tensor๋กœ ๋ณ€ํ™˜๋œ ์ด๋ฏธ์ง€๋ฅผ ์‹œ๊ฐํ™”
        # PyTorch์—์„œ๋Š” ์ด๋ฏธ์ง€๋ฅผ [channels, height, width]๋กœ ํ‘œํ˜„
        # matplotlib์—์„œ๋Š” [height, width, channels]๋กœ ํ‘œํ˜„

    img = img.permute(1,2,0)
    # ์ฑ„๋„ ์ˆœ์„œ๋ฅผ ๋ฐ”๊ฟ”์ฃผ๋Š” permuteํ•จ์ˆ˜๋ฅผ ํ™œ์šฉํ•ด matplotlib์˜ ์ฑ„๋„ ์ˆœ์„œ๋กœ ๋ฐ”๊ฟˆ
    
    fig,ax = plt.subplots(1)
    ax.imshow(img)
    
    for idx in range(len(annotation["boxes"])):
        xmin, ymin, xmax, ymax = annotation["boxes"][idx]

        # ๋งˆ์Šคํฌ ์ฐฉ์šฉ์‹œ ์ดˆ๋ก์ƒ‰
        # ๋งˆ์Šคํฌ๋ฅผ ์˜ฌ๋ฐ”๋ฅด๊ฒŒ ์ฐฉ์šฉ ์•ˆํ–ˆ์„ ์‹œ ์ฃผํ™ฉ์ƒ‰
        # ๋งˆ์Šคํฌ๋ฅผ ์ฐฉ์šฉ ์•ˆํ–ˆ์„ ์‹œ ๋นจ๊ฐ„์ƒ‰ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค

        if annotation['labels'][idx] == 0 :
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='r',facecolor='none')
        
        elif annotation['labels'][idx] == 1 :
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='g',facecolor='none')
            
        else :
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='orange',facecolor='none')

        ax.add_patch(rect)

    plt.show()

 

โž• 1) torchvision.transforms์„ ์‹ค์Šต

from PIL import Image
import cv2
import numpy as np
import time
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations
import albumentations.pytorch
from matplotlib import pyplot as plt
import os
import random

class TorchvisionMaskDataset(Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.imgs = list(sorted(os.listdir(self.path)))
        self.transform = transform
        
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):         # image๋ฅผ ๋ถˆ๋Ÿฌ์˜จ ๋‹ค์Œ ๋ฐ์ดํ„ฐ augmentation์„ ์ง„ํ–‰
        file_image = self.imgs[idx]
        file_label = self.imgs[idx][:-3] + 'xml'
        img_path = os.path.join(self.path, file_image)
        
        if 'test' in self.path:
            label_path = os.path.join("test_annotations/", file_label)
        else:
            label_path = os.path.join("annotations/", file_label)

        img = Image.open(img_path).convert("RGB")
        
        target = generate_target(label_path)
        
        start_t = time.time()
        if self.transform:                  # transform ํŒŒ๋ผ๋ฏธํ„ฐ์— ์ €์žฅ๋ผ ์žˆ๋Š” augmentation ๊ทœ์น™์— ๋”ฐ๋ผ augmentation์ด ์ด๋ค„์ง
            img = self.transform(img)

        total_time = (time.time() - start_t)        # ์‹œ๊ฐ„ ์ธก์ •์„ ์œ„ํ•ด timeํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉ

        # ์ตœ์ข…์ ์œผ๋กœ image, label, total_time์„ ๋ฐ˜ํ™˜
        return img, target, total_time

 

  • ์ด๋ฏธ์ง€ augmentation ์‹ค์Šต์„ ์ง„ํ–‰
    • ์ด๋ฏธ์ง€๋ฅผ (300, 300) ํฌ๊ธฐ๋กœ ๋งŒ๋“  ํ›„, 224 ํฌ๊ธฐ๋กœ ์ž๋ฅด๊ธฐ
    • ์ด๋ฏธ์ง€์˜ ๋ฐ๊ธฐ(brightness), ๋Œ€๋น„(contrast), ์ฑ„๋„(saturation), ์ƒ‰์กฐ(hue)๋ฅผ ๋ฌด์ž‘์œ„๋กœ ๋ฐ”๊พธ๊ธฐ
    • ์ด๋ฏธ์ง€ ์ขŒ์šฐ ๋ฐ˜์ „์„ ์ ์šฉํ•œ ํ›„ tensor๋กœ ๋ณ€ํ™˜ํ•˜๋Š” ์ž‘์—…์„ ์ง„ํ–‰
torchvision_transform = transforms.Compose([
    transforms.Resize((300, 300)), 
    transforms.RandomCrop(224),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomHorizontalFlip(p = 1),
    transforms.ToTensor(),
])

torchvision_dataset = TorchvisionMaskDataset(
    path = 'images/',
    transform = torchvision_transform
)

 

  • transforms์—์„œ ์ œ๊ณตํ•˜๋Š” Resize ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ์ด๋ฏธ์ง€ ํฌ๊ธฐ๋ฅผ ์กฐ์ •
  • RandomCrop ํ•จ์ˆ˜๋ฅผ ํ†ตํ•ด ์ด๋ฏธ์ง€๋ฅผ ์ž๋ฅผ ์ˆ˜ ์žˆ์Œ
  • ColorJitter ํ•จ์ˆ˜๋Š” ๋ฐ๊ธฐ, ๋Œ€๋น„, ์ฑ„๋„, ์ƒ‰์กฐ ๋“ฑ์„ ์ž„์˜๋กœ ๋ฐ”๊พธ๋Š” ๊ธฐ๋Šฅ
  • RandomHorizontalFlip์€ ์ •์˜ํ•œ p์˜ ํ™•๋ฅ ๋กœ ์ขŒ์šฐ๋ฐ˜์ „์„ ์‹ค์‹œ
  • ๋ณ€๊ฒฝ ์ „๊ณผ ๋ณ€๊ฒฝ ํ›„์˜ ์ด๋ฏธ์ง€๋ฅผ ๋น„๊ต
only_totensor = transforms.Compose([transforms.ToTensor()])

torchvision_dataset_no_transform = TorchvisionMaskDataset(
    path = 'images/',
    transform = only_totensor
)

img, annot, transform_time = torchvision_dataset_no_transform[0]
print('transforms ์ ์šฉ ์ „')
plot_image_from_output(img, annot)

transforms ์ ์šฉ ์ „

 

  • ๋ณ€๊ฒฝ ์ „์— ๋น„ํ•ด ๋ณ€๊ฒฝ ํ›„ ์ด๋ฏธ์ง€๋Š” ์•ž์„œ ์–ธ๊ธ‰ํ•œ ๋ณ€ํ™”๋“ค์ด ์ ์šฉ๋จ
  • ์ด๋ฏธ์ง€ ์ž์ฒด์ ์ธ ๋ณ€ํ™”๋Š” ์ด๋ค„์กŒ์ง€๋งŒ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋Š” ๋ณ€ํ™”๋œ ์ด๋ฏธ์ง€์—์„œ ์œ„์น˜๊ฐ€ ์–ด๊ธ‹๋‚œ ๊ฒƒ์„ ํ™•์ธ
  • torchvision.transform์—์„œ ์ œ๊ณตํ•˜๋Š” augmentation์€ ์ด๋ฏธ์ง€ ๊ฐ’์— ๋Œ€ํ•œ augmentation๋งŒ ์ง„ํ–‰์ด ๋˜๋ฉฐ,
  • ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋Š” ๊ฐ™์ด ๋ณ€ํ™˜๋˜์ง€ ์•Š์Œ
    • ์ด๋ฏธ์ง€๊ฐ€ ๋ณ€ํ•ด๋„ ๋ผ๋ฒจ๊ฐ’์ด ๊ณ ์ •์ด์ง€๋งŒ, ๊ฐ์ฒด ๊ฒ€์ถœ ๋ฌธ์ œ์—์„œ๋Š” ์ด๋ฏธ์ง€๊ฐ€ ๋ณ€ํ•จ์— ๋”ฐ๋ผ ๋ผ๋ฒจ ๊ฐ’ ๋˜ํ•œ ํ•จ๊ป˜ ๋ณ€ํ•ด์•ผ ํ•จ
img, annot, transform_time = torchvision_dataset[0]

print('transforms ์ ์šฉ ํ›„')
plot_image_from_output(img, annot)

transforms ์ ์šฉ ํ›„

 

  • torchvision_dataset์—์„œ ์ด๋ฏธ์ง€ ๋ณ€ํ™˜์— ์†Œ์š”๋œ ์‹œ๊ฐ„์„ ๊ณ„์‚ฐํ•˜๊ณ  ๊ทธ๊ฒƒ์„ 100๋ฒˆ ๋ฐ˜๋ณตํ•œ ์‹œ๊ฐ„
total_time = 0
for i in range(100):
  sample, _, transform_time = torchvision_dataset[0]
  total_time += transform_time

print("torchvision time: {} ms".format(total_time*10))

 

โž• 2) Albumentations

  • cv2 ๋ชจ๋“ˆ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€๋ฅผ ์ฝ๊ณ  RGB๋กœ ๋ฐ”๊ฟ”์คŒ
  • ์ด๋ฏธ์ง€ ๋ณ€ํ™˜์„ ์‹ค์‹œํ•œ ํ›„ ๊ฒฐ๊ณผ๊ฐ’์„ ๋ฐ˜ํ™˜
class AlbumentationsDataset(Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.imgs = list(sorted(os.listdir(self.path)))
        self.transform = transform
        
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        file_image = self.imgs[idx]
        file_label = self.imgs[idx][:-3] + 'xml'
        img_path = os.path.join(self.path, file_image)

        if 'test' in self.path:
            label_path = os.path.join("test_annotations/", file_label)
        else:
            label_path = os.path.join("annotations/", file_label)
        
        # Read an image with OpenCV
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        target = generate_target(label_path)

        start_t = time.time()
        if self.transform:
            augmented = self.transform(image=image)
            total_time = (time.time() - start_t)
            image = augmented['image']
        
            
        return image, target, total_time

 

  • albumentations_transform ์†๋„ ๋น„๊ต
# Same transform with torchvision_transform
albumentations_transform = albumentations.Compose([
    albumentations.Resize(300, 300), 
    albumentations.RandomCrop(224, 224),
    albumentations.ColorJitter(p=1), 
    albumentations.HorizontalFlip(p=1), 
    albumentations.pytorch.transforms.ToTensorV2()
])
# ๋ณ€๊ฒฝ ์ „
img, annot, transform_time = torchvision_dataset_no_transform[0]
plot_image_from_output(img, annot)

๋ณ€๊ฒฝ ์ „

 

# ๋ณ€๊ฒฝ ํ›„
albumentation_dataset = AlbumentationsDataset(
    path = 'images/',
    transform = albumentations_transform
)

img, annot, transform_time = albumentation_dataset[0]
plot_image_from_output(img, annot)

๋ณ€๊ฒฝ ํ›„

 

  • ์†๋„ ์ธก์ •์„ ์œ„ํ•ด albumentation์„ 100๋ฒˆ ์ ์šฉ ์‹œํ‚จ ๋’ค ์‹œ๊ฐ„์„ ์ธก์ •
total_time = 0
for i in range(100):
    sample, _, transform_time = albumentation_dataset[0]
    total_time += transform_time

print("albumentations time/sample: {} ms".format(total_time*10))

 

โž• 3) ํ™•๋ฅ  ๊ธฐ๋ฐ˜ Augmentation ์กฐํ•ฉ

  • Albumentations์—์„œ ์ œ๊ณตํ•˜๋Š” OneOf ํ•จ์ˆ˜
  • list ์•ˆ์— ์žˆ๋Š” augmentation ๊ธฐ๋Šฅ ๋“ค์„ ์ฃผ์–ด์ง„ ํ™•๋ฅ  ๊ฐ’์— ๊ธฐ๋ฐ˜ํ•˜์—ฌ ๊ฐ€์ ธ์˜ด
    •  list ๊ฐ’ ์ž์ฒด์˜ ํ™•๋ฅ  ๊ฐ’๊ณผ ๋”๋ถˆ์–ด ํ•ด๋‹น ํ•จ์ˆ˜์˜ ํ™•๋ฅ  ๊ฐ’์„ ํ•จ๊ป˜ ๊ณ ๋ คํ•˜์—ฌ ์‹คํ–‰ ์—ฌ๋ถ€๋ฅผ ๊ฒฐ์ •
  • OneOf ํ•จ์ˆ˜๋Š” ๊ฐ๊ฐ ์„ ํƒ๋  ํ™•๋ฅ ์ด 1
  • ๊ฐ๊ฐ์˜ ํ•จ์ˆ˜ ๋‚ด๋ถ€์— ์žˆ๋Š” 3๊ฐœ์˜ albumentations ๊ธฐ๋Šฅ๋“ค ๋˜ํ•œ ๊ฐ๊ฐ ํ™•๋ฅ  ๊ฐ’์ด 1๋กœ ๋ถ€์—ฌ
  • ์‹ค์งˆ์ ์œผ๋กœ 1/3์˜ ํ™•๋ฅ ๋กœ 3๊ฐœ์˜ ๊ธฐ๋Šฅ ์ค‘ ํ•˜๋‚˜๊ฐ€ ์„ ํƒ๋˜์–ด ์‹คํ–‰๋œ๋‹ค๋Š” ๊ฒƒ์„ ์•Œ ์ˆ˜ ์žˆ์Œ
  • ํ™•๋ฅ  ๊ฐ’์„ ์กฐ์ •ํ•˜์—ฌ ๋‹ค์–‘ํ•œ augmentation์ด ๊ฐ€๋Šฅ
albumentations_transform_oneof = albumentations.Compose([
    albumentations.Resize(300, 300), 
    albumentations.RandomCrop(224, 224),
    albumentations.OneOf([
                          albumentations.HorizontalFlip(p=1),
                          albumentations.RandomRotate90(p=1),
                          albumentations.VerticalFlip(p=1)            
    ], p=1),
    albumentations.OneOf([
                          albumentations.MotionBlur(p=1),
                          albumentations.OpticalDistortion(p=1),
                          albumentations.GaussNoise(p=1)                 
    ], p=1),
    albumentations.pytorch.ToTensorV2()
])

 

  • albumentations_transform_oneof๋ฅผ ์ด๋ฏธ์ง€์— 10๋ฒˆ ์ ์šฉํ•œ ๊ฒฐ๊ณผ
albumentation_dataset_oneof = AlbumentationsDataset(
    path = 'images/',
    transform = albumentations_transform_oneof
)

num_samples = 10
fig, ax = plt.subplots(1, num_samples, figsize=(25, 5))
for i in range(num_samples):
  ax[i].imshow(transforms.ToPILImage()(albumentation_dataset_oneof[0][0]))
  ax[i].axis('off')

 

โž• 4) ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค Augmentation

  • ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋ฅผ ํ•จ๊ป˜ ๋ณ€ํ™˜ ์‹œ์ผœ ์ฃผ์ง€ ์•Š์œผ๋ฉด ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๊ฐ€ ์—‰๋šฑํ•œ ๊ณณ์„ ํƒ์ง€ํ•˜๊ณ  ์žˆ๊ธฐ ๋•Œ๋ฌธ์— ๋ชจ๋ธ ํ•™์Šต์ด ์ œ๋Œ€๋กœ ์ด๋ค„์ง€์ง€ ์•Š์Œ
  • Albumentations์—์„œ ์ œ๊ณตํ•˜๋Š” Compose ํ•จ์ˆ˜์— ์žˆ๋Š” bbox_params ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ํ™œ์šฉํ•˜๋ฉด ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค augmentation์ด ๊ฐ€๋Šฅ

 

  • ์ƒˆ๋กœ์šด ๋ฐ์ดํ„ฐ์…‹ ํด๋ž˜์Šค๋ฅผ ์ƒ์„ฑ
    •  AlbumentationsDataset ํด๋ž˜์Šค์˜ transform ๋ถ€๋ถ„์„ ์ˆ˜์ •
    • ์ด๋ฏธ์ง€๋ฟ๋งŒ ์•„๋‹ˆ๋ผ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋„ transform์ด ์ง„ํ–‰๋˜๊ธฐ ๋•Œ๋ฌธ์— ํ•„์š”ํ•œ ์ž…๋ ฅ๊ฐ’, ์ถœ๋ ฅ๊ฐ’ ์ˆ˜์ •์„ ์ง„ํ–‰
class BboxAugmentationDataset(Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.imgs = list(sorted(os.listdir(self.path)))
        self.transform = transform
        
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        file_image = self.imgs[idx]
        file_label = self.imgs[idx][:-3] + 'xml'
        img_path = os.path.join(self.path, file_image)

        if 'test' in self.path:
            label_path = os.path.join("test_annotations/", file_label)
        else:
            label_path = os.path.join("annotations/", file_label)
        
        image = cv2.imread(img_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

        target = generate_target(label_path)

        if self.transform:
            transformed = self.transform(image = image, bboxes = target['boxes'], labels = target['labels'])
            image = transformed['image']
            target = {'boxes':transformed['bboxes'], 'labels':transformed['labels']}
        
            
        return image, target

 

  • albumentations.Compose ํ•จ์ˆ˜๋ฅผ ํ™œ์šฉํ•ด ๋ณ€ํ™˜์„ ์ •์˜
  • ๊ฐ€์žฅ ๋จผ์ € ์ขŒ์šฐ๋ฐ˜์ „์„ ์‹ค์‹œํ•  ๊ฒƒ์ด๋ฉฐ, ๊ทธ ์ดํ›„์— -90๋„์—์„œ 90๋„ ์‚ฌ์ด์˜ ํšŒ์ „์„ ์ง„ํ–‰
  • ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค๋„ ํ•จ๊ป˜ ๋ณ€ํ™˜์„ ์ง„ํ–‰ํ•ด์ฃผ๊ธฐ ์œ„ํ•ด bbox_params ํŒŒ๋ผ๋ฏธํ„ฐ์— albumentations.BboxParams ๊ฐ์ฒด๋ฅผ ์ž…๋ ฅ
  • Face Mask Detection ๋ฐ์ดํ„ฐ์…‹์€ ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค ํ‘œ๊ธฐ๋ฒ•์ด xmin, ymin, xmax, ymax์œผ๋กœ ๋ผ ์žˆ๊ณ , ์ด๊ฒƒ์€ pascal_voc ํ‘œ๊ธฐ๋ฒ•๊ณผ ๊ฐ™์Œ
  • format ํŒŒ๋ผ๋ฏธํ„ฐ์— pascal_voc์„ ์ž…๋ ฅ
  • transform ์ง„ํ–‰ ์‹œ ๊ฐ์ฒด๋ณ„ ํด๋ž˜์Šค ๊ฐ’์€ labels ํŒŒ๋ผ๋ฏธํ„ฐ์— ์ €์žฅํ•ด๋‘๊ธฐ ์œ„ํ•ด label_field์— labels๋ฅผ ์ž…๋ ฅ

 

bbox_transform = albumentations.Compose(
    [albumentations.HorizontalFlip(p=1),
     albumentations.Rotate(p=1),
     albumentations.pytorch.transforms.ToTensorV2()],
    bbox_params=albumentations.BboxParams(format='pascal_voc', label_fields=['labels']),
)
  • BboxAugmentationDataset ํด๋ž˜์Šค๋ฅผ ํ™œ์„ฑํ™” ํ•˜์—ฌ augmentation ๊ฒฐ๊ณผ๋ฌผ์„ ํ™•์ธ
    • ์ฝ”๋“œ๋ฅผ ์‹คํ–‰ํ•  ๋•Œ๋งˆ๋‹ค ์ด๋ฏธ์ง€๊ฐ€ ๋ณ€ํ™˜๋˜์–ด์„œ ์ถœ๋ ฅ
    • ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค ๋˜ํ•œ ์•Œ๋งž๊ฒŒ ๋ณ€ํ™˜๋˜์–ด ๋ณ€ํ™˜๋œ ์ด๋ฏธ์ง€์— ์žˆ๋Š” ๋งˆ์Šคํฌ ์ฐฉ์šฉ ์–ผ๊ตด๋“ค์„ ์ •ํ™•ํžˆ ํƒ์ง€

 

bbox_transform_dataset = BboxAugmentationDataset(
    path = 'images/',
    transform = bbox_transform
)

img, annot = bbox_transform_dataset[0]
plot_image_from_output(img, annot)

 

 

๐ŸŒ 4. ๋ฐ์ดํ„ฐ ๋ถ„๋ฆฌ

  • ์ธ๊ณต์ง€๋Šฅ ๋ชจ๋ธ์„ ๊ตฌ์ถ•ํ•˜๊ธฐ ์œ„ํ•ด์„  ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ์™€ ์‹œํ—˜ ๋ฐ์ดํ„ฐ๊ฐ€ ํ•„์š”
    • ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ๋Š” ๋ชจ๋ธ ํ›ˆ๋ จ ์‹œ ์‚ฌ์šฉ
    • ์‹œํ—˜ ๋ฐ์ดํ„ฐ๋Š” ๋ชจ๋ธ ํ‰๊ฐ€ ์‹œ ์‚ฌ์šฉ
    • ์‹œํ—˜ ๋ฐ์ดํ„ฐ๋Š” ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ์™€ ์ค‘๋ณต๋˜์ง€ ์•Š์•„์•ผ ํ•จ
print(len(os.listdir('annotations')))
print(len(os.listdir('images')))

 

 

  • ์ผ๋ฐ˜์ ์œผ๋กœ ํ•™์Šต ๋ฐ์ดํ„ฐ์™€ ์‹œํ—˜ ๋ฐ์ดํ„ฐ์˜ ๋น„์œจ์€ 7:3 (์—ฌ๊ธฐ์„œ๋Š” 8:2 ๋น„์œจ)
  • ํ•ด๋‹น ๋ฐ์ดํ„ฐ๋ฅผ ๋ณ„๋„์˜ ํด๋”๋กœ ์˜ฎ๊ฒจ ์ฃผ๋„๋ก ํ•จ
!mkdir test_images
!mkdir test_annotations

 

  • images ํด๋”์™€ annotations ํด๋”์— ์žˆ๋Š” ํŒŒ์ผ ๊ฐ๊ฐ 170๊ฐœ์”ฉ์„ ์ƒˆ๋กœ ์ƒ์„ฑํ•œ ํด๋”๋กœ ์˜ฎ๊ธฐ๊ธฐ
  • random ๋ชจ๋“ˆ์— ์žˆ๋Š” sample ํ•จ์ˆ˜๋ฅผ ํ™œ์šฉํ•ด ๋ฌด์ž‘์œ„๋กœ ์ˆซ์ž๋ฅผ ์ถ”์ถœํ•œ ํ›„ ์ธ๋ฑ์Šค๊ฐ’์œผ๋กœ ํ™œ์šฉ
import random
random.seed(1234)
idx = random.sample(range(853), 170)
print(len(idx))
print(idx[:10])

 

  • shutil ํŒจํ‚ค์ง€๋ฅผ ํ™œ์šฉํ•ด 170๊ฐœ์˜ ์ด๋ฏธ์ง€์™€ 170๊ฐœ์˜ ์ขŒํ‘œ ํŒŒ์ผ๋“ค์„ ๊ฐ๊ฐ test_imagesํด๋”์™€ test_annotations ํด๋”๋กœ ์˜ฎ๊น€
  • ๊ฐ ํด๋”๋ณ„ ํŒŒ์ผ ๊ฐœ์ˆ˜๋ฅผ ํ™•์ธ
import numpy as np
import shutil

for img in np.array(sorted(os.listdir('images')))[idx]:
    shutil.move('images/'+img, 'test_images/'+img)

for annot in np.array(sorted(os.listdir('annotations')))[idx]:
    shutil.move('annotations/'+annot, 'test_annotations/'+annot)

 

print(len(os.listdir('annotations')))
print(len(os.listdir('images')))
print(len(os.listdir('test_annotations')))
print(len(os.listdir('test_images')))

 

  • ๊ฐ์ฒด ํƒ์ง€ ๋ฌธ์ œ์—์„œ๋Š” ๊ฐ ํด๋ž˜์Šค ๋ณ„๋กœ ๋ช‡ ๊ฐœ์˜ ๊ฐ์ฒด๊ฐ€ ๋ฐ์ดํ„ฐ์…‹ ๋‚ด๋ถ€์— ์กด์žฌํ•˜๋Š”์ง€ ํ™•์ธํ•˜๋Š” ์ž‘์—…์ด ํ•„์š”
  • ๋ฐ์ดํ„ฐ์…‹ ๋‚ด๋ถ€์— ์žˆ๋Š” ํด๋ž˜์Šค๋ณ„ ๊ฐ์ฒด ์ˆ˜๋ฅผ ํ™•์ธ
    • ํ•™์Šต์šฉ ๋ฐ์ดํ„ฐ์—๋Š” 532๊ฐœ์˜ 0๋ฒˆ ํด๋ž˜์Šค, 2,691๊ฐœ์˜ 1๋ฒˆ ํด๋ž˜์Šค, 97๊ฐœ์˜ 2๋ฒˆ ํด๋ž˜์Šค๊ฐ€ ์œ„์น˜
    • ์‹œํ—˜์šฉ ๋ฐ์ดํ„ฐ์—๋Š” 185๊ฐœ์˜ 0๋ฒˆ ํด๋ž˜์Šค, 541๊ฐœ์˜ 1๋ฒˆ ํด๋ž˜์Šค, 26๊ฐœ์˜ 2๋ฒˆ ํด๋ž˜์Šค๊ฐ€ ์œ„์น˜
    • ๋ฐ์ดํ„ฐ์…‹๋ณ„๋กœ 0,1,2 ๋น„์œจ์ด ์œ ์‚ฌํ•œ ๊ฒƒ์„ ๋ณด์•„ ์ ์ ˆํžˆ ๋ฐ์ดํ„ฐ๊ฐ€ ๋‚˜๋‰˜์–ด ์ง„ ๊ฒƒ์„ ํ™•์ธ
from tqdm import tqdm
import pandas as pd
from collections import Counter

def get_num_objects_for_each_class(dataset):
    # ๋ฐ์ดํ„ฐ์…‹์— ์žˆ๋Š” ๋ชจ๋“  ๋ฐ”์šด๋”ฉ ๋ฐ•์Šค์˜ ๋ผ๋ฒจ ๊ฐ’์„ total_labels์— ์ €์žฅ ํ›„ 
    # Counter ํด๋ž˜์Šค๋ฅผ ํ™œ์šฉํ•ด ๋ผ๋ฒจ๋ณ„ ๊ฐœ์ˆ˜๋ฅผ ์„ธ์–ด ๋ฐ˜ํ™˜ํ•˜๋Š” ํ•จ์ˆ˜
    total_labels = []
    for img, annot in tqdm(dataset, position = 0, leave = True):
        total_labels += [int(i) for i in annot['labels']]

    return Counter(total_labels)

train_data =  BboxAugmentationDataset(
    path = 'images/'
)

test_data =  BboxAugmentationDataset(
    path = 'test_images/'
)

train_objects = get_num_objects_for_each_class(train_data)
test_objects = get_num_objects_for_each_class(test_data)

print('\n train ๋ฐ์ดํ„ฐ์— ์žˆ๋Š” ๊ฐ์ฒด', train_objects)
print('\n test ๋ฐ์ดํ„ฐ์— ์žˆ๋Š” ๊ฐ์ฒด', test_objects)

 

๐ŸŒ 5. RetinaNet

  • torchvision์—์„œ ์ œ๊ณตํ•˜๋Š” one-stage ๋ชจ๋ธ์ธ RetinaNet์„ ํ™œ์šฉํ•ด ์˜๋ฃŒ์šฉ ๋งˆ์Šคํฌ ๊ฒ€์ถœ ๋ชจ๋ธ์„ ๊ตฌ์ถ•
  • ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
import os
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches
from bs4 import BeautifulSoup
from PIL import Image
import cv2
import numpy as np
import time
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
from matplotlib import pyplot as plt
import os

def generate_box(obj):
    
    xmin = float(obj.find('xmin').text)
    ymin = float(obj.find('ymin').text)
    xmax = float(obj.find('xmax').text)
    ymax = float(obj.find('ymax').text)
    
    return [xmin, ymin, xmax, ymax]

def generate_label(obj):

    if obj.find('name').text == "with_mask":

        return 1

    elif obj.find('name').text == "mask_weared_incorrect":

        return 2

    return 0

def generate_target(file): 
    with open(file) as f:
        data = f.read()
        soup = BeautifulSoup(data, "html.parser")
        objects = soup.find_all("object")

        num_objs = len(objects)

        boxes = []
        labels = []
        for i in objects:
            boxes.append(generate_box(i))
            labels.append(generate_label(i))

        boxes = torch.as_tensor(boxes, dtype=torch.float32) 
        labels = torch.as_tensor(labels, dtype=torch.int64) 
        
        target = {}
        target["boxes"] = boxes
        target["labels"] = labels
        
        return target

def plot_image_from_output(img, annotation):
    
    img = img.cpu().permute(1,2,0)
    
    rects = []

    for idx in range(len(annotation["boxes"])):
        xmin, ymin, xmax, ymax = annotation["boxes"][idx]

        if annotation['labels'][idx] == 0 :
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='r',facecolor='none')
        
        elif annotation['labels'][idx] == 1 :
            
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='g',facecolor='none')
            
        else :
        
            rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='orange',facecolor='none')

        rects.append(rect)

    return img, rects

class MaskDataset(Dataset):
    def __init__(self, path, transform=None):
        self.path = path
        self.imgs = list(sorted(os.listdir(self.path)))
        self.transform = transform
        
    def __len__(self):
        return len(self.imgs)

    def __getitem__(self, idx):
        file_image = self.imgs[idx]
        file_label = self.imgs[idx][:-3] + 'xml'
        img_path = os.path.join(self.path, file_image)
        
        if 'test' in self.path:
            label_path = os.path.join("test_annotations/", file_label)
        else:
            label_path = os.path.join("annotations/", file_label)

        img = Image.open(img_path).convert("RGB")
        target = generate_target(label_path)
        
        to_tensor = torchvision.transforms.ToTensor()

        if self.transform:
            img, transform_target = self.transform(np.array(img), np.array(target['boxes']))
            target['boxes'] = torch.as_tensor(transform_target)

        # tensor๋กœ ๋ณ€๊ฒฝ
        img = to_tensor(img)


        return img, target

def collate_fn(batch):
    return tuple(zip(*batch))

dataset = MaskDataset('images/')
test_dataset = MaskDataset('test_images/')

data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, collate_fn=collate_fn)

 

!pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html

 

 

  • RetinaNet ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
    • Face Mask Detection ๋ฐ์ดํ„ฐ์…‹์— 3๊ฐœ์˜ ํด๋ž˜์Šค๊ฐ€ ์กด์žฌํ•˜๋ฏ€๋กœ num_classes ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ 3์œผ๋กœ ์ •์˜
    • ์ „์ด ํ•™์Šต์„ ํ•  ๊ฒƒ์ด๊ธฐ ๋•Œ๋ฌธ์— backbone ๊ตฌ์กฐ๋Š” ์‚ฌ์ „ ํ•™์Šต ๋œ ๊ฐ€์ค‘์น˜๋ฅผ, ๊ทธ ์™ธ ๊ฐ€์ค‘์น˜๋Š” ์ดˆ๊ธฐํ™” ์ƒํƒœ๋กœ ๊ฐ€์ ธ์˜ค๊ธฐ
      • backbone์€ ๊ฐ์ฒด ํƒ์ง€ ๋ฐ์ดํ„ฐ์…‹์œผ๋กœ ์œ ๋ช…ํ•œ COCO ๋ฐ์ดํ„ฐ์…‹์— ์‚ฌ์ „ ํ•™์Šต๋จ
retina = torchvision.models.detection.retinanet_resnet50_fpn(num_classes = 3, pretrained=False, pretrained_backbone = True)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

num_epochs = 1
retina.to(device)
    
# parameters
params = [p for p in retina.parameters() if p.requires_grad] # gradient calculation์ด ํ•„์š”ํ•œ params๋งŒ ์ถ”์ถœ
optimizer = torch.optim.SGD(params, lr=0.005,
                                momentum=0.9, weight_decay=0.0005)

len_dataloader = len(data_loader)

# epoch ๋‹น ์•ฝ 4๋ถ„ ์†Œ์š”
for epoch in range(num_epochs):
    start = time.time()
    retina.train()

    i = 0    
    epoch_loss = 0
    for images, targets in data_loader:
        images = list(image.to(device) for image in images)
        targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

        loss_dict = retina(images, targets) 

        losses = sum(loss for loss in loss_dict.values()) 

        i += 1

        optimizer.zero_grad()
        losses.backward()
        optimizer.step()
        
        epoch_loss += losses 
    print(epoch_loss, f'time: {time.time() - start}')
torch.save(retina.state_dict(),f'retina_{num_epochs}.pt')
retina.load_state_dict(torch.load(f'retina_{num_epochs}.pt'))

 

 

 

 

 

 

728x90
๋ฐ˜์‘ํ˜•
Comments