wavelet tree

圧縮検索で使われる技術wavelet treeをテンプレートライブラリとして書いてみました。
→を参考にしてみました。高速かつ省メモリで文字列を扱うデータ構造「wavelet tree」
元となる記事が大変興味深かったのだけど、どうもサンプルコードが複雑すぎるのと、僕の解釈が悪いのか、記事中の説明がコードとつじつまが合わないところがあったので、自分で実装してみたしだい。
記事中ではハフマンコード化の話があるのだけど、あくまでそれは最適な圧縮率を出すための理論にしか過ぎなくて、
頻度の順番で文字をソートしておいて、文字ごとにその文字を1にしたビット列を格納していったほうが素直だろう。(元記事中は該当文字を0としたが1としたほうが操作しやすいと思う)
たとえば、文字列T = "abccbbabca"があったときその頻度は'b','c','a'の順番になる。このとき各文字ごとにビット列を作っていけばいい。
該当する文字を1、それ以外を0とする。すでに使用した文字はビット列に含まない。
b:0100110100->0100110100
c:0X11XX0X10->011010
a:1XXXXX1XX1->111
以上でビット配列ができた。

wavelet treeで文字に対してrank,selectを行うには、各ビット列をrankは上から下へ、selectは下から上へたどって、葉のときは1を、それ以外は0をrank,selectしてあげる。


たとえば、rank(3,'a')とする。rankでは上から下つまり根から葉の方へたどってあげる。
ちょっとrankの定義がよろしくないのでrank(n,1)=less(n+1,1)となるような関数を用意してあげる。(これはrankでは0から数えてn番目を含んでn+1個のビットを検査する定義のため)

rank(3,'a')->less(4,'a') rankは上から下へとたどる
(0100110100).less(4,0)=3 前から4つのビットの0を調べる
(011010).less(3,0)=1 前から3つのビットの0を調べる
(111).less(1,1)=1 前から1つのビットの1を調べる

葉までたどり着いたので最終的な結果1がrank(3,'a')の答えとなる。

select(1,'a')もやってみよう。今度は下から上つまり葉から根の方へたどってあげる。

select(1,'a') selectは下から上へとたどる
(111).select(1,1)=1 (0から数えて)1番目の1のインデックスを調べる
(011010).select(1,0)=3 (0から数えて)1番目の0のインデックスを調べる
(0100110100).select(3,0)=6 (0から数えて)3番目の0のインデックスを調べる

根まで数えたので最終的な結果6がselect(1,'a')の答えとなる。



2008/11/29 修正 struct pair_type{T ch;T count;};→struct pair_type{T ch;size_t count;};


ビット列
bit_vector.hpp

#ifndef EDL_BIT_VECTOR_HPP
#define EDL_BIT_VECTOR_HPP

#include <iostream>
#include <sstream>

#include <bitset>

namespace edl{

    template<class T>
    class basic_bit_vector{
    public:
        typedef T block_t;
        static const size_t BLOCK_SIZE = sizeof(block_t)*8;
        static const int    BLOCK_SHIFT = sizeof(block_t)/4*5;
        static const int    BLOCK_MASK  = (1<<BLOCK_SHIFT)-1;

    public:
        basic_bit_vector(size_t sz):sz_(sz),bsz_((sz + BLOCK_SIZE-1)/BLOCK_SIZE),a_(0){
            a_ = new block_t[bsz_];//
            memset(a_,0,sizeof(block_t)*bsz_);
        }

        basic_bit_vector(const basic_bit_vector& rhs):sz_(rhs.sz_),bsz_(rhs.bsz_),a_(0){
            a_ = new block_t[bsz_];//
            memcpy(a_,rhs.a_,sizeof(block_t)*bsz_);         
        }

        template<size_t Bits>
        basic_bit_vector(const std::bitset<Bits>& rhs):sz_(Bits),bsz_((Bits + BLOCK_SIZE-1)/BLOCK_SIZE),a_(0){
            a_ = new block_t[bsz_];//
            memset(a_,0,sizeof(block_t)*bsz_);
            for(size_t i=0;i<Bits;i++){
                if(rhs[i])this->set1(i);
                else      this->set0(i);
            }
        }

