# Copyright 2013 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Image processing utility functions."""


import copy
import io
import logging
import math
import matplotlib
from matplotlib import pylab
import matplotlib.pyplot
import os
import sys

import capture_request_utils
import colour
import error_util
import noise_model_constants
import numpy
from PIL import Image
from PIL import ImageCms


_CMAP_BLUE = ('black', 'blue', 'lightblue')
_CMAP_GREEN = ('black', 'green', 'lightgreen')
_CMAP_RED = ('black', 'red', 'lightcoral')
_CMAP_SIZE = 6  # 6 inches
_NUM_RAW_CHANNELS = 4  # R, Gr, Gb, B

LENS_SHADING_MAP_ON = 1

# The matrix is from JFIF spec
DEFAULT_YUV_TO_RGB_CCM = numpy.matrix([[1.000, 0.000, 1.402],
                                       [1.000, -0.344, -0.714],
                                       [1.000, 1.772, 0.000]])

DEFAULT_YUV_OFFSETS = numpy.array([0, 128, 128], dtype=numpy.uint8)
MAX_LUT_SIZE = 65536
DEFAULT_GAMMA_LUT = numpy.array([
    math.floor((MAX_LUT_SIZE-1) * math.pow(i/(MAX_LUT_SIZE-1), 1/2.2) + 0.5)
    for i in range(MAX_LUT_SIZE)])
RGB2GRAY_WEIGHTS = (0.299, 0.587, 0.114)
TEST_IMG_DIR = os.path.join(os.environ['CAMERA_ITS_TOP'], 'test_images')

# Expected adapted primaries in ICC profile per color space
EXPECTED_RX_P3 = 0.682
EXPECTED_RY_P3 = 0.319
EXPECTED_GX_P3 = 0.285
EXPECTED_GY_P3 = 0.675
EXPECTED_BX_P3 = 0.156
EXPECTED_BY_P3 = 0.066

EXPECTED_RX_SRGB = 0.648
EXPECTED_RY_SRGB = 0.331
EXPECTED_GX_SRGB = 0.321
EXPECTED_GY_SRGB = 0.598
EXPECTED_BX_SRGB = 0.156
EXPECTED_BY_SRGB = 0.066

# Chosen empirically - tolerance for the point in triangle test for colorspace
# chromaticities
COLORSPACE_TRIANGLE_AREA_TOL = 0.00028


def plot_lsc_maps(lsc_maps, plot_name, test_name_with_log_path):
  """Plot the lens shading correction maps.

  Args:
    lsc_maps: 4D np array; r, gr, gb, b lens shading correction maps.
    plot_name: str; identifier for maps ('full_scale' or 'metadata').
    test_name_with_log_path: str; test name with log_path location.

  Returns:
    None, but generates and saves plots.
  """
  aspect_ratio = lsc_maps[:, :, 0].shape[1] / lsc_maps[:, :, 0].shape[0]
  plot_w = 1 + aspect_ratio * _CMAP_SIZE  # add 1 for heatmap legend
  matplotlib.pyplot.figure(plot_name, figsize=(plot_w, _CMAP_SIZE))
  pylab.suptitle(plot_name)

  pylab.subplot(2, 2, 1)  # 2x2 top left
  pylab.title('R')
  cmap = matplotlib.colors.LinearSegmentedColormap.from_list('', _CMAP_RED)
  matplotlib.pyplot.pcolormesh(lsc_maps[:, :, 0], cmap=cmap)
  matplotlib.pyplot.colorbar()

  pylab.subplot(2, 2, 2)  # 2x2 top right
  pylab.title('Gr')
  cmap = matplotlib.colors.LinearSegmentedColormap.from_list('', _CMAP_GREEN)
  matplotlib.pyplot.pcolormesh(lsc_maps[:, :, 1], cmap=cmap)
  matplotlib.pyplot.colorbar()

  pylab.subplot(2, 2, 3)  # 2x2 bottom left
  pylab.title('Gb')
  cmap = matplotlib.colors.LinearSegmentedColormap.from_list('', _CMAP_GREEN)
  matplotlib.pyplot.pcolormesh(lsc_maps[:, :, 2], cmap=cmap)
  matplotlib.pyplot.colorbar()

  pylab.subplot(2, 2, 4)  # 2x2 bottom right
  pylab.title('B')
  cmap = matplotlib.colors.LinearSegmentedColormap.from_list('', _CMAP_BLUE)
  matplotlib.pyplot.pcolormesh(lsc_maps[:, :, 3], cmap=cmap)
  matplotlib.pyplot.colorbar()

  matplotlib.pyplot.savefig(f'{test_name_with_log_path}_{plot_name}_cmaps.png')


def capture_scene_image(cam, props, name_with_log_path):
  """Take a picture of the scene on test FAIL."""
  req = capture_request_utils.auto_capture_request()
  img = convert_capture_to_rgb_image(
      cam.do_capture(req, cam.CAP_YUV), props=props)
  write_image(img, f'{name_with_log_path}_scene.jpg', True)


def convert_image_to_uint8(image):
  image = image*255
  return image.astype(numpy.uint8)


def assert_props_is_not_none(props):
  if not props:
    raise AssertionError('props is None')


def assert_capture_width_and_height(cap, width, height):
  if cap['width'] != width or cap['height'] != height:
    raise AssertionError(
        'Unexpected capture WxH size, expected [{}x{}], actual [{}x{}]'.format(
            width, height, cap['width'], cap['height']
        )
    )


def apply_lens_shading_map(color_plane, black_level, white_level, lsc_map):
  """Apply the lens shading map to the color plane.

  Args:
    color_plane: 2D np array for color plane with values [0.0, 1.0].
    black_level: float; black level for the color plane.
    white_level: int; full scale for the color plane.
    lsc_map: 2D np array lens shading matching size of color_plane.

  Returns:
    color_plane with lsc applied.
  """
  logging.debug('color plane pre-lsc min, max: %.4f, %.4f',
                numpy.min(color_plane), numpy.max(color_plane))
  color_plane = (numpy.multiply((color_plane * white_level - black_level),
                                lsc_map)
                 + black_level) / white_level
  logging.debug('color plane post-lsc min, max: %.4f, %.4f',
                numpy.min(color_plane), numpy.max(color_plane))
  return color_plane


def populate_lens_shading_map(img_shape, lsc_map):
  """Helper function to create LSC coeifficients for RAW image.

  Args:
    img_shape: tuple; RAW image shape.
    lsc_map: 2D low resolution array with lens shading map values.

  Returns:
    value for lens shading map at point (x, y) in the image.
  """
  img_w, img_h = img_shape[1], img_shape[0]
  map_w, map_h = lsc_map.shape[1], lsc_map.shape[0]

  x, y = numpy.meshgrid(numpy.arange(img_w), numpy.arange(img_h))

  # (u,v) is lsc map location, values [0, map_w-1], [0, map_h-1]
  # Vectorized calculations
  u = x * (map_w - 1) / (img_w - 1)
  v = y * (map_h - 1) / (img_h - 1)
  u_min = numpy.floor(u).astype(int)
  v_min = numpy.floor(v).astype(int)
  u_frac = u - u_min
  v_frac = v - v_min
  u_max = numpy.where(u_frac > 0, u_min + 1, u_min)
  v_max = numpy.where(v_frac > 0, v_min + 1, v_min)

  # Gather LSC values, handling edge cases (optional)
  lsc_tl = lsc_map[(v_min, u_min)]
  lsc_tr = lsc_map[(v_min, u_max)]
  lsc_bl = lsc_map[(v_max, u_min)]
  lsc_br = lsc_map[(v_max, u_max)]

  # Bilinear interpolation (vectorized)
  lsc_t = lsc_tl * (1 - u_frac) + lsc_tr * u_frac
  lsc_b = lsc_bl * (1 - u_frac) + lsc_br * u_frac

  return lsc_t * (1 - v_frac) + lsc_b * v_frac


