#
# Copyright (C) 2019 The Android Open Source Project
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

layout = BoolScalar("layout", False) # NHWC

model = Model()
i1 = Input("scores", "TENSOR_FLOAT32", "{1, 2, 2, 2}") # scores
i2 = Input("bboxDeltas", "TENSOR_FLOAT32", "{1, 2, 2, 8}") # bounding box deltas
i3 = Input("anchors", "TENSOR_FLOAT32", "{2, 4}") # anchors
i4 = Input("imageInfo", "TENSOR_FLOAT32", "{1, 2}") # image info
o1 = Output("scoresOut", "TENSOR_FLOAT32", "{4}") # scores out
o2 = Output("roiOut", "TENSOR_FLOAT32", "{4, 4}") # roi out
o3 = Output("batchSplit", "TENSOR_INT32", "{4}") # batch split out
model = model.Operation("GENERATE_PROPOSALS",
    i1, i2, i3, i4, 4.0, 4.0, -1, -1, 0.30, 1.0, layout).To(o1, o2, o3)

quant8_signed = DataTypeConverter().Identify({
    i1: ("TENSOR_QUANT8_ASYMM_SIGNED", 0.01, -28),
    i2: ("TENSOR_QUANT8_ASYMM_SIGNED", 0.05, 0),
    i3: ("TENSOR_QUANT16_SYMM", 0.125, 0),
    i4: ("TENSOR_QUANT16_ASYMM", 0.125, 0),
    o1: ("TENSOR_QUANT8_ASYMM_SIGNED", 0.01, -28),
    o2: ("TENSOR_QUANT16_ASYMM", 0.125, 0)
})

input0 = {
    i1: [   # scores
        0.8, 0.9, 0.85, 0.85,
        0.75, 0.8, 0.9, 0.95
    ],
    i2: [   # bounding box deltas
        0.5, 0.1, 0.1, 0.1, 0.5, 0.1, 0.5, 0.1,
        -0.25, 0.1, -0.1, -0.1, -0.25, 0.1, 0.2, 0.1,
        0.4, -0.1, -0.2, 0.2, 0.4, -0.1, -0.2, 0.2,
        -0.2, -0.2, 0.2, 0.2, -0.2, -0.2, 0.2, 0.2
    ],
    i3: [0, 1, 4, 3, 1, 0, 3, 4],    # anchors
    i4: [32, 32],  # image info
}

output0 = {
    o1: [0.95, 0.9, 0.85, 0.8],  # scores out
    o2: [   # roi out
        4.3785973,  2.7571943 , 6.8214025,  7.642805,
        1.3512788,  0.18965816, 4.648721 ,  4.610342,
        3.1903253,  1.2951627 , 6.8096747,  3.1048374,
        1.9812691,  3.1571944 , 3.6187308,  8.042806
    ],
    o3: [0, 0, 0, 0]
}

Example((input0, output0)).AddNchw(i1, i2, layout).AddVariations(quant8_signed, includeDefault=False)

#######################################################

model = Model()
i1 = Input("scores", "TENSOR_FLOAT32", "{2, 4, 4, 4}") # scores
i2 = Input("bboxDeltas", "TENSOR_FLOAT32", "{2, 4, 4, 16}") # bounding box deltas
i3 = Input("anchors", "TENSOR_FLOAT32", "{4, 4}") # anchors
i4 = Input("imageInfo", "TENSOR_FLOAT32", "{2, 2}") # image info
o1 = Output("scoresOut", "TENSOR_FLOAT32", "{30}") # scores out
o2 = Output("roiOut", "TENSOR_FLOAT32", "{30, 4}") # roi out
o3 = Output("batchSplit", "TENSOR_INT32", "{30}") # batch split out
model = model.Operation("GENERATE_PROPOSALS",
    i1, i2, i3, i4, 10.0, 10.0, 32, 16, 0.20, 1.0, layout).To(o1, o2, o3)

