1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144
| import os import numpy as np import matplotlib.pyplot as plt import cv2 import glob import json
coords = []
def show_mask(mask, ax, random_color=False): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) h, w = mask.shape[-2:] mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) ax.imshow(mask_image)
def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels == 1] neg_points = coords[labels == 0] ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0, 0, 0, 0), lw=2))
def on_click(event): global coords if event.button == 1: x, y = event.xdata, event.ydata print(f"鼠标左键点击:x={x:.2f}, y={y:.2f}") coords.append([x, y]) elif event.button == 3: print("鼠标右键点击")
def get_mask(image, mask_id=1, click_coords=False, choose_mask=False, box=None): image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if click_coords: global coords fig, ax = plt.subplots() fig.set_size_inches(30, 20) ax.imshow(image) cid = fig.canvas.mpl_connect('button_press_event', on_click) plt.show() else: coords = []
from segment_anything import SamPredictor, sam_model_registry sam_checkpoint = "sam_vit_h_4b8939.pth" model_type = "vit_h" device = "cuda" sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) predictor = SamPredictor(sam) predictor.set_image(image)
input_point = np.array(coords) input_label = np.array([1] * len(coords))
input_box = box if len(coords) == 0: input_point = None input_label = None masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, box=input_box[None, :], multimask_output=True)
if choose_mask: plt.figure(figsize=(60, 20)) plt.subplot(1, 3, 1) plt.imshow(image) show_mask(masks[0], plt.gca()) plt.title(f"Mask 0, Score: {scores[0]:.3f}", fontsize=18) plt.subplot(1, 3, 2) plt.imshow(image) show_mask(masks[1], plt.gca()) plt.title(f"Mask 1, Score: {scores[1]:.3f}", fontsize=18) plt.subplot(1, 3, 3) plt.imshow(image) show_mask(masks[2], plt.gca()) plt.title(f"Mask 2, Score: {scores[1]:.3f}", fontsize=18) plt.show() mask_id = int(input())
mask = masks[mask_id] mask = np.tile(np.expand_dims(mask, axis=-1), 3) mask_data = np.where(mask, 255, 0) if click_coords: coords.clear() return mask_data
if __name__ == '__main__': obj = 'shirt' color_path = f'/Data/VolumeDeformData/{obj}/data/' mask_path = f'/Data/VolumeDeformData/{obj}/mask/' if not os.path.exists(mask_path): os.makedirs(mask_path)
img_paths = [] for extension in ["jpg", "png", "jpeg"]: img_paths += glob.glob(os.path.join(color_path, "*.{}".format(extension)))
json_path = 'GroundingDINO-main/box.json' with open(json_path, "r", encoding="utf-8") as f: data = json.load(f) for i in range(len(img_paths) // 2): img_name = f'frame-{str(i).zfill(6)}.color.png' img = cv2.imread(color_path + img_name) id = f'{obj}_{str(i).zfill(6)}' box = np.array(list(map(float, data[id][1:-1].split(',')))) mask = get_mask(img, mask_id=2, click_coords=False, choose_mask=False, box=box) cv2.imwrite(mask_path + str(i).zfill(6) + '.png', mask) print(img_name) f.close()
|