mmpretrain 사용법 (3)
mmpretrain 사용법 (3)
mmpretrain study
mmpretrain 사용법 (3)
(1), (2) 에서 config 파일 수정, 데이터셋 준비 및 그에따른 설정을 스터디했다.
이번에는 모델이 존재한 경우, 추론하는 방법에대해 알아보자.
inference with existing models
API
Inference 에 필요한 APIs에 대해 정리해 두었다.
list_models | List available model names in MMPreTrain. | |
get_model | Get a model from model name or model config. | |
inference_model | Inference a model with the correspondding inferencer. It’s a shortcut for a quick start, and for advanced usage, please use the below inferencer directly. | |
ImageClassificationInferencer | Perform image classification on the given image. | Inferencer |
ImageRetrievalInferencer | Perform image-to-image retrieval from the given image on a given image set. | |
ImageCaptionInferencer | Generate a caption on the given image. | |
VisualQuestionAnsweringInferencer | Answer a question according to the given image. | |
VisualGroundingInferencer | Locate an object from the description on the given image. | |
TextToImageRetrievalInferencer | Perform text-to-image retrieval from the given description on a given image set. | |
ImageToTextRetrievalInferencer | Perform image-to-text retrieval from the given image on a series of text. | |
NLVRInferencer | Perform Natural Language for Visual Reasoning on a given image-pair and text. | |
FeatureExtractor | Extract features from the image files by a vision backbone. |
List available models
mmpretrain에서 제공하는 모델 확인은 아래와 같이 list_models 를 사용한다.
* 를 사용하여 확인도 가능하다.
>>> from mmpretrain import list_models
>>> list_models()
['barlowtwins_resnet50_8xb256-coslr-300e_in1k',
'beit-base-p16_beit-in21k-pre_3rdparty_in1k',
...]
>>> from mmpretrain import list_models
>>> list_models("*convnext-b*21k")
['convnext-base_3rdparty_in21k',
'convnext-base_in21k-pre-3rdparty_in1k-384px',
'convnext-base_in21k-pre_3rdparty_in1k']
task에 상응하는 모델 리스트는 아래와 같이 확인할 수 있다.
>>> from mmpretrain import ImageCaptionInferencer
>>> ImageCaptionInferencer.list_models()
['blip-base_3rdparty_caption',
'blip2-opt2.7b_3rdparty-zeroshot_caption',
'flamingo_3rdparty-zeroshot_caption',
'ofa-base_3rdparty-finetuned_caption']
Get a model
모델을 가져오는 방법은 아래와 같다.
>>> from mmpretrain import get_model
# Get model without loading pre-trained weight.
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k")
# Get model and load the default checkpoint.
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", pretrained=True)
# Get model and load the specified checkpoint.
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", pretrained="your_local_checkpoint_path")
# Get model with extra initialization arguments, for example, modify the num_classes in head.
>>> model = get_model("convnext-base_in21k-pre_3rdparty_in1k", head=dict(num_classes=10))
# Another example, remove the neck and head, and output from stage 1, 2, 3 in backbone
>>> model_headless = get_model("resnet18_8xb32_in1k", head=None, neck=None, backbone=dict(out_indices=(1, 2, 3)))
위에서 가져온 모델은 pytorch module이다.
>>> import torch
>>> from mmpretrain import get_model
>>> model = get_model('convnext-base_in21k-pre_3rdparty_in1k', pretrained=True)
>>> x = torch.rand((1, 3, 224, 224))
>>> y = model(x)
>>> print(type(y), y.shape)
<class 'torch.Tensor'> torch.Size([1, 1000])
Inference on given images
Resnet-50 pre-trained 된 모델이 있다면, inference_model 메서드를 사용해서 아래와 같이 추론할 수 있다.
>>> from mmpretrain import inference_model
>>> image = 'https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG'
>>> # If you have no graphical interface, please set `show=False`
>>> result = inference_model('resnet50_8xb32_in1k', image, show=True)
>>> print(result['pred_class'])
sea snake
여러개의 샘플에 대해서는 아래와 같이 사용할 수 있다.
>>> from mmpretrain import ImageClassificationInferencer
>>> image = 'https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG'
>>> inferencer = ImageClassificationInferencer('resnet50_8xb32_in1k')
>>> # Note that the inferencer output is a list of result even if the input is a single sample.
>>> result = inferencer('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')[0]
>>> print(result['pred_class'])
sea snake
>>>
>>> # You can also use is for multiple images.
>>> image_list = ['demo/demo.JPEG', 'demo/bird.JPEG'] * 16
>>> results = inferencer(image_list, batch_size=8)
>>> print(len(results))
32
>>> print(results[1]['pred_class'])
house finch, linnet, Carpodacus mexicanus
prediction의 결과는 딕셔너리 형태로 아래와 같다.
{
"pred_label": 65,
"pred_score": 0.6649366617202759,
"pred_class":"sea snake",
"pred_scores": array([..., 0.6649366617202759, ...], dtype=float32)
}
자신의 config 파일과, checkpoint 파일을 사용하는 경우에는 아래와 같이 사용한다.
>>> from mmpretrain import ImageClassificationInferencer
>>> image = 'https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG'
>>> config = 'configs/resnet/resnet50_8xb32_in1k.py'
>>> checkpoint = 'https://download.openmmlab.com/mmclassification/v0/resnet/resnet50_8xb32_in1k_20210831-ea4938fc.pth'
>>> inferencer = ImageClassificationInferencer(model=config, pretrained=checkpoint, device='cuda')
>>> result = inferencer(image)[0]
>>> print(result['pred_class'])
sea snake
Inference by a Gradio demo
mmpretrain에서 제공하는 gui 형태의 tool 이다.
설치 : pip install -U gradio
실행이 안된다... 음... 원은은 나중에..
Extract Features From Image
model.extract_fea, FeatureExtractor 는 이미지의 feature를 직접 추출하는데 사용된다.
model.extract_feat 의 입력은 torch.Tensor이고, FeatureExtractor의 입력은 이미지이다.
아래는 예시이다.
>>> from mmpretrain import FeatureExtractor, get_model
>>> model = get_model('resnet50_8xb32_in1k', backbone=dict(out_indices=(0, 1, 2, 3)))
>>> extractor = FeatureExtractor(model)
>>> features = extractor('https://github.com/open-mmlab/mmpretrain/raw/main/demo/demo.JPEG')[0]
>>> features[0].shape, features[1].shape, features[2].shape, features[3].shape
(torch.Size([256]), torch.Size([512]), torch.Size([1024]), torch.Size([2048]))
Reference
https://mmpretrain.readthedocs.io/en/latest/user_guides/inference.html