        ~basic_bit_vector(){
            delete[] a_;
        }

        basic_bit_vector& operator=(const basic_bit_vector& rhs){
            this->swap(basic_bit_vector(rhs));
            return *this;
        }

        //-------------------------------------
        void resize(size_t sz){
            if(sz_>=sz)return;
            basic_bit_vector tmp(sz);
            block_t* pb = tmp.get_ptr();
            memcpy(tmp.get_ptr(),get_ptr(),sizeof(block_t)*bsz_);
            this->swap(tmp);
        }
        //-------------------------------------
        int get(size_t pos)const{
            size_t div = get_div(pos);
            size_t rem = get_rem(pos);
            return (a_[div]>>rem)&1; 
        }
        void set(size_t pos, int bit){
            if(bit)set1(pos);
            else   set0(pos);
        }

        void set1(size_t pos){
            if(pos>sz_)resize(pos+1);
            size_t div = get_div(pos);
            size_t rem = get_rem(pos);
            set1(div,rem);
        }
        void set0(size_t pos){
            if(pos>sz_)resize(pos+1);
            size_t div = get_div(pos);
            size_t rem = get_rem(pos);
            set0(div,rem);
        }
        size_t graw_size(size_t pos){
            size_t n = sz_;
            while(pos>n){
                n*=2;
            }
            return n;
        }        

        int operator[](size_t pos)const{return get(pos);}

        size_t size()const{return sz_;}
        size_t block_size()const{return bsz_;}
        size_t allocation_size()const{return bsz_*sizeof(block_t);}

        const block_t* get_ptr()const{return a_;}
        block_t* get_ptr(){return a_;}

        block_t& get_block(size_t pos){
            return a_[pos];
        }   

        const block_t& get_block(size_t pos)const{
            return a_[pos];
        }   

    public:
        void swap(basic_bit_vector& rhs){
            std::swap(sz_,rhs.sz_);
            std::swap(bsz_,rhs.bsz_);
            std::swap(a_,rhs.a_);
        }

        basic_bit_vector& inv(){
            size_t bsz = bsz_;
            for(size_t i = 0;i<bsz;i++){
                a_[i] = ~a_[i];
            }
            return *this;
        }
    protected:
        static size_t get_div(size_t pos){
            //return pos/BLOCK_SIZE;
            return pos>>BLOCK_SHIFT;
        }
        static size_t get_rem(size_t pos){
            //return pos%BLOCK_SIZE;
            return pos&BLOCK_MASK;
        }

        void set1(size_t div, size_t rem){
            a_[div] |= 1<<rem;
        }
        void set0(size_t div, size_t rem){
            a_[div] &= ~(1<<rem);
        }


    protected:
        size_t sz_; //size of bits 
        size_t bsz_;//size of blocks
        block_t* a_;//
    };

    template<class T>
    basic_bit_vector<T> operator~(const basic_bit_vector<T>& rhs){
        return basic_bit_vector<T>(rhs).inv();
    }


    template<class T, class CharT, class Traits>
    std::basic_ostream<CharT, Traits>& operator<<(std::basic_ostream<CharT, Traits>& os, const basic_bit_vector<T>& rhs){
        std::basic_stringstream<CharT, Traits> s;

        s.flags(os.flags());
        s.imbue(os.getloc());
        s.precision(os.precision());

        size_t sz = rhs.size();
        for(size_t i = 0;i<sz;i++){
            if(rhs[i])s << '1';
            else      s << '0';
        }

        return os<<s.str();
    }

    template class basic_bit_vector<unsigned int>;
    typedef        basic_bit_vector<unsigned int> bit_vector;
}

#endif

rank,selectの実装
bit_vector_rs.hpp

#ifndef EDL_BIT_VECTOR_RS_HPP
#define EDL_BIT_VECTOR_RS_HPP

#include "bit_vector.hpp"

namespace edl{

