Not logged in.  Login/Logout/Register | List snippets | | Create snippet | Upload image | Upload data

600
LINES

< > BotCompany Repo | #1028066 // Auto Classifier v5 [learning message classifier]

JavaX source code (Dynamic Module) [tags: use-pretranspiled] - run with: Stefan's OS

Uses 911K of libraries. Click here for Pure Java version (20530L/109K).

!7

cmodule AutoClassifier > DynConvo {
  // THEORY BUILDING BLOCKS (Theory + MsgProp + subclasses)
  
  srecord Theory(BasicLogicRule statement) {
    new PosNeg<Msg> examples;
    //bool iff; // <=> instead of only =>
    toString { ret str(statement.lhs instanceof MPTrue ? "Every message is " + statement.rhs
      : bidiMode ? statement.lhs + " <=> " + statement.rhs : statement); }
  }

  // propositions about a message. check returns null if unknown
  asclass MsgProp { abstract Bool check(Msg msg); }

  static transformable record MPTrue() > MsgProp {
    Bool check(Msg msg) { true; }
    toString { ret "always"; }
  }
  
  transformable record HasLabel(S label) > MsgProp {
    Bool check(Msg msg) { ret msg2label_new.get(msg, label); }
    toString { ret label; }
  }
  
  transformable record DoesntHaveLabel(S label) > MsgProp {
    Bool check(Msg msg) { ret not(msg2label_new.get(msg, label)); }
    toString { ret "not " + label; }
  }

  transformable record FeatureValueIs(S feature, O value) > MsgProp {
    Bool check(Msg msg) { ret eq(getMsgFeature(msg, feature), value); }
    toString { ret feature + "=" + value; }
  }
  
  // feature extracts the text from msg
  transformable record MMOMatch(S feature, S pattern) > MsgProp {
    Bool check(Msg msg) { ret mmo_match2(pattern, (S) getMsgFeature(msg, feature)); }
    toString { ret renderFunctionCall("MMOMatch", pattern, feature); }
  }
  
  // LABEL class (with best theories)

  class Label {
    S name;

    *() {}
    *(S *name) {}
    
    TreeSetWithDuplicates<Theory> bestTheories = new(reverseComparatorFromCalculatedField theoryScore());

    double score() { ret theoryScore(first(bestTheories)); }
    Theory bestTheory() { ret first(bestTheories); }
  }
  
  // FEATURE base classes (FeatureEnv + FeatureExtractor)
  
  sinterface FeatureEnv<A> {
    A mainObject();
    O getFeature(S name);
  }

  sinterface FeatureExtractor<A> {
    O get(FeatureEnv<A> env);
  }
  
  // PREDICTION class (output of classifier)
  
  srecord Prediction(S label, bool plus, double adjustedConfidence) {
    toString {
      ret predictedLabel() + " (confidence: " + iround(adjustedConfidence) + "%)";
    }

    S predictedLabel() {
      ret (plus ? "" : "not ") + label;
    }
  }

  // DATA (backend)
  
  sbool bidiMode = true; // treat all theories as bidirectional
  L<Msg> msgs; // all messages (order not used yet)
  transient Map<Msg, MapSO> msg2features = new AutoMap<Msg, MapSO>(lambda1 calcMsgFeatures);
  Set<S> allLabels = syncTreeSet();
  transient new Map<S, Label> labelsByName;
  new LinkedHashSet<Theory> theories;
  transient Q thinkQ;
  transient new L<IVF1<S>> onNewLabel;
  new DoubleKeyedMap<Msg, S, Bool> msg2label_new;
  transient new Map<S, FeatureExtractor<Msg>> featureExtractors;
  
  // DATA (GUI)
  
  switchable double minAdjustedScoreToDisplay = 50;
  switchable bool autoNext = false;
  L<Msg> shownMsgs;
  S analysisText, labelsForMsgText;
  transient JTable theoryTable, labelsTable, trainedExamplesTable, objectsTable;
  transient JTabbedPane tabs;
  transient SingleComponentPanel scpPredictions;
  S labelForDeepThought;
  transient JComboBox cbLabelForDeepThought;
  
  // START CODE

  start {
    thinkQ = dm_startQ("Thought Queue");
    thinkQ.add(r {
      // legacy + after deletion cleaning
      setField(allLabels := asSyncTreeSet(msg2label_new.bKeys()));
      updateLabelsByName();
      
      onNewLabel.add(lbl -> change());
  
      makeTheoriesAboutLabels();
      makeTheoriesAboutFeaturesAndLabels();
      
      for (S field : fields(Msg))
        featureExtractors.put(field, env -> getOpt(env.mainObject(), field));
  
      makeTextExtractors("text");
  
      callFAllOnAll(onNewLabel, allLabels);
      onNewLabel.add(lbl -> setComboBoxItems(cbLabelForDeepThought, allLabels));
      
      msg2labelUpdated();
      updatePredictions();
      checkAllTheories();
      //showRandomMsg();
    });
  }
  
  // THEORY MAKING

  void makeTheoriesAboutLabels {
    // For any label X:
    onNewLabel.add(lbl -> {
      // test theory (for every M: M has label X)
      addTheory(new Theory(BasicLogicRule(new MPTrue, new HasLabel(lbl))));
      // test theory (for every M: M doesn't have label X)
      addTheory(new Theory(BasicLogicRule(new MPTrue, new DoesntHaveLabel(lbl))));
    });
  }

  void makeTheoriesAboutFeaturesAndLabels {
    // for every label X:
    onNewLabel.add(lbl -> {
      // For any feature F:
      for (S feature : keys(featureExtractors))
        // for every seen value V of F:
        for (O value : possibleValuesOfFeatureRelatedToLabel(feature, lbl))
          for (O rhs : ll(new HasLabel(lbl), new DoesntHaveLabel(lbl)))
            // test theory (for every M: msg M's feature F has value V => msg has/doesn't have label x))
            addTheory(new Theory(BasicLogicRule(
              new FeatureValueIs(feature, value), rhs)));
    });
  }
  
  // THEORY MAKING (helper functions)

  Set possibleValuesOfFeature(S feature) {
    if (isBoolField(Msg, feature))
      ret litset(false, true);
    ret litset();
  }

  Set possibleValuesOfFeatureRelatedToLabel(S feature, S label) {
    Set set = possibleValuesOfFeature(feature);
    fOr (Msg msg : getMsgsRelatedToLabel(label))
      set.add(getMsgFeature(msg, feature));
    ret set;
  }
  
  // CALCULATE FEATURES

  O getMsgFeature(Msg msg, S feature) {
    ret msg2features.get(msg).get(feature);
  }
  
  // returns AutoMap with no realized entries
  Map<S, O> calcMsgFeatures(Msg msg) {
    new Var<FeatureEnv<Msg>> env;
    AutoMap<S, O> map = new AutoMap<S, O>(feature -> featureExtractors.get(feature).get(env!));
    env.set(new FeatureEnv<Msg> {
      Msg mainObject() { ret msg; }
      O getFeature(S feature) { ret map.get(feature); }
    });
    ret map;    
  }
  
  // GUI: Show messages

  void showMsgs(L<Msg> l) {
    setField(shownMsgs := l);
    setMsgs(l);
    if (l(shownMsgs) == 1) {
      Msg msg = first(shownMsgs);
      setField(labelsForMsgText := or2(renderBoolMap(getMsgLabels(msg)), "-"));
      setField(analysisText := joinWithEmptyLines(
        "Trained Labels: " + labelsForMsgText,
        "Features:\n" + formatColonProperties_quoteStringValues(
msg2features.get(msg))
      ));
      setSCPComponent(scpPredictions,
        scrollableStackWithSpacing(map(predictionsForMsg(msg), p -> {
          S percent = iround(p.adjustedConfidence) + "%";
          S neg = "not " + p.label;
          Bool knownValue = msg2label_new.get(msg, p.label);
          embedded S strong(S html) { ret b(html, style := "font-size: 18; color: #008000"); }
          embedded JComponent makeButton(bool known, bool predicted, S label) {
            S html = predicted ? jlabel_centerHTML(joinWithBR(
              strong(htmlencode(label)), percent))
              : label;
            S toolTip = predicted ? "Predicted with " + percent + " confidence" + stringIf(!known, ". Click to confirm") 
              : !known ? "Click to set this label for message" : "";
            if (known) ret setTooltip(toolTip, jcenteredlabel(html));
            JButton btn = setTooltip(toolTip, jbutton(html, rThread { sendInput2(label) }));
            ret predicted ? btn : jfullcenter(btn);
          }
          
          ret withSideMargin(jhgridWithSpacing(
            makeButton(isTrue(knownValue), p.plus, p.label),
            makeButton(isFalse(knownValue), !p.plus, neg)
          ));
        })));
    } else setFields(analysisText := "", labelsForMsgText := "");
  }

  void updatePredictions() {
    showMsgs(shownMsgs);
  }
  
  void showRandomMsg {
    showMsgs(randomElementAsList(msgs));
  }
  
  void showPrevMsg {
    showMsgs(llNonNulls(prevInCyclicList(msgs, first(shownMsgs))));
  }

  void showNextMsg {
    showMsgs(llNonNulls(nextInCyclicList(msgs, first(shownMsgs))));
  }

  // CALCULATE PREDICTIONS FOR MESSAGE

  L<Prediction> predictionsForMsg(Msg msg) {
    // positive labels first, then "not"s. sort by score in each group
    new L<Prediction> out;
    for (Label label : values(labelsByName)) {
      Theory t = label.bestTheory(), continue if null;
      Bool lhs = evalTheoryLHS(t, msg), continue if null;
      bool prediction = t.statement.rhs instanceof DoesntHaveLabel ? !lhs : lhs;
      double conf = threeB1BScore(t.examples), adjusted = adjustConfidence(conf);
      //if (adjusted < minAdjustedScoreToDisplay) continue;
      out.add(new Prediction(label.name, prediction, adjusted));
    }
    ret sortedByCalculatedFieldDesc(out, p -> /*pair(p.plus,*/ p.adjustedConfidence/*)*/);
  }

  // go from range 50-100 to 0-100 (looks better/more intuitive)
  double adjustConfidence(double x) {
    ret max(0, (x-50)*2);
  }
  
  // rough reverse function of adjustConfidence
  double unadjustConfidence(double x) {
    ret x/2+50;
  }
  
  // GUI: Enter labels
  
  void acceptPrediction(Prediction p) {
    if (p != null) sendInput2(p.predictedLabel());
  }

  void rejectPrediction(Prediction p) {
    if (p != null) sendInput2(cloneWithFlippedBoolField plus(p).predictedLabel());
  }

  @Override
  void sendInput2(S s) {
    // treat input as a label
    if (l(shownMsgs) == 1) {
      Msg shown = first(shownMsgs);
      new Matches m;
      if "not ..." {
        S label = cleanLabel(m.rest());
        doubleKeyedMapPutVerbose(+msg2label_new, shown, label, false);
        msg2labelUpdated(label);
        if (autoNext) showRandomMsg();
      } else {
        S label = cleanLabel(s);
        doubleKeyedMapPutVerbose(+msg2label_new, shown, label, true);
        msg2labelUpdated(label);
        if (autoNext) showRandomMsg();
      }
      change();
    }
  }
  
  // MESSAGE LABEL HANDLING

  Map<S, Bool> getMsgLabels(Msg msg) {
    ret msg2label_new.getA(msg);
  }
  
  Set<Msg> getMsgsRelatedToLabel(S label) { ret msg2label_new.asForB(label); }

  void msg2labelUpdated(S label) {
    for (Theory t : cloneList(labelByName(label).bestTheories))
      checkTheory(t);
    msg2labelUpdated();
  }

  void msg2labelUpdated() {
    callFAllOnAll(onNewLabel, addAll_returnNew(allLabels, msg2label_new.bKeys()));
    updateTrainedExamplesTable();
  }
  
  // QUERY: get all labels + best theory each
  
  Map<S, Theory> labelsToBestTheoryMap() {
    Map<S, L<Theory>> map = multiMapToMap(multiMapIndex targetLabelOfTheory(theories));
    ret mapValues(map, theories -> highestBy theoryScore(theories));
  }
  
  Map<Msg, Bool> examplesForLabel(S label) {
    ret msg2label_new.getB(label);
  }

  // GUI: Main layout

  visual
    withCenteredButtons(super,
      "<", rInThinkQ(r showPrevMsg),
      "Show random msg", rInThinkQ(r showRandomMsg),
      ">", rInThinkQ(r showNextMsg),
      jPopDownButton_noText(flattenObjectArray(
        "Check theories", rInThinkQ(r checkAllTheories),
        "Forget bad theories", rInThinkQ(r { forgetBadTheories(0) }),
        "Forget all theories", rInThinkQ(r clearTheories),
        "Update predictions", rInThinkQ(r updatePredictions),
        dm_importAndExportAllDataMenuItems())));

  JComponent mainPart() {
    tablePopupMenuItemsThreaded_top(labelsTable = sexyTable(),
      "Copy examples to clipboard", rEnter {
        copyTextToClipboard_lineCountInfoBox(collectAsLines text(keysWithValueTrue(examplesForLabel((S) selectedTableCell(labelsTable, "Label")))))
      },
      "Copy counterexamples to clipboard", rEnter {
        copyTextToClipboard_lineCountInfoBox(collectAsLines text(keysWithValueFalse(examplesForLabel((S) selectedTableCell(labelsTable, "Label")))))
      });
    ret jhsplit(jvsplit(
      jCenteredSection("Focused Message",
        centerAndSouthWithMargin(super.mainPart(),
        jCenteredSection("Labels assigned to message", dm_centeredLabel labelsForMsgText()))),
      //jhsplit(
        jCenteredSection("Predictions for message (green)", scpPredictions = singleComponentPanel()),
        /*jCenteredSection("Deep Thought", northAndCenterWithMargin(
          withLabel("Think about label:", cbLabelForDeepThought = dm_comboBox labelForDeepThought(allLabels)),
          jcenteredlabel("TODO")))
      )*/),
      with(r updateTabs, tabs = jtabs(
        "", with(r updateObjectsTable, withRightAlignedButtons(
          tablePopupMenuItemsThreaded(
            onDoubleClickOrEnter(rThread showSelectedObject,
              objectsTable = sexyTable()),
            "Delete", r deleteSelectedMessages),
          "Import messages...", rThreadEnter importMsgs)),
        "", with(r updateLabelsTable, labelsTable),
        "", with(r updateTheoryTable, tableWithSearcher2_returnPanel(theoryTable = sexyTable())),
        "", with(r updateTrainedExamplesTable, tableWithSearcher2_returnPanel(trainedExamplesTable = sexyTable()))
      )));
  }
  
  // GUI: Update tables & tabs

  void updateTrainedExamplesTable {
    dataToTable_uneditable(trainedExamplesTable, map(msg2label_new.map1, (msg, map) ->
      litorderedmap(
        "Message" := (msg.fromUser ? "User" : "Bot") + ": " + msg.text,
        "Labels" := renderBoolMap(map))));
  }

  void updateTabs {
    setTabTitles(tabs,
      firstLetterToUpper(nMessages(msgs)),
      firstLetterToUpper(nLabels(labelsByName)),
      firstLetterToUpper(nTheories(theories)),
      n2(msg2label_new.aKeys(), "Trained Example"));
  }

  void updateTheoryTable {
    L<Theory> sorted = sortedByCalculatedFieldDesc theoryScore(theories);
    dataToTable_uneditable(theoryTable, map(sorted, t -> litorderedmap(
      "Score" := renderTheoryScore(t),
      "Theory" := str(t))));
  }

  void updateObjectsTable enter {
    dataToTable_uneditable_ifHasTable(objectsTable, map(msgs, msg ->
      litorderedmap("Text" := msg.text)
    ));
  }

  void updateLabelsTable enter {
    L<Label> sorted = sortedByCalculatedFieldDesc(values(labelsByName), l -> l.score());
    dataToTable_uneditable_ifHasTable(labelsTable, map(sorted, label -> {
      Cl<Theory> bestTheories = label.bestTheories.tiedForFirst();
      Map<Msg, Bool> examples = examplesForLabel(label.name);
      ret litorderedmap(
        "Label" := label.name,
        "Examples/Counterexamples" := countKeysWithValue(true, examples) + "/" + countKeysWithValue(false, examples),
        "Prediction Confidence" := renderTheoryScore(first(bestTheories)),
        "Best Theory" := empty(bestTheories) ? "" :
          (l(bestTheories) > 1 ? "[+" + (l(bestTheories)-1) + "] " : "") +  first(bestTheories));
    }));
  }
  
  void theoriesChanged {
    updateTheoryTable();
    updateLabelsTable();
    updateTabs();
    updatePredictions();
    change();
  }

  // THEORY SCORING
  
  S renderTheoryScore(Theory t) {
    //ret renderPosNegCounts(t.examples);
    ret t == null || t.examples.isEmpty() ? "" : iround(theoryScore(t)) + "%"
      + " / " + renderPosNegScore2(t.examples);
  }

  // adjusted + 3b1b
  double theoryScore(Theory t) {
    ret t == null ? -100 : adjustConfidence(threeB1BScore(t.examples));
  }
  
  // QUEUE HELPER

  Runnable rInThinkQ(Runnable r) { ret rInQ(thinkQ, r); }
  
  // ADD + REMOVE + CLEAN UP THEORIES

  void addTheory(Theory theory) {
    if (theories.add(theory)) {
      addTheoryToCollectors(theory);
      theoriesChanged();
    }
  }

  void clearTheories { theories.clear(); theoriesChanged(); }
  
  // theories with exaclty minScore will go too
  void forgetBadTheories(double minScore) {
    if (removeElementsThat(theories, t -> theoryScore(t) <= minScore))
      theoriesChanged();
  }
  
  // CHECK PROPOSITIONS + THEORIES

  Bool checkMsgProp(O prop, Msg msg) {
    if (prop cast And) ret checkMsgProp(prop.a, msg) && checkMsgProp(prop.b, msg);
    if (prop cast Not) ret not(checkMsgProp(prop.a, msg));
    ret ((MsgProp) prop).check(msg);
  }

  Bool evalTheoryLHS(Theory theory, Msg msg) {
    ret theory == null ? null
      : checkMsgProp(theory.statement.lhs, msg);
  }

  Bool testTheoryOnMsg(Theory theory, Msg msg) {
    Bool lhs = evalTheoryLHS(theory, msg);
    Bool rhs = checkMsgProp(theory.statement.rhs, msg);
    if (lhs == null || rhs == null) null;
    if (bidiMode)
      ret eq(lhs, rhs);
    else
      ret isTrue(rhs) || isFalse(lhs);
  }

  void checkAllTheories {
    for (Theory theory : theories)
      checkTheory_noTrigger(theory);
    theoriesChanged();
  }

  void checkTheory(Theory theory) {
    checkTheory_noTrigger(theory);
    theoriesChanged();
  }

  void checkTheory_noTrigger(Theory theory) {
    new PosNeg<Msg> pn;
    for (Msg msg : msgs)
      pn.add(msg, testTheoryOnMsg(theory, msg));
    if (!eq(theory.examples, pn)) {
      removeTheoryFromCollectors(theory);
      theory.examples = pn;
      addTheoryToCollectors(theory);
      change();
    }
  }
  
  S targetLabelOfTheory(Theory theory) {
    O o = theory.statement.rhs;
    if (o cast HasLabel) ret o.label;
    if (o cast DoesntHaveLabel) ret o.label;
    null;
  }

  // CANONICALIZE LABELS

  S cleanLabel(S label) { ret upper(label); }
  
  // THEORY + LABEL UPDATES
  
  void addTheoryToCollectors(Theory theory) {
    S lbl = targetLabelOfTheory(theory);
    if (lbl != null)
      labelByName(lbl).bestTheories.add(theory);
  }

  void removeTheoryFromCollectors(Theory theory) {
    S lbl = targetLabelOfTheory(theory);
    if (lbl != null)
      labelByName(lbl).bestTheories.remove(theory);
  }

  Label labelByName(S name) {
    ret getOrCreate(labelsByName, name, () -> new Label(name));
  }

  void updateLabelsByName() {
    for (S lbl : allLabels)
      labelByName(lbl);
    for (Theory t : theories)
      addTheoryToCollectors(t);
  }

  // MAKE FEATURE EXTRACTORS

  void makeTextExtractors(S textFeature) {
    for (WithName<IF1<S, O>> f : textExtractors()) {
      IF1<S, O> theFunction = f!;
      featureExtractors.put(f.name, env -> theFunction.get((S) env.getFeature(textFeature)));
    }
  }

  L<WithName<IF1<S, O>>> textExtractors() {
    new L<WithName<IF1<S, O>>> l;
    l.add(WithName<>("number of words", lambda1 numberOfWords));
    l.add(WithName<>("number of characters", lambda1 l));
    for (char c : characters("\"', .-_"))
      l.add(WithName<>("contains " + quote(c), s -> contains(s, c)));
    /*for (S word : concatAsCISet(lambdaMap words(collect text(msgs))))
      l.add(WithName<>("contains word " + quote(word), s -> containsWord(s, word)));*/
    ret l;
  }
  
  // GUI: Import messages dialog, warn on delete, other stuff

  void importMsgs {
    inputMultiLineText("Messages to import (one per line)", voidfunc(S text) {
      Cl<S> toImport = listMinusSet(asOrderedSet(tlft(text)), collectAsSet text(msgs));
      if (msgs == null) msgs = ll();
      for (S line : toImport)
        msgs.add(new Msg(true, line));
      change();
      infoBox(nMessages(toImport) + " imported");
      updateObjectsTable();
      showRandomMsg();
    });
  }
  
  bool warnOnDelete() { true; }
  
  void showSelectedObject enter {
    showMsgs(llNotNulls(get(msgs, selectedRow(objectsTable))));
  }
  
  void deleteSelectedMessages {
    Set<Msg> toDelete = asSet(getMulti(msgs, selectedRows(objectsTable)));
    removeFromCollection(msgs, toDelete);
    removeAll(msg2features, toDelete);
    msg2label_new.removeAllA(toDelete);
    change();
    updateObjectsTable();
    showRandomMsg();
  }
  
  // DEEP THOUGHT (make more complex theories for label)
  
  void deepThought(S label) {
  }
}

Author comment

Began life as a copy of #1028063

download  show line numbers  debug dex  old transpilations   

Travelled to 7 computer(s): bhatertpkbcr, mqqgnosmbjvj, pyentgdyhuwx, pzhvpgtvlbxg, tvejysmllsmz, vouqrxazstgt, xrpafgyirdlv

No comments. add comment

Snippet ID: #1028066
Snippet name: Auto Classifier v5 [learning message classifier]
Eternal ID of this version: #1028066/30
Text MD5: 28da243219e7f859a308c41d5f12514d
Transpilation MD5: d994e645303cfdfd02142b8f06886da7
Author: stefan
Category: javax / a.i.
Type: JavaX source code (Dynamic Module)
Public (visible to everyone): Yes
Archived (hidden from active list): No
Created/modified: 2020-06-07 19:27:39
Source code size: 19856 bytes / 600 lines
Pitched / IR pitched: No / No
Views / Downloads: 205 / 3359
Version history: 29 change(s)
Referenced in: [show references]