1
2
3
4
5 package asn1
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "math/big"
12 "reflect"
13 "slices"
14 "time"
15 "unicode/utf8"
16 )
17
18 var (
19 byte00Encoder encoder = byteEncoder(0x00)
20 byteFFEncoder encoder = byteEncoder(0xff)
21 )
22
23
24 type encoder interface {
25
26 Len() int
27
28 Encode(dst []byte)
29 }
30
31 type byteEncoder byte
32
33 func (c byteEncoder) Len() int {
34 return 1
35 }
36
37 func (c byteEncoder) Encode(dst []byte) {
38 dst[0] = byte(c)
39 }
40
41 type bytesEncoder []byte
42
43 func (b bytesEncoder) Len() int {
44 return len(b)
45 }
46
47 func (b bytesEncoder) Encode(dst []byte) {
48 if copy(dst, b) != len(b) {
49 panic("internal error")
50 }
51 }
52
53 type stringEncoder string
54
55 func (s stringEncoder) Len() int {
56 return len(s)
57 }
58
59 func (s stringEncoder) Encode(dst []byte) {
60 if copy(dst, s) != len(s) {
61 panic("internal error")
62 }
63 }
64
65 type multiEncoder []encoder
66
67 func (m multiEncoder) Len() int {
68 var size int
69 for _, e := range m {
70 size += e.Len()
71 }
72 return size
73 }
74
75 func (m multiEncoder) Encode(dst []byte) {
76 var off int
77 for _, e := range m {
78 e.Encode(dst[off:])
79 off += e.Len()
80 }
81 }
82
83 type setEncoder []encoder
84
85 func (s setEncoder) Len() int {
86 var size int
87 for _, e := range s {
88 size += e.Len()
89 }
90 return size
91 }
92
93 func (s setEncoder) Encode(dst []byte) {
94
95
96
97
98
99
100
101
102 l := make([][]byte, len(s))
103 for i, e := range s {
104 l[i] = make([]byte, e.Len())
105 e.Encode(l[i])
106 }
107
108
109
110
111
112
113
114 slices.SortFunc(l, bytes.Compare)
115
116 var off int
117 for _, b := range l {
118 copy(dst[off:], b)
119 off += len(b)
120 }
121 }
122
123 type taggedEncoder struct {
124
125
126 scratch [8]byte
127 tag encoder
128 body encoder
129 }
130
131 func (t *taggedEncoder) Len() int {
132 return t.tag.Len() + t.body.Len()
133 }
134
135 func (t *taggedEncoder) Encode(dst []byte) {
136 t.tag.Encode(dst)
137 t.body.Encode(dst[t.tag.Len():])
138 }
139
140 type int64Encoder int64
141
142 func (i int64Encoder) Len() int {
143 n := 1
144
145 for i > 127 {
146 n++
147 i >>= 8
148 }
149
150 for i < -128 {
151 n++
152 i >>= 8
153 }
154
155 return n
156 }
157
158 func (i int64Encoder) Encode(dst []byte) {
159 n := i.Len()
160
161 for j := 0; j < n; j++ {
162 dst[j] = byte(i >> uint((n-1-j)*8))
163 }
164 }
165
166 func base128IntLength(n int64) int {
167 if n == 0 {
168 return 1
169 }
170
171 l := 0
172 for i := n; i > 0; i >>= 7 {
173 l++
174 }
175
176 return l
177 }
178
179 func appendBase128Int(dst []byte, n int64) []byte {
180 l := base128IntLength(n)
181
182 for i := l - 1; i >= 0; i-- {
183 o := byte(n >> uint(i*7))
184 o &= 0x7f
185 if i != 0 {
186 o |= 0x80
187 }
188
189 dst = append(dst, o)
190 }
191
192 return dst
193 }
194
195 func makeBigInt(n *big.Int) (encoder, error) {
196 if n == nil {
197 return nil, StructuralError{"empty integer"}
198 }
199
200 if n.Sign() < 0 {
201
202
203
204
205 nMinus1 := new(big.Int).Neg(n)
206 nMinus1.Sub(nMinus1, bigOne)
207 bytes := nMinus1.Bytes()
208 for i := range bytes {
209 bytes[i] ^= 0xff
210 }
211 if len(bytes) == 0 || bytes[0]&0x80 == 0 {
212 return multiEncoder([]encoder{byteFFEncoder, bytesEncoder(bytes)}), nil
213 }
214 return bytesEncoder(bytes), nil
215 } else if n.Sign() == 0 {
216
217 return byte00Encoder, nil
218 } else {
219 bytes := n.Bytes()
220 if len(bytes) > 0 && bytes[0]&0x80 != 0 {
221
222
223 return multiEncoder([]encoder{byte00Encoder, bytesEncoder(bytes)}), nil
224 }
225 return bytesEncoder(bytes), nil
226 }
227 }
228
229 func appendLength(dst []byte, i int) []byte {
230 n := lengthLength(i)
231
232 for ; n > 0; n-- {
233 dst = append(dst, byte(i>>uint((n-1)*8)))
234 }
235
236 return dst
237 }
238
239 func lengthLength(i int) (numBytes int) {
240 numBytes = 1
241 for i > 255 {
242 numBytes++
243 i >>= 8
244 }
245 return
246 }
247
248 func appendTagAndLength(dst []byte, t tagAndLength) []byte {
249 b := uint8(t.class) << 6
250 if t.isCompound {
251 b |= 0x20
252 }
253 if t.tag >= 31 {
254 b |= 0x1f
255 dst = append(dst, b)
256 dst = appendBase128Int(dst, int64(t.tag))
257 } else {
258 b |= uint8(t.tag)
259 dst = append(dst, b)
260 }
261
262 if t.length >= 128 {
263 l := lengthLength(t.length)
264 dst = append(dst, 0x80|byte(l))
265 dst = appendLength(dst, t.length)
266 } else {
267 dst = append(dst, byte(t.length))
268 }
269
270 return dst
271 }
272
273 type bitStringEncoder BitString
274
275 func (b bitStringEncoder) Len() int {
276 return len(b.Bytes) + 1
277 }
278
279 func (b bitStringEncoder) Encode(dst []byte) {
280 dst[0] = byte((8 - b.BitLength%8) % 8)
281 if copy(dst[1:], b.Bytes) != len(b.Bytes) {
282 panic("internal error")
283 }
284 }
285
286 type oidEncoder []int
287
288 func (oid oidEncoder) Len() int {
289 l := base128IntLength(int64(oid[0]*40 + oid[1]))
290 for i := 2; i < len(oid); i++ {
291 l += base128IntLength(int64(oid[i]))
292 }
293 return l
294 }
295
296 func (oid oidEncoder) Encode(dst []byte) {
297 dst = appendBase128Int(dst[:0], int64(oid[0]*40+oid[1]))
298 for i := 2; i < len(oid); i++ {
299 dst = appendBase128Int(dst, int64(oid[i]))
300 }
301 }
302
303 func makeObjectIdentifier(oid []int) (e encoder, err error) {
304 if len(oid) < 2 || oid[0] > 2 || (oid[0] < 2 && oid[1] >= 40) {
305 return nil, StructuralError{"invalid object identifier"}
306 }
307
308 return oidEncoder(oid), nil
309 }
310
311 func makePrintableString(s string) (e encoder, err error) {
312 for i := 0; i < len(s); i++ {
313
314
315
316
317
318
319 if !isPrintable(s[i], allowAsterisk, rejectAmpersand) {
320 return nil, StructuralError{"PrintableString contains invalid character"}
321 }
322 }
323
324 return stringEncoder(s), nil
325 }
326
327 func makeIA5String(s string) (e encoder, err error) {
328 for i := 0; i < len(s); i++ {
329 if s[i] > 127 {
330 return nil, StructuralError{"IA5String contains invalid character"}
331 }
332 }
333
334 return stringEncoder(s), nil
335 }
336
337 func makeNumericString(s string) (e encoder, err error) {
338 for i := 0; i < len(s); i++ {
339 if !isNumeric(s[i]) {
340 return nil, StructuralError{"NumericString contains invalid character"}
341 }
342 }
343
344 return stringEncoder(s), nil
345 }
346
347 func makeUTF8String(s string) encoder {
348 return stringEncoder(s)
349 }
350
351 func appendTwoDigits(dst []byte, v int) []byte {
352 return append(dst, byte('0'+(v/10)%10), byte('0'+v%10))
353 }
354
355 func appendFourDigits(dst []byte, v int) []byte {
356 return append(dst,
357 byte('0'+(v/1000)%10),
358 byte('0'+(v/100)%10),
359 byte('0'+(v/10)%10),
360 byte('0'+v%10))
361 }
362
363 func outsideUTCRange(t time.Time) bool {
364 year := t.Year()
365 return year < 1950 || year >= 2050
366 }
367
368 func makeUTCTime(t time.Time) (e encoder, err error) {
369 dst := make([]byte, 0, 18)
370
371 dst, err = appendUTCTime(dst, t)
372 if err != nil {
373 return nil, err
374 }
375
376 return bytesEncoder(dst), nil
377 }
378
379 func makeGeneralizedTime(t time.Time) (e encoder, err error) {
380 dst := make([]byte, 0, 20)
381
382 dst, err = appendGeneralizedTime(dst, t)
383 if err != nil {
384 return nil, err
385 }
386
387 return bytesEncoder(dst), nil
388 }
389
390 func appendUTCTime(dst []byte, t time.Time) (ret []byte, err error) {
391 year := t.Year()
392
393 switch {
394 case 1950 <= year && year < 2000:
395 dst = appendTwoDigits(dst, year-1900)
396 case 2000 <= year && year < 2050:
397 dst = appendTwoDigits(dst, year-2000)
398 default:
399 return nil, StructuralError{"cannot represent time as UTCTime"}
400 }
401
402 return appendTimeCommon(dst, t), nil
403 }
404
405 func appendGeneralizedTime(dst []byte, t time.Time) (ret []byte, err error) {
406 year := t.Year()
407 if year < 0 || year > 9999 {
408 return nil, StructuralError{"cannot represent time as GeneralizedTime"}
409 }
410
411 dst = appendFourDigits(dst, year)
412
413 return appendTimeCommon(dst, t), nil
414 }
415
416 func appendTimeCommon(dst []byte, t time.Time) []byte {
417 _, month, day := t.Date()
418
419 dst = appendTwoDigits(dst, int(month))
420 dst = appendTwoDigits(dst, day)
421
422 hour, min, sec := t.Clock()
423
424 dst = appendTwoDigits(dst, hour)
425 dst = appendTwoDigits(dst, min)
426 dst = appendTwoDigits(dst, sec)
427
428 _, offset := t.Zone()
429
430 switch {
431 case offset/60 == 0:
432 return append(dst, 'Z')
433 case offset > 0:
434 dst = append(dst, '+')
435 case offset < 0:
436 dst = append(dst, '-')
437 }
438
439 offsetMinutes := offset / 60
440 if offsetMinutes < 0 {
441 offsetMinutes = -offsetMinutes
442 }
443
444 dst = appendTwoDigits(dst, offsetMinutes/60)
445 dst = appendTwoDigits(dst, offsetMinutes%60)
446
447 return dst
448 }
449
450 func stripTagAndLength(in []byte) []byte {
451 _, offset, err := parseTagAndLength(in, 0)
452 if err != nil {
453 return in
454 }
455 return in[offset:]
456 }
457
458 func makeBody(value reflect.Value, params fieldParameters) (e encoder, err error) {
459 switch value.Type() {
460 case flagType:
461 return bytesEncoder(nil), nil
462 case timeType:
463 t, _ := reflect.TypeAssert[time.Time](value)
464 if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
465 return makeGeneralizedTime(t)
466 }
467 return makeUTCTime(t)
468 case bitStringType:
469 v, _ := reflect.TypeAssert[BitString](value)
470 return bitStringEncoder(v), nil
471 case objectIdentifierType:
472 v, _ := reflect.TypeAssert[ObjectIdentifier](value)
473 return makeObjectIdentifier(v)
474 case bigIntType:
475 v, _ := reflect.TypeAssert[*big.Int](value)
476 return makeBigInt(v)
477 }
478
479 switch v := value; v.Kind() {
480 case reflect.Bool:
481 if v.Bool() {
482 return byteFFEncoder, nil
483 }
484 return byte00Encoder, nil
485 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
486 return int64Encoder(v.Int()), nil
487 case reflect.Struct:
488 t := v.Type()
489
490 for i := 0; i < t.NumField(); i++ {
491 if !t.Field(i).IsExported() {
492 return nil, StructuralError{"struct contains unexported fields"}
493 }
494 }
495
496 startingField := 0
497
498 n := t.NumField()
499 if n == 0 {
500 return bytesEncoder(nil), nil
501 }
502
503
504
505 if t.Field(0).Type == rawContentsType {
506 s := v.Field(0)
507 if s.Len() > 0 {
508 bytes := s.Bytes()
509
513 return bytesEncoder(stripTagAndLength(bytes)), nil
514 }
515
516 startingField = 1
517 }
518
519 switch n1 := n - startingField; n1 {
520 case 0:
521 return bytesEncoder(nil), nil
522 case 1:
523 return makeField(v.Field(startingField), parseFieldParameters(t.Field(startingField).Tag.Get("asn1")))
524 default:
525 m := make([]encoder, n1)
526 for i := 0; i < n1; i++ {
527 m[i], err = makeField(v.Field(i+startingField), parseFieldParameters(t.Field(i+startingField).Tag.Get("asn1")))
528 if err != nil {
529 return nil, err
530 }
531 }
532
533 return multiEncoder(m), nil
534 }
535 case reflect.Slice:
536 sliceType := v.Type()
537 if sliceType.Elem().Kind() == reflect.Uint8 {
538 return bytesEncoder(v.Bytes()), nil
539 }
540
541 var fp fieldParameters
542
543 switch l := v.Len(); l {
544 case 0:
545 return bytesEncoder(nil), nil
546 case 1:
547 return makeField(v.Index(0), fp)
548 default:
549 m := make([]encoder, l)
550
551 for i := 0; i < l; i++ {
552 m[i], err = makeField(v.Index(i), fp)
553 if err != nil {
554 return nil, err
555 }
556 }
557
558 if params.set {
559 return setEncoder(m), nil
560 }
561 return multiEncoder(m), nil
562 }
563 case reflect.String:
564 switch params.stringType {
565 case TagIA5String:
566 return makeIA5String(v.String())
567 case TagPrintableString:
568 return makePrintableString(v.String())
569 case TagNumericString:
570 return makeNumericString(v.String())
571 default:
572 return makeUTF8String(v.String()), nil
573 }
574 }
575
576 return nil, StructuralError{"unknown Go type"}
577 }
578
579 func makeField(v reflect.Value, params fieldParameters) (e encoder, err error) {
580 if !v.IsValid() {
581 return nil, fmt.Errorf("asn1: cannot marshal nil value")
582 }
583
584 if v.Kind() == reflect.Interface && v.Type().NumMethod() == 0 {
585 return makeField(v.Elem(), params)
586 }
587
588 if v.Kind() == reflect.Slice && v.Len() == 0 && params.omitEmpty {
589 return bytesEncoder(nil), nil
590 }
591
592 if params.optional && params.defaultValue != nil && canHaveDefaultValue(v.Kind()) {
593 defaultValue := reflect.New(v.Type()).Elem()
594 defaultValue.SetInt(*params.defaultValue)
595
596 if reflect.DeepEqual(v.Interface(), defaultValue.Interface()) {
597 return bytesEncoder(nil), nil
598 }
599 }
600
601
602
603
604 if params.optional && params.defaultValue == nil {
605 if reflect.DeepEqual(v.Interface(), reflect.Zero(v.Type()).Interface()) {
606 return bytesEncoder(nil), nil
607 }
608 }
609
610 if v.Type() == rawValueType {
611 rv, _ := reflect.TypeAssert[RawValue](v)
612 if len(rv.FullBytes) != 0 {
613 return bytesEncoder(rv.FullBytes), nil
614 }
615
616 t := new(taggedEncoder)
617
618 t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{rv.Class, rv.Tag, len(rv.Bytes), rv.IsCompound}))
619 t.body = bytesEncoder(rv.Bytes)
620
621 return t, nil
622 }
623
624 matchAny, tag, isCompound, ok := getUniversalType(v.Type())
625 if !ok || matchAny {
626 return nil, StructuralError{fmt.Sprintf("unknown Go type: %v", v.Type())}
627 }
628
629 if params.timeType != 0 && tag != TagUTCTime {
630 return nil, StructuralError{"explicit time type given to non-time member"}
631 }
632
633 if params.stringType != 0 && tag != TagPrintableString {
634 return nil, StructuralError{"explicit string type given to non-string member"}
635 }
636
637 switch tag {
638 case TagPrintableString:
639 if params.stringType == 0 {
640
641
642
643 for _, r := range v.String() {
644 if r >= utf8.RuneSelf || !isPrintable(byte(r), rejectAsterisk, rejectAmpersand) {
645 if !utf8.ValidString(v.String()) {
646 return nil, errors.New("asn1: string not valid UTF-8")
647 }
648 tag = TagUTF8String
649 break
650 }
651 }
652 } else {
653 tag = params.stringType
654 }
655 case TagUTCTime:
656 t, _ := reflect.TypeAssert[time.Time](v)
657 if params.timeType == TagGeneralizedTime || outsideUTCRange(t) {
658 tag = TagGeneralizedTime
659 }
660 }
661
662 if params.set {
663 if tag != TagSequence {
664 return nil, StructuralError{"non sequence tagged as set"}
665 }
666 tag = TagSet
667 }
668
669
670
671
672
673
674 if tag == TagSet && !params.set {
675 params.set = true
676 }
677
678 t := new(taggedEncoder)
679
680 t.body, err = makeBody(v, params)
681 if err != nil {
682 return nil, err
683 }
684
685 bodyLen := t.body.Len()
686
687 class := ClassUniversal
688 if params.tag != nil {
689 if params.application {
690 class = ClassApplication
691 } else if params.private {
692 class = ClassPrivate
693 } else {
694 class = ClassContextSpecific
695 }
696
697 if params.explicit {
698 t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{ClassUniversal, tag, bodyLen, isCompound}))
699
700 tt := new(taggedEncoder)
701
702 tt.body = t
703
704 tt.tag = bytesEncoder(appendTagAndLength(tt.scratch[:0], tagAndLength{
705 class: class,
706 tag: *params.tag,
707 length: bodyLen + t.tag.Len(),
708 isCompound: true,
709 }))
710
711 return tt, nil
712 }
713
714
715 tag = *params.tag
716 }
717
718 t.tag = bytesEncoder(appendTagAndLength(t.scratch[:0], tagAndLength{class, tag, bodyLen, isCompound}))
719
720 return t, nil
721 }
722
723
724
725
726
727
728
729
730
731
732
733
734
735 func Marshal(val any) ([]byte, error) {
736 return MarshalWithParams(val, "")
737 }
738
739
740
741 func MarshalWithParams(val any, params string) ([]byte, error) {
742 e, err := makeField(reflect.ValueOf(val), parseFieldParameters(params))
743 if err != nil {
744 return nil, err
745 }
746 b := make([]byte, e.Len())
747 e.Encode(b)
748 return b, nil
749 }
750
View as plain text