Skip to content
Snippets Groups Projects
Commit d3717850 authored by Yifan Zhao's avatar Yifan Zhao
Browse files

Updated YOLO definition to adapt to what NVDLA can do

parent c1eee8c6
No related branches found
No related tags found
No related merge requests found
from .dataset import CompressedATRDataset, DefaultTransforms from .dataset import CompressedATRDataset
from .loss import RegionLoss from .loss import RegionLoss
from .model import TinyYoloPL from .model import TinyYoloPL
...@@ -3,7 +3,6 @@ import tarfile ...@@ -3,7 +3,6 @@ import tarfile
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import imgaug.augmenters as iaa
import numpy as np import numpy as np
import torch import torch
import torchvision.transforms as transforms import torchvision.transforms as transforms
...@@ -16,13 +15,11 @@ PathLike = Union[Path, str] ...@@ -16,13 +15,11 @@ PathLike = Union[Path, str]
class CompressedATRDataset(Dataset): class CompressedATRDataset(Dataset):
loaded_size = 640, 480 image_size = 640, 480
output_size = 640, 640
def __init__( def __init__(
self, self,
tar_dataset: PathLike, tar_dataset: PathLike,
transform=None,
n_boxes: int = 10, n_boxes: int = 10,
extract_to: PathLike = None, extract_to: PathLike = None,
): ):
...@@ -40,17 +37,15 @@ class CompressedATRDataset(Dataset): ...@@ -40,17 +37,15 @@ class CompressedATRDataset(Dataset):
tar.close() tar.close()
self.image_files = list(self.temp_dir.glob("**/*.png")) self.image_files = list(self.temp_dir.glob("**/*.png"))
logger.info("Loaded %d images", len(self.image_files)) logger.info("Loaded %d images", len(self.image_files))
self.transform = transform
self.n_ret_boxes = n_boxes self.n_ret_boxes = n_boxes
def __getitem__(self, index): def __getitem__(self, index):
image_path = self.image_files[index] image_path = self.image_files[index]
image = Image.open(image_path).convert("RGB") image = Image.open(image_path).convert("RGB")
assert image.size == self.loaded_size assert image.size == self.image_size
image_np = np.array(image) image_np = np.array(image)
boxes_np = self._read_boxes(image_path.with_suffix(".txt"), image) boxes_np = self._read_boxes(image_path.with_suffix(".txt"), image)
if self.transform is not None: image_np, boxes_np = self._bbox_transform(image_np, boxes_np)
image_np, boxes_np = self.transform(np.array(image), boxes_np)
image_tensor = transforms.ToTensor()(image_np) image_tensor = transforms.ToTensor()(image_np)
boxes_tensor = torch.tensor(boxes_np, dtype=torch.float) boxes_tensor = torch.tensor(boxes_np, dtype=torch.float)
return image_tensor, boxes_tensor, image_path.as_posix() return image_tensor, boxes_tensor, image_path.as_posix()
...@@ -68,26 +63,15 @@ class CompressedATRDataset(Dataset): ...@@ -68,26 +63,15 @@ class CompressedATRDataset(Dataset):
def __len__(self): def __len__(self):
return len(self.image_files) return len(self.image_files)
@classmethod
class DefaultTransforms: def _bbox_transform(cls, image: np.ndarray, boxes: np.ndarray):
def __init__(self):
self.augmentations = iaa.PadToAspectRatio(
1.0, position="center-center"
).to_deterministic()
def __call__(self, img, boxes):
# Convert xywh to xyxy # Convert xywh to xyxy
boxes = np.array(boxes) boxes = np.array(boxes)
boxes[:, 1:] = xywh2xyxy_np(boxes[:, 1:]) boxes[:, 1:] = xywh2xyxy_np(boxes[:, 1:])
# Convert bounding boxes to imgaug # Convert bounding boxes to imgaug
bounding_boxes = BoundingBoxesOnImage( bounding_boxes = BoundingBoxesOnImage(
[BoundingBox(*box[1:], label=box[0]) for box in boxes], shape=img.shape [BoundingBox(*box[1:], label=box[0]) for box in boxes], shape=image.shape
)
# Apply augmentations
img, bounding_boxes = self.augmentations(
image=img, bounding_boxes=bounding_boxes
) )
# Clip out of image boxes # Clip out of image boxes
...@@ -109,7 +93,7 @@ class DefaultTransforms: ...@@ -109,7 +93,7 @@ class DefaultTransforms:
boxes[box_idx, 3] = x2 - x1 boxes[box_idx, 3] = x2 - x1
boxes[box_idx, 4] = y2 - y1 boxes[box_idx, 4] = y2 - y1
return img, boxes return image, boxes
def xywh2xyxy_np(x: np.ndarray): def xywh2xyxy_np(x: np.ndarray):
......
...@@ -8,7 +8,7 @@ import torch ...@@ -8,7 +8,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataloader import DataLoader
from .dataset import CompressedATRDataset, DefaultTransforms from .dataset import CompressedATRDataset
from .loss import RegionLoss from .loss import RegionLoss
DEFAULT_ANCHORS = [ DEFAULT_ANCHORS = [
...@@ -67,7 +67,7 @@ class TinyYoloV2(lnn.module.Darknet): ...@@ -67,7 +67,7 @@ class TinyYoloV2(lnn.module.Darknet):
('9_convbatch', lnn.layer.Conv2dBatchReLU(128, 256, 3, 1, 1, momentum=momentum)), ('9_convbatch', lnn.layer.Conv2dBatchReLU(128, 256, 3, 1, 1, momentum=momentum)),
('10_max', nn.MaxPool2d(2, 2)), ('10_max', nn.MaxPool2d(2, 2)),
('11_convbatch', lnn.layer.Conv2dBatchReLU(256, 384, 3, 1, 1, momentum=momentum)), ('11_convbatch', lnn.layer.Conv2dBatchReLU(256, 384, 3, 1, 1, momentum=momentum)),
('12_max', lnn.layer.PaddedMaxPool2d(2, 1, (0, 1, 0, 1))), ('12_max', nn.MaxPool2d(3, 1, padding=1)),
('13_convbatch', lnn.layer.Conv2dBatchReLU(384, 384, 3, 1, 1, momentum=momentum)), ('13_convbatch', lnn.layer.Conv2dBatchReLU(384, 384, 3, 1, 1, momentum=momentum)),
('14_convbatch', lnn.layer.Conv2dBatchReLU(384, 384, 3, 1, 1, momentum=momentum)), ('14_convbatch', lnn.layer.Conv2dBatchReLU(384, 384, 3, 1, 1, momentum=momentum)),
('15_conv', nn.Conv2d(384, len(self.anchors)*(5+self.num_classes), 1, 1, 0)), ('15_conv', nn.Conv2d(384, len(self.anchors)*(5+self.num_classes), 1, 1, 0)),
...@@ -125,5 +125,5 @@ class TinyYoloPL(pl.LightningModule): ...@@ -125,5 +125,5 @@ class TinyYoloPL(pl.LightningModule):
return loss return loss
def test_dataloader(self): def test_dataloader(self):
dataset = CompressedATRDataset(self.dataset_path, DefaultTransforms()) dataset = CompressedATRDataset(self.dataset_path)
return DataLoader(dataset, batch_size=16) return DataLoader(dataset, batch_size=16)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment