Skip to content

API Reference

sam_utilities

SamGeo_apb(model_type='vit_h', automatic=True, device=None, checkpoint_dir=None, sam_kwargs=None, **kwargs)

Bases: SamGeo

Source code in apb_spatial_computer_vision/sam_utilities.py
102
103
104
105
106
107
108
109
110
111
112
113
114
def __init__(self,
    model_type="vit_h",
    automatic=True,
    device=None,
    checkpoint_dir=None,
    sam_kwargs=None,
    **kwargs,):
    super().__init__(model_type,
    automatic,
    device,
    checkpoint_dir,
    sam_kwargs,
    **kwargs,)

predict(point_coords=None, point_labels=None, boxes=None, point_crs=None, mask_input=None, multimask_output=True, return_logits=False, output=None, index=None, mask_multiplier=255, dtype='float32', return_results=False, **kwargs)

Predict masks for the given input prompts, using the currently set image.

Parameters:

Name Type Description Default
point_coords str | dict | list | ndarray

A Nx2 array of point prompts to the model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.

None
point_labels list | int | ndarray

A length N array of labels for the point prompts. 1 indicates a foreground point and 0 indicates a background point.

None
point_crs str

The coordinate reference system (CRS) of the point prompts.

None
boxes list | ndarray

A length 4 array given a box prompt to the model, in XYXY format.

None
mask_input ndarray

A low resolution mask input to the model, typically coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256. multimask_output (bool, optional): If true, the model will return three masks. For ambiguous input prompts (such as a single click), this will often produce better masks than a single prediction. If only a single mask is needed, the model's predicted quality score can be used to select the best mask. For non-ambiguous prompts, such as multiple input prompts, multimask_output=False can give better results.

None
return_logits bool

If true, returns un-thresholded masks logits instead of a binary mask.

False
output str

The path to the output image. Defaults to None.

None
index index

The index of the mask to save. Defaults to None, which will save the mask with the highest score.

None
mask_multiplier int

The mask multiplier for the output mask, which is usually a binary mask [0, 1].

255
dtype dtype

The data type of the output image. Defaults to np.float32.

'float32'
return_results bool

Whether to return the predicted masks, scores, and logits. Defaults to False.

False
Source code in apb_spatial_computer_vision/sam_utilities.py
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
def predict(
    self,
    point_coords=None,
    point_labels=None,
    boxes=None,
    point_crs=None,
    mask_input=None,
    multimask_output=True,
    return_logits=False,
    output=None,
    index=None,
    mask_multiplier=255,
    dtype="float32",
    return_results=False,
    **kwargs,
):
    """Predict masks for the given input prompts, using the currently set image.

    Args:
        point_coords (str | dict | list | np.ndarray, optional): A Nx2 array of point prompts to the
            model. Each point is in (X,Y) in pixels. It can be a path to a vector file, a GeoJSON
            dictionary, a list of coordinates [lon, lat], or a numpy array. Defaults to None.
        point_labels (list | int | np.ndarray, optional): A length N array of labels for the
            point prompts. 1 indicates a foreground point and 0 indicates a background point.
        point_crs (str, optional): The coordinate reference system (CRS) of the point prompts.
        boxes (list | np.ndarray, optional): A length 4 array given a box prompt to the
            model, in XYXY format.
        mask_input (np.ndarray, optional): A low resolution mask input to the model, typically
            coming from a previous prediction iteration. Has form 1xHxW, where for SAM, H=W=256.
            multimask_output (bool, optional): If true, the model will return three masks.
            For ambiguous input prompts (such as a single click), this will often
            produce better masks than a single prediction. If only a single
            mask is needed, the model's predicted quality score can be used
            to select the best mask. For non-ambiguous prompts, such as multiple
            input prompts, multimask_output=False can give better results.
        return_logits (bool, optional): If true, returns un-thresholded masks logits
            instead of a binary mask.
        output (str, optional): The path to the output image. Defaults to None.
        index (index, optional): The index of the mask to save. Defaults to None,
            which will save the mask with the highest score.
        mask_multiplier (int, optional): The mask multiplier for the output mask, which is usually a binary mask [0, 1].
        dtype (np.dtype, optional): The data type of the output image. Defaults to np.float32.
        return_results (bool, optional): Whether to return the predicted masks, scores, and logits. Defaults to False.

    """
    out_of_bounds = []

    if isinstance(boxes, str):
        gdf = gpd.read_file(boxes)
        if gdf.crs is not None:
            gdf = gdf.to_crs("epsg:4326")
        boxes = gdf.geometry.bounds.values.tolist()
    elif isinstance(boxes, dict):
        import json

        geojson = json.dumps(boxes)
        gdf = gpd.read_file(geojson, driver="GeoJSON")
        boxes = gdf.geometry.bounds.values.tolist()

    if isinstance(point_coords, str):
        point_coords = vector_to_geojson(point_coords)

    if isinstance(point_coords, dict):
        point_coords = geojson_to_coords(point_coords)

    if hasattr(self, "point_coords"):
        point_coords = self.point_coords

    if hasattr(self, "point_labels"):
        point_labels = self.point_labels

    if (point_crs is not None) and (point_coords is not None):
        point_coords, out_of_bounds = coords_to_xy(
            self.source, point_coords, point_crs, return_out_of_bounds=True
        )

    if isinstance(point_coords, list):
        point_coords = np.array(point_coords)

    if point_coords is not None:
        if point_labels is None:
            point_labels = [1] * len(point_coords)
        elif isinstance(point_labels, int):
            point_labels = [point_labels] * len(point_coords)

    if isinstance(point_labels, list):
        if len(point_labels) != len(point_coords):
            if len(point_labels) == 1:
                point_labels = point_labels * len(point_coords)
            elif len(out_of_bounds) > 0:
                print(f"Removing {len(out_of_bounds)} out-of-bound points.")
                point_labels_new = []
                for i, p in enumerate(point_labels):
                    if i not in out_of_bounds:
                        point_labels_new.append(p)
                point_labels = point_labels_new
            else:
                raise ValueError(
                    "The length of point_labels must be equal to the length of point_coords."
                )
        point_labels = np.array(point_labels)

    predictor = self.predictor

    input_boxes = None
    if isinstance(boxes, list) and (point_crs is not None):
        coords = bbox_to_xy(self.source, boxes, point_crs)
        input_boxes = np.array(coords)
        if isinstance(coords[0], int):
            input_boxes = input_boxes[None, :]
        else:
            input_boxes = torch.tensor(input_boxes, device=self.device)

            input_boxes = predictor.transform.apply_boxes_torch(
                input_boxes, self.image.shape[:2]
            )
    elif isinstance(boxes, list) and (point_crs is None):
        input_boxes = np.array(boxes)
        if isinstance(boxes[0], int):
            input_boxes = input_boxes[None, :]

    self.boxes = input_boxes

    if (
        boxes is None
        or (len(boxes) == 1)
        or (len(boxes) == 4 and isinstance(boxes[0], float))
    ):
        if isinstance(boxes, list) and isinstance(boxes[0], list):
            boxes = boxes[0]
        if isinstance(input_boxes,torch.Tensor):
            masks, scores, logits = predictor.predict(
                point_coords,
                point_labels,
                np.array(input_boxes.cpu()),
                mask_input,
                multimask_output,
                return_logits,
            )
        else:
            masks, scores, logits = predictor.predict(
                point_coords,
                point_labels,
                input_boxes,
                mask_input,
                multimask_output,
                return_logits,
            )
    else:
        masks, scores, logits = predictor.predict_torch(
            point_coords=point_coords,
            point_labels=point_coords,
            boxes=input_boxes,
            multimask_output=True,
        )

    self.masks = masks
    self.scores = scores
    self.logits = logits

    if output is not None:
        if boxes is None or (not isinstance(boxes[0], list)):
            self.save_prediction(output, index, mask_multiplier, dtype, **kwargs)
        else:
            self.tensor_to_numpy(
                index, output, mask_multiplier, dtype, save_args=kwargs
            )

    if return_results:
            return masks, scores, logits