    template<class T>
    class basic_bit_vector_rs:public basic_bit_vector<T>{
    public:
        typedef basic_bit_vector<T> base_type;
    public:
        basic_bit_vector_rs(const basic_bit_vector<T>& rhs):bit_vector(rhs){
            c_ = new size_t[cache_size()];
            recache();
        }
        basic_bit_vector_rs(const basic_bit_vector_rs<T>& rhs):bit_vector(rhs){
            c_ = new size_t[cache_size()];
            recache();
        }
        ~basic_bit_vector_rs(){
            delete[] c_;
        }
        
        basic_bit_vector_rs& operator=(const basic_bit_vector_rs<T>& rhs){
            this->swap(basic_bit_vector_rs<T>(rhs));
            return *this;
        }

        void swap(basic_bit_vector_rs& rhs){
            base_type::swap(rhs);
            std::swap(c_,rhs.c_);
        }

        size_t cache_size()const{
            return block_size()+1;
        }

        //-------------------------------------------------
        size_t size1()const{
            return c_[block_size()];
        }
        size_t size0()const{
            return size() -size1();
        }       
        //-------------------------------------------------
        void recache(){
            c_[0]=0;
            recache(0);
        }

        size_t rank(size_t pos, int bit){
            if(bit)return rank1(pos);
            else   return rank0(pos);
        }

        size_t rank1(size_t pos)const{
            return less1(pos+1);
        }

        size_t rank0(size_t pos)const{
            return less0(pos+1);
        }

        size_t less(size_t pos, int bit)const{
            if(bit)return less1(pos);
            else   return less0(pos);
        }
        //----------------------------
        size_t less1(size_t pos)const{
            return less1_fast(pos);
        }
        size_t less0(size_t pos)const{
            return pos-less1(pos);
        }
        //----------------------------
        void set(size_t pos, int bit){
            if(bit)set1(pos);
            else   set0(pos);
        }

        void set1(size_t pos){
            size_t div = get_div(pos);
            size_t rem = get_rem(pos);

            block_t old = a_[div];
            base_type::set1(div,rem);
            if(old != a_[div])recache(div);
        }
        void set0(size_t pos){
            size_t div = get_div(pos);
            size_t rem = get_rem(pos);

            block_t old = a_[div];
            base_type::set0(div,rem);
            if(old != a_[div])recache(div);
        }

        void set1_raw(size_t pos){
            size_t div = get_div(pos);
            size_t rem = get_rem(pos);
            base_type::set1(div,rem);
        }
        void set0_raw(size_t pos){
            size_t div = get_div(pos);
            size_t rem = get_rem(pos);
            base_type::set0(div,rem);
        }
        //-----------------------------
        size_t select(size_t pos, int bit)const{
            if(bit)return select1(pos);
            else   return select0(pos);
        }

        size_t select1(size_t pos)const{
            if(pos>=size1())return size();
            size_t a = 0;
            size_t b = block_size();
            while(a+1<b){
                size_t m = (a+b)>>1;
                size_t c = c_[m];
                if(c <= pos){a = m;}else{b=m;}
            }
            size_t bsz = c_[a];
            if(pos<bsz||pos-bsz>=sizeof(block_t)*8)return size();
            int rem = select_block(get_block(a),pos-bsz);
            if(rem<0)return size();
            return a*BLOCK_SIZE + rem;          
        }

        size_t select0(size_t pos)const{
            if(pos>=size0())return size();
            size_t a = 0;
            size_t b = block_size();
            while(a+1<b){
                size_t m = (a+b)>>1;
                size_t c = m*BLOCK_SIZE-c_[m];
                if(c <= pos){a = m;}else{b=m;}
            }
            size_t bsz = a*BLOCK_SIZE-c_[a];
            if(pos<bsz||pos-bsz>=sizeof(block_t)*8)return size();
            int rem = select_block(~get_block(a),pos-bsz);
            if(rem<0)return size();
            return a*BLOCK_SIZE + rem;          
        }
    protected:
        void recache(size_t begin){
            size_t r = c_[begin];
            size_t bsz = block_size();
            for(size_t i = begin;i<bsz;i++){
                r += pop_count(get_block(i));
                c_[i+1] = r;
            }
        }

