Source code for rastervision.core.data.raster_transformer.rgb_class_transformer

from typing import TYPE_CHECKING, List, Optional

import numpy as np

from rastervision.core.data.raster_transformer import (RasterTransformer,
                                                       ReclassTransformer)
from rastervision.core.data.utils import (color_to_triple, color_to_integer,
                                          rgb_to_int_array)

if TYPE_CHECKING:
    from rastervision.core.data.class_config import ClassConfig


[docs]class RGBClassTransformer(RasterTransformer): """Maps RGB values to class IDs. Can also do the reverse."""
[docs] def __init__(self, class_config: 'ClassConfig'): class_config.ensure_null_class() self.null_class_id = class_config.null_class_id color_to_class = class_config.get_color_to_class_id() self.rgb_int_to_class = { color_to_integer(col): class_id for col, class_id in color_to_class.items() } self.rgb_int_to_class_tf = ReclassTransformer(self.rgb_int_to_class) class_to_color_triple = { class_id: color_to_triple(col) for col, class_id in color_to_class.items() } # i-th row of this array is the color-triple of the i-th class self.class_to_rgb_arr = np.array( [ class_to_color_triple[c] for c in sorted(class_to_color_triple.keys()) ], dtype=np.uint8)
[docs] def transform(self, chip: np.ndarray, channel_order: Optional[List[int]] = None) -> np.ndarray: """Transform RGB array to array of class IDs or vice versa. Args: chip (np.ndarray): Numpy array of shape (H, W, 3). channel_order (Optional[List[int]], optional): List of indices of channels that were extracted from the raw imagery. Defaults to None. Returns: np.ndarray: An array of class IDs. """ return self.rgb_to_class(chip)
[docs] def rgb_to_class(self, array_rgb: np.ndarray) -> np.ndarray: array_int = rgb_to_int_array(array_rgb) array_class_id = self.rgb_int_to_class_tf.transform(array_int) return array_class_id.astype(np.uint8)
[docs] def class_to_rgb(self, class_labels: np.ndarray) -> np.ndarray: return self.class_to_rgb_arr[class_labels]