GeographicLib: NearestNeighbor.hpp Source File (original) (raw)

97

98 static const int version = 1;

99

100

101 static const int maxbucket =

102 (2 + ((4 * sizeof(dist_t)) / sizeof(int) >= 2 ?

103 (4 * sizeof(dist_t)) / sizeof(int) : 2));

104 public:

105

106

107

108

109

110

112

113

114

115

116

117

118

119

120

121

122

123

124

125

126

127

128

129

130

131

132

133

134

135

136

137

138

139

140

141

142 NearestNeighbor(const std::vector<pos_t>& pts, const distfun_t& dist,

143 int bucket = 4) {

145 }

146

147

148

149

150

151

152

153

154

155

156

157

158

159

160

161

162

163 void Initialize(const std::vector<pos_t>& pts, const distfun_t& dist,

164 int bucket = 4) {

165 static_assert(std::numeric_limits<dist_t>::is_signed,

166 "dist_t must be a signed type");

167 if (!( 0 <= bucket && bucket <= maxbucket ))

169 ("bucket must lie in [0, 2 + 4*sizeof(dist_t)/sizeof(int)]");

170 if (pts.size() > size_t(std::numeric_limits::max()))

172

173 std::vector ids(pts.size());

174 for (int k = int(ids.size()); k--;)

175 ids[k] = std::make_pair(dist_t(0), k);

176 int cost = 0;

177 std::vector tree;

178 init(pts, dist, bucket, tree, ids, cost,

179 0, int(ids.size()), int(ids.size()/2));

180 _tree.swap(tree);

181 _numpoints = int(pts.size());

182 _bucket = bucket;

183 _mc = _sc = 0;

184 _cost = cost; _c1 = _k = _cmax = 0;

185 _cmin = std::numeric_limits::max();

186 }

187

188

189

190

191

192

193

194

195

196

197

198

199

200

201

202

203

204

205

206

207

208

209

210

211

212

213

214

215

216

217

218

219

220

221

222

223

224

225

226

227

228

229

230

231

232

233

234

235

236

237

238

239

240

241

242

243

244

245

246

247

248

249

250

251 dist_t Search(const std::vector<pos_t>& pts, const distfun_t& dist,

252 const pos_t& query,

253 std::vector& ind,

254 int k = 1,

255 dist_t maxdist = std::numeric_limits<dist_t>::max(),

256 dist_t mindist = -1,

257 bool exhaustive = true,

258 dist_t tol = 0) const {

259 if (_numpoints != int(pts.size()))

261 std::priority_queue results;

262 if (_numpoints > 0 && k > 0 && maxdist > mindist) {

263

264 dist_t tau = maxdist;

265

266

267

268 std::priority_queue todo;

269 todo.push(std::make_pair(dist_t(1), int(_tree.size()) - 1));

270 int c = 0;

271 while (!todo.empty()) {

272 int n = todo.top().second;

273 dist_t d = -todo.top().first;

274 todo.pop();

275 dist_t tau1 = tau - tol;

276

277 if (!( n >= 0 && tau1 >= d )) continue;

278 const Node& current = _tree[n];

279 dist_t dst = 0;

280 bool exitflag = false, leaf = current.index < 0;

281 for (int i = 0; i < (leaf ? _bucket : 1); ++i) {

282 int index = leaf ? current.leaves[i] : current.index;

283 if (index < 0) break;

284 dst = dist(pts[index], query);

285 ++c;

286

287 if (dst > mindist && dst <= tau) {

288 if (int(results.size()) == k) results.pop();

289 results.push(std::make_pair(dst, index));

290 if (int(results.size()) == k) {

291 if (exhaustive)

292 tau = results.top().first;

293 else {

294 exitflag = true;

295 break;

296 }

297 if (tau <= tol) {

298 exitflag = true;

299 break;

300 }

301 }

302 }

303 }

304 if (exitflag) break;

305

306 if (current.index < 0) continue;

307 tau1 = tau - tol;

308 for (int l = 0; l < 2; ++l) {

309 if (current.data.child[l] >= 0 &&

310 dst + current.data.upper[l] >= mindist) {

311 if (dst < current.data.lower[l]) {

312 d = current.data.lower[l] - dst;

313 if (tau1 >= d)

314 todo.push(std::make_pair(-d, current.data.child[l]));

315 } else if (dst > current.data.upper[l]) {

316 d = dst - current.data.upper[l];

317 if (tau1 >= d)

318 todo.push(std::make_pair(-d, current.data.child[l]));

319 } else

320 todo.push(std::make_pair(dist_t(1), current.data.child[l]));

321 }

322 }

323 }

324 ++_k;

325 _c1 += c;

326 double omc = _mc;

327 _mc += (c - omc) / _k;

328 _sc += (c - omc) * (c - _mc);

329 if (c > _cmax) _cmax = c;

330 if (c < _cmin) _cmin = c;

331 }

332

333 dist_t d = -1;

334 ind.resize(results.size());

335

336 for (int i = int(ind.size()); i--;) {

337 ind[i] = int(results.top().second);

338 if (i == 0) d = results.top().first;

339 results.pop();

340 }

341 return d;

342

343 }

