Libraryless. Click here for Pure Java version (2287L/15K/51K).
1 | !752 |
2 | |
3 | !include #1003048 // load log & tags |
4 | |
5 | static new HashSet<SlackMsg> train; |
6 | static new HashSet<SlackMsg> test; |
7 | |
8 | static new L<Learner> learners; |
9 | |
10 | static S tag; // only focusing on one tag |
11 | |
12 | // a function that tags messages |
13 | sclass Function { |
14 | L<S> tag(SlackMsg msg) { ret emptyList(); } |
15 | } |
16 | |
17 | static class Learner { |
18 | void learnAll(L<SlackMsg> l) { |
19 | for (SlackMsg msg : l) |
20 | learn(msg); |
21 | } |
22 | |
23 | void learn(SlackMsg msg) {} |
24 | |
25 | // a learner can make multiple functions |
26 | Function nextFunction() { ret null; } |
27 | } |
28 | |
29 | static class DetectWord extends Function { |
30 | S word; |
31 | |
32 | *() {} |
33 | *(S *word) {} |
34 | |
35 | L<S> tag(SlackMsg msg) { |
36 | // rough |
37 | bool b = containsIgnoreCase(msg.text, word); |
38 | |
39 | /*if (eq(word, "bot")) |
40 | print("[debug] " + b + " " + msg.text);*/ |
41 | |
42 | ret b ? litlist(tag) : emptyList(); |
43 | } |
44 | } |
45 | |
46 | static class EndsWith extends Function { |
47 | S s; |
48 | |
49 | *() {} |
50 | *(S *s) {} |
51 | |
52 | L<S> tag(SlackMsg msg) { |
53 | bool b = endsWithIgnoreCase(msg.text, s); |
54 | ret b ? litlist(tag) : emptyList(); |
55 | } |
56 | } |
57 | |
58 | static class LTryPlusWord extends Learner { |
59 | S functionClass = /*"DetectWord"*/"EndsWith"; |
60 | new HashSet<S> allWords; |
61 | Iterator<S> it; |
62 | |
63 | void learn(SlackMsg msg) { |
64 | if (msg.tags.contains(tag)) |
65 | allWords.addAll(codeTokensOnly(nlTok(dropPunctuation(msg.text)))); |
66 | } |
67 | |
68 | Function nextFunction() { |
69 | if (it == null) |
70 | it = allWords.iterator(); |
71 | if (!it.hasNext()) ret null; |
72 | ret (Function) newInstance("main$" + functionClass, it.next()); |
73 | } |
74 | } |
75 | |
76 | static int errors(SlackMsg msg, L<S> tags) { |
77 | int n = 0; |
78 | for (S tag : msg.tags) |
79 | if (!tags.contains(tag)) |
80 | ++n; |
81 | for (S tag : tags) |
82 | if (!msg.tags.contains(tag)) |
83 | ++n; |
84 | ret n; |
85 | } |
86 | |
87 | static L<S> getErrors(SlackMsg msg, L<S> tags) { |
88 | new L<S> l; |
89 | for (S tag : msg.tags) |
90 | if (!tags.contains(tag)) |
91 | l.add("-" + tag); |
92 | for (S tag : tags) |
93 | if (!msg.tags.contains(tag)) |
94 | l.add("+" + tag); |
95 | ret l; |
96 | } |
97 | |
98 | static void printTest() { |
99 | for (SlackMsg msg : test) |
100 | print(structure(msg.tags) + " " + msg.text); |
101 | } |
102 | |
103 | p { |
104 | init(); |
105 | |
106 | tag = "bot greeting"; |
107 | filterTag(tag); |
108 | makeTrain(tag, 10); |
109 | makeTest(tag, 5); |
110 | |
111 | new Map<Function, Int> scores; |
112 | |
113 | test.addAll(train); // test it all |
114 | print("Test size: " + l(test)); |
115 | printTest(); |
116 | |
117 | // make learners |
118 | learners.add(new LTryPlusWord); |
119 | |
120 | L<SlackMsg> lTrain = asList(train); |
121 | for (Learner l : learners) pcall { |
122 | l.learnAll(lTrain); |
123 | |
124 | Function f; |
125 | while ((f = l.nextFunction()) != null) { |
126 | int score = 0; |
127 | for (SlackMsg msg : test) |
128 | score -= errors(msg, f.tag(msg)); |
129 | print("Score " + score + " for " + structure(f)); |
130 | scores.put(f, score); |
131 | } |
132 | } |
133 | |
134 | int baseScore = -countTags(test); |
135 | print("Base score: " + baseScore); |
136 | |
137 | Function best = highest(scores); |
138 | if (best == null) |
139 | print("No winner"); |
140 | else { |
141 | int score = scores.get(best); |
142 | print("Best score: " + score + ", winner: " + structure(best)); |
143 | printMistakes(best); |
144 | } |
145 | } |
146 | |
147 | static void printMistakes(Function f) { |
148 | for (SlackMsg msg : test) { |
149 | L<S> l = getErrors(msg, f.tag(msg)); |
150 | if (nempty(l)) |
151 | print("[mistake] " + msg.text + " " + structure(l)); |
152 | } |
153 | } |
154 | |
155 | static int countTags(Collection<SlackMsg> msgs) { |
156 | int n = 0; |
157 | for (SlackMsg msg : msgs) |
158 | n += l(msg.tags); |
159 | ret n; |
160 | } |
161 | |
162 | static void makeTrain(S tag, int negFieldSize) { |
163 | makeTagField(train, tag, negFieldSize); |
164 | } |
165 | |
166 | static void makeTest(S tag, int negFieldSize) { |
167 | makeTagField(test, tag, negFieldSize); |
168 | } |
169 | |
170 | static void makeTagField(HashSet<SlackMsg> dest, S tag, int negFieldSize) { |
171 | L<SlackMsg> plus = msgsByTag.get(tag); |
172 | |
173 | // add plus |
174 | dest.addAll(plus); |
175 | |
176 | // add neg fields of all plus messages |
177 | for (SlackMsg m : plus) |
178 | dest.addAll(makeNegField(m, negFieldSize)); |
179 | } |
180 | |
181 | static Collection<SlackMsg> makeNegField(SlackMsg msg, int size) { |
182 | new HashSet<SlackMsg> field; |
183 | if (msg == null) ret field; |
184 | int i1 = msg.index-size, i2 = msg.index+size; |
185 | for (int i = i1; i <= i2; i++) { |
186 | SlackMsg m = msgsByIndex.get(i); |
187 | if (m != null) |
188 | field.add(m); |
189 | } |
190 | ret field; |
191 | } |
download show line numbers debug dex old transpilations
Travelled to 13 computer(s): aoiabmzegqzx, bhatertpkbcr, cbybwowwnfue, cfunsshuasjs, gwrvuhgaqvyk, ishqpsrjomds, lpdgvwnxivlt, mqqgnosmbjvj, pyentgdyhuwx, pzhvpgtvlbxg, tslmcundralx, tvejysmllsmz, vouqrxazstgt
No comments. add comment
Snippet ID: | #1003049 |
Snippet name: | Learn Bot Greetings |
Eternal ID of this version: | #1003049/1 |
Text MD5: | d5f2602e676235d576801f2f6adfa493 |
Transpilation MD5: | 175488ffacc9c1c2ac0ea8290d9d4b81 |
Author: | stefan |
Category: | eleu / nl |
Type: | JavaX source code |
Public (visible to everyone): | Yes |
Archived (hidden from active list): | No |
Created/modified: | 2016-04-26 17:08:52 |
Source code size: | 4248 bytes / 191 lines |
Pitched / IR pitched: | No / No |
Views / Downloads: | 591 / 667 |
Referenced in: | [show references] |