Tuples (original) (raw)

// Halide tutorial lesson 13: Tuples

// This lesson describes how to write Funcs that evaluate to multiple // values.

// On linux, you can compile and run it like so: // g++ lesson_13*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -lpthread -ldl -o lesson_13 -std=c++17 // LD_LIBRARY_PATH=<path/to/libHalide.so> ./lesson_13

// On os x: // g++ lesson_13*.cpp -g -I <path/to/Halide.h> -L <path/to/libHalide.so> -lHalide -o lesson_13 -std=c++17 // DYLD_LIBRARY_PATH=<path/to/libHalide.dylib> ./lesson_13

// If you have the entire Halide source tree, you can also build it by // running: // make tutorial_lesson_13_tuples // in a shell with the current directory at the top of the halide // source tree.

#include "Halide.h" #include #include <stdio.h> using namespace Halide;

int main(int argc, char **argv) {

// So far Funcs (such as the one below) have evaluated to a single
// scalar value for each point in their domain.
Func single_valued;
Var x, y;
single_valued(x, y) = x + y;

// One way to write a Func that returns a collection of values is
// to add an additional dimension that indexes that
// collection. This is how we typically deal with color. For
// example, the Func below represents a collection of three values
// for every x, y coordinate indexed by c.
Func color_image;
Var c;
color_image(x, y, c) = select(c == 0, 245,  // Red value
                              c == 1, 42,   // Green value
                              132);         // Blue value

// Since this pattern appears quite often, Halide provides a
// syntatic sugar to write the code above as the following,
// using the "mux" function.
// color_image(x, y, c) = mux(c, {245, 42, 132});

// This method is often convenient because it makes it easy to
// operate on this Func in a way that treats each item in the
// collection equally:
Func brighter;
brighter(x, y, c) = color_image(x, y, c) + 10;

// However this method is also inconvenient for three reasons.
//
// 1) Funcs are defined over an infinite domain, so users of this
// Func can for example access color_image(x, y, -17), which is
// not a meaningful value and is probably indicative of a bug.
//
// 2) It requires a select, which can impact performance if not
// bounded and unrolled:
// brighter.bound(c, 0, 3).unroll(c);
//
// 3) With this method, all values in the collection must have the
// same type. While the above two issues are merely inconvenient,
// this one is a hard limitation that makes it impossible to
// express certain things in this way.

// It is also possible to represent a collection of values as a
// collection of Funcs:
Func func_array[3];
func_array[0](x, y) = x + y;
func_array[1](x, y) = sin(x);
func_array[2](x, y) = cos(y);

// This method avoids the three problems above, but introduces a
// new annoyance. Because these are separate Funcs, it is
// difficult to schedule them so that they are all computed
// together inside a single loop over x, y.

// A third alternative is to define a Func as evaluating to a
// Tuple instead of an Expr. A Tuple is a fixed-size collection of
// Exprs. Each Expr in a Tuple may have a different type. The
// following function evaluates to an integer value (x+y), and a
// floating point value (sin(x*y)).
Func multi_valued;
multi_valued(x, y) = Tuple(x + y, sin(x * y));

// Realizing a tuple-valued Func returns a collection of
// Buffers. We call this a Realization. It's equivalent to a
// std::vector of Buffer objects:
{
    Realization r = multi_valued.realize({80, 60});
    assert(r.size() == 2);
    Buffer<int> im0 = r[0];
    Buffer<float> im1 = r[1];
    assert(im0(30, 40) == 30 + 40);
    assert(im1(30, 40) == sinf(30 * 40));
}

// All Tuple elements are evaluated together over the same domain
// in the same loop nest, but stored in distinct allocations. The
// equivalent C++ code to the above is:
{
    int multi_valued_0[80 * 60];
    float multi_valued_1[80 * 60];
    for (int y = 0; y < 80; y++) {
        for (int x = 0; x < 60; x++) {
            multi_valued_0[x + 60 * y] = x + y;
            multi_valued_1[x + 60 * y] = sinf(x * y);
        }
    }
}

// When compiling ahead-of-time, a Tuple-valued Func evaluates
// into multiple distinct output halide_buffer_t structs. These appear in
// order at the end of the function signature:
// int multi_valued(...input buffers and params...,
//                  halide_buffer_t *output_1, halide_buffer_t *output_2);

// You can construct a Tuple by passing multiple Exprs to the
// Tuple constructor as we did above. Perhaps more elegantly, you
// can also take advantage of initializer lists and just
// enclose your Exprs in braces:
Func multi_valued_2;
multi_valued_2(x, y) = {x + y, sin(x * y)};

// Calls to a multi-valued Func cannot be treated as Exprs. The
// following is a syntax error:
// Func consumer;
// consumer(x, y) = multi_valued_2(x, y) + 10;

// Instead you must index a Tuple with square brackets to retrieve
// the individual Exprs:
Expr integer_part = multi_valued_2(x, y)[0];
Expr floating_part = multi_valued_2(x, y)[1];
Func consumer;
consumer(x, y) = {integer_part + 10, floating_part + 10.0f};

// Tuple reductions.
{
    // Tuples are particularly useful in reductions, as they allow
    // the reduction to maintain complex state as it walks along
    // its domain. The simplest example is an argmax.

    // First we create a Buffer to take the argmax over.
    Func input_func;
    input_func(x) = sin(x);
    Buffer<float> input = input_func.realize({100});

    // Then we define a 2-valued Tuple which tracks the index of
    // the maximum value and the value itself.
    Func arg_max;

    // Pure definition.
    arg_max() = {0, input(0)};

    // Update definition.
    RDom r(1, 99);
    Expr old_index = arg_max()[0];
    Expr old_max = arg_max()[1];
    Expr new_index = select(old_max < input(r), r, old_index);
    Expr new_max = max(input(r), old_max);
    arg_max() = {new_index, new_max};

    // The equivalent C++ is:
    int arg_max_0 = 0;
    float arg_max_1 = input(0);
    for (int r = 1; r < 100; r++) {
        int old_index = arg_max_0;
        float old_max = arg_max_1;
        int new_index = old_max < input(r) ? r : old_index;
        float new_max = std::max(input(r), old_max);
        // In a tuple update definition, all loads and computation
        // are done before any stores, so that all Tuple elements
        // are updated atomically with respect to recursive calls
        // to the same Func.
        arg_max_0 = new_index;
        arg_max_1 = new_max;
    }

    // Let's verify that the Halide and C++ found the same maximum
    // value and index.
    {
        Realization r = arg_max.realize();
        Buffer<int> r0 = r[0];
        Buffer<float> r1 = r[1];
        assert(arg_max_0 == r0(0));
        assert(arg_max_1 == r1(0));
    }

    // Halide provides argmax and argmin as built-in reductions
    // similar to sum, product, maximum, and minimum. They return
    // a Tuple consisting of the point in the reduction domain
    // corresponding to that value, and the value itself. In the
    // case of ties they return the first value found. We'll use
    // one of these in the following section.
}

// Tuples for user-defined types.
{
    // Tuples can also be a convenient way to represent compound
    // objects such as complex numbers. Defining an object that
    // can be converted to and from a Tuple is one way to extend
    // Halide's type system with user-defined types.
    struct Complex {
        Expr real, imag;

        // Construct from a Tuple
        Complex(Tuple t)
            : real(t[0]), imag(t[1]) {
        }

        // Construct from a pair of Exprs
        Complex(Expr r, Expr i)
            : real(r), imag(i) {
        }

        // Construct from a call to a Func by treating it as a Tuple
        Complex(FuncRef t)
            : Complex(Tuple(t)) {
        }

        // Convert to a Tuple
        operator Tuple() const {
            return {real, imag};
        }

        // Complex addition
        Complex operator+(const Complex &other) const {
            return {real + other.real, imag + other.imag};
        }

        // Complex multiplication
        Complex operator*(const Complex &other) const {
            return {real * other.real - imag * other.imag,
                    real * other.imag + imag * other.real};
        }

        // Complex magnitude, squared for efficiency
        Expr magnitude_squared() const {
            return real * real + imag * imag;
        }

        // Other complex operators would go here. The above are
        // sufficient for this example.
    };

    // Let's use the Complex struct to compute a Mandelbrot set.
    Func mandelbrot;

    // The initial complex value corresponding to an x, y coordinate
    // in our Func.
    Complex initial(x / 15.0f - 2.5f, y / 6.0f - 2.0f);

    // Pure definition.
    Var t;
    mandelbrot(x, y, t) = Complex(0.0f, 0.0f);

    // We'll use an update definition to take 12 steps.
    RDom r(1, 12);
    Complex current = mandelbrot(x, y, r - 1);

    // The following line uses the complex multiplication and
    // addition we defined above.
    mandelbrot(x, y, r) = current * current + initial;

    // We'll use another tuple reduction to compute the iteration
    // number where the value first escapes a circle of radius 4.
    // This can be expressed as an argmin of a boolean - we want
    // the index of the first time the given boolean expression is
    // false (we consider false to be less than true).  The argmax
    // would return the index of the first time the expression is
    // true.

    Expr escape_condition = Complex(mandelbrot(x, y, r)).magnitude_squared() < 16.0f;
    Tuple first_escape = argmin(escape_condition);

    // We only want the index, not the value, but argmin returns
    // both, so we'll index the argmin Tuple expression using
    // square brackets to get the Expr representing the index.
    Func escape;
    escape(x, y) = first_escape[0];

    // Realize the pipeline and print the result as ascii art.
    Buffer<int> result = escape.realize({61, 25});
    const char *code = " .:-~*={}&%#@";
    for (int y = 0; y < result.height(); y++) {
        for (int x = 0; x < result.width(); x++) {
            printf("%c", code[result(x, y)]);
        }
        printf("\n");
    }
}

printf("Success!\n");

return 0;

}