344

345

346

347

349

350

351

352

353

354

355

356

357

358

359

360

361

362

363

364

365

366

367 void Save(std::ostream& os, bool bin = true) const {

368 int realspec = std::numeric_limits<dist_t>::digits *

369 (std::numeric_limits<dist_t>::is_integer ? -1 : 1);

370 if (bin) {

371 char id[] = "NearestNeighbor_";

372 os.write(id, 16);

373 int buf[6];

374 buf[0] = version;

375 buf[1] = realspec;

376 buf[2] = _bucket;

377 buf[3] = _numpoints;

378 buf[4] = int(_tree.size());

379 buf[5] = _cost;

380 os.write(reinterpret_cast<const char *>(buf), 6 * sizeof(int));

381 for (int i = 0; i < int(_tree.size()); ++i) {

382 const Node& node = _tree[i];

383 os.write(reinterpret_cast<const char *>(&node.index), sizeof(int));

384 if (node.index >= 0) {

385 os.write(reinterpret_cast<const char *>(node.data.lower),

386 2 * sizeof(dist_t));

387 os.write(reinterpret_cast<const char *>(node.data.upper),

388 2 * sizeof(dist_t));

389 os.write(reinterpret_cast<const char *>(node.data.child),

390 2 * sizeof(int));

391 } else {

392 os.write(reinterpret_cast<const char *>(node.leaves),

393 _bucket * sizeof(int));

394 }

395 }

396 } else {

397 std::stringstream ostring;

398

399

400 if (!std::numeric_limits<dist_t>::is_integer) {

401 static const int prec

402 = int(std::ceil(std::numeric_limits<dist_t>::digits *

403 std::log10(2.0) + 1));

404 ostring.precision(prec);

405 }

406 ostring << version << " " << realspec << " " << _bucket << " "

407 << _numpoints << " " << _tree.size() << " " << _cost;

408 for (int i = 0; i < int(_tree.size()); ++i) {

409 const Node& node = _tree[i];

410 ostring << "\n" << node.index;

411 if (node.index >= 0) {

412 for (int l = 0; l < 2; ++l)

413 ostring << " " << node.data.lower[l] << " " << node.data.upper[l]

414 << " " << node.data.child[l];

415 } else {

416 for (int l = 0; l < _bucket; ++l)

417 ostring << " " << node.leaves[l];

418 }

419 }

420 os << ostring.str();

421 }

422 }

423

424

425

426

427

428

429

430

431

432

433

434

435

436

437

438

439

440

441

442

443

444

445 void Load(std::istream& is, bool bin = true) {

446 int version1, realspec, bucket, numpoints, treesize, cost;

447 if (bin) {

448 char id[17];

449 is.read(id, 16);

450 id[16] = '\0';

451 if (!(std::strcmp(id, "NearestNeighbor_") == 0))

453 is.read(reinterpret_cast<char *>(&version1), sizeof(int));

454 is.read(reinterpret_cast<char *>(&realspec), sizeof(int));

455 is.read(reinterpret_cast<char *>(&bucket), sizeof(int));

456 is.read(reinterpret_cast<char *>(&numpoints), sizeof(int));

457 is.read(reinterpret_cast<char *>(&treesize), sizeof(int));

458 is.read(reinterpret_cast<char *>(&cost), sizeof(int));

459 } else {

460 if (!( is >> version1 >> realspec >> bucket >> numpoints >> treesize

461 >> cost ))

463 }

464 if (!( version1 == version ))

466 if (!( realspec == std::numeric_limits<dist_t>::digits *

467 (std::numeric_limits<dist_t>::is_integer ? -1 : 1) ))

469 if (!( 0 <= bucket && bucket <= maxbucket ))

471 if (!( 0 <= treesize && treesize <= numpoints ))

472 throw

474 if (!( 0 <= cost ))

476 std::vector tree;

477 tree.reserve(treesize);

478 for (int i = 0; i < treesize; ++i) {

479 Node node;

480 if (bin) {

481 is.read(reinterpret_cast<char *>(&node.index), sizeof(int));

482 if (node.index >= 0) {

483 is.read(reinterpret_cast<char *>(node.data.lower),

484 2 * sizeof(dist_t));

485 is.read(reinterpret_cast<char *>(node.data.upper),

486 2 * sizeof(dist_t));

487 is.read(reinterpret_cast<char *>(node.data.child),

488 2 * sizeof(int));

489 } else {

490 is.read(reinterpret_cast<char *>(node.leaves),

491 bucket * sizeof(int));

492 for (int l = bucket; l < maxbucket; ++l)

493 node.leaves[l] = 0;

494 }

495 } else {

496 if (!( is >> node.index ))

498 if (node.index >= 0) {

499 for (int l = 0; l < 2; ++l) {

500 if (!( is >> node.data.lower[l] >> node.data.upper[l]

501 >> node.data.child[l] ))

503 }

504 } else {

505

506

507 for (int l = 0; l < bucket; ++l) {

508 if (!( is >> node.leaves[l] ))

510 }

511 for (int l = bucket; l < maxbucket; ++l)

512 node.leaves[l] = 0;

513 }

514 }

515 node.Check(numpoints, treesize, bucket);

516 tree.push_back(node);

517 }

518 _tree.swap(tree);

519 _numpoints = numpoints;

520 _bucket = bucket;

521 _mc = _sc = 0;

522 _cost = cost; _c1 = _k = _cmax = 0;

523 _cmin = std::numeric_limits::max();

524 }

