转载

Java实现双数组Trie树(DoubleArrayTrie,DAT)

传统的Trie实现简单,但是占用的空间实在是难以接受,特别是当字符集不仅限于英文26个字符的时候,爆炸起来的空间根本无法接受。 双数组Trie就是优化了空间的Trie树,原理本文就不讲了,请参考An Efficient Implementation of Trie Structures,本程序的编写也是参考这篇论文的。 关于几点论文没有提及的细节和与论文不一一致的实现: 1.对于插入字符串,如果有一个字符串是另一个字符串的子串的话,我是将结束符也作为一条边,产生一个新的结点,这个结点新节点的Base我置为0 所以一个字符串结束也有2中情况:一个是Base值为负,存储剩余字符(可能只有一个结束符)到Tail数组;另一个是Base为0。 所以在查询的时候要考虑一下这两种情况 2.对于第一种冲突(论文中的Case 3),可能要将Tail中的字符串取出一部分,作为边放到索引中。论文是使用将尾串左移的方式,我的方式直接修改Base值,而不是移动尾串。 下面是java实现的代码,可以处理相同字符串插入,子串的插入等情况
  1. /*
  2.  * Name:   Double Array Trie
  3.  * Author: Yaguang Ding
  4.  * Mail: dingyaguang117@gmail.com
  5.  * Blog: blog.csdn.net/dingyaguang117
  6.  * Date:   2012/5/21
  7.  * Note: a word ends may be either of these two case:
  8.  * 1. Base[cur_p] == pos  ( pos<0 and Tail[-pos] == 'END_CHAR' )
  9.  * 2. Check[Base[cur_p] + Code('END_CHAR')] ==  cur_p
  10.  */
  11. import java.util.ArrayList;
  12. import java.util.HashMap;
  13. import java.util.Map;
  14. import java.util.Arrays;
  15. public class DoubleArrayTrie {
  16.     final char END_CHAR = '/0';
  17.     final int DEFAULT_LEN = 1024;
  18.     int Base[]  = new int [DEFAULT_LEN];
  19.     int Check[] = new int [DEFAULT_LEN];
  20.     char Tail[] = new char [DEFAULT_LEN];
  21.     int Pos = 1;
  22.     Map<Character ,Integer> CharMap = new HashMap<Character,Integer>();
  23.     ArrayList<Character> CharList = new ArrayList<Character>();
  24.     public DoubleArrayTrie()
  25.     {
  26.         Base[1] = 1;
  27.         CharMap.put(END_CHAR,1);
  28.         CharList.add(END_CHAR);
  29.         CharList.add(END_CHAR);
  30.         for(int i=0;i<26;++i)
  31.         {
  32.             CharMap.put((char)('a'+i),CharMap.size()+1);
  33.             CharList.add((char)('a'+i));
  34.         }
  35.     }
  36.     private void Extend_Array()
  37.     {
  38.         Base = Arrays.copyOf(Base, Base.length*2);
  39.         Check = Arrays.copyOf(Check, Check.length*2);
  40.     }
  41.     private void Extend_Tail()
  42.     {
  43.         Tail = Arrays.copyOf(Tail, Tail.length*2);
  44.     }
  45.     private int GetCharCode(char c)
  46.     {
  47.         if (!CharMap.containsKey(c))
  48.         {
  49.             CharMap.put(c,CharMap.size()+1);
  50.             CharList.add(c);
  51.         }
  52.         return CharMap.get(c);
  53.     }
  54.     private int CopyToTailArray(String s,int p)
  55.     {
  56.         int _Pos = Pos;
  57.         while(s.length()-p+1 > Tail.length-Pos)
  58.         {
  59.             Extend_Tail();
  60.         }
  61.         for(int i=p; i<s.length();++i)
  62.         {
  63.             Tail[_Pos] = s.charAt(i);
  64.             _Pos++;
  65.         }
  66.         return _Pos;
  67.     }
  68.     private int x_check(Integer []set)
  69.     {
  70.         for(int i=1; ; ++i)
  71.         {
  72.             boolean flag = true;
  73.             for(int j=0;j<set.length;++j)
  74.             {
  75.                 int cur_p = i+set[j];
  76.                 if(cur_p>= Base.length) Extend_Array();
  77.                 if(Base[cur_p]!= 0 || Check[cur_p]!= 0)
  78.                 {
  79.                     flag = false;
  80.                     break;
  81.                 }
  82.             }
  83.             if (flag) return i;
  84.         }
  85.     }
  86.     private ArrayList<Integer> GetChildList(int p)
  87.     {
  88.         ArrayList<Integer> ret = new ArrayList<Integer>();
  89.         for(int i=1; i<=CharMap.size();++i)
  90.         {
  91.             if(Base[p]+i >= Check.length) break;
  92.             if(Check[Base[p]+i] == p)
  93.             {
  94.                 ret.add(i);
  95.             }
  96.         }
  97.         return ret;
  98.     }
  99.     private boolean TailContainString(int start,String s2)
  100.     {
  101.         for(int i=0;i<s2.length();++i)
  102.         {
  103.             if(s2.charAt(i) != Tail[i+start]) return false;
  104.         }
  105.         return true;
  106.     }
  107.     private boolean TailMatchString(int start,String s2)
  108.     {
  109.         s2 += END_CHAR;
  110.         for(int i=0;i<s2.length();++i)
  111.         {
  112.             if(s2.charAt(i) != Tail[i+start]) return false;
  113.         }
  114.         return true;
  115.     }
  116.     public void Insert(String s) throws Exception
  117.     {
  118.         s += END_CHAR;
  119.         int pre_p = 1;
  120.         int cur_p;
  121.         for(int i=0; i<s.length(); ++i)
  122.         {
  123.             //获取状态位置
  124.             cur_p = Base[pre_p]+GetCharCode(s.charAt(i));
  125.             //如果长度超过现有,拓展数组
  126.             if (cur_p >= Base.length) Extend_Array();
  127.             //空闲状态
  128.             if(Base[cur_p] == 0 && Check[cur_p] == 0)
  129.             {
  130.                 Base[cur_p] = -Pos;
  131.                 Check[cur_p] = pre_p;
  132.                 Pos = CopyToTailArray(s,i+1);
  133.                 break;
  134.             }else
  135.             //已存在状态
  136.             if(Base[cur_p] > 0 && Check[cur_p] == pre_p)
  137.             {
  138.                 pre_p = cur_p;
  139.                 continue;
  140.             }else
  141.             //冲突 1:遇到 Base[cur_p]小于0的,即遇到一个被压缩存到Tail中的字符串
  142.             if(Base[cur_p] < 0 && Check[cur_p] == pre_p)
  143.             {
  144.                 int head = -Base[cur_p];
  145.                 if(s.charAt(i+1)== END_CHAR && Tail[head]==END_CHAR)    //插入重复字符串
  146.                 {
  147.                     break;
  148.                 }
  149.                 //公共字母的情况,因为上一个判断已经排除了结束符,所以一定是2个都不是结束符
  150.                 if (Tail[head] == s.charAt(i+1))
  151.                 {
  152.                     int avail_base = x_check(new Integer[]{GetCharCode(s.charAt(i+1))});
  153.                     Base[cur_p] = avail_base;
  154.                     Check[avail_base+GetCharCode(s.charAt(i+1))] = cur_p;
  155.                     Base[avail_base+GetCharCode(s.charAt(i+1))] = -(head+1);
  156.                     pre_p = cur_p;
  157.                     continue;
  158.                 }
  159.                 else
  160.                 {
  161.                     //2个字母不相同的情况,可能有一个为结束符
  162.                     int avail_base ;
  163.                     avail_base = x_check(new Integer[]{GetCharCode(s.charAt(i+1)),GetCharCode(Tail[head])});
  164.                     Base[cur_p] = avail_base;
  165.                     Check[avail_base+GetCharCode(Tail[head])] = cur_p;
  166.                     Check[avail_base+GetCharCode(s.charAt(i+1))] = cur_p;
  167.                     //Tail 为END_FLAG 的情况
  168.                     if(Tail[head] == END_CHAR)
  169.                         Base[avail_base+GetCharCode(Tail[head])] = 0;
  170.                     else
  171.                         Base[avail_base+GetCharCode(Tail[head])] = -(head+1);
  172.                     if(s.charAt(i+1) == END_CHAR)
  173.                         Base[avail_base+GetCharCode(s.charAt(i+1))] = 0;
  174.                     else
  175.                         Base[avail_base+GetCharCode(s.charAt(i+1))] = -Pos;
  176.                     Pos = CopyToTailArray(s,i+2);
  177.                     break;
  178.                 }
  179.             }else
  180.             //冲突2:当前结点已经被占用,需要调整pre的base
  181.             if(Check[cur_p] != pre_p)
  182.             {
  183.                 ArrayList<Integer> list1 = GetChildList(pre_p);
  184.                 int toBeAdjust;
  185.                 ArrayList<Integer> list = null;
  186.                 if(true)
  187.                 {
  188.                     toBeAdjust = pre_p;
  189.                     list = list1;
  190.                 }
  191.                 int origin_base = Base[toBeAdjust];
  192.                 list.add(GetCharCode(s.charAt(i)));
  193.                 int avail_base = x_check((Integer[])list.toArray(new Integer[list.size()]));
  194.                 list.remove(list.size()-1);
  195.                 Base[toBeAdjust] = avail_base;
  196.                 for(int j=0; j<list.size(); ++j)
  197.                 {
  198.                     //BUG
  199.                     int tmp1 = origin_base + list.get(j);
  200.                     int tmp2 = avail_base + list.get(j);
  201.                     Base[tmp2] = Base[tmp1];
  202.                     Check[tmp2] = Check[tmp1];
  203.                     //有后续
  204.                     if(Base[tmp1] > 0)
  205.                     {
  206.                         ArrayList<Integer> subsequence = GetChildList(tmp1);
  207.                         for(int k=0; k<subsequence.size(); ++k)
  208.                         {
  209.                             Check[Base[tmp1]+subsequence.get(k)] = tmp2;
  210.                         }
  211.                     }
  212.                     Base[tmp1] = 0;
  213.                     Check[tmp1] = 0;
  214.                 }
  215.                 //更新新的cur_p
  216.                 cur_p = Base[pre_p]+GetCharCode(s.charAt(i));
  217.                 if(s.charAt(i) == END_CHAR)
  218.                     Base[cur_p] = 0;
  219.                 else
  220.                     Base[cur_p] = -Pos;
  221.                 Check[cur_p] = pre_p;
  222.                 Pos = CopyToTailArray(s,i+1);
  223.                 break;
  224.             }
  225.         }
  226.     }
  227.     public boolean Exists(String word)
  228.     {
  229.         int pre_p = 1;
  230.         int cur_p = 0;
  231.         for(int i=0;i<word.length();++i)
  232.         {
  233.             cur_p = Base[pre_p]+GetCharCode(word.charAt(i));
  234.             if(Check[cur_p] != pre_p) return false;
  235.             if(Base[cur_p] < 0)
  236.             {
  237.                 if(TailMatchString(-Base[cur_p],word.substring(i+1)))
  238.                     return true;
  239.                 return false;
  240.             }
  241.             pre_p = cur_p;
  242.         }
  243.         if(Check[Base[cur_p]+GetCharCode(END_CHAR)] == cur_p)
  244.             return true;
  245.         return false;
  246.     }
  247.     //内部函数,返回匹配单词的最靠后的Base index,
  248.     class FindStruct
  249.     {
  250.         int p;
  251.         String prefix="";
  252.     }
  253.     private FindStruct Find(String word)
  254.     {
  255.         int pre_p = 1;
  256.         int cur_p = 0;
  257.         FindStruct fs = new FindStruct();
  258.         for(int i=0;i<word.length();++i)
  259.         {
  260.             // BUG
  261.             fs.prefix += word.charAt(i);
  262.             cur_p = Base[pre_p]+GetCharCode(word.charAt(i));
  263.             if(Check[cur_p] != pre_p)
  264.             {
  265.                 fs.p = -1;
  266.                 return fs;
  267.             }
  268.             if(Base[cur_p] < 0)
  269.             {
  270.                 if(TailContainString(-Base[cur_p],word.substring(i+1)))
  271.                 {
  272.                     fs.p = cur_p;
  273.                     return fs;
  274.                 }
  275.                 fs.p = -1;
  276.                 return fs;
  277.             }
  278.             pre_p = cur_p;
  279.         }
  280.         fs.p =  cur_p;
  281.         return fs;
  282.     }
  283.     public ArrayList<String> GetAllChildWord(int index)
  284.     {
  285.         ArrayList<String> result = new ArrayList<String>();
  286.         if(Base[index] == 0)
  287.         {
  288.             result.add("");
  289.             return result;
  290.         }
  291.         if(Base[index] < 0)
  292.         {
  293.             String r="";
  294.             for(int i=-Base[index];Tail[i]!=END_CHAR;++i)
  295.             {
  296.                 r+= Tail[i];
  297.             }
  298.             result.add(r);
  299.             return result;
  300.         }
  301.         for(int i=1;i<=CharMap.size();++i)
  302.         {
  303.             if(Check[Base[index]+i] == index)
  304.             {
  305.                 for(String s:GetAllChildWord(Base[index]+i))
  306.                 {
  307.                     result.add(CharList.get(i)+s);
  308.                 }
  309.                 //result.addAll(GetAllChildWord(Base[index]+i));
  310.             }
  311.         }
  312.         return result;
  313.     }
  314.     public ArrayList<String> FindAllWords(String word)
  315.     {
  316.         ArrayList<String> result = new ArrayList<String>();
  317.         String prefix = "";
  318.         FindStruct fs = Find(word);
  319.         int p = fs.p;
  320.         if (p == -1) return result;
  321.         if(Base[p]<0)
  322.         {
  323.             String r="";
  324.             for(int i=-Base[p];Tail[i]!=END_CHAR;++i)
  325.             {
  326.                 r+= Tail[i];
  327.             }
  328.             result.add(fs.prefix+r);
  329.             return result;
  330.         }
  331.         if(Base[p] > 0)
  332.         {
  333.             ArrayList<String> r =  GetAllChildWord(p);
  334.             for(int i=0;i<r.size();++i)
  335.             {
  336.                 r.set(i, fs.prefix+r.get(i));
  337.             }
  338.             return r;
  339.         }
  340.         return result;
  341.     }
  342. }
