!752 !include #1003048 // load log & tags static new HashSet train; static new HashSet test; static new L learners; static S tag; // only focusing on one tag // a function that tags messages sclass Function { L tag(SlackMsg msg) { ret emptyList(); } } static class Learner { void learnAll(L l) { for (SlackMsg msg : l) learn(msg); } void learn(SlackMsg msg) {} // a learner can make multiple functions Function nextFunction() { ret null; } } static class DetectWord extends Function { S word; *() {} *(S *word) {} L tag(SlackMsg msg) { // rough bool b = containsIgnoreCase(msg.text, word); /*if (eq(word, "bot")) print("[debug] " + b + " " + msg.text);*/ ret b ? litlist(tag) : emptyList(); } } static class EndsWith extends Function { S s; *() {} *(S *s) {} L tag(SlackMsg msg) { bool b = endsWithIgnoreCase(msg.text, s); ret b ? litlist(tag) : emptyList(); } } static class LTryPlusWord extends Learner { S functionClass = /*"DetectWord"*/"EndsWith"; new HashSet allWords; Iterator it; void learn(SlackMsg msg) { if (msg.tags.contains(tag)) allWords.addAll(codeTokensOnly(nlTok(dropPunctuation(msg.text)))); } Function nextFunction() { if (it == null) it = allWords.iterator(); if (!it.hasNext()) ret null; ret (Function) newInstance("main$" + functionClass, it.next()); } } static int errors(SlackMsg msg, L tags) { int n = 0; for (S tag : msg.tags) if (!tags.contains(tag)) ++n; for (S tag : tags) if (!msg.tags.contains(tag)) ++n; ret n; } static L getErrors(SlackMsg msg, L tags) { new L l; for (S tag : msg.tags) if (!tags.contains(tag)) l.add("-" + tag); for (S tag : tags) if (!msg.tags.contains(tag)) l.add("+" + tag); ret l; } static void printTest() { for (SlackMsg msg : test) print(structure(msg.tags) + " " + msg.text); } p { init(); tag = "bot greeting"; filterTag(tag); makeTrain(tag, 10); makeTest(tag, 5); new Map scores; test.addAll(train); // test it all print("Test size: " + l(test)); printTest(); // make learners learners.add(new LTryPlusWord); L lTrain = asList(train); for (Learner l : learners) pcall { l.learnAll(lTrain); Function f; while ((f = l.nextFunction()) != null) { int score = 0; for (SlackMsg msg : test) score -= errors(msg, f.tag(msg)); print("Score " + score + " for " + structure(f)); scores.put(f, score); } } int baseScore = -countTags(test); print("Base score: " + baseScore); Function best = highest(scores); if (best == null) print("No winner"); else { int score = scores.get(best); print("Best score: " + score + ", winner: " + structure(best)); printMistakes(best); } } static void printMistakes(Function f) { for (SlackMsg msg : test) { L l = getErrors(msg, f.tag(msg)); if (nempty(l)) print("[mistake] " + msg.text + " " + structure(l)); } } static int countTags(Collection msgs) { int n = 0; for (SlackMsg msg : msgs) n += l(msg.tags); ret n; } static void makeTrain(S tag, int negFieldSize) { makeTagField(train, tag, negFieldSize); } static void makeTest(S tag, int negFieldSize) { makeTagField(test, tag, negFieldSize); } static void makeTagField(HashSet dest, S tag, int negFieldSize) { L plus = msgsByTag.get(tag); // add plus dest.addAll(plus); // add neg fields of all plus messages for (SlackMsg m : plus) dest.addAll(makeNegField(m, negFieldSize)); } static Collection makeNegField(SlackMsg msg, int size) { new HashSet field; if (msg == null) ret field; int i1 = msg.index-size, i2 = msg.index+size; for (int i = i1; i <= i2; i++) { SlackMsg m = msgsByIndex.get(i); if (m != null) field.add(m); } ret field; }