Templates -- Meow  1.2.9
A C++ template contains kinds of interesting classes and functions
KD_Tree.h
Go to the documentation of this file.
1 #ifndef dsa_KD_Tree_H__
2 #define dsa_KD_Tree_H__
3 
4 #include "../utility.h"
5 #include "../math/utility.h"
6 
7 #include <cstdlib>
8 
9 #include <vector>
10 #include <algorithm>
11 #include <queue>
12 
13 namespace meow {
14 
39 template<class Vector, class Scalar>
40 class KD_Tree {
41 private:
42  struct Node {
43  Vector vector_;
44  ssize_t lChild_;
45  ssize_t rChild_;
46 
47  Node(Vector v, ssize_t l, ssize_t r): vector_(v), lChild_(l), rChild_(r){
48  }
49  };
50  typedef std::vector<Node> Nodes;
51 
52  class Sorter {
53  private:
54  Nodes const* nodes_;
55  size_t cmp_;
56  public:
57  Sorter(Nodes const* nodes, size_t cmp):
58  nodes_(nodes), cmp_(cmp){
59  }
60  bool operator()(size_t const& a, size_t const& b) const{
61  if ((*nodes_)[a].vector_[cmp_] != (*nodes_)[b].vector_[cmp_]) {
62  return ((*nodes_)[a].vector_[cmp_] < (*nodes_)[b].vector_[cmp_]);
63  }
64  return ((*nodes_)[a].vector_ < (*nodes_)[b].vector_);
65  }
66  };
67  struct Answer {
68  ssize_t index_;
69  Scalar dist2_;
70  //
71  Answer(ssize_t index, Scalar dist2):
72  index_(index), dist2_(dist2) {
73  }
74  Answer(Answer const& answer2):
75  index_(answer2.index_), dist2_(answer2.dist2_) {
76  }
77  };
78  class AnswerCompare {
79  private:
80  Nodes const* nodes_;
81  bool cmpValue_;
82  public:
83  AnswerCompare(Nodes const* nodes, bool cmpValue):
84  nodes_(nodes), cmpValue_(cmpValue) {
85  }
86  bool operator()(Answer const& a, Answer const& b) const {
87  if (cmpValue_ == true && a.dist2_ == b.dist2_) {
88  return ((*nodes_)[a.index_].vector_ < (*nodes_)[b.index_].vector_);
89  }
90  return (a.dist2_ < b.dist2_);
91  }
92  };
93  typedef std::vector<Answer> AnswerV;
94  typedef std::priority_queue<Answer, AnswerV, AnswerCompare> Answers;
95  //
96  const ssize_t kNIL_;
97  //
98  Nodes nodes_;
99  size_t root_;
100  bool needRebuild_;
101  size_t dimension_;
102  //
103  Scalar distance2(Vector const& v1, Vector const& v2) const {
104  Scalar ret(0);
105  for(size_t i = 0; i < dimension_; i++){
106  ret += squ(v1[i] - v2[i]);
107  }
108  return ret;
109  }
110  //
111  void query(Vector const& v,
112  size_t nearestNumber,
113  AnswerCompare const& answerCompare,
114  ssize_t index,
115  int depth,
116  std::vector<Scalar>& dist2Vector,
117  Scalar dist2Minimum,
118  Answers *out) const {
119  if (index == kNIL_) return ;
120  size_t cmp = depth % dimension_;
121  ssize_t this_side, that_side;
122  if (!(nodes_[index].vector_[cmp] < v[cmp])) {
123  this_side = nodes_[index].lChild_;
124  that_side = nodes_[index].rChild_;
125  }else{
126  this_side = nodes_[index].rChild_;
127  that_side = nodes_[index].lChild_;
128  }
129  query(v, nearestNumber, answerCompare,
130  this_side, depth + 1,
131  dist2Vector, dist2Minimum,
132  out);
133  Answer my_ans(index, distance2(nodes_[index].vector_, v));
134  if (out->size() < nearestNumber || answerCompare(my_ans, out->top())) {
135  out->push(my_ans);
136  if (out->size() > nearestNumber) out->pop();
137  }
138  Scalar dist2_old(dist2Vector[cmp]);
139  dist2Vector[cmp] = squ(nodes_[index].vector_[cmp] - v[cmp]);
140  Scalar dist2Minimum2(dist2Minimum + dist2Vector[cmp] - dist2_old);
141  if (out->size() < nearestNumber || !(out->top().dist2_ < dist2Minimum)) {
142  query(v, nearestNumber, answerCompare,
143  that_side, depth + 1,
144  dist2Vector, dist2Minimum2,
145  out);
146  }
147  dist2Vector[cmp] = dist2_old;
148  }
149  ssize_t build(ssize_t beg,
150  ssize_t end,
151  std::vector<size_t>* orders,
152  int depth) {
153  if (beg > end) return kNIL_;
154  size_t tmp_order = dimension_;
155  size_t which_side = dimension_ + 1;
156  ssize_t mid = (beg + end) / 2;
157  size_t cmp = depth % dimension_;
158  for (ssize_t i = beg; i <= mid; i++) {
159  orders[which_side][orders[cmp][i]] = 0;
160  }
161  for (ssize_t i = mid + 1; i <= end; i++) {
162  orders[which_side][orders[cmp][i]] = 1;
163  }
164  for (size_t i = 0; i < dimension_; i++) {
165  if (i == cmp) continue;
166  size_t left = beg, right = mid + 1;
167  for (int j = beg; j <= end; j++) {
168  size_t ask = orders[i][j];
169  if(ask == orders[cmp][mid]) {
170  orders[tmp_order][mid] = ask;
171  }
172  else if(orders[which_side][ask] == 1) {
173  orders[tmp_order][right++] = ask;
174  }
175  else {
176  orders[tmp_order][left++] = ask;
177  }
178  }
179  for (int j = beg; j <= end; j++) {
180  orders[i][j] = orders[tmp_order][j];
181  }
182  }
183  nodes_[orders[cmp][mid]].lChild_ = build(beg, mid - 1, orders, depth + 1);
184  nodes_[orders[cmp][mid]].rChild_ = build(mid + 1, end, orders, depth + 1);
185  return orders[cmp][mid];
186  }
187 public:
189  typedef typename std::vector<Vector> Vectors;
190 
192  KD_Tree(): kNIL_(-1), root_(kNIL_), needRebuild_(false), dimension_(1) {
193  }
194 
196  KD_Tree(size_t dimension):
197  kNIL_(-1), root_(kNIL_), needRebuild_(false), dimension_(dimension) {
198  }
199 
202  }
203 
207  void insert(Vector const& v) {
208  nodes_.push_back(Node(v, kNIL_, kNIL_));
209  needRebuild_ = true;
210  }
211 
215  bool erase(Vector const& v) {
216  for (size_t i = 0, I = nodes_.size(); i < I; i++) {
217  if (nodes_[i] == v) {
218  if (i != I - 1) {
219  std::swap(nodes_[i], nodes_[I - 1]);
220  }
221  needRebuild_ = true;
222  return true;
223  }
224  }
225  return false;
226  }
227 
231  void build(){
232  if (needRebuild_) {
233  forceBuild();
234  }
235  }
236 
240  void forceBuild() {
241  std::vector<size_t> *orders = new std::vector<size_t>[dimension_ + 2];
242  for (size_t j = 0; j < dimension_ + 2; j++) {
243  orders[j].resize(nodes_.size());
244  }
245  for (size_t j = 0; j < dimension_; j++) {
246  for (size_t i = 0, I = nodes_.size(); i < I; i++) {
247  orders[j][i] = i;
248  }
249  std::sort(orders[j].begin(), orders[j].end(), Sorter(&nodes_, j));
250  }
251  root_ = build(0, (ssize_t)nodes_.size() - 1, orders, 0);
252  delete [] orders;
253  needRebuild_ = false;
254  }
255 
263  Vectors query(Vector const& v,
264  size_t nearestNumber,
265  bool compareWholeVector) const {
266  ((KD_Tree*)this)->build();
267  AnswerCompare answer_compare(&nodes_, compareWholeVector);
268  Answers answer_set(answer_compare);
269  std::vector<Scalar> tmp(dimension_, 0);
270  query(v, nearestNumber,
271  answer_compare,
272  root_, 0,
273  tmp, Scalar(0),
274  &answer_set);
275  Vectors ret(answer_set.size());
276  for (int i = (ssize_t)answer_set.size() - 1; i >= 0; i--) {
277  ret[i] = nodes_[answer_set.top().index_].vector_;
278  answer_set.pop();
279  }
280  return ret;
281  }
282 
286  void clear() {
287  root_ = kNIL_;
288  nodes_.clear();
289  needRebuild_ = false;
290  }
291 
295  void reset(size_t dimension) {
296  clear();
297  dimension_ = dimension;
298  }
299 };
300 
301 } // meow
302 
303 #endif // dsa_KD_Tree_H__
bool erase(Vector const &v)
將給定的Vector從set移除
Definition: KD_Tree.h:215
void clear()
清空所有資料
Definition: KD_Tree.h:286
k-dimension tree
Definition: KD_Tree.h:40
void forceBuild()
重新建樹
Definition: KD_Tree.h:240
KD_Tree()
constructor, with dimension = 1
Definition: KD_Tree.h:192
vector
Definition: Vector.h:19
Vectors query(Vector const &v, size_t nearestNumber, bool compareWholeVector) const
查找
Definition: KD_Tree.h:263
void insert(Vector const &v)
將給定的Vector加到set中
Definition: KD_Tree.h:207
std::vector< Vector > Vectors
Custom Type: Vectors is std::vector<Vector>
Definition: KD_Tree.h:189
KD_Tree(size_t dimension)
constructor, given dimension
Definition: KD_Tree.h:196
void reset(size_t dimension)
清空所有資料並重新給定維度
Definition: KD_Tree.h:295
T squ(T const &x)
x*x
Definition: utility.h:67
void build()
檢查至今是否有 insert/erase 被呼叫來決定是否 rebuild()
Definition: KD_Tree.h:231
~KD_Tree()
destructor
Definition: KD_Tree.h:201