Not logged in.  Login/Logout/Register | List snippets | | Create snippet | Upload image | Upload data

379
LINES

< > BotCompany Repo | #1001028 // Token prediction, multiple predictors (v3, including start trees, developing)

JavaX source code [tags: use-pretranspiled] - run with: x30.jar

Libraryless. Click here for Pure Java version (2021L/14K/49K).

!752

static S corpusID = "#1001010";
static int numSnippets = 300;
static boolean showGUI = true;
static int maxCharsGUI = 500000;
static boolean allTokens = true;

static Collector collector;
static L<F> files;
static Map<F, Set<int>> predicted;

// a file to learn from
static class F {
  String id, name;
  L<S> tok;
}

// a predictor
static abstract class P {
  int seen;
  S file;
  
  abstract S read(S file, L<S> tok);
  abstract P derive(); // clone & reset counter for actual use
  abstract P clear();
  
  void prepare(S file) {
    if (!eq(file, this.file)) {
      seen = 0;
      this.file = file;
    }
  }
}

static class Chain extends P {
  new L<P> list;
  
  *() {}
  *(L<P> *list) {}
  *(P... a) { list = asList(a); }
  
  void add(P p) { list.add(p); }
  
  S read(S file, L<S> tok) {
    for (P p : list) {
      S s = p.read(file, tok);
      if (s != null) return s;
    }
    return null;
  }
  
  P derive() {
    new Chain c;
    for (P p : list)
      c.add(p.derive());
    return c;
  }
  
  P clear() {
    new Chain c;
    for (P p : list)
      c.add(p.clear());
    return c;
  }
}
  
static class Tuples extends P {
  Map<L<S>,S> map = new HashMap<L<S>,S>();
  int n;

  *(int *n) {
  }
  
  S read(S file, L<S> tok) {
    prepare(file);
    
    while (tok.size() > seen) {
      ++seen;
      if (seen > n)
        map.put(new ArrayList<S>(tok.subList(seen-n-1, seen-1)), tok.get(seen-1));
    }
    
    if (tok.size() >= n)
      return map.get(new ArrayList<S>(tok.subList(tok.size()-n, tok.size())));
      
    return null;
  }
  
  // slow...
  P oldDerive() {
    Tuples t = new Tuples(n);
    t.map.putAll(map);
    // t.seen == 0 which is ok
    return t;
  }
  
  // fast!
  P derive() {
    Tuples t = new Tuples(n);
    t.map = new DerivedHashMap<L<S>,S>(map);
    return t;
  }
  
  P clear() {
    return new Tuples(n);
  }
}

!include #1001027 // DerivedHashMap

static class Node {
  String token;
  float count;
  new L<Node> next;
  
  *() {} // for clone method
  
  *(S *token) {}
  
  Node find(S token) {
    for (Node n : next)
      if (n.token.equals(token))
        ret n;
    ret null;
  }
  
  Node bestNext() {
    float bestCount = 0f;
    Node best = null;
    for (Node n : next)
      if (best == null || n.count > best.count) {
        best = n;
        bestCount = n.count;
      }
    ret best;
  }
}

static class StartTree extends P {
  Node tree = new Node("");
  Node node;
  boolean nonmod;
  
  S read(S file, L<S> tok) {
    if (!eq(file, this.file)) {
      seen = 0;
      this.file = file;
      node = tree;
    }
    
    if (!nonmod) while (tok.size() > seen) {
      S t = tok.get(seen++);
      Node child = node.find(t);
      if (child == null)
        node.next.add(child = new Node(t));
      child.count++;
      node = child;
    }
    
    Node n = node.bestNext();
    ret n != null ? n.token : null;
  }
  
  // it's a hack - derived predictor doesn't learn
  P derive() {
    //return (P) main.clone(this);
    new StartTree p;
    p.nonmod = true;
    p.tree = tree;
    return p;
  }
  
  P clear() {
    return new StartTree;
  }
}

p {
  files = makeCorpus();
  print("Files in corpus: " + files.size());
  
  print("Learning...");
  collector = new Collector;
  //test(new Tuples(1));
  test(new StartTree);
  test(new Chain(new Tuples(4), new Tuples(3), new Tuples(2), new Tuples(1), new StartTree));

  print("Learning done.");
  printVMSize();
  if (collector.winner != null && showGUI)
    window();
}

static int points = 0, total = 0;

// train & evaluate a predictor
static void test(P p) {
  int lastPercent = 0;
  predicted = new HashMap;
  points = 0;
  total = 0;
  for (int ii = 0; ii < files.size(); ii++) {
    F f = files.get(ii);
    
    testFile(p, f);
    
    int percent = roundUpTo(10, (int) (ii*100L/files.size()));
    if (percent > lastPercent) {
      print("Learning " + percent + "% done.");
      lastPercent = percent;
    }
  }
  double score = points*100.0/total;
  collector.add(p, score);
}

static void testFile(P p, F f) {
  new TreeSet<int> pred;
  new L<S> history;
  for (int i = allTokens ? 0 : 1; i < f.tok.size(); i += allTokens ? 1 : 2) {
    S t = f.tok.get(i);
    S x = p.read(f.name, history);
    boolean correct = t.equals(x);
    total += t.length();
    if (correct) {
      pred.add(i);
      points += t.length();
    }
    history.add(t);
  }
  predicted.put(f, pred);
}

!include #1000989 // SnippetDB

