1
0
Fork 0
mirror of https://github.com/documize/community.git synced 2025-07-19 13:19:43 +02:00

Bump Go deps

This commit is contained in:
Harvey Kandola 2024-02-19 11:54:27 -05:00
parent f2ba294be8
commit acb59e1b43
91 changed files with 9004 additions and 513 deletions

View file

@ -1,5 +1,19 @@
# Changelog
## 1.6.0
### Changed
* Go.mod updated to Go 1.17
* Azure SDK for Go dependencies updated
### Features
* Added `ActiveDirectoryAzCli` and `ActiveDirectoryDeviceCode` authentication types to `azuread` package
* Always Encrypted encryption and decryption with 2 hour key cache (#116)
* 'pfx', 'MSSQL_CERTIFICATE_STORE', and 'AZURE_KEY_VAULT' encryption key providers
* TDS8 can now be used for connections by setting encrypt="strict"
## 1.5.0
### Features

View file

@ -1,4 +1,4 @@
# A pure Go MSSQL driver for Go's database/sql package
# Microsoft's official Go MSSQL driver
[![Go Reference](https://pkg.go.dev/badge/github.com/microsoft/go-mssqldb.svg)](https://pkg.go.dev/github.com/microsoft/go-mssqldb)
[![Build status](https://ci.appveyor.com/api/projects/status/jrln8cs62wj9i0a2?svg=true)](https://ci.appveyor.com/project/microsoft/go-mssqldb)
@ -7,7 +7,7 @@
## Install
Requires Go 1.10 or above.
Requires Go 1.17 or above.
Install with `go install github.com/microsoft/go-mssqldb@latest`.
@ -25,9 +25,10 @@ Other supported formats are listed below.
* `connection timeout` - in seconds (default is 0 for no timeout), set to 0 for no timeout. Recommended to set to 0 and use context to manage query and connection timeouts.
* `dial timeout` - in seconds (default is 15 times the number of registered protocols), set to 0 for no timeout.
* `encrypt`
* `strict` - Data sent between client and server is encrypted E2E using [TDS8](https://learn.microsoft.com/en-us/sql/relational-databases/security/networking/tds-8?view=sql-server-ver16).
* `disable` - Data send between client and server is not encrypted.
* `false` - Data sent between client and server is not encrypted beyond the login packet. (Default)
* `true` - Data sent between client and server is encrypted.
* `false`/`optional`/`no`/`0`/`f` - Data sent between client and server is not encrypted beyond the login packet. (Default)
* `true`/`mandatory`/`yes`/`1`/`t` - Data sent between client and server is encrypted.
* `app name` - The application name (default is go-mssqldb)
* `authenticator` - Can be used to specify use of a registered authentication provider. (e.g. ntlm, winsspi (on windows) or krb5 (on linux))
@ -56,13 +57,14 @@ Other supported formats are listed below.
* `TrustServerCertificate`
* false - Server certificate is checked. Default is false if encrypt is specified.
* true - Server certificate is not checked. Default is true if encrypt is not specified. If trust server certificate is true, driver accepts any certificate presented by the server and any host name in that certificate. In this mode, TLS is susceptible to man-in-the-middle attacks. This should be used only for testing.
* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates.
* `certificate` - The file that contains the public key certificate of the CA that signed the SQL Server certificate. The specified certificate overrides the go platform specific CA certificates. Currently, certificates of PEM type are supported.
* `hostNameInCertificate` - Specifies the Common Name (CN) in the server certificate. Default value is the server host.
* `tlsmin` - Specifies the minimum TLS version for negotiating encryption with the server. Recognized values are `1.0`, `1.1`, `1.2`, `1.3`. If not set to a recognized value the default value for the `tls` package will be used. The default is currently `1.2`.
* `ServerSPN` - The kerberos SPN (Service Principal Name) for the server. Default is MSSQLSvc/host:port.
* `Workstation ID` - The workstation name (default is the host name)
* `ApplicationIntent` - Can be given the value `ReadOnly` to initiate a read-only connection to an Availability Group listener. The `database` must be specified when connecting with `Application Intent` set to `ReadOnly`.
* `protocol` - forces use of a protocol. Make sure the corresponding package is imported.
* `columnencryption` or `column encryption setting` - a boolean value indicating whether Always Encrypted should be enabled on the connection.
### Connection parameters for namedpipe package
* `pipe` - If set, no Browser query is made and named pipe used will be `\\<host>\pipe\<pipe>`
@ -216,6 +218,8 @@ The credential type is determined by the new `fedauth` connection string paramet
* `resource id=<resource id>` - optional resource id of user-assigned managed identity. If empty, system-assigned managed identity or user id are used (if both user id and resource id are provided, resource id will be used)
* `fedauth=ActiveDirectoryInteractive` - authenticates using credentials acquired from an external web browser. Only suitable for use with human interaction.
* `applicationclientid=<application id>` - This guid identifies an Azure Active Directory enterprise application that the AAD admin has approved for accessing Azure SQL database resources in the tenant. This driver does not have an associated application id of its own.
* `fedauth=ActiveDirectoryDeviceCode` - prints a message to stdout giving the user a URL and code to authenticate. Connection continues after user completes the login separately.
* `fedauth=ActiveDirectoryAzCli` - reuses local authentication the user already performed using Azure CLI.
```go
@ -377,8 +381,63 @@ db.QueryContext(ctx, `select * from t2 where user_name = @p1;`, mssql.VarChar(na
// Note: Mismatched data types on table and parameter may cause long running queries
```
## Using Always Encrypted
The protocol and cryptography details for AE are [detailed elsewhere](https://learn.microsoft.com/sql/relational-databases/security/encryption/always-encrypted-database-engine?view=sql-server-ver16).
### Enablement
To enable AE on a connection, set the `ColumnEncryption` value to true on a config or pass `columnencryption=true` in the connection string.
Decryption and encryption won't succeed, however, without also including a decryption key provider. To avoid code size impacts on non-AE applications, key providers are not included by default.
Include the local certificate providers:
```go
import (
"github.com/microsoft/go-mssqldb/aecmk/localcert"
)
```
You can also instantiate a key provider directly in code and hand it to a `Connector` instance.
```go
c := mssql.NewConnectorConfig(myconfig)
c.RegisterCekProvider(providerName, MyProviderType{})
```
### Decryption
If the correct key provider is included in your application, decryption of encrypted cells happens automatically with no extra server round trips.
### Encryption
Encryption of parameters passed to `Exec` and `Query` variants requires an extra round trip per query to fetch the encryption metadata. If the error returned by a query attempt indicates a type mismatch between the parameter and the destination table, most likely your input type is not a strict match for the SQL Server data type of the destination. You may be using a Go `string` when you need to use one of the driver-specific aliases like `VarChar` or `NVarCharMax`.
*** NOTE *** - Currently `char` and `varchar` types do not include a collation parameter component so can't be used for inserting encrypted values. Also, using a nullable sql package type like `sql.NullableInt32` to pass a `NULL` value for an encrypted column will not work unless the encrypted column type is `nvarchar`.
https://github.com/microsoft/go-mssqldb/issues/129
https://github.com/microsoft/go-mssqldb/issues/130
### Local certificate AE key provider
Key provider configuration is managed separately without any properties in the connection string.
The `pfx` provider exposes its instance as the variable `PfxKeyProvider`. You can give it passwords for certificates using `SetCertificatePassword(pathToCertificate, path)`. Use an empty string or `"*"` as the path to use the same password for all certificates.
The `MSSQL_CERTIFICATE_STORE` provider exposes its instance as the variable `WindowsCertificateStoreKeyProvider`.
Both providers can be constrained to an allowed list of encryption key paths by appending paths to `provider.AllowedLocations`.
### Azure Key Vault (AZURE_KEY_VAULT) key provider
Import this provider using `github.com/microsoft/go-mssqldb/aecmk/akv`
Constrain the provider to an allowed list of key vaults by appending vault host strings like "mykeyvault.vault.azure.net" to `akv.KeyProvider.AllowedLocations`.
## Important Notes
* [LastInsertId](https://golang.org/pkg/database/sql/#Result.LastInsertId) should
not be used with this driver (or SQL Server) due to how the TDS protocol
works. Please use the [OUTPUT Clause](https://docs.microsoft.com/en-us/sql/t-sql/queries/output-clause-transact-sql)
@ -409,6 +468,9 @@ db.QueryContext(ctx, `select * from t2 where user_name = @p1;`, mssql.VarChar(na
* A `namedpipe` package to support connections using named pipes (np:) on Windows
* A `sharedmemory` package to support connections using shared memory (lpc:) on Windows
* Dedicated Administrator Connection (DAC) is supported using `admin` protocol
* Always Encrypted
- `MSSQL_CERTIFICATE_STORE` provider on Windows
- `pfx` provider on Linux and Windows
## Tests
@ -449,6 +511,7 @@ To fix SQL Server 2008 R2 issue, install SQL Server 2008 R2 Service Pack 2.
To fix SQL Server 2008 issue, install Microsoft SQL Server 2008 Service Pack 3 and Cumulative update package 3 for SQL Server 2008 SP3.
More information: <http://support.microsoft.com/kb/2653857>
* Bulk copy does not yet support encrypting column values using Always Encrypted. Tracked in [#127](https://github.com/microsoft/go-mssqldb/issues/127)
# Contributing
This project is a fork of [https://github.com/denisenkom/go-mssqldb](https://github.com/denisenkom/go-mssqldb) and welcomes new and previous contributors. For more informaton on contributing to this project, please see [Contributing](./CONTRIBUTING.md).

View file

@ -0,0 +1,112 @@
package aecmk
import (
"fmt"
"sync"
"time"
)
const (
CertificateStoreKeyProvider = "MSSQL_CERTIFICATE_STORE"
CspKeyProvider = "MSSQL_CSP_PROVIDER"
CngKeyProvider = "MSSQL_CNG_STORE"
AzureKeyVaultKeyProvider = "AZURE_KEY_VAULT"
JavaKeyProvider = "MSSQL_JAVA_KEYSTORE"
KeyEncryptionAlgorithm = "RSA_OAEP"
)
// ColumnEncryptionKeyLifetime is the default lifetime of decrypted Column Encryption Keys in the global cache.
// The default is 2 hours
var ColumnEncryptionKeyLifetime time.Duration = 2 * time.Hour
type cekCacheEntry struct {
Expiry time.Time
Key []byte
}
type cekCache map[string]cekCacheEntry
type CekProvider struct {
Provider ColumnEncryptionKeyProvider
decryptedKeys cekCache
mutex sync.Mutex
}
func NewCekProvider(provider ColumnEncryptionKeyProvider) *CekProvider {
return &CekProvider{Provider: provider, decryptedKeys: make(cekCache), mutex: sync.Mutex{}}
}
func (cp *CekProvider) GetDecryptedKey(keyPath string, encryptedBytes []byte) (decryptedKey []byte, err error) {
cp.mutex.Lock()
ev, cachedKey := cp.decryptedKeys[keyPath]
if cachedKey {
if ev.Expiry.Before(time.Now()) {
delete(cp.decryptedKeys, keyPath)
cachedKey = false
} else {
decryptedKey = ev.Key
}
}
// decrypting a key can take a while, so let multiple callers race
// Key providers can choose to optimize their own concurrency.
// For example - there's probably minimal value in serializing access to a local certificate,
// but there'd be high value in having a queue of waiters for decrypting a key stored in the cloud.
cp.mutex.Unlock()
if !cachedKey {
decryptedKey = cp.Provider.DecryptColumnEncryptionKey(keyPath, KeyEncryptionAlgorithm, encryptedBytes)
}
if !cachedKey {
duration := cp.Provider.KeyLifetime()
if duration == nil {
duration = &ColumnEncryptionKeyLifetime
}
expiry := time.Now().Add(*duration)
cp.mutex.Lock()
cp.decryptedKeys[keyPath] = cekCacheEntry{Expiry: expiry, Key: decryptedKey}
cp.mutex.Unlock()
}
return
}
// no synchronization on this map. Providers register during init.
type ColumnEncryptionKeyProviderMap map[string]*CekProvider
var globalCekProviderFactoryMap = ColumnEncryptionKeyProviderMap{}
// ColumnEncryptionKeyProvider is the interface for decrypting and encrypting column encryption keys.
// It is similar to .Net https://learn.microsoft.com/dotnet/api/microsoft.data.sqlclient.sqlcolumnencryptionkeystoreprovider.
type ColumnEncryptionKeyProvider interface {
// DecryptColumnEncryptionKey decrypts the specified encrypted value of a column encryption key.
// The encrypted value is expected to be encrypted using the column master key with the specified key path and using the specified algorithm.
DecryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, encryptedCek []byte) []byte
// EncryptColumnEncryptionKey encrypts a column encryption key using the column master key with the specified key path and using the specified algorithm.
EncryptColumnEncryptionKey(masterKeyPath string, encryptionAlgorithm string, cek []byte) []byte
// SignColumnMasterKeyMetadata digitally signs the column master key metadata with the column master key
// referenced by the masterKeyPath parameter. The input values used to generate the signature should be the
// specified values of the masterKeyPath and allowEnclaveComputations parameters. May return an empty slice if not supported.
SignColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) []byte
// VerifyColumnMasterKeyMetadata verifies the specified signature is valid for the column master key
// with the specified key path and the specified enclave behavior. Return nil if not supported.
VerifyColumnMasterKeyMetadata(masterKeyPath string, allowEnclaveComputations bool) *bool
// KeyLifetime is an optional Duration. Keys fetched by this provider will be discarded after their lifetime expires.
// If it returns nil, the keys will expire based on the value of ColumnEncryptionKeyLifetime.
// If it returns zero, the keys will not be cached.
KeyLifetime() *time.Duration
}
func RegisterCekProvider(name string, provider ColumnEncryptionKeyProvider) error {
_, ok := globalCekProviderFactoryMap[name]
if ok {
return fmt.Errorf("CEK provider %s is already registered", name)
}
globalCekProviderFactoryMap[name] = &CekProvider{Provider: provider, decryptedKeys: cekCache{}, mutex: sync.Mutex{}}
return nil
}
func GetGlobalCekProviders() (providers ColumnEncryptionKeyProviderMap) {
providers = make(ColumnEncryptionKeyProviderMap)
for i, p := range globalCekProviderFactoryMap {
providers[i] = p
}
return
}

View file

@ -11,52 +11,29 @@ environment:
SQLUSER: sa
SQLPASSWORD: Password12!
DATABASE: test
GOVERSION: 113
GOVERSION: 117
COLUMNENCRYPTION:
APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
RACE: -race -cpu 4
TAGS:
matrix:
- GOVERSION: 110
SQLINSTANCE: SQL2017
- GOVERSION: 111
SQLINSTANCE: SQL2017
- GOVERSION: 112
SQLINSTANCE: SQL2017
- SQLINSTANCE: SQL2017
- SQLINSTANCE: SQL2016
- SQLINSTANCE: SQL2014
- SQLINSTANCE: SQL2012SP1
- SQLINSTANCE: SQL2008R2SP2
# Go 1.14+ and SQL2019 are available on the Visual Studio 2019 image only
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 114
- GOVERSION: 118
SQLINSTANCE: SQL2017
- GOVERSION: 120
RACE:
SQLINSTANCE: SQL2019
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 115
SQLINSTANCE: SQL2019
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 115
SQLINSTANCE: SQL2017
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 116
SQLINSTANCE: SQL2017
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 117
SQLINSTANCE: SQL2017
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 118
SQLINSTANCE: SQL2017
COLUMNENCRYPTION: 1
# Cover 32bit and named pipes protocol
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 118-x86
- GOVERSION: 119-x86
SQLINSTANCE: SQL2017
GOARCH: 386
RACE:
PROTOCOL: np
TAGS: -tags np
# Cover SSPI and lpc protocol
- APPVEYOR_BUILD_WORKER_IMAGE: Visual Studio 2019
GOVERSION: 118
- GOVERSION: 120
RACE:
SQLINSTANCE: SQL2019
PROTOCOL: lpc
TAGS: -tags sm
@ -67,9 +44,6 @@ install:
- set PATH=%GOPATH%\bin;%GOROOT%\bin;%PATH%
- go version
- go env
- go get -u github.com/golang-sql/civil
- go get -u github.com/golang-sql/sqlexp
- go get -u golang.org/x/crypto/md4
build_script:
- go build

View file

@ -250,6 +250,10 @@ func (b *Bulk) createColMetadata() []byte {
buf.WriteByte(byte(tokenColMetadata)) // token
binary.Write(buf, binary.LittleEndian, uint16(len(b.bulkColumns))) // column count
// TODO: Write a valid CEK table if any parameters have cekTableEntry values
if b.cn.sess.alwaysEncrypted {
binary.Write(buf, binary.LittleEndian, uint16(0))
}
for i, col := range b.bulkColumns {
if b.cn.sess.loginAck.TDSVersion >= verTDS72 {

View file

@ -0,0 +1,40 @@
package mssql
const (
CertificateStoreKeyProvider = "MSSQL_CERTIFICATE_STORE"
CspKeyProvider = "MSSQL_CSP_PROVIDER"
CngKeyProvider = "MSSQL_CNG_STORE"
AzureKeyVaultKeyProvider = "AZURE_KEY_VAULT"
JavaKeyProvider = "MSSQL_JAVA_KEYSTORE"
KeyEncryptionAlgorithm = "RSA_OAEP"
)
// cek ==> Column Encryption Key
// Every row of an encrypted table has an associated list of keys used to decrypt its columns
type cekTable struct {
entries []cekTableEntry
}
type encryptionKeyInfo struct {
encryptedKey []byte
databaseID int
cekID int
cekVersion int
cekMdVersion []byte
keyPath string
keyStoreName string
algorithmName string
}
type cekTableEntry struct {
databaseID int
keyId int
keyVersion int
mdVersion []byte
valueCount int
cekValues []encryptionKeyInfo
}
func newCekTable(size uint16) cekTable {
return cekTable{entries: make([]cekTableEntry, size)}
}

292
vendor/github.com/microsoft/go-mssqldb/encrypt.go generated vendored Normal file
View file

@ -0,0 +1,292 @@
package mssql
import (
"context"
"database/sql/driver"
"encoding/binary"
"fmt"
"io"
"strings"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys"
)
type ColumnEncryptionType int
var (
ColumnEncryptionPlainText ColumnEncryptionType = 0
ColumnEncryptionDeterministic ColumnEncryptionType = 1
ColumnEncryptionRandomized ColumnEncryptionType = 2
)
type cekData struct {
ordinal int
database_id int
id int
version int
metadataVersion []byte
encryptedValue []byte
cmkStoreName string
cmkPath string
algorithm string
//byEnclave bool
//cmkSignature string
decryptedValue []byte
}
type parameterEncData struct {
ordinal int
name string
algorithm int
encType ColumnEncryptionType
cekOrdinal int
ruleVersion int
}
type paramMapEntry struct {
cek *cekData
p *parameterEncData
}
// when Always Encrypted is turned on, we have to ask the server for metadata about how to encrypt input parameters.
// This function stores the relevant encryption parameters in a copy of the args so they can be
// encrypted just before being sent to the server
func (s *Stmt) encryptArgs(ctx context.Context, args []namedValue) (encryptedArgs []namedValue, err error) {
q := Stmt{c: s.c,
paramCount: s.paramCount,
query: "sp_describe_parameter_encryption",
skipEncryption: true,
}
oldouts := s.c.outs
s.c.clearOuts()
newArgs, err := s.prepareEncryptionQuery(isProc(s.query), s.query, args)
if err != nil {
return
}
// TODO: Consider not using recursion.
rows, err := q.queryContext(ctx, newArgs)
if err != nil {
s.c.outs = oldouts
return
}
cekInfo, paramsInfo, err := processDescribeParameterEncryption(rows)
rows.Close()
s.c.outs = oldouts
if err != nil {
return
}
if len(cekInfo) == 0 {
return args, nil
}
err = s.decryptCek(cekInfo)
if err != nil {
return
}
paramMap := make(map[string]paramMapEntry)
for _, p := range paramsInfo {
if p.encType == ColumnEncryptionPlainText {
paramMap[p.name] = paramMapEntry{nil, p}
} else {
paramMap[p.name] = paramMapEntry{cekInfo[p.cekOrdinal-1], p}
}
}
encryptedArgs = make([]namedValue, len(args))
for i, a := range args {
encryptedArgs[i] = a
name := ""
if len(a.Name) > 0 {
name = "@" + a.Name
} else {
name = fmt.Sprintf("@p%d", a.Ordinal)
}
info := paramMap[name]
if info.p.encType == ColumnEncryptionPlainText || a.Value == nil {
continue
}
encryptedArgs[i].encrypt = getEncryptor(info)
}
return encryptedArgs, nil
}
// returns the arguments to sp_describe_parameter_encryption
// sp_describe_parameter_encryption
// [ @tsql = ] N'Transact-SQL_batch' ,
// [ @params = ] N'parameters'
// [ ;]
func (s *Stmt) prepareEncryptionQuery(isProc bool, q string, args []namedValue) (newArgs []namedValue, err error) {
newArgs = make([]namedValue, 2)
if isProc {
newArgs[0] = namedValue{Name: "tsql", Ordinal: 0, Value: buildStoredProcedureStatementForColumnEncryption(q, args)}
} else {
newArgs[0] = namedValue{Name: "tsql", Ordinal: 0, Value: q}
}
params, err := s.buildParametersForColumnEncryption(args)
if err != nil {
return
}
newArgs[1] = namedValue{Name: "params", Ordinal: 1, Value: params}
return
}
func (s *Stmt) buildParametersForColumnEncryption(args []namedValue) (parameters string, err error) {
_, decls, err := s.makeRPCParams(args, false)
if err != nil {
return
}
parameters = strings.Join(decls, ", ")
return
}
func (s *Stmt) decryptCek(cekInfo []*cekData) error {
for _, info := range cekInfo {
kp, ok := s.c.sess.aeSettings.keyProviders[info.cmkStoreName]
if !ok {
return fmt.Errorf("No provider found for key store %s", info.cmkStoreName)
}
dk, err := kp.GetDecryptedKey(info.cmkPath, info.encryptedValue)
if err != nil {
return err
}
info.decryptedValue = dk
}
return nil
}
func getEncryptor(info paramMapEntry) valueEncryptor {
k := keys.NewAeadAes256CbcHmac256(info.cek.decryptedValue)
alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encryption.From(byte(info.p.encType)), byte(info.cek.version))
// Metadata to append to an encrypted parameter. Doesn't include original typeinfo
// https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/619c43b6-9495-4a58-9e49-a4950db245b3
// ParamCipherInfo = TYPE_INFO
// EncryptionAlgo (byte)
// [AlgoName] (b_varchar) unused, no custom algorithm
// EncryptionType (byte)
// DatabaseId (ulong)
// CekId (ulong)
// CekVersion (ulong)
// CekMDVersion (ulonglong) - really a byte array
// NormVersion (byte)
// algo+ enctype+ dbid+ keyid+ keyver+ normversion
metadataLen := 1 + 1 + 4 + 4 + 4 + 1
metadataLen += len(info.cek.metadataVersion)
metadata := make([]byte, metadataLen)
offset := 0
// AEAD_AES_256_CBC_HMAC_SHA256
metadata[offset] = byte(info.p.algorithm)
offset++
metadata[offset] = byte(info.p.encType)
offset++
binary.LittleEndian.PutUint32(metadata[offset:], uint32(info.cek.database_id))
offset += 4
binary.LittleEndian.PutUint32(metadata[offset:], uint32(info.cek.id))
offset += 4
binary.LittleEndian.PutUint32(metadata[offset:], uint32(info.cek.version))
offset += 4
copy(metadata[offset:], info.cek.metadataVersion)
offset += len(info.cek.metadataVersion)
metadata[offset] = byte(info.p.ruleVersion)
return func(b []byte) ([]byte, []byte, error) {
encryptedData, err := alg.Encrypt(b)
if err != nil {
return nil, nil, err
}
return encryptedData, metadata, nil
}
}
// Based on the .Net implementation at https://github.com/dotnet/SqlClient/blob/2b31810ce69b88d707450e2059ee8fbde63f774f/src/Microsoft.Data.SqlClient/netcore/src/Microsoft/Data/SqlClient/SqlCommand.cs#L6040
func buildStoredProcedureStatementForColumnEncryption(sproc string, args []namedValue) string {
b := new(strings.Builder)
_, _ = b.WriteString("EXEC ")
q := TSQLQuoter{}
sproc = q.ID(sproc)
b.WriteString(sproc)
// Unlike ADO.Net, go-mssqldb doesn't support ReturnValue named parameters
first := true
for _, a := range args {
if !first {
b.WriteRune(',')
}
first = false
b.WriteRune(' ')
name := a.Name
if len(name) == 0 {
name = fmt.Sprintf("@p%d", a.Ordinal)
}
appendPrefixedParameterName(b, name)
if len(a.Name) > 0 {
b.WriteRune('=')
appendPrefixedParameterName(b, a.Name)
}
if isOutputValue(a.Value) {
b.WriteString(" OUTPUT")
}
}
return b.String()
}
func appendPrefixedParameterName(b *strings.Builder, p string) {
if len(p) > 0 {
if p[0] != '@' {
b.WriteRune('@')
}
b.WriteString(p)
}
}
func processDescribeParameterEncryption(rows driver.Rows) (cekInfo []*cekData, paramInfo []*parameterEncData, err error) {
cekInfo = make([]*cekData, 0)
values := make([]driver.Value, 9)
qerr := rows.Next(values)
for qerr == nil {
cekInfo = append(cekInfo, &cekData{ordinal: int(values[0].(int64)),
database_id: int(values[1].(int64)),
id: int(values[2].(int64)),
version: int(values[3].(int64)),
metadataVersion: values[4].([]byte),
encryptedValue: values[5].([]byte),
cmkStoreName: values[6].(string),
cmkPath: values[7].(string),
algorithm: values[8].(string),
})
qerr = rows.Next(values)
}
if len(cekInfo) == 0 || qerr != io.EOF {
if qerr != io.EOF {
err = qerr
}
// No encryption needed
return
}
r := rows.(driver.RowsNextResultSet)
err = r.NextResultSet()
if err != nil {
return
}
paramInfo = make([]*parameterEncData, 0)
qerr = rows.Next(values[:6])
for qerr == nil {
paramInfo = append(paramInfo, &parameterEncData{ordinal: int(values[0].(int64)),
name: values[1].(string),
algorithm: int(values[2].(int64)),
encType: ColumnEncryptionType(values[3].(int64)),
cekOrdinal: int(values[4].(int64)),
ruleVersion: int(values[5].(int64)),
})
qerr = rows.Next(values[:6])
}
if len(paramInfo) == 0 || qerr != io.EOF {
if qerr != io.EOF {
err = qerr
} else {
err = fmt.Errorf("No parameter encryption rows were returned from sp_describe_parameter_encryption")
}
}
return
}

View file

@ -0,0 +1,20 @@
Copyright (c) 2021 Swisscom (Switzerland) Ltd
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

View file

@ -0,0 +1,120 @@
package algorithms
import (
"crypto/rand"
"crypto/subtle"
"fmt"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys"
)
// https://tools.ietf.org/html/draft-mcgrew-aead-aes-cbc-hmac-sha2-05
// https://winprotocoldoc.blob.core.windows.net/productionwindowsarchives/MS-TDS/%5bMS-TDS%5d.pdf
var _ Algorithm = &AeadAes256CbcHmac256Algorithm{}
type AeadAes256CbcHmac256Algorithm struct {
algorithmVersion byte
deterministic bool
blockSizeBytes int
keySizeBytes int
minimumCipherTextLengthBytesNoAuthTag int
minimumCipherTextLengthBytesWithAuthTag int
cek keys.AeadAes256CbcHmac256
version []byte
versionSize []byte
}
func NewAeadAes256CbcHmac256Algorithm(key keys.AeadAes256CbcHmac256, encType encryption.Type, algorithmVersion byte) AeadAes256CbcHmac256Algorithm {
const keySizeBytes = 256 / 8
const blockSizeBytes = 16
const minimumCipherTextLengthBytesNoAuthTag = 1 + 2*blockSizeBytes
const minimumCipherTextLengthBytesWithAuthTag = minimumCipherTextLengthBytesNoAuthTag + keySizeBytes
a := AeadAes256CbcHmac256Algorithm{
algorithmVersion: algorithmVersion,
deterministic: encType.Deterministic,
blockSizeBytes: blockSizeBytes,
keySizeBytes: keySizeBytes,
cek: key,
minimumCipherTextLengthBytesNoAuthTag: minimumCipherTextLengthBytesNoAuthTag,
minimumCipherTextLengthBytesWithAuthTag: minimumCipherTextLengthBytesWithAuthTag,
version: []byte{0x01},
versionSize: []byte{1},
}
a.version[0] = algorithmVersion
return a
}
func (a *AeadAes256CbcHmac256Algorithm) Encrypt(cleartext []byte) ([]byte, error) {
buf := make([]byte, 0)
var iv []byte
if a.deterministic {
iv = crypto.Sha256Hmac(cleartext, a.cek.IvKey())
if len(iv) > a.blockSizeBytes {
iv = iv[:a.blockSizeBytes]
}
} else {
iv = make([]byte, a.blockSizeBytes)
_, err := rand.Read(iv)
if err != nil {
panic(err)
}
}
buf = append(buf, a.algorithmVersion)
aescdbc := crypto.NewAESCbcPKCS5(a.cek.EncryptionKey(), iv)
ciphertext := aescdbc.Encrypt(cleartext)
authTag := a.prepareAuthTag(iv, ciphertext)
buf = append(buf, authTag...)
buf = append(buf, iv...)
buf = append(buf, ciphertext...)
return buf, nil
}
func (a *AeadAes256CbcHmac256Algorithm) Decrypt(ciphertext []byte) ([]byte, error) {
// This algorithm always has the auth tag!
minimumCiphertextLength := a.minimumCipherTextLengthBytesWithAuthTag
if len(ciphertext) < minimumCiphertextLength {
return nil, fmt.Errorf("invalid ciphertext length: at least %v bytes expected", minimumCiphertextLength)
}
idx := 0
if ciphertext[idx] != a.algorithmVersion {
return nil, fmt.Errorf("invalid algorithm version used: %v found but %v expected", ciphertext[idx],
a.algorithmVersion)
}
idx++
authTag := ciphertext[idx : idx+a.keySizeBytes]
idx += a.keySizeBytes
iv := ciphertext[idx : idx+a.blockSizeBytes]
idx += len(iv)
realCiphertext := ciphertext[idx:]
ourAuthTag := a.prepareAuthTag(iv, realCiphertext)
// bytes.Compare is subject to timing attacks
if subtle.ConstantTimeCompare(ourAuthTag, authTag) != 1 {
return nil, fmt.Errorf("invalid auth tag")
}
// decrypt
aescdbc := crypto.NewAESCbcPKCS5(a.cek.EncryptionKey(), iv)
cleartext := aescdbc.Decrypt(realCiphertext)
return cleartext, nil
}
func (a *AeadAes256CbcHmac256Algorithm) prepareAuthTag(iv []byte, ciphertext []byte) []byte {
var input = make([]byte, 0)
input = append(input, a.algorithmVersion)
input = append(input, iv...)
input = append(input, ciphertext...)
input = append(input, a.versionSize...)
return crypto.Sha256Hmac(input, a.cek.MacKey())
}

View file

@ -0,0 +1,6 @@
package algorithms
type Algorithm interface {
Encrypt([]byte) ([]byte, error)
Decrypt([]byte) ([]byte, error)
}

View file

@ -0,0 +1,69 @@
package crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"fmt"
)
// Inspired by: https://gist.github.com/hothero/7d085573f5cb7cdb5801d7adcf66dcf3
type AESCbcPKCS5 struct {
key []byte
iv []byte
block cipher.Block
}
func NewAESCbcPKCS5(key []byte, iv []byte) AESCbcPKCS5 {
a := AESCbcPKCS5{
key: key,
iv: iv,
block: nil,
}
a.initCipher()
return a
}
func (a AESCbcPKCS5) Encrypt(cleartext []byte) (cipherText []byte) {
if a.block == nil {
a.initCipher()
}
blockMode := cipher.NewCBCEncrypter(a.block, a.iv)
paddedCleartext := PKCS5Padding(cleartext, blockMode.BlockSize())
cipherText = make([]byte, len(paddedCleartext))
blockMode.CryptBlocks(cipherText, paddedCleartext)
return
}
func (a AESCbcPKCS5) Decrypt(ciphertext []byte) []byte {
if a.block == nil {
a.initCipher()
}
blockMode := cipher.NewCBCDecrypter(a.block, a.iv)
var cleartext = make([]byte, len(ciphertext))
blockMode.CryptBlocks(cleartext, ciphertext)
return PKCS5Trim(cleartext)
}
func PKCS5Padding(inArr []byte, blockSize int) []byte {
padding := blockSize - len(inArr)%blockSize
padText := bytes.Repeat([]byte{byte(padding)}, padding)
return append(inArr, padText...)
}
func PKCS5Trim(inArr []byte) []byte {
padding := inArr[len(inArr)-1]
return inArr[:len(inArr)-int(padding)]
}
func (a *AESCbcPKCS5) initCipher() {
block, err := aes.NewCipher(a.key)
if err != nil {
panic(fmt.Errorf("unable to create cipher: %v", err))
}
a.block = block
}

View file

@ -0,0 +1,12 @@
package crypto
import (
"crypto/hmac"
"crypto/sha256"
)
func Sha256Hmac(input []byte, key []byte) []byte {
sha256Hmac := hmac.New(sha256.New, key)
sha256Hmac.Write(input)
return sha256Hmac.Sum(nil)
}

View file

@ -0,0 +1,37 @@
package encryption
type Type struct {
Deterministic bool
Name string
Value byte
}
var Plaintext = Type{
Deterministic: false,
Name: "Plaintext",
Value: 0,
}
var Deterministic = Type{
Deterministic: true,
Name: "Deterministic",
Value: 1,
}
var Randomized = Type{
Deterministic: false,
Name: "Randomized",
Value: 2,
}
func From(encType byte) Type {
switch encType {
case 0:
return Plaintext
case 1:
return Deterministic
case 2:
return Randomized
}
return Plaintext
}

View file

@ -0,0 +1,51 @@
package keys
import (
"fmt"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/crypto"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/utils"
)
var _ Key = &AeadAes256CbcHmac256{}
type AeadAes256CbcHmac256 struct {
rootKey []byte
encryptionKey []byte
macKey []byte
ivKey []byte
}
func NewAeadAes256CbcHmac256(rootKey []byte) AeadAes256CbcHmac256 {
const keySize = 256
const encryptionKeySaltFormat = "Microsoft SQL Server cell encryption key with encryption algorithm:%v and key length:%v"
const macKeySaltFormat = "Microsoft SQL Server cell MAC key with encryption algorithm:%v and key length:%v"
const ivKeySaltFormat = "Microsoft SQL Server cell IV key with encryption algorithm:%v and key length:%v"
const algorithmName = "AEAD_AES_256_CBC_HMAC_SHA256"
encryptionKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(encryptionKeySaltFormat, algorithmName, keySize))
macKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(macKeySaltFormat, algorithmName, keySize))
ivKeySalt := utils.ProcessUTF16LE(fmt.Sprintf(ivKeySaltFormat, algorithmName, keySize))
return AeadAes256CbcHmac256{
rootKey: rootKey,
encryptionKey: crypto.Sha256Hmac(encryptionKeySalt, rootKey),
macKey: crypto.Sha256Hmac(macKeySalt, rootKey),
ivKey: crypto.Sha256Hmac(ivKeySalt, rootKey)}
}
func (a AeadAes256CbcHmac256) IvKey() []byte {
return a.ivKey
}
func (a AeadAes256CbcHmac256) MacKey() []byte {
return a.macKey
}
func (a AeadAes256CbcHmac256) EncryptionKey() []byte {
return a.encryptionKey
}
func (a AeadAes256CbcHmac256) RootKey() []byte {
return a.rootKey
}

View file

@ -0,0 +1,5 @@
package keys
type Key interface {
RootKey() []byte
}

View file

@ -0,0 +1,18 @@
package utils
import (
"encoding/binary"
"unicode/utf16"
)
func ConvertUTF16ToLittleEndianBytes(u []uint16) []byte {
b := make([]byte, 2*len(u))
for index, value := range u {
binary.LittleEndian.PutUint16(b[index*2:], value)
}
return b
}
func ProcessUTF16LE(inputString string) []byte {
return ConvertUTF16ToLittleEndianBytes(utf16.Encode([]rune(inputString)))
}

View file

@ -21,10 +21,17 @@ type (
BrowserMsg byte
)
const (
DsnTypeURL = 1
DsnTypeOdbc = 2
DsnTypeAdo = 3
)
const (
EncryptionOff = 0
EncryptionRequired = 1
EncryptionDisabled = 3
EncryptionStrict = 4
)
const (
@ -44,6 +51,34 @@ const (
BrowserDAC BrowserMsg = 0x0f
)
const (
Database = "database"
Encrypt = "encrypt"
Password = "password"
ChangePassword = "change password"
UserID = "user id"
Port = "port"
TrustServerCertificate = "trustservercertificate"
Certificate = "certificate"
TLSMin = "tlsmin"
PacketSize = "packet size"
LogParam = "log"
ConnectionTimeout = "connection timeout"
HostNameInCertificate = "hostnameincertificate"
KeepAlive = "keepalive"
ServerSpn = "serverspn"
WorkstationID = "workstation id"
AppName = "app name"
ApplicationIntent = "applicationintent"
FailoverPartner = "failoverpartner"
FailOverPort = "failoverport"
DisableRetry = "disableretry"
Server = "server"
Protocol = "protocol"
DialTimeout = "dial timeout"
Pipe = "pipe"
)
type Config struct {
Port uint64
Host string
@ -88,6 +123,10 @@ type Config struct {
ProtocolParameters map[string]interface{}
// BrowserMsg is the message identifier to fetch instance data from SQL browser
BrowserMessage BrowserMsg
// ChangePassword is used to set the login's password during login. Ignored for non-SQL authentication.
ChangePassword string
//ColumnEncryption is true if the application needs to decrypt or encrypt Always Encrypted values
ColumnEncryption bool
}
// Build a tls.Config object from the supplied certificate.
@ -128,24 +167,26 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e
trustServerCert := false
var encryption Encryption = EncryptionOff
encrypt, ok := params["encrypt"]
encrypt, ok := params[Encrypt]
if ok {
if strings.EqualFold(encrypt, "DISABLE") {
encrypt = strings.ToLower(encrypt)
switch encrypt {
case "mandatory", "yes", "1", "t", "true":
encryption = EncryptionRequired
case "disable":
encryption = EncryptionDisabled
} else {
e, err := strconv.ParseBool(encrypt)
if err != nil {
f := "invalid encrypt '%s': %s"
return encryption, nil, fmt.Errorf(f, encrypt, err.Error())
}
if e {
encryption = EncryptionRequired
}
case "strict":
encryption = EncryptionStrict
case "optional", "no", "0", "f", "false":
encryption = EncryptionOff
default:
f := "invalid encrypt '%s'"
return encryption, nil, fmt.Errorf(f, encrypt)
}
} else {
trustServerCert = true
}
trust, ok := params["trustservercertificate"]
trust, ok := params[TrustServerCertificate]
if ok {
var err error
trustServerCert, err = strconv.ParseBool(trust)
@ -154,9 +195,12 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e
return encryption, nil, fmt.Errorf(f, trust, err.Error())
}
}
certificate := params["certificate"]
certificate := params[Certificate]
if encryption != EncryptionDisabled {
tlsMin := params["tlsmin"]
tlsMin := params[TLSMin]
if encrypt == "strict" {
trustServerCert = false
}
tlsConfig, err := SetupTLS(certificate, trustServerCert, host, tlsMin)
if err != nil {
return encryption, nil, fmt.Errorf("failed to setup TLS: %w", err)
@ -168,6 +212,38 @@ func parseTLS(params map[string]string, host string) (Encryption, *tls.Config, e
var skipSetup = errors.New("skip setting up TLS")
func getDsnType(dsn string) int {
if strings.HasPrefix(dsn, "sqlserver://") {
return DsnTypeURL
}
if strings.HasPrefix(dsn, "odbc:") {
return DsnTypeOdbc
}
return DsnTypeAdo
}
func getDsnParams(dsn string) (map[string]string, error) {
var params map[string]string
var err error
switch getDsnType(dsn) {
case DsnTypeOdbc:
params, err = splitConnectionStringOdbc(dsn[len("odbc:"):])
if err != nil {
return params, err
}
case DsnTypeURL:
params, err = splitConnectionStringURL(dsn)
if err != nil {
return params, err
}
default:
params = splitConnectionString(dsn)
}
return params, nil
}
func Parse(dsn string) (Config, error) {
p := Config{
ProtocolParameters: map[string]interface{}{},
@ -176,23 +252,14 @@ func Parse(dsn string) (Config, error) {
var params map[string]string
var err error
if strings.HasPrefix(dsn, "odbc:") {
params, err = splitConnectionStringOdbc(dsn[len("odbc:"):])
if err != nil {
return p, err
}
} else if strings.HasPrefix(dsn, "sqlserver://") {
params, err = splitConnectionStringURL(dsn)
if err != nil {
return p, err
}
} else {
params = splitConnectionString(dsn)
}
params, err = getDsnParams(dsn)
if err != nil {
return p, err
}
p.Parameters = params
strlog, ok := params["log"]
strlog, ok := params[LogParam]
if ok {
flags, err := strconv.ParseUint(strlog, 10, 64)
if err != nil {
@ -201,12 +268,12 @@ func Parse(dsn string) (Config, error) {
p.LogFlags = Log(flags)
}
p.Database = params["database"]
p.User = params["user id"]
p.Password = params["password"]
p.Database = params[Database]
p.User = params[UserID]
p.Password = params[Password]
p.ChangePassword = params[ChangePassword]
p.Port = 0
strport, ok := params["port"]
strport, ok := params[Port]
if ok {
var err error
p.Port, err = strconv.ParseUint(strport, 10, 16)
@ -217,7 +284,7 @@ func Parse(dsn string) (Config, error) {
}
// https://docs.microsoft.com/en-us/sql/database-engine/configure-windows/configure-the-network-packet-size-server-configuration-option\
strpsize, ok := params["packet size"]
strpsize, ok := params[PacketSize]
if ok {
var err error
psize, err := strconv.ParseUint(strpsize, 0, 16)
@ -242,7 +309,7 @@ func Parse(dsn string) (Config, error) {
//
// Do not set a connection timeout. Use Context to manage such things.
// Default to zero, but still allow it to be set.
if strconntimeout, ok := params["connection timeout"]; ok {
if strconntimeout, ok := params[ConnectionTimeout]; ok {
timeout, err := strconv.ParseUint(strconntimeout, 10, 64)
if err != nil {
f := "invalid connection timeout '%v': %v"
@ -254,7 +321,7 @@ func Parse(dsn string) (Config, error) {
// default keep alive should be 30 seconds according to spec:
// https://msdn.microsoft.com/en-us/library/dd341108.aspx
p.KeepAlive = 30 * time.Second
if keepAlive, ok := params["keepalive"]; ok {
if keepAlive, ok := params[KeepAlive]; ok {
timeout, err := strconv.ParseUint(keepAlive, 10, 64)
if err != nil {
f := "invalid keepAlive value '%s': %s"
@ -263,12 +330,12 @@ func Parse(dsn string) (Config, error) {
p.KeepAlive = time.Duration(timeout) * time.Second
}
serverSPN, ok := params["serverspn"]
serverSPN, ok := params[ServerSpn]
if ok {
p.ServerSPN = serverSPN
} // If not set by the app, ServerSPN will be set by the successful dialer.
workstation, ok := params["workstation id"]
workstation, ok := params[WorkstationID]
if ok {
p.Workstation = workstation
} else {
@ -278,13 +345,13 @@ func Parse(dsn string) (Config, error) {
}
}
appname, ok := params["app name"]
appname, ok := params[AppName]
if !ok {
appname = "go-mssqldb"
}
p.AppName = appname
appintent, ok := params["applicationintent"]
appintent, ok := params[ApplicationIntent]
if ok {
if appintent == "ReadOnly" {
if p.Database == "" {
@ -294,12 +361,12 @@ func Parse(dsn string) (Config, error) {
}
}
failOverPartner, ok := params["failoverpartner"]
failOverPartner, ok := params[FailoverPartner]
if ok {
p.FailOverPartner = failOverPartner
}
failOverPort, ok := params["failoverport"]
failOverPort, ok := params[FailOverPort]
if ok {
var err error
p.FailOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
@ -309,7 +376,7 @@ func Parse(dsn string) (Config, error) {
}
}
disableRetry, ok := params["disableretry"]
disableRetry, ok := params[DisableRetry]
if ok {
var err error
p.DisableRetry, err = strconv.ParseBool(disableRetry)
@ -321,8 +388,8 @@ func Parse(dsn string) (Config, error) {
p.DisableRetry = disableRetryDefault
}
server := params["server"]
protocol, ok := params["protocol"]
server := params[Server]
protocol, ok := params[Protocol]
for _, parser := range ProtocolParsers {
if (!ok && !parser.Hidden()) || parser.Protocol() == protocol {
@ -348,7 +415,7 @@ func Parse(dsn string) (Config, error) {
f = 1
}
p.DialTimeout = time.Duration(15*f) * time.Second
if strdialtimeout, ok := params["dial timeout"]; ok {
if strdialtimeout, ok := params[DialTimeout]; ok {
timeout, err := strconv.ParseUint(strdialtimeout, 10, 64)
if err != nil {
f := "invalid dial timeout '%v': %v"
@ -358,7 +425,7 @@ func Parse(dsn string) (Config, error) {
p.DialTimeout = time.Duration(timeout) * time.Second
}
hostInCertificate, ok := params["hostnameincertificate"]
hostInCertificate, ok := params[HostNameInCertificate]
if ok {
p.HostInCertificateProvided = true
} else {
@ -371,6 +438,19 @@ func Parse(dsn string) (Config, error) {
return p, err
}
if c, ok := params["columnencryption"]; ok {
columnEncryption, err := strconv.ParseBool(c)
if err != nil {
if strings.EqualFold(c, "Enabled") {
columnEncryption = true
} else if strings.EqualFold(c, "Disabled") {
columnEncryption = false
} else {
return p, fmt.Errorf("invalid columnencryption '%v' : %v", columnEncryption, err.Error())
}
}
p.ColumnEncryption = columnEncryption
}
return p, nil
}
@ -379,10 +459,10 @@ func Parse(dsn string) (Config, error) {
func (p Config) URL() *url.URL {
q := url.Values{}
if p.Database != "" {
q.Add("database", p.Database)
q.Add(Database, p.Database)
}
if p.LogFlags != 0 {
q.Add("log", strconv.FormatUint(uint64(p.LogFlags), 10))
q.Add(LogParam, strconv.FormatUint(uint64(p.LogFlags), 10))
}
host := p.Host
protocol := ""
@ -397,8 +477,8 @@ func (p Config) URL() *url.URL {
if p.Port > 0 {
host = fmt.Sprintf("%s:%d", host, p.Port)
}
q.Add("disableRetry", fmt.Sprintf("%t", p.DisableRetry))
protocolParam, ok := p.Parameters["protocol"]
q.Add(DisableRetry, fmt.Sprintf("%t", p.DisableRetry))
protocolParam, ok := p.Parameters[Protocol]
if ok {
if protocol != "" && protocolParam != protocol {
panic("Mismatched protocol parameters!")
@ -406,11 +486,11 @@ func (p Config) URL() *url.URL {
protocol = protocolParam
}
if protocol != "" {
q.Add("protocol", protocol)
q.Add(Protocol, protocol)
}
pipe, ok := p.Parameters["pipe"]
pipe, ok := p.Parameters[Pipe]
if ok {
q.Add("pipe", pipe)
q.Add(Pipe, pipe)
}
res := url.URL{
Scheme: "sqlserver",
@ -420,7 +500,17 @@ func (p Config) URL() *url.URL {
if p.Instance != "" {
res.Path = p.Instance
}
q.Add("dial timeout", strconv.FormatFloat(float64(p.DialTimeout.Seconds()), 'f', 0, 64))
q.Add(DialTimeout, strconv.FormatFloat(float64(p.DialTimeout.Seconds()), 'f', 0, 64))
switch p.Encryption {
case EncryptionDisabled:
q.Add(Encrypt, "DISABLE")
case EncryptionRequired:
q.Add(Encrypt, "true")
}
if p.ColumnEncryption {
q.Add("columnencryption", "true")
}
if len(q) > 0 {
res.RawQuery = q.Encode()
}
@ -428,15 +518,17 @@ func (p Config) URL() *url.URL {
return &res
}
// ADO connection string keywords at https://github.com/dotnet/SqlClient/blob/main/src/Microsoft.Data.SqlClient/src/Microsoft/Data/Common/DbConnectionStringCommon.cs
var adoSynonyms = map[string]string{
"application name": "app name",
"data source": "server",
"address": "server",
"network address": "server",
"addr": "server",
"user": "user id",
"uid": "user id",
"initial catalog": "database",
"application name": AppName,
"data source": Server,
"address": Server,
"network address": Server,
"addr": Server,
"user": UserID,
"uid": UserID,
"initial catalog": Database,
"column encryption setting": "columnencryption",
}
func splitConnectionString(dsn string) (res map[string]string) {
@ -460,18 +552,18 @@ func splitConnectionString(dsn string) (res map[string]string) {
name = synonym
}
// "server" in ADO can include a protocol and a port.
if name == "server" {
if name == Server {
for _, parser := range ProtocolParsers {
prot := parser.Protocol() + ":"
if strings.HasPrefix(value, prot) {
res["protocol"] = parser.Protocol()
res[Protocol] = parser.Protocol()
}
value = strings.TrimPrefix(value, prot)
}
serverParts := strings.Split(value, ",")
if len(serverParts) == 2 && len(serverParts[1]) > 0 {
value = serverParts[0]
res["port"] = serverParts[1]
res[Port] = serverParts[1]
}
}
res[name] = value
@ -493,10 +585,10 @@ func splitConnectionStringURL(dsn string) (map[string]string, error) {
}
if u.User != nil {
res["user id"] = u.User.Username()
res[UserID] = u.User.Username()
p, exists := u.User.Password()
if exists {
res["password"] = p
res[Password] = p
}
}
@ -506,13 +598,13 @@ func splitConnectionStringURL(dsn string) (map[string]string, error) {
}
if len(u.Path) > 0 {
res["server"] = host + "\\" + u.Path[1:]
res[Server] = host + "\\" + u.Path[1:]
} else {
res["server"] = host
res[Server] = host
}
if len(port) > 0 {
res["port"] = port
res[Port] = port
}
query := u.Query()

View file

@ -17,6 +17,7 @@ import (
"unicode"
"github.com/golang-sql/sqlexp"
"github.com/microsoft/go-mssqldb/aecmk"
"github.com/microsoft/go-mssqldb/internal/querytext"
"github.com/microsoft/go-mssqldb/msdsn"
)
@ -69,10 +70,7 @@ func (d *Driver) OpenConnector(dsn string) (*Connector, error) {
return nil, err
}
return &Connector{
params: params,
driver: d,
}, nil
return newConnector(params, d), nil
}
func (d *Driver) Open(dsn string) (driver.Conn, error) {
@ -122,10 +120,8 @@ func NewConnector(dsn string) (*Connector, error) {
if err != nil {
return nil, err
}
c := &Connector{
params: params,
driver: driverInstanceNoProcess,
}
c := newConnector(params, driverInstanceNoProcess)
return c, nil
}
@ -146,9 +142,14 @@ func NewConnectorWithAccessTokenProvider(dsn string, tokenProvider func(ctx cont
// NewConnectorConfig creates a new Connector for a DSN Config struct.
// The returned connector may be used with sql.OpenDB.
func NewConnectorConfig(config msdsn.Config) *Connector {
return newConnector(config, driverInstanceNoProcess)
}
func newConnector(config msdsn.Config, driver *Driver) *Connector {
return &Connector{
params: config,
driver: driverInstanceNoProcess,
params: config,
driver: driver,
keyProviders: make(aecmk.ColumnEncryptionKeyProviderMap),
}
}
@ -199,6 +200,8 @@ type Connector struct {
//
// If Dialer is not set, normal net dialers are used.
Dialer Dialer
keyProviders aecmk.ColumnEncryptionKeyProviderMap
}
type Dialer interface {
@ -219,6 +222,11 @@ func (c *Connector) getDialer(p *msdsn.Config) Dialer {
return createDialer(p)
}
// RegisterCekProvider associates the given provider with the named key store. If an entry of the given name already exists, that entry is overwritten
func (c *Connector) RegisterCekProvider(name string, provider aecmk.ColumnEncryptionKeyProvider) {
c.keyProviders[name] = aecmk.NewCekProvider(provider)
}
type Conn struct {
connector *Connector
sess *tdsSession
@ -403,7 +411,7 @@ func (d *Driver) open(ctx context.Context, dsn string) (*Conn, error) {
if err != nil {
return nil, err
}
c := &Connector{params: params}
c := newConnector(params, nil)
return d.connect(ctx, c, params)
}
@ -445,10 +453,11 @@ func (c *Conn) Close() error {
}
type Stmt struct {
c *Conn
query string
paramCount int
notifSub *queryNotifSub
c *Conn
query string
paramCount int
notifSub *queryNotifSub
skipEncryption bool
}
type queryNotifSub struct {
@ -472,7 +481,7 @@ func (c *Conn) prepareContext(ctx context.Context, query string) (*Stmt, error)
if c.processQueryText {
query, paramCount = querytext.ParseParams(query)
}
return &Stmt{c, query, paramCount, nil}, nil
return &Stmt{c, query, paramCount, nil, false}, nil
}
func (s *Stmt) Close() error {
@ -654,16 +663,38 @@ func (s *Stmt) makeRPCParams(args []namedValue, isProc bool) ([]param, []string,
if isOutputValue(val.Value) {
output = outputSuffix
}
decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(params[i+offset].ti), output)
tiDecl := params[i+offset].ti
if val.encrypt != nil {
// Encrypted parameters have a few requirements:
// 1. Copy original typeinfo to a block after the data
// 2. Set the parameter type to varbinary(max)
// 3. Append the crypto metadata bytes
params[i+offset].tiOriginal = params[i+offset].ti
params[i+offset].Flags |= fEncrypted
encryptedBytes, metadata, err := val.encrypt(params[i+offset].buffer)
if err != nil {
return nil, nil, err
}
params[i+offset].cipherInfo = metadata
params[i+offset].ti.TypeId = typeBigVarBin
params[i+offset].buffer = encryptedBytes
params[i+offset].ti.Size = 0
}
decls[i] = fmt.Sprintf("%s %s%s", name, makeDecl(tiDecl), output)
}
return params, decls, nil
}
// Encrypts the input bytes. Returns the encrypted bytes followed by the encryption metadata to append to the packet.
type valueEncryptor func(bytes []byte) ([]byte, []byte, error)
type namedValue struct {
Name string
Ordinal int
Value driver.Value
encrypt valueEncryptor
}
func convertOldArgs(args []driver.Value) []namedValue {
@ -677,6 +708,10 @@ func convertOldArgs(args []driver.Value) []namedValue {
return list
}
func (s *Stmt) doEncryption() bool {
return !s.skipEncryption && s.c.sess.alwaysEncrypted
}
func (s *Stmt) Query(args []driver.Value) (driver.Rows, error) {
defer s.c.clearOuts()
@ -687,6 +722,12 @@ func (s *Stmt) queryContext(ctx context.Context, args []namedValue) (rows driver
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
if s.doEncryption() && len(args) > 0 {
args, err = s.encryptArgs(ctx, args)
}
if err != nil {
return nil, err
}
if err = s.sendQuery(ctx, args); err != nil {
return nil, s.c.checkBadConn(ctx, err, true)
}
@ -754,6 +795,12 @@ func (s *Stmt) exec(ctx context.Context, args []namedValue) (res driver.Result,
if !s.c.connectionGood {
return nil, driver.ErrBadConn
}
if s.doEncryption() && len(args) > 0 {
args, err = s.encryptArgs(ctx, args)
}
if err != nil {
return nil, err
}
if err = s.sendQuery(ctx, args); err != nil {
return nil, s.c.checkBadConn(ctx, err, true)
}
@ -872,7 +919,7 @@ func (rc *Rows) NextResultSet() error {
// the value type that can be used to scan types into. For example, the database
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
return makeGoLangScanType(r.cols[index].ti)
return makeGoLangScanType(r.cols[index].originalTypeInfo())
}
// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
@ -881,7 +928,7 @@ func (r *Rows) ColumnTypeScanType(index int) reflect.Type {
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
// "TIMESTAMP".
func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
return makeGoLangTypeName(r.cols[index].ti)
return makeGoLangTypeName(r.cols[index].originalTypeInfo())
}
// RowsColumnTypeLength may be implemented by Rows. It should return the length
@ -897,7 +944,7 @@ func (r *Rows) ColumnTypeDatabaseTypeName(index int) string {
// int (0, false)
// bytea(30) (30, true)
func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
return makeGoLangTypeLength(r.cols[index].ti)
return makeGoLangTypeLength(r.cols[index].originalTypeInfo())
}
// It should return
@ -908,7 +955,7 @@ func (r *Rows) ColumnTypeLength(index int) (int64, bool) {
// int (0, 0, false)
// decimal (math.MaxInt64, math.MaxInt64, true)
func (r *Rows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
return makeGoLangTypePrecisionScale(r.cols[index].ti)
return makeGoLangTypePrecisionScale(r.cols[index].originalTypeInfo())
}
// The nullable value should
@ -974,12 +1021,20 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) {
res.ti.TypeId = typeIntN
res.ti.Size = 8
res.buffer = []byte{}
case byte:
res.ti.TypeId = typeIntN
res.buffer = []byte{val}
res.ti.Size = 1
case float64:
res.ti.TypeId = typeFltN
res.ti.Size = 8
res.buffer = make([]byte, 8)
binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val))
case float32:
res.ti.TypeId = typeFltN
res.ti.Size = 4
res.buffer = make([]byte, 4)
binary.LittleEndian.PutUint32(res.buffer, math.Float32bits(val))
case sql.NullFloat64:
// only null values should be getting here
res.ti.TypeId = typeFltN
@ -1043,7 +1098,7 @@ func (c *Conn) Ping(ctx context.Context) error {
if !c.connectionGood {
return driver.ErrBadConn
}
stmt := &Stmt{c, `select 1;`, 0, nil}
stmt := &Stmt{c, `select 1;`, 0, nil, true}
_, err := stmt.ExecContext(ctx, nil)
return err
}
@ -1108,7 +1163,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
}
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
list[i] = namedValueFromDriverNamedValue(nv)
}
return s.queryContext(ctx, list)
}
@ -1121,11 +1176,15 @@ func (s *Stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
}
list := make([]namedValue, len(args))
for i, nv := range args {
list[i] = namedValue(nv)
list[i] = namedValueFromDriverNamedValue(nv)
}
return s.exec(ctx, list)
}
func namedValueFromDriverNamedValue(v driver.NamedValue) namedValue {
return namedValue{Name: v.Name, Ordinal: v.Ordinal, Value: v.Value, encrypt: nil}
}
// Rowsq implements the sqlexp messages model for Query and QueryContext
// Theory: We could also implement the non-experimental model this way
type Rowsq struct {
@ -1316,7 +1375,7 @@ scan:
// the value type that can be used to scan types into. For example, the database
// column type "bigint" this should return "reflect.TypeOf(int64(0))".
func (r *Rowsq) ColumnTypeScanType(index int) reflect.Type {
return makeGoLangScanType(r.cols[index].ti)
return makeGoLangScanType(r.cols[index].originalTypeInfo())
}
// RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
@ -1325,7 +1384,7 @@ func (r *Rowsq) ColumnTypeScanType(index int) reflect.Type {
// "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
// "TIMESTAMP".
func (r *Rowsq) ColumnTypeDatabaseTypeName(index int) string {
return makeGoLangTypeName(r.cols[index].ti)
return makeGoLangTypeName(r.cols[index].originalTypeInfo())
}
// RowsColumnTypeLength may be implemented by Rows. It should return the length
@ -1341,7 +1400,7 @@ func (r *Rowsq) ColumnTypeDatabaseTypeName(index int) string {
// int (0, false)
// bytea(30) (30, true)
func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) {
return makeGoLangTypeLength(r.cols[index].ti)
return makeGoLangTypeLength(r.cols[index].originalTypeInfo())
}
// It should return
@ -1352,7 +1411,7 @@ func (r *Rowsq) ColumnTypeLength(index int) (int64, bool) {
// int (0, 0, false)
// decimal (math.MaxInt64, math.MaxInt64, true)
func (r *Rowsq) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
return makeGoLangTypePrecisionScale(r.cols[index].ti)
return makeGoLangTypePrecisionScale(r.cols[index].originalTypeInfo())
}
// The nullable value should

View file

@ -29,12 +29,18 @@ type MssqlStmt = Stmt // Deprecated: users should transition to th
var _ driver.NamedValueChecker = &Conn{}
// VarChar parameter types.
// VarChar is used to encode a string parameter as VarChar instead of a sized NVarChar
type VarChar string
// NVarCharMax is used to encode a string parameter as NVarChar(max) instead of a sized NVarChar
type NVarCharMax string
// VarCharMax is used to encode a string parameter as VarChar(max) instead of a sized NVarChar
type VarCharMax string
// NChar is used to encode a string parameter as NChar instead of a sized NVarChar
type NChar string
// DateTime1 encodes parameters to original DateTime SQL types.
type DateTime1 time.Time
@ -45,12 +51,16 @@ func convertInputParameter(val interface{}) (interface{}, error) {
switch v := val.(type) {
case int, int16, int32, int64, int8:
return val, nil
case byte:
return val, nil
case VarChar:
return val, nil
case NVarCharMax:
return val, nil
case VarCharMax:
return val, nil
case NChar:
return val, nil
case DateTime1:
return val, nil
case DateTimeOffset:
@ -61,8 +71,10 @@ func convertInputParameter(val interface{}) (interface{}, error) {
return val, nil
case civil.Time:
return val, nil
// case *apd.Decimal:
// return nil
// case *apd.Decimal:
// return nil
case float32:
return val, nil
default:
return driver.DefaultParameterConverter.ConvertValue(v)
}
@ -144,6 +156,10 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) {
res.ti.TypeId = typeNVarChar
res.buffer = str2ucs2(string(val))
res.ti.Size = 0 // currently zero forces nvarchar(max)
case NChar:
res.ti.TypeId = typeNChar
res.buffer = str2ucs2(string(val))
res.ti.Size = len(res.buffer)
case DateTime1:
t := time.Time(val)
res.ti.TypeId = typeDateTimeN

40
vendor/github.com/microsoft/go-mssqldb/quoter.go generated vendored Normal file
View file

@ -0,0 +1,40 @@
package mssql
import (
"strings"
)
// TSQLQuoter implements sqlexp.Quoter
type TSQLQuoter struct {
}
// ID quotes identifiers such as schema, table, or column names.
// This implementation handles multi-part names.
func (TSQLQuoter) ID(name string) string {
return "[" + strings.Replace(name, "]", "]]", -1) + "]"
}
// Value quotes database values such as string or []byte types as strings
// that are suitable and safe to embed in SQL text. The returned value
// of a string will include all surrounding quotes.
//
// If a value type is not supported it must panic.
func (TSQLQuoter) Value(v interface{}) string {
switch v := v.(type) {
default:
panic("unsupported value")
case string:
return sqlString(v)
case VarChar:
return sqlString(string(v))
case VarCharMax:
return sqlString(string(v))
case NVarCharMax:
return sqlString(string(v))
}
}
func sqlString(v string) string {
return "'" + strings.Replace(string(v), "'", "''", -1) + "'"
}

View file

@ -13,13 +13,16 @@ type procId struct {
const (
fByRevValue = 1
fDefaultValue = 2
fEncrypted = 8
)
type param struct {
Name string
Flags uint8
ti typeInfo
buffer []byte
Name string
Flags uint8
ti typeInfo
buffer []byte
tiOriginal typeInfo
cipherInfo []byte
}
var (
@ -78,6 +81,15 @@ func sendRpc(buf *tdsBuffer, headers []headerStruct, proc procId, flags uint16,
if err != nil {
return
}
if (param.Flags & fEncrypted) == fEncrypted {
err = writeTypeInfo(buf, &param.tiOriginal)
if err != nil {
return
}
if _, err = buf.Write(param.cipherInfo); err != nil {
return
}
}
}
return buf.FinishPacket()
}

View file

@ -15,6 +15,7 @@ import (
"unicode/utf16"
"unicode/utf8"
"github.com/microsoft/go-mssqldb/aecmk"
"github.com/microsoft/go-mssqldb/integratedauth"
"github.com/microsoft/go-mssqldb/msdsn"
)
@ -102,6 +103,7 @@ const (
verTDS73 = verTDS73A
verTDS73B = 0x730B0003
verTDS74 = 0x74000004
verTDS80 = 0x08000000
)
// packet types
@ -143,6 +145,7 @@ const (
encryptOn = 1 // Encryption is available and on.
encryptNotSup = 2 // Encryption is not available.
encryptReq = 3 // Encryption is required.
encryptStrict = 4
)
const (
@ -157,16 +160,23 @@ const (
)
type tdsSession struct {
buf *tdsBuffer
loginAck loginAckStruct
database string
partner string
columns []columnStruct
tranid uint64
logFlags uint64
logger ContextLogger
routedServer string
routedPort uint16
buf *tdsBuffer
loginAck loginAckStruct
database string
partner string
columns []columnStruct
tranid uint64
logFlags uint64
logger ContextLogger
routedServer string
routedPort uint16
alwaysEncrypted bool
aeSettings *alwaysEncryptedSettings
}
type alwaysEncryptedSettings struct {
enclaveType string
keyProviders aecmk.ColumnEncryptionKeyProviderMap
}
const (
@ -178,10 +188,26 @@ const (
)
type columnStruct struct {
UserType uint32
Flags uint16
ColName string
ti typeInfo
UserType uint32
Flags uint16
ColName string
ti typeInfo
cryptoMeta *cryptoMetadata
}
func (c columnStruct) isEncrypted() bool {
return isEncryptedFlag(c.Flags)
}
func isEncryptedFlag(flags uint16) bool {
return colFlagEncrypted == (flags & colFlagEncrypted)
}
func (c columnStruct) originalTypeInfo() typeInfo {
if c.isEncrypted() {
return c.cryptoMeta.typeInfo
}
return c.ti
}
type keySlice []uint8
@ -577,7 +603,7 @@ func sendLogin(w *tdsBuffer, login *login) error {
language := str2ucs2(login.Language)
database := str2ucs2(login.Database)
atchdbfile := str2ucs2(login.AtchDBFile)
changepassword := str2ucs2(login.ChangePassword)
changepassword := manglePassword(login.ChangePassword)
featureExt := login.FeatureExt.toBytes()
hdr := loginHeader{
@ -638,6 +664,9 @@ func sendLogin(w *tdsBuffer, login *login) error {
offset += hdr.ExtensionLength // DWORD
featureExtOffset = uint32(offset)
}
if len(changepassword) > 0 {
hdr.OptionFlags3 |= fChangePassword
}
hdr.Length = uint32(offset) + uint32(featureExtLen)
var err error
@ -977,6 +1006,8 @@ func preparePreloginFields(p msdsn.Config, fe *featureExtFedAuth) map[uint8][]by
encrypt = encryptOn
case msdsn.EncryptionOff:
encrypt = encryptOff
case msdsn.EncryptionStrict:
encrypt = encryptStrict
}
v := getDriverVersion(driverVersion)
fields := map[uint8][]byte{
@ -1023,6 +1054,12 @@ func interpretPreloginResponse(p msdsn.Config, fe *featureExtFedAuth, fields map
}
func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger ContextLogger, auth integratedauth.IntegratedAuthenticator, fe *featureExtFedAuth, packetSize uint32) (l *login, err error) {
var TDSVersion uint32
if p.Encryption == msdsn.EncryptionStrict {
TDSVersion = verTDS80
} else {
TDSVersion = verTDS74
}
var typeFlags uint8
if p.ReadOnlyIntent {
typeFlags |= fReadOnlyIntent
@ -1035,17 +1072,21 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont
serverName = p.Host
}
l = &login{
TDSVersion: verTDS74,
PacketSize: packetSize,
Database: p.Database,
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
OptionFlags1: fUseDB | fSetLang,
HostName: p.Workstation,
ServerName: serverName,
AppName: p.AppName,
TypeFlags: typeFlags,
CtlIntName: "go-mssqldb",
ClientProgVer: getDriverVersion(driverVersion),
TDSVersion: TDSVersion,
PacketSize: packetSize,
Database: p.Database,
OptionFlags2: fODBC, // to get unlimited TEXTSIZE
OptionFlags1: fUseDB | fSetLang,
HostName: p.Workstation,
ServerName: serverName,
AppName: p.AppName,
TypeFlags: typeFlags,
CtlIntName: "go-mssqldb",
ClientProgVer: getDriverVersion(driverVersion),
ChangePassword: p.ChangePassword,
}
if p.ColumnEncryption {
_ = l.FeatureExt.Add(&featureExtColumnEncryption{})
}
switch {
case fe.FedAuthLibrary == FedAuthLibrarySecurityToken:
@ -1061,14 +1102,14 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont
return nil, err
}
l.FeatureExt.Add(fe)
_ = l.FeatureExt.Add(fe)
case fe.FedAuthLibrary == FedAuthLibraryADAL:
if uint64(p.LogFlags)&logDebug != 0 {
logger.Log(ctx, msdsn.LogDebug, "Starting federated authentication using ADAL")
}
l.FeatureExt.Add(fe)
_ = l.FeatureExt.Add(fe)
case auth != nil:
if uint64(p.LogFlags)&logDebug != 0 {
@ -1092,8 +1133,29 @@ func prepareLogin(ctx context.Context, c *Connector, p msdsn.Config, logger Cont
return l, nil
}
func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) {
func getTLSConn(conn *timeoutConn, p msdsn.Config, alpnSeq string) (tlsConn *tls.Conn, err error) {
var config *tls.Config
if pc := p.TLSConfig; pc != nil {
config = pc
}
if config == nil {
config, err = msdsn.SetupTLS("", false, p.Host, "")
if err != nil {
return nil, err
}
}
//Set ALPN Sequence
config.NextProtos = []string{alpnSeq}
tlsConn = tls.Client(conn.c, config)
err = tlsConn.Handshake()
if err != nil {
return nil, fmt.Errorf("TLS Handshake failed: %w", err)
}
return tlsConn, nil
}
func connect(ctx context.Context, c *Connector, logger ContextLogger, p msdsn.Config) (res *tdsSession, err error) {
isTransportEncrypted := false
// if instance is specified use instance resolution service
if len(p.Instance) > 0 && p.Port != 0 && uint64(p.LogFlags)&logDebug != 0 {
// both instance name and port specified
@ -1133,14 +1195,25 @@ initiate_connection:
}
toconn := newTimeoutConn(conn, p.ConnTimeout)
outbuf := newTdsBuffer(packetSize, toconn)
if p.Encryption == msdsn.EncryptionStrict {
outbuf.transport, err = getTLSConn(toconn, p, "tds/8.0")
if err != nil {
return nil, err
}
isTransportEncrypted = true
}
sess := tdsSession{
buf: outbuf,
logger: logger,
logFlags: uint64(p.LogFlags),
buf: outbuf,
logger: logger,
logFlags: uint64(p.LogFlags),
aeSettings: &alwaysEncryptedSettings{keyProviders: aecmk.GetGlobalCekProviders()},
}
for i, p := range c.keyProviders {
sess.aeSettings.keyProviders[i] = p
}
fedAuth := &featureExtFedAuth{
FedAuthLibrary: FedAuthLibraryReserved,
}
@ -1166,42 +1239,47 @@ initiate_connection:
return nil, err
}
if encrypt != encryptNotSup {
var config *tls.Config
if pc := p.TLSConfig; pc != nil {
config = pc
if config.DynamicRecordSizingDisabled == false {
config = config.Clone()
//We need not perform TLS handshake if the communication channel is already encrypted (encrypt=strict)
if !isTransportEncrypted {
if encrypt != encryptNotSup {
var config *tls.Config
if pc := p.TLSConfig; pc != nil {
config = pc
if !config.DynamicRecordSizingDisabled {
config = config.Clone()
// fix for https://github.com/microsoft/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
config.DynamicRecordSizingDisabled = true
// fix for https://github.com/microsoft/go-mssqldb/issues/166
// Go implementation of TLS payload size heuristic algorithm splits single TDS package to multiple TCP segments,
// while SQL Server seems to expect one TCP segment per encrypted TDS package.
// Setting DynamicRecordSizingDisabled to true disables that algorithm and uses 16384 bytes per TLS package
config.DynamicRecordSizingDisabled = true
}
}
}
if config == nil {
config, err = msdsn.SetupTLS("", false, p.Host, "")
if config == nil {
config, err = msdsn.SetupTLS("", false, p.Host, "")
if err != nil {
return nil, err
}
}
// setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream
handshakeConn := tlsHandshakeConn{buf: outbuf}
passthrough := passthroughConn{c: &handshakeConn}
tlsConn := tls.Client(&passthrough, config)
err = tlsConn.Handshake()
passthrough.c = toconn
outbuf.transport = tlsConn
if err != nil {
return nil, err
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
}
if encrypt == encryptOff {
outbuf.afterFirst = func() {
outbuf.transport = toconn
}
}
}
// setting up connection handler which will allow wrapping of TLS handshake packets inside TDS stream
handshakeConn := tlsHandshakeConn{buf: outbuf}
passthrough := passthroughConn{c: &handshakeConn}
tlsConn := tls.Client(&passthrough, config)
err = tlsConn.Handshake()
passthrough.c = toconn
outbuf.transport = tlsConn
if err != nil {
return nil, fmt.Errorf("TLS Handshake failed: %v", err)
}
if encrypt == encryptOff {
outbuf.afterFirst = func() {
outbuf.transport = toconn
}
}
}
auth, err := integratedauth.GetIntegratedAuthenticator(p)
@ -1288,6 +1366,18 @@ initiate_connection:
case loginAckStruct:
sess.loginAck = token
loginAck = true
case featureExtAck:
for _, v := range token {
switch v := v.(type) {
case colAckStruct:
if v.Version <= 2 && v.Version > 0 {
sess.alwaysEncrypted = true
if len(v.EnclaveType) > 0 {
sess.aeSettings.enclaveType = string(v.EnclaveType)
}
}
}
}
case doneStruct:
if token.isError() {
tokenErr := token.getError()
@ -1317,3 +1407,21 @@ initiate_connection:
}
return &sess, nil
}
type featureExtColumnEncryption struct {
}
func (f *featureExtColumnEncryption) featureID() byte {
return featExtCOLUMNENCRYPTION
}
func (f *featureExtColumnEncryption) toBytes() []byte {
/*
1 = The client supports column encryption without enclave computations.
2 = The client SHOULD<25> support column encryption when encrypted data require enclave computations.
3 = The client SHOULD<26> support column encryption when encrypted data require enclave computations
with the additional ability to cache column encryption keys that are to be sent to the enclave
and the ability to retry queries when the keys sent by the client do not match what is needed for the query to run.
*/
return []byte{0x01}
}

View file

@ -1,6 +1,7 @@
package mssql
import (
"bytes"
"context"
"encoding/binary"
"fmt"
@ -10,7 +11,11 @@ import (
"strconv"
"github.com/golang-sql/sqlexp"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption"
"github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys"
"github.com/microsoft/go-mssqldb/msdsn"
"golang.org/x/text/encoding/unicode"
)
//go:generate go run golang.org/x/tools/cmd/stringer -type token
@ -92,10 +97,15 @@ const (
fedAuthInfoSPN = 0x02
)
const (
cipherAlgCustom = 0x00
)
// COLMETADATA flags
// https://msdn.microsoft.com/en-us/library/dd357363.aspx
const (
colFlagNullable = 1
colFlagNullable = 1
colFlagEncrypted = 0x0800
// TODO implement more flags
)
@ -533,7 +543,14 @@ type fedAuthAckStruct struct {
Signature []byte
}
func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} {
type colAckStruct struct {
Version int
EnclaveType string
}
type featureExtAck map[byte]interface{}
func parseFeatureExtAck(r *tdsBuffer) featureExtAck {
ack := map[byte]interface{}{}
for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() {
@ -555,7 +572,21 @@ func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} {
length -= 32
}
ack[feature] = fedAuthAck
case featExtCOLUMNENCRYPTION:
colAck := colAckStruct{Version: int(r.byte())}
length--
if length > 0 {
// enclave type is sent as utf16 le
enclaveLength := r.byte() * 2
length--
enclaveBytes := make([]byte, enclaveLength)
r.ReadFull(enclaveBytes)
// if the enclave type is malformed we'll just ignore it
colAck.EnclaveType, _ = ucs22str(enclaveBytes)
length -= uint32(enclaveLength)
}
ack[feature] = colAck
}
// Skip unprocessed bytes
@ -568,34 +599,265 @@ func parseFeatureExtAck(r *tdsBuffer) map[byte]interface{} {
}
// http://msdn.microsoft.com/en-us/library/dd357363.aspx
func parseColMetadata72(r *tdsBuffer) (columns []columnStruct) {
func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) {
count := r.uint16()
if count == 0xffff {
// no metadata is sent
return nil
}
columns = make([]columnStruct, count)
var cekTable *cekTable
if s.alwaysEncrypted {
// column encryption key list
cekTable = readCekTable(r)
}
for i := range columns {
column := &columns[i]
column.UserType = r.uint32()
column.Flags = r.uint16()
baseTi := getBaseTypeInfo(r, true)
typeInfo := readTypeInfo(r, baseTi.TypeId, column.cryptoMeta)
typeInfo.UserType = baseTi.UserType
typeInfo.Flags = baseTi.Flags
typeInfo.TypeId = baseTi.TypeId
column.Flags = baseTi.Flags
column.UserType = baseTi.UserType
column.ti = typeInfo
if column.isEncrypted() && s.alwaysEncrypted {
// Read Crypto Metadata
cryptoMeta := parseCryptoMetadata(r, cekTable)
cryptoMeta.typeInfo.Flags = baseTi.Flags
column.cryptoMeta = &cryptoMeta
} else {
column.cryptoMeta = nil
}
// parsing TYPE_INFO structure
column.ti = readTypeInfo(r)
column.ColName = r.BVarChar()
}
return columns
}
// http://msdn.microsoft.com/en-us/library/dd357254.aspx
func parseRow(r *tdsBuffer, columns []columnStruct, row []interface{}) {
for i, column := range columns {
row[i] = column.ti.Reader(&column.ti, r)
func getBaseTypeInfo(r *tdsBuffer, parseFlags bool) typeInfo {
userType := r.uint32()
flags := uint16(0)
if parseFlags {
flags = r.uint16()
}
tId := r.byte()
return typeInfo{
UserType: userType,
Flags: flags,
TypeId: tId}
}
type cryptoMetadata struct {
entry *cekTableEntry
ordinal uint16
algorithmId byte
algorithmName *string
encType byte
normRuleVer byte
typeInfo typeInfo
}
func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable) cryptoMetadata {
ordinal := uint16(0)
if cekTable != nil {
ordinal = r.uint16()
}
typeInfo := getBaseTypeInfo(r, false)
ti := readTypeInfo(r, typeInfo.TypeId, nil)
ti.UserType = typeInfo.UserType
ti.Flags = typeInfo.Flags
ti.TypeId = typeInfo.TypeId
algorithmId := r.byte()
var algName *string = nil
if algorithmId == cipherAlgCustom {
// Read the name when a custom algorithm is used
nameLen := int(r.byte())
var algNameUtf16 = make([]byte, nameLen*2)
r.ReadFull(algNameUtf16)
algNameBytes, _ := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder().Bytes(algNameUtf16)
mAlgName := string(algNameBytes)
algName = &mAlgName
}
encType := r.byte()
normRuleVer := r.byte()
var entry *cekTableEntry = nil
if cekTable != nil {
if int(ordinal) > len(cekTable.entries)-1 {
panic(fmt.Errorf("invalid ordinal, cekTable only has %d entries", len(cekTable.entries)))
}
entry = &cekTable.entries[ordinal]
}
return cryptoMetadata{
entry: entry,
ordinal: ordinal,
algorithmId: algorithmId,
algorithmName: algName,
encType: encType,
normRuleVer: normRuleVer,
typeInfo: ti,
}
}
func readCekTable(r *tdsBuffer) *cekTable {
tableSize := r.uint16()
var cekTable *cekTable = nil
if tableSize != 0 {
mCekTable := newCekTable(tableSize)
for i := uint16(0); i < tableSize; i++ {
mCekTable.entries[i] = readCekTableEntry(r)
}
cekTable = &mCekTable
}
return cekTable
}
func readCekTableEntry(r *tdsBuffer) cekTableEntry {
databaseId := r.int32()
cekID := r.int32()
cekVersion := r.int32()
var cekMdVersion = make([]byte, 8)
_, err := r.Read(cekMdVersion)
if err != nil {
panic("unable to read cekMdVersion")
}
cekValueCount := uint(r.byte())
// not using ucs22str because we already know the data is utf16
enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM)
utf16dec := enc.NewDecoder()
cekValues := make([]encryptionKeyInfo, cekValueCount)
for i := uint(0); i < cekValueCount; i++ {
encryptedCekLength := r.uint16()
encryptedCek := make([]byte, encryptedCekLength)
r.ReadFull(encryptedCek)
keyStoreLength := r.byte()
keyStoreNameUtf16 := make([]byte, keyStoreLength*2)
r.ReadFull(keyStoreNameUtf16)
keyStoreName, _ := utf16dec.Bytes(keyStoreNameUtf16)
keyPathLength := r.uint16()
keyPathUtf16 := make([]byte, keyPathLength*2)
r.ReadFull(keyPathUtf16)
keyPath, _ := utf16dec.Bytes(keyPathUtf16)
algLength := r.byte()
algNameUtf16 := make([]byte, algLength*2)
r.ReadFull(algNameUtf16)
algName, _ := utf16dec.Bytes(algNameUtf16)
cekValues[i] = encryptionKeyInfo{
encryptedKey: encryptedCek,
databaseID: int(databaseId),
cekID: int(cekID),
cekVersion: int(cekVersion),
cekMdVersion: cekMdVersion,
keyPath: string(keyPath),
keyStoreName: string(keyStoreName),
algorithmName: string(algName),
}
}
return cekTableEntry{
databaseID: int(databaseId),
keyId: int(cekID),
keyVersion: int(cekVersion),
mdVersion: cekMdVersion,
valueCount: int(cekValueCount),
cekValues: cekValues,
}
}
// http://msdn.microsoft.com/en-us/library/dd357254.aspx
func parseRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) {
for i, column := range columns {
columnContent := column.ti.Reader(&column.ti, r, nil)
if columnContent == nil {
row[i] = columnContent
continue
}
if column.isEncrypted() {
buffer := decryptColumn(column, s, columnContent)
// Decrypt
row[i] = column.cryptoMeta.typeInfo.Reader(&column.cryptoMeta.typeInfo, &buffer, column.cryptoMeta)
} else {
row[i] = columnContent
}
}
}
type RWCBuffer struct {
buffer *bytes.Reader
}
func (R RWCBuffer) Read(p []byte) (n int, err error) {
return R.buffer.Read(p)
}
func (R RWCBuffer) Write(p []byte) (n int, err error) {
return 0, nil
}
func (R RWCBuffer) Close() error {
return nil
}
func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{}) tdsBuffer {
encType := encryption.From(column.cryptoMeta.encType)
cekValue := column.cryptoMeta.entry.cekValues[column.cryptoMeta.ordinal]
if (s.logFlags & uint64(msdsn.LogDebug)) == uint64(msdsn.LogDebug) {
s.logger.Log(context.Background(), msdsn.LogDebug, fmt.Sprintf("Decrypting column %s. Key path: %s, Key store:%s, Algo: %s", column.ColName, cekValue.keyPath, cekValue.keyStoreName, cekValue.algorithmName))
}
cekProvider, ok := s.aeSettings.keyProviders[cekValue.keyStoreName]
if !ok {
panic(fmt.Errorf("Unable to find provider %s to decrypt CEK", cekValue.keyStoreName))
}
cek, err := cekProvider.GetDecryptedKey(cekValue.keyPath, column.cryptoMeta.entry.cekValues[0].encryptedKey)
if err != nil {
panic(err)
}
k := keys.NewAeadAes256CbcHmac256(cek)
alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encType, byte(cekValue.cekVersion))
d, err := alg.Decrypt(columnContent.([]byte))
if err != nil {
panic(err)
}
// Decrypt returns a minimum of 8 bytes so truncate to the actual data size
if column.cryptoMeta.typeInfo.Size > 0 && column.cryptoMeta.typeInfo.Size < len(d) {
d = d[:column.cryptoMeta.typeInfo.Size]
}
var newBuff []byte
newBuff = append(newBuff, d...)
rwc := RWCBuffer{
buffer: bytes.NewReader(newBuff),
}
column.cryptoMeta.typeInfo.Buffer = d
buffer := tdsBuffer{rpos: 0, rsize: len(newBuff), rbuf: newBuff, transport: rwc}
return buffer
}
// http://msdn.microsoft.com/en-us/library/dd304783.aspx
func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) {
func parseNbcRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) {
bitlen := (len(columns) + 7) / 8
pres := make([]byte, bitlen)
r.ReadFull(pres)
@ -604,7 +866,15 @@ func parseNbcRow(r *tdsBuffer, columns []columnStruct, row []interface{}) {
row[i] = nil
continue
}
row[i] = col.ti.Reader(&col.ti, r)
columnContent := col.ti.Reader(&col.ti, r, nil)
if col.isEncrypted() {
buffer := decryptColumn(col, s, columnContent)
// Decrypt
row[i] = col.cryptoMeta.typeInfo.Reader(&col.cryptoMeta.typeInfo, &buffer, col.cryptoMeta)
} else {
row[i] = columnContent
}
}
}
@ -637,7 +907,7 @@ func parseInfo(r *tdsBuffer) (res Error) {
}
// https://msdn.microsoft.com/en-us/library/dd303881.aspx
func parseReturnValue(r *tdsBuffer) (nv namedValue) {
func parseReturnValue(r *tdsBuffer, s *tdsSession) (nv namedValue) {
/*
ParamOrdinal
ParamName
@ -648,13 +918,21 @@ func parseReturnValue(r *tdsBuffer) (nv namedValue) {
CryptoMetadata
Value
*/
r.uint16()
nv.Name = r.BVarChar()
r.byte()
r.uint32() // UserType (uint16 prior to 7.2)
r.uint16()
ti := readTypeInfo(r)
nv.Value = ti.Reader(&ti, r)
_ = r.uint16() // ParamOrdinal
nv.Name = r.BVarChar() // ParamName
_ = r.byte() // Status
ti := getBaseTypeInfo(r, true) // UserType + Flags + TypeInfo
var cryptoMetadata *cryptoMetadata = nil
if s.alwaysEncrypted && (ti.Flags&fEncrypted) == fEncrypted {
cm := parseCryptoMetadata(r, nil) // CryptoMetadata
cryptoMetadata = &cm
}
ti2 := readTypeInfo(r, ti.TypeId, cryptoMetadata)
nv.Value = ti2.Reader(&ti2, r, cryptoMetadata)
return
}
@ -664,6 +942,17 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
if sess.logFlags&logErrors != 0 {
sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Intercepted panic %v", err))
}
if outs.msgq != nil {
var derr error
switch e := err.(type) {
case error:
derr = e
default:
derr = fmt.Errorf("Unhandled session error %v", e)
}
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgError{Error: derr})
}
ch <- err
}
close(ch)
@ -760,7 +1049,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
ch <- done
if done.Status&doneCount != 0 {
if sess.logFlags&logRows != 0 {
sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d row(s) affected)", done.RowCount))
sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(Rows affected: %d)", done.RowCount))
}
if (colsReceived || done.CurCmd != cmdSelect) && outs.msgq != nil {
@ -781,7 +1070,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
return
}
case tokenColMetadata:
columns = parseColMetadata72(sess.buf)
columns = parseColMetadata72(sess.buf, sess)
ch <- columns
colsReceived = true
if outs.msgq != nil {
@ -790,11 +1079,11 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
case tokenRow:
row := make([]interface{}, len(columns))
parseRow(sess.buf, columns, row)
parseRow(sess.buf, sess, columns, row)
ch <- row
case tokenNbcRow:
row := make([]interface{}, len(columns))
parseNbcRow(sess.buf, columns, row)
parseNbcRow(sess.buf, sess, columns, row)
ch <- row
case tokenEnvChange:
processEnvChg(ctx, sess)
@ -822,7 +1111,7 @@ func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenS
_ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNotice{Message: info})
}
case tokenReturnValue:
nv := parseReturnValue(sess.buf)
nv := parseReturnValue(sess.buf, sess)
if len(nv.Name) > 0 {
name := nv.Name[1:] // Remove the leading "@".
if ov, has := outs.params[name]; has {

View file

@ -89,6 +89,8 @@ const (
// http://msdn.microsoft.com/en-us/library/dd358284.aspx
type typeInfo struct {
TypeId uint8
UserType uint32
Flags uint16
Size int
Scale uint8
Prec uint8
@ -96,7 +98,7 @@ type typeInfo struct {
Collation cp.Collation
UdtInfo udtInfo
XmlInfo xmlInfo
Reader func(ti *typeInfo, r *tdsBuffer) (res interface{})
Reader func(ti *typeInfo, r *tdsBuffer, cryptoMeta *cryptoMetadata) (res interface{})
Writer func(w io.Writer, ti typeInfo, buf []byte) (err error)
}
@ -119,9 +121,9 @@ type xmlInfo struct {
XmlSchemaCollection string
}
func readTypeInfo(r *tdsBuffer) (res typeInfo) {
res.TypeId = r.byte()
switch res.TypeId {
func readTypeInfo(r *tdsBuffer, typeId byte, c *cryptoMetadata) (res typeInfo) {
res.TypeId = typeId
switch typeId {
case typeNull, typeInt1, typeBit, typeInt2, typeInt4, typeDateTim4,
typeFlt4, typeMoney, typeDateTime, typeFlt8, typeMoney4, typeInt8:
// those are fixed length types
@ -140,7 +142,7 @@ func readTypeInfo(r *tdsBuffer) (res typeInfo) {
res.Reader = readFixedType
res.Buffer = make([]byte, res.Size)
default: // all others are VARLENTYPE
readVarLen(&res, r)
readVarLen(&res, r, c)
}
return
}
@ -315,7 +317,7 @@ func decodeDateTime(buf []byte) time.Time {
0, 0, secs, ns, time.UTC)
}
func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} {
func readFixedType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} {
r.ReadFull(ti.Buffer)
buf := ti.Buffer
switch ti.TypeId {
@ -349,8 +351,13 @@ func readFixedType(ti *typeInfo, r *tdsBuffer) interface{} {
panic("shoulnd't get here")
}
func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.byte()
func readByteLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} {
var size byte
if c != nil {
size = byte(r.rsize)
} else {
size = r.byte()
}
if size == 0 {
return nil
}
@ -433,7 +440,7 @@ func readByteLenType(ti *typeInfo, r *tdsBuffer) interface{} {
default:
badStreamPanicf("Invalid typeid")
}
panic("shoulnd't get here")
panic("shouldn't get here")
}
func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
@ -448,8 +455,13 @@ func writeByteLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
return
}
func readShortLenType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.uint16()
func readShortLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} {
var size uint16
if c != nil {
size = uint16(r.rsize)
} else {
size = r.uint16()
}
if size == 0xffff {
return nil
}
@ -491,7 +503,7 @@ func writeShortLenType(w io.Writer, ti typeInfo, buf []byte) (err error) {
return
}
func readLongLenType(ti *typeInfo, r *tdsBuffer) interface{} {
func readLongLenType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} {
// information about this format can be found here:
// http://msdn.microsoft.com/en-us/library/dd304783.aspx
// and here:
@ -566,7 +578,7 @@ func writeCollation(w io.Writer, col cp.Collation) (err error) {
// reads variant value
// http://msdn.microsoft.com/en-us/library/dd303302.aspx
func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} {
func readVariantType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} {
size := r.int32()
if size == 0 {
return nil
@ -658,41 +670,47 @@ func readVariantType(ti *typeInfo, r *tdsBuffer) interface{} {
// partially length prefixed stream
// http://msdn.microsoft.com/en-us/library/dd340469.aspx
func readPLPType(ti *typeInfo, r *tdsBuffer) interface{} {
size := r.uint64()
var buf *bytes.Buffer
switch size {
case _PLP_NULL:
// null
return nil
case _UNKNOWN_PLP_LEN:
// size unknown
buf = bytes.NewBuffer(make([]byte, 0, 1000))
default:
buf = bytes.NewBuffer(make([]byte, 0, size))
}
for {
chunksize := r.uint32()
if chunksize == 0 {
break
func readPLPType(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) interface{} {
var bytesToDecode []byte
if c == nil {
size := r.uint64()
var buf *bytes.Buffer
switch size {
case _PLP_NULL:
// null
return nil
case _UNKNOWN_PLP_LEN:
// size unknown
buf = bytes.NewBuffer(make([]byte, 0, 1000))
default:
buf = bytes.NewBuffer(make([]byte, 0, size))
}
if _, err := io.CopyN(buf, r, int64(chunksize)); err != nil {
badStreamPanicf("Reading PLP type failed: %s", err.Error())
for {
chunksize := r.uint32()
if chunksize == 0 {
break
}
if _, err := io.CopyN(buf, r, int64(chunksize)); err != nil {
badStreamPanicf("Reading PLP type failed: %s", err.Error())
}
}
bytesToDecode = buf.Bytes()
} else {
bytesToDecode = r.rbuf
}
switch ti.TypeId {
case typeXml:
return decodeXml(*ti, buf.Bytes())
return decodeXml(*ti, bytesToDecode)
case typeBigVarChar, typeBigChar, typeText:
return decodeChar(ti.Collation, buf.Bytes())
return decodeChar(ti.Collation, bytesToDecode)
case typeBigVarBin, typeBigBinary, typeImage:
return buf.Bytes()
return bytesToDecode
case typeNVarChar, typeNChar, typeNText:
return decodeNChar(buf.Bytes())
return decodeNChar(bytesToDecode)
case typeUdt:
return decodeUdt(*ti, buf.Bytes())
return decodeUdt(*ti, bytesToDecode)
}
panic("shoulnd't get here")
panic("shouldn't get here")
}
func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) {
@ -719,7 +737,7 @@ func writePLPType(w io.Writer, ti typeInfo, buf []byte) (err error) {
}
}
func readVarLen(ti *typeInfo, r *tdsBuffer) {
func readVarLen(ti *typeInfo, r *tdsBuffer, c *cryptoMetadata) {
switch ti.TypeId {
case typeDateN:
ti.Size = 3

View file

@ -4,7 +4,7 @@ import "fmt"
// Update this variable with the release tag before pushing the tag
// This value is written to the prelogin and login7 packets during a new connection
const driverVersion = "v1.5.0"
const driverVersion = "v1.6.0"
func getDriverVersion(ver string) uint32 {
var majorVersion uint32