quant8_signed = DataTypeConverter().Identify({
    i1: ("TENSOR_QUANT8_ASYMM_SIGNED", 0.005, -128),
    i2: ("TENSOR_QUANT8_ASYMM_SIGNED", 0.1, 0),
    i3: ("TENSOR_QUANT16_SYMM", 0.125, 0),
    i4: ("TENSOR_QUANT16_ASYMM", 0.125, 0),
    o1: ("TENSOR_QUANT8_ASYMM_SIGNED", 0.005, -128),
    o2: ("TENSOR_QUANT16_ASYMM", 0.125, 0)
})

input0 = {
    i1: [   # scores
        0.885, 0.21 , 0.78 , 0.57 ,
        0.795, 0.66 , 0.915, 0.615,
        0.27 , 0.69 , 0.645, 0.945,
        0.465, 0.345, 0.855, 0.555,
        0.48 , 0.6  , 0.735, 0.63 ,
        0.495, 0.03 , 0.12 , 0.225,
        0.24 , 0.285, 0.51 , 0.315,
        0.435, 0.255, 0.585, 0.06 ,
        0.9  , 0.75 , 0.18 , 0.45 ,
        0.36 , 0.09 , 0.405, 0.15 ,
        0.   , 0.195, 0.075, 0.81 ,
        0.87 , 0.93 , 0.39 , 0.165,
        0.825, 0.525, 0.765, 0.105,
        0.54 , 0.705, 0.675, 0.3  ,
        0.42 , 0.045, 0.33 , 0.015,
        0.84 , 0.135, 0.72 , 0.375,
        0.495, 0.315, 0.195, 0.24 ,
        0.21 , 0.54 , 0.78 , 0.72 ,
        0.045, 0.93 , 0.27 , 0.735,
        0.135, 0.09 , 0.81 , 0.705,
        0.39 , 0.885, 0.42 , 0.945,
        0.9  , 0.225, 0.75 , 0.3  ,
        0.375, 0.63 , 0.825, 0.675,
        0.015, 0.48 , 0.645, 0.615,
        0.33 , 0.465, 0.66 , 0.6  ,
        0.075, 0.84 , 0.285, 0.57 ,
        0.585, 0.165, 0.06 , 0.36 ,
        0.795, 0.855, 0.105, 0.45 ,
        0.   , 0.87 , 0.525, 0.255,
        0.69 , 0.555, 0.15 , 0.345,
        0.03 , 0.915, 0.405, 0.435,
        0.765, 0.12 , 0.51 , 0.18
    ],
    i2: [   # bounding box deltas
    -1.9,  0.4,  1.4,  0.5, -1.5, -0.2,  0.3,  1.2,  0. , -0.6,  0.4, -1.3,  0.8,  0.9, -0.2,  0.8,
    -0.2,  0. ,  0.4,  0.1, -0.2, -1.6, -0.6, -0.1, -1. ,  0.6,  0.5, -0.2, -1.7, -1.4,  0.5, -0.1,
    -1.5,  1.3, -0.7, -0.9,  0.9,  0.2, -0.2,  0. , -0.7,  0.3, -0.4, -0.3, -0.5, -0.3,  1. , -0.7,
     1.2, -0.3,  0. ,  0.3, -0.7,  1. , -0.2, -0.6, -1.3,  0. ,  0.3,  0.1,  0.4,  0.2,  2.4,  0. ,
     0.1,  0. ,  0.7, -0.9,  0.1, -0.4,  0.3, -0.3, -0.7,  0.1,  0.7,  0. , -0.3,  1.6,  0. ,  1.1,
     0.4, -0.7, -0.9,  0. ,  0. ,  0.4, -0.6,  0.4, -1.9, -1.2,  0. , -0.3,  0.2,  0. ,  0.1,  0.8,
     0. ,  0.9, -1.7,  0.3,  0.7, -0.7,  0.7,  1.2, -0.4, -0.1, -0.6,  0.6, -0.4, -0.2,  0.3, -0.5,
     0. ,  1. , -0.1, -0.3, -0.8,  0.1, -1.2, -2.4,  0.1,  1.4,  0.4,  0.1, -1.1,  0.4, -0.4, -0.2,
     0.1,  0. ,  0.7,  0.1, -1.3,  0.1, -0.4, -0.2,  0.2,  0.1, -0.8,  0. , -1.4,  2. , -0.6, -0.5,
     0. ,  1. , -1.4, -1.1,  0.6, -0.7,  0.4,  1.1, -1.1,  1.6, -0.3,  0. , -0.7,  0.3, -1.3,  0. ,
     0. ,  0. , -0.3,  0. , -1.1, -1.5,  0.9, -1.4, -0.7,  0.1, -1.4,  0.9,  0.1,  0.2, -0.1, -1.7,
     0.2, -0.3, -0.9,  1.1,  0.1,  1. ,  1. , -0.9,  0.7,  0. , -0.3,  0.2, -0.8, -0.5,  0.6, -1.2,
     1. ,  0.6,  0. , -1.6,  0.1, -1.2,  0.7,  0.8,  0.5, -0.2, -0.8, -1.3, -0.3,  0. ,  0. ,  0.3,
    -0.6, -0.3,  1.3,  0.1,  2.2,  1.2, -1.1,  0.1,  1.2,  1.2,  1.3, -0.9,  0.1, -0.5,  0.1, -0.7,
    -1.3,  1.3,  0.1,  2. ,  0. ,  0.2,  0.6,  0. , -0.1, -0.4, -0.5,  0.1, -0.6, -0.3,  0.2, -0.4,
    -0.4, -0.7, -1.8,  0.4, -0.7,  0.4,  1.4, -0.3,  0.8,  0. ,  0.4, -0.1, -1. ,  0.2,  0.5, -0.6,
    -1.1,  0.2,  1.6, -0.2, -0.4, -0.9,  0. ,  0.3,  0. ,  0.3, -0.3,  0.3,  0.3,  1.9,  0.3, -0.5,
    -0.8, -1.3, -0.8,  0.2,  0.2, -0.4, -0.3,  0.6,  0.2, -0.2,  1.2,  0. ,  0. , -0.3,  0.3, -1.5,
    -1. , -0.3, -0.7, -0.3, -0.4, -1. , -0.6, -0.7, -0.2,  0.6, -0.3,  0.5, -0.2,  0.3, -0.5, -1.7,
     0. , -0.7, -0.1, -1.5, -0.9,  0.6,  0.3, -0.1,  0.2,  0.5,  0.6, -0.8, -0.3,  0.6,  0.9, -0.3,
     0.1, -1.7, -1.5,  0. , -0.1, -0.3,  0.7, -0.3, -0.4,  0. , -0.4, -0.3,  0.1,  1.1,  1.8, -0.9,
     0.6,  0.5,  0.2, -0.7,  0.2,  0.1,  1.2,  2.2,  0.3,  0.6,  0.4,  0.1,  0.2,  0. , -1.1, -0.2,
    -0.7,  0. , -1.2,  0.6, -0.6, -0.2, -0.4,  0. ,  0.7, -1.2,  0.8,  0. , -0.3,  0.2,  0.6, -1. ,
    -0.1, -0.1,  0. , -0.4, -0.2,  0.4, -1.4,  0.3,  0.1,  1.3, -0.2, -0.7,  0.6,  0.7,  0.6,  0.1,
    -0.4,  0.1, -0.2, -0.8,  0. , -1.3,  1.2,  1.4,  1.1,  0.5,  0.3,  0. ,  0.1, -0.4,  0.5, -0.1,
    -0.5,  0.3, -0.7,  0.9, -0.1, -0.4,  0.2, -0.8,  1. ,  1. ,  0.1,  0.1, -0.2,  0. , -0.4, -0.3,
    -0.8,  0.7, -0.9, -0.3, -0.3, -2.8,  1. ,  1.4,  0. , -2.6,  1.1, -1.1,  0.5,  0.1, -0.4, -1.5,
     0. ,  0.3, -0.3, -0.2,  0.7, -0.8, -0.1,  0.5,  0.7,  1.4, -1.2, -1. , -0.6,  0.2,  1.1, -0.9,
     0.7, -0.4,  0. ,  0. , -0.2, -0.2,  0.1,  0. ,  0. , -0.7, -0.7, -1.4, -0.9, -0.5, -0.6,  0.4,
     0.3,  0. ,  0.9, -0.2,  0.7,  1.2,  0.5,  0.8, -0.5,  1. ,  0.2, -0.5,  1.3, -0.5,  0.3,  1.2,
    -0.3, -0.1,  1.3,  0.2,  0.6, -1.4, -0.1, -0.2, -0.4, -0.9,  1.2, -0.9, -0.2, -1.2, -1. , -0.2,
    -1.6,  2.1, -0.6, -0.2, -0.3,  0.5,  0.9, -0.4,  0. , -0.1,  0.1, -0.6, -1. , -0.7,  0.2, -0.2
    ],
    i3: [    # anchors
        0, 6, 16, 10,
        6, 0, 10, 16,
        3, 5, 13, 11,
        5, 3, 11, 13
    ],
    i4: [64, 64, 32, 32],  # image info
}

