Uses 911K of libraries. Click here for Pure Java version (18263L/98K).
1 | !7 |
2 | |
3 | cmodule TheoryMaker > DynConvo { |
4 | /* |
5 | 1. measurable features (fields of object) |
6 | 2. labels (words the user throws in) |
7 | 3. make theories (random connectors between features and labels) |
8 | 4. check theories |
9 | |
10 | 1. show a random line |
11 | 2. user types keyword |
12 | 3. assign keyword to line |
13 | 4. check if prediction weas correct |
14 | |
15 | Basic theory making |
16 | ------------------- |
17 | |
18 | For any label X: |
19 | test theory (for every M: M has label X) |
20 | test theory (for every M: M doesn't have label X) |
21 | |
22 | For any feature F: |
23 | for every seen value V of F: |
24 | for every label X: |
25 | test theory (for every M: msg M's feature F has value V => msg has label x)) |
26 | test theory (for every M: msg M's feature F has value V => msg doesn't have label x)) |
27 | |
28 | */ |
29 | |
30 | srecord Theory(BasicLogicRule statement) { |
31 | new PosNeg<Msg> examples; |
32 | //bool iff; // <=> instead of only => |
33 | toString { ret str(statement.lhs instanceof MPTrue ? "Every message is " + statement.rhs |
34 | : bidiMode ? statement.lhs + " <=> " + statement.rhs : statement); } |
35 | } |
36 | |
37 | // propositions about a message. check returns null if unknown |
38 | asclass MsgProp { abstract Bool check(Msg msg); } |
39 | |
40 | srecord MPTrue() > MsgProp { |
41 | Bool check(Msg msg) { true; } |
42 | toString { ret "always"; } |
43 | } |
44 | |
45 | record HasLabel(S label) > MsgProp { |
46 | Bool check(Msg msg) { ret msg2label_new.get(msg, label); } |
47 | toString { ret label; } |
48 | } |
49 | |
50 | record DoesntHaveLabel(S label) > MsgProp { |
51 | Bool check(Msg msg) { ret not(msg2label_new.get(msg, label)); } |
52 | toString { ret "not " + label; } |
53 | } |
54 | |
55 | record FeatureValueIs(S feature, O value) > MsgProp { |
56 | Bool check(Msg msg) { ret eq(getMsgFeature(msg, feature), value); } |
57 | toString { ret feature + "=" + value; } |
58 | } |
59 | |
60 | class Label { |
61 | S name; |
62 | |
63 | *() {} |
64 | *(S *name) {} |
65 | |
66 | TreeSetWithDuplicates<Theory> bestTheories = new(reverseComparatorFromCalculatedField theoryScore()); |
67 | |
68 | int score() { ret theoryScore(first(bestTheories)); } |
69 | Theory bestTheory() { ret first(bestTheories); } |
70 | } |
71 | |
72 | switchable double minAdjustedScoreToDisplay = 50; |
73 | switchable bool autoNext = false; |
74 | static bool bidiMode = true; // treat all theories as bidirectional |
75 | |
76 | L<Msg> msgs; // full dialog |
77 | L<Msg> shownMsgs; |
78 | transient Map<Msg, Map<S, O>> msg2features = AutoMap<>(lambda1 calcMsgFeatures); |
79 | new LinkedHashSet<Theory> theories; |
80 | S analysisText; |
81 | transient JTable theoryTable, labelsTable, trainedExamplesTable; |
82 | transient JTabbedPane tabs; |
83 | transient SingleComponentPanel scpPredictions; |
84 | transient new Map<S, Label> labelsByName; |
85 | |
86 | new Set<S> allLabels; |
87 | transient new L<IVF1<S>> onNewLabel; |
88 | new DoubleKeyedMap<Msg, S, Bool> msg2label_new; |
89 | transient new Map<S, FeatureExtractor<Msg>> featureExtractors; |
90 | transient Q thinkQ; |
91 | |
92 | sinterface FeatureEnv<A> { |
93 | A mainObject(); |
94 | O getFeature(S name); |
95 | } |
96 | |
97 | sinterface FeatureExtractor<A> { |
98 | O get(FeatureEnv<A> env); |
99 | } |
100 | |
101 | start { |
102 | thinkQ = dm_startQ("Thought Queue"); |
103 | thinkQ.add(r { |
104 | // legacy + after deletion cleaning |
105 | setField(allLabels := asTreeSet(msg2label_new.bKeys())); |
106 | updateLabelsByName(); |
107 | |
108 | onNewLabel.add(lbl -> change()); |
109 | |
110 | makeTheoriesAboutLabels(); |
111 | makeTheoriesAboutFeaturesAndLabels(); |
112 | |
113 | for (S field : fields(Msg)) |
114 | featureExtractors.put(field, env -> getOpt(env.mainObject(), field)); |
115 | |
116 | makeTextExtractors("text"); |
117 | |
118 | callFAllOnAll(onNewLabel, allLabels); |
119 | |
120 | msg2labelUpdated(); |
121 | if (empty(msgs)) |
122 | setField(msgs := mainCruddieLog()); |
123 | showRandomMsg(); |
124 | }); |
125 | } |
126 | |
127 | void makeTheoriesAboutLabels { |
128 | // For any label X: |
129 | onNewLabel.add(lbl -> { |
130 | // test theory (for every M: M has label X) |
131 | addTheory(new Theory(BasicLogicRule(new MPTrue, new HasLabel(lbl)))); |
132 | // test theory (for every M: M doesn't have label X) |
133 | addTheory(new Theory(BasicLogicRule(new MPTrue, new DoesntHaveLabel(lbl)))); |
134 | }); |
135 | } |
136 | |
137 | void makeTheoriesAboutFeaturesAndLabels { |
138 | // for every label X: |
139 | onNewLabel.add(lbl -> { |
140 | // For any feature F: |
141 | for (S feature : keys(featureExtractors)) |
142 | // for every seen value V of F: |
143 | for (O value : possibleValuesOfFeatureRelatedToLabel(feature, lbl)) |
144 | for (O rhs : ll(new HasLabel(lbl), new DoesntHaveLabel(lbl))) |
145 | // test theory (for every M: msg M's feature F has value V => msg has/doesn't have label x)) |
146 | addTheory(new Theory(BasicLogicRule( |
147 | new FeatureValueIs(feature, value), rhs))); |
148 | }); |
149 | } |
150 | |
151 | Set possibleValuesOfFeature(S feature) { |
152 | if (isBoolField(Msg, feature)) |
153 | ret litset(false, true); |
154 | ret litset(); |
155 | } |
156 | |
157 | Set possibleValuesOfFeatureRelatedToLabel(S feature, S label) { |
158 | Set set = possibleValuesOfFeature(feature); |
159 | fOr (Msg msg : getMsgsRelatedToLabel(label)) |
160 | set.add(getMsgFeature(msg, feature)); |
161 | ret set; |
162 | } |
163 | |
164 | // returns AutoMap with no realized entries |
165 | Map<S, O> calcMsgFeatures(Msg msg) { |
166 | new Var<FeatureEnv<Msg>> env; |
167 | AutoMap<S, O> map = new(feature -> featureExtractors.get(feature).get(env!)); |
168 | env.set(new FeatureEnv<Msg> { |
169 | Msg mainObject() { ret msg; } |
170 | O getFeature(S feature) { ret map.get(feature); } |
171 | }); |
172 | ret map; |
173 | } |
174 | |
175 | void showMsgs(L<Msg> l) { |
176 | setField(shownMsgs := l); |
177 | setMsgs(l); |
178 | if (l(shownMsgs) == 1) { |
179 | Msg msg = first(shownMsgs); |
180 | setField(analysisText := joinWithEmptyLines( |
181 | "Trained Labels: " + or2(renderBoolMap(getMsgLabels(msg)), "-"), |
182 | "Features:\n" + formatColonProperties_quoteStringValues( |
183 | msg2features.get(msg)) |
184 | )); |
185 | setSCPComponent(scpPredictions, |
186 | scrollableStackWithSpacing(map(predictionsForMsg(msg), p -> |
187 | withSideMargin(jLabelWithButtons(iround(p.adjustedConfidence) + "%: " + p.predictedLabel(), |
188 | "Right", rThread { acceptPrediction(p) }, |
189 | "Wrong", rThread { rejectPrediction(p) }))))); |
190 | } else setField(analysisText := ""); |
191 | } |
192 | |
193 | void updatePredictions() { |
194 | showMsgs(shownMsgs); |
195 | } |
196 | |
197 | srecord Prediction(S label, bool plus, double adjustedConfidence) { |
198 | toString { |
199 | ret predictedLabel() + " (confidence: " + iround(adjustedConfidence) + "%)"; |
200 | } |
201 | |
202 | S predictedLabel() { |
203 | ret (plus ? "" : "not ") + label; |
204 | } |
205 | } |
206 | |
207 | L<Prediction> predictionsForMsg(Msg msg) { |
208 | // positive labels first, then "not"s. sort by score in each group |
209 | new L<Prediction> out; |
210 | for (Label label : values(labelsByName)) { |
211 | Theory t = label.bestTheory(), continue if null; |
212 | Bool lhs = evalTheoryLHS(t, msg), continue if null; |
213 | bool prediction = t.statement.rhs instanceof DoesntHaveLabel ? !lhs : lhs; |
214 | double conf = threeB1BScore(t.examples), adjusted = adjustConfidence(conf); |
215 | if (adjusted < minAdjustedScoreToDisplay) continue; |
216 | out.add(new Prediction(label.name, prediction, adjusted)); |
217 | } |
218 | ret sortedByCalculatedFieldDesc(out, p -> pair(p.plus, p.adjustedConfidence)); |
219 | } |
220 | |
221 | // go from range 50-100 to 0-100 (might look better) |
222 | double adjustConfidence(double x) { |
223 | ret max(0, (x-50)*2); |
224 | } |
225 | |
226 | void showRandomMsg { |
227 | showMsgs(randomElementAsList(msgs)); |
228 | } |
229 | |
230 | void acceptPrediction(Prediction p) { |
231 | if (p != null) sendInput2(p.predictedLabel()); |
232 | } |
233 | |
234 | void rejectPrediction(Prediction p) { |
235 | if (p != null) sendInput2(cloneWithFlippedBoolField plus(p).predictedLabel()); |
236 | } |
237 | |
238 | @Override |
239 | void sendInput2(S s) { |
240 | // treat input as a label |
241 | if (l(shownMsgs) == 1) { |
242 | Msg shown = first(shownMsgs); |
243 | new Matches m; |
244 | if "not ..." { |
245 | S label = cleanLabel(m.rest()); |
246 | doubleKeyedMapPutVerbose(+msg2label_new, shown, label, false); |
247 | msg2labelUpdated(label); |
248 | if (autoNext) showRandomMsg(); |
249 | } else { |
250 | S label = cleanLabel(s); |
251 | doubleKeyedMapPutVerbose(+msg2label_new, shown, label, true); |
252 | msg2labelUpdated(label); |
253 | if (autoNext) showRandomMsg(); |
254 | } |
255 | change(); |
256 | } |
257 | } |
258 | |
259 | Map<S, Bool> getMsgLabels(Msg msg) { |
260 | ret msg2label_new.getA(msg); |
261 | } |
262 | Set<Msg> getMsgsRelatedToLabel(S label) { ret msg2label_new.asForB(label); } |
263 | |
264 | void msg2labelUpdated(S label) { |
265 | for (Theory t : cloneList(labelByName(label).bestTheories)) |
266 | checkTheory(t); |
267 | msg2labelUpdated(); |
268 | } |
269 | |
270 | void msg2labelUpdated() { |
271 | callFAllOnAll(onNewLabel, addAll_returnNew(allLabels, msg2label_new.bKeys())); |
272 | updateTrainedExamplesTable(); |
273 | } |
274 | |
275 | void updateTrainedExamplesTable { |
276 | dataToTable_uneditable(trainedExamplesTable, map(msg2label_new.map1, (msg, map) -> |
277 | litorderedmap( |
278 | "Message" := (msg.fromUser ? "User" : "Bot") + ": " + msg.text, |
279 | "Labels" := renderBoolMap(map)))); |
280 | } |
281 | |
282 | JComponent mainPart() { |
283 | ret jhsplit(jvsplit( |
284 | jCenteredSection("Focused Message", super.mainPart()), |
285 | jhsplit( |
286 | jCenteredSection("Message Analysis", dm_textArea analysisText()), |
287 | jCenteredSection("Predictions", scpPredictions = singleComponentPanel()) |
288 | )), |
289 | with(r updateTabs, tabs = jtabs( |
290 | "", with(r updateLabelsTable, labelsTable = sexyTable()), |
291 | "", with(r updateTheoryTable, tableWithSearcher2_returnPanel(theoryTable = sexyTable())), |
292 | "", with(r updateTrainedExamplesTable, tableWithSearcher2_returnPanel(trainedExamplesTable = sexyTable())) |
293 | ))); |
294 | } |
295 | |
296 | void updateTabs { |
297 | setTabTitles(tabs, |
298 | firstLetterToUpper(nLabels(labelsByName)), |
299 | firstLetterToUpper(nTheories(theories)), |
300 | n2(msg2label_new.aKeys(), "Trained Example")); |
301 | } |
302 | |
303 | void updateTheoryTable { |
304 | L<Theory> sorted = sortedByCalculatedFieldDesc(theories, t -> |
305 | t.examples == null ? null : t.examples.score()); |
306 | dataToTable_uneditable(theoryTable, map(sorted, t -> litorderedmap( |
307 | "Score" := renderTheoryScore(t), |
308 | "Theory" := str(t)))); |
309 | } |
310 | |
311 | Map<S, Theory> labelsToBestTheoryMap() { |
312 | Map<S, L<Theory>> map = multiMapToMap(multiMapIndex targetLabelOfTheory(theories)); |
313 | ret mapValues(map, |
314 | theories -> highestBy theoryScore(theories)); |
315 | } |
316 | |
317 | void updateLabelsTable { |
318 | L<Label> sorted = sortedByCalculatedFieldDesc(values(labelsByName), l -> l.score()); |
319 | dataToTable_uneditable(labelsTable, map(sorted, label -> { |
320 | Cl<Theory> bestTheories = label.bestTheories.tiedForFirst(); |
321 | ret litorderedmap( |
322 | "Label" := label.name, |
323 | "Prediction Confidence" := renderTheoryScore(first(bestTheories)), |
324 | "Best Theory" := empty(bestTheories) ? "" : |
325 | (l(bestTheories) > 1 ? "[+" + (l(bestTheories)-1) + "] " : "") + first(bestTheories)); |
326 | })); |
327 | } |
328 | |
329 | S renderTheoryScore(Theory t) { |
330 | //ret renderPosNegCounts(t.examples); |
331 | ret t == null || t.examples.isEmpty() ? "" : iround(adjustConfidence(threeB1BScore(t.examples))) + "%" |
332 | + " / " + renderPosNegScoreAndCount(t.examples); |
333 | } |
334 | |
335 | int theoryScore(Theory t) { |
336 | ret t == null ? -100 : t.examples.score(); |
337 | } |
338 | |
339 | void theoriesChanged { |
340 | updateTheoryTable(); |
341 | updateLabelsTable(); |
342 | updateTabs(); |
343 | updatePredictions(); |
344 | change(); |
345 | } |
346 | |
347 | visual |
348 | withCenteredButtons(super, |
349 | "Show random msg", rInThinkQ(r showRandomMsg), |
350 | jPopDownButton_noText(flattenObjectArray( |
351 | "Check theories", rInThinkQ(r checkAllTheories), |
352 | "Clear theories", rInThinkQ(r clearTheories), |
353 | "Update predictions", rInThinkQ(r updatePredictions), |
354 | dm_importAndExportAllDataMenuItems()))); |
355 | |
356 | Runnable rInThinkQ(Runnable r) { ret rInQ(thinkQ, r); } |
357 | |
358 | void addTheory(Theory theory) { |
359 | if (theories.add(theory)) { |
360 | print("New theory: " + theory); |
361 | addTheoryToCollectors(theory); |
362 | theoriesChanged(); |
363 | } |
364 | } |
365 | |
366 | void clearTheories { theories.clear(); theoriesChanged(); } |
367 | |
368 | Bool checkMsgProp(O prop, Msg msg) { |
369 | if (prop cast And) ret checkMsgProp(prop.a, msg) && checkMsgProp(prop.b, msg); |
370 | if (prop cast Not) ret not(checkMsgProp(prop.a, msg)); |
371 | ret ((MsgProp) prop).check(msg); |
372 | } |
373 | |
374 | Bool evalTheoryLHS(Theory theory, Msg msg) { |
375 | ret theory == null ? null |
376 | : checkMsgProp(theory.statement.lhs, msg); |
377 | } |
378 | |
379 | Bool testTheoryOnMsg(Theory theory, Msg msg) { |
380 | Bool lhs = evalTheoryLHS(theory, msg); |
381 | Bool rhs = checkMsgProp(theory.statement.rhs, msg); |
382 | if (lhs == null || rhs == null) null; |
383 | if (bidiMode) |
384 | ret eq(lhs, rhs); |
385 | else |
386 | ret isTrue(rhs) || isFalse(lhs); |
387 | } |
388 | |
389 | void checkAllTheories { |
390 | for (Theory theory : theories) |
391 | checkTheory_noTrigger(theory); |
392 | theoriesChanged(); |
393 | } |
394 | |
395 | void checkTheory(Theory theory) { |
396 | checkTheory_noTrigger(theory); |
397 | theoriesChanged(); |
398 | } |
399 | |
400 | void checkTheory_noTrigger(Theory theory) { |
401 | new PosNeg<Msg> pn; |
402 | for (Msg msg : msgs) |
403 | pn.add(msg, testTheoryOnMsg(theory, msg)); |
404 | if (!eq(theory.examples, pn)) { |
405 | removeTheoryFromCollectors(theory); |
406 | theory.examples = pn; |
407 | addTheoryToCollectors(theory); |
408 | change(); |
409 | } |
410 | } |
411 | |
412 | S cleanLabel(S label) { ret upper(label); } |
413 | |
414 | S targetLabelOfTheory(Theory theory) { |
415 | O o = theory.statement.rhs; |
416 | if (o cast HasLabel) ret o.label; |
417 | if (o cast DoesntHaveLabel) ret o.label; |
418 | null; |
419 | } |
420 | |
421 | void addTheoryToCollectors(Theory theory) { |
422 | S lbl = targetLabelOfTheory(theory); |
423 | if (lbl != null) |
424 | labelByName(lbl).bestTheories.add(theory); |
425 | } |
426 | |
427 | void removeTheoryFromCollectors(Theory theory) { |
428 | S lbl = targetLabelOfTheory(theory); |
429 | if (lbl != null) |
430 | labelByName(lbl).bestTheories.remove(theory); |
431 | } |
432 | |
433 | Label labelByName(S name) { |
434 | ret getOrCreate(labelsByName, name, () -> new Label(name)); |
435 | } |
436 | |
437 | void updateLabelsByName() { |
438 | for (S lbl : allLabels) |
439 | labelByName(lbl); |
440 | for (Theory t : theories) |
441 | addTheoryToCollectors(t); |
442 | } |
443 | |
444 | O getMsgFeature(Msg msg, S feature) { |
445 | ret msg2features.get(msg).get(feature); |
446 | } |
447 | |
448 | void makeTextExtractors(S textFeature) { |
449 | for (WithName<IF1<S, O>> f : textExtractors()) { |
450 | IF1<S, O> theFunction = f!; |
451 | featureExtractors.put(f.name, env -> theFunction.get((S) env.getFeature(textFeature))); |
452 | } |
453 | } |
454 | |
455 | L<WithName<IF1<S, O>>> textExtractors() { |
456 | new L<WithName<IF1<S, O>>> l; |
457 | l.add(WithName<>("number of words", lambda1 numberOfWords)); |
458 | l.add(WithName<>("number of characters", lambda1 l)); |
459 | for (char c : characters("\"', .-_")) |
460 | l.add(WithName<>("contains " + quote(c), s -> contains(s, c))); |
461 | ret l; |
462 | } |
463 | } |
download show line numbers debug dex old transpilations
Travelled to 7 computer(s): bhatertpkbcr, mqqgnosmbjvj, pyentgdyhuwx, pzhvpgtvlbxg, tvejysmllsmz, vouqrxazstgt, xrpafgyirdlv
No comments. add comment
Snippet ID: | #1027773 |
Snippet name: | Auto Classifier v1[learning message classifier] |
Eternal ID of this version: | #1027773/179 |
Text MD5: | 970ed7539dfbe1b678b0fc42a7f08fda |
Transpilation MD5: | 567e1aa73e99525599ccc46d760825dd |
Author: | stefan |
Category: | javax / a.i. |
Type: | JavaX source code (Dynamic Module) |
Public (visible to everyone): | Yes |
Archived (hidden from active list): | No |
Created/modified: | 2020-05-07 14:04:38 |
Source code size: | 14798 bytes / 463 lines |
Pitched / IR pitched: | No / No |
Views / Downloads: | 251 / 5691 |
Version history: | 178 change(s) |
Referenced in: | [show references] |