]> git.lizzy.rs Git - plan9front.git/blob - sys/src/9/ip/esp.c
merge
[plan9front.git] / sys / src / 9 / ip / esp.c
1 /*
2  * Encapsulating Security Payload for IPsec for IPv4, rfc1827.
3  * extended to IPv6.
4  * rfc2104 defines hmac computation.
5  *      currently only implements tunnel mode.
6  * TODO: verify aes algorithms;
7  *      transport mode (host-to-host)
8  */
9 #include        "u.h"
10 #include        "../port/lib.h"
11 #include        "mem.h"
12 #include        "dat.h"
13 #include        "fns.h"
14 #include        "../port/error.h"
15
16 #include        "ip.h"
17 #include        "ipv6.h"
18 #include        "libsec.h"
19
20 #define BITS2BYTES(bi) (((bi) + BI2BY - 1) / BI2BY)
21 #define BYTES2BITS(by)  ((by) * BI2BY)
22
23 typedef struct Algorithm Algorithm;
24 typedef struct Esp4hdr Esp4hdr;
25 typedef struct Esp6hdr Esp6hdr;
26 typedef struct Espcb Espcb;
27 typedef struct Esphdr Esphdr;
28 typedef struct Esppriv Esppriv;
29 typedef struct Esptail Esptail;
30 typedef struct Userhdr Userhdr;
31
32 enum {
33         Encrypt,
34         Decrypt,
35
36         IP_ESPPROTO     = 50,   /* IP v4 and v6 protocol number */
37         Esp4hdrlen      = IP4HDR + 8,
38         Esp6hdrlen      = IP6HDR + 8,
39
40         Esptaillen      = 2,    /* does not include pad or auth data */
41         Userhdrlen      = 4,    /* user-visible header size - if enabled */
42
43         Desblk   = BITS2BYTES(64),
44         Des3keysz = BITS2BYTES(192),
45
46         Aesblk   = BITS2BYTES(128),
47         Aeskeysz = BITS2BYTES(128),
48 };
49
50 struct Esphdr
51 {
52         uchar   espspi[4];      /* Security parameter index */
53         uchar   espseq[4];      /* Sequence number */
54         uchar   payload[];
55 };
56
57 /*
58  * tunnel-mode (network-to-network, etc.) layout is:
59  * new IP hdrs | ESP hdr |
60  *       enc { orig IP hdrs | TCP/UDP hdr | user data | ESP trailer } | ESP ICV
61  *
62  * transport-mode (host-to-host) layout would be:
63  *      orig IP hdrs | ESP hdr |
64  *                      enc { TCP/UDP hdr | user data | ESP trailer } | ESP ICV
65  */
66 struct Esp4hdr
67 {
68         /* ipv4 header */
69         uchar   vihl;           /* Version and header length */
70         uchar   tos;            /* Type of service */
71         uchar   length[2];      /* packet length */
72         uchar   id[2];          /* Identification */
73         uchar   frag[2];        /* Fragment information */
74         uchar   Unused;
75         uchar   espproto;       /* Protocol */
76         uchar   espplen[2];     /* Header plus data length */
77         uchar   espsrc[4];      /* Ip source */
78         uchar   espdst[4];      /* Ip destination */
79
80         Esphdr;
81 };
82
83 /* tunnel-mode layout */
84 struct Esp6hdr
85 {
86         IPV6HDR;
87         Esphdr;
88 };
89
90 struct Esptail
91 {
92         uchar   pad;
93         uchar   nexthdr;
94 };
95
96 /* IP-version-dependent data */
97 typedef struct Versdep Versdep;
98 struct Versdep
99 {
100         ulong   version;
101         ulong   iphdrlen;
102         ulong   hdrlen;         /* iphdrlen + esp hdr len */
103         ulong   spi;
104         uchar   laddr[IPaddrlen];
105         uchar   raddr[IPaddrlen];
106 };
107
108 /* header as seen by the user */
109 struct Userhdr
110 {
111         uchar   nexthdr;        /* next protocol */
112         uchar   unused[3];
113 };
114
115 struct Esppriv
116 {
117         uvlong  in;
118         ulong   inerrors;
119 };
120
121 /*
122  *  protocol specific part of Conv
123  */
124 struct Espcb
125 {
126         int     incoming;
127         int     header;         /* user-level header */
128         ulong   spi;
129         ulong   seq;            /* last seq sent */
130         ulong   window;         /* for replay attacks */
131
132         char    *espalg;
133         void    *espstate;      /* other state for esp */
134         int     espivlen;       /* in bytes */
135         int     espblklen;
136         int     (*cipher)(Espcb*, uchar *buf, int len);
137
138         char    *ahalg;
139         void    *ahstate;       /* other state for esp */
140         int     ahlen;          /* auth data length in bytes */
141         int     ahblklen;
142         int     (*auth)(Espcb*, uchar *buf, int len, uchar *hash);
143         DigestState *ds;
144 };
145
146 struct Algorithm
147 {
148         char    *name;
149         int     keylen;         /* in bits */
150         void    (*init)(Espcb*, char* name, uchar *key, unsigned keylen);
151 };
152
153 static  Conv* convlookup(Proto *esp, ulong spi);
154 static  char *setalg(Espcb *ecb, char **f, int n, Algorithm *alg);
155 static  void espkick(void *x);
156
157 static  void nullespinit(Espcb*, char*, uchar *key, unsigned keylen);
158 static  void des3espinit(Espcb*, char*, uchar *key, unsigned keylen);
159 static  void aescbcespinit(Espcb*, char*, uchar *key, unsigned keylen);
160 static  void aesctrespinit(Espcb*, char*, uchar *key, unsigned keylen);
161 static  void desespinit(Espcb *ecb, char *name, uchar *k, unsigned n);
162
163 static  void nullahinit(Espcb*, char*, uchar *key, unsigned keylen);
164 static  void shaahinit(Espcb*, char*, uchar *key, unsigned keylen);
165 static  void aesahinit(Espcb*, char*, uchar *key, unsigned keylen);
166 static  void md5ahinit(Espcb*, char*, uchar *key, unsigned keylen);
167
168 static Algorithm espalg[] =
169 {
170         "null",         0,      nullespinit,
171         "des3_cbc",     192,    des3espinit,    /* new rfc2451, des-ede3 */
172         "aes_128_cbc",  128,    aescbcespinit,  /* new rfc3602 */
173         "aes_ctr",      128,    aesctrespinit,  /* new rfc3686 */
174         "des_56_cbc",   64,     desespinit,     /* rfc2405, deprecated */
175         /* rc4 was never required, was used in original bandt */
176 //      "rc4_128",      128,    rc4espinit,
177         nil,            0,      nil,
178 };
179
180 static Algorithm ahalg[] =
181 {
182         "null",         0,      nullahinit,
183         "hmac_sha1_96", 128,    shaahinit,      /* rfc2404 */
184         "aes_xcbc_mac_96", 128, aesahinit,      /* new rfc3566 */
185         "hmac_md5_96",  128,    md5ahinit,      /* rfc2403 */
186         nil,            0,      nil,
187 };
188
189 static char*
190 espconnect(Conv *c, char **argv, int argc)
191 {
192         char *p, *pp, *e = nil;
193         ulong spi;
194         Espcb *ecb = (Espcb*)c->ptcl;
195
196         switch(argc) {
197         default:
198                 e = "bad args to connect";
199                 break;
200         case 2:
201                 p = strchr(argv[1], '!');
202                 if(p == nil){
203                         e = "malformed address";
204                         break;
205                 }
206                 *p++ = 0;
207                 if (parseip(c->raddr, argv[1]) == -1) {
208                         e = Ebadip;
209                         break;
210                 }
211                 findlocalip(c->p->f, c->laddr, c->raddr);
212                 ecb->incoming = 0;
213                 ecb->seq = 0;
214                 if(strcmp(p, "*") == 0) {
215                         qlock(c->p);
216                         for(;;) {
217                                 spi = nrand(1<<16) + 256;
218                                 if(convlookup(c->p, spi) == nil)
219                                         break;
220                         }
221                         qunlock(c->p);
222                         ecb->spi = spi;
223                         ecb->incoming = 1;
224                         qhangup(c->wq, nil);
225                 } else {
226                         spi = strtoul(p, &pp, 10);
227                         if(pp == p) {
228                                 e = "malformed address";
229                                 break;
230                         }
231                         ecb->spi = spi;
232                         qhangup(c->rq, nil);
233                 }
234                 nullespinit(ecb, "null", nil, 0);
235                 nullahinit(ecb, "null", nil, 0);
236         }
237         Fsconnected(c, e);
238
239         return e;
240 }
241
242
243 static int
244 espstate(Conv *c, char *state, int n)
245 {
246         return snprint(state, n, "%s", c->inuse?"Open\n":"Closed\n");
247 }
248
249 static void
250 espcreate(Conv *c)
251 {
252         c->rq = qopen(64*1024, Qmsg, 0, 0);
253         c->wq = qopen(64*1024, Qkick, espkick, c);
254 }
255
256 static void
257 espclose(Conv *c)
258 {
259         Espcb *ecb;
260
261         qclose(c->rq);
262         qclose(c->wq);
263         qclose(c->eq);
264         ipmove(c->laddr, IPnoaddr);
265         ipmove(c->raddr, IPnoaddr);
266
267         ecb = (Espcb*)c->ptcl;
268         free(ecb->espstate);
269         free(ecb->ahstate);
270         memset(ecb, 0, sizeof(Espcb));
271 }
272
273 static int
274 convipvers(Conv *c)
275 {
276         if((memcmp(c->raddr, v4prefix, IPv4off) == 0 &&
277             memcmp(c->laddr, v4prefix, IPv4off) == 0) ||
278             ipcmp(c->raddr, IPnoaddr) == 0)
279                 return V4;
280         else
281                 return V6;
282 }
283
284 static int
285 pktipvers(Fs *f, Block **bpp)
286 {
287         if (*bpp == nil || BLEN(*bpp) == 0) {
288                 /* get enough to identify the IP version */
289                 *bpp = pullupblock(*bpp, IP4HDR);
290                 if(*bpp == nil) {
291                         netlog(f, Logesp, "esp: short packet\n");
292                         return 0;
293                 }
294         }
295         return (((Esp4hdr*)(*bpp)->rp)->vihl & 0xf0) == IP_VER4? V4: V6;
296 }
297
298 static void
299 getverslens(int version, Versdep *vp)
300 {
301         vp->version = version;
302         switch(vp->version) {
303         case V4:
304                 vp->iphdrlen = IP4HDR;
305                 vp->hdrlen   = Esp4hdrlen;
306                 break;
307         case V6:
308                 vp->iphdrlen = IP6HDR;
309                 vp->hdrlen   = Esp6hdrlen;
310                 break;
311         default:
312                 panic("esp: getverslens version %d wrong", version);
313         }
314 }
315
316 static void
317 getpktspiaddrs(uchar *pkt, Versdep *vp)
318 {
319         Esp4hdr *eh4;
320         Esp6hdr *eh6;
321
322         switch(vp->version) {
323         case V4:
324                 eh4 = (Esp4hdr*)pkt;
325                 v4tov6(vp->raddr, eh4->espsrc);
326                 v4tov6(vp->laddr, eh4->espdst);
327                 vp->spi = nhgetl(eh4->espspi);
328                 break;
329         case V6:
330                 eh6 = (Esp6hdr*)pkt;
331                 ipmove(vp->raddr, eh6->src);
332                 ipmove(vp->laddr, eh6->dst);
333                 vp->spi = nhgetl(eh6->espspi);
334                 break;
335         default:
336                 panic("esp: getpktspiaddrs vp->version %ld wrong", vp->version);
337         }
338 }
339
340 /*
341  * encapsulate next IP packet on x's write queue in IP/ESP packet
342  * and initiate output of the result.
343  */
344 static void
345 espkick(void *x)
346 {
347         int nexthdr, payload, pad, align;
348         uchar *auth;
349         Block *bp;
350         Conv *c = x;
351         Esp4hdr *eh4;
352         Esp6hdr *eh6;
353         Espcb *ecb;
354         Esptail *et;
355         Userhdr *uh;
356         Versdep vers;
357
358         getverslens(convipvers(c), &vers);
359         bp = qget(c->wq);
360         if(bp == nil)
361                 return;
362
363         qlock(c);
364         ecb = c->ptcl;
365
366         if(ecb->header) {
367                 /* make sure the message has a User header */
368                 bp = pullupblock(bp, Userhdrlen);
369                 if(bp == nil) {
370                         qunlock(c);
371                         return;
372                 }
373                 uh = (Userhdr*)bp->rp;
374                 nexthdr = uh->nexthdr;
375                 bp->rp += Userhdrlen;
376         } else {
377                 nexthdr = 0;    /* what should this be? */
378         }
379
380         payload = BLEN(bp) + ecb->espivlen;
381
382         /* Make space to fit ip header */
383         bp = padblock(bp, vers.hdrlen + ecb->espivlen);
384         getpktspiaddrs(bp->rp, &vers);
385
386         align = 4;
387         if(ecb->espblklen > align)
388                 align = ecb->espblklen;
389         if(align % ecb->ahblklen != 0)
390                 panic("espkick: ahblklen is important after all");
391         pad = (align-1) - (payload + Esptaillen-1)%align;
392
393         /*
394          * Make space for tail
395          * this is done by calling padblock with a negative size
396          * Padblock does not change bp->wp!
397          */
398         bp = padblock(bp, -(pad+Esptaillen+ecb->ahlen));
399         bp->wp += pad+Esptaillen+ecb->ahlen;
400
401         et = (Esptail*)(bp->rp + vers.hdrlen + payload + pad);
402
403         /* fill in tail */
404         et->pad = pad;
405         et->nexthdr = nexthdr;
406
407         /* encrypt the payload */
408         ecb->cipher(ecb, bp->rp + vers.hdrlen, payload + pad + Esptaillen);
409         auth = bp->rp + vers.hdrlen + payload + pad + Esptaillen;
410
411         /* fill in head; construct a new IP header and an ESP header */
412         if (vers.version == V4) {
413                 eh4 = (Esp4hdr *)bp->rp;
414                 eh4->vihl = IP_VER4;
415                 v6tov4(eh4->espsrc, c->laddr);
416                 v6tov4(eh4->espdst, c->raddr);
417                 eh4->espproto = IP_ESPPROTO;
418                 eh4->frag[0] = 0;
419                 eh4->frag[1] = 0;
420
421                 hnputl(eh4->espspi, ecb->spi);
422                 hnputl(eh4->espseq, ++ecb->seq);
423         } else {
424                 eh6 = (Esp6hdr *)bp->rp;
425                 eh6->vcf[0] = IP_VER6;
426                 ipmove(eh6->src, c->laddr);
427                 ipmove(eh6->dst, c->raddr);
428                 eh6->proto = IP_ESPPROTO;
429
430                 hnputl(eh6->espspi, ecb->spi);
431                 hnputl(eh6->espseq, ++ecb->seq);
432         }
433
434         /* compute secure hash */
435         ecb->auth(ecb, bp->rp + vers.iphdrlen, (vers.hdrlen - vers.iphdrlen) +
436                 payload + pad + Esptaillen, auth);
437
438         qunlock(c);
439         /* print("esp: pass down: %uld\n", BLEN(bp)); */
440         if (vers.version == V4)
441                 ipoput4(c->p->f, bp, 0, c->ttl, c->tos, c);
442         else
443                 ipoput6(c->p->f, bp, 0, c->ttl, c->tos, c);
444 }
445
446 /*
447  * decapsulate IP packet from IP/ESP packet in bp and
448  * pass the result up the spi's Conv's read queue.
449  */
450 void
451 espiput(Proto *esp, Ipifc*, Block *bp)
452 {
453         int payload, nexthdr;
454         uchar *auth, *espspi;
455         Conv *c;
456         Espcb *ecb;
457         Esptail *et;
458         Fs *f;
459         Userhdr *uh;
460         Versdep vers;
461
462         f = esp->f;
463
464         getverslens(pktipvers(f, &bp), &vers);
465
466         bp = pullupblock(bp, vers.hdrlen + Esptaillen);
467         if(bp == nil) {
468                 netlog(f, Logesp, "esp: short packet\n");
469                 return;
470         }
471         getpktspiaddrs(bp->rp, &vers);
472
473         qlock(esp);
474         /* Look for a conversation structure for this port */
475         c = convlookup(esp, vers.spi);
476         if(c == nil) {
477                 qunlock(esp);
478                 netlog(f, Logesp, "esp: no conv %I -> %I!%lud\n", vers.raddr,
479                         vers.laddr, vers.spi);
480                 icmpnoconv(f, bp);
481                 freeblist(bp);
482                 return;
483         }
484
485         qlock(c);
486         qunlock(esp);
487
488         ecb = c->ptcl;
489         /* too hard to do decryption/authentication on block lists */
490         if(bp->next)
491                 bp = concatblock(bp);
492
493         if(BLEN(bp) < vers.hdrlen + ecb->espivlen + Esptaillen + ecb->ahlen) {
494                 qunlock(c);
495                 netlog(f, Logesp, "esp: short block %I -> %I!%lud\n", vers.raddr,
496                         vers.laddr, vers.spi);
497                 freeb(bp);
498                 return;
499         }
500
501         auth = bp->wp - ecb->ahlen;
502         espspi = vers.version == V4?    ((Esp4hdr*)bp->rp)->espspi:
503                                         ((Esp6hdr*)bp->rp)->espspi;
504
505         /* compute secure hash and authenticate */
506         if(!ecb->auth(ecb, espspi, auth - espspi, auth)) {
507                 qunlock(c);
508 print("esp: bad auth %I -> %I!%ld\n", vers.raddr, vers.laddr, vers.spi);
509                 netlog(f, Logesp, "esp: bad auth %I -> %I!%lud\n", vers.raddr,
510                         vers.laddr, vers.spi);
511                 freeb(bp);
512                 return;
513         }
514
515         payload = BLEN(bp) - vers.hdrlen - ecb->ahlen;
516         if(payload <= 0 || payload % 4 != 0 || payload % ecb->espblklen != 0) {
517                 qunlock(c);
518                 netlog(f, Logesp, "esp: bad length %I -> %I!%lud payload=%d BLEN=%lud\n",
519                         vers.raddr, vers.laddr, vers.spi, payload, BLEN(bp));
520                 freeb(bp);
521                 return;
522         }
523
524         /* decrypt payload */
525         if(!ecb->cipher(ecb, bp->rp + vers.hdrlen, payload)) {
526                 qunlock(c);
527 print("esp: cipher failed %I -> %I!%ld: %s\n", vers.raddr, vers.laddr, vers.spi, up->errstr);
528                 netlog(f, Logesp, "esp: cipher failed %I -> %I!%lud: %s\n",
529                         vers.raddr, vers.laddr, vers.spi, up->errstr);
530                 freeb(bp);
531                 return;
532         }
533
534         payload -= Esptaillen;
535         et = (Esptail*)(bp->rp + vers.hdrlen + payload);
536         payload -= et->pad + ecb->espivlen;
537         nexthdr = et->nexthdr;
538         if(payload <= 0) {
539                 qunlock(c);
540                 netlog(f, Logesp, "esp: short packet after decrypt %I -> %I!%lud\n",
541                         vers.raddr, vers.laddr, vers.spi);
542                 freeb(bp);
543                 return;
544         }
545
546         /* trim packet */
547         bp->rp += vers.hdrlen + ecb->espivlen; /* toss original IP & ESP hdrs */
548         bp->wp = bp->rp + payload;
549         if(ecb->header) {
550                 /* assume Userhdrlen < Esp4hdrlen < Esp6hdrlen */
551                 bp->rp -= Userhdrlen;
552                 uh = (Userhdr*)bp->rp;
553                 memset(uh, 0, Userhdrlen);
554                 uh->nexthdr = nexthdr;
555         }
556
557         /* ingress filtering here? */
558
559         if(qfull(c->rq)){
560                 netlog(f, Logesp, "esp: qfull %I -> %I.%uld\n", vers.raddr,
561                         vers.laddr, vers.spi);
562                 freeblist(bp);
563         }else {
564 //              print("esp: pass up: %uld\n", BLEN(bp));
565                 qpass(c->rq, bp);       /* pass packet up the read queue */
566         }
567
568         qunlock(c);
569 }
570
571 char*
572 espctl(Conv *c, char **f, int n)
573 {
574         Espcb *ecb = c->ptcl;
575         char *e = nil;
576
577         if(strcmp(f[0], "esp") == 0)
578                 e = setalg(ecb, f, n, espalg);
579         else if(strcmp(f[0], "ah") == 0)
580                 e = setalg(ecb, f, n, ahalg);
581         else if(strcmp(f[0], "header") == 0)
582                 ecb->header = 1;
583         else if(strcmp(f[0], "noheader") == 0)
584                 ecb->header = 0;
585         else
586                 e = "unknown control request";
587         return e;
588 }
589
590 /* called from icmp(v6) for unreachable hosts, time exceeded, etc. */
591 void
592 espadvise(Proto *esp, Block *bp, char *msg)
593 {
594         Conv *c;
595         Versdep vers;
596
597         getverslens(pktipvers(esp->f, &bp), &vers);
598         getpktspiaddrs(bp->rp, &vers);
599
600         qlock(esp);
601         c = convlookup(esp, vers.spi);
602         if(c != nil) {
603                 qhangup(c->rq, msg);
604                 qhangup(c->wq, msg);
605         }
606         qunlock(esp);
607         freeblist(bp);
608 }
609
610 int
611 espstats(Proto *esp, char *buf, int len)
612 {
613         Esppriv *upriv;
614
615         upriv = esp->priv;
616         return snprint(buf, len, "%llud %lud\n",
617                 upriv->in,
618                 upriv->inerrors);
619 }
620
621 static int
622 esplocal(Conv *c, char *buf, int len)
623 {
624         Espcb *ecb = c->ptcl;
625         int n;
626
627         qlock(c);
628         if(ecb->incoming)
629                 n = snprint(buf, len, "%I!%uld\n", c->laddr, ecb->spi);
630         else
631                 n = snprint(buf, len, "%I\n", c->laddr);
632         qunlock(c);
633         return n;
634 }
635
636 static int
637 espremote(Conv *c, char *buf, int len)
638 {
639         Espcb *ecb = c->ptcl;
640         int n;
641
642         qlock(c);
643         if(ecb->incoming)
644                 n = snprint(buf, len, "%I\n", c->raddr);
645         else
646                 n = snprint(buf, len, "%I!%uld\n", c->raddr, ecb->spi);
647         qunlock(c);
648         return n;
649 }
650
651 static  Conv*
652 convlookup(Proto *esp, ulong spi)
653 {
654         Conv *c, **p;
655         Espcb *ecb;
656
657         for(p=esp->conv; *p; p++){
658                 c = *p;
659                 ecb = c->ptcl;
660                 if(ecb->incoming && ecb->spi == spi)
661                         return c;
662         }
663         return nil;
664 }
665
666 static char *
667 setalg(Espcb *ecb, char **f, int n, Algorithm *alg)
668 {
669         uchar *key;
670         int c, nbyte, nchar;
671         uint i;
672
673         if(n < 2 || n > 3)
674                 return "bad format";
675         for(; alg->name; alg++)
676                 if(strcmp(f[1], alg->name) == 0)
677                         break;
678         if(alg->name == nil)
679                 return "unknown algorithm";
680
681         nbyte = (alg->keylen + 7) >> 3;
682         if (n == 2)
683                 nchar = 0;
684         else
685                 nchar = strlen(f[2]);
686         if(nchar != 2 * nbyte)                  /* TODO: maybe < is ok */
687                 return "key not required length";
688         /* convert hex digits from ascii, in place */
689         for(i=0; i<nchar; i++) {
690                 c = f[2][i];
691                 if(c >= '0' && c <= '9')
692                         f[2][i] -= '0';
693                 else if(c >= 'a' && c <= 'f')
694                         f[2][i] -= 'a'-10;
695                 else if(c >= 'A' && c <= 'F')
696                         f[2][i] -= 'A'-10;
697                 else
698                         return "non-hex character in key";
699         }
700         /* collapse hex digits into complete bytes in reverse order in key */
701         key = smalloc(nbyte);
702         for(i = 0; i < nchar && i/2 < nbyte; i++) {
703                 c = f[2][nchar-i-1];
704                 if(i&1)
705                         c <<= 4;
706                 key[i/2] |= c;
707         }
708
709         alg->init(ecb, alg->name, key, alg->keylen);
710         free(key);
711         return nil;
712 }
713
714
715 /*
716  * null encryption
717  */
718
719 static int
720 nullcipher(Espcb*, uchar*, int)
721 {
722         return 1;
723 }
724
725 static void
726 nullespinit(Espcb *ecb, char *name, uchar*, unsigned)
727 {
728         ecb->espalg = name;
729         ecb->espblklen = 1;
730         ecb->espivlen = 0;
731         ecb->cipher = nullcipher;
732 }
733
734 static int
735 nullauth(Espcb*, uchar*, int, uchar*)
736 {
737         return 1;
738 }
739
740 static void
741 nullahinit(Espcb *ecb, char *name, uchar*, unsigned)
742 {
743         ecb->ahalg = name;
744         ecb->ahblklen = 1;
745         ecb->ahlen = 0;
746         ecb->auth = nullauth;
747 }
748
749
750 /*
751  * sha1
752  */
753
754 static void
755 seanq_hmac_sha1(uchar hash[SHA1dlen], uchar *t, long tlen, uchar *key, long klen)
756 {
757         int i;
758         uchar ipad[Hmacblksz+1], opad[Hmacblksz+1], innerhash[SHA1dlen];
759         DigestState *digest;
760
761         memset(ipad, 0x36, Hmacblksz);
762         memset(opad, 0x5c, Hmacblksz);
763         ipad[Hmacblksz] = opad[Hmacblksz] = 0;
764         for(i = 0; i < klen; i++){
765                 ipad[i] ^= key[i];
766                 opad[i] ^= key[i];
767         }
768         digest = sha1(ipad, Hmacblksz, nil, nil);
769         sha1(t, tlen, innerhash, digest);
770         digest = sha1(opad, Hmacblksz, nil, nil);
771         sha1(innerhash, SHA1dlen, hash, digest);
772 }
773
774 static int
775 shaauth(Espcb *ecb, uchar *t, int tlen, uchar *auth)
776 {
777         int r;
778         uchar hash[SHA1dlen];
779
780         memset(hash, 0, SHA1dlen);
781         seanq_hmac_sha1(hash, t, tlen, (uchar*)ecb->ahstate, BITS2BYTES(128));
782         r = memcmp(auth, hash, ecb->ahlen) == 0;
783         memmove(auth, hash, ecb->ahlen);
784         return r;
785 }
786
787 static void
788 shaahinit(Espcb *ecb, char *name, uchar *key, unsigned klen)
789 {
790         if(klen != 128)
791                 panic("shaahinit: bad keylen");
792         klen /= BI2BY;
793
794         ecb->ahalg = name;
795         ecb->ahblklen = 1;
796         ecb->ahlen = BITS2BYTES(96);
797         ecb->auth = shaauth;
798         ecb->ahstate = smalloc(klen);
799         memmove(ecb->ahstate, key, klen);
800 }
801
802
803 /*
804  * aes
805  */
806
807 /* ah_aes_xcbc_mac_96, rfc3566 */
808 static int
809 aesahauth(Espcb *ecb, uchar *t, int tlen, uchar *auth)
810 {
811         int r;
812         uchar hash[AESdlen];
813
814         memset(hash, 0, AESdlen);
815         ecb->ds = hmac_aes(t, tlen, (uchar*)ecb->ahstate, BITS2BYTES(96), hash,
816                 ecb->ds);
817         r = memcmp(auth, hash, ecb->ahlen) == 0;
818         memmove(auth, hash, ecb->ahlen);
819         return r;
820 }
821
822 static void
823 aesahinit(Espcb *ecb, char *name, uchar *key, unsigned klen)
824 {
825         if(klen != 128)
826                 panic("aesahinit: keylen not 128");
827         klen /= BI2BY;
828
829         ecb->ahalg = name;
830         ecb->ahblklen = 1;
831         ecb->ahlen = BITS2BYTES(96);
832         ecb->auth = aesahauth;
833         ecb->ahstate = smalloc(klen);
834         memmove(ecb->ahstate, key, klen);
835 }
836
837 static int
838 aescbccipher(Espcb *ecb, uchar *p, int n)       /* 128-bit blocks */
839 {
840         uchar tmp[AESbsize], q[AESbsize];
841         uchar *pp, *tp, *ip, *eip, *ep;
842         AESstate *ds = ecb->espstate;
843
844         ep = p + n;
845         if(ecb->incoming) {
846                 memmove(ds->ivec, p, AESbsize);
847                 p += AESbsize;
848                 while(p < ep){
849                         memmove(tmp, p, AESbsize);
850                         aes_decrypt(ds->dkey, ds->rounds, p, q);
851                         memmove(p, q, AESbsize);
852                         tp = tmp;
853                         ip = ds->ivec;
854                         for(eip = ip + AESbsize; ip < eip; ){
855                                 *p++ ^= *ip;
856                                 *ip++ = *tp++;
857                         }
858                 }
859         } else {
860                 memmove(p, ds->ivec, AESbsize);
861                 for(p += AESbsize; p < ep; p += AESbsize){
862                         pp = p;
863                         ip = ds->ivec;
864                         for(eip = ip + AESbsize; ip < eip; )
865                                 *pp++ ^= *ip++;
866                         aes_encrypt(ds->ekey, ds->rounds, p, q);
867                         memmove(ds->ivec, q, AESbsize);
868                         memmove(p, q, AESbsize);
869                 }
870         }
871         return 1;
872 }
873
874 static void
875 aescbcespinit(Espcb *ecb, char *name, uchar *k, unsigned n)
876 {
877         uchar key[Aeskeysz], ivec[Aeskeysz];
878         int i;
879
880         n = BITS2BYTES(n);
881         if(n > Aeskeysz)
882                 n = Aeskeysz;
883         memset(key, 0, sizeof(key));
884         memmove(key, k, n);
885         for(i = 0; i < Aeskeysz; i++)
886                 ivec[i] = nrand(256);
887         ecb->espalg = name;
888         ecb->espblklen = Aesblk;
889         ecb->espivlen = Aesblk;
890         ecb->cipher = aescbccipher;
891         ecb->espstate = smalloc(sizeof(AESstate));
892         setupAESstate(ecb->espstate, key, n /* keybytes */, ivec);
893 }
894
895 static int
896 aesctrcipher(Espcb *ecb, uchar *p, int n)       /* 128-bit blocks */
897 {
898         uchar tmp[AESbsize], q[AESbsize];
899         uchar *pp, *tp, *ip, *eip, *ep;
900         AESstate *ds = ecb->espstate;
901
902         ep = p + n;
903         if(ecb->incoming) {
904                 memmove(ds->ivec, p, AESbsize);
905                 p += AESbsize;
906                 while(p < ep){
907                         memmove(tmp, p, AESbsize);
908                         aes_decrypt(ds->dkey, ds->rounds, p, q);
909                         memmove(p, q, AESbsize);
910                         tp = tmp;
911                         ip = ds->ivec;
912                         for(eip = ip + AESbsize; ip < eip; ){
913                                 *p++ ^= *ip;
914                                 *ip++ = *tp++;
915                         }
916                 }
917         } else {
918                 memmove(p, ds->ivec, AESbsize);
919                 for(p += AESbsize; p < ep; p += AESbsize){
920                         pp = p;
921                         ip = ds->ivec;
922                         for(eip = ip + AESbsize; ip < eip; )
923                                 *pp++ ^= *ip++;
924                         aes_encrypt(ds->ekey, ds->rounds, p, q);
925                         memmove(ds->ivec, q, AESbsize);
926                         memmove(p, q, AESbsize);
927                 }
928         }
929         return 1;
930 }
931
932 static void
933 aesctrespinit(Espcb *ecb, char *name, uchar *k, unsigned n)
934 {
935         uchar key[Aesblk], ivec[Aesblk];
936         int i;
937
938         n = BITS2BYTES(n);
939         if(n > Aeskeysz)
940                 n = Aeskeysz;
941         memset(key, 0, sizeof(key));
942         memmove(key, k, n);
943         for(i = 0; i < Aesblk; i++)
944                 ivec[i] = nrand(256);
945         ecb->espalg = name;
946         ecb->espblklen = Aesblk;
947         ecb->espivlen = Aesblk;
948         ecb->cipher = aesctrcipher;
949         ecb->espstate = smalloc(sizeof(AESstate));
950         setupAESstate(ecb->espstate, key, n /* keybytes */, ivec);
951 }
952
953
954 /*
955  * md5
956  */
957
958 static void
959 seanq_hmac_md5(uchar hash[MD5dlen], uchar *t, long tlen, uchar *key, long klen)
960 {
961         int i;
962         uchar ipad[Hmacblksz+1], opad[Hmacblksz+1], innerhash[MD5dlen];
963         DigestState *digest;
964
965         memset(ipad, 0x36, Hmacblksz);
966         memset(opad, 0x5c, Hmacblksz);
967         ipad[Hmacblksz] = opad[Hmacblksz] = 0;
968         for(i = 0; i < klen; i++){
969                 ipad[i] ^= key[i];
970                 opad[i] ^= key[i];
971         }
972         digest = md5(ipad, Hmacblksz, nil, nil);
973         md5(t, tlen, innerhash, digest);
974         digest = md5(opad, Hmacblksz, nil, nil);
975         md5(innerhash, MD5dlen, hash, digest);
976 }
977
978 static int
979 md5auth(Espcb *ecb, uchar *t, int tlen, uchar *auth)
980 {
981         uchar hash[MD5dlen];
982         int r;
983
984         memset(hash, 0, MD5dlen);
985         seanq_hmac_md5(hash, t, tlen, (uchar*)ecb->ahstate, BITS2BYTES(128));
986         r = memcmp(auth, hash, ecb->ahlen) == 0;
987         memmove(auth, hash, ecb->ahlen);
988         return r;
989 }
990
991 static void
992 md5ahinit(Espcb *ecb, char *name, uchar *key, unsigned klen)
993 {
994         if(klen != 128)
995                 panic("md5ahinit: bad keylen");
996         klen = BITS2BYTES(klen);
997         ecb->ahalg = name;
998         ecb->ahblklen = 1;
999         ecb->ahlen = BITS2BYTES(96);
1000         ecb->auth = md5auth;
1001         ecb->ahstate = smalloc(klen);
1002         memmove(ecb->ahstate, key, klen);
1003 }
1004
1005
1006 /*
1007  * des, single and triple
1008  */
1009
1010 static int
1011 descipher(Espcb *ecb, uchar *p, int n)
1012 {
1013         DESstate *ds = ecb->espstate;
1014
1015         if(ecb->incoming) {
1016                 memmove(ds->ivec, p, Desblk);
1017                 desCBCdecrypt(p + Desblk, n - Desblk, ds);
1018         } else {
1019                 memmove(p, ds->ivec, Desblk);
1020                 desCBCencrypt(p + Desblk, n - Desblk, ds);
1021         }
1022         return 1;
1023 }
1024
1025 static int
1026 des3cipher(Espcb *ecb, uchar *p, int n)
1027 {
1028         DES3state *ds = ecb->espstate;
1029
1030         if(ecb->incoming) {
1031                 memmove(ds->ivec, p, Desblk);
1032                 des3CBCdecrypt(p + Desblk, n - Desblk, ds);
1033         } else {
1034                 memmove(p, ds->ivec, Desblk);
1035                 des3CBCencrypt(p + Desblk, n - Desblk, ds);
1036         }
1037         return 1;
1038 }
1039
1040 static void
1041 desespinit(Espcb *ecb, char *name, uchar *k, unsigned n)
1042 {
1043         uchar key[Desblk], ivec[Desblk];
1044         int i;
1045
1046         n = BITS2BYTES(n);
1047         if(n > Desblk)
1048                 n = Desblk;
1049         memset(key, 0, sizeof(key));
1050         memmove(key, k, n);
1051         for(i = 0; i < Desblk; i++)
1052                 ivec[i] = nrand(256);
1053         ecb->espalg = name;
1054         ecb->espblklen = Desblk;
1055         ecb->espivlen = Desblk;
1056
1057         ecb->cipher = descipher;
1058         ecb->espstate = smalloc(sizeof(DESstate));
1059         setupDESstate(ecb->espstate, key, ivec);
1060 }
1061
1062 static void
1063 des3espinit(Espcb *ecb, char *name, uchar *k, unsigned n)
1064 {
1065         uchar key[3][Desblk], ivec[Desblk];
1066         int i;
1067
1068         n = BITS2BYTES(n);
1069         if(n > Des3keysz)
1070                 n = Des3keysz;
1071         memset(key, 0, sizeof(key));
1072         memmove(key, k, n);
1073         for(i = 0; i < Desblk; i++)
1074                 ivec[i] = nrand(256);
1075         ecb->espalg = name;
1076         ecb->espblklen = Desblk;
1077         ecb->espivlen = Desblk;
1078
1079         ecb->cipher = des3cipher;
1080         ecb->espstate = smalloc(sizeof(DES3state));
1081         setupDES3state(ecb->espstate, key, ivec);
1082 }
1083
1084
1085 /*
1086  * interfacing to devip
1087  */
1088 void
1089 espinit(Fs *fs)
1090 {
1091         Proto *esp;
1092
1093         esp = smalloc(sizeof(Proto));
1094         esp->priv = smalloc(sizeof(Esppriv));
1095         esp->name = "esp";
1096         esp->connect = espconnect;
1097         esp->announce = nil;
1098         esp->ctl = espctl;
1099         esp->state = espstate;
1100         esp->create = espcreate;
1101         esp->close = espclose;
1102         esp->rcv = espiput;
1103         esp->advise = espadvise;
1104         esp->stats = espstats;
1105         esp->local = esplocal;
1106         esp->remote = espremote;
1107         esp->ipproto = IP_ESPPROTO;
1108         esp->nc = Nchans;
1109         esp->ptclsize = sizeof(Espcb);
1110
1111         Fsproto(fs, esp);
1112 }