EigenCAM for YOLO5 — Advanced AI explainability with pytorch-gradcam (original) (raw)

import warnings warnings.filterwarnings('ignore') warnings.simplefilter('ignore') import torch
import cv2 import numpy as np import requests import torchvision.transforms as transforms from pytorch_grad_cam import EigenCAM from pytorch_grad_cam.utils.image import show_cam_on_image, scale_cam_image from PIL import Image

COLORS = np.random.uniform(0, 255, size=(80, 3))

def parse_detections(results): detections = results.pandas().xyxy[0] detections = detections.to_dict() boxes, colors, names = [], [], []

for i in range(len(detections["xmin"])):
    confidence = detections["confidence"][i]
    if confidence < 0.2:
        continue
    xmin = int(detections["xmin"][i])
    ymin = int(detections["ymin"][i])
    xmax = int(detections["xmax"][i])
    ymax = int(detections["ymax"][i])
    name = detections["name"][i]
    category = int(detections["class"][i])
    color = COLORS[category]

    boxes.append((xmin, ymin, xmax, ymax))
    colors.append(color)
    names.append(name)
return boxes, colors, names

def draw_detections(boxes, colors, names, img): for box, color, name in zip(boxes, colors, names): xmin, ymin, xmax, ymax = box cv2.rectangle( img, (xmin, ymin), (xmax, ymax), color, 2)

    cv2.putText(img, name, (xmin, ymin - 5),
                cv2.FONT_HERSHEY_SIMPLEX, 0.8, color, 2,
                lineType=cv2.LINE_AA)
return img

image_url = "https://upload.wikimedia.org/wikipedia/commons/f/f1/Puppies_%284984818141%29.jpg" img = np.array(Image.open(requests.get(image_url, stream=True).raw)) img = cv2.resize(img, (640, 640)) rgb_img = img.copy() img = np.float32(img) / 255 transform = transforms.ToTensor() tensor = transform(img).unsqueeze(0)

model = torch.hub.load('ultralytics/yolov5', 'yolov5s', pretrained=True) model.eval() model.cpu() target_layers = [model.model.model.model[-2]]

results = model([rgb_img]) boxes, colors, names = parse_detections(results) detections = draw_detections(boxes, colors, names, rgb_img.copy()) Image.fromarray(detections)