525

526

527

528

529

530

531

532

533

535 { t.Save(os, false); return os; }

536

537

538

539

540

541

542

543

544

546 { t.Load(is, false); return is; }

547

548

549

550

551

552

554 std::swap(_numpoints, t._numpoints);

557 _tree.swap(t._tree);

564 }

565

566

567

568

569

570

571

572

573

574

575

576

577

578

579

580 void Statistics(int& setupcost, int& numsearches, int& searchcost,

581 int& mincost, int& maxcost,

582 double& mean, double& sd) const {

583 setupcost = _cost; numsearches = _k; searchcost = _c1;

584 mincost = _cmin; maxcost = _cmax;

585 mean = _mc; sd = std::sqrt(_sc / (_k - 1));

586 }

587

588

589

590

591

593 _mc = _sc = 0;

594 _c1 = _k = _cmax = 0;

595 _cmin = std::numeric_limits::max();

596 }

597

598 private:

599

600

601 typedef std::pair<dist_t, int> item;

602

603 class Node {

604 public:

605 struct bounds {

606 dist_t lower[2], upper[2];

607 int child[2];

608 };

609 union {

610 bounds data;

611 int leaves[maxbucket];

612 };

613 int index;

614

615 Node()

616 : index(-1)

617 {

618 for (int i = 0; i < 2; ++i) {

619 data.lower[i] = data.upper[i] = 0;

620 data.child[i] = -1;

621 }

622 }

623

624

625 void Check(int numpoints, int treesize, int bucket) const {

626 if (!( -1 <= index && index < numpoints ))

628 if (index >= 0) {

629 if (!( -1 <= data.child[0] && data.child[0] < treesize &&

630 -1 <= data.child[1] && data.child[1] < treesize ))

632 if (!( 0 <= data.lower[0] && data.lower[0] <= data.upper[0] &&

633 data.upper[0] <= data.lower[1] &&

634 data.lower[1] <= data.upper[1] ))

636 } else {

637

638

639 bool start = true;

640 for (int l = 0; l < bucket; ++l) {

641 if (!( (start ?

642 ((l == 0 ? 0 : -1) <= leaves[l] && leaves[l] < numpoints) :

643 leaves[l] == -1) ))

644 throw GeographicLib::GeographicErr("Bad leaf data");

645 start = leaves[l] >= 0;

646 }

647 for (int l = bucket; l < maxbucket; ++l) {

648 if (leaves[l] != 0)

650 }

651 }

652 }

653

654#if defined(GEOGRAPHICLIB_HAVE_BOOST_SERIALIZATION) && \

655 GEOGRAPHICLIB_HAVE_BOOST_SERIALIZATION

656 friend class boost::serialization::access;

657 template

658 void save(Archive& ar, const unsigned int) const {

659 ar & boost::serialization::make_nvp("index", index);

660 if (index < 0)

661 ar & boost::serialization::make_nvp("leaves", leaves);

662 else

663 ar & boost::serialization::make_nvp("lower", data.lower)

664 & boost::serialization::make_nvp("upper", data.upper)

665 & boost::serialization::make_nvp("child", data.child);

666 }

667 template

668 void load(Archive& ar, const unsigned int) {

669 ar & boost::serialization::make_nvp("index", index);

670 if (index < 0)

671 ar & boost::serialization::make_nvp("leaves", leaves);

672 else

673 ar & boost::serialization::make_nvp("lower", data.lower)

674 & boost::serialization::make_nvp("upper", data.upper)

675 & boost::serialization::make_nvp("child", data.child);

676 }

677 template

678 void serialize(Archive& ar, const unsigned int file_version)

679 { boost::serialization::split_member(ar, *this, file_version); }

680#endif

681 };

682

