SAM2模型微调训练、验证和预测(Part2)
本文详细介绍了SAM2模型的验证与预测流程。在模型验证部分,文章阐述了如何使用带标注的测试集评估模型性能,包括数据准备、提示点生成、模型推理、结果后处理等关键步骤和IoU和CPA两种评估指标的计算方法。在零样本预测部分,说明了如何对无标注新图像进行预测,包括随机提示点生成和人工定性评估方法。文章提供了完整的代码实现和可视化示例,展示了模型在验证和预测场景下的应用效果。这些流程为SAM2模型的实际应
在Part1中,我们详细介绍了SAM2模型的微调训练流程。
SAM2模型微调训练、验证和预测(Part1)_sam2微调-CSDN博客
本文将重点讲解模型验证和预测的实现方法。
包含以下内容:
1. 模型验证:在有掩码标签的情况下评估模型性能
2. 零样本预测:对全新图像进行无监督预测
一、模型验证流程
验证阶段使用带标注的测试集数据,通过计算IoU(交并比)和CPA(正确像素准确率)量化模型性能。以下是关键步骤解析:
1. 数据准备与加载
# 配置路径
data_dir = "dataset"
images_dir = os.path.join(data_dir, "images")
masks_dir = os.path.join(data_dir, "masks")
train_csv = os.path.join(data_dir, "train.csv")
model_cfg = r"configs\sam2.1\sam2.1_hiera_l.yaml"
checkpoint_path = r"checkpoints\sam2.1_hiera_large.pt"
finetuned_weights = r"weights\best.pt" # 训练后的模型文件名
- 数据集要求与训练集一致,包含
images/
、masks/
目录和train.csv
索引文件 - 读取sam2必要配置文件和模型训练保存的模型文件
2. 读取图像并调整分辨率
def read_image(image_path, mask_path):
img = cv2.imread(image_path)[..., ::-1]
mask = cv2.imread(mask_path, 0)
r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
img = cv2.resize(img, (int(img.shape[1]*r), int(img.shape[0]*r)))
mask = cv2.resize(mask, (int(mask.shape[1]*r), int(mask.shape[0]*r)), interpolation=cv2.INTER_NEAREST)
return img, mask
- 掩码必须使用
INTER_NEAREST
插值,避免插值产生无效类别值 - 图像与掩码需同步缩放,确保空间对齐
3. 提示点生成
def get_points(mask, num_points=30):
coords = np.argwhere(mask > 0) # 获取所有前景像素坐标
points = []
for _ in range(num_points):
yx = coords[np.random.randint(len(coords))] # 随机选择前景点
points.append([[yx[1], yx[0]]]) # 转换为(x,y)格式
return np.array(points)
- 真实掩码区域采样更符合实际应用场景,但是必须依赖掩码文件。
- 复杂场景可以增加点数(如50-100),简单场景可减少(10-20)
4. 模型加载与推理
# 加载微调后的模型
sam2_model = build_sam2(model_cfg, checkpoint_path, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)
predictor.model.load_state_dict(torch.load(finetuned_weights))
# 执行预测
with torch.no_grad():
predictor.set_image(image) # 提取图像特征
masks, scores, _ = predictor.predict(
point_coords=input_points,
point_labels=np.ones([input_points.shape[0], 1]) # 所有点设为前景提示
)
point_labels
全设为1表示这些点均属于目标物体
5. 结果后处理
# 掩码排序与去重
np_masks = np.array(masks[:, 0]) # 取第一个输出通道
sorted_masks = np_masks[np.argsort(scores[:, 0])][::-1] # 按置信度降序排列
seg_map = np.zeros_like(sorted_masks[0], dtype=np.uint8)
occupancy_mask = np.zeros_like(sorted_masks[0], dtype=bool)
for i in range(sorted_masks.shape[0]):
mask = sorted_masks[i]
# 过滤与已存在区域重叠超过15%的掩码
if (mask * occupancy_mask).sum() / mask.sum() > 0.15:
continue
seg_map[mask & ~occupancy_mask] = i + 1 # 分配唯一标签
occupancy_mask |= mask # 更新已占用区域
- 通过
occupancy_mask
避免预测掩码的重叠 - 不同数值代表不同实例,适用于实例分割任务
6. 评估指标计算
def compute_metrics(pred_mask, true_mask):
# 二值化处理
pred_binary = (pred_mask > 0).astype(np.uint8)
true_binary = (true_mask > 0).astype(np.uint8)
# 计算IoU
intersection = np.logical_and(pred_binary, true_binary).sum()
union = np.logical_or(pred_binary, true_binary).sum()
iou = intersection / union if union != 0 else 0
# 计算CPA(像素级准确率)
correct_pixels = (pred_binary == true_binary).sum()
cpa = correct_pixels / true_binary.size
return iou, cpa
- IoU:衡量分割区域重叠度,>0.7通常认为效果良好
- CPA:反映整体像素分类准确率,但对类别不平衡敏感
7. 结果可视化
计算并输出评价指标,同时输出训练集的掩码和预测结果,以进行直观对比。
# 计算指标
iou, cpa = compute_metrics(seg_map, mask)
# 打印结果
print(f"IoU: {iou:.4f}")
print(f"CPA (Correct Pixel Accuracy): {cpa:.4f}")
# ==== 可视化 ====
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 1)
plt.title('Test Image')
plt.imshow(image)
plt.axis('off')
plt.subplot(1, 3, 2)
plt.title('Original Mask')
plt.imshow(mask, cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.title('Predicted Segmentation')
plt.imshow(seg_map, cmap='jet')
plt.axis('off')
plt.tight_layout()
plt.show()
8. 输出结果演示
IoU: 0.8190
CPA (Correct Pixel Accuracy): 0.9802
9. 完整代码
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
# 配置路径
data_dir = "dataset"
images_dir = os.path.join(data_dir, "images")
masks_dir = os.path.join(data_dir, "masks")
train_csv = os.path.join(data_dir, "train.csv")
model_cfg = r"configs\sam2.1\sam2.1_hiera_l.yaml"
checkpoint_path = r"checkpoints\sam2.1_hiera_large.pt"
finetuned_weights = r"weights\best.pt" # 训练后的模型文件名
# 加载一张测试图像和掩膜
import pandas as pd, random
df = pd.read_csv(train_csv)
entry = df.sample(1).iloc[0] # 随机选一张图片
image_path = os.path.join(images_dir, entry["ImageId"])
mask_path = os.path.join(masks_dir, entry["MaskId"])
def read_image(image_path, mask_path):
img = cv2.imread(image_path)[..., ::-1]
mask = cv2.imread(mask_path, 0)
r = np.min([1024 / img.shape[1], 1024 / img.shape[0]])
img = cv2.resize(img, (int(img.shape[1]*r), int(img.shape[0]*r)))
mask = cv2.resize(mask, (int(mask.shape[1]*r), int(mask.shape[0]*r)), interpolation=cv2.INTER_NEAREST)
return img, mask
def get_points(mask, num_points=30):
coords = np.argwhere(mask > 0)
points = []
for _ in range(num_points):
yx = np.array(coords[np.random.randint(len(coords))])
points.append([[yx[1], yx[0]]])
return np.array(points)
image, mask = read_image(image_path, mask_path)
input_points = get_points(mask)
# 加载模型并加载训练权重
sam2_model = build_sam2(model_cfg, checkpoint_path, device="cuda")
predictor = SAM2ImagePredictor(sam2_model)
predictor.model.load_state_dict(torch.load(finetuned_weights))
# 推理
with torch.no_grad():
predictor.set_image(image)
masks, scores, logits = predictor.predict(
point_coords=input_points,
point_labels=np.ones([input_points.shape[0], 1])
)
# 处理结果
np_masks = np.array(masks[:, 0])
np_scores = scores[:, 0]
sorted_masks = np_masks[np.argsort(np_scores)][::-1]
seg_map = np.zeros_like(sorted_masks[0], dtype=np.uint8)
occupancy_mask = np.zeros_like(sorted_masks[0], dtype=bool)
for i in range(sorted_masks.shape[0]):
mask = sorted_masks[i]
if (mask * occupancy_mask).sum() / mask.sum() > 0.15:
continue
mask_bool = mask.astype(bool)
mask_bool[occupancy_mask] = False
seg_map[mask_bool] = i + 1
occupancy_mask[mask_bool] = True
# 计算 IoU 和 CPA
def compute_metrics(pred_mask, true_mask):
pred_binary = (pred_mask > 0).astype(np.uint8)
true_binary = (true_mask > 0).astype(np.uint8)
intersection = np.logical_and(pred_binary, true_binary).sum()
union = np.logical_or(pred_binary, true_binary).sum()
iou = intersection / union if union != 0 else 0
correct_pixels = (pred_binary == true_binary).sum()
total_pixels = pred_binary.size
cpa = correct_pixels / total_pixels
return iou, cpa
# 计算指标
iou, cpa = compute_metrics(seg_map, mask)
# 打印结果
print(f"IoU: {iou:.4f}")
print(f"CPA (Correct Pixel Accuracy): {cpa:.4f}")
# 可视化
plt.figure(figsize=(18, 6))
plt.subplot(1, 3, 1)
plt.title('Test Image')
plt.imshow(image)
plt.axis('off')
plt.subplot(1, 3, 2)
plt.title('Original Mask')
plt.imshow(mask, cmap='gray')
plt.axis('off')
plt.subplot(1, 3, 3)
plt.title('Predicted Segmentation')
plt.imshow(seg_map, cmap='jet')
plt.axis('off')
plt.tight_layout()
plt.show()
二、零样本预测流程
当处理没有标注的新图像时,主要流程差异在于:
1. 提示点生成变化
# 完全随机采样
def sample_random_points(image, num_points=30):
h, w, _ = image.shape
points = []
for _ in range(num_points):
x = random.randint(0, w - 1)
y = random.randint(0, h - 1)
points.append([[x, y]])
return np.array(points)
-
生成提示点的方法的对比:
-
验证阶段
通过np.argwhere(mask>0)
获取掩码中所有前景像素坐标,再从中随机选取指定数量的点。这种方法模拟了人工精准标注的场景,所有提示点均位于目标物体上。 -
预测阶段
直接根据图像尺寸生成随机坐标(x=randint(0,w), y=randint(0,h))
,不依赖任何先验信息。这种方法模拟真实应用中的未知场景,但可能产生大量无效提示。
-
2. 无标签评估
由于缺少真实掩码,此时需要人工定性评估分割结果。
3. 预测结果演示
适用于对没有掩码标签的图像的预测。
4. 完整代码
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
import cv2
import torch
import numpy as np
import matplotlib.pyplot as plt
import random
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
#配置路径
images_dir = r"images/test2"
model_cfg = r"configs/sam2.1/sam2.1_hiera_l.yaml"
base_ckpt = r"checkpoints/sam2.1_hiera_large.pt"
finetuned_weights = r"weights/best.pt"
#随机选择一张图像
image_files = [f for f in os.listdir(images_dir) if f.lower().endswith(('.jpg', '.png', '.jpeg'))]
entry = random.choice(image_files)
image_path = os.path.join(images_dir, entry)
#图像读取和缩放
def read_image(path):
image = cv2.imread(path)[..., ::-1]
r = np.min([1024 / image.shape[1], 1024 / image.shape[0]])
image = cv2.resize(image, (int(image.shape[1]*r), int(image.shape[0]*r)))
return image
image = read_image(image_path)
#随机采样提示点
def sample_random_points(image, num_points=30):
h, w, _ = image.shape
points = []
for _ in range(num_points):
x = random.randint(0, w - 1)
y = random.randint(0, h - 1)
points.append([[x, y]])
return np.array(points)
input_points = sample_random_points(image, num_points=30)
input_labels = np.ones((input_points.shape[0], 1))
#加载模型并加载微调权重
model = build_sam2(model_cfg, base_ckpt, device="cuda", apply_postprocessing=True)
model.load_state_dict(torch.load(finetuned_weights, map_location="cuda"))
predictor = SAM2ImagePredictor(model)
#执行预测
with torch.no_grad():
predictor.set_image(image)
masks, scores, _ = predictor.predict(
point_coords=input_points,
point_labels=input_labels
)
#后处理多掩膜组合为一张图
np_masks = np.array(masks[:, 0])
np_scores = scores if scores.ndim == 1 else scores[:, 0]
sorted_masks = np_masks[np.argsort(np_scores)][::-1]
seg_map = np.zeros(sorted_masks[0].shape, dtype=np.uint8)
occupancy_mask = np.zeros_like(seg_map, dtype=bool)
for i in range(sorted_masks.shape[0]):
mask = sorted_masks[i]
if (mask * occupancy_mask).sum() / (mask.sum() + 1e-6) > 0.15:
continue
mask_bool = mask.astype(bool)
mask_bool[occupancy_mask] = False
seg_map[mask_bool] = i + 1
occupancy_mask[mask_bool] = True
#可视化
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.title("Original Image")
plt.imshow(image)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.title("Predicted Segmentation (random prompt)")
plt.imshow(seg_map, cmap='jet')
plt.axis('off')
plt.suptitle(f"Image: {entry}", fontsize=14)
plt.tight_layout()
plt.show()

欢迎来到由智源人工智能研究院发起的Triton中文社区,这里是一个汇聚了AI开发者、数据科学家、机器学习爱好者以及业界专家的活力平台。我们致力于成为业内领先的Triton技术交流与应用分享的殿堂,为推动人工智能技术的普及与深化应用贡献力量。
更多推荐
所有评论(0)