]> git.lizzy.rs Git - plan9front.git/blob - sys/src/cmd/upas/bayes/bayes.c
merge
[plan9front.git] / sys / src / cmd / upas / bayes / bayes.c
1 #include <u.h>
2 #include <libc.h>
3 #include <bio.h>
4 #include "regexp.h"
5 #include "hash.h"
6
7 enum
8 {
9         MAXTAB = 256,
10         MAXBEST = 32,
11 };
12
13 typedef struct Table Table;
14 struct Table
15 {
16         char *file;
17         Hash *hash;
18         int nmsg;
19 };
20
21 typedef struct Word Word;
22 struct Word
23 {
24         Stringtab *s;   /* from hmsg */
25         int count[MAXTAB];      /* counts from each table */
26         double p[MAXTAB];       /* probabilities from each table */
27         double mp;      /* max probability */
28         int mi;         /* w.p[w.mi] = w.mp */
29 };
30
31 Table tab[MAXTAB];
32 int ntab;
33
34 Word best[MAXBEST];
35 int mbest;
36 int nbest;
37
38 int debug;
39
40 void
41 usage(void)
42 {
43         fprint(2, "usage: bayes [-D] [-m maxword] boxhash ... ~ msghash ...\n");
44         exits("usage");
45 }
46
47 void*
48 emalloc(int n)
49 {
50         void *v;
51
52         v = mallocz(n, 1);
53         if(v == nil)
54                 sysfatal("out of memory");
55         return v;
56 }
57
58 void
59 noteword(Word *w)
60 {
61         int i;
62
63         for(i=nbest-1; i>=0; i--)
64                 if(w->mp < best[i].mp)
65                         break;
66         i++;
67
68         if(i >= mbest)
69                 return;
70         if(nbest == mbest)
71                 nbest--;
72         if(i < nbest)
73                 memmove(&best[i+1], &best[i], (nbest-i)*sizeof(best[0]));
74         best[i] = *w;
75         nbest++;
76 }
77
78 Hash*
79 hread(char *s)
80 {
81         Hash *h;
82         Biobuf *b;
83
84         if((b = Bopenlock(s, OREAD)) == nil)
85                 sysfatal("open %s: %r", s);
86
87         h = emalloc(sizeof(Hash));
88         Breadhash(b, h, 1);
89         Bterm(b);
90         return h;
91 }
92
93 void
94 main(int argc, char **argv)
95 {
96         int i, j, a, mi, oi, tot, keywords;
97         double totp, p, xp[MAXTAB];
98         Hash *hmsg;
99         Word w;
100         Stringtab *s, *t;
101         Biobuf bout;
102
103         mbest = 15;
104         keywords = 0;
105         ARGBEGIN{
106         case 'D':
107                 debug = 1;
108                 break;
109         case 'k':
110                 keywords = 1;
111                 break;
112         case 'm':
113                 mbest = atoi(EARGF(usage()));
114                 if(mbest > MAXBEST)
115                         sysfatal("cannot keep more than %d words", MAXBEST);
116                 break;
117         default:
118                 usage();
119         }ARGEND
120
121         for(i=0; i<argc; i++)
122                 if(strcmp(argv[i], "~") == 0)
123                         break;
124
125         if(i > MAXTAB)
126                 sysfatal("cannot handle more than %d tables", MAXTAB);
127
128         if(i+1 >= argc)
129                 usage();
130
131         for(i=0; i<argc; i++){
132                 if(strcmp(argv[i], "~") == 0)
133                         break;
134                 tab[ntab].file = argv[i];
135                 tab[ntab].hash = hread(argv[i]);
136                 s = findstab(tab[ntab].hash, "*nmsg*", 6, 1);
137                 if(s == nil || s->count == 0)
138                         tab[ntab].nmsg = 1;
139                 else
140                         tab[ntab].nmsg = s->count;
141                 ntab++;
142         }
143
144         Binit(&bout, 1, OWRITE);
145
146         oi = ++i;
147         for(a=i; a<argc; a++){
148                 hmsg = hread(argv[a]);
149                 nbest = 0;
150                 for(s=hmsg->all; s; s=s->link){
151                         w.s = s;
152                         tot = 0;
153                         totp = 0.0;
154                         for(i=0; i<ntab; i++){
155                                 t = findstab(tab[i].hash, s->str, s->n, 0);
156                                 if(t == nil)
157                                         w.count[i] = 0;
158                                 else
159                                         w.count[i] = t->count;
160                                 tot += w.count[i];
161                                 p = w.count[i]/(double)tab[i].nmsg;
162                                 if(p >= 1.0)
163                                         p = 1.0;
164                                 w.p[i] = p;
165                                 totp += p;
166                         }
167
168                         if(tot < 5){            /* word does not appear enough; give to box 0 */
169                                 w.p[0] = 0.5;
170                                 for(i=1; i<ntab; i++)
171                                         w.p[i] = 0.1;
172                                 w.mp = 0.5;
173                                 w.mi = 0;
174                                 noteword(&w);
175                                 continue;
176                         }
177
178                         w.mp = 0.0;
179                         for(i=0; i<ntab; i++){
180                                 p = w.p[i];
181                                 p /= totp;
182                                 if(p < 0.01)
183                                         p = 0.01;
184                                 else if(p > 0.99)
185                                         p = 0.99;
186                                 if(p > w.mp){
187                                         w.mp = p;
188                                         w.mi = i;
189                                 }
190                                 w.p[i] = p;
191                         }
192                         noteword(&w);
193                 }
194
195                 totp = 0.0;
196                 for(i=0; i<ntab; i++){
197                         p = 1.0;
198                         for(j=0; j<nbest; j++)
199                                 p *= best[j].p[i];
200                         xp[i] = p;
201                         totp += p;
202                 }
203                 for(i=0; i<ntab; i++)
204                         xp[i] /= totp;
205                 mi = 0;
206                 for(i=1; i<ntab; i++)
207                         if(xp[i] > xp[mi])
208                                 mi = i;
209                 if(oi != argc-1)
210                         Bprint(&bout, "%s: ", argv[a]);
211                 Bprint(&bout, "%s %f", tab[mi].file, xp[mi]);
212                 if(keywords){
213                         for(i=0; i<nbest; i++){
214                                 Bprint(&bout, " ");
215                                 Bwrite(&bout, best[i].s->str, best[i].s->n);
216                                 Bprint(&bout, " %f", best[i].p[mi]);
217                         }
218                 }
219                 freehash(hmsg);
220                 Bprint(&bout, "\n");
221                 if(debug){
222                         for(i=0; i<nbest; i++){
223                                 Bwrite(&bout, best[i].s->str, best[i].s->n);
224                                 Bprint(&bout, " %f", best[i].p[mi]);
225                                 if(best[i].p[mi] < best[i].mp)
226                                         Bprint(&bout, " (%f %s)", best[i].mp, tab[best[i].mi].file);
227                                 Bprint(&bout, "\n");
228                         }
229                 }
230         }
231         Bterm(&bout);
232 }