]> git.lizzy.rs Git - plan9front.git/blob - sys/src/9/ip/esp.c
set router R-flag when sendra is active for neighbor advertisement
[plan9front.git] / sys / src / 9 / ip / esp.c
1 /*
2  * Encapsulating Security Payload for IPsec for IPv4, rfc1827.
3  * extended to IPv6.
4  * rfc2104 defines hmac computation.
5  *      currently only implements tunnel mode.
6  * TODO: verify aes algorithms;
7  *      transport mode (host-to-host)
8  */
9 #include        "u.h"
10 #include        "../port/lib.h"
11 #include        "mem.h"
12 #include        "dat.h"
13 #include        "fns.h"
14 #include        "../port/error.h"
15
16 #include        "ip.h"
17 #include        "ipv6.h"
18 #include        <libsec.h>
19
20 #define BITS2BYTES(bi) (((bi) + BI2BY - 1) / BI2BY)
21 #define BYTES2BITS(by)  ((by) * BI2BY)
22
23 typedef struct Algorithm Algorithm;
24 typedef struct Esp4hdr Esp4hdr;
25 typedef struct Esp6hdr Esp6hdr;
26 typedef struct Espcb Espcb;
27 typedef struct Esphdr Esphdr;
28 typedef struct Esppriv Esppriv;
29 typedef struct Esptail Esptail;
30 typedef struct Userhdr Userhdr;
31
32 enum {
33         Encrypt,
34         Decrypt,
35
36         IP_ESPPROTO     = 50,   /* IP v4 and v6 protocol number */
37         Esp4hdrlen      = IP4HDR + 8,
38         Esp6hdrlen      = IP6HDR + 8,
39
40         Esptaillen      = 2,    /* does not include pad or auth data */
41         Userhdrlen      = 4,    /* user-visible header size - if enabled */
42
43         Desblk   = BITS2BYTES(64),
44         Des3keysz = BITS2BYTES(192),
45
46         Aesblk   = BITS2BYTES(128),
47         Aeskeysz = BITS2BYTES(128),
48 };
49
50 struct Esphdr
51 {
52         uchar   espspi[4];      /* Security parameter index */
53         uchar   espseq[4];      /* Sequence number */
54         uchar   payload[];
55 };
56
57 /*
58  * tunnel-mode (network-to-network, etc.) layout is:
59  * new IP hdrs | ESP hdr |
60  *       enc { orig IP hdrs | TCP/UDP hdr | user data | ESP trailer } | ESP ICV
61  *
62  * transport-mode (host-to-host) layout would be:
63  *      orig IP hdrs | ESP hdr |
64  *                      enc { TCP/UDP hdr | user data | ESP trailer } | ESP ICV
65  */
66 struct Esp4hdr
67 {
68         /* ipv4 header */
69         uchar   vihl;           /* Version and header length */
70         uchar   tos;            /* Type of service */
71         uchar   length[2];      /* packet length */
72         uchar   id[2];          /* Identification */
73         uchar   frag[2];        /* Fragment information */
74         uchar   Unused;
75         uchar   espproto;       /* Protocol */
76         uchar   espplen[2];     /* Header plus data length */
77         uchar   espsrc[4];      /* Ip source */
78         uchar   espdst[4];      /* Ip destination */
79
80         Esphdr;
81 };
82
83 /* tunnel-mode layout */
84 struct Esp6hdr
85 {
86         IPV6HDR;
87         Esphdr;
88 };
89
90 struct Esptail
91 {
92         uchar   pad;
93         uchar   nexthdr;
94 };
95
96 /* IP-version-dependent data */
97 typedef struct Versdep Versdep;
98 struct Versdep
99 {
100         ulong   version;
101         ulong   iphdrlen;
102         ulong   hdrlen;         /* iphdrlen + esp hdr len */
103         ulong   spi;
104         uchar   laddr[IPaddrlen];
105         uchar   raddr[IPaddrlen];
106 };
107
108 /* header as seen by the user */
109 struct Userhdr
110 {
111         uchar   nexthdr;        /* next protocol */
112         uchar   unused[3];
113 };
114
115 struct Esppriv
116 {
117         uvlong  in;
118         ulong   inerrors;
119 };
120
121 /*
122  *  protocol specific part of Conv
123  */
124 struct Espcb
125 {
126         int     incoming;
127         int     header;         /* user-level header */
128         ulong   spi;
129         ulong   seq;            /* last seq sent */
130         ulong   window;         /* for replay attacks */
131
132         char    *espalg;
133         void    *espstate;      /* other state for esp */
134         int     espivlen;       /* in bytes */
135         int     espblklen;
136         int     (*cipher)(Espcb*, uchar *buf, int len);
137
138         char    *ahalg;
139         void    *ahstate;       /* other state for esp */
140         int     ahlen;          /* auth data length in bytes */
141         int     ahblklen;
142         int     (*auth)(Espcb*, uchar *buf, int len, uchar *hash);
143         DigestState *ds;
144 };
145
146 struct Algorithm
147 {
148         char    *name;
149         int     keylen;         /* in bits */
150         void    (*init)(Espcb*, char* name, uchar *key, unsigned keylen);
151 };
152
153 static  Conv* convlookup(Proto *esp, ulong spi);
154 static  char *setalg(Espcb *ecb, char **f, int n, Algorithm *alg);
155 static  void espkick(void *x);
156
157 static  void nullespinit(Espcb*, char*, uchar *key, unsigned keylen);
158 static  void des3espinit(Espcb*, char*, uchar *key, unsigned keylen);
159 static  void aescbcespinit(Espcb*, char*, uchar *key, unsigned keylen);
160 static  void aesctrespinit(Espcb*, char*, uchar *key, unsigned keylen);
161 static  void desespinit(Espcb *ecb, char *name, uchar *k, unsigned n);
162
163 static  void nullahinit(Espcb*, char*, uchar *key, unsigned keylen);
164 static  void shaahinit(Espcb*, char*, uchar *key, unsigned keylen);
165 static  void md5ahinit(Espcb*, char*, uchar *key, unsigned keylen);
166
167 static Algorithm espalg[] =
168 {
169         "null",         0,      nullespinit,
170         "des3_cbc",     192,    des3espinit,    /* new rfc2451, des-ede3 */
171         "aes_128_cbc",  128,    aescbcespinit,  /* new rfc3602 */
172         "aes_ctr",      128,    aesctrespinit,  /* new rfc3686 */
173         "des_56_cbc",   64,     desespinit,     /* rfc2405, deprecated */
174         nil,            0,      nil,
175 };
176
177 static Algorithm ahalg[] =
178 {
179         "null",         0,      nullahinit,
180         "hmac_sha1_96", 128,    shaahinit,      /* rfc2404 */
181         "hmac_md5_96",  128,    md5ahinit,      /* rfc2403 */
182         nil,            0,      nil,
183 };
184
185 static char*
186 espconnect(Conv *c, char **argv, int argc)
187 {
188         char *p, *pp, *e = nil;
189         ulong spi;
190         Espcb *ecb = (Espcb*)c->ptcl;
191
192         switch(argc) {
193         default:
194                 e = "bad args to connect";
195                 break;
196         case 2:
197                 p = strchr(argv[1], '!');
198                 if(p == nil){
199                         e = "malformed address";
200                         break;
201                 }
202                 *p++ = 0;
203                 if (parseip(c->raddr, argv[1]) == -1) {
204                         e = Ebadip;
205                         break;
206                 }
207                 findlocalip(c->p->f, c->laddr, c->raddr);
208                 ecb->incoming = 0;
209                 ecb->seq = 0;
210                 if(strcmp(p, "*") == 0) {
211                         qlock(c->p);
212                         for(;;) {
213                                 spi = nrand(1<<16) + 256;
214                                 if(convlookup(c->p, spi) == nil)
215                                         break;
216                         }
217                         qunlock(c->p);
218                         ecb->spi = spi;
219                         ecb->incoming = 1;
220                         qhangup(c->wq, nil);
221                 } else {
222                         spi = strtoul(p, &pp, 10);
223                         if(pp == p) {
224                                 e = "malformed address";
225                                 break;
226                         }
227                         ecb->spi = spi;
228                         qhangup(c->rq, nil);
229                 }
230                 nullespinit(ecb, "null", nil, 0);
231                 nullahinit(ecb, "null", nil, 0);
232         }
233         Fsconnected(c, e);
234
235         return e;
236 }
237
238
239 static int
240 espstate(Conv *c, char *state, int n)
241 {
242         return snprint(state, n, "%s", c->inuse?"Open\n":"Closed\n");
243 }
244
245 static void
246 espcreate(Conv *c)
247 {
248         c->rq = qopen(64*1024, Qmsg, 0, 0);
249         c->wq = qopen(64*1024, Qkick, espkick, c);
250 }
251
252 static void
253 espclose(Conv *c)
254 {
255         Espcb *ecb;
256
257         qclose(c->rq);
258         qclose(c->wq);
259         qclose(c->eq);
260         ipmove(c->laddr, IPnoaddr);
261         ipmove(c->raddr, IPnoaddr);
262
263         ecb = (Espcb*)c->ptcl;
264         secfree(ecb->espstate);
265         secfree(ecb->ahstate);
266         memset(ecb, 0, sizeof(Espcb));
267 }
268
269 static int
270 convipvers(Conv *c)
271 {
272         if((memcmp(c->raddr, v4prefix, IPv4off) == 0 &&
273             memcmp(c->laddr, v4prefix, IPv4off) == 0) ||
274             ipcmp(c->raddr, IPnoaddr) == 0)
275                 return V4;
276         else
277                 return V6;
278 }
279
280 static int
281 pktipvers(Fs *f, Block **bpp)
282 {
283         if (*bpp == nil || BLEN(*bpp) == 0) {
284                 /* get enough to identify the IP version */
285                 *bpp = pullupblock(*bpp, IP4HDR);
286                 if(*bpp == nil) {
287                         netlog(f, Logesp, "esp: short packet\n");
288                         return 0;
289                 }
290         }
291         return (((Esp4hdr*)(*bpp)->rp)->vihl & 0xf0) == IP_VER4? V4: V6;
292 }
293
294 static void
295 getverslens(int version, Versdep *vp)
296 {
297         vp->version = version;
298         switch(vp->version) {
299         case V4:
300                 vp->iphdrlen = IP4HDR;
301                 vp->hdrlen   = Esp4hdrlen;
302                 break;
303         case V6:
304                 vp->iphdrlen = IP6HDR;
305                 vp->hdrlen   = Esp6hdrlen;
306                 break;
307         default:
308                 panic("esp: getverslens version %d wrong", version);
309         }
310 }
311
312 static void
313 getpktspiaddrs(uchar *pkt, Versdep *vp)
314 {
315         Esp4hdr *eh4;
316         Esp6hdr *eh6;
317
318         switch(vp->version) {
319         case V4:
320                 eh4 = (Esp4hdr*)pkt;
321                 v4tov6(vp->raddr, eh4->espsrc);
322                 v4tov6(vp->laddr, eh4->espdst);
323                 vp->spi = nhgetl(eh4->espspi);
324                 break;
325         case V6:
326                 eh6 = (Esp6hdr*)pkt;
327                 ipmove(vp->raddr, eh6->src);
328                 ipmove(vp->laddr, eh6->dst);
329                 vp->spi = nhgetl(eh6->espspi);
330                 break;
331         default:
332                 panic("esp: getpktspiaddrs vp->version %ld wrong", vp->version);
333         }
334 }
335
336 /*
337  * encapsulate next IP packet on x's write queue in IP/ESP packet
338  * and initiate output of the result.
339  */
340 static void
341 espkick(void *x)
342 {
343         int nexthdr, payload, pad, align;
344         uchar *auth;
345         Block *bp;
346         Conv *c = x;
347         Esp4hdr *eh4;
348         Esp6hdr *eh6;
349         Espcb *ecb;
350         Esptail *et;
351         Userhdr *uh;
352         Versdep vers;
353
354         getverslens(convipvers(c), &vers);
355         bp = qget(c->wq);
356         if(bp == nil)
357                 return;
358
359         qlock(c);
360         ecb = c->ptcl;
361
362         if(ecb->header) {
363                 /* make sure the message has a User header */
364                 bp = pullupblock(bp, Userhdrlen);
365                 if(bp == nil) {
366                         qunlock(c);
367                         return;
368                 }
369                 uh = (Userhdr*)bp->rp;
370                 nexthdr = uh->nexthdr;
371                 bp->rp += Userhdrlen;
372         } else {
373                 nexthdr = 0;    /* what should this be? */
374         }
375
376         payload = BLEN(bp) + ecb->espivlen;
377
378         /* Make space to fit ip header */
379         bp = padblock(bp, vers.hdrlen + ecb->espivlen);
380         getpktspiaddrs(bp->rp, &vers);
381
382         align = 4;
383         if(ecb->espblklen > align)
384                 align = ecb->espblklen;
385         if(align % ecb->ahblklen != 0)
386                 panic("espkick: ahblklen is important after all");
387         pad = (align-1) - (payload + Esptaillen-1)%align;
388
389         /*
390          * Make space for tail
391          * this is done by calling padblock with a negative size
392          * Padblock does not change bp->wp!
393          */
394         bp = padblock(bp, -(pad+Esptaillen+ecb->ahlen));
395         bp->wp += pad+Esptaillen+ecb->ahlen;
396
397         et = (Esptail*)(bp->rp + vers.hdrlen + payload + pad);
398
399         /* fill in tail */
400         et->pad = pad;
401         et->nexthdr = nexthdr;
402
403         /* encrypt the payload */
404         ecb->cipher(ecb, bp->rp + vers.hdrlen, payload + pad + Esptaillen);
405         auth = bp->rp + vers.hdrlen + payload + pad + Esptaillen;
406
407         /* fill in head; construct a new IP header and an ESP header */
408         if (vers.version == V4) {
409                 eh4 = (Esp4hdr *)bp->rp;
410                 eh4->vihl = IP_VER4;
411                 v6tov4(eh4->espsrc, c->laddr);
412                 v6tov4(eh4->espdst, c->raddr);
413                 eh4->espproto = IP_ESPPROTO;
414                 eh4->frag[0] = 0;
415                 eh4->frag[1] = 0;
416
417                 hnputl(eh4->espspi, ecb->spi);
418                 hnputl(eh4->espseq, ++ecb->seq);
419         } else {
420                 eh6 = (Esp6hdr *)bp->rp;
421                 eh6->vcf[0] = IP_VER6;
422                 ipmove(eh6->src, c->laddr);
423                 ipmove(eh6->dst, c->raddr);
424                 eh6->proto = IP_ESPPROTO;
425
426                 hnputl(eh6->espspi, ecb->spi);
427                 hnputl(eh6->espseq, ++ecb->seq);
428         }
429
430         /* compute secure hash */
431         ecb->auth(ecb, bp->rp + vers.iphdrlen, (vers.hdrlen - vers.iphdrlen) +
432                 payload + pad + Esptaillen, auth);
433
434         qunlock(c);
435         /* print("esp: pass down: %uld\n", BLEN(bp)); */
436         if (vers.version == V4)
437                 ipoput4(c->p->f, bp, 0, c->ttl, c->tos, c);
438         else
439                 ipoput6(c->p->f, bp, 0, c->ttl, c->tos, c);
440 }
441
442 /*
443  * decapsulate IP packet from IP/ESP packet in bp and
444  * pass the result up the spi's Conv's read queue.
445  */
446 void
447 espiput(Proto *esp, Ipifc*, Block *bp)
448 {
449         int payload, nexthdr;
450         uchar *auth, *espspi;
451         Conv *c;
452         Espcb *ecb;
453         Esptail *et;
454         Fs *f;
455         Userhdr *uh;
456         Versdep vers;
457
458         f = esp->f;
459
460         getverslens(pktipvers(f, &bp), &vers);
461
462         bp = pullupblock(bp, vers.hdrlen + Esptaillen);
463         if(bp == nil) {
464                 netlog(f, Logesp, "esp: short packet\n");
465                 return;
466         }
467         getpktspiaddrs(bp->rp, &vers);
468
469         qlock(esp);
470         /* Look for a conversation structure for this port */
471         c = convlookup(esp, vers.spi);
472         if(c == nil) {
473                 qunlock(esp);
474                 netlog(f, Logesp, "esp: no conv %I -> %I!%lud\n", vers.raddr,
475                         vers.laddr, vers.spi);
476                 icmpnoconv(f, bp);
477                 freeblist(bp);
478                 return;
479         }
480
481         qlock(c);
482         qunlock(esp);
483
484         ecb = c->ptcl;
485         /* too hard to do decryption/authentication on block lists */
486         if(bp->next)
487                 bp = concatblock(bp);
488
489         if(BLEN(bp) < vers.hdrlen + ecb->espivlen + Esptaillen + ecb->ahlen) {
490                 qunlock(c);
491                 netlog(f, Logesp, "esp: short block %I -> %I!%lud\n", vers.raddr,
492                         vers.laddr, vers.spi);
493                 freeb(bp);
494                 return;
495         }
496
497         auth = bp->wp - ecb->ahlen;
498         espspi = vers.version == V4?    ((Esp4hdr*)bp->rp)->espspi:
499                                         ((Esp6hdr*)bp->rp)->espspi;
500
501         /* compute secure hash and authenticate */
502         if(!ecb->auth(ecb, espspi, auth - espspi, auth)) {
503                 qunlock(c);
504 print("esp: bad auth %I -> %I!%ld\n", vers.raddr, vers.laddr, vers.spi);
505                 netlog(f, Logesp, "esp: bad auth %I -> %I!%lud\n", vers.raddr,
506                         vers.laddr, vers.spi);
507                 freeb(bp);
508                 return;
509         }
510
511         payload = BLEN(bp) - vers.hdrlen - ecb->ahlen;
512         if(payload <= 0 || payload % 4 != 0 || payload % ecb->espblklen != 0) {
513                 qunlock(c);
514                 netlog(f, Logesp, "esp: bad length %I -> %I!%lud payload=%d BLEN=%zd\n",
515                         vers.raddr, vers.laddr, vers.spi, payload, BLEN(bp));
516                 freeb(bp);
517                 return;
518         }
519
520         /* decrypt payload */
521         if(!ecb->cipher(ecb, bp->rp + vers.hdrlen, payload)) {
522                 qunlock(c);
523 print("esp: cipher failed %I -> %I!%ld: %s\n", vers.raddr, vers.laddr, vers.spi, up->errstr);
524                 netlog(f, Logesp, "esp: cipher failed %I -> %I!%lud: %s\n",
525                         vers.raddr, vers.laddr, vers.spi, up->errstr);
526                 freeb(bp);
527                 return;
528         }
529
530         payload -= Esptaillen;
531         et = (Esptail*)(bp->rp + vers.hdrlen + payload);
532         payload -= et->pad + ecb->espivlen;
533         nexthdr = et->nexthdr;
534         if(payload <= 0) {
535                 qunlock(c);
536                 netlog(f, Logesp, "esp: short packet after decrypt %I -> %I!%lud\n",
537                         vers.raddr, vers.laddr, vers.spi);
538                 freeb(bp);
539                 return;
540         }
541
542         /* trim packet */
543         bp->rp += vers.hdrlen + ecb->espivlen; /* toss original IP & ESP hdrs */
544         bp->wp = bp->rp + payload;
545         if(ecb->header) {
546                 /* assume Userhdrlen < Esp4hdrlen < Esp6hdrlen */
547                 bp->rp -= Userhdrlen;
548                 uh = (Userhdr*)bp->rp;
549                 memset(uh, 0, Userhdrlen);
550                 uh->nexthdr = nexthdr;
551         }
552
553         /* ingress filtering here? */
554
555         if(qfull(c->rq)){
556                 netlog(f, Logesp, "esp: qfull %I -> %I.%uld\n", vers.raddr,
557                         vers.laddr, vers.spi);
558                 freeblist(bp);
559         }else {
560 //              print("esp: pass up: %uld\n", BLEN(bp));
561                 qpass(c->rq, bp);       /* pass packet up the read queue */
562         }
563
564         qunlock(c);
565 }
566
567 char*
568 espctl(Conv *c, char **f, int n)
569 {
570         Espcb *ecb = c->ptcl;
571         char *e = nil;
572
573         if(strcmp(f[0], "esp") == 0)
574                 e = setalg(ecb, f, n, espalg);
575         else if(strcmp(f[0], "ah") == 0)
576                 e = setalg(ecb, f, n, ahalg);
577         else if(strcmp(f[0], "header") == 0)
578                 ecb->header = 1;
579         else if(strcmp(f[0], "noheader") == 0)
580                 ecb->header = 0;
581         else
582                 e = "unknown control request";
583         return e;
584 }
585
586 /* called from icmp(v6) for unreachable hosts, time exceeded, etc. */
587 void
588 espadvise(Proto *esp, Block *bp, char *msg)
589 {
590         Conv *c;
591         Versdep vers;
592
593         getverslens(pktipvers(esp->f, &bp), &vers);
594         getpktspiaddrs(bp->rp, &vers);
595
596         qlock(esp);
597         c = convlookup(esp, vers.spi);
598         if(c != nil && !c->ignoreadvice) {
599                 qhangup(c->rq, msg);
600                 qhangup(c->wq, msg);
601         }
602         qunlock(esp);
603         freeblist(bp);
604 }
605
606 int
607 espstats(Proto *esp, char *buf, int len)
608 {
609         Esppriv *upriv;
610
611         upriv = esp->priv;
612         return snprint(buf, len, "%llud %lud\n",
613                 upriv->in,
614                 upriv->inerrors);
615 }
616
617 static int
618 esplocal(Conv *c, char *buf, int len)
619 {
620         Espcb *ecb = c->ptcl;
621         int n;
622
623         qlock(c);
624         if(ecb->incoming)
625                 n = snprint(buf, len, "%I!%uld\n", c->laddr, ecb->spi);
626         else
627                 n = snprint(buf, len, "%I\n", c->laddr);
628         qunlock(c);
629         return n;
630 }
631
632 static int
633 espremote(Conv *c, char *buf, int len)
634 {
635         Espcb *ecb = c->ptcl;
636         int n;
637
638         qlock(c);
639         if(ecb->incoming)
640                 n = snprint(buf, len, "%I\n", c->raddr);
641         else
642                 n = snprint(buf, len, "%I!%uld\n", c->raddr, ecb->spi);
643         qunlock(c);
644         return n;
645 }
646
647 static  Conv*
648 convlookup(Proto *esp, ulong spi)
649 {
650         Conv *c, **p;
651         Espcb *ecb;
652
653         for(p=esp->conv; *p; p++){
654                 c = *p;
655                 ecb = c->ptcl;
656                 if(ecb->incoming && ecb->spi == spi)
657                         return c;
658         }
659         return nil;
660 }
661
662 static char *
663 setalg(Espcb *ecb, char **f, int n, Algorithm *alg)
664 {
665         uchar *key;
666         int c, nbyte, nchar;
667         uint i;
668
669         if(n < 2 || n > 3)
670                 return "bad format";
671         for(; alg->name; alg++)
672                 if(strcmp(f[1], alg->name) == 0)
673                         break;
674         if(alg->name == nil)
675                 return "unknown algorithm";
676
677         nbyte = (alg->keylen + 7) >> 3;
678         if (n == 2)
679                 nchar = 0;
680         else
681                 nchar = strlen(f[2]);
682         if(nchar != 2 * nbyte)                  /* TODO: maybe < is ok */
683                 return "key not required length";
684         /* convert hex digits from ascii, in place */
685         for(i=0; i<nchar; i++) {
686                 c = f[2][i];
687                 if(c >= '0' && c <= '9')
688                         f[2][i] -= '0';
689                 else if(c >= 'a' && c <= 'f')
690                         f[2][i] -= 'a'-10;
691                 else if(c >= 'A' && c <= 'F')
692                         f[2][i] -= 'A'-10;
693                 else
694                         return "non-hex character in key";
695         }
696         /* collapse hex digits into complete bytes in reverse order in key */
697         key = secalloc(nbyte);
698         for(i = 0; i < nchar && i/2 < nbyte; i++) {
699                 c = f[2][nchar-i-1];
700                 if(i&1)
701                         c <<= 4;
702                 key[i/2] |= c;
703         }
704         memset(f[2], 0, nchar);
705         alg->init(ecb, alg->name, key, alg->keylen);
706         secfree(key);
707         return nil;
708 }
709
710
711 /*
712  * null encryption
713  */
714
715 static int
716 nullcipher(Espcb*, uchar*, int)
717 {
718         return 1;
719 }
720
721 static void
722 nullespinit(Espcb *ecb, char *name, uchar*, unsigned)
723 {
724         ecb->espalg = name;
725         ecb->espblklen = 1;
726         ecb->espivlen = 0;
727         ecb->cipher = nullcipher;
728 }
729
730 static int
731 nullauth(Espcb*, uchar*, int, uchar*)
732 {
733         return 1;
734 }
735
736 static void
737 nullahinit(Espcb *ecb, char *name, uchar*, unsigned)
738 {
739         ecb->ahalg = name;
740         ecb->ahblklen = 1;
741         ecb->ahlen = 0;
742         ecb->auth = nullauth;
743 }
744
745
746 /*
747  * sha1
748  */
749
750 static void
751 seanq_hmac_sha1(uchar hash[SHA1dlen], uchar *t, long tlen, uchar *key, long klen)
752 {
753         int i;
754         uchar ipad[Hmacblksz+1], opad[Hmacblksz+1], innerhash[SHA1dlen];
755         DigestState *digest;
756
757         memset(ipad, 0x36, Hmacblksz);
758         memset(opad, 0x5c, Hmacblksz);
759         ipad[Hmacblksz] = opad[Hmacblksz] = 0;
760         for(i = 0; i < klen; i++){
761                 ipad[i] ^= key[i];
762                 opad[i] ^= key[i];
763         }
764         digest = sha1(ipad, Hmacblksz, nil, nil);
765         sha1(t, tlen, innerhash, digest);
766         digest = sha1(opad, Hmacblksz, nil, nil);
767         sha1(innerhash, SHA1dlen, hash, digest);
768 }
769
770 static int
771 shaauth(Espcb *ecb, uchar *t, int tlen, uchar *auth)
772 {
773         int r;
774         uchar hash[SHA1dlen];
775
776         memset(hash, 0, SHA1dlen);
777         seanq_hmac_sha1(hash, t, tlen, (uchar*)ecb->ahstate, BITS2BYTES(128));
778         r = memcmp(auth, hash, ecb->ahlen) == 0;
779         memmove(auth, hash, ecb->ahlen);
780         return r;
781 }
782
783 static void
784 shaahinit(Espcb *ecb, char *name, uchar *key, unsigned klen)
785 {
786         if(klen != 128)
787                 panic("shaahinit: bad keylen");
788         klen /= BI2BY;
789
790         ecb->ahalg = name;
791         ecb->ahblklen = 1;
792         ecb->ahlen = BITS2BYTES(96);
793         ecb->auth = shaauth;
794         ecb->ahstate = secalloc(klen);
795         memmove(ecb->ahstate, key, klen);
796 }
797
798
799 /*
800  * aes
801  */
802 static int
803 aescbccipher(Espcb *ecb, uchar *p, int n)       /* 128-bit blocks */
804 {
805         uchar tmp[AESbsize], q[AESbsize];
806         uchar *pp, *tp, *ip, *eip, *ep;
807         AESstate *ds = ecb->espstate;
808
809         ep = p + n;
810         if(ecb->incoming) {
811                 memmove(ds->ivec, p, AESbsize);
812                 p += AESbsize;
813                 while(p < ep){
814                         memmove(tmp, p, AESbsize);
815                         aes_decrypt(ds->dkey, ds->rounds, p, q);
816                         memmove(p, q, AESbsize);
817                         tp = tmp;
818                         ip = ds->ivec;
819                         for(eip = ip + AESbsize; ip < eip; ){
820                                 *p++ ^= *ip;
821                                 *ip++ = *tp++;
822                         }
823                 }
824         } else {
825                 memmove(p, ds->ivec, AESbsize);
826                 for(p += AESbsize; p < ep; p += AESbsize){
827                         pp = p;
828                         ip = ds->ivec;
829                         for(eip = ip + AESbsize; ip < eip; )
830                                 *pp++ ^= *ip++;
831                         aes_encrypt(ds->ekey, ds->rounds, p, q);
832                         memmove(ds->ivec, q, AESbsize);
833                         memmove(p, q, AESbsize);
834                 }
835         }
836         return 1;
837 }
838
839 static void
840 aescbcespinit(Espcb *ecb, char *name, uchar *k, unsigned n)
841 {
842         uchar key[Aeskeysz], ivec[Aeskeysz];
843
844         n = BITS2BYTES(n);
845         if(n > Aeskeysz)
846                 n = Aeskeysz;
847         memset(key, 0, sizeof(key));
848         memmove(key, k, n);
849         prng(ivec, Aeskeysz);
850         ecb->espalg = name;
851         ecb->espblklen = Aesblk;
852         ecb->espivlen = Aesblk;
853         ecb->cipher = aescbccipher;
854         ecb->espstate = secalloc(sizeof(AESstate));
855         setupAESstate(ecb->espstate, key, n /* keybytes */, ivec);
856         memset(ivec, 0, sizeof(ivec));
857         memset(key, 0, sizeof(key));
858 }
859
860 static int
861 aesctrcipher(Espcb *ecb, uchar *p, int n)       /* 128-bit blocks */
862 {
863         uchar tmp[AESbsize], q[AESbsize];
864         uchar *pp, *tp, *ip, *eip, *ep;
865         AESstate *ds = ecb->espstate;
866
867         ep = p + n;
868         if(ecb->incoming) {
869                 memmove(ds->ivec, p, AESbsize);
870                 p += AESbsize;
871                 while(p < ep){
872                         memmove(tmp, p, AESbsize);
873                         aes_decrypt(ds->dkey, ds->rounds, p, q);
874                         memmove(p, q, AESbsize);
875                         tp = tmp;
876                         ip = ds->ivec;
877                         for(eip = ip + AESbsize; ip < eip; ){
878                                 *p++ ^= *ip;
879                                 *ip++ = *tp++;
880                         }
881                 }
882         } else {
883                 memmove(p, ds->ivec, AESbsize);
884                 for(p += AESbsize; p < ep; p += AESbsize){
885                         pp = p;
886                         ip = ds->ivec;
887                         for(eip = ip + AESbsize; ip < eip; )
888                                 *pp++ ^= *ip++;
889                         aes_encrypt(ds->ekey, ds->rounds, p, q);
890                         memmove(ds->ivec, q, AESbsize);
891                         memmove(p, q, AESbsize);
892                 }
893         }
894         return 1;
895 }
896
897 static void
898 aesctrespinit(Espcb *ecb, char *name, uchar *k, unsigned n)
899 {
900         uchar key[Aesblk], ivec[Aesblk];
901
902         n = BITS2BYTES(n);
903         if(n > Aeskeysz)
904                 n = Aeskeysz;
905         memset(key, 0, sizeof(key));
906         memmove(key, k, n);
907         prng(ivec, Aesblk);
908         ecb->espalg = name;
909         ecb->espblklen = Aesblk;
910         ecb->espivlen = Aesblk;
911         ecb->cipher = aesctrcipher;
912         ecb->espstate = secalloc(sizeof(AESstate));
913         setupAESstate(ecb->espstate, key, n /* keybytes */, ivec);
914         memset(ivec, 0, sizeof(ivec));
915         memset(key, 0, sizeof(key));
916 }
917
918
919 /*
920  * md5
921  */
922
923 static void
924 seanq_hmac_md5(uchar hash[MD5dlen], uchar *t, long tlen, uchar *key, long klen)
925 {
926         int i;
927         uchar ipad[Hmacblksz+1], opad[Hmacblksz+1], innerhash[MD5dlen];
928         DigestState *digest;
929
930         memset(ipad, 0x36, Hmacblksz);
931         memset(opad, 0x5c, Hmacblksz);
932         ipad[Hmacblksz] = opad[Hmacblksz] = 0;
933         for(i = 0; i < klen; i++){
934                 ipad[i] ^= key[i];
935                 opad[i] ^= key[i];
936         }
937         digest = md5(ipad, Hmacblksz, nil, nil);
938         md5(t, tlen, innerhash, digest);
939         digest = md5(opad, Hmacblksz, nil, nil);
940         md5(innerhash, MD5dlen, hash, digest);
941 }
942
943 static int
944 md5auth(Espcb *ecb, uchar *t, int tlen, uchar *auth)
945 {
946         uchar hash[MD5dlen];
947         int r;
948
949         memset(hash, 0, MD5dlen);
950         seanq_hmac_md5(hash, t, tlen, (uchar*)ecb->ahstate, BITS2BYTES(128));
951         r = memcmp(auth, hash, ecb->ahlen) == 0;
952         memmove(auth, hash, ecb->ahlen);
953         return r;
954 }
955
956 static void
957 md5ahinit(Espcb *ecb, char *name, uchar *key, unsigned klen)
958 {
959         if(klen != 128)
960                 panic("md5ahinit: bad keylen");
961         klen = BITS2BYTES(klen);
962         ecb->ahalg = name;
963         ecb->ahblklen = 1;
964         ecb->ahlen = BITS2BYTES(96);
965         ecb->auth = md5auth;
966         ecb->ahstate = secalloc(klen);
967         memmove(ecb->ahstate, key, klen);
968 }
969
970
971 /*
972  * des, single and triple
973  */
974
975 static int
976 descipher(Espcb *ecb, uchar *p, int n)
977 {
978         DESstate *ds = ecb->espstate;
979
980         if(ecb->incoming) {
981                 memmove(ds->ivec, p, Desblk);
982                 desCBCdecrypt(p + Desblk, n - Desblk, ds);
983         } else {
984                 memmove(p, ds->ivec, Desblk);
985                 desCBCencrypt(p + Desblk, n - Desblk, ds);
986         }
987         return 1;
988 }
989
990 static int
991 des3cipher(Espcb *ecb, uchar *p, int n)
992 {
993         DES3state *ds = ecb->espstate;
994
995         if(ecb->incoming) {
996                 memmove(ds->ivec, p, Desblk);
997                 des3CBCdecrypt(p + Desblk, n - Desblk, ds);
998         } else {
999                 memmove(p, ds->ivec, Desblk);
1000                 des3CBCencrypt(p + Desblk, n - Desblk, ds);
1001         }
1002         return 1;
1003 }
1004
1005 static void
1006 desespinit(Espcb *ecb, char *name, uchar *k, unsigned n)
1007 {
1008         uchar key[Desblk], ivec[Desblk];
1009
1010         n = BITS2BYTES(n);
1011         if(n > Desblk)
1012                 n = Desblk;
1013         memset(key, 0, sizeof(key));
1014         memmove(key, k, n);
1015         prng(ivec, Desblk);
1016         ecb->espalg = name;
1017         ecb->espblklen = Desblk;
1018         ecb->espivlen = Desblk;
1019
1020         ecb->cipher = descipher;
1021         ecb->espstate = secalloc(sizeof(DESstate));
1022         setupDESstate(ecb->espstate, key, ivec);
1023         memset(ivec, 0, sizeof(ivec));
1024         memset(key, 0, sizeof(key));
1025 }
1026
1027 static void
1028 des3espinit(Espcb *ecb, char *name, uchar *k, unsigned n)
1029 {
1030         uchar key[3][Desblk], ivec[Desblk];
1031
1032         n = BITS2BYTES(n);
1033         if(n > Des3keysz)
1034                 n = Des3keysz;
1035         memset(key, 0, sizeof(key));
1036         memmove(key, k, n);
1037         prng(ivec, Desblk);
1038         ecb->espalg = name;
1039         ecb->espblklen = Desblk;
1040         ecb->espivlen = Desblk;
1041
1042         ecb->cipher = des3cipher;
1043         ecb->espstate = secalloc(sizeof(DES3state));
1044         setupDES3state(ecb->espstate, key, ivec);
1045         memset(ivec, 0, sizeof(ivec));
1046         memset(key, 0, sizeof(key));
1047 }
1048
1049
1050 /*
1051  * interfacing to devip
1052  */
1053 void
1054 espinit(Fs *fs)
1055 {
1056         Proto *esp;
1057
1058         esp = smalloc(sizeof(Proto));
1059         esp->priv = smalloc(sizeof(Esppriv));
1060         esp->name = "esp";
1061         esp->connect = espconnect;
1062         esp->announce = nil;
1063         esp->ctl = espctl;
1064         esp->state = espstate;
1065         esp->create = espcreate;
1066         esp->close = espclose;
1067         esp->rcv = espiput;
1068         esp->advise = espadvise;
1069         esp->stats = espstats;
1070         esp->local = esplocal;
1071         esp->remote = espremote;
1072         esp->ipproto = IP_ESPPROTO;
1073         esp->nc = Nchans;
1074         esp->ptclsize = sizeof(Espcb);
1075
1076         Fsproto(fs, esp);
1077 }