diff options
author | cinap_lenrek <cinap_lenrek@felloff.net> | 2015-12-25 17:05:05 +0100 |
---|---|---|
committer | cinap_lenrek <cinap_lenrek@felloff.net> | 2015-12-25 17:05:05 +0100 |
commit | 39f18c9d88f52a22373790dec5721fa3521d3f00 (patch) | |
tree | bfca6b71de83fa380f3b581c7e9726d955218d08 /sys | |
parent | 4a6ab355c1af789f7ddb4edbf4d82d17efd9d2bf (diff) |
libsec: implement TLS-PSK for tlsClient()/tlsServer()
Diffstat (limited to 'sys')
-rw-r--r-- | sys/include/libsec.h | 3 | ||||
-rw-r--r-- | sys/man/2/pushtls | 4 | ||||
-rw-r--r-- | sys/src/libsec/port/tlshand.c | 392 |
3 files changed, 301 insertions, 98 deletions
diff --git a/sys/include/libsec.h b/sys/include/libsec.h index 0b3ba44ac..86cb342e5 100644 --- a/sys/include/libsec.h +++ b/sys/include/libsec.h @@ -412,8 +412,10 @@ typedef struct TLSconn{ char dir[40]; /* connection directory */ uchar *cert; /* certificate (local on input, remote on output) */ uchar *sessionID; + uchar *psk; int certlen; int sessionIDlen; + int psklen; int (*trace)(char*fmt, ...); PEMChain*chain; /* optional extra certificate evidence for servers to present */ char *sessionType; @@ -421,6 +423,7 @@ typedef struct TLSconn{ int sessionKeylen; char *sessionConst; char *serverName; + char *pskID; } TLSconn; /* tlshand.c */ diff --git a/sys/man/2/pushtls b/sys/man/2/pushtls index dfa01e4dd..ed3f293dd 100644 --- a/sys/man/2/pushtls +++ b/sys/man/2/pushtls @@ -100,7 +100,8 @@ typedef struct TLSconn { char dir[40]; /* OUT connection directory */ uchar *cert; /* IN/OUT certificate */ uchar *sessionID; /* IN/OUT session ID */ - int certlen, sessionIDlen; + uchar *psk; /* opt IN pre-shared key */ + int certlen, sessionIDlen, psklen; int (*trace)(char*fmt, ...); PEMChain *chain; char *sessionType; /* opt IN session type */ @@ -108,6 +109,7 @@ typedef struct TLSconn { int sessionKeylen; /* opt IN session key length */ char *sessionConst; /* opt IN session constant */ char *serverName; /* opt IN server name */ + char *pskID; /* opt IN pre-shared key ID */ } TLSconn; .EE .PP diff --git a/sys/src/libsec/port/tlshand.c b/sys/src/libsec/port/tlshand.c index c2599b999..1a37ef94b 100644 --- a/sys/src/libsec/port/tlshand.c +++ b/sys/src/libsec/port/tlshand.c @@ -135,9 +135,11 @@ typedef struct Msg{ Bytes **cas; } certificateRequest; struct { + Bytes *pskid; Bytes *key; } clientKeyExchange; struct { + Bytes *pskid; Bytes *dh_p; Bytes *dh_g; Bytes *dh_Ys; @@ -159,6 +161,8 @@ typedef struct TlsSec{ int ok; // <0 killed; == 0 in progress; >0 reusable RSApub *rsapub; AuthRpc *rpc; // factotum for rsa private key + uchar *psk; // pre-shared key + int psklen; uchar sec[MasterSecretSize]; // master secret uchar crandom[RandomSize]; // client random uchar srandom[RandomSize]; // server random @@ -223,6 +227,7 @@ enum { EInternalError = 80, EUserCanceled = 90, ENoRenegotiation = 100, + EUnknownPSKidentity = 115, EMax = 256 }; @@ -274,6 +279,16 @@ enum { TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = 0XC013, TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = 0XC014, TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = 0xC027, + + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305 = 0xCCA8, + TLS_DHE_RSA_WITH_CHACHA20_POLY1305 = 0xCCAA, + + GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305 = 0xCC13, + GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305 = 0xCC15, + + TLS_PSK_WITH_CHACHA20_POLY1305 = 0xCCAB, + TLS_PSK_WITH_AES_128_CBC_SHA256 = 0x00AE, + TLS_PSK_WITH_AES_128_CBC_SHA = 0x008C, }; // compression methods @@ -283,10 +298,12 @@ enum { }; static Algs cipherAlgs[] = { - {"ccpoly96_aead", "clear", 2*(32+12), 0xCCA8}, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (IETF) - {"ccpoly96_aead", "clear", 2*(32+12), 0xCCAA}, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (IETF) - {"ccpoly64_aead", "clear", 2*32, 0xCC13}, // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (draft) - {"ccpoly64_aead", "clear", 2*32, 0xCC15}, // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 (draft) + {"ccpoly96_aead", "clear", 2*(32+12), TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305}, + {"ccpoly96_aead", "clear", 2*(32+12), TLS_DHE_RSA_WITH_CHACHA20_POLY1305}, + + {"ccpoly64_aead", "clear", 2*32, GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305}, + {"ccpoly64_aead", "clear", 2*32, GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305}, + {"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256}, {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}, {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA}, @@ -299,6 +316,11 @@ static Algs cipherAlgs[] = { {"aes_256_cbc", "sha1", 2*(32+16+SHA1dlen), TLS_RSA_WITH_AES_256_CBC_SHA}, {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA}, {"3des_ede_cbc","sha1", 2*(4*8+SHA1dlen), TLS_RSA_WITH_3DES_EDE_CBC_SHA}, + + // PSK cipher suits + {"ccpoly96_aead", "clear", 2*(32+12), TLS_PSK_WITH_CHACHA20_POLY1305}, + {"aes_128_cbc", "sha256", 2*(16+16+SHA2_256dlen), TLS_PSK_WITH_AES_128_CBC_SHA256}, + {"aes_128_cbc", "sha1", 2*(16+16+SHA1dlen), TLS_PSK_WITH_AES_128_CBC_SHA}, }; static uchar compressors[] = { @@ -327,8 +349,15 @@ static int sigalgs[] = { 0x0201, /* SHA1 RSA */ }; -static TlsConnection *tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chain); -static TlsConnection *tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, int (*trace)(char*fmt, ...)); +static TlsConnection *tlsServer2(int ctl, int hand, + uchar *cert, int certlen, + char *pskid, uchar *psk, int psklen, + int (*trace)(char*fmt, ...), PEMChain *chain); +static TlsConnection *tlsClient2(int ctl, int hand, + uchar *csid, int ncsid, + uchar *cert, int certlen, + char *pskid, uchar *psk, int psklen, + uchar *ext, int extlen, int (*trace)(char*fmt, ...)); static void msgClear(Msg *m); static char* msgPrint(char *buf, int n, Msg *m); static int msgRecv(TlsConnection *c, Msg *m); @@ -340,15 +369,17 @@ static int finishedMatch(TlsConnection *c, Finished *f); static void tlsConnectionFree(TlsConnection *c); static int setAlgs(TlsConnection *c, int a); -static int okCipher(Ints *cv); +static int okCipher(Ints *cv, int ispsk); static int okCompression(Bytes *cv); static int initCiphers(void); -static Ints* makeciphers(void); +static Ints* makeciphers(int ispsk); static TlsSec* tlsSecInits(int cvers, uchar *csid, int ncsid, uchar *crandom, uchar *ssid, int *nssid, uchar *srandom); static int tlsSecRSAs(TlsSec *sec, int vers, Bytes *epm); +static int tlsSecPSKs(TlsSec *sec, int vers); static TlsSec* tlsSecInitc(int cvers, uchar *crandom); static Bytes* tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers); +static int tlsSecPSKc(TlsSec *sec, uchar *srandom, int vers); static Bytes* tlsSecDHEc(TlsSec *sec, uchar *srandom, int vers, Bytes *p, Bytes *g, Bytes *Ys); static Bytes* tlsSecECDHEc(TlsSec *sec, uchar *srandom, int vers, int curve, Bytes *Ys); static int tlsSecFinished(TlsSec *sec, HandshakeHash hsh, uchar *fin, int nfin, int isclient); @@ -424,7 +455,10 @@ tlsServer(int fd, TLSconn *conn) return -1; } fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion); - tls = tlsServer2(ctl, hand, conn->cert, conn->certlen, conn->trace, conn->chain); + tls = tlsServer2(ctl, hand, + conn->cert, conn->certlen, + conn->pskID, conn->psk, conn->psklen, + conn->trace, conn->chain); snprint(dname, sizeof(dname), "#a/tls/%s/data", buf); data = open(dname, ORDWR); close(hand); @@ -435,7 +469,7 @@ tlsServer(int fd, TLSconn *conn) return -1; } free(conn->cert); - conn->cert = 0; // client certificates are not yet implemented + conn->cert = nil; // client certificates are not yet implemented conn->certlen = 0; conn->sessionIDlen = tls->sid->len; conn->sessionID = emalloc(conn->sessionIDlen); @@ -561,7 +595,10 @@ tlsClient(int fd, TLSconn *conn) } fprint(ctl, "fd %d 0x%x", fd, ProtocolVersion); ext = tlsClientExtensions(conn, &n); - tls = tlsClient2(ctl, hand, conn->sessionID, conn->sessionIDlen, conn->cert, conn->certlen, + tls = tlsClient2(ctl, hand, + conn->sessionID, conn->sessionIDlen, + conn->cert, conn->certlen, + conn->pskID, conn->psk, conn->psklen, ext, n, conn->trace); free(ext); close(hand); @@ -570,9 +607,14 @@ tlsClient(int fd, TLSconn *conn) close(data); return -1; } - conn->certlen = tls->cert->len; - conn->cert = emalloc(conn->certlen); - memcpy(conn->cert, tls->cert->data, conn->certlen); + if(tls->cert != nil){ + conn->certlen = tls->cert->len; + conn->cert = emalloc(conn->certlen); + memcpy(conn->cert, tls->cert->data, conn->certlen); + } else { + conn->certlen = 0; + conn->cert = nil; + } conn->sessionIDlen = tls->sid->len; conn->sessionID = emalloc(conn->sessionIDlen); memcpy(conn->sessionID, tls->sid->data, conn->sessionIDlen); @@ -603,7 +645,10 @@ countchain(PEMChain *p) } static TlsConnection * -tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, ...), PEMChain *chp) +tlsServer2(int ctl, int hand, + uchar *cert, int certlen, + char *pskid, uchar *psk, int psklen, + int (*trace)(char*fmt, ...), PEMChain *chp) { TlsConnection *c; Msg m; @@ -641,7 +686,7 @@ tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, . } memmove(c->crandom, m.u.clientHello.random, RandomSize); - cipher = okCipher(m.u.clientHello.ciphers); + cipher = okCipher(m.u.clientHello.ciphers, psklen > 0); if(cipher < 0) { // reply with EInsufficientSecurity if we know that's the case if(cipher == -2) @@ -662,21 +707,27 @@ tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, . csid = m.u.clientHello.sid; if(trace) - trace(" cipher %d, compressor %d, csidlen %d\n", cipher, compressor, csid->len); + trace(" cipher %x, compressor %x, csidlen %d\n", cipher, compressor, csid->len); c->sec = tlsSecInits(c->clientVersion, csid->data, csid->len, c->crandom, sid, &nsid, c->srandom); if(c->sec == nil){ tlsError(c, EHandshakeFailure, "can't initialize security: %r"); goto Err; } - c->sec->rpc = factotum_rsa_open(cert, certlen); - if(c->sec->rpc == nil){ - tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r"); - goto Err; + if(psklen > 0){ + c->sec->psk = psk; + c->sec->psklen = psklen; } - c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0); - if(c->sec->rsapub == nil){ - tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate"); - goto Err; + if(certlen > 0){ + c->sec->rpc = factotum_rsa_open(cert, certlen); + if(c->sec->rpc == nil){ + tlsError(c, EHandshakeFailure, "factotum_rsa_open: %r"); + goto Err; + } + c->sec->rsapub = X509toRSApub(cert, certlen, nil, 0); + if(c->sec->rsapub == nil){ + tlsError(c, EHandshakeFailure, "invalid X509/rsa certificate"); + goto Err; + } } msgClear(&m); @@ -691,16 +742,18 @@ tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, . goto Err; msgClear(&m); - m.tag = HCertificate; - numcerts = countchain(chp); - m.u.certificate.ncert = 1 + numcerts; - m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*)); - m.u.certificate.certs[0] = makebytes(cert, certlen); - for (i = 0; i < numcerts && chp; i++, chp = chp->next) - m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen); - if(!msgSend(c, &m, AQueue)) - goto Err; - msgClear(&m); + if(certlen > 0){ + m.tag = HCertificate; + numcerts = countchain(chp); + m.u.certificate.ncert = 1 + numcerts; + m.u.certificate.certs = emalloc(m.u.certificate.ncert * sizeof(Bytes*)); + m.u.certificate.certs[0] = makebytes(cert, certlen); + for (i = 0; i < numcerts && chp; i++, chp = chp->next) + m.u.certificate.certs[i+1] = makebytes(chp->pem, chp->pemlen); + if(!msgSend(c, &m, AQueue)) + goto Err; + msgClear(&m); + } m.tag = HServerHelloDone; if(!msgSend(c, &m, AFlush)) @@ -713,10 +766,29 @@ tlsServer2(int ctl, int hand, uchar *cert, int certlen, int (*trace)(char*fmt, . tlsError(c, EUnexpectedMessage, "expected a client key exchange"); goto Err; } - if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){ - tlsError(c, EHandshakeFailure, "couldn't set secrets: %r"); + if(pskid != nil){ + if(m.u.clientKeyExchange.pskid == nil + || m.u.clientKeyExchange.pskid->len != strlen(pskid) + || memcmp(pskid, m.u.clientKeyExchange.pskid->data, m.u.clientKeyExchange.pskid->len) != 0){ + tlsError(c, EUnknownPSKidentity, "unknown or missing pskid"); + goto Err; + } + } + if(certlen > 0){ + if(tlsSecRSAs(c->sec, c->version, m.u.clientKeyExchange.key) < 0){ + tlsError(c, EHandshakeFailure, "couldn't set secrets: %r"); + goto Err; + } + } else if(psklen > 0){ + if(tlsSecPSKs(c->sec, c->version) < 0){ + tlsError(c, EHandshakeFailure, "couldn't set secrets: %r"); + goto Err; + } + } else { + tlsError(c, EInternalError, "no psk or certificate"); goto Err; } + setSecrets(c->sec, kd, c->nsecret); if(trace) trace("tls secrets\n"); @@ -786,7 +858,8 @@ isDHE(int tlsid) case TLS_DHE_RSA_WITH_AES_128_CBC_SHA: case TLS_DHE_RSA_WITH_AES_256_CBC_SHA: case TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA: - case 0xCCAA: case 0xCC15: // TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + case TLS_DHE_RSA_WITH_CHACHA20_POLY1305: + case GOOGLE_DHE_RSA_WITH_CHACHA20_POLY1305: return 1; } return 0; @@ -799,7 +872,20 @@ isECDHE(int tlsid) case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256: case TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA: case TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: - case 0xCCA8: case 0xCC13: // TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + case TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305: + case GOOGLE_ECDHE_RSA_WITH_CHACHA20_POLY1305: + return 1; + } + return 0; +} + +static int +isPSK(int tlsid) +{ + switch(tlsid){ + case TLS_PSK_WITH_CHACHA20_POLY1305: + case TLS_PSK_WITH_AES_128_CBC_SHA256: + case TLS_PSK_WITH_AES_128_CBC_SHA: return 1; } return 0; @@ -980,8 +1066,18 @@ verifyDHparams(TlsConnection *c, Bytes *par, Bytes *sig, int sigalg) RSApub *pk; char *err; - if(sig == nil || sig->len <= 0) + if(par == nil || par->len <= 0) + return "no dh parameters"; + + if(sig == nil || sig->len <= 0){ + if(c->sec->psklen > 0) + return nil; + return "no signature"; + } + + if(c->cert == nil) + return "no certificate"; pk = X509toRSApub(c->cert->data, c->cert->len, nil, 0); if(pk == nil) @@ -1015,7 +1111,11 @@ verifyDHparams(TlsConnection *c, Bytes *par, Bytes *sig, int sigalg) } static TlsConnection * -tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, uchar *ext, int extlen, +tlsClient2(int ctl, int hand, + uchar *csid, int ncsid, + uchar *cert, int certlen, + char *pskid, uchar *psk, int psklen, + uchar *ext, int extlen, int (*trace)(char*fmt, ...)) { TlsConnection *c; @@ -1036,17 +1136,24 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, c->trace = trace; c->isClient = 1; c->clientVersion = c->version; + c->cert = nil; c->sec = tlsSecInitc(c->clientVersion, c->crandom); if(c->sec == nil) goto Err; + + if(psklen > 0){ + c->sec->psk = psk; + c->sec->psklen = psklen; + } + /* client hello */ memset(&m, 0, sizeof(m)); m.tag = HClientHello; m.u.clientHello.version = c->clientVersion; memmove(m.u.clientHello.random, c->crandom, RandomSize); m.u.clientHello.sid = makebytes(csid, ncsid); - m.u.clientHello.ciphers = makeciphers(); + m.u.clientHello.ciphers = makeciphers(psklen > 0); m.u.clientHello.compressors = makebytes(compressors,sizeof(compressors)); m.u.clientHello.extensions = makebytes(ext, extlen); if(!msgSend(c, &m, AFlush)) @@ -1071,7 +1178,7 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, goto Err; } cipher = m.u.serverHello.cipher; - if(!setAlgs(c, cipher)) { + if((psklen > 0) != isPSK(cipher) || !setAlgs(c, cipher)) { tlsError(c, EIllegalParameter, "invalid cipher suite"); goto Err; } @@ -1081,48 +1188,47 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, } msgClear(&m); - /* certificate */ - if(!msgRecv(c, &m) || m.tag != HCertificate) { - tlsError(c, EUnexpectedMessage, "expected a certificate"); - goto Err; - } - if(m.u.certificate.ncert < 1) { - tlsError(c, EIllegalParameter, "runt certificate"); - goto Err; - } - c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len); - msgClear(&m); - - /* server key exchange */ dhx = isDHE(cipher) || isECDHE(cipher); if(!msgRecv(c, &m)) goto Err; - if(m.tag == HServerKeyExchange) { - char *err; - - if(!dhx){ - tlsError(c, EUnexpectedMessage, "got an server key exchange"); + if(m.tag == HCertificate){ + if(m.u.certificate.ncert < 1) { + tlsError(c, EIllegalParameter, "runt certificate"); goto Err; } - err = verifyDHparams(c, - m.u.serverKeyExchange.dh_parameters, - m.u.serverKeyExchange.dh_signature, - m.u.serverKeyExchange.sigalg); - if(err != nil){ - tlsError(c, EBadCertificate, "can't verify dh parameters: %s", err); + c->cert = makebytes(m.u.certificate.certs[0]->data, m.u.certificate.certs[0]->len); + msgClear(&m); + if(!msgRecv(c, &m)) + goto Err; + } else if(psklen == 0) { + tlsError(c, EUnexpectedMessage, "expected a certificate"); + goto Err; + } + if(m.tag == HServerKeyExchange) { + if(dhx){ + char *err = verifyDHparams(c, + m.u.serverKeyExchange.dh_parameters, + m.u.serverKeyExchange.dh_signature, + m.u.serverKeyExchange.sigalg); + if(err != nil){ + tlsError(c, EBadCertificate, "can't verify dh parameters: %s", err); + goto Err; + } + if(isECDHE(cipher)) + epm = tlsSecECDHEc(c->sec, c->srandom, c->version, + m.u.serverKeyExchange.curve, + m.u.serverKeyExchange.dh_Ys); + else + epm = tlsSecDHEc(c->sec, c->srandom, c->version, + m.u.serverKeyExchange.dh_p, + m.u.serverKeyExchange.dh_g, + m.u.serverKeyExchange.dh_Ys); + if(epm == nil) + goto Badcert; + } else if(psklen == 0){ + tlsError(c, EUnexpectedMessage, "got an server key exchange"); goto Err; } - if(isECDHE(cipher)) - epm = tlsSecECDHEc(c->sec, c->srandom, c->version, - m.u.serverKeyExchange.curve, - m.u.serverKeyExchange.dh_Ys); - else - epm = tlsSecDHEc(c->sec, c->srandom, c->version, - m.u.serverKeyExchange.dh_p, - m.u.serverKeyExchange.dh_g, - m.u.serverKeyExchange.dh_Ys); - if(epm == nil) - goto Badcert; msgClear(&m); if(!msgRecv(c, &m)) goto Err; @@ -1146,14 +1252,22 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, } msgClear(&m); - if(!dhx) - epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom, - c->cert->data, c->cert->len, c->version); - - if(epm == nil){ - Badcert: - tlsError(c, EBadCertificate, "bad certificate: %r"); - goto Err; + if(!dhx){ + if(c->cert != nil){ + epm = tlsSecRSAc(c->sec, c->sid->data, c->sid->len, c->srandom, + c->cert->data, c->cert->len, c->version); + if(epm == nil){ + Badcert: + tlsError(c, EBadCertificate, "bad certificate: %r"); + goto Err; + } + } else if(psklen > 0) { + if(tlsSecPSKc(c->sec, c->srandom, c->version) < 0) + goto Badcert; + } else { + tlsError(c, EInternalError, "no psk or certificate"); + goto Err; + } } setSecrets(c->sec, kd, c->nsecret); @@ -1182,12 +1296,13 @@ tlsClient2(int ctl, int hand, uchar *csid, int ncsid, uchar *cert, int certlen, /* client key exchange */ m.tag = HClientKeyExchange; + if(psklen > 0){ + if(pskid == nil) + pskid = ""; + m.u.clientKeyExchange.pskid = makebytes((uchar*)pskid, strlen(pskid)); + } m.u.clientKeyExchange.key = epm; epm = nil; - if(m.u.clientKeyExchange.key == nil) { - tlsError(c, EHandshakeFailure, "can't set secret: %r"); - goto Err; - } if(!msgSend(c, &m, AFlush)) goto Err; @@ -1423,8 +1538,17 @@ msgSend(TlsConnection *c, Msg *m, int act) p += 2; memmove(p, m->u.certificateVerify.signature->data, m->u.certificateVerify.signature->len); p += m->u.certificateVerify.signature->len; - break; + break; case HClientKeyExchange: + if(m->u.clientKeyExchange.pskid != nil){ + n = m->u.clientKeyExchange.pskid->len; + put16(p, n); + p += 2; + memmove(p, m->u.clientKeyExchange.pskid->data, n); + p += n; + } + if(m->u.clientKeyExchange.key == nil) + break; n = m->u.clientKeyExchange.key->len; if(c->version != SSL3Version){ if(isECDHE(c->cipher)) @@ -1737,6 +1861,18 @@ msgRecv(TlsConnection *c, Msg *m) case HServerHelloDone: break; case HServerKeyExchange: + if(isPSK(c->cipher)){ + if(n < 2) + goto Short; + nn = get16(p); + p += 2, n -= 2; + if(nn > n) + goto Short; + m->u.serverKeyExchange.pskid = makebytes(p, nn); + p += nn, n -= nn; + if(n == 0) + break; + } if(n < 2) goto Short; s = p; @@ -1805,6 +1941,18 @@ msgRecv(TlsConnection *c, Msg *m) * this message depends upon the encryption selected * assume rsa. */ + if(isPSK(c->cipher)){ + if(n < 2) + goto Short; + nn = get16(p); + p += 2, n -= 2; + if(nn > n) + goto Short; + m->u.clientKeyExchange.pskid = makebytes(p, nn); + p += nn, n -= nn; + if(n == 0) + break; + } if(c->version == SSL3Version) nn = n; else{ @@ -1883,6 +2031,7 @@ msgClear(Msg *m) case HServerHelloDone: break; case HServerKeyExchange: + freebytes(m->u.serverKeyExchange.pskid); freebytes(m->u.serverKeyExchange.dh_p); freebytes(m->u.serverKeyExchange.dh_g); freebytes(m->u.serverKeyExchange.dh_Ys); @@ -1890,6 +2039,7 @@ msgClear(Msg *m) freebytes(m->u.serverKeyExchange.dh_signature); break; case HClientKeyExchange: + freebytes(m->u.clientKeyExchange.pskid); freebytes(m->u.clientKeyExchange.key); break; case HFinished: @@ -1998,6 +2148,10 @@ msgPrint(char *buf, int n, Msg *m) break; case HServerKeyExchange: bs = seprint(bs, be, "HServerKeyExchange\n"); + if(m->u.serverKeyExchange.pskid != nil) + bs = bytesPrint(bs, be, "\tpskid: ", m->u.serverKeyExchange.pskid, "\n"); + if(m->u.serverKeyExchange.dh_parameters == nil) + break; if(m->u.serverKeyExchange.curve != 0){ bs = seprint(bs, be, "\tcurve: %.4x\n", m->u.serverKeyExchange.curve); } else { @@ -2012,7 +2166,10 @@ msgPrint(char *buf, int n, Msg *m) break; case HClientKeyExchange: bs = seprint(bs, be, "HClientKeyExchange\n"); - bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n"); + if(m->u.clientKeyExchange.pskid != nil) + bs = bytesPrint(bs, be, "\tpskid: ", m->u.clientKeyExchange.pskid, "\n"); + if(m->u.clientKeyExchange.key != nil) + bs = bytesPrint(bs, be, "\tkey: ", m->u.clientKeyExchange.key, "\n"); break; case HFinished: bs = seprint(bs, be, "HFinished\n"); @@ -2137,7 +2294,7 @@ setAlgs(TlsConnection *c, int a) } static int -okCipher(Ints *cv) +okCipher(Ints *cv, int ispsk) { int weak, i, j, c; @@ -2148,6 +2305,8 @@ okCipher(Ints *cv) weak = 0; else weak &= weakCipher[c]; + if(isPSK(c) != ispsk) + continue; if(isDHE(c) || isECDHE(c)) continue; /* TODO: not implemented for server */ for(j = 0; j < nelem(cipherAlgs); j++) @@ -2243,17 +2402,17 @@ initCiphers(void) } static Ints* -makeciphers(void) +makeciphers(int ispsk) { Ints *is; int i, j; is = newints(nciphers); j = 0; - for(i = 0; i < nelem(cipherAlgs); i++){ - if(cipherAlgs[i].ok) + for(i = 0; i < nelem(cipherAlgs); i++) + if(cipherAlgs[i].ok && isPSK(cipherAlgs[i].tlsid) == ispsk) is->data[j++] = cipherAlgs[i].tlsid; - } + is->len = j; return is; } @@ -2489,6 +2648,17 @@ Err: return -1; } +static int +tlsSecPSKs(TlsSec *sec, int vers) +{ + if(setVers(sec, vers) < 0){ + sec->ok = -1; + return -1; + } + setMasterSecret(sec, newbytes(sec->psklen)); + return 0; +} + static TlsSec* tlsSecInitc(int cvers, uchar *crandom) { @@ -2500,6 +2670,18 @@ tlsSecInitc(int cvers, uchar *crandom) return sec; } +static int +tlsSecPSKc(TlsSec *sec, uchar *srandom, int vers) +{ + memmove(sec->srandom, srandom, RandomSize); + if(setVers(sec, vers) < 0){ + sec->ok = -1; + return -1; + } + setMasterSecret(sec, newbytes(sec->psklen)); + return 0; +} + static Bytes* tlsSecRSAc(TlsSec *sec, uchar *sid, int nsid, uchar *srandom, uchar *cert, int ncert, int vers) { @@ -2608,6 +2790,22 @@ setSecrets(TlsSec *sec, uchar *kd, int nkd) static void setMasterSecret(TlsSec *sec, Bytes *pm) { + if(sec->psklen > 0){ + Bytes *opm = pm; + uchar *p; + + /* concatenate psk to pre-master secret */ + pm = newbytes(4 + opm->len + sec->psklen); + p = pm->data; + put16(p, opm->len), p += 2; + memmove(p, opm->data, opm->len), p += opm->len; + put16(p, sec->psklen), p += 2; + memmove(p, sec->psk, sec->psklen); + + memset(opm->data, 0, opm->len); + freebytes(opm); + } + (*sec->prf)(sec->sec, MasterSecretSize, pm->data, pm->len, "master secret", sec->crandom, RandomSize, sec->srandom, RandomSize); |