1
0
Fork 0
mirror of https://github.com/documize/community.git synced 2025-07-18 20:59:43 +02:00

auth with cas

This commit is contained in:
Derek Chen 2019-08-09 13:44:03 +08:00
parent 8c99977fc9
commit 8c2df6178d
150 changed files with 43682 additions and 24175 deletions

25
Gopkg.lock generated
View file

@ -100,6 +100,14 @@
revision = "d523deb1b23d913de5bdada721a6071e71283618"
version = "v1.4.0"
[[projects]]
branch = "master"
digest = "1:1ba1d79f2810270045c328ae5d674321db34e3aae468eb4233883b473c5c0467"
name = "github.com/golang/glog"
packages = ["."]
pruneopts = "UT"
revision = "23def4e6c14b4da8ac2ed8007337bc5eb5007998"
[[projects]]
digest = "1:ffc060c551980d37ee9e428ef528ee2813137249ccebb0bfc412ef83071cac91"
name = "github.com/golang/protobuf"
@ -298,6 +306,14 @@
revision = "379148ca0225df7a432012b8df0355c2a2063ac0"
version = "v1.2"
[[projects]]
digest = "1:d095b21d330637ad0e1025231ef91023f64c65dda093a437001fe8becfb77099"
name = "gopkg.in/cas.v2"
packages = ["."]
pruneopts = "UT"
revision = "1b87d011d1fc0430cdbdfe3115c9843ec33d9da6"
version = "v2.1.0"
[[projects]]
digest = "1:e9a0fa7c2dfc90e0fae16be5825ad98074d8704f5fcebfdc289a8e8fb0f8e4b5"
name = "gopkg.in/ldap.v3"
@ -306,6 +322,14 @@
revision = "9f0d712775a0973b7824a1585a86a4ea1d5263d9"
version = "v3.0.3"
[[projects]]
digest = "1:4d2e5a73dc1500038e504a8d78b986630e3626dc027bc030ba5c75da257cdb96"
name = "gopkg.in/yaml.v2"
packages = ["."]
pruneopts = "UT"
revision = "51d6538a90f86fe93ac480b35f37b2be17fef232"
version = "v2.2.2"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
@ -334,6 +358,7 @@
"golang.org/x/oauth2",
"gopkg.in/alexcesaro/quotedprintable.v3",
"gopkg.in/andygrunwald/go-jira.v1",
"gopkg.in/cas.v2",
"gopkg.in/ldap.v3",
]
solver-name = "gps-cdcl"

165
domain/auth/cas/endpoint.go Normal file
View file

@ -0,0 +1,165 @@
package cas
import (
"database/sql"
"encoding/json"
"fmt"
"github.com/documize/community/core/env"
"github.com/documize/community/core/response"
"github.com/documize/community/core/secrets"
"github.com/documize/community/core/streamutil"
"github.com/documize/community/core/stringutil"
"github.com/documize/community/domain"
"github.com/documize/community/domain/auth"
"github.com/documize/community/domain/store"
usr "github.com/documize/community/domain/user"
ath "github.com/documize/community/model/auth"
"github.com/documize/community/model/user"
casv2 "gopkg.in/cas.v2"
"io/ioutil"
"net/http"
"net/url"
"strings"
)
// Handler contains the runtime information such as logging and database.
type Handler struct {
Runtime *env.Runtime
Store *store.Store
}
// Authenticate checks CAS authentication credentials.
func (h *Handler) Authenticate(w http.ResponseWriter, r *http.Request) {
method := "authenticate"
ctx := domain.GetRequestContext(r)
defer streamutil.Close(r.Body)
body, err := ioutil.ReadAll(r.Body)
if err != nil {
response.WriteBadRequestError(w, method, "Bad payload")
h.Runtime.Log.Error(method, err)
return
}
a := ath.CASAuthRequest{}
err = json.Unmarshal(body, &a)
if err != nil {
response.WriteBadRequestError(w, method, err.Error())
h.Runtime.Log.Error(method, err)
return
}
a.Ticket = strings.TrimSpace(a.Ticket)
org, err := h.Store.Organization.GetOrganizationByDomain("")
if err != nil {
response.WriteUnauthorizedError(w)
h.Runtime.Log.Error(method, err)
return
}
ctx.OrgID = org.RefID
// Fetch CAS auth provider config
ac := ath.CASConfig{}
err = json.Unmarshal([]byte(org.AuthConfig), &ac)
if err != nil {
response.WriteBadRequestError(w, method, "Unable to unmarshall Keycloak Public Key")
h.Runtime.Log.Error(method, err)
return
}
service := url.QueryEscape(ac.RedirectURL)
validateUrl := ac.URL + "/serviceValidate?ticket=" + a.Ticket + "&service="+ service;
resp, err := http.Get(validateUrl)
if err != nil {
response.WriteBadRequestError(w, method, "Unable to get service validate url")
h.Runtime.Log.Error(method, err)
return
}
defer streamutil.Close(resp.Body)
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
response.WriteBadRequestError(w, method, "Verity CAS ticket error")
h.Runtime.Log.Error(method, err)
return
}
userInfo, err := casv2.ParseServiceResponse(data)
if err != nil {
response.WriteBadRequestError(w, method, "can't parse user info")
h.Runtime.Log.Error(method, err)
return
}
h.Runtime.Log.Info("cas logon attempt " + userInfo.User)
u, err := h.Store.User.GetByDomain(ctx, a.Domain, userInfo.User)
if err != nil && err != sql.ErrNoRows {
response.WriteServerError(w, method, err)
h.Runtime.Log.Error(method, err)
return
}
// Create user account if not found
if err == sql.ErrNoRows {
h.Runtime.Log.Info("cas add user " + userInfo.User + " @ " + a.Domain)
u = user.User{}
u.Active = true
u.ViewUsers = false
u.Analytics = false
u.Admin = false
u.GlobalAdmin = false
u.Email = userInfo.User
u.Initials = stringutil.MakeInitials(userInfo.User, "")
u.Salt = secrets.GenerateSalt()
u.Password = secrets.GeneratePassword(secrets.GenerateRandomPassword(), u.Salt)
u, err = auth.AddExternalUser(ctx, h.Runtime, h.Store, u, true)
if err != nil {
response.WriteServerError(w, method, err)
h.Runtime.Log.Error(method, err)
return
}
}
// Password correct and active user
if userInfo.User != strings.TrimSpace(strings.ToLower(u.Email)) {
response.WriteUnauthorizedError(w)
return
}
// Attach user accounts and work out permissions.
usr.AttachUserAccounts(ctx, *h.Store, org.RefID, &u)
// No accounts signals data integrity problem
// so we reject login request.
if len(u.Accounts) == 0 {
response.WriteUnauthorizedError(w)
err = fmt.Errorf("no user accounts found for %s", u.Email)
h.Runtime.Log.Error(method, err)
return
}
// Abort login request if account is disabled.
for _, ac := range u.Accounts {
if ac.OrgID == org.RefID {
if ac.Active == false {
response.WriteUnauthorizedError(w)
err = fmt.Errorf("no ACTIVE user account found for %s", u.Email)
h.Runtime.Log.Error(method, err)
return
}
break
}
}
// Generate JWT token
authModel := ath.AuthenticationModel{}
authModel.Token = auth.GenerateJWT(h.Runtime, u.RefID, org.RefID, a.Domain)
authModel.User = u
response.WriteJSON(w, authModel)
return
}

File diff suppressed because one or more lines are too long

View file

@ -9,10 +9,13 @@
//
// https://documize.com
import { isPresent } from '@ember/utils';
import { reject, resolve } from 'rsvp';
import { inject as service } from '@ember/service';
import Base from 'ember-simple-auth/authenticators/base';
import netUtil from "../utils/net";
export default Base.extend({
ajax: service(),
@ -28,8 +31,16 @@ export default Base.extend({
return reject();
},
authenticate(){
return this.get('ajax').request('public/authenticate/cas' );
authenticate(data){
data.domain = netUtil.getSubdomain();
if (!isPresent(data.ticket)) {
return reject("data.ticket is empty");
}
return this.get('ajax').post('public/authenticate/cas', {
data: JSON.stringify(data),
contentType: 'json'
});
},
invalidate() {

View file

@ -0,0 +1,14 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
import Controller from '@ember/controller';
export default Controller.extend({});

View file

@ -0,0 +1,66 @@
// Copyright 2016 Documize Inc. <legal@documize.com>. All rights reserved.
//
// This software (Documize Community Edition) is licensed under
// GNU AGPL v3 http://www.gnu.org/licenses/agpl-3.0.en.html
//
// You can operate outside the AGPL restrictions by purchasing
// Documize Enterprise Edition and obtaining a commercial license
// by contacting <sales@documize.com>.
//
// https://documize.com
import { Promise as EmberPromise } from 'rsvp';
import { inject as service } from '@ember/service';
import Route from '@ember/routing/route';
export default Route.extend({
ajax: service(),
session: service(),
appMeta: service(),
localStorage: service(),
queryParams: {
mode: {
refreshModel: true
},
ticket : {
refreshModel: true,
replace : true,
}
},
message: '',
afterModel(model) {
return new EmberPromise((resolve) => {
let constants = this.get('constants');
if (this.get('appMeta.authProvider') !== constants.AuthProvider.CAS) {
resolve();
}
let ticket = model.ticket;
if (ticket === '') {
resolve();
}
let data = {ticket: ticket};
this.get("session").authenticate('authenticator:cas', data).then(() => {
this.transitionTo('folders');
}, (reject) => {
if (!_.isUndefined(reject.Error)) {
this.set('message', reject.Error);
} else {
this.set('message', reject);
}
this.set('mode', 'reject');
resolve();
});
})
},
model(params) {
return {
mode: this.get('mode'),
message: this.get('message'),
ticket: params.ticket
}
}
});

View file

@ -0,0 +1,13 @@
{{#if (is-equal model.mode "login")}}
<div class="sso-box">
<p>Authenticating with CAS...</p>
<img src="/assets/img/busy-gray.gif">
</div>
{{/if}}
{{#if (is-equal model.mode "reject")}}
<div class="sso-box">
<p>CAS authentication failure</p>
<p>{{model.message}}</p>
</div>
{{/if}}

View file

@ -12,7 +12,6 @@
import { inject as service } from '@ember/service';
import AuthProvider from '../../../mixins/auth';
import Controller from '@ember/controller';
import {Promise as EmberPromise} from "rsvp";
export default Controller.extend(AuthProvider, {
appMeta: service('app-meta'),
@ -73,14 +72,6 @@ export default Controller.extend(AuthProvider, {
// this.set('invalidCredentials', true);
// });
// }
},
loginWithCAS(){
// let config = this.get('config');
let url = 'https://sso.bangdao-tech.com/sso/login?service=' + encodeURIComponent('https://duty.bangdao-tech.com/');
window.location.replace(url);
}
}

View file

@ -41,6 +41,13 @@ export default Route.extend({
});
break;
case constants.AuthProvider.CAS: {
let config = JSON.parse(this.get('appMeta.authConfig'));
let url = config.url + '/login?service=' + encodeURIComponent(config.redirectUrl);
window.location.replace(url);
resolve();
break;
}
default:
this.set('showLogin', true);

View file

@ -27,11 +27,7 @@
{{input type="password" value=password id="authPassword" class="form-control" autocomplete="current-password"}}
{{/if}}
</div>
{{#if isAuthProviderCAS}}
{{ui/ui-button color=constants.Color.Green light=true label=constants.Label.SignIn onClick=(action "loginWithCAS")}}
{{else}}
{{ui/ui-button color=constants.Color.Green light=true label=constants.Label.SignIn onClick=(action "login")}}
{{/if}}
{{ui/ui-button color=constants.Color.Green light=true label=constants.Label.SignIn onClick=(action "login")}}
<div class="{{unless invalidCredentials "invisible"}} color-red-600 mt-3">Invalid credentials</div>

View file

@ -44,6 +44,7 @@ export default Route.extend(AuthenticatedRouteMixin, {
break;
case constants.AuthProvider.CAS:
data.authConfig = config;
break;
case constants.AuthProvider.Documize:
data.authConfig = '';
break;

View file

@ -160,6 +160,9 @@ export default Router.map(function () {
this.route('share', {
path: 'share/:id/:slug/:serial'
});
this.route('cas', {
path: 'cas'
});
}
);

28267
gui/package-lock.json generated Normal file

File diff suppressed because it is too large Load diff

15
model/auth/cas.go Normal file
View file

@ -0,0 +1,15 @@
package auth
// CASAuthRequest data received via Keycloak client library
type CASAuthRequest struct {
Ticket string `json:"ticket"`
Domain string `json:"domain"`
}
// CASConfig server configuration
type CASConfig struct {
URL string `json:"url"`
RedirectURL string `json"redirectUrl"`
}

View file

@ -17,6 +17,7 @@ import (
"github.com/documize/community/core/env"
"github.com/documize/community/domain/attachment"
"github.com/documize/community/domain/auth"
"github.com/documize/community/domain/auth/cas"
"github.com/documize/community/domain/auth/keycloak"
"github.com/documize/community/domain/auth/ldap"
"github.com/documize/community/domain/backup"
@ -66,6 +67,7 @@ func RegisterEndpoints(rt *env.Runtime, s *store.Store) {
setting := setting.Handler{Runtime: rt, Store: s}
category := category.Handler{Runtime: rt, Store: s}
keycloak := keycloak.Handler{Runtime: rt, Store: s}
cas := cas.Handler{Runtime:rt, Store: s}
template := template.Handler{Runtime: rt, Store: s, Indexer: indexer}
document := document.Handler{Runtime: rt, Store: s, Indexer: indexer}
attachment := attachment.Handler{Runtime: rt, Store: s, Indexer: indexer}
@ -93,6 +95,7 @@ func RegisterEndpoints(rt *env.Runtime, s *store.Store) {
AddPublic(rt, "authenticate/keycloak", []string{"POST", "OPTIONS"}, nil, keycloak.Authenticate)
AddPublic(rt, "authenticate/ldap", []string{"POST", "OPTIONS"}, nil, ldap.Authenticate)
AddPublic(rt, "authenticate/cas", []string{"POST", "OPTIONS"}, nil, cas.Authenticate)
AddPublic(rt, "authenticate", []string{"POST", "OPTIONS"}, nil, auth.Login)
AddPublic(rt, "validate", []string{"GET", "OPTIONS"}, nil, auth.ValidateToken)
AddPublic(rt, "forgot", []string{"POST", "OPTIONS"}, nil, user.ForgotPassword)

View file

@ -117,27 +117,6 @@ _, err := db.ExecContext(ctx, "sp_RunMe",
)
```
## Reading Output Parameters from a Stored Procedure with Resultset
To read output parameters from a stored procedure with resultset, make sure you read all the rows before reading the output parameters:
```go
sqltextcreate := `
CREATE PROCEDURE spwithoutputandrows
@bitparam BIT OUTPUT
AS BEGIN
SET @bitparam = 1
SELECT 'Row 1'
END
`
var bitout int64
rows, err := db.QueryContext(ctx, "spwithoutputandrows", sql.Named("bitparam", sql.Out{Dest: &bitout}))
var strrow string
for rows.Next() {
err = rows.Scan(&strrow)
}
fmt.Printf("bitparam is %d", bitout)
```
## Caveat for local temporary tables
Due to protocol limitations, temporary tables will only be allocated on the connection
@ -210,7 +189,7 @@ are supported:
* "cloud.google.com/go/civil".Date -> date
* "cloud.google.com/go/civil".DateTime -> datetime2
* "cloud.google.com/go/civil".Time -> time
* mssql.TVP -> Table Value Parameter (TDS version dependent)
* mssql.TVPType -> Table Value Parameter (TDS version dependent)
## Important Notes

View file

@ -1,287 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// bach splits a single script containing multiple batches separated by
// a keyword into multiple scripts.
package batch
import (
"bytes"
"fmt"
"strconv"
"strings"
"unicode"
)
// Split the provided SQL into multiple sql scripts based on a given
// separator, often "GO". It also allows escaping newlines with a
// backslash.
func Split(sql, separator string) []string {
if len(separator) == 0 || len(sql) < len(separator) {
return []string{sql}
}
l := &lexer{
Sql: sql,
Sep: separator,
At: 0,
}
state := stateWhitespace
for state != nil {
state = state(l)
}
l.AddCurrent(1)
return l.Batch
}
const debugPrintStateName = false
func printStateName(name string, l *lexer) {
if debugPrintStateName {
fmt.Printf("state %s At=%d\n", name, l.At)
}
}
func hasPrefixFold(s, sep string) bool {
if len(s) < len(sep) {
return false
}
return strings.EqualFold(s[:len(sep)], sep)
}
type lexer struct {
Sql string
Sep string
At int
Start int
Skip []int
Batch []string
}
func (l *lexer) Add(b string) {
if len(b) == 0 {
return
}
l.Batch = append(l.Batch, b)
}
func (l *lexer) Next() bool {
l.At++
return l.At < len(l.Sql)
}
func (l *lexer) AddCurrent(count int64) bool {
if count < 0 {
count = 0
}
if l.At >= len(l.Sql) {
l.At = len(l.Sql)
}
text := l.Sql[l.Start:l.At]
if len(l.Skip) > 0 {
buf := &bytes.Buffer{}
nextSkipIndex := 0
nextSkip := l.Skip[nextSkipIndex]
for i, r := range text {
if i == nextSkip {
nextSkipIndex++
if nextSkipIndex < len(l.Skip) {
nextSkip = l.Skip[nextSkipIndex]
}
continue
}
buf.WriteRune(r)
}
text = buf.String()
l.Skip = nil
}
// Limit the number of counts for sanity.
if count > 1000 {
count = 1000
}
for i := int64(0); i < count; i++ {
l.Add(text)
}
l.At += len(l.Sep)
l.Start = l.At
return (l.At < len(l.Sql))
}
type stateFn func(*lexer) stateFn
const (
lineComment = "--"
leftComment = "/*"
rightComment = "*/"
)
func stateSep(l *lexer) stateFn {
printStateName("sep", l)
if l.At+len(l.Sep) >= len(l.Sql) {
return nil
}
s := l.Sql[l.At+len(l.Sep):]
parseNumberStart := -1
loop:
for i, r := range s {
switch {
case r == '\n', r == '\r':
l.AddCurrent(1)
return stateWhitespace
case unicode.IsSpace(r):
case unicode.IsNumber(r):
parseNumberStart = i
break loop
}
}
if parseNumberStart < 0 {
return nil
}
parseNumberCount := 0
numLoop:
for i, r := range s[parseNumberStart:] {
switch {
case unicode.IsNumber(r):
parseNumberCount = i
default:
break numLoop
}
}
parseNumberEnd := parseNumberStart + parseNumberCount + 1
count, err := strconv.ParseInt(s[parseNumberStart:parseNumberEnd], 10, 64)
if err != nil {
return stateText
}
for _, r := range s[parseNumberEnd:] {
switch {
case r == '\n', r == '\r':
l.AddCurrent(count)
l.At += parseNumberEnd
l.Start = l.At
return stateWhitespace
case unicode.IsSpace(r):
default:
return stateText
}
}
return nil
}
func stateText(l *lexer) stateFn {
printStateName("text", l)
for {
ch := l.Sql[l.At]
switch {
case strings.HasPrefix(l.Sql[l.At:], lineComment):
l.At += len(lineComment)
return stateLineComment
case strings.HasPrefix(l.Sql[l.At:], leftComment):
l.At += len(leftComment)
return stateMultiComment
case ch == '\'':
l.At += 1
return stateString
case ch == '\r', ch == '\n':
l.At += 1
return stateWhitespace
default:
if l.Next() == false {
return nil
}
}
}
}
func stateWhitespace(l *lexer) stateFn {
printStateName("whitespace", l)
if l.At >= len(l.Sql) {
return nil
}
ch := l.Sql[l.At]
switch {
case unicode.IsSpace(rune(ch)):
l.At += 1
return stateWhitespace
case hasPrefixFold(l.Sql[l.At:], l.Sep):
return stateSep
default:
return stateText
}
}
func stateLineComment(l *lexer) stateFn {
printStateName("line-comment", l)
for {
if l.At >= len(l.Sql) {
return nil
}
ch := l.Sql[l.At]
switch {
case ch == '\r', ch == '\n':
l.At += 1
return stateWhitespace
default:
if l.Next() == false {
return nil
}
}
}
}
func stateMultiComment(l *lexer) stateFn {
printStateName("multi-line-comment", l)
for {
switch {
case strings.HasPrefix(l.Sql[l.At:], rightComment):
l.At += len(leftComment)
return stateWhitespace
default:
if l.Next() == false {
return nil
}
}
}
}
func stateString(l *lexer) stateFn {
printStateName("string", l)
for {
if l.At >= len(l.Sql) {
return nil
}
ch := l.Sql[l.At]
chNext := rune(-1)
if l.At+1 < len(l.Sql) {
chNext = rune(l.Sql[l.At+1])
}
switch {
case ch == '\\' && (chNext == '\r' || chNext == '\n'):
next := 2
l.Skip = append(l.Skip, l.At, l.At+1)
if chNext == '\r' && l.At+2 < len(l.Sql) && l.Sql[l.At+2] == '\n' {
l.Skip = append(l.Skip, l.At+2)
next = 3
}
l.At += next
case ch == '\'' && chNext == '\'':
l.At += 2
case ch == '\'' && chNext != '\'':
l.At += 1
return stateWhitespace
default:
if l.Next() == false {
return nil
}
}
}
}

View file

@ -1,12 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build gofuzz
package batch
func Fuzz(data []byte) int {
Split(string(data), "GO")
return 0
}

View file

@ -1,120 +0,0 @@
// Copyright 2017 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package batch
import (
"fmt"
"testing"
)
func TestBatchSplit(t *testing.T) {
type testItem struct {
Sql string
Expect []string
}
list := []testItem{
testItem{
Sql: `use DB
go
select 1
go
select 2
`,
Expect: []string{`use DB
`, `
select 1
`, `
select 2
`,
},
},
testItem{
Sql: `go
use DB go
`,
Expect: []string{`
use DB go
`,
},
},
testItem{
Sql: `select 'It''s go time'
go
select top 1 1`,
Expect: []string{`select 'It''s go time'
`, `
select top 1 1`,
},
},
testItem{
Sql: `select 1 /* go */
go
select top 1 1`,
Expect: []string{`select 1 /* go */
`, `
select top 1 1`,
},
},
testItem{
Sql: `select 1 -- go
go
select top 1 1`,
Expect: []string{`select 1 -- go
`, `
select top 1 1`,
},
},
testItem{Sql: `"0'"`, Expect: []string{`"0'"`}},
testItem{Sql: "0'", Expect: []string{"0'"}},
testItem{Sql: "--", Expect: []string{"--"}},
testItem{Sql: "GO", Expect: nil},
testItem{Sql: "/*", Expect: []string{"/*"}},
testItem{Sql: "gO\x01\x00O550655490663051008\n", Expect: []string{"\n"}},
testItem{Sql: "select 1;\nGO 2\nselect 2;", Expect: []string{"select 1;\n", "select 1;\n", "\nselect 2;"}},
testItem{Sql: "select 'hi\\\n-hello';", Expect: []string{"select 'hi-hello';"}},
testItem{Sql: "select 'hi\\\r\n-hello';", Expect: []string{"select 'hi-hello';"}},
testItem{Sql: "select 'hi\\\r-hello';", Expect: []string{"select 'hi-hello';"}},
testItem{Sql: "select 'hi\\\n\nhello';", Expect: []string{"select 'hi\nhello';"}},
}
index := -1
for i := range list {
if index >= 0 && index != i {
continue
}
sqltext := list[i].Sql
t.Run(fmt.Sprintf("index-%d", i), func(t *testing.T) {
ss := Split(sqltext, "go")
if len(ss) != len(list[i].Expect) {
t.Errorf("Test Item index %d; expect %d items, got %d %q", i, len(list[i].Expect), len(ss), ss)
return
}
for j := 0; j < len(ss); j++ {
if ss[j] != list[i].Expect[j] {
t.Errorf("Test Item index %d, batch index %d; expect <%s>, got <%s>", i, j, list[i].Expect[j], ss[j])
}
}
})
}
}
func TestHasPrefixFold(t *testing.T) {
list := []struct {
s, pre string
is bool
}{
{"h", "H", true},
{"h", "K", false},
{"go 5\n", "go", true},
}
for _, item := range list {
is := hasPrefixFold(item.s, item.pre)
if is != item.is {
t.Errorf("want (%q, %q)=%t got %t", item.s, item.pre, item.is, is)
}
}
}

View file

@ -1,284 +0,0 @@
package mssql
import (
"bytes"
"errors"
"testing"
)
type closableBuffer struct {
*bytes.Buffer
}
func (closableBuffer) Close() error {
return nil
}
type failBuffer struct {
}
func (failBuffer) Read([]byte) (int, error) {
return 0, errors.New("read failed")
}
func (failBuffer) Write([]byte) (int, error) {
return 0, errors.New("write failed")
}
func (failBuffer) Close() error {
return nil
}
func makeBuf(bufSize uint16, testData []byte) *tdsBuffer {
buffer := closableBuffer{bytes.NewBuffer(testData)}
return newTdsBuffer(bufSize, &buffer)
}
func TestStreamShorterThanHeader(t *testing.T) {
//buffer := closableBuffer{*bytes.NewBuffer([]byte{0xFF, 0xFF})}
//buffer := closableBuffer{*bytes.NewBuffer([]byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF})}
//tdsBuffer := newTdsBuffer(100, &buffer)
buffer := makeBuf(100, []byte{0xFF, 0xFF})
_, err := buffer.BeginRead()
if err == nil {
t.Fatal("BeginRead was expected to return error but it didn't")
} else {
t.Log("BeginRead failed as expected with error:", err.Error())
}
}
func TestInvalidLengthInHeaderTooLong(t *testing.T) {
buffer := makeBuf(8, []byte{0xFF, 0xFF, 0x0, 0x9, 0xff, 0xff, 0xff, 0xff})
_, err := buffer.BeginRead()
if err == nil {
t.Fatal("BeginRead was expected to return error but it didn't")
} else {
if err.Error() != "Invalid packet size, it is longer than buffer size" {
t.Fatal("BeginRead failed with incorrect error", err)
} else {
t.Log("BeginRead failed as expected with error:", err.Error())
}
}
}
func TestInvalidLengthInHeaderTooShort(t *testing.T) {
buffer := makeBuf(100, []byte{0xFF, 0xFF, 0x0, 0x1, 0xff, 0xff, 0xff, 0xff})
_, err := buffer.BeginRead()
if err == nil {
t.Fatal("BeginRead was expected to return error but it didn't")
} else {
t.Log("BeginRead failed as expected with error:", err.Error())
}
}
func TestInvalidLengthInHeaderLongerThanIncomingBuffer(t *testing.T) {
buffer := makeBuf(9, []byte{0xFF, 0xFF, 0x0, 0x9, 0xff, 0xff, 0xff, 0xff})
_, err := buffer.BeginRead()
if err == nil {
t.Fatal("BeginRead was expected to return error but it didn't")
} else {
t.Log("BeginRead failed as expected with error:", err.Error())
}
}
func TestBeginReadSucceeds(t *testing.T) {
buffer := makeBuf(9, []byte{0x01 /*id*/, 0xFF /*status*/, 0x0, 0x9 /*size*/, 0xff, 0xff, 0xff, 0xff, 0x02 /*test byte*/})
id, err := buffer.BeginRead()
if err != nil {
t.Fatal("BeginRead failed:", err.Error())
}
if id != 1 {
t.Fatalf("Expected id to be 1 but it is %d", id)
}
b, err := buffer.ReadByte()
if err != nil {
t.Fatal("ReadByte failed:", err.Error())
}
if b != 2 {
t.Fatalf("Expected read byte to be 2 but it is %d", b)
}
// should fail because no more bytes left
_, err = buffer.ReadByte()
if err == nil {
t.Fatal("ReadByte was expected to return error but it didn't")
} else {
t.Log("ReadByte failed as expected with error:", err.Error())
}
testBuf := []byte{0, 1, 2}
// should fail because no more bytes left
_, err = buffer.Read(testBuf)
if err == nil {
t.Fatal("Read was expected to return error but it didn't")
} else {
t.Log("Read failed as expected with error:", err.Error())
}
}
func TestReadByteFailsOnSecondPacket(t *testing.T) {
buffer := makeBuf(9, []byte{
0x01 /*id*/, 0x0 /*not final*/, 0x0, 0x9 /*size*/, 0xff, 0xff, 0xff, 0xff, 0x02, /*test byte*/
0x01 /*next id, this packet is invalid, it is too short*/})
_, err := buffer.BeginRead()
if err != nil {
t.Fatal("BeginRead failed:", err.Error())
}
_, err = buffer.ReadByte()
if err != nil {
t.Fatal("ReadByte failed:", err.Error())
}
_, err = buffer.ReadByte()
if err == nil {
t.Fatal("ReadByte was expected to return error but it didn't")
} else {
t.Log("ReadByte failed as expected with error:", err.Error())
}
t.Run("test byte() panic", func(t *testing.T) {
defer func() {
recover()
}()
buffer.byte()
t.Fatal("byte() should panic, but it didn't")
})
t.Run("test ReadFull() panic", func(t *testing.T) {
defer func() {
recover()
}()
buf := make([]byte, 10)
buffer.ReadFull(buf)
t.Fatal("ReadFull() should panic, but it didn't")
})
}
func TestReadFailsOnSecondPacket(t *testing.T) {
buffer := makeBuf(9, []byte{
0x01 /*id*/, 0x0 /*not final*/, 0x0, 0x9 /*size*/, 0xff, 0xff, 0xff, 0xff, 0x02, /*test byte*/
0x01 /*next id, this packet is invalid, it is too short*/})
_, err := buffer.BeginRead()
if err != nil {
t.Fatal("BeginRead failed:", err.Error())
}
testBuf := []byte{0}
_, err = buffer.Read(testBuf)
if err != nil {
t.Fatal("Read failed:", err.Error())
}
if testBuf[0] != 2 {
t.Fatal("Read returned invalid value")
}
_, err = buffer.Read(testBuf)
if err == nil {
t.Fatal("ReadByte was expected to return error but it didn't")
} else {
t.Log("ReadByte failed as expected with error:", err.Error())
}
}
func TestWrite(t *testing.T) {
memBuf := bytes.NewBuffer([]byte{})
buf := newTdsBuffer(11, closableBuffer{memBuf})
buf.BeginPacket(1, false)
err := buf.WriteByte(2)
if err != nil {
t.Fatal("WriteByte failed:", err.Error())
}
wrote, err := buf.Write([]byte{3, 4})
if err != nil {
t.Fatal("Write failed:", err.Error())
}
if wrote != 2 {
t.Fatalf("Write returned invalid value of written bytes %d", wrote)
}
err = buf.FinishPacket()
if err != nil {
t.Fatal("FinishPacket failed:", err.Error())
}
if bytes.Compare(memBuf.Bytes(), []byte{1, 1, 0, 11, 0, 0, 1, 0, 2, 3, 4}) != 0 {
t.Fatalf("Written buffer has invalid content: %v", memBuf.Bytes())
}
buf.BeginPacket(2, false)
wrote, err = buf.Write([]byte{3, 4, 5, 6})
if err != nil {
t.Fatal("Write failed:", err.Error())
}
if wrote != 4 {
t.Fatalf("Write returned invalid value of written bytes %d", wrote)
}
err = buf.FinishPacket()
if err != nil {
t.Fatal("FinishPacket failed:", err.Error())
}
expectedBuf := []byte{
1, 1, 0, 11, 0, 0, 1, 0, 2, 3, 4, // packet 1
2, 0, 0, 11, 0, 0, 1, 0, 3, 4, 5, // packet 2
2, 1, 0, 9, 0, 0, 2, 0, 6, // packet 3
}
if bytes.Compare(memBuf.Bytes(), expectedBuf) != 0 {
t.Fatalf("Written buffer has invalid content:\n got: %v\nwant: %v", memBuf.Bytes(), expectedBuf)
}
}
func TestWriteErrors(t *testing.T) {
// write should fail if underlying transport fails
buf := newTdsBuffer(uint16(headerSize)+1, failBuffer{})
buf.BeginPacket(1, false)
wrote, err := buf.Write([]byte{0, 0})
// may change from error to panic in future
if err == nil {
t.Fatal("Write should fail but it didn't")
}
if wrote != 1 {
t.Fatal("Should write 1 byte but it wrote ", wrote)
}
// writebyte should fail if underlying transport fails
buf = newTdsBuffer(uint16(headerSize)+1, failBuffer{})
buf.BeginPacket(1, false)
// first write should not fail because if fits in the buffer
err = buf.WriteByte(0)
if err != nil {
t.Fatal("First WriteByte should not fail because it should fit in the buffer, but it failed", err)
}
err = buf.WriteByte(0)
// may change from error to panic in future
if err == nil {
t.Fatal("Second WriteByte should fail but it didn't")
}
}
func TestWrite_BufferBounds(t *testing.T) {
memBuf := bytes.NewBuffer([]byte{})
buf := newTdsBuffer(11, closableBuffer{memBuf})
buf.BeginPacket(1, false)
// write bytes enough to complete a package
_, err := buf.Write([]byte{1, 1, 1})
if err != nil {
t.Fatal("Write failed:", err.Error())
}
err = buf.WriteByte(1)
if err != nil {
t.Fatal("WriteByte failed:", err.Error())
}
_, err = buf.Write([]byte{1, 1, 1})
if err != nil {
t.Fatal("Write failed:", err.Error())
}
err = buf.FinishPacket()
if err != nil {
t.Fatal("FinishPacket failed:", err.Error())
}
}

View file

@ -1,237 +0,0 @@
// +build go1.9
package mssql
import (
"context"
"database/sql"
"encoding/hex"
"math"
"reflect"
"strings"
"testing"
"time"
)
func TestBulkcopy(t *testing.T) {
// TDS level Bulk Insert is not supported on Azure SQL Server.
if dsn := makeConnStr(t); strings.HasSuffix(strings.Split(dsn.Host, ":")[0], ".database.windows.net") {
t.Skip("TDS level bulk copy is not supported on Azure SQL Server")
}
type testValue struct {
colname string
val interface{}
}
tableName := "#table_test"
geom, _ := hex.DecodeString("E6100000010C00000000000034400000000000004440")
bin, _ := hex.DecodeString("ba8b7782168d4033a299333aec17bd33")
testValues := []testValue{
{"test_nvarchar", "ab©ĎéⒻghïjklmnopqЯ☀tuvwxyz"},
{"test_varchar", "abcdefg"},
{"test_char", "abcdefg "},
{"test_nchar", "abcdefg "},
{"test_text", "abcdefg"},
{"test_ntext", "abcdefg"},
{"test_float", 1234.56},
{"test_floatn", 1234.56},
{"test_real", 1234.56},
{"test_realn", 1234.56},
{"test_bit", true},
{"test_bitn", nil},
{"test_smalldatetime", time.Date(2010, 11, 12, 13, 14, 0, 0, time.UTC)},
{"test_smalldatetimen", time.Date(2010, 11, 12, 13, 14, 0, 0, time.UTC)},
{"test_datetime", time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC)},
{"test_datetimen", time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC)},
{"test_datetimen_1", time.Date(4010, 11, 12, 13, 14, 15, 120000000, time.UTC)},
{"test_datetime2_1", time.Date(2010, 11, 12, 13, 14, 15, 0, time.UTC)},
{"test_datetime2_3", time.Date(2010, 11, 12, 13, 14, 15, 123000000, time.UTC)},
{"test_datetime2_7", time.Date(2010, 11, 12, 13, 14, 15, 123000000, time.UTC)},
{"test_date", time.Date(2010, 11, 12, 00, 00, 00, 0, time.UTC)},
{"test_tinyint", 255},
{"test_smallint", 32767},
{"test_smallintn", nil},
{"test_int", 2147483647},
{"test_bigint", 9223372036854775807},
{"test_bigintn", nil},
{"test_geom", geom},
{"test_uniqueidentifier", []byte{0x6F, 0x96, 0x19, 0xFF, 0x8B, 0x86, 0xD0, 0x11, 0xB4, 0x2D, 0x00, 0xC0, 0x4F, 0xC9, 0x64, 0xFF}},
// {"test_smallmoney", 1234.56},
// {"test_money", 1234.56},
{"test_decimal_18_0", 1234.0001},
{"test_decimal_9_2", 1234.560001},
{"test_decimal_20_0", 1234.0001},
{"test_numeric_30_10", 1234567.1234567},
{"test_varbinary", []byte("1")},
{"test_varbinary_16", bin},
{"test_varbinary_max", bin},
{"test_binary", []byte("1")},
{"test_binary_16", bin},
}
columns := make([]string, len(testValues))
for i, val := range testValues {
columns[i] = val.colname
}
values := make([]interface{}, len(testValues))
for i, val := range testValues {
values[i] = val.val
}
pool := open(t)
defer pool.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// Now that session resetting is supported, the use of the per session
// temp table requires the use of a dedicated connection from the connection
// pool.
conn, err := pool.Conn(ctx)
if err != nil {
t.Fatal("failed to pull connection from pool", err)
}
defer conn.Close()
err = setupTable(ctx, t, conn, tableName)
if err != nil {
t.Error("Setup table failed: ", err)
return
}
t.Log("Preparing copy in statement")
stmt, err := conn.PrepareContext(ctx, CopyIn(tableName, BulkOptions{}, columns...))
for i := 0; i < 10; i++ {
t.Logf("Executing copy in statement %d time with %d values", i+1, len(values))
_, err = stmt.Exec(values...)
if err != nil {
t.Error("AddRow failed: ", err.Error())
return
}
}
result, err := stmt.Exec()
if err != nil {
t.Fatal("bulkcopy failed: ", err.Error())
}
insertedRowCount, _ := result.RowsAffected()
if insertedRowCount == 0 {
t.Fatal("0 row inserted!")
}
//check that all rows are present
var rowCount int
err = conn.QueryRowContext(ctx, "select count(*) c from "+tableName).Scan(&rowCount)
if rowCount != 10 {
t.Errorf("unexpected row count %d", rowCount)
}
//data verification
rows, err := conn.QueryContext(ctx, "select "+strings.Join(columns, ",")+" from "+tableName)
if err != nil {
t.Fatal(err)
}
defer rows.Close()
for rows.Next() {
ptrs := make([]interface{}, len(columns))
container := make([]interface{}, len(columns))
for i, _ := range ptrs {
ptrs[i] = &container[i]
}
if err := rows.Scan(ptrs...); err != nil {
t.Fatal(err)
}
for i, c := range testValues {
if !compareValue(container[i], c.val) {
t.Errorf("columns %s : expected: %v, got: %v\n", c.colname, c.val, container[i])
}
}
}
if err := rows.Err(); err != nil {
t.Error(err)
}
}
func compareValue(a interface{}, expected interface{}) bool {
switch expected := expected.(type) {
case int:
return int64(expected) == a
case int32:
return int64(expected) == a
case int64:
return int64(expected) == a
case float64:
if got, ok := a.([]uint8); ok {
var nf sql.NullFloat64
nf.Scan(got)
a = nf.Float64
}
return math.Abs(expected-a.(float64)) < 0.0001
default:
return reflect.DeepEqual(expected, a)
}
}
func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName string) (err error) {
tablesql := `CREATE TABLE ` + tableName + ` (
[id] [int] IDENTITY(1,1) NOT NULL,
[test_nvarchar] [nvarchar](50) NULL,
[test_varchar] [varchar](50) NULL,
[test_char] [char](10) NULL,
[test_nchar] [nchar](10) NULL,
[test_text] [text] NULL,
[test_ntext] [ntext] NULL,
[test_float] [float] NOT NULL,
[test_floatn] [float] NULL,
[test_real] [real] NULL,
[test_realn] [real] NULL,
[test_bit] [bit] NOT NULL,
[test_bitn] [bit] NULL,
[test_smalldatetime] [smalldatetime] NOT NULL,
[test_smalldatetimen] [smalldatetime] NULL,
[test_datetime] [datetime] NOT NULL,
[test_datetimen] [datetime] NULL,
[test_datetimen_1] [datetime] NULL,
[test_datetime2_1] [datetime2](1) NULL,
[test_datetime2_3] [datetime2](3) NULL,
[test_datetime2_7] [datetime2](7) NULL,
[test_date] [date] NULL,
[test_smallmoney] [smallmoney] NULL,
[test_money] [money] NULL,
[test_tinyint] [tinyint] NULL,
[test_smallint] [smallint] NOT NULL,
[test_smallintn] [smallint] NULL,
[test_int] [int] NULL,
[test_bigint] [bigint] NOT NULL,
[test_bigintn] [bigint] NULL,
[test_geom] [geometry] NULL,
[test_geog] [geography] NULL,
[text_xml] [xml] NULL,
[test_uniqueidentifier] [uniqueidentifier] NULL,
[test_decimal_18_0] [decimal](18, 0) NULL,
[test_decimal_9_2] [decimal](9, 2) NULL,
[test_decimal_20_0] [decimal](20, 0) NULL,
[test_numeric_30_10] [decimal](30, 10) NULL,
[test_varbinary] VARBINARY NOT NULL,
[test_varbinary_16] VARBINARY(16) NOT NULL,
[test_varbinary_max] VARBINARY(max) NOT NULL,
[test_binary] BINARY NOT NULL,
[test_binary_16] BINARY(16) NOT NULL,
CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED
(
[id] ASC
)WITH (PAD_INDEX = OFF, STATISTICS_NORECOMPUTE = OFF, IGNORE_DUP_KEY = OFF, ALLOW_ROW_LOCKS = ON, ALLOW_PAGE_LOCKS = ON) ON [PRIMARY]
) ON [PRIMARY] TEXTIMAGE_ON [PRIMARY];`
_, err = conn.ExecContext(ctx, tablesql)
if err != nil {
t.Fatal("tablesql failed:", err)
}
return
}

View file

@ -1,116 +0,0 @@
// +build go1.10
package mssql_test
import (
"database/sql"
"flag"
"fmt"
"log"
"strings"
"unicode/utf8"
"github.com/denisenkom/go-mssqldb"
)
const (
createTestTable = `CREATE TABLE test_table(
id int IDENTITY(1,1) NOT NULL,
test_nvarchar nvarchar(50) NULL,
test_varchar varchar(50) NULL,
test_float float NULL,
test_datetime2_3 datetime2(3) NULL,
test_bitn bit NULL,
test_bigint bigint NOT NULL,
test_geom geometry NULL,
CONSTRAINT PK_table_test_id PRIMARY KEY CLUSTERED
(
id ASC
) ON [PRIMARY]);`
dropTestTable = "IF OBJECT_ID('test_table', 'U') IS NOT NULL DROP TABLE test_table;"
)
// This example shows how to perform bulk imports
func ExampleCopyIn() {
flag.Parse()
if *debug {
fmt.Printf(" password:%s\n", *password)
fmt.Printf(" port:%d\n", *port)
fmt.Printf(" server:%s\n", *server)
fmt.Printf(" user:%s\n", *user)
}
connString := makeConnURL().String()
if *debug {
fmt.Printf(" connString:%s\n", connString)
}
db, err := sql.Open("sqlserver", connString)
if err != nil {
log.Fatal("Open connection failed:", err.Error())
}
defer db.Close()
txn, err := db.Begin()
if err != nil {
log.Fatal(err)
}
// Create table
_, err = db.Exec(createTestTable)
if err != nil {
log.Fatal(err)
}
defer db.Exec(dropTestTable)
// mssqldb.CopyIn creates string to be consumed by Prepare
stmt, err := txn.Prepare(mssql.CopyIn("test_table", mssql.BulkOptions{}, "test_varchar", "test_nvarchar", "test_float", "test_bigint"))
if err != nil {
log.Fatal(err.Error())
}
for i := 0; i < 10; i++ {
_, err = stmt.Exec(generateString(0, 30), generateStringUnicode(0, 30), i, i)
if err != nil {
log.Fatal(err.Error())
}
}
result, err := stmt.Exec()
if err != nil {
log.Fatal(err)
}
err = stmt.Close()
if err != nil {
log.Fatal(err)
}
err = txn.Commit()
if err != nil {
log.Fatal(err)
}
rowCount, _ := result.RowsAffected()
log.Printf("%d row copied\n", rowCount)
log.Printf("bye\n")
}
func generateString(x int, n int) string {
letters := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
b := make([]byte, n)
for i := range b {
b[i] = letters[(x+i)%len(letters)]
}
return string(b)
}
func generateStringUnicode(x int, n int) string {
letters := []byte("ab©💾é?ghïjklmnopqЯ☀tuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
b := &strings.Builder{}
for i := 0; i < n; i++ {
r, sz := utf8.DecodeRune(letters[x%len(letters):])
x += sz
b.WriteRune(r)
}
return b.String()
}

View file

@ -1,121 +0,0 @@
// +build go1.10
package mssql_test
import (
"database/sql"
"flag"
"fmt"
"log"
"time"
"cloud.google.com/go/civil"
"github.com/denisenkom/go-mssqldb"
)
// This example shows how to insert and retrieve date and time types data
func ExampleDateTimeOffset() {
flag.Parse()
if *debug {
fmt.Printf(" password:%s\n", *password)
fmt.Printf(" port:%d\n", *port)
fmt.Printf(" server:%s\n", *server)
fmt.Printf(" user:%s\n", *user)
}
connString := makeConnURL().String()
if *debug {
fmt.Printf(" connString:%s\n", connString)
}
db, err := sql.Open("sqlserver", connString)
if err != nil {
log.Fatal("Open connection failed:", err.Error())
}
defer db.Close()
insertDateTime(db)
retrieveDateTime(db)
retrieveDateTimeOutParam(db)
}
func insertDateTime(db *sql.DB) {
_, err := db.Exec("CREATE TABLE datetimeTable (timeCol TIME, dateCol DATE, smalldatetimeCol SMALLDATETIME, datetimeCol DATETIME, datetime2Col DATETIME2, datetimeoffsetCol DATETIMEOFFSET)")
if err != nil {
log.Fatal(err)
}
stmt, err := db.Prepare("INSERT INTO datetimeTable VALUES(@p1, @p2, @p3, @p4, @p5, @p6)")
if err != nil {
log.Fatal(err)
}
tin, err := time.Parse(time.RFC3339, "2006-01-02T22:04:05.787-07:00")
if err != nil {
log.Fatal(err)
}
var timeCol civil.Time = civil.TimeOf(tin)
var dateCol civil.Date = civil.DateOf(tin)
var smalldatetimeCol string = "2006-01-02 22:04:00"
var datetimeCol mssql.DateTime1 = mssql.DateTime1(tin)
var datetime2Col civil.DateTime = civil.DateTimeOf(tin)
var datetimeoffsetCol mssql.DateTimeOffset = mssql.DateTimeOffset(tin)
_, err = stmt.Exec(timeCol, dateCol, smalldatetimeCol, datetimeCol, datetime2Col, datetimeoffsetCol)
if err != nil {
log.Fatal(err)
}
}
func retrieveDateTime(db *sql.DB) {
rows, err := db.Query("SELECT timeCol, dateCol, smalldatetimeCol, datetimeCol, datetime2Col, datetimeoffsetCol FROM datetimeTable")
if err != nil {
log.Fatal(err)
}
defer rows.Close()
var c1, c2, c3, c4, c5, c6 time.Time
for rows.Next() {
err = rows.Scan(&c1, &c2, &c3, &c4, &c5, &c6)
if err != nil {
log.Fatal(err)
}
fmt.Printf("c1: %+v; c2: %+v; c3: %+v; c4: %+v; c5: %+v; c6: %+v;\n", c1, c2, c3, c4, c5, c6)
}
}
func retrieveDateTimeOutParam(db *sql.DB) {
CreateProcSql := `
CREATE PROCEDURE OutDatetimeProc
@timeOutParam TIME OUTPUT,
@dateOutParam DATE OUTPUT,
@smalldatetimeOutParam SMALLDATETIME OUTPUT,
@datetimeOutParam DATETIME OUTPUT,
@datetime2OutParam DATETIME2 OUTPUT,
@datetimeoffsetOutParam DATETIMEOFFSET OUTPUT
AS
SET NOCOUNT ON
SET @timeOutParam = '22:04:05.7870015'
SET @dateOutParam = '2006-01-02'
SET @smalldatetimeOutParam = '2006-01-02 22:04:00'
SET @datetimeOutParam = '2006-01-02 22:04:05.787'
SET @datetime2OutParam = '2006-01-02 22:04:05.7870015'
SET @datetimeoffsetOutParam = '2006-01-02 22:04:05.7870015 -07:00'`
_, err := db.Exec(CreateProcSql)
if err != nil {
log.Fatal(err)
}
var (
timeOutParam, datetime2OutParam, datetimeoffsetOutParam mssql.DateTimeOffset
dateOutParam, datetimeOutParam mssql.DateTime1
smalldatetimeOutParam string
)
_, err = db.Exec("OutDatetimeProc",
sql.Named("timeOutParam", sql.Out{Dest: &timeOutParam}),
sql.Named("dateOutParam", sql.Out{Dest: &dateOutParam}),
sql.Named("smalldatetimeOutParam", sql.Out{Dest: &smalldatetimeOutParam}),
sql.Named("datetimeOutParam", sql.Out{Dest: &datetimeOutParam}),
sql.Named("datetime2OutParam", sql.Out{Dest: &datetime2OutParam}),
sql.Named("datetimeoffsetOutParam", sql.Out{Dest: &datetimeoffsetOutParam}))
if err != nil {
log.Fatal(err)
}
fmt.Printf("timeOutParam: %+v; dateOutParam: %+v; smalldatetimeOutParam: %s; datetimeOutParam: %+v; datetime2OutParam: %+v; datetimeoffsetOutParam: %+v;\n", time.Time(timeOutParam), time.Time(dateOutParam), smalldatetimeOutParam, time.Time(datetimeOutParam), time.Time(datetime2OutParam), time.Time(datetimeoffsetOutParam))
}

View file

@ -1,107 +0,0 @@
package mssql
import (
"math"
"testing"
)
func TestToString(t *testing.T) {
values := []struct {
dec Decimal
s string
}{
{Decimal{positive: true, prec: 10, scale: 0, integer: [4]uint32{1, 0, 0, 0}}, "1"},
{Decimal{positive: false, prec: 10, scale: 0, integer: [4]uint32{1, 0, 0, 0}}, "-1"},
{Decimal{positive: true, prec: 10, scale: 1, integer: [4]uint32{1, 0, 0, 0}}, "0.1"},
{Decimal{positive: true, prec: 10, scale: 2, integer: [4]uint32{1, 0, 0, 0}}, "0.01"},
{Decimal{positive: false, prec: 10, scale: 1, integer: [4]uint32{1, 0, 0, 0}}, "-0.1"},
{Decimal{positive: true, prec: 10, scale: 2, integer: [4]uint32{100, 0, 0, 0}}, "1.00"},
{Decimal{positive: false, prec: 10, scale: 2, integer: [4]uint32{100, 0, 0, 0}}, "-1.00"},
{Decimal{positive: true, prec: 30, scale: 0, integer: [4]uint32{0, 1, 0, 0}}, "4294967296"}, // 2^32
{Decimal{positive: true, prec: 30, scale: 0, integer: [4]uint32{0, 0, 1, 0}}, "18446744073709551616"}, // 2^64
{Decimal{positive: true, prec: 30, scale: 0, integer: [4]uint32{0, 1, 1, 0}}, "18446744078004518912"}, // 2^64+2^32
}
for _, v := range values {
if v.dec.String() != v.s {
t.Error("String values don't match ", v.dec.String(), v.s)
}
}
}
func TestToFloat64(t *testing.T) {
values := []struct {
dec Decimal
flt float64
}{
{Decimal{positive: true, prec: 1},
0.0},
{Decimal{positive: true, prec: 1, integer: [4]uint32{1}},
1.0},
{Decimal{positive: false, prec: 1, integer: [4]uint32{1}},
-1.0},
{Decimal{positive: true, prec: 1, scale: 1, integer: [4]uint32{5}},
0.5},
{Decimal{positive: true, prec: 38, integer: [4]uint32{0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff}},
3.402823669209385e+38},
{Decimal{positive: true, prec: 38, scale: 3, integer: [4]uint32{0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff}},
3.402823669209385e+35},
}
for _, v := range values {
if v.dec.ToFloat64() != v.flt {
t.Error("ToFloat values don't match ", v.dec.ToFloat64(), v.flt)
}
}
}
func TestFromFloat64(t *testing.T) {
values := []struct {
dec Decimal
flt float64
}{
{Decimal{positive: true, prec: 20},
0.0},
{Decimal{positive: true, prec: 20, integer: [4]uint32{1}},
1.0},
{Decimal{positive: false, prec: 20, integer: [4]uint32{1}},
-1.0},
{Decimal{positive: true, prec: 20, scale: 1, integer: [4]uint32{5}},
0.5},
{Decimal{positive: true, prec: 20, integer: [4]uint32{0, 0, 0xfffff000, 0xffffffff}},
3.402823669209384e+38},
//{Decimal{positive: true, prec: 20, scale: 3, integer: [4]uint32{0, 0, 0xfffff000, 0xffffffff}},
// 3.402823669209385e+35},
}
for _, v := range values {
decfromflt, err := Float64ToDecimal(v.flt)
if err == nil {
if decfromflt != v.dec {
t.Error("FromFloat values don't match ", decfromflt, v.dec)
}
} else {
t.Error("Float64ToDecimal failed with error:", err.Error())
}
}
_, err := Float64ToDecimal(math.NaN())
if err == nil {
t.Error("Expected to get error for conversion from NaN, but didn't")
}
_, err = Float64ToDecimal(math.Inf(1))
if err == nil {
t.Error("Expected to get error for conversion from positive infinity, but didn't")
}
_, err = Float64ToDecimal(math.Inf(-1))
if err == nil {
t.Error("Expected to get error for conversion from negative infinity, but didn't")
}
_, err = Float64ToDecimal(3.402823669209386e+38)
if err == nil {
t.Error("Expected to get error for conversion from too big number, but didn't")
}
_, err = Float64ToDecimal(-3.402823669209386e+38)
if err == nil {
t.Error("Expected to get error for conversion from too big number, but didn't")
}
}

View file

@ -1,70 +0,0 @@
# How to Handle Date and Time Types
SQL Server has six date and time datatypes: date, time, smalldatetime, datetime, datetime2 and datetimeoffset. Some of these datatypes may contain more information than others (for example, datetimeoffset is the only type that has time zone awareness), higher ranges, or larger precisions. In a Go application using the mssql driver, the data types used to hold these data must be chosen carefully so no data is lost.
## Inserting Date and Time Data
The following is a list of datatypes that can be used to insert data into a SQL Server date and/or time type column:
- string
- time.Time
- mssql.DateTime1
- mssql.DateTimeOffset
- "cloud.google.com/go/civil".Date
- "cloud.google.com/go/civil".Time
- "cloud.google.com/go/civil".DateTime
`time.Time` and `mssql.DateTimeOffset` contain the most information (time zone and over 7 digits precision). Designed to match the SQL Server `datetime` type, `mssql.DateTime1` does not have time zone information, only has up to 3 digits precision and they are rouded to increments of .000, .003 or .007 seconds when the data is passed to SQL Server. If you use `mssql.DateTime1` to hold time zone information or very precised time data (more than 3 decimal digits), you will see data lost when inserting into columns with types that can hold more information. For example:
```
// all these types have up to 7 digits precision points
// datetimeoffset can hold information about time zone
_, err := db.Exec("CREATE TABLE datetimeTable (timeCol TIME, datetime2Col DATETIME2, datetimeoffsetCol DATETIMEOFFSET)")
stmt, err := db.Prepare("INSERT INTO datetimeTable VALUES (@p1, @p2, @p3))
tin, err := time.Parse(time.RFC3339, "2006-01-02T22:04:05.7870015-07:00") // data containing 7 decimal digits and has time zone awareness
param := mssql.DateTime1(tin) // data is stored in mssql.DateTime1 type
_, err = stmt.Exec(param, param, param)
// result in database:
// timeCol: 22:04:05.7866667
// datetime2Col: 2006-01-02 22:04:05.7866667
// datetimeoffsetCol: 2006-01-02 22:04:05.7866667 +00:00
// precisions are lost in all columns. Also, time zone information is lost in datetimeoffsetCol
```
`"cloud.google.com/go/civil".DateTime` does not have time zone information. `"cloud.google.com/go/civil".Date` only has the date information, and `"cloud.google.com/go/civil".Time` only has the time information. `string` can also be used to insert data into date and time types columns, but you have to make sure the format is accepted by SQL Server.
## Retrieving Date and Time Data
The following is a list of datatypes that can be used to retrieved data from a SQL Server date and/or time type column:
- string
- sql.RawBytes
- time.Time
- mssql.DateTime1
- mssql.DateTiimeOffset
When using these data types to retrieve information from a date and/or time type column, you may end up with some extra unexpected information. For example, if you use Go type `time.Time` to retrieve information from a SQL Server `date` column:
```
var c2 time.Time
rows, err := db.Query("SELECT dateCol FROM datetimeTable") // dateCol has data `2006-01-02`
for rows.Next() {
err = rows.Scan(&c1)
fmr.Printf("c2: %+v")
// c2: 2006-01-02 00:00:00 +0000 UTC
// you get extra time and time zone information defaulty set to 0
}
```
## Output parameters with Date and Time Data
The following is a list of datatypes that can be used as buffer to hold a output parameter of SQL Server date and/or time type
- string
- time.Time
- mssql.DateTime1
- mssql.DateTimeOffset
The only type that can be used to retrieve an output of `smalldatetime` is `string`, otherwise you will get a `mssql: Error converting data type datetimeoffset/datetime1 to smalldatetime` error. Furthermore, `string` and `mssql.DateTime1` are the only types that can be used to retrieve output of `datetime` type, otherwise you will get a `mssql: Error converting data type datetimeoffset to datetime` error.
Similar to retrieving data from a result set, when retrieving data as a output parameter, you may end up with some extra unexpected information when the Go type you use contains more information than the data you retrieved from SQL Server.
## Example
[DateTime handling example](../datetimeoffset_example_test.go)

View file

@ -1,48 +0,0 @@
# How to perform bulk imports
To use the bulk imports feature in go-mssqldb, you need to import the sql and go-mssqldb packages.
```
import (
"database/sql"
"github.com/denisenkom/go-mssqldb"
)
```
The `mssql.CopyIn` function creates a string which can be prepared by passing it to `Prepare`. The string returned contains information such as the name of the table and columns to bulk import data into, and bulk options.
```
bulkImportStr := mssql.CopyIn("tablename", mssql.BulkOptions{}, "column1", "column2", "column3")
stmt, err := db.Prepare(bulkImportStr)
```
Bulk options can be specified using the `mssql.BulkOptions` type. The following is how the `BulkOptions` type is defined:
```
type BulkOptions struct {
CheckConstraints bool
FireTriggers bool
KeepNulls bool
KilobytesPerBatch int
RowsPerBatch int
Order []string
Tablock bool
}
```
The statement can be executed many times to copy data into the table specified.
```
for i := 0; i < 10; i++ {
_, err = stmt.Exec(col1Data[i], col2Data[i], col3Data[i])
}
```
After all the data is processed, call `Exec` once with no arguments to flush all the buffered data.
```
_, err = stmt.Exec()
```
## Example
[Bulk import example](../bulkimport_example_test.go)

View file

@ -1,32 +0,0 @@
# How to use the Connector object
A Connector holds information in a DSN and is ready to make a new connection at any time. Connector implements the database/sql/driver Connector interface so it can be passed to the database/sql `OpenDB` function. One property on the Connector is the `SessionInitSQL` field, which may be used to set any options that cannot be passed through a DSN string.
To use the Connector type, first you need to import the sql and go-mssqldb packages
```
import (
"database/sql"
"github.com/denisenkom/go-mssqldb"
)
```
Now you can create a Connector object by calling `NewConnector`, which creates a new connector from a DSN.
```
dsn := "sqlserver://username:password@hostname/instance?database=databasename"
connector, err := mssql.NewConnector(dsn)
```
You can set `connector.SessionInitSQL` for any options that cannot be passed through in the dsn string.
`connector.SessionInitSQL = "SET ANSI_NULLS ON"`
Open a database by passing connector to `sql.OpenDB`.
`db := sql.OpenDB(connector)`
The returned DB maintains its own pool of idle connections. Now you can use the `sql.DB` object for querying and executing queries.
## Example
[NewConnector example](../newconnector_example_test.go)

View file

@ -1,91 +0,0 @@
# How to use Table-Valued Parameters
Table-valued parameters are declared by using user-defined table types. You can use table-valued parameters to send multiple rows of data to a Transact-SQL statement or a routine, such as a stored procedure or function, without creating a temporary table or many parameters.
To make use of the TVP functionality, first you need to create a table type, and a procedure or function to receive data from the table-valued parameter.
```
createTVP = "CREATE TYPE LocationTableType AS TABLE (LocationName VARCHAR(50), CostRate INT)"
_, err = db.Exec(createTable)
createProc = `
CREATE PROCEDURE dbo.usp_InsertProductionLocation
@TVP LocationTableType READONLY
AS
SET NOCOUNT ON
INSERT INTO Location
(
Name,
CostRate,
Availability,
ModifiedDate)
SELECT *, 0,GETDATE()
FROM @TVP`
_, err = db.Exec(createProc)
```
In your go application, create a struct that corresponds to the table type you have created. Create a slice of these structs which contain the data you want to pass to the stored procedure.
```
type LocationTableTvp struct {
LocationName string
CostRate int64
}
locationTableTypeData := []LocationTableTvp{
{
LocationName: "Alberta",
CostRate: 0,
},
{
LocationName: "British Columbia",
CostRate: 1,
},
}
```
Create a `mssql.TVP` object, and pass the slice of structs into the `Value` member. Set `TypeName` to the table type name.
```
tvpType := mssql.TVP{
TypeName: "LocationTableType",
Value: locationTableTypeData,
}
```
Finally, execute the stored procedure and pass the `mssql.TVPType` object you have created as a parameter.
`_, err = db.Exec("exec dbo.usp_InsertProductionLocation @TVP;", sql.Named("TVP", tvpType))`
## Using Tags to Omit Fields in a Struct
Sometimes users may find it useful to include fields in the struct that do not have corresponding columns in the table type. The driver supports this feature by using tags. To omit a field from a struct, use the `json` or `tvp` tag key and the `"-"` tag value.
For example, the user wants to define a struct with two more fields: `LocationCountry` and `Currency`. However, the `LocationTableType` table type do not have these corresponding columns. The user can omit the two new fields from being read by using the `json` or `tvp` tag.
```
type LocationTableTvpDetailed struct {
LocationName string
LocationCountry string `tvp:"-"`
CostRate int64
Currency string `json:"-"`
}
```
The `tvp` tag is the highest priority. Therefore if there is a field with tag `json:"-" tvp:"any"`, the field is not omitted. The following struct demonstrates different scenarios of using the `json` and `tvp` tags.
```
type T struct {
F1 string `json:"f1" tvp:"f1"` // not omitted
F2 string `json:"-" tvp:"f2"` // tvp tag takes precedence; not omitted
F3 string `json:"f3" tvp:"-"` // tvp tag takes precedence; omitted
F4 string `json:"-" tvp:"-"` // omitted
F5 string `json:"f5"` // not omitted
F6 string `json:"-"` // omitted
F7 string `tvp:"f7"` // not omitted
F8 string `tvp:"-"` // omitted
}
```
## Example
[TVPType example](../tvp_example_test.go)

View file

@ -1,38 +0,0 @@
package mssql
import "fmt"
func ExampleError_1() {
// call a function that might return a mssql error
err := callUsingMSSQL()
type ErrorWithNumber interface {
SQLErrorNumber() int32
}
if errorWithNumber, ok := err.(ErrorWithNumber); ok {
if errorWithNumber.SQLErrorNumber() == 1205 {
fmt.Println("deadlock error")
}
}
}
func ExampleError_2() {
// call a function that might return a mssql error
err := callUsingMSSQL()
type SQLError interface {
SQLErrorNumber() int32
SQLErrorMessage() string
}
if sqlError, ok := err.(SQLError); ok {
if sqlError.SQLErrorNumber() == 1205 {
fmt.Println("deadlock error", sqlError.SQLErrorMessage())
}
}
}
func callUsingMSSQL() error {
return nil
}

View file

@ -1,111 +0,0 @@
package main
import (
"database/sql"
"flag"
"fmt"
"log"
"github.com/denisenkom/go-mssqldb"
)
var (
debug = flag.Bool("debug", true, "enable debugging")
password = flag.String("password", "osmtest", "the database password")
port *int = flag.Int("port", 1433, "the database port")
server = flag.String("server", "localhost", "the database server")
user = flag.String("user", "osmtest", "the database user")
database = flag.String("database", "bulktest", "the database name")
)
/*
CREATE TABLE test_table(
[id] [int] IDENTITY(1,1) NOT NULL,
[test_nvarchar] [nvarchar](50) NULL,
[test_varchar] [varchar](50) NULL,
[test_float] [float] NULL,
[test_datetime2_3] [datetime2](3) NULL,
[test_bitn] [bit] NULL,
[test_bigint] [bigint] NOT NULL,
[test_geom] [geometry] NULL,
CONSTRAINT [PK_table_test_id] PRIMARY KEY CLUSTERED
(
[id] ASC
) ON [PRIMARY]);
*/
func main() {
flag.Parse()
if *debug {
fmt.Printf(" password:%s\n", *password)
fmt.Printf(" port:%d\n", *port)
fmt.Printf(" server:%s\n", *server)
fmt.Printf(" user:%s\n", *user)
fmt.Printf(" database:%s\n", *database)
}
connString := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%d;database=%s", *server, *user, *password, *port, *database)
if *debug {
fmt.Printf("connString:%s\n", connString)
}
conn, err := sql.Open("mssql", connString)
if err != nil {
log.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
txn, err := conn.Begin()
if err != nil {
log.Fatal(err)
}
stmt, err := txn.Prepare(mssql.CopyIn("test_table", mssql.BulkOptions{}, "test_varchar", "test_nvarchar", "test_float", "test_bigint"))
if err != nil {
log.Fatal(err.Error())
}
for i := 0; i < 10; i++ {
_, err = stmt.Exec(generateString(0, 30), generateStringUnicode(0, 30), i, i)
if err != nil {
log.Fatal(err.Error())
}
}
result, err := stmt.Exec()
if err != nil {
log.Fatal(err)
}
err = stmt.Close()
if err != nil {
log.Fatal(err)
}
err = txn.Commit()
if err != nil {
log.Fatal(err)
}
rowCount, _ := result.RowsAffected()
log.Printf("%d row copied\n", rowCount)
log.Printf("bye\n")
}
func generateString(x int, n int) string {
letters := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
b := make([]byte, n)
for i := range b {
b[i] = letters[i%len(letters)]
}
return string(b)
}
func generateStringUnicode(x int, n int) string {
letters := "ab©💾é?ghïjklmnopqЯ☀tuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
b := make([]byte, n)
for i := range b {
b[i] = letters[i%len(letters)]
}
return string(b)
}

View file

@ -1,57 +0,0 @@
package main
import (
"database/sql"
"flag"
"fmt"
"log"
_ "github.com/denisenkom/go-mssqldb"
)
var (
debug = flag.Bool("debug", false, "enable debugging")
password = flag.String("password", "", "the database password")
port *int = flag.Int("port", 1433, "the database port")
server = flag.String("server", "", "the database server")
user = flag.String("user", "", "the database user")
)
func main() {
flag.Parse()
if *debug {
fmt.Printf(" password:%s\n", *password)
fmt.Printf(" port:%d\n", *port)
fmt.Printf(" server:%s\n", *server)
fmt.Printf(" user:%s\n", *user)
}
connString := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%d", *server, *user, *password, *port)
if *debug {
fmt.Printf(" connString:%s\n", connString)
}
conn, err := sql.Open("mssql", connString)
if err != nil {
log.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
stmt, err := conn.Prepare("select 1, 'abc'")
if err != nil {
log.Fatal("Prepare failed:", err.Error())
}
defer stmt.Close()
row := stmt.QueryRow()
var somenumber int64
var somechars string
err = row.Scan(&somenumber, &somechars)
if err != nil {
log.Fatal("Scan failed:", err.Error())
}
fmt.Printf("somenumber:%d\n", somenumber)
fmt.Printf("somechars:%s\n", somechars)
fmt.Printf("bye\n")
}

View file

@ -1,119 +0,0 @@
package main
import (
"bufio"
"database/sql"
"flag"
"fmt"
"io"
"os"
"time"
_ "github.com/denisenkom/go-mssqldb"
)
func main() {
var (
userid = flag.String("U", "", "login_id")
password = flag.String("P", "", "password")
server = flag.String("S", "localhost", "server_name[\\instance_name]")
database = flag.String("d", "", "db_name")
)
flag.Parse()
dsn := "server=" + *server + ";user id=" + *userid + ";password=" + *password + ";database=" + *database
db, err := sql.Open("mssql", dsn)
if err != nil {
fmt.Println("Cannot connect: ", err.Error())
return
}
err = db.Ping()
if err != nil {
fmt.Println("Cannot connect: ", err.Error())
return
}
defer db.Close()
r := bufio.NewReader(os.Stdin)
for {
_, err = os.Stdout.Write([]byte("> "))
if err != nil {
fmt.Println(err)
return
}
cmd, err := r.ReadString('\n')
if err != nil {
if err == io.EOF {
fmt.Println()
return
}
fmt.Println(err)
return
}
err = exec(db, cmd)
if err != nil {
fmt.Println(err)
}
}
}
func exec(db *sql.DB, cmd string) error {
rows, err := db.Query(cmd)
if err != nil {
return err
}
defer rows.Close()
cols, err := rows.Columns()
if err != nil {
return err
}
if cols == nil {
return nil
}
vals := make([]interface{}, len(cols))
for i := 0; i < len(cols); i++ {
vals[i] = new(interface{})
if i != 0 {
fmt.Print("\t")
}
fmt.Print(cols[i])
}
fmt.Println()
for rows.Next() {
err = rows.Scan(vals...)
if err != nil {
fmt.Println(err)
continue
}
for i := 0; i < len(vals); i++ {
if i != 0 {
fmt.Print("\t")
}
printValue(vals[i].(*interface{}))
}
fmt.Println()
}
if rows.Err() != nil {
return rows.Err()
}
return nil
}
func printValue(pval *interface{}) {
switch v := (*pval).(type) {
case nil:
fmt.Print("NULL")
case bool:
if v {
fmt.Print("1")
} else {
fmt.Print("0")
}
case []byte:
fmt.Print(string(v))
case time.Time:
fmt.Print(v.Format("2006-01-02 15:04:05.999"))
default:
fmt.Print(v)
}
}

View file

@ -1,151 +0,0 @@
package main
import (
"database/sql"
"flag"
"fmt"
"github.com/denisenkom/go-mssqldb"
"log"
)
var (
debug = flag.Bool("debug", false, "enable debugging")
password = flag.String("password", "", "the database password")
port = flag.Int("port", 1433, "the database port")
server = flag.String("server", "", "the database server")
user = flag.String("user", "", "the database user")
)
type TvpExample struct {
MessageWithoutAnyTag string
MessageWithJSONTag string `json:"message"`
MessageWithTVPTag string `tvp:"message"`
MessageJSONSkipWithTVPTag string `json:"-" tvp:"message"`
OmitFieldJSONTag string `json:"-"`
OmitFieldTVPTag string `json:"any" tvp:"-"`
OmitFieldTVPTag2 string `tvp:"-"`
}
const (
crateSchema = `create schema TestTVPSchema;`
dropSchema = `drop schema TestTVPSchema;`
createTVP = `
CREATE TYPE TestTVPSchema.exampleTVP AS TABLE
(
message1 NVARCHAR(100),
message2 NVARCHAR(100),
message3 NVARCHAR(100),
message4 NVARCHAR(100)
)`
dropTVP = `DROP TYPE TestTVPSchema.exampleTVP;`
procedureWithTVP = `
CREATE PROCEDURE ExecTVP
@param1 TestTVPSchema.exampleTVP READONLY
AS
BEGIN
SET NOCOUNT ON;
SELECT * FROM @param1;
END;
`
dropProcedure = `drop PROCEDURE ExecTVP`
execTvp = `exec ExecTVP @param1;`
)
func main() {
flag.Parse()
if *debug {
fmt.Printf(" password:%s\n", *password)
fmt.Printf(" port:%d\n", *port)
fmt.Printf(" server:%s\n", *server)
fmt.Printf(" user:%s\n", *user)
}
connString := fmt.Sprintf("server=%s;user id=%s;password=%s;port=%d", *server, *user, *password, *port)
if *debug {
fmt.Printf(" connString:%s\n", connString)
}
conn, err := sql.Open("sqlserver", connString)
if err != nil {
log.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
_, err = conn.Exec(crateSchema)
if err != nil {
log.Println(err)
return
}
defer conn.Exec(dropSchema)
_, err = conn.Exec(createTVP)
if err != nil {
log.Println(err)
return
}
defer conn.Exec(dropTVP)
_, err = conn.Exec(procedureWithTVP)
if err != nil {
log.Println(err)
return
}
defer conn.Exec(dropProcedure)
exampleData := []TvpExample{
{
MessageWithoutAnyTag: "Hello1",
MessageWithJSONTag: "Hello2",
MessageWithTVPTag: "Hello3",
MessageJSONSkipWithTVPTag: "Hello4",
OmitFieldJSONTag: "Hello5",
OmitFieldTVPTag: "Hello6",
OmitFieldTVPTag2: "Hello7",
},
{
MessageWithoutAnyTag: "World1",
MessageWithJSONTag: "World2",
MessageWithTVPTag: "World3",
MessageJSONSkipWithTVPTag: "World4",
OmitFieldJSONTag: "World5",
OmitFieldTVPTag: "World6",
OmitFieldTVPTag2: "World7",
},
}
tvpType := mssql.TVP{
TypeName: "TestTVPSchema.exampleTVP",
Value: exampleData,
}
rows, err := conn.Query(execTvp,
sql.Named("param1", tvpType),
)
if err != nil {
log.Println(err)
return
}
tvpResult := make([]TvpExample, 0)
for rows.Next() {
tvpExample := TvpExample{}
err = rows.Scan(&tvpExample.MessageWithoutAnyTag,
&tvpExample.MessageWithJSONTag,
&tvpExample.MessageWithTVPTag,
&tvpExample.MessageJSONSkipWithTVPTag,
)
if err != nil {
log.Println(err)
return
}
tvpResult = append(tvpResult, tvpExample)
}
fmt.Println(tvpResult)
}

View file

@ -1,10 +0,0 @@
module github.com/denisenkom/go-mssqldb
go 1.12
require (
cloud.google.com/go v0.37.4
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
gopkg.in/yaml.v2 v2.2.2 // indirect
)

View file

@ -1,168 +0,0 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.31.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.37.2 h1:4y4L7BdHenTfZL0HervofNTHh9Ad6mNX72cQvl+5eH0=
cloud.google.com/go v0.37.2/go.mod h1:H8IAquKe2L30IxoupDgqTaQvKSwF/c8prYHynGIWQbA=
git.apache.org/thrift.git v0.0.0-20180902110319-2566ecd5d999/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
git.apache.org/thrift.git v0.12.0/go.mod h1:fPE2ZNJGynbRyZ4dJvy6G277gSllfV2HJqblrnkyeyg=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo=
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
github.com/alecthomas/template v0.0.0-20160405071501-a0175ee3bccc/go.mod h1:LOuyumcjzFXgccqObfd/Ljyb9UuFJ6TxHnclSeseNhc=
github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0=
github.com/anmitsu/go-shlex v0.0.0-20161002113705-648efa622239/go.mod h1:2FmKhYUyUczH0OGQWaF5ceTx0UBShxjsH6f8oGKYe2c=
github.com/apache/thrift v0.12.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
github.com/bradfitz/go-smtpd v0.0.0-20170404230938-deb6d6237625/go.mod h1:HYsPBTaaSFSlLx/70C2HPIMNZpVV8+vt/A+FMnYP11g=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568/go.mod h1:xEzjJPgXI435gkrCt3MPfRiAkVrwSbHsst4LCFVfpJc=
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04=
github.com/gliderlabs/ssh v0.1.1/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0=
github.com/go-kit/kit v0.8.0/go.mod h1:xBxKIO96dXMWWy0MnWVtmwkA9/13aqxPnvrjFYMA2as=
github.com/go-logfmt/logfmt v0.3.0/go.mod h1:Qt1PoO58o5twSAckw1HlFXLmHsOX5/0LbT9GBnD5lWE=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
github.com/golang/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:tluoj9z5200jBnyusfRPU2LqT6J+DAorxEvtC7LHB+E=
github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/mock v1.2.0/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A=
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M=
github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ=
github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck=
github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs=
github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc=
github.com/googleapis/gax-go v2.0.0+incompatible/go.mod h1:SFVmujtThgffbyetf+mdk2eWhX2bMyUtNHzFKcPA9HY=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
github.com/gregjones/httpcache v0.0.0-20180305231024-9cad4c3443a7/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA=
github.com/grpc-ecosystem/grpc-gateway v1.5.0/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw=
github.com/grpc-ecosystem/grpc-gateway v1.6.2/go.mod h1:RSKVYQBd5MCa4OVpNdGskqpgL2+G+NZTnrVHpWWfpdw=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/jellevandenhooff/dkim v0.0.0-20150330215556-f50fe3d243e1/go.mod h1:E0B/fFc00Y+Rasa88328GlI/XbtyysCtTHZS8h7IrBU=
github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU=
github.com/julienschmidt/httprouter v1.2.0/go.mod h1:SYymIcj16QtmaHHD7aYtjjsJG7VTCxuUUipMqKk8s4w=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/logfmt v0.0.0-20140226030751-b84e30acd515/go.mod h1:+0opPa2QZZtGFBFZlji/RkVcI2GknAs/DXo4wKdlNEc=
github.com/kr/pty v1.1.3/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE=
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
github.com/openzipkin/zipkin-go v0.1.3/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
github.com/openzipkin/zipkin-go v0.1.6/go.mod h1:QgAqvLzwWbR/WpD4A3cGpPtJrZXNIiJc5AZX7/PBEpw=
github.com/pierrec/lz4 v2.0.5+incompatible/go.mod h1:pdkljMzZIN41W+lC3N2tnIh5sFi+IEE17M5jbnwPHcY=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/prometheus/client_golang v0.8.0/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.1/go.mod h1:7SWBe2y4D6OKWSNQJUaRYU/AaXPKyh/dDVn+NZz0KFw=
github.com/prometheus/client_golang v0.9.3-0.20190127221311-3c4408c8b829/go.mod h1:p2iRAGwDERtqlqzRXnrOVns+ignqQo//hLXqYxZYVNs=
github.com/prometheus/client_model v0.0.0-20180712105110-5c3871d89910/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/client_model v0.0.0-20190115171406-56726106282f/go.mod h1:MbSGuTsp3dbXC40dX6PRTWyKYBIrTGTE9sqQNg2J8bo=
github.com/prometheus/common v0.0.0-20180801064454-c7de2306084e/go.mod h1:daVV7qP5qjZbuso7PdcryaAu0sAZbrN9i7WWcTMWvro=
github.com/prometheus/common v0.2.0/go.mod h1:TNfzLD0ON7rHzMJeJkieUDPYmFC7Snx/y86RQel1bk4=
github.com/prometheus/procfs v0.0.0-20180725123919-05ee40e3a273/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.0-20181005140218-185b4288413d/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/prometheus/procfs v0.0.0-20190117184657-bf6a532e95b1/go.mod h1:c3At6R/oaqEKCNdg8wHV1ftS6bRYblBhIjjI8uT2IGk=
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
go.opencensus.io v0.18.0/go.mod h1:vKdFvxhtzZ9onBp9VKHK8z/sRpBMnKAsufL7wlDrCOA=
go.opencensus.io v0.19.1/go.mod h1:gug0GbSHa8Pafr0d2urOSgoXHZ6x/RUlaiT0d9pqb4A=
go.opencensus.io v0.19.2/go.mod h1:NO/8qkisMZLZ1FCsKNqtJPwc8/TaclWyY0B6wcYNg9M=
go4.org v0.0.0-20180809161055-417644f6feb5/go.mod h1:MkTOUMDaeVYJUOUsaDXIhWPZYa1yOyC1qaOBpL57BhE=
golang.org/x/build v0.0.0-20190314133821-5284462c4bec/go.mod h1:atTaCNAy0f16Ah5aV1gMSwgiKVHwu/JncqDpuRr7lS4=
golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20181030102418-4d3f4d9ffa16/go.mod h1:6SG95UA2DQfeDnfUPMdvaQW0Q7yPrPDi9nlGo2tz2b4=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c h1:Vj5n4GlwjmQteupaxJ9+0FNOmBrHfq7vN4btdGoDZgI=
golang.org/x/crypto v0.0.0-20190325154230-a5d413f7728c/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20181217174547-8f45f776aaf1/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU=
golang.org/x/lint v0.0.0-20190301231843-5614ed5bae6f/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181029044818-c44066c5c816/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181106065722-10aee1819953/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20181114220301-adae6a3d119a/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190108225652-1e06a53dbb7e/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190125091013-d26f9f9a57f3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181017192945-9dcd33a902f4/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20181203162652-d668ce993890/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/perf v0.0.0-20180704124530-6e6d33e29852/go.mod h1:JLpeXjPJfIyPr5TlbXLkXWLhP8nz10XfvxElABhCtcw=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181029174526-d69651ed3497/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181122145206-62eef0e2fa9b/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20181218192612-074acd46bca6/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/tools v0.0.0-20180828015842-6cd1fcedba52/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20181219222714-6e267b5cc78e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY=
golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
google.golang.org/api v0.0.0-20180910000450-7ca32eb868bf/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0=
google.golang.org/api v0.0.0-20181030000543-1d582fd0359e/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0=
google.golang.org/api v0.0.0-20181220000619-583d854617af/go.mod h1:4mhQ8q/RsB7i+udVvVy5NUi08OU8ZlA0gRVgrF7VFY0=
google.golang.org/api v0.2.0/go.mod h1:IfRCZScioGtypHNTlz3gFk67J8uePVW7uDTBzXuIkhU=
google.golang.org/api v0.3.0/go.mod h1:IuvZyQh8jgscv8qWfQ4ABd8m7hEudgBFM/EdhA3BnXw=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=
google.golang.org/appengine v1.2.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.3.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4=
google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20180831171423-11092d34479b/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20181029155118-b69ba1387ce2/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc=
google.golang.org/genproto v0.0.0-20181219182458-5a97ab628bfb/go.mod h1:7Ep/1NZk928CDR8SjdVbjWNpdIf6nzjE3BTgJDr2Atg=
google.golang.org/genproto v0.0.0-20190307195333-5fe7a883aa19/go.mod h1:VzzqZJRnGkLBvHegQrXjBqPurQTc5/KpmUdxsrq26oE=
google.golang.org/grpc v1.14.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw=
google.golang.org/grpc v1.16.0/go.mod h1:0JHn/cJsOMiMfNA9+DeHDlAU7KAAB5GDlYFpa9MZMio=
google.golang.org/grpc v1.17.0/go.mod h1:6QZJwpn2B+Zp71q/5VxRsJ6NXXVCE5NRUHRo+f3cWCs=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLkstjWtayDeSgw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7/go.mod h1:dt/ZhP58zS4L8KSrWDmTeBkI65Dw0HsyUHuEVlX15mw=
gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o=
honnef.co/go/tools v0.0.0-20180728063816-88497007e858/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20180920025451-e3ad64cb4ed3/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=
honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4=

View file

@ -112,7 +112,7 @@ func (c *Conn) CheckNamedValue(nv *driver.NamedValue) error {
*v = 0 // By default the return value should be zero.
c.returnStatus = v
return driver.ErrRemoveArgument
case TVP:
case TVPType:
return nil
default:
var err error
@ -162,27 +162,15 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) {
case sql.Out:
res, err = s.makeParam(val.Dest)
res.Flags = fByRevValue
case TVP:
case TVPType:
err = val.check()
if err != nil {
return
}
schema, name, errGetName := getSchemeAndName(val.TypeName)
if errGetName != nil {
return
}
res.ti.UdtInfo.TypeName = name
res.ti.UdtInfo.SchemaName = schema
res.ti.UdtInfo.TypeName = val.TVPTypeName
res.ti.UdtInfo.SchemaName = val.TVPScheme
res.ti.TypeId = typeTvp
columnStr, tvpFieldIndexes, errCalTypes := val.columnTypes()
if errCalTypes != nil {
err = errCalTypes
return
}
res.buffer, err = val.encode(schema, name, columnStr, tvpFieldIndexes)
if err != nil {
return
}
res.buffer, err = val.encode()
res.ti.Size = len(res.buffer)
default:

View file

@ -1,39 +0,0 @@
package mssql
import (
"context"
"testing"
)
func TestBadOpen(t *testing.T) {
drv := driverWithProcess(t)
_, err := drv.open(context.Background(), "port=bad")
if err == nil {
t.Fail()
}
}
func TestIsProc(t *testing.T) {
list := []struct {
s string
is bool
}{
{"proc", true},
{"select 1;", false},
{"select 1", false},
{"[proc 1]", true},
{"[proc\n1]", false},
{"schema.name", true},
{"[schema].[name]", true},
{"schema.[name]", true},
{"[schema].name", true},
{"schema.[proc name]", true},
}
for _, item := range list {
got := isProc(item.s)
if got != item.is {
t.Errorf("for %q, got %t want %t", item.s, got, item.is)
}
}
}

View file

@ -1,143 +0,0 @@
// +build go1.10
package mssql_test
import (
"context"
"database/sql"
"flag"
"fmt"
"log"
"net/url"
"strconv"
mssql "github.com/denisenkom/go-mssqldb"
)
var (
debug = flag.Bool("debug", false, "enable debugging")
password = flag.String("password", "", "the database password")
port *int = flag.Int("port", 1433, "the database port")
server = flag.String("server", "", "the database server")
user = flag.String("user", "", "the database user")
)
const (
createTableSql = "CREATE TABLE TestAnsiNull (bitcol bit, charcol char(1));"
dropTableSql = "IF OBJECT_ID('TestAnsiNull', 'U') IS NOT NULL DROP TABLE TestAnsiNull;"
insertQuery1 = "INSERT INTO TestAnsiNull VALUES (0, NULL);"
insertQuery2 = "INSERT INTO TestAnsiNull VALUES (1, 'a');"
selectNullFilter = "SELECT bitcol FROM TestAnsiNull WHERE charcol = NULL;"
selectNotNullFilter = "SELECT bitcol FROM TestAnsiNull WHERE charcol <> NULL;"
)
func makeConnURL() *url.URL {
return &url.URL{
Scheme: "sqlserver",
Host: *server + ":" + strconv.Itoa(*port),
User: url.UserPassword(*user, *password),
}
}
// This example shows the usage of Connector type
func ExampleConnector() {
flag.Parse()
if *debug {
fmt.Printf(" password:%s\n", *password)
fmt.Printf(" port:%d\n", *port)
fmt.Printf(" server:%s\n", *server)
fmt.Printf(" user:%s\n", *user)
}
connString := makeConnURL().String()
if *debug {
fmt.Printf(" connString:%s\n", connString)
}
// Create a new connector object by calling NewConnector
connector, err := mssql.NewConnector(connString)
if err != nil {
log.Println(err)
return
}
// Use SessionInitSql to set any options that cannot be set with the dsn string
// With ANSI_NULLS set to ON, compare NULL data with = NULL or <> NULL will return 0 rows
connector.SessionInitSQL = "SET ANSI_NULLS ON"
// Pass connector to sql.OpenDB to get a sql.DB object
db := sql.OpenDB(connector)
defer db.Close()
// Create and populate table
_, err = db.Exec(createTableSql)
if err != nil {
log.Println(err)
return
}
defer db.Exec(dropTableSql)
_, err = db.Exec(insertQuery1)
if err != nil {
log.Println(err)
return
}
_, err = db.Exec(insertQuery2)
if err != nil {
log.Println(err)
return
}
var bitval bool
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
// (*Row) Scan should return ErrNoRows since ANSI_NULLS is set to ON
err = db.QueryRowContext(ctx, selectNullFilter).Scan(&bitval)
if err != nil {
if err.Error() != "sql: no rows in result set" {
log.Println(err)
return
}
} else {
log.Println("Expects an ErrNoRows error. No error is returned")
return
}
// (*Row) Scan should return ErrNoRows since ANSI_NULLS is set to ON
err = db.QueryRowContext(ctx, selectNotNullFilter).Scan(&bitval)
if err != nil {
if err.Error() != "sql: no rows in result set" {
log.Println(err)
return
}
} else {
log.Println("Expects an ErrNoRows error. No error is returned")
return
}
// Set ANSI_NULLS to OFF
connector.SessionInitSQL = "SET ANSI_NULLS OFF"
// (*Row) Scan should copy data to bitval
err = db.QueryRowContext(ctx, selectNullFilter).Scan(&bitval)
if err != nil {
log.Println(err)
return
}
if bitval != false {
log.Println("Incorrect value retrieved.")
return
}
// (*Row) Scan should copy data to bitval
err = db.QueryRowContext(ctx, selectNotNullFilter).Scan(&bitval)
if err != nil {
log.Println(err)
return
}
if bitval != true {
log.Println("Incorrect value retrieved.")
return
}
}

View file

@ -1,76 +0,0 @@
// +build !windows
package mssql
import (
"encoding/hex"
"testing"
)
func TestLMOWFv1(t *testing.T) {
hash := lmHash("Password")
val := [21]byte{
0xe5, 0x2c, 0xac, 0x67, 0x41, 0x9a, 0x9a, 0x22,
0x4a, 0x3b, 0x10, 0x8f, 0x3f, 0xa6, 0xcb, 0x6d,
0, 0, 0, 0, 0,
}
if hash != val {
t.Errorf("got:\n%sexpected:\n%s", hex.Dump(hash[:]), hex.Dump(val[:]))
}
}
func TestNTLMOWFv1(t *testing.T) {
hash := ntlmHash("Password")
val := [21]byte{
0xa4, 0xf4, 0x9c, 0x40, 0x65, 0x10, 0xbd, 0xca, 0xb6, 0x82, 0x4e, 0xe7, 0xc3, 0x0f, 0xd8, 0x52,
0, 0, 0, 0, 0,
}
if hash != val {
t.Errorf("got:\n%sexpected:\n%s", hex.Dump(hash[:]), hex.Dump(val[:]))
}
}
func TestNTLMv1Response(t *testing.T) {
challenge := [8]byte{
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
}
nt := ntResponse(challenge, "Password")
val := [24]byte{
0x67, 0xc4, 0x30, 0x11, 0xf3, 0x02, 0x98, 0xa2, 0xad, 0x35, 0xec, 0xe6, 0x4f, 0x16, 0x33, 0x1c,
0x44, 0xbd, 0xbe, 0xd9, 0x27, 0x84, 0x1f, 0x94,
}
if nt != val {
t.Errorf("got:\n%sexpected:\n%s", hex.Dump(nt[:]), hex.Dump(val[:]))
}
}
func TestLMv1Response(t *testing.T) {
challenge := [8]byte{
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
}
nt := lmResponse(challenge, "Password")
val := [24]byte{
0x98, 0xde, 0xf7, 0xb8, 0x7f, 0x88, 0xaa, 0x5d, 0xaf, 0xe2, 0xdf, 0x77, 0x96, 0x88, 0xa1, 0x72,
0xde, 0xf1, 0x1c, 0x7d, 0x5c, 0xcd, 0xef, 0x13,
}
if nt != val {
t.Errorf("got:\n%sexpected:\n%s", hex.Dump(nt[:]), hex.Dump(val[:]))
}
}
func TestNTLMSessionResponse(t *testing.T) {
challenge := [8]byte{
0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef,
}
nonce := [8]byte{
0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa, 0xaa,
}
nt := ntlmSessionResponse(nonce, challenge, "Password")
val := [24]byte{
0x75, 0x37, 0xf8, 0x03, 0xae, 0x36, 0x71, 0x28, 0xca, 0x45, 0x82, 0x04, 0xbd, 0xe7, 0xca, 0xf8,
0x1e, 0x97, 0xed, 0x26, 0x83, 0x26, 0x72, 0x32,
}
if nt != val {
t.Errorf("got:\n%sexpected:\n%s", hex.Dump(nt[:]), hex.Dump(val[:]))
}
}

View file

@ -1,60 +0,0 @@
package mssql
import (
"testing"
)
func TestParseParams(t *testing.T) {
values := []struct {
s string
d string
n int
}{
{"select ?", "select @p1", 1},
{"select ?, ?", "select @p1, @p2", 2},
{"select ? -- ?", "select @p1 -- ?", 1},
{"select ? -- ?\n, ?", "select @p1 -- ?\n, @p2", 2},
{"select ? - ?", "select @p1 - @p2", 2},
{"select ? /* ? */, ?", "select @p1 /* ? */, @p2", 2},
{"select ? /* ? * ? */, ?", "select @p1 /* ? * ? */, @p2", 2},
{"select \"foo?\", [foo?], 'foo?', ?", "select \"foo?\", [foo?], 'foo?', @p1", 1},
{"select \"x\"\"y\", [x]]y], 'x''y', ?", "select \"x\"\"y\", [x]]y], 'x''y', @p1", 1},
{"select \"foo?\", ?", "select \"foo?\", @p1", 1},
{"select 'foo?', ?", "select 'foo?', @p1", 1},
{"select [foo?], ?", "select [foo?], @p1", 1},
{"select $1", "select @p1", 1},
{"select $1, $2", "select @p1, @p2", 2},
{"select $1, $1", "select @p1, @p1", 1},
{"select :1", "select @p1", 1},
{"select :1, :2", "select @p1, @p2", 2},
{"select :1, :1", "select @p1, @p1", 1},
{"select ?1", "select @p1", 1},
{"select ?1, ?2", "select @p1, @p2", 2},
{"select ?1, ?1", "select @p1, @p1", 1},
{"select $12", "select @p12", 12},
{"select ? /* ? /* ? */ ? */ ?", "select @p1 /* ? /* ? */ ? */ @p2", 2},
{"select ? /* ? / ? */ ?", "select @p1 /* ? / ? */ @p2", 2},
{"select $", "select $", 0},
{"select x::y", "select x:@y", 1},
{"select '", "select '", 0},
{"select \"", "select \"", 0},
{"select [", "select [", 0},
{"select []", "select []", 0},
{"select -", "select -", 0},
{"select /", "select /", 0},
{"select 1/1", "select 1/1", 0},
{"select /*", "select /*", 0},
{"select /**", "select /**", 0},
{"select /*/", "select /*/", 0},
}
for _, v := range values {
d, n := parseParams(v.s)
if d != v.d {
t.Errorf("Parse params don't match for %s, got %s but expected %s", v.s, d, v.d)
}
if n != v.n {
t.Errorf("Parse number of params don't match for %s, got %d but expected %d", v.s, n, v.n)
}
}
}

View file

@ -1,189 +0,0 @@
// +build go1.10
package mssql
import (
"context"
"database/sql"
"strings"
"testing"
"time"
"cloud.google.com/go/civil"
)
func TestSessionInitSQL(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
d := &Driver{}
connector, err := d.OpenConnector(makeConnStr(t).String())
if err != nil {
t.Fatal("unable to open connector", err)
}
// Do not use these settings in your application
// unless you know what they do.
// Thes are for this unit test only.
//
// Sessions will be reset even if SessionInitSQL is not set.
connector.SessionInitSQL = `
SET XACT_ABORT ON; -- 16384
SET ANSI_NULLS ON; -- 32
SET ARITHIGNORE ON; -- 128
`
pool := sql.OpenDB(connector)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
var opt int32
err = pool.QueryRowContext(ctx, `
select Options = @@OPTIONS;
`).Scan(&opt)
if err != nil {
t.Fatal("failed to run query", err)
}
mask := int32(16384 | 128 | 32)
if opt&mask != mask {
t.Fatal("incorrect session settings", opt)
}
}
func TestParameterTypes(t *testing.T) {
checkConnStr(t)
pool, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
t.Fatal(err)
}
defer pool.Close()
tin, err := time.Parse(time.RFC3339, "2006-01-02T22:04:05-07:00")
if err != nil {
t.Fatal(err)
}
var nv, v, nvcm, vcm, dt1, dt2, tm, d, dto string
row := pool.QueryRow(`
select
nv = SQL_VARIANT_PROPERTY(@nv,'BaseType'),
v = SQL_VARIANT_PROPERTY(@v,'BaseType'),
@nvcm,
@vcm,
dt1 = SQL_VARIANT_PROPERTY(@dt1,'BaseType'),
dt2 = SQL_VARIANT_PROPERTY(@dt2,'BaseType'),
d = SQL_VARIANT_PROPERTY(@d,'BaseType'),
tm = SQL_VARIANT_PROPERTY(@tm,'BaseType'),
dto = SQL_VARIANT_PROPERTY(@dto,'BaseType')
;
`,
sql.Named("nv", "base type nvarchar"),
sql.Named("v", VarChar("base type varchar")),
sql.Named("nvcm", NVarCharMax(strings.Repeat("x", 5000))),
sql.Named("vcm", VarCharMax(strings.Repeat("x", 5000))),
sql.Named("dt1", DateTime1(tin)),
sql.Named("dt2", civil.DateTimeOf(tin)),
sql.Named("d", civil.DateOf(tin)),
sql.Named("tm", civil.TimeOf(tin)),
sql.Named("dto", DateTimeOffset(tin)),
)
err = row.Scan(&nv, &v, &nvcm, &vcm, &dt1, &dt2, &d, &tm, &dto)
if err != nil {
t.Fatal(err)
}
if nv != "nvarchar" {
t.Errorf(`want "nvarchar" got %q`, nv)
}
if v != "varchar" {
t.Errorf(`want "varchar" got %q`, v)
}
if nvcm != strings.Repeat("x", 5000) {
t.Errorf(`incorrect value returned for nvarchar(max): %q`, nvcm)
}
if vcm != strings.Repeat("x", 5000) {
t.Errorf(`incorrect value returned for varchar(max): %q`, vcm)
}
if dt1 != "datetime" {
t.Errorf(`want "datetime" got %q`, dt1)
}
if dt2 != "datetime2" {
t.Errorf(`want "datetime2" got %q`, dt2)
}
if d != "date" {
t.Errorf(`want "date" got %q`, d)
}
if tm != "time" {
t.Errorf(`want "time" got %q`, tm)
}
if dto != "datetimeoffset" {
t.Errorf(`want "datetimeoffset" got %q`, dto)
}
}
func TestParameterValues(t *testing.T) {
checkConnStr(t)
pool, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
t.Fatal(err)
}
defer pool.Close()
sin := "high five"
tin, err := time.Parse(time.RFC3339, "2006-01-02T22:04:05-07:00")
if err != nil {
t.Fatal(err)
}
var nv, v, tgo, dt1, dt2, tm, d, dto string
err = pool.QueryRow(`
select
nv = @nv,
v = @v,
tgo = @tgo,
dt1 = convert(nvarchar(200), @dt1, 121),
dt2 = convert(nvarchar(200), @dt2, 121),
d = convert(nvarchar(200), @d, 121),
tm = convert(nvarchar(200), @tm, 121),
dto = convert(nvarchar(200), @dto, 121)
;
`,
sql.Named("nv", sin),
sql.Named("v", sin),
sql.Named("tgo", tin),
sql.Named("dt1", DateTime1(tin)),
sql.Named("dt2", civil.DateTimeOf(tin)),
sql.Named("d", civil.DateOf(tin)),
sql.Named("tm", civil.TimeOf(tin)),
sql.Named("dto", DateTimeOffset(tin)),
).Scan(&nv, &v, &tgo, &dt1, &dt2, &d, &tm, &dto)
if err != nil {
t.Fatal(err)
}
if want := sin; nv != want {
t.Errorf(`want %q got %q`, want, nv)
}
if want := sin; v != want {
t.Errorf(`want %q got %q`, want, v)
}
if want := "2006-01-02T22:04:05-07:00"; tgo != want {
t.Errorf(`want %q got %q`, want, tgo)
}
if want := "2006-01-02 22:04:05.000"; dt1 != want {
t.Errorf(`want %q got %q`, want, dt1)
}
if want := "2006-01-02 22:04:05.0000000"; dt2 != want {
t.Errorf(`want %q got %q`, want, dt2)
}
if want := "2006-01-02"; d != want {
t.Errorf(`want %q got %q`, want, d)
}
if want := "22:04:05.0000000"; tm != want {
t.Errorf(`want %q got %q`, want, tm)
}
if want := "2006-01-02 22:04:05.0000000 -07:00"; dto != want {
t.Errorf(`want %q got %q`, want, dto)
}
}

View file

@ -1,924 +0,0 @@
// +build go1.9
package mssql
import (
"bytes"
"context"
"database/sql"
"fmt"
"regexp"
"testing"
"time"
)
func TestOutputParam(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
db, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
t.Fatalf("failed to open driver sqlserver")
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
t.Run("sp with rows", func(t *testing.T) {
sqltextcreate := `
CREATE PROCEDURE spwithrows
@intparam INT = NULL OUTPUT
AS
BEGIN
-- return 2 rows
SELECT @intparam
union
SELECT 20
-- set output parameter value
SELECT @intparam = 10
END;
`
sqltextdrop := `DROP PROCEDURE spwithrows;`
sqltextrun := `spwithrows`
db.ExecContext(ctx, sqltextdrop)
_, err = db.ExecContext(ctx, sqltextcreate)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdrop)
if err != nil {
t.Error(err)
}
var intparam int = 5
rows, err := db.QueryContext(ctx, sqltextrun,
sql.Named("intparam", sql.Out{Dest: &intparam}),
)
if err != nil {
t.Error(err)
}
// reading first row
if !rows.Next() {
t.Error("Next returned false")
}
var rowval int
err = rows.Scan(&rowval)
if err != nil {
t.Error(err)
}
if rowval != 5 {
t.Errorf("expected 5, got %d", rowval)
}
// if uncommented would trigger race condition warning
//if intparam != 10 {
// t.Log("output parameter value is not yet 10, it is ", intparam)
//}
// reading second row
if !rows.Next() {
t.Error("Next returned false")
}
err = rows.Scan(&rowval)
if err != nil {
t.Error(err)
}
if rowval != 20 {
t.Errorf("expected 20, got %d", rowval)
}
if rows.Next() {
t.Error("Next returned true but should return false after last row was returned")
}
if intparam != 10 {
t.Errorf("expected 10, got %d", intparam)
}
})
t.Run("sp with no rows", func(t *testing.T) {
sqltextcreate := `
CREATE PROCEDURE abassign
@aid INT = 5,
@bid INT = NULL OUTPUT,
@cstr NVARCHAR(2000) = NULL OUTPUT,
@datetime datetime = NULL OUTPUT
AS
BEGIN
SELECT @bid = @aid, @cstr = 'OK', @datetime = '2010-01-01T00:00:00';
END;
`
sqltextdrop := `DROP PROCEDURE abassign;`
sqltextrun := `abassign`
db.ExecContext(ctx, sqltextdrop)
_, err = db.ExecContext(ctx, sqltextcreate)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdrop)
if err != nil {
t.Error(err)
}
t.Run("should work", func(t *testing.T) {
var bout int64
var cout string
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("aid", 5),
sql.Named("bid", sql.Out{Dest: &bout}),
sql.Named("cstr", sql.Out{Dest: &cout}),
)
if err != nil {
t.Error(err)
}
if bout != 5 {
t.Errorf("expected 5, got %d", bout)
}
if cout != "OK" {
t.Errorf("expected OK, got %s", cout)
}
})
t.Run("should work if aid is not passed", func(t *testing.T) {
var bout int64
var cout string
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("bid", sql.Out{Dest: &bout}),
sql.Named("cstr", sql.Out{Dest: &cout}),
)
if err != nil {
t.Error(err)
}
if bout != 5 {
t.Errorf("expected 5, got %d", bout)
}
if cout != "OK" {
t.Errorf("expected OK, got %s", cout)
}
})
t.Run("should work for DateTime1 parameter", func(t *testing.T) {
tin, err := time.Parse(time.RFC3339, "2006-01-02T22:04:05-07:00")
if err != nil {
t.Fatal(err)
}
expected, err := time.Parse(time.RFC3339, "2010-01-01T00:00:00-00:00")
if err != nil {
t.Fatal(err)
}
var datetime_param DateTime1
datetime_param = DateTime1(tin)
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("datetime", sql.Out{Dest: &datetime_param}),
)
if err != nil {
t.Error(err)
}
if time.Time(datetime_param).UTC() != expected.UTC() {
t.Errorf("Datetime returned '%v' does not match expected value '%v'",
time.Time(datetime_param).UTC(), expected.UTC())
}
})
t.Run("destination is not a pointer", func(t *testing.T) {
var int_out int64
var str_out string
// test when destination is not a pointer
_, actual := db.ExecContext(ctx, sqltextrun,
sql.Named("bid", sql.Out{Dest: int_out}),
sql.Named("cstr", sql.Out{Dest: &str_out}),
)
pattern := ".*destination not a pointer.*"
match, err := regexp.MatchString(pattern, actual.Error())
if err != nil {
t.Error(err)
}
if !match {
t.Errorf("Error '%v', does not match pattern '%v'.", actual, pattern)
}
})
t.Run("should convert int64 to int", func(t *testing.T) {
var bout int
var cout string
_, err := db.ExecContext(ctx, sqltextrun,
sql.Named("bid", sql.Out{Dest: &bout}),
sql.Named("cstr", sql.Out{Dest: &cout}),
)
if err != nil {
t.Error(err)
}
if bout != 5 {
t.Errorf("expected 5, got %d", bout)
}
})
t.Run("should fail if destination has invalid type", func(t *testing.T) {
// Error type should not be supported
var err_out Error
_, err := db.ExecContext(ctx, sqltextrun,
sql.Named("bid", sql.Out{Dest: &err_out}),
)
if err == nil {
t.Error("Expected to fail but it didn't")
}
// double inderection should not work
var out_out = sql.Out{Dest: &err_out}
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("bid", sql.Out{Dest: out_out}),
)
if err == nil {
t.Error("Expected to fail but it didn't")
}
})
t.Run("should fail if parameter has invalid type", func(t *testing.T) {
// passing invalid parameter type
var err_val Error
_, err = db.ExecContext(ctx, sqltextrun, err_val)
if err == nil {
t.Error("Expected to fail but it didn't")
}
})
t.Run("destination is a nil pointer", func(t *testing.T) {
var str_out string
// test when destination is nil pointer
_, actual := db.ExecContext(ctx, sqltextrun,
sql.Named("bid", sql.Out{Dest: nil}),
sql.Named("cstr", sql.Out{Dest: &str_out}),
)
pattern := ".*destination is a nil pointer.*"
match, err := regexp.MatchString(pattern, actual.Error())
if err != nil {
t.Error(err)
}
if !match {
t.Errorf("Error '%v', does not match pattern '%v'.", actual, pattern)
}
})
t.Run("destination is a nil pointer 2", func(t *testing.T) {
var int_ptr *int
_, actual := db.ExecContext(ctx, sqltextrun,
sql.Named("bid", sql.Out{Dest: int_ptr}),
)
pattern := ".*destination is a nil pointer.*"
match, err := regexp.MatchString(pattern, actual.Error())
if err != nil {
t.Error(err)
}
if !match {
t.Errorf("Error '%v', does not match pattern '%v'.", actual, pattern)
}
})
t.Run("pointer to a pointer", func(t *testing.T) {
var str_out *string
_, actual := db.ExecContext(ctx, sqltextrun,
sql.Named("cstr", sql.Out{Dest: &str_out}),
)
pattern := ".*destination is a pointer to a pointer.*"
match, err := regexp.MatchString(pattern, actual.Error())
if err != nil {
t.Error(err)
}
if !match {
t.Errorf("Error '%v', does not match pattern '%v'.", actual, pattern)
}
})
})
}
func TestOutputINOUTParam(t *testing.T) {
sqltextcreate := `
CREATE PROCEDURE abinout
@aid INT = 1,
@bid INT = 2 OUTPUT,
@cstr NVARCHAR(2000) = NULL OUTPUT,
@vout VARCHAR(2000) = NULL OUTPUT,
@nullint INT = NULL OUTPUT,
@nullfloat FLOAT = NULL OUTPUT,
@nullstr NVARCHAR(10) = NULL OUTPUT,
@nullbit BIT = NULL OUTPUT,
@varbin VARBINARY(10) = NULL OUTPUT
AS
BEGIN
SELECT
@bid = @aid + @bid,
@cstr = 'OK',
@Vout = 'DREAM'
;
END;
`
sqltextdrop := `DROP PROCEDURE abinout;`
sqltextrun := `abinout`
checkConnStr(t)
SetLogger(testLogger{t})
db, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
t.Fatalf("failed to open driver sqlserver")
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db.ExecContext(ctx, sqltextdrop)
_, err = db.ExecContext(ctx, sqltextcreate)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdrop)
t.Run("original test", func(t *testing.T) {
var bout int64 = 3
var cout string
var vout VarChar
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("aid", 5),
sql.Named("bid", sql.Out{Dest: &bout}),
sql.Named("cstr", sql.Out{Dest: &cout}),
sql.Named("vout", sql.Out{Dest: &vout}),
)
if err != nil {
t.Error(err)
}
if bout != 8 {
t.Errorf("expected 8, got %d", bout)
}
if cout != "OK" {
t.Errorf("expected OK, got %s", cout)
}
if string(vout) != "DREAM" {
t.Errorf("expected DREAM, got %s", vout)
}
})
t.Run("test null values returned into nullable", func(t *testing.T) {
var nullint sql.NullInt64
var nullfloat sql.NullFloat64
var nullstr sql.NullString
var nullbit sql.NullBool
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("nullint", sql.Out{Dest: &nullint}),
sql.Named("nullfloat", sql.Out{Dest: &nullfloat}),
sql.Named("nullstr", sql.Out{Dest: &nullstr}),
sql.Named("nullbit", sql.Out{Dest: &nullbit}),
)
if err != nil {
t.Error(err)
}
if nullint.Valid {
t.Errorf("expected NULL, got %v", nullint)
}
if nullfloat.Valid {
t.Errorf("expected NULL, got %v", nullfloat)
}
if nullstr.Valid {
t.Errorf("expected NULL, got %v", nullstr)
}
if nullbit.Valid {
t.Errorf("expected NULL, got %v", nullbit)
}
})
// Not yet supported
//t.Run("test null values returned into pointers", func(t *testing.T) {
// var nullint *int64
// var nullfloat *float64
// var nullstr *string
// var nullbit *bool
// _, err = db.ExecContext(ctx, sqltextrun,
// sql.Named("nullint", sql.Out{Dest: &nullint}),
// sql.Named("nullfloat", sql.Out{Dest: &nullfloat}),
// sql.Named("nullstr", sql.Out{Dest: &nullstr}),
// sql.Named("nullbit", sql.Out{Dest: &nullbit}),
// )
// if err != nil {
// t.Error(err)
// }
// if nullint != nil {
// t.Errorf("expected NULL, got %v", nullint)
// }
// if nullfloat != nil {
// t.Errorf("expected NULL, got %v", nullfloat)
// }
// if nullstr != nil {
// t.Errorf("expected NULL, got %v", nullstr)
// }
// if nullbit != nil {
// t.Errorf("expected NULL, got %v", nullbit)
// }
//})
t.Run("test non null values into nullable", func(t *testing.T) {
nullint := sql.NullInt64{10, true}
nullfloat := sql.NullFloat64{1.5, true}
nullstr := sql.NullString{"hello", true}
nullbit := sql.NullBool{true, true}
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("nullint", sql.Out{Dest: &nullint}),
sql.Named("nullfloat", sql.Out{Dest: &nullfloat}),
sql.Named("nullstr", sql.Out{Dest: &nullstr}),
sql.Named("nullbit", sql.Out{Dest: &nullbit}),
)
if err != nil {
t.Error(err)
}
if !nullint.Valid {
t.Error("expected non null value, but got null")
}
if nullint.Int64 != 10 {
t.Errorf("expected 10, got %d", nullint.Int64)
}
if !nullfloat.Valid {
t.Error("expected non null value, but got null")
}
if nullfloat.Float64 != 1.5 {
t.Errorf("expected 1.5, got %v", nullfloat.Float64)
}
if !nullstr.Valid {
t.Error("expected non null value, but got null")
}
if nullstr.String != "hello" {
t.Errorf("expected hello, got %s", nullstr.String)
}
})
t.Run("test return into byte[]", func(t *testing.T) {
cstr := []byte{1, 2, 3}
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("varbin", sql.Out{Dest: &cstr}),
)
if err != nil {
t.Error(err)
}
expected := []byte{1, 2, 3}
if bytes.Compare(cstr, expected) != 0 {
t.Errorf("expected [1,2,3], got %v", cstr)
}
})
t.Run("test int into string", func(t *testing.T) {
var str string
_, err = db.ExecContext(ctx, sqltextrun,
sql.Named("bid", sql.Out{Dest: &str}),
)
if err != nil {
t.Error(err)
}
if str != "1" {
t.Errorf("expected '1', got %v", str)
}
})
t.Run("typeless null for output parameter should return error", func(t *testing.T) {
var val interface{}
_, actual := db.ExecContext(ctx, sqltextrun,
sql.Named("bid", sql.Out{Dest: &val}),
)
if actual == nil {
t.Error("Expected to fail but didn't")
}
pattern := ".*MSSQL does not allow NULL value without type for OUTPUT parameters.*"
match, err := regexp.MatchString(pattern, actual.Error())
if err != nil {
t.Error(err)
}
if !match {
t.Errorf("Error '%v', does not match pattern '%v'.", actual, pattern)
}
})
}
// TestOutputParamWithRows tests reading output parameter before and after
// retrieving rows from the result set of a stored procedure. SQL Server sends output
// parameters after all the rows are returned. Therefore, if the output parameter
// is read before all the rows are retrieved, the value will be incorrect.
//
// Issue https://github.com/denisenkom/go-mssqldb/issues/378
func TestOutputParamWithRows(t *testing.T) {
sqltextcreate := `
CREATE PROCEDURE spwithoutputandrows
@bitparam BIT OUTPUT
AS BEGIN
SET @bitparam = 1
SELECT 'Row 1'
END
`
sqltextdrop := `DROP PROCEDURE spwithoutputandrows;`
sqltextrun := `spwithoutputandrows`
checkConnStr(t)
SetLogger(testLogger{t})
db, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
t.Fatalf("failed to open driver sqlserver")
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db.ExecContext(ctx, sqltextdrop)
_, err = db.ExecContext(ctx, sqltextcreate)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdrop)
t.Run("Retrieve output after reading rows", func(t *testing.T) {
var bitout int64 = 5
rows, err := db.QueryContext(ctx, sqltextrun, sql.Named("bitparam", sql.Out{Dest: &bitout}))
if err != nil {
t.Error(err)
} else {
defer rows.Close()
var strrow string
for rows.Next() {
err = rows.Scan(&strrow)
}
if bitout != 1 {
t.Errorf("expected 1, got %d", bitout)
}
}
})
t.Run("Retrieve output before reading rows", func(t *testing.T) {
var bitout int64 = 5
rows, err := db.QueryContext(ctx, sqltextrun, sql.Named("bitparam", sql.Out{Dest: &bitout}))
if err != nil {
t.Error(err)
} else {
defer rows.Close()
if bitout != 5 {
t.Errorf("expected 5, got %d", bitout)
}
}
})
}
// TestTLSServerReadClose tests writing to an encrypted database connection.
// Currently the database server will close the connection while the server is
// reading the TDS packets and before any of the data has been parsed.
//
// When two queries are sent in reverse order, they PASS, but if we send only
// a single ping (SELECT 1;) first, then the long query the query fails.
//
// The long query text is never parsed. In fact, you can comment out, return
// early, or have malformed sql in the long query text. Just the length matters.
// The error happens when sending the TDS Batch packet to SQL Server the server
// closes the connection..
//
// It appears the driver sends valid TDS packets. In fact, if prefixed with 4
// "SELECT 1;" TDS Batch queries then the long query works, but if zero or one
// "SELECT 1;" TDS Batch queries are send prior the long query fails to send.
//
// Lastly, this only manafests itself with an encrypted connection. This has been
// observed with SQL Server Azure, SQL Server 13.0.1742 on Windows, and SQL Server
// 14.0.900.75 on Linux. It also fails when using the "dev.boringcrypto" (a C based
// TLS crypto). I haven't found any knobs on SQL Server to expose the error message
// nor have I found a good way to decrypt the TDS stream. KeyLogWriter in the TLS
// config may help with that, but wireshark wasn't decrypting TDS based TLS streams
// even when using that.
//
// Issue https://github.com/denisenkom/go-mssqldb/issues/166
func TestTLSServerReadClose(t *testing.T) {
query := `
with
config_cte (config) as (
select *
from ( values
('_partition:{\"Fill\":{\"PatternType\":\"solid\",\"FgColor\":\"99ff99\"}}')
, ('_separation:{\"Fill\":{\"PatternType\":\"solid\",\"FgColor\":\"99ffff\"}}')
, ('Monthly Earnings:\$#,##0.00 ;(\$#,##0.00)')
, ('Weekly Earnings:\$#,##0.00 ;(\$#,##0.00)')
, ('Total Earnings:\$#,##0.00 ;(\$#,##0.00)')
, ('Average Earnings:\$#,##0.00 ;(\$#,##0.00)')
, ('Last Month Earning:#,##0.00 ;(#,##0.00)')
, ('Award:\$#,##0.00 ;(\$#,##0.00)')
, ('Amount:\$#,##0.00 ;(\$#,##0.00)')
, ('Grand Total:\$#,##0.00 ;(\$#,##0.00)')
, ('Total:\$#,##0.00 ;(\$#,##0.00)')
, ('Price Each:\$#,##0.00 ;(\$#,##0.00)')
, ('Hyperwallet:\$#,##0.00 ;(\$#,##0.00)')
, ('Credit/Debit:\$#,##0.00 ;(\$#,##0.00)')
, ('Earning:#,##0.00 ;(#,##0.00)')
, ('Change Earning:#,##0.00 ;(#,##0.00)')
, ('CheckAmount:#,##0.00 ;(#,##0.00)')
, ('Residual:#,##0.00 ;(#,##0.00)')
, ('Prev Residual:#,##0.00 ;(#,##0.00)')
, ('Team Bonuses:#,##0.00 ;(#,##0.00)')
, ('Change:#,##0.00 ;(#,##0.00)')
, ('Shipping Total:#,##0.00 ;(#,##0.00)')
, ('SubTotal:\$#,##0.00 ;(\$#,##0.00)')
, ('Total Diff:#,##0.00 ;(#,##0.00)')
, ('SubTotal Diff:#,##0.00 ;(#,##0.00)')
, ('Return Total:#,##0.00 ;(#,##0.00)')
, ('Return SubTotal:#,##0.00 ;(#,##0.00)')
, ('Return Total Diff:#,##0.00 ;(#,##0.00)')
, ('Return SubTotal Diff:#,##0.00 ;(#,##0.00)')
, ('Cancel Total:#,##0.00 ;(#,##0.00)')
, ('Cancel SubTotal:#,##0.00 ;(#,##0.00)')
, ('Cancel Total Diff:#,##0.00 ;(#,##0.00)')
, ('Cancel SubTotal Diff:#,##0.00 ;(#,##0.00)')
, ('Replacement Total:#,##0.00 ;(#,##0.00)')
, ('Replacement SubTotal:#,##0.00 ;(#,##0.00)')
, ('Replacement Total Diff:#,##0.00 ;(#,##0.00)')
, ('Replacement SubTotal Diff:#,##0.00 ;(#,##0.00)')
, ('Jan Residual:#,##0.00 ;(#,##0.00)')
, ('Jan Bonus:#,##0.00 ;(#,##0.00)')
, ('Jan Total:#,##0.00 ;(#,##0.00)')
, ('January Residual:#,##0.00 ;(#,##0.00)')
, ('Feb Residual:#,##0.00 ;(#,##0.00)')
, ('Feb Bonus:#,##0.00 ;(#,##0.00)')
, ('Feb Total:#,##0.00 ;(#,##0.00)')
, ('February Residual:#,##0.00 ;(#,##0.00)')
, ('Mar Residual:#,##0.00 ;(#,##0.00)')
, ('Mar Bonus:#,##0.00 ;(#,##0.00)')
, ('Mar Total:#,##0.00 ;(#,##0.00)')
, ('March Residual:#,##0.00 ;(#,##0.00)')
, ('Apr Residual:#,##0.00 ;(#,##0.00)')
, ('Apr Bonus:#,##0.00 ;(#,##0.00)')
, ('Apr Total:#,##0.00 ;(#,##0.00)')
, ('April Residual:#,##0.00 ;(#,##0.00)')
, ('May Residual:#,##0.00 ;(#,##0.00)')
, ('May Bonus:#,##0.00 ;(#,##0.00)')
, ('May Total:#,##0.00 ;(#,##0.00)')
, ('Jun Residual:#,##0.00 ;(#,##0.00)')
, ('Jun Bonus:#,##0.00 ;(#,##0.00)')
, ('Jun Total:#,##0.00 ;(#,##0.00)')
, ('June Residual:#,##0.00 ;(#,##0.00)')
, ('Jul Residual:#,##0.00 ;(#,##0.00)')
, ('Jul Bonus:#,##0.00 ;(#,##0.00)')
, ('Jul Total:#,##0.00 ;(#,##0.00)')
, ('July Residual:#,##0.00 ;(#,##0.00)')
, ('Aug Residual:#,##0.00 ;(#,##0.00)')
, ('Aug Bonus:#,##0.00 ;(#,##0.00)')
, ('Aug Total:#,##0.00 ;(#,##0.00)')
, ('August Residual:#,##0.00 ;(#,##0.00)')
, ('Sep Residual:#,##0.00 ;(#,##0.00)')
, ('Sep Bonus:#,##0.00 ;(#,##0.00)')
, ('Sep Total:#,##0.00 ;(#,##0.00)')
, ('September Residual:#,##0.00 ;(#,##0.00)')
, ('Oct Residual:#,##0.00 ;(#,##0.00)')
, ('Oct Bonus:#,##0.00 ;(#,##0.00)')
, ('Oct Total:#,##0.00 ;(#,##0.00)')
, ('October Residual:#,##0.00 ;(#,##0.00)')
, ('Nov Residual:#,##0.00 ;(#,##0.00)')
, ('Nov Bonus:#,##0.00 ;(#,##0.00)')
, ('Nov Total:#,##0.00 ;(#,##0.00)')
, ('November Residual:#,##0.00 ;(#,##0.00)')
, ('Dec Residual:#,##0.00 ;(#,##0.00)')
, ('Dec Bonus:#,##0.00 ;(#,##0.00)')
, ('Dec Total:#,##0.00 ;(#,##0.00)')
, ('December Residual:#,##0.00 ;(#,##0.00)')
, ('January Bonus:#,##0.00 ;(#,##0.00)')
, ('February Bonus:#,##0.00 ;(#,##0.00)')
, ('March Bonus:#,##0.00 ;(#,##0.00)')
, ('April Bonus:#,##0.00 ;(#,##0.00)')
, ('May Bonus:#,##0.00 ;(#,##0.00)')
, ('June Bonus:#,##0.00 ;(#,##0.00)')
, ('July Bonus:#,##0.00 ;(#,##0.00)')
, ('August Bonus:#,##0.00 ;(#,##0.00)')
, ('September Bonus:#,##0.00 ;(#,##0.00)')
, ('October Bonus:#,##0.00 ;(#,##0.00)')
, ('November Bonus:#,##0.00 ;(#,##0.00)')
, ('December Bonus:#,##0.00 ;(#,##0.00)')
, ('January Adj:#,##0.00 ;(#,##0.00)')
, ('February Adj:#,##0.00 ;(#,##0.00)')
, ('March Adj:#,##0.00 ;(#,##0.00)')
, ('April Adj:#,##0.00 ;(#,##0.00)')
, ('May Adj:#,##0.00 ;(#,##0.00)')
, ('June Adj:#,##0.00 ;(#,##0.00)')
, ('July Adj:#,##0.00 ;(#,##0.00)')
, ('August Adj:#,##0.00 ;(#,##0.00)')
, ('September Adj:#,##0.00 ;(#,##0.00)')
, ('October Adj:#,##0.00 ;(#,##0.00)')
, ('November Adj:#,##0.00 ;(#,##0.00)')
, ('December Adj:#,##0.00 ;(#,##0.00)')
, ('2016- 2015 YTD Dif:#,##0.00 ;(#,##0.00)')
, ('2017- 2016 YTD Dif:#,##0.00 ;(#,##0.00)')
, ('2018- 2017 YTD Dif:#,##0.00 ;(#,##0.00)')
, ('Dec to Jan Dif Residual:#,##0.00 ;(#,##0.00)')
, ('Jan to Feb Dif Residual:#,##0.00 ;(#,##0.00)')
, ('Feb to Mar Dif Residual:#,##0.00 ;(#,##0.00)')
, ('Mar to Apr Dif Residual:#,##0.00 ;(#,##0.00)')
, ('Apr to May Dif Residual:#,##0.00 ;(#,##0.00)')
, ('May to Jun Dif Residual:#,##0.00 ;(#,##0.00)')
, ('Jun to Jul Dif Residual:#,##0.00 ;(#,##0.00)')
, ('Jul to Aug Dif Residual:#,##0.00 ;(#,##0.00)')
, ('Aug to Sep Dif Residual:#,##0.00 ;(#,##0.00)')
, ('Sep to Oct Dif Residual:#,##0.00 ;(#,##0.00)')
, ('Oct to Nov Dif Residual:#,##0.00 ;(#,##0.00)')
, ('Nov to Dec Dif Residual:#,##0.00 ;(#,##0.00)')
, ('Dec to Jan Dif Bonus:#,##0.00 ;(#,##0.00)')
, ('Jan to Feb Dif Bonus:#,##0.00 ;(#,##0.00)')
, ('Feb to Mar Dif Bonus:#,##0.00 ;(#,##0.00)')
, ('Mar to Apr Dif Bonus:#,##0.00 ;(#,##0.00)')
, ('Apr to May Dif Bonus:#,##0.00 ;(#,##0.00)')
, ('May to Jun Dif Bonus:#,##0.00 ;(#,##0.00)')
, ('Jun to Jul Dif Bonus:#,##0.00 ;(#,##0.00)')
, ('Jul to Aug Dif Bonus:#,##0.00 ;(#,##0.00)')
, ('Aug to Sep Dif Bonus:#,##0.00 ;(#,##0.00)')
, ('Sep to Oct Dif Bonus:#,##0.00 ;(#,##0.00)')
, ('Oct to Nov Dif Bonus:#,##0.00 ;(#,##0.00)')
, ('Nov to Dec Dif Bonus:#,##0.00 ;(#,##0.00)')
, ('Dec to Jan Dif Total:#,##0.00 ;(#,##0.00)')
, ('Jan to Feb Dif Total:#,##0.00 ;(#,##0.00)')
, ('Feb to Mar Dif Total:#,##0.00 ;(#,##0.00)')
, ('Mar to Apr Dif Total:#,##0.00 ;(#,##0.00)')
, ('Apr to May Dif Total:#,##0.00 ;(#,##0.00)')
, ('May to Jun Dif Total:#,##0.00 ;(#,##0.00)')
, ('Jun to Jul Dif Total:#,##0.00 ;(#,##0.00)')
, ('Jul to Aug Dif Total:#,##0.00 ;(#,##0.00)')
, ('Aug to Sep Dif Total:#,##0.00 ;(#,##0.00)')
, ('Sep to Oct Dif Total:#,##0.00 ;(#,##0.00)')
, ('Oct to Nov Dif Total:#,##0.00 ;(#,##0.00)')
, ('Nov to Dec Dif Total:#,##0.00 ;(#,##0.00)')
, ('Jan Refund Cnt:#,##0 ;(#,##0)')
, ('Feb Refund Cnt:#,##0 ;(#,##0)')
, ('Mar Refund Cnt:#,##0 ;(#,##0)')
, ('Apr Refund Cnt:#,##0 ;(#,##0)')
, ('May Refund Cnt:#,##0 ;(#,##0)')
, ('Jun Refund Cnt:#,##0 ;(#,##0)')
, ('Jul Refund Cnt:#,##0 ;(#,##0)')
, ('Aug Refund Cnt:#,##0 ;(#,##0)')
, ('Sep Refund Cnt:#,##0 ;(#,##0)')
, ('Oct Refund Cnt:#,##0 ;(#,##0)')
, ('Nov Refund Cnt:#,##0 ;(#,##0)')
, ('Dec Refund Cnt:#,##0 ;(#,##0)')
, ('Jan Purchase Cnt:#,##0 ;(#,##0)')
, ('Feb Purchase Cnt:#,##0 ;(#,##0)')
, ('Mar Purchase Cnt:#,##0 ;(#,##0)')
, ('Apr Purchase Cnt:#,##0 ;(#,##0)')
, ('May Purchase Cnt:#,##0 ;(#,##0)')
, ('Jun Purchase Cnt:#,##0 ;(#,##0)')
, ('Jul Purchase Cnt:#,##0 ;(#,##0)')
, ('Aug Purchase Cnt:#,##0 ;(#,##0)')
, ('Sep Purchase Cnt:#,##0 ;(#,##0)')
, ('Oct Purchase Cnt:#,##0 ;(#,##0)')
, ('Nov Purchase Cnt:#,##0 ;(#,##0)')
, ('Dec Purchase Cnt:#,##0 ;(#,##0)')
, ('Jan Refund Amt:#,##0.00 ;(#,##0.00)')
, ('Feb Refund Amt:#,##0.00 ;(#,##0.00)')
, ('Mar Refund Amt:#,##0.00 ;(#,##0.00)')
, ('Apr Refund Amt:#,##0.00 ;(#,##0.00)')
, ('May Refund Amt:#,##0.00 ;(#,##0.00)')
, ('Jun Refund Amt:#,##0.00 ;(#,##0.00)')
, ('Jul Refund Amt:#,##0.00 ;(#,##0.00)')
, ('Aug Refund Amt:#,##0.00 ;(#,##0.00)')
, ('Sep Refund Amt:#,##0.00 ;(#,##0.00)')
, ('Oct Refund Amt:#,##0.00 ;(#,##0.00)')
, ('Nov Refund Amt:#,##0.00 ;(#,##0.00)')
, ('Dec Refund Amt:#,##0.00 ;(#,##0.00)')
, ('Jan Purchase Amt:#,##0.00 ;(#,##0.00)')
, ('Feb Purchase Amt:#,##0.00 ;(#,##0.00)')
, ('Mar Purchase Amt:#,##0.00 ;(#,##0.00)')
, ('Apr Purchase Amt:#,##0.00 ;(#,##0.00)')
, ('May Purchase Amt:#,##0.00 ;(#,##0.00)')
, ('Jun Purchase Amt:#,##0.00 ;(#,##0.00)')
, ('Jul Purchase Amt:#,##0.00 ;(#,##0.00)')
, ('Aug Purchase Amt:#,##0.00 ;(#,##0.00)')
, ('Sep Purchase Amt:#,##0.00 ;(#,##0.00)')
, ('Oct Purchase Amt:#,##0.00 ;(#,##0.00)')
, ('Nov Purchase Amt:#,##0.00 ;(#,##0.00)')
, ('Dec Purchase Amt:#,##0.00 ;(#,##0.00)')
) X(a))
select * from config_cte
`
t.Logf("query len (utf16 bytes)=%d, len/4096=%f\n", len(query)*2, float64(len(query)*2)/4096)
db := open(t)
defer db.Close()
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
type run struct {
name string
pings []int
pass bool
conn *sql.Conn
}
// Use separate Conns from the connection pool to ensure separation.
runs := []*run{
{name: "rev", pings: []int{4, 1}, pass: true},
{name: "forward", pings: []int{1}, pass: true},
}
for _, r := range runs {
var err error
r.conn, err = db.Conn(ctx)
if err != nil {
t.Fatal(err)
}
defer r.conn.Close()
}
for _, r := range runs {
for _, ping := range r.pings {
t.Run(fmt.Sprintf("%s-ping-%d", r.name, ping), func(t *testing.T) {
for i := 0; i < ping; i++ {
if err := r.conn.PingContext(ctx); err != nil {
if r.pass {
t.Error("failed to ping server", err)
} else {
t.Log("failed to ping server", err)
}
return
}
}
rows, err := r.conn.QueryContext(ctx, query)
if err != nil {
if r.pass {
t.Errorf("QueryContext: %+v", err)
} else {
t.Logf("QueryContext: %+v", err)
}
return
}
for rows.Next() {
// Nothing.
}
rows.Close()
})
}
}
}
func TestDateTimeParam19(t *testing.T) {
conn := open(t)
defer conn.Close()
// testing DateTime1, only supported on go 1.9
var emptydate time.Time
mindate1 := time.Date(1753, 1, 1, 0, 0, 0, 0, time.UTC)
maxdate1 := time.Date(9999, 12, 31, 23, 59, 59, 997000000, time.UTC)
testdates1 := []DateTime1{
DateTime1(mindate1),
DateTime1(maxdate1),
DateTime1(time.Date(1752, 12, 31, 23, 59, 59, 997000000, time.UTC)), // just a little below minimum date
DateTime1(time.Date(10000, 1, 1, 0, 0, 0, 0, time.UTC)), // just a little over maximum date
DateTime1(emptydate),
}
for _, test := range testdates1 {
t.Run(fmt.Sprintf("Test datetime for %v", test), func(t *testing.T) {
var res time.Time
expected := time.Time(test)
queryParamRoundTrip(conn, test, &res)
// clip value
if expected.Before(mindate1) {
expected = mindate1
}
if expected.After(maxdate1) {
expected = maxdate1
}
if expected.Sub(res) != 0 {
t.Errorf("expected: '%s', got: '%s' delta: %d", expected, res, expected.Sub(res))
}
})
}
}
func TestReturnStatus(t *testing.T) {
conn := open(t)
defer conn.Close()
_, err := conn.Exec("if object_id('retstatus') is not null drop proc retstatus;")
if err != nil {
t.Fatal(err)
}
_, err = conn.Exec("create proc retstatus as return 2;")
if err != nil {
t.Fatal(err)
}
var rs ReturnStatus
_, err = conn.Exec("retstatus", &rs)
conn.Exec("drop proc retstatus;")
if err != nil {
t.Fatal(err)
}
if rs != 2 {
t.Errorf("expected status=2, got %d", rs)
}
}

File diff suppressed because it is too large Load diff

View file

@ -1,522 +0,0 @@
package mssql
import (
"bytes"
"context"
"database/sql"
"encoding/hex"
"fmt"
"net/url"
"os"
"testing"
"time"
)
type MockTransport struct {
bytes.Buffer
}
func (t *MockTransport) Close() error {
return nil
}
func TestSendLogin(t *testing.T) {
memBuf := new(MockTransport)
buf := newTdsBuffer(1024, memBuf)
login := login{
TDSVersion: verTDS73,
PacketSize: 0x1000,
ClientProgVer: 0x01060100,
ClientPID: 100,
ClientTimeZone: -4 * 60,
ClientID: [6]byte{0x12, 0x34, 0x56, 0x78, 0x90, 0xab},
OptionFlags1: 0xe0,
OptionFlags3: 8,
HostName: "subdev1",
UserName: "test",
Password: "testpwd",
AppName: "appname",
ServerName: "servername",
CtlIntName: "library",
Language: "en",
Database: "database",
ClientLCID: 0x204,
AtchDBFile: "filepath",
}
err := sendLogin(buf, login)
if err != nil {
t.Error("sendLogin should succeed")
}
ref := []byte{
16, 1, 0, 222, 0, 0, 1, 0, 198 + 16, 0, 0, 0, 3, 0, 10, 115, 0, 16, 0, 0, 0, 1,
6, 1, 100, 0, 0, 0, 0, 0, 0, 0, 224, 0, 0, 8, 16, 255, 255, 255, 4, 2, 0,
0, 94, 0, 7, 0, 108, 0, 4, 0, 116, 0, 7, 0, 130, 0, 7, 0, 144, 0, 10, 0, 0,
0, 0, 0, 164, 0, 7, 0, 178, 0, 2, 0, 182, 0, 8, 0, 18, 52, 86, 120, 144, 171,
198, 0, 0, 0, 198, 0, 8, 0, 214, 0, 0, 0, 0, 0, 0, 0, 115, 0, 117, 0, 98,
0, 100, 0, 101, 0, 118, 0, 49, 0, 116, 0, 101, 0, 115, 0, 116, 0, 226, 165,
243, 165, 146, 165, 226, 165, 162, 165, 210, 165, 227, 165, 97, 0, 112,
0, 112, 0, 110, 0, 97, 0, 109, 0, 101, 0, 115, 0, 101, 0, 114, 0, 118, 0,
101, 0, 114, 0, 110, 0, 97, 0, 109, 0, 101, 0, 108, 0, 105, 0, 98, 0, 114,
0, 97, 0, 114, 0, 121, 0, 101, 0, 110, 0, 100, 0, 97, 0, 116, 0, 97, 0, 98,
0, 97, 0, 115, 0, 101, 0, 102, 0, 105, 0, 108, 0, 101, 0, 112, 0, 97, 0,
116, 0, 104, 0}
out := memBuf.Bytes()
if !bytes.Equal(ref, out) {
fmt.Println("Expected:")
fmt.Print(hex.Dump(ref))
fmt.Println("Returned:")
fmt.Print(hex.Dump(out))
t.Error("input output don't match")
}
}
func TestSendSqlBatch(t *testing.T) {
checkConnStr(t)
p, err := parseConnectParams(makeConnStr(t).String())
if err != nil {
t.Error("parseConnectParams failed:", err.Error())
return
}
conn, err := connect(context.Background(), nil, optionalLogger{testLogger{t}}, p)
if err != nil {
t.Error("Open connection failed:", err.Error())
return
}
defer conn.buf.transport.Close()
headers := []headerStruct{
{hdrtype: dataStmHdrTransDescr,
data: transDescrHdr{0, 1}.pack()},
}
err = sendSqlBatch72(conn.buf, "select 1", headers, true)
if err != nil {
t.Error("Sending sql batch failed", err.Error())
return
}
ch := make(chan tokenStruct, 5)
go processResponse(context.Background(), conn, ch, nil)
var lastRow []interface{}
loop:
for tok := range ch {
switch token := tok.(type) {
case doneStruct:
break loop
case []columnStruct:
conn.columns = token
case []interface{}:
lastRow = token
default:
fmt.Println("unknown token", tok)
}
}
if len(lastRow) == 0 {
t.Fatal("expected row but no row set")
}
switch value := lastRow[0].(type) {
case int32:
if value != 1 {
t.Error("Invalid value returned, should be 1", value)
return
}
}
}
func checkConnStr(t *testing.T) {
if len(os.Getenv("SQLSERVER_DSN")) > 0 {
return
}
if len(os.Getenv("HOST")) > 0 && len(os.Getenv("DATABASE")) > 0 {
return
}
t.Skip("no database connection string")
}
// makeConnStr returns a URL struct so it may be modified by various
// tests before used as a DSN.
func makeConnStr(t *testing.T) *url.URL {
dsn := os.Getenv("SQLSERVER_DSN")
if len(dsn) > 0 {
parsed, err := url.Parse(dsn)
if err != nil {
t.Fatal("unable to parse SQLSERVER_DSN as URL", err)
}
values := parsed.Query()
if values.Get("log") == "" {
values.Set("log", "127")
}
parsed.RawQuery = values.Encode()
return parsed
}
values := url.Values{}
values.Set("log", "127")
values.Set("database", os.Getenv("DATABASE"))
return &url.URL{
Scheme: "sqlserver",
Host: os.Getenv("HOST"),
Path: os.Getenv("INSTANCE"),
User: url.UserPassword(os.Getenv("SQLUSER"), os.Getenv("SQLPASSWORD")),
RawQuery: values.Encode(),
}
}
type testLogger struct {
t *testing.T
}
func (l testLogger) Printf(format string, v ...interface{}) {
l.t.Logf(format, v...)
}
func (l testLogger) Println(v ...interface{}) {
l.t.Log(v...)
}
func open(t *testing.T) *sql.DB {
checkConnStr(t)
SetLogger(testLogger{t})
conn, err := sql.Open("mssql", makeConnStr(t).String())
if err != nil {
t.Error("Open connection failed:", err.Error())
return nil
}
return conn
}
func TestConnect(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
conn, err := sql.Open("mssql", makeConnStr(t).String())
if err != nil {
t.Error("Open connection failed:", err.Error())
return
}
defer conn.Close()
}
func simpleQuery(conn *sql.DB, t *testing.T) (stmt *sql.Stmt) {
stmt, err := conn.Prepare("select 1 as a")
if err != nil {
t.Error("Prepare failed:", err.Error())
return nil
}
return stmt
}
func checkSimpleQuery(rows *sql.Rows, t *testing.T) {
numrows := 0
for rows.Next() {
var val int
err := rows.Scan(&val)
if err != nil {
t.Error("Scan failed:", err.Error())
}
if val != 1 {
t.Error("query should return 1")
}
numrows++
}
if numrows != 1 {
t.Error("query should return 1 row, returned", numrows)
}
}
func TestQuery(t *testing.T) {
conn := open(t)
if conn == nil {
return
}
defer conn.Close()
stmt := simpleQuery(conn, t)
if stmt == nil {
return
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
}
defer rows.Close()
columns, err := rows.Columns()
if err != nil {
t.Error("getting columns failed", err.Error())
}
if len(columns) != 1 && columns[0] != "a" {
t.Error("returned incorrect columns (expected ['a']):", columns)
}
checkSimpleQuery(rows, t)
}
func TestMultipleQueriesSequentialy(t *testing.T) {
conn := open(t)
defer conn.Close()
stmt, err := conn.Prepare("select 1 as a")
if err != nil {
t.Error("Prepare failed:", err.Error())
return
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
defer rows.Close()
checkSimpleQuery(rows, t)
rows, err = stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
defer rows.Close()
checkSimpleQuery(rows, t)
}
func TestMultipleQueryClose(t *testing.T) {
conn := open(t)
defer conn.Close()
stmt, err := conn.Prepare("select 1 as a")
if err != nil {
t.Error("Prepare failed:", err.Error())
return
}
defer stmt.Close()
rows, err := stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
rows.Close()
rows, err = stmt.Query()
if err != nil {
t.Error("Query failed:", err.Error())
return
}
defer rows.Close()
checkSimpleQuery(rows, t)
}
func TestPing(t *testing.T) {
conn := open(t)
defer conn.Close()
conn.Ping()
}
func TestSecureWithInvalidHostName(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
dsn := makeConnStr(t)
dsnParams := dsn.Query()
dsnParams.Set("encrypt", "true")
dsnParams.Set("TrustServerCertificate", "false")
dsnParams.Set("hostNameInCertificate", "foo.bar")
dsn.RawQuery = dsnParams.Encode()
conn, err := sql.Open("mssql", dsn.String())
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
err = conn.Ping()
if err == nil {
t.Fatal("Connected to fake foo.bar server")
}
}
func TestSecureConnection(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
dsn := makeConnStr(t)
dsnParams := dsn.Query()
dsnParams.Set("encrypt", "true")
dsnParams.Set("TrustServerCertificate", "true")
dsn.RawQuery = dsnParams.Encode()
conn, err := sql.Open("mssql", dsn.String())
if err != nil {
t.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
var msg string
err = conn.QueryRow("select 'secret'").Scan(&msg)
if err != nil {
t.Fatal("cannot scan value", err)
}
if msg != "secret" {
t.Fatal("expected secret, got: ", msg)
}
var secure bool
err = conn.QueryRow("select encrypt_option from sys.dm_exec_connections where session_id=@@SPID").Scan(&secure)
if err != nil {
t.Fatal("cannot scan value", err)
}
if !secure {
t.Fatal("connection is not encrypted")
}
}
func TestInvalidConnectionString(t *testing.T) {
connStrings := []string{
"log=invalid",
"port=invalid",
"packet size=invalid",
"connection timeout=invalid",
"dial timeout=invalid",
"keepalive=invalid",
"encrypt=invalid",
"trustservercertificate=invalid",
"failoverport=invalid",
// ODBC mode
"odbc:password={",
"odbc:password={somepass",
"odbc:password={somepass}}",
"odbc:password={some}pass",
}
for _, connStr := range connStrings {
_, err := parseConnectParams(connStr)
if err == nil {
t.Errorf("Connection expected to fail for connection string %s but it didn't", connStr)
continue
} else {
t.Logf("Connection failed for %s as expected with error %v", connStr, err)
}
}
}
func TestValidConnectionString(t *testing.T) {
type testStruct struct {
connStr string
check func(connectParams) bool
}
connStrings := []testStruct{
{"server=server\\instance;database=testdb;user id=tester;password=pwd", func(p connectParams) bool {
return p.host == "server" && p.instance == "instance" && p.user == "tester" && p.password == "pwd"
}},
{"server=.", func(p connectParams) bool { return p.host == "localhost" }},
{"server=(local)", func(p connectParams) bool { return p.host == "localhost" }},
{"ServerSPN=serverspn;Workstation ID=workstid", func(p connectParams) bool { return p.serverSPN == "serverspn" && p.workstation == "workstid" }},
{"failoverpartner=fopartner;failoverport=2000", func(p connectParams) bool { return p.failOverPartner == "fopartner" && p.failOverPort == 2000 }},
{"app name=appname;applicationintent=ReadOnly", func(p connectParams) bool { return p.appname == "appname" && (p.typeFlags&fReadOnlyIntent != 0) }},
{"encrypt=disable", func(p connectParams) bool { return p.disableEncryption }},
{"encrypt=true", func(p connectParams) bool { return p.encrypt && !p.disableEncryption }},
{"encrypt=false", func(p connectParams) bool { return !p.encrypt && !p.disableEncryption }},
{"trustservercertificate=true", func(p connectParams) bool { return p.trustServerCertificate }},
{"trustservercertificate=false", func(p connectParams) bool { return !p.trustServerCertificate }},
{"certificate=abc", func(p connectParams) bool { return p.certificate == "abc" }},
{"hostnameincertificate=abc", func(p connectParams) bool { return p.hostInCertificate == "abc" }},
{"connection timeout=3;dial timeout=4;keepalive=5", func(p connectParams) bool {
return p.conn_timeout == 3*time.Second && p.dial_timeout == 4*time.Second && p.keepAlive == 5*time.Second
}},
{"log=63", func(p connectParams) bool { return p.logFlags == 63 && p.port == 1433 }},
{"log=63;port=1000", func(p connectParams) bool { return p.logFlags == 63 && p.port == 1000 }},
{"log=64", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 4096 }},
{"log=64;packet size=0", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 512 }},
{"log=64;packet size=300", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 512 }},
{"log=64;packet size=8192", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 8192 }},
{"log=64;packet size=48000", func(p connectParams) bool { return p.logFlags == 64 && p.packetSize == 32767 }},
// those are supported currently, but maybe should not be
{"someparam", func(p connectParams) bool { return true }},
{";;=;", func(p connectParams) bool { return true }},
// ODBC mode
{"odbc:server=somehost;user id=someuser;password=somepass", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "somepass"
}},
{"odbc:server=somehost;user id=someuser;password=some{pass", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some{pass"
}},
{"odbc:server={somehost};user id={someuser};password={somepass}", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "somepass"
}},
{"odbc:server={somehost};user id={someuser};password={some=pass}", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some=pass"
}},
{"odbc:server={somehost};user id={someuser};password={some;pass}", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some;pass"
}},
{"odbc:server={somehost};user id={someuser};password={some{pass}", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some{pass"
}},
{"odbc:server={somehost};user id={someuser};password={some}}pass}", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some}pass"
}},
{"odbc:server={somehost};user id={someuser};password={some{}}p=a;ss}", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some{}p=a;ss"
}},
{"odbc: server = somehost; user id = someuser ; password = {some pass } ", func(p connectParams) bool {
return p.host == "somehost" && p.user == "someuser" && p.password == "some pass "
}},
// URL mode
{"sqlserver://somehost?connection+timeout=30", func(p connectParams) bool {
return p.host == "somehost" && p.port == 1433 && p.instance == "" && p.conn_timeout == 30*time.Second
}},
{"sqlserver://someuser@somehost?connection+timeout=30", func(p connectParams) bool {
return p.host == "somehost" && p.port == 1433 && p.instance == "" && p.user == "someuser" && p.password == "" && p.conn_timeout == 30*time.Second
}},
{"sqlserver://someuser:@somehost?connection+timeout=30", func(p connectParams) bool {
return p.host == "somehost" && p.port == 1433 && p.instance == "" && p.user == "someuser" && p.password == "" && p.conn_timeout == 30*time.Second
}},
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost?connection+timeout=30", func(p connectParams) bool {
return p.host == "somehost" && p.port == 1433 && p.instance == "" && p.user == "someuser" && p.password == "foo:/\\!~@;bar" && p.conn_timeout == 30*time.Second
}},
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434?connection+timeout=30", func(p connectParams) bool {
return p.host == "somehost" && p.port == 1434 && p.instance == "" && p.user == "someuser" && p.password == "foo:/\\!~@;bar" && p.conn_timeout == 30*time.Second
}},
{"sqlserver://someuser:foo%3A%2F%5C%21~%40;bar@somehost:1434/someinstance?connection+timeout=30", func(p connectParams) bool {
return p.host == "somehost" && p.port == 1434 && p.instance == "someinstance" && p.user == "someuser" && p.password == "foo:/\\!~@;bar" && p.conn_timeout == 30*time.Second
}},
}
for _, ts := range connStrings {
p, err := parseConnectParams(ts.connStr)
if err == nil {
t.Logf("Connection string was parsed successfully %s", ts.connStr)
} else {
t.Errorf("Connection string %s failed to parse with error %s", ts.connStr, err)
continue
}
if !ts.check(p) {
t.Errorf("Check failed on conn str %s", ts.connStr)
}
}
}
func TestBadConnect(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
connURL := makeConnStr(t)
connURL.User = url.UserPassword("baduser", "badpwd")
badDSN := connURL.String()
conn, err := sql.Open("mssql", badDSN)
if err != nil {
t.Error("Open connection failed:", err.Error())
}
defer conn.Close()
err = conn.Ping()
if err == nil {
t.Error("Ping should fail for connection: ", badDSN)
}
}

View file

@ -1,114 +0,0 @@
// +build go1.10
package mssql_test
import (
"database/sql"
"flag"
"fmt"
"log"
mssql "github.com/denisenkom/go-mssqldb"
)
// This example shows how to use tvp type
func ExampleTVP() {
const (
createTable = "CREATE TABLE Location (Name VARCHAR(50), CostRate INT, Availability BIT, ModifiedDate DATETIME2)"
dropTable = "IF OBJECT_ID('Location', 'U') IS NOT NULL DROP TABLE Location"
createTVP = `CREATE TYPE LocationTableType AS TABLE
(LocationName VARCHAR(50),
CostRate INT)`
dropTVP = "IF type_id('LocationTableType') IS NOT NULL DROP TYPE LocationTableType"
createProc = `CREATE PROCEDURE dbo.usp_InsertProductionLocation
@TVP LocationTableType READONLY
AS
SET NOCOUNT ON
INSERT INTO Location
(
Name,
CostRate,
Availability,
ModifiedDate)
SELECT *, 0,GETDATE()
FROM @TVP`
dropProc = "IF OBJECT_ID('dbo.usp_InsertProductionLocation', 'P') IS NOT NULL DROP PROCEDURE dbo.usp_InsertProductionLocation"
execTvp = "exec dbo.usp_InsertProductionLocation @TVP;"
)
type LocationTableTvp struct {
LocationName string
LocationCountry string `tvp:"-"`
CostRate int64
Currency string `json:"-"`
}
flag.Parse()
if *debug {
fmt.Printf(" password:%s\n", *password)
fmt.Printf(" port:%d\n", *port)
fmt.Printf(" server:%s\n", *server)
fmt.Printf(" user:%s\n", *user)
}
connString := makeConnURL().String()
if *debug {
fmt.Printf(" connString:%s\n", connString)
}
db, err := sql.Open("sqlserver", connString)
if err != nil {
log.Fatal("Open connection failed:", err.Error())
}
defer db.Close()
_, err = db.Exec(createTable)
if err != nil {
log.Fatal(err)
}
_, err = db.Exec(createTVP)
if err != nil {
log.Fatal(err)
}
defer db.Exec(dropTVP)
_, err = db.Exec(createProc)
if err != nil {
log.Fatal(err)
}
defer db.Exec(dropProc)
locationTableTypeData := []LocationTableTvp{
{
LocationName: "Alberta",
LocationCountry: "Canada",
CostRate: 0,
Currency: "CAD",
},
{
LocationName: "British Columbia",
LocationCountry: "Canada",
CostRate: 1,
Currency: "CAD",
},
}
tvpType := mssql.TVP{
TypeName: "LocationTableType",
Value: locationTableTypeData,
}
_, err = db.Exec(execTvp, sql.Named("TVP", tvpType))
if err != nil {
log.Fatal(err)
} else {
for _, locationData := range locationTableTypeData {
fmt.Printf("Data for location %s, %s has been inserted.\n", locationData.LocationName, locationData.LocationCountry)
}
}
}

View file

@ -8,70 +8,56 @@ import (
"errors"
"fmt"
"reflect"
"strings"
"time"
)
const (
jsonTag = "json"
tvpTag = "tvp"
skipTagValue = "-"
sqlSeparator = "."
)
var (
ErrorEmptyTVPTypeName = errors.New("TypeName must not be empty")
ErrorTypeSlice = errors.New("TVP must be slice type")
ErrorTypeSliceIsEmpty = errors.New("TVP mustn't be null value")
ErrorSkip = errors.New("all fields mustn't skip")
ErrorObjectName = errors.New("wrong tvp name")
ErrorWrongTyping = errors.New("the number of elements in columnStr and tvpFieldIndexes do not align")
ErrorEmptyTVPName = errors.New("TVPTypeName must not be empty")
ErrorTVPTypeSlice = errors.New("TVPType must be slice type")
ErrorTVPTypeSliceIsEmpty = errors.New("TVPType mustn't be null value")
)
//TVP is driver type, which allows supporting Table Valued Parameters (TVP) in SQL Server
type TVP struct {
//TypeName mustn't be default value
TypeName string
//Value must be the slice, mustn't be nil
Value interface{}
//TVPType is driver type, which allows supporting Table Valued Parameters (TVP) in SQL Server
type TVPType struct {
//TVP param name, mustn't be default value
TVPTypeName string
//TVP scheme name
TVPScheme string
//TVP Value. Param must be the slice, mustn't be nil
TVPValue interface{}
}
func (tvp TVP) check() error {
if len(tvp.TypeName) == 0 {
return ErrorEmptyTVPTypeName
func (tvp TVPType) check() error {
if len(tvp.TVPTypeName) == 0 {
return ErrorEmptyTVPName
}
if !isProc(tvp.TypeName) {
return ErrorEmptyTVPTypeName
}
if sepCount := getCountSQLSeparators(tvp.TypeName); sepCount > 1 {
return ErrorObjectName
}
valueOf := reflect.ValueOf(tvp.Value)
valueOf := reflect.ValueOf(tvp.TVPValue)
if valueOf.Kind() != reflect.Slice {
return ErrorTypeSlice
return ErrorTVPTypeSlice
}
if valueOf.IsNil() {
return ErrorTypeSliceIsEmpty
return ErrorTVPTypeSliceIsEmpty
}
if reflect.TypeOf(tvp.Value).Elem().Kind() != reflect.Struct {
return ErrorTypeSlice
if reflect.TypeOf(tvp.TVPValue).Elem().Kind() != reflect.Struct {
return ErrorTVPTypeSlice
}
return nil
}
func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldIndexes []int) ([]byte, error) {
if len(columnStr) != len(tvpFieldIndexes) {
return nil, ErrorWrongTyping
}
preparedBuffer := make([]byte, 0, 20+(10*len(columnStr)))
buf := bytes.NewBuffer(preparedBuffer)
err := writeBVarChar(buf, "")
func (tvp TVPType) encode() ([]byte, error) {
columnStr, err := tvp.columnTypes()
if err != nil {
return nil, err
}
preparedBuffer := make([]byte, 0, 20+(10*len(columnStr)))
buf := bytes.NewBuffer(preparedBuffer)
err = writeBVarChar(buf, "")
if err != nil {
return nil, err
}
writeBVarChar(buf, tvp.TVPScheme)
writeBVarChar(buf, tvp.TVPTypeName)
writeBVarChar(buf, schema)
writeBVarChar(buf, name)
binary.Write(buf, binary.LittleEndian, uint16(len(columnStr)))
for i, column := range columnStr {
@ -80,9 +66,7 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
writeTypeInfo(buf, &columnStr[i].ti)
writeBVarChar(buf, "")
}
// The returned error is always nil
buf.WriteByte(_TVP_END_TOKEN)
conn := new(Conn)
conn.sess = new(tdsSession)
conn.sess.loginAck = loginAckStruct{TDSVersion: verTDS73}
@ -90,18 +74,18 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
c: conn,
}
val := reflect.ValueOf(tvp.Value)
val := reflect.ValueOf(tvp.TVPValue)
for i := 0; i < val.Len(); i++ {
refStr := reflect.ValueOf(val.Index(i).Interface())
buf.WriteByte(_TVP_ROW_TOKEN)
for columnStrIdx, fieldIdx := range tvpFieldIndexes {
field := refStr.Field(fieldIdx)
for j := 0; j < refStr.NumField(); j++ {
field := refStr.Field(j)
tvpVal := field.Interface()
valOf := reflect.ValueOf(tvpVal)
elemKind := field.Kind()
if elemKind == reflect.Ptr && valOf.IsNil() {
switch tvpVal.(type) {
case *bool, *time.Time, *int8, *int16, *int32, *int64, *float32, *float64, *int:
case *bool, *time.Time, *int8, *int16, *int32, *int64, *float32, *float64:
binary.Write(buf, binary.LittleEndian, uint8(0))
continue
default:
@ -122,44 +106,34 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd
if err != nil {
return nil, fmt.Errorf("failed to make tvp parameter row col: %s", err)
}
columnStr[columnStrIdx].ti.Writer(buf, param.ti, param.buffer)
columnStr[j].ti.Writer(buf, param.ti, param.buffer)
}
}
buf.WriteByte(_TVP_END_TOKEN)
return buf.Bytes(), nil
}
func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
val := reflect.ValueOf(tvp.Value)
func (tvp TVPType) columnTypes() ([]columnStruct, error) {
val := reflect.ValueOf(tvp.TVPValue)
var firstRow interface{}
if val.Len() != 0 {
firstRow = val.Index(0).Interface()
} else {
firstRow = reflect.New(reflect.TypeOf(tvp.Value).Elem()).Elem().Interface()
firstRow = reflect.New(reflect.TypeOf(tvp.TVPValue).Elem()).Elem().Interface()
}
tvpRow := reflect.TypeOf(firstRow)
columnCount := tvpRow.NumField()
defaultValues := make([]interface{}, 0, columnCount)
tvpFieldIndexes := make([]int, 0, columnCount)
for i := 0; i < columnCount; i++ {
field := tvpRow.Field(i)
tvpTagValue, isTvpTag := field.Tag.Lookup(tvpTag)
jsonTagValue, isJsonTag := field.Tag.Lookup(jsonTag)
if IsSkipField(tvpTagValue, isTvpTag, jsonTagValue, isJsonTag) {
continue
}
tvpFieldIndexes = append(tvpFieldIndexes, i)
if field.Type.Kind() == reflect.Ptr {
v := reflect.New(field.Type.Elem())
typeField := tvpRow.Field(i).Type
if typeField.Kind() == reflect.Ptr {
v := reflect.New(typeField.Elem())
defaultValues = append(defaultValues, v.Interface())
continue
}
defaultValues = append(defaultValues, reflect.Zero(field.Type).Interface())
}
if columnCount-len(tvpFieldIndexes) == columnCount {
return nil, nil, ErrorSkip
defaultValues = append(defaultValues, reflect.Zero(typeField).Interface())
}
conn := new(Conn)
@ -173,11 +147,11 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
for index, val := range defaultValues {
cval, err := convertInputParameter(val)
if err != nil {
return nil, nil, fmt.Errorf("failed to convert tvp parameter row %d col %d: %s", index, val, err)
return nil, fmt.Errorf("failed to convert tvp parameter row %d col %d: %s", index, val, err)
}
param, err := stmt.makeParam(cval)
if err != nil {
return nil, nil, err
return nil, err
}
column := columnStruct{
ti: param.ti,
@ -189,43 +163,5 @@ func (tvp TVP) columnTypes() ([]columnStruct, []int, error) {
columnConfiguration = append(columnConfiguration, column)
}
return columnConfiguration, tvpFieldIndexes, nil
}
func IsSkipField(tvpTagValue string, isTvpValue bool, jsonTagValue string, isJsonTagValue bool) bool {
if !isTvpValue && !isJsonTagValue {
return false
} else if isTvpValue && tvpTagValue != skipTagValue {
return false
} else if !isTvpValue && isJsonTagValue && jsonTagValue != skipTagValue {
return false
}
return true
}
func getSchemeAndName(tvpName string) (string, string, error) {
if len(tvpName) == 0 {
return "", "", ErrorEmptyTVPTypeName
}
splitVal := strings.Split(tvpName, ".")
if len(splitVal) > 2 {
return "", "", errors.New("wrong tvp name")
}
if len(splitVal) == 2 {
res := make([]string, 2)
for key, value := range splitVal {
tmp := strings.Replace(value, "[", "", -1)
tmp = strings.Replace(tmp, "]", "", -1)
res[key] = tmp
}
return res[0], res[1], nil
}
tmp := strings.Replace(splitVal[0], "[", "", -1)
tmp = strings.Replace(tmp, "]", "", -1)
return "", tmp, nil
}
func getCountSQLSeparators(str string) int {
return strings.Count(str, sqlSeparator)
return columnConfiguration, nil
}

View file

@ -1,751 +0,0 @@
// +build go1.9
package mssql
import (
"context"
"database/sql"
"log"
"reflect"
"testing"
"time"
)
const (
crateSchema = `create schema TestTVPSchema;`
dropSchema = `drop schema TestTVPSchema;`
createTVP = `
CREATE TYPE TestTVPSchema.exempleTVP AS TABLE
(
message NVARCHAR(100)
)`
dropTVP = `DROP TYPE TestTVPSchema.exempleTVP;`
procedureWithTVP = `
CREATE PROCEDURE ExecTVP
@param1 TestTVPSchema.exempleTVP READONLY
AS
BEGIN
SET NOCOUNT ON;
SELECT * FROM @param1;
END;
`
dropProcedure = `drop PROCEDURE ExecTVP`
execTvp = `exec ExecTVP @param1;`
)
type TvptableRow struct {
PBinary []byte `db:"p_binary"`
PVarchar string `db:"p_varchar"`
PVarcharNull *string `db:"p_varcharNull"`
PNvarchar string `db:"p_nvarchar"`
PNvarcharNull *string `db:"p_nvarcharNull"`
PID UniqueIdentifier `db:"p_id"`
PIDNull *UniqueIdentifier `db:"p_idNull"`
PVarbinary []byte `db:"p_varbinary"`
PTinyint int8 `db:"p_tinyint"`
PTinyintNull *int8 `db:"p_tinyintNull"`
PSmallint int16 `db:"p_smallint"`
PSmallintNull *int16 `db:"p_smallintNull"`
PInt int32 `db:"p_int"`
PIntNull *int32 `db:"p_intNull"`
PBigint int64 `db:"p_bigint"`
PBigintNull *int64 `db:"p_bigintNull"`
PBit bool `db:"p_bit"`
PBitNull *bool `db:"p_bitNull"`
PFloat32 float32 `db:"p_float32"`
PFloatNull32 *float32 `db:"p_floatNull32"`
PFloat64 float64 `db:"p_float64"`
PFloatNull64 *float64 `db:"p_floatNull64"`
DTime time.Time `db:"p_timeNull"`
DTimeNull *time.Time `db:"p_time"`
Pint int `db:"pInt"`
PintNull *int `db:"pIntNull"`
}
type TvptableRowWithSkipTag struct {
PBinary []byte `db:"p_binary"`
SkipPBinary []byte `json:"-"`
PVarchar string `db:"p_varchar"`
SkipPVarchar string `tvp:"-"`
PVarcharNull *string `db:"p_varcharNull"`
SkipPVarcharNull *string `json:"-" tvp:"-"`
PNvarchar string `db:"p_nvarchar"`
SkipPNvarchar string `json:"-"`
PNvarcharNull *string `db:"p_nvarcharNull"`
SkipPNvarcharNull *string `json:"-"`
PID UniqueIdentifier `db:"p_id"`
SkipPID UniqueIdentifier `json:"-"`
PIDNull *UniqueIdentifier `db:"p_idNull"`
SkipPIDNull *UniqueIdentifier `tvp:"-"`
PVarbinary []byte `db:"p_varbinary"`
SkipPVarbinary []byte `json:"-" tvp:"-"`
PTinyint int8 `db:"p_tinyint"`
SkipPTinyint int8 `tvp:"-"`
PTinyintNull *int8 `db:"p_tinyintNull"`
SkipPTinyintNull *int8 `tvp:"-" json:"any"`
PSmallint int16 `db:"p_smallint"`
SkipPSmallint int16 `json:"-"`
PSmallintNull *int16 `db:"p_smallintNull"`
SkipPSmallintNull *int16 `tvp:"-"`
PInt int32 `db:"p_int"`
SkipPInt int32 `json:"-"`
PIntNull *int32 `db:"p_intNull"`
SkipPIntNull *int32 `tvp:"-"`
PBigint int64 `db:"p_bigint"`
SkipPBigint int64 `tvp:"-"`
PBigintNull *int64 `db:"p_bigintNull"`
SkipPBigintNull *int64 `json:"any" tvp:"-"`
PBit bool `db:"p_bit"`
SkipPBit bool `json:"-"`
PBitNull *bool `db:"p_bitNull"`
SkipPBitNull *bool `json:"-"`
PFloat32 float32 `db:"p_float32"`
SkipPFloat32 float32 `tvp:"-"`
PFloatNull32 *float32 `db:"p_floatNull32"`
SkipPFloatNull32 *float32 `tvp:"-"`
PFloat64 float64 `db:"p_float64"`
SkipPFloat64 float64 `tvp:"-"`
PFloatNull64 *float64 `db:"p_floatNull64"`
SkipPFloatNull64 *float64 `tvp:"-"`
DTime time.Time `db:"p_timeNull"`
SkipDTime time.Time `tvp:"-"`
DTimeNull *time.Time `db:"p_time"`
SkipDTimeNull *time.Time `tvp:"-"`
Pint int `db:"p_int_null"`
SkipPint int `tvp:"-"`
PintNull *int `db:"p_int_"`
SkipPintNull *int `tvp:"-"`
}
func TestTVP(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
db, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
t.Fatalf("failed to open driver sqlserver")
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sqltextcreatetable := `
CREATE TYPE tvptable AS TABLE
(
p_binary BINARY(3),
p_varchar VARCHAR(500),
p_varcharNull VARCHAR(500),
p_nvarchar NVARCHAR(100),
p_nvarcharNull NVARCHAR(100),
p_id UNIQUEIDENTIFIER,
p_idNull UNIQUEIDENTIFIER,
p_varbinary VARBINARY(MAX),
p_tinyint TINYINT,
p_tinyintNull TINYINT,
p_smallint SMALLINT,
p_smallintNull SMALLINT,
p_int INT,
p_intNull INT,
p_bigint BIGINT,
p_bigintNull BIGINT,
p_bit BIT,
p_bitNull BIT,
p_float32 FLOAT,
p_floatNull32 FLOAT,
p_float64 FLOAT,
p_floatNull64 FLOAT,
p_time datetime2,
p_timeNull datetime2,
pInt INT,
pIntNull INT
); `
sqltextdroptable := `DROP TYPE tvptable;`
sqltextcreatesp := `
CREATE PROCEDURE spwithtvp
@param1 tvptable READONLY,
@param2 tvptable READONLY,
@param3 NVARCHAR(10)
AS
BEGIN
SET NOCOUNT ON;
SELECT * FROM @param1;
SELECT * FROM @param2;
SELECT @param3;
END;`
sqltextdropsp := `DROP PROCEDURE spwithtvp;`
db.ExecContext(ctx, sqltextdropsp)
db.ExecContext(ctx, sqltextdroptable)
_, err = db.ExecContext(ctx, sqltextcreatetable)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdroptable)
_, err = db.ExecContext(ctx, sqltextcreatesp)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdropsp)
varcharNull := "aaa"
nvarchar := "bbb"
bytesMock := []byte("ddd")
i8 := int8(1)
i16 := int16(2)
i32 := int32(3)
i64 := int64(4)
i := int(5)
bFalse := false
floatValue64 := 0.123
floatValue32 := float32(-10.123)
timeNow := time.Now().UTC()
param1 := []TvptableRow{
{
PBinary: []byte("ccc"),
PVarchar: varcharNull,
PNvarchar: nvarchar,
PID: UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
PVarbinary: bytesMock,
PTinyint: i8,
PSmallint: i16,
PInt: i32,
PBigint: i64,
PBit: bFalse,
PFloat32: floatValue32,
PFloat64: floatValue64,
DTime: timeNow,
Pint: 355,
},
{
PBinary: []byte("www"),
PVarchar: "eee",
PNvarchar: "lll",
PID: UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
PVarbinary: []byte("zzz"),
PTinyint: 5,
PSmallint: 16000,
PInt: 20000000,
PBigint: 2000000020000000,
PBit: true,
PFloat32: -123.45,
PFloat64: -123.45,
DTime: time.Date(2001, 11, 16, 23, 59, 39, 0, time.UTC),
Pint: 455,
},
{
PBinary: nil,
PVarcharNull: &varcharNull,
PNvarcharNull: &nvarchar,
PIDNull: &UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
PTinyintNull: &i8,
PSmallintNull: &i16,
PIntNull: &i32,
PBigintNull: &i64,
PBitNull: &bFalse,
PFloatNull32: &floatValue32,
PFloatNull64: &floatValue64,
DTime: timeNow,
DTimeNull: &timeNow,
PintNull: &i,
},
{
PBinary: []byte("www"),
PVarchar: "eee",
PNvarchar: "lll",
PIDNull: &UniqueIdentifier{},
PVarbinary: []byte("zzz"),
PTinyint: 5,
PSmallint: 16000,
PInt: 20000000,
PBigint: 2000000020000000,
PBit: true,
PFloat64: 123.45,
DTime: time.Date(2001, 11, 16, 23, 59, 39, 0, time.UTC),
PVarcharNull: &varcharNull,
PNvarcharNull: &nvarchar,
PTinyintNull: &i8,
PSmallintNull: &i16,
PIntNull: &i32,
PBigintNull: &i64,
PBitNull: &bFalse,
PFloatNull32: &floatValue32,
PFloatNull64: &floatValue64,
DTimeNull: &timeNow,
PintNull: &i,
},
}
tvpType := TVP{
TypeName: "tvptable",
Value: param1,
}
tvpTypeEmpty := TVP{
TypeName: "tvptable",
Value: []TvptableRow{},
}
rows, err := db.QueryContext(ctx,
"exec spwithtvp @param1, @param2, @param3",
sql.Named("param1", tvpType),
sql.Named("param2", tvpTypeEmpty),
sql.Named("param3", "test"),
)
if err != nil {
t.Fatal(err)
}
var result1 []TvptableRow
for rows.Next() {
var val TvptableRow
err := rows.Scan(
&val.PBinary,
&val.PVarchar,
&val.PVarcharNull,
&val.PNvarchar,
&val.PNvarcharNull,
&val.PID,
&val.PIDNull,
&val.PVarbinary,
&val.PTinyint,
&val.PTinyintNull,
&val.PSmallint,
&val.PSmallintNull,
&val.PInt,
&val.PIntNull,
&val.PBigint,
&val.PBigintNull,
&val.PBit,
&val.PBitNull,
&val.PFloat32,
&val.PFloatNull32,
&val.PFloat64,
&val.PFloatNull64,
&val.DTime,
&val.DTimeNull,
&val.Pint,
&val.PintNull,
)
if err != nil {
t.Fatalf("scan failed with error: %s", err)
}
result1 = append(result1, val)
}
if !reflect.DeepEqual(param1, result1) {
t.Logf("expected: %+v", param1)
t.Logf("actual: %+v", result1)
t.Errorf("first resultset did not match param1")
}
if !rows.NextResultSet() {
t.Errorf("second resultset did not exist")
}
if rows.Next() {
t.Errorf("second resultset was not empty")
}
if !rows.NextResultSet() {
t.Errorf("third resultset did not exist")
}
if !rows.Next() {
t.Errorf("third resultset was empty")
}
var result3 string
if err := rows.Scan(&result3); err != nil {
t.Errorf("error scanning third result set: %s", err)
}
if result3 != "test" {
t.Errorf("third result set had wrong value expected: %s actual: %s", "test", result3)
}
}
func TestTVP_WithTag(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
db, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
t.Fatalf("failed to open driver sqlserver")
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
sqltextcreatetable := `
CREATE TYPE tvptable AS TABLE
(
p_binary BINARY(3),
p_varchar VARCHAR(500),
p_varcharNull VARCHAR(500),
p_nvarchar NVARCHAR(100),
p_nvarcharNull NVARCHAR(100),
p_id UNIQUEIDENTIFIER,
p_idNull UNIQUEIDENTIFIER,
p_varbinary VARBINARY(MAX),
p_tinyint TINYINT,
p_tinyintNull TINYINT,
p_smallint SMALLINT,
p_smallintNull SMALLINT,
p_int INT,
p_intNull INT,
p_bigint BIGINT,
p_bigintNull BIGINT,
p_bit BIT,
p_bitNull BIT,
p_float32 FLOAT,
p_floatNull32 FLOAT,
p_float64 FLOAT,
p_floatNull64 FLOAT,
p_time datetime2,
p_timeNull datetime2,
pInt INT,
pIntNull INT
); `
sqltextdroptable := `DROP TYPE tvptable;`
sqltextcreatesp := `
CREATE PROCEDURE spwithtvp
@param1 tvptable READONLY,
@param2 tvptable READONLY,
@param3 NVARCHAR(10)
AS
BEGIN
SET NOCOUNT ON;
SELECT * FROM @param1;
SELECT * FROM @param2;
SELECT @param3;
END;`
sqltextdropsp := `DROP PROCEDURE spwithtvp;`
db.ExecContext(ctx, sqltextdropsp)
db.ExecContext(ctx, sqltextdroptable)
_, err = db.ExecContext(ctx, sqltextcreatetable)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdroptable)
_, err = db.ExecContext(ctx, sqltextcreatesp)
if err != nil {
t.Fatal(err)
}
defer db.ExecContext(ctx, sqltextdropsp)
varcharNull := "aaa"
nvarchar := "bbb"
bytesMock := []byte("ddd")
i8 := int8(1)
i16 := int16(2)
i32 := int32(3)
i64 := int64(4)
i := int(355)
bFalse := false
floatValue64 := 0.123
floatValue32 := float32(-10.123)
timeNow := time.Now().UTC()
param1 := []TvptableRowWithSkipTag{
{
PBinary: []byte("ccc"),
PVarchar: varcharNull,
PNvarchar: nvarchar,
PID: UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
PVarbinary: bytesMock,
PTinyint: i8,
PSmallint: i16,
PInt: i32,
PBigint: i64,
PBit: bFalse,
PFloat32: floatValue32,
PFloat64: floatValue64,
DTime: timeNow,
Pint: i,
PintNull: &i,
},
{
PBinary: []byte("www"),
PVarchar: "eee",
PNvarchar: "lll",
PID: UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
PVarbinary: []byte("zzz"),
PTinyint: 5,
PSmallint: 16000,
PInt: 20000000,
PBigint: 2000000020000000,
PBit: true,
PFloat32: -123.45,
PFloat64: -123.45,
DTime: time.Date(2001, 11, 16, 23, 59, 39, 0, time.UTC),
Pint: 3669,
PintNull: &i,
},
{
PBinary: nil,
PVarcharNull: &varcharNull,
PNvarcharNull: &nvarchar,
PIDNull: &UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF},
PTinyintNull: &i8,
PSmallintNull: &i16,
PIntNull: &i32,
PBigintNull: &i64,
PBitNull: &bFalse,
PFloatNull32: &floatValue32,
PFloatNull64: &floatValue64,
DTime: timeNow,
DTimeNull: &timeNow,
Pint: 969,
},
{
PBinary: []byte("www"),
PVarchar: "eee",
PNvarchar: "lll",
PIDNull: &UniqueIdentifier{},
PVarbinary: []byte("zzz"),
PTinyint: 5,
PSmallint: 16000,
PInt: 20000000,
PBigint: 2000000020000000,
PBit: true,
PFloat64: 123.45,
DTime: time.Date(2001, 11, 16, 23, 59, 39, 0, time.UTC),
PVarcharNull: &varcharNull,
PNvarcharNull: &nvarchar,
PTinyintNull: &i8,
PSmallintNull: &i16,
PIntNull: &i32,
PBigintNull: &i64,
PBitNull: &bFalse,
PFloatNull32: &floatValue32,
PFloatNull64: &floatValue64,
DTimeNull: &timeNow,
PintNull: &i,
},
}
tvpType := TVP{
TypeName: "tvptable",
Value: param1,
}
tvpTypeEmpty := TVP{
TypeName: "tvptable",
Value: []TvptableRowWithSkipTag{},
}
rows, err := db.QueryContext(ctx,
"exec spwithtvp @param1, @param2, @param3",
sql.Named("param1", tvpType),
sql.Named("param2", tvpTypeEmpty),
sql.Named("param3", "test"),
)
if err != nil {
t.Fatal(err)
}
var result1 []TvptableRowWithSkipTag
for rows.Next() {
var val TvptableRowWithSkipTag
err := rows.Scan(
&val.PBinary,
&val.PVarchar,
&val.PVarcharNull,
&val.PNvarchar,
&val.PNvarcharNull,
&val.PID,
&val.PIDNull,
&val.PVarbinary,
&val.PTinyint,
&val.PTinyintNull,
&val.PSmallint,
&val.PSmallintNull,
&val.PInt,
&val.PIntNull,
&val.PBigint,
&val.PBigintNull,
&val.PBit,
&val.PBitNull,
&val.PFloat32,
&val.PFloatNull32,
&val.PFloat64,
&val.PFloatNull64,
&val.DTime,
&val.DTimeNull,
&val.Pint,
&val.PintNull,
)
if err != nil {
t.Fatalf("scan failed with error: %s", err)
}
result1 = append(result1, val)
}
if !reflect.DeepEqual(param1, result1) {
t.Logf("expected: %+v", param1)
t.Logf("actual: %+v", result1)
t.Errorf("first resultset did not match param1")
}
if !rows.NextResultSet() {
t.Errorf("second resultset did not exist")
}
if rows.Next() {
t.Errorf("second resultset was not empty")
}
if !rows.NextResultSet() {
t.Errorf("third resultset did not exist")
}
if !rows.Next() {
t.Errorf("third resultset was empty")
}
var result3 string
if err := rows.Scan(&result3); err != nil {
t.Errorf("error scanning third result set: %s", err)
}
if result3 != "test" {
t.Errorf("third result set had wrong value expected: %s actual: %s", "test", result3)
}
}
type TvpExample struct {
Message string
}
func TestTVPSchema(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
conn, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
log.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
_, err = conn.Exec(crateSchema)
if err != nil {
log.Println(err)
return
}
defer conn.Exec(dropSchema)
_, err = conn.Exec(createTVP)
if err != nil {
log.Println(err)
return
}
defer conn.Exec(dropTVP)
_, err = conn.Exec(procedureWithTVP)
if err != nil {
log.Println(err)
return
}
defer conn.Exec(dropProcedure)
exempleData := []TvpExample{
{
Message: "Hello",
},
{
Message: "World",
},
{
Message: "TVP",
},
}
tvpType := TVP{
TypeName: "exempleTVP",
Value: exempleData,
}
rows, err := conn.Query(execTvp,
sql.Named("param1", tvpType),
)
if err != nil {
log.Println(err)
return
}
tvpResult := make([]TvpExample, 0)
for rows.Next() {
tvpExemple := TvpExample{}
err = rows.Scan(&tvpExemple.Message)
if err != nil {
log.Println(err)
return
}
tvpResult = append(tvpResult, tvpExemple)
}
log.Println(tvpResult)
}
func TestTVPObject(t *testing.T) {
checkConnStr(t)
SetLogger(testLogger{t})
conn, err := sql.Open("sqlserver", makeConnStr(t).String())
if err != nil {
log.Fatal("Open connection failed:", err.Error())
}
defer conn.Close()
tests := []struct {
name string
tvp TVP
wantErr bool
}{
{
name: "empty name",
wantErr: true,
tvp: TVP{TypeName: ""},
},
{
name: "value is wrong type",
wantErr: true,
tvp: TVP{TypeName: "type", Value: "wrong type"},
},
{
name: "tvp type is wrong",
wantErr: true,
tvp: TVP{TypeName: "[type", Value: []TvpExample{{}}},
},
{
name: "tvp type is wrong",
wantErr: true,
tvp: TVP{TypeName: "[type", Value: []TestFieldsUnsupportedTypes{{}}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := conn.Exec("somequery", tt.tvp)
if (err != nil) != tt.wantErr {
t.Errorf("TVP.encode() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}

View file

@ -1,578 +0,0 @@
// +build go1.9
package mssql
import (
"reflect"
"testing"
"time"
)
type TestFields struct {
PBinary []byte `tvp:"p_binary"`
PVarchar string `json:"p_varchar"`
PNvarchar *string `json:"p_nvarchar"`
TimeValue time.Time `echo:"-"`
TimeNullValue *time.Time
}
type TestFieldError struct {
ErrorValue []*byte
}
type TestFieldsUnsupportedTypes struct {
ErrorType TestFieldError
}
func TestTVPType_columnTypes(t *testing.T) {
type customTypeAllFieldsSkipOne struct {
SkipTest int `tvp:"-"`
}
type customTypeAllFieldsSkipMoreOne struct {
SkipTest int `tvp:"-"`
SkipTest1 int `json:"-"`
}
type skipWrongField struct {
SkipTest int
SkipTest1 []*byte `json:"skip_test" tvp:"-"`
}
type structType struct {
SkipTest int `json:"-" tvp:"test"`
SkipTest1 []*skipWrongField `json:"any" tvp:"tvp"`
}
type skipWithAnotherTagValue struct {
SkipTest int `json:"-" tvp:"test"`
}
type fields struct {
TVPName string
TVPValue interface{}
}
tests := []struct {
name string
fields fields
want []columnStruct
wantErr bool
}{
{
name: "Test Pass",
fields: fields{
TVPValue: []TestFields{TestFields{}},
},
},
{
name: "Value has wrong field type",
fields: fields{
TVPValue: []TestFieldError{TestFieldError{}},
},
wantErr: true,
},
{
name: "Value has wrong type",
fields: fields{
TVPValue: []TestFieldsUnsupportedTypes{},
},
wantErr: true,
},
{
name: "Value has wrong type",
fields: fields{
TVPValue: []structType{},
},
wantErr: true,
},
{
name: "CustomTag all fields are skip, single field",
fields: fields{
TVPValue: []customTypeAllFieldsSkipOne{},
},
wantErr: true,
},
{
name: "CustomTag all fields are skip, > 1 field",
fields: fields{
TVPValue: []customTypeAllFieldsSkipMoreOne{},
},
wantErr: true,
},
{
name: "CustomTag all fields are skip wrong field type",
fields: fields{
TVPValue: []skipWrongField{},
},
wantErr: false,
},
{
name: "CustomTag tag value is not -",
fields: fields{
TVPValue: []skipWithAnotherTagValue{},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tvp := TVP{
TypeName: tt.fields.TVPName,
Value: tt.fields.TVPValue,
}
_, _, err := tvp.columnTypes()
if (err != nil) != tt.wantErr {
t.Errorf("TVP.columnTypes() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func TestTVPType_check(t *testing.T) {
type fields struct {
TVPName string
TVPValue interface{}
}
var nullSlice []*string
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "TypeName is nil",
wantErr: true,
},
{
name: "Value is nil",
fields: fields{
TVPName: "Test",
TVPValue: nil,
},
wantErr: true,
},
{
name: "Value is nil",
fields: fields{
TVPName: "Test",
},
wantErr: true,
},
{
name: "Value isn't slice",
fields: fields{
TVPName: "Test",
TVPValue: "",
},
wantErr: true,
},
{
name: "Value isn't slice",
fields: fields{
TVPName: "Test",
TVPValue: 12345,
},
wantErr: true,
},
{
name: "Value isn't slice",
fields: fields{
TVPName: "Test",
TVPValue: nullSlice,
},
wantErr: true,
},
{
name: "Value isn't right",
fields: fields{
TVPName: "Test",
TVPValue: []*fields{},
},
wantErr: true,
},
{
name: "Value is right",
fields: fields{
TVPName: "Test",
TVPValue: []fields{},
},
wantErr: false,
},
{
name: "Value is right",
fields: fields{
TVPName: "Test",
TVPValue: []fields{},
},
wantErr: false,
},
{
name: "Value is right",
fields: fields{
TVPName: "[Test]",
TVPValue: []fields{},
},
wantErr: false,
},
{
name: "Value is right",
fields: fields{
TVPName: "[123].[Test]",
TVPValue: []fields{},
},
wantErr: false,
},
{
name: "TVP name is right",
fields: fields{
TVPName: "[123].Test",
TVPValue: []fields{},
},
wantErr: false,
},
{
name: "TVP name is right",
fields: fields{
TVPName: "123.[Test]",
TVPValue: []fields{},
},
wantErr: false,
},
{
name: "TVP name is wrong",
fields: fields{
TVPName: "123.[Test\n]",
TVPValue: []fields{},
},
wantErr: true,
},
{
name: "TVP name is wrong",
fields: fields{
TVPName: "123.[Test].456",
TVPValue: []fields{},
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tvp := TVP{
TypeName: tt.fields.TVPName,
Value: tt.fields.TVPValue,
}
if err := tvp.check(); (err != nil) != tt.wantErr {
t.Errorf("TVP.check() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func BenchmarkTVPType_check(b *testing.B) {
type val struct {
Value string
}
tvp := TVP{
TypeName: "Test",
Value: []val{},
}
for i := 0; i < b.N; i++ {
err := tvp.check()
if err != nil {
b.Fail()
}
}
}
func BenchmarkColumnTypes(b *testing.B) {
type str struct {
bytes byte
bytesNull *byte
bytesSlice []byte
int8s int8
int8sNull *int8
uint8s uint8
uint8sNull *uint8
int16s int16
int16sNull *int16
uint16s uint16
uint16sNull *uint16
int32s int32
int32sNull *int32
uint32s uint32
uint32sNull *uint32
int64s int64
int64sNull *int64
uint64s uint64
uint64sNull *uint64
stringVal string
stringValNull *string
bools bool
boolsNull *bool
}
wal := make([]str, 100)
tvp := TVP{
TypeName: "Test",
Value: wal,
}
for i := 0; i < b.N; i++ {
_, _, err := tvp.columnTypes()
if err != nil {
b.Error(err)
}
}
}
func TestIsSkipField(t *testing.T) {
type args struct {
tvpTagValue string
isTvpValue bool
jsonTagValue string
isJsonTagValue bool
}
tests := []struct {
name string
args args
want bool
}{
{
name: "Empty tags",
want: false,
},
{
name: "tvp is skip",
want: true,
args: args{
isTvpValue: true,
tvpTagValue: skipTagValue,
},
},
{
name: "tvp is any",
want: false,
args: args{
isTvpValue: true,
tvpTagValue: "tvp",
},
},
{
name: "Json is skip",
want: true,
args: args{
isJsonTagValue: true,
jsonTagValue: skipTagValue,
},
},
{
name: "Json is any",
want: false,
args: args{
isJsonTagValue: true,
jsonTagValue: "any",
},
},
{
name: "Json is skip tvp is skip",
want: true,
args: args{
isJsonTagValue: true,
jsonTagValue: skipTagValue,
isTvpValue: true,
tvpTagValue: skipTagValue,
},
},
{
name: "Json is skip tvp is any",
want: false,
args: args{
isJsonTagValue: true,
jsonTagValue: skipTagValue,
isTvpValue: true,
tvpTagValue: "tvp",
},
},
{
name: "Json is any tvp is skip",
want: true,
args: args{
isJsonTagValue: true,
jsonTagValue: "json",
isTvpValue: true,
tvpTagValue: skipTagValue,
},
},
{
name: "Json is any tvp is skip",
want: false,
args: args{
isJsonTagValue: true,
jsonTagValue: "json",
isTvpValue: true,
tvpTagValue: "tvp",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := IsSkipField(tt.args.tvpTagValue, tt.args.isTvpValue, tt.args.jsonTagValue, tt.args.isJsonTagValue); got != tt.want {
t.Errorf("IsSkipField() = %v, schema %v", got, tt.want)
}
})
}
}
func Test_getSchemeAndName(t *testing.T) {
type args struct {
tvpName string
}
tests := []struct {
name string
args args
schema string
tvpName string
wantErr bool
}{
{
name: "Empty object name",
wantErr: true,
},
{
name: "Wrong object name",
wantErr: true,
args: args{
tvpName: "1.2.3",
},
},
{
name: "Schema+name",
wantErr: false,
args: args{
tvpName: "obj.tvp",
},
schema: "obj",
tvpName: "tvp",
},
{
name: "Schema+name",
wantErr: false,
args: args{
tvpName: "[obj].[tvp]",
},
schema: "obj",
tvpName: "tvp",
},
{
name: "only name",
wantErr: false,
args: args{
tvpName: "tvp",
},
schema: "",
tvpName: "tvp",
},
{
name: "only name",
wantErr: false,
args: args{
tvpName: "[tvp]",
},
schema: "",
tvpName: "tvp",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
schema, name, err := getSchemeAndName(tt.args.tvpName)
if (err != nil) != tt.wantErr {
t.Errorf("getSchemeAndName() error = %v, wantErr %v", err, tt.wantErr)
return
}
if schema != tt.schema {
t.Errorf("getSchemeAndName() schema = %v, schema %v", schema, tt.schema)
}
if name != tt.tvpName {
t.Errorf("getSchemeAndName() name = %v, schema %v", name, tt.tvpName)
}
})
}
}
func TestTVP_encode(t *testing.T) {
type fields struct {
TypeName string
Value interface{}
}
type args struct {
schema string
name string
columnStr []columnStruct
tvpFieldIndexes []int
}
tests := []struct {
name string
fields fields
args args
want []byte
wantErr bool
wantPanic bool
}{
{
name: "column and indexes are nil",
wantErr: true,
args: args{
tvpFieldIndexes: []int{1, 2},
},
},
{
name: "column and indexes are nil",
wantErr: true,
args: args{
tvpFieldIndexes: []int{1, 2},
columnStr: []columnStruct{columnStruct{}},
},
},
{
name: "column and indexes are nil",
wantErr: true,
args: args{
columnStr: []columnStruct{columnStruct{}},
},
},
{
name: "column and indexes are nil",
wantErr: true,
wantPanic: true,
args: args{
schema: string(make([]byte, 256)),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.wantPanic {
defer func() {
if r := recover(); r == nil {
t.Errorf("Want panic")
}
}()
}
tvp := TVP{
TypeName: tt.fields.TypeName,
Value: tt.fields.Value,
}
got, err := tvp.encode(tt.args.schema, tt.args.name, tt.args.columnStr, tt.args.tvpFieldIndexes)
if (err != nil) != tt.wantErr {
t.Errorf("TVP.encode() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("TVP.encode() = %v, want %v", got, tt.want)
}
})
}
}

View file

@ -73,10 +73,14 @@ const (
const _PLP_NULL = 0xFFFFFFFFFFFFFFFF
const _UNKNOWN_PLP_LEN = 0xFFFFFFFFFFFFFFFE
const _PLP_TERMINATOR = 0x00000000
const _TVP_NULL_TOKEN = 0xffff
// TVP COLUMN FLAGS
const _TVP_COLUMN_DEFAULT_FLAG = 0x200
const _TVP_END_TOKEN = 0x00
const _TVP_ROW_TOKEN = 0x01
const _TVP_ORDER_UNIQUE_TOKEN = 0x10
const _TVP_COLUMN_ORDERING_TOKEN = 0x11
// TYPE_INFO rule
// http://msdn.microsoft.com/en-us/library/dd358284.aspx

View file

@ -1,123 +0,0 @@
package mssql
import (
"reflect"
"testing"
"time"
)
func TestMakeGoLangScanType(t *testing.T) {
if (reflect.TypeOf(int64(0)) != makeGoLangScanType(typeInfo{TypeId: typeInt8})) {
t.Errorf("invalid type returned for typeDateTime")
}
if (reflect.TypeOf(float64(0)) != makeGoLangScanType(typeInfo{TypeId: typeFlt4})) {
t.Errorf("invalid type returned for typeDateTime")
}
if (reflect.TypeOf(float64(0)) != makeGoLangScanType(typeInfo{TypeId: typeFlt8})) {
t.Errorf("invalid type returned for typeDateTime")
}
if (reflect.TypeOf("") != makeGoLangScanType(typeInfo{TypeId: typeVarChar})) {
t.Errorf("invalid type returned for typeDateTime")
}
if (reflect.TypeOf(time.Time{}) != makeGoLangScanType(typeInfo{TypeId: typeDateTime})) {
t.Errorf("invalid type returned for typeDateTime")
}
if (reflect.TypeOf(time.Time{}) != makeGoLangScanType(typeInfo{TypeId: typeDateTim4})) {
t.Errorf("invalid type returned for typeDateTim4")
}
if (reflect.TypeOf(int64(0)) != makeGoLangScanType(typeInfo{TypeId: typeInt1})) {
t.Errorf("invalid type returned for typeInt1")
}
if (reflect.TypeOf(int64(0)) != makeGoLangScanType(typeInfo{TypeId: typeInt2})) {
t.Errorf("invalid type returned for typeInt2")
}
if (reflect.TypeOf(int64(0)) != makeGoLangScanType(typeInfo{TypeId: typeInt4})) {
t.Errorf("invalid type returned for typeInt4")
}
if (reflect.TypeOf(int64(0)) != makeGoLangScanType(typeInfo{TypeId: typeIntN, Size: 4})) {
t.Errorf("invalid type returned for typeIntN")
}
if (reflect.TypeOf([]byte{}) != makeGoLangScanType(typeInfo{TypeId: typeMoney, Size: 8})) {
t.Errorf("invalid type returned for typeIntN")
}
}
func TestMakeGoLangTypeName(t *testing.T) {
defer handlePanic(t)
tests := []struct {
typeName string
typeString string
typeID uint8
}{
{"typeDateTime", "DATETIME", typeDateTime},
{"typeDateTim4", "SMALLDATETIME", typeDateTim4},
{"typeBigBinary", "BINARY", typeBigBinary},
//TODO: Add other supported types
}
for _, tt := range tests {
if makeGoLangTypeName(typeInfo{TypeId: tt.typeID}) != tt.typeString {
t.Errorf("invalid type name returned for %s", tt.typeName)
}
}
}
func TestMakeGoLangTypeLength(t *testing.T) {
defer handlePanic(t)
tests := []struct {
typeName string
typeVarLen bool
typeLen int64
typeID uint8
}{
{"typeDateTime", false, 0, typeDateTime},
{"typeDateTim4", false, 0, typeDateTim4},
{"typeBigBinary", false, 0, typeBigBinary},
//TODO: Add other supported types
}
for _, tt := range tests {
n, v := makeGoLangTypeLength(typeInfo{TypeId: tt.typeID})
if v != tt.typeVarLen {
t.Errorf("invalid type length variability returned for %s", tt.typeName)
}
if n != tt.typeLen {
t.Errorf("invalid type length returned for %s", tt.typeName)
}
}
}
func TestMakeGoLangTypePrecisionScale(t *testing.T) {
defer handlePanic(t)
tests := []struct {
typeName string
typeID uint8
typeVarLen bool
typePrec int64
typeScale int64
}{
{"typeDateTime", typeDateTime, false, 0, 0},
{"typeDateTim4", typeDateTim4, false, 0, 0},
{"typeBigBinary", typeBigBinary, false, 0, 0},
//TODO: Add other supported types
}
for _, tt := range tests {
prec, scale, varLen := makeGoLangTypePrecisionScale(typeInfo{TypeId: tt.typeID})
if varLen != tt.typeVarLen {
t.Errorf("invalid type length variability returned for %s", tt.typeName)
}
if prec != tt.typePrec || scale != tt.typeScale {
t.Errorf("invalid type precision and/or scale returned for %s", tt.typeName)
}
}
}
func handlePanic(t *testing.T) {
if r := recover(); r != nil {
t.Errorf("recovered panic")
}
}

View file

@ -1,70 +0,0 @@
package mssql
import (
"bytes"
"database/sql"
"database/sql/driver"
"fmt"
"testing"
)
func TestUniqueIdentifier(t *testing.T) {
dbUUID := UniqueIdentifier{0x67, 0x45, 0x23, 0x01,
0xAB, 0x89,
0xEF, 0xCD,
0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF,
}
uuid := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}
t.Run("Scan", func(t *testing.T) {
t.Run("[]byte", func(t *testing.T) {
var sut UniqueIdentifier
if err := sut.Scan(dbUUID[:]); err != nil {
t.Fatal(err)
}
if sut != uuid {
t.Errorf("bytes not swapped correctly: got %q; want %q", sut, uuid)
}
})
t.Run("string", func(t *testing.T) {
var sut UniqueIdentifier
if err := sut.Scan(uuid.String()); err != nil {
t.Fatal(err)
}
if sut != uuid {
t.Errorf("string not scanned correctly: got %q; want %q", sut, uuid)
}
})
})
t.Run("Value", func(t *testing.T) {
sut := uuid
v, err := sut.Value()
if err != nil {
t.Fatal(err)
}
b, ok := v.([]byte)
if !ok {
t.Fatalf("(%T) is not []byte", v)
}
if !bytes.Equal(b, dbUUID[:]) {
t.Errorf("got %q; want %q", b, dbUUID)
}
})
}
func TestUniqueIdentifierString(t *testing.T) {
sut := UniqueIdentifier{0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF}
expected := "01234567-89AB-CDEF-0123-456789ABCDEF"
if actual := sut.String(); actual != expected {
t.Errorf("sut.String() = %s; want %s", sut, expected)
}
}
var _ fmt.Stringer = UniqueIdentifier{}
var _ sql.Scanner = &UniqueIdentifier{}
var _ driver.Valuer = UniqueIdentifier{}

View file

@ -27,7 +27,6 @@ Daniël van Eeden <git at myname.nl>
Dave Protasowski <dprotaso at gmail.com>
DisposaBoy <disposaboy at dby.me>
Egor Smolyakov <egorsmkv at gmail.com>
Erwan Martin <hello at erwan.io>
Evan Shaw <evan at vendhq.com>
Frederick Mayle <frederickmayle at gmail.com>
Gustavo Kristic <gkristic at gmail.com>
@ -35,15 +34,12 @@ Hajime Nakagami <nakagami at gmail.com>
Hanno Braun <mail at hannobraun.com>
Henri Yandell <flamefew at gmail.com>
Hirotaka Yamamoto <ymmt2005 at gmail.com>
Huyiguang <hyg at webterren.com>
ICHINOSE Shogo <shogo82148 at gmail.com>
Ilia Cimpoes <ichimpoesh at gmail.com>
INADA Naoki <songofacandy at gmail.com>
Jacek Szwec <szwec.jacek at gmail.com>
James Harr <james.harr at gmail.com>
Jeff Hodges <jeff at somethingsimilar.com>
Jeffrey Charles <jeffreycharles at gmail.com>
Jerome Meyer <jxmeyer at gmail.com>
Jian Zhen <zhenjl at gmail.com>
Joshua Prunier <joshua.prunier at gmail.com>
Julien Lefevre <julien.lefevr at gmail.com>
@ -73,14 +69,9 @@ Richard Wilkes <wilkes at me.com>
Robert Russell <robert at rrbrussell.com>
Runrioter Wung <runrioter at gmail.com>
Shuode Li <elemount at qq.com>
Simon J Mudd <sjmudd at pobox.com>
Soroush Pour <me at soroushjp.com>
Stan Putrya <root.vagner at gmail.com>
Stanley Gunawan <gunawan.stanley at gmail.com>
Steven Hartland <steven.hartland at multiplay.co.uk>
Thomas Wodarek <wodarekwebpage at gmail.com>
Tim Ruffles <timruffles at gmail.com>
Tom Jenkinson <tom at tjenkinson.me>
Xiangyu Hu <xiangyu.hu at outlook.com>
Xiaobing Jiang <s7v7nislands at gmail.com>
Xiuming Chen <cc at cxm.cc>
@ -90,12 +81,9 @@ Zhenye Xie <xiezhenye at gmail.com>
Barracuda Networks, Inc.
Counting Ltd.
Facebook Inc.
GitHub Inc.
Google Inc.
InfoSum Ltd.
Keybase Inc.
Multiplay Ltd.
Percona LLC
Pivotal Inc.
Stripe Inc.

View file

@ -40,7 +40,7 @@ A MySQL-Driver for Go's [database/sql](https://golang.org/pkg/database/sql/) pac
* Optional placeholder interpolation
## Requirements
* Go 1.9 or higher. We aim to support the 3 latest versions of Go.
* Go 1.7 or higher. We aim to support the 3 latest versions of Go.
* MySQL (4.1+), MariaDB, Percona Server, Google CloudSQL or Sphinx (2.2.3+)
---------------------------------------
@ -171,18 +171,13 @@ Unless you need the fallback behavior, please use `collation` instead.
```
Type: string
Valid Values: <name>
Default: utf8mb4_general_ci
Default: utf8_general_ci
```
Sets the collation used for client-server interaction on connection. In contrast to `charset`, `collation` does not issue additional queries. If the specified collation is unavailable on the target server, the connection will fail.
A list of valid charsets for a server is retrievable with `SHOW COLLATION`.
The default collation (`utf8mb4_general_ci`) is supported from MySQL 5.5. You should use an older collation (e.g. `utf8_general_ci`) for older MySQL.
Collations for charset "ucs2", "utf16", "utf16le", and "utf32" can not be used ([ref](https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset)).
##### `clientFoundRows`
```
@ -333,11 +328,11 @@ Timeout for establishing connections, aka dial timeout. The value must be a deci
```
Type: bool / string
Valid Values: true, false, skip-verify, preferred, <name>
Valid Values: true, false, skip-verify, <name>
Default: false
```
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side) or use `preferred` to use TLS only when advertised by the server. This is similar to `skip-verify`, but additionally allows a fallback to a connection which is not encrypted. Neither `skip-verify` nor `preferred` add any reliable security. You can use a custom TLS config after registering it with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
`tls=true` enables TLS / SSL encrypted connection to the server. Use `skip-verify` if you want to use a self-signed or invalid certificate (server side). Use a custom value registered with [`mysql.RegisterTLSConfig`](https://godoc.org/github.com/go-sql-driver/mysql#RegisterTLSConfig).
##### `writeTimeout`
@ -449,7 +444,7 @@ See the [godoc of Go-MySQL-Driver](https://godoc.org/github.com/go-sql-driver/my
### `time.Time` support
The default internal output type of MySQL `DATE` and `DATETIME` values is `[]byte` which allows you to scan the value into a `[]byte`, `string` or `sql.RawBytes` variable in your program.
However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical equivalent in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
However, many want to scan MySQL `DATE` and `DATETIME` values into `time.Time` variables, which is the logical opposite in Go to `DATE` and `DATETIME` in MySQL. You can do that by changing the internal output type from `[]byte` to `time.Time` with the DSN parameter `parseTime=true`. You can set the default [`time.Time` location](https://golang.org/pkg/time/#Location) with the `loc` DSN parameter.
**Caution:** As of Go 1.1, this makes `time.Time` the only variable type you can scan `DATE` and `DATETIME` values into. This breaks for example [`sql.RawBytes` support](https://github.com/go-sql-driver/mysql/wiki/Examples#rawbytes).

View file

@ -11,15 +11,9 @@
package mysql
import (
"context"
"net"
"google.golang.org/appengine/cloudsql"
)
func init() {
RegisterDialContext("cloudsql", func(_ context.Context, instance string) (net.Conn, error) {
// XXX: the cloudsql driver still does not export a Context-aware dialer.
return cloudsql.Dial(instance)
})
RegisterDial("cloudsql", cloudsql.Dial)
}

View file

@ -234,64 +234,64 @@ func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) erro
if err != nil {
return err
}
return mc.writeAuthSwitchPacket(enc)
return mc.writeAuthSwitchPacket(enc, false)
}
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, error) {
func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) {
switch plugin {
case "caching_sha2_password":
authResp := scrambleSHA256Password(authData, mc.cfg.Passwd)
return authResp, nil
return authResp, (authResp == nil), nil
case "mysql_old_password":
if !mc.cfg.AllowOldPasswords {
return nil, ErrOldPassword
return nil, false, ErrOldPassword
}
// Note: there are edge cases where this should work but doesn't;
// this is currently "wontfix":
// https://github.com/go-sql-driver/mysql/issues/184
authResp := append(scrambleOldPassword(authData[:8], mc.cfg.Passwd), 0)
return authResp, nil
authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd)
return authResp, true, nil
case "mysql_clear_password":
if !mc.cfg.AllowCleartextPasswords {
return nil, ErrCleartextPassword
return nil, false, ErrCleartextPassword
}
// http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html
// http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html
return append([]byte(mc.cfg.Passwd), 0), nil
return []byte(mc.cfg.Passwd), true, nil
case "mysql_native_password":
if !mc.cfg.AllowNativePasswords {
return nil, ErrNativePassword
return nil, false, ErrNativePassword
}
// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
// Native password authentication only need and will need 20-byte challenge.
authResp := scramblePassword(authData[:20], mc.cfg.Passwd)
return authResp, nil
return authResp, false, nil
case "sha256_password":
if len(mc.cfg.Passwd) == 0 {
return []byte{0}, nil
return nil, true, nil
}
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
// write cleartext auth packet
return append([]byte(mc.cfg.Passwd), 0), nil
return []byte(mc.cfg.Passwd), true, nil
}
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
return []byte{1}, nil
return []byte{1}, false, nil
}
// encrypted password
enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey)
return enc, err
return enc, false, err
default:
errLog.Print("unknown auth plugin:", plugin)
return nil, ErrUnknownPlugin
return nil, false, ErrUnknownPlugin
}
}
@ -315,11 +315,11 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
plugin = newPlugin
authResp, err := mc.auth(authData, plugin)
authResp, addNUL, err := mc.auth(authData, plugin)
if err != nil {
return err
}
if err = mc.writeAuthSwitchPacket(authResp); err != nil {
if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil {
return err
}
@ -352,7 +352,7 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
case cachingSha2PasswordPerformFullAuthentication:
if mc.cfg.tls != nil || mc.cfg.Net == "unix" {
// write cleartext auth packet
err = mc.writeAuthSwitchPacket(append([]byte(mc.cfg.Passwd), 0))
err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true)
if err != nil {
return err
}
@ -360,15 +360,13 @@ func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error {
pubKey := mc.cfg.pubKey
if pubKey == nil {
// request public key from server
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
return err
}
data := mc.buf.takeSmallBuffer(4 + 1)
data[4] = cachingSha2PasswordRequestPublicKey
mc.writePacket(data)
// parse public key
if data, err = mc.readPacket(); err != nil {
data, err := mc.readPacket()
if err != nil {
return err
}

File diff suppressed because it is too large Load diff

View file

@ -1,373 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"fmt"
"math"
"runtime"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
type TB testing.B
func (tb *TB) check(err error) {
if err != nil {
tb.Fatal(err)
}
}
func (tb *TB) checkDB(db *sql.DB, err error) *sql.DB {
tb.check(err)
return db
}
func (tb *TB) checkRows(rows *sql.Rows, err error) *sql.Rows {
tb.check(err)
return rows
}
func (tb *TB) checkStmt(stmt *sql.Stmt, err error) *sql.Stmt {
tb.check(err)
return stmt
}
func initDB(b *testing.B, queries ...string) *sql.DB {
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
for _, query := range queries {
if _, err := db.Exec(query); err != nil {
b.Fatalf("error on %q: %v", query, err)
}
}
return db
}
const concurrencyLevel = 10
func BenchmarkQuery(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
b.ReportAllocs()
db := initDB(b,
"DROP TABLE IF EXISTS foo",
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
`INSERT INTO foo VALUES (1, "one")`,
`INSERT INTO foo VALUES (2, "two")`,
)
db.SetMaxIdleConns(concurrencyLevel)
defer db.Close()
stmt := tb.checkStmt(db.Prepare("SELECT val FROM foo WHERE id=?"))
defer stmt.Close()
remain := int64(b.N)
var wg sync.WaitGroup
wg.Add(concurrencyLevel)
defer wg.Wait()
b.StartTimer()
for i := 0; i < concurrencyLevel; i++ {
go func() {
for {
if atomic.AddInt64(&remain, -1) < 0 {
wg.Done()
return
}
var got string
tb.check(stmt.QueryRow(1).Scan(&got))
if got != "one" {
b.Errorf("query = %q; want one", got)
wg.Done()
return
}
}
}()
}
}
func BenchmarkExec(b *testing.B) {
tb := (*TB)(b)
b.StopTimer()
b.ReportAllocs()
db := tb.checkDB(sql.Open("mysql", dsn))
db.SetMaxIdleConns(concurrencyLevel)
defer db.Close()
stmt := tb.checkStmt(db.Prepare("DO 1"))
defer stmt.Close()
remain := int64(b.N)
var wg sync.WaitGroup
wg.Add(concurrencyLevel)
defer wg.Wait()
b.StartTimer()
for i := 0; i < concurrencyLevel; i++ {
go func() {
for {
if atomic.AddInt64(&remain, -1) < 0 {
wg.Done()
return
}
if _, err := stmt.Exec(); err != nil {
b.Fatal(err.Error())
}
}
}()
}
}
// data, but no db writes
var roundtripSample []byte
func initRoundtripBenchmarks() ([]byte, int, int) {
if roundtripSample == nil {
roundtripSample = []byte(strings.Repeat("0123456789abcdef", 1024*1024))
}
return roundtripSample, 16, len(roundtripSample)
}
func BenchmarkRoundtripTxt(b *testing.B) {
b.StopTimer()
sample, min, max := initRoundtripBenchmarks()
sampleString := string(sample)
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
defer db.Close()
b.StartTimer()
var result string
for i := 0; i < b.N; i++ {
length := min + i
if length > max {
length = max
}
test := sampleString[0:length]
rows := tb.checkRows(db.Query(`SELECT "` + test + `"`))
if !rows.Next() {
rows.Close()
b.Fatalf("crashed")
}
err := rows.Scan(&result)
if err != nil {
rows.Close()
b.Fatalf("crashed")
}
if result != test {
rows.Close()
b.Errorf("mismatch")
}
rows.Close()
}
}
func BenchmarkRoundtripBin(b *testing.B) {
b.StopTimer()
sample, min, max := initRoundtripBenchmarks()
b.ReportAllocs()
tb := (*TB)(b)
db := tb.checkDB(sql.Open("mysql", dsn))
defer db.Close()
stmt := tb.checkStmt(db.Prepare("SELECT ?"))
defer stmt.Close()
b.StartTimer()
var result sql.RawBytes
for i := 0; i < b.N; i++ {
length := min + i
if length > max {
length = max
}
test := sample[0:length]
rows := tb.checkRows(stmt.Query(test))
if !rows.Next() {
rows.Close()
b.Fatalf("crashed")
}
err := rows.Scan(&result)
if err != nil {
rows.Close()
b.Fatalf("crashed")
}
if !bytes.Equal(result, test) {
rows.Close()
b.Errorf("mismatch")
}
rows.Close()
}
}
func BenchmarkInterpolation(b *testing.B) {
mc := &mysqlConn{
cfg: &Config{
InterpolateParams: true,
Loc: time.UTC,
},
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
buf: newBuffer(nil),
}
args := []driver.Value{
int64(42424242),
float64(math.Pi),
false,
time.Unix(1423411542, 807015000),
[]byte("bytes containing special chars ' \" \a \x00"),
"string containing special chars ' \" \a \x00",
}
q := "SELECT ?, ?, ?, ?, ?, ?"
b.ReportAllocs()
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, err := mc.interpolateParams(q, args)
if err != nil {
b.Fatal(err)
}
}
}
func benchmarkQueryContext(b *testing.B, db *sql.DB, p int) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
tb := (*TB)(b)
stmt := tb.checkStmt(db.PrepareContext(ctx, "SELECT val FROM foo WHERE id=?"))
defer stmt.Close()
b.SetParallelism(p)
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
var got string
for pb.Next() {
tb.check(stmt.QueryRow(1).Scan(&got))
if got != "one" {
b.Fatalf("query = %q; want one", got)
}
}
})
}
func BenchmarkQueryContext(b *testing.B) {
db := initDB(b,
"DROP TABLE IF EXISTS foo",
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
`INSERT INTO foo VALUES (1, "one")`,
`INSERT INTO foo VALUES (2, "two")`,
)
defer db.Close()
for _, p := range []int{1, 2, 3, 4} {
b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
benchmarkQueryContext(b, db, p)
})
}
}
func benchmarkExecContext(b *testing.B, db *sql.DB, p int) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
db.SetMaxIdleConns(p * runtime.GOMAXPROCS(0))
tb := (*TB)(b)
stmt := tb.checkStmt(db.PrepareContext(ctx, "DO 1"))
defer stmt.Close()
b.SetParallelism(p)
b.ReportAllocs()
b.ResetTimer()
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
if _, err := stmt.ExecContext(ctx); err != nil {
b.Fatal(err)
}
}
})
}
func BenchmarkExecContext(b *testing.B) {
db := initDB(b,
"DROP TABLE IF EXISTS foo",
"CREATE TABLE foo (id INT PRIMARY KEY, val CHAR(50))",
`INSERT INTO foo VALUES (1, "one")`,
`INSERT INTO foo VALUES (2, "two")`,
)
defer db.Close()
for _, p := range []int{1, 2, 3, 4} {
b.Run(fmt.Sprintf("%d", p), func(b *testing.B) {
benchmarkQueryContext(b, db, p)
})
}
}
// BenchmarkQueryRawBytes benchmarks fetching 100 blobs using sql.RawBytes.
// "size=" means size of each blobs.
func BenchmarkQueryRawBytes(b *testing.B) {
var sizes []int = []int{100, 1000, 2000, 4000, 8000, 12000, 16000, 32000, 64000, 256000}
db := initDB(b,
"DROP TABLE IF EXISTS bench_rawbytes",
"CREATE TABLE bench_rawbytes (id INT PRIMARY KEY, val LONGBLOB)",
)
defer db.Close()
blob := make([]byte, sizes[len(sizes)-1])
for i := range blob {
blob[i] = 42
}
for i := 0; i < 100; i++ {
_, err := db.Exec("INSERT INTO bench_rawbytes VALUES (?, ?)", i, blob)
if err != nil {
b.Fatal(err)
}
}
for _, s := range sizes {
b.Run(fmt.Sprintf("size=%v", s), func(b *testing.B) {
db.SetMaxIdleConns(0)
db.SetMaxIdleConns(1)
b.ReportAllocs()
b.ResetTimer()
for j := 0; j < b.N; j++ {
rows, err := db.Query("SELECT LEFT(val, ?) as v FROM bench_rawbytes", s)
if err != nil {
b.Fatal(err)
}
nrows := 0
for rows.Next() {
var buf sql.RawBytes
err := rows.Scan(&buf)
if err != nil {
b.Fatal(err)
}
if len(buf) != s {
b.Fatalf("size mismatch: expected %v, got %v", s, len(buf))
}
nrows++
}
rows.Close()
if nrows != 100 {
b.Fatalf("numbers of rows mismatch: expected %v, got %v", 100, nrows)
}
}
})
}
}

View file

@ -15,69 +15,47 @@ import (
)
const defaultBufSize = 4096
const maxCachedBufSize = 256 * 1024
// A buffer which is used for both reading and writing.
// This is possible since communication on each connection is synchronous.
// In other words, we can't write and read simultaneously on the same connection.
// The buffer is similar to bufio.Reader / Writer but zero-copy-ish
// Also highly optimized for this particular use case.
// This buffer is backed by two byte slices in a double-buffering scheme
type buffer struct {
buf []byte // buf is a byte buffer who's length and capacity are equal.
buf []byte
nc net.Conn
idx int
length int
timeout time.Duration
dbuf [2][]byte // dbuf is an array with the two byte slices that back this buffer
flipcnt uint // flipccnt is the current buffer counter for double-buffering
}
// newBuffer allocates and returns a new buffer.
func newBuffer(nc net.Conn) buffer {
fg := make([]byte, defaultBufSize)
var b [defaultBufSize]byte
return buffer{
buf: fg,
nc: nc,
dbuf: [2][]byte{fg, nil},
buf: b[:],
nc: nc,
}
}
// flip replaces the active buffer with the background buffer
// this is a delayed flip that simply increases the buffer counter;
// the actual flip will be performed the next time we call `buffer.fill`
func (b *buffer) flip() {
b.flipcnt += 1
}
// fill reads into the buffer until at least _need_ bytes are in it
func (b *buffer) fill(need int) error {
n := b.length
// fill data into its double-buffering target: if we've called
// flip on this buffer, we'll be copying to the background buffer,
// and then filling it with network data; otherwise we'll just move
// the contents of the current buffer to the front before filling it
dest := b.dbuf[b.flipcnt&1]
// grow buffer if necessary to fit the whole packet.
if need > len(dest) {
// move existing data to the beginning
if n > 0 && b.idx > 0 {
copy(b.buf[0:n], b.buf[b.idx:])
}
// grow buffer if necessary
// TODO: let the buffer shrink again at some point
// Maybe keep the org buf slice and swap back?
if need > len(b.buf) {
// Round up to the next multiple of the default size
dest = make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
// if the allocated buffer is not too large, move it to backing storage
// to prevent extra allocations on applications that perform large reads
if len(dest) <= maxCachedBufSize {
b.dbuf[b.flipcnt&1] = dest
}
newBuf := make([]byte, ((need/defaultBufSize)+1)*defaultBufSize)
copy(newBuf, b.buf)
b.buf = newBuf
}
// if we're filling the fg buffer, move the existing data to the start of it.
// if we're filling the bg buffer, copy over the data
if n > 0 {
copy(dest[:n], b.buf[b.idx:])
}
b.buf = dest
b.idx = 0
for {
@ -127,56 +105,43 @@ func (b *buffer) readNext(need int) ([]byte, error) {
return b.buf[offset:b.idx], nil
}
// takeBuffer returns a buffer with the requested size.
// returns a buffer with the requested size.
// If possible, a slice from the existing buffer is returned.
// Otherwise a bigger buffer is made.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeBuffer(length int) ([]byte, error) {
func (b *buffer) takeBuffer(length int) []byte {
if b.length > 0 {
return nil, ErrBusyBuffer
return nil
}
// test (cheap) general case first
if length <= cap(b.buf) {
return b.buf[:length], nil
if length <= defaultBufSize || length <= cap(b.buf) {
return b.buf[:length]
}
if length < maxPacketSize {
b.buf = make([]byte, length)
return b.buf, nil
return b.buf
}
// buffer is larger than we want to store.
return make([]byte, length), nil
return make([]byte, length)
}
// takeSmallBuffer is shortcut which can be used if length is
// known to be smaller than defaultBufSize.
// shortcut which can be used if the requested buffer is guaranteed to be
// smaller than defaultBufSize
// Only one buffer (total) can be used at a time.
func (b *buffer) takeSmallBuffer(length int) ([]byte, error) {
func (b *buffer) takeSmallBuffer(length int) []byte {
if b.length > 0 {
return nil, ErrBusyBuffer
return nil
}
return b.buf[:length], nil
return b.buf[:length]
}
// takeCompleteBuffer returns the complete existing buffer.
// This can be used if the necessary buffer size is unknown.
// cap and len of the returned buffer will be equal.
// Only one buffer (total) can be used at a time.
func (b *buffer) takeCompleteBuffer() ([]byte, error) {
func (b *buffer) takeCompleteBuffer() []byte {
if b.length > 0 {
return nil, ErrBusyBuffer
return nil
}
return b.buf, nil
}
// store stores buf, an updated buffer, if its suitable to do so.
func (b *buffer) store(buf []byte) error {
if b.length > 0 {
return ErrBusyBuffer
} else if cap(buf) <= maxPacketSize && cap(buf) > cap(b.buf) {
b.buf = buf[:cap(buf)]
}
return nil
return b.buf
}

View file

@ -8,190 +8,183 @@
package mysql
const defaultCollation = "utf8mb4_general_ci"
const defaultCollation = "utf8_general_ci"
const binaryCollation = "binary"
// A list of available collations mapped to the internal ID.
// To update this map use the following MySQL query:
// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS WHERE ID<256 ORDER BY ID
//
// Handshake packet have only 1 byte for collation_id. So we can't use collations with ID > 255.
//
// ucs2, utf16, and utf32 can't be used for connection charset.
// https://dev.mysql.com/doc/refman/5.7/en/charset-connection.html#charset-connection-impermissible-client-charset
// They are commented out to reduce this map.
// SELECT COLLATION_NAME, ID FROM information_schema.COLLATIONS
var collations = map[string]byte{
"big5_chinese_ci": 1,
"latin2_czech_cs": 2,
"dec8_swedish_ci": 3,
"cp850_general_ci": 4,
"latin1_german1_ci": 5,
"hp8_english_ci": 6,
"koi8r_general_ci": 7,
"latin1_swedish_ci": 8,
"latin2_general_ci": 9,
"swe7_swedish_ci": 10,
"ascii_general_ci": 11,
"ujis_japanese_ci": 12,
"sjis_japanese_ci": 13,
"cp1251_bulgarian_ci": 14,
"latin1_danish_ci": 15,
"hebrew_general_ci": 16,
"tis620_thai_ci": 18,
"euckr_korean_ci": 19,
"latin7_estonian_cs": 20,
"latin2_hungarian_ci": 21,
"koi8u_general_ci": 22,
"cp1251_ukrainian_ci": 23,
"gb2312_chinese_ci": 24,
"greek_general_ci": 25,
"cp1250_general_ci": 26,
"latin2_croatian_ci": 27,
"gbk_chinese_ci": 28,
"cp1257_lithuanian_ci": 29,
"latin5_turkish_ci": 30,
"latin1_german2_ci": 31,
"armscii8_general_ci": 32,
"utf8_general_ci": 33,
"cp1250_czech_cs": 34,
//"ucs2_general_ci": 35,
"cp866_general_ci": 36,
"keybcs2_general_ci": 37,
"macce_general_ci": 38,
"macroman_general_ci": 39,
"cp852_general_ci": 40,
"latin7_general_ci": 41,
"latin7_general_cs": 42,
"macce_bin": 43,
"cp1250_croatian_ci": 44,
"utf8mb4_general_ci": 45,
"utf8mb4_bin": 46,
"latin1_bin": 47,
"latin1_general_ci": 48,
"latin1_general_cs": 49,
"cp1251_bin": 50,
"cp1251_general_ci": 51,
"cp1251_general_cs": 52,
"macroman_bin": 53,
//"utf16_general_ci": 54,
//"utf16_bin": 55,
//"utf16le_general_ci": 56,
"cp1256_general_ci": 57,
"cp1257_bin": 58,
"cp1257_general_ci": 59,
//"utf32_general_ci": 60,
//"utf32_bin": 61,
//"utf16le_bin": 62,
"binary": 63,
"armscii8_bin": 64,
"ascii_bin": 65,
"cp1250_bin": 66,
"cp1256_bin": 67,
"cp866_bin": 68,
"dec8_bin": 69,
"greek_bin": 70,
"hebrew_bin": 71,
"hp8_bin": 72,
"keybcs2_bin": 73,
"koi8r_bin": 74,
"koi8u_bin": 75,
"utf8_tolower_ci": 76,
"latin2_bin": 77,
"latin5_bin": 78,
"latin7_bin": 79,
"cp850_bin": 80,
"cp852_bin": 81,
"swe7_bin": 82,
"utf8_bin": 83,
"big5_bin": 84,
"euckr_bin": 85,
"gb2312_bin": 86,
"gbk_bin": 87,
"sjis_bin": 88,
"tis620_bin": 89,
//"ucs2_bin": 90,
"ujis_bin": 91,
"geostd8_general_ci": 92,
"geostd8_bin": 93,
"latin1_spanish_ci": 94,
"cp932_japanese_ci": 95,
"cp932_bin": 96,
"eucjpms_japanese_ci": 97,
"eucjpms_bin": 98,
"cp1250_polish_ci": 99,
//"utf16_unicode_ci": 101,
//"utf16_icelandic_ci": 102,
//"utf16_latvian_ci": 103,
//"utf16_romanian_ci": 104,
//"utf16_slovenian_ci": 105,
//"utf16_polish_ci": 106,
//"utf16_estonian_ci": 107,
//"utf16_spanish_ci": 108,
//"utf16_swedish_ci": 109,
//"utf16_turkish_ci": 110,
//"utf16_czech_ci": 111,
//"utf16_danish_ci": 112,
//"utf16_lithuanian_ci": 113,
//"utf16_slovak_ci": 114,
//"utf16_spanish2_ci": 115,
//"utf16_roman_ci": 116,
//"utf16_persian_ci": 117,
//"utf16_esperanto_ci": 118,
//"utf16_hungarian_ci": 119,
//"utf16_sinhala_ci": 120,
//"utf16_german2_ci": 121,
//"utf16_croatian_ci": 122,
//"utf16_unicode_520_ci": 123,
//"utf16_vietnamese_ci": 124,
//"ucs2_unicode_ci": 128,
//"ucs2_icelandic_ci": 129,
//"ucs2_latvian_ci": 130,
//"ucs2_romanian_ci": 131,
//"ucs2_slovenian_ci": 132,
//"ucs2_polish_ci": 133,
//"ucs2_estonian_ci": 134,
//"ucs2_spanish_ci": 135,
//"ucs2_swedish_ci": 136,
//"ucs2_turkish_ci": 137,
//"ucs2_czech_ci": 138,
//"ucs2_danish_ci": 139,
//"ucs2_lithuanian_ci": 140,
//"ucs2_slovak_ci": 141,
//"ucs2_spanish2_ci": 142,
//"ucs2_roman_ci": 143,
//"ucs2_persian_ci": 144,
//"ucs2_esperanto_ci": 145,
//"ucs2_hungarian_ci": 146,
//"ucs2_sinhala_ci": 147,
//"ucs2_german2_ci": 148,
//"ucs2_croatian_ci": 149,
//"ucs2_unicode_520_ci": 150,
//"ucs2_vietnamese_ci": 151,
//"ucs2_general_mysql500_ci": 159,
//"utf32_unicode_ci": 160,
//"utf32_icelandic_ci": 161,
//"utf32_latvian_ci": 162,
//"utf32_romanian_ci": 163,
//"utf32_slovenian_ci": 164,
//"utf32_polish_ci": 165,
//"utf32_estonian_ci": 166,
//"utf32_spanish_ci": 167,
//"utf32_swedish_ci": 168,
//"utf32_turkish_ci": 169,
//"utf32_czech_ci": 170,
//"utf32_danish_ci": 171,
//"utf32_lithuanian_ci": 172,
//"utf32_slovak_ci": 173,
//"utf32_spanish2_ci": 174,
//"utf32_roman_ci": 175,
//"utf32_persian_ci": 176,
//"utf32_esperanto_ci": 177,
//"utf32_hungarian_ci": 178,
//"utf32_sinhala_ci": 179,
//"utf32_german2_ci": 180,
//"utf32_croatian_ci": 181,
//"utf32_unicode_520_ci": 182,
//"utf32_vietnamese_ci": 183,
"big5_chinese_ci": 1,
"latin2_czech_cs": 2,
"dec8_swedish_ci": 3,
"cp850_general_ci": 4,
"latin1_german1_ci": 5,
"hp8_english_ci": 6,
"koi8r_general_ci": 7,
"latin1_swedish_ci": 8,
"latin2_general_ci": 9,
"swe7_swedish_ci": 10,
"ascii_general_ci": 11,
"ujis_japanese_ci": 12,
"sjis_japanese_ci": 13,
"cp1251_bulgarian_ci": 14,
"latin1_danish_ci": 15,
"hebrew_general_ci": 16,
"tis620_thai_ci": 18,
"euckr_korean_ci": 19,
"latin7_estonian_cs": 20,
"latin2_hungarian_ci": 21,
"koi8u_general_ci": 22,
"cp1251_ukrainian_ci": 23,
"gb2312_chinese_ci": 24,
"greek_general_ci": 25,
"cp1250_general_ci": 26,
"latin2_croatian_ci": 27,
"gbk_chinese_ci": 28,
"cp1257_lithuanian_ci": 29,
"latin5_turkish_ci": 30,
"latin1_german2_ci": 31,
"armscii8_general_ci": 32,
"utf8_general_ci": 33,
"cp1250_czech_cs": 34,
"ucs2_general_ci": 35,
"cp866_general_ci": 36,
"keybcs2_general_ci": 37,
"macce_general_ci": 38,
"macroman_general_ci": 39,
"cp852_general_ci": 40,
"latin7_general_ci": 41,
"latin7_general_cs": 42,
"macce_bin": 43,
"cp1250_croatian_ci": 44,
"utf8mb4_general_ci": 45,
"utf8mb4_bin": 46,
"latin1_bin": 47,
"latin1_general_ci": 48,
"latin1_general_cs": 49,
"cp1251_bin": 50,
"cp1251_general_ci": 51,
"cp1251_general_cs": 52,
"macroman_bin": 53,
"utf16_general_ci": 54,
"utf16_bin": 55,
"utf16le_general_ci": 56,
"cp1256_general_ci": 57,
"cp1257_bin": 58,
"cp1257_general_ci": 59,
"utf32_general_ci": 60,
"utf32_bin": 61,
"utf16le_bin": 62,
"binary": 63,
"armscii8_bin": 64,
"ascii_bin": 65,
"cp1250_bin": 66,
"cp1256_bin": 67,
"cp866_bin": 68,
"dec8_bin": 69,
"greek_bin": 70,
"hebrew_bin": 71,
"hp8_bin": 72,
"keybcs2_bin": 73,
"koi8r_bin": 74,
"koi8u_bin": 75,
"latin2_bin": 77,
"latin5_bin": 78,
"latin7_bin": 79,
"cp850_bin": 80,
"cp852_bin": 81,
"swe7_bin": 82,
"utf8_bin": 83,
"big5_bin": 84,
"euckr_bin": 85,
"gb2312_bin": 86,
"gbk_bin": 87,
"sjis_bin": 88,
"tis620_bin": 89,
"ucs2_bin": 90,
"ujis_bin": 91,
"geostd8_general_ci": 92,
"geostd8_bin": 93,
"latin1_spanish_ci": 94,
"cp932_japanese_ci": 95,
"cp932_bin": 96,
"eucjpms_japanese_ci": 97,
"eucjpms_bin": 98,
"cp1250_polish_ci": 99,
"utf16_unicode_ci": 101,
"utf16_icelandic_ci": 102,
"utf16_latvian_ci": 103,
"utf16_romanian_ci": 104,
"utf16_slovenian_ci": 105,
"utf16_polish_ci": 106,
"utf16_estonian_ci": 107,
"utf16_spanish_ci": 108,
"utf16_swedish_ci": 109,
"utf16_turkish_ci": 110,
"utf16_czech_ci": 111,
"utf16_danish_ci": 112,
"utf16_lithuanian_ci": 113,
"utf16_slovak_ci": 114,
"utf16_spanish2_ci": 115,
"utf16_roman_ci": 116,
"utf16_persian_ci": 117,
"utf16_esperanto_ci": 118,
"utf16_hungarian_ci": 119,
"utf16_sinhala_ci": 120,
"utf16_german2_ci": 121,
"utf16_croatian_ci": 122,
"utf16_unicode_520_ci": 123,
"utf16_vietnamese_ci": 124,
"ucs2_unicode_ci": 128,
"ucs2_icelandic_ci": 129,
"ucs2_latvian_ci": 130,
"ucs2_romanian_ci": 131,
"ucs2_slovenian_ci": 132,
"ucs2_polish_ci": 133,
"ucs2_estonian_ci": 134,
"ucs2_spanish_ci": 135,
"ucs2_swedish_ci": 136,
"ucs2_turkish_ci": 137,
"ucs2_czech_ci": 138,
"ucs2_danish_ci": 139,
"ucs2_lithuanian_ci": 140,
"ucs2_slovak_ci": 141,
"ucs2_spanish2_ci": 142,
"ucs2_roman_ci": 143,
"ucs2_persian_ci": 144,
"ucs2_esperanto_ci": 145,
"ucs2_hungarian_ci": 146,
"ucs2_sinhala_ci": 147,
"ucs2_german2_ci": 148,
"ucs2_croatian_ci": 149,
"ucs2_unicode_520_ci": 150,
"ucs2_vietnamese_ci": 151,
"ucs2_general_mysql500_ci": 159,
"utf32_unicode_ci": 160,
"utf32_icelandic_ci": 161,
"utf32_latvian_ci": 162,
"utf32_romanian_ci": 163,
"utf32_slovenian_ci": 164,
"utf32_polish_ci": 165,
"utf32_estonian_ci": 166,
"utf32_spanish_ci": 167,
"utf32_swedish_ci": 168,
"utf32_turkish_ci": 169,
"utf32_czech_ci": 170,
"utf32_danish_ci": 171,
"utf32_lithuanian_ci": 172,
"utf32_slovak_ci": 173,
"utf32_spanish2_ci": 174,
"utf32_roman_ci": 175,
"utf32_persian_ci": 176,
"utf32_esperanto_ci": 177,
"utf32_hungarian_ci": 178,
"utf32_sinhala_ci": 179,
"utf32_german2_ci": 180,
"utf32_croatian_ci": 181,
"utf32_unicode_520_ci": 182,
"utf32_vietnamese_ci": 183,
"utf8_unicode_ci": 192,
"utf8_icelandic_ci": 193,
"utf8_latvian_ci": 194,
@ -241,25 +234,18 @@ var collations = map[string]byte{
"utf8mb4_croatian_ci": 245,
"utf8mb4_unicode_520_ci": 246,
"utf8mb4_vietnamese_ci": 247,
"gb18030_chinese_ci": 248,
"gb18030_bin": 249,
"gb18030_unicode_520_ci": 250,
"utf8mb4_0900_ai_ci": 255,
}
// A blacklist of collations which is unsafe to interpolate parameters.
// These multibyte encodings may contains 0x5c (`\`) in their trailing bytes.
var unsafeCollations = map[string]bool{
"big5_chinese_ci": true,
"sjis_japanese_ci": true,
"gbk_chinese_ci": true,
"big5_bin": true,
"gb2312_bin": true,
"gbk_bin": true,
"sjis_bin": true,
"cp932_japanese_ci": true,
"cp932_bin": true,
"gb18030_chinese_ci": true,
"gb18030_bin": true,
"gb18030_unicode_520_ci": true,
"big5_chinese_ci": true,
"sjis_japanese_ci": true,
"gbk_chinese_ci": true,
"big5_bin": true,
"gb2312_bin": true,
"gbk_bin": true,
"sjis_bin": true,
"cp932_japanese_ci": true,
"cp932_bin": true,
}

View file

@ -1,53 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build !windows,!appengine
package mysql
import (
"errors"
"io"
"net"
"syscall"
)
var errUnexpectedRead = errors.New("unexpected read from socket")
func connCheck(c net.Conn) error {
var (
n int
err error
buff [1]byte
)
sconn, ok := c.(syscall.Conn)
if !ok {
return nil
}
rc, err := sconn.SyscallConn()
if err != nil {
return err
}
rerr := rc.Read(func(fd uintptr) bool {
n, err = syscall.Read(int(fd), buff[:])
return true
})
switch {
case rerr != nil:
return rerr
case n == 0 && err == nil:
return io.EOF
case n > 0:
return errUnexpectedRead
case err == syscall.EAGAIN || err == syscall.EWOULDBLOCK:
return nil
default:
return err
}
}

View file

@ -1,17 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2019 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build windows appengine
package mysql
import "net"
func connCheck(c net.Conn) error {
return nil
}

View file

@ -1,38 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build go1.10,!windows
package mysql
import (
"testing"
"time"
)
func TestStaleConnectionChecks(t *testing.T) {
runTests(t, dsn, func(dbt *DBTest) {
dbt.mustExec("SET @@SESSION.wait_timeout = 2")
if err := dbt.db.Ping(); err != nil {
dbt.Fatal(err)
}
// wait for MySQL to close our connection
time.Sleep(3 * time.Second)
tx, err := dbt.db.Begin()
if err != nil {
dbt.Fatal(err)
}
if err := tx.Rollback(); err != nil {
dbt.Fatal(err)
}
})
}

View file

@ -9,8 +9,6 @@
package mysql
import (
"context"
"database/sql"
"database/sql/driver"
"io"
"net"
@ -19,10 +17,19 @@ import (
"time"
)
// a copy of context.Context for Go 1.7 and earlier
type mysqlContext interface {
Done() <-chan struct{}
Err() error
// defined in context.Context, but not used in this driver:
// Deadline() (deadline time.Time, ok bool)
// Value(key interface{}) interface{}
}
type mysqlConn struct {
buf buffer
netConn net.Conn
rawConn net.Conn // underlying connection when netConn is TLS connection.
affectedRows uint64
insertId uint64
cfg *Config
@ -33,11 +40,10 @@ type mysqlConn struct {
status statusFlag
sequence uint8
parseTime bool
reset bool // set when the Go SQL package calls ResetSession
// for context support (Go 1.8+)
watching bool
watcher chan<- context.Context
watcher chan<- mysqlContext
closech chan struct{}
finished chan<- struct{}
canceled atomicError // set non-nil if conn is canceled
@ -184,10 +190,10 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
return "", driver.ErrSkip
}
buf, err := mc.buf.takeCompleteBuffer()
if err != nil {
buf := mc.buf.takeCompleteBuffer()
if buf == nil {
// can not take the buffer. Something must be wrong with the connection
errLog.Print(err)
errLog.Print(ErrBusyBuffer)
return "", ErrInvalidConn
}
buf = buf[:0]
@ -213,9 +219,6 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin
switch v := arg.(type) {
case int64:
buf = strconv.AppendInt(buf, v, 10)
case uint64:
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
buf = strconv.AppendUint(buf, v, 10)
case float64:
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
case bool:
@ -456,194 +459,3 @@ func (mc *mysqlConn) finish() {
case <-mc.closech:
}
}
// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
if err = mc.watchCancel(ctx); err != nil {
return
}
defer mc.finish()
if err = mc.writeCommandPacket(comPing); err != nil {
return mc.markBadConn(err)
}
return mc.readResultOK()
}
// BeginTx implements driver.ConnBeginTx interface
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
level, err := mapIsolationLevel(opts.Isolation)
if err != nil {
return nil, err
}
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
if err != nil {
return nil, err
}
}
return mc.begin(opts.ReadOnly)
}
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := mc.query(query, dargs)
if err != nil {
mc.finish()
return nil, err
}
rows.finish = mc.finish
return rows, err
}
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
return mc.Exec(query, dargs)
}
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
stmt, err := mc.Prepare(query)
mc.finish()
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
stmt.Close()
return nil, ctx.Err()
}
return stmt, nil
}
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := stmt.query(dargs)
if err != nil {
stmt.mc.finish()
return nil, err
}
rows.finish = stmt.mc.finish
return rows, err
}
func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
defer stmt.mc.finish()
return stmt.Exec(dargs)
}
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
if mc.watching {
// Reach here if canceled,
// so the connection is already invalid
mc.cleanup()
return nil
}
// When ctx is already cancelled, don't watch it.
if err := ctx.Err(); err != nil {
return err
}
// When ctx is not cancellable, don't watch it.
if ctx.Done() == nil {
return nil
}
// When watcher is not alive, can't watch it.
if mc.watcher == nil {
return nil
}
mc.watching = true
mc.watcher <- ctx
return nil
}
func (mc *mysqlConn) startWatcher() {
watcher := make(chan context.Context, 1)
mc.watcher = watcher
finished := make(chan struct{})
mc.finished = finished
go func() {
for {
var ctx context.Context
select {
case ctx = <-watcher:
case <-mc.closech:
return
}
select {
case <-ctx.Done():
mc.cancel(ctx.Err())
case <-finished:
case <-mc.closech:
return
}
}
}()
}
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = converter{}.ConvertValue(nv.Value)
return
}
// ResetSession implements driver.SessionResetter.
// (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.IsSet() {
return driver.ErrBadConn
}
mc.reset = true
return nil
}

View file

@ -0,0 +1,208 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build go1.8
package mysql
import (
"context"
"database/sql"
"database/sql/driver"
)
// Ping implements driver.Pinger interface
func (mc *mysqlConn) Ping(ctx context.Context) (err error) {
if mc.closed.IsSet() {
errLog.Print(ErrInvalidConn)
return driver.ErrBadConn
}
if err = mc.watchCancel(ctx); err != nil {
return
}
defer mc.finish()
if err = mc.writeCommandPacket(comPing); err != nil {
return
}
return mc.readResultOK()
}
// BeginTx implements driver.ConnBeginTx interface
func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) {
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault {
level, err := mapIsolationLevel(opts.Isolation)
if err != nil {
return nil, err
}
err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level)
if err != nil {
return nil, err
}
}
return mc.begin(opts.ReadOnly)
}
func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := mc.query(query, dargs)
if err != nil {
mc.finish()
return nil, err
}
rows.finish = mc.finish
return rows, err
}
func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
return mc.Exec(query, dargs)
}
func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
stmt, err := mc.Prepare(query)
mc.finish()
if err != nil {
return nil, err
}
select {
default:
case <-ctx.Done():
stmt.Close()
return nil, ctx.Err()
}
return stmt, nil
}
func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
rows, err := stmt.query(dargs)
if err != nil {
stmt.mc.finish()
return nil, err
}
rows.finish = stmt.mc.finish
return rows, err
}
func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
dargs, err := namedValueToValue(args)
if err != nil {
return nil, err
}
if err := stmt.mc.watchCancel(ctx); err != nil {
return nil, err
}
defer stmt.mc.finish()
return stmt.Exec(dargs)
}
func (mc *mysqlConn) watchCancel(ctx context.Context) error {
if mc.watching {
// Reach here if canceled,
// so the connection is already invalid
mc.cleanup()
return nil
}
if ctx.Done() == nil {
return nil
}
mc.watching = true
select {
default:
case <-ctx.Done():
return ctx.Err()
}
if mc.watcher == nil {
return nil
}
mc.watcher <- ctx
return nil
}
func (mc *mysqlConn) startWatcher() {
watcher := make(chan mysqlContext, 1)
mc.watcher = watcher
finished := make(chan struct{})
mc.finished = finished
go func() {
for {
var ctx mysqlContext
select {
case ctx = <-watcher:
case <-mc.closech:
return
}
select {
case <-ctx.Done():
mc.cancel(ctx.Err())
case <-finished:
case <-mc.closech:
return
}
}
}()
}
func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = converter{}.ConvertValue(nv.Value)
return
}
// ResetSession implements driver.SessionResetter.
// (From Go 1.10)
func (mc *mysqlConn) ResetSession(ctx context.Context) error {
if mc.closed.IsSet() {
return driver.ErrBadConn
}
return nil
}

View file

@ -1,175 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"context"
"database/sql/driver"
"errors"
"net"
"testing"
)
func TestInterpolateParams(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42), "gopher"})
if err != nil {
t.Errorf("Expected err=nil, got %#v", err)
return
}
expected := `SELECT 42+'gopher'`
if q != expected {
t.Errorf("Expected: %q\nGot: %q", expected, q)
}
}
func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}
q, err := mc.interpolateParams("SELECT ?+?", []driver.Value{int64(42)})
if err != driver.ErrSkip {
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
}
}
// We don't support placeholder in string literal for now.
// https://github.com/go-sql-driver/mysql/pull/490
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}
q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
if err != driver.ErrSkip {
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
}
}
func TestInterpolateParamsUint64(t *testing.T) {
mc := &mysqlConn{
buf: newBuffer(nil),
maxAllowedPacket: maxPacketSize,
cfg: &Config{
InterpolateParams: true,
},
}
q, err := mc.interpolateParams("SELECT ?", []driver.Value{uint64(42)})
if err != nil {
t.Errorf("Expected err=nil, got err=%#v, q=%#v", err, q)
}
if q != "SELECT 42" {
t.Errorf("Expected uint64 interpolation to work, got q=%#v", q)
}
}
func TestCheckNamedValue(t *testing.T) {
value := driver.NamedValue{Value: ^uint64(0)}
x := &mysqlConn{}
err := x.CheckNamedValue(&value)
if err != nil {
t.Fatal("uint64 high-bit not convertible", err)
}
if value.Value != ^uint64(0) {
t.Fatalf("uint64 high-bit converted, got %#v %T", value.Value, value.Value)
}
}
// TestCleanCancel tests passed context is cancelled at start.
// No packet should be sent. Connection should keep current status.
func TestCleanCancel(t *testing.T) {
mc := &mysqlConn{
closech: make(chan struct{}),
}
mc.startWatcher()
defer mc.cleanup()
ctx, cancel := context.WithCancel(context.Background())
cancel()
for i := 0; i < 3; i++ { // Repeat same behavior
err := mc.Ping(ctx)
if err != context.Canceled {
t.Errorf("expected context.Canceled, got %#v", err)
}
if mc.closed.IsSet() {
t.Error("expected mc is not closed, closed actually")
}
if mc.watching {
t.Error("expected watching is false, but true")
}
}
}
func TestPingMarkBadConnection(t *testing.T) {
nc := badConnection{err: errors.New("boom")}
ms := &mysqlConn{
netConn: nc,
buf: newBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
}
err := ms.Ping(context.Background())
if err != driver.ErrBadConn {
t.Errorf("expected driver.ErrBadConn, got %#v", err)
}
}
func TestPingErrInvalidConn(t *testing.T) {
nc := badConnection{err: errors.New("failed to write"), n: 10}
ms := &mysqlConn{
netConn: nc,
buf: newBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
}
err := ms.Ping(context.Background())
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %#v", err)
}
}
type badConnection struct {
n int
err error
net.Conn
}
func (bc badConnection) Write(b []byte) (n int, err error) {
return bc.n, bc.err
}
func (bc badConnection) Close() error {
return nil
}

View file

@ -1,143 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"context"
"database/sql/driver"
"net"
)
type connector struct {
cfg *Config // immutable private copy.
}
// Connect implements driver.Connector interface.
// Connect returns a connection to the database.
func (c *connector) Connect(ctx context.Context) (driver.Conn, error) {
var err error
// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
cfg: c.cfg,
}
mc.parseTime = mc.cfg.ParseTime
// Connect to Server
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
mc.netConn, err = dial(ctx, mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.DialContext(ctx, mc.cfg.Net, mc.cfg.Addr)
}
if err != nil {
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
errLog.Print("net.Error from Dial()': ", nerr.Error())
return nil, driver.ErrBadConn
}
return nil, err
}
// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
if err := tc.SetKeepAlive(true); err != nil {
// Don't send COM_QUIT before handshake.
mc.netConn.Close()
mc.netConn = nil
return nil, err
}
}
// Call startWatcher for context support (From Go 1.8)
mc.startWatcher()
if err := mc.watchCancel(ctx); err != nil {
return nil, err
}
defer mc.finish()
mc.buf = newBuffer(mc.netConn)
// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout
// Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket()
if err != nil {
mc.cleanup()
return nil, err
}
if plugin == "" {
plugin = defaultAuthPlugin
}
// Send Client Authentication Packet
authResp, err := mc.auth(authData, plugin)
if err != nil {
// try the default auth plugin, if using the requested plugin failed
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
plugin = defaultAuthPlugin
authResp, err = mc.auth(authData, plugin)
if err != nil {
mc.cleanup()
return nil, err
}
}
if err = mc.writeHandshakeResponsePacket(authResp, plugin); err != nil {
mc.cleanup()
return nil, err
}
// Handle response to auth packet, switch methods if possible
if err = mc.handleAuthResult(authData, plugin); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// Do not send COM_QUIT, just cleanup and return the error.
mc.cleanup()
return nil, err
}
if mc.cfg.MaxAllowedPacket > 0 {
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
} else {
// Get max allowed packet size
maxap, err := mc.getSystemVar("max_allowed_packet")
if err != nil {
mc.Close()
return nil, err
}
mc.maxAllowedPacket = stringToInt(maxap) - 1
}
if mc.maxAllowedPacket < maxPacketSize {
mc.maxWriteSize = mc.maxAllowedPacket
}
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}
return mc, nil
}
// Driver implements driver.Connector interface.
// Driver returns &MySQLDriver{}.
func (c *connector) Driver() driver.Driver {
return &MySQLDriver{}
}

View file

@ -17,67 +17,151 @@
package mysql
import (
"context"
"database/sql"
"database/sql/driver"
"net"
"sync"
)
// watcher interface is used for context support (From Go 1.8)
type watcher interface {
startWatcher()
}
// MySQLDriver is exported to make the driver directly accessible.
// In general the driver is used via the database/sql package.
type MySQLDriver struct{}
// DialFunc is a function which can be used to establish the network connection.
// Custom dial functions must be registered with RegisterDial
//
// Deprecated: users should register a DialContextFunc instead
type DialFunc func(addr string) (net.Conn, error)
// DialContextFunc is a function which can be used to establish the network connection.
// Custom dial functions must be registered with RegisterDialContext
type DialContextFunc func(ctx context.Context, addr string) (net.Conn, error)
var (
dialsLock sync.RWMutex
dials map[string]DialContextFunc
dials map[string]DialFunc
)
// RegisterDialContext registers a custom dial function. It can then be used by the
// network address mynet(addr), where mynet is the registered new network.
// The current context for the connection and its address is passed to the dial function.
func RegisterDialContext(net string, dial DialContextFunc) {
dialsLock.Lock()
defer dialsLock.Unlock()
if dials == nil {
dials = make(map[string]DialContextFunc)
}
dials[net] = dial
}
// RegisterDial registers a custom dial function. It can then be used by the
// network address mynet(addr), where mynet is the registered new network.
// addr is passed as a parameter to the dial function.
//
// Deprecated: users should call RegisterDialContext instead
func RegisterDial(network string, dial DialFunc) {
RegisterDialContext(network, func(_ context.Context, addr string) (net.Conn, error) {
return dial(addr)
})
func RegisterDial(net string, dial DialFunc) {
dialsLock.Lock()
defer dialsLock.Unlock()
if dials == nil {
dials = make(map[string]DialFunc)
}
dials[net] = dial
}
// Open new Connection.
// See https://github.com/go-sql-driver/mysql#dsn-data-source-name for how
// the DSN string is formatted
// the DSN string is formated
func (d MySQLDriver) Open(dsn string) (driver.Conn, error) {
cfg, err := ParseDSN(dsn)
var err error
// New mysqlConn
mc := &mysqlConn{
maxAllowedPacket: maxPacketSize,
maxWriteSize: maxPacketSize - 1,
closech: make(chan struct{}),
}
mc.cfg, err = ParseDSN(dsn)
if err != nil {
return nil, err
}
c := &connector{
cfg: cfg,
mc.parseTime = mc.cfg.ParseTime
// Connect to Server
dialsLock.RLock()
dial, ok := dials[mc.cfg.Net]
dialsLock.RUnlock()
if ok {
mc.netConn, err = dial(mc.cfg.Addr)
} else {
nd := net.Dialer{Timeout: mc.cfg.Timeout}
mc.netConn, err = nd.Dial(mc.cfg.Net, mc.cfg.Addr)
}
return c.Connect(context.Background())
if err != nil {
return nil, err
}
// Enable TCP Keepalives on TCP connections
if tc, ok := mc.netConn.(*net.TCPConn); ok {
if err := tc.SetKeepAlive(true); err != nil {
// Don't send COM_QUIT before handshake.
mc.netConn.Close()
mc.netConn = nil
return nil, err
}
}
// Call startWatcher for context support (From Go 1.8)
if s, ok := interface{}(mc).(watcher); ok {
s.startWatcher()
}
mc.buf = newBuffer(mc.netConn)
// Set I/O timeouts
mc.buf.timeout = mc.cfg.ReadTimeout
mc.writeTimeout = mc.cfg.WriteTimeout
// Reading Handshake Initialization Packet
authData, plugin, err := mc.readHandshakePacket()
if err != nil {
mc.cleanup()
return nil, err
}
// Send Client Authentication Packet
authResp, addNUL, err := mc.auth(authData, plugin)
if err != nil {
// try the default auth plugin, if using the requested plugin failed
errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error())
plugin = defaultAuthPlugin
authResp, addNUL, err = mc.auth(authData, plugin)
if err != nil {
mc.cleanup()
return nil, err
}
}
if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil {
mc.cleanup()
return nil, err
}
// Handle response to auth packet, switch methods if possible
if err = mc.handleAuthResult(authData, plugin); err != nil {
// Authentication failed and MySQL has already closed the connection
// (https://dev.mysql.com/doc/internals/en/authentication-fails.html).
// Do not send COM_QUIT, just cleanup and return the error.
mc.cleanup()
return nil, err
}
if mc.cfg.MaxAllowedPacket > 0 {
mc.maxAllowedPacket = mc.cfg.MaxAllowedPacket
} else {
// Get max allowed packet size
maxap, err := mc.getSystemVar("max_allowed_packet")
if err != nil {
mc.Close()
return nil, err
}
mc.maxAllowedPacket = stringToInt(maxap) - 1
}
if mc.maxAllowedPacket < maxPacketSize {
mc.maxWriteSize = mc.maxAllowedPacket
}
// Handle DSN Params
err = mc.handleParams()
if err != nil {
mc.Close()
return nil, err
}
return mc, nil
}
func init() {

View file

@ -1,37 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build go1.10
package mysql
import (
"database/sql/driver"
)
// NewConnector returns new driver.Connector.
func NewConnector(cfg *Config) (driver.Connector, error) {
cfg = cfg.Clone()
// normalize the contents of cfg so calls to NewConnector have the same
// behavior as MySQLDriver.OpenConnector
if err := cfg.normalize(); err != nil {
return nil, err
}
return &connector{cfg: cfg}, nil
}
// OpenConnector implements driver.DriverContext.
func (d MySQLDriver) OpenConnector(dsn string) (driver.Connector, error) {
cfg, err := ParseDSN(dsn)
if err != nil {
return nil, err
}
return &connector{
cfg: cfg,
}, nil
}

View file

@ -1,137 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build go1.10
package mysql
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"net"
"testing"
"time"
)
var _ driver.DriverContext = &MySQLDriver{}
type dialCtxKey struct{}
func TestConnectorObeysDialTimeouts(t *testing.T) {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}
RegisterDialContext("dialctxtest", func(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer
if !ctx.Value(dialCtxKey{}).(bool) {
return nil, fmt.Errorf("test error: query context is not propagated to our dialer")
}
return d.DialContext(ctx, prot, addr)
})
db, err := sql.Open("mysql", fmt.Sprintf("%s:%s@dialctxtest(%s)/%s?timeout=30s", user, pass, addr, dbname))
if err != nil {
t.Fatalf("error connecting: %s", err.Error())
}
defer db.Close()
ctx := context.WithValue(context.Background(), dialCtxKey{}, true)
_, err = db.ExecContext(ctx, "DO 1")
if err != nil {
t.Fatal(err)
}
}
func configForTests(t *testing.T) *Config {
if !available {
t.Skipf("MySQL server not running on %s", netAddr)
}
mycnf := NewConfig()
mycnf.User = user
mycnf.Passwd = pass
mycnf.Addr = addr
mycnf.Net = prot
mycnf.DBName = dbname
return mycnf
}
func TestNewConnector(t *testing.T) {
mycnf := configForTests(t)
conn, err := NewConnector(mycnf)
if err != nil {
t.Fatal(err)
}
db := sql.OpenDB(conn)
defer db.Close()
if err := db.Ping(); err != nil {
t.Fatal(err)
}
}
type slowConnection struct {
net.Conn
slowdown time.Duration
}
func (sc *slowConnection) Read(b []byte) (int, error) {
time.Sleep(sc.slowdown)
return sc.Conn.Read(b)
}
type connectorHijack struct {
driver.Connector
connErr error
}
func (cw *connectorHijack) Connect(ctx context.Context) (driver.Conn, error) {
var conn driver.Conn
conn, cw.connErr = cw.Connector.Connect(ctx)
return conn, cw.connErr
}
func TestConnectorTimeoutsDuringOpen(t *testing.T) {
RegisterDialContext("slowconn", func(ctx context.Context, addr string) (net.Conn, error) {
var d net.Dialer
conn, err := d.DialContext(ctx, prot, addr)
if err != nil {
return nil, err
}
return &slowConnection{Conn: conn, slowdown: 100 * time.Millisecond}, nil
})
mycnf := configForTests(t)
mycnf.Net = "slowconn"
conn, err := NewConnector(mycnf)
if err != nil {
t.Fatal(err)
}
hijack := &connectorHijack{Connector: conn}
db := sql.OpenDB(hijack)
defer db.Close()
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, err = db.ExecContext(ctx, "DO 1")
if err != context.DeadlineExceeded {
t.Fatalf("ExecContext should have timed out")
}
if hijack.connErr != context.DeadlineExceeded {
t.Fatalf("(*Connector).Connect should have timed out")
}
}

File diff suppressed because it is too large Load diff

View file

@ -14,7 +14,6 @@ import (
"crypto/tls"
"errors"
"fmt"
"math/big"
"net"
"net/url"
"sort"
@ -73,26 +72,6 @@ func NewConfig() *Config {
}
}
func (cfg *Config) Clone() *Config {
cp := *cfg
if cp.tls != nil {
cp.tls = cfg.tls.Clone()
}
if len(cp.Params) > 0 {
cp.Params = make(map[string]string, len(cfg.Params))
for k, v := range cfg.Params {
cp.Params[k] = v
}
}
if cfg.pubKey != nil {
cp.pubKey = &rsa.PublicKey{
N: new(big.Int).Set(cfg.pubKey.N),
E: cfg.pubKey.E,
}
}
return &cp
}
func (cfg *Config) normalize() error {
if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
return errInvalidDSNUnsafeCollation
@ -113,35 +92,17 @@ func (cfg *Config) normalize() error {
default:
return errors.New("default addr for network '" + cfg.Net + "' unknown")
}
} else if cfg.Net == "tcp" {
cfg.Addr = ensureHavePort(cfg.Addr)
}
switch cfg.TLSConfig {
case "false", "":
// don't set anything
case "true":
cfg.tls = &tls.Config{}
case "skip-verify", "preferred":
cfg.tls = &tls.Config{InsecureSkipVerify: true}
default:
cfg.tls = getTLSConfigClone(cfg.TLSConfig)
if cfg.tls == nil {
return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
}
}
if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
host, _, err := net.SplitHostPort(cfg.Addr)
if err == nil {
cfg.tls.ServerName = host
}
}
if cfg.ServerPubKey != "" {
cfg.pubKey = getServerPubKey(cfg.ServerPubKey)
if cfg.pubKey == nil {
return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey)
if cfg.tls != nil {
if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
host, _, err := net.SplitHostPort(cfg.Addr)
if err == nil {
cfg.tls.ServerName = host
}
}
}
@ -570,7 +531,13 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if err != nil {
return fmt.Errorf("invalid value for server pub key name: %v", err)
}
cfg.ServerPubKey = name
if pubKey := getServerPubKey(name); pubKey != nil {
cfg.ServerPubKey = name
cfg.pubKey = pubKey
} else {
return errors.New("invalid value / unknown server pub key name: " + name)
}
// Strict mode
case "strict":
@ -589,17 +556,25 @@ func parseDSNParams(cfg *Config, params string) (err error) {
if isBool {
if boolValue {
cfg.TLSConfig = "true"
cfg.tls = &tls.Config{}
} else {
cfg.TLSConfig = "false"
}
} else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
} else if vl := strings.ToLower(value); vl == "skip-verify" {
cfg.TLSConfig = vl
cfg.tls = &tls.Config{InsecureSkipVerify: true}
} else {
name, err := url.QueryUnescape(value)
if err != nil {
return fmt.Errorf("invalid value for TLS config name: %v", err)
}
cfg.TLSConfig = name
if tlsConfig := getTLSConfigClone(name); tlsConfig != nil {
cfg.TLSConfig = name
cfg.tls = tlsConfig
} else {
return errors.New("invalid value / unknown config name: " + name)
}
}
// I/O write Timeout

View file

@ -1,415 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"crypto/tls"
"fmt"
"net/url"
"reflect"
"testing"
"time"
)
var testDSNs = []struct {
in string
out *Config
}{{
"username:password@protocol(address)/dbname?param=value",
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true",
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true},
}, {
"username:password@protocol(address)/dbname?param=value&columnsWithAlias=true&multiStatements=true",
&Config{User: "username", Passwd: "password", Net: "protocol", Addr: "address", DBName: "dbname", Params: map[string]string{"param": "value"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, ColumnsWithAlias: true, MultiStatements: true},
}, {
"user@unix(/path/to/socket)/dbname?charset=utf8",
&Config{User: "user", Net: "unix", Addr: "/path/to/socket", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"user:password@tcp(localhost:5555)/dbname?charset=utf8&tls=true",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "true"},
}, {
"user:password@tcp(localhost:5555)/dbname?charset=utf8mb4,utf8&tls=skip-verify",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "localhost:5555", DBName: "dbname", Params: map[string]string{"charset": "utf8mb4,utf8"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true, TLSConfig: "skip-verify"},
}, {
"user:password@/dbname?loc=UTC&timeout=30s&readTimeout=1s&writeTimeout=1s&allowAllFiles=1&clientFoundRows=true&allowOldPasswords=TRUE&collation=utf8mb4_unicode_ci&maxAllowedPacket=16777216&tls=false&allowCleartextPasswords=true&parseTime=true&rejectReadOnly=true",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_unicode_ci", Loc: time.UTC, TLSConfig: "false", AllowCleartextPasswords: true, AllowNativePasswords: true, Timeout: 30 * time.Second, ReadTimeout: time.Second, WriteTimeout: time.Second, AllowAllFiles: true, AllowOldPasswords: true, ClientFoundRows: true, MaxAllowedPacket: 16777216, ParseTime: true, RejectReadOnly: true},
}, {
"user:password@/dbname?allowNativePasswords=false&maxAllowedPacket=0",
&Config{User: "user", Passwd: "password", Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: 0, AllowNativePasswords: false},
}, {
"user:p@ss(word)@tcp([de:ad:be:ef::ca:fe]:80)/dbname?loc=Local",
&Config{User: "user", Passwd: "p@ss(word)", Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:80", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.Local, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"/dbname",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"@/",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"/",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"user:p@/ssword@/",
&Config{User: "user", Passwd: "p@/ssword", Net: "tcp", Addr: "127.0.0.1:3306", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"unix/?arg=%2Fsome%2Fpath.ext",
&Config{Net: "unix", Addr: "/tmp/mysql.sock", Params: map[string]string{"arg": "/some/path.ext"}, Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"tcp(127.0.0.1)/dbname",
&Config{Net: "tcp", Addr: "127.0.0.1:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
}, {
"tcp(de:ad:be:ef::ca:fe)/dbname",
&Config{Net: "tcp", Addr: "[de:ad:be:ef::ca:fe]:3306", DBName: "dbname", Collation: "utf8mb4_general_ci", Loc: time.UTC, MaxAllowedPacket: defaultMaxAllowedPacket, AllowNativePasswords: true},
},
}
func TestDSNParser(t *testing.T) {
for i, tst := range testDSNs {
cfg, err := ParseDSN(tst.in)
if err != nil {
t.Error(err.Error())
}
// pointer not static
cfg.tls = nil
if !reflect.DeepEqual(cfg, tst.out) {
t.Errorf("%d. ParseDSN(%q) mismatch:\ngot %+v\nwant %+v", i, tst.in, cfg, tst.out)
}
}
}
func TestDSNParserInvalid(t *testing.T) {
var invalidDSNs = []string{
"@net(addr/", // no closing brace
"@tcp(/", // no closing brace
"tcp(/", // no closing brace
"(/", // no closing brace
"net(addr)//", // unescaped
"User:pass@tcp(1.2.3.4:3306)", // no trailing slash
"net()/", // unknown default addr
//"/dbname?arg=/some/unescaped/path",
}
for i, tst := range invalidDSNs {
if _, err := ParseDSN(tst); err == nil {
t.Errorf("invalid DSN #%d. (%s) didn't error!", i, tst)
}
}
}
func TestDSNReformat(t *testing.T) {
for i, tst := range testDSNs {
dsn1 := tst.in
cfg1, err := ParseDSN(dsn1)
if err != nil {
t.Error(err.Error())
continue
}
cfg1.tls = nil // pointer not static
res1 := fmt.Sprintf("%+v", cfg1)
dsn2 := cfg1.FormatDSN()
cfg2, err := ParseDSN(dsn2)
if err != nil {
t.Error(err.Error())
continue
}
cfg2.tls = nil // pointer not static
res2 := fmt.Sprintf("%+v", cfg2)
if res1 != res2 {
t.Errorf("%d. %q does not match %q", i, res2, res1)
}
}
}
func TestDSNServerPubKey(t *testing.T) {
baseDSN := "User:password@tcp(localhost:5555)/dbname?serverPubKey="
RegisterServerPubKey("testKey", testPubKeyRSA)
defer DeregisterServerPubKey("testKey")
tst := baseDSN + "testKey"
cfg, err := ParseDSN(tst)
if err != nil {
t.Error(err.Error())
}
if cfg.ServerPubKey != "testKey" {
t.Errorf("unexpected cfg.ServerPubKey value: %v", cfg.ServerPubKey)
}
if cfg.pubKey != testPubKeyRSA {
t.Error("pub key pointer doesn't match")
}
// Key is missing
tst = baseDSN + "invalid_name"
cfg, err = ParseDSN(tst)
if err == nil {
t.Errorf("invalid name in DSN (%s) but did not error. Got config: %#v", tst, cfg)
}
}
func TestDSNServerPubKeyQueryEscape(t *testing.T) {
const name = "&%!:"
dsn := "User:password@tcp(localhost:5555)/dbname?serverPubKey=" + url.QueryEscape(name)
RegisterServerPubKey(name, testPubKeyRSA)
defer DeregisterServerPubKey(name)
cfg, err := ParseDSN(dsn)
if err != nil {
t.Error(err.Error())
}
if cfg.pubKey != testPubKeyRSA {
t.Error("pub key pointer doesn't match")
}
}
func TestDSNWithCustomTLS(t *testing.T) {
baseDSN := "User:password@tcp(localhost:5555)/dbname?tls="
tlsCfg := tls.Config{}
RegisterTLSConfig("utils_test", &tlsCfg)
defer DeregisterTLSConfig("utils_test")
// Custom TLS is missing
tst := baseDSN + "invalid_tls"
cfg, err := ParseDSN(tst)
if err == nil {
t.Errorf("invalid custom TLS in DSN (%s) but did not error. Got config: %#v", tst, cfg)
}
tst = baseDSN + "utils_test"
// Custom TLS with a server name
name := "foohost"
tlsCfg.ServerName = name
cfg, err = ParseDSN(tst)
if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, tst)
}
// Custom TLS without a server name
name = "localhost"
tlsCfg.ServerName = ""
cfg, err = ParseDSN(tst)
if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
t.Errorf("did not get the correct ServerName (%s) parsing DSN (%s).", name, tst)
} else if tlsCfg.ServerName != "" {
t.Errorf("tlsCfg was mutated ServerName (%s) should be empty parsing DSN (%s).", name, tst)
}
}
func TestDSNTLSConfig(t *testing.T) {
expectedServerName := "example.com"
dsn := "tcp(example.com:1234)/?tls=true"
cfg, err := ParseDSN(dsn)
if err != nil {
t.Error(err.Error())
}
if cfg.tls == nil {
t.Error("cfg.tls should not be nil")
}
if cfg.tls.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
}
dsn = "tcp(example.com)/?tls=true"
cfg, err = ParseDSN(dsn)
if err != nil {
t.Error(err.Error())
}
if cfg.tls == nil {
t.Error("cfg.tls should not be nil")
}
if cfg.tls.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host without port)", expectedServerName, cfg.tls.ServerName)
}
}
func TestDSNWithCustomTLSQueryEscape(t *testing.T) {
const configKey = "&%!:"
dsn := "User:password@tcp(localhost:5555)/dbname?tls=" + url.QueryEscape(configKey)
name := "foohost"
tlsCfg := tls.Config{ServerName: name}
RegisterTLSConfig(configKey, &tlsCfg)
defer DeregisterTLSConfig(configKey)
cfg, err := ParseDSN(dsn)
if err != nil {
t.Error(err.Error())
} else if cfg.tls.ServerName != name {
t.Errorf("did not get the correct TLS ServerName (%s) parsing DSN (%s).", name, dsn)
}
}
func TestDSNUnsafeCollation(t *testing.T) {
_, err := ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=true")
if err != errInvalidDSNUnsafeCollation {
t.Errorf("expected %v, got %v", errInvalidDSNUnsafeCollation, err)
}
_, err = ParseDSN("/dbname?collation=gbk_chinese_ci&interpolateParams=false")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=gbk_chinese_ci")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=ascii_bin&interpolateParams=true")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=latin1_german1_ci&interpolateParams=true")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=utf8_general_ci&interpolateParams=true")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
_, err = ParseDSN("/dbname?collation=utf8mb4_general_ci&interpolateParams=true")
if err != nil {
t.Errorf("expected %v, got %v", nil, err)
}
}
func TestParamsAreSorted(t *testing.T) {
expected := "/dbname?interpolateParams=true&foobar=baz&quux=loo"
cfg := NewConfig()
cfg.DBName = "dbname"
cfg.InterpolateParams = true
cfg.Params = map[string]string{
"quux": "loo",
"foobar": "baz",
}
actual := cfg.FormatDSN()
if actual != expected {
t.Errorf("generic Config.Params were not sorted: want %#v, got %#v", expected, actual)
}
}
func TestCloneConfig(t *testing.T) {
RegisterServerPubKey("testKey", testPubKeyRSA)
defer DeregisterServerPubKey("testKey")
expectedServerName := "example.com"
dsn := "tcp(example.com:1234)/?tls=true&foobar=baz&serverPubKey=testKey"
cfg, err := ParseDSN(dsn)
if err != nil {
t.Fatal(err.Error())
}
cfg2 := cfg.Clone()
if cfg == cfg2 {
t.Errorf("Config.Clone did not create a separate config struct")
}
if cfg2.tls.ServerName != expectedServerName {
t.Errorf("cfg.tls.ServerName should be %q, got %q (host with port)", expectedServerName, cfg.tls.ServerName)
}
cfg2.tls.ServerName = "example2.com"
if cfg.tls.ServerName == cfg2.tls.ServerName {
t.Errorf("changed cfg.tls.Server name should not propagate to original Config")
}
if _, ok := cfg2.Params["foobar"]; !ok {
t.Errorf("cloned Config is missing custom params")
}
delete(cfg2.Params, "foobar")
if _, ok := cfg.Params["foobar"]; !ok {
t.Errorf("custom params in cloned Config should not propagate to original Config")
}
if !reflect.DeepEqual(cfg.pubKey, cfg2.pubKey) {
t.Errorf("public key in Config should be identical")
}
}
func TestNormalizeTLSConfig(t *testing.T) {
tt := []struct {
tlsConfig string
want *tls.Config
}{
{"", nil},
{"false", nil},
{"true", &tls.Config{ServerName: "myserver"}},
{"skip-verify", &tls.Config{InsecureSkipVerify: true}},
{"preferred", &tls.Config{InsecureSkipVerify: true}},
{"test_tls_config", &tls.Config{ServerName: "myServerName"}},
}
RegisterTLSConfig("test_tls_config", &tls.Config{ServerName: "myServerName"})
defer func() { DeregisterTLSConfig("test_tls_config") }()
for _, tc := range tt {
t.Run(tc.tlsConfig, func(t *testing.T) {
cfg := &Config{
Addr: "myserver:3306",
TLSConfig: tc.tlsConfig,
}
cfg.normalize()
if cfg.tls == nil {
if tc.want != nil {
t.Fatal("wanted a tls config but got nil instead")
}
return
}
if cfg.tls.ServerName != tc.want.ServerName {
t.Errorf("tls.ServerName doesn't match (want: '%s', got: '%s')",
tc.want.ServerName, cfg.tls.ServerName)
}
if cfg.tls.InsecureSkipVerify != tc.want.InsecureSkipVerify {
t.Errorf("tls.InsecureSkipVerify doesn't match (want: %T, got :%T)",
tc.want.InsecureSkipVerify, cfg.tls.InsecureSkipVerify)
}
})
}
}
func BenchmarkParseDSN(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
for _, tst := range testDSNs {
if _, err := ParseDSN(tst.in); err != nil {
b.Error(err.Error())
}
}
}
}

View file

@ -1,42 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"log"
"testing"
)
func TestErrorsSetLogger(t *testing.T) {
previous := errLog
defer func() {
errLog = previous
}()
// set up logger
const expected = "prefix: test\n"
buffer := bytes.NewBuffer(make([]byte, 0, 64))
logger := log.New(buffer, "prefix: ", 0)
// print
SetLogger(logger)
errLog.Print("test")
// check result
if actual := buffer.String(); actual != expected {
t.Errorf("expected %q, got %q", expected, actual)
}
}
func TestErrorsStrictIgnoreNotes(t *testing.T) {
runTests(t, dsn+"&sql_notes=false", func(dbt *DBTest) {
dbt.mustExec("DROP TABLE IF EXISTS does_not_exist")
})
}

View file

@ -51,7 +51,7 @@ func (mc *mysqlConn) readPacket() ([]byte, error) {
mc.sequence++
// packets with length 0 terminate a previous packet which is a
// multiple of (2^24)-1 bytes long
// multiple of (2^24)1 bytes long
if pktLen == 0 {
// there was no previous packet
if prevData == nil {
@ -96,35 +96,6 @@ func (mc *mysqlConn) writePacket(data []byte) error {
return ErrPktTooLarge
}
// Perform a stale connection check. We only perform this check for
// the first query on a connection that has been checked out of the
// connection pool: a fresh connection from the pool is more likely
// to be stale, and it has not performed any previous writes that
// could cause data corruption, so it's safe to return ErrBadConn
// if the check fails.
if mc.reset {
mc.reset = false
conn := mc.netConn
if mc.rawConn != nil {
conn = mc.rawConn
}
var err error
// If this connection has a ReadTimeout which we've been setting on
// reads, reset it to its default value before we attempt a non-blocking
// read, otherwise the scheduler will just time us out before we can read
if mc.cfg.ReadTimeout != 0 {
err = conn.SetReadDeadline(time.Time{})
}
if err == nil {
err = connCheck(conn)
}
if err != nil {
errLog.Print("closing bad idle connection: ", err)
mc.Close()
return driver.ErrBadConn
}
}
for {
var size int
if pktLen >= maxPacketSize {
@ -183,15 +154,15 @@ func (mc *mysqlConn) writePacket(data []byte) error {
// Handshake Initialization Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err error) {
data, err = mc.readPacket()
func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) {
data, err := mc.readPacket()
if err != nil {
// for init we can rewrite this to ErrBadConn for sql.Driver to retry, since
// in connection initialization we don't risk retrying non-idempotent actions.
if err == ErrInvalidConn {
return nil, "", driver.ErrBadConn
}
return
return nil, "", err
}
if data[0] == iERR {
@ -223,14 +194,11 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
return nil, "", ErrOldProtocol
}
if mc.flags&clientSSL == 0 && mc.cfg.tls != nil {
if mc.cfg.TLSConfig == "preferred" {
mc.cfg.tls = nil
} else {
return nil, "", ErrNoTLS
}
return nil, "", ErrNoTLS
}
pos += 2
plugin := ""
if len(data) > pos {
// character set [1 byte]
// status flags [2 bytes]
@ -268,6 +236,8 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
return b[:], plugin, nil
}
plugin = defaultAuthPlugin
// make a memory safe copy of the cipher slice
var b [8]byte
copy(b[:], authData)
@ -276,7 +246,7 @@ func (mc *mysqlConn) readHandshakePacket() (data []byte, plugin string, err erro
// Client Authentication Packet
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string) error {
func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error {
// Adjust client flags based on server support
clientFlags := clientProtocol41 |
clientSecureConn |
@ -302,8 +272,7 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// encode length of the auth plugin data
var authRespLEIBuf [9]byte
authRespLen := len(authResp)
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(authRespLen))
authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp)))
if len(authRespLEI) > 1 {
// if the length can not be written in 1 byte, it must be written as a
// length encoded integer
@ -311,6 +280,9 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
}
pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1
if addNUL {
pktLen++
}
// To specify a db name
if n := len(mc.cfg.DBName); n > 0 {
@ -319,10 +291,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
}
// Calculate packet length and get buffer with that size
data, err := mc.buf.takeSmallBuffer(pktLen + 4)
if err != nil {
data := mc.buf.takeSmallBuffer(pktLen + 4)
if data == nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(err)
errLog.Print(ErrBusyBuffer)
return errBadConnNoWrite
}
@ -361,7 +333,6 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
if err := tlsConn.Handshake(); err != nil {
return err
}
mc.rawConn = mc.netConn
mc.netConn = tlsConn
mc.buf.nc = tlsConn
}
@ -382,6 +353,10 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
// Auth Data [length encoded integer]
pos += copy(data[pos:], authRespLEI)
pos += copy(data[pos:], authResp)
if addNUL {
data[pos] = 0x00
pos++
}
// Databasename [null terminated string]
if len(mc.cfg.DBName) > 0 {
@ -392,24 +367,30 @@ func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, plugin string
pos += copy(data[pos:], plugin)
data[pos] = 0x00
pos++
// Send Auth packet
return mc.writePacket(data[:pos])
return mc.writePacket(data)
}
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte) error {
func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error {
pktLen := 4 + len(authData)
data, err := mc.buf.takeSmallBuffer(pktLen)
if err != nil {
if addNUL {
pktLen++
}
data := mc.buf.takeSmallBuffer(pktLen)
if data == nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(err)
errLog.Print(ErrBusyBuffer)
return errBadConnNoWrite
}
// Add the auth data [EOF]
copy(data[4:], authData)
if addNUL {
data[pktLen-1] = 0x00
}
return mc.writePacket(data)
}
@ -421,10 +402,10 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error {
// Reset Packet Sequence
mc.sequence = 0
data, err := mc.buf.takeSmallBuffer(4 + 1)
if err != nil {
data := mc.buf.takeSmallBuffer(4 + 1)
if data == nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(err)
errLog.Print(ErrBusyBuffer)
return errBadConnNoWrite
}
@ -440,10 +421,10 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error {
mc.sequence = 0
pktLen := 1 + len(arg)
data, err := mc.buf.takeBuffer(pktLen + 4)
if err != nil {
data := mc.buf.takeBuffer(pktLen + 4)
if data == nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(err)
errLog.Print(ErrBusyBuffer)
return errBadConnNoWrite
}
@ -461,10 +442,10 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error {
// Reset Packet Sequence
mc.sequence = 0
data, err := mc.buf.takeSmallBuffer(4 + 1 + 4)
if err != nil {
data := mc.buf.takeSmallBuffer(4 + 1 + 4)
if data == nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(err)
errLog.Print(ErrBusyBuffer)
return errBadConnNoWrite
}
@ -501,7 +482,7 @@ func (mc *mysqlConn) readAuthResult() ([]byte, string, error) {
return data[1:], "", err
case iEOF:
if len(data) == 1 {
if len(data) < 1 {
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest
return nil, "mysql_old_password", nil
}
@ -917,7 +898,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
const minPktLen = 4 + 1 + 4 + 1 + 4
mc := stmt.mc
// Determine threshold dynamically to avoid packet size shortage.
// Determine threshould dynamically to avoid packet size shortage.
longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1)
if longDataSize < 64 {
longDataSize = 64
@ -927,17 +908,15 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
mc.sequence = 0
var data []byte
var err error
if len(args) == 0 {
data, err = mc.buf.takeBuffer(minPktLen)
data = mc.buf.takeBuffer(minPktLen)
} else {
data, err = mc.buf.takeCompleteBuffer()
// In this case the len(data) == cap(data) which is used to optimise the flow below.
data = mc.buf.takeCompleteBuffer()
}
if err != nil {
if data == nil {
// cannot take the buffer. Something must be wrong with the connection
errLog.Print(err)
errLog.Print(ErrBusyBuffer)
return errBadConnNoWrite
}
@ -963,7 +942,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
pos := minPktLen
var nullMask []byte
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= cap(data) {
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
// buffer has to be extended but we don't know by how much so
// we depend on append after all data with known sizes fit.
// We stop at that because we deal with a lot of columns here
@ -972,11 +951,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
copy(tmp[:pos], data[:pos])
data = tmp
nullMask = data[pos : pos+maskLen]
// No need to clean nullMask as make ensures that.
pos += maskLen
} else {
nullMask = data[pos : pos+maskLen]
for i := range nullMask {
for i := 0; i < maskLen; i++ {
nullMask[i] = 0
}
pos += maskLen
@ -1021,22 +999,6 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
)
}
case uint64:
paramTypes[i+i] = byte(fieldTypeLongLong)
paramTypes[i+i+1] = 0x80 // type is unsigned
if cap(paramValues)-len(paramValues)-8 >= 0 {
paramValues = paramValues[:len(paramValues)+8]
binary.LittleEndian.PutUint64(
paramValues[len(paramValues)-8:],
uint64(v),
)
} else {
paramValues = append(paramValues,
uint64ToBytes(uint64(v))...,
)
}
case float64:
paramTypes[i+i] = byte(fieldTypeDouble)
paramTypes[i+i+1] = 0x00
@ -1129,10 +1091,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
// In that case we must build the data packet with the new values buffer
if valuesCap != cap(paramValues) {
data = append(data[:pos], paramValues...)
if err = mc.buf.store(data); err != nil {
errLog.Print(err)
return errBadConnNoWrite
}
mc.buf.buf = data
}
pos += len(paramValues)
@ -1302,7 +1261,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
rows.rs.columns[i].decimals,
)
}
dest[i], err = formatBinaryTime(data[pos:pos+int(num)], dstlen)
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true)
case rows.mc.parseTime:
dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc)
default:
@ -1322,7 +1281,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error {
)
}
}
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen)
dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, false)
}
if err == nil {

View file

@ -1,336 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"errors"
"net"
"testing"
"time"
)
var (
errConnClosed = errors.New("connection is closed")
errConnTooManyReads = errors.New("too many reads")
errConnTooManyWrites = errors.New("too many writes")
)
// struct to mock a net.Conn for testing purposes
type mockConn struct {
laddr net.Addr
raddr net.Addr
data []byte
written []byte
queuedReplies [][]byte
closed bool
read int
reads int
writes int
maxReads int
maxWrites int
}
func (m *mockConn) Read(b []byte) (n int, err error) {
if m.closed {
return 0, errConnClosed
}
m.reads++
if m.maxReads > 0 && m.reads > m.maxReads {
return 0, errConnTooManyReads
}
n = copy(b, m.data)
m.read += n
m.data = m.data[n:]
return
}
func (m *mockConn) Write(b []byte) (n int, err error) {
if m.closed {
return 0, errConnClosed
}
m.writes++
if m.maxWrites > 0 && m.writes > m.maxWrites {
return 0, errConnTooManyWrites
}
n = len(b)
m.written = append(m.written, b...)
if n > 0 && len(m.queuedReplies) > 0 {
m.data = m.queuedReplies[0]
m.queuedReplies = m.queuedReplies[1:]
}
return
}
func (m *mockConn) Close() error {
m.closed = true
return nil
}
func (m *mockConn) LocalAddr() net.Addr {
return m.laddr
}
func (m *mockConn) RemoteAddr() net.Addr {
return m.raddr
}
func (m *mockConn) SetDeadline(t time.Time) error {
return nil
}
func (m *mockConn) SetReadDeadline(t time.Time) error {
return nil
}
func (m *mockConn) SetWriteDeadline(t time.Time) error {
return nil
}
// make sure mockConn implements the net.Conn interface
var _ net.Conn = new(mockConn)
func newRWMockConn(sequence uint8) (*mockConn, *mysqlConn) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
cfg: NewConfig(),
netConn: conn,
closech: make(chan struct{}),
maxAllowedPacket: defaultMaxAllowedPacket,
sequence: sequence,
}
return conn, mc
}
func TestReadPacketSingleByte(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
}
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
conn.maxReads = 1
packet, err := mc.readPacket()
if err != nil {
t.Fatal(err)
}
if len(packet) != 1 {
t.Fatalf("unexpected packet length: expected %d, got %d", 1, len(packet))
}
if packet[0] != 0xff {
t.Fatalf("unexpected packet content: expected %x, got %x", 0xff, packet[0])
}
}
func TestReadPacketWrongSequenceID(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
}
// too low sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x00, 0xff}
conn.maxReads = 1
mc.sequence = 1
_, err := mc.readPacket()
if err != ErrPktSync {
t.Errorf("expected ErrPktSync, got %v", err)
}
// reset
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
// too high sequence id
conn.data = []byte{0x01, 0x00, 0x00, 0x42, 0xff}
_, err = mc.readPacket()
if err != ErrPktSyncMul {
t.Errorf("expected ErrPktSyncMul, got %v", err)
}
}
func TestReadPacketSplit(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
}
data := make([]byte, maxPacketSize*2+4*3)
const pkt2ofs = maxPacketSize + 4
const pkt3ofs = 2 * (maxPacketSize + 4)
// case 1: payload has length maxPacketSize
data = data[:pkt2ofs+4]
// 1st packet has maxPacketSize length and sequence id 0
// ff ff ff 00 ...
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff
// mark the payload start and end of 1st packet so that we can check if the
// content was correctly appended
data[4] = 0x11
data[maxPacketSize+3] = 0x22
// 2nd packet has payload length 0 and squence id 1
// 00 00 00 01
data[pkt2ofs+3] = 0x01
conn.data = data
conn.maxReads = 3
packet, err := mc.readPacket()
if err != nil {
t.Fatal(err)
}
if len(packet) != maxPacketSize {
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize, len(packet))
}
if packet[0] != 0x11 {
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
}
if packet[maxPacketSize-1] != 0x22 {
t.Fatalf("unexpected payload end: expected %x, got %x", 0x22, packet[maxPacketSize-1])
}
// case 2: payload has length which is a multiple of maxPacketSize
data = data[:cap(data)]
// 2nd packet now has maxPacketSize length
data[pkt2ofs] = 0xff
data[pkt2ofs+1] = 0xff
data[pkt2ofs+2] = 0xff
// mark the payload start and end of the 2nd packet
data[pkt2ofs+4] = 0x33
data[pkt2ofs+maxPacketSize+3] = 0x44
// 3rd packet has payload length 0 and squence id 2
// 00 00 00 02
data[pkt3ofs+3] = 0x02
conn.data = data
conn.reads = 0
conn.maxReads = 5
mc.sequence = 0
packet, err = mc.readPacket()
if err != nil {
t.Fatal(err)
}
if len(packet) != 2*maxPacketSize {
t.Fatalf("unexpected packet length: expected %d, got %d", 2*maxPacketSize, len(packet))
}
if packet[0] != 0x11 {
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
}
if packet[2*maxPacketSize-1] != 0x44 {
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[2*maxPacketSize-1])
}
// case 3: payload has a length larger maxPacketSize, which is not an exact
// multiple of it
data = data[:pkt2ofs+4+42]
data[pkt2ofs] = 0x2a
data[pkt2ofs+1] = 0x00
data[pkt2ofs+2] = 0x00
data[pkt2ofs+4+41] = 0x44
conn.data = data
conn.reads = 0
conn.maxReads = 4
mc.sequence = 0
packet, err = mc.readPacket()
if err != nil {
t.Fatal(err)
}
if len(packet) != maxPacketSize+42 {
t.Fatalf("unexpected packet length: expected %d, got %d", maxPacketSize+42, len(packet))
}
if packet[0] != 0x11 {
t.Fatalf("unexpected payload start: expected %x, got %x", 0x11, packet[0])
}
if packet[maxPacketSize+41] != 0x44 {
t.Fatalf("unexpected payload end: expected %x, got %x", 0x44, packet[maxPacketSize+41])
}
}
func TestReadPacketFail(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
closech: make(chan struct{}),
}
// illegal empty (stand-alone) packet
conn.data = []byte{0x00, 0x00, 0x00, 0x00}
conn.maxReads = 1
_, err := mc.readPacket()
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %v", err)
}
// reset
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
// fail to read header
conn.closed = true
_, err = mc.readPacket()
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %v", err)
}
// reset
conn.closed = false
conn.reads = 0
mc.sequence = 0
mc.buf = newBuffer(conn)
// fail to read body
conn.maxReads = 1
_, err = mc.readPacket()
if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %v", err)
}
}
// https://github.com/go-sql-driver/mysql/pull/801
// not-NUL terminated plugin_name in init packet
func TestRegression801(t *testing.T) {
conn := new(mockConn)
mc := &mysqlConn{
buf: newBuffer(conn),
cfg: new(Config),
sequence: 42,
closech: make(chan struct{}),
}
conn.data = []byte{72, 0, 0, 42, 10, 53, 46, 53, 46, 56, 0, 165, 0, 0, 0,
60, 70, 63, 58, 68, 104, 34, 97, 0, 223, 247, 33, 2, 0, 15, 128, 21, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 98, 120, 114, 47, 85, 75, 109, 99, 51, 77,
50, 64, 0, 109, 121, 115, 113, 108, 95, 110, 97, 116, 105, 118, 101, 95,
112, 97, 115, 115, 119, 111, 114, 100}
conn.maxReads = 1
authData, pluginName, err := mc.readHandshakePacket()
if err != nil {
t.Fatalf("got error: %v", err)
}
if pluginName != "mysql_native_password" {
t.Errorf("expected plugin name 'mysql_native_password', got '%s'", pluginName)
}
expectedAuthData := []byte{60, 70, 63, 58, 68, 104, 34, 97, 98, 120, 114,
47, 85, 75, 109, 99, 51, 77, 50, 64}
if !bytes.Equal(authData, expectedAuthData) {
t.Errorf("expected authData '%v', got '%v'", expectedAuthData, authData)
}
}

View file

@ -111,13 +111,6 @@ func (rows *mysqlRows) Close() (err error) {
return err
}
// flip the buffer for this connection if we need to drain it.
// note that for a successful query (i.e. one where rows.next()
// has been called until it returns false), `rows.mc` will be nil
// by the time the user calls `(*Rows).Close`, so we won't reach this
// see: https://github.com/golang/go/commit/651ddbdb5056ded455f47f9c494c67b389622a47
mc.buf.flip()
// Remove unread packets from stream
if !rows.rs.done {
err = mc.readUntilEOF()

View file

@ -13,6 +13,7 @@ import (
"fmt"
"io"
"reflect"
"strconv"
)
type mysqlStmt struct {
@ -163,8 +164,14 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) {
}
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return rv.Int(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return rv.Uint(), nil
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
return int64(rv.Uint()), nil
case reflect.Uint64:
u64 := rv.Uint()
if u64 >= 1<<63 {
return strconv.FormatUint(u64, 10), nil
}
return int64(u64), nil
case reflect.Float32, reflect.Float64:
return rv.Float(), nil
case reflect.Bool:

View file

@ -1,126 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"testing"
)
func TestConvertDerivedString(t *testing.T) {
type derived string
output, err := converter{}.ConvertValue(derived("value"))
if err != nil {
t.Fatal("Derived string type not convertible", err)
}
if output != "value" {
t.Fatalf("Derived string type not converted, got %#v %T", output, output)
}
}
func TestConvertDerivedByteSlice(t *testing.T) {
type derived []uint8
output, err := converter{}.ConvertValue(derived("value"))
if err != nil {
t.Fatal("Byte slice not convertible", err)
}
if bytes.Compare(output.([]byte), []byte("value")) != 0 {
t.Fatalf("Byte slice not converted, got %#v %T", output, output)
}
}
func TestConvertDerivedUnsupportedSlice(t *testing.T) {
type derived []int
_, err := converter{}.ConvertValue(derived{1})
if err == nil || err.Error() != "unsupported type mysql.derived, a slice of int" {
t.Fatal("Unexpected error", err)
}
}
func TestConvertDerivedBool(t *testing.T) {
type derived bool
output, err := converter{}.ConvertValue(derived(true))
if err != nil {
t.Fatal("Derived bool type not convertible", err)
}
if output != true {
t.Fatalf("Derived bool type not converted, got %#v %T", output, output)
}
}
func TestConvertPointer(t *testing.T) {
str := "value"
output, err := converter{}.ConvertValue(&str)
if err != nil {
t.Fatal("Pointer type not convertible", err)
}
if output != "value" {
t.Fatalf("Pointer type not converted, got %#v %T", output, output)
}
}
func TestConvertSignedIntegers(t *testing.T) {
values := []interface{}{
int8(-42),
int16(-42),
int32(-42),
int64(-42),
int(-42),
}
for _, value := range values {
output, err := converter{}.ConvertValue(value)
if err != nil {
t.Fatalf("%T type not convertible %s", value, err)
}
if output != int64(-42) {
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
}
}
}
func TestConvertUnsignedIntegers(t *testing.T) {
values := []interface{}{
uint8(42),
uint16(42),
uint32(42),
uint64(42),
uint(42),
}
for _, value := range values {
output, err := converter{}.ConvertValue(value)
if err != nil {
t.Fatalf("%T type not convertible %s", value, err)
}
if output != uint64(42) {
t.Fatalf("%T type not converted, got %#v %T", value, output, output)
}
}
output, err := converter{}.ConvertValue(^uint64(0))
if err != nil {
t.Fatal("uint64 high-bit not convertible", err)
}
if output != ^uint64(0) {
t.Fatalf("uint64 high-bit converted, got %#v %T", output, output)
}
}

View file

@ -10,13 +10,10 @@ package mysql
import (
"crypto/tls"
"database/sql"
"database/sql/driver"
"encoding/binary"
"errors"
"fmt"
"io"
"strconv"
"strings"
"sync"
"sync/atomic"
@ -56,7 +53,7 @@ var (
// db, err := sql.Open("mysql", "user@tcp(localhost:3306)/test?tls=custom")
//
func RegisterTLSConfig(key string, config *tls.Config) error {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" || strings.ToLower(key) == "preferred" {
if _, isBool := readBool(key); isBool || strings.ToLower(key) == "skip-verify" {
return fmt.Errorf("key '%s' is reserved", key)
}
@ -82,7 +79,7 @@ func DeregisterTLSConfig(key string) {
func getTLSConfigClone(key string) (config *tls.Config) {
tlsConfigLock.RLock()
if v, ok := tlsConfigRegistry[key]; ok {
config = v.Clone()
config = cloneTLSConfig(v)
}
tlsConfigLock.RUnlock()
return
@ -230,104 +227,87 @@ var zeroDateTime = []byte("0000-00-00 00:00:00.000000")
const digits01 = "0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789"
const digits10 = "0000000000111111111122222222223333333333444444444455555555556666666666777777777788888888889999999999"
func appendMicrosecs(dst, src []byte, decimals int) []byte {
if decimals <= 0 {
return dst
}
if len(src) == 0 {
return append(dst, ".000000"[:decimals+1]...)
}
microsecs := binary.LittleEndian.Uint32(src[:4])
p1 := byte(microsecs / 10000)
microsecs -= 10000 * uint32(p1)
p2 := byte(microsecs / 100)
microsecs -= 100 * uint32(p2)
p3 := byte(microsecs)
switch decimals {
default:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
digits10[p3], digits01[p3],
)
case 1:
return append(dst, '.',
digits10[p1],
)
case 2:
return append(dst, '.',
digits10[p1], digits01[p1],
)
case 3:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2],
)
case 4:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
)
case 5:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
digits10[p3],
)
}
}
func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) {
func formatBinaryDateTime(src []byte, length uint8, justTime bool) (driver.Value, error) {
// length expects the deterministic length of the zero value,
// negative time and 100+ hours are automatically added if needed
if len(src) == 0 {
if justTime {
return zeroDateTime[11 : 11+length], nil
}
return zeroDateTime[:length], nil
}
var dst []byte // return value
var p1, p2, p3 byte // current digit pair
switch length {
case 10, 19, 21, 22, 23, 24, 25, 26:
default:
t := "DATE"
if length > 10 {
t += "TIME"
var dst []byte // return value
var pt, p1, p2, p3 byte // current digit pair
var zOffs byte // offset of value in zeroDateTime
if justTime {
switch length {
case
8, // time (can be up to 10 when negative and 100+ hours)
10, 11, 12, 13, 14, 15: // time with fractional seconds
default:
return nil, fmt.Errorf("illegal TIME length %d", length)
}
return nil, fmt.Errorf("illegal %s length %d", t, length)
}
switch len(src) {
case 4, 7, 11:
default:
t := "DATE"
if length > 10 {
t += "TIME"
switch len(src) {
case 8, 12:
default:
return nil, fmt.Errorf("invalid TIME packet length %d", len(src))
}
return nil, fmt.Errorf("illegal %s packet length %d", t, len(src))
// +2 to enable negative time and 100+ hours
dst = make([]byte, 0, length+2)
if src[0] == 1 {
dst = append(dst, '-')
}
if src[1] != 0 {
hour := uint16(src[1])*24 + uint16(src[5])
pt = byte(hour / 100)
p1 = byte(hour - 100*uint16(pt))
dst = append(dst, digits01[pt])
} else {
p1 = src[5]
}
zOffs = 11
src = src[6:]
} else {
switch length {
case 10, 19, 21, 22, 23, 24, 25, 26:
default:
t := "DATE"
if length > 10 {
t += "TIME"
}
return nil, fmt.Errorf("illegal %s length %d", t, length)
}
switch len(src) {
case 4, 7, 11:
default:
t := "DATE"
if length > 10 {
t += "TIME"
}
return nil, fmt.Errorf("illegal %s packet length %d", t, len(src))
}
dst = make([]byte, 0, length)
// start with the date
year := binary.LittleEndian.Uint16(src[:2])
pt = byte(year / 100)
p1 = byte(year - 100*uint16(pt))
p2, p3 = src[2], src[3]
dst = append(dst,
digits10[pt], digits01[pt],
digits10[p1], digits01[p1], '-',
digits10[p2], digits01[p2], '-',
digits10[p3], digits01[p3],
)
if length == 10 {
return dst, nil
}
if len(src) == 4 {
return append(dst, zeroDateTime[10:length]...), nil
}
dst = append(dst, ' ')
p1 = src[4] // hour
src = src[5:]
}
dst = make([]byte, 0, length)
// start with the date
year := binary.LittleEndian.Uint16(src[:2])
pt := year / 100
p1 = byte(year - 100*uint16(pt))
p2, p3 = src[2], src[3]
dst = append(dst,
digits10[pt], digits01[pt],
digits10[p1], digits01[p1], '-',
digits10[p2], digits01[p2], '-',
digits10[p3], digits01[p3],
)
if length == 10 {
return dst, nil
}
if len(src) == 4 {
return append(dst, zeroDateTime[10:length]...), nil
}
dst = append(dst, ' ')
p1 = src[4] // hour
src = src[5:]
// p1 is 2-digit hour, src is after hour
p2, p3 = src[0], src[1]
dst = append(dst,
@ -335,49 +315,51 @@ func formatBinaryDateTime(src []byte, length uint8) (driver.Value, error) {
digits10[p2], digits01[p2], ':',
digits10[p3], digits01[p3],
)
return appendMicrosecs(dst, src[2:], int(length)-20), nil
}
func formatBinaryTime(src []byte, length uint8) (driver.Value, error) {
// length expects the deterministic length of the zero value,
// negative time and 100+ hours are automatically added if needed
if length <= byte(len(dst)) {
return dst, nil
}
src = src[2:]
if len(src) == 0 {
return zeroDateTime[11 : 11+length], nil
return append(dst, zeroDateTime[19:zOffs+length]...), nil
}
var dst []byte // return value
switch length {
case
8, // time (can be up to 10 when negative and 100+ hours)
10, 11, 12, 13, 14, 15: // time with fractional seconds
microsecs := binary.LittleEndian.Uint32(src[:4])
p1 = byte(microsecs / 10000)
microsecs -= 10000 * uint32(p1)
p2 = byte(microsecs / 100)
microsecs -= 100 * uint32(p2)
p3 = byte(microsecs)
switch decimals := zOffs + length - 20; decimals {
default:
return nil, fmt.Errorf("illegal TIME length %d", length)
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
digits10[p3], digits01[p3],
), nil
case 1:
return append(dst, '.',
digits10[p1],
), nil
case 2:
return append(dst, '.',
digits10[p1], digits01[p1],
), nil
case 3:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2],
), nil
case 4:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
), nil
case 5:
return append(dst, '.',
digits10[p1], digits01[p1],
digits10[p2], digits01[p2],
digits10[p3],
), nil
}
switch len(src) {
case 8, 12:
default:
return nil, fmt.Errorf("invalid TIME packet length %d", len(src))
}
// +2 to enable negative time and 100+ hours
dst = make([]byte, 0, length+2)
if src[0] == 1 {
dst = append(dst, '-')
}
days := binary.LittleEndian.Uint32(src[1:5])
hours := int64(days)*24 + int64(src[5])
if hours >= 100 {
dst = strconv.AppendInt(dst, hours, 10)
} else {
dst = append(dst, digits10[hours], digits01[hours])
}
min, sec := src[6], src[7]
dst = append(dst, ':',
digits10[min], digits01[min], ':',
digits10[sec], digits01[sec],
)
return appendMicrosecs(dst, src[8:], int(length)-9), nil
}
/******************************************************************************
@ -684,7 +666,7 @@ type atomicBool struct {
value uint32
}
// IsSet returns whether the current boolean value is true
// IsSet returns wether the current boolean value is true
func (ab *atomicBool) IsSet() bool {
return atomic.LoadUint32(&ab.value) > 0
}
@ -698,7 +680,7 @@ func (ab *atomicBool) Set(value bool) {
}
}
// TrySet sets the value of the bool and returns whether the value changed
// TrySet sets the value of the bool and returns wether the value changed
func (ab *atomicBool) TrySet(value bool) bool {
if value {
return atomic.SwapUint32(&ab.value, 1) == 0
@ -726,30 +708,3 @@ func (ae *atomicError) Value() error {
}
return nil
}
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
// TODO: support the use of Named Parameters #561
return nil, errors.New("mysql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}
func mapIsolationLevel(level driver.IsolationLevel) (string, error) {
switch sql.IsolationLevel(level) {
case sql.LevelRepeatableRead:
return "REPEATABLE READ", nil
case sql.LevelReadCommitted:
return "READ COMMITTED", nil
case sql.LevelReadUncommitted:
return "READ UNCOMMITTED", nil
case sql.LevelSerializable:
return "SERIALIZABLE", nil
default:
return "", fmt.Errorf("mysql: unsupported isolation level: %v", level)
}
}

40
vendor/github.com/go-sql-driver/mysql/utils_go17.go generated vendored Normal file
View file

@ -0,0 +1,40 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build go1.7
// +build !go1.8
package mysql
import "crypto/tls"
func cloneTLSConfig(c *tls.Config) *tls.Config {
return &tls.Config{
Rand: c.Rand,
Time: c.Time,
Certificates: c.Certificates,
NameToCertificate: c.NameToCertificate,
GetCertificate: c.GetCertificate,
RootCAs: c.RootCAs,
NextProtos: c.NextProtos,
ServerName: c.ServerName,
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
SessionTicketsDisabled: c.SessionTicketsDisabled,
SessionTicketKey: c.SessionTicketKey,
ClientSessionCache: c.ClientSessionCache,
MinVersion: c.MinVersion,
MaxVersion: c.MaxVersion,
CurvePreferences: c.CurvePreferences,
DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled,
Renegotiation: c.Renegotiation,
}
}

50
vendor/github.com/go-sql-driver/mysql/utils_go18.go generated vendored Normal file
View file

@ -0,0 +1,50 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// +build go1.8
package mysql
import (
"crypto/tls"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
)
func cloneTLSConfig(c *tls.Config) *tls.Config {
return c.Clone()
}
func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) {
dargs := make([]driver.Value, len(named))
for n, param := range named {
if len(param.Name) > 0 {
// TODO: support the use of Named Parameters #561
return nil, errors.New("mysql: driver does not support the use of Named Parameters")
}
dargs[n] = param.Value
}
return dargs, nil
}
func mapIsolationLevel(level driver.IsolationLevel) (string, error) {
switch sql.IsolationLevel(level) {
case sql.LevelRepeatableRead:
return "REPEATABLE READ", nil
case sql.LevelReadCommitted:
return "READ COMMITTED", nil
case sql.LevelReadUncommitted:
return "READ UNCOMMITTED", nil
case sql.LevelSerializable:
return "SERIALIZABLE", nil
default:
return "", fmt.Errorf("mysql: unsupported isolation level: %v", level)
}
}

View file

@ -1,334 +0,0 @@
// Go MySQL Driver - A MySQL-Driver for Go's database/sql package
//
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
package mysql
import (
"bytes"
"database/sql"
"database/sql/driver"
"encoding/binary"
"testing"
"time"
)
func TestScanNullTime(t *testing.T) {
var scanTests = []struct {
in interface{}
error bool
valid bool
time time.Time
}{
{tDate, false, true, tDate},
{sDate, false, true, tDate},
{[]byte(sDate), false, true, tDate},
{tDateTime, false, true, tDateTime},
{sDateTime, false, true, tDateTime},
{[]byte(sDateTime), false, true, tDateTime},
{tDate0, false, true, tDate0},
{sDate0, false, true, tDate0},
{[]byte(sDate0), false, true, tDate0},
{sDateTime0, false, true, tDate0},
{[]byte(sDateTime0), false, true, tDate0},
{"", true, false, tDate0},
{"1234", true, false, tDate0},
{0, true, false, tDate0},
}
var nt = NullTime{}
var err error
for _, tst := range scanTests {
err = nt.Scan(tst.in)
if (err != nil) != tst.error {
t.Errorf("%v: expected error status %t, got %t", tst.in, tst.error, (err != nil))
}
if nt.Valid != tst.valid {
t.Errorf("%v: expected valid status %t, got %t", tst.in, tst.valid, nt.Valid)
}
if nt.Time != tst.time {
t.Errorf("%v: expected time %v, got %v", tst.in, tst.time, nt.Time)
}
}
}
func TestLengthEncodedInteger(t *testing.T) {
var integerTests = []struct {
num uint64
encoded []byte
}{
{0x0000000000000000, []byte{0x00}},
{0x0000000000000012, []byte{0x12}},
{0x00000000000000fa, []byte{0xfa}},
{0x0000000000000100, []byte{0xfc, 0x00, 0x01}},
{0x0000000000001234, []byte{0xfc, 0x34, 0x12}},
{0x000000000000ffff, []byte{0xfc, 0xff, 0xff}},
{0x0000000000010000, []byte{0xfd, 0x00, 0x00, 0x01}},
{0x0000000000123456, []byte{0xfd, 0x56, 0x34, 0x12}},
{0x0000000000ffffff, []byte{0xfd, 0xff, 0xff, 0xff}},
{0x0000000001000000, []byte{0xfe, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00}},
{0x123456789abcdef0, []byte{0xfe, 0xf0, 0xde, 0xbc, 0x9a, 0x78, 0x56, 0x34, 0x12}},
{0xffffffffffffffff, []byte{0xfe, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}},
}
for _, tst := range integerTests {
num, isNull, numLen := readLengthEncodedInteger(tst.encoded)
if isNull {
t.Errorf("%x: expected %d, got NULL", tst.encoded, tst.num)
}
if num != tst.num {
t.Errorf("%x: expected %d, got %d", tst.encoded, tst.num, num)
}
if numLen != len(tst.encoded) {
t.Errorf("%x: expected size %d, got %d", tst.encoded, len(tst.encoded), numLen)
}
encoded := appendLengthEncodedInteger(nil, num)
if !bytes.Equal(encoded, tst.encoded) {
t.Errorf("%v: expected %x, got %x", num, tst.encoded, encoded)
}
}
}
func TestFormatBinaryDateTime(t *testing.T) {
rawDate := [11]byte{}
binary.LittleEndian.PutUint16(rawDate[:2], 1978) // years
rawDate[2] = 12 // months
rawDate[3] = 30 // days
rawDate[4] = 15 // hours
rawDate[5] = 46 // minutes
rawDate[6] = 23 // seconds
binary.LittleEndian.PutUint32(rawDate[7:], 987654) // microseconds
expect := func(expected string, inlen, outlen uint8) {
actual, _ := formatBinaryDateTime(rawDate[:inlen], outlen)
bytes, ok := actual.([]byte)
if !ok {
t.Errorf("formatBinaryDateTime must return []byte, was %T", actual)
}
if string(bytes) != expected {
t.Errorf(
"expected %q, got %q for length in %d, out %d",
expected, actual, inlen, outlen,
)
}
}
expect("0000-00-00", 0, 10)
expect("0000-00-00 00:00:00", 0, 19)
expect("1978-12-30", 4, 10)
expect("1978-12-30 15:46:23", 7, 19)
expect("1978-12-30 15:46:23.987654", 11, 26)
}
func TestFormatBinaryTime(t *testing.T) {
expect := func(expected string, src []byte, outlen uint8) {
actual, _ := formatBinaryTime(src, outlen)
bytes, ok := actual.([]byte)
if !ok {
t.Errorf("formatBinaryDateTime must return []byte, was %T", actual)
}
if string(bytes) != expected {
t.Errorf(
"expected %q, got %q for src=%q and outlen=%d",
expected, actual, src, outlen)
}
}
// binary format:
// sign (0: positive, 1: negative), days(4), hours, minutes, seconds, micro(4)
// Zeros
expect("00:00:00", []byte{}, 8)
expect("00:00:00.0", []byte{}, 10)
expect("00:00:00.000000", []byte{}, 15)
// Without micro(4)
expect("12:34:56", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 8)
expect("-12:34:56", []byte{1, 0, 0, 0, 0, 12, 34, 56}, 8)
expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56}, 11)
expect("24:34:56", []byte{0, 1, 0, 0, 0, 0, 34, 56}, 8)
expect("-99:34:56", []byte{1, 4, 0, 0, 0, 3, 34, 56}, 8)
expect("103079215103:34:56", []byte{0, 255, 255, 255, 255, 23, 34, 56}, 8)
// With micro(4)
expect("12:34:56.00", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 11)
expect("12:34:56.000099", []byte{0, 0, 0, 0, 0, 12, 34, 56, 99, 0, 0, 0}, 15)
}
func TestEscapeBackslash(t *testing.T) {
expect := func(expected, value string) {
actual := string(escapeBytesBackslash([]byte{}, []byte(value)))
if actual != expected {
t.Errorf(
"expected %s, got %s",
expected, actual,
)
}
actual = string(escapeStringBackslash([]byte{}, value))
if actual != expected {
t.Errorf(
"expected %s, got %s",
expected, actual,
)
}
}
expect("foo\\0bar", "foo\x00bar")
expect("foo\\nbar", "foo\nbar")
expect("foo\\rbar", "foo\rbar")
expect("foo\\Zbar", "foo\x1abar")
expect("foo\\\"bar", "foo\"bar")
expect("foo\\\\bar", "foo\\bar")
expect("foo\\'bar", "foo'bar")
}
func TestEscapeQuotes(t *testing.T) {
expect := func(expected, value string) {
actual := string(escapeBytesQuotes([]byte{}, []byte(value)))
if actual != expected {
t.Errorf(
"expected %s, got %s",
expected, actual,
)
}
actual = string(escapeStringQuotes([]byte{}, value))
if actual != expected {
t.Errorf(
"expected %s, got %s",
expected, actual,
)
}
}
expect("foo\x00bar", "foo\x00bar") // not affected
expect("foo\nbar", "foo\nbar") // not affected
expect("foo\rbar", "foo\rbar") // not affected
expect("foo\x1abar", "foo\x1abar") // not affected
expect("foo''bar", "foo'bar") // affected
expect("foo\"bar", "foo\"bar") // not affected
}
func TestAtomicBool(t *testing.T) {
var ab atomicBool
if ab.IsSet() {
t.Fatal("Expected value to be false")
}
ab.Set(true)
if ab.value != 1 {
t.Fatal("Set(true) did not set value to 1")
}
if !ab.IsSet() {
t.Fatal("Expected value to be true")
}
ab.Set(true)
if !ab.IsSet() {
t.Fatal("Expected value to be true")
}
ab.Set(false)
if ab.value != 0 {
t.Fatal("Set(false) did not set value to 0")
}
if ab.IsSet() {
t.Fatal("Expected value to be false")
}
ab.Set(false)
if ab.IsSet() {
t.Fatal("Expected value to be false")
}
if ab.TrySet(false) {
t.Fatal("Expected TrySet(false) to fail")
}
if !ab.TrySet(true) {
t.Fatal("Expected TrySet(true) to succeed")
}
if !ab.IsSet() {
t.Fatal("Expected value to be true")
}
ab.Set(true)
if !ab.IsSet() {
t.Fatal("Expected value to be true")
}
if ab.TrySet(true) {
t.Fatal("Expected TrySet(true) to fail")
}
if !ab.TrySet(false) {
t.Fatal("Expected TrySet(false) to succeed")
}
if ab.IsSet() {
t.Fatal("Expected value to be false")
}
ab._noCopy.Lock() // we've "tested" it ¯\_(ツ)_/¯
}
func TestAtomicError(t *testing.T) {
var ae atomicError
if ae.Value() != nil {
t.Fatal("Expected value to be nil")
}
ae.Set(ErrMalformPkt)
if v := ae.Value(); v != ErrMalformPkt {
if v == nil {
t.Fatal("Value is still nil")
}
t.Fatal("Error did not match")
}
ae.Set(ErrPktSync)
if ae.Value() == ErrMalformPkt {
t.Fatal("Error still matches old error")
}
if v := ae.Value(); v != ErrPktSync {
t.Fatal("Error did not match")
}
}
func TestIsolationLevelMapping(t *testing.T) {
data := []struct {
level driver.IsolationLevel
expected string
}{
{
level: driver.IsolationLevel(sql.LevelReadCommitted),
expected: "READ COMMITTED",
},
{
level: driver.IsolationLevel(sql.LevelRepeatableRead),
expected: "REPEATABLE READ",
},
{
level: driver.IsolationLevel(sql.LevelReadUncommitted),
expected: "READ UNCOMMITTED",
},
{
level: driver.IsolationLevel(sql.LevelSerializable),
expected: "SERIALIZABLE",
},
}
for i, td := range data {
if actual, err := mapIsolationLevel(td.level); actual != td.expected || err != nil {
t.Fatal(i, td.expected, actual, err)
}
}
// check unsupported mapping
expectedErr := "mysql: unsupported isolation level: 7"
actual, err := mapIsolationLevel(driver.IsolationLevel(sql.LevelLinearizable))
if actual != "" || err == nil {
t.Fatal("Expected error on unsupported isolation level")
}
if err.Error() != expectedErr {
t.Fatalf("Expected error to be %q, got %q", expectedErr, err)
}
}

191
vendor/github.com/golang/glog/LICENSE generated vendored Normal file
View file

@ -0,0 +1,191 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction, and
distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by the copyright
owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all other entities
that control, are controlled by, or are under common control with that entity.
For the purposes of this definition, "control" means (i) the power, direct or
indirect, to cause the direction or management of such entity, whether by
contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity exercising
permissions granted by this License.
"Source" form shall mean the preferred form for making modifications, including
but not limited to software source code, documentation source, and configuration
files.
"Object" form shall mean any form resulting from mechanical transformation or
translation of a Source form, including but not limited to compiled object code,
generated documentation, and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or Object form, made
available under the License, as indicated by a copyright notice that is included
in or attached to the work (an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object form, that
is based on (or derived from) the Work and for which the editorial revisions,
annotations, elaborations, or other modifications represent, as a whole, an
original work of authorship. For the purposes of this License, Derivative Works
shall not include works that remain separable from, or merely link (or bind by
name) to the interfaces of, the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including the original version
of the Work and any modifications or additions to that Work or Derivative Works
thereof, that is intentionally submitted to Licensor for inclusion in the Work
by the copyright owner or by an individual or Legal Entity authorized to submit
on behalf of the copyright owner. For the purposes of this definition,
"submitted" means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems, and
issue tracking systems that are managed by, or on behalf of, the Licensor for
the purpose of discussing and improving the Work, but excluding communication
that is conspicuously marked or otherwise designated in writing by the copyright
owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf
of whom a Contribution has been received by Licensor and subsequently
incorporated within the Work.
2. Grant of Copyright License.
Subject to the terms and conditions of this License, each Contributor hereby
grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the Work and such
Derivative Works in Source or Object form.
3. Grant of Patent License.
Subject to the terms and conditions of this License, each Contributor hereby
grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
irrevocable (except as stated in this section) patent license to make, have
made, use, offer to sell, sell, import, and otherwise transfer the Work, where
such license applies only to those patent claims licensable by such Contributor
that are necessarily infringed by their Contribution(s) alone or by combination
of their Contribution(s) with the Work to which such Contribution(s) was
submitted. If You institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work or a
Contribution incorporated within the Work constitutes direct or contributory
patent infringement, then any patent licenses granted to You under this License
for that Work shall terminate as of the date such litigation is filed.
4. Redistribution.
You may reproduce and distribute copies of the Work or Derivative Works thereof
in any medium, with or without modifications, and in Source or Object form,
provided that You meet the following conditions:
You must give any other recipients of the Work or Derivative Works a copy of
this License; and
You must cause any modified files to carry prominent notices stating that You
changed the files; and
You must retain, in the Source form of any Derivative Works that You distribute,
all copyright, patent, trademark, and attribution notices from the Source form
of the Work, excluding those notices that do not pertain to any part of the
Derivative Works; and
If the Work includes a "NOTICE" text file as part of its distribution, then any
Derivative Works that You distribute must include a readable copy of the
attribution notices contained within such NOTICE file, excluding those notices
that do not pertain to any part of the Derivative Works, in at least one of the
following places: within a NOTICE text file distributed as part of the
Derivative Works; within the Source form or documentation, if provided along
with the Derivative Works; or, within a display generated by the Derivative
Works, if and wherever such third-party notices normally appear. The contents of
the NOTICE file are for informational purposes only and do not modify the
License. You may add Your own attribution notices within Derivative Works that
You distribute, alongside or as an addendum to the NOTICE text from the Work,
provided that such additional attribution notices cannot be construed as
modifying the License.
You may add Your own copyright statement to Your modifications and may provide
additional or different license terms and conditions for use, reproduction, or
distribution of Your modifications, or for any such Derivative Works as a whole,
provided Your use, reproduction, and distribution of the Work otherwise complies
with the conditions stated in this License.
5. Submission of Contributions.
Unless You explicitly state otherwise, any Contribution intentionally submitted
for inclusion in the Work by You to the Licensor shall be under the terms and
conditions of this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify the terms of
any separate license agreement you may have executed with Licensor regarding
such Contributions.
6. Trademarks.
This License does not grant permission to use the trade names, trademarks,
service marks, or product names of the Licensor, except as required for
reasonable and customary use in describing the origin of the Work and
reproducing the content of the NOTICE file.
7. Disclaimer of Warranty.
Unless required by applicable law or agreed to in writing, Licensor provides the
Work (and each Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied,
including, without limitation, any warranties or conditions of TITLE,
NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are
solely responsible for determining the appropriateness of using or
redistributing the Work and assume any risks associated with Your exercise of
permissions under this License.
8. Limitation of Liability.
In no event and under no legal theory, whether in tort (including negligence),
contract, or otherwise, unless required by applicable law (such as deliberate
and grossly negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special, incidental,
or consequential damages of any character arising as a result of this License or
out of the use or inability to use the Work (including but not limited to
damages for loss of goodwill, work stoppage, computer failure or malfunction, or
any and all other commercial damages or losses), even if such Contributor has
been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability.
While redistributing the Work or Derivative Works thereof, You may choose to
offer, and charge a fee for, acceptance of support, warranty, indemnity, or
other liability obligations and/or rights consistent with this License. However,
in accepting such obligations, You may act only on Your own behalf and on Your
sole responsibility, not on behalf of any other Contributor, and only if You
agree to indemnify, defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason of your
accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work
To apply the Apache License to your work, attach the following boilerplate
notice, with the fields enclosed by brackets "[]" replaced with your own
identifying information. (Don't include the brackets!) The text should be
enclosed in the appropriate comment syntax for the file format. We also
recommend that a file or class name and description of purpose be included on
the same "printed page" as the copyright notice for easier identification within
third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

44
vendor/github.com/golang/glog/README generated vendored Normal file
View file

@ -0,0 +1,44 @@
glog
====
Leveled execution logs for Go.
This is an efficient pure Go implementation of leveled logs in the
manner of the open source C++ package
https://github.com/google/glog
By binding methods to booleans it is possible to use the log package
without paying the expense of evaluating the arguments to the log.
Through the -vmodule flag, the package also provides fine-grained
control over logging at the file level.
The comment from glog.go introduces the ideas:
Package glog implements logging analogous to the Google-internal
C++ INFO/ERROR/V setup. It provides functions Info, Warning,
Error, Fatal, plus formatting variants such as Infof. It
also provides V-style logging controlled by the -v and
-vmodule=file=2 flags.
Basic examples:
glog.Info("Prepare to repel boarders")
glog.Fatalf("Initialization failed: %s", err)
See the documentation for the V function for an explanation
of these examples:
if glog.V(2) {
glog.Info("Starting transaction...")
}
glog.V(2).Infoln("Processed", nItems, "elements")
The repository contains an open source version of the log package
used inside Google. The master copy of the source lives inside
Google, not here. The code in this repo is for export only and is not itself
under development. Feature requests will be ignored.
Send bug reports to golang-nuts@googlegroups.com.

1180
vendor/github.com/golang/glog/glog.go generated vendored Normal file

File diff suppressed because it is too large Load diff

124
vendor/github.com/golang/glog/glog_file.go generated vendored Normal file
View file

@ -0,0 +1,124 @@
// Go support for leveled logs, analogous to https://code.google.com/p/google-glog/
//
// Copyright 2013 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// File I/O for logs.
package glog
import (
"errors"
"flag"
"fmt"
"os"
"os/user"
"path/filepath"
"strings"
"sync"
"time"
)
// MaxSize is the maximum size of a log file in bytes.
var MaxSize uint64 = 1024 * 1024 * 1800
// logDirs lists the candidate directories for new log files.
var logDirs []string
// If non-empty, overrides the choice of directory in which to write logs.
// See createLogDirs for the full list of possible destinations.
var logDir = flag.String("log_dir", "", "If non-empty, write log files in this directory")
func createLogDirs() {
if *logDir != "" {
logDirs = append(logDirs, *logDir)
}
logDirs = append(logDirs, os.TempDir())
}
var (
pid = os.Getpid()
program = filepath.Base(os.Args[0])
host = "unknownhost"
userName = "unknownuser"
)
func init() {
h, err := os.Hostname()
if err == nil {
host = shortHostname(h)
}
current, err := user.Current()
if err == nil {
userName = current.Username
}
// Sanitize userName since it may contain filepath separators on Windows.
userName = strings.Replace(userName, `\`, "_", -1)
}
// shortHostname returns its argument, truncating at the first period.
// For instance, given "www.google.com" it returns "www".
func shortHostname(hostname string) string {
if i := strings.Index(hostname, "."); i >= 0 {
return hostname[:i]
}
return hostname
}
// logName returns a new log file name containing tag, with start time t, and
// the name for the symlink for tag.
func logName(tag string, t time.Time) (name, link string) {
name = fmt.Sprintf("%s.%s.%s.log.%s.%04d%02d%02d-%02d%02d%02d.%d",
program,
host,
userName,
tag,
t.Year(),
t.Month(),
t.Day(),
t.Hour(),
t.Minute(),
t.Second(),
pid)
return name, program + "." + tag
}
var onceLogDirs sync.Once
// create creates a new log file and returns the file and its filename, which
// contains tag ("INFO", "FATAL", etc.) and t. If the file is created
// successfully, create also attempts to update the symlink for that tag, ignoring
// errors.
func create(tag string, t time.Time) (f *os.File, filename string, err error) {
onceLogDirs.Do(createLogDirs)
if len(logDirs) == 0 {
return nil, "", errors.New("log: no log dirs")
}
name, link := logName(tag, t)
var lastErr error
for _, dir := range logDirs {
fname := filepath.Join(dir, name)
f, err := os.Create(fname)
if err == nil {
symlink := filepath.Join(dir, link)
os.Remove(symlink) // ignore err
os.Symlink(name, symlink) // ignore err
return f, fname, nil
}
lastErr = err
}
return nil, "", fmt.Errorf("log: cannot create log: %v", lastErr)
}

2
vendor/github.com/lib/pq/README.md generated vendored
View file

@ -10,7 +10,7 @@
## Docs
For detailed documentation and basic usage examples, please see the package
documentation at <https://godoc.org/github.com/lib/pq>.
documentation at <http://godoc.org/github.com/lib/pq>.
## Tests

1311
vendor/github.com/lib/pq/array_test.go generated vendored

File diff suppressed because it is too large Load diff

View file

@ -1,434 +0,0 @@
package pq
import (
"bufio"
"bytes"
"context"
"database/sql"
"database/sql/driver"
"io"
"math/rand"
"net"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/lib/pq/oid"
)
var (
selectStringQuery = "SELECT '" + strings.Repeat("0123456789", 10) + "'"
selectSeriesQuery = "SELECT generate_series(1, 100)"
)
func BenchmarkSelectString(b *testing.B) {
var result string
benchQuery(b, selectStringQuery, &result)
}
func BenchmarkSelectSeries(b *testing.B) {
var result int
benchQuery(b, selectSeriesQuery, &result)
}
func benchQuery(b *testing.B, query string, result interface{}) {
b.StopTimer()
db := openTestConn(b)
defer db.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
benchQueryLoop(b, db, query, result)
}
}
func benchQueryLoop(b *testing.B, db *sql.DB, query string, result interface{}) {
rows, err := db.Query(query)
if err != nil {
b.Fatal(err)
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(result)
if err != nil {
b.Fatal("failed to scan", err)
}
}
}
// reading from circularConn yields content[:prefixLen] once, followed by
// content[prefixLen:] over and over again. It never returns EOF.
type circularConn struct {
content string
prefixLen int
pos int
net.Conn // for all other net.Conn methods that will never be called
}
func (r *circularConn) Read(b []byte) (n int, err error) {
n = copy(b, r.content[r.pos:])
r.pos += n
if r.pos >= len(r.content) {
r.pos = r.prefixLen
}
return
}
func (r *circularConn) Write(b []byte) (n int, err error) { return len(b), nil }
func (r *circularConn) Close() error { return nil }
func fakeConn(content string, prefixLen int) *conn {
c := &circularConn{content: content, prefixLen: prefixLen}
return &conn{buf: bufio.NewReader(c), c: c}
}
// This benchmark is meant to be the same as BenchmarkSelectString, but takes
// out some of the factors this package can't control. The numbers are less noisy,
// but also the costs of network communication aren't accurately represented.
func BenchmarkMockSelectString(b *testing.B) {
b.StopTimer()
// taken from a recorded run of BenchmarkSelectString
// See: http://www.postgresql.org/docs/current/static/protocol-message-formats.html
const response = "1\x00\x00\x00\x04" +
"t\x00\x00\x00\x06\x00\x00" +
"T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" +
"Z\x00\x00\x00\x05I" +
"2\x00\x00\x00\x04" +
"D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" +
"C\x00\x00\x00\rSELECT 1\x00" +
"Z\x00\x00\x00\x05I" +
"3\x00\x00\x00\x04" +
"Z\x00\x00\x00\x05I"
c := fakeConn(response, 0)
b.StartTimer()
for i := 0; i < b.N; i++ {
benchMockQuery(b, c, selectStringQuery)
}
}
var seriesRowData = func() string {
var buf bytes.Buffer
for i := 1; i <= 100; i++ {
digits := byte(2)
if i >= 100 {
digits = 3
} else if i < 10 {
digits = 1
}
buf.WriteString("D\x00\x00\x00")
buf.WriteByte(10 + digits)
buf.WriteString("\x00\x01\x00\x00\x00")
buf.WriteByte(digits)
buf.WriteString(strconv.Itoa(i))
}
return buf.String()
}()
func BenchmarkMockSelectSeries(b *testing.B) {
b.StopTimer()
var response = "1\x00\x00\x00\x04" +
"t\x00\x00\x00\x06\x00\x00" +
"T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" +
"Z\x00\x00\x00\x05I" +
"2\x00\x00\x00\x04" +
seriesRowData +
"C\x00\x00\x00\x0fSELECT 100\x00" +
"Z\x00\x00\x00\x05I" +
"3\x00\x00\x00\x04" +
"Z\x00\x00\x00\x05I"
c := fakeConn(response, 0)
b.StartTimer()
for i := 0; i < b.N; i++ {
benchMockQuery(b, c, selectSeriesQuery)
}
}
func benchMockQuery(b *testing.B, c *conn, query string) {
stmt, err := c.Prepare(query)
if err != nil {
b.Fatal(err)
}
defer stmt.Close()
rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil)
if err != nil {
b.Fatal(err)
}
defer rows.Close()
var dest [1]driver.Value
for {
if err := rows.Next(dest[:]); err != nil {
if err == io.EOF {
break
}
b.Fatal(err)
}
}
}
func BenchmarkPreparedSelectString(b *testing.B) {
var result string
benchPreparedQuery(b, selectStringQuery, &result)
}
func BenchmarkPreparedSelectSeries(b *testing.B) {
var result int
benchPreparedQuery(b, selectSeriesQuery, &result)
}
func benchPreparedQuery(b *testing.B, query string, result interface{}) {
b.StopTimer()
db := openTestConn(b)
defer db.Close()
stmt, err := db.Prepare(query)
if err != nil {
b.Fatal(err)
}
defer stmt.Close()
b.StartTimer()
for i := 0; i < b.N; i++ {
benchPreparedQueryLoop(b, db, stmt, result)
}
}
func benchPreparedQueryLoop(b *testing.B, db *sql.DB, stmt *sql.Stmt, result interface{}) {
rows, err := stmt.Query()
if err != nil {
b.Fatal(err)
}
if !rows.Next() {
rows.Close()
b.Fatal("no rows")
}
defer rows.Close()
for rows.Next() {
err = rows.Scan(&result)
if err != nil {
b.Fatal("failed to scan")
}
}
}
// See the comment for BenchmarkMockSelectString.
func BenchmarkMockPreparedSelectString(b *testing.B) {
b.StopTimer()
const parseResponse = "1\x00\x00\x00\x04" +
"t\x00\x00\x00\x06\x00\x00" +
"T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" +
"Z\x00\x00\x00\x05I"
const responses = parseResponse +
"2\x00\x00\x00\x04" +
"D\x00\x00\x00n\x00\x01\x00\x00\x00d0123456789012345678901234567890123456789012345678901234567890123456789012345678901234567890123456789" +
"C\x00\x00\x00\rSELECT 1\x00" +
"Z\x00\x00\x00\x05I"
c := fakeConn(responses, len(parseResponse))
stmt, err := c.Prepare(selectStringQuery)
if err != nil {
b.Fatal(err)
}
b.StartTimer()
for i := 0; i < b.N; i++ {
benchPreparedMockQuery(b, c, stmt)
}
}
func BenchmarkMockPreparedSelectSeries(b *testing.B) {
b.StopTimer()
const parseResponse = "1\x00\x00\x00\x04" +
"t\x00\x00\x00\x06\x00\x00" +
"T\x00\x00\x00!\x00\x01?column?\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\xc1\xff\xfe\xff\xff\xff\xff\x00\x00" +
"Z\x00\x00\x00\x05I"
var responses = parseResponse +
"2\x00\x00\x00\x04" +
seriesRowData +
"C\x00\x00\x00\x0fSELECT 100\x00" +
"Z\x00\x00\x00\x05I"
c := fakeConn(responses, len(parseResponse))
stmt, err := c.Prepare(selectSeriesQuery)
if err != nil {
b.Fatal(err)
}
b.StartTimer()
for i := 0; i < b.N; i++ {
benchPreparedMockQuery(b, c, stmt)
}
}
func benchPreparedMockQuery(b *testing.B, c *conn, stmt driver.Stmt) {
rows, err := stmt.(driver.StmtQueryContext).QueryContext(context.Background(), nil)
if err != nil {
b.Fatal(err)
}
defer rows.Close()
var dest [1]driver.Value
for {
if err := rows.Next(dest[:]); err != nil {
if err == io.EOF {
break
}
b.Fatal(err)
}
}
}
func BenchmarkEncodeInt64(b *testing.B) {
for i := 0; i < b.N; i++ {
encode(&parameterStatus{}, int64(1234), oid.T_int8)
}
}
func BenchmarkEncodeFloat64(b *testing.B) {
for i := 0; i < b.N; i++ {
encode(&parameterStatus{}, 3.14159, oid.T_float8)
}
}
var testByteString = []byte("abcdefghijklmnopqrstuvwxyz")
func BenchmarkEncodeByteaHex(b *testing.B) {
for i := 0; i < b.N; i++ {
encode(&parameterStatus{serverVersion: 90000}, testByteString, oid.T_bytea)
}
}
func BenchmarkEncodeByteaEscape(b *testing.B) {
for i := 0; i < b.N; i++ {
encode(&parameterStatus{serverVersion: 84000}, testByteString, oid.T_bytea)
}
}
func BenchmarkEncodeBool(b *testing.B) {
for i := 0; i < b.N; i++ {
encode(&parameterStatus{}, true, oid.T_bool)
}
}
var testTimestamptz = time.Date(2001, time.January, 1, 0, 0, 0, 0, time.Local)
func BenchmarkEncodeTimestamptz(b *testing.B) {
for i := 0; i < b.N; i++ {
encode(&parameterStatus{}, testTimestamptz, oid.T_timestamptz)
}
}
var testIntBytes = []byte("1234")
func BenchmarkDecodeInt64(b *testing.B) {
for i := 0; i < b.N; i++ {
decode(&parameterStatus{}, testIntBytes, oid.T_int8, formatText)
}
}
var testFloatBytes = []byte("3.14159")
func BenchmarkDecodeFloat64(b *testing.B) {
for i := 0; i < b.N; i++ {
decode(&parameterStatus{}, testFloatBytes, oid.T_float8, formatText)
}
}
var testBoolBytes = []byte{'t'}
func BenchmarkDecodeBool(b *testing.B) {
for i := 0; i < b.N; i++ {
decode(&parameterStatus{}, testBoolBytes, oid.T_bool, formatText)
}
}
func TestDecodeBool(t *testing.T) {
db := openTestConn(t)
rows, err := db.Query("select true")
if err != nil {
t.Fatal(err)
}
rows.Close()
}
var testTimestamptzBytes = []byte("2013-09-17 22:15:32.360754-07")
func BenchmarkDecodeTimestamptz(b *testing.B) {
for i := 0; i < b.N; i++ {
decode(&parameterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText)
}
}
func BenchmarkDecodeTimestamptzMultiThread(b *testing.B) {
oldProcs := runtime.GOMAXPROCS(0)
defer runtime.GOMAXPROCS(oldProcs)
runtime.GOMAXPROCS(runtime.NumCPU())
globalLocationCache = newLocationCache()
f := func(wg *sync.WaitGroup, loops int) {
defer wg.Done()
for i := 0; i < loops; i++ {
decode(&parameterStatus{}, testTimestamptzBytes, oid.T_timestamptz, formatText)
}
}
wg := &sync.WaitGroup{}
b.ResetTimer()
for j := 0; j < 10; j++ {
wg.Add(1)
go f(wg, b.N/10)
}
wg.Wait()
}
func BenchmarkLocationCache(b *testing.B) {
globalLocationCache = newLocationCache()
for i := 0; i < b.N; i++ {
globalLocationCache.getLocation(rand.Intn(10000))
}
}
func BenchmarkLocationCacheMultiThread(b *testing.B) {
oldProcs := runtime.GOMAXPROCS(0)
defer runtime.GOMAXPROCS(oldProcs)
runtime.GOMAXPROCS(runtime.NumCPU())
globalLocationCache = newLocationCache()
f := func(wg *sync.WaitGroup, loops int) {
defer wg.Done()
for i := 0; i < loops; i++ {
globalLocationCache.getLocation(rand.Intn(10000))
}
}
wg := &sync.WaitGroup{}
b.ResetTimer()
for j := 0; j < 10; j++ {
wg.Add(1)
go f(wg, b.N/10)
}
wg.Wait()
}
// Stress test the performance of parsing results from the wire.
func BenchmarkResultParsing(b *testing.B) {
b.StopTimer()
db := openTestConn(b)
defer db.Close()
_, err := db.Exec("BEGIN")
if err != nil {
b.Fatal(err)
}
b.StartTimer()
for i := 0; i < b.N; i++ {
res, err := db.Query("SELECT generate_series(1, 50000)")
if err != nil {
b.Fatal(err)
}
res.Close()
}
}

View file

@ -1,3 +0,0 @@
This directory contains certificates and private keys for testing some
SSL-related functionality in Travis. Do NOT use these certificates for
anything other than testing.

View file

@ -1,15 +0,0 @@
-----BEGIN RSA PRIVATE KEY-----
MIICWwIBAAKBgQDjjAaacFRR0TQ0gznNolkPBe2N2A400JL0CU3ujHhVSST4POA0
WAKy55RYwejlu9Gv9lTBQLGQcHkNNVScjxbpwvCS5mRJOMF2+EdmxFtKtqlDzsi+
bE0rlJc8VbzR0G63U66JXEtrhkC+wa4eZM6crocKaeXIIRK+rh32Rd8WpwIDAQAB
AoGAM5dM6/kp9P700i8qjOgRPym96Zoh5nGfz/rIE5z/r36NBkdvIg8OVZfR96nH
b0b9TOMR5lsPp0sI9yivTWvX6qyvLJRWy2vvx17hXK9NxXUNTAm0PYZUTvCtcPeX
RnJpzQKNZQPkFzF0uXBc4CtPK2Vz0+FGvAelrhYAxnw1dIkCQQD+9qaW5QhXjsjb
Nl85CmXgxPmGROcgLQCO+omfrjf9UXrituU9Dz6auym5lDGEdMFnkzfr+wpasEy9
mf5ZZOhDAkEA5HjXfVGaCtpydOt6hDon/uZsyssCK2lQ7NSuE3vP+sUsYMzIpEoy
t3VWXqKbo+g9KNDTP4WEliqp1aiSIylzzQJANPeqzihQnlgEdD4MdD4rwhFJwVIp
Le8Lcais1KaN7StzOwxB/XhgSibd2TbnPpw+3bSg5n5lvUdo+e62/31OHwJAU1jS
I+F09KikQIr28u3UUWT2IzTT4cpVv1AHAQyV3sG3YsjSGT0IK20eyP9BEBZU2WL0
7aNjrvR5aHxKc5FXsQJABsFtyGpgI5X4xufkJZVZ+Mklz2n7iXa+XPatMAHFxAtb
EEMt60rngwMjXAzBSC6OYuYogRRAY3UCacNC5VhLYQ==
-----END RSA PRIVATE KEY-----

View file

@ -1,27 +0,0 @@
-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEA14pMhfsXpTyP4HIRKc4/sB8/fcbuf6f8Ais1RwimPZDfXFYU
lADHbdHS4mGVd7jjpmYx+R8hfWLhJ9qUN2FK6mNToGG4nLul4ue3ptgPBQTHKeLq
SSt/3hUAphhwUMcM3pr5Wpaw4ZQGxm1KITu0D6VtkoY0sk7XDqcZwHcLe4fIkt5C
/4bSt5qk1BUjyq2laSG4zn5my4Vdue2LLQmNlOQEHnLs79B2kBVapPeRS+nOTp1d
mnAXnNjpc4PqPWGZps2skUBaiHflTiqOPRPz+ThvgWuKlcoOB6tv2rSM2f+qeAOq
x8LPb2SS09iD1a/xIxinLnsXC+d98fqoQaMEVwIDAQABAoIBAF3ZoihUhJ82F4+r
Gz4QyDpv4L1reT2sb1aiabhcU8ZK5nbWJG+tRyjSS/i2dNaEcttpdCj9HR/zhgZM
bm0OuAgG58rVwgS80CZUruq++Qs+YVojq8/gWPTiQD4SNhV2Fmx3HkwLgUk3oxuT
SsvdqzGE3okGVrutCIcgy126eA147VPMoej1Bb3fO6npqK0pFPhZfAc0YoqJuM+k
obRm5pAnGUipyLCFXjA9HYPKwYZw2RtfdA3CiImHeanSdqS+ctrC9y8BV40Th7gZ
haXdKUNdjmIxV695QQ1mkGqpKLZFqhzKioGQ2/Ly2d1iaKN9fZltTusu8unepWJ2
tlT9qMECgYEA9uHaF1t2CqE+AJvWTihHhPIIuLxoOQXYea1qvxfcH/UMtaLKzCNm
lQ5pqCGsPvp+10f36yttO1ZehIvlVNXuJsjt0zJmPtIolNuJY76yeussfQ9jHheB
5uPEzCFlHzxYbBUyqgWaF6W74okRGzEGJXjYSP0yHPPdU4ep2q3bGiUCgYEA34Af
wBSuQSK7uLxArWHvQhyuvi43ZGXls6oRGl+Ysj54s8BP6XGkq9hEJ6G4yxgyV+BR
DUOs5X8/TLT8POuIMYvKTQthQyCk0eLv2FLdESDuuKx0kBVY3s8lK3/z5HhrdOiN
VMNZU+xDKgKc3hN9ypkk8vcZe6EtH7Y14e0rVcsCgYBTgxi8F/M5K0wG9rAqphNz
VFBA9XKn/2M33cKjO5X5tXIEKzpAjaUQvNxexG04rJGljzG8+mar0M6ONahw5yD1
O7i/XWgazgpuOEkkVYiYbd8RutfDgR4vFVMn3hAP3eDnRtBplRWH9Ec3HTiNIys6
F8PKBOQjyRZQQC7jyzW3hQKBgACe5HeuFwXLSOYsb6mLmhR+6+VPT4wR1F95W27N
USk9jyxAnngxfpmTkiziABdgS9N+pfr5cyN4BP77ia/Jn6kzkC5Cl9SN5KdIkA3z
vPVtN/x/ThuQU5zaymmig1ThGLtMYggYOslG4LDfLPxY5YKIhle+Y+259twdr2yf
Mf2dAoGAaGv3tWMgnIdGRk6EQL/yb9PKHo7ShN+tKNlGaK7WwzBdKs+Fe8jkgcr7
pz4Ne887CmxejdISzOCcdT+Zm9Bx6I/uZwWOtDvWpIgIxVX9a9URj/+D1MxTE/y4
d6H+c89yDY62I2+drMpdjCd3EtCaTlxpTbRS+s1eAHMH7aEkcCE=
-----END RSA PRIVATE KEY-----

293
vendor/github.com/lib/pq/conn.go generated vendored
View file

@ -2,9 +2,7 @@ package pq
import (
"bufio"
"context"
"crypto/md5"
"crypto/sha256"
"database/sql"
"database/sql/driver"
"encoding/binary"
@ -22,7 +20,6 @@ import (
"unicode"
"github.com/lib/pq/oid"
"github.com/lib/pq/scram"
)
// Common error types
@ -92,24 +89,13 @@ type Dialer interface {
DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
}
type DialerContext interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
type defaultDialer struct{}
type defaultDialer struct {
d net.Dialer
func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
return net.Dial(ntw, addr)
}
func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
return d.d.Dial(network, address)
}
func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
return d.DialContext(ctx, network, address)
}
func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return d.d.DialContext(ctx, network, address)
func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout(ntw, addr, timeout)
}
type conn struct {
@ -258,35 +244,90 @@ func (cn *conn) writeBuf(b byte) *writeBuf {
}
}
// Open opens a new connection to the database. dsn is a connection string.
// Open opens a new connection to the database. name is a connection string.
// Most users should only use it through database/sql package from the standard
// library.
func Open(dsn string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, dsn)
func Open(name string) (_ driver.Conn, err error) {
return DialOpen(defaultDialer{}, name)
}
// DialOpen opens a new connection to the database using a dialer.
func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
c, err := NewConnector(dsn)
if err != nil {
return nil, err
}
c.dialer = d
return c.open(context.Background())
}
func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
// Handle any panics during connection initialization. Note that we
// specifically do *not* want to use errRecover(), as that would turn any
// connection errors into ErrBadConns, hiding the real error message from
// the user.
defer errRecoverNoErrBadConn(&err)
o := c.opts
o := make(values)
cn = &conn{
// A number of defaults are applied here, in this order:
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Explicitly passed connection information
o["host"] = "localhost"
o["port"] = "5432"
// N.B.: Extra float digits should be set to 3, but that breaks
// Postgres 8.4 and older, where the max is 2.
o["extra_float_digits"] = "2"
for k, v := range parseEnviron(os.Environ()) {
o[k] = v
}
if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
name, err = ParseURL(name)
if err != nil {
return nil, err
}
}
if err := parseOpts(name, o); err != nil {
return nil, err
}
// Use the "fallback" application name if necessary
if fallback, ok := o["fallback_application_name"]; ok {
if _, ok := o["application_name"]; !ok {
o["application_name"] = fallback
}
}
// We can't work with any client_encoding other than UTF-8 currently.
// However, we have historically allowed the user to set it to UTF-8
// explicitly, and there's no reason to break such programs, so allow that.
// Note that the "options" setting could also set client_encoding, but
// parsing its value is not worth it. Instead, we always explicitly send
// client_encoding as a separate run-time parameter, which should override
// anything set in options.
if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
return nil, errors.New("client_encoding must be absent or 'UTF8'")
}
o["client_encoding"] = "UTF8"
// DateStyle needs a similar treatment.
if datestyle, ok := o["datestyle"]; ok {
if datestyle != "ISO, MDY" {
panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
"ISO, MDY", datestyle))
}
} else {
o["datestyle"] = "ISO, MDY"
}
// If a user is not provided by any other means, the last
// resort is to use the current operating system provided user
// name.
if _, ok := o["user"]; !ok {
u, err := userCurrent()
if err != nil {
return nil, err
}
o["user"] = u
}
cn := &conn{
opts: o,
dialer: c.dialer,
dialer: d,
}
err = cn.handleDriverSettings(o)
if err != nil {
@ -294,16 +335,13 @@ func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
}
cn.handlePgpass(o)
cn.c, err = dial(ctx, c.dialer, o)
cn.c, err = dial(d, o)
if err != nil {
return nil, err
}
err = cn.ssl(o)
if err != nil {
if cn.c != nil {
cn.c.Close()
}
return nil, err
}
@ -326,10 +364,10 @@ func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
return cn, err
}
func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
network, address := network(o)
func dial(d Dialer, o values) (net.Conn, error) {
ntw, addr := network(o)
// SSL is not necessary or supported over UNIX domain sockets
if network == "unix" {
if ntw == "unix" {
o["sslmode"] = "disable"
}
@ -340,30 +378,19 @@ func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
}
duration := time.Duration(seconds) * time.Second
// connect_timeout should apply to the entire connection establishment
// procedure, so we both use a timeout for the TCP connection
// establishment and set a deadline for doing the initial handshake.
// The deadline is then reset after startup() is done.
deadline := time.Now().Add(duration)
var conn net.Conn
if dctx, ok := d.(DialerContext); ok {
ctx, cancel := context.WithTimeout(ctx, duration)
defer cancel()
conn, err = dctx.DialContext(ctx, network, address)
} else {
conn, err = d.DialTimeout(network, address, duration)
}
conn, err := d.DialTimeout(ntw, addr, duration)
if err != nil {
return nil, err
}
err = conn.SetDeadline(deadline)
return conn, err
}
if dctx, ok := d.(DialerContext); ok {
return dctx.DialContext(ctx, network, address)
}
return d.Dial(network, address)
return d.Dial(ntw, addr)
}
func network(o values) (string, string) {
@ -677,7 +704,7 @@ func (cn *conn) simpleQuery(q string) (res *rows, err error) {
// res might be non-nil here if we received a previous
// CommandComplete, but that's fine; just overwrite it
res = &rows{cn: cn}
res.rowsHeader = parsePortalRowDescribe(r)
res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
// To work around a bug in QueryRow in Go 1.2 and earlier, wait
// until the first DataRow has been received.
@ -834,15 +861,17 @@ func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
cn.readParseResponse()
cn.readBindResponse()
rows := &rows{cn: cn}
rows.rowsHeader = cn.readPortalDescribeResponse()
rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
cn.postExecuteWorkaround()
return rows, nil
}
st := cn.prepareTo(query, "")
st.exec(args)
return &rows{
cn: cn,
rowsHeader: st.rowsHeader,
cn: cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
}, nil
}
@ -963,6 +992,7 @@ func (cn *conn) recv() (t byte, r *readBuf) {
if err != nil {
panic(err)
}
switch t {
case 'E':
panic(parseError(r))
@ -1133,55 +1163,6 @@ func (cn *conn) auth(r *readBuf, o values) {
if r.int32() != 0 {
errorf("unexpected authentication response: %q", t)
}
case 10:
sc := scram.NewClient(sha256.New, o["user"], o["password"])
sc.Step(nil)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
}
scOut := sc.Out()
w := cn.writeBuf('p')
w.string("SCRAM-SHA-256")
w.int32(len(scOut))
w.bytes(scOut)
cn.send(w)
t, r := cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 11 {
errorf("unexpected authentication response: %q", t)
}
nextStep := r.next(len(*r))
sc.Step(nextStep)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
}
scOut = sc.Out()
w = cn.writeBuf('p')
w.bytes(scOut)
cn.send(w)
t, r = cn.recv()
if t != 'R' {
errorf("unexpected password response: %q", t)
}
if r.int32() != 12 {
errorf("unexpected authentication response: %q", t)
}
nextStep = r.next(len(*r))
sc.Step(nextStep)
if sc.Err() != nil {
errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
}
default:
errorf("unknown authentication response: %d", code)
}
@ -1199,10 +1180,12 @@ var colFmtDataAllBinary = []byte{0, 1, 0, 1}
var colFmtDataAllText = []byte{0, 0}
type stmt struct {
cn *conn
name string
rowsHeader
cn *conn
name string
colNames []string
colFmts []format
colFmtData []byte
colTyps []fieldDesc
paramTyps []oid.Oid
closed bool
}
@ -1248,8 +1231,10 @@ func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
st.exec(v)
return &rows{
cn: st.cn,
rowsHeader: st.rowsHeader,
cn: st.cn,
colNames: st.colNames,
colTyps: st.colTyps,
colFmts: st.colFmts,
}, nil
}
@ -1359,22 +1344,16 @@ func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
return driver.RowsAffected(n), commandTag
}
type rowsHeader struct {
type rows struct {
cn *conn
finish func()
colNames []string
colTyps []fieldDesc
colFmts []format
}
type rows struct {
cn *conn
finish func()
rowsHeader
done bool
rb readBuf
result driver.Result
tag string
next *rowsHeader
done bool
rb readBuf
result driver.Result
tag string
}
func (rs *rows) Close() error {
@ -1461,8 +1440,7 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
}
return
case 'T':
next := parsePortalRowDescribe(&rs.rb)
rs.next = &next
rs.colNames, rs.colFmts, rs.colTyps = parsePortalRowDescribe(&rs.rb)
return io.EOF
default:
errorf("unexpected message after execute: %q", t)
@ -1471,16 +1449,10 @@ func (rs *rows) Next(dest []driver.Value) (err error) {
}
func (rs *rows) HasNextResultSet() bool {
hasNext := rs.next != nil && !rs.done
return hasNext
return !rs.done
}
func (rs *rows) NextResultSet() error {
if rs.next == nil {
return io.EOF
}
rs.rowsHeader = *rs.next
rs.next = nil
return nil
}
@ -1503,39 +1475,6 @@ func QuoteIdentifier(name string) string {
return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
}
// QuoteLiteral quotes a 'literal' (e.g. a parameter, often used to pass literal
// to DDL and other statements that do not accept parameters) to be used as part
// of an SQL statement. For example:
//
// exp_date := pq.QuoteLiteral("2023-01-05 15:00:00Z")
// err := db.Exec(fmt.Sprintf("CREATE ROLE my_user VALID UNTIL %s", exp_date))
//
// Any single quotes in name will be escaped. Any backslashes (i.e. "\") will be
// replaced by two backslashes (i.e. "\\") and the C-style escape identifier
// that PostgreSQL provides ('E') will be prepended to the string.
func QuoteLiteral(literal string) string {
// This follows the PostgreSQL internal algorithm for handling quoted literals
// from libpq, which can be found in the "PQEscapeStringInternal" function,
// which is found in the libpq/fe-exec.c source file:
// https://git.postgresql.org/gitweb/?p=postgresql.git;a=blob;f=src/interfaces/libpq/fe-exec.c
//
// substitute any single-quotes (') with two single-quotes ('')
literal = strings.Replace(literal, `'`, `''`, -1)
// determine if the string has any backslashes (\) in it.
// if it does, replace any backslashes (\) with two backslashes (\\)
// then, we need to wrap the entire string with a PostgreSQL
// C-style escape. Per how "PQEscapeStringInternal" handles this case, we
// also add a space before the "E"
if strings.Contains(literal, `\`) {
literal = strings.Replace(literal, `\`, `\\`, -1)
literal = ` E'` + literal + `'`
} else {
// otherwise, we can just wrap the literal with a pair of single quotes
literal = `'` + literal + `'`
}
return literal
}
func md5s(s string) string {
h := md5.New()
h.Write([]byte(s))
@ -1691,13 +1630,13 @@ func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames [
}
}
func (cn *conn) readPortalDescribeResponse() rowsHeader {
func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []fieldDesc) {
t, r := cn.recv1()
switch t {
case 'T':
return parsePortalRowDescribe(r)
case 'n':
return rowsHeader{}
return nil, nil, nil
case 'E':
err := parseError(r)
cn.readReadyForQuery()
@ -1803,11 +1742,11 @@ func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDe
return
}
func parsePortalRowDescribe(r *readBuf) rowsHeader {
func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []fieldDesc) {
n := r.int16()
colNames := make([]string, n)
colFmts := make([]format, n)
colTyps := make([]fieldDesc, n)
colNames = make([]string, n)
colFmts = make([]format, n)
colTyps = make([]fieldDesc, n)
for i := range colNames {
colNames[i] = r.string()
r.next(6)
@ -1816,11 +1755,7 @@ func parsePortalRowDescribe(r *readBuf) rowsHeader {
colTyps[i].Mod = r.int32()
colFmts[i] = format(r.int16())
}
return rowsHeader{
colNames: colNames,
colFmts: colFmts,
colTyps: colTyps,
}
return
}
// parseEnviron tries to mimic some of libpq's environment handling

View file

@ -1,3 +1,5 @@
// +build go1.8
package pq
import (
@ -7,7 +9,6 @@ import (
"fmt"
"io"
"io/ioutil"
"time"
)
// Implement the "QueryerContext" interface
@ -75,32 +76,13 @@ func (cn *conn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx,
return tx, nil
}
func (cn *conn) Ping(ctx context.Context) error {
if finish := cn.watchCancel(ctx); finish != nil {
defer finish()
}
rows, err := cn.simpleQuery("SELECT 'lib/pq ping test';")
if err != nil {
return driver.ErrBadConn // https://golang.org/pkg/database/sql/driver/#Pinger
}
rows.Close()
return nil
}
func (cn *conn) watchCancel(ctx context.Context) func() {
if done := ctx.Done(); done != nil {
finished := make(chan struct{})
go func() {
select {
case <-done:
// At this point the function level context is canceled,
// so it must not be used for the additional network
// request to cancel the query.
// Create a new context to pass into the dial.
ctxCancel, cancel := context.WithTimeout(context.Background(), time.Second*10)
defer cancel()
_ = cn.cancel(ctxCancel)
_ = cn.cancel()
finished <- struct{}{}
case <-finished:
}
@ -115,8 +97,8 @@ func (cn *conn) watchCancel(ctx context.Context) func() {
return nil
}
func (cn *conn) cancel(ctx context.Context) error {
c, err := dial(ctx, cn.dialer, cn.opts)
func (cn *conn) cancel() error {
c, err := dial(cn.dialer, cn.opts)
if err != nil {
return err
}

1741
vendor/github.com/lib/pq/conn_test.go generated vendored

File diff suppressed because it is too large Load diff

View file

@ -1,12 +1,10 @@
// +build go1.10
package pq
import (
"context"
"database/sql/driver"
"errors"
"fmt"
"os"
"strings"
)
// Connector represents a fixed configuration for the pq driver with a given
@ -16,95 +14,30 @@ import (
//
// See https://golang.org/pkg/database/sql/driver/#Connector.
// See https://golang.org/pkg/database/sql/#OpenDB.
type Connector struct {
opts values
dialer Dialer
type connector struct {
name string
}
// Connect returns a connection to the database using the fixed configuration
// of this Connector. Context is not used.
func (c *Connector) Connect(ctx context.Context) (driver.Conn, error) {
return c.open(ctx)
func (c *connector) Connect(_ context.Context) (driver.Conn, error) {
return (&Driver{}).Open(c.name)
}
// Driver returnst the underlying driver of this Connector.
func (c *Connector) Driver() driver.Driver {
func (c *connector) Driver() driver.Driver {
return &Driver{}
}
var _ driver.Connector = &connector{}
// NewConnector returns a connector for the pq driver in a fixed configuration
// with the given dsn. The returned connector can be used to create any number
// with the given name. The returned connector can be used to create any number
// of equivalent Conn's. The returned connector is intended to be used with
// database/sql.OpenDB.
//
// See https://golang.org/pkg/database/sql/driver/#Connector.
// See https://golang.org/pkg/database/sql/#OpenDB.
func NewConnector(dsn string) (*Connector, error) {
var err error
o := make(values)
// A number of defaults are applied here, in this order:
//
// * Very low precedence defaults applied in every situation
// * Environment variables
// * Explicitly passed connection information
o["host"] = "localhost"
o["port"] = "5432"
// N.B.: Extra float digits should be set to 3, but that breaks
// Postgres 8.4 and older, where the max is 2.
o["extra_float_digits"] = "2"
for k, v := range parseEnviron(os.Environ()) {
o[k] = v
}
if strings.HasPrefix(dsn, "postgres://") || strings.HasPrefix(dsn, "postgresql://") {
dsn, err = ParseURL(dsn)
if err != nil {
return nil, err
}
}
if err := parseOpts(dsn, o); err != nil {
return nil, err
}
// Use the "fallback" application name if necessary
if fallback, ok := o["fallback_application_name"]; ok {
if _, ok := o["application_name"]; !ok {
o["application_name"] = fallback
}
}
// We can't work with any client_encoding other than UTF-8 currently.
// However, we have historically allowed the user to set it to UTF-8
// explicitly, and there's no reason to break such programs, so allow that.
// Note that the "options" setting could also set client_encoding, but
// parsing its value is not worth it. Instead, we always explicitly send
// client_encoding as a separate run-time parameter, which should override
// anything set in options.
if enc, ok := o["client_encoding"]; ok && !isUTF8(enc) {
return nil, errors.New("client_encoding must be absent or 'UTF8'")
}
o["client_encoding"] = "UTF8"
// DateStyle needs a similar treatment.
if datestyle, ok := o["datestyle"]; ok {
if datestyle != "ISO, MDY" {
return nil, fmt.Errorf("setting datestyle must be absent or %v; got %v", "ISO, MDY", datestyle)
}
} else {
o["datestyle"] = "ISO, MDY"
}
// If a user is not provided by any other means, the last
// resort is to use the current operating system provided user
// name.
if _, ok := o["user"]; !ok {
u, err := userCurrent()
if err != nil {
return nil, err
}
o["user"] = u
}
return &Connector{opts: o, dialer: defaultDialer{}}, nil
func NewConnector(name string) (driver.Connector, error) {
return &connector{name: name}, nil
}

View file

@ -1,33 +0,0 @@
// +build go1.10
package pq_test
import (
"database/sql"
"fmt"
"github.com/lib/pq"
)
func ExampleNewConnector() {
name := ""
connector, err := pq.NewConnector(name)
if err != nil {
fmt.Println(err)
return
}
db := sql.OpenDB(connector)
if err != nil {
fmt.Println(err)
return
}
defer db.Close()
// Use the DB
txn, err := db.Begin()
if err != nil {
fmt.Println(err)
return
}
txn.Rollback()
}

View file

@ -1,67 +0,0 @@
// +build go1.10
package pq
import (
"context"
"database/sql"
"database/sql/driver"
"testing"
)
func TestNewConnector_WorksWithOpenDB(t *testing.T) {
name := ""
c, err := NewConnector(name)
if err != nil {
t.Fatal(err)
}
db := sql.OpenDB(c)
defer db.Close()
// database/sql might not call our Open at all unless we do something with
// the connection
txn, err := db.Begin()
if err != nil {
t.Fatal(err)
}
txn.Rollback()
}
func TestNewConnector_Connect(t *testing.T) {
name := ""
c, err := NewConnector(name)
if err != nil {
t.Fatal(err)
}
db, err := c.Connect(context.Background())
if err != nil {
t.Fatal(err)
}
defer db.Close()
// database/sql might not call our Open at all unless we do something with
// the connection
txn, err := db.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{})
if err != nil {
t.Fatal(err)
}
txn.Rollback()
}
func TestNewConnector_Driver(t *testing.T) {
name := ""
c, err := NewConnector(name)
if err != nil {
t.Fatal(err)
}
db, err := c.Driver().Open(name)
if err != nil {
t.Fatal(err)
}
defer db.Close()
// database/sql might not call our Open at all unless we do something with
// the connection
txn, err := db.(driver.ConnBeginTx).BeginTx(context.Background(), driver.TxOptions{})
if err != nil {
t.Fatal(err)
}
txn.Rollback()
}

468
vendor/github.com/lib/pq/copy_test.go generated vendored
View file

@ -1,468 +0,0 @@
package pq
import (
"bytes"
"database/sql"
"database/sql/driver"
"net"
"strings"
"testing"
)
func TestCopyInStmt(t *testing.T) {
stmt := CopyIn("table name")
if stmt != `COPY "table name" () FROM STDIN` {
t.Fatal(stmt)
}
stmt = CopyIn("table name", "column 1", "column 2")
if stmt != `COPY "table name" ("column 1", "column 2") FROM STDIN` {
t.Fatal(stmt)
}
stmt = CopyIn(`table " name """`, `co"lumn""`)
if stmt != `COPY "table "" name """"""" ("co""lumn""""") FROM STDIN` {
t.Fatal(stmt)
}
}
func TestCopyInSchemaStmt(t *testing.T) {
stmt := CopyInSchema("schema name", "table name")
if stmt != `COPY "schema name"."table name" () FROM STDIN` {
t.Fatal(stmt)
}
stmt = CopyInSchema("schema name", "table name", "column 1", "column 2")
if stmt != `COPY "schema name"."table name" ("column 1", "column 2") FROM STDIN` {
t.Fatal(stmt)
}
stmt = CopyInSchema(`schema " name """`, `table " name """`, `co"lumn""`)
if stmt != `COPY "schema "" name """"""".`+
`"table "" name """"""" ("co""lumn""""") FROM STDIN` {
t.Fatal(stmt)
}
}
func TestCopyInMultipleValues(t *testing.T) {
db := openTestConn(t)
defer db.Close()
txn, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer txn.Rollback()
_, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
if err != nil {
t.Fatal(err)
}
stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
if err != nil {
t.Fatal(err)
}
longString := strings.Repeat("#", 500)
for i := 0; i < 500; i++ {
_, err = stmt.Exec(int64(i), longString)
if err != nil {
t.Fatal(err)
}
}
_, err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
var num int
err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
if err != nil {
t.Fatal(err)
}
if num != 500 {
t.Fatalf("expected 500 items, not %d", num)
}
}
func TestCopyInRaiseStmtTrigger(t *testing.T) {
db := openTestConn(t)
defer db.Close()
if getServerVersion(t, db) < 90000 {
var exists int
err := db.QueryRow("SELECT 1 FROM pg_language WHERE lanname = 'plpgsql'").Scan(&exists)
if err == sql.ErrNoRows {
t.Skip("language PL/PgSQL does not exist; skipping TestCopyInRaiseStmtTrigger")
} else if err != nil {
t.Fatal(err)
}
}
txn, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer txn.Rollback()
_, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
if err != nil {
t.Fatal(err)
}
_, err = txn.Exec(`
CREATE OR REPLACE FUNCTION pg_temp.temptest()
RETURNS trigger AS
$BODY$ begin
raise notice 'Hello world';
return new;
end $BODY$
LANGUAGE plpgsql`)
if err != nil {
t.Fatal(err)
}
_, err = txn.Exec(`
CREATE TRIGGER temptest_trigger
BEFORE INSERT
ON temp
FOR EACH ROW
EXECUTE PROCEDURE pg_temp.temptest()`)
if err != nil {
t.Fatal(err)
}
stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
if err != nil {
t.Fatal(err)
}
longString := strings.Repeat("#", 500)
_, err = stmt.Exec(int64(1), longString)
if err != nil {
t.Fatal(err)
}
_, err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
var num int
err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
if err != nil {
t.Fatal(err)
}
if num != 1 {
t.Fatalf("expected 1 items, not %d", num)
}
}
func TestCopyInTypes(t *testing.T) {
db := openTestConn(t)
defer db.Close()
txn, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer txn.Rollback()
_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER, text VARCHAR, blob BYTEA, nothing VARCHAR)")
if err != nil {
t.Fatal(err)
}
stmt, err := txn.Prepare(CopyIn("temp", "num", "text", "blob", "nothing"))
if err != nil {
t.Fatal(err)
}
_, err = stmt.Exec(int64(1234567890), "Héllö\n ☃!\r\t\\", []byte{0, 255, 9, 10, 13}, nil)
if err != nil {
t.Fatal(err)
}
_, err = stmt.Exec()
if err != nil {
t.Fatal(err)
}
err = stmt.Close()
if err != nil {
t.Fatal(err)
}
var num int
var text string
var blob []byte
var nothing sql.NullString
err = txn.QueryRow("SELECT * FROM temp").Scan(&num, &text, &blob, &nothing)
if err != nil {
t.Fatal(err)
}
if num != 1234567890 {
t.Fatal("unexpected result", num)
}
if text != "Héllö\n ☃!\r\t\\" {
t.Fatal("unexpected result", text)
}
if !bytes.Equal(blob, []byte{0, 255, 9, 10, 13}) {
t.Fatal("unexpected result", blob)
}
if nothing.Valid {
t.Fatal("unexpected result", nothing.String)
}
}
func TestCopyInWrongType(t *testing.T) {
db := openTestConn(t)
defer db.Close()
txn, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer txn.Rollback()
_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
if err != nil {
t.Fatal(err)
}
stmt, err := txn.Prepare(CopyIn("temp", "num"))
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
_, err = stmt.Exec("Héllö\n ☃!\r\t\\")
if err != nil {
t.Fatal(err)
}
_, err = stmt.Exec()
if err == nil {
t.Fatal("expected error")
}
if pge := err.(*Error); pge.Code.Name() != "invalid_text_representation" {
t.Fatalf("expected 'invalid input syntax for integer' error, got %s (%+v)", pge.Code.Name(), pge)
}
}
func TestCopyOutsideOfTxnError(t *testing.T) {
db := openTestConn(t)
defer db.Close()
_, err := db.Prepare(CopyIn("temp", "num"))
if err == nil {
t.Fatal("COPY outside of transaction did not return an error")
}
if err != errCopyNotSupportedOutsideTxn {
t.Fatalf("expected %s, got %s", err, err.Error())
}
}
func TestCopyInBinaryError(t *testing.T) {
db := openTestConn(t)
defer db.Close()
txn, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer txn.Rollback()
_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
if err != nil {
t.Fatal(err)
}
_, err = txn.Prepare("COPY temp (num) FROM STDIN WITH binary")
if err != errBinaryCopyNotSupported {
t.Fatalf("expected %s, got %+v", errBinaryCopyNotSupported, err)
}
// check that the protocol is in a valid state
err = txn.Rollback()
if err != nil {
t.Fatal(err)
}
}
func TestCopyFromError(t *testing.T) {
db := openTestConn(t)
defer db.Close()
txn, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer txn.Rollback()
_, err = txn.Exec("CREATE TEMP TABLE temp (num INTEGER)")
if err != nil {
t.Fatal(err)
}
_, err = txn.Prepare("COPY temp (num) TO STDOUT")
if err != errCopyToNotSupported {
t.Fatalf("expected %s, got %+v", errCopyToNotSupported, err)
}
// check that the protocol is in a valid state
err = txn.Rollback()
if err != nil {
t.Fatal(err)
}
}
func TestCopySyntaxError(t *testing.T) {
db := openTestConn(t)
defer db.Close()
txn, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer txn.Rollback()
_, err = txn.Prepare("COPY ")
if err == nil {
t.Fatal("expected error")
}
if pge := err.(*Error); pge.Code.Name() != "syntax_error" {
t.Fatalf("expected syntax error, got %s (%+v)", pge.Code.Name(), pge)
}
// check that the protocol is in a valid state
err = txn.Rollback()
if err != nil {
t.Fatal(err)
}
}
// Tests for connection errors in copyin.resploop()
func TestCopyRespLoopConnectionError(t *testing.T) {
db := openTestConn(t)
defer db.Close()
txn, err := db.Begin()
if err != nil {
t.Fatal(err)
}
defer txn.Rollback()
var pid int
err = txn.QueryRow("SELECT pg_backend_pid()").Scan(&pid)
if err != nil {
t.Fatal(err)
}
_, err = txn.Exec("CREATE TEMP TABLE temp (a int)")
if err != nil {
t.Fatal(err)
}
stmt, err := txn.Prepare(CopyIn("temp", "a"))
if err != nil {
t.Fatal(err)
}
defer stmt.Close()
_, err = db.Exec("SELECT pg_terminate_backend($1)", pid)
if err != nil {
t.Fatal(err)
}
if getServerVersion(t, db) < 90500 {
// We have to try and send something over, since postgres before
// version 9.5 won't process SIGTERMs while it's waiting for
// CopyData/CopyEnd messages; see tcop/postgres.c.
_, err = stmt.Exec(1)
if err != nil {
t.Fatal(err)
}
}
_, err = stmt.Exec()
if err == nil {
t.Fatalf("expected error")
}
switch pge := err.(type) {
case *Error:
if pge.Code.Name() != "admin_shutdown" {
t.Fatalf("expected admin_shutdown, got %s", pge.Code.Name())
}
case *net.OpError:
// ignore
default:
if err == driver.ErrBadConn {
// likely an EPIPE
} else {
t.Fatalf("unexpected error, got %+#v", err)
}
}
_ = stmt.Close()
}
func BenchmarkCopyIn(b *testing.B) {
db := openTestConn(b)
defer db.Close()
txn, err := db.Begin()
if err != nil {
b.Fatal(err)
}
defer txn.Rollback()
_, err = txn.Exec("CREATE TEMP TABLE temp (a int, b varchar)")
if err != nil {
b.Fatal(err)
}
stmt, err := txn.Prepare(CopyIn("temp", "a", "b"))
if err != nil {
b.Fatal(err)
}
for i := 0; i < b.N; i++ {
_, err = stmt.Exec(int64(i), "hello world!")
if err != nil {
b.Fatal(err)
}
}
_, err = stmt.Exec()
if err != nil {
b.Fatal(err)
}
err = stmt.Close()
if err != nil {
b.Fatal(err)
}
var num int
err = txn.QueryRow("SELECT COUNT(*) FROM temp").Scan(&num)
if err != nil {
b.Fatal(err)
}
if num != b.N {
b.Fatalf("expected %d items, not %d", b.N, num)
}
}

2
vendor/github.com/lib/pq/doc.go generated vendored
View file

@ -239,7 +239,7 @@ for more information). Note that the channel name will be truncated to 63
bytes by the PostgreSQL server.
You can find a complete, working example of Listener usage at
https://godoc.org/github.com/lib/pq/example/listen.
http://godoc.org/github.com/lib/pq/example/listen.
*/
package pq

Some files were not shown because too many files have changed in this diff Show more