raster_to_vector(source, output=None, simplify_tolerance=None, dst_crs=None, **kwargs)

Vectorize a raster dataset.

Parameters:

Name Type Description Default
source str

The path to the tiff file.

required
output str

The path to the vector file.

None
simplify_tolerance float

The maximum allowed geometry displacement. The higher this value, the smaller the number of vertices in the resulting geometry.

None
Source code in apb_spatial_computer_vision/sam_utilities.py
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
def raster_to_vector(source, output=None, simplify_tolerance=None, dst_crs=None, **kwargs):
    """Vectorize a raster dataset.

    Args:
        source (str): The path to the tiff file.
        output (str): The path to the vector file.
        simplify_tolerance (float, optional): The maximum allowed geometry displacement.
            The higher this value, the smaller the number of vertices in the resulting geometry.
    """
    from rasterio import features

    with rasterio.open(source) as src:
        band = src.read()

        mask = band != 0
        shapes = features.shapes(band, mask=mask, transform=src.transform)
        src.close()
    fc = [
        {"geometry": shapely.geometry.shape(shape), "properties": {"value": value}}
        for shape, value in shapes
    ]
    if simplify_tolerance is not None:
        for i in fc:
            i["geometry"] = i["geometry"].simplify(tolerance=simplify_tolerance)

    gdf = gpd.GeoDataFrame.from_features(fc)
    if src.crs is not None:
        gdf.set_crs(crs=src.crs, inplace=True)

    if dst_crs is not None:
        gdf = gdf.to_crs(dst_crs)        

    if output is not None:
        gdf.to_file(output)     

    n_gdf=gpd.tools.collect(gdf.geometry)
    return n_gdf.wkt

set_image(image, image_format='RGB')

Set the input image as a numpy array.

Parameters:

Name Type Description Default
image ndarray

The input image as a numpy array.

required
image_format str

The image format, can be RGB or BGR. Defaults to "RGB".

'RGB'
Source code in apb_spatial_computer_vision/sam_utilities.py
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
def set_image(self, image, image_format="RGB"):
    """Set the input image as a numpy array.

    Args:
        image (np.ndarray): The input image as a numpy array.
        image_format (str, optional): The image format, can be RGB or BGR. Defaults to "RGB".
    """
    if isinstance(image, str):
        if image.startswith("http"):
            image = download_file(image)

        if not os.path.exists(image):
            raise ValueError(f"Input path {image} does not exist.")

        self.source = image
        image=Ortophoto(image).raster.ReadAsArray()
        ar=image[:3,:,:]
        arr=np.transpose(ar,(1,2,0))
        rgb_image=cv2.cvtColor(arr,cv2.COLOR_BGR2RGB)
        self.image=rgb_image
    elif isinstance(image, np.ndarray):
        pass
    else:
        raise ValueError("Input image must be either a path or a numpy array.")
    try:
        self.predictor.set_image(self.image, image_format=image_format)
    except torch.OutOfMemoryError:
        pass