Uses 911K of libraries. Click here for Pure Java version (19135L/103K).
1 | !7 |
2 | |
3 | cmodule AutoClassifier > DynConvo { |
4 | // THEORY BUILDING BLOCKS (Theory + MsgProp + subclasses) |
5 | |
6 | srecord Theory(BasicLogicRule statement) { |
7 | new PosNeg<Msg> examples; |
8 | //bool iff; // <=> instead of only => |
9 | toString { ret str(statement.lhs instanceof MPTrue ? "Every message is " + statement.rhs |
10 | : bidiMode ? statement.lhs + " <=> " + statement.rhs : statement); } |
11 | } |
12 | |
13 | // propositions about a message. check returns null if unknown |
14 | asclass MsgProp { abstract Bool check(Msg msg); } |
15 | |
16 | srecord MPTrue() > MsgProp { |
17 | Bool check(Msg msg) { true; } |
18 | toString { ret "always"; } |
19 | } |
20 | |
21 | record HasLabel(S label) > MsgProp { |
22 | Bool check(Msg msg) { ret msg2label_new.get(msg, label); } |
23 | toString { ret label; } |
24 | } |
25 | |
26 | record DoesntHaveLabel(S label) > MsgProp { |
27 | Bool check(Msg msg) { ret not(msg2label_new.get(msg, label)); } |
28 | toString { ret "not " + label; } |
29 | } |
30 | |
31 | record FeatureValueIs(S feature, O value) > MsgProp { |
32 | Bool check(Msg msg) { ret eq(getMsgFeature(msg, feature), value); } |
33 | toString { ret feature + "=" + value; } |
34 | } |
35 | |
36 | // LABEL class (with best theories) |
37 | |
38 | class Label { |
39 | S name; |
40 | |
41 | *() {} |
42 | *(S *name) {} |
43 | |
44 | TreeSetWithDuplicates<Theory> bestTheories = new(reverseComparatorFromCalculatedField theoryScore()); |
45 | |
46 | double score() { ret theoryScore(first(bestTheories)); } |
47 | Theory bestTheory() { ret first(bestTheories); } |
48 | } |
49 | |
50 | // FEATURE base classes (FeatureEnv + FeatureExtractor) |
51 | |
52 | sinterface FeatureEnv<A> { |
53 | A mainObject(); |
54 | O getFeature(S name); |
55 | } |
56 | |
57 | sinterface FeatureExtractor<A> { |
58 | O get(FeatureEnv<A> env); |
59 | } |
60 | |
61 | // PREDICTION class (output of classifier) |
62 | |
63 | srecord Prediction(S label, bool plus, double adjustedConfidence) { |
64 | toString { |
65 | ret predictedLabel() + " (confidence: " + iround(adjustedConfidence) + "%)"; |
66 | } |
67 | |
68 | S predictedLabel() { |
69 | ret (plus ? "" : "not ") + label; |
70 | } |
71 | } |
72 | |
73 | // DATA (backend) |
74 | |
75 | sbool bidiMode = true; // treat all theories as bidirectional |
76 | L<Msg> msgs; // all messages (order not used yet) |
77 | transient Map<Msg, Map<S, O>> msg2features = AutoMap<>(lambda1 calcMsgFeatures); |
78 | new Set<S> allLabels; |
79 | transient new Map<S, Label> labelsByName; |
80 | new LinkedHashSet<Theory> theories; |
81 | transient Q thinkQ; |
82 | transient new L<IVF1<S>> onNewLabel; |
83 | new DoubleKeyedMap<Msg, S, Bool> msg2label_new; |
84 | transient new Map<S, FeatureExtractor<Msg>> featureExtractors; |
85 | |
86 | // DATA (GUI) |
87 | |
88 | switchable double minAdjustedScoreToDisplay = 50; |
89 | switchable bool autoNext = false; |
90 | L<Msg> shownMsgs; |
91 | S analysisText; |
92 | transient JTable theoryTable, labelsTable, trainedExamplesTable, objectsTable; |
93 | transient JTabbedPane tabs; |
94 | transient SingleComponentPanel scpPredictions; |
95 | |
96 | // START CODE |
97 | |
98 | start { |
99 | thinkQ = dm_startQ("Thought Queue"); |
100 | thinkQ.add(r { |
101 | // legacy + after deletion cleaning |
102 | setField(allLabels := asTreeSet(msg2label_new.bKeys())); |
103 | updateLabelsByName(); |
104 | |
105 | onNewLabel.add(lbl -> change()); |
106 | |
107 | makeTheoriesAboutLabels(); |
108 | makeTheoriesAboutFeaturesAndLabels(); |
109 | |
110 | for (S field : fields(Msg)) |
111 | featureExtractors.put(field, env -> getOpt(env.mainObject(), field)); |
112 | |
113 | makeTextExtractors("text"); |
114 | |
115 | callFAllOnAll(onNewLabel, allLabels); |
116 | |
117 | msg2labelUpdated(); |
118 | checkAllTheories(); |
119 | //showRandomMsg(); |
120 | }); |
121 | } |
122 | |
123 | // THEORY MAKING |
124 | |
125 | void makeTheoriesAboutLabels { |
126 | // For any label X: |
127 | onNewLabel.add(lbl -> { |
128 | // test theory (for every M: M has label X) |
129 | addTheory(new Theory(BasicLogicRule(new MPTrue, new HasLabel(lbl)))); |
130 | // test theory (for every M: M doesn't have label X) |
131 | addTheory(new Theory(BasicLogicRule(new MPTrue, new DoesntHaveLabel(lbl)))); |
132 | }); |
133 | } |
134 | |
135 | void makeTheoriesAboutFeaturesAndLabels { |
136 | // for every label X: |
137 | onNewLabel.add(lbl -> { |
138 | // For any feature F: |
139 | for (S feature : keys(featureExtractors)) |
140 | // for every seen value V of F: |
141 | for (O value : possibleValuesOfFeatureRelatedToLabel(feature, lbl)) |
142 | for (O rhs : ll(new HasLabel(lbl), new DoesntHaveLabel(lbl))) |
143 | // test theory (for every M: msg M's feature F has value V => msg has/doesn't have label x)) |
144 | addTheory(new Theory(BasicLogicRule( |
145 | new FeatureValueIs(feature, value), rhs))); |
146 | }); |
147 | } |
148 | |
149 | // THEORY MAKING (helper functions) |
150 | |
151 | Set possibleValuesOfFeature(S feature) { |
152 | if (isBoolField(Msg, feature)) |
153 | ret litset(false, true); |
154 | ret litset(); |
155 | } |
156 | |
157 | Set possibleValuesOfFeatureRelatedToLabel(S feature, S label) { |
158 | Set set = possibleValuesOfFeature(feature); |
159 | fOr (Msg msg : getMsgsRelatedToLabel(label)) |
160 | set.add(getMsgFeature(msg, feature)); |
161 | ret set; |
162 | } |
163 | |
164 | // CALCULATE FEATURES |
165 | |
166 | O getMsgFeature(Msg msg, S feature) { |
167 | ret msg2features.get(msg).get(feature); |
168 | } |
169 | |
170 | // returns AutoMap with no realized entries |
171 | Map<S, O> calcMsgFeatures(Msg msg) { |
172 | new Var<FeatureEnv<Msg>> env; |
173 | AutoMap<S, O> map = new(feature -> featureExtractors.get(feature).get(env!)); |
174 | env.set(new FeatureEnv<Msg> { |
175 | Msg mainObject() { ret msg; } |
176 | O getFeature(S feature) { ret map.get(feature); } |
177 | }); |
178 | ret map; |
179 | } |
180 | |
181 | // GUI: Show messages |
182 | |
183 | void showMsgs(L<Msg> l) { |
184 | setField(shownMsgs := l); |
185 | setMsgs(l); |
186 | if (l(shownMsgs) == 1) { |
187 | Msg msg = first(shownMsgs); |
188 | setField(analysisText := joinWithEmptyLines( |
189 | "Trained Labels: " + or2(renderBoolMap(getMsgLabels(msg)), "-"), |
190 | "Features:\n" + formatColonProperties_quoteStringValues( |
191 | msg2features.get(msg)) |
192 | )); |
193 | setSCPComponent(scpPredictions, |
194 | scrollableStackWithSpacing(map(predictionsForMsg(msg), p -> |
195 | withSideMargin(jLabelWithButtons(iround(p.adjustedConfidence) + "%: " + p.predictedLabel(), |
196 | "Right", rThread { acceptPrediction(p) }, |
197 | "Wrong", rThread { rejectPrediction(p) }))))); |
198 | } else setField(analysisText := ""); |
199 | } |
200 | |
201 | void updatePredictions() { |
202 | showMsgs(shownMsgs); |
203 | } |
204 | |
205 | void showRandomMsg { |
206 | showMsgs(randomElementAsList(msgs)); |
207 | } |
208 | |
209 | void showPrevMsg { |
210 | showMsgs(llNonNulls(prevInCyclicList(msgs, first(shownMsgs)))); |
211 | } |
212 | |
213 | void showNextMsg { |
214 | showMsgs(llNonNulls(nextInCyclicList(msgs, first(shownMsgs)))); |
215 | } |
216 | |
217 | // CALCULATE PREDICTIONS FOR MESSAGE |
218 | |
219 | L<Prediction> predictionsForMsg(Msg msg) { |
220 | // positive labels first, then "not"s. sort by score in each group |
221 | new L<Prediction> out; |
222 | for (Label label : values(labelsByName)) { |
223 | Theory t = label.bestTheory(), continue if null; |
224 | Bool lhs = evalTheoryLHS(t, msg), continue if null; |
225 | bool prediction = t.statement.rhs instanceof DoesntHaveLabel ? !lhs : lhs; |
226 | double conf = threeB1BScore(t.examples), adjusted = adjustConfidence(conf); |
227 | //if (adjusted < minAdjustedScoreToDisplay) continue; |
228 | out.add(new Prediction(label.name, prediction, adjusted)); |
229 | } |
230 | ret sortedByCalculatedFieldDesc(out, p -> /*pair(p.plus,*/ p.adjustedConfidence/*)*/); |
231 | } |
232 | |
233 | // go from range 50-100 to 0-100 (looks better/more intuitive) |
234 | double adjustConfidence(double x) { |
235 | ret max(0, (x-50)*2); |
236 | } |
237 | |
238 | // rough reverse function of adjustConfidence |
239 | double unadjustConfidence(double x) { |
240 | ret x/2+50; |
241 | } |
242 | |
243 | // GUI: Enter labels |
244 | |
245 | void acceptPrediction(Prediction p) { |
246 | if (p != null) sendInput2(p.predictedLabel()); |
247 | } |
248 | |
249 | void rejectPrediction(Prediction p) { |
250 | if (p != null) sendInput2(cloneWithFlippedBoolField plus(p).predictedLabel()); |
251 | } |
252 | |
253 | @Override |
254 | void sendInput2(S s) { |
255 | // treat input as a label |
256 | if (l(shownMsgs) == 1) { |
257 | Msg shown = first(shownMsgs); |
258 | new Matches m; |
259 | if "not ..." { |
260 | S label = cleanLabel(m.rest()); |
261 | doubleKeyedMapPutVerbose(+msg2label_new, shown, label, false); |
262 | msg2labelUpdated(label); |
263 | if (autoNext) showRandomMsg(); |
264 | } else { |
265 | S label = cleanLabel(s); |
266 | doubleKeyedMapPutVerbose(+msg2label_new, shown, label, true); |
267 | msg2labelUpdated(label); |
268 | if (autoNext) showRandomMsg(); |
269 | } |
270 | change(); |
271 | } |
272 | } |
273 | |
274 | // MESSAGE LABEL HANDLING |
275 | |
276 | Map<S, Bool> getMsgLabels(Msg msg) { |
277 | ret msg2label_new.getA(msg); |
278 | } |
279 | |
280 | Set<Msg> getMsgsRelatedToLabel(S label) { ret msg2label_new.asForB(label); } |
281 | |
282 | void msg2labelUpdated(S label) { |
283 | for (Theory t : cloneList(labelByName(label).bestTheories)) |
284 | checkTheory(t); |
285 | msg2labelUpdated(); |
286 | } |
287 | |
288 | void msg2labelUpdated() { |
289 | callFAllOnAll(onNewLabel, addAll_returnNew(allLabels, msg2label_new.bKeys())); |
290 | updateTrainedExamplesTable(); |
291 | } |
292 | |
293 | // QUERY: get all labels + best theory each |
294 | |
295 | Map<S, Theory> labelsToBestTheoryMap() { |
296 | Map<S, L<Theory>> map = multiMapToMap(multiMapIndex targetLabelOfTheory(theories)); |
297 | ret mapValues(map, theories -> highestBy theoryScore(theories)); |
298 | } |
299 | |
300 | // GUI: Main layout |
301 | |
302 | visual |
303 | withCenteredButtons(super, |
304 | "<", rInThinkQ(r showPrevMsg), |
305 | "Show random msg", rInThinkQ(r showRandomMsg), |
306 | ">", rInThinkQ(r showNextMsg), |
307 | jPopDownButton_noText(flattenObjectArray( |
308 | "Check theories", rInThinkQ(r checkAllTheories), |
309 | "Forget bad theories", rInThinkQ(r { forgetBadTheories(0) }), |
310 | "Forget all theories", rInThinkQ(r clearTheories), |
311 | "Update predictions", rInThinkQ(r updatePredictions), |
312 | dm_importAndExportAllDataMenuItems(), |
313 | "Upgrade to v4", rThreadEnter upgradeMe))); |
314 | |
315 | JComponent mainPart() { |
316 | ret jhsplit(jvsplit( |
317 | jCenteredSection("Focused Message", super.mainPart()), |
318 | jhsplit( |
319 | jCenteredSection("Message Analysis", dm_textArea analysisText()), |
320 | jCenteredSection("Predictions", scpPredictions = singleComponentPanel()) |
321 | )), |
322 | with(r updateTabs, tabs = jtabs( |
323 | "", with(r updateObjectsTable, withRightAlignedButtons( |
324 | objectsTable = sexyTable(), |
325 | "Import messages...", rThreadEnter importMsgs)), |
326 | "", with(r updateLabelsTable, labelsTable = sexyTable()), |
327 | "", with(r updateTheoryTable, tableWithSearcher2_returnPanel(theoryTable = sexyTable())), |
328 | "", with(r updateTrainedExamplesTable, tableWithSearcher2_returnPanel(trainedExamplesTable = sexyTable())) |
329 | ))); |
330 | } |
331 | |
332 | // GUI: Update tables & tabs |
333 | |
334 | void updateTrainedExamplesTable { |
335 | dataToTable_uneditable(trainedExamplesTable, map(msg2label_new.map1, (msg, map) -> |
336 | litorderedmap( |
337 | "Message" := (msg.fromUser ? "User" : "Bot") + ": " + msg.text, |
338 | "Labels" := renderBoolMap(map)))); |
339 | } |
340 | |
341 | void updateTabs { |
342 | setTabTitles(tabs, |
343 | firstLetterToUpper(nMessages(msgs)), |
344 | firstLetterToUpper(nLabels(labelsByName)), |
345 | firstLetterToUpper(nTheories(theories)), |
346 | n2(msg2label_new.aKeys(), "Trained Example")); |
347 | } |
348 | |
349 | void updateTheoryTable { |
350 | L<Theory> sorted = sortedByCalculatedFieldDesc theoryScore(theories); |
351 | dataToTable_uneditable(theoryTable, map(sorted, t -> litorderedmap( |
352 | "Score" := renderTheoryScore(t), |
353 | "Theory" := str(t)))); |
354 | } |
355 | |
356 | void updateObjectsTable enter { |
357 | dataToTable_uneditable_ifHasTable(objectsTable, map(msgs, msg -> |
358 | litorderedmap("Text" := msg.text) |
359 | )); |
360 | } |
361 | |
362 | void updateLabelsTable enter { |
363 | L<Label> sorted = sortedByCalculatedFieldDesc(values(labelsByName), l -> l.score()); |
364 | dataToTable_uneditable_ifHasTable(labelsTable, map(sorted, label -> { |
365 | Cl<Theory> bestTheories = label.bestTheories.tiedForFirst(); |
366 | ret litorderedmap( |
367 | "Label" := label.name, |
368 | "Prediction Confidence" := renderTheoryScore(first(bestTheories)), |
369 | "Best Theory" := empty(bestTheories) ? "" : |
370 | (l(bestTheories) > 1 ? "[+" + (l(bestTheories)-1) + "] " : "") + first(bestTheories)); |
371 | })); |
372 | } |
373 | |
374 | void theoriesChanged { |
375 | updateTheoryTable(); |
376 | updateLabelsTable(); |
377 | updateTabs(); |
378 | updatePredictions(); |
379 | change(); |
380 | } |
381 | |
382 | // THEORY SCORING |
383 | |
384 | S renderTheoryScore(Theory t) { |
385 | //ret renderPosNegCounts(t.examples); |
386 | ret t == null || t.examples.isEmpty() ? "" : iround(theoryScore(t)) + "%" |
387 | + " / " + renderPosNegScore2(t.examples); |
388 | } |
389 | |
390 | // adjusted + 3b1b |
391 | double theoryScore(Theory t) { |
392 | ret t == null ? -100 : adjustConfidence(threeB1BScore(t.examples)); |
393 | } |
394 | |
395 | // QUEUE HELPER |
396 | |
397 | Runnable rInThinkQ(Runnable r) { ret rInQ(thinkQ, r); } |
398 | |
399 | // ADD + REMOVE + CLEAN UP THEORIES |
400 | |
401 | void addTheory(Theory theory) { |
402 | if (theories.add(theory)) { |
403 | addTheoryToCollectors(theory); |
404 | theoriesChanged(); |
405 | } |
406 | } |
407 | |
408 | void clearTheories { theories.clear(); theoriesChanged(); } |
409 | |
410 | // theories with exaclty minScore will go too |
411 | void forgetBadTheories(double minScore) { |
412 | if (removeElementsThat(theories, t -> theoryScore(t) <= minScore)) |
413 | theoriesChanged(); |
414 | } |
415 | |
416 | // CHECK PROPOSITIONS + THEORIES |
417 | |
418 | Bool checkMsgProp(O prop, Msg msg) { |
419 | if (prop cast And) ret checkMsgProp(prop.a, msg) && checkMsgProp(prop.b, msg); |
420 | if (prop cast Not) ret not(checkMsgProp(prop.a, msg)); |
421 | ret ((MsgProp) prop).check(msg); |
422 | } |
423 | |
424 | Bool evalTheoryLHS(Theory theory, Msg msg) { |
425 | ret theory == null ? null |
426 | : checkMsgProp(theory.statement.lhs, msg); |
427 | } |
428 | |
429 | Bool testTheoryOnMsg(Theory theory, Msg msg) { |
430 | Bool lhs = evalTheoryLHS(theory, msg); |
431 | Bool rhs = checkMsgProp(theory.statement.rhs, msg); |
432 | if (lhs == null || rhs == null) null; |
433 | if (bidiMode) |
434 | ret eq(lhs, rhs); |
435 | else |
436 | ret isTrue(rhs) || isFalse(lhs); |
437 | } |
438 | |
439 | void checkAllTheories { |
440 | for (Theory theory : theories) |
441 | checkTheory_noTrigger(theory); |
442 | theoriesChanged(); |
443 | } |
444 | |
445 | void checkTheory(Theory theory) { |
446 | checkTheory_noTrigger(theory); |
447 | theoriesChanged(); |
448 | } |
449 | |
450 | void checkTheory_noTrigger(Theory theory) { |
451 | new PosNeg<Msg> pn; |
452 | for (Msg msg : msgs) |
453 | pn.add(msg, testTheoryOnMsg(theory, msg)); |
454 | if (!eq(theory.examples, pn)) { |
455 | removeTheoryFromCollectors(theory); |
456 | theory.examples = pn; |
457 | addTheoryToCollectors(theory); |
458 | change(); |
459 | } |
460 | } |
461 | |
462 | S targetLabelOfTheory(Theory theory) { |
463 | O o = theory.statement.rhs; |
464 | if (o cast HasLabel) ret o.label; |
465 | if (o cast DoesntHaveLabel) ret o.label; |
466 | null; |
467 | } |
468 | |
469 | // CANONICALIZE LABELS |
470 | |
471 | S cleanLabel(S label) { ret upper(label); } |
472 | |
473 | // THEORY + LABEL UPDATES |
474 | |
475 | void addTheoryToCollectors(Theory theory) { |
476 | S lbl = targetLabelOfTheory(theory); |
477 | if (lbl != null) |
478 | labelByName(lbl).bestTheories.add(theory); |
479 | } |
480 | |
481 | void removeTheoryFromCollectors(Theory theory) { |
482 | S lbl = targetLabelOfTheory(theory); |
483 | if (lbl != null) |
484 | labelByName(lbl).bestTheories.remove(theory); |
485 | } |
486 | |
487 | Label labelByName(S name) { |
488 | ret getOrCreate(labelsByName, name, () -> new Label(name)); |
489 | } |
490 | |
491 | void updateLabelsByName() { |
492 | for (S lbl : allLabels) |
493 | labelByName(lbl); |
494 | for (Theory t : theories) |
495 | addTheoryToCollectors(t); |
496 | } |
497 | |
498 | // MAKE FEATURE EXTRACTORS |
499 | |
500 | void makeTextExtractors(S textFeature) { |
501 | for (WithName<IF1<S, O>> f : textExtractors()) { |
502 | IF1<S, O> theFunction = f!; |
503 | featureExtractors.put(f.name, env -> theFunction.get((S) env.getFeature(textFeature))); |
504 | } |
505 | } |
506 | |
507 | L<WithName<IF1<S, O>>> textExtractors() { |
508 | new L<WithName<IF1<S, O>>> l; |
509 | l.add(WithName<>("number of words", lambda1 numberOfWords)); |
510 | l.add(WithName<>("number of characters", lambda1 l)); |
511 | for (char c : characters("\"', .-_")) |
512 | l.add(WithName<>("contains " + quote(c), s -> contains(s, c))); |
513 | /*for (S word : concatAsCISet(lambdaMap words(collect text(msgs)))) |
514 | l.add(WithName<>("contains word " + quote(word), s -> containsWord(s, word)));*/ |
515 | ret l; |
516 | } |
517 | |
518 | // GUI: Import messages dialog |
519 | |
520 | void importMsgs { |
521 | inputMultiLineText("Messages to import (one per line)", voidfunc(S text) { |
522 | Cl<S> toImport = listMinusSet(asOrderedSet(tlft(text)), collectAsSet text(msgs)); |
523 | if (msgs == null) msgs = ll(); |
524 | for (S line : toImport) |
525 | msgs.add(new Msg(true, line)); |
526 | change(); |
527 | infoBox(nMessages(toImport) + " imported"); |
528 | updateObjectsTable(); |
529 | showRandomMsg(); |
530 | }); |
531 | } |
532 | |
533 | void upgradeMe { |
534 | dm_exportStructureToDefaultFile(this); |
535 | dm_changeModuleLibID(this, "#1028063/AutoClassifier"); |
536 | } |
537 | } |
Began life as a copy of #1028055
download show line numbers debug dex old transpilations
Travelled to 7 computer(s): bhatertpkbcr, mqqgnosmbjvj, pyentgdyhuwx, pzhvpgtvlbxg, tvejysmllsmz, vouqrxazstgt, xrpafgyirdlv
No comments. add comment
Snippet ID: | #1028058 |
Snippet name: | Auto Classifier v3 [learning message classifier, dev.] |
Eternal ID of this version: | #1028058/14 |
Text MD5: | 82abc889936952c3f24d06a57b99218c |
Transpilation MD5: | 1c8f6f456387cf063caada2da795c7a0 |
Author: | stefan |
Category: | javax / a.i. |
Type: | JavaX source code (Dynamic Module) |
Public (visible to everyone): | Yes |
Archived (hidden from active list): | No |
Created/modified: | 2020-05-15 14:21:33 |
Source code size: | 16718 bytes / 537 lines |
Pitched / IR pitched: | No / No |
Views / Downloads: | 199 / 696 |
Version history: | 13 change(s) |
Referenced in: | [show references] |