본문 바로가기
Two/영상처리

mmpretrain 적용 - CIFAR10

by DH Father 2024. 1. 8.

mmpretrain 적용

CIFAR10



mmpretrain 적용 - CIFAR10

mmpretrain 스터디를 했으니, 실제 데이터를 적용해보았다.

 

Dataset

적용할 데이터셋은 CIFAR10 데이터셋이다. 

CIFAR10 데이터셋에 대해 정리한 내용은 아래와 같다. ( https://www.cs.toronto.edu/~kriz/cifar.html )

총 이미지수 60,000장
이미지 사이즈 32 x32
classes 10
class 당 이미지 수  6,000
train image 수 50,000
test image 수 10,000
classes name airport, automobile, bird, cat, deer, dog, frog, horse, ship, truch

 

checkpoints 파일 다운로드

mmpretrain documents 나 github에 보면 다양한 모델들을 지원하고 있다. 

아래는 documentation에 있는 Model Zoo의 내용이다. 

나는 resnet18_8xb16_cifar10 모델을 다운로드 했다. 그러면 .pth 파일이 다운로드 된다. 

mmpretrain 작업폴더에 checkpoints 폴더를 하나 생성해주고, 그곳에 다운로드 받은 파일을 넣어준다. 

 

Config 파일 생성 및 수정

mmpretrian/config_work 폴더를 생성해준다. 이름은 각자 알아서.

아까 .pth 파일을 다운로드 옆에 Config가 있다. 

이걸 클릭해주면 해당 모델의 github 파일로 이동한다. 아래와 같이 있을것이다. 

_base_ = [
    '../_base_/models/resnet18_cifar.py', '../_base_/datasets/cifar10_bs16.py',
    '../_base_/schedules/cifar10_bs128.py', '../_base_/default_runtime.py'
]

해당 파일들을 내용을 하나의 파일을 생성(resnet18_b16x8_cifar10_config.py) 해서 다 작성해주고, 필요한 부분을 수정해준다. 

각각의 해당 파일에 해당하는 부분을 찾아서 수정해줘도 되지만 그러면 나중에 복잡해지고 헷갈린다. 

아래는 위의 4개 파일을 전체 합친 Config 파일(config_work/ resnet18_b16x8_cifar10_config.py)이다. 

# model settings
model = dict(
    type='ImageClassifier',
    backbone=dict(
        type='ResNet_CIFAR',
        depth=18,
        num_stages=4,
        out_indices=(3, ),
        style='pytorch'),
    neck=dict(type='GlobalAveragePooling'),
    head=dict(
        type='LinearClsHead',
        num_classes=10,
        in_channels=512,
        loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
    ))


# dataset settings
dataset_type = 'CIFAR10'
data_preprocessor = dict(
    num_classes=10,
    # RGB format normalization parameters
    mean=[125.307, 122.961, 113.8575],
    std=[51.5865, 50.847, 51.255],
    # loaded images are already RGB format
    to_rgb=False)

train_pipeline = [
    dict(type='RandomCrop', crop_size=32, padding=4),
    dict(type='RandomFlip', prob=0.5, direction='horizontal'),
    dict(type='PackInputs'),
]

test_pipeline = [
    dict(type='PackInputs'),
]

train_dataloader = dict(
    batch_size=16,
    num_workers=2,
    dataset=dict(
        type=dataset_type,
        data_root='data/cifar10',
        split='train',
        pipeline=train_pipeline),
    sampler=dict(type='DefaultSampler', shuffle=True),
)

val_dataloader = dict(
    batch_size=16,
    num_workers=2,
    dataset=dict(
        type=dataset_type,
        data_root='data/cifar10/',
        split='test',
        pipeline=test_pipeline),
    sampler=dict(type='DefaultSampler', shuffle=False),
)
val_evaluator = dict(type='Accuracy', topk=(1, ))

test_dataloader = val_dataloader
test_evaluator = val_evaluator


# optimizer
optim_wrapper = dict(
    optimizer=dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001))
# learning policy
param_scheduler = dict(
    type='MultiStepLR', by_epoch=True, milestones=[100, 150], gamma=0.1)

