GeographicLib: NearestNeighbor.hpp Source File (original) (raw)
97
98 static const int version = 1;
99
100
101 static const int maxbucket =
102 (2 + ((4 * sizeof(dist_t)) / sizeof(int) >= 2 ?
103 (4 * sizeof(dist_t)) / sizeof(int) : 2));
104 public:
105
106
107
108
109
110
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142 NearestNeighbor(const std::vector<pos_t>& pts, const distfun_t& dist,
143 int bucket = 4) {
145 }
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163 void Initialize(const std::vector<pos_t>& pts, const distfun_t& dist,
164 int bucket = 4) {
165 static_assert(std::numeric_limits<dist_t>::is_signed,
166 "dist_t must be a signed type");
167 if (!( 0 <= bucket && bucket <= maxbucket ))
169 ("bucket must lie in [0, 2 + 4*sizeof(dist_t)/sizeof(int)]");
170 if (pts.size() > size_t(std::numeric_limits::max()))
172
173 std::vector ids(pts.size());
174 for (int k = int(ids.size()); k--;)
175 ids[k] = std::make_pair(dist_t(0), k);
176 int cost = 0;
177 std::vector tree;
178 init(pts, dist, bucket, tree, ids, cost,
179 0, int(ids.size()), int(ids.size()/2));
180 _tree.swap(tree);
181 _numpoints = int(pts.size());
182 _bucket = bucket;
183 _mc = _sc = 0;
184 _cost = cost; _c1 = _k = _cmax = 0;
185 _cmin = std::numeric_limits::max();
186 }
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251 dist_t Search(const std::vector<pos_t>& pts, const distfun_t& dist,
252 const pos_t& query,
253 std::vector& ind,
254 int k = 1,
255 dist_t maxdist = std::numeric_limits<dist_t>::max(),
256 dist_t mindist = -1,
257 bool exhaustive = true,
258 dist_t tol = 0) const {
259 if (_numpoints != int(pts.size()))
261 std::priority_queue results;
262 if (_numpoints > 0 && k > 0 && maxdist > mindist) {
263
264 dist_t tau = maxdist;
265
266
267
268 std::priority_queue todo;
269 todo.push(std::make_pair(dist_t(1), int(_tree.size()) - 1));
270 int c = 0;
271 while (!todo.empty()) {
272 int n = todo.top().second;
273 dist_t d = -todo.top().first;
274 todo.pop();
275 dist_t tau1 = tau - tol;
276
277 if (!( n >= 0 && tau1 >= d )) continue;
278 const Node& current = _tree[n];
279 dist_t dst = 0;
280 bool exitflag = false, leaf = current.index < 0;
281 for (int i = 0; i < (leaf ? _bucket : 1); ++i) {
282 int index = leaf ? current.leaves[i] : current.index;
283 if (index < 0) break;
284 dst = dist(pts[index], query);
285 ++c;
286
287 if (dst > mindist && dst <= tau) {
288 if (int(results.size()) == k) results.pop();
289 results.push(std::make_pair(dst, index));
290 if (int(results.size()) == k) {
291 if (exhaustive)
292 tau = results.top().first;
293 else {
294 exitflag = true;
295 break;
296 }
297 if (tau <= tol) {
298 exitflag = true;
299 break;
300 }
301 }
302 }
303 }
304 if (exitflag) break;
305
306 if (current.index < 0) continue;
307 tau1 = tau - tol;
308 for (int l = 0; l < 2; ++l) {
309 if (current.data.child[l] >= 0 &&
310 dst + current.data.upper[l] >= mindist) {
311 if (dst < current.data.lower[l]) {
312 d = current.data.lower[l] - dst;
313 if (tau1 >= d)
314 todo.push(std::make_pair(-d, current.data.child[l]));
315 } else if (dst > current.data.upper[l]) {
316 d = dst - current.data.upper[l];
317 if (tau1 >= d)
318 todo.push(std::make_pair(-d, current.data.child[l]));
319 } else
320 todo.push(std::make_pair(dist_t(1), current.data.child[l]));
321 }
322 }
323 }
324 ++_k;
325 _c1 += c;
326 double omc = _mc;
327 _mc += (c - omc) / _k;
328 _sc += (c - omc) * (c - _mc);
329 if (c > _cmax) _cmax = c;
330 if (c < _cmin) _cmin = c;
331 }
332
333 dist_t d = -1;
334 ind.resize(results.size());
335
336 for (int i = int(ind.size()); i--;) {
337 ind[i] = int(results.top().second);
338 if (i == 0) d = results.top().first;
339 results.pop();
340 }
341 return d;
342
343 }
344
345
346
347
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367 void Save(std::ostream& os, bool bin = true) const {
368 int realspec = std::numeric_limits<dist_t>::digits *
369 (std::numeric_limits<dist_t>::is_integer ? -1 : 1);
370 if (bin) {
371 char id[] = "NearestNeighbor_";
372 os.write(id, 16);
373 int buf[6];
374 buf[0] = version;
375 buf[1] = realspec;
376 buf[2] = _bucket;
377 buf[3] = _numpoints;
378 buf[4] = int(_tree.size());
379 buf[5] = _cost;
380 os.write(reinterpret_cast<const char *>(buf), 6 * sizeof(int));
381 for (int i = 0; i < int(_tree.size()); ++i) {
382 const Node& node = _tree[i];
383 os.write(reinterpret_cast<const char *>(&node.index), sizeof(int));
384 if (node.index >= 0) {
385 os.write(reinterpret_cast<const char *>(node.data.lower),
386 2 * sizeof(dist_t));
387 os.write(reinterpret_cast<const char *>(node.data.upper),
388 2 * sizeof(dist_t));
389 os.write(reinterpret_cast<const char *>(node.data.child),
390 2 * sizeof(int));
391 } else {
392 os.write(reinterpret_cast<const char *>(node.leaves),
393 _bucket * sizeof(int));
394 }
395 }
396 } else {
397 std::stringstream ostring;
398
399
400 if (!std::numeric_limits<dist_t>::is_integer) {
401 static const int prec
402 = int(std::ceil(std::numeric_limits<dist_t>::digits *
403 std::log10(2.0) + 1));
404 ostring.precision(prec);
405 }
406 ostring << version << " " << realspec << " " << _bucket << " "
407 << _numpoints << " " << _tree.size() << " " << _cost;
408 for (int i = 0; i < int(_tree.size()); ++i) {
409 const Node& node = _tree[i];
410 ostring << "\n" << node.index;
411 if (node.index >= 0) {
412 for (int l = 0; l < 2; ++l)
413 ostring << " " << node.data.lower[l] << " " << node.data.upper[l]
414 << " " << node.data.child[l];
415 } else {
416 for (int l = 0; l < _bucket; ++l)
417 ostring << " " << node.leaves[l];
418 }
419 }
420 os << ostring.str();
421 }
422 }
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445 void Load(std::istream& is, bool bin = true) {
446 int version1, realspec, bucket, numpoints, treesize, cost;
447 if (bin) {
448 char id[17];
449 is.read(id, 16);
450 id[16] = '\0';
451 if (!(std::strcmp(id, "NearestNeighbor_") == 0))
453 is.read(reinterpret_cast<char *>(&version1), sizeof(int));
454 is.read(reinterpret_cast<char *>(&realspec), sizeof(int));
455 is.read(reinterpret_cast<char *>(&bucket), sizeof(int));
456 is.read(reinterpret_cast<char *>(&numpoints), sizeof(int));
457 is.read(reinterpret_cast<char *>(&treesize), sizeof(int));
458 is.read(reinterpret_cast<char *>(&cost), sizeof(int));
459 } else {
460 if (!( is >> version1 >> realspec >> bucket >> numpoints >> treesize
461 >> cost ))
463 }
464 if (!( version1 == version ))
466 if (!( realspec == std::numeric_limits<dist_t>::digits *
467 (std::numeric_limits<dist_t>::is_integer ? -1 : 1) ))
469 if (!( 0 <= bucket && bucket <= maxbucket ))
471 if (!( 0 <= treesize && treesize <= numpoints ))
472 throw
474 if (!( 0 <= cost ))
476 std::vector tree;
477 tree.reserve(treesize);
478 for (int i = 0; i < treesize; ++i) {
479 Node node;
480 if (bin) {
481 is.read(reinterpret_cast<char *>(&node.index), sizeof(int));
482 if (node.index >= 0) {
483 is.read(reinterpret_cast<char *>(node.data.lower),
484 2 * sizeof(dist_t));
485 is.read(reinterpret_cast<char *>(node.data.upper),
486 2 * sizeof(dist_t));
487 is.read(reinterpret_cast<char *>(node.data.child),
488 2 * sizeof(int));
489 } else {
490 is.read(reinterpret_cast<char *>(node.leaves),
491 bucket * sizeof(int));
492 for (int l = bucket; l < maxbucket; ++l)
493 node.leaves[l] = 0;
494 }
495 } else {
496 if (!( is >> node.index ))
498 if (node.index >= 0) {
499 for (int l = 0; l < 2; ++l) {
500 if (!( is >> node.data.lower[l] >> node.data.upper[l]
501 >> node.data.child[l] ))
503 }
504 } else {
505
506
507 for (int l = 0; l < bucket; ++l) {
508 if (!( is >> node.leaves[l] ))
510 }
511 for (int l = bucket; l < maxbucket; ++l)
512 node.leaves[l] = 0;
513 }
514 }
515 node.Check(numpoints, treesize, bucket);
516 tree.push_back(node);
517 }
518 _tree.swap(tree);
519 _numpoints = numpoints;
520 _bucket = bucket;
521 _mc = _sc = 0;
522 _cost = cost; _c1 = _k = _cmax = 0;
523 _cmin = std::numeric_limits::max();
524 }
525
526
527
528
529
530
531
532
533
535 { t.Save(os, false); return os; }
536
537
538
539
540
541
542
543
544
546 { t.Load(is, false); return is; }
547
548
549
550
551
552
554 std::swap(_numpoints, t._numpoints);
557 _tree.swap(t._tree);
564 }
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580 void Statistics(int& setupcost, int& numsearches, int& searchcost,
581 int& mincost, int& maxcost,
582 double& mean, double& sd) const {
583 setupcost = _cost; numsearches = _k; searchcost = _c1;
584 mincost = _cmin; maxcost = _cmax;
585 mean = _mc; sd = std::sqrt(_sc / (_k - 1));
586 }
587
588
589
590
591
593 _mc = _sc = 0;
594 _c1 = _k = _cmax = 0;
595 _cmin = std::numeric_limits::max();
596 }
597
598 private:
599
600
601 typedef std::pair<dist_t, int> item;
602
603 class Node {
604 public:
605 struct bounds {
606 dist_t lower[2], upper[2];
607 int child[2];
608 };
609 union {
610 bounds data;
611 int leaves[maxbucket];
612 };
613 int index;
614
615 Node()
616 : index(-1)
617 {
618 for (int i = 0; i < 2; ++i) {
619 data.lower[i] = data.upper[i] = 0;
620 data.child[i] = -1;
621 }
622 }
623
624
625 void Check(int numpoints, int treesize, int bucket) const {
626 if (!( -1 <= index && index < numpoints ))
628 if (index >= 0) {
629 if (!( -1 <= data.child[0] && data.child[0] < treesize &&
630 -1 <= data.child[1] && data.child[1] < treesize ))
632 if (!( 0 <= data.lower[0] && data.lower[0] <= data.upper[0] &&
633 data.upper[0] <= data.lower[1] &&
634 data.lower[1] <= data.upper[1] ))
636 } else {
637
638
639 bool start = true;
640 for (int l = 0; l < bucket; ++l) {
641 if (!( (start ?
642 ((l == 0 ? 0 : -1) <= leaves[l] && leaves[l] < numpoints) :
643 leaves[l] == -1) ))
644 throw GeographicLib::GeographicErr("Bad leaf data");
645 start = leaves[l] >= 0;
646 }
647 for (int l = bucket; l < maxbucket; ++l) {
648 if (leaves[l] != 0)
650 }
651 }
652 }
653
654#if defined(GEOGRAPHICLIB_HAVE_BOOST_SERIALIZATION) && \
655 GEOGRAPHICLIB_HAVE_BOOST_SERIALIZATION
656 friend class boost::serialization::access;
657 template
658 void save(Archive& ar, const unsigned int) const {
659 ar & boost::serialization::make_nvp("index", index);
660 if (index < 0)
661 ar & boost::serialization::make_nvp("leaves", leaves);
662 else
663 ar & boost::serialization::make_nvp("lower", data.lower)
664 & boost::serialization::make_nvp("upper", data.upper)
665 & boost::serialization::make_nvp("child", data.child);
666 }
667 template
668 void load(Archive& ar, const unsigned int) {
669 ar & boost::serialization::make_nvp("index", index);
670 if (index < 0)
671 ar & boost::serialization::make_nvp("leaves", leaves);
672 else
673 ar & boost::serialization::make_nvp("lower", data.lower)
674 & boost::serialization::make_nvp("upper", data.upper)
675 & boost::serialization::make_nvp("child", data.child);
676 }
677 template
678 void serialize(Archive& ar, const unsigned int file_version)
679 { boost::serialization::split_member(ar, *this, file_version); }
680#endif
681 };
682
683#if defined(GEOGRAPHICLIB_HAVE_BOOST_SERIALIZATION) && \
684 GEOGRAPHICLIB_HAVE_BOOST_SERIALIZATION
685 friend class boost::serialization::access;
686 template void save(Archive& ar, const unsigned) const {
687 int realspec = std::numeric_limits<dist_t>::digits *
688 (std::numeric_limits<dist_t>::is_integer ? -1 : 1);
689
690
691 int version1 = version;
692 ar & boost::serialization::make_nvp("version", version1)
693 & boost::serialization::make_nvp("realspec", realspec)
694 & boost::serialization::make_nvp("bucket", _bucket)
695 & boost::serialization::make_nvp("numpoints", _numpoints)
696 & boost::serialization::make_nvp("cost", _cost)
697 & boost::serialization::make_nvp("tree", _tree);
698 }
699 template void load(Archive& ar, const unsigned) {
700 int version1, realspec, bucket, numpoints, cost;
701 ar & boost::serialization::make_nvp("version", version1);
702 if (version1 != version)
704 std::vector tree;
705 ar & boost::serialization::make_nvp("realspec", realspec);
706 if (!( realspec == std::numeric_limits<dist_t>::digits *
707 (std::numeric_limits<dist_t>::is_integer ? -1 : 1) ))
708 throw GeographicLib::GeographicErr("Different dist_t types");
709 ar & boost::serialization::make_nvp("bucket", bucket);
710 if (!( 0 <= bucket && bucket <= maxbucket ))
712 ar & boost::serialization::make_nvp("numpoints", numpoints)
713 & boost::serialization::make_nvp("cost", cost)
714 & boost::serialization::make_nvp("tree", tree);
715 if (!( 0 <= int(tree.size()) && int(tree.size()) <= numpoints ))
716 throw
718 for (int i = 0; i < int(tree.size()); ++i)
719 tree[i].Check(numpoints, int(tree.size()), bucket);
720 _tree.swap(tree);
721 _numpoints = numpoints;
722 _bucket = bucket;
723 _mc = _sc = 0;
724 _cost = cost; _c1 = _k = _cmax = 0;
725 _cmin = std::numeric_limits::max();
726 }
727 template
728 void serialize(Archive& ar, const unsigned int file_version)
729 { boost::serialization::split_member(ar, *this, file_version); }
730#endif
731
732 int _numpoints, _bucket, _cost;
733 std::vector _tree;
734
735 mutable double _mc, _sc;
736 mutable int _c1, _k, _cmin, _cmax;
737
738 int init(const std::vector<pos_t>& pts, const distfun_t& dist, int bucket,
739 std::vector& tree, std::vector& ids, int& cost,
740 int l, int u, int vp) {
741
742 if (u == l)
743 return -1;
744 Node node;
745
746 if (u - l > (bucket == 0 ? 1 : bucket)) {
747
748
749 int i = vp;
751
752 int m = (u + l + 1) / 2;
753
754 for (int k = l + 1; k < u; ++k) {
755 ids[k].first = dist(pts[ids[l].second], pts[ids[k].second]);
756 ++cost;
757 }
758
759 std::nth_element(ids.begin() + l + 1,
760 ids.begin() + m,
761 ids.begin() + u);
762 node.index = ids[l].second;
763 if (m > l + 1) {
764 typename std::vector::iterator
765 t = std::min_element(ids.begin() + l + 1, ids.begin() + m);
766 node.data.lower[0] = t->first;
767 t = std::max_element(ids.begin() + l + 1, ids.begin() + m);
768 node.data.upper[0] = t->first;
769
770
771 node.data.child[0] = init(pts, dist, bucket, tree, ids, cost,
772 l + 1, m, int(t - ids.begin()));
773 }
774 typename std::vector::iterator
775 t = std::max_element(ids.begin() + m, ids.begin() + u);
776 node.data.lower[1] = ids[m].first;
777 node.data.upper[1] = t->first;
778
779 node.data.child[1] = init(pts, dist, bucket, tree, ids, cost,
780 m, u, int(t - ids.begin()));
781 } else {
782 if (bucket == 0)
783 node.index = ids[l].second;
784 else {
785 node.index = -1;
786
787
788 std::sort(ids.begin() + l, ids.begin() + u);
789 for (int i = l; i < u; ++i)
790 node.leaves[i-l] = ids[i].second;
791 for (int i = u - l; i < bucket; ++i)
792 node.leaves[i] = -1;
793 for (int i = bucket; i < maxbucket; ++i)
794 node.leaves[i] = 0;
795 }
796 }
797
798 tree.push_back(node);
799 return int(tree.size()) - 1;
800 }
801
802 };
807
808
809
810
811
812
813
814
815
816
817
818 template<typename dist_t, typename pos_t, class distfun_t>
823
824}