Uses 911K of libraries. Click here for Pure Java version (19059L/103K).
1 | !7 |
2 | |
3 | cmodule AutoClassifier > DynConvo {
|
4 | srecord Theory(BasicLogicRule statement) {
|
5 | new PosNeg<Msg> examples; |
6 | //bool iff; // <=> instead of only => |
7 | toString { ret str(statement.lhs instanceof MPTrue ? "Every message is " + statement.rhs
|
8 | : bidiMode ? statement.lhs + " <=> " + statement.rhs : statement); } |
9 | } |
10 | |
11 | // propositions about a message. check returns null if unknown |
12 | asclass MsgProp { abstract Bool check(Msg msg); }
|
13 | |
14 | srecord MPTrue() > MsgProp {
|
15 | Bool check(Msg msg) { true; }
|
16 | toString { ret "always"; }
|
17 | } |
18 | |
19 | record HasLabel(S label) > MsgProp {
|
20 | Bool check(Msg msg) { ret msg2label_new.get(msg, label); }
|
21 | toString { ret label; }
|
22 | } |
23 | |
24 | record DoesntHaveLabel(S label) > MsgProp {
|
25 | Bool check(Msg msg) { ret not(msg2label_new.get(msg, label)); }
|
26 | toString { ret "not " + label; }
|
27 | } |
28 | |
29 | record FeatureValueIs(S feature, O value) > MsgProp {
|
30 | Bool check(Msg msg) { ret eq(getMsgFeature(msg, feature), value); }
|
31 | toString { ret feature + "=" + value; }
|
32 | } |
33 | |
34 | class Label {
|
35 | S name; |
36 | |
37 | *() {}
|
38 | *(S *name) {}
|
39 | |
40 | TreeSetWithDuplicates<Theory> bestTheories = new(reverseComparatorFromCalculatedField theoryScore()); |
41 | |
42 | int score() { ret theoryScore(first(bestTheories)); }
|
43 | Theory bestTheory() { ret first(bestTheories); }
|
44 | } |
45 | |
46 | switchable double minAdjustedScoreToDisplay = 50; |
47 | switchable bool autoNext = false; |
48 | static bool bidiMode = true; // treat all theories as bidirectional |
49 | |
50 | L<Msg> msgs; // full dialog |
51 | L<Msg> shownMsgs; |
52 | transient Map<Msg, Map<S, O>> msg2features = AutoMap<>(lambda1 calcMsgFeatures); |
53 | new LinkedHashSet<Theory> theories; |
54 | S analysisText; |
55 | transient JTable theoryTable, labelsTable, trainedExamplesTable, objectsTable; |
56 | transient JTabbedPane tabs; |
57 | transient SingleComponentPanel scpPredictions; |
58 | transient new Map<S, Label> labelsByName; |
59 | |
60 | new Set<S> allLabels; |
61 | transient new L<IVF1<S>> onNewLabel; |
62 | new DoubleKeyedMap<Msg, S, Bool> msg2label_new; |
63 | transient new Map<S, FeatureExtractor<Msg>> featureExtractors; |
64 | transient Q thinkQ; |
65 | |
66 | sinterface FeatureEnv<A> {
|
67 | A mainObject(); |
68 | O getFeature(S name); |
69 | } |
70 | |
71 | sinterface FeatureExtractor<A> {
|
72 | O get(FeatureEnv<A> env); |
73 | } |
74 | |
75 | start {
|
76 | thinkQ = dm_startQ("Thought Queue");
|
77 | thinkQ.add(r {
|
78 | // legacy + after deletion cleaning |
79 | setField(allLabels := asTreeSet(msg2label_new.bKeys())); |
80 | updateLabelsByName(); |
81 | |
82 | onNewLabel.add(lbl -> change()); |
83 | |
84 | makeTheoriesAboutLabels(); |
85 | makeTheoriesAboutFeaturesAndLabels(); |
86 | |
87 | for (S field : fields(Msg)) |
88 | featureExtractors.put(field, env -> getOpt(env.mainObject(), field)); |
89 | |
90 | makeTextExtractors("text");
|
91 | |
92 | callFAllOnAll(onNewLabel, allLabels); |
93 | |
94 | msg2labelUpdated(); |
95 | //showRandomMsg(); |
96 | }); |
97 | } |
98 | |
99 | void makeTheoriesAboutLabels {
|
100 | // For any label X: |
101 | onNewLabel.add(lbl -> {
|
102 | // test theory (for every M: M has label X) |
103 | addTheory(new Theory(BasicLogicRule(new MPTrue, new HasLabel(lbl)))); |
104 | // test theory (for every M: M doesn't have label X) |
105 | addTheory(new Theory(BasicLogicRule(new MPTrue, new DoesntHaveLabel(lbl)))); |
106 | }); |
107 | } |
108 | |
109 | void makeTheoriesAboutFeaturesAndLabels {
|
110 | // for every label X: |
111 | onNewLabel.add(lbl -> {
|
112 | // For any feature F: |
113 | for (S feature : keys(featureExtractors)) |
114 | // for every seen value V of F: |
115 | for (O value : possibleValuesOfFeatureRelatedToLabel(feature, lbl)) |
116 | for (O rhs : ll(new HasLabel(lbl), new DoesntHaveLabel(lbl))) |
117 | // test theory (for every M: msg M's feature F has value V => msg has/doesn't have label x)) |
118 | addTheory(new Theory(BasicLogicRule( |
119 | new FeatureValueIs(feature, value), rhs))); |
120 | }); |
121 | } |
122 | |
123 | Set possibleValuesOfFeature(S feature) {
|
124 | if (isBoolField(Msg, feature)) |
125 | ret litset(false, true); |
126 | ret litset(); |
127 | } |
128 | |
129 | Set possibleValuesOfFeatureRelatedToLabel(S feature, S label) {
|
130 | Set set = possibleValuesOfFeature(feature); |
131 | fOr (Msg msg : getMsgsRelatedToLabel(label)) |
132 | set.add(getMsgFeature(msg, feature)); |
133 | ret set; |
134 | } |
135 | |
136 | // returns AutoMap with no realized entries |
137 | Map<S, O> calcMsgFeatures(Msg msg) {
|
138 | new Var<FeatureEnv<Msg>> env; |
139 | AutoMap<S, O> map = new(feature -> featureExtractors.get(feature).get(env!)); |
140 | env.set(new FeatureEnv<Msg> {
|
141 | Msg mainObject() { ret msg; }
|
142 | O getFeature(S feature) { ret map.get(feature); }
|
143 | }); |
144 | ret map; |
145 | } |
146 | |
147 | void showMsgs(L<Msg> l) {
|
148 | setField(shownMsgs := l); |
149 | setMsgs(l); |
150 | if (l(shownMsgs) == 1) {
|
151 | Msg msg = first(shownMsgs); |
152 | setField(analysisText := joinWithEmptyLines( |
153 | "Trained Labels: " + or2(renderBoolMap(getMsgLabels(msg)), "-"), |
154 | "Features:\n" + formatColonProperties_quoteStringValues( |
155 | msg2features.get(msg)) |
156 | )); |
157 | setSCPComponent(scpPredictions, |
158 | scrollableStackWithSpacing(map(predictionsForMsg(msg), p -> |
159 | withSideMargin(jLabelWithButtons(iround(p.adjustedConfidence) + "%: " + p.predictedLabel(), |
160 | "Right", rThread { acceptPrediction(p) },
|
161 | "Wrong", rThread { rejectPrediction(p) })))));
|
162 | } else setField(analysisText := ""); |
163 | } |
164 | |
165 | void updatePredictions() {
|
166 | showMsgs(shownMsgs); |
167 | } |
168 | |
169 | srecord Prediction(S label, bool plus, double adjustedConfidence) {
|
170 | toString {
|
171 | ret predictedLabel() + " (confidence: " + iround(adjustedConfidence) + "%)"; |
172 | } |
173 | |
174 | S predictedLabel() {
|
175 | ret (plus ? "" : "not ") + label; |
176 | } |
177 | } |
178 | |
179 | L<Prediction> predictionsForMsg(Msg msg) {
|
180 | // positive labels first, then "not"s. sort by score in each group |
181 | new L<Prediction> out; |
182 | for (Label label : values(labelsByName)) {
|
183 | Theory t = label.bestTheory(), continue if null; |
184 | Bool lhs = evalTheoryLHS(t, msg), continue if null; |
185 | bool prediction = t.statement.rhs instanceof DoesntHaveLabel ? !lhs : lhs; |
186 | double conf = threeB1BScore(t.examples), adjusted = adjustConfidence(conf); |
187 | //if (adjusted < minAdjustedScoreToDisplay) continue; |
188 | out.add(new Prediction(label.name, prediction, adjusted)); |
189 | } |
190 | ret sortedByCalculatedFieldDesc(out, p -> /*pair(p.plus,*/ p.adjustedConfidence/*)*/); |
191 | } |
192 | |
193 | // go from range 50-100 to 0-100 (might look better) |
194 | double adjustConfidence(double x) {
|
195 | ret max(0, (x-50)*2); |
196 | } |
197 | |
198 | void showRandomMsg {
|
199 | showMsgs(randomElementAsList(msgs)); |
200 | } |
201 | |
202 | void showPrevMsg {
|
203 | showMsgs(llNonNulls(prevInCyclicList(msgs, first(shownMsgs)))); |
204 | } |
205 | |
206 | void showNextMsg {
|
207 | showMsgs(llNonNulls(nextInCyclicList(msgs, first(shownMsgs)))); |
208 | } |
209 | |
210 | void acceptPrediction(Prediction p) {
|
211 | if (p != null) sendInput2(p.predictedLabel()); |
212 | } |
213 | |
214 | void rejectPrediction(Prediction p) {
|
215 | if (p != null) sendInput2(cloneWithFlippedBoolField plus(p).predictedLabel()); |
216 | } |
217 | |
218 | @Override |
219 | void sendInput2(S s) {
|
220 | // treat input as a label |
221 | if (l(shownMsgs) == 1) {
|
222 | Msg shown = first(shownMsgs); |
223 | new Matches m; |
224 | if "not ..." {
|
225 | S label = cleanLabel(m.rest()); |
226 | doubleKeyedMapPutVerbose(+msg2label_new, shown, label, false); |
227 | msg2labelUpdated(label); |
228 | if (autoNext) showRandomMsg(); |
229 | } else {
|
230 | S label = cleanLabel(s); |
231 | doubleKeyedMapPutVerbose(+msg2label_new, shown, label, true); |
232 | msg2labelUpdated(label); |
233 | if (autoNext) showRandomMsg(); |
234 | } |
235 | change(); |
236 | } |
237 | } |
238 | |
239 | Map<S, Bool> getMsgLabels(Msg msg) {
|
240 | ret msg2label_new.getA(msg); |
241 | } |
242 | Set<Msg> getMsgsRelatedToLabel(S label) { ret msg2label_new.asForB(label); }
|
243 | |
244 | void msg2labelUpdated(S label) {
|
245 | for (Theory t : cloneList(labelByName(label).bestTheories)) |
246 | checkTheory(t); |
247 | msg2labelUpdated(); |
248 | } |
249 | |
250 | void msg2labelUpdated() {
|
251 | callFAllOnAll(onNewLabel, addAll_returnNew(allLabels, msg2label_new.bKeys())); |
252 | updateTrainedExamplesTable(); |
253 | } |
254 | |
255 | void updateTrainedExamplesTable {
|
256 | dataToTable_uneditable(trainedExamplesTable, map(msg2label_new.map1, (msg, map) -> |
257 | litorderedmap( |
258 | "Message" := (msg.fromUser ? "User" : "Bot") + ": " + msg.text, |
259 | "Labels" := renderBoolMap(map)))); |
260 | } |
261 | |
262 | JComponent mainPart() {
|
263 | ret jhsplit(jvsplit( |
264 | jCenteredSection("Focused Message", super.mainPart()),
|
265 | jhsplit( |
266 | jCenteredSection("Message Analysis", dm_textArea analysisText()),
|
267 | jCenteredSection("Predictions", scpPredictions = singleComponentPanel())
|
268 | )), |
269 | with(r updateTabs, tabs = jtabs( |
270 | "", with(r updateObjectsTable, withRightAlignedButtons( |
271 | objectsTable = sexyTable(), |
272 | "Import messages...", rThreadEnter importMsgs)), |
273 | "", with(r updateLabelsTable, labelsTable = sexyTable()), |
274 | "", with(r updateTheoryTable, tableWithSearcher2_returnPanel(theoryTable = sexyTable())), |
275 | "", with(r updateTrainedExamplesTable, tableWithSearcher2_returnPanel(trainedExamplesTable = sexyTable())) |
276 | ))); |
277 | } |
278 | |
279 | void updateTabs {
|
280 | setTabTitles(tabs, |
281 | firstLetterToUpper(nMessages(msgs)), |
282 | firstLetterToUpper(nLabels(labelsByName)), |
283 | firstLetterToUpper(nTheories(theories)), |
284 | n2(msg2label_new.aKeys(), "Trained Example")); |
285 | } |
286 | |
287 | void updateTheoryTable {
|
288 | L<Theory> sorted = sortedByCalculatedFieldDesc(theories, t -> |
289 | t.examples == null ? null : t.examples.score()); |
290 | dataToTable_uneditable(theoryTable, map(sorted, t -> litorderedmap( |
291 | "Score" := renderTheoryScore(t), |
292 | "Theory" := str(t)))); |
293 | } |
294 | |
295 | Map<S, Theory> labelsToBestTheoryMap() {
|
296 | Map<S, L<Theory>> map = multiMapToMap(multiMapIndex targetLabelOfTheory(theories)); |
297 | ret mapValues(map, |
298 | theories -> highestBy theoryScore(theories)); |
299 | } |
300 | |
301 | void updateObjectsTable enter {
|
302 | dataToTable_uneditable_ifHasTable(objectsTable, map(msgs, msg -> |
303 | litorderedmap("Text" := msg.text)
|
304 | )); |
305 | } |
306 | |
307 | void updateLabelsTable enter {
|
308 | L<Label> sorted = sortedByCalculatedFieldDesc(values(labelsByName), l -> l.score()); |
309 | dataToTable_uneditable_ifHasTable(labelsTable, map(sorted, label -> {
|
310 | Cl<Theory> bestTheories = label.bestTheories.tiedForFirst(); |
311 | ret litorderedmap( |
312 | "Label" := label.name, |
313 | "Prediction Confidence" := renderTheoryScore(first(bestTheories)), |
314 | "Best Theory" := empty(bestTheories) ? "" : |
315 | (l(bestTheories) > 1 ? "[+" + (l(bestTheories)-1) + "] " : "") + first(bestTheories)); |
316 | })); |
317 | } |
318 | |
319 | S renderTheoryScore(Theory t) {
|
320 | //ret renderPosNegCounts(t.examples); |
321 | ret t == null || t.examples.isEmpty() ? "" : iround(adjustConfidence(threeB1BScore(t.examples))) + "%" |
322 | + " / " + renderPosNegScoreAndCount(t.examples); |
323 | } |
324 | |
325 | int theoryScore(Theory t) {
|
326 | ret t == null ? -100 : t.examples.score(); |
327 | } |
328 | |
329 | void theoriesChanged {
|
330 | updateTheoryTable(); |
331 | updateLabelsTable(); |
332 | updateTabs(); |
333 | updatePredictions(); |
334 | change(); |
335 | } |
336 | |
337 | visual |
338 | withCenteredButtons(super, |
339 | "<", rInThinkQ(r showPrevMsg), |
340 | "Show random msg", rInThinkQ(r showRandomMsg), |
341 | ">", rInThinkQ(r showNextMsg), |
342 | jPopDownButton_noText(flattenObjectArray( |
343 | "Check theories", rInThinkQ(r checkAllTheories), |
344 | "Clear theories", rInThinkQ(r clearTheories), |
345 | "Update predictions", rInThinkQ(r updatePredictions), |
346 | dm_importAndExportAllDataMenuItems(), |
347 | "Upgrade to v3", rThreadEnter upgradeMe))); |
348 | |
349 | Runnable rInThinkQ(Runnable r) { ret rInQ(thinkQ, r); }
|
350 | |
351 | void addTheory(Theory theory) {
|
352 | if (theories.add(theory)) {
|
353 | print("New theory: " + theory);
|
354 | addTheoryToCollectors(theory); |
355 | theoriesChanged(); |
356 | } |
357 | } |
358 | |
359 | void clearTheories { theories.clear(); theoriesChanged(); }
|
360 | |
361 | Bool checkMsgProp(O prop, Msg msg) {
|
362 | if (prop cast And) ret checkMsgProp(prop.a, msg) && checkMsgProp(prop.b, msg); |
363 | if (prop cast Not) ret not(checkMsgProp(prop.a, msg)); |
364 | ret ((MsgProp) prop).check(msg); |
365 | } |
366 | |
367 | Bool evalTheoryLHS(Theory theory, Msg msg) {
|
368 | ret theory == null ? null |
369 | : checkMsgProp(theory.statement.lhs, msg); |
370 | } |
371 | |
372 | Bool testTheoryOnMsg(Theory theory, Msg msg) {
|
373 | Bool lhs = evalTheoryLHS(theory, msg); |
374 | Bool rhs = checkMsgProp(theory.statement.rhs, msg); |
375 | if (lhs == null || rhs == null) null; |
376 | if (bidiMode) |
377 | ret eq(lhs, rhs); |
378 | else |
379 | ret isTrue(rhs) || isFalse(lhs); |
380 | } |
381 | |
382 | void checkAllTheories {
|
383 | for (Theory theory : theories) |
384 | checkTheory_noTrigger(theory); |
385 | theoriesChanged(); |
386 | } |
387 | |
388 | void checkTheory(Theory theory) {
|
389 | checkTheory_noTrigger(theory); |
390 | theoriesChanged(); |
391 | } |
392 | |
393 | void checkTheory_noTrigger(Theory theory) {
|
394 | new PosNeg<Msg> pn; |
395 | for (Msg msg : msgs) |
396 | pn.add(msg, testTheoryOnMsg(theory, msg)); |
397 | if (!eq(theory.examples, pn)) {
|
398 | removeTheoryFromCollectors(theory); |
399 | theory.examples = pn; |
400 | addTheoryToCollectors(theory); |
401 | change(); |
402 | } |
403 | } |
404 | |
405 | S cleanLabel(S label) { ret upper(label); }
|
406 | |
407 | S targetLabelOfTheory(Theory theory) {
|
408 | O o = theory.statement.rhs; |
409 | if (o cast HasLabel) ret o.label; |
410 | if (o cast DoesntHaveLabel) ret o.label; |
411 | null; |
412 | } |
413 | |
414 | void addTheoryToCollectors(Theory theory) {
|
415 | S lbl = targetLabelOfTheory(theory); |
416 | if (lbl != null) |
417 | labelByName(lbl).bestTheories.add(theory); |
418 | } |
419 | |
420 | void removeTheoryFromCollectors(Theory theory) {
|
421 | S lbl = targetLabelOfTheory(theory); |
422 | if (lbl != null) |
423 | labelByName(lbl).bestTheories.remove(theory); |
424 | } |
425 | |
426 | Label labelByName(S name) {
|
427 | ret getOrCreate(labelsByName, name, () -> new Label(name)); |
428 | } |
429 | |
430 | void updateLabelsByName() {
|
431 | for (S lbl : allLabels) |
432 | labelByName(lbl); |
433 | for (Theory t : theories) |
434 | addTheoryToCollectors(t); |
435 | } |
436 | |
437 | O getMsgFeature(Msg msg, S feature) {
|
438 | ret msg2features.get(msg).get(feature); |
439 | } |
440 | |
441 | void makeTextExtractors(S textFeature) {
|
442 | for (WithName<IF1<S, O>> f : textExtractors()) {
|
443 | IF1<S, O> theFunction = f!; |
444 | featureExtractors.put(f.name, env -> theFunction.get((S) env.getFeature(textFeature))); |
445 | } |
446 | } |
447 | |
448 | L<WithName<IF1<S, O>>> textExtractors() {
|
449 | new L<WithName<IF1<S, O>>> l; |
450 | l.add(WithName<>("number of words", lambda1 numberOfWords));
|
451 | l.add(WithName<>("number of characters", lambda1 l));
|
452 | for (char c : characters("\"', .-_"))
|
453 | l.add(WithName<>("contains " + quote(c), s -> contains(s, c)));
|
454 | /*for (S word : concatAsCISet(lambdaMap words(collect text(msgs)))) |
455 | l.add(WithName<>("contains word " + quote(word), s -> containsWord(s, word)));*/
|
456 | ret l; |
457 | } |
458 | |
459 | void importMsgs {
|
460 | inputMultiLineText("Messages to import (one per line)", voidfunc(S text) {
|
461 | Cl<S> toImport = listMinusSet(asOrderedSet(tlft(text)), collectAsSet text(msgs)); |
462 | if (msgs == null) msgs = ll(); |
463 | for (S line : toImport) |
464 | msgs.add(new Msg(true, line)); |
465 | change(); |
466 | infoBox(nMessages(toImport) + " imported"); |
467 | updateObjectsTable(); |
468 | showRandomMsg(); |
469 | }); |
470 | } |
471 | |
472 | void upgradeMe {
|
473 | dm_exportStructureToDefaultFile(this); |
474 | dm_changeModuleLibID(this, "#1028058/AutoClassifier"); |
475 | } |
476 | } |
download show line numbers debug dex old transpilations
Travelled to 8 computer(s): bhatertpkbcr, mqqgnosmbjvj, pyentgdyhuwx, pzhvpgtvlbxg, qsqiayxyrbia, tvejysmllsmz, vouqrxazstgt, xrpafgyirdlv
No comments. add comment
| Snippet ID: | #1028055 |
| Snippet name: | Auto Classifier v2 [learning message classifier] |
| Eternal ID of this version: | #1028055/20 |
| Text MD5: | 76e4933e577f6b608884073002df95f2 |
| Transpilation MD5: | 0d6903e8aa8d8b104bde8f271eebc20a |
| Author: | stefan |
| Category: | |
| Type: | JavaX source code (Dynamic Module) |
| Public (visible to everyone): | Yes |
| Archived (hidden from active list): | No |
| Created/modified: | 2020-05-15 13:12:27 |
| Source code size: | 15445 bytes / 476 lines |
| Pitched / IR pitched: | No / No |
| Views / Downloads: | 434 / 792 |
| Version history: | 19 change(s) |
| Referenced in: | [show references] |