在Part1中,我们详细介绍了SAM2模型的微调训练流程。

SAM2模型微调训练、验证和预测(Part1)_sam2微调-CSDN博客

        本文将重点讲解模型验证和预测的实现方法。

        包含以下内容:

        1. 模型验证​​:在有掩码标签的情况下评估模型性能

        2. 零样本预测​​:对全新图像进行无监督预测

一、模型验证流程

        验证阶段使用带标注的测试集数据,通过计算IoU(交并比)和CPA(正确像素准确率)量化模型性能。以下是关键步骤解析:

1. 数据准备与加载

# 配置路径
data_dir = "dataset"
images_dir = os.path.join(data_dir, "images")
masks_dir = os.path.join(data_dir, "masks")
train_csv = os.path.join(data_dir, "train.csv")

model_cfg = r"configs\sam2.1\sam2.1_hiera_l.yaml"
checkpoint_path = r"checkpoints\sam2.1_hiera_large.pt"
finetuned_weights = r"weights\best.pt"  # 训练后的模型文件名
  • 数据集要求与训练集一致,包含images/masks/目录和train.csv索引文件
  • 读取sam2必要配置文件和模型训练保存的模型文件

2. 读取图像并调整分辨率

def read_image(image_path, mask_path):
    img = cv2.imread(image_path)[..., ::-1]
    mask = cv2.imread(mask_path, 0)
    r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
    img = cv2.resize(img, (int(img.shape[1]*r), int(img.shape[0]*r)))
    mask = cv2.resize(mask, (int(mask.shape[1]*r), int(mask.shape[0]*r)), interpolation=cv2.INTER_NEAREST)
    return img, mask
  • 掩码必须使用INTER_NEAREST插值,避免插值产生无效类别值
  • 图像与掩码需同步缩放,确保空间对齐

3. 提示点生成

def get_points(mask, num_points=30):
    coords = np.argwhere(mask > 0)  # 获取所有前景像素坐标
    points = []
    for _ in range(num_points):
        yx = coords[np.random.randint(len(coords))]  # 随机选择前景点
        points.append([[yx[1], yx[0]]])  # 转换为(x,y)格式
    return np.array(points)
  • 真实掩码区域采样更符合实际应用场景,但是必须依赖掩码文件。
  • 复杂场景可以增加点数(如50-100),简单场景可减少(10-20)

4. 模型加载与推理

# 加载微调后的模型
sam2_model = build_sam2(model_cfg, checkpoint_path, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)
predictor.model.load_state_dict(torch.load(finetuned_weights))

# 执行预测
with torch.no_grad():
    predictor.set_image(image)  # 提取图像特征
    masks, scores, _ = predictor.predict(
        point_coords=input_points,
        point_labels=np.ones([input_points.shape[0], 1])  # 所有点设为前景提示
    )
  • point_labels全设为1表示这些点均属于目标物体

5. 结果后处理

# 掩码排序与去重
np_masks = np.array(masks[:, 0])  # 取第一个输出通道
sorted_masks = np_masks[np.argsort(scores[:, 0])][::-1]  # 按置信度降序排列

seg_map = np.zeros_like(sorted_masks[0], dtype=np.uint8)
occupancy_mask = np.zeros_like(sorted_masks[0], dtype=bool)

for i in range(sorted_masks.shape[0]):
    mask = sorted_masks[i]
    # 过滤与已存在区域重叠超过15%的掩码
    if (mask * occupancy_mask).sum() / mask.sum() > 0.15:
        continue
    seg_map[mask & ~occupancy_mask] = i + 1  # 分配唯一标签
    occupancy_mask |= mask  # 更新已占用区域
  • 通过occupancy_mask避免预测掩码的重叠
  • 不同数值代表不同实例,适用于实例分割任务

6. 评估指标计算

def compute_metrics(pred_mask, true_mask):
    # 二值化处理
    pred_binary = (pred_mask > 0).astype(np.uint8)
    true_binary = (true_mask > 0).astype(np.uint8)
    
    # 计算IoU
    intersection = np.logical_and(pred_binary, true_binary).sum()
    union = np.logical_or(pred_binary, true_binary).sum()
    iou = intersection / union if union != 0 else 0
    
    # 计算CPA(像素级准确率)
    correct_pixels = (pred_binary == true_binary).sum()
    cpa = correct_pixels / true_binary.size
    return iou, cpa
  • ​IoU​​:衡量分割区域重叠度,>0.7通常认为效果良好
  • ​CPA​​:反映整体像素分类准确率,但对类别不平衡敏感

