Source code for rastervision.core.data.raster_transformer.rgb_class_transformer
from typing import TYPE_CHECKING, Literal
from pydantic.types import PositiveInt as PosInt
import numpy as np
from rastervision.core.data.raster_transformer import (RasterTransformer,
ReclassTransformer)
from rastervision.core.data.utils import (color_to_triple, color_to_integer,
rgb_to_int_array)
if TYPE_CHECKING:
from rastervision.core.data.class_config import ClassConfig
[docs]class RGBClassTransformer(RasterTransformer):
"""Maps RGB values to class IDs. Can also do the reverse."""
[docs] def __init__(self, class_config: 'ClassConfig'):
class_config.ensure_null_class()
self.null_class_id = class_config.null_class_id
color_to_class = class_config.get_color_to_class_id()
self.rgb_int_to_class = {
color_to_integer(col): class_id
for col, class_id in color_to_class.items()
}
self.rgb_int_to_class_tf = ReclassTransformer(self.rgb_int_to_class)
class_to_color_triple = {
class_id: color_to_triple(col) if isinstance(col, str) else col
for col, class_id in color_to_class.items()
}
# i-th row of this array is the color-triple of the i-th class
self.class_to_rgb_arr = np.array(
[
class_to_color_triple[c]
for c in sorted(class_to_color_triple.keys())
],
dtype=np.uint8)
[docs] def transform(self, chip: np.ndarray) -> np.ndarray:
"""Transform RGB array to array of class IDs or vice versa.
Args:
chip: Numpy array of shape (H, W, 3).
Returns:
An array of class IDs of shape (H, W, 1).
"""
return self.rgb_to_class(chip)
[docs] def rgb_to_class(self, array_rgb: np.ndarray) -> np.ndarray:
array_int = rgb_to_int_array(array_rgb)
array_class_id = self.rgb_int_to_class_tf.transform(array_int)
return array_class_id.astype(np.uint8)
[docs] def class_to_rgb(self, class_labels: np.ndarray) -> np.ndarray:
return self.class_to_rgb_arr[class_labels]
[docs] def get_out_channels(self, in_channels: PosInt) -> Literal[1]:
if in_channels != 3:
raise ValueError(
'RGBClassTransformer only accepts 3-channel inputs.')
return 1