Not logged in.  Login/Logout/Register | List snippets | | Create snippet | Upload image | Upload data

419
LINES

< > BotCompany Repo | #1001037 // Token prediction, multiple predictors (v5, developing)

JavaX source code [tags: use-pretranspiled] - run with: x30.jar

Libraryless. Click here for Pure Java version (2911L/22K/67K).

!747
!actionListener {

m {
  static S corpusID = 
    //"#1001034"; // one small snippet
    "#1001006"; // snippets DB
  static int numSnippets = 100;
  static boolean showGUI = true;
  static int maxCharsGUI = 500000;
  static boolean allTokens = true;
  
  static Collector collector;
  static L<F> files;
  static Map<F, Set<int>> predicted;
  
  // a file to learn from
  static class F {
    String id, name;
    L<S> tok;
  }
  
  // a predictor
  static abstract class P {
    int seen;
    S file;

    // basic function - predict next token    
    abstract S read(S file, L<S> tok);
    
    // advanced function - predict rest of token starting with t
    S complete(S file, L<S> tok, S t) { return null; }
    
    abstract P derive(); // clone with trained data
    abstract P clear(); // clone without trained data
    
    void prepare(S file) {
      if (!eq(file, this.file)) {
        seen = 0;
        this.file = file;
      }
    }
  }

  static class Chain extends P {
    new L<P> list;
    
    *() {}
    *(L<P> *list) {}
    *(P... a) { list = asList(a); }
    
    void add(P p) { list.add(p); }
    
    S read(S file, L<S> tok) {
      for (P p : list) {
        S s = p.read(file, tok);
        if (s != null) return s;
      }
      return null;
    }
    
    P derive() {
      new Chain c;
      for (P p : list)
        c.add(p.derive());
      return c;
    }
    
    P clear() {
      new Chain c;
      for (P p : list)
        c.add(p.clear());
      return c;
    }
  }
    
  static class Tuples extends P {
    Map<L<S>,S> map = new HashMap<L<S>,S>();
    int n;

    *(int *n) {
    }
    
    S read(S file, L<S> tok) {
      prepare(file);
      
      while (tok.size() > seen) {
        ++seen;
        if (seen > n)
          map.put(new ArrayList<S>(tok.subList(seen-n-1, seen-1)), tok.get(seen-1));
      }
      
      if (tok.size() >= n)
        return map.get(new ArrayList<S>(tok.subList(tok.size()-n, tok.size())));
        
      return null;
    }
    
    P derive() {
      Tuples t = new Tuples(n);
      t.map = new DerivedHashMap<L<S>,S>(map);
      return t;
    }
    
    P clear() {
      return new Tuples(n);
    }
  }
  
  static Map<S, S> makeMapPrefix(L<S> tok1, L<S> tok2) {
    if (tok1.size() < tok2.size()) return null;
    
    new Map<S, S> map;
    for (int i = 1; i < tok2.size(); i += 2) {
      S t1 = tok1.get(i), t2 = tok2.get(i);
      if (!t1.equals(t2)) {
        S v = map.get(t1);
        if (v == null)
          map.put(t1, t2);
        else if (!v.equals(t2))
          return null; // match fail
      }
    }
    
    // match succeeds
    return map;
  }
  
  !include #1001041 // Pattern
  
  !include #1001027 // DerivedHashMap
  
  !include #1001036 // LastWordToLower
  
  static class Node {
    String token;
    float count;
    new L<Node> next;
    
    *() {} // for clone method
    
    *(S *token) {}
    
    Node find(S token) {
      for (Node n : next)
        if (n.token.equals(token))
          ret n;
      ret null;
    }
    
    Node bestNext() {
      float bestCount = 0f;
      Node best = null;
      for (Node n : next)
        if (best == null || n.count > best.count) {
          best = n;
          bestCount = n.count;
        }
      ret best;
    }
  }
  
  static class StartTree extends P {
    Node tree = new Node("");
    Node node;
    boolean nonmod;
    
    S read(S file, L<S> tok) {
      if (!eq(file, this.file)) {
        seen = 0;
        this.file = file;
        node = tree;
      }
      
      if (!nonmod) while (tok.size() > seen) {
        S t = tok.get(seen++);
        Node child = node.find(t);
        if (child == null)
          node.next.add(child = new Node(t));
        child.count++;
        node = child;
      }
      
      Node n = node.bestNext();
      ret n != null ? n.token : null;
    }
    
    // it's a hack - derived predictor doesn't learn
    P derive() {
      //return (P) main.clone(this);
      new StartTree p;
      p.nonmod = true;
      p.tree = tree;
      return p;
    }
    
    P clear() {
      return new StartTree;
    }
  }
  
  p {
    files = makeCorpus();
    print("Files in corpus: " + files.size());
    
    print("Learning...");
    collector = new Collector;
    test(new Chain(new Tuples(8), new Tuples(6), new Tuples(4), new Tuples(2), new Tuples(1), new StartTree));
    
    //test(new Patterns(6));
    //test(new Chain(new Patterns(9), new LastWordToLower));
    test(new Chain(new Patterns(9), new Patterns(7), new Patterns(5), new LastWordToLower));

    print("Learning done.");
    printVMSize();
    if (collector.winner != null && showGUI)
      window();
  }
  
  static int points = 0, total = 0;
  
  // train & evaluate a predictor
  static void test(P p) {
    int lastPercent = 0;
    predicted = new HashMap;
    points = 0;
    total = 0;
    for (int ii = 0; ii < files.size(); ii++) {
      F f = files.get(ii);
      
      testFile(p, f);
      
      int percent = roundUpTo(10, (int) (ii*100L/files.size()));
      if (percent > lastPercent) {
        print("Learning " + percent + "% done.");
        lastPercent = percent;
      }
    }
    double score = points*100.0/total;
    collector.add(p, score);
  }
  
  static void testFile(P p, F f) {
    new TreeSet<int> pred;
    new L<S> history;
    for (int i = allTokens ? 0 : 1; i < f.tok.size(); i += allTokens ? 1 : 2) {
      S t = f.tok.get(i);
      S x = p.read(f.name, history);
      boolean correct = t.equals(x);
      total += t.length();
      if (correct) {
        pred.add(i);
        points += t.length();
      }
      history.add(t);
    }
    p.read(f.name, history); // feed last token, ignore output
    predicted.put(f, pred);
  }
  
  !include #1000989 // SnippetDB
  
  static L<F> makeCorpus() ctex {
    S name = getSnippetTitle(corpusID);
    S s = loadSnippet(corpusID);
    if (s.length() != 0)
      return makeCorpus_single(s);
    else if (name.toLowerCase().indexOf(".zip") >= 0)
      return makeCorpus_zip();
    else
      return makeCorpus_mysqldump();
  }
  
  static L<F> makeCorpus_single(S text) ctex {
    new L<F> files;
    new F f;
    f.id = corpusID;
    f.name = getSnippetTitle(corpusID);
    f.tok = internAll(javaTok(text));
    files.add(f);
    return files;
  }
  
  static L<F> makeCorpus_zip() ctex {
    new L<F> files;
    ZipFile zipFile = new ZipFile(loadLibrary(corpusID));
    Enumeration entries = zipFile.entries();

    while (entries.hasMoreElements() && files.size() < numSnippets) {
      ZipEntry entry = (ZipEntry) entries.nextElement(); 
      if (entry.isDirectory()) continue;
      //System.out.println("File found: " + entry.getName());

      InputStream fin = zipFile.getInputStream(entry);
      // TODO: try to skip binary files?
      
      InputStreamReader reader = new InputStreamReader(fin, "UTF-8");
      new StringBuilder builder;
      BufferedReader bufferedReader = new BufferedReader(reader);
      String line;
      while ((line = bufferedReader.readLine()) != null)
        builder.append(line).append('\n');
      fin.close();
      S text = builder.toString();
      
      new F f;
      f.name = entry.getName();
      f.tok = internAll(javaTok(text));
      files.add(f);
    }
    
    zipFile.close();
    return files;
  }
  
  static L<F> makeCorpus_mysqldump() {
    new L<F> files;
    SnippetDB db = new SnippetDB(corpusID);
    List<List<S>> rows = db.rowsOrderedBy("sn_created");
    for (int i = Math.max(0, rows.size()-numSnippets); i < rows.size(); i++) {
      new F f;
      f.id = db.getField(rows.get(i), "sn_id");
      f.name = db.getField(rows.get(i), "sn_title");
      S text = db.getField(rows.get(i), "sn_text");
      f.tok = internAll(javaTok(text));
      files.add(f);
      ++i;
    }
    return files;
  }

  static class Collector {
    P winner;
    double bestScore = -1;
    Map<F, Set<int>> predicted;

    void add(P p, double score) {
      if (winner == null || score > bestScore) {
        winner = p;
        bestScore = score;
        //S name = shorten(structure(p), 100);
        S name = p.getClass().getName();
        print("New best score: " + formatDouble(score, 2) + "% (" + name + ")");
        predicted = main.predicted;
      }
    }
  }
  
  static void window() {
    //final P p = collector.winner.clear();

    JFrame jf = new JFrame("Predicted = green");
    Container cp = jf.getContentPane();

    final JButton btnNext = new JButton("Next");
    
    final JTextPane pane = new JTextPane();
    //pane.setFont(loadFont("#1000993", 24));
    
    JScrollPane scrollPane = new JScrollPane(pane);
    cp.add(scrollPane, BorderLayout.CENTER);
    
    class X {
      int ii;
      
      void y() ctex {
        ii = ii == 0 ? files.size()-1 : ii-1;
        F f = files.get(ii);
        //testFile(p, f);
        Set<int> pred = collector.predicted.get(f);
        
        StyledDocument doc = new DefaultStyledDocument();

        L<S> tok = f.tok;
        int i = tok.size(), len = 0;
        while (len <= maxCharsGUI && i > 0) {
          --i;
          len += tok.get(i).length();
        }
        
        for (; i < tok.size(); i++) {
          if (tok.get(i).length() == 0) continue;
          boolean green = pred.contains(i);
          SimpleAttributeSet set = new SimpleAttributeSet();
          StyleConstants.setForeground(set, green ? Color.green : Color.gray);
          doc.insertString(doc.getLength(), tok.get(i), set);
        }
        
        pane.setDocument(doc);
        double score = getScore(pred, tok);
        btnNext.setText(f.name + " (" + (ii+1) + "/" + files.size() + ") - " + (int) score + " %");
      }
    }
    final new X x;
    
    btnNext.addActionListener(actionListener {
      x.y();
    });
    cp.add(btnNext, BorderLayout.NORTH);

    x.y();
    
    jf.setBounds(100, 100, 600, 600);
    jf.setVisible(true);    
  }
  
  !include #1001032 // clone function
  
  static double getScore(Set<int> pred, L<S> tok) {
    int total = 0, score = 0;
    for (int i = 0; i < tok.size(); i++) {
      int n = tok.get(i).length();
      total += n;
      if (pred.contains(i))
        score += n;
    }
    ret score*100.0/total;
  }
}

Author comment

Began life as a copy of #1001033

download  show line numbers  debug dex  old transpilations   

Travelled to 15 computer(s): aoiabmzegqzx, bhatertpkbcr, cbybwowwnfue, cfunsshuasjs, gwrvuhgaqvyk, ishqpsrjomds, lpdgvwnxivlt, mqqgnosmbjvj, onxytkatvevr, pyentgdyhuwx, pzhvpgtvlbxg, teubizvjbppd, tslmcundralx, tvejysmllsmz, vouqrxazstgt

No comments. add comment

Snippet ID: #1001037
Snippet name: Token prediction, multiple predictors (v5, developing)
Eternal ID of this version: #1001037/1
Text MD5: b286849fecdd14a17f356c50a3b92756
Transpilation MD5: d3e70f9fff633fe7154db05fbb2e2d3e
Author: stefan
Category:
Type: JavaX source code
Public (visible to everyone): Yes
Archived (hidden from active list): No
Created/modified: 2015-09-16 21:44:38
Source code size: 10641 bytes / 419 lines
Pitched / IR pitched: No / Yes
Views / Downloads: 620 / 627
Referenced in: [show references]