1 #ifndef dsa_KD_Tree_H__
2 #define dsa_KD_Tree_H__
4 #include "../utility.h"
5 #include "../math/utility.h"
39 template<
class Vector,
class Scalar>
47 Node(
Vector v, ssize_t l, ssize_t r): vector_(v), lChild_(l), rChild_(r){
50 typedef std::vector<Node> Nodes;
57 Sorter(Nodes
const* nodes,
size_t cmp):
58 nodes_(nodes), cmp_(cmp){
60 bool operator()(
size_t const& a,
size_t const& b)
const{
61 if ((*nodes_)[a].vector_[cmp_] != (*nodes_)[b].vector_[cmp_]) {
62 return ((*nodes_)[a].vector_[cmp_] < (*nodes_)[b].vector_[cmp_]);
64 return ((*nodes_)[a].vector_ < (*nodes_)[b].vector_);
71 Answer(ssize_t index, Scalar dist2):
72 index_(index), dist2_(dist2) {
74 Answer(Answer
const& answer2):
75 index_(answer2.index_), dist2_(answer2.dist2_) {
83 AnswerCompare(Nodes
const* nodes,
bool cmpValue):
84 nodes_(nodes), cmpValue_(cmpValue) {
86 bool operator()(Answer
const& a, Answer
const& b)
const {
87 if (cmpValue_ ==
true && a.dist2_ == b.dist2_) {
88 return ((*nodes_)[a.index_].vector_ < (*nodes_)[b.index_].vector_);
90 return (a.dist2_ < b.dist2_);
93 typedef std::vector<Answer> AnswerV;
94 typedef std::priority_queue<Answer, AnswerV, AnswerCompare> Answers;
103 Scalar distance2(
Vector const& v1,
Vector const& v2)
const {
105 for(
size_t i = 0; i < dimension_; i++){
106 ret +=
squ(v1[i] - v2[i]);
111 void query(
Vector const& v,
112 size_t nearestNumber,
113 AnswerCompare
const& answerCompare,
116 std::vector<Scalar>& dist2Vector,
118 Answers *out)
const {
119 if (index == kNIL_) return ;
120 size_t cmp = depth % dimension_;
121 ssize_t this_side, that_side;
122 if (!(nodes_[index].vector_[cmp] < v[cmp])) {
123 this_side = nodes_[index].lChild_;
124 that_side = nodes_[index].rChild_;
126 this_side = nodes_[index].rChild_;
127 that_side = nodes_[index].lChild_;
129 query(v, nearestNumber, answerCompare,
130 this_side, depth + 1,
131 dist2Vector, dist2Minimum,
133 Answer my_ans(index, distance2(nodes_[index].vector_, v));
134 if (out->size() < nearestNumber || answerCompare(my_ans, out->top())) {
136 if (out->size() > nearestNumber) out->pop();
138 Scalar dist2_old(dist2Vector[cmp]);
139 dist2Vector[cmp] =
squ(nodes_[index].vector_[cmp] - v[cmp]);
140 Scalar dist2Minimum2(dist2Minimum + dist2Vector[cmp] - dist2_old);
141 if (out->size() < nearestNumber || !(out->top().dist2_ < dist2Minimum)) {
142 query(v, nearestNumber, answerCompare,
143 that_side, depth + 1,
144 dist2Vector, dist2Minimum2,
147 dist2Vector[cmp] = dist2_old;
149 ssize_t
build(ssize_t beg,
151 std::vector<size_t>* orders,
153 if (beg > end)
return kNIL_;
154 size_t tmp_order = dimension_;
155 size_t which_side = dimension_ + 1;
156 ssize_t mid = (beg + end) / 2;
157 size_t cmp = depth % dimension_;
158 for (ssize_t i = beg; i <= mid; i++) {
159 orders[which_side][orders[cmp][i]] = 0;
161 for (ssize_t i = mid + 1; i <= end; i++) {
162 orders[which_side][orders[cmp][i]] = 1;
164 for (
size_t i = 0; i < dimension_; i++) {
165 if (i == cmp)
continue;
166 size_t left = beg, right = mid + 1;
167 for (
int j = beg; j <= end; j++) {
168 size_t ask = orders[i][j];
169 if(ask == orders[cmp][mid]) {
170 orders[tmp_order][mid] = ask;
172 else if(orders[which_side][ask] == 1) {
173 orders[tmp_order][right++] = ask;
176 orders[tmp_order][left++] = ask;
179 for (
int j = beg; j <= end; j++) {
180 orders[i][j] = orders[tmp_order][j];
183 nodes_[orders[cmp][mid]].lChild_ =
build(beg, mid - 1, orders, depth + 1);
184 nodes_[orders[cmp][mid]].rChild_ =
build(mid + 1, end, orders, depth + 1);
185 return orders[cmp][mid];
192 KD_Tree(): kNIL_(-1), root_(kNIL_), needRebuild_(false), dimension_(1) {
197 kNIL_(-1), root_(kNIL_), needRebuild_(false), dimension_(dimension) {
208 nodes_.push_back(Node(v, kNIL_, kNIL_));
216 for (
size_t i = 0, I = nodes_.size(); i < I; i++) {
217 if (nodes_[i] == v) {
219 std::swap(nodes_[i], nodes_[I - 1]);
241 std::vector<size_t> *orders =
new std::vector<size_t>[dimension_ + 2];
242 for (
size_t j = 0; j < dimension_ + 2; j++) {
243 orders[j].resize(nodes_.size());
245 for (
size_t j = 0; j < dimension_; j++) {
246 for (
size_t i = 0, I = nodes_.size(); i < I; i++) {
249 std::sort(orders[j].begin(), orders[j].end(), Sorter(&nodes_, j));
251 root_ =
build(0, (ssize_t)nodes_.size() - 1, orders, 0);
253 needRebuild_ =
false;
264 size_t nearestNumber,
265 bool compareWholeVector)
const {
267 AnswerCompare answer_compare(&nodes_, compareWholeVector);
268 Answers answer_set(answer_compare);
269 std::vector<Scalar> tmp(dimension_, 0);
270 query(v, nearestNumber,
275 Vectors ret(answer_set.size());
276 for (
int i = (ssize_t)answer_set.size() - 1; i >= 0; i--) {
277 ret[i] = nodes_[answer_set.top().index_].vector_;
289 needRebuild_ =
false;
297 dimension_ = dimension;
303 #endif // dsa_KD_Tree_H__
bool erase(Vector const &v)
將給定的Vector從set移除
KD_Tree()
constructor, with dimension = 1
Vectors query(Vector const &v, size_t nearestNumber, bool compareWholeVector) const
查找
void insert(Vector const &v)
將給定的Vector加到set中
std::vector< Vector > Vectors
Custom Type: Vectors is std::vector<Vector>
KD_Tree(size_t dimension)
constructor, given dimension
void reset(size_t dimension)
清空所有資料並重新給定維度
void build()
檢查至今是否有 insert/erase 被呼叫來決定是否 rebuild()