]> git.lizzy.rs Git - plan9front.git/blob - sys/src/liboventi/rpc.c
libsec: use tsmemcmp() when comparing hashes, use mpfield() for ecc, use mptober...
[plan9front.git] / sys / src / liboventi / rpc.c
1 #include <u.h>
2 #include <libc.h>
3 #include <oventi.h>
4 #include "session.h"
5
6 struct {
7         int version;
8         char *s;
9 } vtVersions[] = {
10         VtVersion02, "02",
11         0, 0,
12 };
13
14 static char EBigString[] = "string too long";
15 static char EBigPacket[] = "packet too long";
16 static char ENullString[] = "missing string";
17 static char EBadVersion[] = "bad format in version string";
18
19 static Packet *vtRPC(VtSession *z, int op, Packet *p);
20
21
22 VtSession *
23 vtAlloc(void)
24 {
25         VtSession *z;
26
27         z = vtMemAllocZ(sizeof(VtSession));
28         z->lk = vtLockAlloc();
29 //      z->inHash = vtSha1Alloc();
30         z->inLock = vtLockAlloc();
31         z->part = packetAlloc();
32 //      z->outHash = vtSha1Alloc();
33         z->outLock = vtLockAlloc();
34         z->fd = -1;
35         z->uid = vtStrDup("anonymous");
36         z->sid = vtStrDup("anonymous");
37         return z;
38 }
39
40 void
41 vtReset(VtSession *z)
42 {
43         vtLock(z->lk);
44         z->cstate = VtStateAlloc;
45         if(z->fd >= 0){
46                 vtFdClose(z->fd);
47                 z->fd = -1;
48         }
49         vtUnlock(z->lk);
50 }
51
52 int
53 vtConnected(VtSession *z)
54 {
55         return z->cstate == VtStateConnected;
56 }
57
58 void
59 vtDisconnect(VtSession *z, int error)
60 {
61         Packet *p;
62         uchar *b;
63
64 vtDebug(z, "vtDisconnect\n");
65         vtLock(z->lk);
66         if(z->cstate == VtStateConnected && !error && z->vtbl == nil) {
67                 /* clean shutdown */
68                 p = packetAlloc();
69                 b = packetHeader(p, 2);
70                 b[0] = VtQGoodbye;
71                 b[1] = 0;
72                 vtSendPacket(z, p);
73         }
74         if(z->fd >= 0)
75                 vtFdClose(z->fd);
76         z->fd = -1;
77         z->cstate = VtStateClosed;
78         vtUnlock(z->lk);
79 }
80
81 void
82 vtClose(VtSession *z)
83 {
84         vtDisconnect(z, 0);
85 }
86
87 void
88 vtFree(VtSession *z)
89 {
90         if(z == nil)
91                 return;
92         vtLockFree(z->lk);
93         vtSha1Free(z->inHash);
94         vtLockFree(z->inLock);
95         packetFree(z->part);
96         vtSha1Free(z->outHash);
97         vtLockFree(z->outLock);
98         vtMemFree(z->uid);
99         vtMemFree(z->sid);
100         vtMemFree(z->vtbl);
101
102         memset(z, 0, sizeof(VtSession));
103         z->fd = -1;
104
105         vtMemFree(z);
106 }
107
108 char *
109 vtGetUid(VtSession *s)
110 {
111         return s->uid;
112 }
113
114 char *
115 vtGetSid(VtSession *z)
116 {
117         return z->sid;
118 }
119
120 int
121 vtSetDebug(VtSession *z, int debug)
122 {
123         int old;
124         vtLock(z->lk);
125         old = z->debug;
126         z->debug = debug;
127         vtUnlock(z->lk);
128         return old;
129 }
130
131 int
132 vtSetFd(VtSession *z, int fd)
133 {
134         vtLock(z->lk);
135         if(z->cstate != VtStateAlloc) {
136                 vtSetError("bad state");
137                 vtUnlock(z->lk);
138                 return 0;
139         }
140         if(z->fd >= 0)
141                 vtFdClose(z->fd);
142         z->fd = fd;
143         vtUnlock(z->lk);
144         return 1;
145 }
146
147 int
148 vtGetFd(VtSession *z)
149 {
150         return z->fd;
151 }
152
153 int
154 vtSetCryptoStrength(VtSession *z, int c)
155 {
156         if(z->cstate != VtStateAlloc) {
157                 vtSetError("bad state");
158                 return 0;
159         }
160         if(c != VtCryptoStrengthNone) {
161                 vtSetError("not supported yet");
162                 return 0;
163         }
164         return 1;
165 }
166
167 int
168 vtGetCryptoStrength(VtSession *s)
169 {
170         return s->cryptoStrength;
171 }
172
173 int
174 vtSetCompression(VtSession *z, int fd)
175 {
176         vtLock(z->lk);
177         if(z->cstate != VtStateAlloc) {
178                 vtSetError("bad state");
179                 vtUnlock(z->lk);
180                 return 0;
181         }
182         z->fd = fd;
183         vtUnlock(z->lk);
184         return 1;
185 }
186
187 int
188 vtGetCompression(VtSession *s)
189 {
190         return s->compression;
191 }
192
193 int
194 vtGetCrypto(VtSession *s)
195 {
196         return s->crypto;
197 }
198
199 int
200 vtGetCodec(VtSession *s)
201 {
202         return s->codec;
203 }
204
205 char *
206 vtGetVersion(VtSession *z)
207 {
208         int v, i;
209         
210         v = z->version;
211         if(v == 0)
212                 return "unknown";
213         for(i=0; vtVersions[i].version; i++)
214                 if(vtVersions[i].version == v)
215                         return vtVersions[i].s;
216         assert(0);
217         return 0;
218 }
219
220 /* hold z->inLock */
221 static int
222 vtVersionRead(VtSession *z, char *prefix, int *ret)
223 {
224         char c;
225         char buf[VtMaxStringSize];
226         char *q, *p, *pp;
227         int i;
228
229         q = prefix;
230         p = buf;
231         for(;;) {
232                 if(p >= buf + sizeof(buf)) {
233                         vtSetError(EBadVersion);
234                         return 0;
235                 }
236                 if(!vtFdReadFully(z->fd, (uchar*)&c, 1))
237                         return 0;
238                 if(z->inHash)
239                         vtSha1Update(z->inHash, (uchar*)&c, 1);
240                 if(c == '\n') {
241                         *p = 0;
242                         break;
243                 }
244                 if(c < ' ' || *q && c != *q) {
245                         vtSetError(EBadVersion);
246                         return 0;
247                 }
248                 *p++ = c;
249                 if(*q)
250                         q++;
251         }
252                 
253         vtDebug(z, "version string in: %s\n", buf);
254
255         p = buf + strlen(prefix);
256         for(;;) {
257                 for(pp=p; *pp && *pp != ':'  && *pp != '-'; pp++)
258                         ;
259                 for(i=0; vtVersions[i].version; i++) {
260                         if(strlen(vtVersions[i].s) != pp-p)
261                                 continue;
262                         if(memcmp(vtVersions[i].s, p, pp-p) == 0) {
263                                 *ret = vtVersions[i].version;
264                                 return 1;
265                         }
266                 }
267                 p = pp;
268                 if(*p != ':')
269                         return 0;
270                 p++;
271         }       
272 }
273
274 Packet*
275 vtRecvPacket(VtSession *z)
276 {
277         uchar buf[10], *b;
278         int n;
279         Packet *p;
280         int size, len;
281
282         if(z->cstate != VtStateConnected) {
283                 vtSetError("session not connected");
284                 return 0;
285         }
286
287         vtLock(z->inLock);
288         p = z->part;
289         /* get enough for head size */
290         size = packetSize(p);
291         while(size < 2) {
292                 b = packetTrailer(p, MaxFragSize);
293                 assert(b != nil);
294                 n = vtFdRead(z->fd, b, MaxFragSize);
295                 if(n <= 0)
296                         goto Err;
297                 size += n;
298                 packetTrim(p, 0, size);
299         }
300
301         if(!packetConsume(p, buf, 2))
302                 goto Err;
303         len = (buf[0] << 8) | buf[1];
304         size -= 2;
305
306         while(size < len) {
307                 n = len - size;
308                 if(n > MaxFragSize)
309                         n = MaxFragSize;
310                 b = packetTrailer(p, n);
311                 if(!vtFdReadFully(z->fd, b, n))
312                         goto Err;
313                 size += n;
314         }
315         p = packetSplit(p, len);
316         vtUnlock(z->inLock);
317         return p;
318 Err:    
319         vtUnlock(z->inLock);
320         return nil;     
321 }
322
323 int
324 vtSendPacket(VtSession *z, Packet *p)
325 {
326         IOchunk ioc;
327         int n;
328         uchar buf[2];
329         
330         /* add framing */
331         n = packetSize(p);
332         if(n >= (1<<16)) {
333                 vtSetError(EBigPacket);
334                 packetFree(p);
335                 return 0;
336         }
337         buf[0] = n>>8;
338         buf[1] = n;
339         packetPrefix(p, buf, 2);
340
341         for(;;) {
342                 n = packetFragments(p, &ioc, 1, 0);
343                 if(n == 0)
344                         break;
345                 if(!vtFdWrite(z->fd, ioc.addr, ioc.len)) {
346                         packetFree(p);
347                         return 0;
348                 }
349                 packetConsume(p, nil, n);
350         }
351         packetFree(p);
352         return 1;
353 }
354
355
356 int
357 vtGetString(Packet *p, char **ret)
358 {
359         uchar buf[2];
360         int n;
361         char *s;
362
363         if(!packetConsume(p, buf, 2))
364                 return 0;
365         n = (buf[0]<<8) + buf[1];
366         if(n > VtMaxStringSize) {
367                 vtSetError(EBigString);
368                 return 0;
369         }
370         s = vtMemAlloc(n+1);
371         setmalloctag(s, getcallerpc(&p));
372         if(!packetConsume(p, (uchar*)s, n)) {
373                 vtMemFree(s);
374                 return 0;
375         }
376         s[n] = 0;
377         *ret = s;
378         return 1;
379 }
380
381 int
382 vtAddString(Packet *p, char *s)
383 {
384         uchar buf[2];
385         int n;
386
387         if(s == nil) {
388                 vtSetError(ENullString);
389                 return 0;
390         }
391         n = strlen(s);
392         if(n > VtMaxStringSize) {
393                 vtSetError(EBigString);
394                 return 0;
395         }
396         buf[0] = n>>8;
397         buf[1] = n;
398         packetAppend(p, buf, 2);
399         packetAppend(p, (uchar*)s, n);
400         return 1;
401 }
402
403 int
404 vtConnect(VtSession *z, char *password)
405 {
406         char buf[VtMaxStringSize], *p, *ep, *prefix;
407         int i;
408
409         USED(password);
410         vtLock(z->lk);
411         if(z->cstate != VtStateAlloc) {
412                 vtSetError("bad session state");
413                 vtUnlock(z->lk);
414                 return 0;
415         }
416         if(z->fd < 0){
417                 vtSetError("%s", z->fderror);
418                 vtUnlock(z->lk);
419                 return 0;
420         }
421
422         /* be a little anal */
423         vtLock(z->inLock);
424         vtLock(z->outLock);
425
426         prefix = "venti-";
427         p = buf;
428         ep = buf + sizeof(buf);
429         p = seprint(p, ep, "%s", prefix);
430         p += strlen(p);
431         for(i=0; vtVersions[i].version; i++) {
432                 if(i != 0)
433                         *p++ = ':';
434                 p = seprint(p, ep, "%s", vtVersions[i].s);
435         }
436         p = seprint(p, ep, "-libventi\n");
437         assert(p-buf < sizeof(buf));
438         if(z->outHash)
439                 vtSha1Update(z->outHash, (uchar*)buf, p-buf);
440         if(!vtFdWrite(z->fd, (uchar*)buf, p-buf))
441                 goto Err;
442         
443         vtDebug(z, "version string out: %s", buf);
444
445         if(!vtVersionRead(z, prefix, &z->version))
446                 goto Err;
447                 
448         vtDebug(z, "version = %d: %s\n", z->version, vtGetVersion(z));
449
450         vtUnlock(z->inLock);
451         vtUnlock(z->outLock);
452         z->cstate = VtStateConnected;
453         vtUnlock(z->lk);
454
455         if(z->vtbl)
456                 return 1;
457
458         if(!vtHello(z))
459                 goto Err;
460         return 1;       
461 Err:
462         if(z->fd >= 0)
463                 vtFdClose(z->fd);
464         z->fd = -1;
465         vtUnlock(z->inLock);
466         vtUnlock(z->outLock);
467         z->cstate = VtStateClosed;
468         vtUnlock(z->lk);
469         return 0;       
470 }
471