7. 结果可视化

        计算并输出评价指标,同时输出训练集的掩码和预测结果,以进行直观对比。

# 计算指标
iou, cpa = compute_metrics(seg_map, mask)

# 打印结果
print(f"IoU: {iou:.4f}")
print(f"CPA (Correct Pixel Accuracy): {cpa:.4f}")

# ==== 可视化 ====
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 1)
plt.title('Test Image')
plt.imshow(image)
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title('Original Mask')
plt.imshow(mask, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title('Predicted Segmentation')
plt.imshow(seg_map, cmap='jet')
plt.axis('off')

plt.tight_layout()
plt.show()

8. 输出结果演示

IoU: 0.8190
CPA (Correct Pixel Accuracy): 0.9802

9. 完整代码

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# 配置路径
data_dir = "dataset"
images_dir = os.path.join(data_dir, "images")
masks_dir = os.path.join(data_dir, "masks")
train_csv = os.path.join(data_dir, "train.csv")

model_cfg = r"configs\sam2.1\sam2.1_hiera_l.yaml"
checkpoint_path = r"checkpoints\sam2.1_hiera_large.pt"
finetuned_weights = r"weights\best.pt"  # 训练后的模型文件名

# 加载一张测试图像和掩膜
import pandas as pd, random
df = pd.read_csv(train_csv)
entry = df.sample(1).iloc[0]  # 随机选一张图片

image_path = os.path.join(images_dir, entry["ImageId"])
mask_path = os.path.join(masks_dir, entry["MaskId"])

def read_image(image_path, mask_path):
    img = cv2.imread(image_path)[..., ::-1]
    mask = cv2.imread(mask_path, 0)
    r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
    img = cv2.resize(img, (int(img.shape[1]*r), int(img.shape[0]*r)))
    mask = cv2.resize(mask, (int(mask.shape[1]*r), int(mask.shape[0]*r)), interpolation=cv2.INTER_NEAREST)
    return img, mask

def get_points(mask, num_points=30):
    coords = np.argwhere(mask > 0)
    points = []
    for _ in range(num_points):
        yx = np.array(coords[np.random.randint(len(coords))])
        points.append([[yx[1], yx[0]]])
    return np.array(points)

image, mask = read_image(image_path, mask_path)
input_points = get_points(mask)

# 加载模型并加载训练权重
sam2_model = build_sam2(model_cfg, checkpoint_path, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)
predictor.model.load_state_dict(torch.load(finetuned_weights))

# 推理
with torch.no_grad():
    predictor.set_image(image)
    masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=np.ones([input_points.shape[0], 1])
    )

# 处理结果
np_masks = np.array(masks[:, 0])
np_scores = scores[:, 0]
sorted_masks = np_masks[np.argsort(np_scores)][::-1]

seg_map = np.zeros_like(sorted_masks[0], dtype=np.uint8)
occupancy_mask = np.zeros_like(sorted_masks[0], dtype=bool)

for i in range(sorted_masks.shape[0]):
    mask = sorted_masks[i]
    if (mask * occupancy_mask).sum() / mask.sum() > 0.15:
        continue
    mask_bool = mask.astype(bool)
    mask_bool[occupancy_mask] = False
    seg_map[mask_bool] = i + 1
    occupancy_mask[mask_bool] = True

# 计算 IoU 和 CPA
def compute_metrics(pred_mask, true_mask):
    pred_binary = (pred_mask > 0).astype(np.uint8)
    true_binary = (true_mask > 0).astype(np.uint8)

    intersection = np.logical_and(pred_binary, true_binary).sum()
    union = np.logical_or(pred_binary, true_binary).sum()
    iou = intersection / union if union != 0 else 0

    correct_pixels = (pred_binary == true_binary).sum()
    total_pixels = pred_binary.size
    cpa = correct_pixels / total_pixels
    return iou, cpa

# 计算指标
iou, cpa = compute_metrics(seg_map, mask)

# 打印结果
print(f"IoU: {iou:.4f}")
print(f"CPA (Correct Pixel Accuracy): {cpa:.4f}")

