]> git.lizzy.rs Git - plan9front.git/blob - sys/src/liboventi/server.c
libsec: use tsmemcmp() when comparing hashes, use mpfield() for ecc, use mptober...
[plan9front.git] / sys / src / liboventi / server.c
1 #include <u.h>
2 #include <libc.h>
3 #include <oventi.h>
4 #include "session.h"
5
6 static char EAuthState[] = "bad authentication state";
7 static char ENotServer[] = "not a server session";
8 static char EVersion[] = "incorrect version number";
9 static char EProtocolBotch[] = "venti protocol botch";
10
11 VtSession *
12 vtServerAlloc(VtServerVtbl *vtbl)
13 {
14         VtSession *z = vtAlloc();
15         z->vtbl = vtMemAlloc(sizeof(VtServerVtbl));
16         setmalloctag(z->vtbl, getcallerpc(&vtbl));
17         *z->vtbl = *vtbl;
18         return z;
19 }
20
21 static int
22 srvHello(VtSession *z, char *version, char *uid, int , uchar *, int , uchar *, int )
23 {
24         vtLock(z->lk);
25         if(z->auth.state != VtAuthHello) {
26                 vtSetError(EAuthState);
27                 goto Err;
28         }
29         if(strcmp(version, vtGetVersion(z)) != 0) {
30                 vtSetError(EVersion);
31                 goto Err;
32         }
33         vtMemFree(z->uid);
34         z->uid = vtStrDup(uid);
35         z->auth.state = VtAuthOK;
36         vtUnlock(z->lk);
37         return 1;
38 Err:
39         z->auth.state = VtAuthFailed;
40         vtUnlock(z->lk);
41         return 0;
42 }
43
44
45 static int
46 dispatchHello(VtSession *z, Packet **pkt)
47 {
48         char *version, *uid;
49         uchar *crypto, *codec;
50         uchar buf[10];
51         int ncrypto, ncodec, cryptoStrength;
52         int ret;
53         Packet *p;
54
55         p = *pkt;
56
57         version = nil;  
58         uid = nil;
59         crypto = nil;
60         codec = nil;
61
62         ret = 0;
63         if(!vtGetString(p, &version))
64                 goto Err;
65         if(!vtGetString(p, &uid))
66                 goto Err;
67         if(!packetConsume(p, buf, 2))
68                 goto Err;
69         cryptoStrength = buf[0];
70         ncrypto = buf[1];
71         crypto = vtMemAlloc(ncrypto);
72         if(!packetConsume(p, crypto, ncrypto))
73                 goto Err;
74
75         if(!packetConsume(p, buf, 1))
76                 goto Err;
77         ncodec = buf[0];
78         codec = vtMemAlloc(ncodec);
79         if(!packetConsume(p, codec, ncodec))
80                 goto Err;
81
82         if(packetSize(p) != 0) {
83                 vtSetError(EProtocolBotch);
84                 goto Err;
85         }
86         if(!srvHello(z, version, uid, cryptoStrength, crypto, ncrypto, codec, ncodec)) {
87                 packetFree(p);
88                 *pkt = nil;
89         } else {
90                 if(!vtAddString(p, vtGetSid(z)))
91                         goto Err;
92                 buf[0] = vtGetCrypto(z);
93                 buf[1] = vtGetCodec(z);
94                 packetAppend(p, buf, 2);
95         }
96         ret = 1;
97 Err:
98         vtMemFree(version);
99         vtMemFree(uid);
100         vtMemFree(crypto);
101         vtMemFree(codec);
102         return ret;
103 }
104
105 static int
106 dispatchRead(VtSession *z, Packet **pkt)
107 {
108         Packet *p;
109         int type, n;
110         uchar score[VtScoreSize], buf[4];
111
112         p = *pkt;
113         if(!packetConsume(p, score, VtScoreSize))
114                 return 0;
115         if(!packetConsume(p, buf, 4))
116                 return 0;
117         type = buf[0];
118         n = (buf[2]<<8) | buf[3];
119         if(packetSize(p) != 0) {
120                 vtSetError(EProtocolBotch);
121                 return 0;
122         }
123         packetFree(p);
124         *pkt = (*z->vtbl->read)(z, score, type, n);
125         return 1;
126 }
127
128 static int
129 dispatchWrite(VtSession *z, Packet **pkt)
130 {
131         Packet *p;
132         int type;
133         uchar score[VtScoreSize], buf[4];
134
135         p = *pkt;
136         if(!packetConsume(p, buf, 4))
137                 return 0;
138         type = buf[0];
139         if(!(z->vtbl->write)(z, score, type, p)) {
140                 *pkt = 0;
141         } else {
142                 *pkt = packetAlloc();
143                 packetAppend(*pkt, score, VtScoreSize);
144         }
145         return 1;
146 }
147
148 static int
149 dispatchSync(VtSession *z, Packet **pkt)
150 {
151         (z->vtbl->sync)(z);
152         if(packetSize(*pkt) != 0) {
153                 vtSetError(EProtocolBotch);
154                 return 0;
155         }
156         return 1;
157 }
158
159 int
160 vtExport(VtSession *z)
161 {
162         Packet *p;
163         uchar buf[10], *hdr;
164         int op, tid, clean;
165
166         if(z->vtbl == nil) {
167                 vtSetError(ENotServer);
168                 return 0;
169         }
170
171         /* fork off slave */
172         switch(rfork(RFNOWAIT|RFMEM|RFPROC)){
173         case -1:
174                 vtOSError();
175                 return 0;
176         case 0:
177                 break;
178         default:
179                 return 1;
180         }
181
182         
183         p = nil;
184         clean = 0;
185         vtAttach();
186         if(!vtConnect(z, nil))
187                 goto Exit;
188
189         vtDebug(z, "server connected!\n");
190 if(0)   vtSetDebug(z, 1);
191
192         for(;;) {
193                 p = vtRecvPacket(z);
194                 if(p == nil) {
195                         break;
196                 }
197                 vtDebug(z, "server recv: ");
198                 vtDebugMesg(z, p, "\n");
199
200                 if(!packetConsume(p, buf, 2)) {
201                         vtSetError(EProtocolBotch);
202                         break;
203                 }
204                 op = buf[0];
205                 tid = buf[1];
206                 switch(op) {
207                 default:
208                         vtSetError(EProtocolBotch);
209                         goto Exit;
210                 case VtQPing:
211                         break;
212                 case VtQGoodbye:
213                         clean = 1;
214                         goto Exit;
215                 case VtQHello:
216                         if(!dispatchHello(z, &p))
217                                 goto Exit;
218                         break;
219                 case VtQRead:
220                         if(!dispatchRead(z, &p))
221                                 goto Exit;
222                         break;
223                 case VtQWrite:
224                         if(!dispatchWrite(z, &p))
225                                 goto Exit;
226                         break;
227                 case VtQSync:
228                         if(!dispatchSync(z, &p))
229                                 goto Exit;
230                         break;
231                 }
232                 if(p != nil) {
233                         hdr = packetHeader(p, 2);
234                         hdr[0] = op+1;
235                         hdr[1] = tid;
236                 } else {
237                         p = packetAlloc();
238                         hdr = packetHeader(p, 2);
239                         hdr[0] = VtRError;
240                         hdr[1] = tid;
241                         if(!vtAddString(p, vtGetError()))
242                                 goto Exit;
243                 }
244
245                 vtDebug(z, "server send: ");
246                 vtDebugMesg(z, p, "\n");
247
248                 if(!vtSendPacket(z, p)) {
249                         p = nil;
250                         goto Exit;
251                 }
252         }
253 Exit:
254         if(p != nil)
255                 packetFree(p);
256         if(z->vtbl->closing)
257                 z->vtbl->closing(z, clean);
258         vtClose(z);
259         vtFree(z);
260         vtDetach();
261
262         exits(0);
263         return 0;       /* never gets here */
264 }
265