Uses 911K of libraries. Click here for Pure Java version (19059L/103K).
1 | !7 |
2 | |
3 | cmodule AutoClassifier > DynConvo { |
4 | srecord Theory(BasicLogicRule statement) { |
5 | new PosNeg<Msg> examples; |
6 | //bool iff; // <=> instead of only => |
7 | toString { ret str(statement.lhs instanceof MPTrue ? "Every message is " + statement.rhs |
8 | : bidiMode ? statement.lhs + " <=> " + statement.rhs : statement); } |
9 | } |
10 | |
11 | // propositions about a message. check returns null if unknown |
12 | asclass MsgProp { abstract Bool check(Msg msg); } |
13 | |
14 | srecord MPTrue() > MsgProp { |
15 | Bool check(Msg msg) { true; } |
16 | toString { ret "always"; } |
17 | } |
18 | |
19 | record HasLabel(S label) > MsgProp { |
20 | Bool check(Msg msg) { ret msg2label_new.get(msg, label); } |
21 | toString { ret label; } |
22 | } |
23 | |
24 | record DoesntHaveLabel(S label) > MsgProp { |
25 | Bool check(Msg msg) { ret not(msg2label_new.get(msg, label)); } |
26 | toString { ret "not " + label; } |
27 | } |
28 | |
29 | record FeatureValueIs(S feature, O value) > MsgProp { |
30 | Bool check(Msg msg) { ret eq(getMsgFeature(msg, feature), value); } |
31 | toString { ret feature + "=" + value; } |
32 | } |
33 | |
34 | class Label { |
35 | S name; |
36 | |
37 | *() {} |
38 | *(S *name) {} |
39 | |
40 | TreeSetWithDuplicates<Theory> bestTheories = new(reverseComparatorFromCalculatedField theoryScore()); |
41 | |
42 | int score() { ret theoryScore(first(bestTheories)); } |
43 | Theory bestTheory() { ret first(bestTheories); } |
44 | } |
45 | |
46 | switchable double minAdjustedScoreToDisplay = 50; |
47 | switchable bool autoNext = false; |
48 | static bool bidiMode = true; // treat all theories as bidirectional |
49 | |
50 | L<Msg> msgs; // full dialog |
51 | L<Msg> shownMsgs; |
52 | transient Map<Msg, Map<S, O>> msg2features = AutoMap<>(lambda1 calcMsgFeatures); |
53 | new LinkedHashSet<Theory> theories; |
54 | S analysisText; |
55 | transient JTable theoryTable, labelsTable, trainedExamplesTable, objectsTable; |
56 | transient JTabbedPane tabs; |
57 | transient SingleComponentPanel scpPredictions; |
58 | transient new Map<S, Label> labelsByName; |
59 | |
60 | new Set<S> allLabels; |
61 | transient new L<IVF1<S>> onNewLabel; |
62 | new DoubleKeyedMap<Msg, S, Bool> msg2label_new; |
63 | transient new Map<S, FeatureExtractor<Msg>> featureExtractors; |
64 | transient Q thinkQ; |
65 | |
66 | sinterface FeatureEnv<A> { |
67 | A mainObject(); |
68 | O getFeature(S name); |
69 | } |
70 | |
71 | sinterface FeatureExtractor<A> { |
72 | O get(FeatureEnv<A> env); |
73 | } |
74 | |
75 | start { |
76 | thinkQ = dm_startQ("Thought Queue"); |
77 | thinkQ.add(r { |
78 | // legacy + after deletion cleaning |
79 | setField(allLabels := asTreeSet(msg2label_new.bKeys())); |
80 | updateLabelsByName(); |
81 | |
82 | onNewLabel.add(lbl -> change()); |
83 | |
84 | makeTheoriesAboutLabels(); |
85 | makeTheoriesAboutFeaturesAndLabels(); |
86 | |
87 | for (S field : fields(Msg)) |
88 | featureExtractors.put(field, env -> getOpt(env.mainObject(), field)); |
89 | |
90 | makeTextExtractors("text"); |
91 | |
92 | callFAllOnAll(onNewLabel, allLabels); |
93 | |
94 | msg2labelUpdated(); |
95 | //showRandomMsg(); |
96 | }); |
97 | } |
98 | |
99 | void makeTheoriesAboutLabels { |
100 | // For any label X: |
101 | onNewLabel.add(lbl -> { |
102 | // test theory (for every M: M has label X) |
103 | addTheory(new Theory(BasicLogicRule(new MPTrue, new HasLabel(lbl)))); |
104 | // test theory (for every M: M doesn't have label X) |
105 | addTheory(new Theory(BasicLogicRule(new MPTrue, new DoesntHaveLabel(lbl)))); |
106 | }); |
107 | } |
108 | |
109 | void makeTheoriesAboutFeaturesAndLabels { |
110 | // for every label X: |
111 | onNewLabel.add(lbl -> { |
112 | // For any feature F: |
113 | for (S feature : keys(featureExtractors)) |
114 | // for every seen value V of F: |
115 | for (O value : possibleValuesOfFeatureRelatedToLabel(feature, lbl)) |
116 | for (O rhs : ll(new HasLabel(lbl), new DoesntHaveLabel(lbl))) |
117 | // test theory (for every M: msg M's feature F has value V => msg has/doesn't have label x)) |
118 | addTheory(new Theory(BasicLogicRule( |
119 | new FeatureValueIs(feature, value), rhs))); |
120 | }); |
121 | } |
122 | |
123 | Set possibleValuesOfFeature(S feature) { |
124 | if (isBoolField(Msg, feature)) |
125 | ret litset(false, true); |
126 | ret litset(); |
127 | } |
128 | |
129 | Set possibleValuesOfFeatureRelatedToLabel(S feature, S label) { |
130 | Set set = possibleValuesOfFeature(feature); |
131 | fOr (Msg msg : getMsgsRelatedToLabel(label)) |
132 | set.add(getMsgFeature(msg, feature)); |
133 | ret set; |
134 | } |
135 | |
136 | // returns AutoMap with no realized entries |
137 | Map<S, O> calcMsgFeatures(Msg msg) { |
138 | new Var<FeatureEnv<Msg>> env; |
139 | AutoMap<S, O> map = new(feature -> featureExtractors.get(feature).get(env!)); |
140 | env.set(new FeatureEnv<Msg> { |
141 | Msg mainObject() { ret msg; } |
142 | O getFeature(S feature) { ret map.get(feature); } |
143 | }); |
144 | ret map; |
145 | } |
146 | |
147 | void showMsgs(L<Msg> l) { |
148 | setField(shownMsgs := l); |
149 | setMsgs(l); |
150 | if (l(shownMsgs) == 1) { |
151 | Msg msg = first(shownMsgs); |
152 | setField(analysisText := joinWithEmptyLines( |
153 | "Trained Labels: " + or2(renderBoolMap(getMsgLabels(msg)), "-"), |
154 | "Features:\n" + formatColonProperties_quoteStringValues( |
155 | msg2features.get(msg)) |
156 | )); |
157 | setSCPComponent(scpPredictions, |
158 | scrollableStackWithSpacing(map(predictionsForMsg(msg), p -> |
159 | withSideMargin(jLabelWithButtons(iround(p.adjustedConfidence) + "%: " + p.predictedLabel(), |
160 | "Right", rThread { acceptPrediction(p) }, |
161 | "Wrong", rThread { rejectPrediction(p) }))))); |
162 | } else setField(analysisText := ""); |
163 | } |
164 | |
165 | void updatePredictions() { |
166 | showMsgs(shownMsgs); |
167 | } |
168 | |
169 | srecord Prediction(S label, bool plus, double adjustedConfidence) { |
170 | toString { |
171 | ret predictedLabel() + " (confidence: " + iround(adjustedConfidence) + "%)"; |
172 | } |
173 | |
174 | S predictedLabel() { |
175 | ret (plus ? "" : "not ") + label; |
176 | } |
177 | } |
178 | |
179 | L<Prediction> predictionsForMsg(Msg msg) { |
180 | // positive labels first, then "not"s. sort by score in each group |
181 | new L<Prediction> out; |
182 | for (Label label : values(labelsByName)) { |
183 | Theory t = label.bestTheory(), continue if null; |
184 | Bool lhs = evalTheoryLHS(t, msg), continue if null; |
185 | bool prediction = t.statement.rhs instanceof DoesntHaveLabel ? !lhs : lhs; |
186 | double conf = threeB1BScore(t.examples), adjusted = adjustConfidence(conf); |
187 | //if (adjusted < minAdjustedScoreToDisplay) continue; |
188 | out.add(new Prediction(label.name, prediction, adjusted)); |
189 | } |
190 | ret sortedByCalculatedFieldDesc(out, p -> /*pair(p.plus,*/ p.adjustedConfidence/*)*/); |
191 | } |
192 | |
193 | // go from range 50-100 to 0-100 (might look better) |
194 | double adjustConfidence(double x) { |
195 | ret max(0, (x-50)*2); |
196 | } |
197 | |
198 | void showRandomMsg { |
199 | showMsgs(randomElementAsList(msgs)); |
200 | } |
201 | |
202 | void showPrevMsg { |
203 | showMsgs(llNonNulls(prevInCyclicList(msgs, first(shownMsgs)))); |
204 | } |
205 | |
206 | void showNextMsg { |
207 | showMsgs(llNonNulls(nextInCyclicList(msgs, first(shownMsgs)))); |
208 | } |
209 | |
210 | void acceptPrediction(Prediction p) { |
211 | if (p != null) sendInput2(p.predictedLabel()); |
212 | } |
213 | |
214 | void rejectPrediction(Prediction p) { |
215 | if (p != null) sendInput2(cloneWithFlippedBoolField plus(p).predictedLabel()); |
216 | } |
217 | |
218 | @Override |
219 | void sendInput2(S s) { |
220 | // treat input as a label |
221 | if (l(shownMsgs) == 1) { |
222 | Msg shown = first(shownMsgs); |
223 | new Matches m; |
224 | if "not ..." { |
225 | S label = cleanLabel(m.rest()); |
226 | doubleKeyedMapPutVerbose(+msg2label_new, shown, label, false); |
227 | msg2labelUpdated(label); |
228 | if (autoNext) showRandomMsg(); |
229 | } else { |
230 | S label = cleanLabel(s); |
231 | doubleKeyedMapPutVerbose(+msg2label_new, shown, label, true); |
232 | msg2labelUpdated(label); |
233 | if (autoNext) showRandomMsg(); |
234 | } |
235 | change(); |
236 | } |
237 | } |
238 | |
239 | Map<S, Bool> getMsgLabels(Msg msg) { |
240 | ret msg2label_new.getA(msg); |
241 | } |
242 | Set<Msg> getMsgsRelatedToLabel(S label) { ret msg2label_new.asForB(label); } |
243 | |
244 | void msg2labelUpdated(S label) { |
245 | for (Theory t : cloneList(labelByName(label).bestTheories)) |
246 | checkTheory(t); |
247 | msg2labelUpdated(); |
248 | } |
249 | |
250 | void msg2labelUpdated() { |
251 | callFAllOnAll(onNewLabel, addAll_returnNew(allLabels, msg2label_new.bKeys())); |
252 | updateTrainedExamplesTable(); |
253 | } |
254 | |
255 | void updateTrainedExamplesTable { |
256 | dataToTable_uneditable(trainedExamplesTable, map(msg2label_new.map1, (msg, map) -> |
257 | litorderedmap( |
258 | "Message" := (msg.fromUser ? "User" : "Bot") + ": " + msg.text, |
259 | "Labels" := renderBoolMap(map)))); |
260 | } |
261 | |
262 | JComponent mainPart() { |
263 | ret jhsplit(jvsplit( |
264 | jCenteredSection("Focused Message", super.mainPart()), |
265 | jhsplit( |
266 | jCenteredSection("Message Analysis", dm_textArea analysisText()), |
267 | jCenteredSection("Predictions", scpPredictions = singleComponentPanel()) |
268 | )), |
269 | with(r updateTabs, tabs = jtabs( |
270 | "", with(r updateObjectsTable, withRightAlignedButtons( |
271 | objectsTable = sexyTable(), |
272 | "Import messages...", rThreadEnter importMsgs)), |
273 | "", with(r updateLabelsTable, labelsTable = sexyTable()), |
274 | "", with(r updateTheoryTable, tableWithSearcher2_returnPanel(theoryTable = sexyTable())), |
275 | "", with(r updateTrainedExamplesTable, tableWithSearcher2_returnPanel(trainedExamplesTable = sexyTable())) |
276 | ))); |
277 | } |
278 | |
279 | void updateTabs { |
280 | setTabTitles(tabs, |
281 | firstLetterToUpper(nMessages(msgs)), |
282 | firstLetterToUpper(nLabels(labelsByName)), |
283 | firstLetterToUpper(nTheories(theories)), |
284 | n2(msg2label_new.aKeys(), "Trained Example")); |
285 | } |
286 | |
287 | void updateTheoryTable { |
288 | L<Theory> sorted = sortedByCalculatedFieldDesc(theories, t -> |
289 | t.examples == null ? null : t.examples.score()); |
290 | dataToTable_uneditable(theoryTable, map(sorted, t -> litorderedmap( |
291 | "Score" := renderTheoryScore(t), |
292 | "Theory" := str(t)))); |
293 | } |
294 | |
295 | Map<S, Theory> labelsToBestTheoryMap() { |
296 | Map<S, L<Theory>> map = multiMapToMap(multiMapIndex targetLabelOfTheory(theories)); |
297 | ret mapValues(map, |
298 | theories -> highestBy theoryScore(theories)); |
299 | } |
300 | |
301 | void updateObjectsTable enter { |
302 | dataToTable_uneditable_ifHasTable(objectsTable, map(msgs, msg -> |
303 | litorderedmap("Text" := msg.text) |
304 | )); |
305 | } |
306 | |
307 | void updateLabelsTable enter { |
308 | L<Label> sorted = sortedByCalculatedFieldDesc(values(labelsByName), l -> l.score()); |
309 | dataToTable_uneditable_ifHasTable(labelsTable, map(sorted, label -> { |
310 | Cl<Theory> bestTheories = label.bestTheories.tiedForFirst(); |
311 | ret litorderedmap( |
312 | "Label" := label.name, |
313 | "Prediction Confidence" := renderTheoryScore(first(bestTheories)), |
314 | "Best Theory" := empty(bestTheories) ? "" : |
315 | (l(bestTheories) > 1 ? "[+" + (l(bestTheories)-1) + "] " : "") + first(bestTheories)); |
316 | })); |
317 | } |
318 | |
319 | S renderTheoryScore(Theory t) { |
320 | //ret renderPosNegCounts(t.examples); |
321 | ret t == null || t.examples.isEmpty() ? "" : iround(adjustConfidence(threeB1BScore(t.examples))) + "%" |
322 | + " / " + renderPosNegScoreAndCount(t.examples); |
323 | } |
324 | |
325 | int theoryScore(Theory t) { |
326 | ret t == null ? -100 : t.examples.score(); |
327 | } |
328 | |
329 | void theoriesChanged { |
330 | updateTheoryTable(); |
331 | updateLabelsTable(); |
332 | updateTabs(); |
333 | updatePredictions(); |
334 | change(); |
335 | } |
336 | |
337 | visual |
338 | withCenteredButtons(super, |
339 | "<", rInThinkQ(r showPrevMsg), |
340 | "Show random msg", rInThinkQ(r showRandomMsg), |
341 | ">", rInThinkQ(r showNextMsg), |
342 | jPopDownButton_noText(flattenObjectArray( |
343 | "Check theories", rInThinkQ(r checkAllTheories), |
344 | "Clear theories", rInThinkQ(r clearTheories), |
345 | "Update predictions", rInThinkQ(r updatePredictions), |
346 | dm_importAndExportAllDataMenuItems(), |
347 | "Upgrade to v3", rThreadEnter upgradeMe))); |
348 | |
349 | Runnable rInThinkQ(Runnable r) { ret rInQ(thinkQ, r); } |
350 | |
351 | void addTheory(Theory theory) { |
352 | if (theories.add(theory)) { |
353 | print("New theory: " + theory); |
354 | addTheoryToCollectors(theory); |
355 | theoriesChanged(); |
356 | } |
357 | } |
358 | |
359 | void clearTheories { theories.clear(); theoriesChanged(); } |
360 | |
361 | Bool checkMsgProp(O prop, Msg msg) { |
362 | if (prop cast And) ret checkMsgProp(prop.a, msg) && checkMsgProp(prop.b, msg); |
363 | if (prop cast Not) ret not(checkMsgProp(prop.a, msg)); |
364 | ret ((MsgProp) prop).check(msg); |
365 | } |
366 | |
367 | Bool evalTheoryLHS(Theory theory, Msg msg) { |
368 | ret theory == null ? null |
369 | : checkMsgProp(theory.statement.lhs, msg); |
370 | } |
371 | |
372 | Bool testTheoryOnMsg(Theory theory, Msg msg) { |
373 | Bool lhs = evalTheoryLHS(theory, msg); |
374 | Bool rhs = checkMsgProp(theory.statement.rhs, msg); |
375 | if (lhs == null || rhs == null) null; |
376 | if (bidiMode) |
377 | ret eq(lhs, rhs); |
378 | else |
379 | ret isTrue(rhs) || isFalse(lhs); |
380 | } |
381 | |
382 | void checkAllTheories { |
383 | for (Theory theory : theories) |
384 | checkTheory_noTrigger(theory); |
385 | theoriesChanged(); |
386 | } |
387 | |
388 | void checkTheory(Theory theory) { |
389 | checkTheory_noTrigger(theory); |
390 | theoriesChanged(); |
391 | } |
392 | |
393 | void checkTheory_noTrigger(Theory theory) { |
394 | new PosNeg<Msg> pn; |
395 | for (Msg msg : msgs) |
396 | pn.add(msg, testTheoryOnMsg(theory, msg)); |
397 | if (!eq(theory.examples, pn)) { |
398 | removeTheoryFromCollectors(theory); |
399 | theory.examples = pn; |
400 | addTheoryToCollectors(theory); |
401 | change(); |
402 | } |
403 | } |
404 | |
405 | S cleanLabel(S label) { ret upper(label); } |
406 | |
407 | S targetLabelOfTheory(Theory theory) { |
408 | O o = theory.statement.rhs; |
409 | if (o cast HasLabel) ret o.label; |
410 | if (o cast DoesntHaveLabel) ret o.label; |
411 | null; |
412 | } |
413 | |
414 | void addTheoryToCollectors(Theory theory) { |
415 | S lbl = targetLabelOfTheory(theory); |
416 | if (lbl != null) |
417 | labelByName(lbl).bestTheories.add(theory); |
418 | } |
419 | |
420 | void removeTheoryFromCollectors(Theory theory) { |
421 | S lbl = targetLabelOfTheory(theory); |
422 | if (lbl != null) |
423 | labelByName(lbl).bestTheories.remove(theory); |
424 | } |
425 | |
426 | Label labelByName(S name) { |
427 | ret getOrCreate(labelsByName, name, () -> new Label(name)); |
428 | } |
429 | |
430 | void updateLabelsByName() { |
431 | for (S lbl : allLabels) |
432 | labelByName(lbl); |
433 | for (Theory t : theories) |
434 | addTheoryToCollectors(t); |
435 | } |
436 | |
437 | O getMsgFeature(Msg msg, S feature) { |
438 | ret msg2features.get(msg).get(feature); |
439 | } |
440 | |
441 | void makeTextExtractors(S textFeature) { |
442 | for (WithName<IF1<S, O>> f : textExtractors()) { |
443 | IF1<S, O> theFunction = f!; |
444 | featureExtractors.put(f.name, env -> theFunction.get((S) env.getFeature(textFeature))); |
445 | } |
446 | } |
447 | |
448 | L<WithName<IF1<S, O>>> textExtractors() { |
449 | new L<WithName<IF1<S, O>>> l; |
450 | l.add(WithName<>("number of words", lambda1 numberOfWords)); |
451 | l.add(WithName<>("number of characters", lambda1 l)); |
452 | for (char c : characters("\"', .-_")) |
453 | l.add(WithName<>("contains " + quote(c), s -> contains(s, c))); |
454 | /*for (S word : concatAsCISet(lambdaMap words(collect text(msgs)))) |
455 | l.add(WithName<>("contains word " + quote(word), s -> containsWord(s, word)));*/ |
456 | ret l; |
457 | } |
458 | |
459 | void importMsgs { |
460 | inputMultiLineText("Messages to import (one per line)", voidfunc(S text) { |
461 | Cl<S> toImport = listMinusSet(asOrderedSet(tlft(text)), collectAsSet text(msgs)); |
462 | if (msgs == null) msgs = ll(); |
463 | for (S line : toImport) |
464 | msgs.add(new Msg(true, line)); |
465 | change(); |
466 | infoBox(nMessages(toImport) + " imported"); |
467 | updateObjectsTable(); |
468 | showRandomMsg(); |
469 | }); |
470 | } |
471 | |
472 | void upgradeMe { |
473 | dm_exportStructureToDefaultFile(this); |
474 | dm_changeModuleLibID(this, "#1028058/AutoClassifier"); |
475 | } |
476 | } |
download show line numbers debug dex old transpilations
Travelled to 8 computer(s): bhatertpkbcr, mqqgnosmbjvj, pyentgdyhuwx, pzhvpgtvlbxg, qsqiayxyrbia, tvejysmllsmz, vouqrxazstgt, xrpafgyirdlv
No comments. add comment
Snippet ID: | #1028055 |
Snippet name: | Auto Classifier v2 [learning message classifier] |
Eternal ID of this version: | #1028055/20 |
Text MD5: | 76e4933e577f6b608884073002df95f2 |
Transpilation MD5: | 0d6903e8aa8d8b104bde8f271eebc20a |
Author: | stefan |
Category: | |
Type: | JavaX source code (Dynamic Module) |
Public (visible to everyone): | Yes |
Archived (hidden from active list): | No |
Created/modified: | 2020-05-15 13:12:27 |
Source code size: | 15445 bytes / 476 lines |
Pitched / IR pitched: | No / No |
Views / Downloads: | 199 / 514 |
Version history: | 19 change(s) |
Referenced in: | [show references] |