728x90
반응형

Video Inference 수행

  • mmdetection의 video_demo.py 대로 Video inference 수행시 image 처리 시간이 상대적으로 오래 걸림.
  • 이미지 처리 로직을 변경하여 적용
!wget -O /content/data/John_Wick_small.mp4 https://github.com/chulminkw/DLCV/blob/master/data/video/John_Wick_small.mp4?raw=true
from mmdet.apis import init_detector, inference_detector
import mmcv

config_file = '/content/mmdetection/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py'
checkpoint_file = '/content/mmdetection/checkpoints/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
model = init_detector(config_file, checkpoint_file, device='cuda:0')
# https://github.com/open-mmlab/mmdetection/blob/master/demo/video_demo.py 대로 video detection 수행. 

import cv2

video_reader = mmcv.VideoReader('/content/data/John_Wick_small.mp4')
video_writer = None
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
video_writer = cv2.VideoWriter('/content/data/John_Wick_small_out1.mp4', fourcc, video_reader.fps,(video_reader.width, video_reader.height))

for frame in mmcv.track_iter_progress(video_reader):
  result = inference_detector(model, frame)
  frame = model.show_result(frame, result, score_thr=0.4)

  video_writer.write(frame)

if video_writer:
        video_writer.release()

Custom된 frame처리 로직을 적용하여 Video Inference 수행.

  • 기존에 사용한 get_detected_img()를 그대로 사용함.
# model과 원본 이미지 array, filtering할 기준 class confidence score를 인자로 가지는 inference 시각화용 함수 생성. 
import numpy as np

# 0부터 순차적으로 클래스 매핑된 label 적용. 
labels_to_names_seq = {0:'person',1:'bicycle',2:'car',3:'motorbike',4:'aeroplane',5:'bus',6:'train',7:'truck',8:'boat',9:'traffic light',10:'fire hydrant',
                        11:'stop sign',12:'parking meter',13:'bench',14:'bird',15:'cat',16:'dog',17:'horse',18:'sheep',19:'cow',20:'elephant',
                        21:'bear',22:'zebra',23:'giraffe',24:'backpack',25:'umbrella',26:'handbag',27:'tie',28:'suitcase',29:'frisbee',30:'skis',
                        31:'snowboard',32:'sports ball',33:'kite',34:'baseball bat',35:'baseball glove',36:'skateboard',37:'surfboard',38:'tennis racket',39:'bottle',40:'wine glass',
                        41:'cup',42:'fork',43:'knife',44:'spoon',45:'bowl',46:'banana',47:'apple',48:'sandwich',49:'orange',50:'broccoli',
                        51:'carrot',52:'hot dog',53:'pizza',54:'donut',55:'cake',56:'chair',57:'sofa',58:'pottedplant',59:'bed',60:'diningtable',
                        61:'toilet',62:'tvmonitor',63:'laptop',64:'mouse',65:'remote',66:'keyboard',67:'cell phone',68:'microwave',69:'oven',70:'toaster',
                        71:'sink',72:'refrigerator',73:'book',74:'clock',75:'vase',76:'scissors',77:'teddy bear',78:'hair drier',79:'toothbrush' }

def get_detected_img(model, img_array,  score_threshold=0.3, is_print=True):
  # 인자로 들어온 image_array를 복사. 
  draw_img = img_array.copy()
  bbox_color=(0, 255, 0)
  text_color=(0, 0, 255)

  # model과 image array를 입력 인자로 inference detection 수행하고 결과를 results로 받음. 
  # results는 80개의 2차원 array(shape=(오브젝트갯수, 5))를 가지는 list. 
  results = inference_detector(model, img_array)

  # 80개의 array원소를 가지는 results 리스트를 loop를 돌면서 개별 2차원 array들을 추출하고 이를 기반으로 이미지 시각화 
  # results 리스트의 위치 index가 바로 COCO 매핑된 Class id. 여기서는 result_ind가 class id
  # 개별 2차원 array에 오브젝트별 좌표와 class confidence score 값을 가짐. 
  for result_ind, result in enumerate(results):
    # 개별 2차원 array의 row size가 0 이면 해당 Class id로 값이 없으므로 다음 loop로 진행. 
    if len(result) == 0:
      continue
    
    # 2차원 array에서 5번째 컬럼에 해당하는 값이 score threshold이며 이 값이 함수 인자로 들어온 score_threshold 보다 낮은 경우는 제외. 
    result_filtered = result[np.where(result[:, 4] > score_threshold)]
    
    # 해당 클래스 별로 Detect된 여러개의 오브젝트 정보가 2차원 array에 담겨 있으며, 이 2차원 array를 row수만큼 iteration해서 개별 오브젝트의 좌표값 추출. 
    for i in range(len(result_filtered)):
      # 좌상단, 우하단 좌표 추출. 
      left = int(result_filtered[i, 0])
      top = int(result_filtered[i, 1])
      right = int(result_filtered[i, 2])
      bottom = int(result_filtered[i, 3])
      caption = "{}: {:.4f}".format(labels_to_names_seq[result_ind], result_filtered[i, 4])
      cv2.rectangle(draw_img, (left, top), (right, bottom), color=bbox_color, thickness=2)
      cv2.putText(draw_img, caption, (int(left), int(top - 7)), cv2.FONT_HERSHEY_SIMPLEX, 0.37, text_color, 1)
      if is_print:
        print(caption)

  return draw_img