static L<F> makeCorpus() {
  S name = getSnippetTitle(corpusID);
  if (name.toLowerCase().indexOf(".zip") >= 0)
    return makeCorpus_zip();
  else
    return makeCorpus_mysqldump();
}

static L<F> makeCorpus_zip() ctex {
  new L<F> files;
  ZipFile zipFile = new ZipFile(loadLibrary(corpusID));
  Enumeration entries = zipFile.entries();

  while (entries.hasMoreElements() && files.size() < numSnippets) {
    ZipEntry entry = (ZipEntry) entries.nextElement(); 
    if (entry.isDirectory()) continue;
    //System.out.println("File found: " + entry.getName());

    InputStream fin = zipFile.getInputStream(entry);
    // TODO: try to skip binary files?
    
    InputStreamReader reader = new InputStreamReader(fin, "UTF-8");
    new StringBuilder builder;
    BufferedReader bufferedReader = new BufferedReader(reader);
    String line;
    while ((line = bufferedReader.readLine()) != null)
      builder.append(line).append('\n');
    fin.close();
    S text = builder.toString();
    
    new F f;
    f.name = entry.getName();
    f.tok = internAll(javaTok(text));
    files.add(f);
  }
  
  zipFile.close();
  return files;
}

static L<F> makeCorpus_mysqldump() {
  new L<F> files;
  SnippetDB db = new SnippetDB(corpusID);
  List<List<S>> rows = db.rowsOrderedBy("sn_created");
  for (int i = 0; i < Math.min(rows.size(), numSnippets); i++) {
    new F f;
    f.id = db.getField(rows.get(i), "sn_id");
    f.name = db.getField(rows.get(i), "sn_title");
    S text = db.getField(rows.get(i), "sn_text");
    f.tok = internAll(javaTok(text));
    files.add(f);
    ++i;
  }
  return files;
}

static class Collector {
  P winner;
  double bestScore = -1;
  Map<F, Set<int>> predicted;

  void add(P p, double score) {
    if (winner == null || score > bestScore) {
      winner = p;
      bestScore = score;
      //S name = shorten(structure(p), 100);
      S name = p.getClass().getName();
      print("New best score: " + formatDouble(score, 2) + "% (" + name + ")");
      predicted = main.predicted;
    }
  }
}

static void window() {
  //final P p = collector.winner.clear();

  JFrame jf = new JFrame("Predicted = green");
  Container cp = jf.getContentPane();

  final JButton btnNext = new JButton("Next");
  
  final JTextPane pane = new JTextPane();
  //pane.setFont(loadFont("#1000993", 24));
  
  JScrollPane scrollPane = new JScrollPane(pane);
  cp.add(scrollPane, BorderLayout.CENTER);
  
  class X {
    int ii;
    
    void y() ctex {
      ii = ii == 0 ? files.size()-1 : ii-1;
      F f = files.get(ii);
      //testFile(p, f);
      Set<int> pred = collector.predicted.get(f);
      
      StyledDocument doc = new DefaultStyledDocument();

      L<S> tok = f.tok;
      int i = tok.size(), len = 0;
      while (len <= maxCharsGUI && i > 0) {
        --i;
        len += tok.get(i).length();
      }
      
      for (; i < tok.size(); i++) {
        if (tok.get(i).length() == 0) continue;
        boolean green = pred.contains(i);
        SimpleAttributeSet set = new SimpleAttributeSet();
        StyleConstants.setForeground(set, green ? Color.green : Color.gray);
        doc.insertString(doc.getLength(), tok.get(i), set);
      }
      
      pane.setDocument(doc);
      double score = getScore(pred, tok);
      btnNext.setText(f.name + " (" + (ii+1) + "/" + files.size() + ") - " + (int) score + " %");
    }
  }
  final new X x;
  
  btnNext.addActionListener(actionListener {
    x.y();
  });
  cp.add(btnNext, BorderLayout.NORTH);

  x.y();
  
  jf.setBounds(100, 100, 600, 600);
  jf.setVisible(true);    
}

!include #1001032 // clone function

static double getScore(Set<int> pred, L<S> tok) {
  int total = 0, score = 0;
  for (int i = 0; i < tok.size(); i++) {
    int n = tok.get(i).length();
    total += n;
    if (pred.contains(i))
      score += n;
  }
  ret score*100.0/total;
}

Author comment

Began life as a copy of #1001025

download  show line numbers  debug dex  old transpilations   

Travelled to 16 computer(s): aoiabmzegqzx, bhatertpkbcr, cbybwowwnfue, cfunsshuasjs, ddnzoavkxhuk, gwrvuhgaqvyk, ishqpsrjomds, lpdgvwnxivlt, mqqgnosmbjvj, onxytkatvevr, pyentgdyhuwx, pzhvpgtvlbxg, teubizvjbppd, tslmcundralx, tvejysmllsmz, vouqrxazstgt

No comments. add comment

Snippet ID: #1001028
Snippet name: Token prediction, multiple predictors (v3, including start trees, developing)
Eternal ID of this version: #1001028/1
Text MD5: 6a1e871348685598ee089e9642400ab0
Transpilation MD5: e27ea70ea3a336e2c3da51a230a32af6
Author: stefan
Category:
Type: JavaX source code
Public (visible to everyone): Yes
Archived (hidden from active list): No
Created/modified: 2016-06-15 14:32:50
Source code size: 8674 bytes / 379 lines
Pitched / IR pitched: No / No
Views / Downloads: 692 / 993
Referenced in: [show references]