]> git.lizzy.rs Git - plan9front.git/blob - sys/src/9/ip/esp.c
merge
[plan9front.git] / sys / src / 9 / ip / esp.c
1 /*
2  * Encapsulating Security Payload for IPsec for IPv4, rfc1827.
3  * extended to IPv6.
4  * rfc2104 defines hmac computation.
5  *      currently only implements tunnel mode.
6  * TODO: verify aes algorithms;
7  *      transport mode (host-to-host)
8  */
9 #include        "u.h"
10 #include        "../port/lib.h"
11 #include        "mem.h"
12 #include        "dat.h"
13 #include        "fns.h"
14 #include        "../port/error.h"
15
16 #include        "ip.h"
17 #include        "ipv6.h"
18 #include        <libsec.h>
19
20 #define BITS2BYTES(bi) (((bi) + BI2BY - 1) / BI2BY)
21 #define BYTES2BITS(by)  ((by) * BI2BY)
22
23 typedef struct Algorithm Algorithm;
24 typedef struct Esp4hdr Esp4hdr;
25 typedef struct Esp6hdr Esp6hdr;
26 typedef struct Espcb Espcb;
27 typedef struct Esphdr Esphdr;
28 typedef struct Esppriv Esppriv;
29 typedef struct Esptail Esptail;
30 typedef struct Userhdr Userhdr;
31
32 enum {
33         Encrypt,
34         Decrypt,
35
36         IP_ESPPROTO     = 50,   /* IP v4 and v6 protocol number */
37         Esp4hdrlen      = IP4HDR + 8,
38         Esp6hdrlen      = IP6HDR + 8,
39
40         Esptaillen      = 2,    /* does not include pad or auth data */
41         Userhdrlen      = 4,    /* user-visible header size - if enabled */
42
43         Desblk   = BITS2BYTES(64),
44         Des3keysz = BITS2BYTES(192),
45
46         Aesblk   = BITS2BYTES(128),
47         Aeskeysz = BITS2BYTES(128),
48 };
49
50 struct Esphdr
51 {
52         uchar   espspi[4];      /* Security parameter index */
53         uchar   espseq[4];      /* Sequence number */
54         uchar   payload[];
55 };
56
57 /*
58  * tunnel-mode (network-to-network, etc.) layout is:
59  * new IP hdrs | ESP hdr |
60  *       enc { orig IP hdrs | TCP/UDP hdr | user data | ESP trailer } | ESP ICV
61  *
62  * transport-mode (host-to-host) layout would be:
63  *      orig IP hdrs | ESP hdr |
64  *                      enc { TCP/UDP hdr | user data | ESP trailer } | ESP ICV
65  */
66 struct Esp4hdr
67 {
68         /* ipv4 header */
69         uchar   vihl;           /* Version and header length */
70         uchar   tos;            /* Type of service */
71         uchar   length[2];      /* packet length */
72         uchar   id[2];          /* Identification */
73         uchar   frag[2];        /* Fragment information */
74         uchar   Unused;
75         uchar   espproto;       /* Protocol */
76         uchar   espplen[2];     /* Header plus data length */
77         uchar   espsrc[4];      /* Ip source */
78         uchar   espdst[4];      /* Ip destination */
79
80         Esphdr;
81 };
82
83 /* tunnel-mode layout */
84 struct Esp6hdr
85 {
86         IPV6HDR;
87         Esphdr;
88 };
89
90 struct Esptail
91 {
92         uchar   pad;
93         uchar   nexthdr;
94 };
95
96 /* IP-version-dependent data */
97 typedef struct Versdep Versdep;
98 struct Versdep
99 {
100         ulong   version;
101         ulong   iphdrlen;
102         ulong   hdrlen;         /* iphdrlen + esp hdr len */
103         ulong   spi;
104         uchar   laddr[IPaddrlen];
105         uchar   raddr[IPaddrlen];
106 };
107
108 /* header as seen by the user */
109 struct Userhdr
110 {
111         uchar   nexthdr;        /* next protocol */
112         uchar   unused[3];
113 };
114
115 struct Esppriv
116 {
117         uvlong  in;
118         ulong   inerrors;
119 };
120
121 /*
122  *  protocol specific part of Conv
123  */
124 struct Espcb
125 {
126         int     incoming;
127         int     header;         /* user-level header */
128         ulong   spi;
129         ulong   seq;            /* last seq sent */
130         ulong   window;         /* for replay attacks */
131
132         char    *espalg;
133         void    *espstate;      /* other state for esp */
134         int     espivlen;       /* in bytes */
135         int     espblklen;
136         int     (*cipher)(Espcb*, uchar *buf, int len);
137
138         char    *ahalg;
139         void    *ahstate;       /* other state for esp */
140         int     ahlen;          /* auth data length in bytes */
141         int     ahblklen;
142         int     (*auth)(Espcb*, uchar *buf, int len, uchar *hash);
143         DigestState *ds;
144 };
145
146 struct Algorithm
147 {
148         char    *name;
149         int     keylen;         /* in bits */
150         void    (*init)(Espcb*, char* name, uchar *key, unsigned keylen);
151 };
152
153 static  Conv* convlookup(Proto *esp, ulong spi);
154 static  char *setalg(Espcb *ecb, char **f, int n, Algorithm *alg);
155 static  void espkick(void *x);
156
157 static  void nullespinit(Espcb*, char*, uchar *key, unsigned keylen);
158 static  void des3espinit(Espcb*, char*, uchar *key, unsigned keylen);
159 static  void aescbcespinit(Espcb*, char*, uchar *key, unsigned keylen);
160 static  void aesctrespinit(Espcb*, char*, uchar *key, unsigned keylen);
161 static  void desespinit(Espcb *ecb, char *name, uchar *k, unsigned n);
162
163 static  void nullahinit(Espcb*, char*, uchar *key, unsigned keylen);
164 static  void shaahinit(Espcb*, char*, uchar *key, unsigned keylen);
165 static  void md5ahinit(Espcb*, char*, uchar *key, unsigned keylen);
166
167 static Algorithm espalg[] =
168 {
169         "null",         0,      nullespinit,
170         "des3_cbc",     192,    des3espinit,    /* new rfc2451, des-ede3 */
171         "aes_128_cbc",  128,    aescbcespinit,  /* new rfc3602 */
172         "aes_ctr",      128,    aesctrespinit,  /* new rfc3686 */
173         "des_56_cbc",   64,     desespinit,     /* rfc2405, deprecated */
174         nil,            0,      nil,
175 };
176
177 static Algorithm ahalg[] =
178 {
179         "null",         0,      nullahinit,
180         "hmac_sha1_96", 128,    shaahinit,      /* rfc2404 */
181         "hmac_md5_96",  128,    md5ahinit,      /* rfc2403 */
182         nil,            0,      nil,
183 };
184
185 static char*
186 espconnect(Conv *c, char **argv, int argc)
187 {
188         char *p, *pp, *e = nil;
189         ulong spi;
190         Espcb *ecb = (Espcb*)c->ptcl;
191
192         switch(argc) {
193         default:
194                 e = "bad args to connect";
195                 break;
196         case 2:
197                 p = strchr(argv[1], '!');
198                 if(p == nil){
199                         e = "malformed address";
200                         break;
201                 }
202                 *p++ = 0;
203                 if (parseip(c->raddr, argv[1]) == -1) {
204                         e = Ebadip;
205                         break;
206                 }
207                 findlocalip(c->p->f, c->laddr, c->raddr);
208                 ecb->incoming = 0;
209                 ecb->seq = 0;
210                 if(strcmp(p, "*") == 0) {
211                         qlock(c->p);
212                         for(;;) {
213                                 spi = nrand(1<<16) + 256;
214                                 if(convlookup(c->p, spi) == nil)
215                                         break;
216                         }
217                         qunlock(c->p);
218                         ecb->spi = spi;
219                         ecb->incoming = 1;
220                         qhangup(c->wq, nil);
221                 } else {
222                         spi = strtoul(p, &pp, 10);
223                         if(pp == p) {
224                                 e = "malformed address";
225                                 break;
226                         }
227                         ecb->spi = spi;
228                         qhangup(c->rq, nil);
229                 }
230                 nullespinit(ecb, "null", nil, 0);
231                 nullahinit(ecb, "null", nil, 0);
232         }
233         Fsconnected(c, e);
234
235         return e;
236 }
237
238
239 static int
240 espstate(Conv *c, char *state, int n)
241 {
242         return snprint(state, n, "%s", c->inuse?"Open\n":"Closed\n");
243 }
244
245 static void
246 espcreate(Conv *c)
247 {
248         c->rq = qopen(64*1024, Qmsg, 0, 0);
249         c->wq = qopen(64*1024, Qkick, espkick, c);
250 }
251
252 static void
253 espclose(Conv *c)
254 {
255         Espcb *ecb;
256
257         qclose(c->rq);
258         qclose(c->wq);
259         qclose(c->eq);
260         ipmove(c->laddr, IPnoaddr);
261         ipmove(c->raddr, IPnoaddr);
262
263         ecb = (Espcb*)c->ptcl;
264         secfree(ecb->espstate);
265         secfree(ecb->ahstate);
266         memset(ecb, 0, sizeof(Espcb));
267 }
268
269 static int
270 pktipvers(Fs *f, Block **bpp)
271 {
272         if (*bpp == nil || BLEN(*bpp) == 0) {
273                 /* get enough to identify the IP version */
274                 *bpp = pullupblock(*bpp, IP4HDR);
275                 if(*bpp == nil) {
276                         netlog(f, Logesp, "esp: short packet\n");
277                         return 0;
278                 }
279         }
280         return (((Esp4hdr*)(*bpp)->rp)->vihl & 0xf0) == IP_VER4? V4: V6;
281 }
282
283 static void
284 getverslens(int version, Versdep *vp)
285 {
286         vp->version = version;
287         switch(vp->version) {
288         case V4:
289                 vp->iphdrlen = IP4HDR;
290                 vp->hdrlen   = Esp4hdrlen;
291                 break;
292         case V6:
293                 vp->iphdrlen = IP6HDR;
294                 vp->hdrlen   = Esp6hdrlen;
295                 break;
296         default:
297                 panic("esp: getverslens version %d wrong", version);
298         }
299 }
300
301 static void
302 getpktspiaddrs(uchar *pkt, Versdep *vp)
303 {
304         Esp4hdr *eh4;
305         Esp6hdr *eh6;
306
307         switch(vp->version) {
308         case V4:
309                 eh4 = (Esp4hdr*)pkt;
310                 v4tov6(vp->raddr, eh4->espsrc);
311                 v4tov6(vp->laddr, eh4->espdst);
312                 vp->spi = nhgetl(eh4->espspi);
313                 break;
314         case V6:
315                 eh6 = (Esp6hdr*)pkt;
316                 ipmove(vp->raddr, eh6->src);
317                 ipmove(vp->laddr, eh6->dst);
318                 vp->spi = nhgetl(eh6->espspi);
319                 break;
320         default:
321                 panic("esp: getpktspiaddrs vp->version %ld wrong", vp->version);
322         }
323 }
324
325 /*
326  * encapsulate next IP packet on x's write queue in IP/ESP packet
327  * and initiate output of the result.
328  */
329 static void
330 espkick(void *x)
331 {
332         int nexthdr, payload, pad, align;
333         uchar *auth;
334         Block *bp;
335         Conv *c = x;
336         Esp4hdr *eh4;
337         Esp6hdr *eh6;
338         Espcb *ecb;
339         Esptail *et;
340         Userhdr *uh;
341         Versdep vers;
342
343         getverslens(convipvers(c), &vers);
344         bp = qget(c->wq);
345         if(bp == nil)
346                 return;
347
348         qlock(c);
349         ecb = c->ptcl;
350
351         if(ecb->header) {
352                 /* make sure the message has a User header */
353                 bp = pullupblock(bp, Userhdrlen);
354                 if(bp == nil) {
355                         qunlock(c);
356                         return;
357                 }
358                 uh = (Userhdr*)bp->rp;
359                 nexthdr = uh->nexthdr;
360                 bp->rp += Userhdrlen;
361         } else {
362                 nexthdr = 0;    /* what should this be? */
363         }
364
365         payload = BLEN(bp) + ecb->espivlen;
366
367         /* Make space to fit ip header */
368         bp = padblock(bp, vers.hdrlen + ecb->espivlen);
369         getpktspiaddrs(bp->rp, &vers);
370
371         align = 4;
372         if(ecb->espblklen > align)
373                 align = ecb->espblklen;
374         if(align % ecb->ahblklen != 0)
375                 panic("espkick: ahblklen is important after all");
376         pad = (align-1) - (payload + Esptaillen-1)%align;
377
378         /*
379          * Make space for tail
380          * this is done by calling padblock with a negative size
381          * Padblock does not change bp->wp!
382          */
383         bp = padblock(bp, -(pad+Esptaillen+ecb->ahlen));
384         bp->wp += pad+Esptaillen+ecb->ahlen;
385
386         et = (Esptail*)(bp->rp + vers.hdrlen + payload + pad);
387
388         /* fill in tail */
389         et->pad = pad;
390         et->nexthdr = nexthdr;
391
392         /* encrypt the payload */
393         ecb->cipher(ecb, bp->rp + vers.hdrlen, payload + pad + Esptaillen);
394         auth = bp->rp + vers.hdrlen + payload + pad + Esptaillen;
395
396         /* fill in head; construct a new IP header and an ESP header */
397         if (vers.version == V4) {
398                 eh4 = (Esp4hdr *)bp->rp;
399                 eh4->vihl = IP_VER4;
400                 v6tov4(eh4->espsrc, c->laddr);
401                 v6tov4(eh4->espdst, c->raddr);
402                 eh4->espproto = IP_ESPPROTO;
403                 eh4->frag[0] = 0;
404                 eh4->frag[1] = 0;
405
406                 hnputl(eh4->espspi, ecb->spi);
407                 hnputl(eh4->espseq, ++ecb->seq);
408         } else {
409                 eh6 = (Esp6hdr *)bp->rp;
410                 eh6->vcf[0] = IP_VER6;
411                 ipmove(eh6->src, c->laddr);
412                 ipmove(eh6->dst, c->raddr);
413                 eh6->proto = IP_ESPPROTO;
414
415                 hnputl(eh6->espspi, ecb->spi);
416                 hnputl(eh6->espseq, ++ecb->seq);
417         }
418
419         /* compute secure hash */
420         ecb->auth(ecb, bp->rp + vers.iphdrlen, (vers.hdrlen - vers.iphdrlen) +
421                 payload + pad + Esptaillen, auth);
422
423         qunlock(c);
424         /* print("esp: pass down: %uld\n", BLEN(bp)); */
425         if (vers.version == V4)
426                 ipoput4(c->p->f, bp, 0, c->ttl, c->tos, c);
427         else
428                 ipoput6(c->p->f, bp, 0, c->ttl, c->tos, c);
429 }
430
431 /*
432  * decapsulate IP packet from IP/ESP packet in bp and
433  * pass the result up the spi's Conv's read queue.
434  */
435 void
436 espiput(Proto *esp, Ipifc*, Block *bp)
437 {
438         int payload, nexthdr;
439         uchar *auth, *espspi;
440         Conv *c;
441         Espcb *ecb;
442         Esptail *et;
443         Fs *f;
444         Userhdr *uh;
445         Versdep vers;
446
447         f = esp->f;
448
449         getverslens(pktipvers(f, &bp), &vers);
450
451         bp = pullupblock(bp, vers.hdrlen + Esptaillen);
452         if(bp == nil) {
453                 netlog(f, Logesp, "esp: short packet\n");
454                 return;
455         }
456         getpktspiaddrs(bp->rp, &vers);
457
458         qlock(esp);
459         /* Look for a conversation structure for this port */
460         c = convlookup(esp, vers.spi);
461         if(c == nil) {
462                 qunlock(esp);
463                 netlog(f, Logesp, "esp: no conv %I -> %I!%lud\n", vers.raddr,
464                         vers.laddr, vers.spi);
465                 icmpnoconv(f, bp);
466                 freeblist(bp);
467                 return;
468         }
469
470         qlock(c);
471         qunlock(esp);
472
473         ecb = c->ptcl;
474         /* too hard to do decryption/authentication on block lists */
475         if(bp->next)
476                 bp = concatblock(bp);
477
478         if(BLEN(bp) < vers.hdrlen + ecb->espivlen + Esptaillen + ecb->ahlen) {
479                 qunlock(c);
480                 netlog(f, Logesp, "esp: short block %I -> %I!%lud\n", vers.raddr,
481                         vers.laddr, vers.spi);
482                 freeb(bp);
483                 return;
484         }
485
486         auth = bp->wp - ecb->ahlen;
487         espspi = vers.version == V4?    ((Esp4hdr*)bp->rp)->espspi:
488                                         ((Esp6hdr*)bp->rp)->espspi;
489
490         /* compute secure hash and authenticate */
491         if(!ecb->auth(ecb, espspi, auth - espspi, auth)) {
492                 qunlock(c);
493 print("esp: bad auth %I -> %I!%ld\n", vers.raddr, vers.laddr, vers.spi);
494                 netlog(f, Logesp, "esp: bad auth %I -> %I!%lud\n", vers.raddr,
495                         vers.laddr, vers.spi);
496                 freeb(bp);
497                 return;
498         }
499
500         payload = BLEN(bp) - vers.hdrlen - ecb->ahlen;
501         if(payload <= 0 || payload % 4 != 0 || payload % ecb->espblklen != 0) {
502                 qunlock(c);
503                 netlog(f, Logesp, "esp: bad length %I -> %I!%lud payload=%d BLEN=%zd\n",
504                         vers.raddr, vers.laddr, vers.spi, payload, BLEN(bp));
505                 freeb(bp);
506                 return;
507         }
508
509         /* decrypt payload */
510         if(!ecb->cipher(ecb, bp->rp + vers.hdrlen, payload)) {
511                 qunlock(c);
512 print("esp: cipher failed %I -> %I!%ld: %s\n", vers.raddr, vers.laddr, vers.spi, up->errstr);
513                 netlog(f, Logesp, "esp: cipher failed %I -> %I!%lud: %s\n",
514                         vers.raddr, vers.laddr, vers.spi, up->errstr);
515                 freeb(bp);
516                 return;
517         }
518
519         payload -= Esptaillen;
520         et = (Esptail*)(bp->rp + vers.hdrlen + payload);
521         payload -= et->pad + ecb->espivlen;
522         nexthdr = et->nexthdr;
523         if(payload <= 0) {
524                 qunlock(c);
525                 netlog(f, Logesp, "esp: short packet after decrypt %I -> %I!%lud\n",
526                         vers.raddr, vers.laddr, vers.spi);
527                 freeb(bp);
528                 return;
529         }
530
531         /* trim packet */
532         bp->rp += vers.hdrlen + ecb->espivlen; /* toss original IP & ESP hdrs */
533         bp->wp = bp->rp + payload;
534         if(ecb->header) {
535                 /* assume Userhdrlen < Esp4hdrlen < Esp6hdrlen */
536                 bp->rp -= Userhdrlen;
537                 uh = (Userhdr*)bp->rp;
538                 memset(uh, 0, Userhdrlen);
539                 uh->nexthdr = nexthdr;
540         }
541
542         /* ingress filtering here? */
543
544         if(qfull(c->rq)){
545                 netlog(f, Logesp, "esp: qfull %I -> %I.%uld\n", vers.raddr,
546                         vers.laddr, vers.spi);
547                 freeblist(bp);
548         }else {
549 //              print("esp: pass up: %uld\n", BLEN(bp));
550                 qpass(c->rq, bp);       /* pass packet up the read queue */
551         }
552
553         qunlock(c);
554 }
555
556 char*
557 espctl(Conv *c, char **f, int n)
558 {
559         Espcb *ecb = c->ptcl;
560         char *e = nil;
561
562         if(strcmp(f[0], "esp") == 0)
563                 e = setalg(ecb, f, n, espalg);
564         else if(strcmp(f[0], "ah") == 0)
565                 e = setalg(ecb, f, n, ahalg);
566         else if(strcmp(f[0], "header") == 0)
567                 ecb->header = 1;
568         else if(strcmp(f[0], "noheader") == 0)
569                 ecb->header = 0;
570         else
571                 e = "unknown control request";
572         return e;
573 }
574
575 /* called from icmp(v6) for unreachable hosts, time exceeded, etc. */
576 void
577 espadvise(Proto *esp, Block *bp, char *msg)
578 {
579         Conv *c;
580         Versdep vers;
581
582         getverslens(pktipvers(esp->f, &bp), &vers);
583         getpktspiaddrs(bp->rp, &vers);
584
585         qlock(esp);
586         c = convlookup(esp, vers.spi);
587         if(c != nil && !c->ignoreadvice) {
588                 qhangup(c->rq, msg);
589                 qhangup(c->wq, msg);
590         }
591         qunlock(esp);
592         freeblist(bp);
593 }
594
595 int
596 espstats(Proto *esp, char *buf, int len)
597 {
598         Esppriv *upriv;
599
600         upriv = esp->priv;
601         return snprint(buf, len, "%llud %lud\n",
602                 upriv->in,
603                 upriv->inerrors);
604 }
605
606 static int
607 esplocal(Conv *c, char *buf, int len)
608 {
609         Espcb *ecb = c->ptcl;
610         int n;
611
612         qlock(c);
613         if(ecb->incoming)
614                 n = snprint(buf, len, "%I!%uld\n", c->laddr, ecb->spi);
615         else
616                 n = snprint(buf, len, "%I\n", c->laddr);
617         qunlock(c);
618         return n;
619 }
620
621 static int
622 espremote(Conv *c, char *buf, int len)
623 {
624         Espcb *ecb = c->ptcl;
625         int n;
626
627         qlock(c);
628         if(ecb->incoming)
629                 n = snprint(buf, len, "%I\n", c->raddr);
630         else
631                 n = snprint(buf, len, "%I!%uld\n", c->raddr, ecb->spi);
632         qunlock(c);
633         return n;
634 }
635
636 static  Conv*
637 convlookup(Proto *esp, ulong spi)
638 {
639         Conv *c, **p;
640         Espcb *ecb;
641
642         for(p=esp->conv; *p; p++){
643                 c = *p;
644                 ecb = c->ptcl;
645                 if(ecb->incoming && ecb->spi == spi)
646                         return c;
647         }
648         return nil;
649 }
650
651 static char *
652 setalg(Espcb *ecb, char **f, int n, Algorithm *alg)
653 {
654         uchar *key;
655         int c, nbyte, nchar;
656         uint i;
657
658         if(n < 2 || n > 3)
659                 return "bad format";
660         for(; alg->name; alg++)
661                 if(strcmp(f[1], alg->name) == 0)
662                         break;
663         if(alg->name == nil)
664                 return "unknown algorithm";
665
666         nbyte = (alg->keylen + 7) >> 3;
667         if (n == 2)
668                 nchar = 0;
669         else
670                 nchar = strlen(f[2]);
671         if(nchar != 2 * nbyte)                  /* TODO: maybe < is ok */
672                 return "key not required length";
673         /* convert hex digits from ascii, in place */
674         for(i=0; i<nchar; i++) {
675                 c = f[2][i];
676                 if(c >= '0' && c <= '9')
677                         f[2][i] -= '0';
678                 else if(c >= 'a' && c <= 'f')
679                         f[2][i] -= 'a'-10;
680                 else if(c >= 'A' && c <= 'F')
681                         f[2][i] -= 'A'-10;
682                 else
683                         return "non-hex character in key";
684         }
685         /* collapse hex digits into complete bytes in reverse order in key */
686         key = secalloc(nbyte);
687         for(i = 0; i < nchar && i/2 < nbyte; i++) {
688                 c = f[2][nchar-i-1];
689                 if(i&1)
690                         c <<= 4;
691                 key[i/2] |= c;
692         }
693         memset(f[2], 0, nchar);
694         alg->init(ecb, alg->name, key, alg->keylen);
695         secfree(key);
696         return nil;
697 }
698
699
700 /*
701  * null encryption
702  */
703
704 static int
705 nullcipher(Espcb*, uchar*, int)
706 {
707         return 1;
708 }
709
710 static void
711 nullespinit(Espcb *ecb, char *name, uchar*, unsigned)
712 {
713         ecb->espalg = name;
714         ecb->espblklen = 1;
715         ecb->espivlen = 0;
716         ecb->cipher = nullcipher;
717 }
718
719 static int
720 nullauth(Espcb*, uchar*, int, uchar*)
721 {
722         return 1;
723 }
724
725 static void
726 nullahinit(Espcb *ecb, char *name, uchar*, unsigned)
727 {
728         ecb->ahalg = name;
729         ecb->ahblklen = 1;
730         ecb->ahlen = 0;
731         ecb->auth = nullauth;
732 }
733
734
735 /*
736  * sha1
737  */
738
739 static void
740 seanq_hmac_sha1(uchar hash[SHA1dlen], uchar *t, long tlen, uchar *key, long klen)
741 {
742         int i;
743         uchar ipad[Hmacblksz+1], opad[Hmacblksz+1], innerhash[SHA1dlen];
744         DigestState *digest;
745
746         memset(ipad, 0x36, Hmacblksz);
747         memset(opad, 0x5c, Hmacblksz);
748         ipad[Hmacblksz] = opad[Hmacblksz] = 0;
749         for(i = 0; i < klen; i++){
750                 ipad[i] ^= key[i];
751                 opad[i] ^= key[i];
752         }
753         digest = sha1(ipad, Hmacblksz, nil, nil);
754         sha1(t, tlen, innerhash, digest);
755         digest = sha1(opad, Hmacblksz, nil, nil);
756         sha1(innerhash, SHA1dlen, hash, digest);
757 }
758
759 static int
760 shaauth(Espcb *ecb, uchar *t, int tlen, uchar *auth)
761 {
762         int r;
763         uchar hash[SHA1dlen];
764
765         memset(hash, 0, SHA1dlen);
766         seanq_hmac_sha1(hash, t, tlen, (uchar*)ecb->ahstate, BITS2BYTES(128));
767         r = memcmp(auth, hash, ecb->ahlen) == 0;
768         memmove(auth, hash, ecb->ahlen);
769         return r;
770 }
771
772 static void
773 shaahinit(Espcb *ecb, char *name, uchar *key, unsigned klen)
774 {
775         if(klen != 128)
776                 panic("shaahinit: bad keylen");
777         klen /= BI2BY;
778
779         ecb->ahalg = name;
780         ecb->ahblklen = 1;
781         ecb->ahlen = BITS2BYTES(96);
782         ecb->auth = shaauth;
783         ecb->ahstate = secalloc(klen);
784         memmove(ecb->ahstate, key, klen);
785 }
786
787
788 /*
789  * aes
790  */
791 static int
792 aescbccipher(Espcb *ecb, uchar *p, int n)       /* 128-bit blocks */
793 {
794         uchar tmp[AESbsize], q[AESbsize];
795         uchar *pp, *tp, *ip, *eip, *ep;
796         AESstate *ds = ecb->espstate;
797
798         ep = p + n;
799         if(ecb->incoming) {
800                 memmove(ds->ivec, p, AESbsize);
801                 p += AESbsize;
802                 while(p < ep){
803                         memmove(tmp, p, AESbsize);
804                         aes_decrypt(ds->dkey, ds->rounds, p, q);
805                         memmove(p, q, AESbsize);
806                         tp = tmp;
807                         ip = ds->ivec;
808                         for(eip = ip + AESbsize; ip < eip; ){
809                                 *p++ ^= *ip;
810                                 *ip++ = *tp++;
811                         }
812                 }
813         } else {
814                 memmove(p, ds->ivec, AESbsize);
815                 for(p += AESbsize; p < ep; p += AESbsize){
816                         pp = p;
817                         ip = ds->ivec;
818                         for(eip = ip + AESbsize; ip < eip; )
819                                 *pp++ ^= *ip++;
820                         aes_encrypt(ds->ekey, ds->rounds, p, q);
821                         memmove(ds->ivec, q, AESbsize);
822                         memmove(p, q, AESbsize);
823                 }
824         }
825         return 1;
826 }
827
828 static void
829 aescbcespinit(Espcb *ecb, char *name, uchar *k, unsigned n)
830 {
831         uchar key[Aeskeysz], ivec[Aeskeysz];
832
833         n = BITS2BYTES(n);
834         if(n > Aeskeysz)
835                 n = Aeskeysz;
836         memset(key, 0, sizeof(key));
837         memmove(key, k, n);
838         prng(ivec, Aeskeysz);
839         ecb->espalg = name;
840         ecb->espblklen = Aesblk;
841         ecb->espivlen = Aesblk;
842         ecb->cipher = aescbccipher;
843         ecb->espstate = secalloc(sizeof(AESstate));
844         setupAESstate(ecb->espstate, key, n /* keybytes */, ivec);
845         memset(ivec, 0, sizeof(ivec));
846         memset(key, 0, sizeof(key));
847 }
848
849 static int
850 aesctrcipher(Espcb *ecb, uchar *p, int n)       /* 128-bit blocks */
851 {
852         uchar tmp[AESbsize], q[AESbsize];
853         uchar *pp, *tp, *ip, *eip, *ep;
854         AESstate *ds = ecb->espstate;
855
856         ep = p + n;
857         if(ecb->incoming) {
858                 memmove(ds->ivec, p, AESbsize);
859                 p += AESbsize;
860                 while(p < ep){
861                         memmove(tmp, p, AESbsize);
862                         aes_decrypt(ds->dkey, ds->rounds, p, q);
863                         memmove(p, q, AESbsize);
864                         tp = tmp;
865                         ip = ds->ivec;
866                         for(eip = ip + AESbsize; ip < eip; ){
867                                 *p++ ^= *ip;
868                                 *ip++ = *tp++;
869                         }
870                 }
871         } else {
872                 memmove(p, ds->ivec, AESbsize);
873                 for(p += AESbsize; p < ep; p += AESbsize){
874                         pp = p;
875                         ip = ds->ivec;
876                         for(eip = ip + AESbsize; ip < eip; )
877                                 *pp++ ^= *ip++;
878                         aes_encrypt(ds->ekey, ds->rounds, p, q);
879                         memmove(ds->ivec, q, AESbsize);
880                         memmove(p, q, AESbsize);
881                 }
882         }
883         return 1;
884 }
885
886 static void
887 aesctrespinit(Espcb *ecb, char *name, uchar *k, unsigned n)
888 {
889         uchar key[Aesblk], ivec[Aesblk];
890
891         n = BITS2BYTES(n);
892         if(n > Aeskeysz)
893                 n = Aeskeysz;
894         memset(key, 0, sizeof(key));
895         memmove(key, k, n);
896         prng(ivec, Aesblk);
897         ecb->espalg = name;
898         ecb->espblklen = Aesblk;
899         ecb->espivlen = Aesblk;
900         ecb->cipher = aesctrcipher;
901         ecb->espstate = secalloc(sizeof(AESstate));
902         setupAESstate(ecb->espstate, key, n /* keybytes */, ivec);
903         memset(ivec, 0, sizeof(ivec));
904         memset(key, 0, sizeof(key));
905 }
906
907
908 /*
909  * md5
910  */
911
912 static void
913 seanq_hmac_md5(uchar hash[MD5dlen], uchar *t, long tlen, uchar *key, long klen)
914 {
915         int i;
916         uchar ipad[Hmacblksz+1], opad[Hmacblksz+1], innerhash[MD5dlen];
917         DigestState *digest;
918
919         memset(ipad, 0x36, Hmacblksz);
920         memset(opad, 0x5c, Hmacblksz);
921         ipad[Hmacblksz] = opad[Hmacblksz] = 0;
922         for(i = 0; i < klen; i++){
923                 ipad[i] ^= key[i];
924                 opad[i] ^= key[i];
925         }
926         digest = md5(ipad, Hmacblksz, nil, nil);
927         md5(t, tlen, innerhash, digest);
928         digest = md5(opad, Hmacblksz, nil, nil);
929         md5(innerhash, MD5dlen, hash, digest);
930 }
931
932 static int
933 md5auth(Espcb *ecb, uchar *t, int tlen, uchar *auth)
934 {
935         uchar hash[MD5dlen];
936         int r;
937
938         memset(hash, 0, MD5dlen);
939         seanq_hmac_md5(hash, t, tlen, (uchar*)ecb->ahstate, BITS2BYTES(128));
940         r = memcmp(auth, hash, ecb->ahlen) == 0;
941         memmove(auth, hash, ecb->ahlen);
942         return r;
943 }
944
945 static void
946 md5ahinit(Espcb *ecb, char *name, uchar *key, unsigned klen)
947 {
948         if(klen != 128)
949                 panic("md5ahinit: bad keylen");
950         klen = BITS2BYTES(klen);
951         ecb->ahalg = name;
952         ecb->ahblklen = 1;
953         ecb->ahlen = BITS2BYTES(96);
954         ecb->auth = md5auth;
955         ecb->ahstate = secalloc(klen);
956         memmove(ecb->ahstate, key, klen);
957 }
958
959
960 /*
961  * des, single and triple
962  */
963
964 static int
965 descipher(Espcb *ecb, uchar *p, int n)
966 {
967         DESstate *ds = ecb->espstate;
968
969         if(ecb->incoming) {
970                 memmove(ds->ivec, p, Desblk);
971                 desCBCdecrypt(p + Desblk, n - Desblk, ds);
972         } else {
973                 memmove(p, ds->ivec, Desblk);
974                 desCBCencrypt(p + Desblk, n - Desblk, ds);
975         }
976         return 1;
977 }
978
979 static int
980 des3cipher(Espcb *ecb, uchar *p, int n)
981 {
982         DES3state *ds = ecb->espstate;
983
984         if(ecb->incoming) {
985                 memmove(ds->ivec, p, Desblk);
986                 des3CBCdecrypt(p + Desblk, n - Desblk, ds);
987         } else {
988                 memmove(p, ds->ivec, Desblk);
989                 des3CBCencrypt(p + Desblk, n - Desblk, ds);
990         }
991         return 1;
992 }
993
994 static void
995 desespinit(Espcb *ecb, char *name, uchar *k, unsigned n)
996 {
997         uchar key[Desblk], ivec[Desblk];
998
999         n = BITS2BYTES(n);
1000         if(n > Desblk)
1001                 n = Desblk;
1002         memset(key, 0, sizeof(key));
1003         memmove(key, k, n);
1004         prng(ivec, Desblk);
1005         ecb->espalg = name;
1006         ecb->espblklen = Desblk;
1007         ecb->espivlen = Desblk;
1008
1009         ecb->cipher = descipher;
1010         ecb->espstate = secalloc(sizeof(DESstate));
1011         setupDESstate(ecb->espstate, key, ivec);
1012         memset(ivec, 0, sizeof(ivec));
1013         memset(key, 0, sizeof(key));
1014 }
1015
1016 static void
1017 des3espinit(Espcb *ecb, char *name, uchar *k, unsigned n)
1018 {
1019         uchar key[3][Desblk], ivec[Desblk];
1020
1021         n = BITS2BYTES(n);
1022         if(n > Des3keysz)
1023                 n = Des3keysz;
1024         memset(key, 0, sizeof(key));
1025         memmove(key, k, n);
1026         prng(ivec, Desblk);
1027         ecb->espalg = name;
1028         ecb->espblklen = Desblk;
1029         ecb->espivlen = Desblk;
1030
1031         ecb->cipher = des3cipher;
1032         ecb->espstate = secalloc(sizeof(DES3state));
1033         setupDES3state(ecb->espstate, key, ivec);
1034         memset(ivec, 0, sizeof(ivec));
1035         memset(key, 0, sizeof(key));
1036 }
1037
1038
1039 /*
1040  * interfacing to devip
1041  */
1042 void
1043 espinit(Fs *fs)
1044 {
1045         Proto *esp;
1046
1047         esp = smalloc(sizeof(Proto));
1048         esp->priv = smalloc(sizeof(Esppriv));
1049         esp->name = "esp";
1050         esp->connect = espconnect;
1051         esp->announce = nil;
1052         esp->ctl = espctl;
1053         esp->state = espstate;
1054         esp->create = espcreate;
1055         esp->close = espclose;
1056         esp->rcv = espiput;
1057         esp->advise = espadvise;
1058         esp->stats = espstats;
1059         esp->local = esplocal;
1060         esp->remote = espremote;
1061         esp->ipproto = IP_ESPPROTO;
1062         esp->nc = Nchans;
1063         esp->ptclsize = sizeof(Espcb);
1064
1065         Fsproto(fs, esp);
1066 }