683#if defined(GEOGRAPHICLIB_HAVE_BOOST_SERIALIZATION) && \

684 GEOGRAPHICLIB_HAVE_BOOST_SERIALIZATION

685 friend class boost::serialization::access;

686 template void save(Archive& ar, const unsigned) const {

687 int realspec = std::numeric_limits<dist_t>::digits *

688 (std::numeric_limits<dist_t>::is_integer ? -1 : 1);

689

690

691 int version1 = version;

692 ar & boost::serialization::make_nvp("version", version1)

693 & boost::serialization::make_nvp("realspec", realspec)

694 & boost::serialization::make_nvp("bucket", _bucket)

695 & boost::serialization::make_nvp("numpoints", _numpoints)

696 & boost::serialization::make_nvp("cost", _cost)

697 & boost::serialization::make_nvp("tree", _tree);

698 }

699 template void load(Archive& ar, const unsigned) {

700 int version1, realspec, bucket, numpoints, cost;

701 ar & boost::serialization::make_nvp("version", version1);

702 if (version1 != version)

704 std::vector tree;

705 ar & boost::serialization::make_nvp("realspec", realspec);

706 if (!( realspec == std::numeric_limits<dist_t>::digits *

707 (std::numeric_limits<dist_t>::is_integer ? -1 : 1) ))

708 throw GeographicLib::GeographicErr("Different dist_t types");

709 ar & boost::serialization::make_nvp("bucket", bucket);

710 if (!( 0 <= bucket && bucket <= maxbucket ))

712 ar & boost::serialization::make_nvp("numpoints", numpoints)

713 & boost::serialization::make_nvp("cost", cost)

714 & boost::serialization::make_nvp("tree", tree);

715 if (!( 0 <= int(tree.size()) && int(tree.size()) <= numpoints ))

716 throw

718 for (int i = 0; i < int(tree.size()); ++i)

719 tree[i].Check(numpoints, int(tree.size()), bucket);

720 _tree.swap(tree);

721 _numpoints = numpoints;

722 _bucket = bucket;

723 _mc = _sc = 0;

724 _cost = cost; _c1 = _k = _cmax = 0;

725 _cmin = std::numeric_limits::max();

726 }

727 template

728 void serialize(Archive& ar, const unsigned int file_version)

729 { boost::serialization::split_member(ar, *this, file_version); }

730#endif

731

732 int _numpoints, _bucket, _cost;

733 std::vector _tree;

734

735 mutable double _mc, _sc;

736 mutable int _c1, _k, _cmin, _cmax;

737

738 int init(const std::vector<pos_t>& pts, const distfun_t& dist, int bucket,

739 std::vector& tree, std::vector& ids, int& cost,

740 int l, int u, int vp) {

741

742 if (u == l)

743 return -1;

744 Node node;

745

746 if (u - l > (bucket == 0 ? 1 : bucket)) {

747

748

749 int i = vp;

751

752 int m = (u + l + 1) / 2;

753

754 for (int k = l + 1; k < u; ++k) {

755 ids[k].first = dist(pts[ids[l].second], pts[ids[k].second]);

756 ++cost;

757 }

758

759 std::nth_element(ids.begin() + l + 1,

760 ids.begin() + m,

761 ids.begin() + u);

762 node.index = ids[l].second;

763 if (m > l + 1) {

764 typename std::vector::iterator

765 t = std::min_element(ids.begin() + l + 1, ids.begin() + m);

766 node.data.lower[0] = t->first;

767 t = std::max_element(ids.begin() + l + 1, ids.begin() + m);

768 node.data.upper[0] = t->first;

769

770

771 node.data.child[0] = init(pts, dist, bucket, tree, ids, cost,

772 l + 1, m, int(t - ids.begin()));

773 }

774 typename std::vector::iterator

775 t = std::max_element(ids.begin() + m, ids.begin() + u);

776 node.data.lower[1] = ids[m].first;

777 node.data.upper[1] = t->first;

778

779 node.data.child[1] = init(pts, dist, bucket, tree, ids, cost,

780 m, u, int(t - ids.begin()));

781 } else {

782 if (bucket == 0)

783 node.index = ids[l].second;

784 else {

785 node.index = -1;

786

787

788 std::sort(ids.begin() + l, ids.begin() + u);

789 for (int i = l; i < u; ++i)

790 node.leaves[i-l] = ids[i].second;

791 for (int i = u - l; i < bucket; ++i)

792 node.leaves[i] = -1;

793 for (int i = bucket; i < maxbucket; ++i)

794 node.leaves[i] = 0;

795 }

796 }

797

798 tree.push_back(node);

799 return int(tree.size()) - 1;

800 }

801

802 };

807

808

809

810

811

812

813

814

815

816

817

818 template<typename dist_t, typename pos_t, class distfun_t>

823

824}