package mssql import ( "bytes" "context" "encoding/binary" "fmt" "io" "io/ioutil" "net" "strconv" "github.com/golang-sql/sqlexp" "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/algorithms" "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/encryption" "github.com/microsoft/go-mssqldb/internal/github.com/swisscom/mssql-always-encrypted/pkg/keys" "github.com/microsoft/go-mssqldb/msdsn" "golang.org/x/text/encoding/unicode" ) //go:generate go run golang.org/x/tools/cmd/stringer -type token type token byte // token ids const ( tokenReturnStatus token = 121 // 0x79 tokenColMetadata token = 129 // 0x81 tokenOrder token = 169 // 0xA9 tokenError token = 170 // 0xAA tokenInfo token = 171 // 0xAB tokenReturnValue token = 0xAC tokenLoginAck token = 173 // 0xad tokenFeatureExtAck token = 174 // 0xae tokenRow token = 209 // 0xd1 tokenNbcRow token = 210 // 0xd2 tokenEnvChange token = 227 // 0xE3 tokenSSPI token = 237 // 0xED tokenFedAuthInfo token = 238 // 0xEE tokenDone token = 253 // 0xFD tokenDoneProc token = 254 tokenDoneInProc token = 255 ) // done flags // https://msdn.microsoft.com/en-us/library/dd340421.aspx const ( doneFinal = 0 doneMore = 1 doneError = 2 doneInxact = 4 doneCount = 0x10 doneAttn = 0x20 doneSrvError = 0x100 ) // CurCmd values in done (undocumented) const ( cmdSelect = 0xc1 // cmdInsert = 0xc3 // cmdDelete = 0xc4 // cmdUpdate = 0xc5 // cmdAbort = 0xd2 // cmdBeginXaxt = 0xd4 // cmdEndXact = 0xd5 // cmdBulkInsert = 0xf0 // cmdOpenCursor = 0x20 // cmdMerge = 0x117 ) // ENVCHANGE types // http://msdn.microsoft.com/en-us/library/dd303449.aspx const ( envTypDatabase = 1 envTypLanguage = 2 envTypCharset = 3 envTypPacketSize = 4 envSortId = 5 envSortFlags = 6 envSqlCollation = 7 envTypBeginTran = 8 envTypCommitTran = 9 envTypRollbackTran = 10 envEnlistDTC = 11 envDefectTran = 12 envDatabaseMirrorPartner = 13 envPromoteTran = 15 envTranMgrAddr = 16 envTranEnded = 17 envResetConnAck = 18 envStartedInstanceName = 19 envRouting = 20 ) const ( fedAuthInfoSTSURL = 0x01 fedAuthInfoSPN = 0x02 ) const ( cipherAlgCustom = 0x00 ) // COLMETADATA flags // https://msdn.microsoft.com/en-us/library/dd357363.aspx const ( colFlagNullable = 1 colFlagEncrypted = 0x0800 // TODO implement more flags ) // interface for all tokens type tokenStruct interface{} type orderStruct struct { ColIds []uint16 } type doneStruct struct { Status uint16 CurCmd uint16 RowCount uint64 errors []Error } func (d doneStruct) isError() bool { return d.Status&doneError != 0 || len(d.errors) > 0 } func (d doneStruct) getError() Error { n := len(d.errors) if n == 0 { return Error{Message: "Request failed but didn't provide reason"} } err := d.errors[n-1] // should this return the most severe error? err.All = make([]Error, n) copy(err.All, d.errors) return err } type doneInProcStruct doneStruct // ENVCHANGE stream // http://msdn.microsoft.com/en-us/library/dd303449.aspx func processEnvChg(ctx context.Context, sess *tdsSession) { size := sess.buf.uint16() r := &io.LimitedReader{R: sess.buf, N: int64(size)} for { var err error var envtype uint8 err = binary.Read(r, binary.LittleEndian, &envtype) if err == io.EOF { return } if err != nil { badStreamPanic(err) } switch envtype { case envTypDatabase: sess.database, err = readBVarChar(r) if err != nil { badStreamPanic(err) } _, err = readBVarChar(r) if err != nil { badStreamPanic(err) } case envTypLanguage: // currently ignored // new value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envTypCharset: // currently ignored // new value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envTypPacketSize: packetsize, err := readBVarChar(r) if err != nil { badStreamPanic(err) } _, err = readBVarChar(r) if err != nil { badStreamPanic(err) } packetsizei, err := strconv.Atoi(packetsize) if err != nil { badStreamPanicf("Invalid Packet size value returned from server (%s): %s", packetsize, err.Error()) } sess.buf.ResizeBuffer(packetsizei) case envSortId: // currently ignored // new value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envSortFlags: // currently ignored // new value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envSqlCollation: // currently ignored var collationSize uint8 err = binary.Read(r, binary.LittleEndian, &collationSize) if err != nil { badStreamPanic(err) } // SQL Collation data should contain 5 bytes in length if collationSize != 5 { badStreamPanicf("Invalid SQL Collation size value returned from server: %d", collationSize) } // 4 bytes, contains: LCID ColFlags Version var info uint32 err = binary.Read(r, binary.LittleEndian, &info) if err != nil { badStreamPanic(err) } // 1 byte, contains: sortID var sortID uint8 err = binary.Read(r, binary.LittleEndian, &sortID) if err != nil { badStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envTypBeginTran: tranid, err := readBVarByte(r) if len(tranid) != 8 { badStreamPanicf("invalid size of transaction identifier: %d", len(tranid)) } sess.tranid = binary.LittleEndian.Uint64(tranid) if err != nil { badStreamPanic(err) } if sess.logFlags&logTransaction != 0 { sess.logger.Log(ctx, msdsn.LogTransaction, fmt.Sprintf("BEGIN TRANSACTION %x", sess.tranid)) } _, err = readBVarByte(r) if err != nil { badStreamPanic(err) } case envTypCommitTran, envTypRollbackTran: _, err = readBVarByte(r) if err != nil { badStreamPanic(err) } _, err = readBVarByte(r) if err != nil { badStreamPanic(err) } if sess.logFlags&logTransaction != 0 { if envtype == envTypCommitTran { sess.logger.Log(ctx, msdsn.LogTransaction, fmt.Sprintf("COMMIT TRANSACTION %x", sess.tranid)) } else { sess.logger.Log(ctx, msdsn.LogTransaction, fmt.Sprintf("ROLLBACK TRANSACTION %x", sess.tranid)) } } sess.tranid = 0 case envEnlistDTC: // currently ignored // new value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // old value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envDefectTran: // currently ignored // new value if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envDatabaseMirrorPartner: sess.partner, err = readBVarChar(r) if err != nil { badStreamPanic(err) } _, err = readBVarChar(r) if err != nil { badStreamPanic(err) } case envPromoteTran: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // dtc token // spec says it should be L_VARBYTE, so this code might be wrong if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envTranMgrAddr: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // XACT_MANAGER_ADDRESS = B_VARBYTE if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envTranEnded: // currently ignored // old value, B_VARBYTE if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envResetConnAck: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envStartedInstanceName: // currently ignored // old value, should be 0 if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } // instance name if _, err = readBVarChar(r); err != nil { badStreamPanic(err) } case envRouting: // RoutingData message is: // ValueLength USHORT // Protocol (TCP = 0) BYTE // ProtocolProperty (new port) USHORT // AlternateServer US_VARCHAR _, err := readUshort(r) if err != nil { badStreamPanic(err) } protocol, err := readByte(r) if err != nil || protocol != 0 { badStreamPanic(err) } newPort, err := readUshort(r) if err != nil { badStreamPanic(err) } newServer, err := readUsVarChar(r) if err != nil { badStreamPanic(err) } // consume the OLDVALUE = %x00 %x00 _, err = readUshort(r) if err != nil { badStreamPanic(err) } sess.routedServer = newServer sess.routedPort = newPort default: // ignore rest of records because we don't know how to skip those if sess.logFlags&logDebug != 0 { sess.logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf("WARN: Unknown ENVCHANGE record detected with type id = %d", envtype)) } return } } } // http://msdn.microsoft.com/en-us/library/dd358180.aspx func parseReturnStatus(r *tdsBuffer) ReturnStatus { return ReturnStatus(r.int32()) } func parseOrder(r *tdsBuffer) (res orderStruct) { len := int(r.uint16()) res.ColIds = make([]uint16, len/2) for i := 0; i < len/2; i++ { res.ColIds[i] = r.uint16() } return res } // https://msdn.microsoft.com/en-us/library/dd340421.aspx func parseDone(r *tdsBuffer) (res doneStruct) { res.Status = r.uint16() res.CurCmd = r.uint16() res.RowCount = r.uint64() return res } // https://msdn.microsoft.com/en-us/library/dd340553.aspx func parseDoneInProc(r *tdsBuffer) (res doneInProcStruct) { res.Status = r.uint16() res.CurCmd = r.uint16() res.RowCount = r.uint64() return res } type sspiMsg []byte func parseSSPIMsg(r *tdsBuffer) sspiMsg { size := r.uint16() buf := make([]byte, size) r.ReadFull(buf) return sspiMsg(buf) } type fedAuthInfoStruct struct { STSURL string ServerSPN string } type fedAuthInfoOpt struct { fedAuthInfoID byte dataLength, dataOffset uint32 } func parseFedAuthInfo(r *tdsBuffer) fedAuthInfoStruct { size := r.uint32() var STSURL, SPN string var err error // Each fedAuthInfoOpt is one byte to indicate the info ID, // then a four byte offset and a four byte length. count := r.uint32() offset := uint32(4) opts := make([]fedAuthInfoOpt, count) for i := uint32(0); i < count; i++ { fedAuthInfoID := r.byte() dataLength := r.uint32() dataOffset := r.uint32() offset += 1 + 4 + 4 opts[i] = fedAuthInfoOpt{ fedAuthInfoID: fedAuthInfoID, dataLength: dataLength, dataOffset: dataOffset, } } data := make([]byte, size-offset) r.ReadFull(data) for i := uint32(0); i < count; i++ { if opts[i].dataOffset < offset { badStreamPanicf("Fed auth info opt stated data offset %d is before data begins in packet at %d", opts[i].dataOffset, offset) // returns via panic } if opts[i].dataOffset+opts[i].dataLength > size { badStreamPanicf("Fed auth info opt stated data length %d added to stated offset exceeds size of packet %d", opts[i].dataOffset+opts[i].dataLength, size) // returns via panic } optData := data[opts[i].dataOffset-offset : opts[i].dataOffset-offset+opts[i].dataLength] switch opts[i].fedAuthInfoID { case fedAuthInfoSTSURL: STSURL, err = ucs22str(optData) case fedAuthInfoSPN: SPN, err = ucs22str(optData) default: err = fmt.Errorf("unexpected fed auth info opt ID %d", int(opts[i].fedAuthInfoID)) } if err != nil { badStreamPanic(err) } } return fedAuthInfoStruct{ STSURL: STSURL, ServerSPN: SPN, } } type loginAckStruct struct { Interface uint8 TDSVersion uint32 ProgName string ProgVer uint32 } func parseLoginAck(r *tdsBuffer) loginAckStruct { size := r.uint16() buf := make([]byte, size) r.ReadFull(buf) var res loginAckStruct res.Interface = buf[0] res.TDSVersion = binary.BigEndian.Uint32(buf[1:]) prognamelen := buf[1+4] var err error if res.ProgName, err = ucs22str(buf[1+4+1 : 1+4+1+prognamelen*2]); err != nil { badStreamPanic(err) } res.ProgVer = binary.BigEndian.Uint32(buf[size-4:]) return res } // https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-tds/2eb82f8e-11f0-46dc-b42d-27302fa4701a type fedAuthAckStruct struct { Nonce []byte Signature []byte } type colAckStruct struct { Version int EnclaveType string } type featureExtAck map[byte]interface{} func parseFeatureExtAck(r *tdsBuffer) featureExtAck { ack := map[byte]interface{}{} for feature := r.byte(); feature != featExtTERMINATOR; feature = r.byte() { length := r.uint32() switch feature { case featExtFEDAUTH: // In theory we need to know the federated authentication library to // know how to parse, but the alternatives provide compatible structures. fedAuthAck := fedAuthAckStruct{} if length >= 32 { fedAuthAck.Nonce = make([]byte, 32) r.ReadFull(fedAuthAck.Nonce) length -= 32 } if length >= 32 { fedAuthAck.Signature = make([]byte, 32) r.ReadFull(fedAuthAck.Signature) length -= 32 } ack[feature] = fedAuthAck case featExtCOLUMNENCRYPTION: colAck := colAckStruct{Version: int(r.byte())} length-- if length > 0 { // enclave type is sent as utf16 le enclaveLength := r.byte() * 2 length-- enclaveBytes := make([]byte, enclaveLength) r.ReadFull(enclaveBytes) // if the enclave type is malformed we'll just ignore it colAck.EnclaveType, _ = ucs22str(enclaveBytes) length -= uint32(enclaveLength) } ack[feature] = colAck } // Skip unprocessed bytes if length > 0 { io.CopyN(ioutil.Discard, r, int64(length)) } } return ack } // http://msdn.microsoft.com/en-us/library/dd357363.aspx func parseColMetadata72(r *tdsBuffer, s *tdsSession) (columns []columnStruct) { count := r.uint16() if count == 0xffff { // no metadata is sent return nil } columns = make([]columnStruct, count) var cekTable *cekTable if s.alwaysEncrypted { // column encryption key list cekTable = readCekTable(r) } for i := range columns { column := &columns[i] baseTi := getBaseTypeInfo(r, true) typeInfo := readTypeInfo(r, baseTi.TypeId, column.cryptoMeta) typeInfo.UserType = baseTi.UserType typeInfo.Flags = baseTi.Flags typeInfo.TypeId = baseTi.TypeId column.Flags = baseTi.Flags column.UserType = baseTi.UserType column.ti = typeInfo if column.isEncrypted() && s.alwaysEncrypted { // Read Crypto Metadata cryptoMeta := parseCryptoMetadata(r, cekTable) cryptoMeta.typeInfo.Flags = baseTi.Flags column.cryptoMeta = &cryptoMeta } else { column.cryptoMeta = nil } column.ColName = r.BVarChar() } return columns } func getBaseTypeInfo(r *tdsBuffer, parseFlags bool) typeInfo { userType := r.uint32() flags := uint16(0) if parseFlags { flags = r.uint16() } tId := r.byte() return typeInfo{ UserType: userType, Flags: flags, TypeId: tId} } type cryptoMetadata struct { entry *cekTableEntry ordinal uint16 algorithmId byte algorithmName *string encType byte normRuleVer byte typeInfo typeInfo } func parseCryptoMetadata(r *tdsBuffer, cekTable *cekTable) cryptoMetadata { ordinal := uint16(0) if cekTable != nil { ordinal = r.uint16() } typeInfo := getBaseTypeInfo(r, false) ti := readTypeInfo(r, typeInfo.TypeId, nil) ti.UserType = typeInfo.UserType ti.Flags = typeInfo.Flags ti.TypeId = typeInfo.TypeId algorithmId := r.byte() var algName *string = nil if algorithmId == cipherAlgCustom { // Read the name when a custom algorithm is used nameLen := int(r.byte()) var algNameUtf16 = make([]byte, nameLen*2) r.ReadFull(algNameUtf16) algNameBytes, _ := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM).NewDecoder().Bytes(algNameUtf16) mAlgName := string(algNameBytes) algName = &mAlgName } encType := r.byte() normRuleVer := r.byte() var entry *cekTableEntry = nil if cekTable != nil { if int(ordinal) > len(cekTable.entries)-1 { panic(fmt.Errorf("invalid ordinal, cekTable only has %d entries", len(cekTable.entries))) } entry = &cekTable.entries[ordinal] } return cryptoMetadata{ entry: entry, ordinal: ordinal, algorithmId: algorithmId, algorithmName: algName, encType: encType, normRuleVer: normRuleVer, typeInfo: ti, } } func readCekTable(r *tdsBuffer) *cekTable { tableSize := r.uint16() var cekTable *cekTable = nil if tableSize != 0 { mCekTable := newCekTable(tableSize) for i := uint16(0); i < tableSize; i++ { mCekTable.entries[i] = readCekTableEntry(r) } cekTable = &mCekTable } return cekTable } func readCekTableEntry(r *tdsBuffer) cekTableEntry { databaseId := r.int32() cekID := r.int32() cekVersion := r.int32() var cekMdVersion = make([]byte, 8) _, err := r.Read(cekMdVersion) if err != nil { panic("unable to read cekMdVersion") } cekValueCount := uint(r.byte()) // not using ucs22str because we already know the data is utf16 enc := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) utf16dec := enc.NewDecoder() cekValues := make([]encryptionKeyInfo, cekValueCount) for i := uint(0); i < cekValueCount; i++ { encryptedCekLength := r.uint16() encryptedCek := make([]byte, encryptedCekLength) r.ReadFull(encryptedCek) keyStoreLength := r.byte() keyStoreNameUtf16 := make([]byte, keyStoreLength*2) r.ReadFull(keyStoreNameUtf16) keyStoreName, _ := utf16dec.Bytes(keyStoreNameUtf16) keyPathLength := r.uint16() keyPathUtf16 := make([]byte, keyPathLength*2) r.ReadFull(keyPathUtf16) keyPath, _ := utf16dec.Bytes(keyPathUtf16) algLength := r.byte() algNameUtf16 := make([]byte, algLength*2) r.ReadFull(algNameUtf16) algName, _ := utf16dec.Bytes(algNameUtf16) cekValues[i] = encryptionKeyInfo{ encryptedKey: encryptedCek, databaseID: int(databaseId), cekID: int(cekID), cekVersion: int(cekVersion), cekMdVersion: cekMdVersion, keyPath: string(keyPath), keyStoreName: string(keyStoreName), algorithmName: string(algName), } } return cekTableEntry{ databaseID: int(databaseId), keyId: int(cekID), keyVersion: int(cekVersion), mdVersion: cekMdVersion, valueCount: int(cekValueCount), cekValues: cekValues, } } // http://msdn.microsoft.com/en-us/library/dd357254.aspx func parseRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) { for i, column := range columns { columnContent := column.ti.Reader(&column.ti, r, nil) if columnContent == nil { row[i] = columnContent continue } if column.isEncrypted() { buffer := decryptColumn(column, s, columnContent) // Decrypt row[i] = column.cryptoMeta.typeInfo.Reader(&column.cryptoMeta.typeInfo, &buffer, column.cryptoMeta) } else { row[i] = columnContent } } } type RWCBuffer struct { buffer *bytes.Reader } func (R RWCBuffer) Read(p []byte) (n int, err error) { return R.buffer.Read(p) } func (R RWCBuffer) Write(p []byte) (n int, err error) { return 0, nil } func (R RWCBuffer) Close() error { return nil } func decryptColumn(column columnStruct, s *tdsSession, columnContent interface{}) tdsBuffer { encType := encryption.From(column.cryptoMeta.encType) cekValue := column.cryptoMeta.entry.cekValues[column.cryptoMeta.ordinal] if (s.logFlags & uint64(msdsn.LogDebug)) == uint64(msdsn.LogDebug) { s.logger.Log(context.Background(), msdsn.LogDebug, fmt.Sprintf("Decrypting column %s. Key path: %s, Key store:%s, Algo: %s", column.ColName, cekValue.keyPath, cekValue.keyStoreName, cekValue.algorithmName)) } cekProvider, ok := s.aeSettings.keyProviders[cekValue.keyStoreName] if !ok { panic(fmt.Errorf("Unable to find provider %s to decrypt CEK", cekValue.keyStoreName)) } cek, err := cekProvider.GetDecryptedKey(cekValue.keyPath, column.cryptoMeta.entry.cekValues[0].encryptedKey) if err != nil { panic(err) } k := keys.NewAeadAes256CbcHmac256(cek) alg := algorithms.NewAeadAes256CbcHmac256Algorithm(k, encType, byte(cekValue.cekVersion)) d, err := alg.Decrypt(columnContent.([]byte)) if err != nil { panic(err) } // Decrypt returns a minimum of 8 bytes so truncate to the actual data size if column.cryptoMeta.typeInfo.Size > 0 && column.cryptoMeta.typeInfo.Size < len(d) { d = d[:column.cryptoMeta.typeInfo.Size] } var newBuff []byte newBuff = append(newBuff, d...) rwc := RWCBuffer{ buffer: bytes.NewReader(newBuff), } column.cryptoMeta.typeInfo.Buffer = d buffer := tdsBuffer{rpos: 0, rsize: len(newBuff), rbuf: newBuff, transport: rwc} return buffer } // http://msdn.microsoft.com/en-us/library/dd304783.aspx func parseNbcRow(r *tdsBuffer, s *tdsSession, columns []columnStruct, row []interface{}) { bitlen := (len(columns) + 7) / 8 pres := make([]byte, bitlen) r.ReadFull(pres) for i, col := range columns { if pres[i/8]&(1<<(uint(i)%8)) != 0 { row[i] = nil continue } columnContent := col.ti.Reader(&col.ti, r, nil) if col.isEncrypted() { buffer := decryptColumn(col, s, columnContent) // Decrypt row[i] = col.cryptoMeta.typeInfo.Reader(&col.cryptoMeta.typeInfo, &buffer, col.cryptoMeta) } else { row[i] = columnContent } } } // http://msdn.microsoft.com/en-us/library/dd304156.aspx func parseError72(r *tdsBuffer) (res Error) { length := r.uint16() _ = length // ignore length res.Number = r.int32() res.State = r.byte() res.Class = r.byte() res.Message = r.UsVarChar() res.ServerName = r.BVarChar() res.ProcName = r.BVarChar() res.LineNo = r.int32() return } // http://msdn.microsoft.com/en-us/library/dd304156.aspx func parseInfo(r *tdsBuffer) (res Error) { length := r.uint16() _ = length // ignore length res.Number = r.int32() res.State = r.byte() res.Class = r.byte() res.Message = r.UsVarChar() res.ServerName = r.BVarChar() res.ProcName = r.BVarChar() res.LineNo = r.int32() return } // https://msdn.microsoft.com/en-us/library/dd303881.aspx func parseReturnValue(r *tdsBuffer, s *tdsSession) (nv namedValue) { /* ParamOrdinal ParamName Status UserType Flags TypeInfo CryptoMetadata Value */ _ = r.uint16() // ParamOrdinal nv.Name = r.BVarChar() // ParamName _ = r.byte() // Status ti := getBaseTypeInfo(r, true) // UserType + Flags + TypeInfo var cryptoMetadata *cryptoMetadata = nil if s.alwaysEncrypted && (ti.Flags&fEncrypted) == fEncrypted { cm := parseCryptoMetadata(r, nil) // CryptoMetadata cryptoMetadata = &cm } ti2 := readTypeInfo(r, ti.TypeId, cryptoMetadata) nv.Value = ti2.Reader(&ti2, r, cryptoMetadata) return } func processSingleResponse(ctx context.Context, sess *tdsSession, ch chan tokenStruct, outs outputs) { defer func() { if err := recover(); err != nil { if sess.logFlags&logErrors != 0 { sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("Intercepted panic %v", err)) } if outs.msgq != nil { var derr error switch e := err.(type) { case error: derr = e default: derr = fmt.Errorf("Unhandled session error %v", e) } _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgError{Error: derr}) } ch <- err } close(ch) }() colsReceived := false packet_type, err := sess.buf.BeginRead() if err != nil { if sess.logFlags&logErrors != 0 { sess.logger.Log(ctx, msdsn.LogErrors, fmt.Sprintf("BeginRead failed %v", err)) } switch e := err.(type) { case *net.OpError: err = e default: // the named pipe provider returns a raw win32 error so fake an OpError err = &net.OpError{Op: "Read", Err: err} } ch <- err return } if packet_type != packReply { badStreamPanic(fmt.Errorf("unexpected packet type in reply: got %v, expected %v", packet_type, packReply)) } var columns []columnStruct errs := make([]Error, 0, 5) for tokens := 0; ; tokens += 1 { token := token(sess.buf.byte()) if sess.logFlags&logDebug != 0 { sess.logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf("got token %v", token)) } switch token { case tokenSSPI: ch <- parseSSPIMsg(sess.buf) return case tokenFedAuthInfo: ch <- parseFedAuthInfo(sess.buf) return case tokenReturnStatus: returnStatus := parseReturnStatus(sess.buf) ch <- returnStatus case tokenLoginAck: loginAck := parseLoginAck(sess.buf) ch <- loginAck case tokenFeatureExtAck: featureExtAck := parseFeatureExtAck(sess.buf) ch <- featureExtAck case tokenOrder: order := parseOrder(sess.buf) ch <- order case tokenDoneInProc: done := parseDoneInProc(sess.buf) ch <- done if done.Status&doneCount != 0 { if sess.logFlags&logRows != 0 { sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(%d rows affected)", done.RowCount)) } if (colsReceived || done.CurCmd != cmdSelect) && outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)}) } } if outs.msgq != nil { // For now we ignore ctx->Done errors that ReturnMessageEnqueue might return // It's not clear how to handle them correctly here, and data/sql seems // to set Rows.Err correctly when ctx expires already _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) } colsReceived = false if done.Status&doneMore == 0 { // Rows marks the request as done when seeing this done token. We queue another result set message // so the app calls NextResultSet again which will return false. if outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) } return } case tokenDone, tokenDoneProc: done := parseDone(sess.buf) done.errors = errs if outs.msgq != nil { errs = make([]Error, 0, 5) } if sess.logFlags&logDebug != 0 { sess.logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf("got DONE or DONEPROC status=%d", done.Status)) } if done.Status&doneSrvError != 0 { ch <- ServerError{done.getError()} if outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) } return } ch <- done if done.Status&doneCount != 0 { if sess.logFlags&logRows != 0 { sess.logger.Log(ctx, msdsn.LogRows, fmt.Sprintf("(Rows affected: %d)", done.RowCount)) } if (colsReceived || done.CurCmd != cmdSelect) && outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgRowsAffected{Count: int64(done.RowCount)}) } } colsReceived = false if outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) } if done.Status&doneMore == 0 { // Rows marks the request as done when seeing this done token. We queue another result set message // so the app calls NextResultSet again which will return false. if outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNextResultSet{}) } return } case tokenColMetadata: columns = parseColMetadata72(sess.buf, sess) ch <- columns colsReceived = true if outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNext{}) } case tokenRow: row := make([]interface{}, len(columns)) parseRow(sess.buf, sess, columns, row) ch <- row case tokenNbcRow: row := make([]interface{}, len(columns)) parseNbcRow(sess.buf, sess, columns, row) ch <- row case tokenEnvChange: processEnvChg(ctx, sess) case tokenError: err := parseError72(sess.buf) if sess.logFlags&logDebug != 0 { sess.logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf("got ERROR %d %s", err.Number, err.Message)) } errs = append(errs, err) if sess.logFlags&logErrors != 0 { sess.logger.Log(ctx, msdsn.LogErrors, err.Message) } if outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgError{Error: err}) } case tokenInfo: info := parseInfo(sess.buf) if sess.logFlags&logDebug != 0 { sess.logger.Log(ctx, msdsn.LogDebug, fmt.Sprintf("got INFO %d %s", info.Number, info.Message)) } if sess.logFlags&logMessages != 0 { sess.logger.Log(ctx, msdsn.LogMessages, info.Message) } if outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(ctx, outs.msgq, sqlexp.MsgNotice{Message: info}) } case tokenReturnValue: nv := parseReturnValue(sess.buf, sess) if len(nv.Name) > 0 { name := nv.Name[1:] // Remove the leading "@". if ov, has := outs.params[name]; has { err = scanIntoOut(name, nv.Value, ov) if err != nil { fmt.Println("scan error", err) ch <- err } } } default: badStreamPanic(fmt.Errorf("unknown token type returned: %v", token)) } } } type tokenProcessor struct { tokChan chan tokenStruct ctx context.Context sess *tdsSession outs outputs lastRow []interface{} rowCount int64 firstError error // whether to skip sending attention when ctx is done noAttn bool } func startReading(sess *tdsSession, ctx context.Context, outs outputs) *tokenProcessor { tokChan := make(chan tokenStruct, 5) go processSingleResponse(ctx, sess, tokChan, outs) return &tokenProcessor{ tokChan: tokChan, ctx: ctx, sess: sess, outs: outs, } } func (t *tokenProcessor) iterateResponse() error { for { tok, err := t.nextToken() if err == nil { if tok == nil { return t.firstError } else { switch token := tok.(type) { case []columnStruct: t.sess.columns = token case []interface{}: t.lastRow = token case doneInProcStruct: if token.Status&doneCount != 0 { t.rowCount += int64(token.RowCount) } case doneStruct: if token.Status&doneCount != 0 { t.rowCount += int64(token.RowCount) } if token.isError() && t.firstError == nil { t.firstError = token.getError() } case ReturnStatus: if t.outs.returnStatus != nil { *t.outs.returnStatus = token } /*case error: if resultError == nil { resultError = token }*/ } } } else { return err } } } func (t tokenProcessor) nextToken() (tokenStruct, error) { // we do this separate non-blocking check on token channel to // prioritize it over cancellation channel select { case tok, more := <-t.tokChan: err, more := tok.(error) if more { // this is an error and not a token return nil, err } else { return tok, nil } default: // there are no tokens on the channel, will need to wait } select { case tok, more := <-t.tokChan: if more { err, ok := tok.(error) if ok { // this is an error and not a token return nil, err } else { return tok, nil } } else { // completed reading response return nil, nil } case <-t.ctx.Done(): // It seems the Message function on t.outs.msgq doesn't get the Done if it comes here instead if t.outs.msgq != nil { _ = sqlexp.ReturnMessageEnqueue(t.ctx, t.outs.msgq, sqlexp.MsgNextResultSet{}) } if t.noAttn { return nil, t.ctx.Err() } if err := sendAttention(t.sess.buf); err != nil { // unable to send attention, current connection is bad // notify caller and close channel return nil, err } // now the server should send cancellation confirmation // it is possible that we already received full response // just before we sent cancellation request // in this case current response would not contain confirmation // and we would need to read one more response // first lets finish reading current response and look // for confirmation in it if readCancelConfirmation(t.tokChan) { // we got confirmation in current response return nil, t.ctx.Err() } // we did not get cancellation confirmation in the current response // read one more response, it must be there t.tokChan = make(chan tokenStruct, 5) go processSingleResponse(t.ctx, t.sess, t.tokChan, t.outs) if readCancelConfirmation(t.tokChan) { return nil, t.ctx.Err() } // we did not get cancellation confirmation, something is not // right, this connection is not usable anymore return nil, ServerError{Error{Message: "did not get cancellation confirmation from the server"}} } } func readCancelConfirmation(tokChan chan tokenStruct) bool { for tok := range tokChan { switch tok := tok.(type) { default: // just skip token case doneStruct: if tok.Status&doneAttn != 0 { // got cancellation confirmation, exit return true } } } return false }