Source code for rastervision.core.data.raster_transformer.reclass_transformer
from typing import TYPE_CHECKING, Dict, List, Optional
from rastervision.core.data.raster_transformer import RasterTransformer
if TYPE_CHECKING:
import numpy as np
[docs]class ReclassTransformer(RasterTransformer):
"""Maps class IDs in a label raster to other values."""
[docs] def __init__(self, mapping: Dict[int, int]):
"""Construct a new ReclassTransformer.
Args:
mapping: (dict) Remapping dictionary
"""
self.mapping = mapping
[docs] def transform(self,
chip: 'np.ndarray',
channel_order: Optional[List[int]] = None):
"""Transform a chip.
Reclassify a label raster using the given mapping.
Args:
chip: ndarray of shape [height, width, channels] This is assumed to already
have the channel_order applied to it if channel_order is set. In other
words, channels should be equal to len(channel_order).
channel_order: list of indices of channels that were extracted from the
raw imagery.
Returns:
[height, width, channels] numpy array
"""
masks = []
for (value_from, value_to) in self.mapping.items():
mask = (chip == value_from)
masks.append((mask, value_to))
for (mask, value_to) in masks:
chip[mask] = value_to
return chip