1
2
3
4
5
6
7 package tls
8
9 import (
10 "bytes"
11 "context"
12 "crypto/cipher"
13 "crypto/subtle"
14 "crypto/x509"
15 "errors"
16 "fmt"
17 "hash"
18 "internal/godebug"
19 "io"
20 "net"
21 "sync"
22 "sync/atomic"
23 "time"
24 )
25
26
27
28 type Conn struct {
29
30 conn net.Conn
31 isClient bool
32 handshakeFn func(context.Context) error
33 quic *quicState
34
35
36
37
38 isHandshakeComplete atomic.Bool
39
40 handshakeMutex sync.Mutex
41 handshakeErr error
42 vers uint16
43 haveVers bool
44 config *Config
45
46
47
48 handshakes int
49 extMasterSecret bool
50 didResume bool
51 didHRR bool
52 cipherSuite uint16
53 curveID CurveID
54 peerSigAlg SignatureScheme
55 ocspResponse []byte
56 scts [][]byte
57 peerCertificates []*x509.Certificate
58
59
60 verifiedChains [][]*x509.Certificate
61
62 serverName string
63
64
65
66 secureRenegotiation bool
67
68 ekm func(label string, context []byte, length int) ([]byte, error)
69
70
71 resumptionSecret []byte
72 echAccepted bool
73
74
75
76
77 ticketKeys []ticketKey
78
79
80
81
82
83 clientFinishedIsFirst bool
84
85
86 closeNotifyErr error
87
88
89 closeNotifySent bool
90
91
92
93
94
95 clientFinished [12]byte
96 serverFinished [12]byte
97
98
99 clientProtocol string
100
101
102 in, out halfConn
103 rawInput bytes.Buffer
104 input bytes.Reader
105 hand bytes.Buffer
106 buffering bool
107 sendBuf []byte
108
109
110
111 bytesSent int64
112 packetsSent int64
113
114
115
116
117 retryCount int
118
119
120
121 activeCall atomic.Int32
122
123 tmp [16]byte
124 }
125
126
127
128
129
130
131 func (c *Conn) LocalAddr() net.Addr {
132 return c.conn.LocalAddr()
133 }
134
135
136 func (c *Conn) RemoteAddr() net.Addr {
137 return c.conn.RemoteAddr()
138 }
139
140
141
142
143 func (c *Conn) SetDeadline(t time.Time) error {
144 return c.conn.SetDeadline(t)
145 }
146
147
148
149 func (c *Conn) SetReadDeadline(t time.Time) error {
150 return c.conn.SetReadDeadline(t)
151 }
152
153
154
155
156 func (c *Conn) SetWriteDeadline(t time.Time) error {
157 return c.conn.SetWriteDeadline(t)
158 }
159
160
161
162
163 func (c *Conn) NetConn() net.Conn {
164 return c.conn
165 }
166
167
168
169 type halfConn struct {
170 sync.Mutex
171
172 err error
173 version uint16
174 cipher any
175 mac hash.Hash
176 seq [8]byte
177
178 scratchBuf [13]byte
179
180 nextCipher any
181 nextMac hash.Hash
182
183 level QUICEncryptionLevel
184 trafficSecret []byte
185 }
186
187 type permanentError struct {
188 err net.Error
189 }
190
191 func (e *permanentError) Error() string { return e.err.Error() }
192 func (e *permanentError) Unwrap() error { return e.err }
193 func (e *permanentError) Timeout() bool { return e.err.Timeout() }
194 func (e *permanentError) Temporary() bool { return false }
195
196 func (hc *halfConn) setErrorLocked(err error) error {
197 if e, ok := err.(net.Error); ok {
198 hc.err = &permanentError{err: e}
199 } else {
200 hc.err = err
201 }
202 return hc.err
203 }
204
205
206
207 func (hc *halfConn) prepareCipherSpec(version uint16, cipher any, mac hash.Hash) {
208 hc.version = version
209 hc.nextCipher = cipher
210 hc.nextMac = mac
211 }
212
213
214
215 func (hc *halfConn) changeCipherSpec() error {
216 if hc.nextCipher == nil || hc.version == VersionTLS13 {
217 return alertInternalError
218 }
219 hc.cipher = hc.nextCipher
220 hc.mac = hc.nextMac
221 hc.nextCipher = nil
222 hc.nextMac = nil
223 clear(hc.seq[:])
224 return nil
225 }
226
227
228
229
230 func (hc *halfConn) setTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
231 hc.trafficSecret = secret
232 hc.level = level
233 key, iv := suite.trafficKey(secret)
234 hc.cipher = suite.aead(key, iv)
235 clear(hc.seq[:])
236 }
237
238
239 func (hc *halfConn) incSeq() {
240 for i := 7; i >= 0; i-- {
241 hc.seq[i]++
242 if hc.seq[i] != 0 {
243 return
244 }
245 }
246
247
248
249
250 panic("TLS: sequence number wraparound")
251 }
252
253
254
255
256 func (hc *halfConn) explicitNonceLen() int {
257 if hc.cipher == nil {
258 return 0
259 }
260
261 switch c := hc.cipher.(type) {
262 case cipher.Stream:
263 return 0
264 case aead:
265 return c.explicitNonceLen()
266 case cbcMode:
267
268 if hc.version >= VersionTLS11 {
269 return c.BlockSize()
270 }
271 return 0
272 default:
273 panic("unknown cipher type")
274 }
275 }
276
277
278
279
280 func extractPadding(payload []byte) (toRemove int, good byte) {
281 if len(payload) < 1 {
282 return 0, 0
283 }
284
285 paddingLen := payload[len(payload)-1]
286 t := uint(len(payload)-1) - uint(paddingLen)
287
288 good = byte(int32(^t) >> 31)
289
290
291 toCheck := 256
292
293 if toCheck > len(payload) {
294 toCheck = len(payload)
295 }
296
297 for i := 0; i < toCheck; i++ {
298 t := uint(paddingLen) - uint(i)
299
300 mask := byte(int32(^t) >> 31)
301 b := payload[len(payload)-1-i]
302 good &^= mask&paddingLen ^ mask&b
303 }
304
305
306
307 good &= good << 4
308 good &= good << 2
309 good &= good << 1
310 good = uint8(int8(good) >> 7)
311
312
313
314
315
316
317
318
319
320
321 paddingLen &= good
322
323 toRemove = int(paddingLen) + 1
324 return
325 }
326
327 func roundUp(a, b int) int {
328 return a + (b-a%b)%b
329 }
330
331
332 type cbcMode interface {
333 cipher.BlockMode
334 SetIV([]byte)
335 }
336
337
338
339 func (hc *halfConn) decrypt(record []byte) ([]byte, recordType, error) {
340 var plaintext []byte
341 typ := recordType(record[0])
342 payload := record[recordHeaderLen:]
343
344
345
346 if hc.version == VersionTLS13 && typ == recordTypeChangeCipherSpec {
347 return payload, typ, nil
348 }
349
350 paddingGood := byte(255)
351 paddingLen := 0
352
353 explicitNonceLen := hc.explicitNonceLen()
354
355 if hc.cipher != nil {
356 switch c := hc.cipher.(type) {
357 case cipher.Stream:
358 c.XORKeyStream(payload, payload)
359 case aead:
360 if len(payload) < explicitNonceLen {
361 return nil, 0, alertBadRecordMAC
362 }
363 nonce := payload[:explicitNonceLen]
364 if len(nonce) == 0 {
365 nonce = hc.seq[:]
366 }
367 payload = payload[explicitNonceLen:]
368
369 var additionalData []byte
370 if hc.version == VersionTLS13 {
371 additionalData = record[:recordHeaderLen]
372 } else {
373 additionalData = append(hc.scratchBuf[:0], hc.seq[:]...)
374 additionalData = append(additionalData, record[:3]...)
375 n := len(payload) - c.Overhead()
376 additionalData = append(additionalData, byte(n>>8), byte(n))
377 }
378
379 var err error
380 plaintext, err = c.Open(payload[:0], nonce, payload, additionalData)
381 if err != nil {
382 return nil, 0, alertBadRecordMAC
383 }
384 case cbcMode:
385 blockSize := c.BlockSize()
386 minPayload := explicitNonceLen + roundUp(hc.mac.Size()+1, blockSize)
387 if len(payload)%blockSize != 0 || len(payload) < minPayload {
388 return nil, 0, alertBadRecordMAC
389 }
390
391 if explicitNonceLen > 0 {
392 c.SetIV(payload[:explicitNonceLen])
393 payload = payload[explicitNonceLen:]
394 }
395 c.CryptBlocks(payload, payload)
396
397
398
399
400
401
402
403 paddingLen, paddingGood = extractPadding(payload)
404 default:
405 panic("unknown cipher type")
406 }
407
408 if hc.version == VersionTLS13 {
409 if typ != recordTypeApplicationData {
410 return nil, 0, alertUnexpectedMessage
411 }
412 if len(plaintext) > maxPlaintext+1 {
413 return nil, 0, alertRecordOverflow
414 }
415
416 for i := len(plaintext) - 1; i >= 0; i-- {
417 if plaintext[i] != 0 {
418 typ = recordType(plaintext[i])
419 plaintext = plaintext[:i]
420 break
421 }
422 if i == 0 {
423 return nil, 0, alertUnexpectedMessage
424 }
425 }
426 }
427 } else {
428 plaintext = payload
429 }
430
431 if hc.mac != nil {
432 macSize := hc.mac.Size()
433 if len(payload) < macSize {
434 return nil, 0, alertBadRecordMAC
435 }
436
437 n := len(payload) - macSize - paddingLen
438 n = subtle.ConstantTimeSelect(int(uint32(n)>>31), 0, n)
439 record[3] = byte(n >> 8)
440 record[4] = byte(n)
441 remoteMAC := payload[n : n+macSize]
442 localMAC := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload[:n], payload[n+macSize:])
443
444
445
446
447
448
449
450
451 macAndPaddingGood := subtle.ConstantTimeCompare(localMAC, remoteMAC) & int(paddingGood)
452 if macAndPaddingGood != 1 {
453 return nil, 0, alertBadRecordMAC
454 }
455
456 plaintext = payload[:n]
457 }
458
459 hc.incSeq()
460 return plaintext, typ, nil
461 }
462
463
464
465
466 func sliceForAppend(in []byte, n int) (head, tail []byte) {
467 if total := len(in) + n; cap(in) >= total {
468 head = in[:total]
469 } else {
470 head = make([]byte, total)
471 copy(head, in)
472 }
473 tail = head[len(in):]
474 return
475 }
476
477
478
479 func (hc *halfConn) encrypt(record, payload []byte, rand io.Reader) ([]byte, error) {
480 if hc.cipher == nil {
481 return append(record, payload...), nil
482 }
483
484 var explicitNonce []byte
485 if explicitNonceLen := hc.explicitNonceLen(); explicitNonceLen > 0 {
486 record, explicitNonce = sliceForAppend(record, explicitNonceLen)
487 if _, isCBC := hc.cipher.(cbcMode); !isCBC && explicitNonceLen < 16 {
488
489
490
491
492
493
494
495
496
497 copy(explicitNonce, hc.seq[:])
498 } else {
499 if _, err := io.ReadFull(rand, explicitNonce); err != nil {
500 return nil, err
501 }
502 }
503 }
504
505 var dst []byte
506 switch c := hc.cipher.(type) {
507 case cipher.Stream:
508 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
509 record, dst = sliceForAppend(record, len(payload)+len(mac))
510 c.XORKeyStream(dst[:len(payload)], payload)
511 c.XORKeyStream(dst[len(payload):], mac)
512 case aead:
513 nonce := explicitNonce
514 if len(nonce) == 0 {
515 nonce = hc.seq[:]
516 }
517
518 if hc.version == VersionTLS13 {
519 record = append(record, payload...)
520
521
522 record = append(record, record[0])
523 record[0] = byte(recordTypeApplicationData)
524
525 n := len(payload) + 1 + c.Overhead()
526 record[3] = byte(n >> 8)
527 record[4] = byte(n)
528
529 record = c.Seal(record[:recordHeaderLen],
530 nonce, record[recordHeaderLen:], record[:recordHeaderLen])
531 } else {
532 additionalData := append(hc.scratchBuf[:0], hc.seq[:]...)
533 additionalData = append(additionalData, record[:recordHeaderLen]...)
534 record = c.Seal(record, nonce, payload, additionalData)
535 }
536 case cbcMode:
537 mac := tls10MAC(hc.mac, hc.scratchBuf[:0], hc.seq[:], record[:recordHeaderLen], payload, nil)
538 blockSize := c.BlockSize()
539 plaintextLen := len(payload) + len(mac)
540 paddingLen := blockSize - plaintextLen%blockSize
541 record, dst = sliceForAppend(record, plaintextLen+paddingLen)
542 copy(dst, payload)
543 copy(dst[len(payload):], mac)
544 for i := plaintextLen; i < len(dst); i++ {
545 dst[i] = byte(paddingLen - 1)
546 }
547 if len(explicitNonce) > 0 {
548 c.SetIV(explicitNonce)
549 }
550 c.CryptBlocks(dst, dst)
551 default:
552 panic("unknown cipher type")
553 }
554
555
556 n := len(record) - recordHeaderLen
557 record[3] = byte(n >> 8)
558 record[4] = byte(n)
559 hc.incSeq()
560
561 return record, nil
562 }
563
564
565 type RecordHeaderError struct {
566
567 Msg string
568
569
570 RecordHeader [5]byte
571
572
573
574
575 Conn net.Conn
576 }
577
578 func (e RecordHeaderError) Error() string { return "tls: " + e.Msg }
579
580 func (c *Conn) newRecordHeaderError(conn net.Conn, msg string) (err RecordHeaderError) {
581 err.Msg = msg
582 err.Conn = conn
583 copy(err.RecordHeader[:], c.rawInput.Bytes())
584 return err
585 }
586
587 func (c *Conn) readRecord() error {
588 return c.readRecordOrCCS(false)
589 }
590
591 func (c *Conn) readChangeCipherSpec() error {
592 return c.readRecordOrCCS(true)
593 }
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609 func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
610 if c.in.err != nil {
611 return c.in.err
612 }
613 handshakeComplete := c.isHandshakeComplete.Load()
614
615
616 if c.input.Len() != 0 {
617 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with pending application data"))
618 }
619 c.input.Reset(nil)
620
621 if c.quic != nil {
622 return c.in.setErrorLocked(errors.New("tls: internal error: attempted to read record with QUIC transport"))
623 }
624
625
626 if err := c.readFromUntil(c.conn, recordHeaderLen); err != nil {
627
628
629
630 if err == io.ErrUnexpectedEOF && c.rawInput.Len() == 0 {
631 err = io.EOF
632 }
633 if e, ok := err.(net.Error); !ok || !e.Temporary() {
634 c.in.setErrorLocked(err)
635 }
636 return err
637 }
638 hdr := c.rawInput.Bytes()[:recordHeaderLen]
639 typ := recordType(hdr[0])
640
641
642
643
644
645 if !handshakeComplete && typ == 0x80 {
646 c.sendAlert(alertProtocolVersion)
647 return c.in.setErrorLocked(c.newRecordHeaderError(nil, "unsupported SSLv2 handshake received"))
648 }
649
650 vers := uint16(hdr[1])<<8 | uint16(hdr[2])
651 expectedVers := c.vers
652 if expectedVers == VersionTLS13 {
653
654
655 expectedVers = VersionTLS12
656 }
657 n := int(hdr[3])<<8 | int(hdr[4])
658 if c.haveVers && vers != expectedVers {
659 c.sendAlert(alertProtocolVersion)
660 msg := fmt.Sprintf("received record with version %x when expecting version %x", vers, expectedVers)
661 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
662 }
663 if !c.haveVers {
664
665
666
667
668 if (typ != recordTypeAlert && typ != recordTypeHandshake) || vers >= 0x1000 {
669 return c.in.setErrorLocked(c.newRecordHeaderError(c.conn, "first record does not look like a TLS handshake"))
670 }
671 }
672 if c.vers == VersionTLS13 && n > maxCiphertextTLS13 || n > maxCiphertext {
673 c.sendAlert(alertRecordOverflow)
674 msg := fmt.Sprintf("oversized record received with length %d", n)
675 return c.in.setErrorLocked(c.newRecordHeaderError(nil, msg))
676 }
677 if err := c.readFromUntil(c.conn, recordHeaderLen+n); err != nil {
678 if e, ok := err.(net.Error); !ok || !e.Temporary() {
679 c.in.setErrorLocked(err)
680 }
681 return err
682 }
683
684
685 record := c.rawInput.Next(recordHeaderLen + n)
686 data, typ, err := c.in.decrypt(record)
687 if err != nil {
688 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
689 }
690 if len(data) > maxPlaintext {
691 return c.in.setErrorLocked(c.sendAlert(alertRecordOverflow))
692 }
693
694
695 if c.in.cipher == nil && typ == recordTypeApplicationData {
696 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
697 }
698
699 if typ != recordTypeAlert && typ != recordTypeChangeCipherSpec && len(data) > 0 {
700
701 c.retryCount = 0
702 }
703
704
705 if c.vers == VersionTLS13 && typ != recordTypeHandshake && c.hand.Len() > 0 {
706 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
707 }
708
709 switch typ {
710 default:
711 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
712
713 case recordTypeAlert:
714 if c.quic != nil {
715 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
716 }
717 if len(data) != 2 {
718 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
719 }
720 if alert(data[1]) == alertCloseNotify {
721 return c.in.setErrorLocked(io.EOF)
722 }
723 if c.vers == VersionTLS13 {
724
725
726
727
728
729 if alert(data[1]) == alertUserCanceled {
730
731 return c.retryReadRecord(expectChangeCipherSpec)
732 }
733 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
734 }
735 switch data[0] {
736 case alertLevelWarning:
737
738 return c.retryReadRecord(expectChangeCipherSpec)
739 case alertLevelError:
740 return c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])})
741 default:
742 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
743 }
744
745 case recordTypeChangeCipherSpec:
746 if len(data) != 1 || data[0] != 1 {
747 return c.in.setErrorLocked(c.sendAlert(alertDecodeError))
748 }
749
750 if c.hand.Len() > 0 {
751 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
752 }
753
754
755
756
757
758 if c.vers == VersionTLS13 {
759 return c.retryReadRecord(expectChangeCipherSpec)
760 }
761 if !expectChangeCipherSpec {
762 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
763 }
764 if err := c.in.changeCipherSpec(); err != nil {
765 return c.in.setErrorLocked(c.sendAlert(err.(alert)))
766 }
767
768 case recordTypeApplicationData:
769 if !handshakeComplete || expectChangeCipherSpec {
770 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
771 }
772
773
774 if len(data) == 0 {
775 return c.retryReadRecord(expectChangeCipherSpec)
776 }
777
778
779
780 c.input.Reset(data)
781
782 case recordTypeHandshake:
783 if len(data) == 0 || expectChangeCipherSpec {
784 return c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
785 }
786 c.hand.Write(data)
787 }
788
789 return nil
790 }
791
792
793
794 func (c *Conn) retryReadRecord(expectChangeCipherSpec bool) error {
795 c.retryCount++
796 if c.retryCount > maxUselessRecords {
797 c.sendAlert(alertUnexpectedMessage)
798 return c.in.setErrorLocked(errors.New("tls: too many ignored records"))
799 }
800 return c.readRecordOrCCS(expectChangeCipherSpec)
801 }
802
803
804
805
806 type atLeastReader struct {
807 R io.Reader
808 N int64
809 }
810
811 func (r *atLeastReader) Read(p []byte) (int, error) {
812 if r.N <= 0 {
813 return 0, io.EOF
814 }
815 n, err := r.R.Read(p)
816 r.N -= int64(n)
817 if r.N > 0 && err == io.EOF {
818 return n, io.ErrUnexpectedEOF
819 }
820 if r.N <= 0 && err == nil {
821 return n, io.EOF
822 }
823 return n, err
824 }
825
826
827
828 func (c *Conn) readFromUntil(r io.Reader, n int) error {
829 if c.rawInput.Len() >= n {
830 return nil
831 }
832 needs := n - c.rawInput.Len()
833
834
835
836 c.rawInput.Grow(needs + bytes.MinRead)
837 _, err := c.rawInput.ReadFrom(&atLeastReader{r, int64(needs)})
838 return err
839 }
840
841
842 func (c *Conn) sendAlertLocked(err alert) error {
843 if c.quic != nil {
844 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
845 }
846
847 switch err {
848 case alertNoRenegotiation, alertCloseNotify:
849 c.tmp[0] = alertLevelWarning
850 default:
851 c.tmp[0] = alertLevelError
852 }
853 c.tmp[1] = byte(err)
854
855 _, writeErr := c.writeRecordLocked(recordTypeAlert, c.tmp[0:2])
856 if err == alertCloseNotify {
857
858 return writeErr
859 }
860
861 return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err})
862 }
863
864
865 func (c *Conn) sendAlert(err alert) error {
866 c.out.Lock()
867 defer c.out.Unlock()
868 return c.sendAlertLocked(err)
869 }
870
871 const (
872
873
874
875
876
877 tcpMSSEstimate = 1208
878
879
880
881
882 recordSizeBoostThreshold = 128 * 1024
883 )
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901 func (c *Conn) maxPayloadSizeForWrite(typ recordType) int {
902 if c.config.DynamicRecordSizingDisabled || typ != recordTypeApplicationData {
903 return maxPlaintext
904 }
905
906 if c.bytesSent >= recordSizeBoostThreshold {
907 return maxPlaintext
908 }
909
910
911 payloadBytes := tcpMSSEstimate - recordHeaderLen - c.out.explicitNonceLen()
912 if c.out.cipher != nil {
913 switch ciph := c.out.cipher.(type) {
914 case cipher.Stream:
915 payloadBytes -= c.out.mac.Size()
916 case cipher.AEAD:
917 payloadBytes -= ciph.Overhead()
918 case cbcMode:
919 blockSize := ciph.BlockSize()
920
921
922 payloadBytes = (payloadBytes & ^(blockSize - 1)) - 1
923
924
925 payloadBytes -= c.out.mac.Size()
926 default:
927 panic("unknown cipher type")
928 }
929 }
930 if c.vers == VersionTLS13 {
931 payloadBytes--
932 }
933
934
935 pkt := c.packetsSent
936 c.packetsSent++
937 if pkt > 1000 {
938 return maxPlaintext
939 }
940
941 n := payloadBytes * int(pkt+1)
942 if n > maxPlaintext {
943 n = maxPlaintext
944 }
945 return n
946 }
947
948 func (c *Conn) write(data []byte) (int, error) {
949 if c.buffering {
950 c.sendBuf = append(c.sendBuf, data...)
951 return len(data), nil
952 }
953
954 n, err := c.conn.Write(data)
955 c.bytesSent += int64(n)
956 return n, err
957 }
958
959 func (c *Conn) flush() (int, error) {
960 if len(c.sendBuf) == 0 {
961 return 0, nil
962 }
963
964 n, err := c.conn.Write(c.sendBuf)
965 c.bytesSent += int64(n)
966 c.sendBuf = nil
967 c.buffering = false
968 return n, err
969 }
970
971
972 var outBufPool = sync.Pool{
973 New: func() any {
974 return new([]byte)
975 },
976 }
977
978
979
980 func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
981 if c.quic != nil {
982 if typ != recordTypeHandshake {
983 return 0, errors.New("tls: internal error: sending non-handshake message to QUIC transport")
984 }
985 c.quicWriteCryptoData(c.out.level, data)
986 if !c.buffering {
987 if _, err := c.flush(); err != nil {
988 return 0, err
989 }
990 }
991 return len(data), nil
992 }
993
994 outBufPtr := outBufPool.Get().(*[]byte)
995 outBuf := *outBufPtr
996 defer func() {
997
998
999
1000
1001
1002 *outBufPtr = outBuf
1003 outBufPool.Put(outBufPtr)
1004 }()
1005
1006 var n int
1007 for len(data) > 0 {
1008 m := len(data)
1009 if maxPayload := c.maxPayloadSizeForWrite(typ); m > maxPayload {
1010 m = maxPayload
1011 }
1012
1013 _, outBuf = sliceForAppend(outBuf[:0], recordHeaderLen)
1014 outBuf[0] = byte(typ)
1015 vers := c.vers
1016 if vers == 0 {
1017
1018
1019 vers = VersionTLS10
1020 } else if vers == VersionTLS13 {
1021
1022
1023 vers = VersionTLS12
1024 }
1025 outBuf[1] = byte(vers >> 8)
1026 outBuf[2] = byte(vers)
1027 outBuf[3] = byte(m >> 8)
1028 outBuf[4] = byte(m)
1029
1030 var err error
1031 outBuf, err = c.out.encrypt(outBuf, data[:m], c.config.rand())
1032 if err != nil {
1033 return n, err
1034 }
1035 if _, err := c.write(outBuf); err != nil {
1036 return n, err
1037 }
1038 n += m
1039 data = data[m:]
1040 }
1041
1042 if typ == recordTypeChangeCipherSpec && c.vers != VersionTLS13 {
1043 if err := c.out.changeCipherSpec(); err != nil {
1044 return n, c.sendAlertLocked(err.(alert))
1045 }
1046 }
1047
1048 return n, nil
1049 }
1050
1051
1052
1053
1054 func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
1055 c.out.Lock()
1056 defer c.out.Unlock()
1057
1058 data, err := msg.marshal()
1059 if err != nil {
1060 return 0, err
1061 }
1062 if transcript != nil {
1063 transcript.Write(data)
1064 }
1065
1066 return c.writeRecordLocked(recordTypeHandshake, data)
1067 }
1068
1069
1070
1071 func (c *Conn) writeChangeCipherRecord() error {
1072 c.out.Lock()
1073 defer c.out.Unlock()
1074 _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
1075 return err
1076 }
1077
1078
1079 func (c *Conn) readHandshakeBytes(n int) error {
1080 if c.quic != nil {
1081 return c.quicReadHandshakeBytes(n)
1082 }
1083 for c.hand.Len() < n {
1084 if err := c.readRecord(); err != nil {
1085 return err
1086 }
1087 }
1088 return nil
1089 }
1090
1091
1092
1093
1094 func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
1095 if err := c.readHandshakeBytes(4); err != nil {
1096 return nil, err
1097 }
1098 data := c.hand.Bytes()
1099
1100 maxHandshakeSize := maxHandshake
1101
1102
1103
1104 if c.haveVers && data[0] == typeCertificate {
1105
1106
1107
1108 maxHandshakeSize = maxHandshakeCertificateMsg
1109 }
1110
1111 n := int(data[1])<<16 | int(data[2])<<8 | int(data[3])
1112 if n > maxHandshakeSize {
1113 c.sendAlertLocked(alertInternalError)
1114 return nil, c.in.setErrorLocked(fmt.Errorf("tls: handshake message of length %d bytes exceeds maximum of %d bytes", n, maxHandshakeSize))
1115 }
1116 if err := c.readHandshakeBytes(4 + n); err != nil {
1117 return nil, err
1118 }
1119 data = c.hand.Next(4 + n)
1120 return c.unmarshalHandshakeMessage(data, transcript)
1121 }
1122
1123 func (c *Conn) unmarshalHandshakeMessage(data []byte, transcript transcriptHash) (handshakeMessage, error) {
1124 var m handshakeMessage
1125 switch data[0] {
1126 case typeHelloRequest:
1127 m = new(helloRequestMsg)
1128 case typeClientHello:
1129 m = new(clientHelloMsg)
1130 case typeServerHello:
1131 m = new(serverHelloMsg)
1132 case typeNewSessionTicket:
1133 if c.vers == VersionTLS13 {
1134 m = new(newSessionTicketMsgTLS13)
1135 } else {
1136 m = new(newSessionTicketMsg)
1137 }
1138 case typeCertificate:
1139 if c.vers == VersionTLS13 {
1140 m = new(certificateMsgTLS13)
1141 } else {
1142 m = new(certificateMsg)
1143 }
1144 case typeCertificateRequest:
1145 if c.vers == VersionTLS13 {
1146 m = new(certificateRequestMsgTLS13)
1147 } else {
1148 m = &certificateRequestMsg{
1149 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1150 }
1151 }
1152 case typeCertificateStatus:
1153 m = new(certificateStatusMsg)
1154 case typeServerKeyExchange:
1155 m = new(serverKeyExchangeMsg)
1156 case typeServerHelloDone:
1157 m = new(serverHelloDoneMsg)
1158 case typeClientKeyExchange:
1159 m = new(clientKeyExchangeMsg)
1160 case typeCertificateVerify:
1161 m = &certificateVerifyMsg{
1162 hasSignatureAlgorithm: c.vers >= VersionTLS12,
1163 }
1164 case typeFinished:
1165 m = new(finishedMsg)
1166 case typeEncryptedExtensions:
1167 m = new(encryptedExtensionsMsg)
1168 case typeEndOfEarlyData:
1169 m = new(endOfEarlyDataMsg)
1170 case typeKeyUpdate:
1171 m = new(keyUpdateMsg)
1172 default:
1173 return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
1174 }
1175
1176
1177
1178
1179 data = append([]byte(nil), data...)
1180
1181 if !m.unmarshal(data) {
1182 return nil, c.in.setErrorLocked(c.sendAlert(alertDecodeError))
1183 }
1184
1185 if transcript != nil {
1186 transcript.Write(data)
1187 }
1188
1189 return m, nil
1190 }
1191
1192 var (
1193 errShutdown = errors.New("tls: protocol is shutdown")
1194 )
1195
1196
1197
1198
1199
1200
1201
1202 func (c *Conn) Write(b []byte) (int, error) {
1203
1204 for {
1205 x := c.activeCall.Load()
1206 if x&1 != 0 {
1207 return 0, net.ErrClosed
1208 }
1209 if c.activeCall.CompareAndSwap(x, x+2) {
1210 break
1211 }
1212 }
1213 defer c.activeCall.Add(-2)
1214
1215 if err := c.Handshake(); err != nil {
1216 return 0, err
1217 }
1218
1219 c.out.Lock()
1220 defer c.out.Unlock()
1221
1222 if err := c.out.err; err != nil {
1223 return 0, err
1224 }
1225
1226 if !c.isHandshakeComplete.Load() {
1227 return 0, alertInternalError
1228 }
1229
1230 if c.closeNotifySent {
1231 return 0, errShutdown
1232 }
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243 var m int
1244 if len(b) > 1 && c.vers == VersionTLS10 {
1245 if _, ok := c.out.cipher.(cipher.BlockMode); ok {
1246 n, err := c.writeRecordLocked(recordTypeApplicationData, b[:1])
1247 if err != nil {
1248 return n, c.out.setErrorLocked(err)
1249 }
1250 m, b = 1, b[1:]
1251 }
1252 }
1253
1254 n, err := c.writeRecordLocked(recordTypeApplicationData, b)
1255 return n + m, c.out.setErrorLocked(err)
1256 }
1257
1258
1259 func (c *Conn) handleRenegotiation() error {
1260 if c.vers == VersionTLS13 {
1261 return errors.New("tls: internal error: unexpected renegotiation")
1262 }
1263
1264 msg, err := c.readHandshake(nil)
1265 if err != nil {
1266 return err
1267 }
1268
1269 helloReq, ok := msg.(*helloRequestMsg)
1270 if !ok {
1271 c.sendAlert(alertUnexpectedMessage)
1272 return unexpectedMessageError(helloReq, msg)
1273 }
1274
1275 if !c.isClient {
1276 return c.sendAlert(alertNoRenegotiation)
1277 }
1278
1279 switch c.config.Renegotiation {
1280 case RenegotiateNever:
1281 return c.sendAlert(alertNoRenegotiation)
1282 case RenegotiateOnceAsClient:
1283 if c.handshakes > 1 {
1284 return c.sendAlert(alertNoRenegotiation)
1285 }
1286 case RenegotiateFreelyAsClient:
1287
1288 default:
1289 c.sendAlert(alertInternalError)
1290 return errors.New("tls: unknown Renegotiation value")
1291 }
1292
1293 c.handshakeMutex.Lock()
1294 defer c.handshakeMutex.Unlock()
1295
1296 c.isHandshakeComplete.Store(false)
1297 if c.handshakeErr = c.clientHandshake(context.Background()); c.handshakeErr == nil {
1298 c.handshakes++
1299 }
1300 return c.handshakeErr
1301 }
1302
1303
1304
1305 func (c *Conn) handlePostHandshakeMessage() error {
1306 if c.vers != VersionTLS13 {
1307 return c.handleRenegotiation()
1308 }
1309
1310 msg, err := c.readHandshake(nil)
1311 if err != nil {
1312 return err
1313 }
1314 c.retryCount++
1315 if c.retryCount > maxUselessRecords {
1316 c.sendAlert(alertUnexpectedMessage)
1317 return c.in.setErrorLocked(errors.New("tls: too many non-advancing records"))
1318 }
1319
1320 switch msg := msg.(type) {
1321 case *newSessionTicketMsgTLS13:
1322 return c.handleNewSessionTicket(msg)
1323 case *keyUpdateMsg:
1324 return c.handleKeyUpdate(msg)
1325 }
1326
1327
1328
1329
1330 c.sendAlert(alertUnexpectedMessage)
1331 return fmt.Errorf("tls: received unexpected handshake message of type %T", msg)
1332 }
1333
1334 func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
1335 if c.quic != nil {
1336 c.sendAlert(alertUnexpectedMessage)
1337 return c.in.setErrorLocked(errors.New("tls: received unexpected key update message"))
1338 }
1339
1340 cipherSuite := cipherSuiteTLS13ByID(c.cipherSuite)
1341 if cipherSuite == nil {
1342 return c.in.setErrorLocked(c.sendAlert(alertInternalError))
1343 }
1344
1345 if keyUpdate.updateRequested {
1346 c.out.Lock()
1347 defer c.out.Unlock()
1348
1349 msg := &keyUpdateMsg{}
1350 msgBytes, err := msg.marshal()
1351 if err != nil {
1352 return err
1353 }
1354 _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
1355 if err != nil {
1356
1357 c.out.setErrorLocked(err)
1358 return nil
1359 }
1360
1361 newSecret := cipherSuite.nextTrafficSecret(c.out.trafficSecret)
1362 c.setWriteTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret)
1363 }
1364
1365 newSecret := cipherSuite.nextTrafficSecret(c.in.trafficSecret)
1366 if err := c.setReadTrafficSecret(cipherSuite, QUICEncryptionLevelInitial, newSecret); err != nil {
1367 return err
1368 }
1369
1370 return nil
1371 }
1372
1373
1374
1375
1376
1377
1378
1379 func (c *Conn) Read(b []byte) (int, error) {
1380 if err := c.Handshake(); err != nil {
1381 return 0, err
1382 }
1383 if len(b) == 0 {
1384
1385
1386 return 0, nil
1387 }
1388
1389 c.in.Lock()
1390 defer c.in.Unlock()
1391
1392 for c.input.Len() == 0 {
1393 if err := c.readRecord(); err != nil {
1394 return 0, err
1395 }
1396 for c.hand.Len() > 0 {
1397 if err := c.handlePostHandshakeMessage(); err != nil {
1398 return 0, err
1399 }
1400 }
1401 }
1402
1403 n, _ := c.input.Read(b)
1404
1405
1406
1407
1408
1409
1410
1411
1412 if n != 0 && c.input.Len() == 0 && c.rawInput.Len() > 0 &&
1413 recordType(c.rawInput.Bytes()[0]) == recordTypeAlert {
1414 if err := c.readRecord(); err != nil {
1415 return n, err
1416 }
1417 }
1418
1419 return n, nil
1420 }
1421
1422
1423 func (c *Conn) Close() error {
1424
1425 var x int32
1426 for {
1427 x = c.activeCall.Load()
1428 if x&1 != 0 {
1429 return net.ErrClosed
1430 }
1431 if c.activeCall.CompareAndSwap(x, x|1) {
1432 break
1433 }
1434 }
1435 if x != 0 {
1436
1437
1438
1439
1440
1441
1442 return c.conn.Close()
1443 }
1444
1445 var alertErr error
1446 if c.isHandshakeComplete.Load() {
1447 if err := c.closeNotify(); err != nil {
1448 alertErr = fmt.Errorf("tls: failed to send closeNotify alert (but connection was closed anyway): %w", err)
1449 }
1450 }
1451
1452 if err := c.conn.Close(); err != nil {
1453 return err
1454 }
1455 return alertErr
1456 }
1457
1458 var errEarlyCloseWrite = errors.New("tls: CloseWrite called before handshake complete")
1459
1460
1461
1462
1463 func (c *Conn) CloseWrite() error {
1464 if !c.isHandshakeComplete.Load() {
1465 return errEarlyCloseWrite
1466 }
1467
1468 return c.closeNotify()
1469 }
1470
1471 func (c *Conn) closeNotify() error {
1472 c.out.Lock()
1473 defer c.out.Unlock()
1474
1475 if !c.closeNotifySent {
1476
1477 c.SetWriteDeadline(time.Now().Add(time.Second * 5))
1478 c.closeNotifyErr = c.sendAlertLocked(alertCloseNotify)
1479 c.closeNotifySent = true
1480
1481 c.SetWriteDeadline(time.Now())
1482 }
1483 return c.closeNotifyErr
1484 }
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499 func (c *Conn) Handshake() error {
1500 return c.HandshakeContext(context.Background())
1501 }
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513 func (c *Conn) HandshakeContext(ctx context.Context) error {
1514
1515
1516 return c.handshakeContext(ctx)
1517 }
1518
1519 func (c *Conn) handshakeContext(ctx context.Context) (ret error) {
1520
1521
1522
1523 if c.isHandshakeComplete.Load() {
1524 return nil
1525 }
1526
1527 handshakeCtx, cancel := context.WithCancel(ctx)
1528
1529
1530
1531 defer cancel()
1532
1533 if c.quic != nil {
1534 c.quic.ctx = handshakeCtx
1535 c.quic.cancel = cancel
1536 } else if ctx.Done() != nil {
1537
1538 stop := context.AfterFunc(ctx, func() {
1539 _ = c.conn.Close()
1540 })
1541 defer func() {
1542 if !stop() {
1543
1544 ret = ctx.Err()
1545 }
1546 }()
1547 }
1548
1549 c.handshakeMutex.Lock()
1550 defer c.handshakeMutex.Unlock()
1551
1552 if err := c.handshakeErr; err != nil {
1553 return err
1554 }
1555 if c.isHandshakeComplete.Load() {
1556 return nil
1557 }
1558
1559 c.in.Lock()
1560 defer c.in.Unlock()
1561
1562 c.handshakeErr = c.handshakeFn(handshakeCtx)
1563 if c.handshakeErr == nil {
1564 c.handshakes++
1565 } else {
1566
1567
1568 c.flush()
1569 }
1570
1571 if c.handshakeErr == nil && !c.isHandshakeComplete.Load() {
1572 c.handshakeErr = errors.New("tls: internal error: handshake should have had a result")
1573 }
1574 if c.handshakeErr != nil && c.isHandshakeComplete.Load() {
1575 panic("tls: internal error: handshake returned an error but is marked successful")
1576 }
1577
1578 if c.quic != nil {
1579 if c.handshakeErr == nil {
1580 c.quicHandshakeComplete()
1581
1582
1583
1584 if err := c.quicSetReadSecret(QUICEncryptionLevelApplication, c.cipherSuite, c.in.trafficSecret); err != nil {
1585 return err
1586 }
1587 } else {
1588 c.out.Lock()
1589 a, ok := errors.AsType[alert](c.out.err)
1590 if !ok {
1591 a = alertInternalError
1592 }
1593 c.out.Unlock()
1594
1595
1596
1597
1598 c.handshakeErr = fmt.Errorf("%w%.0w", c.handshakeErr, AlertError(a))
1599 }
1600 close(c.quic.blockedc)
1601 close(c.quic.signalc)
1602 }
1603
1604 return c.handshakeErr
1605 }
1606
1607
1608 func (c *Conn) ConnectionState() ConnectionState {
1609 c.handshakeMutex.Lock()
1610 defer c.handshakeMutex.Unlock()
1611 return c.connectionStateLocked()
1612 }
1613
1614 var tlsunsafeekm = godebug.New("tlsunsafeekm")
1615
1616 func (c *Conn) connectionStateLocked() ConnectionState {
1617 var state ConnectionState
1618 state.HandshakeComplete = c.isHandshakeComplete.Load()
1619 state.Version = c.vers
1620 state.NegotiatedProtocol = c.clientProtocol
1621 state.DidResume = c.didResume
1622 state.HelloRetryRequest = c.didHRR
1623 state.testingOnlyPeerSignatureAlgorithm = c.peerSigAlg
1624 state.CurveID = c.curveID
1625 state.NegotiatedProtocolIsMutual = true
1626 state.ServerName = c.serverName
1627 state.CipherSuite = c.cipherSuite
1628 state.PeerCertificates = c.peerCertificates
1629 state.VerifiedChains = c.verifiedChains
1630 state.SignedCertificateTimestamps = c.scts
1631 state.OCSPResponse = c.ocspResponse
1632 if (!c.didResume || c.extMasterSecret) && c.vers != VersionTLS13 {
1633 if c.clientFinishedIsFirst {
1634 state.TLSUnique = c.clientFinished[:]
1635 } else {
1636 state.TLSUnique = c.serverFinished[:]
1637 }
1638 }
1639 if c.config.Renegotiation != RenegotiateNever {
1640 state.ekm = noEKMBecauseRenegotiation
1641 } else if c.vers != VersionTLS13 && !c.extMasterSecret {
1642 state.ekm = func(label string, context []byte, length int) ([]byte, error) {
1643 if tlsunsafeekm.Value() == "1" {
1644 tlsunsafeekm.IncNonDefault()
1645 return c.ekm(label, context, length)
1646 }
1647 return noEKMBecauseNoEMS(label, context, length)
1648 }
1649 } else {
1650 state.ekm = c.ekm
1651 }
1652 state.ECHAccepted = c.echAccepted
1653 return state
1654 }
1655
1656
1657
1658 func (c *Conn) OCSPResponse() []byte {
1659 c.handshakeMutex.Lock()
1660 defer c.handshakeMutex.Unlock()
1661
1662 return c.ocspResponse
1663 }
1664
1665
1666
1667
1668 func (c *Conn) VerifyHostname(host string) error {
1669 c.handshakeMutex.Lock()
1670 defer c.handshakeMutex.Unlock()
1671 if !c.isClient {
1672 return errors.New("tls: VerifyHostname called on TLS server connection")
1673 }
1674 if !c.isHandshakeComplete.Load() {
1675 return errors.New("tls: handshake has not yet been performed")
1676 }
1677 if len(c.verifiedChains) == 0 {
1678 return errors.New("tls: handshake did not verify certificate chain")
1679 }
1680 return c.peerCertificates[0].VerifyHostname(host)
1681 }
1682
1683
1684
1685
1686 func (c *Conn) setReadTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) error {
1687
1688
1689
1690 if c.hand.Len() != 0 {
1691 c.sendAlert(alertUnexpectedMessage)
1692 return errors.New("tls: handshake buffer not empty before setting read traffic secret")
1693 }
1694 c.in.setTrafficSecret(suite, level, secret)
1695 return nil
1696 }
1697
1698
1699
1700
1701 func (c *Conn) setWriteTrafficSecret(suite *cipherSuiteTLS13, level QUICEncryptionLevel, secret []byte) {
1702 c.out.setTrafficSecret(suite, level, secret)
1703 }
1704
View as plain text