# train, val, test setting
train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=1)
val_cfg = dict()
test_cfg = dict()

# NOTE: `auto_scale_lr` is for automatically scaling LR
# based on the actual training batch size.
auto_scale_lr = dict(base_batch_size=128)


# defaults to use registries in mmpretrain
default_scope = 'mmpretrain'

# configure default hooks
default_hooks = dict(
    # record the time of every iteration.
    timer=dict(type='IterTimerHook'),

    # print log every 100 iterations.
    logger=dict(type='LoggerHook', interval=100),

    # enable the parameter scheduler.
    param_scheduler=dict(type='ParamSchedulerHook'),

    # save checkpoint per epoch.
    checkpoint=dict(type='CheckpointHook', interval=1),

    # set sampler seed in distributed evrionment.
    sampler_seed=dict(type='DistSamplerSeedHook'),

    # validation results visualization, set True to enable it.
    visualization=dict(type='VisualizationHook', enable=False),
)

# configure environment
env_cfg = dict(
    # whether to enable cudnn benchmark
    cudnn_benchmark=False,

    # set multi process parameters
    mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),

    # set distributed parameters
    dist_cfg=dict(backend='nccl'),
)

# set visualizer
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(type='UniversalVisualizer', vis_backends=vis_backends)

# set log level
log_level = 'INFO'

# load from which checkpoint
load_from = None

# whether to resume training from the loaded checkpoint
resume = False

# Defaults to use random seed and disable `deterministic`
randomness = dict(seed=None, deterministic=False)

 

간략하게 살펴보면 

현재 모델의 경우, mmpretrain 에서 지원하는 CIFAR10 모델과 그에 따른 config 파일이라 수정없이 사용했다.

 

하지만 custom 데이터를 사용한다면 아래의 부분을 중점으로 변경이 필요하다. 

 

model settings 의 경우,

backbone, neck, head로 구성되어있고, 

model.type 은 'ImageClassifier' 이다. 

head.num_classes=10 이다. custom 데이터셋을 사용하는 경우, classes에 맞게 수정해준다.

 

datas settings 의 경우,

dataset_type = 'CIRAR10' , num_classess = 10 이 부분을 필요에 맞게 수정해줘야한다. 

train_dataloader, val_dataloader 에서는 

batch_size, num_workers, data_root 부분을 수정해 주어야한다. 

 

optimizer, learning policy 의 경우,

필요에 따라 type, lr, momentum, weight_decay 부분을 수정한다. 

 

train, val, test settings 의 경우,

max_epochs, val_interval 부분을 수정한다. 

 

Train

일단 큰 수정없이 아래와 같은 명령어로 학습을 시켜줬다. 

python tools/train.py config_work/resnet18_b16x8_cifar10_config.py

 

결과 dir 지정하지 않았다면, work_dirs/모델명 으로 결과에 대한 파일들이 생성된다.

아래는 vscode의 실행 결과 화면이다. 

accuracy/top1 : 94.83의 결과가 나왔다. 홈페이지와 같은 결과이다.  

 

Test

아래의 명령어와 같이 test를 실행해주었다. 

옵션은 문서를 참조하자. 

# python tools/test.py ${CONFIG_FILE} ${CHECKPOINT_FILE} [ARGS]

python tools\test.py work_dirs\resnet18_b16x8_cifar10_config\20240104_130732\vis_data\config.py work_dirs\resnet18_b16x8_cifar10_config\20240104_130732\epoch_200.pth

 

visual attention network tiny 적용

연습으로 다른 네트워크 모델을 적용해 보았다.

 

정리

1. Check Point 파일 다운로드

2. 실행 Config 파일 생성 및 수정

3. Train 실행

4. Test 실행

 

 

Reference

https://mmpretrain.readthedocs.io/en/latest/index.html

'Two > 영상처리' 카테고리의 다른 글

mmdetection - tutorial (1)  (0) 2024.01.11
mmdetection 개요 및 설치  (0) 2024.01.08
mmpretrain 사용법 (4)  (0) 2023.12.31
mmpretrain 사용법 (3)  (1) 2023.12.30
mmpretrain 사용법 (2)  (0) 2023.12.29