def unpack_lsc_map_from_metadata(metadata):
  """Get lens shading correction map from metadata and turn into 3D array.

  Args:
    metadata: dict; metadata from RAW capture.

  Returns:
    3D numpy array of lens shading maps.
  """
  lsc_metadata = metadata['android.statistics.lensShadingCorrectionMap']
  lsc_map_w, lsc_map_h = lsc_metadata['width'], lsc_metadata['height']
  lsc_map = lsc_metadata['map']
  logging.debug(
      'lensShadingCorrectionMap (H, W): (%d, %d)', lsc_map_h, lsc_map_w
  )
  return numpy.array(lsc_map).reshape(lsc_map_h, lsc_map_w, _NUM_RAW_CHANNELS)


def convert_raw_capture_to_rgb_image(cap_raw, props, raw_fmt,
                                     log_path_with_name):
  """Convert a RAW captured image object to a RGB image.

  Args:
    cap_raw: A RAW capture object as returned by its_session_utils.do_capture.
    props: camera properties object (of static values).
    raw_fmt: string of type 'raw', 'raw10', 'raw12'.
    log_path_with_name: string with test name and save location.

  Returns:
    RGB float-3 image array, with pixel values in [0.0, 1.0].
  """
  shading_mode = cap_raw['metadata']['android.shading.mode']
  lens_shading_map_mode = cap_raw[
      'metadata'].get('android.statistics.lensShadingMapMode')
  lens_shading_applied = props['android.sensor.info.lensShadingApplied']
  control_af_mode = cap_raw['metadata']['android.control.afMode']
  focus_distance = cap_raw['metadata']['android.lens.focusDistance']
  logging.debug('%s capture AF mode: %s', raw_fmt, control_af_mode)
  logging.debug('%s capture focus distance: %s', raw_fmt, focus_distance)
  logging.debug('%s capture shading mode: %d', raw_fmt, shading_mode)
  logging.debug('lensShadingMapApplied: %r', lens_shading_applied)
  logging.debug('lensShadingMapMode: %s', lens_shading_map_mode)

  # Split RAW to RGB conversion in 2 to allow LSC application (if needed).
  r, gr, gb, b = convert_capture_to_planes(cap_raw, props=props)

  # get from metadata, upsample, and apply
  if lens_shading_map_mode == LENS_SHADING_MAP_ON:
    logging.debug('Applying lens shading map')
    plot_name_stem_with_log_path = f'{log_path_with_name}_{raw_fmt}'
    black_levels = get_black_levels(props, cap_raw)
    white_level = int(props['android.sensor.info.whiteLevel'])
    lsc_maps = unpack_lsc_map_from_metadata(cap_raw['metadata'])
    plot_lsc_maps(lsc_maps, 'metadata', plot_name_stem_with_log_path)
    lsc_map_fs_r = populate_lens_shading_map(r.shape, lsc_maps[:, :, 0])
    lsc_map_fs_gr = populate_lens_shading_map(gr.shape, lsc_maps[:, :, 1])
    lsc_map_fs_gb = populate_lens_shading_map(gb.shape, lsc_maps[:, :, 2])
    lsc_map_fs_b = populate_lens_shading_map(b.shape, lsc_maps[:, :, 3])
    plot_lsc_maps(
        numpy.dstack((lsc_map_fs_r, lsc_map_fs_gr, lsc_map_fs_gb,
                      lsc_map_fs_b)),
        'fullscale', plot_name_stem_with_log_path
    )
    r = apply_lens_shading_map(
        r[:, :, 0], black_levels[0], white_level, lsc_map_fs_r
    )
    gr = apply_lens_shading_map(
        gr[:, :, 0], black_levels[1], white_level, lsc_map_fs_gr
    )
    gb = apply_lens_shading_map(
        gb[:, :, 0], black_levels[2], white_level, lsc_map_fs_gb
    )
    b = apply_lens_shading_map(
        b[:, :, 0], black_levels[3], white_level, lsc_map_fs_b
    )
  img = convert_raw_to_rgb_image(r, gr, gb, b, props, cap_raw['metadata'])
  return img