import time

def do_detected_video(model, input_path, output_path, score_threshold, do_print=True):
    
    cap = cv2.VideoCapture(input_path)

    codec = cv2.VideoWriter_fourcc(*'XVID')

    vid_size = (round(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),round(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
    vid_fps = cap.get(cv2.CAP_PROP_FPS)

    vid_writer = cv2.VideoWriter(output_path, codec, vid_fps, vid_size) 

    frame_cnt = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    print('총 Frame 갯수:', frame_cnt)
    btime = time.time()
    while True:
        hasFrame, img_frame = cap.read()
        if not hasFrame:
            print('더 이상 처리할 frame이 없습니다.')
            break
        stime = time.time()
        img_frame = get_detected_img(model, img_frame,  score_threshold=score_threshold, is_print=False)
        if do_print:
          print('frame별 detection 수행 시간:', round(time.time() - stime, 4))
        vid_writer.write(img_frame)
    # end of while loop

    vid_writer.release()
    cap.release()

    print('최종 detection 완료 수행 시간:', round(time.time() - btime, 4))
do_detected_video(model, '/content/data/John_Wick_small.mp4', '/content/data/John_Wick_small_out2.mp4', score_threshold=0.4, do_print=True)

frame별 detection 수행 시간: 0.304
frame별 detection 수행 시간: 0.3139
frame별 detection 수행 시간: 0.2853
frame별 detection 수행 시간: 0.2918
frame별 detection 수행 시간: 0.2905
frame별 detection 수행 시간: 0.299
frame별 detection 수행 시간: 0.3006
frame별 detection 수행 시간: 0.2887
frame별 detection 수행 시간: 0.295
frame별 detection 수행 시간: 0.2932
frame별 detection 수행 시간: 0.2991
frame별 detection 수행 시간: 0.303
frame별 detection 수행 시간: 0.2913
frame별 detection 수행 시간: 0.2953
frame별 detection 수행 시간: 0.3003
frame별 detection 수행 시간: 0.3022
frame별 detection 수행 시간: 0.3018
frame별 detection 수행 시간: 0.2954
frame별 detection 수행 시간: 0.2975
frame별 detection 수행 시간: 0.3034
frame별 detection 수행 시간: 0.292
frame별 detection 수행 시간: 0.288
frame별 detection 수행 시간: 0.2939
frame별 detection 수행 시간: 0.2911
frame별 detection 수행 시간: 0.2969
frame별 detection 수행 시간: 0.2903
frame별 detection 수행 시간: 0.2912
frame별 detection 수행 시간: 0.2906
frame별 detection 수행 시간: 0.2947
frame별 detection 수행 시간: 0.2956
frame별 detection 수행 시간: 0.2936
frame별 detection 수행 시간: 0.2939
frame별 detection 수행 시간: 0.2925
frame별 detection 수행 시간: 0.2939
frame별 detection 수행 시간: 0.3
frame별 detection 수행 시간: 0.2862
frame별 detection 수행 시간: 0.2961
frame별 detection 수행 시간: 0.2958
frame별 detection 수행 시간: 0.2943
frame별 detection 수행 시간: 0.2931
frame별 detection 수행 시간: 0.2856
frame별 detection 수행 시간: 0.2911
frame별 detection 수행 시간: 0.2967
frame별 detection 수행 시간: 0.2944
frame별 detection 수행 시간: 0.2933
frame별 detection 수행 시간: 0.2886
frame별 detection 수행 시간: 0.2926
frame별 detection 수행 시간: 0.2991
frame별 detection 수행 시간: 0.2963
frame별 detection 수행 시간: 0.2885
frame별 detection 수행 시간: 0.2917
frame별 detection 수행 시간: 0.292
frame별 detection 수행 시간: 0.3013
frame별 detection 수행 시간: 0.2963
frame별 detection 수행 시간: 0.2903
더 이상 처리할 frame이 없습니다.
최종 detection 완료 수행 시간: 17.738

 

 

 

 

반응형

+ Recent posts