Not logged in.  Login/Logout/Register | List snippets | | Create snippet | Upload image | Upload data

379
LINES

< > BotCompany Repo | #1001028 // Token prediction, multiple predictors (v3, including start trees, developing)

JavaX source code [tags: use-pretranspiled] - run with: x30.jar

Libraryless. Click here for Pure Java version (2021L/14K/49K).

1  
!752
2  
3  
static S corpusID = "#1001010";
4  
static int numSnippets = 300;
5  
static boolean showGUI = true;
6  
static int maxCharsGUI = 500000;
7  
static boolean allTokens = true;
8  
9  
static Collector collector;
10  
static L<F> files;
11  
static Map<F, Set<int>> predicted;
12  
13  
// a file to learn from
14  
static class F {
15  
  String id, name;
16  
  L<S> tok;
17  
}
18  
19  
// a predictor
20  
static abstract class P {
21  
  int seen;
22  
  S file;
23  
  
24  
  abstract S read(S file, L<S> tok);
25  
  abstract P derive(); // clone & reset counter for actual use
26  
  abstract P clear();
27  
  
28  
  void prepare(S file) {
29  
    if (!eq(file, this.file)) {
30  
      seen = 0;
31  
      this.file = file;
32  
    }
33  
  }
34  
}
35  
36  
static class Chain extends P {
37  
  new L<P> list;
38  
  
39  
  *() {}
40  
  *(L<P> *list) {}
41  
  *(P... a) { list = asList(a); }
42  
  
43  
  void add(P p) { list.add(p); }
44  
  
45  
  S read(S file, L<S> tok) {
46  
    for (P p : list) {
47  
      S s = p.read(file, tok);
48  
      if (s != null) return s;
49  
    }
50  
    return null;
51  
  }
52  
  
53  
  P derive() {
54  
    new Chain c;
55  
    for (P p : list)
56  
      c.add(p.derive());
57  
    return c;
58  
  }
59  
  
60  
  P clear() {
61  
    new Chain c;
62  
    for (P p : list)
63  
      c.add(p.clear());
64  
    return c;
65  
  }
66  
}
67  
  