        size_t less1_fast(size_t pos)const{//no pos
            size_t div = get_div(pos);
            size_t rem = get_rem(pos);
            return c_[div] + pop_count(get_block(div) & ((1<<rem)-1));
        }

        size_t less1_basic(size_t pos)const{
            size_t div = get_div(pos);
            size_t rem = get_rem(pos);
            size_t r = 0;
            for(size_t i = 0;i<div;i++){
                r += pop_count(get_block(i));
            }
            return r + pop_count(get_block(div) & ((1<<rem)-1));
        }
        static int pop_count_byte(int p){
            assert(0x0<=p&&p<=0xFF);
            static const unsigned char popCountArray[] = {
                0,1,1,2,1,2,2,3,1,2,2,3,2,3,3,4,1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,
                1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
                1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
                2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
                1,2,2,3,2,3,3,4,2,3,3,4,3,4,4,5,2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,
                2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
                2,3,3,4,3,4,4,5,3,4,4,5,4,5,5,6,3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,
                3,4,4,5,4,5,5,6,4,5,5,6,5,6,6,7,4,5,5,6,5,6,6,7,5,6,6,7,6,7,7,8
            };
            return popCountArray[p];
        }

        static int pop_count(block_t r){
            int nRet = 0;
            for(int i=0;i<sizeof(block_t);i++){
                nRet += pop_count_byte((r >> (i*8)) & 0xFF);
            }
            return nRet;
        }
        //-------------------------------------------------

        static int select_block(block_t r, int p){
            //if(pop_count(r)<=p)return -1;

            size_t nRet = 0;
            int i = 0;
            int b;
            while(i<sizeof(block_t)){
                b = (r >> (i*8)) & 0xFF;
                int t = pop_count_byte(b);
                if (p < t) break;
                p -= t;
                i++;
            }
            if(!(0<=p&&p<=8))return -1;
            nRet += 8*i;
            nRet += select_byte(b,p);
            return nRet;
        }
        static int select_byte(unsigned char byte, int p){
            assert(0<=p&&p<=8);
            
            int a = 0;
            int b = 8;
            for(int i = 0;i<3;i++){
                int m = (a+b)>>1;
                int c = pop_count_byte(byte&((1<<m)-1));
                if(c <= p){a = m;}else{b=m;}
            }
            
            return a;
        }

    protected:
        size_t* c_;     
    };

    template class basic_bit_vector_rs<unsigned int>;
    typedef basic_bit_vector_rs<bit_vector::block_t> bit_vector_rs;

}

#endif

wavelet tree
wavelet_tree.hpp

#ifndef EDL_WAVELET_TREE_HPP
#define EDL_WAVELET_TREE_HPP

#include "bit_vector_rs.hpp"

#include <utility>
#include <functional>
#include <algorithm>
#include <iostream>


namespace edl{

    template<class T>
    class basic_wavelet_tree{
    public:
        typedef T char_type;
        static const size_t TSZ = 1<<(sizeof(T)*8);

        basic_wavelet_tree(const T* s, size_t sz):sz_(sz){
            build(s, sz);
        }
        ~basic_wavelet_tree(){
            destroy();
        }

        size_t rank(size_t pos, T c)const{
            return less(pos+1,c);
        }

        size_t less(size_t pos, T c)const{
            int o = c2o_[c];
            if(!bits_[o])return sz_;
            size_t next = pos;
            for(int i = 0;i<o;i++){
                next = bits_[i]->less0(next);
            }
            next = bits_[o]->less1(next);
            return next;
        }

        size_t select(size_t pos, T c)const{
            int o = c2o_[c];
            if(!bits_[o])return sz_;
            size_t next = bits_[o]->select1(pos);
            for(int i = o-1;i>=0;i--){
                next = bits_[i]->select0(next);
            }
            return next;
        }

