Source file src/crypto/subtle/constant_time_test.go

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package subtle
     6  
     7  import (
     8  	"testing"
     9  	"testing/quick"
    10  )
    11  
    12  type TestConstantTimeCompareStruct struct {
    13  	a, b []byte
    14  	out  int
    15  }
    16  
    17  var testConstantTimeCompareData = []TestConstantTimeCompareStruct{
    18  	{[]byte{}, []byte{}, 1},
    19  	{[]byte{0x11}, []byte{0x11}, 1},
    20  	{[]byte{0x12}, []byte{0x11}, 0},
    21  	{[]byte{0x11}, []byte{0x11, 0x12}, 0},
    22  	{[]byte{0x11, 0x12}, []byte{0x11}, 0},
    23  }
    24  
    25  func TestConstantTimeCompare(t *testing.T) {
    26  	for i, test := range testConstantTimeCompareData {
    27  		if r := ConstantTimeCompare(test.a, test.b); r != test.out {
    28  			t.Errorf("#%d bad result (got %x, want %x)", i, r, test.out)
    29  		}
    30  	}
    31  }
    32  
    33  type TestConstantTimeByteEqStruct struct {
    34  	a, b uint8
    35  	out  int
    36  }
    37  
    38  var testConstandTimeByteEqData = []TestConstantTimeByteEqStruct{
    39  	{0, 0, 1},
    40  	{0, 1, 0},
    41  	{1, 0, 0},
    42  	{0xff, 0xff, 1},
    43  	{0xff, 0xfe, 0},
    44  }
    45  
    46  func byteEq(a, b uint8) int {
    47  	if a == b {
    48  		return 1
    49  	}
    50  	return 0
    51  }
    52  
    53  func TestConstantTimeByteEq(t *testing.T) {
    54  	for i, test := range testConstandTimeByteEqData {
    55  		if r := ConstantTimeByteEq(test.a, test.b); r != test.out {
    56  			t.Errorf("#%d bad result (got %x, want %x)", i, r, test.out)
    57  		}
    58  	}
    59  	err := quick.CheckEqual(ConstantTimeByteEq, byteEq, nil)
    60  	if err != nil {
    61  		t.Error(err)
    62  	}
    63  }
    64  
    65  func eq(a, b int32) int {
    66  	if a == b {
    67  		return 1
    68  	}
    69  	return 0
    70  }
    71  
    72  func TestConstantTimeEq(t *testing.T) {
    73  	err := quick.CheckEqual(ConstantTimeEq, eq, nil)
    74  	if err != nil {
    75  		t.Error(err)
    76  	}
    77  }
    78  
    79  func makeCopy(v int, x, y []byte) []byte {
    80  	if len(x) > len(y) {
    81  		x = x[:len(y)]
    82  	} else {
    83  		y = y[:len(x)]
    84  	}
    85  	if v == 1 {
    86  		copy(x, y)
    87  	}
    88  	return x
    89  }
    90  
    91  func constantTimeCopyWrapper(v int, x, y []byte) []byte {
    92  	if len(x) > len(y) {
    93  		x = x[:len(y)]
    94  	} else {
    95  		y = y[:len(x)]
    96  	}
    97  	v &= 1
    98  	ConstantTimeCopy(v, x, y)
    99  	return x
   100  }
   101  
   102  func TestConstantTimeCopy(t *testing.T) {
   103  	err := quick.CheckEqual(constantTimeCopyWrapper, makeCopy, nil)
   104  	if err != nil {
   105  		t.Error(err)
   106  	}
   107  }
   108  
   109  var lessOrEqTests = []struct {
   110  	x, y, result int
   111  }{
   112  	{0, 0, 1},
   113  	{1, 0, 0},
   114  	{0, 1, 1},
   115  	{10, 20, 1},
   116  	{20, 10, 0},
   117  	{10, 10, 1},
   118  }
   119  
   120  func TestConstantTimeLessOrEq(t *testing.T) {
   121  	for i, test := range lessOrEqTests {
   122  		result := ConstantTimeLessOrEq(test.x, test.y)
   123  		if result != test.result {
   124  			t.Errorf("#%d: %d <= %d gave %d, expected %d", i, test.x, test.y, result, test.result)
   125  		}
   126  	}
   127  }
   128  
   129  var benchmarkGlobal uint8
   130  
   131  func BenchmarkConstantTimeSelect(b *testing.B) {
   132  	x := int(benchmarkGlobal)
   133  	var y, z int
   134  
   135  	for range b.N {
   136  		y, z, x = ConstantTimeSelect(x, y, z), y, z
   137  	}
   138  
   139  	benchmarkGlobal = uint8(x)
   140  }
   141  
   142  func BenchmarkConstantTimeByteEq(b *testing.B) {
   143  	var x, y uint8
   144  
   145  	for i := 0; i < b.N; i++ {
   146  		x, y = uint8(ConstantTimeByteEq(x, y)), x
   147  	}
   148  
   149  	benchmarkGlobal = x
   150  }
   151  
   152  func BenchmarkConstantTimeEq(b *testing.B) {
   153  	var x, y int
   154  
   155  	for i := 0; i < b.N; i++ {
   156  		x, y = ConstantTimeEq(int32(x), int32(y)), x
   157  	}
   158  
   159  	benchmarkGlobal = uint8(x)
   160  }
   161  
   162  func BenchmarkConstantTimeLessOrEq(b *testing.B) {
   163  	var x, y int
   164  
   165  	for i := 0; i < b.N; i++ {
   166  		x, y = ConstantTimeLessOrEq(x, y), x
   167  	}
   168  
   169  	benchmarkGlobal = uint8(x)
   170  }
   171  

View as plain text