๐ ๊ณต๋ถํ๋ ์ง์ง์ํ์นด๋ ์ฒ์์ด์ง?
[Kaggle] ์๋ฃ์ฉ ๋ง์คํฌ ํ์ง ๋ชจ๋ธ Object Detection ๋ณธ๋ฌธ
[Kaggle] ์๋ฃ์ฉ ๋ง์คํฌ ํ์ง ๋ชจ๋ธ Object Detection
์ง์ง์ํ์นด 2022. 11. 9. 16:08<๋ณธ ๋ธ๋ก๊ทธ๋ pseudo-lab ๋์ Tutorial-Book ๋ธ๋ก๊ทธ๋ฅผ ์ฐธ๊ณ ํด์ ๊ณต๋ถํ๋ฉฐ ์์ฑํ์์ต๋๋ค>
https://pseudo-lab.github.io/Tutorial-Book/chapters/object-detection/Ch1-Object-Detection.html
1. ๊ฐ์ฒด ํ์ง ์๊ฐ — PseudoLab Tutorial Book
๊ฐ์ฒด ํ์ง(Object Detection)๋ ์ปดํจํฐ ๋น์ ๊ธฐ์ ์ ์ธ๋ถ ๋ถ์ผ์ค ํ๋๋ก์จ ์ฃผ์ด์ง ์ด๋ฏธ์ง๋ด ์ฌ์ฉ์๊ฐ ๊ด์ฌ ์๋ ๊ฐ์ฒด๋ฅผ ํ์งํ๋ ๊ธฐ์ ์ ๋๋ค. ์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ์ด ๊ทธ๋ฆผ 1-1 ์ข์ธก์ ์๋ ๊ฐ์์ง ์ฌ์ง์ ๊ฐ
pseudo-lab.github.io
โ ์๋ฃ์ฉ ๋ง์คํฌ ํ์ง ๋ชจ๋ธ
๋ง์คํฌ๋ฅผ ์ฐฉ์ฉํ ์ฌ๋, ์ฐฉ์ฉํ์ง ์์ ์ฌ๋ ๋๋ ๋ง์คํฌ๋ฅผ ๋ถ์ ์ ํ๊ฒ ์ฐฉ์ฉํ ์ฌ๋์ ๊ฐ์งํ๋ ๋ชจ๋ธ์ ๋ง๋ค๊ธฐ
๐๐จ ๋ฐ์ดํฐ ์ธํธ
3 ๊ฐ์ง ํด๋์ค ์ ์ํ๋ 853 ๊ฐ์ ์ด๋ฏธ์ง ์ PASCAL VOC ํ์์ ๊ฒฝ๊ณ ์์๊ฐ ํฌํจ
- ๋ง์คํฌ ํฌํจ
- ๋ง์คํฌ ์์ด
- ๋ง์คํฌ๋ฅผ ์๋ชป ์ฐฉ์ฉ
๐๐จ images์ annotations ํด๋
- images ํด๋์๋ ์ด๋ฏธ์ง ํ์ผ์ด 0๋ถํฐ 852
- annotations ํด๋์๋ xml ํ์ผ์ด 0๋ถํฐ 852
- annotations ํด๋ ์์ ์๋ xml ํ์ผ๋ค์ ๊ฐ๊ฐ์ ์ด๋ฏธ์ง ํ์ผ์ ์ ๋ณด
- ํด๋๋ช ๊ณผ ํ์ผ๋ช ์ด ๋์ค๋ฉฐ, ์ด๋ฏธ์ง ํฌ๊ธฐ ์ ๋ณด๊ฐ ํฌํจ๋์ด ์๋ ๊ฑธ ํ์ธ
- mask_weared_incorrect์ ๊ฒฝ์ฐ ๋ง์คํฌ๋ฅผ ์ ๋๋ก ์ฐ์ง ์์ ๊ฐ์ฒด์ ์ ๋ณด
- with_mask๋ ๋ง์คํฌ๋ฅผ ์ฐฉ์ฉํ๊ณ ์๋ ๊ฐ์ฒด ์์น ์ ๋ณด
- without_mask์ ๋ง์คํฌ๋ฅผ ์ฐ์ง ์์ ๊ฐ์ฒด์ ์ ๋ณด
- bndbox์์๋ xmin, ymin, xmax, ymax๊ฐ ์์๋๋ก ์์ ๋ฐ์ด๋ฉ ๋ฐ์ค ์์ญ์ ์ง์ ํ๋ ์ ๋ณด
- annotations ํด๋ ์์ ์๋ xml ํ์ผ๋ค์ ๊ฐ๊ฐ์ ์ด๋ฏธ์ง ํ์ผ์ ์ ๋ณด
๐ 1. ๋ฐ์ดํฐ ๊ฒ์ฆ
import os
import glob # ํ์ผ์ ๋ค๋ฃจ๋๋ฐ ๋๋ฆฌ ์ฐ์
import matplotlib.pyplot as plt # ์๊ฐํ
import matplotlib.image as mpimg
import matplotlib.patches as patches
# HTML๊ณผ XML ๋ฌธ์ ํ์ผ์ ํ์ฑ(Parsing), ์น์คํฌ๋ํ(Web Scraping)์ ์ ์ฉ
from bs4 import BeautifulSoup
from PIL import Image
import cv2
import numpy as np
import time
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations
import albumentations.pytorch
import os
import random
- glob ํจํค์ง๋ฅผ ์ด์ฉํด ๋ฐ์ดํฐ์ ์ ๋ถ๋ฌ์ด
- sorted ํจ์๋ฅผ ์ฌ์ฉํด img_list์ ์๋ ํ์ผ์ id ์์์ annot_list์ ์๋ ํ์ผ์ id ์์๊ฐ ๊ฐ๋๋ก ํจ
img_list = sorted(glob.glob('archive/images/*'))
annot_list = sorted(glob.glob('archive/annotations/*'))
๐ 2. ๋ฐ์ด๋ฉ ๋ฐ์ค ์๊ฐํ๋ฅผ ์ํ ํจ์๋ฅผ ์ ์
def generate_box(obj):
# xmin, ymin, xmax, ymax ๊ฐ์ ๋ฐํ
xmin = float(obj.find('xmin').text)
ymin = float(obj.find('ymin').text)
xmax = float(obj.find('xmax').text)
ymax = float(obj.find('ymax').text)
return [xmin, ymin, xmax, ymax]
def generate_label(obj):
# ๋ง์คํฌ ์ฐฉ์ฉ ์ฌ๋ถ๋ฅผ ์ธ๋จ๊ณ๋ก ๋๋ ์ 0, 1, 2 ๊ฐ์ ๋ฐํ
# with_mask์ ๊ฒฝ์ฐ 1
# mask_weared_incorrect์ ๊ฒฝ์ฐ 2
# without_mask๋ 0
if obj.find('name').text == "with_mask":
return 1
elif obj.find('name').text == "mask_weared_incorrect":
return 2
return 0
def generate_target(file):
# generate_box์ generate_label๋ฅผ ๊ฐ๊ฐ ํธ์ถํ์ฌ ๋ฐํ๋ ๊ฐ์ ๋์
๋๋ฆฌ์ ์ ์ฅํด ๋ฐํ
with open(file) as f:
data = f.read()
# annotations ํ์ผ์ ์๋ ๋ด์ฉ๋ค์ ๋ถ๋ฌ์ ํ๊ฒ์ ๋ฐ์ด๋ฉ ๋ฐ์ค์ ๋ผ๋ฒจ์ ์ถ์ถ
soup = BeautifulSoup(data, "html.parser")
objects = soup.find_all("object")
num_objs = len(objects)
boxes = []
labels = []
for i in objects:
boxes.append(generate_box(i))
labels.append(generate_label(i))
target = {}
target["boxes"] = boxes
target["labels"] = labels
return target
def plot_image(img_path, annotation):
# ์ด๋ฏธ์ง์ ๋ฐ์ด๋ฉ ๋ฐ์ค๋ฅผ ํจ๊ป ์๊ฐํ
img = mpimg.imread(img_path)
fig, ax = plt.subplots(1)
ax.imshow(img)
for idx in range(len(annotation["boxes"])):
xmin, ymin, xmax, ymax = annotation["boxes"][idx]
# ๋ง์คํฌ ์ฐฉ์ฉ์ ์ด๋ก์
# ๋ง์คํฌ๋ฅผ ์ฌ๋ฐ๋ฅด๊ฒ ์ฐฉ์ฉ ์ํ์ ์ ์ฃผํฉ์
# ๋ง์คํฌ๋ฅผ ์ฐฉ์ฉ ์ํ์ ์ ๋นจ๊ฐ์ ๋ฐ์ด๋ฉ ๋ฐ์ค
if annotation['labels'][idx] == 0 :
rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='r',facecolor='none')
elif annotation['labels'][idx] == 1 :
rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='g',facecolor='none')
else :
rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='orange',facecolor='none')
ax.add_patch(rect)
plt.show()
- ์ธ๋ฑ์ค ๊ฐ ํ์ธ
img_list.index('archive/images/maksssksksss307.png')
- generate_target ํจ์๋ฅผ ํ์ฉํด maksssksksss307.png ํ์ผ์ ํด๋นํ๋ ๋ฐ์ด๋ฉ ๋ฐ์ค ์ ๋ณด๋ฅผ bbox์ ์ ์ฅ
- plot_image ํจ์์ ๋ฐ์ด๋ฉ ๋ฐ์ค ์ ๋ณด์ ๋๋ถ์ด ํด๋น ์ด๋ฏธ์ง ํ์ผ ์ ๋ณด๋ ๋๊ฒจ์ฃผ์ด ์ด๋ฏธ์ง ์์ ๋ฐ์ด๋ฉ ๋ฐ์ค๋ฅผ ์๊ฐํ
- img_list[]์ annot_list[]์์ ์ซ์๋ maksssksksss307.png ํ์ผ์ ์์น๋ฅผ ๋ปํ๋ฏ๋ก ๊ฐ์ ์ซ์๊ฐ ๋ค์ด๊ฐ ์์
bbox = generate_target(annot_list[232])
plot_image(img_list[232], bbox)
๐ 3. ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ
- augmentation : ๋ฐ์ดํฐ ์ฆ๊ฐ๋ฒ
- torchvision.transforms๋ ํ์ดํ ์น์์ ๊ณต์์ ์ผ๋ก ์ ๊ณต
- ๋ฐ์ด๋ฉ ๋ฐ์ค ๋ณํ ๊ธฐ๋ฅ ์์
- albumentations๋ OpenCV์ ๊ฐ์ ์คํ ์์ค ์ปดํจํฐ ๋น์ ผ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ต์ ํ ํ์๊ธฐ์ ๋ค๋ฅธ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ณด๋ค ๋ ๋น ๋ฅธ ์ฒ๋ฆฌ ์๋ ๋ฐ ๊ธฐํ ๊ธฐ๋ฅ์ ์ ๊ณต
- ๊ฐ์ฒด ํ์ง์ฉ ์ด๋ฏธ์ง augmentation์ ์ด๋ฏธ์ง ๋ฟ๋ง ์๋๋ผ ๋ฐ์ด๋ฉ ๋ฐ์ค๊น์ง ๋ณํ์ ์ฃผ์ด์ผ ํจ
- torchvision.transforms๋ ํ์ดํ ์น์์ ๊ณต์์ ์ผ๋ก ์ ๊ณต
!git clone https://github.com/Pseudo-Lab/Tutorial-Book-Utils
!python Tutorial-Book-Utils/PL_data_loader.py --data FaceMaskDetection
!unzip -q Face\ Mask\ Detection.zip
!pip install --upgrade albumentations
- augmentation ๊ฒฐ๊ณผ๋ฌผ์ ์๊ฐํ
import os
import glob
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from bs4 import BeautifulSoup
def generate_box(obj):
# xmin, ymin, xmax, ymax ๊ฐ์ ๋ฐํ
xmin = float(obj.find('xmin').text)
ymin = float(obj.find('ymin').text)
xmax = float(obj.find('xmax').text)
ymax = float(obj.find('ymax').text)
return [xmin, ymin, xmax, ymax]
def generate_label(obj):
# ๋ง์คํฌ ์ฐฉ์ฉ ์ฌ๋ถ๋ฅผ ์ธ๋จ๊ณ๋ก ๋๋ ์ 0, 1, 2 ๊ฐ์ ๋ฐํ
# with_mask์ ๊ฒฝ์ฐ 1
# mask_weared_incorrect์ ๊ฒฝ์ฐ 2
# without_mask๋ 0
if obj.find('name').text == "with_mask":
return 1
elif obj.find('name').text == "mask_weared_incorrect":
return 2
return 0
def generate_target(file):
with open(file) as f:
data = f.read()
soup = BeautifulSoup(data, "html.parser")
objects = soup.find_all("object")
num_objs = len(objects)
boxes = []
labels = []
for i in objects:
boxes.append(generate_box(i))
labels.append(generate_label(i))
# ๋ฅ๋ฌ๋ ๋ชจ๋ธ ํ์ต์ ์ํ tensor๊ฐ์ ์ฐ์ฐ์ ์ค๋นํ๊ธฐ ์ํด torch.as_tensorํจ์๊ฐ ์ถ๊ฐ
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
return target
def plot_image_from_output(img, annotation):
# ์ด๋ฏธ์ง์ ๋ฐ์ด๋ฉ ๋ฐ์ค๋ฅผ ํจ๊ป ์๊ฐํ
# torch.Tensor๋ก ๋ณํ๋ ์ด๋ฏธ์ง๋ฅผ ์๊ฐํ
# PyTorch์์๋ ์ด๋ฏธ์ง๋ฅผ [channels, height, width]๋ก ํํ
# matplotlib์์๋ [height, width, channels]๋ก ํํ
img = img.permute(1,2,0)
# ์ฑ๋ ์์๋ฅผ ๋ฐ๊ฟ์ฃผ๋ permuteํจ์๋ฅผ ํ์ฉํด matplotlib์ ์ฑ๋ ์์๋ก ๋ฐ๊ฟ
fig,ax = plt.subplots(1)
ax.imshow(img)
for idx in range(len(annotation["boxes"])):
xmin, ymin, xmax, ymax = annotation["boxes"][idx]
# ๋ง์คํฌ ์ฐฉ์ฉ์ ์ด๋ก์
# ๋ง์คํฌ๋ฅผ ์ฌ๋ฐ๋ฅด๊ฒ ์ฐฉ์ฉ ์ํ์ ์ ์ฃผํฉ์
# ๋ง์คํฌ๋ฅผ ์ฐฉ์ฉ ์ํ์ ์ ๋นจ๊ฐ์ ๋ฐ์ด๋ฉ ๋ฐ์ค
if annotation['labels'][idx] == 0 :
rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='r',facecolor='none')
elif annotation['labels'][idx] == 1 :
rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='g',facecolor='none')
else :
rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='orange',facecolor='none')
ax.add_patch(rect)
plt.show()
โ 1) torchvision.transforms์ ์ค์ต
from PIL import Image
import cv2
import numpy as np
import time
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
import albumentations
import albumentations.pytorch
from matplotlib import pyplot as plt
import os
import random
class TorchvisionMaskDataset(Dataset):
def __init__(self, path, transform=None):
self.path = path
self.imgs = list(sorted(os.listdir(self.path)))
self.transform = transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx): # image๋ฅผ ๋ถ๋ฌ์จ ๋ค์ ๋ฐ์ดํฐ augmentation์ ์งํ
file_image = self.imgs[idx]
file_label = self.imgs[idx][:-3] + 'xml'
img_path = os.path.join(self.path, file_image)
if 'test' in self.path:
label_path = os.path.join("test_annotations/", file_label)
else:
label_path = os.path.join("annotations/", file_label)
img = Image.open(img_path).convert("RGB")
target = generate_target(label_path)
start_t = time.time()
if self.transform: # transform ํ๋ผ๋ฏธํฐ์ ์ ์ฅ๋ผ ์๋ augmentation ๊ท์น์ ๋ฐ๋ผ augmentation์ด ์ด๋ค์ง
img = self.transform(img)
total_time = (time.time() - start_t) # ์๊ฐ ์ธก์ ์ ์ํด timeํจ์๋ฅผ ์ฌ์ฉ
# ์ต์ข
์ ์ผ๋ก image, label, total_time์ ๋ฐํ
return img, target, total_time
- ์ด๋ฏธ์ง augmentation ์ค์ต์ ์งํ
- ์ด๋ฏธ์ง๋ฅผ (300, 300) ํฌ๊ธฐ๋ก ๋ง๋ ํ, 224 ํฌ๊ธฐ๋ก ์๋ฅด๊ธฐ
- ์ด๋ฏธ์ง์ ๋ฐ๊ธฐ(brightness), ๋๋น(contrast), ์ฑ๋(saturation), ์์กฐ(hue)๋ฅผ ๋ฌด์์๋ก ๋ฐ๊พธ๊ธฐ
- ์ด๋ฏธ์ง ์ข์ฐ ๋ฐ์ ์ ์ ์ฉํ ํ tensor๋ก ๋ณํํ๋ ์์ ์ ์งํ
torchvision_transform = transforms.Compose([
transforms.Resize((300, 300)),
transforms.RandomCrop(224),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
transforms.RandomHorizontalFlip(p = 1),
transforms.ToTensor(),
])
torchvision_dataset = TorchvisionMaskDataset(
path = 'images/',
transform = torchvision_transform
)
- transforms์์ ์ ๊ณตํ๋ Resize ํจ์๋ฅผ ํตํด ์ด๋ฏธ์ง ํฌ๊ธฐ๋ฅผ ์กฐ์
- RandomCrop ํจ์๋ฅผ ํตํด ์ด๋ฏธ์ง๋ฅผ ์๋ฅผ ์ ์์
- ColorJitter ํจ์๋ ๋ฐ๊ธฐ, ๋๋น, ์ฑ๋, ์์กฐ ๋ฑ์ ์์๋ก ๋ฐ๊พธ๋ ๊ธฐ๋ฅ
- RandomHorizontalFlip์ ์ ์ํ p์ ํ๋ฅ ๋ก ์ข์ฐ๋ฐ์ ์ ์ค์
- ๋ณ๊ฒฝ ์ ๊ณผ ๋ณ๊ฒฝ ํ์ ์ด๋ฏธ์ง๋ฅผ ๋น๊ต
only_totensor = transforms.Compose([transforms.ToTensor()])
torchvision_dataset_no_transform = TorchvisionMaskDataset(
path = 'images/',
transform = only_totensor
)
img, annot, transform_time = torchvision_dataset_no_transform[0]
print('transforms ์ ์ฉ ์ ')
plot_image_from_output(img, annot)
- ๋ณ๊ฒฝ ์ ์ ๋นํด ๋ณ๊ฒฝ ํ ์ด๋ฏธ์ง๋ ์์ ์ธ๊ธํ ๋ณํ๋ค์ด ์ ์ฉ๋จ
- ์ด๋ฏธ์ง ์์ฒด์ ์ธ ๋ณํ๋ ์ด๋ค์ก์ง๋ง ๋ฐ์ด๋ฉ ๋ฐ์ค๋ ๋ณํ๋ ์ด๋ฏธ์ง์์ ์์น๊ฐ ์ด๊ธ๋ ๊ฒ์ ํ์ธ
- torchvision.transform์์ ์ ๊ณตํ๋ augmentation์ ์ด๋ฏธ์ง ๊ฐ์ ๋ํ augmentation๋ง ์งํ์ด ๋๋ฉฐ,
- ๋ฐ์ด๋ฉ ๋ฐ์ค๋ ๊ฐ์ด ๋ณํ๋์ง ์์
- ์ด๋ฏธ์ง๊ฐ ๋ณํด๋ ๋ผ๋ฒจ๊ฐ์ด ๊ณ ์ ์ด์ง๋ง, ๊ฐ์ฒด ๊ฒ์ถ ๋ฌธ์ ์์๋ ์ด๋ฏธ์ง๊ฐ ๋ณํจ์ ๋ฐ๋ผ ๋ผ๋ฒจ ๊ฐ ๋ํ ํจ๊ป ๋ณํด์ผ ํจ
img, annot, transform_time = torchvision_dataset[0]
print('transforms ์ ์ฉ ํ')
plot_image_from_output(img, annot)
- torchvision_dataset์์ ์ด๋ฏธ์ง ๋ณํ์ ์์๋ ์๊ฐ์ ๊ณ์ฐํ๊ณ ๊ทธ๊ฒ์ 100๋ฒ ๋ฐ๋ณตํ ์๊ฐ
total_time = 0
for i in range(100):
sample, _, transform_time = torchvision_dataset[0]
total_time += transform_time
print("torchvision time: {} ms".format(total_time*10))
โ 2) Albumentations
- cv2 ๋ชจ๋์ ์ฌ์ฉํ์ฌ ์ด๋ฏธ์ง๋ฅผ ์ฝ๊ณ RGB๋ก ๋ฐ๊ฟ์ค
- ์ด๋ฏธ์ง ๋ณํ์ ์ค์ํ ํ ๊ฒฐ๊ณผ๊ฐ์ ๋ฐํ
class AlbumentationsDataset(Dataset):
def __init__(self, path, transform=None):
self.path = path
self.imgs = list(sorted(os.listdir(self.path)))
self.transform = transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
file_image = self.imgs[idx]
file_label = self.imgs[idx][:-3] + 'xml'
img_path = os.path.join(self.path, file_image)
if 'test' in self.path:
label_path = os.path.join("test_annotations/", file_label)
else:
label_path = os.path.join("annotations/", file_label)
# Read an image with OpenCV
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
target = generate_target(label_path)
start_t = time.time()
if self.transform:
augmented = self.transform(image=image)
total_time = (time.time() - start_t)
image = augmented['image']
return image, target, total_time
- albumentations_transform ์๋ ๋น๊ต
# Same transform with torchvision_transform
albumentations_transform = albumentations.Compose([
albumentations.Resize(300, 300),
albumentations.RandomCrop(224, 224),
albumentations.ColorJitter(p=1),
albumentations.HorizontalFlip(p=1),
albumentations.pytorch.transforms.ToTensorV2()
])
# ๋ณ๊ฒฝ ์
img, annot, transform_time = torchvision_dataset_no_transform[0]
plot_image_from_output(img, annot)
# ๋ณ๊ฒฝ ํ
albumentation_dataset = AlbumentationsDataset(
path = 'images/',
transform = albumentations_transform
)
img, annot, transform_time = albumentation_dataset[0]
plot_image_from_output(img, annot)
- ์๋ ์ธก์ ์ ์ํด albumentation์ 100๋ฒ ์ ์ฉ ์ํจ ๋ค ์๊ฐ์ ์ธก์
total_time = 0
for i in range(100):
sample, _, transform_time = albumentation_dataset[0]
total_time += transform_time
print("albumentations time/sample: {} ms".format(total_time*10))
โ 3) ํ๋ฅ ๊ธฐ๋ฐ Augmentation ์กฐํฉ
- Albumentations์์ ์ ๊ณตํ๋ OneOf ํจ์
- list ์์ ์๋ augmentation ๊ธฐ๋ฅ ๋ค์ ์ฃผ์ด์ง ํ๋ฅ ๊ฐ์ ๊ธฐ๋ฐํ์ฌ ๊ฐ์ ธ์ด
- list ๊ฐ ์์ฒด์ ํ๋ฅ ๊ฐ๊ณผ ๋๋ถ์ด ํด๋น ํจ์์ ํ๋ฅ ๊ฐ์ ํจ๊ป ๊ณ ๋ คํ์ฌ ์คํ ์ฌ๋ถ๋ฅผ ๊ฒฐ์
- OneOf ํจ์๋ ๊ฐ๊ฐ ์ ํ๋ ํ๋ฅ ์ด 1
- ๊ฐ๊ฐ์ ํจ์ ๋ด๋ถ์ ์๋ 3๊ฐ์ albumentations ๊ธฐ๋ฅ๋ค ๋ํ ๊ฐ๊ฐ ํ๋ฅ ๊ฐ์ด 1๋ก ๋ถ์ฌ
- ์ค์ง์ ์ผ๋ก 1/3์ ํ๋ฅ ๋ก 3๊ฐ์ ๊ธฐ๋ฅ ์ค ํ๋๊ฐ ์ ํ๋์ด ์คํ๋๋ค๋ ๊ฒ์ ์ ์ ์์
- ํ๋ฅ ๊ฐ์ ์กฐ์ ํ์ฌ ๋ค์ํ augmentation์ด ๊ฐ๋ฅ
albumentations_transform_oneof = albumentations.Compose([
albumentations.Resize(300, 300),
albumentations.RandomCrop(224, 224),
albumentations.OneOf([
albumentations.HorizontalFlip(p=1),
albumentations.RandomRotate90(p=1),
albumentations.VerticalFlip(p=1)
], p=1),
albumentations.OneOf([
albumentations.MotionBlur(p=1),
albumentations.OpticalDistortion(p=1),
albumentations.GaussNoise(p=1)
], p=1),
albumentations.pytorch.ToTensorV2()
])
- albumentations_transform_oneof๋ฅผ ์ด๋ฏธ์ง์ 10๋ฒ ์ ์ฉํ ๊ฒฐ๊ณผ
albumentation_dataset_oneof = AlbumentationsDataset(
path = 'images/',
transform = albumentations_transform_oneof
)
num_samples = 10
fig, ax = plt.subplots(1, num_samples, figsize=(25, 5))
for i in range(num_samples):
ax[i].imshow(transforms.ToPILImage()(albumentation_dataset_oneof[0][0]))
ax[i].axis('off')
โ 4) ๋ฐ์ด๋ฉ ๋ฐ์ค Augmentation
- ๋ฐ์ด๋ฉ ๋ฐ์ค๋ฅผ ํจ๊ป ๋ณํ ์์ผ ์ฃผ์ง ์์ผ๋ฉด ๋ฐ์ด๋ฉ ๋ฐ์ค๊ฐ ์๋ฑํ ๊ณณ์ ํ์งํ๊ณ ์๊ธฐ ๋๋ฌธ์ ๋ชจ๋ธ ํ์ต์ด ์ ๋๋ก ์ด๋ค์ง์ง ์์
- Albumentations์์ ์ ๊ณตํ๋ Compose ํจ์์ ์๋ bbox_params ํ๋ผ๋ฏธํฐ๋ฅผ ํ์ฉํ๋ฉด ๋ฐ์ด๋ฉ ๋ฐ์ค augmentation์ด ๊ฐ๋ฅ
- ์๋ก์ด ๋ฐ์ดํฐ์
ํด๋์ค๋ฅผ ์์ฑ
- AlbumentationsDataset ํด๋์ค์ transform ๋ถ๋ถ์ ์์
- ์ด๋ฏธ์ง๋ฟ๋ง ์๋๋ผ ๋ฐ์ด๋ฉ ๋ฐ์ค๋ transform์ด ์งํ๋๊ธฐ ๋๋ฌธ์ ํ์ํ ์ ๋ ฅ๊ฐ, ์ถ๋ ฅ๊ฐ ์์ ์ ์งํ
class BboxAugmentationDataset(Dataset):
def __init__(self, path, transform=None):
self.path = path
self.imgs = list(sorted(os.listdir(self.path)))
self.transform = transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
file_image = self.imgs[idx]
file_label = self.imgs[idx][:-3] + 'xml'
img_path = os.path.join(self.path, file_image)
if 'test' in self.path:
label_path = os.path.join("test_annotations/", file_label)
else:
label_path = os.path.join("annotations/", file_label)
image = cv2.imread(img_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
target = generate_target(label_path)
if self.transform:
transformed = self.transform(image = image, bboxes = target['boxes'], labels = target['labels'])
image = transformed['image']
target = {'boxes':transformed['bboxes'], 'labels':transformed['labels']}
return image, target
- albumentations.Compose ํจ์๋ฅผ ํ์ฉํด ๋ณํ์ ์ ์
- ๊ฐ์ฅ ๋จผ์ ์ข์ฐ๋ฐ์ ์ ์ค์ํ ๊ฒ์ด๋ฉฐ, ๊ทธ ์ดํ์ -90๋์์ 90๋ ์ฌ์ด์ ํ์ ์ ์งํ
- ๋ฐ์ด๋ฉ ๋ฐ์ค๋ ํจ๊ป ๋ณํ์ ์งํํด์ฃผ๊ธฐ ์ํด bbox_params ํ๋ผ๋ฏธํฐ์ albumentations.BboxParams ๊ฐ์ฒด๋ฅผ ์ ๋ ฅ
- Face Mask Detection ๋ฐ์ดํฐ์ ์ ๋ฐ์ด๋ฉ ๋ฐ์ค ํ๊ธฐ๋ฒ์ด xmin, ymin, xmax, ymax์ผ๋ก ๋ผ ์๊ณ , ์ด๊ฒ์ pascal_voc ํ๊ธฐ๋ฒ๊ณผ ๊ฐ์
- format ํ๋ผ๋ฏธํฐ์ pascal_voc์ ์ ๋ ฅ
- transform ์งํ ์ ๊ฐ์ฒด๋ณ ํด๋์ค ๊ฐ์ labels ํ๋ผ๋ฏธํฐ์ ์ ์ฅํด๋๊ธฐ ์ํด label_field์ labels๋ฅผ ์ ๋ ฅ
bbox_transform = albumentations.Compose(
[albumentations.HorizontalFlip(p=1),
albumentations.Rotate(p=1),
albumentations.pytorch.transforms.ToTensorV2()],
bbox_params=albumentations.BboxParams(format='pascal_voc', label_fields=['labels']),
)
- BboxAugmentationDataset ํด๋์ค๋ฅผ ํ์ฑํ ํ์ฌ augmentation ๊ฒฐ๊ณผ๋ฌผ์ ํ์ธ
- ์ฝ๋๋ฅผ ์คํํ ๋๋ง๋ค ์ด๋ฏธ์ง๊ฐ ๋ณํ๋์ด์ ์ถ๋ ฅ
- ๋ฐ์ด๋ฉ ๋ฐ์ค ๋ํ ์๋ง๊ฒ ๋ณํ๋์ด ๋ณํ๋ ์ด๋ฏธ์ง์ ์๋ ๋ง์คํฌ ์ฐฉ์ฉ ์ผ๊ตด๋ค์ ์ ํํ ํ์ง
bbox_transform_dataset = BboxAugmentationDataset(
path = 'images/',
transform = bbox_transform
)
img, annot = bbox_transform_dataset[0]
plot_image_from_output(img, annot)
๐ 4. ๋ฐ์ดํฐ ๋ถ๋ฆฌ
- ์ธ๊ณต์ง๋ฅ ๋ชจ๋ธ์ ๊ตฌ์ถํ๊ธฐ ์ํด์ ํ์ต์ฉ ๋ฐ์ดํฐ์ ์ํ ๋ฐ์ดํฐ๊ฐ ํ์
- ํ์ต์ฉ ๋ฐ์ดํฐ๋ ๋ชจ๋ธ ํ๋ จ ์ ์ฌ์ฉ
- ์ํ ๋ฐ์ดํฐ๋ ๋ชจ๋ธ ํ๊ฐ ์ ์ฌ์ฉ
- ์ํ ๋ฐ์ดํฐ๋ ํ์ต์ฉ ๋ฐ์ดํฐ์ ์ค๋ณต๋์ง ์์์ผ ํจ
print(len(os.listdir('annotations')))
print(len(os.listdir('images')))
- ์ผ๋ฐ์ ์ผ๋ก ํ์ต ๋ฐ์ดํฐ์ ์ํ ๋ฐ์ดํฐ์ ๋น์จ์ 7:3 (์ฌ๊ธฐ์๋ 8:2 ๋น์จ)
- ํด๋น ๋ฐ์ดํฐ๋ฅผ ๋ณ๋์ ํด๋๋ก ์ฎ๊ฒจ ์ฃผ๋๋ก ํจ
!mkdir test_images
!mkdir test_annotations
- images ํด๋์ annotations ํด๋์ ์๋ ํ์ผ ๊ฐ๊ฐ 170๊ฐ์ฉ์ ์๋ก ์์ฑํ ํด๋๋ก ์ฎ๊ธฐ๊ธฐ
- random ๋ชจ๋์ ์๋ sample ํจ์๋ฅผ ํ์ฉํด ๋ฌด์์๋ก ์ซ์๋ฅผ ์ถ์ถํ ํ ์ธ๋ฑ์ค๊ฐ์ผ๋ก ํ์ฉ
import random
random.seed(1234)
idx = random.sample(range(853), 170)
print(len(idx))
print(idx[:10])
- shutil ํจํค์ง๋ฅผ ํ์ฉํด 170๊ฐ์ ์ด๋ฏธ์ง์ 170๊ฐ์ ์ขํ ํ์ผ๋ค์ ๊ฐ๊ฐ test_imagesํด๋์ test_annotations ํด๋๋ก ์ฎ๊น
- ๊ฐ ํด๋๋ณ ํ์ผ ๊ฐ์๋ฅผ ํ์ธ
import numpy as np
import shutil
for img in np.array(sorted(os.listdir('images')))[idx]:
shutil.move('images/'+img, 'test_images/'+img)
for annot in np.array(sorted(os.listdir('annotations')))[idx]:
shutil.move('annotations/'+annot, 'test_annotations/'+annot)
print(len(os.listdir('annotations')))
print(len(os.listdir('images')))
print(len(os.listdir('test_annotations')))
print(len(os.listdir('test_images')))
- ๊ฐ์ฒด ํ์ง ๋ฌธ์ ์์๋ ๊ฐ ํด๋์ค ๋ณ๋ก ๋ช ๊ฐ์ ๊ฐ์ฒด๊ฐ ๋ฐ์ดํฐ์ ๋ด๋ถ์ ์กด์ฌํ๋์ง ํ์ธํ๋ ์์ ์ด ํ์
- ๋ฐ์ดํฐ์
๋ด๋ถ์ ์๋ ํด๋์ค๋ณ ๊ฐ์ฒด ์๋ฅผ ํ์ธ
- ํ์ต์ฉ ๋ฐ์ดํฐ์๋ 532๊ฐ์ 0๋ฒ ํด๋์ค, 2,691๊ฐ์ 1๋ฒ ํด๋์ค, 97๊ฐ์ 2๋ฒ ํด๋์ค๊ฐ ์์น
- ์ํ์ฉ ๋ฐ์ดํฐ์๋ 185๊ฐ์ 0๋ฒ ํด๋์ค, 541๊ฐ์ 1๋ฒ ํด๋์ค, 26๊ฐ์ 2๋ฒ ํด๋์ค๊ฐ ์์น
- ๋ฐ์ดํฐ์ ๋ณ๋ก 0,1,2 ๋น์จ์ด ์ ์ฌํ ๊ฒ์ ๋ณด์ ์ ์ ํ ๋ฐ์ดํฐ๊ฐ ๋๋์ด ์ง ๊ฒ์ ํ์ธ
from tqdm import tqdm
import pandas as pd
from collections import Counter
def get_num_objects_for_each_class(dataset):
# ๋ฐ์ดํฐ์
์ ์๋ ๋ชจ๋ ๋ฐ์ด๋ฉ ๋ฐ์ค์ ๋ผ๋ฒจ ๊ฐ์ total_labels์ ์ ์ฅ ํ
# Counter ํด๋์ค๋ฅผ ํ์ฉํด ๋ผ๋ฒจ๋ณ ๊ฐ์๋ฅผ ์ธ์ด ๋ฐํํ๋ ํจ์
total_labels = []
for img, annot in tqdm(dataset, position = 0, leave = True):
total_labels += [int(i) for i in annot['labels']]
return Counter(total_labels)
train_data = BboxAugmentationDataset(
path = 'images/'
)
test_data = BboxAugmentationDataset(
path = 'test_images/'
)
train_objects = get_num_objects_for_each_class(train_data)
test_objects = get_num_objects_for_each_class(test_data)
print('\n train ๋ฐ์ดํฐ์ ์๋ ๊ฐ์ฒด', train_objects)
print('\n test ๋ฐ์ดํฐ์ ์๋ ๊ฐ์ฒด', test_objects)
๐ 5. RetinaNet
- torchvision์์ ์ ๊ณตํ๋ one-stage ๋ชจ๋ธ์ธ RetinaNet์ ํ์ฉํด ์๋ฃ์ฉ ๋ง์คํฌ ๊ฒ์ถ ๋ชจ๋ธ์ ๊ตฌ์ถ
- ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ
import os
import glob
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import matplotlib.patches as patches
from bs4 import BeautifulSoup
from PIL import Image
import cv2
import numpy as np
import time
import torch
import torchvision
from torch.utils.data import Dataset
from torchvision import transforms
from matplotlib import pyplot as plt
import os
def generate_box(obj):
xmin = float(obj.find('xmin').text)
ymin = float(obj.find('ymin').text)
xmax = float(obj.find('xmax').text)
ymax = float(obj.find('ymax').text)
return [xmin, ymin, xmax, ymax]
def generate_label(obj):
if obj.find('name').text == "with_mask":
return 1
elif obj.find('name').text == "mask_weared_incorrect":
return 2
return 0
def generate_target(file):
with open(file) as f:
data = f.read()
soup = BeautifulSoup(data, "html.parser")
objects = soup.find_all("object")
num_objs = len(objects)
boxes = []
labels = []
for i in objects:
boxes.append(generate_box(i))
labels.append(generate_label(i))
boxes = torch.as_tensor(boxes, dtype=torch.float32)
labels = torch.as_tensor(labels, dtype=torch.int64)
target = {}
target["boxes"] = boxes
target["labels"] = labels
return target
def plot_image_from_output(img, annotation):
img = img.cpu().permute(1,2,0)
rects = []
for idx in range(len(annotation["boxes"])):
xmin, ymin, xmax, ymax = annotation["boxes"][idx]
if annotation['labels'][idx] == 0 :
rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='r',facecolor='none')
elif annotation['labels'][idx] == 1 :
rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='g',facecolor='none')
else :
rect = patches.Rectangle((xmin,ymin),(xmax-xmin),(ymax-ymin),linewidth=1,edgecolor='orange',facecolor='none')
rects.append(rect)
return img, rects
class MaskDataset(Dataset):
def __init__(self, path, transform=None):
self.path = path
self.imgs = list(sorted(os.listdir(self.path)))
self.transform = transform
def __len__(self):
return len(self.imgs)
def __getitem__(self, idx):
file_image = self.imgs[idx]
file_label = self.imgs[idx][:-3] + 'xml'
img_path = os.path.join(self.path, file_image)
if 'test' in self.path:
label_path = os.path.join("test_annotations/", file_label)
else:
label_path = os.path.join("annotations/", file_label)
img = Image.open(img_path).convert("RGB")
target = generate_target(label_path)
to_tensor = torchvision.transforms.ToTensor()
if self.transform:
img, transform_target = self.transform(np.array(img), np.array(target['boxes']))
target['boxes'] = torch.as_tensor(transform_target)
# tensor๋ก ๋ณ๊ฒฝ
img = to_tensor(img)
return img, target
def collate_fn(batch):
return tuple(zip(*batch))
dataset = MaskDataset('images/')
test_dataset = MaskDataset('test_images/')
data_loader = torch.utils.data.DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
test_data_loader = torch.utils.data.DataLoader(test_dataset, batch_size=2, collate_fn=collate_fn)
!pip install torch==1.7.0+cu101 torchvision==0.8.1+cu101 torchaudio==0.7.0 -f https://download.pytorch.org/whl/torch_stable.html
- RetinaNet ๋ชจ๋ธ์ ๋ถ๋ฌ์ค๊ธฐ
- Face Mask Detection ๋ฐ์ดํฐ์ ์ 3๊ฐ์ ํด๋์ค๊ฐ ์กด์ฌํ๋ฏ๋ก num_classes ๋งค๊ฐ๋ณ์๋ฅผ 3์ผ๋ก ์ ์
- ์ ์ด ํ์ต์ ํ ๊ฒ์ด๊ธฐ ๋๋ฌธ์ backbone ๊ตฌ์กฐ๋ ์ฌ์ ํ์ต ๋ ๊ฐ์ค์น๋ฅผ, ๊ทธ ์ธ ๊ฐ์ค์น๋ ์ด๊ธฐํ ์ํ๋ก ๊ฐ์ ธ์ค๊ธฐ
- backbone์ ๊ฐ์ฒด ํ์ง ๋ฐ์ดํฐ์ ์ผ๋ก ์ ๋ช ํ COCO ๋ฐ์ดํฐ์ ์ ์ฌ์ ํ์ต๋จ
retina = torchvision.models.detection.retinanet_resnet50_fpn(num_classes = 3, pretrained=False, pretrained_backbone = True)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_epochs = 1
retina.to(device)
# parameters
params = [p for p in retina.parameters() if p.requires_grad] # gradient calculation์ด ํ์ํ params๋ง ์ถ์ถ
optimizer = torch.optim.SGD(params, lr=0.005,
momentum=0.9, weight_decay=0.0005)
len_dataloader = len(data_loader)
# epoch ๋น ์ฝ 4๋ถ ์์
for epoch in range(num_epochs):
start = time.time()
retina.train()
i = 0
epoch_loss = 0
for images, targets in data_loader:
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
loss_dict = retina(images, targets)
losses = sum(loss for loss in loss_dict.values())
i += 1
optimizer.zero_grad()
losses.backward()
optimizer.step()
epoch_loss += losses
print(epoch_loss, f'time: {time.time() - start}')
torch.save(retina.state_dict(),f'retina_{num_epochs}.pt')
retina.load_state_dict(torch.load(f'retina_{num_epochs}.pt'))