diff --git a/domain/backup/endpoint.go b/domain/backup/endpoint.go index 0aabcecc..4e23181c 100644 --- a/domain/backup/endpoint.go +++ b/domain/backup/endpoint.go @@ -206,5 +206,8 @@ func (h *Handler) Restore(w http.ResponseWriter, r *http.Request) { h.Runtime.Log.Infof("Restore remapped %d UserID values", len(rh.MapUserID)) h.Runtime.Log.Info("Restore completed") + h.Runtime.Log.Info("Building search index") + go h.Indexer.Rebuild(ctx) + response.WriteEmpty(w) } diff --git a/vendor/github.com/go-sql-driver/mysql/AUTHORS b/vendor/github.com/go-sql-driver/mysql/AUTHORS index 73ff68fb..bfe74c4e 100644 --- a/vendor/github.com/go-sql-driver/mysql/AUTHORS +++ b/vendor/github.com/go-sql-driver/mysql/AUTHORS @@ -27,6 +27,7 @@ Daniël van Eeden Dave Protasowski DisposaBoy Egor Smolyakov +Erwan Martin Evan Shaw Frederick Mayle Gustavo Kristic @@ -34,12 +35,15 @@ Hajime Nakagami Hanno Braun Henri Yandell Hirotaka Yamamoto +Huyiguang ICHINOSE Shogo +Ilia Cimpoes INADA Naoki Jacek Szwec James Harr Jeff Hodges Jeffrey Charles +Jerome Meyer Jian Zhen Joshua Prunier Julien Lefevre @@ -69,9 +73,14 @@ Richard Wilkes Robert Russell Runrioter Wung Shuode Li +Simon J Mudd Soroush Pour Stan Putrya Stanley Gunawan +Steven Hartland +Thomas Wodarek +Tim Ruffles +Tom Jenkinson Xiangyu Hu Xiaobing Jiang Xiuming Chen @@ -81,9 +90,12 @@ Zhenye Xie Barracuda Networks, Inc. Counting Ltd. +Facebook Inc. +GitHub Inc. Google Inc. InfoSum Ltd. Keybase Inc. +Multiplay Ltd. Percona LLC Pivotal Inc. Stripe Inc. diff --git a/vendor/github.com/go-sql-driver/mysql/README.md b/vendor/github.com/go-sql-driver/mysql/README.md index 2e9b07ee..c6adf1d6 100644 --- a/vendor/github.com/go-sql-driver/mysql/README.md +++ b/vendor/github.com/go-sql-driver/mysql/README.md @@ -40,7 +40,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac * Optional placeholder interpolation ## Requirements - * Go 1.7 or higher. We aim to support the 3 latest versions of Go. + * Go 1.9 or higher. We aim to support the 3 latest versions of Go. * MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+) --------------------------------------- @@ -171,13 +171,18 @@ Unless you need the fallback behavior, please use `collation` instead. ``` Type: string Valid Values: -Default: utf8_general_ci +Default: utf8mb4_general_ci ``` Sets the collation used for client-server interaction on connection. In contrast to `charset`, `collation` does not issue additional queries. If the specified collation is unavailable on the target server, the connection will fail. A list of valid charsets for a server is retrievable with `SHOW COLLATION`. +The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You should use an older collation (e.g. `utf8_general_ci`) for older MySQL. + +Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)). + + ##### `clientFoundRows` ``` @@ -328,11 +333,11 @@ Timeout for establishing connections, aka dial timeout. The value must be a deci ``` Type: bool / string -Valid Values: true, false, skip-verify, +Valid Values: true, false, skip-verify, preferred, Default: false ``` -`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). +`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side) or use `preferred` to use TLS only when advertised by the server. This is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Neither `skip-verify` nor `preferred` add any reliable security. You can use a custom TLS config after registering it with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig). ##### `writeTimeout` @@ -444,7 +449,7 @@ See the [godoc of Go-MySQL-Driver](https://godoc.org/github.com/go-sql-driver/my ### `time.Time` support The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your program. -However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter. +However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical equivalent in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter. **Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes). diff --git a/vendor/github.com/go-sql-driver/mysql/appengine.go b/vendor/github.com/go-sql-driver/mysql/appengine.go index be41f2ee..914e6623 100644 --- a/vendor/github.com/go-sql-driver/mysql/appengine.go +++ b/vendor/github.com/go-sql-driver/mysql/appengine.go @@ -11,9 +11,15 @@ package mysql import ( + "context" + "net" + "google.golang.org/appengine/cloudsql" ) func init() { - RegisterDial("cloudsql", cloudsql.Dial) + RegisterDialContext("cloudsql", func(_ context.Context, instance string) (net.Conn, error) { + // XXX: the cloudsql driver still does not export a Context-aware dialer. + return cloudsql.Dial(instance) + }) } diff --git a/vendor/github.com/go-sql-driver/mysql/auth.go b/vendor/github.com/go-sql-driver/mysql/auth.go index 0b59f52e..fec7040d 100644 --- a/vendor/github.com/go-sql-driver/mysql/auth.go +++ b/vendor/github.com/go-sql-driver/mysql/auth.go @@ -234,64 +234,64 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro if err != nil { return err } - return mc.writeAuthSwitchPacket(enc, false) + return mc.writeAuthSwitchPacket(enc) } -func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) { +func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) { switch plugin { case "caching_sha2_password": authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) - return authResp, (authResp == nil), nil + return authResp, nil case "mysql_old_password": if !mc.cfg.AllowOldPasswords { - return nil, false, ErrOldPassword + return nil, ErrOldPassword } // Note: there are edge cases where this should work but doesn't; // this is currently "wontfix": // https://github.com/go-sql-driver/mysql/issues/184 - authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd) - return authResp, true, nil + authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0) + return authResp, nil case "mysql_clear_password": if !mc.cfg.AllowCleartextPasswords { - return nil, false, ErrCleartextPassword + return nil, ErrCleartextPassword } // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - return []byte(mc.cfg.Passwd), true, nil + return append([]byte(mc.cfg.Passwd), 0), nil case "mysql_native_password": if !mc.cfg.AllowNativePasswords { - return nil, false, ErrNativePassword + return nil, ErrNativePassword } // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html // Native password authentication only need and will need 20-byte challenge. authResp := scramblePassword(authData[:20], mc.cfg.Passwd) - return authResp, false, nil + return authResp, nil case "sha256_password": if len(mc.cfg.Passwd) == 0 { - return nil, true, nil + return []byte{0}, nil } if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - return []byte(mc.cfg.Passwd), true, nil + return append([]byte(mc.cfg.Passwd), 0), nil } pubKey := mc.cfg.pubKey if pubKey == nil { // request public key from server - return []byte{1}, false, nil + return []byte{1}, nil } // encrypted password enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) - return enc, false, err + return enc, err default: errLog.Print("unknown auth plugin:", plugin) - return nil, false, ErrUnknownPlugin + return nil, ErrUnknownPlugin } } @@ -315,11 +315,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { plugin = newPlugin - authResp, addNUL, err := mc.auth(authData, plugin) + authResp, err := mc.auth(authData, plugin) if err != nil { return err } - if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil { + if err = mc.writeAuthSwitchPacket(authResp); err != nil { return err } @@ -352,7 +352,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { case cachingSha2PasswordPerformFullAuthentication: if mc.cfg.tls != nil || mc.cfg.Net == "unix" { // write cleartext auth packet - err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true) + err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0)) if err != nil { return err } @@ -360,13 +360,15 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { pubKey := mc.cfg.pubKey if pubKey == nil { // request public key from server - data := mc.buf.takeSmallBuffer(4 + 1) + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { + return err + } data[4] = cachingSha2PasswordRequestPublicKey mc.writePacket(data) // parse public key - data, err := mc.readPacket() - if err != nil { + if data, err = mc.readPacket(); err != nil { return err } diff --git a/vendor/github.com/go-sql-driver/mysql/auth_test.go b/vendor/github.com/go-sql-driver/mysql/auth_test.go new file mode 100644 index 00000000..1920ef39 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/auth_test.go @@ -0,0 +1,1330 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "encoding/pem" + "fmt" + "testing" +) + +var testPubKey = []byte("-----BEGIN PUBLIC KEY-----\n" + + "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAol0Z8G8U+25Btxk/g/fm\n" + + "UAW/wEKjQCTjkibDE4B+qkuWeiumg6miIRhtilU6m9BFmLQSy1ltYQuu4k17A4tQ\n" + + "rIPpOQYZges/qsDFkZh3wyK5jL5WEFVdOasf6wsfszExnPmcZS4axxoYJfiuilrN\n" + + "hnwinBAqfi3S0sw5MpSI4Zl1AbOrHG4zDI62Gti2PKiMGyYDZTS9xPrBLbN95Kby\n" + + "FFclQLEzA9RJcS1nHFsWtRgHjGPhhjCQxEm9NQ1nePFhCfBfApyfH1VM2VCOQum6\n" + + "Ci9bMuHWjTjckC84mzF99kOxOWVU7mwS6gnJqBzpuz8t3zq8/iQ2y7QrmZV+jTJP\n" + + "WQIDAQAB\n" + + "-----END PUBLIC KEY-----\n") + +var testPubKeyRSA *rsa.PublicKey + +func init() { + block, _ := pem.Decode(testPubKey) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + panic(err) + } + testPubKeyRSA = pub.(*rsa.PublicKey) +} + +func TestScrambleOldPass(t *testing.T) { + scramble := []byte{9, 8, 7, 6, 5, 4, 3, 2} + vectors := []struct { + pass string + out string + }{ + {" pass", "47575c5a435b4251"}, + {"pass ", "47575c5a435b4251"}, + {"123\t456", "575c47505b5b5559"}, + {"C0mpl!ca ted#PASS123", "5d5d554849584a45"}, + } + for _, tuple := range vectors { + ours := scrambleOldPassword(scramble, tuple.pass) + if tuple.out != fmt.Sprintf("%x", ours) { + t.Errorf("Failed old password %q", tuple.pass) + } + } +} + +func TestScrambleSHA256Pass(t *testing.T) { + scramble := []byte{10, 47, 74, 111, 75, 73, 34, 48, 88, 76, 114, 74, 37, 13, 3, 80, 82, 2, 23, 21} + vectors := []struct { + pass string + out string + }{ + {"secret", "f490e76f66d9d86665ce54d98c78d0acfe2fb0b08b423da807144873d30b312c"}, + {"secret2", "abc3934a012cf342e876071c8ee202de51785b430258a7a0138bc79c4d800bc6"}, + } + for _, tuple := range vectors { + ours := scrambleSHA256Password(scramble, tuple.pass) + if tuple.out != fmt.Sprintf("%x", ours) { + t.Errorf("Failed SHA256 password %q", tuple.pass) + } + } +} + +func TestAuthFastCachingSHA256PasswordCached(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, + 22, 41, 84, 32, 123, 43, 118} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{102, 32, 5, 35, 143, 161, 140, 241, 171, 232, 56, + 139, 43, 14, 107, 196, 249, 170, 147, 60, 220, 204, 120, 178, 214, 15, + 184, 150, 26, 61, 57, 235} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 3, // Fast Auth Success + 7, 0, 0, 3, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastCachingSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + + authData := []byte{90, 105, 74, 126, 30, 48, 37, 56, 3, 23, 115, 127, 69, + 22, 41, 84, 32, 123, 43, 118} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + if writtenAuthRespLen != 0 { + t.Fatalf("unexpected written auth response (%d bytes): %v", + writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastCachingSHA256PasswordFullRSA(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, + 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, + 110, 40, 139, 124, 41} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 4, // Perform Full Authentication + } + conn.queuedReplies = [][]byte{ + // pub key response + append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...), + + // OK + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{1, 0, 0, 3, 2, 0, 1, 0, 5}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastCachingSHA256PasswordFullRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, + 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, + 110, 40, 139, 124, 41} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 4, // Perform Full Authentication + } + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastCachingSHA256PasswordFullSecure(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "caching_sha2_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // Hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{171, 201, 138, 146, 89, 159, 11, 170, 0, 67, 165, + 49, 175, 94, 218, 68, 177, 109, 110, 86, 34, 33, 44, 190, 67, 240, 70, + 110, 40, 139, 124, 41} + if writtenAuthRespLen != 32 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 2, 0, 0, 2, 1, 4, // Perform Full Authentication + } + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.Equal(conn.written, []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastCleartextPasswordNotAllowed(t *testing.T) { + _, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_clear_password" + + // Send Client Authentication Packet + _, err := mc.auth(authData, plugin) + if err != ErrCleartextPassword { + t.Errorf("expected ErrCleartextPassword, got %v", err) + } +} + +func TestAuthFastCleartextPassword(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.AllowCleartextPasswords = true + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_clear_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastCleartextPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + mc.cfg.AllowCleartextPasswords = true + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_clear_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastNativePasswordNotAllowed(t *testing.T) { + _, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.AllowNativePasswords = false + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_native_password" + + // Send Client Authentication Packet + _, err := mc.auth(authData, plugin) + if err != ErrNativePassword { + t.Errorf("expected ErrNativePassword, got %v", err) + } +} + +func TestAuthFastNativePassword(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_native_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{53, 177, 140, 159, 251, 189, 127, 53, 109, 252, + 172, 50, 211, 192, 240, 164, 26, 48, 207, 45} + if writtenAuthRespLen != 20 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastNativePasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + + authData := []byte{70, 114, 92, 94, 1, 38, 11, 116, 63, 114, 23, 101, 126, + 103, 26, 95, 81, 17, 24, 21} + plugin := "mysql_native_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + if writtenAuthRespLen != 0 { + t.Fatalf("unexpected written auth response (%d bytes): %v", + writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response + conn.data = []byte{ + 7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0, // OK + } + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{0} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response (pub key response) + conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...) + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastSHA256PasswordRSA(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{1} + if writtenAuthRespLen != 1 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response (pub key response) + conn.data = append([]byte{byte(1 + len(testPubKey)), 1, 0, 2, 1}, testPubKey...) + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.HasPrefix(conn.written, []byte{0, 1, 0, 3}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthFastSHA256PasswordRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // auth response (OK) + conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } +} + +func TestAuthFastSHA256PasswordSecure(t *testing.T) { + conn, mc := newRWMockConn(1) + mc.cfg.User = "root" + mc.cfg.Passwd = "secret" + + // hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + authData := []byte{6, 81, 96, 114, 14, 42, 50, 30, 76, 47, 1, 95, 126, 81, + 62, 94, 83, 80, 52, 85} + plugin := "sha256_password" + + // send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + t.Fatal(err) + } + + // unset TLS config to prevent the actual establishment of a TLS wrapper + mc.cfg.tls = nil + + err = mc.writeHandshakeResponsePacket(authResp, plugin) + if err != nil { + t.Fatal(err) + } + + // check written auth response + authRespStart := 4 + 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + authRespEnd := authRespStart + 1 + len(authResp) + writtenAuthRespLen := conn.written[authRespStart] + writtenAuthResp := conn.written[authRespStart+1 : authRespEnd] + expectedAuthResp := []byte{115, 101, 99, 114, 101, 116, 0} + if writtenAuthRespLen != 7 || !bytes.Equal(writtenAuthResp, expectedAuthResp) { + t.Fatalf("unexpected written auth response (%d bytes): %v", writtenAuthRespLen, writtenAuthResp) + } + conn.written = nil + + // auth response (OK) + conn.data = []byte{7, 0, 0, 2, 0, 0, 0, 2, 0, 0, 0} + conn.maxReads = 1 + + // Handle response to auth packet + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + if !bytes.Equal(conn.written, []byte{}) { + t.Errorf("unexpected written data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordCached(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{ + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, // OK + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + } + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{0, 0, 0, 3} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordFullRSA(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + conn.queuedReplies = [][]byte{ + // Perform Full Authentication + {2, 0, 0, 4, 1, 4}, + + // Pub Key Response + append([]byte{byte(1 + len(testPubKey)), 1, 0, 6, 1}, testPubKey...), + + // OK + {7, 0, 0, 8, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 4 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + + // 2. Packet: Pub Key Request + 1, 0, 0, 5, 2, + + // 3. Packet: Encrypted Password + 0, 1, 0, 7, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordFullRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + conn.queuedReplies = [][]byte{ + // Perform Full Authentication + {2, 0, 0, 4, 1, 4}, + + // OK + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + + // 2. Packet: Encrypted Password + 0, 1, 0, 5, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCachingSHA256PasswordFullSecure(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // Hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 99, 97, 99, 104, 105, 110, 103, 95, + 115, 104, 97, 50, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 101, + 11, 26, 18, 94, 97, 22, 72, 2, 46, 70, 106, 29, 55, 45, 94, 76, 90, 84, + 50, 0} + + // auth response + conn.queuedReplies = [][]byte{ + {2, 0, 0, 4, 1, 4}, // Perform Full Authentication + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, // OK + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{ + // 1. Packet: Hash + 32, 0, 0, 3, 129, 93, 132, 95, 114, 48, 79, 215, 128, 62, 193, 118, 128, + 54, 75, 208, 159, 252, 227, 215, 129, 15, 242, 97, 19, 159, 31, 20, 58, + 153, 9, 130, + + // 2. Packet: Cleartext password + 7, 0, 0, 5, 115, 101, 99, 114, 101, 116, 0, + } + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCleartextPasswordNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, + 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + conn.maxReads = 1 + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrCleartextPassword { + t.Errorf("expected ErrCleartextPassword, got %v", err) + } +} + +func TestAuthSwitchCleartextPassword(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowCleartextPasswords = true + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, + 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchCleartextPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowCleartextPasswords = true + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{22, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 99, 108, + 101, 97, 114, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{1, 0, 0, 3, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchNativePasswordNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowNativePasswords = false + + conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, + 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, + 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, + 31, 0} + conn.maxReads = 1 + authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, + 48, 31, 89, 39, 55, 31} + plugin := "caching_sha2_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrNativePassword { + t.Errorf("expected ErrNativePassword, got %v", err) + } +} + +func TestAuthSwitchNativePassword(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowNativePasswords = true + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, + 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, + 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, + 31, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, + 48, 31, 89, 39, 55, 31} + plugin := "caching_sha2_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{20, 0, 0, 3, 202, 41, 195, 164, 34, 226, 49, 103, + 21, 211, 167, 199, 227, 116, 8, 48, 57, 71, 149, 146} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchNativePasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowNativePasswords = true + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{44, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 110, 97, + 116, 105, 118, 101, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 96, + 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, 48, 31, 89, 39, 55, + 31, 0} + + // auth response + conn.queuedReplies = [][]byte{{7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{96, 71, 63, 8, 1, 58, 75, 12, 69, 95, 66, 60, 117, 31, + 48, 31, 89, 39, 55, 31} + plugin := "caching_sha2_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{0, 0, 0, 3} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchOldPasswordNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, + 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, + 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + conn.maxReads = 1 + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrOldPassword { + t.Errorf("expected ErrOldPassword, got %v", err) + } +} + +// Same to TestAuthSwitchOldPasswordNotAllowed, but use OldAuthSwitch request. +func TestOldAuthSwitchNotAllowed(t *testing.T) { + conn, mc := newRWMockConn(2) + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + conn.maxReads = 1 + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + err := mc.handleAuthResult(authData, plugin) + if err != ErrOldPassword { + t.Errorf("expected ErrOldPassword, got %v", err) + } +} + +func TestAuthSwitchOldPassword(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, + 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, + 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +// Same to TestAuthSwitchOldPassword, but use OldAuthSwitch request. +func TestOldAuthSwitch(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "secret" + + // OldAuthSwitch request + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{9, 0, 0, 3, 86, 83, 83, 79, 74, 78, 65, 66, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} +func TestAuthSwitchOldPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{41, 0, 0, 2, 254, 109, 121, 115, 113, 108, 95, 111, 108, + 100, 95, 112, 97, 115, 115, 119, 111, 114, 100, 0, 95, 84, 103, 43, 61, + 49, 123, 61, 91, 50, 40, 113, 35, 84, 96, 101, 92, 123, 121, 107, 0} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{1, 0, 0, 3, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +// Same to TestAuthSwitchOldPasswordEmpty, but use OldAuthSwitch request. +func TestOldAuthSwitchPasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.AllowOldPasswords = true + mc.cfg.Passwd = "" + + // OldAuthSwitch request. + conn.data = []byte{1, 0, 0, 2, 0xfe} + + // auth response + conn.queuedReplies = [][]byte{{8, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0, 0}} + conn.maxReads = 2 + + authData := []byte{95, 84, 103, 43, 61, 49, 123, 61, 91, 50, 40, 113, 35, + 84, 96, 101, 92, 123, 121, 107} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReply := []byte{1, 0, 0, 3, 0} + if !bytes.Equal(conn.written, expectedReply) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordEmpty(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "" + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Empty Password + 1, 0, 0, 3, 0, + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordRSA(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // Pub Key Response + append([]byte{byte(1 + len(testPubKey)), 1, 0, 4, 1}, testPubKey...), + + // OK + {7, 0, 0, 6, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 3 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Pub Key Request + 1, 0, 0, 3, 1, + + // 2. Packet: Encrypted Password + 0, 1, 0, 5, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordRSAWithKey(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + mc.cfg.pubKey = testPubKeyRSA + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Encrypted Password + 0, 1, 0, 3, // [changing bytes] + } + if !bytes.HasPrefix(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} + +func TestAuthSwitchSHA256PasswordSecure(t *testing.T) { + conn, mc := newRWMockConn(2) + mc.cfg.Passwd = "secret" + + // Hack to make the caching_sha2_password plugin believe that the connection + // is secure + mc.cfg.tls = &tls.Config{InsecureSkipVerify: true} + + // auth switch request + conn.data = []byte{38, 0, 0, 2, 254, 115, 104, 97, 50, 53, 54, 95, 112, 97, + 115, 115, 119, 111, 114, 100, 0, 78, 82, 62, 40, 100, 1, 59, 31, 44, 69, + 33, 112, 8, 81, 51, 96, 65, 82, 16, 114, 0} + + conn.queuedReplies = [][]byte{ + // OK + {7, 0, 0, 4, 0, 0, 0, 2, 0, 0, 0}, + } + conn.maxReads = 2 + + authData := []byte{123, 87, 15, 84, 20, 58, 37, 121, 91, 117, 51, 24, 19, + 47, 43, 9, 41, 112, 67, 110} + plugin := "mysql_native_password" + + if err := mc.handleAuthResult(authData, plugin); err != nil { + t.Errorf("got error: %v", err) + } + + expectedReplyPrefix := []byte{ + // 1. Packet: Cleartext Password + 7, 0, 0, 3, 115, 101, 99, 114, 101, 116, 0, + } + if !bytes.Equal(conn.written, expectedReplyPrefix) { + t.Errorf("got unexpected data: %v", conn.written) + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/benchmark_test.go b/vendor/github.com/go-sql-driver/mysql/benchmark_test.go new file mode 100644 index 00000000..3e25a3bf --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/benchmark_test.go @@ -0,0 +1,373 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "fmt" + "math" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +type TB testing.B + +func (tb *TB) check(err error) { + if err != nil { + tb.Fatal(err) + } +} + +func (tb *TB) checkDB(db *sql.DB, err error) *sql.DB { + tb.check(err) + return db +} + +func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows { + tb.check(err) + return rows +} + +func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt { + tb.check(err) + return stmt +} + +func initDB(b *testing.B, queries ...string) *sql.DB { + tb := (*TB)(b) + db := tb.checkDB(sql.Open("mysql", dsn)) + for _, query := range queries { + if _, err := db.Exec(query); err != nil { + b.Fatalf("error on %q: %v", query, err) + } + } + return db +} + +const concurrencyLevel = 10 + +func BenchmarkQuery(b *testing.B) { + tb := (*TB)(b) + b.StopTimer() + b.ReportAllocs() + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + db.SetMaxIdleConns(concurrencyLevel) + defer db.Close() + + stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?")) + defer stmt.Close() + + remain := int64(b.N) + var wg sync.WaitGroup + wg.Add(concurrencyLevel) + defer wg.Wait() + b.StartTimer() + + for i := 0; i < concurrencyLevel; i++ { + go func() { + for { + if atomic.AddInt64(&remain, -1) < 0 { + wg.Done() + return + } + + var got string + tb.check(stmt.QueryRow(1).Scan(&got)) + if got != "one" { + b.Errorf("query = %q; want one", got) + wg.Done() + return + } + } + }() + } +} + +func BenchmarkExec(b *testing.B) { + tb := (*TB)(b) + b.StopTimer() + b.ReportAllocs() + db := tb.checkDB(sql.Open("mysql", dsn)) + db.SetMaxIdleConns(concurrencyLevel) + defer db.Close() + + stmt := tb.checkStmt(db.Prepare("DO 1")) + defer stmt.Close() + + remain := int64(b.N) + var wg sync.WaitGroup + wg.Add(concurrencyLevel) + defer wg.Wait() + b.StartTimer() + + for i := 0; i < concurrencyLevel; i++ { + go func() { + for { + if atomic.AddInt64(&remain, -1) < 0 { + wg.Done() + return + } + + if _, err := stmt.Exec(); err != nil { + b.Fatal(err.Error()) + } + } + }() + } +} + +// data, but no db writes +var roundtripSample []byte + +func initRoundtripBenchmarks() ([]byte, int, int) { + if roundtripSample == nil { + roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024)) + } + return roundtripSample, 16, len(roundtripSample) +} + +func BenchmarkRoundtripTxt(b *testing.B) { + b.StopTimer() + sample, min, max := initRoundtripBenchmarks() + sampleString := string(sample) + b.ReportAllocs() + tb := (*TB)(b) + db := tb.checkDB(sql.Open("mysql", dsn)) + defer db.Close() + b.StartTimer() + var result string + for i := 0; i < b.N; i++ { + length := min + i + if length > max { + length = max + } + test := sampleString[0:length] + rows := tb.checkRows(db.Query(`SELECT "` + test + `"`)) + if !rows.Next() { + rows.Close() + b.Fatalf("crashed") + } + err := rows.Scan(&result) + if err != nil { + rows.Close() + b.Fatalf("crashed") + } + if result != test { + rows.Close() + b.Errorf("mismatch") + } + rows.Close() + } +} + +func BenchmarkRoundtripBin(b *testing.B) { + b.StopTimer() + sample, min, max := initRoundtripBenchmarks() + b.ReportAllocs() + tb := (*TB)(b) + db := tb.checkDB(sql.Open("mysql", dsn)) + defer db.Close() + stmt := tb.checkStmt(db.Prepare("SELECT ?")) + defer stmt.Close() + b.StartTimer() + var result sql.RawBytes + for i := 0; i < b.N; i++ { + length := min + i + if length > max { + length = max + } + test := sample[0:length] + rows := tb.checkRows(stmt.Query(test)) + if !rows.Next() { + rows.Close() + b.Fatalf("crashed") + } + err := rows.Scan(&result) + if err != nil { + rows.Close() + b.Fatalf("crashed") + } + if !bytes.Equal(result, test) { + rows.Close() + b.Errorf("mismatch") + } + rows.Close() + } +} + +func BenchmarkInterpolation(b *testing.B) { + mc := &mysqlConn{ + cfg: &Config{ + InterpolateParams: true, + Loc: time.UTC, + }, + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + buf: newBuffer(nil), + } + + args := []driver.Value{ + int64(42424242), + float64(math.Pi), + false, + time.Unix(1423411542, 807015000), + []byte("bytes containing special chars ' \" \a \x00"), + "string containing special chars ' \" \a \x00", + } + q := "SELECT ?, ?, ?, ?, ?, ?" + + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, err := mc.interpolateParams(q, args) + if err != nil { + b.Fatal(err) + } + } +} + +func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0)) + + tb := (*TB)(b) + stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?")) + defer stmt.Close() + + b.SetParallelism(p) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + var got string + for pb.Next() { + tb.check(stmt.QueryRow(1).Scan(&got)) + if got != "one" { + b.Fatalf("query = %q; want one", got) + } + } + }) +} + +func BenchmarkQueryContext(b *testing.B) { + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + defer db.Close() + for _, p := range []int{1, 2, 3, 4} { + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + benchmarkQueryContext(b, db, p) + }) + } +} + +func benchmarkExecContext(b *testing.B, db *sql.DB, p int) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0)) + + tb := (*TB)(b) + stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1")) + defer stmt.Close() + + b.SetParallelism(p) + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if _, err := stmt.ExecContext(ctx); err != nil { + b.Fatal(err) + } + } + }) +} + +func BenchmarkExecContext(b *testing.B) { + db := initDB(b, + "DROP TABLE IF EXISTS foo", + "CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))", + `INSERT INTO foo VALUES (1, "one")`, + `INSERT INTO foo VALUES (2, "two")`, + ) + defer db.Close() + for _, p := range []int{1, 2, 3, 4} { + b.Run(fmt.Sprintf("%d", p), func(b *testing.B) { + benchmarkQueryContext(b, db, p) + }) + } +} + +// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes. +// "size=" means size of each blobs. +func BenchmarkQueryRawBytes(b *testing.B) { + var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000} + db := initDB(b, + "DROP TABLE IF EXISTS bench_rawbytes", + "CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)", + ) + defer db.Close() + + blob := make([]byte, sizes[len(sizes)-1]) + for i := range blob { + blob[i] = 42 + } + for i := 0; i < 100; i++ { + _, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob) + if err != nil { + b.Fatal(err) + } + } + + for _, s := range sizes { + b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) { + db.SetMaxIdleConns(0) + db.SetMaxIdleConns(1) + b.ReportAllocs() + b.ResetTimer() + + for j := 0; j < b.N; j++ { + rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s) + if err != nil { + b.Fatal(err) + } + nrows := 0 + for rows.Next() { + var buf sql.RawBytes + err := rows.Scan(&buf) + if err != nil { + b.Fatal(err) + } + if len(buf) != s { + b.Fatalf("size mismatch: expected %v, got %v", s, len(buf)) + } + nrows++ + } + rows.Close() + if nrows != 100 { + b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows) + } + } + }) + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/buffer.go b/vendor/github.com/go-sql-driver/mysql/buffer.go index eb4748bf..0774c5c8 100644 --- a/vendor/github.com/go-sql-driver/mysql/buffer.go +++ b/vendor/github.com/go-sql-driver/mysql/buffer.go @@ -15,47 +15,69 @@ import ( ) const defaultBufSize = 4096 +const maxCachedBufSize = 256 * 1024 // A buffer which is used for both reading and writing. // This is possible since communication on each connection is synchronous. // In other words, we can't write and read simultaneously on the same connection. // The buffer is similar to bufio.Reader / Writer but zero-copy-ish // Also highly optimized for this particular use case. +// This buffer is backed by two byte slices in a double-buffering scheme type buffer struct { - buf []byte + buf []byte // buf is a byte buffer who's length and capacity are equal. nc net.Conn idx int length int timeout time.Duration + dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer + flipcnt uint // flipccnt is the current buffer counter for double-buffering } +// newBuffer allocates and returns a new buffer. func newBuffer(nc net.Conn) buffer { - var b [defaultBufSize]byte + fg := make([]byte, defaultBufSize) return buffer{ - buf: b[:], - nc: nc, + buf: fg, + nc: nc, + dbuf: [2][]byte{fg, nil}, } } +// flip replaces the active buffer with the background buffer +// this is a delayed flip that simply increases the buffer counter; +// the actual flip will be performed the next time we call `buffer.fill` +func (b *buffer) flip() { + b.flipcnt += 1 +} + // fill reads into the buffer until at least _need_ bytes are in it func (b *buffer) fill(need int) error { n := b.length + // fill data into its double-buffering target: if we've called + // flip on this buffer, we'll be copying to the background buffer, + // and then filling it with network data; otherwise we'll just move + // the contents of the current buffer to the front before filling it + dest := b.dbuf[b.flipcnt&1] - // move existing data to the beginning - if n > 0 && b.idx > 0 { - copy(b.buf[0:n], b.buf[b.idx:]) - } - - // grow buffer if necessary - // TODO: let the buffer shrink again at some point - // Maybe keep the org buf slice and swap back? - if need > len(b.buf) { + // grow buffer if necessary to fit the whole packet. + if need > len(dest) { // Round up to the next multiple of the default size - newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) - copy(newBuf, b.buf) - b.buf = newBuf + dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize) + + // if the allocated buffer is not too large, move it to backing storage + // to prevent extra allocations on applications that perform large reads + if len(dest) <= maxCachedBufSize { + b.dbuf[b.flipcnt&1] = dest + } } + // if we're filling the fg buffer, move the existing data to the start of it. + // if we're filling the bg buffer, copy over the data + if n > 0 { + copy(dest[:n], b.buf[b.idx:]) + } + + b.buf = dest b.idx = 0 for { @@ -105,43 +127,56 @@ func (b *buffer) readNext(need int) ([]byte, error) { return b.buf[offset:b.idx], nil } -// returns a buffer with the requested size. +// takeBuffer returns a buffer with the requested size. // If possible, a slice from the existing buffer is returned. // Otherwise a bigger buffer is made. // Only one buffer (total) can be used at a time. -func (b *buffer) takeBuffer(length int) []byte { +func (b *buffer) takeBuffer(length int) ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } // test (cheap) general case first - if length <= defaultBufSize || length <= cap(b.buf) { - return b.buf[:length] + if length <= cap(b.buf) { + return b.buf[:length], nil } if length < maxPacketSize { b.buf = make([]byte, length) - return b.buf + return b.buf, nil } - return make([]byte, length) + + // buffer is larger than we want to store. + return make([]byte, length), nil } -// shortcut which can be used if the requested buffer is guaranteed to be -// smaller than defaultBufSize +// takeSmallBuffer is shortcut which can be used if length is +// known to be smaller than defaultBufSize. // Only one buffer (total) can be used at a time. -func (b *buffer) takeSmallBuffer(length int) []byte { +func (b *buffer) takeSmallBuffer(length int) ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } - return b.buf[:length] + return b.buf[:length], nil } // takeCompleteBuffer returns the complete existing buffer. // This can be used if the necessary buffer size is unknown. +// cap and len of the returned buffer will be equal. // Only one buffer (total) can be used at a time. -func (b *buffer) takeCompleteBuffer() []byte { +func (b *buffer) takeCompleteBuffer() ([]byte, error) { if b.length > 0 { - return nil + return nil, ErrBusyBuffer } - return b.buf + return b.buf, nil +} + +// store stores buf, an updated buffer, if its suitable to do so. +func (b *buffer) store(buf []byte) error { + if b.length > 0 { + return ErrBusyBuffer + } else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) { + b.buf = buf[:cap(buf)] + } + return nil } diff --git a/vendor/github.com/go-sql-driver/mysql/collations.go b/vendor/github.com/go-sql-driver/mysql/collations.go index 136c9e4d..8d2b5567 100644 --- a/vendor/github.com/go-sql-driver/mysql/collations.go +++ b/vendor/github.com/go-sql-driver/mysql/collations.go @@ -8,183 +8,190 @@ package mysql -const defaultCollation = "utf8_general_ci" +const defaultCollation = "utf8mb4_general_ci" const binaryCollation = "binary" // A list of available collations mapped to the internal ID. // To update this map use the following MySQL query: -// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS +// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID +// +// Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255. +// +// ucs2, utf16, and utf32 can't be used for connection charset. +// https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset +// They are commented out to reduce this map. var collations = map[string]byte{ - "big5_chinese_ci": 1, - "latin2_czech_cs": 2, - "dec8_swedish_ci": 3, - "cp850_general_ci": 4, - "latin1_german1_ci": 5, - "hp8_english_ci": 6, - "koi8r_general_ci": 7, - "latin1_swedish_ci": 8, - "latin2_general_ci": 9, - "swe7_swedish_ci": 10, - "ascii_general_ci": 11, - "ujis_japanese_ci": 12, - "sjis_japanese_ci": 13, - "cp1251_bulgarian_ci": 14, - "latin1_danish_ci": 15, - "hebrew_general_ci": 16, - "tis620_thai_ci": 18, - "euckr_korean_ci": 19, - "latin7_estonian_cs": 20, - "latin2_hungarian_ci": 21, - "koi8u_general_ci": 22, - "cp1251_ukrainian_ci": 23, - "gb2312_chinese_ci": 24, - "greek_general_ci": 25, - "cp1250_general_ci": 26, - "latin2_croatian_ci": 27, - "gbk_chinese_ci": 28, - "cp1257_lithuanian_ci": 29, - "latin5_turkish_ci": 30, - "latin1_german2_ci": 31, - "armscii8_general_ci": 32, - "utf8_general_ci": 33, - "cp1250_czech_cs": 34, - "ucs2_general_ci": 35, - "cp866_general_ci": 36, - "keybcs2_general_ci": 37, - "macce_general_ci": 38, - "macroman_general_ci": 39, - "cp852_general_ci": 40, - "latin7_general_ci": 41, - "latin7_general_cs": 42, - "macce_bin": 43, - "cp1250_croatian_ci": 44, - "utf8mb4_general_ci": 45, - "utf8mb4_bin": 46, - "latin1_bin": 47, - "latin1_general_ci": 48, - "latin1_general_cs": 49, - "cp1251_bin": 50, - "cp1251_general_ci": 51, - "cp1251_general_cs": 52, - "macroman_bin": 53, - "utf16_general_ci": 54, - "utf16_bin": 55, - "utf16le_general_ci": 56, - "cp1256_general_ci": 57, - "cp1257_bin": 58, - "cp1257_general_ci": 59, - "utf32_general_ci": 60, - "utf32_bin": 61, - "utf16le_bin": 62, - "binary": 63, - "armscii8_bin": 64, - "ascii_bin": 65, - "cp1250_bin": 66, - "cp1256_bin": 67, - "cp866_bin": 68, - "dec8_bin": 69, - "greek_bin": 70, - "hebrew_bin": 71, - "hp8_bin": 72, - "keybcs2_bin": 73, - "koi8r_bin": 74, - "koi8u_bin": 75, - "latin2_bin": 77, - "latin5_bin": 78, - "latin7_bin": 79, - "cp850_bin": 80, - "cp852_bin": 81, - "swe7_bin": 82, - "utf8_bin": 83, - "big5_bin": 84, - "euckr_bin": 85, - "gb2312_bin": 86, - "gbk_bin": 87, - "sjis_bin": 88, - "tis620_bin": 89, - "ucs2_bin": 90, - "ujis_bin": 91, - "geostd8_general_ci": 92, - "geostd8_bin": 93, - "latin1_spanish_ci": 94, - "cp932_japanese_ci": 95, - "cp932_bin": 96, - "eucjpms_japanese_ci": 97, - "eucjpms_bin": 98, - "cp1250_polish_ci": 99, - "utf16_unicode_ci": 101, - "utf16_icelandic_ci": 102, - "utf16_latvian_ci": 103, - "utf16_romanian_ci": 104, - "utf16_slovenian_ci": 105, - "utf16_polish_ci": 106, - "utf16_estonian_ci": 107, - "utf16_spanish_ci": 108, - "utf16_swedish_ci": 109, - "utf16_turkish_ci": 110, - "utf16_czech_ci": 111, - "utf16_danish_ci": 112, - "utf16_lithuanian_ci": 113, - "utf16_slovak_ci": 114, - "utf16_spanish2_ci": 115, - "utf16_roman_ci": 116, - "utf16_persian_ci": 117, - "utf16_esperanto_ci": 118, - "utf16_hungarian_ci": 119, - "utf16_sinhala_ci": 120, - "utf16_german2_ci": 121, - "utf16_croatian_ci": 122, - "utf16_unicode_520_ci": 123, - "utf16_vietnamese_ci": 124, - "ucs2_unicode_ci": 128, - "ucs2_icelandic_ci": 129, - "ucs2_latvian_ci": 130, - "ucs2_romanian_ci": 131, - "ucs2_slovenian_ci": 132, - "ucs2_polish_ci": 133, - "ucs2_estonian_ci": 134, - "ucs2_spanish_ci": 135, - "ucs2_swedish_ci": 136, - "ucs2_turkish_ci": 137, - "ucs2_czech_ci": 138, - "ucs2_danish_ci": 139, - "ucs2_lithuanian_ci": 140, - "ucs2_slovak_ci": 141, - "ucs2_spanish2_ci": 142, - "ucs2_roman_ci": 143, - "ucs2_persian_ci": 144, - "ucs2_esperanto_ci": 145, - "ucs2_hungarian_ci": 146, - "ucs2_sinhala_ci": 147, - "ucs2_german2_ci": 148, - "ucs2_croatian_ci": 149, - "ucs2_unicode_520_ci": 150, - "ucs2_vietnamese_ci": 151, - "ucs2_general_mysql500_ci": 159, - "utf32_unicode_ci": 160, - "utf32_icelandic_ci": 161, - "utf32_latvian_ci": 162, - "utf32_romanian_ci": 163, - "utf32_slovenian_ci": 164, - "utf32_polish_ci": 165, - "utf32_estonian_ci": 166, - "utf32_spanish_ci": 167, - "utf32_swedish_ci": 168, - "utf32_turkish_ci": 169, - "utf32_czech_ci": 170, - "utf32_danish_ci": 171, - "utf32_lithuanian_ci": 172, - "utf32_slovak_ci": 173, - "utf32_spanish2_ci": 174, - "utf32_roman_ci": 175, - "utf32_persian_ci": 176, - "utf32_esperanto_ci": 177, - "utf32_hungarian_ci": 178, - "utf32_sinhala_ci": 179, - "utf32_german2_ci": 180, - "utf32_croatian_ci": 181, - "utf32_unicode_520_ci": 182, - "utf32_vietnamese_ci": 183, + "big5_chinese_ci": 1, + "latin2_czech_cs": 2, + "dec8_swedish_ci": 3, + "cp850_general_ci": 4, + "latin1_german1_ci": 5, + "hp8_english_ci": 6, + "koi8r_general_ci": 7, + "latin1_swedish_ci": 8, + "latin2_general_ci": 9, + "swe7_swedish_ci": 10, + "ascii_general_ci": 11, + "ujis_japanese_ci": 12, + "sjis_japanese_ci": 13, + "cp1251_bulgarian_ci": 14, + "latin1_danish_ci": 15, + "hebrew_general_ci": 16, + "tis620_thai_ci": 18, + "euckr_korean_ci": 19, + "latin7_estonian_cs": 20, + "latin2_hungarian_ci": 21, + "koi8u_general_ci": 22, + "cp1251_ukrainian_ci": 23, + "gb2312_chinese_ci": 24, + "greek_general_ci": 25, + "cp1250_general_ci": 26, + "latin2_croatian_ci": 27, + "gbk_chinese_ci": 28, + "cp1257_lithuanian_ci": 29, + "latin5_turkish_ci": 30, + "latin1_german2_ci": 31, + "armscii8_general_ci": 32, + "utf8_general_ci": 33, + "cp1250_czech_cs": 34, + //"ucs2_general_ci": 35, + "cp866_general_ci": 36, + "keybcs2_general_ci": 37, + "macce_general_ci": 38, + "macroman_general_ci": 39, + "cp852_general_ci": 40, + "latin7_general_ci": 41, + "latin7_general_cs": 42, + "macce_bin": 43, + "cp1250_croatian_ci": 44, + "utf8mb4_general_ci": 45, + "utf8mb4_bin": 46, + "latin1_bin": 47, + "latin1_general_ci": 48, + "latin1_general_cs": 49, + "cp1251_bin": 50, + "cp1251_general_ci": 51, + "cp1251_general_cs": 52, + "macroman_bin": 53, + //"utf16_general_ci": 54, + //"utf16_bin": 55, + //"utf16le_general_ci": 56, + "cp1256_general_ci": 57, + "cp1257_bin": 58, + "cp1257_general_ci": 59, + //"utf32_general_ci": 60, + //"utf32_bin": 61, + //"utf16le_bin": 62, + "binary": 63, + "armscii8_bin": 64, + "ascii_bin": 65, + "cp1250_bin": 66, + "cp1256_bin": 67, + "cp866_bin": 68, + "dec8_bin": 69, + "greek_bin": 70, + "hebrew_bin": 71, + "hp8_bin": 72, + "keybcs2_bin": 73, + "koi8r_bin": 74, + "koi8u_bin": 75, + "utf8_tolower_ci": 76, + "latin2_bin": 77, + "latin5_bin": 78, + "latin7_bin": 79, + "cp850_bin": 80, + "cp852_bin": 81, + "swe7_bin": 82, + "utf8_bin": 83, + "big5_bin": 84, + "euckr_bin": 85, + "gb2312_bin": 86, + "gbk_bin": 87, + "sjis_bin": 88, + "tis620_bin": 89, + //"ucs2_bin": 90, + "ujis_bin": 91, + "geostd8_general_ci": 92, + "geostd8_bin": 93, + "latin1_spanish_ci": 94, + "cp932_japanese_ci": 95, + "cp932_bin": 96, + "eucjpms_japanese_ci": 97, + "eucjpms_bin": 98, + "cp1250_polish_ci": 99, + //"utf16_unicode_ci": 101, + //"utf16_icelandic_ci": 102, + //"utf16_latvian_ci": 103, + //"utf16_romanian_ci": 104, + //"utf16_slovenian_ci": 105, + //"utf16_polish_ci": 106, + //"utf16_estonian_ci": 107, + //"utf16_spanish_ci": 108, + //"utf16_swedish_ci": 109, + //"utf16_turkish_ci": 110, + //"utf16_czech_ci": 111, + //"utf16_danish_ci": 112, + //"utf16_lithuanian_ci": 113, + //"utf16_slovak_ci": 114, + //"utf16_spanish2_ci": 115, + //"utf16_roman_ci": 116, + //"utf16_persian_ci": 117, + //"utf16_esperanto_ci": 118, + //"utf16_hungarian_ci": 119, + //"utf16_sinhala_ci": 120, + //"utf16_german2_ci": 121, + //"utf16_croatian_ci": 122, + //"utf16_unicode_520_ci": 123, + //"utf16_vietnamese_ci": 124, + //"ucs2_unicode_ci": 128, + //"ucs2_icelandic_ci": 129, + //"ucs2_latvian_ci": 130, + //"ucs2_romanian_ci": 131, + //"ucs2_slovenian_ci": 132, + //"ucs2_polish_ci": 133, + //"ucs2_estonian_ci": 134, + //"ucs2_spanish_ci": 135, + //"ucs2_swedish_ci": 136, + //"ucs2_turkish_ci": 137, + //"ucs2_czech_ci": 138, + //"ucs2_danish_ci": 139, + //"ucs2_lithuanian_ci": 140, + //"ucs2_slovak_ci": 141, + //"ucs2_spanish2_ci": 142, + //"ucs2_roman_ci": 143, + //"ucs2_persian_ci": 144, + //"ucs2_esperanto_ci": 145, + //"ucs2_hungarian_ci": 146, + //"ucs2_sinhala_ci": 147, + //"ucs2_german2_ci": 148, + //"ucs2_croatian_ci": 149, + //"ucs2_unicode_520_ci": 150, + //"ucs2_vietnamese_ci": 151, + //"ucs2_general_mysql500_ci": 159, + //"utf32_unicode_ci": 160, + //"utf32_icelandic_ci": 161, + //"utf32_latvian_ci": 162, + //"utf32_romanian_ci": 163, + //"utf32_slovenian_ci": 164, + //"utf32_polish_ci": 165, + //"utf32_estonian_ci": 166, + //"utf32_spanish_ci": 167, + //"utf32_swedish_ci": 168, + //"utf32_turkish_ci": 169, + //"utf32_czech_ci": 170, + //"utf32_danish_ci": 171, + //"utf32_lithuanian_ci": 172, + //"utf32_slovak_ci": 173, + //"utf32_spanish2_ci": 174, + //"utf32_roman_ci": 175, + //"utf32_persian_ci": 176, + //"utf32_esperanto_ci": 177, + //"utf32_hungarian_ci": 178, + //"utf32_sinhala_ci": 179, + //"utf32_german2_ci": 180, + //"utf32_croatian_ci": 181, + //"utf32_unicode_520_ci": 182, + //"utf32_vietnamese_ci": 183, "utf8_unicode_ci": 192, "utf8_icelandic_ci": 193, "utf8_latvian_ci": 194, @@ -234,18 +241,25 @@ var collations = map[string]byte{ "utf8mb4_croatian_ci": 245, "utf8mb4_unicode_520_ci": 246, "utf8mb4_vietnamese_ci": 247, + "gb18030_chinese_ci": 248, + "gb18030_bin": 249, + "gb18030_unicode_520_ci": 250, + "utf8mb4_0900_ai_ci": 255, } // A blacklist of collations which is unsafe to interpolate parameters. // These multibyte encodings may contains 0x5c (`\`) in their trailing bytes. var unsafeCollations = map[string]bool{ - "big5_chinese_ci": true, - "sjis_japanese_ci": true, - "gbk_chinese_ci": true, - "big5_bin": true, - "gb2312_bin": true, - "gbk_bin": true, - "sjis_bin": true, - "cp932_japanese_ci": true, - "cp932_bin": true, + "big5_chinese_ci": true, + "sjis_japanese_ci": true, + "gbk_chinese_ci": true, + "big5_bin": true, + "gb2312_bin": true, + "gbk_bin": true, + "sjis_bin": true, + "cp932_japanese_ci": true, + "cp932_bin": true, + "gb18030_chinese_ci": true, + "gb18030_bin": true, + "gb18030_unicode_520_ci": true, } diff --git a/vendor/github.com/go-sql-driver/mysql/conncheck.go b/vendor/github.com/go-sql-driver/mysql/conncheck.go new file mode 100644 index 00000000..cc47aa55 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/conncheck.go @@ -0,0 +1,53 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build !windows,!appengine + +package mysql + +import ( + "errors" + "io" + "net" + "syscall" +) + +var errUnexpectedRead = errors.New("unexpected read from socket") + +func connCheck(c net.Conn) error { + var ( + n int + err error + buff [1]byte + ) + + sconn, ok := c.(syscall.Conn) + if !ok { + return nil + } + rc, err := sconn.SyscallConn() + if err != nil { + return err + } + rerr := rc.Read(func(fd uintptr) bool { + n, err = syscall.Read(int(fd), buff[:]) + return true + }) + switch { + case rerr != nil: + return rerr + case n == 0 && err == nil: + return io.EOF + case n > 0: + return errUnexpectedRead + case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK: + return nil + default: + return err + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/conncheck_dummy.go b/vendor/github.com/go-sql-driver/mysql/conncheck_dummy.go new file mode 100644 index 00000000..fd01f64c --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/conncheck_dummy.go @@ -0,0 +1,17 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build windows appengine + +package mysql + +import "net" + +func connCheck(c net.Conn) error { + return nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/conncheck_test.go b/vendor/github.com/go-sql-driver/mysql/conncheck_test.go new file mode 100644 index 00000000..b7234b0f --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/conncheck_test.go @@ -0,0 +1,38 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.10,!windows + +package mysql + +import ( + "testing" + "time" +) + +func TestStaleConnectionChecks(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("SET @@SESSION.wait_timeout = 2") + + if err := dbt.db.Ping(); err != nil { + dbt.Fatal(err) + } + + // wait for MySQL to close our connection + time.Sleep(3 * time.Second) + + tx, err := dbt.db.Begin() + if err != nil { + dbt.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + dbt.Fatal(err) + } + }) +} diff --git a/vendor/github.com/go-sql-driver/mysql/connection.go b/vendor/github.com/go-sql-driver/mysql/connection.go index e5706141..565a5480 100644 --- a/vendor/github.com/go-sql-driver/mysql/connection.go +++ b/vendor/github.com/go-sql-driver/mysql/connection.go @@ -9,6 +9,8 @@ package mysql import ( + "context" + "database/sql" "database/sql/driver" "io" "net" @@ -17,19 +19,10 @@ import ( "time" ) -// a copy of context.Context for Go 1.7 and earlier -type mysqlContext interface { - Done() <-chan struct{} - Err() error - - // defined in context.Context, but not used in this driver: - // Deadline() (deadline time.Time, ok bool) - // Value(key interface{}) interface{} -} - type mysqlConn struct { buf buffer netConn net.Conn + rawConn net.Conn // underlying connection when netConn is TLS connection. affectedRows uint64 insertId uint64 cfg *Config @@ -40,10 +33,11 @@ type mysqlConn struct { status statusFlag sequence uint8 parseTime bool + reset bool // set when the Go SQL package calls ResetSession // for context support (Go 1.8+) watching bool - watcher chan<- mysqlContext + watcher chan<- context.Context closech chan struct{} finished chan<- struct{} canceled atomicError // set non-nil if conn is canceled @@ -190,10 +184,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin return "", driver.ErrSkip } - buf := mc.buf.takeCompleteBuffer() - if buf == nil { + buf, err := mc.buf.takeCompleteBuffer() + if err != nil { // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return "", ErrInvalidConn } buf = buf[:0] @@ -219,6 +213,9 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin switch v := arg.(type) { case int64: buf = strconv.AppendInt(buf, v, 10) + case uint64: + // Handle uint64 explicitly because our custom ConvertValue emits unsigned values + buf = strconv.AppendUint(buf, v, 10) case float64: buf = strconv.AppendFloat(buf, v, 'g', -1, 64) case bool: @@ -459,3 +456,194 @@ func (mc *mysqlConn) finish() { case <-mc.closech: } } + +// Ping implements driver.Pinger interface +func (mc *mysqlConn) Ping(ctx context.Context) (err error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + if err = mc.watchCancel(ctx); err != nil { + return + } + defer mc.finish() + + if err = mc.writeCommandPacket(comPing); err != nil { + return mc.markBadConn(err) + } + + return mc.readResultOK() +} + +// BeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { + level, err := mapIsolationLevel(opts.Isolation) + if err != nil { + return nil, err + } + err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) + if err != nil { + return nil, err + } + } + + return mc.begin(opts.ReadOnly) +} + +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := mc.query(query, dargs) + if err != nil { + mc.finish() + return nil, err + } + rows.finish = mc.finish + return rows, err +} + +func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + return mc.Exec(query, dargs) +} + +func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + stmt, err := mc.Prepare(query) + mc.finish() + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + stmt.Close() + return nil, ctx.Err() + } + return stmt, nil +} + +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := stmt.query(dargs) + if err != nil { + stmt.mc.finish() + return nil, err + } + rows.finish = stmt.mc.finish + return rows, err +} + +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + defer stmt.mc.finish() + + return stmt.Exec(dargs) +} + +func (mc *mysqlConn) watchCancel(ctx context.Context) error { + if mc.watching { + // Reach here if canceled, + // so the connection is already invalid + mc.cleanup() + return nil + } + // When ctx is already cancelled, don't watch it. + if err := ctx.Err(); err != nil { + return err + } + // When ctx is not cancellable, don't watch it. + if ctx.Done() == nil { + return nil + } + // When watcher is not alive, can't watch it. + if mc.watcher == nil { + return nil + } + + mc.watching = true + mc.watcher <- ctx + return nil +} + +func (mc *mysqlConn) startWatcher() { + watcher := make(chan context.Context, 1) + mc.watcher = watcher + finished := make(chan struct{}) + mc.finished = finished + go func() { + for { + var ctx context.Context + select { + case ctx = <-watcher: + case <-mc.closech: + return + } + + select { + case <-ctx.Done(): + mc.cancel(ctx.Err()) + case <-finished: + case <-mc.closech: + return + } + } + }() +} + +func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} + +// ResetSession implements driver.SessionResetter. +// (From Go 1.10) +func (mc *mysqlConn) ResetSession(ctx context.Context) error { + if mc.closed.IsSet() { + return driver.ErrBadConn + } + mc.reset = true + return nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/connection_go18.go b/vendor/github.com/go-sql-driver/mysql/connection_go18.go deleted file mode 100644 index 62796bfc..00000000 --- a/vendor/github.com/go-sql-driver/mysql/connection_go18.go +++ /dev/null @@ -1,208 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// +build go1.8 - -package mysql - -import ( - "context" - "database/sql" - "database/sql/driver" -) - -// Ping implements driver.Pinger interface -func (mc *mysqlConn) Ping(ctx context.Context) (err error) { - if mc.closed.IsSet() { - errLog.Print(ErrInvalidConn) - return driver.ErrBadConn - } - - if err = mc.watchCancel(ctx); err != nil { - return - } - defer mc.finish() - - if err = mc.writeCommandPacket(comPing); err != nil { - return - } - - return mc.readResultOK() -} - -// BeginTx implements driver.ConnBeginTx interface -func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { - if err := mc.watchCancel(ctx); err != nil { - return nil, err - } - defer mc.finish() - - if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { - level, err := mapIsolationLevel(opts.Isolation) - if err != nil { - return nil, err - } - err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) - if err != nil { - return nil, err - } - } - - return mc.begin(opts.ReadOnly) -} - -func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { - dargs, err := namedValueToValue(args) - if err != nil { - return nil, err - } - - if err := mc.watchCancel(ctx); err != nil { - return nil, err - } - - rows, err := mc.query(query, dargs) - if err != nil { - mc.finish() - return nil, err - } - rows.finish = mc.finish - return rows, err -} - -func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { - dargs, err := namedValueToValue(args) - if err != nil { - return nil, err - } - - if err := mc.watchCancel(ctx); err != nil { - return nil, err - } - defer mc.finish() - - return mc.Exec(query, dargs) -} - -func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { - if err := mc.watchCancel(ctx); err != nil { - return nil, err - } - - stmt, err := mc.Prepare(query) - mc.finish() - if err != nil { - return nil, err - } - - select { - default: - case <-ctx.Done(): - stmt.Close() - return nil, ctx.Err() - } - return stmt, nil -} - -func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { - dargs, err := namedValueToValue(args) - if err != nil { - return nil, err - } - - if err := stmt.mc.watchCancel(ctx); err != nil { - return nil, err - } - - rows, err := stmt.query(dargs) - if err != nil { - stmt.mc.finish() - return nil, err - } - rows.finish = stmt.mc.finish - return rows, err -} - -func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { - dargs, err := namedValueToValue(args) - if err != nil { - return nil, err - } - - if err := stmt.mc.watchCancel(ctx); err != nil { - return nil, err - } - defer stmt.mc.finish() - - return stmt.Exec(dargs) -} - -func (mc *mysqlConn) watchCancel(ctx context.Context) error { - if mc.watching { - // Reach here if canceled, - // so the connection is already invalid - mc.cleanup() - return nil - } - if ctx.Done() == nil { - return nil - } - - mc.watching = true - select { - default: - case <-ctx.Done(): - return ctx.Err() - } - if mc.watcher == nil { - return nil - } - - mc.watcher <- ctx - - return nil -} - -func (mc *mysqlConn) startWatcher() { - watcher := make(chan mysqlContext, 1) - mc.watcher = watcher - finished := make(chan struct{}) - mc.finished = finished - go func() { - for { - var ctx mysqlContext - select { - case ctx = <-watcher: - case <-mc.closech: - return - } - - select { - case <-ctx.Done(): - mc.cancel(ctx.Err()) - case <-finished: - case <-mc.closech: - return - } - } - }() -} - -func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { - nv.Value, err = converter{}.ConvertValue(nv.Value) - return -} - -// ResetSession implements driver.SessionResetter. -// (From Go 1.10) -func (mc *mysqlConn) ResetSession(ctx context.Context) error { - if mc.closed.IsSet() { - return driver.ErrBadConn - } - return nil -} diff --git a/vendor/github.com/go-sql-driver/mysql/connection_test.go b/vendor/github.com/go-sql-driver/mysql/connection_test.go new file mode 100644 index 00000000..19c17ff8 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/connection_test.go @@ -0,0 +1,175 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "context" + "database/sql/driver" + "errors" + "net" + "testing" +) + +func TestInterpolateParams(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"}) + if err != nil { + t.Errorf("Expected err=nil, got %#v", err) + return + } + expected := `SELECT 42+'gopher'` + if q != expected { + t.Errorf("Expected: %q\nGot: %q", expected, q) + } +} + +func TestInterpolateParamsTooManyPlaceholders(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)}) + if err != driver.ErrSkip { + t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q) + } +} + +// We don't support placeholder in string literal for now. +// https://github.com/go-sql-driver/mysql/pull/490 +func TestInterpolateParamsPlaceholderInString(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)}) + // When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42` + if err != driver.ErrSkip { + t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q) + } +} + +func TestInterpolateParamsUint64(t *testing.T) { + mc := &mysqlConn{ + buf: newBuffer(nil), + maxAllowedPacket: maxPacketSize, + cfg: &Config{ + InterpolateParams: true, + }, + } + + q, err := mc.interpolateParams("SELECT ?", []driver.Value{uint64(42)}) + if err != nil { + t.Errorf("Expected err=nil, got err=%#v, q=%#v", err, q) + } + if q != "SELECT 42" { + t.Errorf("Expected uint64 interpolation to work, got q=%#v", q) + } +} + +func TestCheckNamedValue(t *testing.T) { + value := driver.NamedValue{Value: ^uint64(0)} + x := &mysqlConn{} + err := x.CheckNamedValue(&value) + + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if value.Value != ^uint64(0) { + t.Fatalf("uint64 high-bit converted, got %#v %T", value.Value, value.Value) + } +} + +// TestCleanCancel tests passed context is cancelled at start. +// No packet should be sent. Connection should keep current status. +func TestCleanCancel(t *testing.T) { + mc := &mysqlConn{ + closech: make(chan struct{}), + } + mc.startWatcher() + defer mc.cleanup() + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + for i := 0; i < 3; i++ { // Repeat same behavior + err := mc.Ping(ctx) + if err != context.Canceled { + t.Errorf("expected context.Canceled, got %#v", err) + } + + if mc.closed.IsSet() { + t.Error("expected mc is not closed, closed actually") + } + + if mc.watching { + t.Error("expected watching is false, but true") + } + } +} + +func TestPingMarkBadConnection(t *testing.T) { + nc := badConnection{err: errors.New("boom")} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + } + + err := ms.Ping(context.Background()) + + if err != driver.ErrBadConn { + t.Errorf("expected driver.ErrBadConn, got %#v", err) + } +} + +func TestPingErrInvalidConn(t *testing.T) { + nc := badConnection{err: errors.New("failed to write"), n: 10} + ms := &mysqlConn{ + netConn: nc, + buf: newBuffer(nc), + maxAllowedPacket: defaultMaxAllowedPacket, + closech: make(chan struct{}), + } + + err := ms.Ping(context.Background()) + + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %#v", err) + } +} + +type badConnection struct { + n int + err error + net.Conn +} + +func (bc badConnection) Write(b []byte) (n int, err error) { + return bc.n, bc.err +} + +func (bc badConnection) Close() error { + return nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/connector.go b/vendor/github.com/go-sql-driver/mysql/connector.go new file mode 100644 index 00000000..5aaaba43 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/connector.go @@ -0,0 +1,143 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "context" + "database/sql/driver" + "net" +) + +type connector struct { + cfg *Config // immutable private copy. +} + +// Connect implements driver.Connector interface. +// Connect returns a connection to the database. +func (c *connector) Connect(ctx context.Context) (driver.Conn, error) { + var err error + + // New mysqlConn + mc := &mysqlConn{ + maxAllowedPacket: maxPacketSize, + maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), + cfg: c.cfg, + } + mc.parseTime = mc.cfg.ParseTime + + // Connect to Server + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { + mc.netConn, err = dial(ctx, mc.cfg.Addr) + } else { + nd := net.Dialer{Timeout: mc.cfg.Timeout} + mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr) + } + + if err != nil { + if nerr, ok := err.(net.Error); ok && nerr.Temporary() { + errLog.Print("net.Error from Dial()': ", nerr.Error()) + return nil, driver.ErrBadConn + } + return nil, err + } + + // Enable TCP Keepalives on TCP connections + if tc, ok := mc.netConn.(*net.TCPConn); ok { + if err := tc.SetKeepAlive(true); err != nil { + // Don't send COM_QUIT before handshake. + mc.netConn.Close() + mc.netConn = nil + return nil, err + } + } + + // Call startWatcher for context support (From Go 1.8) + mc.startWatcher() + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + mc.buf = newBuffer(mc.netConn) + + // Set I/O timeouts + mc.buf.timeout = mc.cfg.ReadTimeout + mc.writeTimeout = mc.cfg.WriteTimeout + + // Reading Handshake Initialization Packet + authData, plugin, err := mc.readHandshakePacket() + if err != nil { + mc.cleanup() + return nil, err + } + + if plugin == "" { + plugin = defaultAuthPlugin + } + + // Send Client Authentication Packet + authResp, err := mc.auth(authData, plugin) + if err != nil { + // try the default auth plugin, if using the requested plugin failed + errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) + plugin = defaultAuthPlugin + authResp, err = mc.auth(authData, plugin) + if err != nil { + mc.cleanup() + return nil, err + } + } + if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil { + mc.cleanup() + return nil, err + } + + // Handle response to auth packet, switch methods if possible + if err = mc.handleAuthResult(authData, plugin); err != nil { + // Authentication failed and MySQL has already closed the connection + // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). + // Do not send COM_QUIT, just cleanup and return the error. + mc.cleanup() + return nil, err + } + + if mc.cfg.MaxAllowedPacket > 0 { + mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket + } else { + // Get max allowed packet size + maxap, err := mc.getSystemVar("max_allowed_packet") + if err != nil { + mc.Close() + return nil, err + } + mc.maxAllowedPacket = stringToInt(maxap) - 1 + } + if mc.maxAllowedPacket < maxPacketSize { + mc.maxWriteSize = mc.maxAllowedPacket + } + + // Handle DSN Params + err = mc.handleParams() + if err != nil { + mc.Close() + return nil, err + } + + return mc, nil +} + +// Driver implements driver.Connector interface. +// Driver returns &MySQLDriver{}. +func (c *connector) Driver() driver.Driver { + return &MySQLDriver{} +} diff --git a/vendor/github.com/go-sql-driver/mysql/driver.go b/vendor/github.com/go-sql-driver/mysql/driver.go index 1a75a16e..1f9decf8 100644 --- a/vendor/github.com/go-sql-driver/mysql/driver.go +++ b/vendor/github.com/go-sql-driver/mysql/driver.go @@ -17,151 +17,67 @@ package mysql import ( + "context" "database/sql" "database/sql/driver" "net" "sync" ) -// watcher interface is used for context support (From Go 1.8) -type watcher interface { - startWatcher() -} - // MySQLDriver is exported to make the driver directly accessible. // In general the driver is used via the database/sql package. type MySQLDriver struct{} // DialFunc is a function which can be used to establish the network connection. // Custom dial functions must be registered with RegisterDial +// +// Deprecated: users should register a DialContextFunc instead type DialFunc func(addr string) (net.Conn, error) +// DialContextFunc is a function which can be used to establish the network connection. +// Custom dial functions must be registered with RegisterDialContext +type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error) + var ( dialsLock sync.RWMutex - dials map[string]DialFunc + dials map[string]DialContextFunc ) -// RegisterDial registers a custom dial function. It can then be used by the +// RegisterDialContext registers a custom dial function. It can then be used by the // network address mynet(addr), where mynet is the registered new network. -// addr is passed as a parameter to the dial function. -func RegisterDial(net string, dial DialFunc) { +// The current context for the connection and its address is passed to the dial function. +func RegisterDialContext(net string, dial DialContextFunc) { dialsLock.Lock() defer dialsLock.Unlock() if dials == nil { - dials = make(map[string]DialFunc) + dials = make(map[string]DialContextFunc) } dials[net] = dial } +// RegisterDial registers a custom dial function. It can then be used by the +// network address mynet(addr), where mynet is the registered new network. +// addr is passed as a parameter to the dial function. +// +// Deprecated: users should call RegisterDialContext instead +func RegisterDial(network string, dial DialFunc) { + RegisterDialContext(network, func(_ context.Context, addr string) (net.Conn, error) { + return dial(addr) + }) +} + // Open new Connection. // See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how -// the DSN string is formated +// the DSN string is formatted func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { - var err error - - // New mysqlConn - mc := &mysqlConn{ - maxAllowedPacket: maxPacketSize, - maxWriteSize: maxPacketSize - 1, - closech: make(chan struct{}), - } - mc.cfg, err = ParseDSN(dsn) + cfg, err := ParseDSN(dsn) if err != nil { return nil, err } - mc.parseTime = mc.cfg.ParseTime - - // Connect to Server - dialsLock.RLock() - dial, ok := dials[mc.cfg.Net] - dialsLock.RUnlock() - if ok { - mc.netConn, err = dial(mc.cfg.Addr) - } else { - nd := net.Dialer{Timeout: mc.cfg.Timeout} - mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr) + c := &connector{ + cfg: cfg, } - if err != nil { - return nil, err - } - - // Enable TCP Keepalives on TCP connections - if tc, ok := mc.netConn.(*net.TCPConn); ok { - if err := tc.SetKeepAlive(true); err != nil { - // Don't send COM_QUIT before handshake. - mc.netConn.Close() - mc.netConn = nil - return nil, err - } - } - - // Call startWatcher for context support (From Go 1.8) - if s, ok := interface{}(mc).(watcher); ok { - s.startWatcher() - } - - mc.buf = newBuffer(mc.netConn) - - // Set I/O timeouts - mc.buf.timeout = mc.cfg.ReadTimeout - mc.writeTimeout = mc.cfg.WriteTimeout - - // Reading Handshake Initialization Packet - authData, plugin, err := mc.readHandshakePacket() - if err != nil { - mc.cleanup() - return nil, err - } - - // Send Client Authentication Packet - authResp, addNUL, err := mc.auth(authData, plugin) - if err != nil { - // try the default auth plugin, if using the requested plugin failed - errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) - plugin = defaultAuthPlugin - authResp, addNUL, err = mc.auth(authData, plugin) - if err != nil { - mc.cleanup() - return nil, err - } - } - if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { - mc.cleanup() - return nil, err - } - - // Handle response to auth packet, switch methods if possible - if err = mc.handleAuthResult(authData, plugin); err != nil { - // Authentication failed and MySQL has already closed the connection - // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). - // Do not send COM_QUIT, just cleanup and return the error. - mc.cleanup() - return nil, err - } - - if mc.cfg.MaxAllowedPacket > 0 { - mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket - } else { - // Get max allowed packet size - maxap, err := mc.getSystemVar("max_allowed_packet") - if err != nil { - mc.Close() - return nil, err - } - mc.maxAllowedPacket = stringToInt(maxap) - 1 - } - if mc.maxAllowedPacket < maxPacketSize { - mc.maxWriteSize = mc.maxAllowedPacket - } - - // Handle DSN Params - err = mc.handleParams() - if err != nil { - mc.Close() - return nil, err - } - - return mc, nil + return c.Connect(context.Background()) } func init() { diff --git a/vendor/github.com/go-sql-driver/mysql/driver_go110.go b/vendor/github.com/go-sql-driver/mysql/driver_go110.go new file mode 100644 index 00000000..eb5a8fe9 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/driver_go110.go @@ -0,0 +1,37 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.10 + +package mysql + +import ( + "database/sql/driver" +) + +// NewConnector returns new driver.Connector. +func NewConnector(cfg *Config) (driver.Connector, error) { + cfg = cfg.Clone() + // normalize the contents of cfg so calls to NewConnector have the same + // behavior as MySQLDriver.OpenConnector + if err := cfg.normalize(); err != nil { + return nil, err + } + return &connector{cfg: cfg}, nil +} + +// OpenConnector implements driver.DriverContext. +func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) { + cfg, err := ParseDSN(dsn) + if err != nil { + return nil, err + } + return &connector{ + cfg: cfg, + }, nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/driver_go110_test.go b/vendor/github.com/go-sql-driver/mysql/driver_go110_test.go new file mode 100644 index 00000000..19a0e595 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/driver_go110_test.go @@ -0,0 +1,137 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.10 + +package mysql + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "net" + "testing" + "time" +) + +var _ driver.DriverContext = &MySQLDriver{} + +type dialCtxKey struct{} + +func TestConnectorObeysDialTimeouts(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + RegisterDialContext("dialctxtest", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + if !ctx.Value(dialCtxKey{}).(bool) { + return nil, fmt.Errorf("test error: query context is not propagated to our dialer") + } + return d.DialContext(ctx, prot, addr) + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + ctx := context.WithValue(context.Background(), dialCtxKey{}, true) + + _, err = db.ExecContext(ctx, "DO 1") + if err != nil { + t.Fatal(err) + } +} + +func configForTests(t *testing.T) *Config { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + mycnf := NewConfig() + mycnf.User = user + mycnf.Passwd = pass + mycnf.Addr = addr + mycnf.Net = prot + mycnf.DBName = dbname + return mycnf +} + +func TestNewConnector(t *testing.T) { + mycnf := configForTests(t) + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + db := sql.OpenDB(conn) + defer db.Close() + + if err := db.Ping(); err != nil { + t.Fatal(err) + } +} + +type slowConnection struct { + net.Conn + slowdown time.Duration +} + +func (sc *slowConnection) Read(b []byte) (int, error) { + time.Sleep(sc.slowdown) + return sc.Conn.Read(b) +} + +type connectorHijack struct { + driver.Connector + connErr error +} + +func (cw *connectorHijack) Connect(ctx context.Context) (driver.Conn, error) { + var conn driver.Conn + conn, cw.connErr = cw.Connector.Connect(ctx) + return conn, cw.connErr +} + +func TestConnectorTimeoutsDuringOpen(t *testing.T) { + RegisterDialContext("slowconn", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + conn, err := d.DialContext(ctx, prot, addr) + if err != nil { + return nil, err + } + return &slowConnection{Conn: conn, slowdown: 100 * time.Millisecond}, nil + }) + + mycnf := configForTests(t) + mycnf.Net = "slowconn" + + conn, err := NewConnector(mycnf) + if err != nil { + t.Fatal(err) + } + + hijack := &connectorHijack{Connector: conn} + + db := sql.OpenDB(hijack) + defer db.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) + defer cancel() + + _, err = db.ExecContext(ctx, "DO 1") + if err != context.DeadlineExceeded { + t.Fatalf("ExecContext should have timed out") + } + if hijack.connErr != context.DeadlineExceeded { + t.Fatalf("(*Connector).Connect should have timed out") + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/driver_test.go b/vendor/github.com/go-sql-driver/mysql/driver_test.go new file mode 100644 index 00000000..3dee1bab --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/driver_test.go @@ -0,0 +1,2996 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "context" + "crypto/tls" + "database/sql" + "database/sql/driver" + "fmt" + "io" + "io/ioutil" + "log" + "math" + "net" + "net/url" + "os" + "reflect" + "strings" + "sync" + "sync/atomic" + "testing" + "time" +) + +// Ensure that all the driver interfaces are implemented +var ( + _ driver.Rows = &binaryRows{} + _ driver.Rows = &textRows{} +) + +var ( + user string + pass string + prot string + addr string + dbname string + dsn string + netAddr string + available bool +) + +var ( + tDate = time.Date(2012, 6, 14, 0, 0, 0, 0, time.UTC) + sDate = "2012-06-14" + tDateTime = time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC) + sDateTime = "2011-11-20 21:27:37" + tDate0 = time.Time{} + sDate0 = "0000-00-00" + sDateTime0 = "0000-00-00 00:00:00" +) + +// See https://github.com/go-sql-driver/mysql/wiki/Testing +func init() { + // get environment variables + env := func(key, defaultValue string) string { + if value := os.Getenv(key); value != "" { + return value + } + return defaultValue + } + user = env("MYSQL_TEST_USER", "root") + pass = env("MYSQL_TEST_PASS", "") + prot = env("MYSQL_TEST_PROT", "tcp") + addr = env("MYSQL_TEST_ADDR", "localhost:3306") + dbname = env("MYSQL_TEST_DBNAME", "gotest") + netAddr = fmt.Sprintf("%s(%s)", prot, addr) + dsn = fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, pass, netAddr, dbname) + c, err := net.Dial(prot, addr) + if err == nil { + available = true + c.Close() + } +} + +type DBTest struct { + *testing.T + db *sql.DB +} + +type netErrorMock struct { + temporary bool + timeout bool +} + +func (e netErrorMock) Temporary() bool { + return e.temporary +} + +func (e netErrorMock) Timeout() bool { + return e.timeout +} + +func (e netErrorMock) Error() string { + return fmt.Sprintf("mock net error. Temporary: %v, Timeout %v", e.temporary, e.timeout) +} + +func runTestsWithMultiStatement(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + dsn += "&multiStatements=true" + var db *sql.DB + if _, err := ParseDSN(dsn); err != errInvalidDSNUnsafeCollation { + db, err = sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + } + + dbt := &DBTest{t, db} + for _, test := range tests { + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") + } +} + +func runTests(t *testing.T, dsn string, tests ...func(dbt *DBTest)) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + db, err := sql.Open("mysql", dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + db.Exec("DROP TABLE IF EXISTS test") + + dsn2 := dsn + "&interpolateParams=true" + var db2 *sql.DB + if _, err := ParseDSN(dsn2); err != errInvalidDSNUnsafeCollation { + db2, err = sql.Open("mysql", dsn2) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db2.Close() + } + + dsn3 := dsn + "&multiStatements=true" + var db3 *sql.DB + if _, err := ParseDSN(dsn3); err != errInvalidDSNUnsafeCollation { + db3, err = sql.Open("mysql", dsn3) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db3.Close() + } + + dbt := &DBTest{t, db} + dbt2 := &DBTest{t, db2} + dbt3 := &DBTest{t, db3} + for _, test := range tests { + test(dbt) + dbt.db.Exec("DROP TABLE IF EXISTS test") + if db2 != nil { + test(dbt2) + dbt2.db.Exec("DROP TABLE IF EXISTS test") + } + if db3 != nil { + test(dbt3) + dbt3.db.Exec("DROP TABLE IF EXISTS test") + } + } +} + +func (dbt *DBTest) fail(method, query string, err error) { + if len(query) > 300 { + query = "[query too large to print]" + } + dbt.Fatalf("error on %s %s: %s", method, query, err.Error()) +} + +func (dbt *DBTest) mustExec(query string, args ...interface{}) (res sql.Result) { + res, err := dbt.db.Exec(query, args...) + if err != nil { + dbt.fail("exec", query, err) + } + return res +} + +func (dbt *DBTest) mustQuery(query string, args ...interface{}) (rows *sql.Rows) { + rows, err := dbt.db.Query(query, args...) + if err != nil { + dbt.fail("query", query, err) + } + return rows +} + +func maybeSkip(t *testing.T, err error, skipErrno uint16) { + mySQLErr, ok := err.(*MySQLError) + if !ok { + return + } + + if mySQLErr.Number == skipErrno { + t.Skipf("skipping test for error: %v", err) + } +} + +func TestEmptyQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // just a comment, no query + rows := dbt.mustQuery("--") + defer rows.Close() + // will hang before #255 + if rows.Next() { + dbt.Errorf("next on rows must be false") + } + }) +} + +func TestCRUD(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + + // Test for unexpected data + var out bool + rows := dbt.mustQuery("SELECT * FROM test") + if rows.Next() { + dbt.Error("unexpected data in empty table") + } + rows.Close() + + // Create Data + res := dbt.mustExec("INSERT INTO test VALUES (1)") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + id, err := res.LastInsertId() + if err != nil { + dbt.Fatalf("res.LastInsertId() returned error: %s", err.Error()) + } + if id != 0 { + dbt.Fatalf("expected InsertId 0, got %d", id) + } + + // Read + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if true != out { + dbt.Errorf("true != %t", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + rows.Close() + + // Update + res = dbt.mustExec("UPDATE test SET value = ? WHERE value = ?", false, true) + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Check Update + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if false != out { + dbt.Errorf("false != %t", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + rows.Close() + + // Delete + res = dbt.mustExec("DELETE FROM test WHERE value = ?", false) + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Check for unexpected rows + res = dbt.mustExec("DELETE FROM test") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 0 { + dbt.Fatalf("expected 0 affected row, got %d", count) + } + }) +} + +func TestMultiQuery(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE `test` (`id` int(11) NOT NULL, `value` int(11) NOT NULL) ") + + // Create Data + res := dbt.mustExec("INSERT INTO test VALUES (1, 1)") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Update + res = dbt.mustExec("UPDATE test SET value = 3 WHERE id = 1; UPDATE test SET value = 4 WHERE id = 1; UPDATE test SET value = 5 WHERE id = 1;") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 1 { + dbt.Fatalf("expected 1 affected row, got %d", count) + } + + // Read + var out int + rows := dbt.mustQuery("SELECT value FROM test WHERE id=1;") + if rows.Next() { + rows.Scan(&out) + if 5 != out { + dbt.Errorf("5 != %d", out) + } + + if rows.Next() { + dbt.Error("unexpected data") + } + } else { + dbt.Error("no data") + } + rows.Close() + + }) +} + +func TestInt(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [5]string{"TINYINT", "SMALLINT", "MEDIUMINT", "INT", "BIGINT"} + in := int64(42) + var out int64 + var rows *sql.Rows + + // SIGNED + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + + // UNSIGNED ZEROFILL + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + " ZEROFILL)") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s ZEROFILL: %d != %d", v, in, out) + } + } else { + dbt.Errorf("%s ZEROFILL: no data", v) + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat32(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + in := float32(42.23) + var out float32 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %g != %g", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat64(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + var expected float64 = 42.23 + var out float64 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (42.23)") + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestFloat64Placeholder(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [2]string{"FLOAT", "DOUBLE"} + var expected float64 = 42.23 + var out float64 + var rows *sql.Rows + for _, v := range types { + dbt.mustExec("CREATE TABLE test (id int, value " + v + ")") + dbt.mustExec("INSERT INTO test VALUES (1, 42.23)") + rows = dbt.mustQuery("SELECT value FROM test WHERE id = ?", 1) + if rows.Next() { + rows.Scan(&out) + if expected != out { + dbt.Errorf("%s: %g != %g", v, expected, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + dbt.mustExec("DROP TABLE IF EXISTS test") + } + }) +} + +func TestString(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + types := [6]string{"CHAR(255)", "VARCHAR(255)", "TINYTEXT", "TEXT", "MEDIUMTEXT", "LONGTEXT"} + in := "κόσμε üöäßñóùéàâÿœ'îë Árvíztűrő いろはにほへとちりぬるを イロハニホヘト דג סקרן чащах น่าฟังเอย" + var out string + var rows *sql.Rows + + for _, v := range types { + dbt.mustExec("CREATE TABLE test (value " + v + ") CHARACTER SET utf8") + + dbt.mustExec("INSERT INTO test VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Errorf("%s: %s != %s", v, in, out) + } + } else { + dbt.Errorf("%s: no data", v) + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + } + + // BLOB + dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + + id := 2 + in = "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + + "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet. " + + "Lorem ipsum dolor sit amet, consetetur sadipscing elitr, " + + "sed diam nonumy eirmod tempor invidunt ut labore et dolore magna aliquyam erat, " + + "sed diam voluptua. At vero eos et accusam et justo duo dolores et ea rebum. " + + "Stet clita kasd gubergren, no sea takimata sanctus est Lorem ipsum dolor sit amet." + dbt.mustExec("INSERT INTO test VALUES (?, ?)", id, in) + + err := dbt.db.QueryRow("SELECT value FROM test WHERE id = ?", id).Scan(&out) + if err != nil { + dbt.Fatalf("Error on BLOB-Query: %s", err.Error()) + } else if out != in { + dbt.Errorf("BLOB: %s != %s", in, out) + } + }) +} + +func TestRawBytes(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + v1 := []byte("aaa") + v2 := []byte("bbb") + rows := dbt.mustQuery("SELECT ?, ?", v1, v2) + defer rows.Close() + if rows.Next() { + var o1, o2 sql.RawBytes + if err := rows.Scan(&o1, &o2); err != nil { + dbt.Errorf("Got error: %v", err) + } + if !bytes.Equal(v1, o1) { + dbt.Errorf("expected %v, got %v", v1, o1) + } + if !bytes.Equal(v2, o2) { + dbt.Errorf("expected %v, got %v", v2, o2) + } + // https://github.com/go-sql-driver/mysql/issues/765 + // Appending to RawBytes shouldn't overwrite next RawBytes. + o1 = append(o1, "xyzzy"...) + if !bytes.Equal(v2, o2) { + dbt.Errorf("expected %v, got %v", v2, o2) + } + } else { + dbt.Errorf("no data") + } + }) +} + +type testValuer struct { + value string +} + +func (tv testValuer) Value() (driver.Value, error) { + return tv.value, nil +} + +func TestValuer(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + in := testValuer{"a_value"} + var out string + var rows *sql.Rows + + dbt.mustExec("CREATE TABLE test (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO test VALUES (?)", in) + rows = dbt.mustQuery("SELECT value FROM test") + if rows.Next() { + rows.Scan(&out) + if in.value != out { + dbt.Errorf("Valuer: %v != %s", in, out) + } + } else { + dbt.Errorf("Valuer: no data") + } + rows.Close() + + dbt.mustExec("DROP TABLE IF EXISTS test") + }) +} + +type testValuerWithValidation struct { + value string +} + +func (tv testValuerWithValidation) Value() (driver.Value, error) { + if len(tv.value) == 0 { + return nil, fmt.Errorf("Invalid string valuer. Value must not be empty") + } + + return tv.value, nil +} + +func TestValuerWithValidation(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + in := testValuerWithValidation{"a_value"} + var out string + var rows *sql.Rows + + dbt.mustExec("CREATE TABLE testValuer (value VARCHAR(255)) CHARACTER SET utf8") + dbt.mustExec("INSERT INTO testValuer VALUES (?)", in) + + rows = dbt.mustQuery("SELECT value FROM testValuer") + defer rows.Close() + + if rows.Next() { + rows.Scan(&out) + if in.value != out { + dbt.Errorf("Valuer: %v != %s", in, out) + } + } else { + dbt.Errorf("Valuer: no data") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", testValuerWithValidation{""}); err == nil { + dbt.Errorf("Failed to check valuer error") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", nil); err != nil { + dbt.Errorf("Failed to check nil") + } + + if _, err := dbt.db.Exec("INSERT INTO testValuer VALUES (?)", map[string]bool{}); err == nil { + dbt.Errorf("Failed to check not valuer") + } + + dbt.mustExec("DROP TABLE IF EXISTS testValuer") + }) +} + +type timeTests struct { + dbtype string + tlayout string + tests []timeTest +} + +type timeTest struct { + s string // leading "!": do not use t as value in queries + t time.Time +} + +type timeMode byte + +func (t timeMode) String() string { + switch t { + case binaryString: + return "binary:string" + case binaryTime: + return "binary:time.Time" + case textString: + return "text:string" + } + panic("unsupported timeMode") +} + +func (t timeMode) Binary() bool { + switch t { + case binaryString, binaryTime: + return true + } + return false +} + +const ( + binaryString timeMode = iota + binaryTime + textString +) + +func (t timeTest) genQuery(dbtype string, mode timeMode) string { + var inner string + if mode.Binary() { + inner = "?" + } else { + inner = `"%s"` + } + return `SELECT cast(` + inner + ` as ` + dbtype + `)` +} + +func (t timeTest) run(dbt *DBTest, dbtype, tlayout string, mode timeMode) { + var rows *sql.Rows + query := t.genQuery(dbtype, mode) + switch mode { + case binaryString: + rows = dbt.mustQuery(query, t.s) + case binaryTime: + rows = dbt.mustQuery(query, t.t) + case textString: + query = fmt.Sprintf(query, t.s) + rows = dbt.mustQuery(query) + default: + panic("unsupported mode") + } + defer rows.Close() + var err error + if !rows.Next() { + err = rows.Err() + if err == nil { + err = fmt.Errorf("no data") + } + dbt.Errorf("%s [%s]: %s", dbtype, mode, err) + return + } + var dst interface{} + err = rows.Scan(&dst) + if err != nil { + dbt.Errorf("%s [%s]: %s", dbtype, mode, err) + return + } + switch val := dst.(type) { + case []uint8: + str := string(val) + if str == t.s { + return + } + if mode.Binary() && dbtype == "DATETIME" && len(str) == 26 && str[:19] == t.s { + // a fix mainly for TravisCI: + // accept full microsecond resolution in result for DATETIME columns + // where the binary protocol was used + return + } + dbt.Errorf("%s [%s] to string: expected %q, got %q", + dbtype, mode, + t.s, str, + ) + case time.Time: + if val == t.t { + return + } + dbt.Errorf("%s [%s] to string: expected %q, got %q", + dbtype, mode, + t.s, val.Format(tlayout), + ) + default: + fmt.Printf("%#v\n", []interface{}{dbtype, tlayout, mode, t.s, t.t}) + dbt.Errorf("%s [%s]: unhandled type %T (is '%v')", + dbtype, mode, + val, val, + ) + } +} + +func TestDateTime(t *testing.T) { + afterTime := func(t time.Time, d string) time.Time { + dur, err := time.ParseDuration(d) + if err != nil { + panic(err) + } + return t.Add(dur) + } + // NOTE: MySQL rounds DATETIME(x) up - but that's not included in the tests + format := "2006-01-02 15:04:05.999999" + t0 := time.Time{} + tstr0 := "0000-00-00 00:00:00.000000" + testcases := []timeTests{ + {"DATE", format[:10], []timeTest{ + {t: time.Date(2011, 11, 20, 0, 0, 0, 0, time.UTC)}, + {t: t0, s: tstr0[:10]}, + }}, + {"DATETIME", format[:19], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, + {t: t0, s: tstr0[:19]}, + }}, + {"DATETIME(0)", format[:21], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 0, time.UTC)}, + {t: t0, s: tstr0[:19]}, + }}, + {"DATETIME(1)", format[:21], []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 100000000, time.UTC)}, + {t: t0, s: tstr0[:21]}, + }}, + {"DATETIME(6)", format, []timeTest{ + {t: time.Date(2011, 11, 20, 21, 27, 37, 123456000, time.UTC)}, + {t: t0, s: tstr0}, + }}, + {"TIME", format[11:19], []timeTest{ + {t: afterTime(t0, "12345s")}, + {s: "!-12:34:56"}, + {s: "!-838:59:59"}, + {s: "!838:59:59"}, + {t: t0, s: tstr0[11:19]}, + }}, + {"TIME(0)", format[11:19], []timeTest{ + {t: afterTime(t0, "12345s")}, + {s: "!-12:34:56"}, + {s: "!-838:59:59"}, + {s: "!838:59:59"}, + {t: t0, s: tstr0[11:19]}, + }}, + {"TIME(1)", format[11:21], []timeTest{ + {t: afterTime(t0, "12345600ms")}, + {s: "!-12:34:56.7"}, + {s: "!-838:59:58.9"}, + {s: "!838:59:58.9"}, + {t: t0, s: tstr0[11:21]}, + }}, + {"TIME(6)", format[11:], []timeTest{ + {t: afterTime(t0, "1234567890123000ns")}, + {s: "!-12:34:56.789012"}, + {s: "!-838:59:58.999999"}, + {s: "!838:59:58.999999"}, + {t: t0, s: tstr0[11:]}, + }}, + } + dsns := []string{ + dsn + "&parseTime=true", + dsn + "&parseTime=false", + } + for _, testdsn := range dsns { + runTests(t, testdsn, func(dbt *DBTest) { + microsecsSupported := false + zeroDateSupported := false + var rows *sql.Rows + var err error + rows, err = dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`) + if err == nil { + rows.Scan(µsecsSupported) + rows.Close() + } + rows, err = dbt.db.Query(`SELECT cast("0000-00-00" as DATE) = "0000-00-00"`) + if err == nil { + rows.Scan(&zeroDateSupported) + rows.Close() + } + for _, setups := range testcases { + if t := setups.dbtype; !microsecsSupported && t[len(t)-1:] == ")" { + // skip fractional second tests if unsupported by server + continue + } + for _, setup := range setups.tests { + allowBinTime := true + if setup.s == "" { + // fill time string wherever Go can reliable produce it + setup.s = setup.t.Format(setups.tlayout) + } else if setup.s[0] == '!' { + // skip tests using setup.t as source in queries + allowBinTime = false + // fix setup.s - remove the "!" + setup.s = setup.s[1:] + } + if !zeroDateSupported && setup.s == tstr0[:len(setup.s)] { + // skip disallowed 0000-00-00 date + continue + } + setup.run(dbt, setups.dbtype, setups.tlayout, textString) + setup.run(dbt, setups.dbtype, setups.tlayout, binaryString) + if allowBinTime { + setup.run(dbt, setups.dbtype, setups.tlayout, binaryTime) + } + } + } + }) + } +} + +func TestTimestampMicros(t *testing.T) { + format := "2006-01-02 15:04:05.999999" + f0 := format[:19] + f1 := format[:21] + f6 := format[:26] + runTests(t, dsn, func(dbt *DBTest) { + // check if microseconds are supported. + // Do not use timestamp(x) for that check - before 5.5.6, x would mean display width + // and not precision. + // Se last paragraph at http://dev.mysql.com/doc/refman/5.6/en/fractional-seconds.html + microsecsSupported := false + if rows, err := dbt.db.Query(`SELECT cast("00:00:00.1" as TIME(1)) = "00:00:00.1"`); err == nil { + rows.Scan(µsecsSupported) + rows.Close() + } + if !microsecsSupported { + // skip test + return + } + _, err := dbt.db.Exec(` + CREATE TABLE test ( + value0 TIMESTAMP NOT NULL DEFAULT '` + f0 + `', + value1 TIMESTAMP(1) NOT NULL DEFAULT '` + f1 + `', + value6 TIMESTAMP(6) NOT NULL DEFAULT '` + f6 + `' + )`, + ) + if err != nil { + dbt.Error(err) + } + defer dbt.mustExec("DROP TABLE IF EXISTS test") + dbt.mustExec("INSERT INTO test SET value0=?, value1=?, value6=?", f0, f1, f6) + var res0, res1, res6 string + rows := dbt.mustQuery("SELECT * FROM test") + defer rows.Close() + if !rows.Next() { + dbt.Errorf("test contained no selectable values") + } + err = rows.Scan(&res0, &res1, &res6) + if err != nil { + dbt.Error(err) + } + if res0 != f0 { + dbt.Errorf("expected %q, got %q", f0, res0) + } + if res1 != f1 { + dbt.Errorf("expected %q, got %q", f1, res1) + } + if res6 != f6 { + dbt.Errorf("expected %q, got %q", f6, res6) + } + }) +} + +func TestNULL(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + nullStmt, err := dbt.db.Prepare("SELECT NULL") + if err != nil { + dbt.Fatal(err) + } + defer nullStmt.Close() + + nonNullStmt, err := dbt.db.Prepare("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + defer nonNullStmt.Close() + + // NullBool + var nb sql.NullBool + // Invalid + if err = nullStmt.QueryRow().Scan(&nb); err != nil { + dbt.Fatal(err) + } + if nb.Valid { + dbt.Error("valid NullBool which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&nb); err != nil { + dbt.Fatal(err) + } + if !nb.Valid { + dbt.Error("invalid NullBool which should be valid") + } else if nb.Bool != true { + dbt.Errorf("Unexpected NullBool value: %t (should be true)", nb.Bool) + } + + // NullFloat64 + var nf sql.NullFloat64 + // Invalid + if err = nullStmt.QueryRow().Scan(&nf); err != nil { + dbt.Fatal(err) + } + if nf.Valid { + dbt.Error("valid NullFloat64 which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&nf); err != nil { + dbt.Fatal(err) + } + if !nf.Valid { + dbt.Error("invalid NullFloat64 which should be valid") + } else if nf.Float64 != float64(1) { + dbt.Errorf("unexpected NullFloat64 value: %f (should be 1.0)", nf.Float64) + } + + // NullInt64 + var ni sql.NullInt64 + // Invalid + if err = nullStmt.QueryRow().Scan(&ni); err != nil { + dbt.Fatal(err) + } + if ni.Valid { + dbt.Error("valid NullInt64 which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&ni); err != nil { + dbt.Fatal(err) + } + if !ni.Valid { + dbt.Error("invalid NullInt64 which should be valid") + } else if ni.Int64 != int64(1) { + dbt.Errorf("unexpected NullInt64 value: %d (should be 1)", ni.Int64) + } + + // NullString + var ns sql.NullString + // Invalid + if err = nullStmt.QueryRow().Scan(&ns); err != nil { + dbt.Fatal(err) + } + if ns.Valid { + dbt.Error("valid NullString which should be invalid") + } + // Valid + if err = nonNullStmt.QueryRow().Scan(&ns); err != nil { + dbt.Fatal(err) + } + if !ns.Valid { + dbt.Error("invalid NullString which should be valid") + } else if ns.String != `1` { + dbt.Error("unexpected NullString value:" + ns.String + " (should be `1`)") + } + + // nil-bytes + var b []byte + // Read nil + if err = nullStmt.QueryRow().Scan(&b); err != nil { + dbt.Fatal(err) + } + if b != nil { + dbt.Error("non-nil []byte which should be nil") + } + // Read non-nil + if err = nonNullStmt.QueryRow().Scan(&b); err != nil { + dbt.Fatal(err) + } + if b == nil { + dbt.Error("nil []byte which should be non-nil") + } + // Insert nil + b = nil + success := false + if err = dbt.db.QueryRow("SELECT ? IS NULL", b).Scan(&success); err != nil { + dbt.Fatal(err) + } + if !success { + dbt.Error("inserting []byte(nil) as NULL failed") + } + // Check input==output with input==nil + b = nil + if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + dbt.Fatal(err) + } + if b != nil { + dbt.Error("non-nil echo from nil input") + } + // Check input==output with input!=nil + b = []byte("") + if err = dbt.db.QueryRow("SELECT ?", b).Scan(&b); err != nil { + dbt.Fatal(err) + } + if b == nil { + dbt.Error("nil echo from non-nil input") + } + + // Insert NULL + dbt.mustExec("CREATE TABLE test (dummmy1 int, value int, dummy2 int)") + + dbt.mustExec("INSERT INTO test VALUES (?, ?, ?)", 1, nil, 2) + + var out interface{} + rows := dbt.mustQuery("SELECT * FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if out != nil { + dbt.Errorf("%v != nil", out) + } + } else { + dbt.Error("no data") + } + }) +} + +func TestUint64(t *testing.T) { + const ( + u0 = uint64(0) + uall = ^u0 + uhigh = uall >> 1 + utop = ^uhigh + s0 = int64(0) + sall = ^s0 + shigh = int64(uhigh) + stop = ^shigh + ) + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare(`SELECT ?, ?, ? ,?, ?, ?, ?, ?`) + if err != nil { + dbt.Fatal(err) + } + defer stmt.Close() + row := stmt.QueryRow( + u0, uhigh, utop, uall, + s0, shigh, stop, sall, + ) + + var ua, ub, uc, ud uint64 + var sa, sb, sc, sd int64 + + err = row.Scan(&ua, &ub, &uc, &ud, &sa, &sb, &sc, &sd) + if err != nil { + dbt.Fatal(err) + } + switch { + case ua != u0, + ub != uhigh, + uc != utop, + ud != uall, + sa != s0, + sb != shigh, + sc != stop, + sd != sall: + dbt.Fatal("unexpected result value") + } + }) +} + +func TestLongData(t *testing.T) { + runTests(t, dsn+"&maxAllowedPacket=0", func(dbt *DBTest) { + var maxAllowedPacketSize int + err := dbt.db.QueryRow("select @@max_allowed_packet").Scan(&maxAllowedPacketSize) + if err != nil { + dbt.Fatal(err) + } + maxAllowedPacketSize-- + + // don't get too ambitious + if maxAllowedPacketSize > 1<<25 { + maxAllowedPacketSize = 1 << 25 + } + + dbt.mustExec("CREATE TABLE test (value LONGBLOB)") + + in := strings.Repeat(`a`, maxAllowedPacketSize+1) + var out string + var rows *sql.Rows + + // Long text data + const nonDataQueryLen = 28 // length query w/o value + inS := in[:maxAllowedPacketSize-nonDataQueryLen] + dbt.mustExec("INSERT INTO test VALUES('" + inS + "')") + rows = dbt.mustQuery("SELECT value FROM test") + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if inS != out { + dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(inS), len(out)) + } + if rows.Next() { + dbt.Error("LONGBLOB: unexpexted row") + } + } else { + dbt.Fatalf("LONGBLOB: no data") + } + + // Empty table + dbt.mustExec("TRUNCATE TABLE test") + + // Long binary data + dbt.mustExec("INSERT INTO test VALUES(?)", in) + rows = dbt.mustQuery("SELECT value FROM test WHERE 1=?", 1) + defer rows.Close() + if rows.Next() { + rows.Scan(&out) + if in != out { + dbt.Fatalf("LONGBLOB: length in: %d, length out: %d", len(in), len(out)) + } + if rows.Next() { + dbt.Error("LONGBLOB: unexpexted row") + } + } else { + if err = rows.Err(); err != nil { + dbt.Fatalf("LONGBLOB: no data (err: %s)", err.Error()) + } else { + dbt.Fatal("LONGBLOB: no data (err: )") + } + } + }) +} + +func TestLoadData(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + verifyLoadDataResult := func() { + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + dbt.Fatal(err.Error()) + } + + i := 0 + values := [4]string{ + "a string", + "a string containing a \t", + "a string containing a \n", + "a string containing both \t\n", + } + + var id int + var value string + + for rows.Next() { + i++ + err = rows.Scan(&id, &value) + if err != nil { + dbt.Fatal(err.Error()) + } + if i != id { + dbt.Fatalf("%d != %d", i, id) + } + if values[i-1] != value { + dbt.Fatalf("%q != %q", values[i-1], value) + } + } + err = rows.Err() + if err != nil { + dbt.Fatal(err.Error()) + } + + if i != 4 { + dbt.Fatalf("rows count mismatch. Got %d, want 4", i) + } + } + + dbt.db.Exec("DROP TABLE IF EXISTS test") + dbt.mustExec("CREATE TABLE test (id INT NOT NULL PRIMARY KEY, value TEXT NOT NULL) CHARACTER SET utf8") + + // Local File + file, err := ioutil.TempFile("", "gotest") + defer os.Remove(file.Name()) + if err != nil { + dbt.Fatal(err) + } + RegisterLocalFile(file.Name()) + + // Try first with empty file + dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) + var count int + err = dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&count) + if err != nil { + dbt.Fatal(err.Error()) + } + if count != 0 { + dbt.Fatalf("unexpected row count: got %d, want 0", count) + } + + // Then fille File with data and try to load it + file.WriteString("1\ta string\n2\ta string containing a \\t\n3\ta string containing a \\n\n4\ta string containing both \\t\\n\n") + file.Close() + dbt.mustExec(fmt.Sprintf("LOAD DATA LOCAL INFILE %q INTO TABLE test", file.Name())) + verifyLoadDataResult() + + // Try with non-existing file + _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'doesnotexist' INTO TABLE test") + if err == nil { + dbt.Fatal("load non-existent file didn't fail") + } else if err.Error() != "local file 'doesnotexist' is not registered" { + dbt.Fatal(err.Error()) + } + + // Empty table + dbt.mustExec("TRUNCATE TABLE test") + + // Reader + RegisterReaderHandler("test", func() io.Reader { + file, err = os.Open(file.Name()) + if err != nil { + dbt.Fatal(err) + } + return file + }) + dbt.mustExec("LOAD DATA LOCAL INFILE 'Reader::test' INTO TABLE test") + verifyLoadDataResult() + // negative test + _, err = dbt.db.Exec("LOAD DATA LOCAL INFILE 'Reader::doesnotexist' INTO TABLE test") + if err == nil { + dbt.Fatal("load non-existent Reader didn't fail") + } else if err.Error() != "Reader 'doesnotexist' is not registered" { + dbt.Fatal(err.Error()) + } + }) +} + +func TestFoundRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 affected rows, got %d", count) + } + }) + runTests(t, dsn+"&clientFoundRows=true", func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id INT NOT NULL ,data INT NOT NULL)") + dbt.mustExec("INSERT INTO test (id, data) VALUES (0, 0),(0, 0),(1, 0),(1, 0),(1, 1)") + + res := dbt.mustExec("UPDATE test SET data = 1 WHERE id = 0") + count, err := res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 2 { + dbt.Fatalf("Expected 2 matched rows, got %d", count) + } + res = dbt.mustExec("UPDATE test SET data = 1 WHERE id = 1") + count, err = res.RowsAffected() + if err != nil { + dbt.Fatalf("res.RowsAffected() returned error: %s", err.Error()) + } + if count != 3 { + dbt.Fatalf("Expected 3 matched rows, got %d", count) + } + }) +} + +func TestTLS(t *testing.T) { + tlsTestReq := func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + if err == ErrNoTLS { + dbt.Skip("server does not support TLS") + } else { + dbt.Fatalf("error on Ping: %s", err.Error()) + } + } + + rows := dbt.mustQuery("SHOW STATUS LIKE 'Ssl_cipher'") + defer rows.Close() + + var variable, value *sql.RawBytes + for rows.Next() { + if err := rows.Scan(&variable, &value); err != nil { + dbt.Fatal(err.Error()) + } + + if (*value == nil) || (len(*value) == 0) { + dbt.Fatalf("no Cipher") + } else { + dbt.Logf("Cipher: %s", *value) + } + } + } + tlsTestOpt := func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.Fatalf("error on Ping: %s", err.Error()) + } + } + + runTests(t, dsn+"&tls=preferred", tlsTestOpt) + runTests(t, dsn+"&tls=skip-verify", tlsTestReq) + + // Verify that registering / using a custom cfg works + RegisterTLSConfig("custom-skip-verify", &tls.Config{ + InsecureSkipVerify: true, + }) + runTests(t, dsn+"&tls=custom-skip-verify", tlsTestReq) +} + +func TestReuseClosedConnection(t *testing.T) { + // this test does not use sql.database, it uses the driver directly + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + md := &MySQLDriver{} + conn, err := md.Open(dsn) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + stmt, err := conn.Prepare("DO 1") + if err != nil { + t.Fatalf("error preparing statement: %s", err.Error()) + } + _, err = stmt.Exec(nil) + if err != nil { + t.Fatalf("error executing statement: %s", err.Error()) + } + err = conn.Close() + if err != nil { + t.Fatalf("error closing connection: %s", err.Error()) + } + + defer func() { + if err := recover(); err != nil { + t.Errorf("panic after reusing a closed connection: %v", err) + } + }() + _, err = stmt.Exec(nil) + if err != nil && err != driver.ErrBadConn { + t.Errorf("unexpected error '%s', expected '%s'", + err.Error(), driver.ErrBadConn.Error()) + } +} + +func TestCharset(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + mustSetCharset := func(charsetParam, expected string) { + runTests(t, dsn+"&"+charsetParam, func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT @@character_set_connection") + defer rows.Close() + + if !rows.Next() { + dbt.Fatalf("error getting connection charset: %s", rows.Err()) + } + + var got string + rows.Scan(&got) + + if got != expected { + dbt.Fatalf("expected connection charset %s but got %s", expected, got) + } + }) + } + + // non utf8 test + mustSetCharset("charset=ascii", "ascii") + + // when the first charset is invalid, use the second + mustSetCharset("charset=none,utf8", "utf8") + + // when the first charset is valid, use it + mustSetCharset("charset=ascii,utf8", "ascii") + mustSetCharset("charset=utf8,ascii", "utf8") +} + +func TestFailingCharset(t *testing.T) { + runTests(t, dsn+"&charset=none", func(dbt *DBTest) { + // run query to really establish connection... + _, err := dbt.db.Exec("SELECT 1") + if err == nil { + dbt.db.Close() + t.Fatalf("connection must not succeed without a valid charset") + } + }) +} + +func TestCollation(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + defaultCollation := "utf8mb4_general_ci" + testCollations := []string{ + "", // do not set + defaultCollation, // driver default + "latin1_general_ci", + "binary", + "utf8_unicode_ci", + "cp1257_bin", + } + + for _, collation := range testCollations { + var expected, tdsn string + if collation != "" { + tdsn = dsn + "&collation=" + collation + expected = collation + } else { + tdsn = dsn + expected = defaultCollation + } + + runTests(t, tdsn, func(dbt *DBTest) { + var got string + if err := dbt.db.QueryRow("SELECT @@collation_connection").Scan(&got); err != nil { + dbt.Fatal(err) + } + + if got != expected { + dbt.Fatalf("expected connection collation %s but got %s", expected, got) + } + }) + } +} + +func TestColumnsWithAlias(t *testing.T) { + runTests(t, dsn+"&columnsWithAlias=true", func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT 1 AS A") + defer rows.Close() + cols, _ := rows.Columns() + if len(cols) != 1 { + t.Fatalf("expected 1 column, got %d", len(cols)) + } + if cols[0] != "A" { + t.Fatalf("expected column name \"A\", got \"%s\"", cols[0]) + } + + rows = dbt.mustQuery("SELECT * FROM (SELECT 1 AS one) AS A") + defer rows.Close() + cols, _ = rows.Columns() + if len(cols) != 1 { + t.Fatalf("expected 1 column, got %d", len(cols)) + } + if cols[0] != "A.one" { + t.Fatalf("expected column name \"A.one\", got \"%s\"", cols[0]) + } + }) +} + +func TestRawBytesResultExceedsBuffer(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // defaultBufSize from buffer.go + expected := strings.Repeat("abc", defaultBufSize) + + rows := dbt.mustQuery("SELECT '" + expected + "'") + defer rows.Close() + if !rows.Next() { + dbt.Error("expected result, got none") + } + var result sql.RawBytes + rows.Scan(&result) + if expected != string(result) { + dbt.Error("result did not match expected value") + } + }) +} + +func TestTimezoneConversion(t *testing.T) { + zones := []string{"UTC", "US/Central", "US/Pacific", "Local"} + + // Regression test for timezone handling + tzTest := func(dbt *DBTest) { + // Create table + dbt.mustExec("CREATE TABLE test (ts TIMESTAMP)") + + // Insert local time into database (should be converted) + usCentral, _ := time.LoadLocation("US/Central") + reftime := time.Date(2014, 05, 30, 18, 03, 17, 0, time.UTC).In(usCentral) + dbt.mustExec("INSERT INTO test VALUE (?)", reftime) + + // Retrieve time from DB + rows := dbt.mustQuery("SELECT ts FROM test") + defer rows.Close() + if !rows.Next() { + dbt.Fatal("did not get any rows out") + } + + var dbTime time.Time + err := rows.Scan(&dbTime) + if err != nil { + dbt.Fatal("Err", err) + } + + // Check that dates match + if reftime.Unix() != dbTime.Unix() { + dbt.Errorf("times do not match.\n") + dbt.Errorf(" Now(%v)=%v\n", usCentral, reftime) + dbt.Errorf(" Now(UTC)=%v\n", dbTime) + } + } + + for _, tz := range zones { + runTests(t, dsn+"&parseTime=true&loc="+url.QueryEscape(tz), tzTest) + } +} + +// Special cases + +func TestRowsClose(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + rows, err := dbt.db.Query("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + + err = rows.Close() + if err != nil { + dbt.Fatal(err) + } + + if rows.Next() { + dbt.Fatal("unexpected row after rows.Close()") + } + + err = rows.Err() + if err != nil { + dbt.Fatal(err) + } + }) +} + +// dangling statements +// http://code.google.com/p/go/issues/detail?id=3865 +func TestCloseStmtBeforeRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare("SELECT 1") + if err != nil { + dbt.Fatal(err) + } + + rows, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows.Close() + + err = stmt.Close() + if err != nil { + dbt.Fatal(err) + } + + if !rows.Next() { + dbt.Fatal("getting row failed") + } else { + err = rows.Err() + if err != nil { + dbt.Fatal(err) + } + + var out bool + err = rows.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + }) +} + +// It is valid to have multiple Rows for the same Stmt +// http://code.google.com/p/go/issues/detail?id=3734 +func TestStmtMultiRows(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + stmt, err := dbt.db.Prepare("SELECT 1 UNION SELECT 0") + if err != nil { + dbt.Fatal(err) + } + + rows1, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows1.Close() + + rows2, err := stmt.Query() + if err != nil { + stmt.Close() + dbt.Fatal(err) + } + defer rows2.Close() + + var out bool + + // 1 + if !rows1.Next() { + dbt.Fatal("first rows1.Next failed") + } else { + err = rows1.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows1.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + + if !rows2.Next() { + dbt.Fatal("first rows2.Next failed") + } else { + err = rows2.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows2.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != true { + dbt.Errorf("true != %t", out) + } + } + + // 2 + if !rows1.Next() { + dbt.Fatal("second rows1.Next failed") + } else { + err = rows1.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows1.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != false { + dbt.Errorf("false != %t", out) + } + + if rows1.Next() { + dbt.Fatal("unexpected row on rows1") + } + err = rows1.Close() + if err != nil { + dbt.Fatal(err) + } + } + + if !rows2.Next() { + dbt.Fatal("second rows2.Next failed") + } else { + err = rows2.Err() + if err != nil { + dbt.Fatal(err) + } + + err = rows2.Scan(&out) + if err != nil { + dbt.Fatalf("error on rows.Scan(): %s", err.Error()) + } + if out != false { + dbt.Errorf("false != %t", out) + } + + if rows2.Next() { + dbt.Fatal("unexpected row on rows2") + } + err = rows2.Close() + if err != nil { + dbt.Fatal(err) + } + } + }) +} + +// Regression test for +// * more than 32 NULL parameters (issue 209) +// * more parameters than fit into the buffer (issue 201) +// * parameters * 64 > max_allowed_packet (issue 734) +func TestPreparedManyCols(t *testing.T) { + numParams := 65535 + runTests(t, dsn, func(dbt *DBTest) { + query := "SELECT ?" + strings.Repeat(",?", numParams-1) + stmt, err := dbt.db.Prepare(query) + if err != nil { + dbt.Fatal(err) + } + defer stmt.Close() + + // create more parameters than fit into the buffer + // which will take nil-values + params := make([]interface{}, numParams) + rows, err := stmt.Query(params...) + if err != nil { + dbt.Fatal(err) + } + rows.Close() + + // Create 0byte string which we can't send via STMT_LONG_DATA. + for i := 0; i < numParams; i++ { + params[i] = "" + } + rows, err = stmt.Query(params...) + if err != nil { + dbt.Fatal(err) + } + rows.Close() + }) +} + +func TestConcurrent(t *testing.T) { + if enabled, _ := readBool(os.Getenv("MYSQL_TEST_CONCURRENT")); !enabled { + t.Skip("MYSQL_TEST_CONCURRENT env var not set") + } + + runTests(t, dsn, func(dbt *DBTest) { + var max int + err := dbt.db.QueryRow("SELECT @@max_connections").Scan(&max) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + dbt.Logf("testing up to %d concurrent connections \r\n", max) + + var remaining, succeeded int32 = int32(max), 0 + + var wg sync.WaitGroup + wg.Add(max) + + var fatalError string + var once sync.Once + fatalf := func(s string, vals ...interface{}) { + once.Do(func() { + fatalError = fmt.Sprintf(s, vals...) + }) + } + + for i := 0; i < max; i++ { + go func(id int) { + defer wg.Done() + + tx, err := dbt.db.Begin() + atomic.AddInt32(&remaining, -1) + + if err != nil { + if err.Error() != "Error 1040: Too many connections" { + fatalf("error on conn %d: %s", id, err.Error()) + } + return + } + + // keep the connection busy until all connections are open + for remaining > 0 { + if _, err = tx.Exec("DO 1"); err != nil { + fatalf("error on conn %d: %s", id, err.Error()) + return + } + } + + if err = tx.Commit(); err != nil { + fatalf("error on conn %d: %s", id, err.Error()) + return + } + + // everything went fine with this connection + atomic.AddInt32(&succeeded, 1) + }(i) + } + + // wait until all conections are open + wg.Wait() + + if fatalError != "" { + dbt.Fatal(fatalError) + } + + dbt.Logf("reached %d concurrent connections\r\n", succeeded) + }) +} + +func testDialError(t *testing.T, dialErr error, expectErr error) { + RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { + return nil, dialErr + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + _, err = db.Exec("DO 1") + if err != expectErr { + t.Fatalf("was expecting %s. Got: %s", dialErr, err) + } +} + +func TestDialUnknownError(t *testing.T) { + testErr := fmt.Errorf("test") + testDialError(t, testErr, testErr) +} + +func TestDialNonRetryableNetErr(t *testing.T) { + testErr := netErrorMock{} + testDialError(t, testErr, testErr) +} + +func TestDialTemporaryNetErr(t *testing.T) { + testErr := netErrorMock{temporary: true} + testDialError(t, testErr, driver.ErrBadConn) +} + +// Tests custom dial functions +func TestCustomDial(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + // our custom dial function which justs wraps net.Dial here + RegisterDialContext("mydial", func(ctx context.Context, addr string) (net.Conn, error) { + var d net.Dialer + return d.DialContext(ctx, prot, addr) + }) + + db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@mydial(%s)/%s?timeout=30s", user, pass, addr, dbname)) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + if _, err = db.Exec("DO 1"); err != nil { + t.Fatalf("connection failed: %s", err.Error()) + } +} + +func TestSQLInjection(t *testing.T) { + createTest := func(arg string) func(dbt *DBTest) { + return func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + dbt.mustExec("INSERT INTO test VALUES (?)", 1) + + var v int + // NULL can't be equal to anything, the idea here is to inject query so it returns row + // This test verifies that escapeQuotes and escapeBackslash are working properly + err := dbt.db.QueryRow("SELECT v FROM test WHERE NULL = ?", arg).Scan(&v) + if err == sql.ErrNoRows { + return // success, sql injection failed + } else if err == nil { + dbt.Errorf("sql injection successful with arg: %s", arg) + } else { + dbt.Errorf("error running query with arg: %s; err: %s", arg, err.Error()) + } + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", + } + for _, testdsn := range dsns { + runTests(t, testdsn, createTest("1 OR 1=1")) + runTests(t, testdsn, createTest("' OR '1'='1")) + } +} + +// Test if inserted data is correctly retrieved after being escaped +func TestInsertRetrieveEscapedData(t *testing.T) { + testData := func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v VARCHAR(255))") + + // All sequences that are escaped by escapeQuotes and escapeBackslash + v := "foo \x00\n\r\x1a\"'\\" + dbt.mustExec("INSERT INTO test VALUES (?)", v) + + var out string + err := dbt.db.QueryRow("SELECT v FROM test").Scan(&out) + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + if out != v { + dbt.Errorf("%q != %q", out, v) + } + } + + dsns := []string{ + dsn, + dsn + "&sql_mode='NO_BACKSLASH_ESCAPES'", + } + for _, testdsn := range dsns { + runTests(t, testdsn, testData) + } +} + +func TestUnixSocketAuthFail(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Save the current logger so we can restore it. + oldLogger := errLog + + // Set a new logger so we can capture its output. + buffer := bytes.NewBuffer(make([]byte, 0, 64)) + newLogger := log.New(buffer, "prefix: ", 0) + SetLogger(newLogger) + + // Restore the logger. + defer SetLogger(oldLogger) + + // Make a new DSN that uses the MySQL socket file and a bad password, which + // we can make by simply appending any character to the real password. + badPass := pass + "x" + socket := "" + if prot == "unix" { + socket = addr + } else { + // Get socket file from MySQL. + err := dbt.db.QueryRow("SELECT @@socket").Scan(&socket) + if err != nil { + t.Fatalf("error on SELECT @@socket: %s", err.Error()) + } + } + t.Logf("socket: %s", socket) + badDSN := fmt.Sprintf("%s:%s@unix(%s)/%s?timeout=30s", user, badPass, socket, dbname) + db, err := sql.Open("mysql", badDSN) + if err != nil { + t.Fatalf("error connecting: %s", err.Error()) + } + defer db.Close() + + // Connect to MySQL for real. This will cause an auth failure. + err = db.Ping() + if err == nil { + t.Error("expected Ping() to return an error") + } + + // The driver should not log anything. + if actual := buffer.String(); actual != "" { + t.Errorf("expected no output, got %q", actual) + } + }) +} + +// See Issue #422 +func TestInterruptBySignal(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + dbt.mustExec(` + DROP PROCEDURE IF EXISTS test_signal; + CREATE PROCEDURE test_signal(ret INT) + BEGIN + SELECT ret; + SIGNAL SQLSTATE + '45001' + SET + MESSAGE_TEXT = "an error", + MYSQL_ERRNO = 45001; + END + `) + defer dbt.mustExec("DROP PROCEDURE test_signal") + + var val int + + // text protocol + rows, err := dbt.db.Query("CALL test_signal(42)") + if err != nil { + dbt.Fatalf("error on text query: %s", err.Error()) + } + for rows.Next() { + if err := rows.Scan(&val); err != nil { + dbt.Error(err) + } else if val != 42 { + dbt.Errorf("expected val to be 42") + } + } + rows.Close() + + // binary protocol + rows, err = dbt.db.Query("CALL test_signal(?)", 42) + if err != nil { + dbt.Fatalf("error on binary query: %s", err.Error()) + } + for rows.Next() { + if err := rows.Scan(&val); err != nil { + dbt.Error(err) + } else if val != 42 { + dbt.Errorf("expected val to be 42") + } + } + rows.Close() + }) +} + +func TestColumnsReusesSlice(t *testing.T) { + rows := mysqlRows{ + rs: resultSet{ + columns: []mysqlField{ + { + tableName: "test", + name: "A", + }, + { + tableName: "test", + name: "B", + }, + }, + }, + } + + allocs := testing.AllocsPerRun(1, func() { + cols := rows.Columns() + + if len(cols) != 2 { + t.Fatalf("expected 2 columns, got %d", len(cols)) + } + }) + + if allocs != 0 { + t.Fatalf("expected 0 allocations, got %d", int(allocs)) + } + + if rows.rs.columnNames == nil { + t.Fatalf("expected columnNames to be set, got nil") + } +} + +func TestRejectReadOnly(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + // Set the session to read-only. We didn't set the `rejectReadOnly` + // option, so any writes after this should fail. + _, err := dbt.db.Exec("SET SESSION TRANSACTION READ ONLY") + // Error 1193: Unknown system variable 'TRANSACTION' => skip test, + // MySQL server version is too old + maybeSkip(t, err, 1193) + if _, err := dbt.db.Exec("DROP TABLE test"); err == nil { + t.Fatalf("writing to DB in read-only session without " + + "rejectReadOnly did not error") + } + // Set the session back to read-write so runTests() can properly clean + // up the table `test`. + dbt.mustExec("SET SESSION TRANSACTION READ WRITE") + }) + + // Enable the `rejectReadOnly` option. + runTests(t, dsn+"&rejectReadOnly=true", func(dbt *DBTest) { + // Create Table + dbt.mustExec("CREATE TABLE test (value BOOL)") + // Set the session to read only. Any writes after this should error on + // a driver.ErrBadConn, and cause `database/sql` to initiate a new + // connection. + dbt.mustExec("SET SESSION TRANSACTION READ ONLY") + // This would error, but `database/sql` should automatically retry on a + // new connection which is not read-only, and eventually succeed. + dbt.mustExec("DROP TABLE test") + }) +} + +func TestPing(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + if err := dbt.db.Ping(); err != nil { + dbt.fail("Ping", "Ping", err) + } + }) +} + +// See Issue #799 +func TestEmptyPassword(t *testing.T) { + if !available { + t.Skipf("MySQL server not running on %s", netAddr) + } + + dsn := fmt.Sprintf("%s:%s@%s/%s?timeout=30s", user, "", netAddr, dbname) + db, err := sql.Open("mysql", dsn) + if err == nil { + defer db.Close() + err = db.Ping() + } + + if pass == "" { + if err != nil { + t.Fatal(err.Error()) + } + } else { + if err == nil { + t.Fatal("expected authentication error") + } + if !strings.HasPrefix(err.Error(), "Error 1045") { + t.Fatal(err.Error()) + } + } +} + +// static interface implementation checks of mysqlConn +var ( + _ driver.ConnBeginTx = &mysqlConn{} + _ driver.ConnPrepareContext = &mysqlConn{} + _ driver.ExecerContext = &mysqlConn{} + _ driver.Pinger = &mysqlConn{} + _ driver.QueryerContext = &mysqlConn{} +) + +// static interface implementation checks of mysqlStmt +var ( + _ driver.StmtExecContext = &mysqlStmt{} + _ driver.StmtQueryContext = &mysqlStmt{} +) + +// Ensure that all the driver interfaces are implemented +var ( + // _ driver.RowsColumnTypeLength = &binaryRows{} + // _ driver.RowsColumnTypeLength = &textRows{} + _ driver.RowsColumnTypeDatabaseTypeName = &binaryRows{} + _ driver.RowsColumnTypeDatabaseTypeName = &textRows{} + _ driver.RowsColumnTypeNullable = &binaryRows{} + _ driver.RowsColumnTypeNullable = &textRows{} + _ driver.RowsColumnTypePrecisionScale = &binaryRows{} + _ driver.RowsColumnTypePrecisionScale = &textRows{} + _ driver.RowsColumnTypeScanType = &binaryRows{} + _ driver.RowsColumnTypeScanType = &textRows{} + _ driver.RowsNextResultSet = &binaryRows{} + _ driver.RowsNextResultSet = &textRows{} +) + +func TestMultiResultSet(t *testing.T) { + type result struct { + values [][]int + columns []string + } + + // checkRows is a helper test function to validate rows containing 3 result + // sets with specific values and columns. The basic query would look like this: + // + // SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + // SELECT 0 UNION SELECT 1; + // SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + // + // to distinguish test cases the first string argument is put in front of + // every error or fatal message. + checkRows := func(desc string, rows *sql.Rows, dbt *DBTest) { + expected := []result{ + { + values: [][]int{{1, 2}, {3, 4}}, + columns: []string{"col1", "col2"}, + }, + { + values: [][]int{{1, 2, 3}, {4, 5, 6}}, + columns: []string{"col1", "col2", "col3"}, + }, + } + + var res1 result + for rows.Next() { + var res [2]int + if err := rows.Scan(&res[0], &res[1]); err != nil { + dbt.Fatal(err) + } + res1.values = append(res1.values, res[:]) + } + + cols, err := rows.Columns() + if err != nil { + dbt.Fatal(desc, err) + } + res1.columns = cols + + if !reflect.DeepEqual(expected[0], res1) { + dbt.Error(desc, "want =", expected[0], "got =", res1) + } + + if !rows.NextResultSet() { + dbt.Fatal(desc, "expected next result set") + } + + // ignoring one result set + + if !rows.NextResultSet() { + dbt.Fatal(desc, "expected next result set") + } + + var res2 result + cols, err = rows.Columns() + if err != nil { + dbt.Fatal(desc, err) + } + res2.columns = cols + + for rows.Next() { + var res [3]int + if err := rows.Scan(&res[0], &res[1], &res[2]); err != nil { + dbt.Fatal(desc, err) + } + res2.values = append(res2.values, res[:]) + } + + if !reflect.DeepEqual(expected[1], res2) { + dbt.Error(desc, "want =", expected[1], "got =", res2) + } + + if rows.NextResultSet() { + dbt.Error(desc, "unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error(desc, err) + } + } + + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery(`DO 1; + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + DO 1; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6;`) + defer rows.Close() + checkRows("query: ", rows, dbt) + }) + + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + queries := []string{ + ` + DROP PROCEDURE IF EXISTS test_mrss; + CREATE PROCEDURE test_mrss() + BEGIN + DO 1; + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + DO 1; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + END + `, + ` + DROP PROCEDURE IF EXISTS test_mrss; + CREATE PROCEDURE test_mrss() + BEGIN + SELECT 1 AS col1, 2 AS col2 UNION SELECT 3, 4; + SELECT 0 UNION SELECT 1; + SELECT 1 AS col1, 2 AS col2, 3 AS col3 UNION SELECT 4, 5, 6; + END + `, + } + + defer dbt.mustExec("DROP PROCEDURE IF EXISTS test_mrss") + + for i, query := range queries { + dbt.mustExec(query) + + stmt, err := dbt.db.Prepare("CALL test_mrss()") + if err != nil { + dbt.Fatalf("%v (i=%d)", err, i) + } + defer stmt.Close() + + for j := 0; j < 2; j++ { + rows, err := stmt.Query() + if err != nil { + dbt.Fatalf("%v (i=%d) (j=%d)", err, i, j) + } + checkRows(fmt.Sprintf("prepared stmt query (i=%d) (j=%d): ", i, j), rows, dbt) + } + } + }) +} + +func TestMultiResultSetNoSelect(t *testing.T) { + runTestsWithMultiStatement(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery("DO 1; DO 2;") + defer rows.Close() + + if rows.Next() { + dbt.Error("unexpected row") + } + + if rows.NextResultSet() { + dbt.Error("unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error("expected nil; got ", err) + } + }) +} + +// tests if rows are set in a proper state if some results were ignored before +// calling rows.NextResultSet. +func TestSkipResults(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + rows := dbt.mustQuery("SELECT 1, 2") + defer rows.Close() + + if !rows.Next() { + dbt.Error("expected row") + } + + if rows.NextResultSet() { + dbt.Error("unexpected next result set") + } + + if err := rows.Err(); err != nil { + dbt.Error("expected nil; got ", err) + } + }) +} + +func TestPingContext(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if err := dbt.db.PingContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelExec(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query to be done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + + // Context is already canceled, so error should come before execution. + if _, err := dbt.db.ExecContext(ctx, "INSERT INTO test VALUES (1)"); err == nil { + dbt.Error("expected error") + } else if err.Error() != "context canceled" { + dbt.Fatalf("unexpected error: %s", err) + } + + // The second insert query will fail, so the table has no changes. + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query to be done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + + // Context is already canceled, so error should come before execution. + if _, err := dbt.db.QueryContext(ctx, "INSERT INTO test VALUES (1)"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + + // The second insert query will fail, so the table has no changes. + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelQueryRow(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + dbt.mustExec("INSERT INTO test VALUES (1), (2), (3)") + ctx, cancel := context.WithCancel(context.Background()) + + rows, err := dbt.db.QueryContext(ctx, "SELECT v FROM test") + if err != nil { + dbt.Fatalf("%s", err.Error()) + } + + // the first row will be succeed. + var v int + if !rows.Next() { + dbt.Fatalf("unexpected end") + } + if err := rows.Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + + cancel() + // make sure the driver receives the cancel request. + time.Sleep(100 * time.Millisecond) + + if rows.Next() { + dbt.Errorf("expected end, but not") + } + if err := rows.Err(); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelPrepare(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + if _, err := dbt.db.PrepareContext(ctx, "SELECT 1"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextCancelStmtExec(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + if err != nil { + dbt.Fatalf("unexpected error: %v", err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := stmt.ExecContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query to be done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelStmtQuery(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + stmt, err := dbt.db.PrepareContext(ctx, "INSERT INTO test VALUES (SLEEP(1))") + if err != nil { + dbt.Fatalf("unexpected error: %v", err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(250*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := stmt.QueryContext(ctx); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Wait for the INSERT query has done. + time.Sleep(time.Second) + + // Check how many times the query is executed. + var v int + if err := dbt.db.QueryRow("SELECT COUNT(*) FROM test").Scan(&v); err != nil { + dbt.Fatalf("%s", err.Error()) + } + if v != 1 { // TODO: need to kill the query, and v should be 0. + dbt.Skipf("[WARN] expected val to be 1, got %d", v) + } + }) +} + +func TestContextCancelBegin(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + tx, err := dbt.db.BeginTx(ctx, nil) + if err != nil { + dbt.Fatal(err) + } + + // Delay execution for just a bit until db.ExecContext has begun. + defer time.AfterFunc(100*time.Millisecond, cancel).Stop() + + // This query will be canceled. + startTime := time.Now() + if _, err := tx.ExecContext(ctx, "INSERT INTO test VALUES (SLEEP(1))"); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + if d := time.Since(startTime); d > 500*time.Millisecond { + dbt.Errorf("too long execution time: %s", d) + } + + // Transaction is canceled, so expect an error. + switch err := tx.Commit(); err { + case sql.ErrTxDone: + // because the transaction has already been rollbacked. + // the database/sql package watches ctx + // and rollbacks when ctx is canceled. + case context.Canceled: + // the database/sql package rollbacks on another goroutine, + // so the transaction may not be rollbacked depending on goroutine scheduling. + default: + dbt.Errorf("expected sql.ErrTxDone or context.Canceled, got %v", err) + } + + // Context is canceled, so cannot begin a transaction. + if _, err := dbt.db.BeginTx(ctx, nil); err != context.Canceled { + dbt.Errorf("expected context.Canceled, got %v", err) + } + }) +} + +func TestContextBeginIsolationLevel(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx1, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelRepeatableRead, + }) + if err != nil { + dbt.Fatal(err) + } + + tx2, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ + Isolation: sql.LevelReadCommitted, + }) + if err != nil { + dbt.Fatal(err) + } + + _, err = tx1.ExecContext(ctx, "INSERT INTO test VALUES (1)") + if err != nil { + dbt.Fatal(err) + } + + var v int + row := tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + if err := row.Scan(&v); err != nil { + dbt.Fatal(err) + } + // Because writer transaction wasn't commited yet, it should be available + if v != 0 { + dbt.Errorf("expected val to be 0, got %d", v) + } + + err = tx1.Commit() + if err != nil { + dbt.Fatal(err) + } + + row = tx2.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + if err := row.Scan(&v); err != nil { + dbt.Fatal(err) + } + // Data written by writer transaction is already commited, it should be selectable + if v != 1 { + dbt.Errorf("expected val to be 1, got %d", v) + } + tx2.Commit() + }) +} + +func TestContextBeginReadOnly(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (v INTEGER)") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + tx, err := dbt.db.BeginTx(ctx, &sql.TxOptions{ + ReadOnly: true, + }) + if _, ok := err.(*MySQLError); ok { + dbt.Skip("It seems that your MySQL does not support READ ONLY transactions") + return + } else if err != nil { + dbt.Fatal(err) + } + + // INSERT queries fail in a READ ONLY transaction. + _, err = tx.ExecContext(ctx, "INSERT INTO test VALUES (1)") + if _, ok := err.(*MySQLError); !ok { + dbt.Errorf("expected MySQLError, got %v", err) + } + + // SELECT queries can be executed. + var v int + row := tx.QueryRowContext(ctx, "SELECT COUNT(*) FROM test") + if err := row.Scan(&v); err != nil { + dbt.Fatal(err) + } + if v != 0 { + dbt.Errorf("expected val to be 0, got %d", v) + } + + if err := tx.Commit(); err != nil { + dbt.Fatal(err) + } + }) +} + +func TestRowsColumnTypes(t *testing.T) { + niNULL := sql.NullInt64{Int64: 0, Valid: false} + ni0 := sql.NullInt64{Int64: 0, Valid: true} + ni1 := sql.NullInt64{Int64: 1, Valid: true} + ni42 := sql.NullInt64{Int64: 42, Valid: true} + nfNULL := sql.NullFloat64{Float64: 0.0, Valid: false} + nf0 := sql.NullFloat64{Float64: 0.0, Valid: true} + nf1337 := sql.NullFloat64{Float64: 13.37, Valid: true} + nt0 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 0, time.UTC), Valid: true} + nt1 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 100000000, time.UTC), Valid: true} + nt2 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 110000000, time.UTC), Valid: true} + nt6 := NullTime{Time: time.Date(2006, 01, 02, 15, 04, 05, 111111000, time.UTC), Valid: true} + nd1 := NullTime{Time: time.Date(2006, 01, 02, 0, 0, 0, 0, time.UTC), Valid: true} + nd2 := NullTime{Time: time.Date(2006, 03, 04, 0, 0, 0, 0, time.UTC), Valid: true} + ndNULL := NullTime{Time: time.Time{}, Valid: false} + rbNULL := sql.RawBytes(nil) + rb0 := sql.RawBytes("0") + rb42 := sql.RawBytes("42") + rbTest := sql.RawBytes("Test") + rb0pad4 := sql.RawBytes("0\x00\x00\x00") // BINARY right-pads values with 0x00 + rbx0 := sql.RawBytes("\x00") + rbx42 := sql.RawBytes("\x42") + + var columns = []struct { + name string + fieldType string // type used when creating table schema + databaseTypeName string // actual type used by MySQL + scanType reflect.Type + nullable bool + precision int64 // 0 if not ok + scale int64 + valuesIn [3]string + valuesOut [3]interface{} + }{ + {"bit8null", "BIT(8)", "BIT", scanTypeRawBytes, true, 0, 0, [3]string{"0x0", "NULL", "0x42"}, [3]interface{}{rbx0, rbNULL, rbx42}}, + {"boolnull", "BOOL", "TINYINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "true", "0"}, [3]interface{}{niNULL, ni1, ni0}}, + {"bool", "BOOL NOT NULL", "TINYINT", scanTypeInt8, false, 0, 0, [3]string{"1", "0", "FALSE"}, [3]interface{}{int8(1), int8(0), int8(0)}}, + {"intnull", "INTEGER", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"smallint", "SMALLINT NOT NULL", "SMALLINT", scanTypeInt16, false, 0, 0, [3]string{"0", "-32768", "32767"}, [3]interface{}{int16(0), int16(-32768), int16(32767)}}, + {"smallintnull", "SMALLINT", "SMALLINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"int3null", "INT(3)", "INT", scanTypeNullInt, true, 0, 0, [3]string{"0", "NULL", "42"}, [3]interface{}{ni0, niNULL, ni42}}, + {"int7", "INT(7) NOT NULL", "INT", scanTypeInt32, false, 0, 0, [3]string{"0", "-1337", "42"}, [3]interface{}{int32(0), int32(-1337), int32(42)}}, + {"mediumintnull", "MEDIUMINT", "MEDIUMINT", scanTypeNullInt, true, 0, 0, [3]string{"0", "42", "NULL"}, [3]interface{}{ni0, ni42, niNULL}}, + {"bigint", "BIGINT NOT NULL", "BIGINT", scanTypeInt64, false, 0, 0, [3]string{"0", "65535", "-42"}, [3]interface{}{int64(0), int64(65535), int64(-42)}}, + {"bigintnull", "BIGINT", "BIGINT", scanTypeNullInt, true, 0, 0, [3]string{"NULL", "1", "42"}, [3]interface{}{niNULL, ni1, ni42}}, + {"tinyuint", "TINYINT UNSIGNED NOT NULL", "TINYINT", scanTypeUint8, false, 0, 0, [3]string{"0", "255", "42"}, [3]interface{}{uint8(0), uint8(255), uint8(42)}}, + {"smalluint", "SMALLINT UNSIGNED NOT NULL", "SMALLINT", scanTypeUint16, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint16(0), uint16(65535), uint16(42)}}, + {"biguint", "BIGINT UNSIGNED NOT NULL", "BIGINT", scanTypeUint64, false, 0, 0, [3]string{"0", "65535", "42"}, [3]interface{}{uint64(0), uint64(65535), uint64(42)}}, + {"uint13", "INT(13) UNSIGNED NOT NULL", "INT", scanTypeUint32, false, 0, 0, [3]string{"0", "1337", "42"}, [3]interface{}{uint32(0), uint32(1337), uint32(42)}}, + {"float", "FLOAT NOT NULL", "FLOAT", scanTypeFloat32, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float32(0), float32(42), float32(13.37)}}, + {"floatnull", "FLOAT", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"float74null", "FLOAT(7,4)", "FLOAT", scanTypeNullFloat, true, math.MaxInt64, 4, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"double", "DOUBLE NOT NULL", "DOUBLE", scanTypeFloat64, false, math.MaxInt64, math.MaxInt64, [3]string{"0", "42", "13.37"}, [3]interface{}{float64(0), float64(42), float64(13.37)}}, + {"doublenull", "DOUBLE", "DOUBLE", scanTypeNullFloat, true, math.MaxInt64, math.MaxInt64, [3]string{"0", "NULL", "13.37"}, [3]interface{}{nf0, nfNULL, nf1337}}, + {"decimal1", "DECIMAL(10,6) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 10, 6, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), sql.RawBytes("13.370000"), sql.RawBytes("1234.123456")}}, + {"decimal1null", "DECIMAL(10,6)", "DECIMAL", scanTypeRawBytes, true, 10, 6, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.000000"), rbNULL, sql.RawBytes("1234.123456")}}, + {"decimal2", "DECIMAL(8,4) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 8, 4, [3]string{"0", "13.37", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), sql.RawBytes("13.3700"), sql.RawBytes("1234.1235")}}, + {"decimal2null", "DECIMAL(8,4)", "DECIMAL", scanTypeRawBytes, true, 8, 4, [3]string{"0", "NULL", "1234.123456"}, [3]interface{}{sql.RawBytes("0.0000"), rbNULL, sql.RawBytes("1234.1235")}}, + {"decimal3", "DECIMAL(5,0) NOT NULL", "DECIMAL", scanTypeRawBytes, false, 5, 0, [3]string{"0", "13.37", "-12345.123456"}, [3]interface{}{rb0, sql.RawBytes("13"), sql.RawBytes("-12345")}}, + {"decimal3null", "DECIMAL(5,0)", "DECIMAL", scanTypeRawBytes, true, 5, 0, [3]string{"0", "NULL", "-12345.123456"}, [3]interface{}{rb0, rbNULL, sql.RawBytes("-12345")}}, + {"char25null", "CHAR(25)", "CHAR", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"varchar42", "VARCHAR(42) NOT NULL", "VARCHAR", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"binary4null", "BINARY(4)", "BINARY", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0pad4, rbNULL, rbTest}}, + {"varbinary42", "VARBINARY(42) NOT NULL", "VARBINARY", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"tinyblobnull", "TINYBLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"tinytextnull", "TINYTEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"blobnull", "BLOB", "BLOB", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"textnull", "TEXT", "TEXT", scanTypeRawBytes, true, 0, 0, [3]string{"0", "NULL", "'Test'"}, [3]interface{}{rb0, rbNULL, rbTest}}, + {"mediumblob", "MEDIUMBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"mediumtext", "MEDIUMTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"longblob", "LONGBLOB NOT NULL", "BLOB", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"longtext", "LONGTEXT NOT NULL", "TEXT", scanTypeRawBytes, false, 0, 0, [3]string{"0", "'Test'", "42"}, [3]interface{}{rb0, rbTest, rb42}}, + {"datetime", "DATETIME", "DATETIME", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt0, nt0}}, + {"datetime2", "DATETIME(2)", "DATETIME", scanTypeNullTime, true, 2, 2, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt2}}, + {"datetime6", "DATETIME(6)", "DATETIME", scanTypeNullTime, true, 6, 6, [3]string{"'2006-01-02 15:04:05'", "'2006-01-02 15:04:05.1'", "'2006-01-02 15:04:05.111111'"}, [3]interface{}{nt0, nt1, nt6}}, + {"date", "DATE", "DATE", scanTypeNullTime, true, 0, 0, [3]string{"'2006-01-02'", "NULL", "'2006-03-04'"}, [3]interface{}{nd1, ndNULL, nd2}}, + {"year", "YEAR NOT NULL", "YEAR", scanTypeUint16, false, 0, 0, [3]string{"2006", "2000", "1994"}, [3]interface{}{uint16(2006), uint16(2000), uint16(1994)}}, + } + + schema := "" + values1 := "" + values2 := "" + values3 := "" + for _, column := range columns { + schema += fmt.Sprintf("`%s` %s, ", column.name, column.fieldType) + values1 += column.valuesIn[0] + ", " + values2 += column.valuesIn[1] + ", " + values3 += column.valuesIn[2] + ", " + } + schema = schema[:len(schema)-2] + values1 = values1[:len(values1)-2] + values2 = values2[:len(values2)-2] + values3 = values3[:len(values3)-2] + + dsns := []string{ + dsn + "&parseTime=true", + dsn + "&parseTime=false", + } + for _, testdsn := range dsns { + runTests(t, testdsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (" + schema + ")") + dbt.mustExec("INSERT INTO test VALUES (" + values1 + "), (" + values2 + "), (" + values3 + ")") + + rows, err := dbt.db.Query("SELECT * FROM test") + if err != nil { + t.Fatalf("Query: %v", err) + } + + tt, err := rows.ColumnTypes() + if err != nil { + t.Fatalf("ColumnTypes: %v", err) + } + + if len(tt) != len(columns) { + t.Fatalf("unexpected number of columns: expected %d, got %d", len(columns), len(tt)) + } + + types := make([]reflect.Type, len(tt)) + for i, tp := range tt { + column := columns[i] + + // Name + name := tp.Name() + if name != column.name { + t.Errorf("column name mismatch %s != %s", name, column.name) + continue + } + + // DatabaseTypeName + databaseTypeName := tp.DatabaseTypeName() + if databaseTypeName != column.databaseTypeName { + t.Errorf("databasetypename name mismatch for column %q: %s != %s", name, databaseTypeName, column.databaseTypeName) + continue + } + + // ScanType + scanType := tp.ScanType() + if scanType != column.scanType { + if scanType == nil { + t.Errorf("scantype is null for column %q", name) + } else { + t.Errorf("scantype mismatch for column %q: %s != %s", name, scanType.Name(), column.scanType.Name()) + } + continue + } + types[i] = scanType + + // Nullable + nullable, ok := tp.Nullable() + if !ok { + t.Errorf("nullable not ok %q", name) + continue + } + if nullable != column.nullable { + t.Errorf("nullable mismatch for column %q: %t != %t", name, nullable, column.nullable) + } + + // Length + // length, ok := tp.Length() + // if length != column.length { + // if !ok { + // t.Errorf("length not ok for column %q", name) + // } else { + // t.Errorf("length mismatch for column %q: %d != %d", name, length, column.length) + // } + // continue + // } + + // Precision and Scale + precision, scale, ok := tp.DecimalSize() + if precision != column.precision { + if !ok { + t.Errorf("precision not ok for column %q", name) + } else { + t.Errorf("precision mismatch for column %q: %d != %d", name, precision, column.precision) + } + continue + } + if scale != column.scale { + if !ok { + t.Errorf("scale not ok for column %q", name) + } else { + t.Errorf("scale mismatch for column %q: %d != %d", name, scale, column.scale) + } + continue + } + } + + values := make([]interface{}, len(tt)) + for i := range values { + values[i] = reflect.New(types[i]).Interface() + } + i := 0 + for rows.Next() { + err = rows.Scan(values...) + if err != nil { + t.Fatalf("failed to scan values in %v", err) + } + for j := range values { + value := reflect.ValueOf(values[j]).Elem().Interface() + if !reflect.DeepEqual(value, columns[j].valuesOut[i]) { + if columns[j].scanType == scanTypeRawBytes { + t.Errorf("row %d, column %d: %v != %v", i, j, string(value.(sql.RawBytes)), string(columns[j].valuesOut[i].(sql.RawBytes))) + } else { + t.Errorf("row %d, column %d: %v != %v", i, j, value, columns[j].valuesOut[i]) + } + } + } + i++ + } + if i != 3 { + t.Errorf("expected 3 rows, got %d", i) + } + + if err := rows.Close(); err != nil { + t.Errorf("error closing rows: %s", err) + } + }) + } +} + +func TestValuerWithValueReceiverGivenNilValue(t *testing.T) { + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (value VARCHAR(255))") + dbt.db.Exec("INSERT INTO test VALUES (?)", (*testValuer)(nil)) + // This test will panic on the INSERT if ConvertValue() does not check for typed nil before calling Value() + }) +} + +// TestRawBytesAreNotModified checks for a race condition that arises when a query context +// is canceled while a user is calling rows.Scan. This is a more stringent test than the one +// proposed in https://github.com/golang/go/issues/23519. Here we're explicitly using +// `sql.RawBytes` to check the contents of our internal buffers are not modified after an implicit +// call to `Rows.Close`, so Context cancellation should **not** invalidate the backing buffers. +func TestRawBytesAreNotModified(t *testing.T) { + const blob = "abcdefghijklmnop" + const contextRaceIterations = 20 + const blobSize = defaultBufSize * 3 / 4 // Second row overwrites first row. + const insertRows = 4 + + var sqlBlobs = [2]string{ + strings.Repeat(blob, blobSize/len(blob)), + strings.Repeat(strings.ToUpper(blob), blobSize/len(blob)), + } + + runTests(t, dsn, func(dbt *DBTest) { + dbt.mustExec("CREATE TABLE test (id int, value BLOB) CHARACTER SET utf8") + for i := 0; i < insertRows; i++ { + dbt.mustExec("INSERT INTO test VALUES (?, ?)", i+1, sqlBlobs[i&1]) + } + + for i := 0; i < contextRaceIterations; i++ { + func() { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + rows, err := dbt.db.QueryContext(ctx, `SELECT id, value FROM test`) + if err != nil { + t.Fatal(err) + } + + var b int + var raw sql.RawBytes + for rows.Next() { + if err := rows.Scan(&b, &raw); err != nil { + t.Fatal(err) + } + + before := string(raw) + // Ensure cancelling the query does not corrupt the contents of `raw` + cancel() + time.Sleep(time.Microsecond * 100) + after := string(raw) + + if before != after { + t.Fatalf("the backing storage for sql.RawBytes has been modified (i=%v)", i) + } + } + rows.Close() + }() + } + }) +} diff --git a/vendor/github.com/go-sql-driver/mysql/dsn.go b/vendor/github.com/go-sql-driver/mysql/dsn.go index be014bab..1d9b4ab0 100644 --- a/vendor/github.com/go-sql-driver/mysql/dsn.go +++ b/vendor/github.com/go-sql-driver/mysql/dsn.go @@ -14,6 +14,7 @@ import ( "crypto/tls" "errors" "fmt" + "math/big" "net" "net/url" "sort" @@ -72,6 +73,26 @@ func NewConfig() *Config { } } +func (cfg *Config) Clone() *Config { + cp := *cfg + if cp.tls != nil { + cp.tls = cfg.tls.Clone() + } + if len(cp.Params) > 0 { + cp.Params = make(map[string]string, len(cfg.Params)) + for k, v := range cfg.Params { + cp.Params[k] = v + } + } + if cfg.pubKey != nil { + cp.pubKey = &rsa.PublicKey{ + N: new(big.Int).Set(cfg.pubKey.N), + E: cfg.pubKey.E, + } + } + return &cp +} + func (cfg *Config) normalize() error { if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { return errInvalidDSNUnsafeCollation @@ -92,17 +113,35 @@ func (cfg *Config) normalize() error { default: return errors.New("default addr for network '" + cfg.Net + "' unknown") } - } else if cfg.Net == "tcp" { cfg.Addr = ensureHavePort(cfg.Addr) } - if cfg.tls != nil { - if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { - host, _, err := net.SplitHostPort(cfg.Addr) - if err == nil { - cfg.tls.ServerName = host - } + switch cfg.TLSConfig { + case "false", "": + // don't set anything + case "true": + cfg.tls = &tls.Config{} + case "skip-verify", "preferred": + cfg.tls = &tls.Config{InsecureSkipVerify: true} + default: + cfg.tls = getTLSConfigClone(cfg.TLSConfig) + if cfg.tls == nil { + return errors.New("invalid value / unknown config name: " + cfg.TLSConfig) + } + } + + if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + cfg.tls.ServerName = host + } + } + + if cfg.ServerPubKey != "" { + cfg.pubKey = getServerPubKey(cfg.ServerPubKey) + if cfg.pubKey == nil { + return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey) } } @@ -531,13 +570,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { if err != nil { return fmt.Errorf("invalid value for server pub key name: %v", err) } - - if pubKey := getServerPubKey(name); pubKey != nil { - cfg.ServerPubKey = name - cfg.pubKey = pubKey - } else { - return errors.New("invalid value / unknown server pub key name: " + name) - } + cfg.ServerPubKey = name // Strict mode case "strict": @@ -556,25 +589,17 @@ func parseDSNParams(cfg *Config, params string) (err error) { if isBool { if boolValue { cfg.TLSConfig = "true" - cfg.tls = &tls.Config{} } else { cfg.TLSConfig = "false" } - } else if vl := strings.ToLower(value); vl == "skip-verify" { + } else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" { cfg.TLSConfig = vl - cfg.tls = &tls.Config{InsecureSkipVerify: true} } else { name, err := url.QueryUnescape(value) if err != nil { return fmt.Errorf("invalid value for TLS config name: %v", err) } - - if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { - cfg.TLSConfig = name - cfg.tls = tlsConfig - } else { - return errors.New("invalid value / unknown config name: " + name) - } + cfg.TLSConfig = name } // I/O write Timeout diff --git a/vendor/github.com/go-sql-driver/mysql/dsn_test.go b/vendor/github.com/go-sql-driver/mysql/dsn_test.go new file mode 100644 index 00000000..50dc2932 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/dsn_test.go @@ -0,0 +1,415 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/tls" + "fmt" + "net/url" + "reflect" + "testing" + "time" +) + +var testDSNs = []struct { + in string + out *Config +}{{ + "username:password@protocol(address)/dbname?param=value", + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true", + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true}, +}, { + "username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true", + &Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true, MultiStatements: true}, +}, { + "user@unix(/path/to/socket)/dbname?charset=utf8", + &Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "true"}, +}, { + "user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"}, +}, { + "user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true}, +}, { + "user:password@/dbname?allowNativePasswords=false&maxAllowedPacket=0", + &Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false}, +}, { + "user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local", + &Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "/dbname", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "@/", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "/", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "user:p@/ssword@/", + &Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "unix/?arg=%2Fsome%2Fpath.ext", + &Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "tcp(127.0.0.1)/dbname", + &Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, { + "tcp(de:ad:be:ef::ca:fe)/dbname", + &Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true}, +}, +} + +func TestDSNParser(t *testing.T) { + for i, tst := range testDSNs { + cfg, err := ParseDSN(tst.in) + if err != nil { + t.Error(err.Error()) + } + + // pointer not static + cfg.tls = nil + + if !reflect.DeepEqual(cfg, tst.out) { + t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out) + } + } +} + +func TestDSNParserInvalid(t *testing.T) { + var invalidDSNs = []string{ + "@net(addr/", // no closing brace + "@tcp(/", // no closing brace + "tcp(/", // no closing brace + "(/", // no closing brace + "net(addr)//", // unescaped + "User:pass@tcp(1.2.3.4:3306)", // no trailing slash + "net()/", // unknown default addr + //"/dbname?arg=/some/unescaped/path", + } + + for i, tst := range invalidDSNs { + if _, err := ParseDSN(tst); err == nil { + t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst) + } + } +} + +func TestDSNReformat(t *testing.T) { + for i, tst := range testDSNs { + dsn1 := tst.in + cfg1, err := ParseDSN(dsn1) + if err != nil { + t.Error(err.Error()) + continue + } + cfg1.tls = nil // pointer not static + res1 := fmt.Sprintf("%+v", cfg1) + + dsn2 := cfg1.FormatDSN() + cfg2, err := ParseDSN(dsn2) + if err != nil { + t.Error(err.Error()) + continue + } + cfg2.tls = nil // pointer not static + res2 := fmt.Sprintf("%+v", cfg2) + + if res1 != res2 { + t.Errorf("%d. %q does not match %q", i, res2, res1) + } + } +} + +func TestDSNServerPubKey(t *testing.T) { + baseDSN := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + + RegisterServerPubKey("testKey", testPubKeyRSA) + defer DeregisterServerPubKey("testKey") + + tst := baseDSN + "testKey" + cfg, err := ParseDSN(tst) + if err != nil { + t.Error(err.Error()) + } + + if cfg.ServerPubKey != "testKey" { + t.Errorf("unexpected cfg.ServerPubKey value: %v", cfg.ServerPubKey) + } + if cfg.pubKey != testPubKeyRSA { + t.Error("pub key pointer doesn't match") + } + + // Key is missing + tst = baseDSN + "invalid_name" + cfg, err = ParseDSN(tst) + if err == nil { + t.Errorf("invalid name in DSN (%s) but did not error. Got config: %#v", tst, cfg) + } +} + +func TestDSNServerPubKeyQueryEscape(t *testing.T) { + const name = "&%!:" + dsn := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + url.QueryEscape(name) + + RegisterServerPubKey(name, testPubKeyRSA) + defer DeregisterServerPubKey(name) + + cfg, err := ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + + if cfg.pubKey != testPubKeyRSA { + t.Error("pub key pointer doesn't match") + } +} + +func TestDSNWithCustomTLS(t *testing.T) { + baseDSN := "User:password@tcp(localhost:5555)/dbname?tls=" + tlsCfg := tls.Config{} + + RegisterTLSConfig("utils_test", &tlsCfg) + defer DeregisterTLSConfig("utils_test") + + // Custom TLS is missing + tst := baseDSN + "invalid_tls" + cfg, err := ParseDSN(tst) + if err == nil { + t.Errorf("invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg) + } + + tst = baseDSN + "utils_test" + + // Custom TLS with a server name + name := "foohost" + tlsCfg.ServerName = name + cfg, err = ParseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst) + } + + // Custom TLS without a server name + name = "localhost" + tlsCfg.ServerName = "" + cfg, err = ParseDSN(tst) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst) + } else if tlsCfg.ServerName != "" { + t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst) + } +} + +func TestDSNTLSConfig(t *testing.T) { + expectedServerName := "example.com" + dsn := "tcp(example.com:1234)/?tls=true" + + cfg, err := ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + } + + dsn = "tcp(example.com)/?tls=true" + cfg, err = ParseDSN(dsn) + if err != nil { + t.Error(err.Error()) + } + if cfg.tls == nil { + t.Error("cfg.tls should not be nil") + } + if cfg.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName) + } +} + +func TestDSNWithCustomTLSQueryEscape(t *testing.T) { + const configKey = "&%!:" + dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey) + name := "foohost" + tlsCfg := tls.Config{ServerName: name} + + RegisterTLSConfig(configKey, &tlsCfg) + defer DeregisterTLSConfig(configKey) + + cfg, err := ParseDSN(dsn) + + if err != nil { + t.Error(err.Error()) + } else if cfg.tls.ServerName != name { + t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn) + } +} + +func TestDSNUnsafeCollation(t *testing.T) { + _, err := ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true") + if err != errInvalidDSNUnsafeCollation { + t.Errorf("expected %v, got %v", errInvalidDSNUnsafeCollation, err) + } + + _, err = ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=gbk_chinese_ci") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=ascii_bin&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } + + _, err = ParseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true") + if err != nil { + t.Errorf("expected %v, got %v", nil, err) + } +} + +func TestParamsAreSorted(t *testing.T) { + expected := "/dbname?interpolateParams=true&foobar=baz&quux=loo" + cfg := NewConfig() + cfg.DBName = "dbname" + cfg.InterpolateParams = true + cfg.Params = map[string]string{ + "quux": "loo", + "foobar": "baz", + } + actual := cfg.FormatDSN() + if actual != expected { + t.Errorf("generic Config.Params were not sorted: want %#v, got %#v", expected, actual) + } +} + +func TestCloneConfig(t *testing.T) { + RegisterServerPubKey("testKey", testPubKeyRSA) + defer DeregisterServerPubKey("testKey") + + expectedServerName := "example.com" + dsn := "tcp(example.com:1234)/?tls=true&foobar=baz&serverPubKey=testKey" + cfg, err := ParseDSN(dsn) + if err != nil { + t.Fatal(err.Error()) + } + + cfg2 := cfg.Clone() + if cfg == cfg2 { + t.Errorf("Config.Clone did not create a separate config struct") + } + + if cfg2.tls.ServerName != expectedServerName { + t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName) + } + + cfg2.tls.ServerName = "example2.com" + if cfg.tls.ServerName == cfg2.tls.ServerName { + t.Errorf("changed cfg.tls.Server name should not propagate to original Config") + } + + if _, ok := cfg2.Params["foobar"]; !ok { + t.Errorf("cloned Config is missing custom params") + } + + delete(cfg2.Params, "foobar") + + if _, ok := cfg.Params["foobar"]; !ok { + t.Errorf("custom params in cloned Config should not propagate to original Config") + } + + if !reflect.DeepEqual(cfg.pubKey, cfg2.pubKey) { + t.Errorf("public key in Config should be identical") + } +} + +func TestNormalizeTLSConfig(t *testing.T) { + tt := []struct { + tlsConfig string + want *tls.Config + }{ + {"", nil}, + {"false", nil}, + {"true", &tls.Config{ServerName: "myserver"}}, + {"skip-verify", &tls.Config{InsecureSkipVerify: true}}, + {"preferred", &tls.Config{InsecureSkipVerify: true}}, + {"test_tls_config", &tls.Config{ServerName: "myServerName"}}, + } + + RegisterTLSConfig("test_tls_config", &tls.Config{ServerName: "myServerName"}) + defer func() { DeregisterTLSConfig("test_tls_config") }() + + for _, tc := range tt { + t.Run(tc.tlsConfig, func(t *testing.T) { + cfg := &Config{ + Addr: "myserver:3306", + TLSConfig: tc.tlsConfig, + } + + cfg.normalize() + + if cfg.tls == nil { + if tc.want != nil { + t.Fatal("wanted a tls config but got nil instead") + } + return + } + + if cfg.tls.ServerName != tc.want.ServerName { + t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')", + tc.want.ServerName, cfg.tls.ServerName) + } + if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify { + t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)", + tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify) + } + }) + } +} + +func BenchmarkParseDSN(b *testing.B) { + b.ReportAllocs() + + for i := 0; i < b.N; i++ { + for _, tst := range testDSNs { + if _, err := ParseDSN(tst.in); err != nil { + b.Error(err.Error()) + } + } + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/errors_test.go b/vendor/github.com/go-sql-driver/mysql/errors_test.go new file mode 100644 index 00000000..96f9126d --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/errors_test.go @@ -0,0 +1,42 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "log" + "testing" +) + +func TestErrorsSetLogger(t *testing.T) { + previous := errLog + defer func() { + errLog = previous + }() + + // set up logger + const expected = "prefix: test\n" + buffer := bytes.NewBuffer(make([]byte, 0, 64)) + logger := log.New(buffer, "prefix: ", 0) + + // print + SetLogger(logger) + errLog.Print("test") + + // check result + if actual := buffer.String(); actual != expected { + t.Errorf("expected %q, got %q", expected, actual) + } +} + +func TestErrorsStrictIgnoreNotes(t *testing.T) { + runTests(t, dsn+"&sql_notes=false", func(dbt *DBTest) { + dbt.mustExec("DROP TABLE IF EXISTS does_not_exist") + }) +} diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go index d873a97b..30b3352c 100644 --- a/vendor/github.com/go-sql-driver/mysql/packets.go +++ b/vendor/github.com/go-sql-driver/mysql/packets.go @@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { mc.sequence++ // packets with length 0 terminate a previous packet which is a - // multiple of (2^24)−1 bytes long + // multiple of (2^24)-1 bytes long if pktLen == 0 { // there was no previous packet if prevData == nil { @@ -96,6 +96,35 @@ func (mc *mysqlConn) writePacket(data []byte) error { return ErrPktTooLarge } + // Perform a stale connection check. We only perform this check for + // the first query on a connection that has been checked out of the + // connection pool: a fresh connection from the pool is more likely + // to be stale, and it has not performed any previous writes that + // could cause data corruption, so it's safe to return ErrBadConn + // if the check fails. + if mc.reset { + mc.reset = false + conn := mc.netConn + if mc.rawConn != nil { + conn = mc.rawConn + } + var err error + // If this connection has a ReadTimeout which we've been setting on + // reads, reset it to its default value before we attempt a non-blocking + // read, otherwise the scheduler will just time us out before we can read + if mc.cfg.ReadTimeout != 0 { + err = conn.SetReadDeadline(time.Time{}) + } + if err == nil { + err = connCheck(conn) + } + if err != nil { + errLog.Print("closing bad idle connection: ", err) + mc.Close() + return driver.ErrBadConn + } + } + for { var size int if pktLen >= maxPacketSize { @@ -154,15 +183,15 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { - data, err := mc.readPacket() +func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) { + data, err = mc.readPacket() if err != nil { // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since // in connection initialization we don't risk retrying non-idempotent actions. if err == ErrInvalidConn { return nil, "", driver.ErrBadConn } - return nil, "", err + return } if data[0] == iERR { @@ -194,11 +223,14 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { return nil, "", ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - return nil, "", ErrNoTLS + if mc.cfg.TLSConfig == "preferred" { + mc.cfg.tls = nil + } else { + return nil, "", ErrNoTLS + } } pos += 2 - plugin := "" if len(data) > pos { // character set [1 byte] // status flags [2 bytes] @@ -236,8 +268,6 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { return b[:], plugin, nil } - plugin = defaultAuthPlugin - // make a memory safe copy of the cipher slice var b [8]byte copy(b[:], authData) @@ -246,7 +276,7 @@ func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -272,7 +302,8 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, // encode length of the auth plugin data var authRespLEIBuf [9]byte - authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) + authRespLen := len(authResp) + authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen)) if len(authRespLEI) > 1 { // if the length can not be written in 1 byte, it must be written as a // length encoded integer @@ -280,9 +311,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, } pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 - if addNUL { - pktLen++ - } // To specify a db name if n := len(mc.cfg.DBName); n > 0 { @@ -291,10 +319,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, } // Calculate packet length and get buffer with that size - data := mc.buf.takeSmallBuffer(pktLen + 4) - if data == nil { + data, err := mc.buf.takeSmallBuffer(pktLen + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -333,6 +361,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, if err := tlsConn.Handshake(); err != nil { return err } + mc.rawConn = mc.netConn mc.netConn = tlsConn mc.buf.nc = tlsConn } @@ -353,10 +382,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, // Auth Data [length encoded integer] pos += copy(data[pos:], authRespLEI) pos += copy(data[pos:], authResp) - if addNUL { - data[pos] = 0x00 - pos++ - } // Databasename [null terminated string] if len(mc.cfg.DBName) > 0 { @@ -367,30 +392,24 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, pos += copy(data[pos:], plugin) data[pos] = 0x00 + pos++ // Send Auth packet - return mc.writePacket(data) + return mc.writePacket(data[:pos]) } // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error { +func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error { pktLen := 4 + len(authData) - if addNUL { - pktLen++ - } - data := mc.buf.takeSmallBuffer(pktLen) - if data == nil { + data, err := mc.buf.takeSmallBuffer(pktLen) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } // Add the auth data [EOF] copy(data[4:], authData) - if addNUL { - data[pktLen-1] = 0x00 - } - return mc.writePacket(data) } @@ -402,10 +421,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1) - if data == nil { + data, err := mc.buf.takeSmallBuffer(4 + 1) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -421,10 +440,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { mc.sequence = 0 pktLen := 1 + len(arg) - data := mc.buf.takeBuffer(pktLen + 4) - if data == nil { + data, err := mc.buf.takeBuffer(pktLen + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -442,10 +461,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { // Reset Packet Sequence mc.sequence = 0 - data := mc.buf.takeSmallBuffer(4 + 1 + 4) - if data == nil { + data, err := mc.buf.takeSmallBuffer(4 + 1 + 4) + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -482,7 +501,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { return data[1:], "", err case iEOF: - if len(data) < 1 { + if len(data) == 1 { // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest return nil, "mysql_old_password", nil } @@ -898,7 +917,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc - // Determine threshould dynamically to avoid packet size shortage. + // Determine threshold dynamically to avoid packet size shortage. longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) if longDataSize < 64 { longDataSize = 64 @@ -908,15 +927,17 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { mc.sequence = 0 var data []byte + var err error if len(args) == 0 { - data = mc.buf.takeBuffer(minPktLen) + data, err = mc.buf.takeBuffer(minPktLen) } else { - data = mc.buf.takeCompleteBuffer() + data, err = mc.buf.takeCompleteBuffer() + // In this case the len(data) == cap(data) which is used to optimise the flow below. } - if data == nil { + if err != nil { // cannot take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) + errLog.Print(err) return errBadConnNoWrite } @@ -942,7 +963,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { pos := minPktLen var nullMask []byte - if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) { + if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) { // buffer has to be extended but we don't know by how much so // we depend on append after all data with known sizes fit. // We stop at that because we deal with a lot of columns here @@ -951,10 +972,11 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { copy(tmp[:pos], data[:pos]) data = tmp nullMask = data[pos : pos+maskLen] + // No need to clean nullMask as make ensures that. pos += maskLen } else { nullMask = data[pos : pos+maskLen] - for i := 0; i < maskLen; i++ { + for i := range nullMask { nullMask[i] = 0 } pos += maskLen @@ -999,6 +1021,22 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { ) } + case uint64: + paramTypes[i+i] = byte(fieldTypeLongLong) + paramTypes[i+i+1] = 0x80 // type is unsigned + + if cap(paramValues)-len(paramValues)-8 >= 0 { + paramValues = paramValues[:len(paramValues)+8] + binary.LittleEndian.PutUint64( + paramValues[len(paramValues)-8:], + uint64(v), + ) + } else { + paramValues = append(paramValues, + uint64ToBytes(uint64(v))..., + ) + } + case float64: paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i+1] = 0x00 @@ -1091,7 +1129,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // In that case we must build the data packet with the new values buffer if valuesCap != cap(paramValues) { data = append(data[:pos], paramValues...) - mc.buf.buf = data + if err = mc.buf.store(data); err != nil { + errLog.Print(err) + return errBadConnNoWrite + } } pos += len(paramValues) @@ -1261,7 +1302,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { rows.rs.columns[i].decimals, ) } - dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) + dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen) case rows.mc.parseTime: dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) default: @@ -1281,7 +1322,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { ) } } - dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false) + dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen) } if err == nil { diff --git a/vendor/github.com/go-sql-driver/mysql/packets_test.go b/vendor/github.com/go-sql-driver/mysql/packets_test.go new file mode 100644 index 00000000..b61e4dbf --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/packets_test.go @@ -0,0 +1,336 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "errors" + "net" + "testing" + "time" +) + +var ( + errConnClosed = errors.New("connection is closed") + errConnTooManyReads = errors.New("too many reads") + errConnTooManyWrites = errors.New("too many writes") +) + +// struct to mock a net.Conn for testing purposes +type mockConn struct { + laddr net.Addr + raddr net.Addr + data []byte + written []byte + queuedReplies [][]byte + closed bool + read int + reads int + writes int + maxReads int + maxWrites int +} + +func (m *mockConn) Read(b []byte) (n int, err error) { + if m.closed { + return 0, errConnClosed + } + + m.reads++ + if m.maxReads > 0 && m.reads > m.maxReads { + return 0, errConnTooManyReads + } + + n = copy(b, m.data) + m.read += n + m.data = m.data[n:] + return +} +func (m *mockConn) Write(b []byte) (n int, err error) { + if m.closed { + return 0, errConnClosed + } + + m.writes++ + if m.maxWrites > 0 && m.writes > m.maxWrites { + return 0, errConnTooManyWrites + } + + n = len(b) + m.written = append(m.written, b...) + + if n > 0 && len(m.queuedReplies) > 0 { + m.data = m.queuedReplies[0] + m.queuedReplies = m.queuedReplies[1:] + } + return +} +func (m *mockConn) Close() error { + m.closed = true + return nil +} +func (m *mockConn) LocalAddr() net.Addr { + return m.laddr +} +func (m *mockConn) RemoteAddr() net.Addr { + return m.raddr +} +func (m *mockConn) SetDeadline(t time.Time) error { + return nil +} +func (m *mockConn) SetReadDeadline(t time.Time) error { + return nil +} +func (m *mockConn) SetWriteDeadline(t time.Time) error { + return nil +} + +// make sure mockConn implements the net.Conn interface +var _ net.Conn = new(mockConn) + +func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + cfg: NewConfig(), + netConn: conn, + closech: make(chan struct{}), + maxAllowedPacket: defaultMaxAllowedPacket, + sequence: sequence, + } + return conn, mc +} + +func TestReadPacketSingleByte(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + } + + conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} + conn.maxReads = 1 + packet, err := mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != 1 { + t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet)) + } + if packet[0] != 0xff { + t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0]) + } +} + +func TestReadPacketWrongSequenceID(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + } + + // too low sequence id + conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff} + conn.maxReads = 1 + mc.sequence = 1 + _, err := mc.readPacket() + if err != ErrPktSync { + t.Errorf("expected ErrPktSync, got %v", err) + } + + // reset + conn.reads = 0 + mc.sequence = 0 + mc.buf = newBuffer(conn) + + // too high sequence id + conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff} + _, err = mc.readPacket() + if err != ErrPktSyncMul { + t.Errorf("expected ErrPktSyncMul, got %v", err) + } +} + +func TestReadPacketSplit(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + } + + data := make([]byte, maxPacketSize*2+4*3) + const pkt2ofs = maxPacketSize + 4 + const pkt3ofs = 2 * (maxPacketSize + 4) + + // case 1: payload has length maxPacketSize + data = data[:pkt2ofs+4] + + // 1st packet has maxPacketSize length and sequence id 0 + // ff ff ff 00 ... + data[0] = 0xff + data[1] = 0xff + data[2] = 0xff + + // mark the payload start and end of 1st packet so that we can check if the + // content was correctly appended + data[4] = 0x11 + data[maxPacketSize+3] = 0x22 + + // 2nd packet has payload length 0 and squence id 1 + // 00 00 00 01 + data[pkt2ofs+3] = 0x01 + + conn.data = data + conn.maxReads = 3 + packet, err := mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != maxPacketSize { + t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(packet)) + } + if packet[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) + } + if packet[maxPacketSize-1] != 0x22 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1]) + } + + // case 2: payload has length which is a multiple of maxPacketSize + data = data[:cap(data)] + + // 2nd packet now has maxPacketSize length + data[pkt2ofs] = 0xff + data[pkt2ofs+1] = 0xff + data[pkt2ofs+2] = 0xff + + // mark the payload start and end of the 2nd packet + data[pkt2ofs+4] = 0x33 + data[pkt2ofs+maxPacketSize+3] = 0x44 + + // 3rd packet has payload length 0 and squence id 2 + // 00 00 00 02 + data[pkt3ofs+3] = 0x02 + + conn.data = data + conn.reads = 0 + conn.maxReads = 5 + mc.sequence = 0 + packet, err = mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != 2*maxPacketSize { + t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet)) + } + if packet[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) + } + if packet[2*maxPacketSize-1] != 0x44 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1]) + } + + // case 3: payload has a length larger maxPacketSize, which is not an exact + // multiple of it + data = data[:pkt2ofs+4+42] + data[pkt2ofs] = 0x2a + data[pkt2ofs+1] = 0x00 + data[pkt2ofs+2] = 0x00 + data[pkt2ofs+4+41] = 0x44 + + conn.data = data + conn.reads = 0 + conn.maxReads = 4 + mc.sequence = 0 + packet, err = mc.readPacket() + if err != nil { + t.Fatal(err) + } + if len(packet) != maxPacketSize+42 { + t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet)) + } + if packet[0] != 0x11 { + t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0]) + } + if packet[maxPacketSize+41] != 0x44 { + t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41]) + } +} + +func TestReadPacketFail(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + closech: make(chan struct{}), + } + + // illegal empty (stand-alone) packet + conn.data = []byte{0x00, 0x00, 0x00, 0x00} + conn.maxReads = 1 + _, err := mc.readPacket() + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } + + // reset + conn.reads = 0 + mc.sequence = 0 + mc.buf = newBuffer(conn) + + // fail to read header + conn.closed = true + _, err = mc.readPacket() + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } + + // reset + conn.closed = false + conn.reads = 0 + mc.sequence = 0 + mc.buf = newBuffer(conn) + + // fail to read body + conn.maxReads = 1 + _, err = mc.readPacket() + if err != ErrInvalidConn { + t.Errorf("expected ErrInvalidConn, got %v", err) + } +} + +// https://github.com/go-sql-driver/mysql/pull/801 +// not-NUL terminated plugin_name in init packet +func TestRegression801(t *testing.T) { + conn := new(mockConn) + mc := &mysqlConn{ + buf: newBuffer(conn), + cfg: new(Config), + sequence: 42, + closech: make(chan struct{}), + } + + conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0, + 60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 98, 120, 114, 47, 85, 75, 109, 99, 51, 77, + 50, 64, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95, + 112, 97, 115, 115, 119, 111, 114, 100} + conn.maxReads = 1 + + authData, pluginName, err := mc.readHandshakePacket() + if err != nil { + t.Fatalf("got error: %v", err) + } + + if pluginName != "mysql_native_password" { + t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName) + } + + expectedAuthData := []byte{60, 70, 63, 58, 68, 104, 34, 97, 98, 120, 114, + 47, 85, 75, 109, 99, 51, 77, 50, 64} + if !bytes.Equal(authData, expectedAuthData) { + t.Errorf("expected authData '%v', got '%v'", expectedAuthData, authData) + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/rows.go b/vendor/github.com/go-sql-driver/mysql/rows.go index d3b1e282..888bdb5f 100644 --- a/vendor/github.com/go-sql-driver/mysql/rows.go +++ b/vendor/github.com/go-sql-driver/mysql/rows.go @@ -111,6 +111,13 @@ func (rows *mysqlRows) Close() (err error) { return err } + // flip the buffer for this connection if we need to drain it. + // note that for a successful query (i.e. one where rows.next() + // has been called until it returns false), `rows.mc` will be nil + // by the time the user calls `(*Rows).Close`, so we won't reach this + // see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47 + mc.buf.flip() + // Remove unread packets from stream if !rows.rs.done { err = mc.readUntilEOF() diff --git a/vendor/github.com/go-sql-driver/mysql/statement.go b/vendor/github.com/go-sql-driver/mysql/statement.go index ce7fe4cd..f7e37093 100644 --- a/vendor/github.com/go-sql-driver/mysql/statement.go +++ b/vendor/github.com/go-sql-driver/mysql/statement.go @@ -13,7 +13,6 @@ import ( "fmt" "io" "reflect" - "strconv" ) type mysqlStmt struct { @@ -164,14 +163,8 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return rv.Int(), nil - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: - return int64(rv.Uint()), nil - case reflect.Uint64: - u64 := rv.Uint() - if u64 >= 1<<63 { - return strconv.FormatUint(u64, 10), nil - } - return int64(u64), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return rv.Uint(), nil case reflect.Float32, reflect.Float64: return rv.Float(), nil case reflect.Bool: diff --git a/vendor/github.com/go-sql-driver/mysql/statement_test.go b/vendor/github.com/go-sql-driver/mysql/statement_test.go new file mode 100644 index 00000000..4b9914f8 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/statement_test.go @@ -0,0 +1,126 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "testing" +) + +func TestConvertDerivedString(t *testing.T) { + type derived string + + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Derived string type not convertible", err) + } + + if output != "value" { + t.Fatalf("Derived string type not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedByteSlice(t *testing.T) { + type derived []uint8 + + output, err := converter{}.ConvertValue(derived("value")) + if err != nil { + t.Fatal("Byte slice not convertible", err) + } + + if bytes.Compare(output.([]byte), []byte("value")) != 0 { + t.Fatalf("Byte slice not converted, got %#v %T", output, output) + } +} + +func TestConvertDerivedUnsupportedSlice(t *testing.T) { + type derived []int + + _, err := converter{}.ConvertValue(derived{1}) + if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" { + t.Fatal("Unexpected error", err) + } +} + +func TestConvertDerivedBool(t *testing.T) { + type derived bool + + output, err := converter{}.ConvertValue(derived(true)) + if err != nil { + t.Fatal("Derived bool type not convertible", err) + } + + if output != true { + t.Fatalf("Derived bool type not converted, got %#v %T", output, output) + } +} + +func TestConvertPointer(t *testing.T) { + str := "value" + + output, err := converter{}.ConvertValue(&str) + if err != nil { + t.Fatal("Pointer type not convertible", err) + } + + if output != "value" { + t.Fatalf("Pointer type not converted, got %#v %T", output, output) + } +} + +func TestConvertSignedIntegers(t *testing.T) { + values := []interface{}{ + int8(-42), + int16(-42), + int32(-42), + int64(-42), + int(-42), + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != int64(-42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } +} + +func TestConvertUnsignedIntegers(t *testing.T) { + values := []interface{}{ + uint8(42), + uint16(42), + uint32(42), + uint64(42), + uint(42), + } + + for _, value := range values { + output, err := converter{}.ConvertValue(value) + if err != nil { + t.Fatalf("%T type not convertible %s", value, err) + } + + if output != uint64(42) { + t.Fatalf("%T type not converted, got %#v %T", value, output, output) + } + } + + output, err := converter{}.ConvertValue(^uint64(0)) + if err != nil { + t.Fatal("uint64 high-bit not convertible", err) + } + + if output != ^uint64(0) { + t.Fatalf("uint64 high-bit converted, got %#v %T", output, output) + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/utils.go b/vendor/github.com/go-sql-driver/mysql/utils.go index 84d595b6..cfa10e9c 100644 --- a/vendor/github.com/go-sql-driver/mysql/utils.go +++ b/vendor/github.com/go-sql-driver/mysql/utils.go @@ -10,10 +10,13 @@ package mysql import ( "crypto/tls" + "database/sql" "database/sql/driver" "encoding/binary" + "errors" "fmt" "io" + "strconv" "strings" "sync" "sync/atomic" @@ -53,7 +56,7 @@ var ( // db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom") // func RegisterTLSConfig(key string, config *tls.Config) error { - if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" { + if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" { return fmt.Errorf("key '%s' is reserved", key) } @@ -79,7 +82,7 @@ func DeregisterTLSConfig(key string) { func getTLSConfigClone(key string) (config *tls.Config) { tlsConfigLock.RLock() if v, ok := tlsConfigRegistry[key]; ok { - config = cloneTLSConfig(v) + config = v.Clone() } tlsConfigLock.RUnlock() return @@ -227,87 +230,104 @@ var zeroDateTime = []byte("0000-00-00 00:00:00.000000") const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999" -func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) { +func appendMicrosecs(dst, src []byte, decimals int) []byte { + if decimals <= 0 { + return dst + } + if len(src) == 0 { + return append(dst, ".000000"[:decimals+1]...) + } + + microsecs := binary.LittleEndian.Uint32(src[:4]) + p1 := byte(microsecs / 10000) + microsecs -= 10000 * uint32(p1) + p2 := byte(microsecs / 100) + microsecs -= 100 * uint32(p2) + p3 := byte(microsecs) + + switch decimals { + default: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + digits10[p3], digits01[p3], + ) + case 1: + return append(dst, '.', + digits10[p1], + ) + case 2: + return append(dst, '.', + digits10[p1], digits01[p1], + ) + case 3: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], + ) + case 4: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + ) + case 5: + return append(dst, '.', + digits10[p1], digits01[p1], + digits10[p2], digits01[p2], + digits10[p3], + ) + } +} + +func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) { // length expects the deterministic length of the zero value, // negative time and 100+ hours are automatically added if needed if len(src) == 0 { - if justTime { - return zeroDateTime[11 : 11+length], nil - } return zeroDateTime[:length], nil } - var dst []byte // return value - var pt, p1, p2, p3 byte // current digit pair - var zOffs byte // offset of value in zeroDateTime - if justTime { - switch length { - case - 8, // time (can be up to 10 when negative and 100+ hours) - 10, 11, 12, 13, 14, 15: // time with fractional seconds - default: - return nil, fmt.Errorf("illegal TIME length %d", length) + var dst []byte // return value + var p1, p2, p3 byte // current digit pair + + switch length { + case 10, 19, 21, 22, 23, 24, 25, 26: + default: + t := "DATE" + if length > 10 { + t += "TIME" } - switch len(src) { - case 8, 12: - default: - return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) - } - // +2 to enable negative time and 100+ hours - dst = make([]byte, 0, length+2) - if src[0] == 1 { - dst = append(dst, '-') - } - if src[1] != 0 { - hour := uint16(src[1])*24 + uint16(src[5]) - pt = byte(hour / 100) - p1 = byte(hour - 100*uint16(pt)) - dst = append(dst, digits01[pt]) - } else { - p1 = src[5] - } - zOffs = 11 - src = src[6:] - } else { - switch length { - case 10, 19, 21, 22, 23, 24, 25, 26: - default: - t := "DATE" - if length > 10 { - t += "TIME" - } - return nil, fmt.Errorf("illegal %s length %d", t, length) - } - switch len(src) { - case 4, 7, 11: - default: - t := "DATE" - if length > 10 { - t += "TIME" - } - return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) - } - dst = make([]byte, 0, length) - // start with the date - year := binary.LittleEndian.Uint16(src[:2]) - pt = byte(year / 100) - p1 = byte(year - 100*uint16(pt)) - p2, p3 = src[2], src[3] - dst = append(dst, - digits10[pt], digits01[pt], - digits10[p1], digits01[p1], '-', - digits10[p2], digits01[p2], '-', - digits10[p3], digits01[p3], - ) - if length == 10 { - return dst, nil - } - if len(src) == 4 { - return append(dst, zeroDateTime[10:length]...), nil - } - dst = append(dst, ' ') - p1 = src[4] // hour - src = src[5:] + return nil, fmt.Errorf("illegal %s length %d", t, length) } + switch len(src) { + case 4, 7, 11: + default: + t := "DATE" + if length > 10 { + t += "TIME" + } + return nil, fmt.Errorf("illegal %s packet length %d", t, len(src)) + } + dst = make([]byte, 0, length) + // start with the date + year := binary.LittleEndian.Uint16(src[:2]) + pt := year / 100 + p1 = byte(year - 100*uint16(pt)) + p2, p3 = src[2], src[3] + dst = append(dst, + digits10[pt], digits01[pt], + digits10[p1], digits01[p1], '-', + digits10[p2], digits01[p2], '-', + digits10[p3], digits01[p3], + ) + if length == 10 { + return dst, nil + } + if len(src) == 4 { + return append(dst, zeroDateTime[10:length]...), nil + } + dst = append(dst, ' ') + p1 = src[4] // hour + src = src[5:] + // p1 is 2-digit hour, src is after hour p2, p3 = src[0], src[1] dst = append(dst, @@ -315,51 +335,49 @@ func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value digits10[p2], digits01[p2], ':', digits10[p3], digits01[p3], ) - if length <= byte(len(dst)) { - return dst, nil - } - src = src[2:] + return appendMicrosecs(dst, src[2:], int(length)-20), nil +} + +func formatBinaryTime(src []byte, length uint8) (driver.Value, error) { + // length expects the deterministic length of the zero value, + // negative time and 100+ hours are automatically added if needed if len(src) == 0 { - return append(dst, zeroDateTime[19:zOffs+length]...), nil + return zeroDateTime[11 : 11+length], nil } - microsecs := binary.LittleEndian.Uint32(src[:4]) - p1 = byte(microsecs / 10000) - microsecs -= 10000 * uint32(p1) - p2 = byte(microsecs / 100) - microsecs -= 100 * uint32(p2) - p3 = byte(microsecs) - switch decimals := zOffs + length - 20; decimals { + var dst []byte // return value + + switch length { + case + 8, // time (can be up to 10 when negative and 100+ hours) + 10, 11, 12, 13, 14, 15: // time with fractional seconds default: - return append(dst, '.', - digits10[p1], digits01[p1], - digits10[p2], digits01[p2], - digits10[p3], digits01[p3], - ), nil - case 1: - return append(dst, '.', - digits10[p1], - ), nil - case 2: - return append(dst, '.', - digits10[p1], digits01[p1], - ), nil - case 3: - return append(dst, '.', - digits10[p1], digits01[p1], - digits10[p2], - ), nil - case 4: - return append(dst, '.', - digits10[p1], digits01[p1], - digits10[p2], digits01[p2], - ), nil - case 5: - return append(dst, '.', - digits10[p1], digits01[p1], - digits10[p2], digits01[p2], - digits10[p3], - ), nil + return nil, fmt.Errorf("illegal TIME length %d", length) } + switch len(src) { + case 8, 12: + default: + return nil, fmt.Errorf("invalid TIME packet length %d", len(src)) + } + // +2 to enable negative time and 100+ hours + dst = make([]byte, 0, length+2) + if src[0] == 1 { + dst = append(dst, '-') + } + days := binary.LittleEndian.Uint32(src[1:5]) + hours := int64(days)*24 + int64(src[5]) + + if hours >= 100 { + dst = strconv.AppendInt(dst, hours, 10) + } else { + dst = append(dst, digits10[hours], digits01[hours]) + } + + min, sec := src[6], src[7] + dst = append(dst, ':', + digits10[min], digits01[min], ':', + digits10[sec], digits01[sec], + ) + return appendMicrosecs(dst, src[8:], int(length)-9), nil } /****************************************************************************** @@ -666,7 +684,7 @@ type atomicBool struct { value uint32 } -// IsSet returns wether the current boolean value is true +// IsSet returns whether the current boolean value is true func (ab *atomicBool) IsSet() bool { return atomic.LoadUint32(&ab.value) > 0 } @@ -680,7 +698,7 @@ func (ab *atomicBool) Set(value bool) { } } -// TrySet sets the value of the bool and returns wether the value changed +// TrySet sets the value of the bool and returns whether the value changed func (ab *atomicBool) TrySet(value bool) bool { if value { return atomic.SwapUint32(&ab.value, 1) == 0 @@ -708,3 +726,30 @@ func (ae *atomicError) Value() error { } return nil } + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + // TODO: support the use of Named Parameters #561 + return nil, errors.New("mysql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} + +func mapIsolationLevel(level driver.IsolationLevel) (string, error) { + switch sql.IsolationLevel(level) { + case sql.LevelRepeatableRead: + return "REPEATABLE READ", nil + case sql.LevelReadCommitted: + return "READ COMMITTED", nil + case sql.LevelReadUncommitted: + return "READ UNCOMMITTED", nil + case sql.LevelSerializable: + return "SERIALIZABLE", nil + default: + return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/utils_go17.go b/vendor/github.com/go-sql-driver/mysql/utils_go17.go deleted file mode 100644 index f5956345..00000000 --- a/vendor/github.com/go-sql-driver/mysql/utils_go17.go +++ /dev/null @@ -1,40 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// +build go1.7 -// +build !go1.8 - -package mysql - -import "crypto/tls" - -func cloneTLSConfig(c *tls.Config) *tls.Config { - return &tls.Config{ - Rand: c.Rand, - Time: c.Time, - Certificates: c.Certificates, - NameToCertificate: c.NameToCertificate, - GetCertificate: c.GetCertificate, - RootCAs: c.RootCAs, - NextProtos: c.NextProtos, - ServerName: c.ServerName, - ClientAuth: c.ClientAuth, - ClientCAs: c.ClientCAs, - InsecureSkipVerify: c.InsecureSkipVerify, - CipherSuites: c.CipherSuites, - PreferServerCipherSuites: c.PreferServerCipherSuites, - SessionTicketsDisabled: c.SessionTicketsDisabled, - SessionTicketKey: c.SessionTicketKey, - ClientSessionCache: c.ClientSessionCache, - MinVersion: c.MinVersion, - MaxVersion: c.MaxVersion, - CurvePreferences: c.CurvePreferences, - DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, - Renegotiation: c.Renegotiation, - } -} diff --git a/vendor/github.com/go-sql-driver/mysql/utils_go18.go b/vendor/github.com/go-sql-driver/mysql/utils_go18.go deleted file mode 100644 index c35c2a6a..00000000 --- a/vendor/github.com/go-sql-driver/mysql/utils_go18.go +++ /dev/null @@ -1,50 +0,0 @@ -// Go MySQL Driver - A MySQL-Driver for Go's database/sql package -// -// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. -// -// This Source Code Form is subject to the terms of the Mozilla Public -// License, v. 2.0. If a copy of the MPL was not distributed with this file, -// You can obtain one at http://mozilla.org/MPL/2.0/. - -// +build go1.8 - -package mysql - -import ( - "crypto/tls" - "database/sql" - "database/sql/driver" - "errors" - "fmt" -) - -func cloneTLSConfig(c *tls.Config) *tls.Config { - return c.Clone() -} - -func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { - dargs := make([]driver.Value, len(named)) - for n, param := range named { - if len(param.Name) > 0 { - // TODO: support the use of Named Parameters #561 - return nil, errors.New("mysql: driver does not support the use of Named Parameters") - } - dargs[n] = param.Value - } - return dargs, nil -} - -func mapIsolationLevel(level driver.IsolationLevel) (string, error) { - switch sql.IsolationLevel(level) { - case sql.LevelRepeatableRead: - return "REPEATABLE READ", nil - case sql.LevelReadCommitted: - return "READ COMMITTED", nil - case sql.LevelReadUncommitted: - return "READ UNCOMMITTED", nil - case sql.LevelSerializable: - return "SERIALIZABLE", nil - default: - return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) - } -} diff --git a/vendor/github.com/go-sql-driver/mysql/utils_test.go b/vendor/github.com/go-sql-driver/mysql/utils_test.go new file mode 100644 index 00000000..8951a7a8 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/utils_test.go @@ -0,0 +1,334 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/binary" + "testing" + "time" +) + +func TestScanNullTime(t *testing.T) { + var scanTests = []struct { + in interface{} + error bool + valid bool + time time.Time + }{ + {tDate, false, true, tDate}, + {sDate, false, true, tDate}, + {[]byte(sDate), false, true, tDate}, + {tDateTime, false, true, tDateTime}, + {sDateTime, false, true, tDateTime}, + {[]byte(sDateTime), false, true, tDateTime}, + {tDate0, false, true, tDate0}, + {sDate0, false, true, tDate0}, + {[]byte(sDate0), false, true, tDate0}, + {sDateTime0, false, true, tDate0}, + {[]byte(sDateTime0), false, true, tDate0}, + {"", true, false, tDate0}, + {"1234", true, false, tDate0}, + {0, true, false, tDate0}, + } + + var nt = NullTime{} + var err error + + for _, tst := range scanTests { + err = nt.Scan(tst.in) + if (err != nil) != tst.error { + t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil)) + } + if nt.Valid != tst.valid { + t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid) + } + if nt.Time != tst.time { + t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time) + } + } +} + +func TestLengthEncodedInteger(t *testing.T) { + var integerTests = []struct { + num uint64 + encoded []byte + }{ + {0x0000000000000000, []byte{0x00}}, + {0x0000000000000012, []byte{0x12}}, + {0x00000000000000fa, []byte{0xfa}}, + {0x0000000000000100, []byte{0xfc, 0x00, 0x01}}, + {0x0000000000001234, []byte{0xfc, 0x34, 0x12}}, + {0x000000000000ffff, []byte{0xfc, 0xff, 0xff}}, + {0x0000000000010000, []byte{0xfd, 0x00, 0x00, 0x01}}, + {0x0000000000123456, []byte{0xfd, 0x56, 0x34, 0x12}}, + {0x0000000000ffffff, []byte{0xfd, 0xff, 0xff, 0xff}}, + {0x0000000001000000, []byte{0xfe, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}}, + {0x123456789abcdef0, []byte{0xfe, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12}}, + {0xffffffffffffffff, []byte{0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}}, + } + + for _, tst := range integerTests { + num, isNull, numLen := readLengthEncodedInteger(tst.encoded) + if isNull { + t.Errorf("%x: expected %d, got NULL", tst.encoded, tst.num) + } + if num != tst.num { + t.Errorf("%x: expected %d, got %d", tst.encoded, tst.num, num) + } + if numLen != len(tst.encoded) { + t.Errorf("%x: expected size %d, got %d", tst.encoded, len(tst.encoded), numLen) + } + encoded := appendLengthEncodedInteger(nil, num) + if !bytes.Equal(encoded, tst.encoded) { + t.Errorf("%v: expected %x, got %x", num, tst.encoded, encoded) + } + } +} + +func TestFormatBinaryDateTime(t *testing.T) { + rawDate := [11]byte{} + binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years + rawDate[2] = 12 // months + rawDate[3] = 30 // days + rawDate[4] = 15 // hours + rawDate[5] = 46 // minutes + rawDate[6] = 23 // seconds + binary.LittleEndian.PutUint32(rawDate[7:], 987654) // microseconds + expect := func(expected string, inlen, outlen uint8) { + actual, _ := formatBinaryDateTime(rawDate[:inlen], outlen) + bytes, ok := actual.([]byte) + if !ok { + t.Errorf("formatBinaryDateTime must return []byte, was %T", actual) + } + if string(bytes) != expected { + t.Errorf( + "expected %q, got %q for length in %d, out %d", + expected, actual, inlen, outlen, + ) + } + } + expect("0000-00-00", 0, 10) + expect("0000-00-00 00:00:00", 0, 19) + expect("1978-12-30", 4, 10) + expect("1978-12-30 15:46:23", 7, 19) + expect("1978-12-30 15:46:23.987654", 11, 26) +} + +func TestFormatBinaryTime(t *testing.T) { + expect := func(expected string, src []byte, outlen uint8) { + actual, _ := formatBinaryTime(src, outlen) + bytes, ok := actual.([]byte) + if !ok { + t.Errorf("formatBinaryDateTime must return []byte, was %T", actual) + } + if string(bytes) != expected { + t.Errorf( + "expected %q, got %q for src=%q and outlen=%d", + expected, actual, src, outlen) + } + } + + // binary format: + // sign (0: positive, 1: negative), days(4), hours, minutes, seconds, micro(4) + + // Zeros + expect("00:00:00", []byte{}, 8) + expect("00:00:00.0", []byte{}, 10) + expect("00:00:00.000000", []byte{}, 15) + + // Without micro(4) + expect("12:34:56", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 8) + expect("-12:34:56", []byte{1, 0, 0, 0, 0, 12, 34, 56}, 8) + expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 11) + expect("24:34:56", []byte{0, 1, 0, 0, 0, 0, 34, 56}, 8) + expect("-99:34:56", []byte{1, 4, 0, 0, 0, 3, 34, 56}, 8) + expect("103079215103:34:56", []byte{0, 255, 255, 255, 255, 23, 34, 56}, 8) + + // With micro(4) + expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 11) + expect("12:34:56.000099", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 15) +} + +func TestEscapeBackslash(t *testing.T) { + expect := func(expected, value string) { + actual := string(escapeBytesBackslash([]byte{}, []byte(value))) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + + actual = string(escapeStringBackslash([]byte{}, value)) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + } + + expect("foo\\0bar", "foo\x00bar") + expect("foo\\nbar", "foo\nbar") + expect("foo\\rbar", "foo\rbar") + expect("foo\\Zbar", "foo\x1abar") + expect("foo\\\"bar", "foo\"bar") + expect("foo\\\\bar", "foo\\bar") + expect("foo\\'bar", "foo'bar") +} + +func TestEscapeQuotes(t *testing.T) { + expect := func(expected, value string) { + actual := string(escapeBytesQuotes([]byte{}, []byte(value))) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + + actual = string(escapeStringQuotes([]byte{}, value)) + if actual != expected { + t.Errorf( + "expected %s, got %s", + expected, actual, + ) + } + } + + expect("foo\x00bar", "foo\x00bar") // not affected + expect("foo\nbar", "foo\nbar") // not affected + expect("foo\rbar", "foo\rbar") // not affected + expect("foo\x1abar", "foo\x1abar") // not affected + expect("foo''bar", "foo'bar") // affected + expect("foo\"bar", "foo\"bar") // not affected +} + +func TestAtomicBool(t *testing.T) { + var ab atomicBool + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + + ab.Set(true) + if ab.value != 1 { + t.Fatal("Set(true) did not set value to 1") + } + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + + ab.Set(true) + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + + ab.Set(false) + if ab.value != 0 { + t.Fatal("Set(false) did not set value to 0") + } + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + + ab.Set(false) + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + if ab.TrySet(false) { + t.Fatal("Expected TrySet(false) to fail") + } + if !ab.TrySet(true) { + t.Fatal("Expected TrySet(true) to succeed") + } + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + + ab.Set(true) + if !ab.IsSet() { + t.Fatal("Expected value to be true") + } + if ab.TrySet(true) { + t.Fatal("Expected TrySet(true) to fail") + } + if !ab.TrySet(false) { + t.Fatal("Expected TrySet(false) to succeed") + } + if ab.IsSet() { + t.Fatal("Expected value to be false") + } + + ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯ +} + +func TestAtomicError(t *testing.T) { + var ae atomicError + if ae.Value() != nil { + t.Fatal("Expected value to be nil") + } + + ae.Set(ErrMalformPkt) + if v := ae.Value(); v != ErrMalformPkt { + if v == nil { + t.Fatal("Value is still nil") + } + t.Fatal("Error did not match") + } + ae.Set(ErrPktSync) + if ae.Value() == ErrMalformPkt { + t.Fatal("Error still matches old error") + } + if v := ae.Value(); v != ErrPktSync { + t.Fatal("Error did not match") + } +} + +func TestIsolationLevelMapping(t *testing.T) { + data := []struct { + level driver.IsolationLevel + expected string + }{ + { + level: driver.IsolationLevel(sql.LevelReadCommitted), + expected: "READ COMMITTED", + }, + { + level: driver.IsolationLevel(sql.LevelRepeatableRead), + expected: "REPEATABLE READ", + }, + { + level: driver.IsolationLevel(sql.LevelReadUncommitted), + expected: "READ UNCOMMITTED", + }, + { + level: driver.IsolationLevel(sql.LevelSerializable), + expected: "SERIALIZABLE", + }, + } + + for i, td := range data { + if actual, err := mapIsolationLevel(td.level); actual != td.expected || err != nil { + t.Fatal(i, td.expected, actual, err) + } + } + + // check unsupported mapping + expectedErr := "mysql: unsupported isolation level: 7" + actual, err := mapIsolationLevel(driver.IsolationLevel(sql.LevelLinearizable)) + if actual != "" || err == nil { + t.Fatal("Expected error on unsupported isolation level") + } + if err.Error() != expectedErr { + t.Fatalf("Expected error to be %q, got %q", expectedErr, err) + } +}