def convert_capture_to_rgb_image(cap,
                                 props=None,
                                 apply_ccm_raw_to_rgb=True):
  """Convert a captured image object to a RGB image.

  Args:
     cap: A capture object as returned by its_session_utils.do_capture.
     props: (Optional) camera properties object (of static values);
            required for processing raw images.
     apply_ccm_raw_to_rgb: (Optional) boolean to apply color correction matrix.

  Returns:
        RGB float-3 image array, with pixel values in [0.0, 1.0].
  """
  w = cap['width']
  h = cap['height']
  if cap['format'] == 'raw10' or cap['format'] == 'raw10QuadBayer':
    assert_props_is_not_none(props)
    is_quad_bayer = cap['format'] == 'raw10QuadBayer'
    cap = unpack_raw10_capture(cap, is_quad_bayer)

  if cap['format'] == 'raw12':
    assert_props_is_not_none(props)
    cap = unpack_raw12_capture(cap)

  if cap['format'] == 'yuv':
    y = cap['data'][0: w * h]
    u = cap['data'][w * h: w * h * 5//4]
    v = cap['data'][w * h * 5//4: w * h * 6//4]
    return convert_yuv420_planar_to_rgb_image(y, u, v, w, h)
  elif cap['format'] == 'jpeg' or cap['format'] == 'jpeg_r':
    return decompress_jpeg_to_rgb_image(cap['data'])
  elif (cap['format'] in ('raw', 'rawQuadBayer') or
        cap['format'] in noise_model_constants.VALID_RAW_STATS_FORMATS):
    assert_props_is_not_none(props)
    r, gr, gb, b = convert_capture_to_planes(cap, props)
    return convert_raw_to_rgb_image(
        r, gr, gb, b, props, cap['metadata'], apply_ccm_raw_to_rgb)
  elif cap['format'] == 'y8':
    y = cap['data'][0: w * h]
    return convert_y8_to_rgb_image(y, w, h)
  else:
    raise error_util.CameraItsError(f"Invalid format {cap['format']}")


def unpack_raw10_capture(cap, is_quad_bayer=False):
  """Unpack a raw-10 capture to a raw-16 capture.

  Args:
    cap: A raw-10 capture object.
    is_quad_bayer: Boolean flag for Bayer or Quad Bayer capture.

  Returns:
    New capture object with raw-16 data.
  """
  # Data is packed as 4x10b pixels in 5 bytes, with the first 4 bytes holding
  # the MSBs of the pixels, and the 5th byte holding 4x2b LSBs.
  w, h = cap['width'], cap['height']
  if w % 4 != 0:
    raise error_util.CameraItsError('Invalid raw-10 buffer width')
  cap = copy.deepcopy(cap)
  cap['data'] = unpack_raw10_image(cap['data'].reshape(h, w * 5 // 4))
  cap['format'] = 'rawQuadBayer' if is_quad_bayer else 'raw'
  return cap


def unpack_raw10_image(img):
  """Unpack a raw-10 image to a raw-16 image.

  Output image will have the 10 LSBs filled in each 16b word, and the 6 MSBs
  will be set to zero.

  Args:
    img: A raw-10 image, as a uint8 numpy array.

  Returns:
    Image as a uint16 numpy array, with all row padding stripped.
  """
  if img.shape[1] % 5 != 0:
    raise error_util.CameraItsError('Invalid raw-10 buffer width')
  w = img.shape[1] * 4 // 5
  h = img.shape[0]
  # Cut out the 4x8b MSBs and shift to bits [9:2] in 16b words.
  msbs = numpy.delete(img, numpy.s_[4::5], 1)
  msbs = msbs.astype(numpy.uint16)
  msbs = numpy.left_shift(msbs, 2)
  msbs = msbs.reshape(h, w)
  # Cut out the 4x2b LSBs and put each in bits [1:0] of their own 8b words.
  lsbs = img[::, 4::5].reshape(h, w // 4)
  lsbs = numpy.right_shift(
      numpy.packbits(numpy.unpackbits(lsbs).reshape((h, w // 4, 4, 2)), 3), 6)
  # Pair the LSB bits group to 0th pixel instead of 3rd pixel
  lsbs = lsbs.reshape(h, w // 4, 4)[:, :, ::-1]
  lsbs = lsbs.reshape(h, w)
  # Fuse the MSBs and LSBs back together
  img16 = numpy.bitwise_or(msbs, lsbs).reshape(h, w)
  return img16


def unpack_raw12_capture(cap):
  """Unpack a raw-12 capture to a raw-16 capture.

  Args:
    cap: A raw-12 capture object.

  Returns:
     New capture object with raw-16 data.
  """
  # Data is packed as 4x10b pixels in 5 bytes, with the first 4 bytes holding
  # the MSBs of the pixels, and the 5th byte holding 4x2b LSBs.
  w, h = cap['width'], cap['height']
  if w % 2 != 0:
    raise error_util.CameraItsError('Invalid raw-12 buffer width')
  cap = copy.deepcopy(cap)
  cap['data'] = unpack_raw12_image(cap['data'].reshape(h, w * 3 // 2))
  cap['format'] = 'raw'
  return cap


def unpack_raw12_image(img):
  """Unpack a raw-12 image to a raw-16 image.

  Output image will have the 12 LSBs filled in each 16b word, and the 4 MSBs
  will be set to zero.

  Args:
   img: A raw-12 image, as a uint8 numpy array.

  Returns:
    Image as a uint16 numpy array, with all row padding stripped.
  """
  if img.shape[1] % 3 != 0:
    raise error_util.CameraItsError('Invalid raw-12 buffer width')
  w = img.shape[1] * 2 // 3
  h = img.shape[0]
  # Cut out the 2x8b MSBs and shift to bits [11:4] in 16b words.
  msbs = numpy.delete(img, numpy.s_[2::3], 1)
  msbs = msbs.astype(numpy.uint16)
  msbs = numpy.left_shift(msbs, 4)
  msbs = msbs.reshape(h, w)
  # Cut out the 2x4b LSBs and put each in bits [3:0] of their own 8b words.
  lsbs = img[::, 2::3].reshape(h, w // 2)
  lsbs = numpy.right_shift(
      numpy.packbits(numpy.unpackbits(lsbs).reshape((h, w // 2, 2, 4)), 3), 4)
  # Pair the LSB bits group to pixel 0 instead of pixel 1
  lsbs = lsbs.reshape(h, w // 2, 2)[:, :, ::-1]
  lsbs = lsbs.reshape(h, w)
  # Fuse the MSBs and LSBs back together
  img16 = numpy.bitwise_or(msbs, lsbs).reshape(h, w)
  return img16


def convert_yuv420_planar_to_rgb_image(y_plane, u_plane, v_plane,
                                       w, h,
                                       ccm_yuv_to_rgb=DEFAULT_YUV_TO_RGB_CCM,
                                       yuv_off=DEFAULT_YUV_OFFSETS):
  """Convert a YUV420 8-bit planar image to an RGB image.

  Args:
    y_plane: The packed 8-bit Y plane.
    u_plane: The packed 8-bit U plane.
    v_plane: The packed 8-bit V plane.
    w: The width of the image.
    h: The height of the image.
    ccm_yuv_to_rgb: (Optional) the 3x3 CCM to convert from YUV to RGB.
    yuv_off: (Optional) offsets to subtract from each of Y,U,V values.

  Returns:
    RGB float-3 image array, with pixel values in [0.0, 1.0].
  """
  y = numpy.subtract(y_plane, yuv_off[0])
  u = numpy.subtract(u_plane, yuv_off[1]).view(numpy.int8)
  v = numpy.subtract(v_plane, yuv_off[2]).view(numpy.int8)
  u = u.reshape(h // 2, w // 2).repeat(2, axis=1).repeat(2, axis=0)
  v = v.reshape(h // 2, w // 2).repeat(2, axis=1).repeat(2, axis=0)
  yuv = numpy.dstack([y, u.reshape(w * h), v.reshape(w * h)])
  flt = numpy.empty([h, w, 3], dtype=numpy.float32)
  flt.reshape(w * h * 3)[:] = yuv.reshape(h * w * 3)[:]
  flt = numpy.dot(flt.reshape(w * h, 3), ccm_yuv_to_rgb.T).clip(0, 255)
  rgb = numpy.empty([h, w, 3], dtype=numpy.uint8)
  rgb.reshape(w * h * 3)[:] = flt.reshape(w * h * 3)[:]
  return rgb.astype(numpy.float32) / 255.0


def decompress_jpeg_to_rgb_image(jpeg_buffer):
  """Decompress a JPEG-compressed image, returning as an RGB image.

  Args:
    jpeg_buffer: The JPEG stream.

  Returns:
     A numpy array for the RGB image, with pixels in [0,1].
  """
  img = Image.open(io.BytesIO(jpeg_buffer))
  w = img.size[0]
  h = img.size[1]
  return numpy.array(img).reshape((h, w, 3)) / 255.0


def decompress_jpeg_to_yuv_image(jpeg_buffer):
  """Decompress a JPEG-compressed image, returning as a YUV image.

  Args:
    jpeg_buffer: The JPEG stream.

  Returns:
     A numpy array for the YUV image, with pixels in [0,1].
  """
  img = Image.open(io.BytesIO(jpeg_buffer))
  img = img.convert('YCbCr')
  w = img.size[0]
  h = img.size[1]
  return numpy.array(img).reshape((h, w, 3)) / 255.0


def extract_luma_from_patch(cap, patch_x, patch_y, patch_w, patch_h):
  """Extract luma from capture."""
  y, _, _ = convert_capture_to_planes(cap)
  patch = get_image_patch(y, patch_x, patch_y, patch_w, patch_h)
  luma = compute_image_means(patch)[0]
  return luma


def convert_image_to_numpy_array(image_path):
  """Converts image at image_path to numpy array and returns the array.

  Args:
    image_path: file path

  Returns:
    numpy array
  """
  if not os.path.exists(image_path):
    raise AssertionError(f'{image_path} does not exist.')
  image = Image.open(image_path)
  return numpy.array(image)


def _convert_quad_bayer_img_to_bayer_channels(quad_bayer_img, props=None):
  """Convert a quad Bayer image to the Bayer image channels.

  Args:
      quad_bayer_img: The quad Bayer image.
      props: The camera properties.

  Returns:
      A list of reordered standard Bayer channels of the Bayer image.
  """
  height, width, num_channels = quad_bayer_img.shape

  if num_channels != noise_model_constants.NUM_QUAD_BAYER_CHANNELS:
    raise AssertionError(
        'The number of channels in the quad Bayer image must be '
        f'{noise_model_constants.NUM_QUAD_BAYER_CHANNELS}.'
    )
  quad_bayer_cfa_order = get_canonical_cfa_order(props, is_quad_bayer=True)

  # Bayer channels are in the order of R, Gr, Gb and B.
  bayer_channels = []
  for ch in range(4):
    channel_img = numpy.zeros(shape=(height, width), dtype='<f')
    # Average every four quad Bayer channels into a standard Bayer channel.
    for i in quad_bayer_cfa_order[4 * ch: 4 * (ch + 1)]:
      channel_img[:, :] += quad_bayer_img[:, :, i]
    bayer_channels.append(channel_img / 4)
  return bayer_channels


def subsample(image, num_channels=4):
  """Subsamples the image to separate its color channels.

  Args:
    image:        2-D numpy array of raw image.
    num_channels: The number of channels in the image.

  Returns:
    3-D numpy image with each channel separated.
  """
  if num_channels not in noise_model_constants.VALID_NUM_CHANNELS:
    raise error_util.CameraItsError(
        f'Invalid number of channels {num_channels}, which should be in '
        f'{noise_model_constants.VALID_NUM_CHANNELS}.'
    )

  size_h, size_v = image.shape[1], image.shape[0]

  # Subsample step size, which is the horizontal or vertical pixel interval
  # between two adjacent pixels of the same channel.
  stride = int(numpy.sqrt(num_channels))
  subsample_img = lambda img, i, h, v, s: img[i // s: v: s, i % s: h: s]
  channel_img = numpy.empty((
      image.shape[0] // stride,
      image.shape[1] // stride,
      num_channels,
  ))

  for i in range(num_channels):
    sub_img = subsample_img(image, i, size_h, size_v, stride)
    channel_img[:, :, i] = sub_img

  return channel_img


def convert_capture_to_planes(cap, props=None):
  """Convert a captured image object to separate image planes.

  Decompose an image into multiple images, corresponding to different planes.

  For YUV420 captures ("yuv"):
        Returns Y,U,V planes, where the Y plane is full-res and the U,V planes
        are each 1/2 x 1/2 of the full res.

    For standard Bayer or quad Bayer captures ("raw", "raw10", "raw12",
    "rawQuadBayer", "rawStats", "rawQuadBayerStats", "raw10QuadBayer",
    "raw10Stats", "raw10QuadBayerStats"):
        Returns planes in the order R, Gr, Gb, B, regardless of the Bayer
        pattern layout.
        For full-res raw images ("raw", "rawQuadBayer", "raw10",
        "raw10QuadBayer", "raw12"), each plane is 1/2 x 1/2 of the full res.
        For standard Bayer stats images, the mean image is returned.
        For quad Bayer stats images, the average mean image is returned.

    For JPEG captures ("jpeg"):
        Returns R,G,B full-res planes.

  Args:
    cap: A capture object as returned by its_session_utils.do_capture.
    props: (Optional) camera properties object (of static values);
            required for processing raw images.

  Returns:
    A tuple of float numpy arrays (one per plane), consisting of pixel values
    in the range [0.0, 1.0].
  """
  w = cap['width']
  h = cap['height']
  if cap['format'] in ('raw10', 'raw10QuadBayer'):
    assert_props_is_not_none(props)
    is_quad_bayer = cap['format'] == 'raw10QuadBayer'
    cap = unpack_raw10_capture(cap, is_quad_bayer)

  if cap['format'] == 'raw12':
    assert_props_is_not_none(props)
    cap = unpack_raw12_capture(cap)
  if cap['format'] == 'yuv':
    y = cap['data'][0:w * h]
    u = cap['data'][w * h:w * h * 5 // 4]
    v = cap['data'][w * h * 5 // 4:w * h * 6 // 4]
    return ((y.astype(numpy.float32) / 255.0).reshape(h, w, 1),
            (u.astype(numpy.float32) / 255.0).reshape(h // 2, w // 2, 1),
            (v.astype(numpy.float32) / 255.0).reshape(h // 2, w // 2, 1))
  elif cap['format'] == 'jpeg':
    rgb = decompress_jpeg_to_rgb_image(cap['data']).reshape(w * h * 3)
    return (rgb[::3].reshape(h, w, 1), rgb[1::3].reshape(h, w, 1),
            rgb[2::3].reshape(h, w, 1))
  elif cap['format'] in ('raw', 'rawQuadBayer'):
    assert_props_is_not_none(props)
    is_quad_bayer = 'QuadBayer' in cap['format']
    white_level = get_white_level(props, cap['metadata'])
    img = numpy.ndarray(
        shape=(h * w,), dtype='<u2', buffer=cap['data'][0:w * h * 2])
    img = img.astype(numpy.float32).reshape(h, w) / white_level
    if is_quad_bayer:
      pixel_array_size = props.get(
          'android.sensor.info.pixelArraySizeMaximumResolution'
      )
      active_array_size = props.get(
          'android.sensor.info.preCorrectionActiveArraySizeMaximumResolution'
      )
    else:
      pixel_array_size = props.get('android.sensor.info.pixelArraySize')
      active_array_size = props.get(
          'android.sensor.info.preCorrectionActiveArraySize'
      )
    # Crop the raw image to the active array region.
    if pixel_array_size and active_array_size:
      # Note that the Rect class is defined such that the left,top values
      # are "inside" while the right,bottom values are "outside"; that is,
      # it's inclusive of the top,left sides only. So, the width is
      # computed as right-left, rather than right-left+1, etc.
      wfull = pixel_array_size['width']
      hfull = pixel_array_size['height']
      xcrop = active_array_size['left']
      ycrop = active_array_size['top']
      wcrop = active_array_size['right'] - xcrop
      hcrop = active_array_size['bottom'] - ycrop
      if not wfull >= wcrop >= 0:
        raise AssertionError(f'wcrop: {wcrop} not in wfull: {wfull}')
      if not hfull >= hcrop >= 0:
        raise AssertionError(f'hcrop: {hcrop} not in hfull: {hfull}')
      if not wfull - wcrop >= xcrop >= 0:
        raise AssertionError(f'xcrop: {xcrop} not in wfull-crop: {wfull-wcrop}')
      if not hfull - hcrop >= ycrop >= 0:
        raise AssertionError(f'ycrop: {ycrop} not in hfull-crop: {hfull-hcrop}')
      if w == wfull and h == hfull:
        # Crop needed; extract the center region.
        img = img[ycrop:ycrop + hcrop, xcrop:xcrop + wcrop]
        w = wcrop
        h = hcrop
      elif w == wcrop and h == hcrop:
        logging.debug('Image is already cropped. No cropping needed.')
      else:
        raise error_util.CameraItsError('Invalid image size metadata')

    idxs = get_canonical_cfa_order(props, is_quad_bayer)
    if is_quad_bayer:
      # Subsample image array based on the color map.
      quad_bayer_img = subsample(
          img, noise_model_constants.NUM_QUAD_BAYER_CHANNELS
      )
      bayer_channels = _convert_quad_bayer_img_to_bayer_channels(
          quad_bayer_img, props
      )
      return bayer_channels
    else:
      # Separate the image planes.
      imgs = [
          img[::2].reshape(w * h // 2)[::2].reshape(h // 2, w // 2, 1),
          img[::2].reshape(w * h // 2)[1::2].reshape(h // 2, w // 2, 1),
          img[1::2].reshape(w * h // 2)[::2].reshape(h // 2, w // 2, 1),
          img[1::2].reshape(w * h // 2)[1::2].reshape(h // 2, w // 2, 1),
      ]
      return [imgs[i] for i in idxs]
  elif cap['format'] in (
      'rawStats',
      'raw10Stats',
      'rawQuadBayerStats',
      'raw10QuadBayerStats',
  ):
    assert_props_is_not_none(props)
    is_quad_bayer = 'QuadBayer' in cap['format']
    white_level = get_white_level(props, cap['metadata'])
    if is_quad_bayer:
      num_channels = noise_model_constants.NUM_QUAD_BAYER_CHANNELS
    else:
      num_channels = noise_model_constants.NUM_BAYER_CHANNELS
    mean_image, _ = unpack_rawstats_capture(cap, num_channels)
    if is_quad_bayer:
      bayer_channels = _convert_quad_bayer_img_to_bayer_channels(
          mean_image, props
      )
      bayer_channels = [
          bayer_channels[i] / white_level for i in range(len(bayer_channels))
      ]
      return bayer_channels
    else:
      # Standard Bayer canonical color channel indices.
      idxs = get_canonical_cfa_order(props, is_quad_bayer=False)
      # Normalizes the range to [0, 1] without subtracting the black level.
      return [mean_image[:, :, i] / white_level for i in idxs]
  else:
    raise error_util.CameraItsError(f"Invalid format {cap['format']}")


def downscale_image(img, f):
  """Shrink an image by a given integer factor.

  This function computes output pixel values by averaging over rectangular
  regions of the input image; it doesn't skip or sample pixels, and all input
  image pixels are evenly weighted.

  If the downscaling factor doesn't cleanly divide the width and/or height,
  then the remaining pixels on the right or bottom edge are discarded prior
  to the downscaling.

  Args:
    img: The input image as an ndarray.
    f: The downscaling factor, which should be an integer.

  Returns:
    The new (downscaled) image, as an ndarray.
  """
  h, w, chans = img.shape
  f = int(f)
  assert f >= 1
  h = (h//f)*f
  w = (w//f)*f
  img = img[0:h:, 0:w:, ::]
  chs = []
  for i in range(chans):
    ch = img.reshape(h*w*chans)[i::chans].reshape(h, w)
    ch = ch.reshape(h, w//f, f).mean(2).reshape(h, w//f)
    ch = ch.T.reshape(w//f, h//f, f).mean(2).T.reshape(h//f, w//f)
    chs.append(ch.reshape(h*w//(f*f)))
  img = numpy.vstack(chs).T.reshape(h//f, w//f, chans)
  return img


def convert_raw_to_rgb_image(r_plane, gr_plane, gb_plane, b_plane, props,
                             cap_res, apply_ccm_raw_to_rgb=True):
  """Convert a Bayer raw-16 image to an RGB image.

  Includes some extremely rudimentary demosaicking and color processing
  operations; the output of this function shouldn't be used for any image
  quality analysis.

  Args:
   r_plane:
   gr_plane:
   gb_plane:
   b_plane: Numpy arrays for each color plane
            in the Bayer image, with pixels in the [0.0, 1.0] range.
   props: Camera properties object.
   cap_res: Capture result (metadata) object.
   apply_ccm_raw_to_rgb: (Optional) boolean to apply color correction matrix.

  Returns:
   RGB float-3 image array, with pixel values in [0.0, 1.0]
  """
  # Values required for the RAW to RGB conversion.
  assert_props_is_not_none(props)
  white_level = get_white_level(props, cap_res)
  gains = cap_res['android.colorCorrection.gains']
  ccm = cap_res['android.colorCorrection.transform']

  # Reorder black levels and gains to R,Gr,Gb,B, to match the order
  # of the planes.
  black_levels = get_black_levels(props, cap_res, is_quad_bayer=False)
  logging.debug('dynamic black levels: %s', black_levels)
  gains = get_gains_in_canonical_order(props, gains)

  # Convert CCM from rational to float, as numpy arrays.
  ccm = numpy.array(capture_request_utils.rational_to_float(ccm)).reshape(3, 3)

  # Need to scale the image back to the full [0,1] range after subtracting
  # the black level from each pixel.
  scale = white_level / (white_level - max(black_levels))

  # Three-channel black levels, normalized to [0,1] by white_level.
  black_levels = numpy.array(
      [b / white_level for b in [black_levels[i] for i in [0, 1, 3]]])

  # Three-channel gains.
  gains = numpy.array([gains[i] for i in [0, 1, 3]])

  h, w = r_plane.shape[:2]
  img = numpy.dstack([r_plane, (gr_plane + gb_plane) / 2.0, b_plane])
  img = (((img.reshape(h, w, 3) - black_levels) * scale) * gains).clip(0.0, 1.0)
  if apply_ccm_raw_to_rgb:
    img = numpy.dot(
        img.reshape(w * h, 3), ccm.T).reshape((h, w, 3)).clip(0.0, 1.0)
  return img


def convert_y8_to_rgb_image(y_plane, w, h):
  """Convert a Y 8-bit image to an RGB image.

  Args:
    y_plane: The packed 8-bit Y plane.
    w: The width of the image.
    h: The height of the image.

  Returns:
    RGB float-3 image array, with pixel values in [0.0, 1.0].
  """
  y3 = numpy.dstack([y_plane, y_plane, y_plane])
  rgb = numpy.empty([h, w, 3], dtype=numpy.uint8)
  rgb.reshape(w * h * 3)[:] = y3.reshape(w * h * 3)[:]
  return rgb.astype(numpy.float32) / 255.0


def write_rgb_uint8_image(img, file_name):
  """Save a uint8 numpy array image to a file.

  Supported formats: PNG, JPEG, and others; see PIL docs for more.

  Args:
   img: numpy image array data.
   file_name: path of file to save to; the extension specifies the format.
  """
  if img.dtype != 'uint8':
    raise AssertionError(f'Incorrect input type: {img.dtype}! Expected: uint8')
  else:
    Image.fromarray(img, 'RGB').save(file_name)


def write_image(img, fname, apply_gamma=False, is_yuv=False):
  """Save a float-3 numpy array image to a file.

  Supported formats: PNG, JPEG, and others; see PIL docs for more.

  Image can be 3-channel, which is interpreted as RGB or YUV, or can be
  1-channel, which is greyscale.

  Can optionally specify that the image should be gamma-encoded prior to
  writing it out; this should be done if the image contains linear pixel
  values, to make the image look "normal".

  Args:
   img: Numpy image array data.
   fname: Path of file to save to; the extension specifies the format.
   apply_gamma: (Optional) apply gamma to the image prior to writing it.
   is_yuv: Whether the image is in YUV format.
  """
  if apply_gamma:
    img = apply_lut_to_image(img, DEFAULT_GAMMA_LUT)
  (h, w, chans) = img.shape
  if chans == 3:
    if not is_yuv:
      Image.fromarray((img * 255.0).astype(numpy.uint8), 'RGB').save(fname)
    else:
      Image.fromarray((img * 255.0).astype(numpy.uint8), 'YCbCr').save(fname)
  elif chans == 1:
    img3 = (img * 255.0).astype(numpy.uint8).repeat(3).reshape(h, w, 3)
    Image.fromarray(img3, 'RGB').save(fname)
  else:
    raise error_util.CameraItsError('Unsupported image type')


def read_image(fname):
  """Read image function to match write_image() above."""
  return Image.open(fname)


def apply_lut_to_image(img, lut):
  """Applies a LUT to every pixel in a float image array.

  Internally converts to a 16b integer image, since the LUT can work with up
  to 16b->16b mappings (i.e. values in the range [0,65535]). The lut can also
  have fewer than 65536 entries, however it must be sized as a power of 2
  (and for smaller luts, the scale must match the bitdepth).

  For a 16b lut of 65536 entries, the operation performed is:

  lut[r * 65535] / 65535 -> r'
  lut[g * 65535] / 65535 -> g'
  lut[b * 65535] / 65535 -> b'

  For a 10b lut of 1024 entries, the operation becomes:

  lut[r * 1023] / 1023 -> r'
  lut[g * 1023] / 1023 -> g'
  lut[b * 1023] / 1023 -> b'

  Args:
    img: Numpy float image array, with pixel values in [0,1].
    lut: Numpy table encoding a LUT, mapping 16b integer values.

  Returns:
    Float image array after applying LUT to each pixel.
  """
  n = len(lut)
  if n <= 0 or n > MAX_LUT_SIZE or (n & (n - 1)) != 0:
    raise error_util.CameraItsError(f'Invalid arg LUT size: {n}')
  m = float(n - 1)
  return (lut[(img * m).astype(numpy.uint16)] / m).astype(numpy.float32)


def get_gains_in_canonical_order(props, gains):
  """Reorders the gains tuple to the canonical R,Gr,Gb,B order.

  Args:
    props: Camera properties object.
    gains: List of 4 values, in R,G_even,G_odd,B order.

  Returns:
    List of gains values, in R,Gr,Gb,B order.
  """
  cfa_pat = props['android.sensor.info.colorFilterArrangement']
  if cfa_pat in [0, 1]:
    # RGGB or GRBG, so G_even is Gr
    return gains
  elif cfa_pat in [2, 3]:
    # GBRG or BGGR, so G_even is Gb
    return [gains[0], gains[2], gains[1], gains[3]]
  else:
    raise error_util.CameraItsError('Not supported')


def get_white_level(props, cap_metadata=None):
  """Gets white level to use for a given capture.

  Uses a dynamic value from the capture result if available, else falls back
  to the static global value in the camera characteristics.

  Args:
    props: The camera properties object.
    cap_metadata: A capture results metadata object.

  Returns:
    Float white level value.
  """
  if (cap_metadata is not None and
      'android.sensor.dynamicWhiteLevel' in cap_metadata and
      cap_metadata['android.sensor.dynamicWhiteLevel'] is not None):
    white_level = cap_metadata['android.sensor.dynamicWhiteLevel']
    logging.debug('dynamic white level: %.2f', white_level)
  else:
    white_level = props['android.sensor.info.whiteLevel']
    logging.debug('white level: %.2f', white_level)
  return float(white_level)


def get_black_levels(props, cap=None, is_quad_bayer=False):
  """Gets black levels to use for a given capture.

  Uses a dynamic value from the capture result if available, else falls back
  to the static global value in the camera characteristics.

  Args:
    props: The camera properties object.
    cap: A capture object.
    is_quad_bayer: Boolean flag for Bayer or Quad Bayer capture.

  Returns:
    A list of black level values reordered in canonical order.
  """
  if (cap is not None and
      'android.sensor.dynamicBlackLevel' in cap and
      cap['android.sensor.dynamicBlackLevel'] is not None):
    black_levels = cap['android.sensor.dynamicBlackLevel']
  else:
    black_levels = props['android.sensor.blackLevelPattern']

  idxs = get_canonical_cfa_order(props, is_quad_bayer)
  if is_quad_bayer:
    ordered_black_levels = [black_levels[i // 4] for i in idxs]
  else:
    ordered_black_levels = [black_levels[i] for i in idxs]
  return ordered_black_levels


def get_canonical_cfa_order(props, is_quad_bayer=False):
  """Returns a list of channel indices according to color filter arrangement.

  Color filter arrangement index is a integer ranging from 0 to 3, which maps
  the color filter arrangement in the following way.
    0: R, Gr, Gb, B,
    1: Gr, R, B, Gb,
    2: Gb, B, R, Gr,
    3: B, Gb, Gr, R.

  This function return a list of channel indices that can be used to reorder
  the stats data as the canonical order:
    (1) For standard Bayer: R, Gr, Gb, B.
    (2) For quad Bayer: R0, R1, R2, R3,
                        Gr0, Gr1, Gr2, Gr3,
                        Gb0, Gb1, Gb2, Gb3,
                        B0, B1, B2, B3.

  Args:
    props: Camera properties object.
    is_quad_bayer: Boolean flag for Bayer or Quad Bayer capture.

  Returns:
    A list of channel indices with values ranging from:
      (1) [0, 3] for standard Bayer,
      (2) [0, 15] for quad Bayer.
  """
  cfa_pat = props['android.sensor.info.colorFilterArrangement']
  if not 0 <= cfa_pat < 4:
    raise error_util.CameraItsError('Not supported')

  channel_indices = []
  if is_quad_bayer:
    color_map = noise_model_constants.QUAD_BAYER_COLOR_FILTER_MAP[cfa_pat]
    for ch in noise_model_constants.BAYER_COLORS:
      channel_indices.extend(color_map[ch])
  else:
    color_map = noise_model_constants.BAYER_COLOR_FILTER_MAP[cfa_pat]
    channel_indices = [
        color_map[ch] for ch in noise_model_constants.BAYER_COLORS
    ]
  return channel_indices


def unpack_rawstats_capture(cap, num_channels=4):
  """Unpacks a stats image capture to the mean and variance images.

  Args:
    cap: A capture object as returned by its_session_utils.do_capture.
    num_channels: The number of color channels in the stats image capture, which
      can be one of noise_model_constants.VALID_NUM_CHANNELS.

  Returns:
    Tuple (mean_image var_image) of float-4 images, with non-normalized
    pixel values computed from the RAW10/RAW16 images on the device
  """
  if cap['format'] not in noise_model_constants.VALID_RAW_STATS_FORMATS:
    raise AssertionError(f"Unsupported stats format: {cap['format']}")

  if num_channels not in noise_model_constants.VALID_NUM_CHANNELS:
    raise AssertionError(
        f'Unsupported number of channels {num_channels}, which should be in'
        f' {noise_model_constants.VALID_NUM_CHANNELS}.'
    )

  w = cap['width']
  h = cap['height']
  img = numpy.ndarray(
      shape=(2 * h * w * num_channels,), dtype='<f', buffer=cap['data']
  )
  analysis_image = img.reshape((2, h, w, num_channels))
  mean_image = analysis_image[0, :, :, :].reshape(h, w, num_channels)
  var_image = analysis_image[1, :, :, :].reshape(h, w, num_channels)
  return mean_image, var_image


def get_image_patch(img, xnorm, ynorm, wnorm, hnorm):
  """Get a patch (tile) of an image.

  Args:
   img: Numpy float image array, with pixel values in [0,1].
   xnorm:
   ynorm:
   wnorm:
   hnorm: Normalized (in [0,1]) coords for the tile.

  Returns:
     Numpy float image array of the patch.
  """
  hfull = img.shape[0]
  wfull = img.shape[1]
  xtile = int(math.ceil(xnorm * wfull))
  ytile = int(math.ceil(ynorm * hfull))
  wtile = int(math.floor(wnorm * wfull))
  htile = int(math.floor(hnorm * hfull))
  if len(img.shape) == 2:
    return img[ytile:ytile + htile, xtile:xtile + wtile].copy()
  else:
    return img[ytile:ytile + htile, xtile:xtile + wtile, :].copy()


def compute_image_means(img):
  """Calculate the mean of each color channel in the image.

  Args:
    img: Numpy float image array, with pixel values in [0,1].

  Returns:
     A list of mean values, one per color channel in the image.
  """
  means = []
  chans = img.shape[2]
  for i in range(chans):
    means.append(numpy.mean(img[:, :, i], dtype=numpy.float64))
  return means


def compute_image_variances(img):
  """Calculate the variance of each color channel in the image.

  Args:
    img: Numpy float image array, with pixel values in [0,1].

  Returns:
    A list of variance values, one per color channel in the image.
  """
  variances = []
  chans = img.shape[2]
  for i in range(chans):
    variances.append(numpy.var(img[:, :, i], dtype=numpy.float64))
  return variances


def compute_image_sharpness(img):
  """Calculate the sharpness of input image.

  Args:
    img: numpy float RGB/luma image array, with pixel values in [0,1].

  Returns:
    Sharpness estimation value based on the average of gradient magnitude.
    Larger value means the image is sharper.
  """
  chans = img.shape[2]
  if chans != 1 and chans != 3:
    raise AssertionError(f'Not RGB or MONO image! depth: {chans}')
  if chans == 1:
    luma = img[:, :, 0]
  else:
    luma = convert_rgb_to_grayscale(img)
  gy, gx = numpy.gradient(luma)
  return numpy.average(numpy.sqrt(gy*gy + gx*gx))


def compute_image_max_gradients(img):
  """Calculate the maximum gradient of each color channel in the image.

  Args:
    img: Numpy float image array, with pixel values in [0,1].

  Returns:
    A list of gradient max values, one per color channel in the image.
  """
  grads = []
  chans = img.shape[2]
  for i in range(chans):
    grads.append(numpy.amax(numpy.gradient(img[:, :, i])))
  return grads


def compute_image_snrs(img):
  """Calculate the SNR (dB) of each color channel in the image.

  Args:
    img: Numpy float image array, with pixel values in [0,1].

  Returns:
    A list of SNR values in dB, one per color channel in the image.
  """
  means = compute_image_means(img)
  variances = compute_image_variances(img)
  std_devs = [math.sqrt(v) for v in variances]
  snrs = [20 * math.log10(m/s) for m, s in zip(means, std_devs)]
  return snrs


def convert_rgb_to_grayscale(img):
  """Convert a 3-D array RGB image to grayscale image.

  Args:
    img: numpy 3-D array RGB image of type [0.0, 1.0] float or [0, 255] uint8.

  Returns:
    2-D grayscale image of same type as input.
  """
  chans = img.shape[2]
  if chans != 3:
    raise AssertionError(f'Not an RGB image! Depth: {chans}')
  img_gray = numpy.dot(img[..., :3], RGB2GRAY_WEIGHTS)
  if img.dtype == 'uint8':
    return img_gray.round().astype(numpy.uint8)
  else:
    return img_gray


def normalize_img(img):
  """Normalize the image values to between 0 and 1.

  Args:
    img: 2-D numpy array of image values
  Returns:
    Normalized image
  """
  return (img - numpy.amin(img))/(numpy.amax(img) - numpy.amin(img))


def rotate_img_per_argv(img):
  """Rotate an image 180 degrees if "rotate" is in argv.

  Args:
    img: 2-D numpy array of image values
  Returns:
    Rotated image
  """
  img_out = img
  if 'rotate180' in sys.argv:
    img_out = numpy.fliplr(numpy.flipud(img_out))
  return img_out


def compute_image_rms_difference_1d(rgb_x, rgb_y):
  """Calculate the RMS difference between 2 RBG images as 1D arrays.

  Args:
    rgb_x: image array
    rgb_y: image array

  Returns:
    rms_diff
  """
  len_rgb_x = len(rgb_x)
  len_rgb_y = len(rgb_y)
  if len_rgb_y != len_rgb_x:
    raise AssertionError('RGB images have different number of planes! '
                         f'x: {len_rgb_x}, y: {len_rgb_y}')
  return math.sqrt(sum([pow(rgb_x[i] - rgb_y[i], 2.0)
                        for i in range(len_rgb_x)]) / len_rgb_x)


def compute_image_rms_difference_3d(rgb_x, rgb_y):
  """Calculate the RMS difference between 2 RBG images as 3D arrays.

  Args:
    rgb_x: image array in the form of w * h * channels
    rgb_y: image array in the form of w * h * channels

  Returns:
    rms_diff
  """
  shape_rgb_x = numpy.shape(rgb_x)
  shape_rgb_y = numpy.shape(rgb_y)
  if shape_rgb_y != shape_rgb_x:
    raise AssertionError('RGB images have different number of planes! '
                         f'x: {shape_rgb_x}, y: {shape_rgb_y}')
  if len(shape_rgb_x) != 3:
    raise AssertionError(f'RGB images dimension {len(shape_rgb_x)} is not 3!')

  mean_square_sum = 0.0
  for i in range(shape_rgb_x[0]):
    for j in range(shape_rgb_x[1]):
      for k in range(shape_rgb_x[2]):
        mean_square_sum += pow(float(rgb_x[i][j][k]) - float(rgb_y[i][j][k]),
                               2.0)
  return (math.sqrt(mean_square_sum /
                    (shape_rgb_x[0] * shape_rgb_x[1] * shape_rgb_x[2])))


def compute_image_sad(img_x, img_y):
  """Calculate the sum of absolute differences between 2 images.

  Args:
    img_x: image array in the form of w * h * channels
    img_y: image array in the form of w * h * channels

  Returns:
    sad
  """
  img_x = img_x[:, :, 1:].ravel()
  img_y = img_y[:, :, 1:].ravel()
  return numpy.sum(numpy.abs(numpy.subtract(img_x, img_y, dtype=float)))


def get_img(buffer):
  """Return a PIL.Image of the capture buffer.

  Args:
    buffer: data field from the capture result.

  Returns:
    A PIL.Image
  """
  return Image.open(io.BytesIO(buffer))


def jpeg_has_icc_profile(jpeg_img):
  """Checks if a jpeg PIL.Image has an icc profile attached.

  Args:
    jpeg_img: The PIL.Image.

  Returns:
    True if an icc profile is present, False otherwise.
  """
  return jpeg_img.info.get('icc_profile') is not None


def get_primary_chromaticity(primary):
  """Given an ImageCms primary, returns just the xy chromaticity coordinates.

  Args:
    primary: The primary from the ImageCms profile.

  Returns:
    (float, float): The xy chromaticity coordinates of the primary.
  """
  ((_, _, _), (x, y, _)) = primary
  return x, y


def is_jpeg_icc_profile_correct(jpeg_img, color_space, icc_profile_path=None):
  """Compare a jpeg's icc profile to a color space's expected parameters.

  Args:
    jpeg_img: The PIL.Image.
    color_space: 'DISPLAY_P3' or 'SRGB'
    icc_profile_path: Optional path to an icc file to be created with the
        raw contents.

  Returns:
    True if the icc profile matches expectations, False otherwise.
  """
  icc = jpeg_img.info.get('icc_profile')
  f = io.BytesIO(icc)
  icc_profile = ImageCms.getOpenProfile(f)

  if icc_profile_path is not None:
    raw_icc_bytes = f.getvalue()
    f = open(icc_profile_path, 'wb')
    f.write(raw_icc_bytes)
    f.close()

  cms_profile = icc_profile.profile
  (rx, ry) = get_primary_chromaticity(cms_profile.red_primary)
  (gx, gy) = get_primary_chromaticity(cms_profile.green_primary)
  (bx, by) = get_primary_chromaticity(cms_profile.blue_primary)

  if color_space == 'DISPLAY_P3':
    # Expected primaries based on Apple's Display P3 primaries
    expected_rx = EXPECTED_RX_P3
    expected_ry = EXPECTED_RY_P3
    expected_gx = EXPECTED_GX_P3
    expected_gy = EXPECTED_GY_P3
    expected_bx = EXPECTED_BX_P3
    expected_by = EXPECTED_BY_P3
  elif color_space == 'SRGB':
    # Expected primaries based on Pixel sRGB profile
    expected_rx = EXPECTED_RX_SRGB
    expected_ry = EXPECTED_RY_SRGB
    expected_gx = EXPECTED_GX_SRGB
    expected_gy = EXPECTED_GY_SRGB
    expected_bx = EXPECTED_BX_SRGB
    expected_by = EXPECTED_BY_SRGB
  else:
    # Unsupported color space for comparison
    return False

  cmp_values = [
      [rx, expected_rx],
      [ry, expected_ry],
      [gx, expected_gx],
      [gy, expected_gy],
      [bx, expected_bx],
      [by, expected_by]
  ]

  for (actual, expected) in cmp_values:
    if not math.isclose(actual, expected, abs_tol=0.001):
      # Values significantly differ
      return False

  return True


def area_of_triangle(x1, y1, x2, y2, x3, y3):
  """Calculates the area of a triangle formed by three points.

  Args:
    x1 (float): The x-coordinate of the first point.
    y1 (float): The y-coordinate of the first point.
    x2 (float): The x-coordinate of the second point.
    y2 (float): The y-coordinate of the second point.
    x3 (float): The x-coordinate of the third point.
    y3 (float): The y-coordinate of the third point.

  Returns:
    float: The area of the triangle.
  """
  area = abs((x1 * (y2 - y3) + x2 * (y3 - y1) + x3 * (y1 - y2)) / 2.0)
  return area


def point_in_triangle(x1, y1, x2, y2, x3, y3, xp, yp, abs_tol):
  """Checks if the point (xp, yp) is inside the triangle.

  Args:
    x1 (float): The x-coordinate of the first point.
    y1 (float): The y-coordinate of the first point.
    x2 (float): The x-coordinate of the second point.
    y2 (float): The y-coordinate of the second point.
    x3 (float): The x-coordinate of the third point.
    y3 (float): The y-coordinate of the third point.
    xp (float): The x-coordinate of the point to check.
    yp (float): The y-coordinate of the point to check.
    abs_tol (float): Absolute tolerance amount.

  Returns:
    bool: True if the point is inside the triangle, False otherwise.
  """
  a = area_of_triangle(x1, y1, x2, y2, x3, y3)
  a1 = area_of_triangle(xp, yp, x2, y2, x3, y3)
  a2 = area_of_triangle(x1, y1, xp, yp, x3, y3)
  a3 = area_of_triangle(x1, y1, x2, y2, xp, yp)
  return math.isclose(a, (a1 + a2 + a3), abs_tol=abs_tol)


def distance(p, q):
  """Returns the Euclidean distance from point p to point q.

  Args:
    p: an Iterable of numbers
    q: an Iterable of numbers
  """
  return math.sqrt(sum((px - qx) ** 2.0 for px, qx in zip(p, q)))


def p3_img_has_wide_gamut(wide_img):
  """Check if a DISPLAY_P3 image contains wide gamut pixels.

  Given a DISPLAY_P3 image that should have a wider gamut than SRGB, checks all
  pixel values to see if any reside outside the SRGB gamut. This is done by
  converting to CIE xy chromaticities using a Bradford chromatic adaptation for
  consistency with ICC profiles.

  Args:
    wide_img: The PIL.Image in the DISPLAY_P3 color space.

  Returns:
    True if the gamut of wide_img is greater than that of SRGB.
    False otherwise.
  """
  w = wide_img.size[0]
  h = wide_img.size[1]
  wide_arr = numpy.array(wide_img)

  img_arr = colour.RGB_to_XYZ(
      wide_arr / 255.0,
      colour.models.rgb.datasets.display_p3.RGB_COLOURSPACE_DISPLAY_P3.whitepoint,
      colour.models.rgb.datasets.display_p3.RGB_COLOURSPACE_DISPLAY_P3.whitepoint,
      colour.models.rgb.datasets.display_p3.RGB_COLOURSPACE_DISPLAY_P3.matrix_RGB_to_XYZ,
      'Bradford', lambda x: colour.eotf(x, 'sRGB'))

  xy_arr = colour.XYZ_to_xy(img_arr)

  srgb_colorspace = colour.models.RGB_COLOURSPACE_sRGB
  srgb_primaries = srgb_colorspace.primaries

  for y in range(h):
    for x in range(w):
      # Check if the pixel chromaticity is inside or outside the SRGB gamut.
      # This check is not guaranteed not to emit false positives / negatives,
      # however the probability of either on an arbitrary DISPLAY_P3 camera
      # capture is exceedingly unlikely.
      if not point_in_triangle(*srgb_primaries.reshape(6),
                               xy_arr[y][x][0], xy_arr[y][x][1],
                               COLORSPACE_TRIANGLE_AREA_TOL):
        return True

  return False


def convert_image_coords_to_sensor_coords(
    aa_width, aa_height, coords, img_width, img_height):
  """Transform image coordinates to sensor coordinate system.

  Calculate the difference between sensor active array and image aspect ratio.
  Taking the difference into account, figure out if the width or height has been
  cropped. Using this information, transform the image coordinates to sensor
  coordinates.

  Args:
    aa_width: int; active array width.
    aa_height: int; active array height.
    coords: coordinates; a pair of (x, y) coordinates from image.
    img_width: int; width of image.
    img_height: int; height of image.
  Returns:
    sensor_coords: coordinates; corresponding coordinates on
      sensor coordinate system.
  """
  # TODO: b/330382627 - find out if distortion correction is ON/OFF
  aa_aspect_ratio = aa_width / aa_height
  image_aspect_ratio = img_width / img_height
  if aa_aspect_ratio >= image_aspect_ratio:
    # If aa aspect ratio is greater than image aspect ratio, then
    # sensor width is being cropped
    aspect_ratio_multiplication_factor = aa_height / img_height
    crop_width = img_width * aspect_ratio_multiplication_factor
    buffer = (aa_width - crop_width) / 2
    sensor_coords = (coords[0] * aspect_ratio_multiplication_factor + buffer,
                     coords[1] * aspect_ratio_multiplication_factor)
  else:
    # If aa aspect ratio is less than image aspect ratio, then
    # sensor height is being cropped
    aspect_ratio_multiplication_factor = aa_width / img_width
    crop_height = img_height * aspect_ratio_multiplication_factor
    buffer = (aa_height - crop_height) / 2
    sensor_coords = (coords[0] * aspect_ratio_multiplication_factor,
                     coords[1] * aspect_ratio_multiplication_factor + buffer)
  logging.debug('Sensor coordinates: %s', sensor_coords)
  return sensor_coords
