wavelet tree
圧縮検索で使われる技術wavelet treeをテンプレートライブラリとして書いてみました。
→を参考にしてみました。高速かつ省メモリで文字列を扱うデータ構造「wavelet tree」
元となる記事が大変興味深かったのだけど、どうもサンプルコードが複雑すぎるのと、僕の解釈が悪いのか、記事中の説明がコードとつじつまが合わないところがあったので、自分で実装してみたしだい。
記事中ではハフマンコード化の話があるのだけど、あくまでそれは最適な圧縮率を出すための理論にしか過ぎなくて、
頻度の順番で文字をソートしておいて、文字ごとにその文字を1にしたビット列を格納していったほうが素直だろう。(元記事中は該当文字を0としたが1としたほうが操作しやすいと思う)
たとえば、文字列T = "abccbbabca"があったときその頻度は'b','c','a'の順番になる。このとき各文字ごとにビット列を作っていけばいい。
該当する文字を1、それ以外を0とする。すでに使用した文字はビット列に含まない。
b:0100110100->0100110100
c:0X11XX0X10->011010
a:1XXXXX1XX1->111
以上でビット配列ができた。
wavelet treeで文字に対してrank,selectを行うには、各ビット列をrankは上から下へ、selectは下から上へたどって、葉のときは1を、それ以外は0をrank,selectしてあげる。
たとえば、rank(3,'a')とする。rankでは上から下つまり根から葉の方へたどってあげる。
ちょっとrankの定義がよろしくないのでrank(n,1)=less(n+1,1)となるような関数を用意してあげる。(これはrankでは0から数えてn番目を含んでn+1個のビットを検査する定義のため)
rank(3,'a')->less(4,'a') | rankは上から下へとたどる |
(0100110100).less(4,0)=3 | 前から4つのビットの0を調べる |
(011010).less(3,0)=1 | 前から3つのビットの0を調べる |
(111).less(1,1)=1 | 前から1つのビットの1を調べる |
葉までたどり着いたので最終的な結果1がrank(3,'a')の答えとなる。
select(1,'a')もやってみよう。今度は下から上つまり葉から根の方へたどってあげる。
select(1,'a') | selectは下から上へとたどる |
(111).select(1,1)=1 | (0から数えて)1番目の1のインデックスを調べる |
(011010).select(1,0)=3 | (0から数えて)1番目の0のインデックスを調べる |
(0100110100).select(3,0)=6 | (0から数えて)3番目の0のインデックスを調べる |
根まで数えたので最終的な結果6がselect(1,'a')の答えとなる。
2008/11/29 修正 struct pair_type{T ch;T count;};→struct pair_type{T ch;size_t count;};
ビット列
bit_vector.hpp
#ifndef EDL_BIT_VECTOR_HPP #define EDL_BIT_VECTOR_HPP #include <iostream> #include <sstream> #include <bitset> namespace edl{ template<class T> class basic_bit_vector{ public: typedef T block_t; static const size_t BLOCK_SIZE = sizeof(block_t)*8; static const int BLOCK_SHIFT = sizeof(block_t)/4*5; static const int BLOCK_MASK = (1<<BLOCK_SHIFT)-1; public: basic_bit_vector(size_t sz):sz_(sz),bsz_((sz + BLOCK_SIZE-1)/BLOCK_SIZE),a_(0){ a_ = new block_t[bsz_];// memset(a_,0,sizeof(block_t)*bsz_); } basic_bit_vector(const basic_bit_vector& rhs):sz_(rhs.sz_),bsz_(rhs.bsz_),a_(0){ a_ = new block_t[bsz_];// memcpy(a_,rhs.a_,sizeof(block_t)*bsz_); } template<size_t Bits> basic_bit_vector(const std::bitset<Bits>& rhs):sz_(Bits),bsz_((Bits + BLOCK_SIZE-1)/BLOCK_SIZE),a_(0){ a_ = new block_t[bsz_];// memset(a_,0,sizeof(block_t)*bsz_); for(size_t i=0;i<Bits;i++){ if(rhs[i])this->set1(i); else this->set0(i); } } ~basic_bit_vector(){ delete[] a_; } basic_bit_vector& operator=(const basic_bit_vector& rhs){ this->swap(basic_bit_vector(rhs)); return *this; } //------------------------------------- void resize(size_t sz){ if(sz_>=sz)return; basic_bit_vector tmp(sz); block_t* pb = tmp.get_ptr(); memcpy(tmp.get_ptr(),get_ptr(),sizeof(block_t)*bsz_); this->swap(tmp); } //------------------------------------- int get(size_t pos)const{ size_t div = get_div(pos); size_t rem = get_rem(pos); return (a_[div]>>rem)&1; } void set(size_t pos, int bit){ if(bit)set1(pos); else set0(pos); } void set1(size_t pos){ if(pos>sz_)resize(pos+1); size_t div = get_div(pos); size_t rem = get_rem(pos); set1(div,rem); } void set0(size_t pos){ if(pos>sz_)resize(pos+1); size_t div = get_div(pos); size_t rem = get_rem(pos); set0(div,rem); } size_t graw_size(size_t pos){ size_t n = sz_; while(pos>n){ n*=2; } return n; } int operator[](size_t pos)const{return get(pos);} size_t size()const{return sz_;} size_t block_size()const{return bsz_;} size_t allocation_size()const{return bsz_*sizeof(block_t);} const block_t* get_ptr()const{return a_;} block_t* get_ptr(){return a_;} block_t& get_block(size_t pos){ return a_[pos]; } const block_t& get_block(size_t pos)const{ return a_[pos]; } public: void swap(basic_bit_vector& rhs){ std::swap(sz_,rhs.sz_); std::swap(bsz_,rhs.bsz_); std::swap(a_,rhs.a_); } basic_bit_vector& inv(){ size_t bsz = bsz_; for(size_t i = 0;i<bsz;i++){ a_[i] = ~a_[i]; } return *this; } protected: static size_t get_div(size_t pos){ //return pos/BLOCK_SIZE; return pos>>BLOCK_SHIFT; } static size_t get_rem(size_t pos){ //return pos%BLOCK_SIZE; return pos&BLOCK_MASK; } void set1(size_t div, size_t rem){ a_[div] |= 1<<rem; } void set0(size_t div, size_t rem){ a_[div] &= ~(1<<rem); } protected: size_t sz_; //size of bits size_t bsz_;//size of blocks block_t* a_;// }; template<class T> basic_bit_vector<T> operator~(const basic_bit_vector<T>& rhs){ return basic_bit_vector<T>(rhs).inv(); } template<class T, class CharT, class Traits> std::basic_ostream<CharT, Traits>& operator<<(std::basic_ostream<CharT, Traits>& os, const basic_bit_vector<T>& rhs){ std::basic_stringstream<CharT, Traits> s; s.flags(os.flags()); s.imbue(os.getloc()); s.precision(os.precision()); size_t sz = rhs.size(); for(size_t i = 0;i<sz;i++){ if(rhs[i])s << '1'; else s << '0'; } return os<<s.str(); } template class basic_bit_vector<unsigned int>; typedef basic_bit_vector<unsigned int> bit_vector; } #endif
rank,selectの実装
bit_vector_rs.hpp
#ifndef EDL_BIT_VECTOR_RS_HPP #define EDL_BIT_VECTOR_RS_HPP #include "bit_vector.hpp" namespace edl{ template<class T> class basic_bit_vector_rs:public basic_bit_vector<T>{ public: typedef basic_bit_vector<T> base_type; public: basic_bit_vector_rs(const basic_bit_vector<T>& rhs):bit_vector(rhs){ c_ = new size_t[cache_size()]; recache(); } basic_bit_vector_rs(const basic_bit_vector_rs<T>& rhs):bit_vector(rhs){ c_ = new size_t[cache_size()]; recache(); } ~basic_bit_vector_rs(){ delete[] c_; } basic_bit_vector_rs& operator=(const basic_bit_vector_rs<T>& rhs){ this->swap(basic_bit_vector_rs<T>(rhs)); return *this; } void swap(basic_bit_vector_rs& rhs){ base_type::swap(rhs); std::swap(c_,rhs.c_); } size_t cache_size()const{ return block_size()+1; } //------------------------------------------------- size_t size1()const{ return c_[block_size()]; } size_t size0()const{ return size() -size1(); } //------------------------------------------------- void recache(){ c_[0]=0; recache(0); } size_t rank(size_t pos, int bit){ if(bit)return rank1(pos); else return rank0(pos); } size_t rank1(size_t pos)const{ return less1(pos+1); } size_t rank0(size_t pos)const{ return less0(pos+1); } size_t less(size_t pos, int bit)const{ if(bit)return less1(pos); else return less0(pos); } //---------------------------- size_t less1(size_t pos)const{ return less1_fast(pos); } size_t less0(size_t pos)const{ return pos-less1(pos); } //---------------------------- void set(size_t pos, int bit){ if(bit)set1(pos); else set0(pos); } void set1(size_t pos){ size_t div = get_div(pos); size_t rem = get_rem(pos); block_t old = a_[div]; base_type::set1(div,rem); if(old != a_[div])recache(div); } void set0(size_t pos){ size_t div = get_div(pos); size_t rem = get_rem(pos); block_t old = a_[div]; base_type::set0(div,rem); if(old != a_[div])recache(div); } void set1_raw(size_t pos){ size_t div = get_div(pos); size_t rem = get_rem(pos); base_type::set1(div,rem); } void set0_raw(size_t pos){ size_t div = get_div(pos); size_t rem = get_rem(pos); base_type::set0(div,rem); } //----------------------------- size_t select(size_t pos, int bit)const{ if(bit)return select1(pos); else return select0(pos); } size_t select1(size_t pos)const{ if(pos>=size1())return size(); size_t a = 0; size_t b = block_size(); while(a+1<b){ size_t m = (a+b)>>1; size_t c = c_[m]; if(c <= pos){a = m;}else{b=m;} } size_t bsz = c_[a]; if(pos<bsz||pos-bsz>=sizeof(block_t)*8)return size(); int rem = select_block(get_block(a),pos-bsz); if(rem<0)return size(); return a*BLOCK_SIZE + rem; } size_t select0(size_t pos)const{ if(pos>=size0())return size(); size_t a = 0; size_t b = block_size(); while(a+1<b){ size_t m = (a+b)>>1; size_t c = m*BLOCK_SIZE-c_[m]; if(c <= pos){a = m;}else{b=m;} } size_t bsz = a*BLOCK_SIZE-c_[a]; if(pos<bsz||pos-bsz>=sizeof(block_t)*8)return size(); int rem = select_block(~get_block(a),pos-bsz); if(rem<0)return size(); return a*BLOCK_SIZE + rem; } protected: void recache(size_t begin){ size_t r = c_[begin]; size_t bsz = block_size(); for(size_t i = begin;i<bsz;i++){ r += pop_count(get_block(i)); c_[i+1] = r; } } size_t less1_fast(size_t pos)const{//no pos size_t div = get_div(pos); size_t rem = get_rem(pos); return c_[div] + pop_count(get_block(div) & ((1<<rem)-1)); } size_t less1_basic(size_t pos)const{ size_t div = get_div(pos); size_t rem = get_rem(pos); size_t r = 0; for(size_t i = 0;i<div;i++){ r += pop_count(get_block(i)); } return r + pop_count(get_block(div) & ((1<<rem)-1)); } static int pop_count_byte(int p){ assert(0x0<=p&&p<=0xFF); static const unsigned char popCountArray[] = { 0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4,1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5, 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6, 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6, 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7, 1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6, 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7, 2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7, 3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,4,5,5,6,5,6,6,7,5,6,6,7,6,7,7,8 }; return popCountArray[p]; } static int pop_count(block_t r){ int nRet = 0; for(int i=0;i<sizeof(block_t);i++){ nRet += pop_count_byte((r >> (i*8)) & 0xFF); } return nRet; } //------------------------------------------------- static int select_block(block_t r, int p){ //if(pop_count(r)<=p)return -1; size_t nRet = 0; int i = 0; int b; while(i<sizeof(block_t)){ b = (r >> (i*8)) & 0xFF; int t = pop_count_byte(b); if (p < t) break; p -= t; i++; } if(!(0<=p&&p<=8))return -1; nRet += 8*i; nRet += select_byte(b,p); return nRet; } static int select_byte(unsigned char byte, int p){ assert(0<=p&&p<=8); int a = 0; int b = 8; for(int i = 0;i<3;i++){ int m = (a+b)>>1; int c = pop_count_byte(byte&((1<<m)-1)); if(c <= p){a = m;}else{b=m;} } return a; } protected: size_t* c_; }; template class basic_bit_vector_rs<unsigned int>; typedef basic_bit_vector_rs<bit_vector::block_t> bit_vector_rs; } #endif
wavelet tree
wavelet_tree.hpp
#ifndef EDL_WAVELET_TREE_HPP #define EDL_WAVELET_TREE_HPP #include "bit_vector_rs.hpp" #include <utility> #include <functional> #include <algorithm> #include <iostream> namespace edl{ template<class T> class basic_wavelet_tree{ public: typedef T char_type; static const size_t TSZ = 1<<(sizeof(T)*8); basic_wavelet_tree(const T* s, size_t sz):sz_(sz){ build(s, sz); } ~basic_wavelet_tree(){ destroy(); } size_t rank(size_t pos, T c)const{ return less(pos+1,c); } size_t less(size_t pos, T c)const{ int o = c2o_[c]; if(!bits_[o])return sz_; size_t next = pos; for(int i = 0;i<o;i++){ next = bits_[i]->less0(next); } next = bits_[o]->less1(next); return next; } size_t select(size_t pos, T c)const{ int o = c2o_[c]; if(!bits_[o])return sz_; size_t next = bits_[o]->select1(pos); for(int i = o-1;i>=0;i--){ next = bits_[i]->select0(next); } return next; } void debug_print(){ T o2c[TSZ]; for(size_t i = 0;i<TSZ;i++){ o2c[c2o_[i]]=i; } for(size_t i = 0;i<TSZ;i++){ if(bits_[i]){ std::cout <<o2c[i] <<":"<<*bits_[i] << std::endl; } } } protected: void build(const T* s, size_t sz){ T o2c[TSZ]; // sort_order(o2c, s, sz); std::auto_ptr<bit_vector_rs> bits[TSZ]; bit_vector bUsed(TSZ); size_t bsz = sz; for(size_t o=0;o<TSZ;o++){ if(bsz==0)break; T c = o2c[o];// bits[o].reset(new bit_vector_rs(bsz)); { size_t count=0; size_t used =0; for(size_t i=0;i<sz;i++){ if(bUsed.get(s[i]) == 0){ if(s[i]==c){ bits[o]->set1_raw(i-used); count++; } }else{ used++; } } bsz -= count; } bits[o]->recache(); bUsed.set1(c); } //----------------------- for(size_t o=0;o<TSZ;o++){ bits_[o]=bits[o].release(); c2o_[o2c[o]]=o; } } void destroy(){ for(size_t i = 0;i<TSZ;i++){ if(bits_[i])delete bits_[i]; bits_[i]=0; } } static void sort_order(T o2c[TSZ], const T* s, size_t sz){ struct pair_type{T ch;size_t count;}; struct pair_sorter:public std::binary_function<pair_type,pair_type,bool>{ bool operator()(const pair_type& r, const pair_type& l)const{ return r.count > l.count; } }; pair_type count[TSZ]; memset(count,0,sizeof(pair_type)*TSZ); for(T i=0;i<TSZ;i++){ count[i].ch = i; } //------------------------- for(size_t i = 0;i<sz;i++){ count[s[i]].count++; } //------------------------- std::sort(count,count+TSZ,pair_sorter()); //------------------------- for(T i=0;i<TSZ;i++){ //c2o[count[i].ch] = i; o2c[i] = count[i].ch; } } protected: T c2o_[TSZ];//character->order bit_vector_rs* bits_[TSZ]; size_t sz_; }; template class basic_wavelet_tree<char>; typedef basic_wavelet_tree<char> wavelet_tree; } #endif
テスト
bit_test.cpp
// bit_test.cpp : コンソール アプリケーションのエントリ ポイントを定義します。 // #include "stdafx.h" #include "type.h" #include <iostream> #include <sstream> #include "edl/bit_vector.hpp" #include "edl/bit_vector_rs.hpp" #include "edl/wavelet_tree.hpp" int _tmain(int argc, _TCHAR* argv[]) { std::stringstream ss; std::cout << ss.str() << std::endl; edl::bit_vector bv(50); bv.set(8,1); bv.set(9,1); bv.set(37,1); bv.set(33,1); edl::bit_vector_rs bvr(bv); edl::bit_vector_rs bvr2(bvr); std::cout << bvr << std::endl; std::cout << bvr2 << std::endl; bvr.set(40,1); bvr2 = bvr; std::cout << bvr << std::endl; std::cout << bvr2 << std::endl; const char* s = "abccbbabca"; edl::wavelet_tree wt(s,strlen(s)); std::cout << wt.rank(5,'b') << std::endl; std::cout << wt.rank(5,'c') << std::endl; std::cout << wt.rank(5,'a') << std::endl; std::cout << wt.select(1,'b') << std::endl; std::cout << wt.select(1,'c') << std::endl; std::cout << wt.select(1,'a') << std::endl; wt.debug_print(); return 0; }