# 可视化
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 1)
plt.title('Test Image')
plt.imshow(image)
plt.axis('off')

plt.subplot(1, 3, 2)
plt.title('Original Mask')
plt.imshow(mask, cmap='gray')
plt.axis('off')

plt.subplot(1, 3, 3)
plt.title('Predicted Segmentation')
plt.imshow(seg_map, cmap='jet')
plt.axis('off')

plt.tight_layout()
plt.show()

二、零样本预测流程

        当处理没有标注的新图像时,主要流程差异在于:

1. 提示点生成变化

# 完全随机采样
def sample_random_points(image, num_points=30):
    h, w, _ = image.shape
    points = []
    for _ in range(num_points):
        x = random.randint(0, w - 1)
        y = random.randint(0, h - 1)
        points.append([[x, y]])
    return np.array(points)
  • ​生成提示点的方法的对比:

    • 验证阶段​
      通过np.argwhere(mask>0)获取掩码中所有前景像素坐标,再从中随机选取指定数量的点。这种方法​模拟了人工精准标注​​的场景,所有提示点均位于目标物体上。

    • ​预测阶段​
      直接根据图像尺寸生成随机坐标(x=randint(0,w), y=randint(0,h)),不依赖任何先验信息。这种方法​模拟真实应用中的未知场景​,但可能产生大量无效提示。

2. 无标签评估

        由于缺少真实掩码,此时需要人工定性评估分割结果。

3. 预测结果演示

        适用于对没有掩码标签的图像的预测。

4. 完整代码

import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

#配置路径
images_dir = r"images/test2"
model_cfg = r"configs/sam2.1/sam2.1_hiera_l.yaml"
base_ckpt = r"checkpoints/sam2.1_hiera_large.pt"
finetuned_weights = r"weights/best.pt"

#随机选择一张图像
image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
entry = random.choice(image_files)
image_path = os.path.join(images_dir, entry)

#图像读取和缩放
def read_image(path):
    image = cv2.imread(path)[..., ::-1]
    r = np.min([1024 / image.shape[1], 1024 / image.shape[0]])
    image = cv2.resize(image, (int(image.shape[1]*r), int(image.shape[0]*r)))
    return image

image = read_image(image_path)

#随机采样提示点
def sample_random_points(image, num_points=30):
    h, w, _ = image.shape
    points = []
    for _ in range(num_points):
        x = random.randint(0, w - 1)
        y = random.randint(0, h - 1)
        points.append([[x, y]])
    return np.array(points)

input_points = sample_random_points(image, num_points=30)
input_labels = np.ones((input_points.shape[0], 1))

#加载模型并加载微调权重
model = build_sam2(model_cfg, base_ckpt, device="cuda", apply_postprocessing=True)
model.load_state_dict(torch.load(finetuned_weights, map_location="cuda"))
predictor = SAM2ImagePredictor(model)

#执行预测
with torch.no_grad():
    predictor.set_image(image)
    masks, scores, _ = predictor.predict(
        point_coords=input_points,
        point_labels=input_labels
    )

#后处理多掩膜组合为一张图
np_masks = np.array(masks[:, 0])
np_scores = scores if scores.ndim == 1 else scores[:, 0]
sorted_masks = np_masks[np.argsort(np_scores)][::-1]

seg_map = np.zeros(sorted_masks[0].shape, dtype=np.uint8)
occupancy_mask = np.zeros_like(seg_map, dtype=bool)

for i in range(sorted_masks.shape[0]):
    mask = sorted_masks[i]
    if (mask * occupancy_mask).sum() / (mask.sum() + 1e-6) > 0.15:
        continue
    mask_bool = mask.astype(bool)
    mask_bool[occupancy_mask] = False
    seg_map[mask_bool] = i + 1
    occupancy_mask[mask_bool] = True

#可视化
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(image)
plt.axis('off')

plt.subplot(1, 2, 2)
plt.title("Predicted Segmentation (random prompt)")
plt.imshow(seg_map, cmap='jet')
plt.axis('off')

plt.suptitle(f"Image: {entry}", fontsize=14)
plt.tight_layout()
plt.show()

Logo

欢迎来到由智源人工智能研究院发起的Triton中文社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。

更多推荐