【python】OpenCV—Tracking(10.3)——GOTURN-代码实现

时间:2024-10-28 13:26:48
# Import modules
import cv2, sys, os

num = 0

if not (os.path.isfile('goturn.caffemodel') and os.path.isfile('goturn.prototxt')):
    errorMsg = '''
    Could not find GOTURN model in current directory.
    Please ensure goturn.caffemodel and goturn.prototxt are in the current directory
    '''

    print(errorMsg)
    sys.exit()

导入必要的库函数,判断 caffe 模型文件是否存在,模型下载地址见本博客最后的参考

# Create tracker
tracker = cv2.TrackerGOTURN_create()

# Read video
video = cv2.VideoCapture("tracking2.mkv")

# Exit if video not opened
if not video.isOpened():
    print("Could not open video")
    sys.exit()

# Read first frame
ok, frame = video.read()
# cv2.imwrite("first.jpg", frame)
# ok = False

if not ok:
    print("Cannot read video file")
    sys.exit()

创建跟踪器,读取视频第一帧

# Define a bounding box
# bbox = (25, 65, 35, 90)  # 左上坐标宽高
bbox = (90, 85, 50, 105)  # 左上坐标宽高

# Uncomment the line below to select a different bounding box
# bbox = cv2.selectROI(frame, False)

# Initialize tracker with first frame and bounding box
ok = tracker.init(frame, bbox)

用第一帧初始化跟踪器,作为待跟踪的模板,bbox 配置是目标左上角的横纵坐标,以及目标的宽和高

while True:
    num += 1
    # Read a new frame
    ok, frame = video.read()

    if not ok:
        break

    # Start timer
    timer = cv2.getTickCount()

    # Update tracker
    ok, bbox = tracker.update(frame)

    # Calculate Frames per second (FPS)
    fps = cv2.getTickFrequency() / (cv2.getTickCount() - timer);

    # Draw bounding box
    if ok:
        # Tracking success
        p1 = (int(bbox[0]), int(bbox[1]))
        p2 = (int(bbox[0] + bbox[2]), int(bbox[1] + bbox[3]))
        cv2.rectangle(frame, p1, p2, (0, 0, 255), 2, 1)
    else:
        # Tracking failure
        cv2.putText(frame, "Tracking failure detected", (100, 80), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (0, 0, 255), 2)

    # Display tracker type on frame
    cv2.putText(frame, "GOTURN Tracker", (100, 20), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50, 170, 50), 2);

    # Display FPS on frame
    cv2.putText(frame, "FPS : " + str(int(fps)), (100, 50), cv2.FONT_HERSHEY_SIMPLEX, 0.75, (50, 170, 50), 2);

    # Display result
    cv2.imshow("Tracking", frame)

    cv2.imwrite(f"{str(num).zfill(3)}.jpg", frame)

    # Exit if ESC pressed
    k = cv2.waitKey(1) & 0xff
    if k == 27:
        break

遍历视频中的图片,ok, bbox = tracker.update(frame) 跟踪,绘制帧率、跟踪器,保存、显示每一帧结果,ESC 键退出