ArrayFire: machine_learning/geneticalgorithm.cpp (original) (raw)

#include

#include

#include

#include

using namespace af;

static const float DefaultTopFittest = 0.5;

const array& sampleY, const int n) {

return searchSpace(sampleY * n + sampleX);

}

array selectFittest(const array& sampleZ, const int nSamples,

const float topFit = DefaultTopFittest) {

array indices, values;

sort(values, indices, sampleZ);

int topFitElem = topFit * nSamples;

return (n > topFitElem) ? indices(seq(n - topFitElem, n - 1)) : indices;

}

void reproduce(array& searchSpace, array& sampleX, array& sampleY,

array& sampleZ, const int nSamples, const int n) {

array selection = selectFittest(sampleZ, nSamples);

array parentsX = sampleX(selection);

array parentsY = sampleY(selection);

int bits = (int)log2(n);

array lowermask = (1 << crossover) - 1;

array uppermask = INT_MAX - lowermask;

array childrenX1 = (parentsX1 & uppermask) + (parentsX2 & lowermask);

array childrenY1 = (parentsY1 & uppermask) + (parentsY2 & lowermask);

array childrenX2 = (parentsX2 & uppermask) + (parentsX1 & lowermask);

array childrenY2 = (parentsY2 & uppermask) + (parentsY1 & lowermask);

sampleX = join(0, childrenX1, childrenX2);

sampleY = join(0, childrenY1, childrenY2);

array mutantX = sampleX;

array mutantY = sampleY;

mutantX = mutantX ^ (1 << (randu(nSamples / 2, u32) % bits));

mutantY = mutantY ^ (1 << (randu(nSamples / 2, u32) % bits));

sampleX = join(0, sampleX, mutantX);

sampleY = join(0, sampleY, mutantY);

sampleZ = update(searchSpace, sampleX, sampleY, n);

}

void initSamples(array& searchSpace, array& sampleX, array& sampleY,

array& sampleZ, const int nSamples, const int n) {

sampleX = randu(nSamples, u32) % n;

sampleY = randu(nSamples, u32) % n;

sampleZ = update(searchSpace, sampleX, sampleY, n);

}

void init(array& searchSpace, array& searchSpaceXDisplay,

array& sampleZ, const int nSamples, const int n) {

searchSpace = range(dim4(n / 2, n / 2), 0) + range(dim4(n / 2, n / 2), 1);

searchSpace = join(0, searchSpace, flip(searchSpace, 0));

searchSpace = join(1, searchSpace, flip(searchSpace, 1));

searchSpaceXDisplay = iota(dim4(n, 1), dim4(1, n));

searchSpaceYDisplay = iota(dim4(1, n), dim4(n, 1));

initSamples(searchSpace, sampleX, sampleY, sampleZ, nSamples, n);

}

void reproducePrint(float& currentMax, array& searchSpace, array& sampleX,

array& sampleY, array& sampleZ, const float trueMax,

const int nSamples, const int n) {

if (currentMax < trueMax * 0.99) {

float maximum = max(sampleZ);

array whereM = where(sampleZ == maximum);

if (maximum < trueMax * 0.99) {

printf("Current max at ");

} else {

printf("\nMax found at ");

}

printf("(%d,%d): %f (trueMax %f)\n",

sampleX(whereM).scalar(),

sampleY(whereM).scalar(), maximum, trueMax);

currentMax = maximum;

reproduce(searchSpace, sampleX, sampleY, sampleZ, nSamples, n);

}

}

void geneticSearch(bool console, const int nSamples, const int n) {

array searchSpaceXDisplay;

array searchSpaceYDisplay;

init(searchSpace, searchSpaceXDisplay, searchSpaceYDisplay, sampleX,

sampleY, sampleZ, nSamples, n);

float trueMax = max(searchSpace);

float maximum = -trueMax;

if (!console) {

af::Window win(1600, 800, "Arrayfire Genetic Algorithm Search Demo");

win.grid(1, 2);

do {

reproducePrint(maximum, searchSpace, sampleX, sampleY, sampleZ,

trueMax, nSamples, n);

win(0, 0).setAxesTitles("IdX", "IdY", "Search Space");

win(0, 1).setAxesTitles("IdX", "IdY", "Search Space");

win(0, 0).surface(searchSpaceXDisplay, searchSpaceYDisplay,

searchSpace);

win(0, 1).scatter(sampleX.as(f32), sampleY.as(f32), sampleZ.as(f32),

win.show();

} while (!win.close());

} else {

do {

reproducePrint(maximum, searchSpace, sampleX, sampleY, sampleZ,

trueMax, nSamples, n);

} while (maximum < trueMax * 0.99);

}

}

int main(int argc, char** argv) {

bool console = false;

const int n = 32;

const int nSamples = 16;

if (argc > 2 || (argc == 2 && strcmp(argv[1], "-"))) {

printf("usage: %s [-]\n", argv[0]);

return -1;

} else if (argc == 2 && argv[1][0] == '-') {

console = true;

}

try {

printf(

"** ArrayFire Genetic Algorithm Search Demo **\n\n"

"Search for trueMax in a search space where the objective "

"function is defined as :\n\n"

"SS(x ,y) = min(x, n - (x + 1)) + min(y, n - (y + 1))\n\n"

"(x, y) belongs to RxR; R = [0, n); n = %d\n\n",

n);

if (!console) {

printf(

"The left figure shows the objective function.\n"

"The right figure shows current generation's "

"parameters and function values.\n\n");

}

geneticSearch(console, nSamples, n);

return 0;

}

Window object to render af::arrays.

A multi dimensional data container.

const array as(dtype type) const

Casts the array into another data type.

dim_t elements() const

Get the total number of elements across all dimensions of the array.

Generic object that represents size and shape.

An ArrayFire exception class.

virtual const char * what() const

Returns an error message for the exception in a string format.

seq is used to create sequences for indexing af::array

@ u32

32-bit unsigned integral values

@ f32

32-bit floating point values

array::array_proxy rows(int first, int last)

Returns a reference to sequence of rows.

AFAPI array iota(const dim4 &dims, const dim4 &tile_dims=dim4(1), const dtype ty=f32)

C++ Interface to generate an array with [0, n-1] values modified to specified dimensions and tiling.

AFAPI array range(const dim4 &dims, const int seq_dim=-1, const dtype ty=f32)

C++ Interface to generate an array with [0, n-1] values along the seq_dim dimension and tiled across ...

AFAPI array flip(const array &in, const unsigned dim)

C++ Interface to flip an array.

AFAPI array join(const int dim, const array &first, const array &second)

C++ Interface to join 2 arrays along a dimension.

AFAPI array randu(const dim4 &dims, const dtype ty, randomEngine &r)

C++ Interface to create an array of random numbers uniformly distributed.

AFAPI void setSeed(const unsigned long long seed)

C++ Interface to set the seed of the default random number generator.

AFAPI array where(const array &in)

C++ Interface to locate the indices of the non-zero values in an array.

AFAPI array log2(const array &in)

C++ Interface to evaluate the base 2 logarithm.