68  
static class Tuples extends P {
69  
  Map<L<S>,S> map = new HashMap<L<S>,S>();
70  
  int n;
71  
72  
  *(int *n) {
73  
  }
74  
  
75  
  S read(S file, L<S> tok) {
76  
    prepare(file);
77  
    
78  
    while (tok.size() > seen) {
79  
      ++seen;
80  
      if (seen > n)
81  
        map.put(new ArrayList<S>(tok.subList(seen-n-1, seen-1)), tok.get(seen-1));
82  
    }
83  
    
84  
    if (tok.size() >= n)
85  
      return map.get(new ArrayList<S>(tok.subList(tok.size()-n, tok.size())));
86  
      
87  
    return null;
88  
  }
89  
  
90  
  // slow...
91  
  P oldDerive() {
92  
    Tuples t = new Tuples(n);
93  
    t.map.putAll(map);
94  
    // t.seen == 0 which is ok
95  
    return t;
96  
  }
97  
  
98  
  // fast!
99  
  P derive() {
100  
    Tuples t = new Tuples(n);
101  
    t.map = new DerivedHashMap<L<S>,S>(map);
102  
    return t;
103  
  }
104  
  
105  
  P clear() {
106  
    return new Tuples(n);
107  
  }
108  
}
109  
110  
!include #1001027 // DerivedHashMap
111  
112  
static class Node {
113  
  String token;
114  
  float count;
115  
  new L<Node> next;
116  
  
117  
  *() {} // for clone method
118  
  
119  
  *(S *token) {}
120  
  
121  
  Node find(S token) {
122  
    for (Node n : next)
123  
      if (n.token.equals(token))
124  
        ret n;
125  
    ret null;
126  
  }
127  
  
128  
  Node bestNext() {
129  
    float bestCount = 0f;
130  
    Node best = null;
131  
    for (Node n : next)
132  
      if (best == null || n.count > best.count) {
133  
        best = n;
134  
        bestCount = n.count;
135  
      }
136  
    ret best;
137  
  }
138  
}
139  
140  
static class StartTree extends P {
141  
  Node tree = new Node("");
142  
  Node node;
143  
  boolean nonmod;
144  
  
145  
  S read(S file, L<S> tok) {
146  
    if (!eq(file, this.file)) {
147  
      seen = 0;
148  
      this.file = file;
149  
      node = tree;
150  
    }
151  
    
152  
    if (!nonmod) while (tok.size() > seen) {
153  
      S t = tok.get(seen++);
154  
      Node child = node.find(t);
155  
      if (child == null)
156  
        node.next.add(child = new Node(t));
157  
      child.count++;
158  
      node = child;
159  
    }
160  
    
161  
    Node n = node.bestNext();
162  
    ret n != null ? n.token : null;
163  
  }
164  
  
165  
  // it's a hack - derived predictor doesn't learn
166  
  P derive() {
167  
    //return (P) main.clone(this);
168  
    new StartTree p;
169  
    p.nonmod = true;
170  
    p.tree = tree;
171  
    return p;
172  
  }
173  
  
174  
  P clear() {
175  
    return new StartTree;
176  
  }
177  
}
178  
179  
p {
180  
  files = makeCorpus();
181  
  print("Files in corpus: " + files.size());
182  
  
183  
  print("Learning...");
184  
  collector = new Collector;
185  
  //test(new Tuples(1));
186  
  test(new StartTree);
187  
  test(new Chain(new Tuples(4), new Tuples(3), new Tuples(2), new Tuples(1), new StartTree));
188  
189  
  print("Learning done.");
190  
  printVMSize();
191  
  if (collector.winner != null && showGUI)
192  
    window();
193  
}
194  
195  
static int points = 0, total = 0;
196  
197  
// train & evaluate a predictor
198  
static void test(P p) {
199  
  int lastPercent = 0;
200  
  predicted = new HashMap;
201  
  points = 0;
202  
  total = 0;
203  
  for (int ii = 0; ii < files.size(); ii++) {
204  
    F f = files.get(ii);
205  
    
206  
    testFile(p, f);
207  
    
208  
    int percent = roundUpTo(10, (int) (ii*100L/files.size()));
209  
    if (percent > lastPercent) {
210  
      print("Learning " + percent + "% done.");
211  
      lastPercent = percent;
212  
    }
213  
  }
214  
  double score = points*100.0/total;
215  
  collector.add(p, score);
216  
}
217  
218  
static void testFile(P p, F f) {
219  
  new TreeSet<int> pred;
220  
  new L<S> history;
221  
  for (int i = allTokens ? 0 : 1; i < f.tok.size(); i += allTokens ? 1 : 2) {
222  
    S t = f.tok.get(i);
223  
    S x = p.read(f.name, history);
224  
    boolean correct = t.equals(x);
225  
    total += t.length();
226  
    if (correct) {
227  
      pred.add(i);
228  
      points += t.length();
229  
    }
230  
    history.add(t);
231  
  }
232  
  predicted.put(f, pred);
233  
}
234  
235  
!include #1000989 // SnippetDB
236  
237  
static L<F> makeCorpus() {
238  
  S name = getSnippetTitle(corpusID);
239  
  if (name.toLowerCase().indexOf(".zip") >= 0)
240  
    return makeCorpus_zip();
241  
  else
242  
    return makeCorpus_mysqldump();
243  
}
244  
245  
static L<F> makeCorpus_zip() ctex {
246  
  new L<F> files;
247  
  ZipFile zipFile = new ZipFile(loadLibrary(corpusID));
248  
  Enumeration entries = zipFile.entries();
249  
250  
  while (entries.hasMoreElements() && files.size() < numSnippets) {
251  
    ZipEntry entry = (ZipEntry) entries.nextElement(); 
252  
    if (entry.isDirectory()) continue;
253  
    //System.out.println("File found: " + entry.getName());
254  
255  
    InputStream fin = zipFile.getInputStream(entry);
256  
    // TODO: try to skip binary files?
257  
    
258  
    InputStreamReader reader = new InputStreamReader(fin, "UTF-8");
259  
    new StringBuilder builder;
260  
    BufferedReader bufferedReader = new BufferedReader(reader);
261  
    String line;
262  
    while ((line = bufferedReader.readLine()) != null)
263  
      builder.append(line).append('\n');
264  
    fin.close();
265  
    S text = builder.toString();
266  
    
267  
    new F f;
268  
    f.name = entry.getName();
269  
    f.tok = internAll(javaTok(text));
270  
    files.add(f);
271  
  }
272  
  
273  
  zipFile.close();
274  
  return files;
275  
}
276  
277  
static L<F> makeCorpus_mysqldump() {
278  
  new L<F> files;
279  
  SnippetDB db = new SnippetDB(corpusID);
280  
  List<List<S>> rows = db.rowsOrderedBy("sn_created");
281  
  for (int i = 0; i < Math.min(rows.size(), numSnippets); i++) {
282  
    new F f;
283  
    f.id = db.getField(rows.get(i), "sn_id");
284  
    f.name = db.getField(rows.get(i), "sn_title");
285  
    S text = db.getField(rows.get(i), "sn_text");
286  
    f.tok = internAll(javaTok(text));
287  
    files.add(f);
288  
    ++i;
289  
  }
290  
  return files;
291  
}
292  
293  
static class Collector {
294  
  P winner;
295  
  double bestScore = -1;
296  
  Map<F, Set<int>> predicted;
297  
298  
  void add(P p, double score) {
299  
    if (winner == null || score > bestScore) {
300  
      winner = p;
301  
      bestScore = score;
302  
      //S name = shorten(structure(p), 100);
303  
      S name = p.getClass().getName();
304  
      print("New best score: " + formatDouble(score, 2) + "% (" + name + ")");
305  
      predicted = main.predicted;
306  
    }
307  
  }
308  
}
309  
310  
static void window() {
311  
  //final P p = collector.winner.clear();
312  
313  
  JFrame jf = new JFrame("Predicted = green");
314  
  Container cp = jf.getContentPane();
315  
316  
  final JButton btnNext = new JButton("Next");
317  
  
318  
  final JTextPane pane = new JTextPane();
319  
  //pane.setFont(loadFont("#1000993", 24));
320  
  
321  
  JScrollPane scrollPane = new JScrollPane(pane);
322  
  cp.add(scrollPane, BorderLayout.CENTER);
323  
  
324  
  class X {
325  
    int ii;
326  
    
327  
    void y() ctex {
328  
      ii = ii == 0 ? files.size()-1 : ii-1;
329  
      F f = files.get(ii);
330  
      //testFile(p, f);
331  
      Set<int> pred = collector.predicted.get(f);
332  
      
333  
      StyledDocument doc = new DefaultStyledDocument();
334  
335  
      L<S> tok = f.tok;
336  
      int i = tok.size(), len = 0;
337  
      while (len <= maxCharsGUI && i > 0) {
338  
        --i;
339  
        len += tok.get(i).length();
340  
      }
341  
      
342  
      for (; i < tok.size(); i++) {
343  
        if (tok.get(i).length() == 0) continue;
344  
        boolean green = pred.contains(i);
345  
        SimpleAttributeSet set = new SimpleAttributeSet();
346  
        StyleConstants.setForeground(set, green ? Color.green : Color.gray);
347  
        doc.insertString(doc.getLength(), tok.get(i), set);
348  
      }
349  
      
350  
      pane.setDocument(doc);
351  
      double score = getScore(pred, tok);
352  
      btnNext.setText(f.name + " (" + (ii+1) + "/" + files.size() + ") - " + (int) score + " %");
353  
    }
354  
  }
355  
  final new X x;
356  
  
357  
  btnNext.addActionListener(actionListener {
358  
    x.y();
359  
  });
360  
  cp.add(btnNext, BorderLayout.NORTH);
361  
362  
  x.y();
363  
  
364  
  jf.setBounds(100, 100, 600, 600);
365  
  jf.setVisible(true);    
366  
}
367  
368  
!include #1001032 // clone function
369  
370  
static double getScore(Set<int> pred, L<S> tok) {
371  
  int total = 0, score = 0;
372  
  for (int i = 0; i < tok.size(); i++) {
373  
    int n = tok.get(i).length();
374  
    total += n;
375  
    if (pred.contains(i))
376  
      score += n;
377  
  }
378  
  ret score*100.0/total;
379  
}

Author comment

Began life as a copy of #1001025

download  show line numbers  debug dex  old transpilations   

Travelled to 16 computer(s): aoiabmzegqzx, bhatertpkbcr, cbybwowwnfue, cfunsshuasjs, ddnzoavkxhuk, gwrvuhgaqvyk, ishqpsrjomds, lpdgvwnxivlt, mqqgnosmbjvj, onxytkatvevr, pyentgdyhuwx, pzhvpgtvlbxg, teubizvjbppd, tslmcundralx, tvejysmllsmz, vouqrxazstgt

No comments. add comment

Snippet ID: #1001028
Snippet name: Token prediction, multiple predictors (v3, including start trees, developing)
Eternal ID of this version: #1001028/1
Text MD5: 6a1e871348685598ee089e9642400ab0
Transpilation MD5: e27ea70ea3a336e2c3da51a230a32af6
Author: stefan
Category:
Type: JavaX source code
Public (visible to everyone): Yes
Archived (hidden from active list): No
Created/modified: 2016-06-15 14:32:50
Source code size: 8674 bytes / 379 lines
Pitched / IR pitched: No / No
Views / Downloads: 766 / 1085
Referenced in: [show references]