        void debug_print(){
            T o2c[TSZ];
            for(size_t i = 0;i<TSZ;i++){
                o2c[c2o_[i]]=i;
            }

            for(size_t i = 0;i<TSZ;i++){
                if(bits_[i]){
                    std::cout <<o2c[i] <<":"<<*bits_[i] << std::endl;
                }
            }
        }
    protected:
        void build(const T* s, size_t sz){
            T o2c[TSZ];   //
            sort_order(o2c, s, sz);

            std::auto_ptr<bit_vector_rs> bits[TSZ];

            bit_vector bUsed(TSZ);
            size_t bsz = sz;
            for(size_t o=0;o<TSZ;o++){
                if(bsz==0)break;
                T c = o2c[o];//
                bits[o].reset(new bit_vector_rs(bsz));
                {
                    size_t count=0;
                    size_t used =0;
                    for(size_t i=0;i<sz;i++){
                        if(bUsed.get(s[i]) == 0){
                            if(s[i]==c){
                                bits[o]->set1_raw(i-used);
                                count++;
                            }
                        }else{
                            used++;
                        }
                    }
                    bsz -= count;
                }
                bits[o]->recache();
                bUsed.set1(c);
            }
            //-----------------------
            for(size_t o=0;o<TSZ;o++){
                bits_[o]=bits[o].release();
                c2o_[o2c[o]]=o;
            }

        }

        void destroy(){
            for(size_t i = 0;i<TSZ;i++){
                if(bits_[i])delete bits_[i];
                bits_[i]=0;
            }
        }

        static void sort_order(T o2c[TSZ], const T* s, size_t sz){
            struct pair_type{T ch;size_t count;};
            struct pair_sorter:public std::binary_function<pair_type,pair_type,bool>{
                bool operator()(const pair_type& r, const pair_type& l)const{
                    return r.count > l.count;
                }
            };
            
            pair_type count[TSZ];
            memset(count,0,sizeof(pair_type)*TSZ);
            for(T i=0;i<TSZ;i++){
                count[i].ch = i;
            }
            
            //-------------------------
            for(size_t i = 0;i<sz;i++){
                count[s[i]].count++;
            }
            //-------------------------
            std::sort(count,count+TSZ,pair_sorter());
            //-------------------------
            for(T i=0;i<TSZ;i++){
                //c2o[count[i].ch] = i;
                o2c[i] = count[i].ch;
            }
        }
    protected:
        T c2o_[TSZ];//character->order
        bit_vector_rs* bits_[TSZ];  
        size_t sz_;
    };


    template class basic_wavelet_tree<char>;
    typedef basic_wavelet_tree<char> wavelet_tree;

}

#endif

テスト
bit_test.cpp

// bit_test.cpp : コンソール アプリケーションのエントリ ポイントを定義します。
//

#include "stdafx.h"

#include "type.h"

#include <iostream>
#include <sstream>

#include "edl/bit_vector.hpp"
#include "edl/bit_vector_rs.hpp"
#include "edl/wavelet_tree.hpp"



int _tmain(int argc, _TCHAR* argv[])
{

    std::stringstream ss;


    std::cout << ss.str() << std::endl;


    edl::bit_vector bv(50);

    bv.set(8,1);
    bv.set(9,1);
    bv.set(37,1);
    bv.set(33,1);

    edl::bit_vector_rs bvr(bv);
    edl::bit_vector_rs bvr2(bvr);

    std::cout << bvr << std::endl;
    std::cout << bvr2 << std::endl;

    bvr.set(40,1);
    bvr2 = bvr;

    std::cout << bvr << std::endl;
    std::cout << bvr2 << std::endl;

    const char* s = "abccbbabca";
    edl::wavelet_tree wt(s,strlen(s));
    std::cout << wt.rank(5,'b') << std::endl;
    std::cout << wt.rank(5,'c') << std::endl;
    std::cout << wt.rank(5,'a') << std::endl;

    std::cout << wt.select(1,'b') << std::endl;
    std::cout << wt.select(1,'c') << std::endl;
    std::cout << wt.select(1,'a') << std::endl;

    wt.debug_print();

    
    return 0;
}