From c538fc9eb1398114c36c97390d5dce58bff79dc6 Mon Sep 17 00:00:00 2001 From: HarveyKandola Date: Tue, 25 Jun 2019 15:33:51 +0100 Subject: [PATCH] Update PostgreSQL driver library --- vendor/github.com/lib/pq/README.md | 2 +- vendor/github.com/lib/pq/array_test.go | 1311 +++++++++++++ vendor/github.com/lib/pq/bench_test.go | 434 ++++ vendor/github.com/lib/pq/certs/README | 3 + vendor/github.com/lib/pq/certs/postgresql.key | 15 + vendor/github.com/lib/pq/certs/server.key | 27 + vendor/github.com/lib/pq/conn.go | 295 +-- vendor/github.com/lib/pq/conn_go18.go | 28 +- vendor/github.com/lib/pq/conn_test.go | 1741 +++++++++++++++++ vendor/github.com/lib/pq/connector.go | 91 +- .../lib/pq/connector_example_test.go | 33 + vendor/github.com/lib/pq/connector_test.go | 67 + vendor/github.com/lib/pq/copy_test.go | 468 +++++ vendor/github.com/lib/pq/doc.go | 2 +- vendor/github.com/lib/pq/encode.go | 9 +- vendor/github.com/lib/pq/encode_test.go | 766 ++++++++ .../github.com/lib/pq/example/listen/doc.go | 98 + vendor/github.com/lib/pq/go18_test.go | 319 +++ vendor/github.com/lib/pq/go19_test.go | 69 + vendor/github.com/lib/pq/hstore/hstore.go | 118 ++ .../github.com/lib/pq/hstore/hstore_test.go | 148 ++ vendor/github.com/lib/pq/issues_test.go | 26 + vendor/github.com/lib/pq/notify_test.go | 570 ++++++ vendor/github.com/lib/pq/rows_test.go | 218 +++ vendor/github.com/lib/pq/scram/scram.go | 264 +++ vendor/github.com/lib/pq/ssl.go | 8 +- vendor/github.com/lib/pq/ssl_go1.7.go | 14 - vendor/github.com/lib/pq/ssl_renegotiation.go | 8 - vendor/github.com/lib/pq/ssl_test.go | 279 +++ vendor/github.com/lib/pq/url_test.go | 66 + vendor/github.com/lib/pq/uuid_test.go | 46 + 31 files changed, 7381 insertions(+), 162 deletions(-) create mode 100644 vendor/github.com/lib/pq/array_test.go create mode 100644 vendor/github.com/lib/pq/bench_test.go create mode 100644 vendor/github.com/lib/pq/certs/README create mode 100644 vendor/github.com/lib/pq/certs/postgresql.key create mode 100644 vendor/github.com/lib/pq/certs/server.key create mode 100644 vendor/github.com/lib/pq/conn_test.go create mode 100644 vendor/github.com/lib/pq/connector_example_test.go create mode 100644 vendor/github.com/lib/pq/connector_test.go create mode 100644 vendor/github.com/lib/pq/copy_test.go create mode 100644 vendor/github.com/lib/pq/encode_test.go create mode 100644 vendor/github.com/lib/pq/example/listen/doc.go create mode 100644 vendor/github.com/lib/pq/go18_test.go create mode 100644 vendor/github.com/lib/pq/go19_test.go create mode 100644 vendor/github.com/lib/pq/hstore/hstore.go create mode 100644 vendor/github.com/lib/pq/hstore/hstore_test.go create mode 100644 vendor/github.com/lib/pq/issues_test.go create mode 100644 vendor/github.com/lib/pq/notify_test.go create mode 100644 vendor/github.com/lib/pq/rows_test.go create mode 100644 vendor/github.com/lib/pq/scram/scram.go delete mode 100644 vendor/github.com/lib/pq/ssl_go1.7.go delete mode 100644 vendor/github.com/lib/pq/ssl_renegotiation.go create mode 100644 vendor/github.com/lib/pq/ssl_test.go create mode 100644 vendor/github.com/lib/pq/url_test.go create mode 100644 vendor/github.com/lib/pq/uuid_test.go diff --git a/vendor/github.com/lib/pq/README.md b/vendor/github.com/lib/pq/README.md index d71f3c2c..385fe735 100644 --- a/vendor/github.com/lib/pq/README.md +++ b/vendor/github.com/lib/pq/README.md @@ -10,7 +10,7 @@ ## Docs For detailed documentation and basic usage examples, please see the package -documentation at . +documentation at . ## Tests diff --git a/vendor/github.com/lib/pq/array_test.go b/vendor/github.com/lib/pq/array_test.go new file mode 100644 index 00000000..f724bcd8 --- /dev/null +++ b/vendor/github.com/lib/pq/array_test.go @@ -0,0 +1,1311 @@ +package pq + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "math/rand" + "reflect" + "strings" + "testing" +) + +func TestParseArray(t *testing.T) { + for _, tt := range []struct { + input string + delim string + dims []int + elems [][]byte + }{ + {`{}`, `,`, nil, [][]byte{}}, + {`{NULL}`, `,`, []int{1}, [][]byte{nil}}, + {`{a}`, `,`, []int{1}, [][]byte{{'a'}}}, + {`{a,b}`, `,`, []int{2}, [][]byte{{'a'}, {'b'}}}, + {`{{a,b}}`, `,`, []int{1, 2}, [][]byte{{'a'}, {'b'}}}, + {`{{a},{b}}`, `,`, []int{2, 1}, [][]byte{{'a'}, {'b'}}}, + {`{{{a,b},{c,d},{e,f}}}`, `,`, []int{1, 3, 2}, [][]byte{ + {'a'}, {'b'}, {'c'}, {'d'}, {'e'}, {'f'}, + }}, + {`{""}`, `,`, []int{1}, [][]byte{{}}}, + {`{","}`, `,`, []int{1}, [][]byte{{','}}}, + {`{",",","}`, `,`, []int{2}, [][]byte{{','}, {','}}}, + {`{{",",","}}`, `,`, []int{1, 2}, [][]byte{{','}, {','}}}, + {`{{","},{","}}`, `,`, []int{2, 1}, [][]byte{{','}, {','}}}, + {`{{{",",","},{",",","},{",",","}}}`, `,`, []int{1, 3, 2}, [][]byte{ + {','}, {','}, {','}, {','}, {','}, {','}, + }}, + {`{"\"}"}`, `,`, []int{1}, [][]byte{{'"', '}'}}}, + {`{"\"","\""}`, `,`, []int{2}, [][]byte{{'"'}, {'"'}}}, + {`{{"\"","\""}}`, `,`, []int{1, 2}, [][]byte{{'"'}, {'"'}}}, + {`{{"\""},{"\""}}`, `,`, []int{2, 1}, [][]byte{{'"'}, {'"'}}}, + {`{{{"\"","\""},{"\"","\""},{"\"","\""}}}`, `,`, []int{1, 3, 2}, [][]byte{ + {'"'}, {'"'}, {'"'}, {'"'}, {'"'}, {'"'}, + }}, + {`{axyzb}`, `xyz`, []int{2}, [][]byte{{'a'}, {'b'}}}, + } { + dims, elems, err := parseArray([]byte(tt.input), []byte(tt.delim)) + + if err != nil { + t.Fatalf("Expected no error for %q, got %q", tt.input, err) + } + if !reflect.DeepEqual(dims, tt.dims) { + t.Errorf("Expected %v dimensions for %q, got %v", tt.dims, tt.input, dims) + } + if !reflect.DeepEqual(elems, tt.elems) { + t.Errorf("Expected %v elements for %q, got %v", tt.elems, tt.input, elems) + } + } +} + +func TestParseArrayError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "expected '{' at offset 0"}, + {`x`, "expected '{' at offset 0"}, + {`}`, "expected '{' at offset 0"}, + {`{`, "expected '}' at offset 1"}, + {`{{}`, "expected '}' at offset 3"}, + {`{}}`, "unexpected '}' at offset 2"}, + {`{,}`, "unexpected ',' at offset 1"}, + {`{,x}`, "unexpected ',' at offset 1"}, + {`{x,}`, "unexpected '}' at offset 3"}, + {`{x,{`, "unexpected '{' at offset 3"}, + {`{x},`, "unexpected ',' at offset 3"}, + {`{x}}`, "unexpected '}' at offset 3"}, + {`{{x}`, "expected '}' at offset 4"}, + {`{""x}`, "unexpected 'x' at offset 3"}, + {`{{a},{b,c}}`, "multidimensional arrays must have elements with matching dimensions"}, + } { + _, _, err := parseArray([]byte(tt.input), []byte{','}) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + } +} + +func TestArrayScanner(t *testing.T) { + var s sql.Scanner = Array(&[]bool{}) + if _, ok := s.(*BoolArray); !ok { + t.Errorf("Expected *BoolArray, got %T", s) + } + + s = Array(&[]float64{}) + if _, ok := s.(*Float64Array); !ok { + t.Errorf("Expected *Float64Array, got %T", s) + } + + s = Array(&[]int64{}) + if _, ok := s.(*Int64Array); !ok { + t.Errorf("Expected *Int64Array, got %T", s) + } + + s = Array(&[]string{}) + if _, ok := s.(*StringArray); !ok { + t.Errorf("Expected *StringArray, got %T", s) + } + + for _, tt := range []interface{}{ + &[]sql.Scanner{}, + &[][]bool{}, + &[][]float64{}, + &[][]int64{}, + &[][]string{}, + } { + s = Array(tt) + if _, ok := s.(GenericArray); !ok { + t.Errorf("Expected GenericArray for %T, got %T", tt, s) + } + } +} + +func TestArrayValuer(t *testing.T) { + var v driver.Valuer = Array([]bool{}) + if _, ok := v.(*BoolArray); !ok { + t.Errorf("Expected *BoolArray, got %T", v) + } + + v = Array([]float64{}) + if _, ok := v.(*Float64Array); !ok { + t.Errorf("Expected *Float64Array, got %T", v) + } + + v = Array([]int64{}) + if _, ok := v.(*Int64Array); !ok { + t.Errorf("Expected *Int64Array, got %T", v) + } + + v = Array([]string{}) + if _, ok := v.(*StringArray); !ok { + t.Errorf("Expected *StringArray, got %T", v) + } + + for _, tt := range []interface{}{ + nil, + []driver.Value{}, + [][]bool{}, + [][]float64{}, + [][]int64{}, + [][]string{}, + } { + v = Array(tt) + if _, ok := v.(GenericArray); !ok { + t.Errorf("Expected GenericArray for %T, got %T", tt, v) + } + } +} + +func TestBoolArrayScanUnsupported(t *testing.T) { + var arr BoolArray + err := arr.Scan(1) + + if err == nil { + t.Fatal("Expected error when scanning from int") + } + if !strings.Contains(err.Error(), "int to BoolArray") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestBoolArrayScanEmpty(t *testing.T) { + var arr BoolArray + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestBoolArrayScanNil(t *testing.T) { + arr := BoolArray{true, true, true} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var BoolArrayStringTests = []struct { + str string + arr BoolArray +}{ + {`{}`, BoolArray{}}, + {`{t}`, BoolArray{true}}, + {`{f,t}`, BoolArray{false, true}}, +} + +func TestBoolArrayScanBytes(t *testing.T) { + for _, tt := range BoolArrayStringTests { + bytes := []byte(tt.str) + arr := BoolArray{true, true, true} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkBoolArrayScanBytes(b *testing.B) { + var a BoolArray + var x interface{} = []byte(`{t,f,t,f,t,f,t,f,t,f}`) + + for i := 0; i < b.N; i++ { + a = BoolArray{} + a.Scan(x) + } +} + +func TestBoolArrayScanString(t *testing.T) { + for _, tt := range BoolArrayStringTests { + arr := BoolArray{true, true, true} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestBoolArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{t},{f}}`, "cannot convert ARRAY[2][1] to BoolArray"}, + {`{NULL}`, `could not parse boolean array index 0: invalid boolean ""`}, + {`{a}`, `could not parse boolean array index 0: invalid boolean "a"`}, + {`{t,b}`, `could not parse boolean array index 1: invalid boolean "b"`}, + {`{t,f,cd}`, `could not parse boolean array index 2: invalid boolean "cd"`}, + } { + arr := BoolArray{true, true, true} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, BoolArray{true, true, true}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestBoolArrayValue(t *testing.T) { + result, err := BoolArray(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = BoolArray([]bool{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = BoolArray([]bool{false, true, false}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{f,t,f}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkBoolArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]bool, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Intn(2) == 0 + } + a := BoolArray(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestByteaArrayScanUnsupported(t *testing.T) { + var arr ByteaArray + err := arr.Scan(1) + + if err == nil { + t.Fatal("Expected error when scanning from int") + } + if !strings.Contains(err.Error(), "int to ByteaArray") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestByteaArrayScanEmpty(t *testing.T) { + var arr ByteaArray + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestByteaArrayScanNil(t *testing.T) { + arr := ByteaArray{{2}, {6}, {0, 0}} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var ByteaArrayStringTests = []struct { + str string + arr ByteaArray +}{ + {`{}`, ByteaArray{}}, + {`{NULL}`, ByteaArray{nil}}, + {`{"\\xfeff"}`, ByteaArray{{'\xFE', '\xFF'}}}, + {`{"\\xdead","\\xbeef"}`, ByteaArray{{'\xDE', '\xAD'}, {'\xBE', '\xEF'}}}, +} + +func TestByteaArrayScanBytes(t *testing.T) { + for _, tt := range ByteaArrayStringTests { + bytes := []byte(tt.str) + arr := ByteaArray{{2}, {6}, {0, 0}} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkByteaArrayScanBytes(b *testing.B) { + var a ByteaArray + var x interface{} = []byte(`{"\\xfe","\\xff","\\xdead","\\xbeef","\\xfe","\\xff","\\xdead","\\xbeef","\\xfe","\\xff"}`) + + for i := 0; i < b.N; i++ { + a = ByteaArray{} + a.Scan(x) + } +} + +func TestByteaArrayScanString(t *testing.T) { + for _, tt := range ByteaArrayStringTests { + arr := ByteaArray{{2}, {6}, {0, 0}} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestByteaArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{"\\xfeff"},{"\\xbeef"}}`, "cannot convert ARRAY[2][1] to ByteaArray"}, + {`{"\\abc"}`, "could not parse bytea array index 0: could not parse bytea value"}, + } { + arr := ByteaArray{{2}, {6}, {0, 0}} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, ByteaArray{{2}, {6}, {0, 0}}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestByteaArrayValue(t *testing.T) { + result, err := ByteaArray(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = ByteaArray([][]byte{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = ByteaArray([][]byte{{'\xDE', '\xAD', '\xBE', '\xEF'}, {'\xFE', '\xFF'}, {}}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{"\\xdeadbeef","\\xfeff","\\x"}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkByteaArrayValue(b *testing.B) { + rand.Seed(1) + x := make([][]byte, 10) + for i := 0; i < len(x); i++ { + x[i] = make([]byte, len(x)) + for j := 0; j < len(x); j++ { + x[i][j] = byte(rand.Int()) + } + } + a := ByteaArray(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestFloat64ArrayScanUnsupported(t *testing.T) { + var arr Float64Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Float64Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestFloat64ArrayScanEmpty(t *testing.T) { + var arr Float64Array + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestFloat64ArrayScanNil(t *testing.T) { + arr := Float64Array{5, 5, 5} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var Float64ArrayStringTests = []struct { + str string + arr Float64Array +}{ + {`{}`, Float64Array{}}, + {`{1.2}`, Float64Array{1.2}}, + {`{3.456,7.89}`, Float64Array{3.456, 7.89}}, + {`{3,1,2}`, Float64Array{3, 1, 2}}, +} + +func TestFloat64ArrayScanBytes(t *testing.T) { + for _, tt := range Float64ArrayStringTests { + bytes := []byte(tt.str) + arr := Float64Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkFloat64ArrayScanBytes(b *testing.B) { + var a Float64Array + var x interface{} = []byte(`{1.2,3.4,5.6,7.8,9.01,2.34,5.67,8.90,1.234,5.678}`) + + for i := 0; i < b.N; i++ { + a = Float64Array{} + a.Scan(x) + } +} + +func TestFloat64ArrayScanString(t *testing.T) { + for _, tt := range Float64ArrayStringTests { + arr := Float64Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestFloat64ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5.6},{7.8}}`, "cannot convert ARRAY[2][1] to Float64Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5.6,a}`, "parsing array element index 1:"}, + {`{5.6,7.8,a}`, "parsing array element index 2:"}, + } { + arr := Float64Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Float64Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestFloat64ArrayValue(t *testing.T) { + result, err := Float64Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Float64Array([]float64{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Float64Array([]float64{1.2, 3.4, 5.6}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1.2,3.4,5.6}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkFloat64ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]float64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.NormFloat64() + } + a := Float64Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestInt64ArrayScanUnsupported(t *testing.T) { + var arr Int64Array + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to Int64Array") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestInt64ArrayScanEmpty(t *testing.T) { + var arr Int64Array + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestInt64ArrayScanNil(t *testing.T) { + arr := Int64Array{5, 5, 5} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var Int64ArrayStringTests = []struct { + str string + arr Int64Array +}{ + {`{}`, Int64Array{}}, + {`{12}`, Int64Array{12}}, + {`{345,678}`, Int64Array{345, 678}}, +} + +func TestInt64ArrayScanBytes(t *testing.T) { + for _, tt := range Int64ArrayStringTests { + bytes := []byte(tt.str) + arr := Int64Array{5, 5, 5} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkInt64ArrayScanBytes(b *testing.B) { + var a Int64Array + var x interface{} = []byte(`{1,2,3,4,5,6,7,8,9,0}`) + + for i := 0; i < b.N; i++ { + a = Int64Array{} + a.Scan(x) + } +} + +func TestInt64ArrayScanString(t *testing.T) { + for _, tt := range Int64ArrayStringTests { + arr := Int64Array{5, 5, 5} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestInt64ArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{5},{6}}`, "cannot convert ARRAY[2][1] to Int64Array"}, + {`{NULL}`, "parsing array element index 0:"}, + {`{a}`, "parsing array element index 0:"}, + {`{5,a}`, "parsing array element index 1:"}, + {`{5,6,a}`, "parsing array element index 2:"}, + } { + arr := Int64Array{5, 5, 5} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, Int64Array{5, 5, 5}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestInt64ArrayValue(t *testing.T) { + result, err := Int64Array(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = Int64Array([]int64{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = Int64Array([]int64{1, 2, 3}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{1,2,3}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkInt64ArrayValue(b *testing.B) { + rand.Seed(1) + x := make([]int64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Int63() + } + a := Int64Array(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestStringArrayScanUnsupported(t *testing.T) { + var arr StringArray + err := arr.Scan(true) + + if err == nil { + t.Fatal("Expected error when scanning from bool") + } + if !strings.Contains(err.Error(), "bool to StringArray") { + t.Errorf("Expected type to be mentioned when scanning, got %q", err) + } +} + +func TestStringArrayScanEmpty(t *testing.T) { + var arr StringArray + err := arr.Scan(`{}`) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr == nil || len(arr) != 0 { + t.Errorf("Expected empty, got %#v", arr) + } +} + +func TestStringArrayScanNil(t *testing.T) { + arr := StringArray{"x", "x", "x"} + err := arr.Scan(nil) + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if arr != nil { + t.Errorf("Expected nil, got %+v", arr) + } +} + +var StringArrayStringTests = []struct { + str string + arr StringArray +}{ + {`{}`, StringArray{}}, + {`{t}`, StringArray{"t"}}, + {`{f,1}`, StringArray{"f", "1"}}, + {`{"a\\b","c d",","}`, StringArray{"a\\b", "c d", ","}}, +} + +func TestStringArrayScanBytes(t *testing.T) { + for _, tt := range StringArrayStringTests { + bytes := []byte(tt.str) + arr := StringArray{"x", "x", "x"} + err := arr.Scan(bytes) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", bytes, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, bytes, arr) + } + } +} + +func BenchmarkStringArrayScanBytes(b *testing.B) { + var a StringArray + var x interface{} = []byte(`{a,b,c,d,e,f,g,h,i,j}`) + var y interface{} = []byte(`{"\a","\b","\c","\d","\e","\f","\g","\h","\i","\j"}`) + + for i := 0; i < b.N; i++ { + a = StringArray{} + a.Scan(x) + a = StringArray{} + a.Scan(y) + } +} + +func TestStringArrayScanString(t *testing.T) { + for _, tt := range StringArrayStringTests { + arr := StringArray{"x", "x", "x"} + err := arr.Scan(tt.str) + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.str, err) + } + if !reflect.DeepEqual(arr, tt.arr) { + t.Errorf("Expected %+v for %q, got %+v", tt.arr, tt.str, arr) + } + } +} + +func TestStringArrayScanError(t *testing.T) { + for _, tt := range []struct { + input, err string + }{ + {``, "unable to parse array"}, + {`{`, "unable to parse array"}, + {`{{a},{b}}`, "cannot convert ARRAY[2][1] to StringArray"}, + {`{NULL}`, "parsing array element index 0: cannot convert nil to string"}, + {`{a,NULL}`, "parsing array element index 1: cannot convert nil to string"}, + {`{a,b,NULL}`, "parsing array element index 2: cannot convert nil to string"}, + } { + arr := StringArray{"x", "x", "x"} + err := arr.Scan(tt.input) + + if err == nil { + t.Fatalf("Expected error for %q, got none", tt.input) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for %q, got %q", tt.err, tt.input, err) + } + if !reflect.DeepEqual(arr, StringArray{"x", "x", "x"}) { + t.Errorf("Expected destination not to change for %q, got %+v", tt.input, arr) + } + } +} + +func TestStringArrayValue(t *testing.T) { + result, err := StringArray(nil).Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + result, err = StringArray([]string{}).Value() + + if err != nil { + t.Fatalf("Expected no error for empty, got %v", err) + } + if expected := `{}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected empty, got %q", result) + } + + result, err = StringArray([]string{`a`, `\b`, `c"`, `d,e`}).Value() + + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if expected := `{"a","\\b","c\"","d,e"}`; !reflect.DeepEqual(result, expected) { + t.Errorf("Expected %q, got %q", expected, result) + } +} + +func BenchmarkStringArrayValue(b *testing.B) { + x := make([]string, 10) + for i := 0; i < len(x); i++ { + x[i] = strings.Repeat(`abc"def\ghi`, 5) + } + a := StringArray(x) + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestGenericArrayScanUnsupported(t *testing.T) { + var s string + var ss []string + var nsa [1]sql.NullString + + for _, tt := range []struct { + src, dest interface{} + err string + }{ + {nil, nil, "destination is not a pointer to array or slice"}, + {nil, true, "destination bool is not a pointer to array or slice"}, + {nil, &s, "destination *string is not a pointer to array or slice"}, + {nil, ss, "destination []string is not a pointer to array or slice"}, + {nil, &nsa, " to [1]sql.NullString"}, + {true, &ss, "bool to []string"}, + {`{{x}}`, &ss, "multidimensional ARRAY[1][1] is not implemented"}, + {`{{x},{x}}`, &ss, "multidimensional ARRAY[2][1] is not implemented"}, + {`{x}`, &ss, "scanning to string is not implemented"}, + } { + err := GenericArray{tt.dest}.Scan(tt.src) + + if err == nil { + t.Fatalf("Expected error for [%#v %#v]", tt.src, tt.dest) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for [%#v %#v], got %q", tt.err, tt.src, tt.dest, err) + } + } +} + +func TestGenericArrayScanScannerArrayBytes(t *testing.T) { + src, expected, nsa := []byte(`{NULL,abc,"\""}`), + [3]sql.NullString{{}, {String: `abc`, Valid: true}, {String: `"`, Valid: true}}, + [3]sql.NullString{{String: ``, Valid: true}, {}, {}} + + if err := (GenericArray{&nsa}).Scan(src); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(nsa, expected) { + t.Errorf("Expected %v, got %v", expected, nsa) + } +} + +func TestGenericArrayScanScannerArrayString(t *testing.T) { + src, expected, nsa := `{NULL,"\"",xyz}`, + [3]sql.NullString{{}, {String: `"`, Valid: true}, {String: `xyz`, Valid: true}}, + [3]sql.NullString{{String: ``, Valid: true}, {}, {}} + + if err := (GenericArray{&nsa}).Scan(src); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(nsa, expected) { + t.Errorf("Expected %v, got %v", expected, nsa) + } +} + +func TestGenericArrayScanScannerSliceEmpty(t *testing.T) { + var nss []sql.NullString + + if err := (GenericArray{&nss}).Scan(`{}`); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if nss == nil || len(nss) != 0 { + t.Errorf("Expected empty, got %#v", nss) + } +} + +func TestGenericArrayScanScannerSliceNil(t *testing.T) { + nss := []sql.NullString{{String: ``, Valid: true}, {}} + + if err := (GenericArray{&nss}).Scan(nil); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if nss != nil { + t.Errorf("Expected nil, got %+v", nss) + } +} + +func TestGenericArrayScanScannerSliceBytes(t *testing.T) { + src, expected, nss := []byte(`{NULL,abc,"\""}`), + []sql.NullString{{}, {String: `abc`, Valid: true}, {String: `"`, Valid: true}}, + []sql.NullString{{String: ``, Valid: true}, {}, {}, {}, {}} + + if err := (GenericArray{&nss}).Scan(src); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(nss, expected) { + t.Errorf("Expected %v, got %v", expected, nss) + } +} + +func BenchmarkGenericArrayScanScannerSliceBytes(b *testing.B) { + var a GenericArray + var x interface{} = []byte(`{a,b,c,d,e,f,g,h,i,j}`) + var y interface{} = []byte(`{"\a","\b","\c","\d","\e","\f","\g","\h","\i","\j"}`) + + for i := 0; i < b.N; i++ { + a = GenericArray{new([]sql.NullString)} + a.Scan(x) + a = GenericArray{new([]sql.NullString)} + a.Scan(y) + } +} + +func TestGenericArrayScanScannerSliceString(t *testing.T) { + src, expected, nss := `{NULL,"\"",xyz}`, + []sql.NullString{{}, {String: `"`, Valid: true}, {String: `xyz`, Valid: true}}, + []sql.NullString{{String: ``, Valid: true}, {}, {}} + + if err := (GenericArray{&nss}).Scan(src); err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(nss, expected) { + t.Errorf("Expected %v, got %v", expected, nss) + } +} + +type TildeNullInt64 struct{ sql.NullInt64 } + +func (TildeNullInt64) ArrayDelimiter() string { return "~" } + +func TestGenericArrayScanDelimiter(t *testing.T) { + src, expected, tnis := `{12~NULL~76}`, + []TildeNullInt64{{sql.NullInt64{Int64: 12, Valid: true}}, {}, {sql.NullInt64{Int64: 76, Valid: true}}}, + []TildeNullInt64{{sql.NullInt64{Int64: 0, Valid: true}}, {}} + + if err := (GenericArray{&tnis}).Scan(src); err != nil { + t.Fatalf("Expected no error for %#v, got %v", src, err) + } + if !reflect.DeepEqual(tnis, expected) { + t.Errorf("Expected %v for %#v, got %v", expected, src, tnis) + } +} + +func TestGenericArrayScanErrors(t *testing.T) { + var sa [1]string + var nis []sql.NullInt64 + var pss *[]string + + for _, tt := range []struct { + src, dest interface{} + err string + }{ + {nil, pss, "destination *[]string is nil"}, + {`{`, &sa, "unable to parse"}, + {`{}`, &sa, "cannot convert ARRAY[0] to [1]string"}, + {`{x,x}`, &sa, "cannot convert ARRAY[2] to [1]string"}, + {`{x}`, &nis, `parsing array element index 0: converting`}, + } { + err := GenericArray{tt.dest}.Scan(tt.src) + + if err == nil { + t.Fatalf("Expected error for [%#v %#v]", tt.src, tt.dest) + } + if !strings.Contains(err.Error(), tt.err) { + t.Errorf("Expected error to contain %q for [%#v %#v], got %q", tt.err, tt.src, tt.dest, err) + } + } +} + +func TestGenericArrayValueUnsupported(t *testing.T) { + _, err := GenericArray{true}.Value() + + if err == nil { + t.Fatal("Expected error for bool") + } + if !strings.Contains(err.Error(), "bool to array") { + t.Errorf("Expected type to be mentioned, got %q", err) + } +} + +type ByteArrayValuer [1]byte +type ByteSliceValuer []byte +type FuncArrayValuer struct { + delimiter func() string + value func() (driver.Value, error) +} + +func (a ByteArrayValuer) Value() (driver.Value, error) { return a[:], nil } +func (b ByteSliceValuer) Value() (driver.Value, error) { return []byte(b), nil } +func (f FuncArrayValuer) ArrayDelimiter() string { return f.delimiter() } +func (f FuncArrayValuer) Value() (driver.Value, error) { return f.value() } + +func TestGenericArrayValue(t *testing.T) { + result, err := GenericArray{nil}.Value() + + if err != nil { + t.Fatalf("Expected no error for nil, got %v", err) + } + if result != nil { + t.Errorf("Expected nil, got %q", result) + } + + for _, tt := range []interface{}{ + []bool(nil), + [][]int(nil), + []*int(nil), + []sql.NullString(nil), + } { + result, err := GenericArray{tt}.Value() + + if err != nil { + t.Fatalf("Expected no error for %#v, got %v", tt, err) + } + if result != nil { + t.Errorf("Expected nil for %#v, got %q", tt, result) + } + } + + Tilde := func(v driver.Value) FuncArrayValuer { + return FuncArrayValuer{ + func() string { return "~" }, + func() (driver.Value, error) { return v, nil }} + } + + for _, tt := range []struct { + result string + input interface{} + }{ + {`{}`, []bool{}}, + {`{true}`, []bool{true}}, + {`{true,false}`, []bool{true, false}}, + {`{true,false}`, [2]bool{true, false}}, + + {`{}`, [][]int{{}}}, + {`{}`, [][]int{{}, {}}}, + {`{{1}}`, [][]int{{1}}}, + {`{{1},{2}}`, [][]int{{1}, {2}}}, + {`{{1,2},{3,4}}`, [][]int{{1, 2}, {3, 4}}}, + {`{{1,2},{3,4}}`, [2][2]int{{1, 2}, {3, 4}}}, + + {`{"a","\\b","c\"","d,e"}`, []string{`a`, `\b`, `c"`, `d,e`}}, + {`{"a","\\b","c\"","d,e"}`, [][]byte{{'a'}, {'\\', 'b'}, {'c', '"'}, {'d', ',', 'e'}}}, + + {`{NULL}`, []*int{nil}}, + {`{0,NULL}`, []*int{new(int), nil}}, + + {`{NULL}`, []sql.NullString{{}}}, + {`{"\"",NULL}`, []sql.NullString{{String: `"`, Valid: true}, {}}}, + + {`{"a","b"}`, []ByteArrayValuer{{'a'}, {'b'}}}, + {`{{"a","b"},{"c","d"}}`, [][]ByteArrayValuer{{{'a'}, {'b'}}, {{'c'}, {'d'}}}}, + + {`{"e","f"}`, []ByteSliceValuer{{'e'}, {'f'}}}, + {`{{"e","f"},{"g","h"}}`, [][]ByteSliceValuer{{{'e'}, {'f'}}, {{'g'}, {'h'}}}}, + + {`{1~2}`, []FuncArrayValuer{Tilde(int64(1)), Tilde(int64(2))}}, + {`{{1~2}~{3~4}}`, [][]FuncArrayValuer{{Tilde(int64(1)), Tilde(int64(2))}, {Tilde(int64(3)), Tilde(int64(4))}}}, + } { + result, err := GenericArray{tt.input}.Value() + + if err != nil { + t.Fatalf("Expected no error for %q, got %v", tt.input, err) + } + if !reflect.DeepEqual(result, tt.result) { + t.Errorf("Expected %q for %q, got %q", tt.result, tt.input, result) + } + } +} + +func TestGenericArrayValueErrors(t *testing.T) { + v := []interface{}{func() {}} + if _, err := (GenericArray{v}).Value(); err == nil { + t.Errorf("Expected error for %q, got nil", v) + } + + v = []interface{}{nil, func() {}} + if _, err := (GenericArray{v}).Value(); err == nil { + t.Errorf("Expected error for %q, got nil", v) + } +} + +func BenchmarkGenericArrayValueBools(b *testing.B) { + rand.Seed(1) + x := make([]bool, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Intn(2) == 0 + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func BenchmarkGenericArrayValueFloat64s(b *testing.B) { + rand.Seed(1) + x := make([]float64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.NormFloat64() + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func BenchmarkGenericArrayValueInt64s(b *testing.B) { + rand.Seed(1) + x := make([]int64, 10) + for i := 0; i < len(x); i++ { + x[i] = rand.Int63() + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func BenchmarkGenericArrayValueByteSlices(b *testing.B) { + x := make([][]byte, 10) + for i := 0; i < len(x); i++ { + x[i] = bytes.Repeat([]byte(`abc"def\ghi`), 5) + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func BenchmarkGenericArrayValueStrings(b *testing.B) { + x := make([]string, 10) + for i := 0; i < len(x); i++ { + x[i] = strings.Repeat(`abc"def\ghi`, 5) + } + a := GenericArray{x} + + for i := 0; i < b.N; i++ { + a.Value() + } +} + +func TestArrayScanBackend(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + for _, tt := range []struct { + s string + d sql.Scanner + e interface{} + }{ + {`ARRAY[true, false]`, new(BoolArray), &BoolArray{true, false}}, + {`ARRAY[E'\\xdead', E'\\xbeef']`, new(ByteaArray), &ByteaArray{{'\xDE', '\xAD'}, {'\xBE', '\xEF'}}}, + {`ARRAY[1.2, 3.4]`, new(Float64Array), &Float64Array{1.2, 3.4}}, + {`ARRAY[1, 2, 3]`, new(Int64Array), &Int64Array{1, 2, 3}}, + {`ARRAY['a', E'\\b', 'c"', 'd,e']`, new(StringArray), &StringArray{`a`, `\b`, `c"`, `d,e`}}, + } { + err := db.QueryRow(`SELECT ` + tt.s).Scan(tt.d) + if err != nil { + t.Errorf("Expected no error when scanning %s into %T, got %v", tt.s, tt.d, err) + } + if !reflect.DeepEqual(tt.d, tt.e) { + t.Errorf("Expected %v when scanning %s into %T, got %v", tt.e, tt.s, tt.d, tt.d) + } + } +} + +func TestArrayValueBackend(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + for _, tt := range []struct { + s string + v driver.Valuer + }{ + {`ARRAY[true, false]`, BoolArray{true, false}}, + {`ARRAY[E'\\xdead', E'\\xbeef']`, ByteaArray{{'\xDE', '\xAD'}, {'\xBE', '\xEF'}}}, + {`ARRAY[1.2, 3.4]`, Float64Array{1.2, 3.4}}, + {`ARRAY[1, 2, 3]`, Int64Array{1, 2, 3}}, + {`ARRAY['a', E'\\b', 'c"', 'd,e']`, StringArray{`a`, `\b`, `c"`, `d,e`}}, + } { + var x int + err := db.QueryRow(`SELECT 1 WHERE `+tt.s+` <> $1`, tt.v).Scan(&x) + if err != sql.ErrNoRows { + t.Errorf("Expected %v to equal %s, got %v", tt.v, tt.s, err) + } + } +} diff --git a/vendor/github.com/lib/pq/bench_test.go b/vendor/github.com/lib/pq/bench_test.go new file mode 100644 index 00000000..b3754980 --- /dev/null +++ b/vendor/github.com/lib/pq/bench_test.go @@ -0,0 +1,434 @@ +package pq + +import ( + "bufio" + "bytes" + "context" + "database/sql" + "database/sql/driver" + "io" + "math/rand" + "net" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/lib/pq/oid" +) + +var ( + selectStringQuery = "SELECT '" + strings.Repeat("0123456789", 10) + "'" + selectSeriesQuery = "SELECT generate_series(1, 100)" +) + +func BenchmarkSelectString(b *testing.B) { + var result string + benchQuery(b, selectStringQuery, &result) +} + +func BenchmarkSelectSeries(b *testing.B) { + var result int + benchQuery(b, selectSeriesQuery, &result) +} + +func benchQuery(b *testing.B, query string, result interface{}) { + b.StopTimer() + db := openTestConn(b) + defer db.Close() + b.StartTimer() + + for i := 0; i < b.N; i++ { + benchQueryLoop(b, db, query, result) + } +} + +func benchQueryLoop(b *testing.B, db *sql.DB, query string, result interface{}) { + rows, err := db.Query(query) + if err != nil { + b.Fatal(err) + } + defer rows.Close() + for rows.Next() { + err = rows.Scan(result) + if err != nil { + b.Fatal("failed to scan", err) + } + } +} + +// reading from circularConn yields content[:prefixLen] once, followed by +// content[prefixLen:] over and over again. It never returns EOF. +type circularConn struct { + content string + prefixLen int + pos int + net.Conn // for all other net.Conn methods that will never be called +} + +func (r *circularConn) Read(b []byte) (n int, err error) { + n = copy(b, r.content[r.pos:]) + r.pos += n + if r.pos >= len(r.content) { + r.pos = r.prefixLen + } + return +} + +func (r *circularConn) Write(b []byte) (n int, err error) { return len(b), nil } + +func (r *circularConn) Close() error { return nil } + +func fakeConn(content string, prefixLen int) *conn { + c := &circularConn{content: content, prefixLen: prefixLen} + return &conn{buf: bufio.NewReader(c), c: c} +} + +// This benchmark is meant to be the same as BenchmarkSelectString, but takes +// out some of the factors this package can't control. The numbers are less noisy, +// but also the costs of network communication aren't accurately represented. +func BenchmarkMockSelectString(b *testing.B) { + b.StopTimer() + // taken from a recorded run of BenchmarkSelectString + // See: http://www.postgresql.org/docs/current/static/protocol-message-formats.html + const response = "1\x00\x00\x00\x04" + + "t\x00\x00\x00\x06\x00\x00" + + "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + + "Z\x00\x00\x00\x05I" + + "2\x00\x00\x00\x04" + + "D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + + "C\x00\x00\x00\rSELECT 1\x00" + + "Z\x00\x00\x00\x05I" + + "3\x00\x00\x00\x04" + + "Z\x00\x00\x00\x05I" + c := fakeConn(response, 0) + b.StartTimer() + + for i := 0; i < b.N; i++ { + benchMockQuery(b, c, selectStringQuery) + } +} + +var seriesRowData = func() string { + var buf bytes.Buffer + for i := 1; i <= 100; i++ { + digits := byte(2) + if i >= 100 { + digits = 3 + } else if i < 10 { + digits = 1 + } + buf.WriteString("D\x00\x00\x00") + buf.WriteByte(10 + digits) + buf.WriteString("\x00\x01\x00\x00\x00") + buf.WriteByte(digits) + buf.WriteString(strconv.Itoa(i)) + } + return buf.String() +}() + +func BenchmarkMockSelectSeries(b *testing.B) { + b.StopTimer() + var response = "1\x00\x00\x00\x04" + + "t\x00\x00\x00\x06\x00\x00" + + "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + + "Z\x00\x00\x00\x05I" + + "2\x00\x00\x00\x04" + + seriesRowData + + "C\x00\x00\x00\x0fSELECT 100\x00" + + "Z\x00\x00\x00\x05I" + + "3\x00\x00\x00\x04" + + "Z\x00\x00\x00\x05I" + c := fakeConn(response, 0) + b.StartTimer() + + for i := 0; i < b.N; i++ { + benchMockQuery(b, c, selectSeriesQuery) + } +} + +func benchMockQuery(b *testing.B, c *conn, query string) { + stmt, err := c.Prepare(query) + if err != nil { + b.Fatal(err) + } + defer stmt.Close() + rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil) + if err != nil { + b.Fatal(err) + } + defer rows.Close() + var dest [1]driver.Value + for { + if err := rows.Next(dest[:]); err != nil { + if err == io.EOF { + break + } + b.Fatal(err) + } + } +} + +func BenchmarkPreparedSelectString(b *testing.B) { + var result string + benchPreparedQuery(b, selectStringQuery, &result) +} + +func BenchmarkPreparedSelectSeries(b *testing.B) { + var result int + benchPreparedQuery(b, selectSeriesQuery, &result) +} + +func benchPreparedQuery(b *testing.B, query string, result interface{}) { + b.StopTimer() + db := openTestConn(b) + defer db.Close() + stmt, err := db.Prepare(query) + if err != nil { + b.Fatal(err) + } + defer stmt.Close() + b.StartTimer() + + for i := 0; i < b.N; i++ { + benchPreparedQueryLoop(b, db, stmt, result) + } +} + +func benchPreparedQueryLoop(b *testing.B, db *sql.DB, stmt *sql.Stmt, result interface{}) { + rows, err := stmt.Query() + if err != nil { + b.Fatal(err) + } + if !rows.Next() { + rows.Close() + b.Fatal("no rows") + } + defer rows.Close() + for rows.Next() { + err = rows.Scan(&result) + if err != nil { + b.Fatal("failed to scan") + } + } +} + +// See the comment for BenchmarkMockSelectString. +func BenchmarkMockPreparedSelectString(b *testing.B) { + b.StopTimer() + const parseResponse = "1\x00\x00\x00\x04" + + "t\x00\x00\x00\x06\x00\x00" + + "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + + "Z\x00\x00\x00\x05I" + const responses = parseResponse + + "2\x00\x00\x00\x04" + + "D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" + + "C\x00\x00\x00\rSELECT 1\x00" + + "Z\x00\x00\x00\x05I" + c := fakeConn(responses, len(parseResponse)) + + stmt, err := c.Prepare(selectStringQuery) + if err != nil { + b.Fatal(err) + } + b.StartTimer() + + for i := 0; i < b.N; i++ { + benchPreparedMockQuery(b, c, stmt) + } +} + +func BenchmarkMockPreparedSelectSeries(b *testing.B) { + b.StopTimer() + const parseResponse = "1\x00\x00\x00\x04" + + "t\x00\x00\x00\x06\x00\x00" + + "T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" + + "Z\x00\x00\x00\x05I" + var responses = parseResponse + + "2\x00\x00\x00\x04" + + seriesRowData + + "C\x00\x00\x00\x0fSELECT 100\x00" + + "Z\x00\x00\x00\x05I" + c := fakeConn(responses, len(parseResponse)) + + stmt, err := c.Prepare(selectSeriesQuery) + if err != nil { + b.Fatal(err) + } + b.StartTimer() + + for i := 0; i < b.N; i++ { + benchPreparedMockQuery(b, c, stmt) + } +} + +func benchPreparedMockQuery(b *testing.B, c *conn, stmt driver.Stmt) { + rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil) + if err != nil { + b.Fatal(err) + } + defer rows.Close() + var dest [1]driver.Value + for { + if err := rows.Next(dest[:]); err != nil { + if err == io.EOF { + break + } + b.Fatal(err) + } + } +} + +func BenchmarkEncodeInt64(b *testing.B) { + for i := 0; i < b.N; i++ { + encode(¶meterStatus{}, int64(1234), oid.T_int8) + } +} + +func BenchmarkEncodeFloat64(b *testing.B) { + for i := 0; i < b.N; i++ { + encode(¶meterStatus{}, 3.14159, oid.T_float8) + } +} + +var testByteString = []byte("abcdefghijklmnopqrstuvwxyz") + +func BenchmarkEncodeByteaHex(b *testing.B) { + for i := 0; i < b.N; i++ { + encode(¶meterStatus{serverVersion: 90000}, testByteString, oid.T_bytea) + } +} +func BenchmarkEncodeByteaEscape(b *testing.B) { + for i := 0; i < b.N; i++ { + encode(¶meterStatus{serverVersion: 84000}, testByteString, oid.T_bytea) + } +} + +func BenchmarkEncodeBool(b *testing.B) { + for i := 0; i < b.N; i++ { + encode(¶meterStatus{}, true, oid.T_bool) + } +} + +var testTimestamptz = time.Date(2001, time.January, 1, 0, 0, 0, 0, time.Local) + +func BenchmarkEncodeTimestamptz(b *testing.B) { + for i := 0; i < b.N; i++ { + encode(¶meterStatus{}, testTimestamptz, oid.T_timestamptz) + } +} + +var testIntBytes = []byte("1234") + +func BenchmarkDecodeInt64(b *testing.B) { + for i := 0; i < b.N; i++ { + decode(¶meterStatus{}, testIntBytes, oid.T_int8, formatText) + } +} + +var testFloatBytes = []byte("3.14159") + +func BenchmarkDecodeFloat64(b *testing.B) { + for i := 0; i < b.N; i++ { + decode(¶meterStatus{}, testFloatBytes, oid.T_float8, formatText) + } +} + +var testBoolBytes = []byte{'t'} + +func BenchmarkDecodeBool(b *testing.B) { + for i := 0; i < b.N; i++ { + decode(¶meterStatus{}, testBoolBytes, oid.T_bool, formatText) + } +} + +func TestDecodeBool(t *testing.T) { + db := openTestConn(t) + rows, err := db.Query("select true") + if err != nil { + t.Fatal(err) + } + rows.Close() +} + +var testTimestamptzBytes = []byte("2013-09-17 22:15:32.360754-07") + +func BenchmarkDecodeTimestamptz(b *testing.B) { + for i := 0; i < b.N; i++ { + decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText) + } +} + +func BenchmarkDecodeTimestamptzMultiThread(b *testing.B) { + oldProcs := runtime.GOMAXPROCS(0) + defer runtime.GOMAXPROCS(oldProcs) + runtime.GOMAXPROCS(runtime.NumCPU()) + globalLocationCache = newLocationCache() + + f := func(wg *sync.WaitGroup, loops int) { + defer wg.Done() + for i := 0; i < loops; i++ { + decode(¶meterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText) + } + } + + wg := &sync.WaitGroup{} + b.ResetTimer() + for j := 0; j < 10; j++ { + wg.Add(1) + go f(wg, b.N/10) + } + wg.Wait() +} + +func BenchmarkLocationCache(b *testing.B) { + globalLocationCache = newLocationCache() + for i := 0; i < b.N; i++ { + globalLocationCache.getLocation(rand.Intn(10000)) + } +} + +func BenchmarkLocationCacheMultiThread(b *testing.B) { + oldProcs := runtime.GOMAXPROCS(0) + defer runtime.GOMAXPROCS(oldProcs) + runtime.GOMAXPROCS(runtime.NumCPU()) + globalLocationCache = newLocationCache() + + f := func(wg *sync.WaitGroup, loops int) { + defer wg.Done() + for i := 0; i < loops; i++ { + globalLocationCache.getLocation(rand.Intn(10000)) + } + } + + wg := &sync.WaitGroup{} + b.ResetTimer() + for j := 0; j < 10; j++ { + wg.Add(1) + go f(wg, b.N/10) + } + wg.Wait() +} + +// Stress test the performance of parsing results from the wire. +func BenchmarkResultParsing(b *testing.B) { + b.StopTimer() + + db := openTestConn(b) + defer db.Close() + _, err := db.Exec("BEGIN") + if err != nil { + b.Fatal(err) + } + + b.StartTimer() + for i := 0; i < b.N; i++ { + res, err := db.Query("SELECT generate_series(1, 50000)") + if err != nil { + b.Fatal(err) + } + res.Close() + } +} diff --git a/vendor/github.com/lib/pq/certs/README b/vendor/github.com/lib/pq/certs/README new file mode 100644 index 00000000..24ab7b25 --- /dev/null +++ b/vendor/github.com/lib/pq/certs/README @@ -0,0 +1,3 @@ +This directory contains certificates and private keys for testing some +SSL-related functionality in Travis. Do NOT use these certificates for +anything other than testing. diff --git a/vendor/github.com/lib/pq/certs/postgresql.key b/vendor/github.com/lib/pq/certs/postgresql.key new file mode 100644 index 00000000..eb8b20be --- /dev/null +++ b/vendor/github.com/lib/pq/certs/postgresql.key @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICWwIBAAKBgQDjjAaacFRR0TQ0gznNolkPBe2N2A400JL0CU3ujHhVSST4POA0 +WAKy55RYwejlu9Gv9lTBQLGQcHkNNVScjxbpwvCS5mRJOMF2+EdmxFtKtqlDzsi+ +bE0rlJc8VbzR0G63U66JXEtrhkC+wa4eZM6crocKaeXIIRK+rh32Rd8WpwIDAQAB +AoGAM5dM6/kp9P700i8qjOgRPym96Zoh5nGfz/rIE5z/r36NBkdvIg8OVZfR96nH +b0b9TOMR5lsPp0sI9yivTWvX6qyvLJRWy2vvx17hXK9NxXUNTAm0PYZUTvCtcPeX +RnJpzQKNZQPkFzF0uXBc4CtPK2Vz0+FGvAelrhYAxnw1dIkCQQD+9qaW5QhXjsjb +Nl85CmXgxPmGROcgLQCO+omfrjf9UXrituU9Dz6auym5lDGEdMFnkzfr+wpasEy9 +mf5ZZOhDAkEA5HjXfVGaCtpydOt6hDon/uZsyssCK2lQ7NSuE3vP+sUsYMzIpEoy +t3VWXqKbo+g9KNDTP4WEliqp1aiSIylzzQJANPeqzihQnlgEdD4MdD4rwhFJwVIp +Le8Lcais1KaN7StzOwxB/XhgSibd2TbnPpw+3bSg5n5lvUdo+e62/31OHwJAU1jS +I+F09KikQIr28u3UUWT2IzTT4cpVv1AHAQyV3sG3YsjSGT0IK20eyP9BEBZU2WL0 +7aNjrvR5aHxKc5FXsQJABsFtyGpgI5X4xufkJZVZ+Mklz2n7iXa+XPatMAHFxAtb +EEMt60rngwMjXAzBSC6OYuYogRRAY3UCacNC5VhLYQ== +-----END RSA PRIVATE KEY----- diff --git a/vendor/github.com/lib/pq/certs/server.key b/vendor/github.com/lib/pq/certs/server.key new file mode 100644 index 00000000..bd7b019b --- /dev/null +++ b/vendor/github.com/lib/pq/certs/server.key @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEA14pMhfsXpTyP4HIRKc4/sB8/fcbuf6f8Ais1RwimPZDfXFYU +lADHbdHS4mGVd7jjpmYx+R8hfWLhJ9qUN2FK6mNToGG4nLul4ue3ptgPBQTHKeLq +SSt/3hUAphhwUMcM3pr5Wpaw4ZQGxm1KITu0D6VtkoY0sk7XDqcZwHcLe4fIkt5C +/4bSt5qk1BUjyq2laSG4zn5my4Vdue2LLQmNlOQEHnLs79B2kBVapPeRS+nOTp1d +mnAXnNjpc4PqPWGZps2skUBaiHflTiqOPRPz+ThvgWuKlcoOB6tv2rSM2f+qeAOq +x8LPb2SS09iD1a/xIxinLnsXC+d98fqoQaMEVwIDAQABAoIBAF3ZoihUhJ82F4+r +Gz4QyDpv4L1reT2sb1aiabhcU8ZK5nbWJG+tRyjSS/i2dNaEcttpdCj9HR/zhgZM +bm0OuAgG58rVwgS80CZUruq++Qs+YVojq8/gWPTiQD4SNhV2Fmx3HkwLgUk3oxuT +SsvdqzGE3okGVrutCIcgy126eA147VPMoej1Bb3fO6npqK0pFPhZfAc0YoqJuM+k +obRm5pAnGUipyLCFXjA9HYPKwYZw2RtfdA3CiImHeanSdqS+ctrC9y8BV40Th7gZ +haXdKUNdjmIxV695QQ1mkGqpKLZFqhzKioGQ2/Ly2d1iaKN9fZltTusu8unepWJ2 +tlT9qMECgYEA9uHaF1t2CqE+AJvWTihHhPIIuLxoOQXYea1qvxfcH/UMtaLKzCNm +lQ5pqCGsPvp+10f36yttO1ZehIvlVNXuJsjt0zJmPtIolNuJY76yeussfQ9jHheB +5uPEzCFlHzxYbBUyqgWaF6W74okRGzEGJXjYSP0yHPPdU4ep2q3bGiUCgYEA34Af +wBSuQSK7uLxArWHvQhyuvi43ZGXls6oRGl+Ysj54s8BP6XGkq9hEJ6G4yxgyV+BR +DUOs5X8/TLT8POuIMYvKTQthQyCk0eLv2FLdESDuuKx0kBVY3s8lK3/z5HhrdOiN +VMNZU+xDKgKc3hN9ypkk8vcZe6EtH7Y14e0rVcsCgYBTgxi8F/M5K0wG9rAqphNz +VFBA9XKn/2M33cKjO5X5tXIEKzpAjaUQvNxexG04rJGljzG8+mar0M6ONahw5yD1 +O7i/XWgazgpuOEkkVYiYbd8RutfDgR4vFVMn3hAP3eDnRtBplRWH9Ec3HTiNIys6 +F8PKBOQjyRZQQC7jyzW3hQKBgACe5HeuFwXLSOYsb6mLmhR+6+VPT4wR1F95W27N +USk9jyxAnngxfpmTkiziABdgS9N+pfr5cyN4BP77ia/Jn6kzkC5Cl9SN5KdIkA3z +vPVtN/x/ThuQU5zaymmig1ThGLtMYggYOslG4LDfLPxY5YKIhle+Y+259twdr2yf +Mf2dAoGAaGv3tWMgnIdGRk6EQL/yb9PKHo7ShN+tKNlGaK7WwzBdKs+Fe8jkgcr7 +pz4Ne887CmxejdISzOCcdT+Zm9Bx6I/uZwWOtDvWpIgIxVX9a9URj/+D1MxTE/y4 +d6H+c89yDY62I2+drMpdjCd3EtCaTlxpTbRS+s1eAHMH7aEkcCE= +-----END RSA PRIVATE KEY----- diff --git a/vendor/github.com/lib/pq/conn.go b/vendor/github.com/lib/pq/conn.go index 43c8df29..012c8c7c 100644 --- a/vendor/github.com/lib/pq/conn.go +++ b/vendor/github.com/lib/pq/conn.go @@ -2,7 +2,9 @@ package pq import ( "bufio" + "context" "crypto/md5" + "crypto/sha256" "database/sql" "database/sql/driver" "encoding/binary" @@ -20,6 +22,7 @@ import ( "unicode" "github.com/lib/pq/oid" + "github.com/lib/pq/scram" ) // Common error types @@ -89,13 +92,24 @@ type Dialer interface { DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) } -type defaultDialer struct{} - -func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) { - return net.Dial(ntw, addr) +type DialerContext interface { + DialContext(ctx context.Context, network, address string) (net.Conn, error) } -func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout(ntw, addr, timeout) + +type defaultDialer struct { + d net.Dialer +} + +func (d defaultDialer) Dial(network, address string) (net.Conn, error) { + return d.d.Dial(network, address) +} +func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + return d.DialContext(ctx, network, address) +} +func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) { + return d.d.DialContext(ctx, network, address) } type conn struct { @@ -244,90 +258,35 @@ func (cn *conn) writeBuf(b byte) *writeBuf { } } -// Open opens a new connection to the database. name is a connection string. +// Open opens a new connection to the database. dsn is a connection string. // Most users should only use it through database/sql package from the standard // library. -func Open(name string) (_ driver.Conn, err error) { - return DialOpen(defaultDialer{}, name) +func Open(dsn string) (_ driver.Conn, err error) { + return DialOpen(defaultDialer{}, dsn) } // DialOpen opens a new connection to the database using a dialer. -func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { +func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) { + c, err := NewConnector(dsn) + if err != nil { + return nil, err + } + c.dialer = d + return c.open(context.Background()) +} + +func (c *Connector) open(ctx context.Context) (cn *conn, err error) { // Handle any panics during connection initialization. Note that we // specifically do *not* want to use errRecover(), as that would turn any // connection errors into ErrBadConns, hiding the real error message from // the user. defer errRecoverNoErrBadConn(&err) - o := make(values) + o := c.opts - // A number of defaults are applied here, in this order: - // - // * Very low precedence defaults applied in every situation - // * Environment variables - // * Explicitly passed connection information - o["host"] = "localhost" - o["port"] = "5432" - // N.B.: Extra float digits should be set to 3, but that breaks - // Postgres 8.4 and older, where the max is 2. - o["extra_float_digits"] = "2" - for k, v := range parseEnviron(os.Environ()) { - o[k] = v - } - - if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") { - name, err = ParseURL(name) - if err != nil { - return nil, err - } - } - - if err := parseOpts(name, o); err != nil { - return nil, err - } - - // Use the "fallback" application name if necessary - if fallback, ok := o["fallback_application_name"]; ok { - if _, ok := o["application_name"]; !ok { - o["application_name"] = fallback - } - } - - // We can't work with any client_encoding other than UTF-8 currently. - // However, we have historically allowed the user to set it to UTF-8 - // explicitly, and there's no reason to break such programs, so allow that. - // Note that the "options" setting could also set client_encoding, but - // parsing its value is not worth it. Instead, we always explicitly send - // client_encoding as a separate run-time parameter, which should override - // anything set in options. - if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { - return nil, errors.New("client_encoding must be absent or 'UTF8'") - } - o["client_encoding"] = "UTF8" - // DateStyle needs a similar treatment. - if datestyle, ok := o["datestyle"]; ok { - if datestyle != "ISO, MDY" { - panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v", - "ISO, MDY", datestyle)) - } - } else { - o["datestyle"] = "ISO, MDY" - } - - // If a user is not provided by any other means, the last - // resort is to use the current operating system provided user - // name. - if _, ok := o["user"]; !ok { - u, err := userCurrent() - if err != nil { - return nil, err - } - o["user"] = u - } - - cn := &conn{ + cn = &conn{ opts: o, - dialer: d, + dialer: c.dialer, } err = cn.handleDriverSettings(o) if err != nil { @@ -335,13 +294,16 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { } cn.handlePgpass(o) - cn.c, err = dial(d, o) + cn.c, err = dial(ctx, c.dialer, o) if err != nil { return nil, err } err = cn.ssl(o) if err != nil { + if cn.c != nil { + cn.c.Close() + } return nil, err } @@ -364,10 +326,10 @@ func DialOpen(d Dialer, name string) (_ driver.Conn, err error) { return cn, err } -func dial(d Dialer, o values) (net.Conn, error) { - ntw, addr := network(o) +func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) { + network, address := network(o) // SSL is not necessary or supported over UNIX domain sockets - if ntw == "unix" { + if network == "unix" { o["sslmode"] = "disable" } @@ -378,19 +340,30 @@ func dial(d Dialer, o values) (net.Conn, error) { return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err) } duration := time.Duration(seconds) * time.Second + // connect_timeout should apply to the entire connection establishment // procedure, so we both use a timeout for the TCP connection // establishment and set a deadline for doing the initial handshake. // The deadline is then reset after startup() is done. deadline := time.Now().Add(duration) - conn, err := d.DialTimeout(ntw, addr, duration) + var conn net.Conn + if dctx, ok := d.(DialerContext); ok { + ctx, cancel := context.WithTimeout(ctx, duration) + defer cancel() + conn, err = dctx.DialContext(ctx, network, address) + } else { + conn, err = d.DialTimeout(network, address, duration) + } if err != nil { return nil, err } err = conn.SetDeadline(deadline) return conn, err } - return d.Dial(ntw, addr) + if dctx, ok := d.(DialerContext); ok { + return dctx.DialContext(ctx, network, address) + } + return d.Dial(network, address) } func network(o values) (string, string) { @@ -704,7 +677,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) { // res might be non-nil here if we received a previous // CommandComplete, but that's fine; just overwrite it res = &rows{cn: cn} - res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r) + res.rowsHeader = parsePortalRowDescribe(r) // To work around a bug in QueryRow in Go 1.2 and earlier, wait // until the first DataRow has been received. @@ -861,17 +834,15 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) { cn.readParseResponse() cn.readBindResponse() rows := &rows{cn: cn} - rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse() + rows.rowsHeader = cn.readPortalDescribeResponse() cn.postExecuteWorkaround() return rows, nil } st := cn.prepareTo(query, "") st.exec(args) return &rows{ - cn: cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, + cn: cn, + rowsHeader: st.rowsHeader, }, nil } @@ -992,7 +963,6 @@ func (cn *conn) recv() (t byte, r *readBuf) { if err != nil { panic(err) } - switch t { case 'E': panic(parseError(r)) @@ -1163,6 +1133,55 @@ func (cn *conn) auth(r *readBuf, o values) { if r.int32() != 0 { errorf("unexpected authentication response: %q", t) } + case 10: + sc := scram.NewClient(sha256.New, o["user"], o["password"]) + sc.Step(nil) + if sc.Err() != nil { + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + } + scOut := sc.Out() + + w := cn.writeBuf('p') + w.string("SCRAM-SHA-256") + w.int32(len(scOut)) + w.bytes(scOut) + cn.send(w) + + t, r := cn.recv() + if t != 'R' { + errorf("unexpected password response: %q", t) + } + + if r.int32() != 11 { + errorf("unexpected authentication response: %q", t) + } + + nextStep := r.next(len(*r)) + sc.Step(nextStep) + if sc.Err() != nil { + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + } + + scOut = sc.Out() + w = cn.writeBuf('p') + w.bytes(scOut) + cn.send(w) + + t, r = cn.recv() + if t != 'R' { + errorf("unexpected password response: %q", t) + } + + if r.int32() != 12 { + errorf("unexpected authentication response: %q", t) + } + + nextStep = r.next(len(*r)) + sc.Step(nextStep) + if sc.Err() != nil { + errorf("SCRAM-SHA-256 error: %s", sc.Err().Error()) + } + default: errorf("unknown authentication response: %d", code) } @@ -1180,12 +1199,10 @@ var colFmtDataAllBinary = []byte{0, 1, 0, 1} var colFmtDataAllText = []byte{0, 0} type stmt struct { - cn *conn - name string - colNames []string - colFmts []format + cn *conn + name string + rowsHeader colFmtData []byte - colTyps []fieldDesc paramTyps []oid.Oid closed bool } @@ -1231,10 +1248,8 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) { st.exec(v) return &rows{ - cn: st.cn, - colNames: st.colNames, - colTyps: st.colTyps, - colFmts: st.colFmts, + cn: st.cn, + rowsHeader: st.rowsHeader, }, nil } @@ -1344,16 +1359,22 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) { return driver.RowsAffected(n), commandTag } -type rows struct { - cn *conn - finish func() +type rowsHeader struct { colNames []string colTyps []fieldDesc colFmts []format - done bool - rb readBuf - result driver.Result - tag string +} + +type rows struct { + cn *conn + finish func() + rowsHeader + done bool + rb readBuf + result driver.Result + tag string + + next *rowsHeader } func (rs *rows) Close() error { @@ -1440,7 +1461,8 @@ func (rs *rows) Next(dest []driver.Value) (err error) { } return case 'T': - rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb) + next := parsePortalRowDescribe(&rs.rb) + rs.next = &next return io.EOF default: errorf("unexpected message after execute: %q", t) @@ -1449,10 +1471,16 @@ func (rs *rows) Next(dest []driver.Value) (err error) { } func (rs *rows) HasNextResultSet() bool { - return !rs.done + hasNext := rs.next != nil && !rs.done + return hasNext } func (rs *rows) NextResultSet() error { + if rs.next == nil { + return io.EOF + } + rs.rowsHeader = *rs.next + rs.next = nil return nil } @@ -1475,6 +1503,39 @@ func QuoteIdentifier(name string) string { return `"` + strings.Replace(name, `"`, `""`, -1) + `"` } +// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal +// to DDL and other statements that do not accept parameters) to be used as part +// of an SQL statement. For example: +// +// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z") +// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date)) +// +// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be +// replaced by two backslashes (i.e. "\\") and the C-style escape identifier +// that PostgreSQL provides ('E') will be prepended to the string. +func QuoteLiteral(literal string) string { + // This follows the PostgreSQL internal algorithm for handling quoted literals + // from libpq, which can be found in the "PQEscapeStringInternal" function, + // which is found in the libpq/fe-exec.c source file: + // https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c + // + // substitute any single-quotes (') with two single-quotes ('') + literal = strings.Replace(literal, `'`, `''`, -1) + // determine if the string has any backslashes (\) in it. + // if it does, replace any backslashes (\) with two backslashes (\\) + // then, we need to wrap the entire string with a PostgreSQL + // C-style escape. Per how "PQEscapeStringInternal" handles this case, we + // also add a space before the "E" + if strings.Contains(literal, `\`) { + literal = strings.Replace(literal, `\`, `\\`, -1) + literal = ` E'` + literal + `'` + } else { + // otherwise, we can just wrap the literal with a pair of single quotes + literal = `'` + literal + `'` + } + return literal +} + func md5s(s string) string { h := md5.New() h.Write([]byte(s)) @@ -1630,13 +1691,13 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [ } } -func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) { +func (cn *conn) readPortalDescribeResponse() rowsHeader { t, r := cn.recv1() switch t { case 'T': return parsePortalRowDescribe(r) case 'n': - return nil, nil, nil + return rowsHeader{} case 'E': err := parseError(r) cn.readReadyForQuery() @@ -1742,11 +1803,11 @@ func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDe return } -func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) { +func parsePortalRowDescribe(r *readBuf) rowsHeader { n := r.int16() - colNames = make([]string, n) - colFmts = make([]format, n) - colTyps = make([]fieldDesc, n) + colNames := make([]string, n) + colFmts := make([]format, n) + colTyps := make([]fieldDesc, n) for i := range colNames { colNames[i] = r.string() r.next(6) @@ -1755,7 +1816,11 @@ func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, co colTyps[i].Mod = r.int32() colFmts[i] = format(r.int16()) } - return + return rowsHeader{ + colNames: colNames, + colFmts: colFmts, + colTyps: colTyps, + } } // parseEnviron tries to mimic some of libpq's environment handling diff --git a/vendor/github.com/lib/pq/conn_go18.go b/vendor/github.com/lib/pq/conn_go18.go index a5254f2b..0fdd06a6 100644 --- a/vendor/github.com/lib/pq/conn_go18.go +++ b/vendor/github.com/lib/pq/conn_go18.go @@ -1,5 +1,3 @@ -// +build go1.8 - package pq import ( @@ -9,6 +7,7 @@ import ( "fmt" "io" "io/ioutil" + "time" ) // Implement the "QueryerContext" interface @@ -76,13 +75,32 @@ func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, return tx, nil } +func (cn *conn) Ping(ctx context.Context) error { + if finish := cn.watchCancel(ctx); finish != nil { + defer finish() + } + rows, err := cn.simpleQuery("SELECT 'lib/pq ping test';") + if err != nil { + return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger + } + rows.Close() + return nil +} + func (cn *conn) watchCancel(ctx context.Context) func() { if done := ctx.Done(); done != nil { finished := make(chan struct{}) go func() { select { case <-done: - _ = cn.cancel() + // At this point the function level context is canceled, + // so it must not be used for the additional network + // request to cancel the query. + // Create a new context to pass into the dial. + ctxCancel, cancel := context.WithTimeout(context.Background(), time.Second*10) + defer cancel() + + _ = cn.cancel(ctxCancel) finished <- struct{}{} case <-finished: } @@ -97,8 +115,8 @@ func (cn *conn) watchCancel(ctx context.Context) func() { return nil } -func (cn *conn) cancel() error { - c, err := dial(cn.dialer, cn.opts) +func (cn *conn) cancel(ctx context.Context) error { + c, err := dial(ctx, cn.dialer, cn.opts) if err != nil { return err } diff --git a/vendor/github.com/lib/pq/conn_test.go b/vendor/github.com/lib/pq/conn_test.go new file mode 100644 index 00000000..f44bbdea --- /dev/null +++ b/vendor/github.com/lib/pq/conn_test.go @@ -0,0 +1,1741 @@ +package pq + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "io" + "net" + "os" + "reflect" + "strings" + "testing" + "time" +) + +type Fatalistic interface { + Fatal(args ...interface{}) +} + +func forceBinaryParameters() bool { + bp := os.Getenv("PQTEST_BINARY_PARAMETERS") + if bp == "yes" { + return true + } else if bp == "" || bp == "no" { + return false + } else { + panic("unexpected value for PQTEST_BINARY_PARAMETERS") + } +} + +func testConninfo(conninfo string) string { + defaultTo := func(envvar string, value string) { + if os.Getenv(envvar) == "" { + os.Setenv(envvar, value) + } + } + defaultTo("PGDATABASE", "pqgotest") + defaultTo("PGSSLMODE", "disable") + defaultTo("PGCONNECT_TIMEOUT", "20") + + if forceBinaryParameters() && + !strings.HasPrefix(conninfo, "postgres://") && + !strings.HasPrefix(conninfo, "postgresql://") { + conninfo = conninfo + " binary_parameters=yes" + } + return conninfo +} + +func openTestConnConninfo(conninfo string) (*sql.DB, error) { + return sql.Open("postgres", testConninfo(conninfo)) +} + +func openTestConn(t Fatalistic) *sql.DB { + conn, err := openTestConnConninfo("") + if err != nil { + t.Fatal(err) + } + + return conn +} + +func getServerVersion(t *testing.T, db *sql.DB) int { + var version int + err := db.QueryRow("SHOW server_version_num").Scan(&version) + if err != nil { + t.Fatal(err) + } + return version +} + +func TestReconnect(t *testing.T) { + db1 := openTestConn(t) + defer db1.Close() + tx, err := db1.Begin() + if err != nil { + t.Fatal(err) + } + var pid1 int + err = tx.QueryRow("SELECT pg_backend_pid()").Scan(&pid1) + if err != nil { + t.Fatal(err) + } + db2 := openTestConn(t) + defer db2.Close() + _, err = db2.Exec("SELECT pg_terminate_backend($1)", pid1) + if err != nil { + t.Fatal(err) + } + // The rollback will probably "fail" because we just killed + // its connection above + _ = tx.Rollback() + + const expected int = 42 + var result int + err = db1.QueryRow(fmt.Sprintf("SELECT %d", expected)).Scan(&result) + if err != nil { + t.Fatal(err) + } + if result != expected { + t.Errorf("got %v; expected %v", result, expected) + } +} + +func TestCommitInFailedTransaction(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + rows, err := txn.Query("SELECT error") + if err == nil { + rows.Close() + t.Fatal("expected failure") + } + err = txn.Commit() + if err != ErrInFailedTransaction { + t.Fatalf("expected ErrInFailedTransaction; got %#v", err) + } +} + +func TestOpenURL(t *testing.T) { + testURL := func(url string) { + db, err := openTestConnConninfo(url) + if err != nil { + t.Fatal(err) + } + defer db.Close() + // database/sql might not call our Open at all unless we do something with + // the connection + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + txn.Rollback() + } + testURL("postgres://") + testURL("postgresql://") +} + +const pgpassFile = "/tmp/pqgotest_pgpass" + +func TestPgpass(t *testing.T) { + if os.Getenv("TRAVIS") != "true" { + t.Skip("not running under Travis, skipping pgpass tests") + } + + testAssert := func(conninfo string, expected string, reason string) { + conn, err := openTestConnConninfo(conninfo) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + txn, err := conn.Begin() + if err != nil { + if expected != "fail" { + t.Fatalf(reason, err) + } + return + } + rows, err := txn.Query("SELECT USER") + if err != nil { + txn.Rollback() + if expected != "fail" { + t.Fatalf(reason, err) + } + } else { + rows.Close() + if expected != "ok" { + t.Fatalf(reason, err) + } + } + txn.Rollback() + } + testAssert("", "ok", "missing .pgpass, unexpected error %#v") + os.Setenv("PGPASSFILE", pgpassFile) + testAssert("host=/tmp", "fail", ", unexpected error %#v") + os.Remove(pgpassFile) + pgpass, err := os.OpenFile(pgpassFile, os.O_RDWR|os.O_CREATE, 0644) + if err != nil { + t.Fatalf("Unexpected error writing pgpass file %#v", err) + } + _, err = pgpass.WriteString(`# comment +server:5432:some_db:some_user:pass_A +*:5432:some_db:some_user:pass_B +localhost:*:*:*:pass_C +*:*:*:*:pass_fallback +`) + if err != nil { + t.Fatalf("Unexpected error writing pgpass file %#v", err) + } + pgpass.Close() + + assertPassword := func(extra values, expected string) { + o := values{ + "host": "localhost", + "sslmode": "disable", + "connect_timeout": "20", + "user": "majid", + "port": "5432", + "extra_float_digits": "2", + "dbname": "pqgotest", + "client_encoding": "UTF8", + "datestyle": "ISO, MDY", + } + for k, v := range extra { + o[k] = v + } + (&conn{}).handlePgpass(o) + if pw := o["password"]; pw != expected { + t.Fatalf("For %v expected %s got %s", extra, expected, pw) + } + } + // wrong permissions for the pgpass file means it should be ignored + assertPassword(values{"host": "example.com", "user": "foo"}, "") + // fix the permissions and check if it has taken effect + os.Chmod(pgpassFile, 0600) + assertPassword(values{"host": "server", "dbname": "some_db", "user": "some_user"}, "pass_A") + assertPassword(values{"host": "example.com", "user": "foo"}, "pass_fallback") + assertPassword(values{"host": "example.com", "dbname": "some_db", "user": "some_user"}, "pass_B") + // localhost also matches the default "" and UNIX sockets + assertPassword(values{"host": "", "user": "some_user"}, "pass_C") + assertPassword(values{"host": "/tmp", "user": "some_user"}, "pass_C") + // cleanup + os.Remove(pgpassFile) + os.Setenv("PGPASSFILE", "") +} + +func TestExec(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + _, err := db.Exec("CREATE TEMP TABLE temp (a int)") + if err != nil { + t.Fatal(err) + } + + r, err := db.Exec("INSERT INTO temp VALUES (1)") + if err != nil { + t.Fatal(err) + } + + if n, _ := r.RowsAffected(); n != 1 { + t.Fatalf("expected 1 row affected, not %d", n) + } + + r, err = db.Exec("INSERT INTO temp VALUES ($1), ($2), ($3)", 1, 2, 3) + if err != nil { + t.Fatal(err) + } + + if n, _ := r.RowsAffected(); n != 3 { + t.Fatalf("expected 3 rows affected, not %d", n) + } + + // SELECT doesn't send the number of returned rows in the command tag + // before 9.0 + if getServerVersion(t, db) >= 90000 { + r, err = db.Exec("SELECT g FROM generate_series(1, 2) g") + if err != nil { + t.Fatal(err) + } + if n, _ := r.RowsAffected(); n != 2 { + t.Fatalf("expected 2 rows affected, not %d", n) + } + + r, err = db.Exec("SELECT g FROM generate_series(1, $1) g", 3) + if err != nil { + t.Fatal(err) + } + if n, _ := r.RowsAffected(); n != 3 { + t.Fatalf("expected 3 rows affected, not %d", n) + } + } +} + +func TestStatment(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + st, err := db.Prepare("SELECT 1") + if err != nil { + t.Fatal(err) + } + + st1, err := db.Prepare("SELECT 2") + if err != nil { + t.Fatal(err) + } + + r, err := st.Query() + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + t.Fatal("expected row") + } + + var i int + err = r.Scan(&i) + if err != nil { + t.Fatal(err) + } + + if i != 1 { + t.Fatalf("expected 1, got %d", i) + } + + // st1 + + r1, err := st1.Query() + if err != nil { + t.Fatal(err) + } + defer r1.Close() + + if !r1.Next() { + if r.Err() != nil { + t.Fatal(r1.Err()) + } + t.Fatal("expected row") + } + + err = r1.Scan(&i) + if err != nil { + t.Fatal(err) + } + + if i != 2 { + t.Fatalf("expected 2, got %d", i) + } +} + +func TestRowsCloseBeforeDone(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + r, err := db.Query("SELECT 1") + if err != nil { + t.Fatal(err) + } + + err = r.Close() + if err != nil { + t.Fatal(err) + } + + if r.Next() { + t.Fatal("unexpected row") + } + + if r.Err() != nil { + t.Fatal(r.Err()) + } +} + +func TestParameterCountMismatch(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + var notused int + err := db.QueryRow("SELECT false", 1).Scan(¬used) + if err == nil { + t.Fatal("expected err") + } + // make sure we clean up correctly + err = db.QueryRow("SELECT 1").Scan(¬used) + if err != nil { + t.Fatal(err) + } + + err = db.QueryRow("SELECT $1").Scan(¬used) + if err == nil { + t.Fatal("expected err") + } + // make sure we clean up correctly + err = db.QueryRow("SELECT 1").Scan(¬used) + if err != nil { + t.Fatal(err) + } +} + +// Test that EmptyQueryResponses are handled correctly. +func TestEmptyQuery(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + res, err := db.Exec("") + if err != nil { + t.Fatal(err) + } + if _, err := res.RowsAffected(); err != errNoRowsAffected { + t.Fatalf("expected %s, got %v", errNoRowsAffected, err) + } + if _, err := res.LastInsertId(); err != errNoLastInsertID { + t.Fatalf("expected %s, got %v", errNoLastInsertID, err) + } + rows, err := db.Query("") + if err != nil { + t.Fatal(err) + } + cols, err := rows.Columns() + if err != nil { + t.Fatal(err) + } + if len(cols) != 0 { + t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols)) + } + if rows.Next() { + t.Fatal("unexpected row") + } + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + + stmt, err := db.Prepare("") + if err != nil { + t.Fatal(err) + } + res, err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + if _, err := res.RowsAffected(); err != errNoRowsAffected { + t.Fatalf("expected %s, got %v", errNoRowsAffected, err) + } + if _, err := res.LastInsertId(); err != errNoLastInsertID { + t.Fatalf("expected %s, got %v", errNoLastInsertID, err) + } + rows, err = stmt.Query() + if err != nil { + t.Fatal(err) + } + cols, err = rows.Columns() + if err != nil { + t.Fatal(err) + } + if len(cols) != 0 { + t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols)) + } + if rows.Next() { + t.Fatal("unexpected row") + } + if rows.Err() != nil { + t.Fatal(rows.Err()) + } +} + +// Test that rows.Columns() is correct even if there are no result rows. +func TestEmptyResultSetColumns(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar WHERE FALSE") + if err != nil { + t.Fatal(err) + } + cols, err := rows.Columns() + if err != nil { + t.Fatal(err) + } + if len(cols) != 2 { + t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols)) + } + if rows.Next() { + t.Fatal("unexpected row") + } + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + if cols[0] != "a" || cols[1] != "bar" { + t.Fatalf("unexpected Columns result %v", cols) + } + + stmt, err := db.Prepare("SELECT $1::int AS a, text 'bar' AS bar WHERE FALSE") + if err != nil { + t.Fatal(err) + } + rows, err = stmt.Query(1) + if err != nil { + t.Fatal(err) + } + cols, err = rows.Columns() + if err != nil { + t.Fatal(err) + } + if len(cols) != 2 { + t.Fatalf("unexpected number of columns %d in response to an empty query", len(cols)) + } + if rows.Next() { + t.Fatal("unexpected row") + } + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + if cols[0] != "a" || cols[1] != "bar" { + t.Fatalf("unexpected Columns result %v", cols) + } + +} + +func TestEncodeDecode(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + q := ` + SELECT + E'\\000\\001\\002'::bytea, + 'foobar'::text, + NULL::integer, + '2000-1-1 01:02:03.04-7'::timestamptz, + 0::boolean, + 123, + -321, + 3.14::float8 + WHERE + E'\\000\\001\\002'::bytea = $1 + AND 'foobar'::text = $2 + AND $3::integer is NULL + ` + // AND '2000-1-1 12:00:00.000000-7'::timestamp = $3 + + exp1 := []byte{0, 1, 2} + exp2 := "foobar" + + r, err := db.Query(q, exp1, exp2, nil) + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + if r.Err() != nil { + t.Fatal(r.Err()) + } + t.Fatal("expected row") + } + + var got1 []byte + var got2 string + var got3 = sql.NullInt64{Valid: true} + var got4 time.Time + var got5, got6, got7, got8 interface{} + + err = r.Scan(&got1, &got2, &got3, &got4, &got5, &got6, &got7, &got8) + if err != nil { + t.Fatal(err) + } + + if !reflect.DeepEqual(exp1, got1) { + t.Errorf("expected %q byte: %q", exp1, got1) + } + + if !reflect.DeepEqual(exp2, got2) { + t.Errorf("expected %q byte: %q", exp2, got2) + } + + if got3.Valid { + t.Fatal("expected invalid") + } + + if got4.Year() != 2000 { + t.Fatal("wrong year") + } + + if got5 != false { + t.Fatalf("expected false, got %q", got5) + } + + if got6 != int64(123) { + t.Fatalf("expected 123, got %d", got6) + } + + if got7 != int64(-321) { + t.Fatalf("expected -321, got %d", got7) + } + + if got8 != float64(3.14) { + t.Fatalf("expected 3.14, got %f", got8) + } +} + +func TestNoData(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + st, err := db.Prepare("SELECT 1 WHERE true = false") + if err != nil { + t.Fatal(err) + } + defer st.Close() + + r, err := st.Query() + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if r.Next() { + if r.Err() != nil { + t.Fatal(r.Err()) + } + t.Fatal("unexpected row") + } + + _, err = db.Query("SELECT * FROM nonexistenttable WHERE age=$1", 20) + if err == nil { + t.Fatal("Should have raised an error on non existent table") + } + + _, err = db.Query("SELECT * FROM nonexistenttable") + if err == nil { + t.Fatal("Should have raised an error on non existent table") + } +} + +func TestErrorDuringStartup(t *testing.T) { + // Don't use the normal connection setup, this is intended to + // blow up in the startup packet from a non-existent user. + db, err := openTestConnConninfo("user=thisuserreallydoesntexist") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + _, err = db.Begin() + if err == nil { + t.Fatal("expected error") + } + + e, ok := err.(*Error) + if !ok { + t.Fatalf("expected Error, got %#v", err) + } else if e.Code.Name() != "invalid_authorization_specification" && e.Code.Name() != "invalid_password" { + t.Fatalf("expected invalid_authorization_specification or invalid_password, got %s (%+v)", e.Code.Name(), err) + } +} + +type testConn struct { + closed bool + net.Conn +} + +func (c *testConn) Close() error { + c.closed = true + return c.Conn.Close() +} + +type testDialer struct { + conns []*testConn +} + +func (d *testDialer) Dial(ntw, addr string) (net.Conn, error) { + c, err := net.Dial(ntw, addr) + if err != nil { + return nil, err + } + tc := &testConn{Conn: c} + d.conns = append(d.conns, tc) + return tc, nil +} + +func (d *testDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) { + c, err := net.DialTimeout(ntw, addr, timeout) + if err != nil { + return nil, err + } + tc := &testConn{Conn: c} + d.conns = append(d.conns, tc) + return tc, nil +} + +func TestErrorDuringStartupClosesConn(t *testing.T) { + // Don't use the normal connection setup, this is intended to + // blow up in the startup packet from a non-existent user. + var d testDialer + c, err := DialOpen(&d, testConninfo("user=thisuserreallydoesntexist")) + if err == nil { + c.Close() + t.Fatal("expected dial error") + } + if len(d.conns) != 1 { + t.Fatalf("got len(d.conns) = %d, want = %d", len(d.conns), 1) + } + if !d.conns[0].closed { + t.Error("connection leaked") + } +} + +func TestBadConn(t *testing.T) { + var err error + + cn := conn{} + func() { + defer cn.errRecover(&err) + panic(io.EOF) + }() + if err != driver.ErrBadConn { + t.Fatalf("expected driver.ErrBadConn, got: %#v", err) + } + if !cn.bad { + t.Fatalf("expected cn.bad") + } + + cn = conn{} + func() { + defer cn.errRecover(&err) + e := &Error{Severity: Efatal} + panic(e) + }() + if err != driver.ErrBadConn { + t.Fatalf("expected driver.ErrBadConn, got: %#v", err) + } + if !cn.bad { + t.Fatalf("expected cn.bad") + } +} + +// TestCloseBadConn tests that the underlying connection can be closed with +// Close after an error. +func TestCloseBadConn(t *testing.T) { + nc, err := net.Dial("tcp", "localhost:5432") + if err != nil { + t.Fatal(err) + } + cn := conn{c: nc} + func() { + defer cn.errRecover(&err) + panic(io.EOF) + }() + // Verify we can write before closing. + if _, err := nc.Write(nil); err != nil { + t.Fatal(err) + } + // First close should close the connection. + if err := cn.Close(); err != nil { + t.Fatal(err) + } + + // During the Go 1.9 cycle, https://github.com/golang/go/commit/3792db5 + // changed this error from + // + // net.errClosing = errors.New("use of closed network connection") + // + // to + // + // internal/poll.ErrClosing = errors.New("use of closed file or network connection") + const errClosing = "use of closed" + + // Verify write after closing fails. + if _, err := nc.Write(nil); err == nil { + t.Fatal("expected error") + } else if !strings.Contains(err.Error(), errClosing) { + t.Fatalf("expected %s error, got %s", errClosing, err) + } + // Verify second close fails. + if err := cn.Close(); err == nil { + t.Fatal("expected error") + } else if !strings.Contains(err.Error(), errClosing) { + t.Fatalf("expected %s error, got %s", errClosing, err) + } +} + +func TestErrorOnExec(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + + _, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)") + if err != nil { + t.Fatal(err) + } + + _, err = txn.Exec("INSERT INTO foo VALUES (0), (0)") + if err == nil { + t.Fatal("Should have raised error") + } + + e, ok := err.(*Error) + if !ok { + t.Fatalf("expected Error, got %#v", err) + } else if e.Code.Name() != "unique_violation" { + t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err) + } +} + +func TestErrorOnQuery(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + + _, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)") + if err != nil { + t.Fatal(err) + } + + _, err = txn.Query("INSERT INTO foo VALUES (0), (0)") + if err == nil { + t.Fatal("Should have raised error") + } + + e, ok := err.(*Error) + if !ok { + t.Fatalf("expected Error, got %#v", err) + } else if e.Code.Name() != "unique_violation" { + t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err) + } +} + +func TestErrorOnQueryRowSimpleQuery(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + + _, err = txn.Exec("CREATE TEMPORARY TABLE foo(f1 int PRIMARY KEY)") + if err != nil { + t.Fatal(err) + } + + var v int + err = txn.QueryRow("INSERT INTO foo VALUES (0), (0)").Scan(&v) + if err == nil { + t.Fatal("Should have raised error") + } + + e, ok := err.(*Error) + if !ok { + t.Fatalf("expected Error, got %#v", err) + } else if e.Code.Name() != "unique_violation" { + t.Fatalf("expected unique_violation, got %s (%+v)", e.Code.Name(), err) + } +} + +// Test the QueryRow bug workarounds in stmt.exec() and simpleQuery() +func TestQueryRowBugWorkaround(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + // stmt.exec() + _, err := db.Exec("CREATE TEMP TABLE notnulltemp (a varchar(10) not null)") + if err != nil { + t.Fatal(err) + } + + var a string + err = db.QueryRow("INSERT INTO notnulltemp(a) values($1) RETURNING a", nil).Scan(&a) + if err == sql.ErrNoRows { + t.Fatalf("expected constraint violation error; got: %v", err) + } + pge, ok := err.(*Error) + if !ok { + t.Fatalf("expected *Error; got: %#v", err) + } + if pge.Code.Name() != "not_null_violation" { + t.Fatalf("expected not_null_violation; got: %s (%+v)", pge.Code.Name(), err) + } + + // Test workaround in simpleQuery() + tx, err := db.Begin() + if err != nil { + t.Fatalf("unexpected error %s in Begin", err) + } + defer tx.Rollback() + + _, err = tx.Exec("SET LOCAL check_function_bodies TO FALSE") + if err != nil { + t.Fatalf("could not disable check_function_bodies: %s", err) + } + _, err = tx.Exec(` +CREATE OR REPLACE FUNCTION bad_function() +RETURNS integer +-- hack to prevent the function from being inlined +SET check_function_bodies TO TRUE +AS $$ + SELECT text 'bad' +$$ LANGUAGE sql`) + if err != nil { + t.Fatalf("could not create function: %s", err) + } + + err = tx.QueryRow("SELECT * FROM bad_function()").Scan(&a) + if err == nil { + t.Fatalf("expected error") + } + pge, ok = err.(*Error) + if !ok { + t.Fatalf("expected *Error; got: %#v", err) + } + if pge.Code.Name() != "invalid_function_definition" { + t.Fatalf("expected invalid_function_definition; got: %s (%+v)", pge.Code.Name(), err) + } + + err = tx.Rollback() + if err != nil { + t.Fatalf("unexpected error %s in Rollback", err) + } + + // Also test that simpleQuery()'s workaround works when the query fails + // after a row has been received. + rows, err := db.Query(` +select + (select generate_series(1, ss.i)) +from (select gs.i + from generate_series(1, 2) gs(i) + order by gs.i limit 2) ss`) + if err != nil { + t.Fatalf("query failed: %s", err) + } + if !rows.Next() { + t.Fatalf("expected at least one result row; got %s", rows.Err()) + } + var i int + err = rows.Scan(&i) + if err != nil { + t.Fatalf("rows.Scan() failed: %s", err) + } + if i != 1 { + t.Fatalf("unexpected value for i: %d", i) + } + if rows.Next() { + t.Fatalf("unexpected row") + } + pge, ok = rows.Err().(*Error) + if !ok { + t.Fatalf("expected *Error; got: %#v", err) + } + if pge.Code.Name() != "cardinality_violation" { + t.Fatalf("expected cardinality_violation; got: %s (%+v)", pge.Code.Name(), rows.Err()) + } +} + +func TestSimpleQuery(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + r, err := db.Query("select 1") + if err != nil { + t.Fatal(err) + } + defer r.Close() + + if !r.Next() { + t.Fatal("expected row") + } +} + +func TestBindError(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + _, err := db.Exec("create temp table test (i integer)") + if err != nil { + t.Fatal(err) + } + + _, err = db.Query("select * from test where i=$1", "hhh") + if err == nil { + t.Fatal("expected an error") + } + + // Should not get error here + r, err := db.Query("select * from test where i=$1", 1) + if err != nil { + t.Fatal(err) + } + defer r.Close() +} + +func TestParseErrorInExtendedQuery(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + _, err := db.Query("PARSE_ERROR $1", 1) + pqErr, _ := err.(*Error) + // Expecting a syntax error. + if err == nil || pqErr == nil || pqErr.Code != "42601" { + t.Fatalf("expected syntax error, got %s", err) + } + + rows, err := db.Query("SELECT 1") + if err != nil { + t.Fatal(err) + } + rows.Close() +} + +// TestReturning tests that an INSERT query using the RETURNING clause returns a row. +func TestReturning(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + _, err := db.Exec("CREATE TEMP TABLE distributors (did integer default 0, dname text)") + if err != nil { + t.Fatal(err) + } + + rows, err := db.Query("INSERT INTO distributors (did, dname) VALUES (DEFAULT, 'XYZ Widgets') " + + "RETURNING did;") + if err != nil { + t.Fatal(err) + } + if !rows.Next() { + t.Fatal("no rows") + } + var did int + err = rows.Scan(&did) + if err != nil { + t.Fatal(err) + } + if did != 0 { + t.Fatalf("bad value for did: got %d, want %d", did, 0) + } + + if rows.Next() { + t.Fatal("unexpected next row") + } + err = rows.Err() + if err != nil { + t.Fatal(err) + } +} + +func TestIssue186(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + // Exec() a query which returns results + _, err := db.Exec("VALUES (1), (2), (3)") + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("VALUES ($1), ($2), ($3)", 1, 2, 3) + if err != nil { + t.Fatal(err) + } + + // Query() a query which doesn't return any results + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + + rows, err := txn.Query("CREATE TEMP TABLE foo(f1 int)") + if err != nil { + t.Fatal(err) + } + if err = rows.Close(); err != nil { + t.Fatal(err) + } + + // small trick to get NoData from a parameterized query + _, err = txn.Exec("CREATE RULE nodata AS ON INSERT TO foo DO INSTEAD NOTHING") + if err != nil { + t.Fatal(err) + } + rows, err = txn.Query("INSERT INTO foo VALUES ($1)", 1) + if err != nil { + t.Fatal(err) + } + if err = rows.Close(); err != nil { + t.Fatal(err) + } +} + +func TestIssue196(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + row := db.QueryRow("SELECT float4 '0.10000122' = $1, float8 '35.03554004971999' = $2", + float32(0.10000122), float64(35.03554004971999)) + + var float4match, float8match bool + err := row.Scan(&float4match, &float8match) + if err != nil { + t.Fatal(err) + } + if !float4match { + t.Errorf("Expected float4 fidelity to be maintained; got no match") + } + if !float8match { + t.Errorf("Expected float8 fidelity to be maintained; got no match") + } +} + +// Test that any CommandComplete messages sent before the query results are +// ignored. +func TestIssue282(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + var searchPath string + err := db.QueryRow(` + SET LOCAL search_path TO pg_catalog; + SET LOCAL search_path TO pg_catalog; + SHOW search_path`).Scan(&searchPath) + if err != nil { + t.Fatal(err) + } + if searchPath != "pg_catalog" { + t.Fatalf("unexpected search_path %s", searchPath) + } +} + +func TestReadFloatPrecision(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + row := db.QueryRow("SELECT float4 '0.10000122', float8 '35.03554004971999', float4 '1.2'") + var float4val float32 + var float8val float64 + var float4val2 float64 + err := row.Scan(&float4val, &float8val, &float4val2) + if err != nil { + t.Fatal(err) + } + if float4val != float32(0.10000122) { + t.Errorf("Expected float4 fidelity to be maintained; got no match") + } + if float8val != float64(35.03554004971999) { + t.Errorf("Expected float8 fidelity to be maintained; got no match") + } + if float4val2 != float64(1.2) { + t.Errorf("Expected float4 fidelity into a float64 to be maintained; got no match") + } +} + +func TestXactMultiStmt(t *testing.T) { + // minified test case based on bug reports from + // pico303@gmail.com and rangelspam@gmail.com + t.Skip("Skipping failing test") + db := openTestConn(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Commit() + + rows, err := tx.Query("select 1") + if err != nil { + t.Fatal(err) + } + + if rows.Next() { + var val int32 + if err = rows.Scan(&val); err != nil { + t.Fatal(err) + } + } else { + t.Fatal("Expected at least one row in first query in xact") + } + + rows2, err := tx.Query("select 2") + if err != nil { + t.Fatal(err) + } + + if rows2.Next() { + var val2 int32 + if err := rows2.Scan(&val2); err != nil { + t.Fatal(err) + } + } else { + t.Fatal("Expected at least one row in second query in xact") + } + + if err = rows.Err(); err != nil { + t.Fatal(err) + } + + if err = rows2.Err(); err != nil { + t.Fatal(err) + } + + if err = tx.Commit(); err != nil { + t.Fatal(err) + } +} + +var envParseTests = []struct { + Expected map[string]string + Env []string +}{ + { + Env: []string{"PGDATABASE=hello", "PGUSER=goodbye"}, + Expected: map[string]string{"dbname": "hello", "user": "goodbye"}, + }, + { + Env: []string{"PGDATESTYLE=ISO, MDY"}, + Expected: map[string]string{"datestyle": "ISO, MDY"}, + }, + { + Env: []string{"PGCONNECT_TIMEOUT=30"}, + Expected: map[string]string{"connect_timeout": "30"}, + }, +} + +func TestParseEnviron(t *testing.T) { + for i, tt := range envParseTests { + results := parseEnviron(tt.Env) + if !reflect.DeepEqual(tt.Expected, results) { + t.Errorf("%d: Expected: %#v Got: %#v", i, tt.Expected, results) + } + } +} + +func TestParseComplete(t *testing.T) { + tpc := func(commandTag string, command string, affectedRows int64, shouldFail bool) { + defer func() { + if p := recover(); p != nil { + if !shouldFail { + t.Error(p) + } + } + }() + cn := &conn{} + res, c := cn.parseComplete(commandTag) + if c != command { + t.Errorf("Expected %v, got %v", command, c) + } + n, err := res.RowsAffected() + if err != nil { + t.Fatal(err) + } + if n != affectedRows { + t.Errorf("Expected %d, got %d", affectedRows, n) + } + } + + tpc("ALTER TABLE", "ALTER TABLE", 0, false) + tpc("INSERT 0 1", "INSERT", 1, false) + tpc("UPDATE 100", "UPDATE", 100, false) + tpc("SELECT 100", "SELECT", 100, false) + tpc("FETCH 100", "FETCH", 100, false) + // allow COPY (and others) without row count + tpc("COPY", "COPY", 0, false) + // don't fail on command tags we don't recognize + tpc("UNKNOWNCOMMANDTAG", "UNKNOWNCOMMANDTAG", 0, false) + + // failure cases + tpc("INSERT 1", "", 0, true) // missing oid + tpc("UPDATE 0 1", "", 0, true) // too many numbers + tpc("SELECT foo", "", 0, true) // invalid row count +} + +// Test interface conformance. +var ( + _ driver.ExecerContext = (*conn)(nil) + _ driver.QueryerContext = (*conn)(nil) +) + +func TestNullAfterNonNull(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + r, err := db.Query("SELECT 9::integer UNION SELECT NULL::integer") + if err != nil { + t.Fatal(err) + } + + var n sql.NullInt64 + + if !r.Next() { + if r.Err() != nil { + t.Fatal(err) + } + t.Fatal("expected row") + } + + if err := r.Scan(&n); err != nil { + t.Fatal(err) + } + + if n.Int64 != 9 { + t.Fatalf("expected 2, not %d", n.Int64) + } + + if !r.Next() { + if r.Err() != nil { + t.Fatal(err) + } + t.Fatal("expected row") + } + + if err := r.Scan(&n); err != nil { + t.Fatal(err) + } + + if n.Valid { + t.Fatal("expected n to be invalid") + } + + if n.Int64 != 0 { + t.Fatalf("expected n to 2, not %d", n.Int64) + } +} + +func Test64BitErrorChecking(t *testing.T) { + defer func() { + if err := recover(); err != nil { + t.Fatal("panic due to 0xFFFFFFFF != -1 " + + "when int is 64 bits") + } + }() + + db := openTestConn(t) + defer db.Close() + + r, err := db.Query(`SELECT * +FROM (VALUES (0::integer, NULL::text), (1, 'test string')) AS t;`) + + if err != nil { + t.Fatal(err) + } + + defer r.Close() + + for r.Next() { + } +} + +func TestCommit(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + _, err := db.Exec("CREATE TEMP TABLE temp (a int)") + if err != nil { + t.Fatal(err) + } + sqlInsert := "INSERT INTO temp VALUES (1)" + sqlSelect := "SELECT * FROM temp" + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + _, err = tx.Exec(sqlInsert) + if err != nil { + t.Fatal(err) + } + err = tx.Commit() + if err != nil { + t.Fatal(err) + } + var i int + err = db.QueryRow(sqlSelect).Scan(&i) + if err != nil { + t.Fatal(err) + } + if i != 1 { + t.Fatalf("expected 1, got %d", i) + } +} + +func TestErrorClass(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + _, err := db.Query("SELECT int 'notint'") + if err == nil { + t.Fatal("expected error") + } + pge, ok := err.(*Error) + if !ok { + t.Fatalf("expected *pq.Error, got %#+v", err) + } + if pge.Code.Class() != "22" { + t.Fatalf("expected class 28, got %v", pge.Code.Class()) + } + if pge.Code.Class().Name() != "data_exception" { + t.Fatalf("expected data_exception, got %v", pge.Code.Class().Name()) + } +} + +func TestParseOpts(t *testing.T) { + tests := []struct { + in string + expected values + valid bool + }{ + {"dbname=hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, + {"dbname=hello user=goodbye ", values{"dbname": "hello", "user": "goodbye"}, true}, + {"dbname = hello user=goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, + {"dbname=hello user =goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, + {"dbname=hello user= goodbye", values{"dbname": "hello", "user": "goodbye"}, true}, + {"host=localhost password='correct horse battery staple'", values{"host": "localhost", "password": "correct horse battery staple"}, true}, + {"dbname=データベース password=パスワード", values{"dbname": "データベース", "password": "パスワード"}, true}, + {"dbname=hello user=''", values{"dbname": "hello", "user": ""}, true}, + {"user='' dbname=hello", values{"dbname": "hello", "user": ""}, true}, + // The last option value is an empty string if there's no non-whitespace after its = + {"dbname=hello user= ", values{"dbname": "hello", "user": ""}, true}, + + // The parser ignores spaces after = and interprets the next set of non-whitespace characters as the value. + {"user= password=foo", values{"user": "password=foo"}, true}, + + // Backslash escapes next char + {`user=a\ \'\\b`, values{"user": `a '\b`}, true}, + {`user='a \'b'`, values{"user": `a 'b`}, true}, + + // Incomplete escape + {`user=x\`, values{}, false}, + + // No '=' after the key + {"postgre://marko@internet", values{}, false}, + {"dbname user=goodbye", values{}, false}, + {"user=foo blah", values{}, false}, + {"user=foo blah ", values{}, false}, + + // Unterminated quoted value + {"dbname=hello user='unterminated", values{}, false}, + } + + for _, test := range tests { + o := make(values) + err := parseOpts(test.in, o) + + switch { + case err != nil && test.valid: + t.Errorf("%q got unexpected error: %s", test.in, err) + case err == nil && test.valid && !reflect.DeepEqual(test.expected, o): + t.Errorf("%q got: %#v want: %#v", test.in, o, test.expected) + case err == nil && !test.valid: + t.Errorf("%q expected an error", test.in) + } + } +} + +func TestRuntimeParameters(t *testing.T) { + tests := []struct { + conninfo string + param string + expected string + success bool + }{ + // invalid parameter + {"DOESNOTEXIST=foo", "", "", false}, + // we can only work with a specific value for these two + {"client_encoding=SQL_ASCII", "", "", false}, + {"datestyle='ISO, YDM'", "", "", false}, + // "options" should work exactly as it does in libpq + {"options='-c search_path=pqgotest'", "search_path", "pqgotest", true}, + // pq should override client_encoding in this case + {"options='-c client_encoding=SQL_ASCII'", "client_encoding", "UTF8", true}, + // allow client_encoding to be set explicitly + {"client_encoding=UTF8", "client_encoding", "UTF8", true}, + // test a runtime parameter not supported by libpq + {"work_mem='139kB'", "work_mem", "139kB", true}, + // test fallback_application_name + {"application_name=foo fallback_application_name=bar", "application_name", "foo", true}, + {"application_name='' fallback_application_name=bar", "application_name", "", true}, + {"fallback_application_name=bar", "application_name", "bar", true}, + } + + for _, test := range tests { + db, err := openTestConnConninfo(test.conninfo) + if err != nil { + t.Fatal(err) + } + + // application_name didn't exist before 9.0 + if test.param == "application_name" && getServerVersion(t, db) < 90000 { + db.Close() + continue + } + + tryGetParameterValue := func() (value string, success bool) { + defer db.Close() + row := db.QueryRow("SELECT current_setting($1)", test.param) + err = row.Scan(&value) + if err != nil { + return "", false + } + return value, true + } + + value, success := tryGetParameterValue() + if success != test.success && !test.success { + t.Fatalf("%v: unexpected error: %v", test.conninfo, err) + } + if success != test.success { + t.Fatalf("unexpected outcome %v (was expecting %v) for conninfo \"%s\"", + success, test.success, test.conninfo) + } + if value != test.expected { + t.Fatalf("bad value for %s: got %s, want %s with conninfo \"%s\"", + test.param, value, test.expected, test.conninfo) + } + } +} + +func TestIsUTF8(t *testing.T) { + var cases = []struct { + name string + want bool + }{ + {"unicode", true}, + {"utf-8", true}, + {"utf_8", true}, + {"UTF-8", true}, + {"UTF8", true}, + {"utf8", true}, + {"u n ic_ode", true}, + {"ut_f%8", true}, + {"ubf8", false}, + {"punycode", false}, + } + + for _, test := range cases { + if g := isUTF8(test.name); g != test.want { + t.Errorf("isUTF8(%q) = %v want %v", test.name, g, test.want) + } + } +} + +func TestQuoteIdentifier(t *testing.T) { + var cases = []struct { + input string + want string + }{ + {`foo`, `"foo"`}, + {`foo bar baz`, `"foo bar baz"`}, + {`foo"bar`, `"foo""bar"`}, + {"foo\x00bar", `"foo"`}, + {"\x00foo", `""`}, + } + + for _, test := range cases { + got := QuoteIdentifier(test.input) + if got != test.want { + t.Errorf("QuoteIdentifier(%q) = %v want %v", test.input, got, test.want) + } + } +} + +func TestQuoteLiteral(t *testing.T) { + var cases = []struct { + input string + want string + }{ + {`foo`, `'foo'`}, + {`foo bar baz`, `'foo bar baz'`}, + {`foo'bar`, `'foo''bar'`}, + {`foo\bar`, ` E'foo\\bar'`}, + {`foo\ba'r`, ` E'foo\\ba''r'`}, + {`foo"bar`, `'foo"bar'`}, + {`foo\x00bar`, ` E'foo\\x00bar'`}, + {`\x00foo`, ` E'\\x00foo'`}, + {`'`, `''''`}, + {`''`, `''''''`}, + {`\`, ` E'\\'`}, + {`'abc'; DROP TABLE users;`, `'''abc''; DROP TABLE users;'`}, + {`\'`, ` E'\\'''`}, + {`E'\''`, ` E'E''\\'''''`}, + {`e'\''`, ` E'e''\\'''''`}, + {`E'\'abc\'; DROP TABLE users;'`, ` E'E''\\''abc\\''; DROP TABLE users;'''`}, + {`e'\'abc\'; DROP TABLE users;'`, ` E'e''\\''abc\\''; DROP TABLE users;'''`}, + } + + for _, test := range cases { + got := QuoteLiteral(test.input) + if got != test.want { + t.Errorf("QuoteLiteral(%q) = %v want %v", test.input, got, test.want) + } + } +} + +func TestRowsResultTag(t *testing.T) { + type ResultTag interface { + Result() driver.Result + Tag() string + } + + tests := []struct { + query string + tag string + ra int64 + }{ + { + query: "CREATE TEMP TABLE temp (a int)", + tag: "CREATE TABLE", + }, + { + query: "INSERT INTO temp VALUES (1), (2)", + tag: "INSERT", + ra: 2, + }, + { + query: "SELECT 1", + }, + // A SELECT anywhere should take precedent. + { + query: "SELECT 1; INSERT INTO temp VALUES (1), (2)", + }, + { + query: "INSERT INTO temp VALUES (1), (2); SELECT 1", + }, + // Multiple statements that don't return rows should return the last tag. + { + query: "CREATE TEMP TABLE t (a int); DROP TABLE t", + tag: "DROP TABLE", + }, + // Ensure a rows-returning query in any position among various tags-returing + // statements will prefer the rows. + { + query: "SELECT 1; CREATE TEMP TABLE t (a int); DROP TABLE t", + }, + { + query: "CREATE TEMP TABLE t (a int); SELECT 1; DROP TABLE t", + }, + { + query: "CREATE TEMP TABLE t (a int); DROP TABLE t; SELECT 1", + }, + // Verify that an no-results query doesn't set the tag. + { + query: "CREATE TEMP TABLE t (a int); SELECT 1 WHERE FALSE; DROP TABLE t;", + }, + } + + // If this is the only test run, this will correct the connection string. + openTestConn(t).Close() + + conn, err := Open("") + if err != nil { + t.Fatal(err) + } + defer conn.Close() + q := conn.(driver.QueryerContext) + + for _, test := range tests { + if rows, err := q.QueryContext(context.Background(), test.query, nil); err != nil { + t.Fatalf("%s: %s", test.query, err) + } else { + r := rows.(ResultTag) + if tag := r.Tag(); tag != test.tag { + t.Fatalf("%s: unexpected tag %q", test.query, tag) + } + res := r.Result() + if ra, _ := res.RowsAffected(); ra != test.ra { + t.Fatalf("%s: unexpected rows affected: %d", test.query, ra) + } + rows.Close() + } + } +} + +// TestQuickClose tests that closing a query early allows a subsequent query to work. +func TestQuickClose(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + rows, err := tx.Query("SELECT 1; SELECT 2;") + if err != nil { + t.Fatal(err) + } + if err := rows.Close(); err != nil { + t.Fatal(err) + } + + var id int + if err := tx.QueryRow("SELECT 3").Scan(&id); err != nil { + t.Fatal(err) + } + if id != 3 { + t.Fatalf("unexpected %d", id) + } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } +} + +func TestMultipleResult(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query(` + begin; + select * from information_schema.tables limit 1; + select * from information_schema.columns limit 2; + commit; + `) + if err != nil { + t.Fatal(err) + } + type set struct { + cols []string + rowCount int + } + buf := []*set{} + for { + cols, err := rows.Columns() + if err != nil { + t.Fatal(err) + } + s := &set{ + cols: cols, + } + buf = append(buf, s) + + for rows.Next() { + s.rowCount++ + } + if !rows.NextResultSet() { + break + } + } + if len(buf) != 2 { + t.Fatalf("got %d sets, expected 2", len(buf)) + } + if len(buf[0].cols) == len(buf[1].cols) || len(buf[1].cols) == 0 { + t.Fatal("invalid cols size, expected different column count and greater then zero") + } + if buf[0].rowCount != 1 || buf[1].rowCount != 2 { + t.Fatal("incorrect number of rows returned") + } +} diff --git a/vendor/github.com/lib/pq/connector.go b/vendor/github.com/lib/pq/connector.go index 9e66eb5d..2f8ced67 100644 --- a/vendor/github.com/lib/pq/connector.go +++ b/vendor/github.com/lib/pq/connector.go @@ -1,10 +1,12 @@ -// +build go1.10 - package pq import ( "context" "database/sql/driver" + "errors" + "fmt" + "os" + "strings" ) // Connector represents a fixed configuration for the pq driver with a given @@ -14,30 +16,95 @@ import ( // // See https://golang.org/pkg/database/sql/driver/#Connector. // See https://golang.org/pkg/database/sql/#OpenDB. -type connector struct { - name string +type Connector struct { + opts values + dialer Dialer } // Connect returns a connection to the database using the fixed configuration // of this Connector. Context is not used. -func (c *connector) Connect(_ context.Context) (driver.Conn, error) { - return (&Driver{}).Open(c.name) +func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) { + return c.open(ctx) } // Driver returnst the underlying driver of this Connector. -func (c *connector) Driver() driver.Driver { +func (c *Connector) Driver() driver.Driver { return &Driver{} } -var _ driver.Connector = &connector{} - // NewConnector returns a connector for the pq driver in a fixed configuration -// with the given name. The returned connector can be used to create any number +// with the given dsn. The returned connector can be used to create any number // of equivalent Conn's. The returned connector is intended to be used with // database/sql.OpenDB. // // See https://golang.org/pkg/database/sql/driver/#Connector. // See https://golang.org/pkg/database/sql/#OpenDB. -func NewConnector(name string) (driver.Connector, error) { - return &connector{name: name}, nil +func NewConnector(dsn string) (*Connector, error) { + var err error + o := make(values) + + // A number of defaults are applied here, in this order: + // + // * Very low precedence defaults applied in every situation + // * Environment variables + // * Explicitly passed connection information + o["host"] = "localhost" + o["port"] = "5432" + // N.B.: Extra float digits should be set to 3, but that breaks + // Postgres 8.4 and older, where the max is 2. + o["extra_float_digits"] = "2" + for k, v := range parseEnviron(os.Environ()) { + o[k] = v + } + + if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") { + dsn, err = ParseURL(dsn) + if err != nil { + return nil, err + } + } + + if err := parseOpts(dsn, o); err != nil { + return nil, err + } + + // Use the "fallback" application name if necessary + if fallback, ok := o["fallback_application_name"]; ok { + if _, ok := o["application_name"]; !ok { + o["application_name"] = fallback + } + } + + // We can't work with any client_encoding other than UTF-8 currently. + // However, we have historically allowed the user to set it to UTF-8 + // explicitly, and there's no reason to break such programs, so allow that. + // Note that the "options" setting could also set client_encoding, but + // parsing its value is not worth it. Instead, we always explicitly send + // client_encoding as a separate run-time parameter, which should override + // anything set in options. + if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) { + return nil, errors.New("client_encoding must be absent or 'UTF8'") + } + o["client_encoding"] = "UTF8" + // DateStyle needs a similar treatment. + if datestyle, ok := o["datestyle"]; ok { + if datestyle != "ISO, MDY" { + return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle) + } + } else { + o["datestyle"] = "ISO, MDY" + } + + // If a user is not provided by any other means, the last + // resort is to use the current operating system provided user + // name. + if _, ok := o["user"]; !ok { + u, err := userCurrent() + if err != nil { + return nil, err + } + o["user"] = u + } + + return &Connector{opts: o, dialer: defaultDialer{}}, nil } diff --git a/vendor/github.com/lib/pq/connector_example_test.go b/vendor/github.com/lib/pq/connector_example_test.go new file mode 100644 index 00000000..5b66cf4b --- /dev/null +++ b/vendor/github.com/lib/pq/connector_example_test.go @@ -0,0 +1,33 @@ +// +build go1.10 + +package pq_test + +import ( + "database/sql" + "fmt" + + "github.com/lib/pq" +) + +func ExampleNewConnector() { + name := "" + connector, err := pq.NewConnector(name) + if err != nil { + fmt.Println(err) + return + } + db := sql.OpenDB(connector) + if err != nil { + fmt.Println(err) + return + } + defer db.Close() + + // Use the DB + txn, err := db.Begin() + if err != nil { + fmt.Println(err) + return + } + txn.Rollback() +} diff --git a/vendor/github.com/lib/pq/connector_test.go b/vendor/github.com/lib/pq/connector_test.go new file mode 100644 index 00000000..3d2c67b0 --- /dev/null +++ b/vendor/github.com/lib/pq/connector_test.go @@ -0,0 +1,67 @@ +// +build go1.10 + +package pq + +import ( + "context" + "database/sql" + "database/sql/driver" + "testing" +) + +func TestNewConnector_WorksWithOpenDB(t *testing.T) { + name := "" + c, err := NewConnector(name) + if err != nil { + t.Fatal(err) + } + db := sql.OpenDB(c) + defer db.Close() + // database/sql might not call our Open at all unless we do something with + // the connection + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + txn.Rollback() +} + +func TestNewConnector_Connect(t *testing.T) { + name := "" + c, err := NewConnector(name) + if err != nil { + t.Fatal(err) + } + db, err := c.Connect(context.Background()) + if err != nil { + t.Fatal(err) + } + defer db.Close() + // database/sql might not call our Open at all unless we do something with + // the connection + txn, err := db.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) + if err != nil { + t.Fatal(err) + } + txn.Rollback() +} + +func TestNewConnector_Driver(t *testing.T) { + name := "" + c, err := NewConnector(name) + if err != nil { + t.Fatal(err) + } + db, err := c.Driver().Open(name) + if err != nil { + t.Fatal(err) + } + defer db.Close() + // database/sql might not call our Open at all unless we do something with + // the connection + txn, err := db.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{}) + if err != nil { + t.Fatal(err) + } + txn.Rollback() +} diff --git a/vendor/github.com/lib/pq/copy_test.go b/vendor/github.com/lib/pq/copy_test.go new file mode 100644 index 00000000..a888a894 --- /dev/null +++ b/vendor/github.com/lib/pq/copy_test.go @@ -0,0 +1,468 @@ +package pq + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "net" + "strings" + "testing" +) + +func TestCopyInStmt(t *testing.T) { + stmt := CopyIn("table name") + if stmt != `COPY "table name" () FROM STDIN` { + t.Fatal(stmt) + } + + stmt = CopyIn("table name", "column 1", "column 2") + if stmt != `COPY "table name" ("column 1", "column 2") FROM STDIN` { + t.Fatal(stmt) + } + + stmt = CopyIn(`table " name """`, `co"lumn""`) + if stmt != `COPY "table "" name """"""" ("co""lumn""""") FROM STDIN` { + t.Fatal(stmt) + } +} + +func TestCopyInSchemaStmt(t *testing.T) { + stmt := CopyInSchema("schema name", "table name") + if stmt != `COPY "schema name"."table name" () FROM STDIN` { + t.Fatal(stmt) + } + + stmt = CopyInSchema("schema name", "table name", "column 1", "column 2") + if stmt != `COPY "schema name"."table name" ("column 1", "column 2") FROM STDIN` { + t.Fatal(stmt) + } + + stmt = CopyInSchema(`schema " name """`, `table " name """`, `co"lumn""`) + if stmt != `COPY "schema "" name """"""".`+ + `"table "" name """"""" ("co""lumn""""") FROM STDIN` { + t.Fatal(stmt) + } +} + +func TestCopyInMultipleValues(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + + _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") + if err != nil { + t.Fatal(err) + } + + stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) + if err != nil { + t.Fatal(err) + } + + longString := strings.Repeat("#", 500) + + for i := 0; i < 500; i++ { + _, err = stmt.Exec(int64(i), longString) + if err != nil { + t.Fatal(err) + } + } + + _, err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + + var num int + err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) + if err != nil { + t.Fatal(err) + } + + if num != 500 { + t.Fatalf("expected 500 items, not %d", num) + } +} + +func TestCopyInRaiseStmtTrigger(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + if getServerVersion(t, db) < 90000 { + var exists int + err := db.QueryRow("SELECT 1 FROM pg_language WHERE lanname = 'plpgsql'").Scan(&exists) + if err == sql.ErrNoRows { + t.Skip("language PL/PgSQL does not exist; skipping TestCopyInRaiseStmtTrigger") + } else if err != nil { + t.Fatal(err) + } + } + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + + _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") + if err != nil { + t.Fatal(err) + } + + _, err = txn.Exec(` + CREATE OR REPLACE FUNCTION pg_temp.temptest() + RETURNS trigger AS + $BODY$ begin + raise notice 'Hello world'; + return new; + end $BODY$ + LANGUAGE plpgsql`) + if err != nil { + t.Fatal(err) + } + + _, err = txn.Exec(` + CREATE TRIGGER temptest_trigger + BEFORE INSERT + ON temp + FOR EACH ROW + EXECUTE PROCEDURE pg_temp.temptest()`) + if err != nil { + t.Fatal(err) + } + + stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) + if err != nil { + t.Fatal(err) + } + + longString := strings.Repeat("#", 500) + + _, err = stmt.Exec(int64(1), longString) + if err != nil { + t.Fatal(err) + } + + _, err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + + var num int + err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) + if err != nil { + t.Fatal(err) + } + + if num != 1 { + t.Fatalf("expected 1 items, not %d", num) + } +} + +func TestCopyInTypes(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + + _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER, text VARCHAR, blob BYTEA, nothing VARCHAR)") + if err != nil { + t.Fatal(err) + } + + stmt, err := txn.Prepare(CopyIn("temp", "num", "text", "blob", "nothing")) + if err != nil { + t.Fatal(err) + } + + _, err = stmt.Exec(int64(1234567890), "Héllö\n ☃!\r\t\\", []byte{0, 255, 9, 10, 13}, nil) + if err != nil { + t.Fatal(err) + } + + _, err = stmt.Exec() + if err != nil { + t.Fatal(err) + } + + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + + var num int + var text string + var blob []byte + var nothing sql.NullString + + err = txn.QueryRow("SELECT * FROM temp").Scan(&num, &text, &blob, ¬hing) + if err != nil { + t.Fatal(err) + } + + if num != 1234567890 { + t.Fatal("unexpected result", num) + } + if text != "Héllö\n ☃!\r\t\\" { + t.Fatal("unexpected result", text) + } + if !bytes.Equal(blob, []byte{0, 255, 9, 10, 13}) { + t.Fatal("unexpected result", blob) + } + if nothing.Valid { + t.Fatal("unexpected result", nothing.String) + } +} + +func TestCopyInWrongType(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + + _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") + if err != nil { + t.Fatal(err) + } + + stmt, err := txn.Prepare(CopyIn("temp", "num")) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + _, err = stmt.Exec("Héllö\n ☃!\r\t\\") + if err != nil { + t.Fatal(err) + } + + _, err = stmt.Exec() + if err == nil { + t.Fatal("expected error") + } + if pge := err.(*Error); pge.Code.Name() != "invalid_text_representation" { + t.Fatalf("expected 'invalid input syntax for integer' error, got %s (%+v)", pge.Code.Name(), pge) + } +} + +func TestCopyOutsideOfTxnError(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + _, err := db.Prepare(CopyIn("temp", "num")) + if err == nil { + t.Fatal("COPY outside of transaction did not return an error") + } + if err != errCopyNotSupportedOutsideTxn { + t.Fatalf("expected %s, got %s", err, err.Error()) + } +} + +func TestCopyInBinaryError(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + + _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") + if err != nil { + t.Fatal(err) + } + _, err = txn.Prepare("COPY temp (num) FROM STDIN WITH binary") + if err != errBinaryCopyNotSupported { + t.Fatalf("expected %s, got %+v", errBinaryCopyNotSupported, err) + } + // check that the protocol is in a valid state + err = txn.Rollback() + if err != nil { + t.Fatal(err) + } +} + +func TestCopyFromError(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + + _, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)") + if err != nil { + t.Fatal(err) + } + _, err = txn.Prepare("COPY temp (num) TO STDOUT") + if err != errCopyToNotSupported { + t.Fatalf("expected %s, got %+v", errCopyToNotSupported, err) + } + // check that the protocol is in a valid state + err = txn.Rollback() + if err != nil { + t.Fatal(err) + } +} + +func TestCopySyntaxError(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + + _, err = txn.Prepare("COPY ") + if err == nil { + t.Fatal("expected error") + } + if pge := err.(*Error); pge.Code.Name() != "syntax_error" { + t.Fatalf("expected syntax error, got %s (%+v)", pge.Code.Name(), pge) + } + // check that the protocol is in a valid state + err = txn.Rollback() + if err != nil { + t.Fatal(err) + } +} + +// Tests for connection errors in copyin.resploop() +func TestCopyRespLoopConnectionError(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + + var pid int + err = txn.QueryRow("SELECT pg_backend_pid()").Scan(&pid) + if err != nil { + t.Fatal(err) + } + + _, err = txn.Exec("CREATE TEMP TABLE temp (a int)") + if err != nil { + t.Fatal(err) + } + + stmt, err := txn.Prepare(CopyIn("temp", "a")) + if err != nil { + t.Fatal(err) + } + defer stmt.Close() + + _, err = db.Exec("SELECT pg_terminate_backend($1)", pid) + if err != nil { + t.Fatal(err) + } + + if getServerVersion(t, db) < 90500 { + // We have to try and send something over, since postgres before + // version 9.5 won't process SIGTERMs while it's waiting for + // CopyData/CopyEnd messages; see tcop/postgres.c. + _, err = stmt.Exec(1) + if err != nil { + t.Fatal(err) + } + } + _, err = stmt.Exec() + if err == nil { + t.Fatalf("expected error") + } + switch pge := err.(type) { + case *Error: + if pge.Code.Name() != "admin_shutdown" { + t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name()) + } + case *net.OpError: + // ignore + default: + if err == driver.ErrBadConn { + // likely an EPIPE + } else { + t.Fatalf("unexpected error, got %+#v", err) + } + } + + _ = stmt.Close() +} + +func BenchmarkCopyIn(b *testing.B) { + db := openTestConn(b) + defer db.Close() + + txn, err := db.Begin() + if err != nil { + b.Fatal(err) + } + defer txn.Rollback() + + _, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)") + if err != nil { + b.Fatal(err) + } + + stmt, err := txn.Prepare(CopyIn("temp", "a", "b")) + if err != nil { + b.Fatal(err) + } + + for i := 0; i < b.N; i++ { + _, err = stmt.Exec(int64(i), "hello world!") + if err != nil { + b.Fatal(err) + } + } + + _, err = stmt.Exec() + if err != nil { + b.Fatal(err) + } + + err = stmt.Close() + if err != nil { + b.Fatal(err) + } + + var num int + err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num) + if err != nil { + b.Fatal(err) + } + + if num != b.N { + b.Fatalf("expected %d items, not %d", b.N, num) + } +} diff --git a/vendor/github.com/lib/pq/doc.go b/vendor/github.com/lib/pq/doc.go index a1b02971..2a60054e 100644 --- a/vendor/github.com/lib/pq/doc.go +++ b/vendor/github.com/lib/pq/doc.go @@ -239,7 +239,7 @@ for more information). Note that the channel name will be truncated to 63 bytes by the PostgreSQL server. You can find a complete, working example of Listener usage at -http://godoc.org/github.com/lib/pq/example/listen. +https://godoc.org/github.com/lib/pq/example/listen. */ package pq diff --git a/vendor/github.com/lib/pq/encode.go b/vendor/github.com/lib/pq/encode.go index 3b0d365f..a6902fae 100644 --- a/vendor/github.com/lib/pq/encode.go +++ b/vendor/github.com/lib/pq/encode.go @@ -117,11 +117,10 @@ func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interfa } return i case oid.T_float4, oid.T_float8: - bits := 64 - if typ == oid.T_float4 { - bits = 32 - } - f, err := strconv.ParseFloat(string(s), bits) + // We always use 64 bit parsing, regardless of whether the input text is for + // a float4 or float8, because clients expect float64s for all float datatypes + // and returning a 32-bit parsed float64 produces lossy results. + f, err := strconv.ParseFloat(string(s), 64) if err != nil { errorf("%s", err) } diff --git a/vendor/github.com/lib/pq/encode_test.go b/vendor/github.com/lib/pq/encode_test.go new file mode 100644 index 00000000..d58798a4 --- /dev/null +++ b/vendor/github.com/lib/pq/encode_test.go @@ -0,0 +1,766 @@ +package pq + +import ( + "bytes" + "database/sql" + "fmt" + "regexp" + "testing" + "time" + + "github.com/lib/pq/oid" +) + +func TestScanTimestamp(t *testing.T) { + var nt NullTime + tn := time.Now() + nt.Scan(tn) + if !nt.Valid { + t.Errorf("Expected Valid=false") + } + if nt.Time != tn { + t.Errorf("Time value mismatch") + } +} + +func TestScanNilTimestamp(t *testing.T) { + var nt NullTime + nt.Scan(nil) + if nt.Valid { + t.Errorf("Expected Valid=false") + } +} + +var timeTests = []struct { + str string + timeval time.Time +}{ + {"22001-02-03", time.Date(22001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))}, + {"2001-02-03", time.Date(2001, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))}, + {"0001-12-31 BC", time.Date(0, time.December, 31, 0, 0, 0, 0, time.FixedZone("", 0))}, + {"2001-02-03 BC", time.Date(-2000, time.February, 3, 0, 0, 0, 0, time.FixedZone("", 0))}, + {"2001-02-03 04:05:06", time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))}, + {"2001-02-03 04:05:06.000001", time.Date(2001, time.February, 3, 4, 5, 6, 1000, time.FixedZone("", 0))}, + {"2001-02-03 04:05:06.00001", time.Date(2001, time.February, 3, 4, 5, 6, 10000, time.FixedZone("", 0))}, + {"2001-02-03 04:05:06.0001", time.Date(2001, time.February, 3, 4, 5, 6, 100000, time.FixedZone("", 0))}, + {"2001-02-03 04:05:06.001", time.Date(2001, time.February, 3, 4, 5, 6, 1000000, time.FixedZone("", 0))}, + {"2001-02-03 04:05:06.01", time.Date(2001, time.February, 3, 4, 5, 6, 10000000, time.FixedZone("", 0))}, + {"2001-02-03 04:05:06.1", time.Date(2001, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))}, + {"2001-02-03 04:05:06.12", time.Date(2001, time.February, 3, 4, 5, 6, 120000000, time.FixedZone("", 0))}, + {"2001-02-03 04:05:06.123", time.Date(2001, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, + {"2001-02-03 04:05:06.1234", time.Date(2001, time.February, 3, 4, 5, 6, 123400000, time.FixedZone("", 0))}, + {"2001-02-03 04:05:06.12345", time.Date(2001, time.February, 3, 4, 5, 6, 123450000, time.FixedZone("", 0))}, + {"2001-02-03 04:05:06.123456", time.Date(2001, time.February, 3, 4, 5, 6, 123456000, time.FixedZone("", 0))}, + {"2001-02-03 04:05:06.123-07", time.Date(2001, time.February, 3, 4, 5, 6, 123000000, + time.FixedZone("", -7*60*60))}, + {"2001-02-03 04:05:06-07", time.Date(2001, time.February, 3, 4, 5, 6, 0, + time.FixedZone("", -7*60*60))}, + {"2001-02-03 04:05:06-07:42", time.Date(2001, time.February, 3, 4, 5, 6, 0, + time.FixedZone("", -(7*60*60+42*60)))}, + {"2001-02-03 04:05:06-07:30:09", time.Date(2001, time.February, 3, 4, 5, 6, 0, + time.FixedZone("", -(7*60*60+30*60+9)))}, + {"2001-02-03 04:05:06+07", time.Date(2001, time.February, 3, 4, 5, 6, 0, + time.FixedZone("", 7*60*60))}, + {"0011-02-03 04:05:06 BC", time.Date(-10, time.February, 3, 4, 5, 6, 0, time.FixedZone("", 0))}, + {"0011-02-03 04:05:06.123 BC", time.Date(-10, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, + {"0011-02-03 04:05:06.123-07 BC", time.Date(-10, time.February, 3, 4, 5, 6, 123000000, + time.FixedZone("", -7*60*60))}, + {"0001-02-03 04:05:06.123", time.Date(1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, + {"0001-02-03 04:05:06.123 BC", time.Date(1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0)).AddDate(-1, 0, 0)}, + {"0001-02-03 04:05:06.123 BC", time.Date(0, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, + {"0002-02-03 04:05:06.123 BC", time.Date(0, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0)).AddDate(-1, 0, 0)}, + {"0002-02-03 04:05:06.123 BC", time.Date(-1, time.February, 3, 4, 5, 6, 123000000, time.FixedZone("", 0))}, + {"12345-02-03 04:05:06.1", time.Date(12345, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))}, + {"123456-02-03 04:05:06.1", time.Date(123456, time.February, 3, 4, 5, 6, 100000000, time.FixedZone("", 0))}, +} + +// Test that parsing the string results in the expected value. +func TestParseTs(t *testing.T) { + for i, tt := range timeTests { + val, err := ParseTimestamp(nil, tt.str) + if err != nil { + t.Errorf("%d: got error: %v", i, err) + } else if val.String() != tt.timeval.String() { + t.Errorf("%d: expected to parse %q into %q; got %q", + i, tt.str, tt.timeval, val) + } + } +} + +var timeErrorTests = []string{ + "BC", + " BC", + "2001", + "2001-2-03", + "2001-02-3", + "2001-02-03 ", + "2001-02-03 B", + "2001-02-03 04", + "2001-02-03 04:", + "2001-02-03 04:05", + "2001-02-03 04:05 B", + "2001-02-03 04:05 BC", + "2001-02-03 04:05:", + "2001-02-03 04:05:6", + "2001-02-03 04:05:06 B", + "2001-02-03 04:05:06BC", + "2001-02-03 04:05:06.123 B", +} + +// Test that parsing the string results in an error. +func TestParseTsErrors(t *testing.T) { + for i, tt := range timeErrorTests { + _, err := ParseTimestamp(nil, tt) + if err == nil { + t.Errorf("%d: expected an error from parsing: %v", i, tt) + } + } +} + +// Now test that sending the value into the database and parsing it back +// returns the same time.Time value. +func TestEncodeAndParseTs(t *testing.T) { + db, err := openTestConnConninfo("timezone='Etc/UTC'") + if err != nil { + t.Fatal(err) + } + defer db.Close() + + for i, tt := range timeTests { + var dbstr string + err = db.QueryRow("SELECT ($1::timestamptz)::text", tt.timeval).Scan(&dbstr) + if err != nil { + t.Errorf("%d: could not send value %q to the database: %s", i, tt.timeval, err) + continue + } + + val, err := ParseTimestamp(nil, dbstr) + if err != nil { + t.Errorf("%d: could not parse value %q: %s", i, dbstr, err) + continue + } + val = val.In(tt.timeval.Location()) + if val.String() != tt.timeval.String() { + t.Errorf("%d: expected to parse %q into %q; got %q", i, dbstr, tt.timeval, val) + } + } +} + +var formatTimeTests = []struct { + time time.Time + expected string +}{ + {time.Time{}, "0001-01-01 00:00:00Z"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "2001-02-03 04:05:06.123456789Z"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "2001-02-03 04:05:06.123456789+02:00"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "2001-02-03 04:05:06.123456789-06:00"}, + {time.Date(2001, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "2001-02-03 04:05:06-07:30:09"}, + + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z"}, + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00"}, + {time.Date(1, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00"}, + + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 0)), "0001-02-03 04:05:06.123456789Z BC"}, + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", 2*60*60)), "0001-02-03 04:05:06.123456789+02:00 BC"}, + {time.Date(0, time.February, 3, 4, 5, 6, 123456789, time.FixedZone("", -6*60*60)), "0001-02-03 04:05:06.123456789-06:00 BC"}, + + {time.Date(1, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09"}, + {time.Date(0, time.February, 3, 4, 5, 6, 0, time.FixedZone("", -(7*60*60+30*60+9))), "0001-02-03 04:05:06-07:30:09 BC"}, +} + +func TestFormatTs(t *testing.T) { + for i, tt := range formatTimeTests { + val := string(formatTs(tt.time)) + if val != tt.expected { + t.Errorf("%d: incorrect time format %q, want %q", i, val, tt.expected) + } + } +} + +func TestFormatTsBackend(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + var str string + err := db.QueryRow("SELECT '2001-02-03T04:05:06.007-08:09:10'::time::text").Scan(&str) + if err == nil { + t.Fatalf("PostgreSQL is accepting an ISO timestamp input for time") + } + + for i, tt := range formatTimeTests { + for _, typ := range []string{"date", "time", "timetz", "timestamp", "timestamptz"} { + err = db.QueryRow("SELECT $1::"+typ+"::text", tt.time).Scan(&str) + if err != nil { + t.Errorf("%d: incorrect time format for %v on the backend: %v", i, typ, err) + } + } + } +} + +func TestTimestampWithTimeZone(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + tx, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + // try several different locations, all included in Go's zoneinfo.zip + for _, locName := range []string{ + "UTC", + "America/Chicago", + "America/New_York", + "Australia/Darwin", + "Australia/Perth", + } { + loc, err := time.LoadLocation(locName) + if err != nil { + t.Logf("Could not load time zone %s - skipping", locName) + continue + } + + // Postgres timestamps have a resolution of 1 microsecond, so don't + // use the full range of the Nanosecond argument + refTime := time.Date(2012, 11, 6, 10, 23, 42, 123456000, loc) + + for _, pgTimeZone := range []string{"US/Eastern", "Australia/Darwin"} { + // Switch Postgres's timezone to test different output timestamp formats + _, err = tx.Exec(fmt.Sprintf("set time zone '%s'", pgTimeZone)) + if err != nil { + t.Fatal(err) + } + + var gotTime time.Time + row := tx.QueryRow("select $1::timestamp with time zone", refTime) + err = row.Scan(&gotTime) + if err != nil { + t.Fatal(err) + } + + if !refTime.Equal(gotTime) { + t.Errorf("timestamps not equal: %s != %s", refTime, gotTime) + } + + // check that the time zone is set correctly based on TimeZone + pgLoc, err := time.LoadLocation(pgTimeZone) + if err != nil { + t.Logf("Could not load time zone %s - skipping", pgLoc) + continue + } + translated := refTime.In(pgLoc) + if translated.String() != gotTime.String() { + t.Errorf("timestamps not equal: %s != %s", translated, gotTime) + } + } + } +} + +func TestTimestampWithOutTimezone(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + test := func(ts, pgts string) { + r, err := db.Query("SELECT $1::timestamp", pgts) + if err != nil { + t.Fatalf("Could not run query: %v", err) + } + + if !r.Next() { + t.Fatal("Expected at least one row") + } + + var result time.Time + err = r.Scan(&result) + if err != nil { + t.Fatalf("Did not expect error scanning row: %v", err) + } + + expected, err := time.Parse(time.RFC3339, ts) + if err != nil { + t.Fatalf("Could not parse test time literal: %v", err) + } + + if !result.Equal(expected) { + t.Fatalf("Expected time to match %v: got mismatch %v", + expected, result) + } + + if r.Next() { + t.Fatal("Expected only one row") + } + } + + test("2000-01-01T00:00:00Z", "2000-01-01T00:00:00") + + // Test higher precision time + test("2013-01-04T20:14:58.80033Z", "2013-01-04 20:14:58.80033") +} + +func TestInfinityTimestamp(t *testing.T) { + db := openTestConn(t) + defer db.Close() + var err error + var resultT time.Time + + expectedErrorStrRegexp := regexp.MustCompile( + `^sql: Scan error on column index 0(, name "timestamp(tz)?"|): unsupported`) + + type testCases []struct { + Query string + Param string + ExpectedErrorStrRegexp *regexp.Regexp + ExpectedVal interface{} + } + tc := testCases{ + {"SELECT $1::timestamp", "-infinity", expectedErrorStrRegexp, "-infinity"}, + {"SELECT $1::timestamptz", "-infinity", expectedErrorStrRegexp, "-infinity"}, + {"SELECT $1::timestamp", "infinity", expectedErrorStrRegexp, "infinity"}, + {"SELECT $1::timestamptz", "infinity", expectedErrorStrRegexp, "infinity"}, + } + // try to assert []byte to time.Time + for _, q := range tc { + err = db.QueryRow(q.Query, q.Param).Scan(&resultT) + if !q.ExpectedErrorStrRegexp.MatchString(err.Error()) { + t.Errorf("Scanning -/+infinity, expected error to match regexp %q, got %q", + q.ExpectedErrorStrRegexp, err) + } + } + // yield []byte + for _, q := range tc { + var resultI interface{} + err = db.QueryRow(q.Query, q.Param).Scan(&resultI) + if err != nil { + t.Errorf("Scanning -/+infinity, expected no error, got %q", err) + } + result, ok := resultI.([]byte) + if !ok { + t.Errorf("Scanning -/+infinity, expected []byte, got %#v", resultI) + } + if string(result) != q.ExpectedVal { + t.Errorf("Scanning -/+infinity, expected %q, got %q", q.ExpectedVal, result) + } + } + + y1500 := time.Date(1500, time.January, 1, 0, 0, 0, 0, time.UTC) + y2500 := time.Date(2500, time.January, 1, 0, 0, 0, 0, time.UTC) + EnableInfinityTs(y1500, y2500) + + err = db.QueryRow("SELECT $1::timestamp", "infinity").Scan(&resultT) + if err != nil { + t.Errorf("Scanning infinity, expected no error, got %q", err) + } + if !resultT.Equal(y2500) { + t.Errorf("Scanning infinity, expected %q, got %q", y2500, resultT) + } + + err = db.QueryRow("SELECT $1::timestamptz", "infinity").Scan(&resultT) + if err != nil { + t.Errorf("Scanning infinity, expected no error, got %q", err) + } + if !resultT.Equal(y2500) { + t.Errorf("Scanning Infinity, expected time %q, got %q", y2500, resultT.String()) + } + + err = db.QueryRow("SELECT $1::timestamp", "-infinity").Scan(&resultT) + if err != nil { + t.Errorf("Scanning -infinity, expected no error, got %q", err) + } + if !resultT.Equal(y1500) { + t.Errorf("Scanning -infinity, expected time %q, got %q", y1500, resultT.String()) + } + + err = db.QueryRow("SELECT $1::timestamptz", "-infinity").Scan(&resultT) + if err != nil { + t.Errorf("Scanning -infinity, expected no error, got %q", err) + } + if !resultT.Equal(y1500) { + t.Errorf("Scanning -infinity, expected time %q, got %q", y1500, resultT.String()) + } + + ym1500 := time.Date(-1500, time.January, 1, 0, 0, 0, 0, time.UTC) + y11500 := time.Date(11500, time.January, 1, 0, 0, 0, 0, time.UTC) + var s string + err = db.QueryRow("SELECT $1::timestamp::text", ym1500).Scan(&s) + if err != nil { + t.Errorf("Encoding -infinity, expected no error, got %q", err) + } + if s != "-infinity" { + t.Errorf("Encoding -infinity, expected %q, got %q", "-infinity", s) + } + err = db.QueryRow("SELECT $1::timestamptz::text", ym1500).Scan(&s) + if err != nil { + t.Errorf("Encoding -infinity, expected no error, got %q", err) + } + if s != "-infinity" { + t.Errorf("Encoding -infinity, expected %q, got %q", "-infinity", s) + } + + err = db.QueryRow("SELECT $1::timestamp::text", y11500).Scan(&s) + if err != nil { + t.Errorf("Encoding infinity, expected no error, got %q", err) + } + if s != "infinity" { + t.Errorf("Encoding infinity, expected %q, got %q", "infinity", s) + } + err = db.QueryRow("SELECT $1::timestamptz::text", y11500).Scan(&s) + if err != nil { + t.Errorf("Encoding infinity, expected no error, got %q", err) + } + if s != "infinity" { + t.Errorf("Encoding infinity, expected %q, got %q", "infinity", s) + } + + disableInfinityTs() + + var panicErrorString string + func() { + defer func() { + panicErrorString, _ = recover().(string) + }() + EnableInfinityTs(y2500, y1500) + }() + if panicErrorString != infinityTsNegativeMustBeSmaller { + t.Errorf("Expected error, %q, got %q", infinityTsNegativeMustBeSmaller, panicErrorString) + } +} + +func TestStringWithNul(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + hello0world := string("hello\x00world") + _, err := db.Query("SELECT $1::text", &hello0world) + if err == nil { + t.Fatal("Postgres accepts a string with nul in it; " + + "injection attacks may be plausible") + } +} + +func TestByteSliceToText(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + b := []byte("hello world") + row := db.QueryRow("SELECT $1::text", b) + + var result []byte + err := row.Scan(&result) + if err != nil { + t.Fatal(err) + } + + if string(result) != string(b) { + t.Fatalf("expected %v but got %v", b, result) + } +} + +func TestStringToBytea(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + b := "hello world" + row := db.QueryRow("SELECT $1::bytea", b) + + var result []byte + err := row.Scan(&result) + if err != nil { + t.Fatal(err) + } + + if !bytes.Equal(result, []byte(b)) { + t.Fatalf("expected %v but got %v", b, result) + } +} + +func TestTextByteSliceToUUID(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + b := []byte("a0eebc99-9c0b-4ef8-bb6d-6bb9bd380a11") + row := db.QueryRow("SELECT $1::uuid", b) + + var result string + err := row.Scan(&result) + if forceBinaryParameters() { + pqErr := err.(*Error) + if pqErr == nil { + t.Errorf("Expected to get error") + } else if pqErr.Code != "22P03" { + t.Fatalf("Expected to get invalid binary encoding error (22P03), got %s", pqErr.Code) + } + } else { + if err != nil { + t.Fatal(err) + } + + if result != string(b) { + t.Fatalf("expected %v but got %v", b, result) + } + } +} + +func TestBinaryByteSlicetoUUID(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + b := []byte{'\xa0', '\xee', '\xbc', '\x99', + '\x9c', '\x0b', + '\x4e', '\xf8', + '\xbb', '\x00', '\x6b', + '\xb9', '\xbd', '\x38', '\x0a', '\x11'} + row := db.QueryRow("SELECT $1::uuid", b) + + var result string + err := row.Scan(&result) + if forceBinaryParameters() { + if err != nil { + t.Fatal(err) + } + + if result != string("a0eebc99-9c0b-4ef8-bb00-6bb9bd380a11") { + t.Fatalf("expected %v but got %v", b, result) + } + } else { + pqErr := err.(*Error) + if pqErr == nil { + t.Errorf("Expected to get error") + } else if pqErr.Code != "22021" { + t.Fatalf("Expected to get invalid byte sequence for encoding error (22021), got %s", pqErr.Code) + } + } +} + +func TestStringToUUID(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + s := "a0eebc99-9c0b-4ef8-bb00-6bb9bd380a11" + row := db.QueryRow("SELECT $1::uuid", s) + + var result string + err := row.Scan(&result) + if err != nil { + t.Fatal(err) + } + + if result != s { + t.Fatalf("expected %v but got %v", s, result) + } +} + +func TestTextByteSliceToInt(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + expected := 12345678 + b := []byte(fmt.Sprintf("%d", expected)) + row := db.QueryRow("SELECT $1::int", b) + + var result int + err := row.Scan(&result) + if forceBinaryParameters() { + pqErr := err.(*Error) + if pqErr == nil { + t.Errorf("Expected to get error") + } else if pqErr.Code != "22P03" { + t.Fatalf("Expected to get invalid binary encoding error (22P03), got %s", pqErr.Code) + } + } else { + if err != nil { + t.Fatal(err) + } + if result != expected { + t.Fatalf("expected %v but got %v", expected, result) + } + } +} + +func TestBinaryByteSliceToInt(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + expected := 12345678 + b := []byte{'\x00', '\xbc', '\x61', '\x4e'} + row := db.QueryRow("SELECT $1::int", b) + + var result int + err := row.Scan(&result) + if forceBinaryParameters() { + if err != nil { + t.Fatal(err) + } + if result != expected { + t.Fatalf("expected %v but got %v", expected, result) + } + } else { + pqErr := err.(*Error) + if pqErr == nil { + t.Errorf("Expected to get error") + } else if pqErr.Code != "22021" { + t.Fatalf("Expected to get invalid byte sequence for encoding error (22021), got %s", pqErr.Code) + } + } +} + +func TestTextDecodeIntoString(t *testing.T) { + input := []byte("hello world") + want := string(input) + for _, typ := range []oid.Oid{oid.T_char, oid.T_varchar, oid.T_text} { + got := decode(¶meterStatus{}, input, typ, formatText) + if got != want { + t.Errorf("invalid string decoding output for %T(%+v), got %v but expected %v", typ, typ, got, want) + } + } +} + +func TestByteaOutputFormatEncoding(t *testing.T) { + input := []byte("\\x\x00\x01\x02\xFF\xFEabcdefg0123") + want := []byte("\\x5c78000102fffe6162636465666730313233") + got := encode(¶meterStatus{serverVersion: 90000}, input, oid.T_bytea) + if !bytes.Equal(want, got) { + t.Errorf("invalid hex bytea output, got %v but expected %v", got, want) + } + + want = []byte("\\\\x\\000\\001\\002\\377\\376abcdefg0123") + got = encode(¶meterStatus{serverVersion: 84000}, input, oid.T_bytea) + if !bytes.Equal(want, got) { + t.Errorf("invalid escape bytea output, got %v but expected %v", got, want) + } +} + +func TestByteaOutputFormats(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + if getServerVersion(t, db) < 90000 { + // skip + return + } + + testByteaOutputFormat := func(f string, usePrepared bool) { + expectedData := []byte("\x5c\x78\x00\xff\x61\x62\x63\x01\x08") + sqlQuery := "SELECT decode('5c7800ff6162630108', 'hex')" + + var data []byte + + // use a txn to avoid relying on getting the same connection + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + defer txn.Rollback() + + _, err = txn.Exec("SET LOCAL bytea_output TO " + f) + if err != nil { + t.Fatal(err) + } + var rows *sql.Rows + var stmt *sql.Stmt + if usePrepared { + stmt, err = txn.Prepare(sqlQuery) + if err != nil { + t.Fatal(err) + } + rows, err = stmt.Query() + } else { + // use Query; QueryRow would hide the actual error + rows, err = txn.Query(sqlQuery) + } + if err != nil { + t.Fatal(err) + } + if !rows.Next() { + if rows.Err() != nil { + t.Fatal(rows.Err()) + } + t.Fatal("shouldn't happen") + } + err = rows.Scan(&data) + if err != nil { + t.Fatal(err) + } + err = rows.Close() + if err != nil { + t.Fatal(err) + } + if stmt != nil { + err = stmt.Close() + if err != nil { + t.Fatal(err) + } + } + if !bytes.Equal(data, expectedData) { + t.Errorf("unexpected bytea value %v for format %s; expected %v", data, f, expectedData) + } + } + + testByteaOutputFormat("hex", false) + testByteaOutputFormat("escape", false) + testByteaOutputFormat("hex", true) + testByteaOutputFormat("escape", true) +} + +func TestAppendEncodedText(t *testing.T) { + var buf []byte + + buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, int64(10)) + buf = append(buf, '\t') + buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, 42.0000000001) + buf = append(buf, '\t') + buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, "hello\tworld") + buf = append(buf, '\t') + buf = appendEncodedText(¶meterStatus{serverVersion: 90000}, buf, []byte{0, 128, 255}) + + if string(buf) != "10\t42.0000000001\thello\\tworld\t\\\\x0080ff" { + t.Fatal(string(buf)) + } +} + +func TestAppendEscapedText(t *testing.T) { + if esc := appendEscapedText(nil, "hallo\tescape"); string(esc) != "hallo\\tescape" { + t.Fatal(string(esc)) + } + if esc := appendEscapedText(nil, "hallo\\tescape\n"); string(esc) != "hallo\\\\tescape\\n" { + t.Fatal(string(esc)) + } + if esc := appendEscapedText(nil, "\n\r\t\f"); string(esc) != "\\n\\r\\t\f" { + t.Fatal(string(esc)) + } +} + +func TestAppendEscapedTextExistingBuffer(t *testing.T) { + buf := []byte("123\t") + if esc := appendEscapedText(buf, "hallo\tescape"); string(esc) != "123\thallo\\tescape" { + t.Fatal(string(esc)) + } + buf = []byte("123\t") + if esc := appendEscapedText(buf, "hallo\\tescape\n"); string(esc) != "123\thallo\\\\tescape\\n" { + t.Fatal(string(esc)) + } + buf = []byte("123\t") + if esc := appendEscapedText(buf, "\n\r\t\f"); string(esc) != "123\t\\n\\r\\t\f" { + t.Fatal(string(esc)) + } +} + +func BenchmarkAppendEscapedText(b *testing.B) { + longString := "" + for i := 0; i < 100; i++ { + longString += "123456789\n" + } + for i := 0; i < b.N; i++ { + appendEscapedText(nil, longString) + } +} + +func BenchmarkAppendEscapedTextNoEscape(b *testing.B) { + longString := "" + for i := 0; i < 100; i++ { + longString += "1234567890" + } + for i := 0; i < b.N; i++ { + appendEscapedText(nil, longString) + } +} diff --git a/vendor/github.com/lib/pq/example/listen/doc.go b/vendor/github.com/lib/pq/example/listen/doc.go new file mode 100644 index 00000000..91e2ddba --- /dev/null +++ b/vendor/github.com/lib/pq/example/listen/doc.go @@ -0,0 +1,98 @@ +/* + +Package listen is a self-contained Go program which uses the LISTEN / NOTIFY +mechanism to avoid polling the database while waiting for more work to arrive. + + // + // You can see the program in action by defining a function similar to + // the following: + // + // CREATE OR REPLACE FUNCTION public.get_work() + // RETURNS bigint + // LANGUAGE sql + // AS $$ + // SELECT CASE WHEN random() >= 0.2 THEN int8 '1' END + // $$ + // ; + + package main + + import ( + "database/sql" + "fmt" + "time" + + "github.com/lib/pq" + ) + + func doWork(db *sql.DB, work int64) { + // work here + } + + func getWork(db *sql.DB) { + for { + // get work from the database here + var work sql.NullInt64 + err := db.QueryRow("SELECT get_work()").Scan(&work) + if err != nil { + fmt.Println("call to get_work() failed: ", err) + time.Sleep(10 * time.Second) + continue + } + if !work.Valid { + // no more work to do + fmt.Println("ran out of work") + return + } + + fmt.Println("starting work on ", work.Int64) + go doWork(db, work.Int64) + } + } + + func waitForNotification(l *pq.Listener) { + select { + case <-l.Notify: + fmt.Println("received notification, new work available") + case <-time.After(90 * time.Second): + go l.Ping() + // Check if there's more work available, just in case it takes + // a while for the Listener to notice connection loss and + // reconnect. + fmt.Println("received no work for 90 seconds, checking for new work") + } + } + + func main() { + var conninfo string = "" + + db, err := sql.Open("postgres", conninfo) + if err != nil { + panic(err) + } + + reportProblem := func(ev pq.ListenerEventType, err error) { + if err != nil { + fmt.Println(err.Error()) + } + } + + minReconn := 10 * time.Second + maxReconn := time.Minute + listener := pq.NewListener(conninfo, minReconn, maxReconn, reportProblem) + err = listener.Listen("getwork") + if err != nil { + panic(err) + } + + fmt.Println("entering main loop") + for { + // process all available work before waiting for notifications + getWork(db) + waitForNotification(listener) + } + } + + +*/ +package listen diff --git a/vendor/github.com/lib/pq/go18_test.go b/vendor/github.com/lib/pq/go18_test.go new file mode 100644 index 00000000..72cd71fe --- /dev/null +++ b/vendor/github.com/lib/pq/go18_test.go @@ -0,0 +1,319 @@ +package pq + +import ( + "context" + "database/sql" + "runtime" + "strings" + "testing" + "time" +) + +func TestMultipleSimpleQuery(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query("select 1; set time zone default; select 2; select 3") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + var i int + for rows.Next() { + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 1 { + t.Fatalf("expected 1, got %d", i) + } + } + if !rows.NextResultSet() { + t.Fatal("expected more result sets", rows.Err()) + } + for rows.Next() { + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 2 { + t.Fatalf("expected 2, got %d", i) + } + } + + // Make sure that if we ignore a result we can still query. + + rows, err = db.Query("select 4; select 5") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + for rows.Next() { + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 4 { + t.Fatalf("expected 4, got %d", i) + } + } + if !rows.NextResultSet() { + t.Fatal("expected more result sets", rows.Err()) + } + for rows.Next() { + if err := rows.Scan(&i); err != nil { + t.Fatal(err) + } + if i != 5 { + t.Fatalf("expected 5, got %d", i) + } + } + if rows.NextResultSet() { + t.Fatal("unexpected result set") + } +} + +const contextRaceIterations = 100 + +func TestContextCancelExec(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(time.Millisecond*10, cancel).Stop() + + // Not canceled until after the exec has started. + if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil { + t.Fatal("expected error") + } else if err.Error() != "pq: canceling statement due to user request" { + t.Fatalf("unexpected error: %s", err) + } + + // Context is already canceled, so error should come before execution. + if _, err := db.ExecContext(ctx, "select pg_sleep(1)"); err == nil { + t.Fatal("expected error") + } else if err.Error() != "context canceled" { + t.Fatalf("unexpected error: %s", err) + } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if _, err := db.ExecContext(ctx, "select 1"); err != nil { + t.Fatal(err) + } + }() + + if _, err := db.Exec("select 1"); err != nil { + t.Fatal(err) + } + } +} + +func TestContextCancelQuery(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.QueryContext has begun. + defer time.AfterFunc(time.Millisecond*10, cancel).Stop() + + // Not canceled until after the exec has started. + if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil { + t.Fatal("expected error") + } else if err.Error() != "pq: canceling statement due to user request" { + t.Fatalf("unexpected error: %s", err) + } + + // Context is already canceled, so error should come before execution. + if _, err := db.QueryContext(ctx, "select pg_sleep(1)"); err == nil { + t.Fatal("expected error") + } else if err.Error() != "context canceled" { + t.Fatalf("unexpected error: %s", err) + } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + rows, err := db.QueryContext(ctx, "select 1") + cancel() + if err != nil { + t.Fatal(err) + } else if err := rows.Close(); err != nil { + t.Fatal(err) + } + }() + + if rows, err := db.Query("select 1"); err != nil { + t.Fatal(err) + } else if err := rows.Close(); err != nil { + t.Fatal(err) + } + } +} + +// TestIssue617 tests that a failed query in QueryContext doesn't lead to a +// goroutine leak. +func TestIssue617(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + const N = 10 + + numGoroutineStart := runtime.NumGoroutine() + for i := 0; i < N; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + _, err := db.QueryContext(ctx, `SELECT * FROM DOESNOTEXIST`) + pqErr, _ := err.(*Error) + // Expecting "pq: relation \"doesnotexist\" does not exist" error. + if err == nil || pqErr == nil || pqErr.Code != "42P01" { + t.Fatalf("expected undefined table error, got %v", err) + } + }() + } + numGoroutineFinish := runtime.NumGoroutine() + + // We use N/2 and not N because the GC and other actors may increase or + // decrease the number of goroutines. + if numGoroutineFinish-numGoroutineStart >= N/2 { + t.Errorf("goroutine leak detected, was %d, now %d", numGoroutineStart, numGoroutineFinish) + } +} + +func TestContextCancelBegin(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + ctx, cancel := context.WithCancel(context.Background()) + tx, err := db.BeginTx(ctx, nil) + if err != nil { + t.Fatal(err) + } + + // Delay execution for just a bit until tx.Exec has begun. + defer time.AfterFunc(time.Millisecond*10, cancel).Stop() + + // Not canceled until after the exec has started. + if _, err := tx.Exec("select pg_sleep(1)"); err == nil { + t.Fatal("expected error") + } else if err.Error() != "pq: canceling statement due to user request" { + t.Fatalf("unexpected error: %s", err) + } + + // Transaction is canceled, so expect an error. + if _, err := tx.Query("select pg_sleep(1)"); err == nil { + t.Fatal("expected error") + } else if err != sql.ErrTxDone { + t.Fatalf("unexpected error: %s", err) + } + + // Context is canceled, so cannot begin a transaction. + if _, err := db.BeginTx(ctx, nil); err == nil { + t.Fatal("expected error") + } else if err.Error() != "context canceled" { + t.Fatalf("unexpected error: %s", err) + } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + tx, err := db.BeginTx(ctx, nil) + cancel() + if err != nil { + t.Fatal(err) + } else if err := tx.Rollback(); err != nil && + err.Error() != "pq: canceling statement due to user request" && + err != sql.ErrTxDone { + t.Fatal(err) + } + }() + + if tx, err := db.Begin(); err != nil { + t.Fatal(err) + } else if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + } +} + +func TestTxOptions(t *testing.T) { + db := openTestConn(t) + defer db.Close() + ctx := context.Background() + + tests := []struct { + level sql.IsolationLevel + isolation string + }{ + { + level: sql.LevelDefault, + isolation: "", + }, + { + level: sql.LevelReadUncommitted, + isolation: "read uncommitted", + }, + { + level: sql.LevelReadCommitted, + isolation: "read committed", + }, + { + level: sql.LevelRepeatableRead, + isolation: "repeatable read", + }, + { + level: sql.LevelSerializable, + isolation: "serializable", + }, + } + + for _, test := range tests { + for _, ro := range []bool{true, false} { + tx, err := db.BeginTx(ctx, &sql.TxOptions{ + Isolation: test.level, + ReadOnly: ro, + }) + if err != nil { + t.Fatal(err) + } + + var isolation string + err = tx.QueryRow("select current_setting('transaction_isolation')").Scan(&isolation) + if err != nil { + t.Fatal(err) + } + + if test.isolation != "" && isolation != test.isolation { + t.Errorf("wrong isolation level: %s != %s", isolation, test.isolation) + } + + var isRO string + err = tx.QueryRow("select current_setting('transaction_read_only')").Scan(&isRO) + if err != nil { + t.Fatal(err) + } + + if ro != (isRO == "on") { + t.Errorf("read/[write,only] not set: %t != %s for level %s", + ro, isRO, test.isolation) + } + + tx.Rollback() + } + } + + _, err := db.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelLinearizable, + }) + if err == nil { + t.Fatal("expected LevelLinearizable to fail") + } + if !strings.Contains(err.Error(), "isolation level not supported") { + t.Errorf("Expected error to mention isolation level, got %q", err) + } +} diff --git a/vendor/github.com/lib/pq/go19_test.go b/vendor/github.com/lib/pq/go19_test.go new file mode 100644 index 00000000..1949249d --- /dev/null +++ b/vendor/github.com/lib/pq/go19_test.go @@ -0,0 +1,69 @@ +// +build go1.9 + +package pq + +import ( + "context" + "database/sql" + "database/sql/driver" + "reflect" + "testing" +) + +func TestPing(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + db := openTestConn(t) + defer db.Close() + + if _, ok := reflect.TypeOf(db).MethodByName("Conn"); !ok { + t.Skipf("Conn method undefined on type %T, skipping test (requires at least go1.9)", db) + } + + if err := db.PingContext(ctx); err != nil { + t.Fatal("expected Ping to succeed") + } + defer cancel() + + // grab a connection + conn, err := db.Conn(ctx) + if err != nil { + t.Fatal(err) + } + + // start a transaction and read backend pid of our connection + tx, err := conn.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelDefault, + ReadOnly: true, + }) + if err != nil { + t.Fatal(err) + } + + rows, err := tx.Query("SELECT pg_backend_pid()") + if err != nil { + t.Fatal(err) + } + defer rows.Close() + + // read the pid from result + var pid int + for rows.Next() { + if err := rows.Scan(&pid); err != nil { + t.Fatal(err) + } + } + if rows.Err() != nil { + t.Fatal(err) + } + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + // kill the process which handles our connection and test if the ping fails + if _, err := db.Exec("SELECT pg_terminate_backend($1)", pid); err != nil { + t.Fatal(err) + } + if err := conn.PingContext(ctx); err != driver.ErrBadConn { + t.Fatalf("expected error %s, instead got %s", driver.ErrBadConn, err) + } +} diff --git a/vendor/github.com/lib/pq/hstore/hstore.go b/vendor/github.com/lib/pq/hstore/hstore.go new file mode 100644 index 00000000..f1470db1 --- /dev/null +++ b/vendor/github.com/lib/pq/hstore/hstore.go @@ -0,0 +1,118 @@ +package hstore + +import ( + "database/sql" + "database/sql/driver" + "strings" +) + +// Hstore is a wrapper for transferring Hstore values back and forth easily. +type Hstore struct { + Map map[string]sql.NullString +} + +// escapes and quotes hstore keys/values +// s should be a sql.NullString or string +func hQuote(s interface{}) string { + var str string + switch v := s.(type) { + case sql.NullString: + if !v.Valid { + return "NULL" + } + str = v.String + case string: + str = v + default: + panic("not a string or sql.NullString") + } + + str = strings.Replace(str, "\\", "\\\\", -1) + return `"` + strings.Replace(str, "\"", "\\\"", -1) + `"` +} + +// Scan implements the Scanner interface. +// +// Note h.Map is reallocated before the scan to clear existing values. If the +// hstore column's database value is NULL, then h.Map is set to nil instead. +func (h *Hstore) Scan(value interface{}) error { + if value == nil { + h.Map = nil + return nil + } + h.Map = make(map[string]sql.NullString) + var b byte + pair := [][]byte{{}, {}} + pi := 0 + inQuote := false + didQuote := false + sawSlash := false + bindex := 0 + for bindex, b = range value.([]byte) { + if sawSlash { + pair[pi] = append(pair[pi], b) + sawSlash = false + continue + } + + switch b { + case '\\': + sawSlash = true + continue + case '"': + inQuote = !inQuote + if !didQuote { + didQuote = true + } + continue + default: + if !inQuote { + switch b { + case ' ', '\t', '\n', '\r': + continue + case '=': + continue + case '>': + pi = 1 + didQuote = false + continue + case ',': + s := string(pair[1]) + if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { + h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} + } else { + h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} + } + pair[0] = []byte{} + pair[1] = []byte{} + pi = 0 + continue + } + } + } + pair[pi] = append(pair[pi], b) + } + if bindex > 0 { + s := string(pair[1]) + if !didQuote && len(s) == 4 && strings.ToLower(s) == "null" { + h.Map[string(pair[0])] = sql.NullString{String: "", Valid: false} + } else { + h.Map[string(pair[0])] = sql.NullString{String: string(pair[1]), Valid: true} + } + } + return nil +} + +// Value implements the driver Valuer interface. Note if h.Map is nil, the +// database column value will be set to NULL. +func (h Hstore) Value() (driver.Value, error) { + if h.Map == nil { + return nil, nil + } + parts := []string{} + for key, val := range h.Map { + thispart := hQuote(key) + "=>" + hQuote(val) + parts = append(parts, thispart) + } + return []byte(strings.Join(parts, ",")), nil +} diff --git a/vendor/github.com/lib/pq/hstore/hstore_test.go b/vendor/github.com/lib/pq/hstore/hstore_test.go new file mode 100644 index 00000000..1c9f2bd4 --- /dev/null +++ b/vendor/github.com/lib/pq/hstore/hstore_test.go @@ -0,0 +1,148 @@ +package hstore + +import ( + "database/sql" + "os" + "testing" + + _ "github.com/lib/pq" +) + +type Fatalistic interface { + Fatal(args ...interface{}) +} + +func openTestConn(t Fatalistic) *sql.DB { + datname := os.Getenv("PGDATABASE") + sslmode := os.Getenv("PGSSLMODE") + + if datname == "" { + os.Setenv("PGDATABASE", "pqgotest") + } + + if sslmode == "" { + os.Setenv("PGSSLMODE", "disable") + } + + conn, err := sql.Open("postgres", "") + if err != nil { + t.Fatal(err) + } + + return conn +} + +func TestHstore(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + // quitely create hstore if it doesn't exist + _, err := db.Exec("CREATE EXTENSION IF NOT EXISTS hstore") + if err != nil { + t.Skipf("Skipping hstore tests - hstore extension create failed: %s", err.Error()) + } + + hs := Hstore{} + + // test for null-valued hstores + err = db.QueryRow("SELECT NULL::hstore").Scan(&hs) + if err != nil { + t.Fatal(err) + } + if hs.Map != nil { + t.Fatalf("expected null map") + } + + err = db.QueryRow("SELECT $1::hstore", hs).Scan(&hs) + if err != nil { + t.Fatalf("re-query null map failed: %s", err.Error()) + } + if hs.Map != nil { + t.Fatalf("expected null map") + } + + // test for empty hstores + err = db.QueryRow("SELECT ''::hstore").Scan(&hs) + if err != nil { + t.Fatal(err) + } + if hs.Map == nil { + t.Fatalf("expected empty map, got null map") + } + if len(hs.Map) != 0 { + t.Fatalf("expected empty map, got len(map)=%d", len(hs.Map)) + } + + err = db.QueryRow("SELECT $1::hstore", hs).Scan(&hs) + if err != nil { + t.Fatalf("re-query empty map failed: %s", err.Error()) + } + if hs.Map == nil { + t.Fatalf("expected empty map, got null map") + } + if len(hs.Map) != 0 { + t.Fatalf("expected empty map, got len(map)=%d", len(hs.Map)) + } + + // a few example maps to test out + hsOnePair := Hstore{ + Map: map[string]sql.NullString{ + "key1": {String: "value1", Valid: true}, + }, + } + + hsThreePairs := Hstore{ + Map: map[string]sql.NullString{ + "key1": {String: "value1", Valid: true}, + "key2": {String: "value2", Valid: true}, + "key3": {String: "value3", Valid: true}, + }, + } + + hsSmorgasbord := Hstore{ + Map: map[string]sql.NullString{ + "nullstring": {String: "NULL", Valid: true}, + "actuallynull": {String: "", Valid: false}, + "NULL": {String: "NULL string key", Valid: true}, + "withbracket": {String: "value>42", Valid: true}, + "withequal": {String: "value=42", Valid: true}, + `"withquotes1"`: {String: `this "should" be fine`, Valid: true}, + `"withquotes"2"`: {String: `this "should\" also be fine`, Valid: true}, + "embedded1": {String: "value1=>x1", Valid: true}, + "embedded2": {String: `"value2"=>x2`, Valid: true}, + "withnewlines": {String: "\n\nvalue\t=>2", Valid: true}, + "<>": {String: `this, "should,\" also, => be fine`, Valid: true}, + }, + } + + // test encoding in query params, then decoding during Scan + testBidirectional := func(h Hstore) { + err = db.QueryRow("SELECT $1::hstore", h).Scan(&hs) + if err != nil { + t.Fatalf("re-query %d-pair map failed: %s", len(h.Map), err.Error()) + } + if hs.Map == nil { + t.Fatalf("expected %d-pair map, got null map", len(h.Map)) + } + if len(hs.Map) != len(h.Map) { + t.Fatalf("expected %d-pair map, got len(map)=%d", len(h.Map), len(hs.Map)) + } + + for key, val := range hs.Map { + otherval, found := h.Map[key] + if !found { + t.Fatalf(" key '%v' not found in %d-pair map", key, len(h.Map)) + } + if otherval.Valid != val.Valid { + t.Fatalf(" value %v <> %v in %d-pair map", otherval, val, len(h.Map)) + } + if otherval.String != val.String { + t.Fatalf(" value '%v' <> '%v' in %d-pair map", otherval.String, val.String, len(h.Map)) + } + } + } + + testBidirectional(hsOnePair) + testBidirectional(hsThreePairs) + testBidirectional(hsSmorgasbord) +} diff --git a/vendor/github.com/lib/pq/issues_test.go b/vendor/github.com/lib/pq/issues_test.go new file mode 100644 index 00000000..3a330a0a --- /dev/null +++ b/vendor/github.com/lib/pq/issues_test.go @@ -0,0 +1,26 @@ +package pq + +import "testing" + +func TestIssue494(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + query := `CREATE TEMP TABLE t (i INT PRIMARY KEY)` + if _, err := db.Exec(query); err != nil { + t.Fatal(err) + } + + txn, err := db.Begin() + if err != nil { + t.Fatal(err) + } + + if _, err := txn.Prepare(CopyIn("t", "i")); err != nil { + t.Fatal(err) + } + + if _, err := txn.Query("SELECT 1"); err == nil { + t.Fatal("expected error") + } +} diff --git a/vendor/github.com/lib/pq/notify_test.go b/vendor/github.com/lib/pq/notify_test.go new file mode 100644 index 00000000..075666dd --- /dev/null +++ b/vendor/github.com/lib/pq/notify_test.go @@ -0,0 +1,570 @@ +package pq + +import ( + "errors" + "fmt" + "io" + "os" + "runtime" + "sync" + "testing" + "time" +) + +var errNilNotification = errors.New("nil notification") + +func expectNotification(t *testing.T, ch <-chan *Notification, relname string, extra string) error { + select { + case n := <-ch: + if n == nil { + return errNilNotification + } + if n.Channel != relname || n.Extra != extra { + return fmt.Errorf("unexpected notification %v", n) + } + return nil + case <-time.After(1500 * time.Millisecond): + return fmt.Errorf("timeout") + } +} + +func expectNoNotification(t *testing.T, ch <-chan *Notification) error { + select { + case n := <-ch: + return fmt.Errorf("unexpected notification %v", n) + case <-time.After(100 * time.Millisecond): + return nil + } +} + +func expectEvent(t *testing.T, eventch <-chan ListenerEventType, et ListenerEventType) error { + select { + case e := <-eventch: + if e != et { + return fmt.Errorf("unexpected event %v", e) + } + return nil + case <-time.After(1500 * time.Millisecond): + panic("expectEvent timeout") + } +} + +func expectNoEvent(t *testing.T, eventch <-chan ListenerEventType) error { + select { + case e := <-eventch: + return fmt.Errorf("unexpected event %v", e) + case <-time.After(100 * time.Millisecond): + return nil + } +} + +func newTestListenerConn(t *testing.T) (*ListenerConn, <-chan *Notification) { + datname := os.Getenv("PGDATABASE") + sslmode := os.Getenv("PGSSLMODE") + + if datname == "" { + os.Setenv("PGDATABASE", "pqgotest") + } + + if sslmode == "" { + os.Setenv("PGSSLMODE", "disable") + } + + notificationChan := make(chan *Notification) + l, err := NewListenerConn("", notificationChan) + if err != nil { + t.Fatal(err) + } + + return l, notificationChan +} + +func TestNewListenerConn(t *testing.T) { + l, _ := newTestListenerConn(t) + + defer l.Close() +} + +func TestConnListen(t *testing.T) { + l, channel := newTestListenerConn(t) + + defer l.Close() + + db := openTestConn(t) + defer db.Close() + + ok, err := l.Listen("notify_test") + if !ok || err != nil { + t.Fatal(err) + } + + _, err = db.Exec("NOTIFY notify_test") + if err != nil { + t.Fatal(err) + } + + err = expectNotification(t, channel, "notify_test", "") + if err != nil { + t.Fatal(err) + } +} + +func TestConnUnlisten(t *testing.T) { + l, channel := newTestListenerConn(t) + + defer l.Close() + + db := openTestConn(t) + defer db.Close() + + ok, err := l.Listen("notify_test") + if !ok || err != nil { + t.Fatal(err) + } + + _, err = db.Exec("NOTIFY notify_test") + if err != nil { + t.Fatal(err) + } + + err = expectNotification(t, channel, "notify_test", "") + if err != nil { + t.Fatal(err) + } + + ok, err = l.Unlisten("notify_test") + if !ok || err != nil { + t.Fatal(err) + } + + _, err = db.Exec("NOTIFY notify_test") + if err != nil { + t.Fatal(err) + } + + err = expectNoNotification(t, channel) + if err != nil { + t.Fatal(err) + } +} + +func TestConnUnlistenAll(t *testing.T) { + l, channel := newTestListenerConn(t) + + defer l.Close() + + db := openTestConn(t) + defer db.Close() + + ok, err := l.Listen("notify_test") + if !ok || err != nil { + t.Fatal(err) + } + + _, err = db.Exec("NOTIFY notify_test") + if err != nil { + t.Fatal(err) + } + + err = expectNotification(t, channel, "notify_test", "") + if err != nil { + t.Fatal(err) + } + + ok, err = l.UnlistenAll() + if !ok || err != nil { + t.Fatal(err) + } + + _, err = db.Exec("NOTIFY notify_test") + if err != nil { + t.Fatal(err) + } + + err = expectNoNotification(t, channel) + if err != nil { + t.Fatal(err) + } +} + +func TestConnClose(t *testing.T) { + l, _ := newTestListenerConn(t) + defer l.Close() + + err := l.Close() + if err != nil { + t.Fatal(err) + } + err = l.Close() + if err != errListenerConnClosed { + t.Fatalf("expected errListenerConnClosed; got %v", err) + } +} + +func TestConnPing(t *testing.T) { + l, _ := newTestListenerConn(t) + defer l.Close() + err := l.Ping() + if err != nil { + t.Fatal(err) + } + err = l.Close() + if err != nil { + t.Fatal(err) + } + err = l.Ping() + if err != errListenerConnClosed { + t.Fatalf("expected errListenerConnClosed; got %v", err) + } +} + +// Test for deadlock where a query fails while another one is queued +func TestConnExecDeadlock(t *testing.T) { + l, _ := newTestListenerConn(t) + defer l.Close() + + var wg sync.WaitGroup + wg.Add(2) + + go func() { + l.ExecSimpleQuery("SELECT pg_sleep(60)") + wg.Done() + }() + runtime.Gosched() + go func() { + l.ExecSimpleQuery("SELECT 1") + wg.Done() + }() + // give the two goroutines some time to get into position + runtime.Gosched() + // calls Close on the net.Conn; equivalent to a network failure + l.Close() + + defer time.AfterFunc(10*time.Second, func() { + panic("timed out") + }).Stop() + wg.Wait() +} + +// Test for ListenerConn being closed while a slow query is executing +func TestListenerConnCloseWhileQueryIsExecuting(t *testing.T) { + l, _ := newTestListenerConn(t) + defer l.Close() + + var wg sync.WaitGroup + wg.Add(1) + + go func() { + sent, err := l.ExecSimpleQuery("SELECT pg_sleep(60)") + if sent { + panic("expected sent=false") + } + // could be any of a number of errors + if err == nil { + panic("expected error") + } + wg.Done() + }() + // give the above goroutine some time to get into position + runtime.Gosched() + err := l.Close() + if err != nil { + t.Fatal(err) + } + + defer time.AfterFunc(10*time.Second, func() { + panic("timed out") + }).Stop() + wg.Wait() +} + +func TestNotifyExtra(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + if getServerVersion(t, db) < 90000 { + t.Skip("skipping NOTIFY payload test since the server does not appear to support it") + } + + l, channel := newTestListenerConn(t) + defer l.Close() + + ok, err := l.Listen("notify_test") + if !ok || err != nil { + t.Fatal(err) + } + + _, err = db.Exec("NOTIFY notify_test, 'something'") + if err != nil { + t.Fatal(err) + } + + err = expectNotification(t, channel, "notify_test", "something") + if err != nil { + t.Fatal(err) + } +} + +// create a new test listener and also set the timeouts +func newTestListenerTimeout(t *testing.T, min time.Duration, max time.Duration) (*Listener, <-chan ListenerEventType) { + datname := os.Getenv("PGDATABASE") + sslmode := os.Getenv("PGSSLMODE") + + if datname == "" { + os.Setenv("PGDATABASE", "pqgotest") + } + + if sslmode == "" { + os.Setenv("PGSSLMODE", "disable") + } + + eventch := make(chan ListenerEventType, 16) + l := NewListener("", min, max, func(t ListenerEventType, err error) { eventch <- t }) + err := expectEvent(t, eventch, ListenerEventConnected) + if err != nil { + t.Fatal(err) + } + return l, eventch +} + +func newTestListener(t *testing.T) (*Listener, <-chan ListenerEventType) { + return newTestListenerTimeout(t, time.Hour, time.Hour) +} + +func TestListenerListen(t *testing.T) { + l, _ := newTestListener(t) + defer l.Close() + + db := openTestConn(t) + defer db.Close() + + err := l.Listen("notify_listen_test") + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("NOTIFY notify_listen_test") + if err != nil { + t.Fatal(err) + } + + err = expectNotification(t, l.Notify, "notify_listen_test", "") + if err != nil { + t.Fatal(err) + } +} + +func TestListenerUnlisten(t *testing.T) { + l, _ := newTestListener(t) + defer l.Close() + + db := openTestConn(t) + defer db.Close() + + err := l.Listen("notify_listen_test") + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("NOTIFY notify_listen_test") + if err != nil { + t.Fatal(err) + } + + err = l.Unlisten("notify_listen_test") + if err != nil { + t.Fatal(err) + } + + err = expectNotification(t, l.Notify, "notify_listen_test", "") + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("NOTIFY notify_listen_test") + if err != nil { + t.Fatal(err) + } + + err = expectNoNotification(t, l.Notify) + if err != nil { + t.Fatal(err) + } +} + +func TestListenerUnlistenAll(t *testing.T) { + l, _ := newTestListener(t) + defer l.Close() + + db := openTestConn(t) + defer db.Close() + + err := l.Listen("notify_listen_test") + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("NOTIFY notify_listen_test") + if err != nil { + t.Fatal(err) + } + + err = l.UnlistenAll() + if err != nil { + t.Fatal(err) + } + + err = expectNotification(t, l.Notify, "notify_listen_test", "") + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("NOTIFY notify_listen_test") + if err != nil { + t.Fatal(err) + } + + err = expectNoNotification(t, l.Notify) + if err != nil { + t.Fatal(err) + } +} + +func TestListenerFailedQuery(t *testing.T) { + l, eventch := newTestListener(t) + defer l.Close() + + db := openTestConn(t) + defer db.Close() + + err := l.Listen("notify_listen_test") + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("NOTIFY notify_listen_test") + if err != nil { + t.Fatal(err) + } + + err = expectNotification(t, l.Notify, "notify_listen_test", "") + if err != nil { + t.Fatal(err) + } + + // shouldn't cause a disconnect + ok, err := l.cn.ExecSimpleQuery("SELECT error") + if !ok { + t.Fatalf("could not send query to server: %v", err) + } + _, ok = err.(PGError) + if !ok { + t.Fatalf("unexpected error %v", err) + } + err = expectNoEvent(t, eventch) + if err != nil { + t.Fatal(err) + } + + // should still work + _, err = db.Exec("NOTIFY notify_listen_test") + if err != nil { + t.Fatal(err) + } + + err = expectNotification(t, l.Notify, "notify_listen_test", "") + if err != nil { + t.Fatal(err) + } +} + +func TestListenerReconnect(t *testing.T) { + l, eventch := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) + defer l.Close() + + db := openTestConn(t) + defer db.Close() + + err := l.Listen("notify_listen_test") + if err != nil { + t.Fatal(err) + } + + _, err = db.Exec("NOTIFY notify_listen_test") + if err != nil { + t.Fatal(err) + } + + err = expectNotification(t, l.Notify, "notify_listen_test", "") + if err != nil { + t.Fatal(err) + } + + // kill the connection and make sure it comes back up + ok, err := l.cn.ExecSimpleQuery("SELECT pg_terminate_backend(pg_backend_pid())") + if ok { + t.Fatalf("could not kill the connection: %v", err) + } + if err != io.EOF { + t.Fatalf("unexpected error %v", err) + } + err = expectEvent(t, eventch, ListenerEventDisconnected) + if err != nil { + t.Fatal(err) + } + err = expectEvent(t, eventch, ListenerEventReconnected) + if err != nil { + t.Fatal(err) + } + + // should still work + _, err = db.Exec("NOTIFY notify_listen_test") + if err != nil { + t.Fatal(err) + } + + // should get nil after Reconnected + err = expectNotification(t, l.Notify, "", "") + if err != errNilNotification { + t.Fatal(err) + } + + err = expectNotification(t, l.Notify, "notify_listen_test", "") + if err != nil { + t.Fatal(err) + } +} + +func TestListenerClose(t *testing.T) { + l, _ := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) + defer l.Close() + + err := l.Close() + if err != nil { + t.Fatal(err) + } + err = l.Close() + if err != errListenerClosed { + t.Fatalf("expected errListenerClosed; got %v", err) + } +} + +func TestListenerPing(t *testing.T) { + l, _ := newTestListenerTimeout(t, 20*time.Millisecond, time.Hour) + defer l.Close() + + err := l.Ping() + if err != nil { + t.Fatal(err) + } + + err = l.Close() + if err != nil { + t.Fatal(err) + } + + err = l.Ping() + if err != errListenerClosed { + t.Fatalf("expected errListenerClosed; got %v", err) + } +} diff --git a/vendor/github.com/lib/pq/rows_test.go b/vendor/github.com/lib/pq/rows_test.go new file mode 100644 index 00000000..b3420a29 --- /dev/null +++ b/vendor/github.com/lib/pq/rows_test.go @@ -0,0 +1,218 @@ +package pq + +import ( + "math" + "reflect" + "testing" + + "github.com/lib/pq/oid" +) + +func TestDataTypeName(t *testing.T) { + tts := []struct { + typ oid.Oid + name string + }{ + {oid.T_int8, "INT8"}, + {oid.T_int4, "INT4"}, + {oid.T_int2, "INT2"}, + {oid.T_varchar, "VARCHAR"}, + {oid.T_text, "TEXT"}, + {oid.T_bool, "BOOL"}, + {oid.T_numeric, "NUMERIC"}, + {oid.T_date, "DATE"}, + {oid.T_time, "TIME"}, + {oid.T_timetz, "TIMETZ"}, + {oid.T_timestamp, "TIMESTAMP"}, + {oid.T_timestamptz, "TIMESTAMPTZ"}, + {oid.T_bytea, "BYTEA"}, + } + + for i, tt := range tts { + dt := fieldDesc{OID: tt.typ} + if name := dt.Name(); name != tt.name { + t.Errorf("(%d) got: %s want: %s", i, name, tt.name) + } + } +} + +func TestDataType(t *testing.T) { + tts := []struct { + typ oid.Oid + kind reflect.Kind + }{ + {oid.T_int8, reflect.Int64}, + {oid.T_int4, reflect.Int32}, + {oid.T_int2, reflect.Int16}, + {oid.T_varchar, reflect.String}, + {oid.T_text, reflect.String}, + {oid.T_bool, reflect.Bool}, + {oid.T_date, reflect.Struct}, + {oid.T_time, reflect.Struct}, + {oid.T_timetz, reflect.Struct}, + {oid.T_timestamp, reflect.Struct}, + {oid.T_timestamptz, reflect.Struct}, + {oid.T_bytea, reflect.Slice}, + } + + for i, tt := range tts { + dt := fieldDesc{OID: tt.typ} + if kind := dt.Type().Kind(); kind != tt.kind { + t.Errorf("(%d) got: %s want: %s", i, kind, tt.kind) + } + } +} + +func TestDataTypeLength(t *testing.T) { + tts := []struct { + typ oid.Oid + len int + mod int + length int64 + ok bool + }{ + {oid.T_int4, 0, -1, 0, false}, + {oid.T_varchar, 65535, 9, 5, true}, + {oid.T_text, 65535, -1, math.MaxInt64, true}, + {oid.T_bytea, 65535, -1, math.MaxInt64, true}, + } + + for i, tt := range tts { + dt := fieldDesc{OID: tt.typ, Len: tt.len, Mod: tt.mod} + if l, k := dt.Length(); k != tt.ok || l != tt.length { + t.Errorf("(%d) got: %d, %t want: %d, %t", i, l, k, tt.length, tt.ok) + } + } +} + +func TestDataTypePrecisionScale(t *testing.T) { + tts := []struct { + typ oid.Oid + mod int + precision, scale int64 + ok bool + }{ + {oid.T_int4, -1, 0, 0, false}, + {oid.T_numeric, 589830, 9, 2, true}, + {oid.T_text, -1, 0, 0, false}, + } + + for i, tt := range tts { + dt := fieldDesc{OID: tt.typ, Mod: tt.mod} + p, s, k := dt.PrecisionScale() + if k != tt.ok { + t.Errorf("(%d) got: %t want: %t", i, k, tt.ok) + } + if p != tt.precision { + t.Errorf("(%d) wrong precision got: %d want: %d", i, p, tt.precision) + } + if s != tt.scale { + t.Errorf("(%d) wrong scale got: %d want: %d", i, s, tt.scale) + } + } +} + +func TestRowsColumnTypes(t *testing.T) { + columnTypesTests := []struct { + Name string + TypeName string + Length struct { + Len int64 + OK bool + } + DecimalSize struct { + Precision int64 + Scale int64 + OK bool + } + ScanType reflect.Type + }{ + { + Name: "a", + TypeName: "INT4", + Length: struct { + Len int64 + OK bool + }{ + Len: 0, + OK: false, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(int32(0)), + }, { + Name: "bar", + TypeName: "TEXT", + Length: struct { + Len int64 + OK bool + }{ + Len: math.MaxInt64, + OK: true, + }, + DecimalSize: struct { + Precision int64 + Scale int64 + OK bool + }{ + Precision: 0, + Scale: 0, + OK: false, + }, + ScanType: reflect.TypeOf(""), + }, + } + + db := openTestConn(t) + defer db.Close() + + rows, err := db.Query("SELECT 1 AS a, text 'bar' AS bar, 1.28::numeric(9, 2) AS dec") + if err != nil { + t.Fatal(err) + } + + columns, err := rows.ColumnTypes() + if err != nil { + t.Fatal(err) + } + if len(columns) != 3 { + t.Errorf("expected 3 columns found %d", len(columns)) + } + + for i, tt := range columnTypesTests { + c := columns[i] + if c.Name() != tt.Name { + t.Errorf("(%d) got: %s, want: %s", i, c.Name(), tt.Name) + } + if c.DatabaseTypeName() != tt.TypeName { + t.Errorf("(%d) got: %s, want: %s", i, c.DatabaseTypeName(), tt.TypeName) + } + l, ok := c.Length() + if l != tt.Length.Len { + t.Errorf("(%d) got: %d, want: %d", i, l, tt.Length.Len) + } + if ok != tt.Length.OK { + t.Errorf("(%d) got: %t, want: %t", i, ok, tt.Length.OK) + } + p, s, ok := c.DecimalSize() + if p != tt.DecimalSize.Precision { + t.Errorf("(%d) got: %d, want: %d", i, p, tt.DecimalSize.Precision) + } + if s != tt.DecimalSize.Scale { + t.Errorf("(%d) got: %d, want: %d", i, s, tt.DecimalSize.Scale) + } + if ok != tt.DecimalSize.OK { + t.Errorf("(%d) got: %t, want: %t", i, ok, tt.DecimalSize.OK) + } + if c.ScanType() != tt.ScanType { + t.Errorf("(%d) got: %v, want: %v", i, c.ScanType(), tt.ScanType) + } + } +} diff --git a/vendor/github.com/lib/pq/scram/scram.go b/vendor/github.com/lib/pq/scram/scram.go new file mode 100644 index 00000000..5d0358f8 --- /dev/null +++ b/vendor/github.com/lib/pq/scram/scram.go @@ -0,0 +1,264 @@ +// Copyright (c) 2014 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Pacakage scram implements a SCRAM-{SHA-1,etc} client per RFC5802. +// +// http://tools.ietf.org/html/rfc5802 +// +package scram + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "encoding/base64" + "fmt" + "hash" + "strconv" + "strings" +) + +// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc). +// +// A Client may be used within a SASL conversation with logic resembling: +// +// var in []byte +// var client = scram.NewClient(sha1.New, user, pass) +// for client.Step(in) { +// out := client.Out() +// // send out to server +// in := serverOut +// } +// if client.Err() != nil { +// // auth failed +// } +// +type Client struct { + newHash func() hash.Hash + + user string + pass string + step int + out bytes.Buffer + err error + + clientNonce []byte + serverNonce []byte + saltedPass []byte + authMsg bytes.Buffer +} + +// NewClient returns a new SCRAM-* client with the provided hash algorithm. +// +// For SCRAM-SHA-256, for example, use: +// +// client := scram.NewClient(sha256.New, user, pass) +// +func NewClient(newHash func() hash.Hash, user, pass string) *Client { + c := &Client{ + newHash: newHash, + user: user, + pass: pass, + } + c.out.Grow(256) + c.authMsg.Grow(256) + return c +} + +// Out returns the data to be sent to the server in the current step. +func (c *Client) Out() []byte { + if c.out.Len() == 0 { + return nil + } + return c.out.Bytes() +} + +// Err returns the error that ocurred, or nil if there were no errors. +func (c *Client) Err() error { + return c.err +} + +// SetNonce sets the client nonce to the provided value. +// If not set, the nonce is generated automatically out of crypto/rand on the first step. +func (c *Client) SetNonce(nonce []byte) { + c.clientNonce = nonce +} + +var escaper = strings.NewReplacer("=", "=3D", ",", "=2C") + +// Step processes the incoming data from the server and makes the +// next round of data for the server available via Client.Out. +// Step returns false if there are no errors and more data is +// still expected. +func (c *Client) Step(in []byte) bool { + c.out.Reset() + if c.step > 2 || c.err != nil { + return false + } + c.step++ + switch c.step { + case 1: + c.err = c.step1(in) + case 2: + c.err = c.step2(in) + case 3: + c.err = c.step3(in) + } + return c.step > 2 || c.err != nil +} + +func (c *Client) step1(in []byte) error { + if len(c.clientNonce) == 0 { + const nonceLen = 16 + buf := make([]byte, nonceLen+b64.EncodedLen(nonceLen)) + if _, err := rand.Read(buf[:nonceLen]); err != nil { + return fmt.Errorf("cannot read random SCRAM-SHA-256 nonce from operating system: %v", err) + } + c.clientNonce = buf[nonceLen:] + b64.Encode(c.clientNonce, buf[:nonceLen]) + } + c.authMsg.WriteString("n=") + escaper.WriteString(&c.authMsg, c.user) + c.authMsg.WriteString(",r=") + c.authMsg.Write(c.clientNonce) + + c.out.WriteString("n,,") + c.out.Write(c.authMsg.Bytes()) + return nil +} + +var b64 = base64.StdEncoding + +func (c *Client) step2(in []byte) error { + c.authMsg.WriteByte(',') + c.authMsg.Write(in) + + fields := bytes.Split(in, []byte(",")) + if len(fields) != 3 { + return fmt.Errorf("expected 3 fields in first SCRAM-SHA-256 server message, got %d: %q", len(fields), in) + } + if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 nonce: %q", fields[0]) + } + if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 salt: %q", fields[1]) + } + if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2]) + } + + c.serverNonce = fields[0][2:] + if !bytes.HasPrefix(c.serverNonce, c.clientNonce) { + return fmt.Errorf("server SCRAM-SHA-256 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce) + } + + salt := make([]byte, b64.DecodedLen(len(fields[1][2:]))) + n, err := b64.Decode(salt, fields[1][2:]) + if err != nil { + return fmt.Errorf("cannot decode SCRAM-SHA-256 salt sent by server: %q", fields[1]) + } + salt = salt[:n] + iterCount, err := strconv.Atoi(string(fields[2][2:])) + if err != nil { + return fmt.Errorf("server sent an invalid SCRAM-SHA-256 iteration count: %q", fields[2]) + } + c.saltPassword(salt, iterCount) + + c.authMsg.WriteString(",c=biws,r=") + c.authMsg.Write(c.serverNonce) + + c.out.WriteString("c=biws,r=") + c.out.Write(c.serverNonce) + c.out.WriteString(",p=") + c.out.Write(c.clientProof()) + return nil +} + +func (c *Client) step3(in []byte) error { + var isv, ise bool + var fields = bytes.Split(in, []byte(",")) + if len(fields) == 1 { + isv = bytes.HasPrefix(fields[0], []byte("v=")) + ise = bytes.HasPrefix(fields[0], []byte("e=")) + } + if ise { + return fmt.Errorf("SCRAM-SHA-256 authentication error: %s", fields[0][2:]) + } else if !isv { + return fmt.Errorf("unsupported SCRAM-SHA-256 final message from server: %q", in) + } + if !bytes.Equal(c.serverSignature(), fields[0][2:]) { + return fmt.Errorf("cannot authenticate SCRAM-SHA-256 server signature: %q", fields[0][2:]) + } + return nil +} + +func (c *Client) saltPassword(salt []byte, iterCount int) { + mac := hmac.New(c.newHash, []byte(c.pass)) + mac.Write(salt) + mac.Write([]byte{0, 0, 0, 1}) + ui := mac.Sum(nil) + hi := make([]byte, len(ui)) + copy(hi, ui) + for i := 1; i < iterCount; i++ { + mac.Reset() + mac.Write(ui) + mac.Sum(ui[:0]) + for j, b := range ui { + hi[j] ^= b + } + } + c.saltedPass = hi +} + +func (c *Client) clientProof() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Client Key")) + clientKey := mac.Sum(nil) + hash := c.newHash() + hash.Write(clientKey) + storedKey := hash.Sum(nil) + mac = hmac.New(c.newHash, storedKey) + mac.Write(c.authMsg.Bytes()) + clientProof := mac.Sum(nil) + for i, b := range clientKey { + clientProof[i] ^= b + } + clientProof64 := make([]byte, b64.EncodedLen(len(clientProof))) + b64.Encode(clientProof64, clientProof) + return clientProof64 +} + +func (c *Client) serverSignature() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Server Key")) + serverKey := mac.Sum(nil) + + mac = hmac.New(c.newHash, serverKey) + mac.Write(c.authMsg.Bytes()) + serverSignature := mac.Sum(nil) + + encoded := make([]byte, b64.EncodedLen(len(serverSignature))) + b64.Encode(encoded, serverSignature) + return encoded +} diff --git a/vendor/github.com/lib/pq/ssl.go b/vendor/github.com/lib/pq/ssl.go index e1a326a0..d9020845 100644 --- a/vendor/github.com/lib/pq/ssl.go +++ b/vendor/github.com/lib/pq/ssl.go @@ -58,7 +58,13 @@ func ssl(o values) (func(net.Conn) (net.Conn, error), error) { if err != nil { return nil, err } - sslRenegotiation(&tlsConf) + + // Accept renegotiation requests initiated by the backend. + // + // Renegotiation was deprecated then removed from PostgreSQL 9.5, but + // the default configuration of older versions has it enabled. Redshift + // also initiates renegotiations and cannot be reconfigured. + tlsConf.Renegotiation = tls.RenegotiateFreelyAsClient return func(conn net.Conn) (net.Conn, error) { client := tls.Client(conn, &tlsConf) diff --git a/vendor/github.com/lib/pq/ssl_go1.7.go b/vendor/github.com/lib/pq/ssl_go1.7.go deleted file mode 100644 index d7ba43b3..00000000 --- a/vendor/github.com/lib/pq/ssl_go1.7.go +++ /dev/null @@ -1,14 +0,0 @@ -// +build go1.7 - -package pq - -import "crypto/tls" - -// Accept renegotiation requests initiated by the backend. -// -// Renegotiation was deprecated then removed from PostgreSQL 9.5, but -// the default configuration of older versions has it enabled. Redshift -// also initiates renegotiations and cannot be reconfigured. -func sslRenegotiation(conf *tls.Config) { - conf.Renegotiation = tls.RenegotiateFreelyAsClient -} diff --git a/vendor/github.com/lib/pq/ssl_renegotiation.go b/vendor/github.com/lib/pq/ssl_renegotiation.go deleted file mode 100644 index 85ed5e43..00000000 --- a/vendor/github.com/lib/pq/ssl_renegotiation.go +++ /dev/null @@ -1,8 +0,0 @@ -// +build !go1.7 - -package pq - -import "crypto/tls" - -// Renegotiation is not supported by crypto/tls until Go 1.7. -func sslRenegotiation(*tls.Config) {} diff --git a/vendor/github.com/lib/pq/ssl_test.go b/vendor/github.com/lib/pq/ssl_test.go new file mode 100644 index 00000000..3eafbfd2 --- /dev/null +++ b/vendor/github.com/lib/pq/ssl_test.go @@ -0,0 +1,279 @@ +package pq + +// This file contains SSL tests + +import ( + _ "crypto/sha256" + "crypto/x509" + "database/sql" + "os" + "path/filepath" + "testing" +) + +func maybeSkipSSLTests(t *testing.T) { + // Require some special variables for testing certificates + if os.Getenv("PQSSLCERTTEST_PATH") == "" { + t.Skip("PQSSLCERTTEST_PATH not set, skipping SSL tests") + } + + value := os.Getenv("PQGOSSLTESTS") + if value == "" || value == "0" { + t.Skip("PQGOSSLTESTS not enabled, skipping SSL tests") + } else if value != "1" { + t.Fatalf("unexpected value %q for PQGOSSLTESTS", value) + } +} + +func openSSLConn(t *testing.T, conninfo string) (*sql.DB, error) { + db, err := openTestConnConninfo(conninfo) + if err != nil { + // should never fail + t.Fatal(err) + } + // Do something with the connection to see whether it's working or not. + tx, err := db.Begin() + if err == nil { + return db, tx.Rollback() + } + _ = db.Close() + return nil, err +} + +func checkSSLSetup(t *testing.T, conninfo string) { + _, err := openSSLConn(t, conninfo) + if pge, ok := err.(*Error); ok { + if pge.Code.Name() != "invalid_authorization_specification" { + t.Fatalf("unexpected error code '%s'", pge.Code.Name()) + } + } else { + t.Fatalf("expected %T, got %v", (*Error)(nil), err) + } +} + +// Connect over SSL and run a simple query to test the basics +func TestSSLConnection(t *testing.T) { + maybeSkipSSLTests(t) + // Environment sanity check: should fail without SSL + checkSSLSetup(t, "sslmode=disable user=pqgossltest") + + db, err := openSSLConn(t, "sslmode=require user=pqgossltest") + if err != nil { + t.Fatal(err) + } + rows, err := db.Query("SELECT 1") + if err != nil { + t.Fatal(err) + } + rows.Close() +} + +// Test sslmode=verify-full +func TestSSLVerifyFull(t *testing.T) { + maybeSkipSSLTests(t) + // Environment sanity check: should fail without SSL + checkSSLSetup(t, "sslmode=disable user=pqgossltest") + + // Not OK according to the system CA + _, err := openSSLConn(t, "host=postgres sslmode=verify-full user=pqgossltest") + if err == nil { + t.Fatal("expected error") + } + _, ok := err.(x509.UnknownAuthorityError) + if !ok { + t.Fatalf("expected x509.UnknownAuthorityError, got %#+v", err) + } + + rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") + rootCert := "sslrootcert=" + rootCertPath + " " + // No match on Common Name + _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-full user=pqgossltest") + if err == nil { + t.Fatal("expected error") + } + _, ok = err.(x509.HostnameError) + if !ok { + t.Fatalf("expected x509.HostnameError, got %#+v", err) + } + // OK + _, err = openSSLConn(t, rootCert+"host=postgres sslmode=verify-full user=pqgossltest") + if err != nil { + t.Fatal(err) + } +} + +// Test sslmode=require sslrootcert=rootCertPath +func TestSSLRequireWithRootCert(t *testing.T) { + maybeSkipSSLTests(t) + // Environment sanity check: should fail without SSL + checkSSLSetup(t, "sslmode=disable user=pqgossltest") + + bogusRootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "bogus_root.crt") + bogusRootCert := "sslrootcert=" + bogusRootCertPath + " " + + // Not OK according to the bogus CA + _, err := openSSLConn(t, bogusRootCert+"host=postgres sslmode=require user=pqgossltest") + if err == nil { + t.Fatal("expected error") + } + _, ok := err.(x509.UnknownAuthorityError) + if !ok { + t.Fatalf("expected x509.UnknownAuthorityError, got %s, %#+v", err, err) + } + + nonExistentCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "non_existent.crt") + nonExistentCert := "sslrootcert=" + nonExistentCertPath + " " + + // No match on Common Name, but that's OK because we're not validating anything. + _, err = openSSLConn(t, nonExistentCert+"host=127.0.0.1 sslmode=require user=pqgossltest") + if err != nil { + t.Fatal(err) + } + + rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") + rootCert := "sslrootcert=" + rootCertPath + " " + + // No match on Common Name, but that's OK because we're not validating the CN. + _, err = openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=require user=pqgossltest") + if err != nil { + t.Fatal(err) + } + // Everything OK + _, err = openSSLConn(t, rootCert+"host=postgres sslmode=require user=pqgossltest") + if err != nil { + t.Fatal(err) + } +} + +// Test sslmode=verify-ca +func TestSSLVerifyCA(t *testing.T) { + maybeSkipSSLTests(t) + // Environment sanity check: should fail without SSL + checkSSLSetup(t, "sslmode=disable user=pqgossltest") + + // Not OK according to the system CA + { + _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest") + if _, ok := err.(x509.UnknownAuthorityError); !ok { + t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err) + } + } + + // Still not OK according to the system CA; empty sslrootcert is treated as unspecified. + { + _, err := openSSLConn(t, "host=postgres sslmode=verify-ca user=pqgossltest sslrootcert=''") + if _, ok := err.(x509.UnknownAuthorityError); !ok { + t.Fatalf("expected %T, got %#+v", x509.UnknownAuthorityError{}, err) + } + } + + rootCertPath := filepath.Join(os.Getenv("PQSSLCERTTEST_PATH"), "root.crt") + rootCert := "sslrootcert=" + rootCertPath + " " + // No match on Common Name, but that's OK + if _, err := openSSLConn(t, rootCert+"host=127.0.0.1 sslmode=verify-ca user=pqgossltest"); err != nil { + t.Fatal(err) + } + // Everything OK + if _, err := openSSLConn(t, rootCert+"host=postgres sslmode=verify-ca user=pqgossltest"); err != nil { + t.Fatal(err) + } +} + +// Authenticate over SSL using client certificates +func TestSSLClientCertificates(t *testing.T) { + maybeSkipSSLTests(t) + // Environment sanity check: should fail without SSL + checkSSLSetup(t, "sslmode=disable user=pqgossltest") + + const baseinfo = "sslmode=require user=pqgosslcert" + + // Certificate not specified, should fail + { + _, err := openSSLConn(t, baseinfo) + if pge, ok := err.(*Error); ok { + if pge.Code.Name() != "invalid_authorization_specification" { + t.Fatalf("unexpected error code '%s'", pge.Code.Name()) + } + } else { + t.Fatalf("expected %T, got %v", (*Error)(nil), err) + } + } + + // Empty certificate specified, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert=''") + if pge, ok := err.(*Error); ok { + if pge.Code.Name() != "invalid_authorization_specification" { + t.Fatalf("unexpected error code '%s'", pge.Code.Name()) + } + } else { + t.Fatalf("expected %T, got %v", (*Error)(nil), err) + } + } + + // Non-existent certificate specified, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert=/tmp/filedoesnotexist") + if pge, ok := err.(*Error); ok { + if pge.Code.Name() != "invalid_authorization_specification" { + t.Fatalf("unexpected error code '%s'", pge.Code.Name()) + } + } else { + t.Fatalf("expected %T, got %v", (*Error)(nil), err) + } + } + + certpath, ok := os.LookupEnv("PQSSLCERTTEST_PATH") + if !ok { + t.Fatalf("PQSSLCERTTEST_PATH not present in environment") + } + + sslcert := filepath.Join(certpath, "postgresql.crt") + + // Cert present, key not specified, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert) + if _, ok := err.(*os.PathError); !ok { + t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) + } + } + + // Cert present, empty key specified, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=''") + if _, ok := err.(*os.PathError); !ok { + t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) + } + } + + // Cert present, non-existent key, should fail + { + _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey=/tmp/filedoesnotexist") + if _, ok := err.(*os.PathError); !ok { + t.Fatalf("expected %T, got %#+v", (*os.PathError)(nil), err) + } + } + + // Key has wrong permissions (passing the cert as the key), should fail + if _, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslcert); err != ErrSSLKeyHasWorldPermissions { + t.Fatalf("expected %s, got %#+v", ErrSSLKeyHasWorldPermissions, err) + } + + sslkey := filepath.Join(certpath, "postgresql.key") + + // Should work + if db, err := openSSLConn(t, baseinfo+" sslcert="+sslcert+" sslkey="+sslkey); err != nil { + t.Fatal(err) + } else { + rows, err := db.Query("SELECT 1") + if err != nil { + t.Fatal(err) + } + if err := rows.Close(); err != nil { + t.Fatal(err) + } + if err := db.Close(); err != nil { + t.Fatal(err) + } + } +} diff --git a/vendor/github.com/lib/pq/url_test.go b/vendor/github.com/lib/pq/url_test.go new file mode 100644 index 00000000..4ff0ce03 --- /dev/null +++ b/vendor/github.com/lib/pq/url_test.go @@ -0,0 +1,66 @@ +package pq + +import ( + "testing" +) + +func TestSimpleParseURL(t *testing.T) { + expected := "host=hostname.remote" + str, err := ParseURL("postgres://hostname.remote") + if err != nil { + t.Fatal(err) + } + + if str != expected { + t.Fatalf("unexpected result from ParseURL:\n+ %v\n- %v", str, expected) + } +} + +func TestIPv6LoopbackParseURL(t *testing.T) { + expected := "host=::1 port=1234" + str, err := ParseURL("postgres://[::1]:1234") + if err != nil { + t.Fatal(err) + } + + if str != expected { + t.Fatalf("unexpected result from ParseURL:\n+ %v\n- %v", str, expected) + } +} + +func TestFullParseURL(t *testing.T) { + expected := `dbname=database host=hostname.remote password=top\ secret port=1234 user=username` + str, err := ParseURL("postgres://username:top%20secret@hostname.remote:1234/database") + if err != nil { + t.Fatal(err) + } + + if str != expected { + t.Fatalf("unexpected result from ParseURL:\n+ %s\n- %s", str, expected) + } +} + +func TestInvalidProtocolParseURL(t *testing.T) { + _, err := ParseURL("http://hostname.remote") + switch err { + case nil: + t.Fatal("Expected an error from parsing invalid protocol") + default: + msg := "invalid connection protocol: http" + if err.Error() != msg { + t.Fatalf("Unexpected error message:\n+ %s\n- %s", + err.Error(), msg) + } + } +} + +func TestMinimalURL(t *testing.T) { + cs, err := ParseURL("postgres://") + if err != nil { + t.Fatal(err) + } + + if cs != "" { + t.Fatalf("expected blank connection string, got: %q", cs) + } +} diff --git a/vendor/github.com/lib/pq/uuid_test.go b/vendor/github.com/lib/pq/uuid_test.go new file mode 100644 index 00000000..8ecee2fd --- /dev/null +++ b/vendor/github.com/lib/pq/uuid_test.go @@ -0,0 +1,46 @@ +package pq + +import ( + "reflect" + "strings" + "testing" +) + +func TestDecodeUUIDBinaryError(t *testing.T) { + t.Parallel() + _, err := decodeUUIDBinary([]byte{0x12, 0x34}) + + if err == nil { + t.Fatal("Expected error, got none") + } + if !strings.HasPrefix(err.Error(), "pq:") { + t.Errorf("Expected error to start with %q, got %q", "pq:", err.Error()) + } + if !strings.Contains(err.Error(), "bad length: 2") { + t.Errorf("Expected error to contain length, got %q", err.Error()) + } +} + +func BenchmarkDecodeUUIDBinary(b *testing.B) { + x := []byte{0x03, 0xa3, 0x52, 0x2f, 0x89, 0x28, 0x49, 0x87, 0x84, 0xd6, 0x93, 0x7b, 0x36, 0xec, 0x27, 0x6f} + + for i := 0; i < b.N; i++ { + decodeUUIDBinary(x) + } +} + +func TestDecodeUUIDBackend(t *testing.T) { + db := openTestConn(t) + defer db.Close() + + var s = "a0ecc91d-a13f-4fe4-9fce-7e09777cc70a" + var scanned interface{} + + err := db.QueryRow(`SELECT $1::uuid`, s).Scan(&scanned) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !reflect.DeepEqual(scanned, []byte(s)) { + t.Errorf("Expected []byte(%q), got %T(%q)", s, scanned, scanned) + } +}