diff --git a/colorchecker.h b/colorchecker.h index c40e2de..8188cf6 100644 --- a/colorchecker.h +++ b/colorchecker.h @@ -112,7 +112,7 @@ std::unique_ptr> HighlightClosest(const Image& image) { return out; } -template +template int32_t OptimizeLut(const Image& image, Lut3d* lut) { static_assert(C == 3); @@ -153,3 +153,37 @@ int32_t OptimizeLut(const Image& image, Lut3d +int32_t OptimizeLut(const Image& image, Lut1d* lut) { + static_assert(C == 3); + + auto snapshot = *lut; + int32_t diff = 0; + + for (int32_t x = 0; x < LUT_X; ++x) { + auto& color = lut->at(x); + + std::cout << Coord<1>{{{{x}}}} << std::endl; + + for (int32_t c = 0; c < C; ++c) { + auto& channel = color.at(c); + + auto min = FindPossibleMinimum( + -UINT16_MAX, UINT16_MAX * 2, + [&image, &snapshot, x, c](int32_t val) { + auto test_lut = snapshot; + test_lut.at(x).at(c) = val; + return ScoreLut(image, test_lut); + }); + // Magic value of 8 is the number of points making up a square, so the number + // of points that control any given given LUT mapping. + auto new_value = Interpolate(channel, min, INT32_C(1), INT32_C(8)); + std::cout << "\tC" << c << ": " << channel << " -> " << new_value << " (interpolated from " << min << ")" << std::endl; + diff += AbsDiff(channel, new_value); + channel = new_value; + } + } + + return diff; +} diff --git a/image.h b/image.h index 8614b4f..d084742 100644 --- a/image.h +++ b/image.h @@ -37,14 +37,14 @@ template void Image::DrawXLine(const Coord<2>& coord, const Color& color, int32_t length) { auto& row = this->at(coord.at(1)); - for (int32_t x = coord.at(0); x < std::min(X, coord.at(0) + length); ++x) { + for (int32_t x = coord.at(0); x < std::min(X - 1, coord.at(0) + length); ++x) { row.at(x) = color; } } template void Image::DrawYLine(const Coord<2>& coord, const Color& color, int32_t length) { - for (int32_t y = coord.at(1); y <= std::min(Y, coord.at(1) + length); ++y) { + for (int32_t y = coord.at(1); y <= std::min(Y - 1, coord.at(1) + length); ++y) { SetPixel({{{{coord.at(0), y}}}}, color); } } diff --git a/lut.h b/lut.h index f8d12aa..7fb49b8 100644 --- a/lut.h +++ b/lut.h @@ -15,6 +15,10 @@ class LutBase { template std::unique_ptr> MapImage(const Image& in) const; + + protected: + static constexpr std::pair FindChannelRoot(int32_t value, int32_t points); + static constexpr int32_t BlockSize(int32_t points); }; template @@ -31,6 +35,59 @@ std::unique_ptr> LutBase::MapImage(const Image& in) cons return out; } +constexpr int32_t LutBase::BlockSize(int32_t points) { + return (kMaxColor + 1) / (points - 1); +} + +constexpr std::pair LutBase::FindChannelRoot(int32_t value, int32_t points) { + // points - 1 is the last point index. Since we're going to find the region + // around this point by adding to the root, we need to be at least 1 less + // than that. + int32_t index = std::min(points - 2, value / BlockSize(points)); + return std::make_pair(index, value - (index * BlockSize(points))); +} + + +template +class Lut1d : public Array, X>, public LutBase { + public: + static Lut1d Identity(); + + Color<3> MapColor(const Color<3>& in) const override; +}; + +typedef Lut1d<2> MinimalLut1d; + +template +Lut1d Lut1d::Identity() { + Lut1d ret; + + Color<3> color; + for (int32_t x = 0; x < X; ++x) { + color.at(0) = color.at(1) = color.at(2) = std::min(kMaxColor, BlockSize(X) * x); + ret.at(x) = color; + } + + return ret; +} + +template +Color<3> Lut1d::MapColor(const Color<3>& in) const { + Color<3> ret; + + for (int32_t c = 0; c < 3; ++c) { + const auto root_rem = FindChannelRoot(in.at(c), X); + const auto& root = root_rem.first; + const auto& rem = root_rem.second; + ret.at(c) = Interpolate( + this->at(root + 0).at(c), + this->at(root + 1).at(c), + rem, BlockSize(X)); + } + + return ret; +} + template class Lut3d : public Array, X>, Y>, Z>, public LutBase { @@ -42,12 +99,8 @@ class Lut3d : public Array, X>, Y>, Z>, public LutBase { private: // Return value is (root_indices, remainders) constexpr static std::pair, Coord<3>> FindRoot(const Color<3>& in); - constexpr static std::pair FindChannelRoot(int32_t value, int32_t points); - - constexpr static int32_t BlockSize(int32_t points); }; -// Minimum size LUT typedef Lut3d<2, 2, 2> MinimalLut3d; template @@ -116,17 +169,3 @@ constexpr std::pair, Coord<3>> Lut3d::FindRoot(const Color<3>& {{{{root_x.second, root_y.second, root_z.second}}}}, }; } - -template -constexpr std::pair Lut3d::FindChannelRoot(const int32_t value, const int32_t points) { - // points - 1 is the last point index. Since we're going to fidn the cube - // around this point by adding to the root, we need to be at least 1 less - // than that. - int32_t index = std::min(points - 2, value / BlockSize(points)); - return std::make_pair(index, value - (index * BlockSize(points))); -} - -template -constexpr int32_t Lut3d::BlockSize(int32_t points) { - return (kMaxColor + 1) / (points - 1); -} diff --git a/piphoto.cc b/piphoto.cc index c50fa03..1b559f4 100644 --- a/piphoto.cc +++ b/piphoto.cc @@ -9,12 +9,12 @@ int main() { auto image = PiRaw2::FromJpeg(ReadFile("test.jpg")); WriteFile("start.png", HighlightClosest(*image)->ToPng()); - auto lut = MinimalLut3d::Identity(); + auto lut = MinimalLut1d::Identity(); std::cout << "Initial error: " << ScoreLut(*image, lut) << std::endl; int32_t diff = 1; while (diff) { - diff = OptimizeLut<4>(*image, &lut); + diff = OptimizeLut(*image, &lut); std::cout << "diff=" << diff << " error=" << ScoreLut(*image, lut) << std::endl; WriteFile("inter.png", HighlightClosest(*lut.MapImage(*image))->ToPng()); }