output0 = {
    o1: [  # scores out
        0.945, 0.93 , 0.915, 0.9  , 0.87 , 0.84 , 0.81, 0.795, 0.78, 0.765, 0.75, 0.735,
        0.72 , 0.705, 0.69 , 0.675, 0.945, 0.915, 0.9 , 0.885, 0.87, 0.84 , 0.81, 0.78,
        0.735, 0.72 , 0.63 , 0.6  , 0.585, 0.54
    ],
    o2: [   # roi out
        16.845154 ,  2.5170734, 33.154846 ,  7.4829264,
        32.96344  , 40.747444 , 43.836563 , 47.252556 ,
         0.       ,  9.143808 , 16.243607 , 14.056192 ,
         0.       , 25.789658 , 25.710022 , 30.210342 ,
        37.947445 , 20.791668 , 44.452557 , 32.80833  ,
        30.277609 , 32.21635  , 32.92239  , 38.18365  ,
        25.885489 , 29.086582 , 31.314512 , 30.913418 ,
         2.8654022,  5.789658 , 26.734598 , 10.210342 ,
         0.5408764,  3.5824041, 15.459124 ,  5.217595 ,
        10.753355 , 35.982403 , 15.246645 , 37.617596 ,
         1.4593601, 23.050154 ,  4.1406403, 36.149845 ,
         0.       , 15.6      , 11.068764 , 21.6      ,
        38.54088  , 35.28549  , 53.45912  , 40.71451  ,
        26.134256 , 48.358635 , 27.465742 , 64.       ,
        29.96254  ,  3.1999998, 33.23746  , 19.2      ,
        11.653517 , 43.980293 , 48.34648  , 46.41971  ,
         0.       , 26.967152 , 26.748941 , 31.032848 ,
        28.590324 ,  9.050154 , 32.       , 22.149847 ,
        17.828777 , 19.00683  , 32.       , 20.99317  ,
         3.5724945,  7.273454 , 11.627505 , 19.126545 ,
         4.989658 , 26.8      ,  9.410341 , 32.       ,
        15.157195 , 18.00537  , 20.042807 , 25.194632 ,
        30.889404 ,  9.652013 , 32.       , 12.347987 ,
         3.399414 ,  3.8000002, 32.       ,  9.8      ,
        24.980408 , 10.086582 , 28.61959  , 11.913418 ,
        13.950423 ,  3.884349 , 22.049576 ,  6.115651 ,
        24.259361 ,  6.8      , 26.94064  , 22.8      ,
         3.6538367, 19.475813 , 13.546164 , 28.524187 ,
        11.947443 , 29.318363 , 18.452557 , 32.       ,
        17.318363 ,  0.       , 20.281635 , 16.17695
    ],
    o3: [
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
    ]
}

Example((input0, output0)).AddNchw(i1, i2, layout).AddVariations(quant8_signed, includeDefault=False)
