Bugfixes to make gradient descent finally incrementally sane

This commit is contained in:
Ian Gulliver
2017-08-10 21:28:27 -07:00
parent e30eb932c9
commit a8413180bd
2 changed files with 12 additions and 6 deletions

View File

@@ -116,6 +116,7 @@ template <uint32_t P, uint32_t LUT_X, uint32_t LUT_Y, uint32_t LUT_Z, uint32_t I
uint32_t OptimizeLut(const Image<IMG_X, IMG_Y, C>& image, Lut3d<LUT_X, LUT_Y, LUT_Z>* lut) { uint32_t OptimizeLut(const Image<IMG_X, IMG_Y, C>& image, Lut3d<LUT_X, LUT_Y, LUT_Z>* lut) {
static_assert(C == 3); static_assert(C == 3);
auto snapshot = *lut;
uint32_t diff = 0; uint32_t diff = 0;
for (uint32_t x = 0; x < LUT_X; ++x) { for (uint32_t x = 0; x < LUT_X; ++x) {
@@ -130,16 +131,21 @@ uint32_t OptimizeLut(const Image<IMG_X, IMG_Y, C>& image, Lut3d<LUT_X, LUT_Y, LU
std::cout << Coord<3>{{{x, y, z}}} << std::endl; std::cout << Coord<3>{{{x, y, z}}} << std::endl;
for (uint32_t c = 0; c < C; ++c) { for (uint32_t c = 0; c < C; ++c) {
auto& channel = color.at(c);
auto min = FindPossibleMinimum<uint32_t, uint32_t, 4>( auto min = FindPossibleMinimum<uint32_t, uint32_t, 4>(
0, UINT16_MAX, 0, UINT16_MAX,
[&image, &lut, x, y, z, c](uint32_t val) { [&image, &snapshot, x, y, z, c](uint32_t val) {
auto test_lut = *lut; auto test_lut = snapshot;
test_lut.at(x).at(y).at(z).at(c) = val; test_lut.at(x).at(y).at(z).at(c) = val;
return ScoreImage(*test_lut.MapImage(image)); return ScoreImage(*test_lut.MapImage(image));
}); });
std::cout << "\tC" << c << ": " << color.at(c) << " -> " << min << std::endl; // Magic value of 8 is the number of points making up a square, so the number
diff += AbsDiff(color.at(c), min); // of points that control any given given LUT mapping.
color.at(c) = min; auto new_value = Interpolate(channel, min, UINT32_C(1), UINT32_C(8));
std::cout << "\tC" << c << ": " << channel << " -> " << new_value << " (interpolated from " << min << ")" << std::endl;
diff += AbsDiff(channel, new_value);
channel = new_value;
} }
} }
} }

View File

@@ -10,6 +10,6 @@ constexpr T Interpolate(T val0, T val1, T mul, T div) {
if (val1 > val0) { if (val1 > val0) {
return val0 + ((mul * (val1 - val0)) / div); return val0 + ((mul * (val1 - val0)) / div);
} else { } else {
return val0 - (((div - mul) * (val0 - val1)) / div); return val0 - ((mul * (val0 - val1)) / div);
} }
} }