测  试
  1. import java.io.BufferedReader;
  2. import java.io.FileInputStream;
  3. import java.io.IOException;
  4. import java.io.InputStream;
  5. import java.io.InputStreamReader;
  6. import java.util.ArrayList;
  7. import java.util.Scanner;
  8. import javax.xml.crypto.Data;
  9. public class Main {
  10.     public static void main(String[] args) throws Exception {
  11.         ArrayList<String> words = new ArrayList<String>();
  12.         BufferedReader reader = new BufferedReader(new InputStreamReader(new FileInputStream("E:/兔子的试验学习中心[课内]/ACM大赛/ACM第四届校赛/E命令提示/words3.dic")));
  13.         String s;
  14.         int num = 0;
  15.         while((s=reader.readLine()) != null)
  16.         {
  17.             words.add(s);
  18.             num ++;
  19.         }
  20.         DoubleArrayTrie dat = new DoubleArrayTrie();
  21.         for(String word: words)
  22.         {
  23.             dat.Insert(word);
  24.         }
  25.         System.out.println(dat.Base.length);
  26.         System.out.println(dat.Tail.length);
  27.         Scanner sc = new Scanner(System.in);
  28.         while(sc.hasNext())
  29.         {
  30.             String word = sc.next();
  31.             System.out.println(dat.Exists(word));
  32.             System.out.println(dat.FindAllWords(word));
  33.         }
  34.     }
  35. }
下面是测试结果,构造6W英文单词的DAT,大概需要20秒 我增长数组的时候是每次长度增加到2倍,初始1024 Base和Check数组的长度为131072 Tail的长度为262144 Java实现双数组Trie树(DoubleArrayTrie,DAT) 原文地址:Java实现双数组Trie树(DoubleArrayTrie,DAT)  
正文到此结束
Loading...