From 9d4c1ddfa14d33ce3b78c02421fb76a93e5ca2d6 Mon Sep 17 00:00:00 2001 From: techknowlogick Date: Tue, 3 Jul 2018 17:58:31 -0400 Subject: [PATCH] Dep upgrade mysql lib (#4161) * update gopkg file to add sql dep --- Gopkg.lock | 4 +- Gopkg.toml | 4 + vendor/github.com/go-sql-driver/mysql/AUTHORS | 34 ++ .../go-sql-driver/mysql/appengine.go | 2 +- vendor/github.com/go-sql-driver/mysql/auth.go | 420 +++++++++++++++++ .../github.com/go-sql-driver/mysql/buffer.go | 12 +- .../go-sql-driver/mysql/collations.go | 1 + .../go-sql-driver/mysql/connection.go | 152 ++++-- .../go-sql-driver/mysql/connection_go18.go | 208 +++++++++ .../github.com/go-sql-driver/mysql/const.go | 25 +- .../github.com/go-sql-driver/mysql/driver.go | 81 ++-- vendor/github.com/go-sql-driver/mysql/dsn.go | 159 +++++-- .../github.com/go-sql-driver/mysql/errors.go | 79 +--- .../github.com/go-sql-driver/mysql/fields.go | 194 ++++++++ .../github.com/go-sql-driver/mysql/infile.go | 6 +- .../github.com/go-sql-driver/mysql/packets.go | 441 ++++++++++-------- vendor/github.com/go-sql-driver/mysql/rows.go | 174 +++++-- .../go-sql-driver/mysql/statement.go | 123 +++-- .../go-sql-driver/mysql/transaction.go | 4 +- .../github.com/go-sql-driver/mysql/utils.go | 214 ++++----- .../go-sql-driver/mysql/utils_go17.go | 40 ++ .../go-sql-driver/mysql/utils_go18.go | 50 ++ 22 files changed, 1813 insertions(+), 614 deletions(-) create mode 100644 vendor/github.com/go-sql-driver/mysql/auth.go create mode 100644 vendor/github.com/go-sql-driver/mysql/connection_go18.go create mode 100644 vendor/github.com/go-sql-driver/mysql/fields.go create mode 100644 vendor/github.com/go-sql-driver/mysql/utils_go17.go create mode 100644 vendor/github.com/go-sql-driver/mysql/utils_go18.go diff --git a/Gopkg.lock b/Gopkg.lock index 6551354a09..d5b2408bba 100644 --- a/Gopkg.lock +++ b/Gopkg.lock @@ -294,7 +294,7 @@ [[projects]] name = "github.com/go-sql-driver/mysql" packages = ["."] - revision = "ce924a41eea897745442daaa1739089b0f3f561d" + revision = "d523deb1b23d913de5bdada721a6071e71283618" [[projects]] name = "github.com/go-xorm/builder" @@ -873,6 +873,6 @@ [solve-meta] analyzer-name = "dep" analyzer-version = 1 - inputs-digest = "036b8c882671cf8d2c5e2fdbe53b1bdfbd39f7ebd7765bd50276c7c4ecf16687" + inputs-digest = "96c83a3502bd50c5ca8e4d9b4145172267630270e587c79b7253156725eeb9b8" solver-name = "gps-cdcl" solver-version = 1 diff --git a/Gopkg.toml b/Gopkg.toml index 1019888c01..42523550fd 100644 --- a/Gopkg.toml +++ b/Gopkg.toml @@ -40,6 +40,10 @@ ignored = ["google.golang.org/appengine*"] #version = "0.6.5" revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03" +[[override]] + name = "github.com/go-sql-driver/mysql" + revision = "d523deb1b23d913de5bdada721a6071e71283618" + [[override]] name = "github.com/gorilla/mux" revision = "757bef944d0f21880861c2dd9c871ca543023cba" diff --git a/vendor/github.com/go-sql-driver/mysql/AUTHORS b/vendor/github.com/go-sql-driver/mysql/AUTHORS index 148ea93aa7..73ff68fbcf 100644 --- a/vendor/github.com/go-sql-driver/mysql/AUTHORS +++ b/vendor/github.com/go-sql-driver/mysql/AUTHORS @@ -12,34 +12,63 @@ # Individual Persons Aaron Hopkins +Achille Roussel +Alexey Palazhchenko +Andrew Reid Arne Hormann +Asta Xie +Bulat Gaifullin Carlos Nieto Chris Moos +Craig Wilson +Daniel Montoya Daniel Nichter Daniël van Eeden +Dave Protasowski DisposaBoy +Egor Smolyakov +Evan Shaw Frederick Mayle Gustavo Kristic +Hajime Nakagami Hanno Braun Henri Yandell Hirotaka Yamamoto +ICHINOSE Shogo INADA Naoki +Jacek Szwec James Harr +Jeff Hodges +Jeffrey Charles Jian Zhen Joshua Prunier Julien Lefevre Julien Schmidt +Justin Li +Justin Nuß Kamil Dziedzic Kevin Malachowski +Kieron Woodhouse Lennart Rudolph Leonardo YongUk Kim +Linh Tran Tuan +Lion Yang Luca Looz Lucas Liu Luke Scott +Maciej Zimnoch Michael Woolnough Nicola Peduzzi +Olivier Mengué +oscarzhao Paul Bonser +Peter Schultz +Rebecca Chin +Reed Allman +Richard Wilkes +Robert Russell Runrioter Wung +Shuode Li Soroush Pour Stan Putrya Stanley Gunawan @@ -51,5 +80,10 @@ Zhenye Xie # Organizations Barracuda Networks, Inc. +Counting Ltd. Google Inc. +InfoSum Ltd. +Keybase Inc. +Percona LLC +Pivotal Inc. Stripe Inc. diff --git a/vendor/github.com/go-sql-driver/mysql/appengine.go b/vendor/github.com/go-sql-driver/mysql/appengine.go index 565614eef7..be41f2ee6d 100644 --- a/vendor/github.com/go-sql-driver/mysql/appengine.go +++ b/vendor/github.com/go-sql-driver/mysql/appengine.go @@ -11,7 +11,7 @@ package mysql import ( - "appengine/cloudsql" + "google.golang.org/appengine/cloudsql" ) func init() { diff --git a/vendor/github.com/go-sql-driver/mysql/auth.go b/vendor/github.com/go-sql-driver/mysql/auth.go new file mode 100644 index 0000000000..0b59f52ee7 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/auth.go @@ -0,0 +1,420 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/sha256" + "crypto/x509" + "encoding/pem" + "sync" +) + +// server pub keys registry +var ( + serverPubKeyLock sync.RWMutex + serverPubKeyRegistry map[string]*rsa.PublicKey +) + +// RegisterServerPubKey registers a server RSA public key which can be used to +// send data in a secure manner to the server without receiving the public key +// in a potentially insecure way from the server first. +// Registered keys can afterwards be used adding serverPubKey= to the DSN. +// +// Note: The provided rsa.PublicKey instance is exclusively owned by the driver +// after registering it and may not be modified. +// +// data, err := ioutil.ReadFile("mykey.pem") +// if err != nil { +// log.Fatal(err) +// } +// +// block, _ := pem.Decode(data) +// if block == nil || block.Type != "PUBLIC KEY" { +// log.Fatal("failed to decode PEM block containing public key") +// } +// +// pub, err := x509.ParsePKIXPublicKey(block.Bytes) +// if err != nil { +// log.Fatal(err) +// } +// +// if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { +// mysql.RegisterServerPubKey("mykey", rsaPubKey) +// } else { +// log.Fatal("not a RSA public key") +// } +// +func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry == nil { + serverPubKeyRegistry = make(map[string]*rsa.PublicKey) + } + + serverPubKeyRegistry[name] = pubKey + serverPubKeyLock.Unlock() +} + +// DeregisterServerPubKey removes the public key registered with the given name. +func DeregisterServerPubKey(name string) { + serverPubKeyLock.Lock() + if serverPubKeyRegistry != nil { + delete(serverPubKeyRegistry, name) + } + serverPubKeyLock.Unlock() +} + +func getServerPubKey(name string) (pubKey *rsa.PublicKey) { + serverPubKeyLock.RLock() + if v, ok := serverPubKeyRegistry[name]; ok { + pubKey = v + } + serverPubKeyLock.RUnlock() + return +} + +// Hash password using pre 4.1 (old password) method +// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c +type myRnd struct { + seed1, seed2 uint32 +} + +const myRndMaxVal = 0x3FFFFFFF + +// Pseudo random number generator +func newMyRnd(seed1, seed2 uint32) *myRnd { + return &myRnd{ + seed1: seed1 % myRndMaxVal, + seed2: seed2 % myRndMaxVal, + } +} + +// Tested to be equivalent to MariaDB's floating point variant +// http://play.golang.org/p/QHvhd4qved +// http://play.golang.org/p/RG0q4ElWDx +func (r *myRnd) NextByte() byte { + r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal + r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal + + return byte(uint64(r.seed1) * 31 / myRndMaxVal) +} + +// Generate binary hash from byte string using insecure pre 4.1 method +func pwHash(password []byte) (result [2]uint32) { + var add uint32 = 7 + var tmp uint32 + + result[0] = 1345345333 + result[1] = 0x12345671 + + for _, c := range password { + // skip spaces and tabs in password + if c == ' ' || c == '\t' { + continue + } + + tmp = uint32(c) + result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) + result[1] += (result[1] << 8) ^ result[0] + add += tmp + } + + // Remove sign bit (1<<31)-1) + result[0] &= 0x7FFFFFFF + result[1] &= 0x7FFFFFFF + + return +} + +// Hash password using insecure pre 4.1 method +func scrambleOldPassword(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + scramble = scramble[:8] + + hashPw := pwHash([]byte(password)) + hashSc := pwHash(scramble) + + r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) + + var out [8]byte + for i := range out { + out[i] = r.NextByte() + 64 + } + + mask := r.NextByte() + for i := range out { + out[i] ^= mask + } + + return out[:] +} + +// Hash password using 4.1+ method (SHA1) +func scramblePassword(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // stage1Hash = SHA1(password) + crypt := sha1.New() + crypt.Write([]byte(password)) + stage1 := crypt.Sum(nil) + + // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) + // inner Hash + crypt.Reset() + crypt.Write(stage1) + hash := crypt.Sum(nil) + + // outer Hash + crypt.Reset() + crypt.Write(scramble) + crypt.Write(hash) + scramble = crypt.Sum(nil) + + // token = scrambleHash XOR stage1Hash + for i := range scramble { + scramble[i] ^= stage1[i] + } + return scramble +} + +// Hash password using MySQL 8+ method (SHA256) +func scrambleSHA256Password(scramble []byte, password string) []byte { + if len(password) == 0 { + return nil + } + + // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) + + crypt := sha256.New() + crypt.Write([]byte(password)) + message1 := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1) + message1Hash := crypt.Sum(nil) + + crypt.Reset() + crypt.Write(message1Hash) + crypt.Write(scramble) + message2 := crypt.Sum(nil) + + for i := range message1 { + message1[i] ^= message2[i] + } + + return message1 +} + +func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { + plain := make([]byte, len(password)+1) + copy(plain, password) + for i := range plain { + j := i % len(seed) + plain[i] ^= seed[j] + } + sha1 := sha1.New() + return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) +} + +func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { + enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) + if err != nil { + return err + } + return mc.writeAuthSwitchPacket(enc, false) +} + +func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) { + switch plugin { + case "caching_sha2_password": + authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) + return authResp, (authResp == nil), nil + + case "mysql_old_password": + if !mc.cfg.AllowOldPasswords { + return nil, false, ErrOldPassword + } + // Note: there are edge cases where this should work but doesn't; + // this is currently "wontfix": + // https://github.com/go-sql-driver/mysql/issues/184 + authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd) + return authResp, true, nil + + case "mysql_clear_password": + if !mc.cfg.AllowCleartextPasswords { + return nil, false, ErrCleartextPassword + } + // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html + // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html + return []byte(mc.cfg.Passwd), true, nil + + case "mysql_native_password": + if !mc.cfg.AllowNativePasswords { + return nil, false, ErrNativePassword + } + // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html + // Native password authentication only need and will need 20-byte challenge. + authResp := scramblePassword(authData[:20], mc.cfg.Passwd) + return authResp, false, nil + + case "sha256_password": + if len(mc.cfg.Passwd) == 0 { + return nil, true, nil + } + if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + // write cleartext auth packet + return []byte(mc.cfg.Passwd), true, nil + } + + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + return []byte{1}, false, nil + } + + // encrypted password + enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) + return enc, false, err + + default: + errLog.Print("unknown auth plugin:", plugin) + return nil, false, ErrUnknownPlugin + } +} + +func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { + // Read Result Packet + authData, newPlugin, err := mc.readAuthResult() + if err != nil { + return err + } + + // handle auth plugin switch, if requested + if newPlugin != "" { + // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is + // sent and we have to keep using the cipher sent in the init packet. + if authData == nil { + authData = oldAuthData + } else { + // copy data from read buffer to owned slice + copy(oldAuthData, authData) + } + + plugin = newPlugin + + authResp, addNUL, err := mc.auth(authData, plugin) + if err != nil { + return err + } + if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil { + return err + } + + // Read Result Packet + authData, newPlugin, err = mc.readAuthResult() + if err != nil { + return err + } + + // Do not allow to change the auth plugin more than once + if newPlugin != "" { + return ErrMalformPkt + } + } + + switch plugin { + + // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ + case "caching_sha2_password": + switch len(authData) { + case 0: + return nil // auth successful + case 1: + switch authData[0] { + case cachingSha2PasswordFastAuthSuccess: + if err = mc.readResultOK(); err == nil { + return nil // auth successful + } + + case cachingSha2PasswordPerformFullAuthentication: + if mc.cfg.tls != nil || mc.cfg.Net == "unix" { + // write cleartext auth packet + err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true) + if err != nil { + return err + } + } else { + pubKey := mc.cfg.pubKey + if pubKey == nil { + // request public key from server + data := mc.buf.takeSmallBuffer(4 + 1) + data[4] = cachingSha2PasswordRequestPublicKey + mc.writePacket(data) + + // parse public key + data, err := mc.readPacket() + if err != nil { + return err + } + + block, _ := pem.Decode(data[1:]) + pkix, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + pubKey = pkix.(*rsa.PublicKey) + } + + // send encrypted password + err = mc.sendEncryptedPassword(oldAuthData, pubKey) + if err != nil { + return err + } + } + return mc.readResultOK() + + default: + return ErrMalformPkt + } + default: + return ErrMalformPkt + } + + case "sha256_password": + switch len(authData) { + case 0: + return nil // auth successful + default: + block, _ := pem.Decode(authData) + pub, err := x509.ParsePKIXPublicKey(block.Bytes) + if err != nil { + return err + } + + // send encrypted password + err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey)) + if err != nil { + return err + } + return mc.readResultOK() + } + + default: + return nil // auth successful + } + + return err +} diff --git a/vendor/github.com/go-sql-driver/mysql/buffer.go b/vendor/github.com/go-sql-driver/mysql/buffer.go index 2001feacd3..eb4748bf44 100644 --- a/vendor/github.com/go-sql-driver/mysql/buffer.go +++ b/vendor/github.com/go-sql-driver/mysql/buffer.go @@ -130,18 +130,18 @@ func (b *buffer) takeBuffer(length int) []byte { // smaller than defaultBufSize // Only one buffer (total) can be used at a time. func (b *buffer) takeSmallBuffer(length int) []byte { - if b.length == 0 { - return b.buf[:length] + if b.length > 0 { + return nil } - return nil + return b.buf[:length] } // takeCompleteBuffer returns the complete existing buffer. // This can be used if the necessary buffer size is unknown. // Only one buffer (total) can be used at a time. func (b *buffer) takeCompleteBuffer() []byte { - if b.length == 0 { - return b.buf + if b.length > 0 { + return nil } - return nil + return b.buf } diff --git a/vendor/github.com/go-sql-driver/mysql/collations.go b/vendor/github.com/go-sql-driver/mysql/collations.go index 82079cfb93..136c9e4d1e 100644 --- a/vendor/github.com/go-sql-driver/mysql/collations.go +++ b/vendor/github.com/go-sql-driver/mysql/collations.go @@ -9,6 +9,7 @@ package mysql const defaultCollation = "utf8_general_ci" +const binaryCollation = "binary" // A list of available collations mapped to the internal ID. // To update this map use the following MySQL query: diff --git a/vendor/github.com/go-sql-driver/mysql/connection.go b/vendor/github.com/go-sql-driver/mysql/connection.go index d82c728f3b..e57061412b 100644 --- a/vendor/github.com/go-sql-driver/mysql/connection.go +++ b/vendor/github.com/go-sql-driver/mysql/connection.go @@ -10,12 +10,23 @@ package mysql import ( "database/sql/driver" + "io" "net" "strconv" "strings" "time" ) +// a copy of context.Context for Go 1.7 and earlier +type mysqlContext interface { + Done() <-chan struct{} + Err() error + + // defined in context.Context, but not used in this driver: + // Deadline() (deadline time.Time, ok bool) + // Value(key interface{}) interface{} +} + type mysqlConn struct { buf buffer netConn net.Conn @@ -29,7 +40,14 @@ type mysqlConn struct { status statusFlag sequence uint8 parseTime bool - strict bool + + // for context support (Go 1.8+) + watching bool + watcher chan<- mysqlContext + closech chan struct{} + finished chan<- struct{} + canceled atomicError // set non-nil if conn is canceled + closed atomicBool // set when conn is closed, before closech is closed } // Handles parameters set in DSN after the connection is established @@ -62,22 +80,41 @@ func (mc *mysqlConn) handleParams() (err error) { return } +func (mc *mysqlConn) markBadConn(err error) error { + if mc == nil { + return err + } + if err != errBadConnNoWrite { + return err + } + return driver.ErrBadConn +} + func (mc *mysqlConn) Begin() (driver.Tx, error) { - if mc.netConn == nil { + return mc.begin(false) +} + +func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } - err := mc.exec("START TRANSACTION") + var q string + if readOnly { + q = "START TRANSACTION READ ONLY" + } else { + q = "START TRANSACTION" + } + err := mc.exec(q) if err == nil { return &mysqlTx{mc}, err } - - return nil, err + return nil, mc.markBadConn(err) } func (mc *mysqlConn) Close() (err error) { // Makes Close idempotent - if mc.netConn != nil { + if !mc.closed.IsSet() { err = mc.writeCommandPacket(comQuit) } @@ -91,26 +128,39 @@ func (mc *mysqlConn) Close() (err error) { // is called before auth or on auth failure because MySQL will have already // closed the network connection. func (mc *mysqlConn) cleanup() { - // Makes cleanup idempotent - if mc.netConn != nil { - if err := mc.netConn.Close(); err != nil { - errLog.Print(err) - } - mc.netConn = nil + if !mc.closed.TrySet(true) { + return } - mc.cfg = nil - mc.buf.nc = nil + + // Makes cleanup idempotent + close(mc.closech) + if mc.netConn == nil { + return + } + if err := mc.netConn.Close(); err != nil { + errLog.Print(err) + } +} + +func (mc *mysqlConn) error() error { + if mc.closed.IsSet() { + if err := mc.canceled.Value(); err != nil { + return err + } + return ErrInvalidConn + } + return nil } func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { - if mc.netConn == nil { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := mc.writeCommandPacketStr(comStmtPrepare, query) if err != nil { - return nil, err + return nil, mc.markBadConn(err) } stmt := &mysqlStmt{ @@ -144,7 +194,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin if buf == nil { // can not take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return "", driver.ErrBadConn + return "", ErrInvalidConn } buf = buf[:0] argPos := 0 @@ -257,7 +307,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin } func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { - if mc.netConn == nil { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -271,7 +321,6 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err return nil, err } query = prepared - args = nil } mc.affectedRows = 0 mc.insertId = 0 @@ -283,32 +332,43 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err insertId: int64(mc.insertId), }, err } - return nil, err + return nil, mc.markBadConn(err) } // Internal function to execute commands func (mc *mysqlConn) exec(query string) error { // Send command - err := mc.writeCommandPacketStr(comQuery, query) - if err != nil { - return err + if err := mc.writeCommandPacketStr(comQuery, query); err != nil { + return mc.markBadConn(err) } // Read Result resLen, err := mc.readResultSetHeaderPacket() - if err == nil && resLen > 0 { - if err = mc.readUntilEOF(); err != nil { + if err != nil { + return err + } + + if resLen > 0 { + // columns + if err := mc.readUntilEOF(); err != nil { return err } - err = mc.readUntilEOF() + // rows + if err := mc.readUntilEOF(); err != nil { + return err + } } - return err + return mc.discardResults() } func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { - if mc.netConn == nil { + return mc.query(query, args) +} + +func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { + if mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } @@ -322,7 +382,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro return nil, err } query = prepared - args = nil } // Send command err := mc.writeCommandPacketStr(comQuery, query) @@ -335,15 +394,22 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro rows.mc = mc if resLen == 0 { - // no columns, no more data - return emptyRows{}, nil + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err + } } + // Columns - rows.columns, err = mc.readColumns(resLen) + rows.rs.columns, err = mc.readColumns(resLen) return rows, err } } - return nil, err + return nil, mc.markBadConn(err) } // Gets the value of the given MySQL System Variable @@ -359,7 +425,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { if err == nil { rows := new(textRows) rows.mc = mc - rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}} + rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} if resLen > 0 { // Columns @@ -375,3 +441,21 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { } return nil, err } + +// finish is called when the query has canceled. +func (mc *mysqlConn) cancel(err error) { + mc.canceled.Set(err) + mc.cleanup() +} + +// finish is called when the query has succeeded. +func (mc *mysqlConn) finish() { + if !mc.watching || mc.finished == nil { + return + } + select { + case mc.finished <- struct{}{}: + mc.watching = false + case <-mc.closech: + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/connection_go18.go b/vendor/github.com/go-sql-driver/mysql/connection_go18.go new file mode 100644 index 0000000000..62796bfcea --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/connection_go18.go @@ -0,0 +1,208 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.8 + +package mysql + +import ( + "context" + "database/sql" + "database/sql/driver" +) + +// Ping implements driver.Pinger interface +func (mc *mysqlConn) Ping(ctx context.Context) (err error) { + if mc.closed.IsSet() { + errLog.Print(ErrInvalidConn) + return driver.ErrBadConn + } + + if err = mc.watchCancel(ctx); err != nil { + return + } + defer mc.finish() + + if err = mc.writeCommandPacket(comPing); err != nil { + return + } + + return mc.readResultOK() +} + +// BeginTx implements driver.ConnBeginTx interface +func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { + level, err := mapIsolationLevel(opts.Isolation) + if err != nil { + return nil, err + } + err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) + if err != nil { + return nil, err + } + } + + return mc.begin(opts.ReadOnly) +} + +func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := mc.query(query, dargs) + if err != nil { + mc.finish() + return nil, err + } + rows.finish = mc.finish + return rows, err +} + +func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + defer mc.finish() + + return mc.Exec(query, dargs) +} + +func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { + if err := mc.watchCancel(ctx); err != nil { + return nil, err + } + + stmt, err := mc.Prepare(query) + mc.finish() + if err != nil { + return nil, err + } + + select { + default: + case <-ctx.Done(): + stmt.Close() + return nil, ctx.Err() + } + return stmt, nil +} + +func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + + rows, err := stmt.query(dargs) + if err != nil { + stmt.mc.finish() + return nil, err + } + rows.finish = stmt.mc.finish + return rows, err +} + +func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { + dargs, err := namedValueToValue(args) + if err != nil { + return nil, err + } + + if err := stmt.mc.watchCancel(ctx); err != nil { + return nil, err + } + defer stmt.mc.finish() + + return stmt.Exec(dargs) +} + +func (mc *mysqlConn) watchCancel(ctx context.Context) error { + if mc.watching { + // Reach here if canceled, + // so the connection is already invalid + mc.cleanup() + return nil + } + if ctx.Done() == nil { + return nil + } + + mc.watching = true + select { + default: + case <-ctx.Done(): + return ctx.Err() + } + if mc.watcher == nil { + return nil + } + + mc.watcher <- ctx + + return nil +} + +func (mc *mysqlConn) startWatcher() { + watcher := make(chan mysqlContext, 1) + mc.watcher = watcher + finished := make(chan struct{}) + mc.finished = finished + go func() { + for { + var ctx mysqlContext + select { + case ctx = <-watcher: + case <-mc.closech: + return + } + + select { + case <-ctx.Done(): + mc.cancel(ctx.Err()) + case <-finished: + case <-mc.closech: + return + } + } + }() +} + +func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { + nv.Value, err = converter{}.ConvertValue(nv.Value) + return +} + +// ResetSession implements driver.SessionResetter. +// (From Go 1.10) +func (mc *mysqlConn) ResetSession(ctx context.Context) error { + if mc.closed.IsSet() { + return driver.ErrBadConn + } + return nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/const.go b/vendor/github.com/go-sql-driver/mysql/const.go index 88cfff3fd8..b1e6b85efc 100644 --- a/vendor/github.com/go-sql-driver/mysql/const.go +++ b/vendor/github.com/go-sql-driver/mysql/const.go @@ -9,7 +9,9 @@ package mysql const ( - minProtocolVersion byte = 10 + defaultAuthPlugin = "mysql_native_password" + defaultMaxAllowedPacket = 4 << 20 // 4 MiB + minProtocolVersion = 10 maxPacketSize = 1<<24 - 1 timeFormat = "2006-01-02 15:04:05.999999" ) @@ -18,10 +20,11 @@ const ( // http://dev.mysql.com/doc/internals/en/client-server-protocol.html const ( - iOK byte = 0x00 - iLocalInFile byte = 0xfb - iEOF byte = 0xfe - iERR byte = 0xff + iOK byte = 0x00 + iAuthMoreData byte = 0x01 + iLocalInFile byte = 0xfb + iEOF byte = 0xfe + iERR byte = 0xff ) // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags @@ -87,8 +90,10 @@ const ( ) // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType +type fieldType byte + const ( - fieldTypeDecimal byte = iota + fieldTypeDecimal fieldType = iota fieldTypeTiny fieldTypeShort fieldTypeLong @@ -107,7 +112,7 @@ const ( fieldTypeBit ) const ( - fieldTypeJSON byte = iota + 0xf5 + fieldTypeJSON fieldType = iota + 0xf5 fieldTypeNewDecimal fieldTypeEnum fieldTypeSet @@ -161,3 +166,9 @@ const ( statusInTransReadonly statusSessionStateChanged ) + +const ( + cachingSha2PasswordRequestPublicKey = 2 + cachingSha2PasswordFastAuthSuccess = 3 + cachingSha2PasswordPerformFullAuthentication = 4 +) diff --git a/vendor/github.com/go-sql-driver/mysql/driver.go b/vendor/github.com/go-sql-driver/mysql/driver.go index f5ee4728e5..1a75a16ecf 100644 --- a/vendor/github.com/go-sql-driver/mysql/driver.go +++ b/vendor/github.com/go-sql-driver/mysql/driver.go @@ -4,7 +4,7 @@ // License, v. 2.0. If a copy of the MPL was not distributed with this file, // You can obtain one at http://mozilla.org/MPL/2.0/. -// Package mysql provides a MySQL driver for Go's database/sql package +// Package mysql provides a MySQL driver for Go's database/sql package. // // The driver should be used via the database/sql package: // @@ -20,8 +20,14 @@ import ( "database/sql" "database/sql/driver" "net" + "sync" ) +// watcher interface is used for context support (From Go 1.8) +type watcher interface { + startWatcher() +} + // MySQLDriver is exported to make the driver directly accessible. // In general the driver is used via the database/sql package. type MySQLDriver struct{} @@ -30,12 +36,17 @@ type MySQLDriver struct{} // Custom dial functions must be registered with RegisterDial type DialFunc func(addr string) (net.Conn, error) -var dials map[string]DialFunc +var ( + dialsLock sync.RWMutex + dials map[string]DialFunc +) // RegisterDial registers a custom dial function. It can then be used by the // network address mynet(addr), where mynet is the registered new network. // addr is passed as a parameter to the dial function. func RegisterDial(net string, dial DialFunc) { + dialsLock.Lock() + defer dialsLock.Unlock() if dials == nil { dials = make(map[string]DialFunc) } @@ -52,16 +63,19 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc := &mysqlConn{ maxAllowedPacket: maxPacketSize, maxWriteSize: maxPacketSize - 1, + closech: make(chan struct{}), } mc.cfg, err = ParseDSN(dsn) if err != nil { return nil, err } mc.parseTime = mc.cfg.ParseTime - mc.strict = mc.cfg.Strict // Connect to Server - if dial, ok := dials[mc.cfg.Net]; ok { + dialsLock.RLock() + dial, ok := dials[mc.cfg.Net] + dialsLock.RUnlock() + if ok { mc.netConn, err = dial(mc.cfg.Addr) } else { nd := net.Dialer{Timeout: mc.cfg.Timeout} @@ -81,6 +95,11 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { } } + // Call startWatcher for context support (From Go 1.8) + if s, ok := interface{}(mc).(watcher); ok { + s.startWatcher() + } + mc.buf = newBuffer(mc.netConn) // Set I/O timeouts @@ -88,20 +107,31 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { mc.writeTimeout = mc.cfg.WriteTimeout // Reading Handshake Initialization Packet - cipher, err := mc.readInitPacket() + authData, plugin, err := mc.readHandshakePacket() if err != nil { mc.cleanup() return nil, err } // Send Client Authentication Packet - if err = mc.writeAuthPacket(cipher); err != nil { + authResp, addNUL, err := mc.auth(authData, plugin) + if err != nil { + // try the default auth plugin, if using the requested plugin failed + errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) + plugin = defaultAuthPlugin + authResp, addNUL, err = mc.auth(authData, plugin) + if err != nil { + mc.cleanup() + return nil, err + } + } + if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { mc.cleanup() return nil, err } // Handle response to auth packet, switch methods if possible - if err = handleAuthResult(mc); err != nil { + if err = mc.handleAuthResult(authData, plugin); err != nil { // Authentication failed and MySQL has already closed the connection // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). // Do not send COM_QUIT, just cleanup and return the error. @@ -134,43 +164,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { return mc, nil } -func handleAuthResult(mc *mysqlConn) error { - // Read Result Packet - cipher, err := mc.readResultOK() - if err == nil { - return nil // auth successful - } - - if mc.cfg == nil { - return err // auth failed and retry not possible - } - - // Retry auth if configured to do so. - if mc.cfg.AllowOldPasswords && err == ErrOldPassword { - // Retry with old authentication method. Note: there are edge cases - // where this should work but doesn't; this is currently "wontfix": - // https://github.com/go-sql-driver/mysql/issues/184 - if err = mc.writeOldAuthPacket(cipher); err != nil { - return err - } - _, err = mc.readResultOK() - } else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { - // Retry with clear text password for - // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html - // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html - if err = mc.writeClearAuthPacket(); err != nil { - return err - } - _, err = mc.readResultOK() - } else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { - if err = mc.writeNativeAuthPacket(cipher); err != nil { - return err - } - _, err = mc.readResultOK() - } - return err -} - func init() { sql.Register("mysql", &MySQLDriver{}) } diff --git a/vendor/github.com/go-sql-driver/mysql/dsn.go b/vendor/github.com/go-sql-driver/mysql/dsn.go index ac00dcedd5..be014babe3 100644 --- a/vendor/github.com/go-sql-driver/mysql/dsn.go +++ b/vendor/github.com/go-sql-driver/mysql/dsn.go @@ -10,11 +10,13 @@ package mysql import ( "bytes" + "crypto/rsa" "crypto/tls" "errors" "fmt" "net" "net/url" + "sort" "strconv" "strings" "time" @@ -27,7 +29,9 @@ var ( errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") ) -// Config is a configuration parsed from a DSN string +// Config is a configuration parsed from a DSN string. +// If a new Config is created instead of being parsed from a DSN string, +// the NewConfig function should be used, which sets default values. type Config struct { User string // Username Passwd string // Password (requires User) @@ -38,6 +42,8 @@ type Config struct { Collation string // Connection collation Loc *time.Location // Location for time.Time values MaxAllowedPacket int // Max packet size allowed + ServerPubKey string // Server public key name + pubKey *rsa.PublicKey // Server public key TLSConfig string // TLS configuration name tls *tls.Config // TLS configuration Timeout time.Duration // Dial timeout @@ -53,7 +59,54 @@ type Config struct { InterpolateParams bool // Interpolate placeholders into query string MultiStatements bool // Allow multiple statements in one query ParseTime bool // Parse time values to time.Time - Strict bool // Return warnings as errors + RejectReadOnly bool // Reject read-only connections +} + +// NewConfig creates a new Config and sets default values. +func NewConfig() *Config { + return &Config{ + Collation: defaultCollation, + Loc: time.UTC, + MaxAllowedPacket: defaultMaxAllowedPacket, + AllowNativePasswords: true, + } +} + +func (cfg *Config) normalize() error { + if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { + return errInvalidDSNUnsafeCollation + } + + // Set default network if empty + if cfg.Net == "" { + cfg.Net = "tcp" + } + + // Set default address if empty + if cfg.Addr == "" { + switch cfg.Net { + case "tcp": + cfg.Addr = "127.0.0.1:3306" + case "unix": + cfg.Addr = "/tmp/mysql.sock" + default: + return errors.New("default addr for network '" + cfg.Net + "' unknown") + } + + } else if cfg.Net == "tcp" { + cfg.Addr = ensureHavePort(cfg.Addr) + } + + if cfg.tls != nil { + if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { + host, _, err := net.SplitHostPort(cfg.Addr) + if err == nil { + cfg.tls.ServerName = host + } + } + } + + return nil } // FormatDSN formats the given Config into a DSN string which can be passed to @@ -102,12 +155,12 @@ func (cfg *Config) FormatDSN() string { } } - if cfg.AllowNativePasswords { + if !cfg.AllowNativePasswords { if hasParam { - buf.WriteString("&allowNativePasswords=true") + buf.WriteString("&allowNativePasswords=false") } else { hasParam = true - buf.WriteString("?allowNativePasswords=true") + buf.WriteString("?allowNativePasswords=false") } } @@ -195,15 +248,25 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(cfg.ReadTimeout.String()) } - if cfg.Strict { + if cfg.RejectReadOnly { if hasParam { - buf.WriteString("&strict=true") + buf.WriteString("&rejectReadOnly=true") } else { hasParam = true - buf.WriteString("?strict=true") + buf.WriteString("?rejectReadOnly=true") } } + if len(cfg.ServerPubKey) > 0 { + if hasParam { + buf.WriteString("&serverPubKey=") + } else { + hasParam = true + buf.WriteString("?serverPubKey=") + } + buf.WriteString(url.QueryEscape(cfg.ServerPubKey)) + } + if cfg.Timeout > 0 { if hasParam { buf.WriteString("&timeout=") @@ -234,7 +297,7 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(cfg.WriteTimeout.String()) } - if cfg.MaxAllowedPacket > 0 { + if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { if hasParam { buf.WriteString("&maxAllowedPacket=") } else { @@ -247,7 +310,12 @@ func (cfg *Config) FormatDSN() string { // other params if cfg.Params != nil { - for param, value := range cfg.Params { + var params []string + for param := range cfg.Params { + params = append(params, param) + } + sort.Strings(params) + for _, param := range params { if hasParam { buf.WriteByte('&') } else { @@ -257,7 +325,7 @@ func (cfg *Config) FormatDSN() string { buf.WriteString(param) buf.WriteByte('=') - buf.WriteString(url.QueryEscape(value)) + buf.WriteString(url.QueryEscape(cfg.Params[param])) } } @@ -267,10 +335,7 @@ func (cfg *Config) FormatDSN() string { // ParseDSN parses the DSN string to a Config func ParseDSN(dsn string) (cfg *Config, err error) { // New config with some default values - cfg = &Config{ - Loc: time.UTC, - Collation: defaultCollation, - } + cfg = NewConfig() // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] // Find the last '/' (since the password or the net addr might contain a '/') @@ -338,28 +403,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) { return nil, errInvalidDSNNoSlash } - if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { - return nil, errInvalidDSNUnsafeCollation + if err = cfg.normalize(); err != nil { + return nil, err } - - // Set default network if empty - if cfg.Net == "" { - cfg.Net = "tcp" - } - - // Set default address if empty - if cfg.Addr == "" { - switch cfg.Net { - case "tcp": - cfg.Addr = "127.0.0.1:3306" - case "unix": - cfg.Addr = "/tmp/mysql.sock" - default: - return nil, errors.New("default addr for network '" + cfg.Net + "' unknown") - } - - } - return } @@ -374,7 +420,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { // cfg params switch value := param[1]; param[0] { - // Disable INFILE whitelist / enable all files case "allowAllFiles": var isBool bool @@ -472,14 +517,32 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } - // Strict mode - case "strict": + // Reject read-only connections + case "rejectReadOnly": var isBool bool - cfg.Strict, isBool = readBool(value) + cfg.RejectReadOnly, isBool = readBool(value) if !isBool { return errors.New("invalid bool value: " + value) } + // Server public key + case "serverPubKey": + name, err := url.QueryUnescape(value) + if err != nil { + return fmt.Errorf("invalid value for server pub key name: %v", err) + } + + if pubKey := getServerPubKey(name); pubKey != nil { + cfg.ServerPubKey = name + cfg.pubKey = pubKey + } else { + return errors.New("invalid value / unknown server pub key name: " + name) + } + + // Strict mode + case "strict": + panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") + // Dial Timeout case "timeout": cfg.Timeout, err = time.ParseDuration(value) @@ -506,14 +569,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { return fmt.Errorf("invalid value for TLS config name: %v", err) } - if tlsConfig, ok := tlsConfigRegister[name]; ok { - if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { - host, _, err := net.SplitHostPort(cfg.Addr) - if err == nil { - tlsConfig.ServerName = host - } - } - + if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { cfg.TLSConfig = name cfg.tls = tlsConfig } else { @@ -546,3 +602,10 @@ func parseDSNParams(cfg *Config, params string) (err error) { return } + +func ensureHavePort(addr string) string { + if _, _, err := net.SplitHostPort(addr); err != nil { + return net.JoinHostPort(addr, "3306") + } + return addr +} diff --git a/vendor/github.com/go-sql-driver/mysql/errors.go b/vendor/github.com/go-sql-driver/mysql/errors.go index 857854e147..760782ff2f 100644 --- a/vendor/github.com/go-sql-driver/mysql/errors.go +++ b/vendor/github.com/go-sql-driver/mysql/errors.go @@ -9,10 +9,8 @@ package mysql import ( - "database/sql/driver" "errors" "fmt" - "io" "log" "os" ) @@ -31,6 +29,12 @@ var ( ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") ErrBusyBuffer = errors.New("busy buffer") + + // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. + // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn + // to trigger a resend. + // See https://github.com/go-sql-driver/mysql/pull/302 + errBadConnNoWrite = errors.New("bad connection") ) var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) @@ -59,74 +63,3 @@ type MySQLError struct { func (me *MySQLError) Error() string { return fmt.Sprintf("Error %d: %s", me.Number, me.Message) } - -// MySQLWarnings is an error type which represents a group of one or more MySQL -// warnings -type MySQLWarnings []MySQLWarning - -func (mws MySQLWarnings) Error() string { - var msg string - for i, warning := range mws { - if i > 0 { - msg += "\r\n" - } - msg += fmt.Sprintf( - "%s %s: %s", - warning.Level, - warning.Code, - warning.Message, - ) - } - return msg -} - -// MySQLWarning is an error type which represents a single MySQL warning. -// Warnings are returned in groups only. See MySQLWarnings -type MySQLWarning struct { - Level string - Code string - Message string -} - -func (mc *mysqlConn) getWarnings() (err error) { - rows, err := mc.Query("SHOW WARNINGS", nil) - if err != nil { - return - } - - var warnings = MySQLWarnings{} - var values = make([]driver.Value, 3) - - for { - err = rows.Next(values) - switch err { - case nil: - warning := MySQLWarning{} - - if raw, ok := values[0].([]byte); ok { - warning.Level = string(raw) - } else { - warning.Level = fmt.Sprintf("%s", values[0]) - } - if raw, ok := values[1].([]byte); ok { - warning.Code = string(raw) - } else { - warning.Code = fmt.Sprintf("%s", values[1]) - } - if raw, ok := values[2].([]byte); ok { - warning.Message = string(raw) - } else { - warning.Message = fmt.Sprintf("%s", values[0]) - } - - warnings = append(warnings, warning) - - case io.EOF: - return warnings - - default: - rows.Close() - return - } - } -} diff --git a/vendor/github.com/go-sql-driver/mysql/fields.go b/vendor/github.com/go-sql-driver/mysql/fields.go new file mode 100644 index 0000000000..e1e2ece4b1 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/fields.go @@ -0,0 +1,194 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +package mysql + +import ( + "database/sql" + "reflect" +) + +func (mf *mysqlField) typeDatabaseName() string { + switch mf.fieldType { + case fieldTypeBit: + return "BIT" + case fieldTypeBLOB: + if mf.charSet != collations[binaryCollation] { + return "TEXT" + } + return "BLOB" + case fieldTypeDate: + return "DATE" + case fieldTypeDateTime: + return "DATETIME" + case fieldTypeDecimal: + return "DECIMAL" + case fieldTypeDouble: + return "DOUBLE" + case fieldTypeEnum: + return "ENUM" + case fieldTypeFloat: + return "FLOAT" + case fieldTypeGeometry: + return "GEOMETRY" + case fieldTypeInt24: + return "MEDIUMINT" + case fieldTypeJSON: + return "JSON" + case fieldTypeLong: + return "INT" + case fieldTypeLongBLOB: + if mf.charSet != collations[binaryCollation] { + return "LONGTEXT" + } + return "LONGBLOB" + case fieldTypeLongLong: + return "BIGINT" + case fieldTypeMediumBLOB: + if mf.charSet != collations[binaryCollation] { + return "MEDIUMTEXT" + } + return "MEDIUMBLOB" + case fieldTypeNewDate: + return "DATE" + case fieldTypeNewDecimal: + return "DECIMAL" + case fieldTypeNULL: + return "NULL" + case fieldTypeSet: + return "SET" + case fieldTypeShort: + return "SMALLINT" + case fieldTypeString: + if mf.charSet == collations[binaryCollation] { + return "BINARY" + } + return "CHAR" + case fieldTypeTime: + return "TIME" + case fieldTypeTimestamp: + return "TIMESTAMP" + case fieldTypeTiny: + return "TINYINT" + case fieldTypeTinyBLOB: + if mf.charSet != collations[binaryCollation] { + return "TINYTEXT" + } + return "TINYBLOB" + case fieldTypeVarChar: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeVarString: + if mf.charSet == collations[binaryCollation] { + return "VARBINARY" + } + return "VARCHAR" + case fieldTypeYear: + return "YEAR" + default: + return "" + } +} + +var ( + scanTypeFloat32 = reflect.TypeOf(float32(0)) + scanTypeFloat64 = reflect.TypeOf(float64(0)) + scanTypeInt8 = reflect.TypeOf(int8(0)) + scanTypeInt16 = reflect.TypeOf(int16(0)) + scanTypeInt32 = reflect.TypeOf(int32(0)) + scanTypeInt64 = reflect.TypeOf(int64(0)) + scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) + scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) + scanTypeNullTime = reflect.TypeOf(NullTime{}) + scanTypeUint8 = reflect.TypeOf(uint8(0)) + scanTypeUint16 = reflect.TypeOf(uint16(0)) + scanTypeUint32 = reflect.TypeOf(uint32(0)) + scanTypeUint64 = reflect.TypeOf(uint64(0)) + scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) + scanTypeUnknown = reflect.TypeOf(new(interface{})) +) + +type mysqlField struct { + tableName string + name string + length uint32 + flags fieldFlag + fieldType fieldType + decimals byte + charSet uint8 +} + +func (mf *mysqlField) scanType() reflect.Type { + switch mf.fieldType { + case fieldTypeTiny: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint8 + } + return scanTypeInt8 + } + return scanTypeNullInt + + case fieldTypeShort, fieldTypeYear: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint16 + } + return scanTypeInt16 + } + return scanTypeNullInt + + case fieldTypeInt24, fieldTypeLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint32 + } + return scanTypeInt32 + } + return scanTypeNullInt + + case fieldTypeLongLong: + if mf.flags&flagNotNULL != 0 { + if mf.flags&flagUnsigned != 0 { + return scanTypeUint64 + } + return scanTypeInt64 + } + return scanTypeNullInt + + case fieldTypeFloat: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat32 + } + return scanTypeNullFloat + + case fieldTypeDouble: + if mf.flags&flagNotNULL != 0 { + return scanTypeFloat64 + } + return scanTypeNullFloat + + case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, + fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, + fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, + fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, + fieldTypeTime: + return scanTypeRawBytes + + case fieldTypeDate, fieldTypeNewDate, + fieldTypeTimestamp, fieldTypeDateTime: + // NullTime is always returned for more consistent behavior as it can + // handle both cases of parseTime regardless if the field is nullable. + return scanTypeNullTime + + default: + return scanTypeUnknown + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/infile.go b/vendor/github.com/go-sql-driver/mysql/infile.go index 547357cfa7..273cb0ba50 100644 --- a/vendor/github.com/go-sql-driver/mysql/infile.go +++ b/vendor/github.com/go-sql-driver/mysql/infile.go @@ -147,7 +147,8 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { } // send content packets - if err == nil { + // if packetSize == 0, the Reader contains no data + if err == nil && packetSize > 0 { data := make([]byte, 4+packetSize) var n int for err == nil { @@ -173,8 +174,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { // read OK packet if err == nil { - _, err = mc.readResultOK() - return err + return mc.readResultOK() } mc.readPacket() diff --git a/vendor/github.com/go-sql-driver/mysql/packets.go b/vendor/github.com/go-sql-driver/mysql/packets.go index 9160beb090..d873a97b2f 100644 --- a/vendor/github.com/go-sql-driver/mysql/packets.go +++ b/vendor/github.com/go-sql-driver/mysql/packets.go @@ -25,26 +25,23 @@ import ( // Read packet to buffer 'data' func (mc *mysqlConn) readPacket() ([]byte, error) { - var payload []byte + var prevData []byte for { - // Read packet header + // read packet header data, err := mc.buf.readNext(4) if err != nil { + if cerr := mc.canceled.Value(); cerr != nil { + return nil, cerr + } errLog.Print(err) mc.Close() - return nil, driver.ErrBadConn + return nil, ErrInvalidConn } - // Packet Length [24 bit] + // packet length [24 bit] pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) - if pktLen < 1 { - errLog.Print(ErrMalformPkt) - mc.Close() - return nil, driver.ErrBadConn - } - - // Check Packet Sync [8 bit] + // check packet sync [8 bit] if data[3] != mc.sequence { if data[3] > mc.sequence { return nil, ErrPktSyncMul @@ -53,26 +50,41 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { } mc.sequence++ - // Read packet body [pktLen bytes] + // packets with length 0 terminate a previous packet which is a + // multiple of (2^24)−1 bytes long + if pktLen == 0 { + // there was no previous packet + if prevData == nil { + errLog.Print(ErrMalformPkt) + mc.Close() + return nil, ErrInvalidConn + } + + return prevData, nil + } + + // read packet body [pktLen bytes] data, err = mc.buf.readNext(pktLen) if err != nil { + if cerr := mc.canceled.Value(); cerr != nil { + return nil, cerr + } errLog.Print(err) mc.Close() - return nil, driver.ErrBadConn + return nil, ErrInvalidConn } - isLastPacket := (pktLen < maxPacketSize) + // return data if this was the last packet + if pktLen < maxPacketSize { + // zero allocations for non-split packets + if prevData == nil { + return data, nil + } - // Zero allocations for non-splitting packets - if isLastPacket && payload == nil { - return data, nil + return append(prevData, data...), nil } - payload = append(payload, data...) - - if isLastPacket { - return payload, nil - } + prevData = append(prevData, data...) } } @@ -119,33 +131,47 @@ func (mc *mysqlConn) writePacket(data []byte) error { // Handle error if err == nil { // n != len(data) + mc.cleanup() errLog.Print(ErrMalformPkt) } else { + if cerr := mc.canceled.Value(); cerr != nil { + return cerr + } + if n == 0 && pktLen == len(data)-4 { + // only for the first loop iteration when nothing was written yet + return errBadConnNoWrite + } + mc.cleanup() errLog.Print(err) } - return driver.ErrBadConn + return ErrInvalidConn } } /****************************************************************************** -* Initialisation Process * +* Initialization Process * ******************************************************************************/ // Handshake Initialization Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake -func (mc *mysqlConn) readInitPacket() ([]byte, error) { +func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { data, err := mc.readPacket() if err != nil { - return nil, err + // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since + // in connection initialization we don't risk retrying non-idempotent actions. + if err == ErrInvalidConn { + return nil, "", driver.ErrBadConn + } + return nil, "", err } if data[0] == iERR { - return nil, mc.handleErrorPacket(data) + return nil, "", mc.handleErrorPacket(data) } // protocol version [1 byte] if data[0] < minProtocolVersion { - return nil, fmt.Errorf( + return nil, "", fmt.Errorf( "unsupported protocol version %d. Version %d or higher is required", data[0], minProtocolVersion, @@ -157,7 +183,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 // first part of the password cipher [8 bytes] - cipher := data[pos : pos+8] + authData := data[pos : pos+8] // (filler) always 0x00 [1 byte] pos += 8 + 1 @@ -165,13 +191,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // capability flags (lower 2 bytes) [2 bytes] mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) if mc.flags&clientProtocol41 == 0 { - return nil, ErrOldProtocol + return nil, "", ErrOldProtocol } if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { - return nil, ErrNoTLS + return nil, "", ErrNoTLS } pos += 2 + plugin := "" if len(data) > pos { // character set [1 byte] // status flags [2 bytes] @@ -192,32 +219,34 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { // // The official Python library uses the fixed length 12 // which seems to work but technically could have a hidden bug. - cipher = append(cipher, data[pos:pos+12]...) + authData = append(authData, data[pos:pos+12]...) + pos += 13 - // TODO: Verify string termination // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) // \NUL otherwise - // - //if data[len(data)-1] == 0 { - // return - //} - //return ErrMalformPkt + if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { + plugin = string(data[pos : pos+end]) + } else { + plugin = string(data[pos:]) + } // make a memory safe copy of the cipher slice var b [20]byte - copy(b[:], cipher) - return b[:], nil + copy(b[:], authData) + return b[:], plugin, nil } + plugin = defaultAuthPlugin + // make a memory safe copy of the cipher slice var b [8]byte - copy(b[:], cipher) - return b[:], nil + copy(b[:], authData) + return b[:], plugin, nil } // Client Authentication Packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse -func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { +func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error { // Adjust client flags based on server support clientFlags := clientProtocol41 | clientSecureConn | @@ -241,10 +270,19 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { clientFlags |= clientMultiStatements } - // User Password - scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) + // encode length of the auth plugin data + var authRespLEIBuf [9]byte + authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) + if len(authRespLEI) > 1 { + // if the length can not be written in 1 byte, it must be written as a + // length encoded integer + clientFlags |= clientPluginAuthLenEncClientData + } - pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 + pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 + if addNUL { + pktLen++ + } // To specify a db name if n := len(mc.cfg.DBName); n > 0 { @@ -255,9 +293,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { // Calculate packet length and get buffer with that size data := mc.buf.takeSmallBuffer(pktLen + 4) if data == nil { - // can not take the buffer. Something must be wrong with the connection + // cannot take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // ClientFlags [32 bit] @@ -312,9 +350,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { data[pos] = 0x00 pos++ - // ScrambleBuffer [length encoded integer] - data[pos] = byte(len(scrambleBuff)) - pos += 1 + copy(data[pos+1:], scrambleBuff) + // Auth Data [length encoded integer] + pos += copy(data[pos:], authRespLEI) + pos += copy(data[pos:], authResp) + if addNUL { + data[pos] = 0x00 + pos++ + } // Databasename [null terminated string] if len(mc.cfg.DBName) > 0 { @@ -323,72 +365,32 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { pos++ } - // Assume native client during response - pos += copy(data[pos:], "mysql_native_password") + pos += copy(data[pos:], plugin) data[pos] = 0x00 // Send Auth packet return mc.writePacket(data) } -// Client old authentication packet // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { - // User password - scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) - - // Calculate the packet length and add a tailing 0 - pktLen := len(scrambleBuff) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) +func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error { + pktLen := 4 + len(authData) + if addNUL { + pktLen++ + } + data := mc.buf.takeSmallBuffer(pktLen) if data == nil { - // can not take the buffer. Something must be wrong with the connection + // cannot take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } - // Add the scrambled password [null terminated string] - copy(data[4:], scrambleBuff) - data[4+pktLen-1] = 0x00 - - return mc.writePacket(data) -} - -// Client clear text authentication packet -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeClearAuthPacket() error { - // Calculate the packet length and add a tailing 0 - pktLen := len(mc.cfg.Passwd) + 1 - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + // Add the auth data [EOF] + copy(data[4:], authData) + if addNUL { + data[pktLen-1] = 0x00 } - // Add the clear password [null terminated string] - copy(data[4:], mc.cfg.Passwd) - data[4+pktLen-1] = 0x00 - - return mc.writePacket(data) -} - -// Native password authentication method -// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse -func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { - scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) - - // Calculate the packet length and add a tailing 0 - pktLen := len(scrambleBuff) - data := mc.buf.takeSmallBuffer(4 + pktLen) - if data == nil { - // can not take the buffer. Something must be wrong with the connection - errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn - } - - // Add the scramble - copy(data[4:], scrambleBuff) - return mc.writePacket(data) } @@ -402,9 +404,9 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { data := mc.buf.takeSmallBuffer(4 + 1) if data == nil { - // can not take the buffer. Something must be wrong with the connection + // cannot take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // Add command byte @@ -421,9 +423,9 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { pktLen := 1 + len(arg) data := mc.buf.takeBuffer(pktLen + 4) if data == nil { - // can not take the buffer. Something must be wrong with the connection + // cannot take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // Add command byte @@ -442,9 +444,9 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { data := mc.buf.takeSmallBuffer(4 + 1 + 4) if data == nil { - // can not take the buffer. Something must be wrong with the connection + // cannot take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // Add command byte @@ -464,43 +466,50 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { * Result Packets * ******************************************************************************/ -// Returns error if Packet is not an 'Result OK'-Packet -func (mc *mysqlConn) readResultOK() ([]byte, error) { +func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { data, err := mc.readPacket() - if err == nil { - // packet indicator - switch data[0] { - - case iOK: - return nil, mc.handleOkPacket(data) - - case iEOF: - if len(data) > 1 { - pluginEndIndex := bytes.IndexByte(data, 0x00) - plugin := string(data[1:pluginEndIndex]) - cipher := data[pluginEndIndex+1 : len(data)-1] - - if plugin == "mysql_old_password" { - // using old_passwords - return cipher, ErrOldPassword - } else if plugin == "mysql_clear_password" { - // using clear text password - return cipher, ErrCleartextPassword - } else if plugin == "mysql_native_password" { - // using mysql default authentication method - return cipher, ErrNativePassword - } else { - return cipher, ErrUnknownPlugin - } - } else { - return nil, ErrOldPassword - } - - default: // Error otherwise - return nil, mc.handleErrorPacket(data) - } + if err != nil { + return nil, "", err } - return nil, err + + // packet indicator + switch data[0] { + + case iOK: + return nil, "", mc.handleOkPacket(data) + + case iAuthMoreData: + return data[1:], "", err + + case iEOF: + if len(data) < 1 { + // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest + return nil, "mysql_old_password", nil + } + pluginEndIndex := bytes.IndexByte(data, 0x00) + if pluginEndIndex < 0 { + return nil, "", ErrMalformPkt + } + plugin := string(data[1:pluginEndIndex]) + authData := data[pluginEndIndex+1:] + return authData, plugin, nil + + default: // Error otherwise + return nil, "", mc.handleErrorPacket(data) + } +} + +// Returns error if Packet is not an 'Result OK'-Packet +func (mc *mysqlConn) readResultOK() error { + data, err := mc.readPacket() + if err != nil { + return err + } + + if data[0] == iOK { + return mc.handleOkPacket(data) + } + return mc.handleErrorPacket(data) } // Result Set Header Packet @@ -543,6 +552,22 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { // Error Number [16 bit uint] errno := binary.LittleEndian.Uint16(data[1:3]) + // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION + // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) + if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { + // Oops; we are connected to a read-only connection, and won't be able + // to issue any write statements. Since RejectReadOnly is configured, + // we throw away this connection hoping this one would have write + // permission. This is specifically for a possible race condition + // during failover (e.g. on AWS Aurora). See README.md for more. + // + // We explicitly close the connection before returning + // driver.ErrBadConn to ensure that `database/sql` purges this + // connection and initiates a new one for next statement next time. + mc.Close() + return driver.ErrBadConn + } + pos := 3 // SQL State [optional: # + 5bytes string] @@ -577,19 +602,12 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { // server_status [2 bytes] mc.status = readStatus(data[1+n+m : 1+n+m+2]) - if err := mc.discardResults(); err != nil { - return err - } - - // warning count [2 bytes] - if !mc.strict { + if mc.status&statusMoreResultsExists != 0 { return nil } - pos := 1 + n + m + 2 - if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { - return mc.getWarnings() - } + // warning count [2 bytes] + return nil } @@ -661,14 +679,21 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { if err != nil { return nil, err } + pos += n // Filler [uint8] + pos++ + // Charset [charset, collation uint8] + columns[i].charSet = data[pos] + pos += 2 + // Length [uint32] - pos += n + 1 + 2 + 4 + columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) + pos += 4 // Field type [uint8] - columns[i].fieldType = data[pos] + columns[i].fieldType = fieldType(data[pos]) pos++ // Flags [uint16] @@ -691,6 +716,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { func (rows *textRows) readRow(dest []driver.Value) error { mc := rows.mc + if rows.rs.done { + return io.EOF + } + data, err := mc.readPacket() if err != nil { return err @@ -700,10 +729,10 @@ func (rows *textRows) readRow(dest []driver.Value) error { if data[0] == iEOF && len(data) == 5 { // server_status [2 bytes] rows.mc.status = readStatus(data[3:]) - if err := rows.mc.discardResults(); err != nil { - return err + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil } - rows.mc = nil return io.EOF } if data[0] == iERR { @@ -725,7 +754,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { if !mc.parseTime { continue } else { - switch rows.columns[i].fieldType { + switch rows.rs.columns[i].fieldType { case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeDate, fieldTypeNewDate: dest[i], err = parseDateTime( @@ -797,14 +826,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { // Reserved [8 bit] // Warning count [16 bit uint] - if !stmt.mc.strict { - return columnCount, nil - } - // Check for warnings count > 0, only available in MySQL > 4.1 - if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { - return columnCount, stmt.mc.getWarnings() - } return columnCount, nil } return 0, err @@ -821,7 +843,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { // 2 bytes paramID const dataOffset = 1 + 4 + 2 - // Can not use the write buffer since + // Cannot use the write buffer since // a) the buffer is too small // b) it is in use data := make([]byte, 4+1+4+2+len(arg)) @@ -876,6 +898,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { const minPktLen = 4 + 1 + 4 + 1 + 4 mc := stmt.mc + // Determine threshould dynamically to avoid packet size shortage. + longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) + if longDataSize < 64 { + longDataSize = 64 + } + // Reset packet-sequence mc.sequence = 0 @@ -887,9 +915,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { data = mc.buf.takeCompleteBuffer() } if data == nil { - // can not take the buffer. Something must be wrong with the connection + // cannot take the buffer. Something must be wrong with the connection errLog.Print(ErrBusyBuffer) - return driver.ErrBadConn + return errBadConnNoWrite } // command [1 byte] @@ -948,7 +976,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // build NULL-bitmap if arg == nil { nullMask[i/8] |= 1 << (uint(i) & 7) - paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 continue } @@ -956,7 +984,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // cache types and values switch v := arg.(type) { case int64: - paramTypes[i+i] = fieldTypeLongLong + paramTypes[i+i] = byte(fieldTypeLongLong) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { @@ -972,7 +1000,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case float64: - paramTypes[i+i] = fieldTypeDouble + paramTypes[i+i] = byte(fieldTypeDouble) paramTypes[i+i+1] = 0x00 if cap(paramValues)-len(paramValues)-8 >= 0 { @@ -988,7 +1016,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case bool: - paramTypes[i+i] = fieldTypeTiny + paramTypes[i+i] = byte(fieldTypeTiny) paramTypes[i+i+1] = 0x00 if v { @@ -1000,10 +1028,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { case []byte: // Common case (non-nil value) first if v != nil { - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) @@ -1018,14 +1046,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { // Handle []byte(nil) as a NULL value nullMask[i/8] |= 1 << (uint(i) & 7) - paramTypes[i+i] = fieldTypeNULL + paramTypes[i+i] = byte(fieldTypeNULL) paramTypes[i+i+1] = 0x00 case string: - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { + if len(v) < longDataSize { paramValues = appendLengthEncodedInteger(paramValues, uint64(len(v)), ) @@ -1037,23 +1065,25 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { } case time.Time: - paramTypes[i+i] = fieldTypeString + paramTypes[i+i] = byte(fieldTypeString) paramTypes[i+i+1] = 0x00 - var val []byte + var a [64]byte + var b = a[:0] + if v.IsZero() { - val = []byte("0000-00-00") + b = append(b, "0000-00-00"...) } else { - val = []byte(v.In(mc.cfg.Loc).Format(timeFormat)) + b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) } paramValues = appendLengthEncodedInteger(paramValues, - uint64(len(val)), + uint64(len(b)), ) - paramValues = append(paramValues, val...) + paramValues = append(paramValues, b...) default: - return fmt.Errorf("can not convert type: %T", arg) + return fmt.Errorf("cannot convert type: %T", arg) } } @@ -1086,8 +1116,6 @@ func (mc *mysqlConn) discardResults() error { if err := mc.readUntilEOF(); err != nil { return err } - } else { - mc.status &^= statusMoreResultsExists } } return nil @@ -1105,16 +1133,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // EOF Packet if data[0] == iEOF && len(data) == 5 { rows.mc.status = readStatus(data[3:]) - if err := rows.mc.discardResults(); err != nil { - return err + rows.rs.done = true + if !rows.HasNextResultSet() { + rows.mc = nil } - rows.mc = nil return io.EOF } + mc := rows.mc rows.mc = nil // Error otherwise - return rows.mc.handleErrorPacket(data) + return mc.handleErrorPacket(data) } // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] @@ -1130,14 +1159,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { } // Convert to byte-coded string - switch rows.columns[i].fieldType { + switch rows.rs.columns[i].fieldType { case fieldTypeNULL: dest[i] = nil continue // Numeric Types case fieldTypeTiny: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(data[pos]) } else { dest[i] = int64(int8(data[pos])) @@ -1146,7 +1175,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeShort, fieldTypeYear: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) } else { dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) @@ -1155,7 +1184,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeInt24, fieldTypeLong: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) } else { dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) @@ -1164,7 +1193,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeLongLong: - if rows.columns[i].flags&flagUnsigned != 0 { + if rows.rs.columns[i].flags&flagUnsigned != 0 { val := binary.LittleEndian.Uint64(data[pos : pos+8]) if val > math.MaxInt64 { dest[i] = uint64ToString(val) @@ -1178,7 +1207,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { continue case fieldTypeFloat: - dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) + dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) pos += 4 continue @@ -1218,10 +1247,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { case isNull: dest[i] = nil continue - case rows.columns[i].fieldType == fieldTypeTime: + case rows.rs.columns[i].fieldType == fieldTypeTime: // database/sql does not support an equivalent to TIME, return a string var dstlen uint8 - switch decimals := rows.columns[i].decimals; decimals { + switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 8 case 1, 2, 3, 4, 5, 6: @@ -1229,7 +1258,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { default: return fmt.Errorf( "protocol error, illegal decimals value %d", - rows.columns[i].decimals, + rows.rs.columns[i].decimals, ) } dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) @@ -1237,10 +1266,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) default: var dstlen uint8 - if rows.columns[i].fieldType == fieldTypeDate { + if rows.rs.columns[i].fieldType == fieldTypeDate { dstlen = 10 } else { - switch decimals := rows.columns[i].decimals; decimals { + switch decimals := rows.rs.columns[i].decimals; decimals { case 0x00, 0x1f: dstlen = 19 case 1, 2, 3, 4, 5, 6: @@ -1248,7 +1277,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { default: return fmt.Errorf( "protocol error, illegal decimals value %d", - rows.columns[i].decimals, + rows.rs.columns[i].decimals, ) } } @@ -1264,7 +1293,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { // Please report if this happens! default: - return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) + return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) } } diff --git a/vendor/github.com/go-sql-driver/mysql/rows.go b/vendor/github.com/go-sql-driver/mysql/rows.go index c08255eeec..d3b1e28221 100644 --- a/vendor/github.com/go-sql-driver/mysql/rows.go +++ b/vendor/github.com/go-sql-driver/mysql/rows.go @@ -11,19 +11,20 @@ package mysql import ( "database/sql/driver" "io" + "math" + "reflect" ) -type mysqlField struct { - tableName string - name string - flags fieldFlag - fieldType byte - decimals byte +type resultSet struct { + columns []mysqlField + columnNames []string + done bool } type mysqlRows struct { - mc *mysqlConn - columns []mysqlField + mc *mysqlConn + rs resultSet + finish func() } type binaryRows struct { @@ -34,37 +35,86 @@ type textRows struct { mysqlRows } -type emptyRows struct{} - func (rows *mysqlRows) Columns() []string { - columns := make([]string, len(rows.columns)) + if rows.rs.columnNames != nil { + return rows.rs.columnNames + } + + columns := make([]string, len(rows.rs.columns)) if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { for i := range columns { - if tableName := rows.columns[i].tableName; len(tableName) > 0 { - columns[i] = tableName + "." + rows.columns[i].name + if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { + columns[i] = tableName + "." + rows.rs.columns[i].name } else { - columns[i] = rows.columns[i].name + columns[i] = rows.rs.columns[i].name } } } else { for i := range columns { - columns[i] = rows.columns[i].name + columns[i] = rows.rs.columns[i].name } } + + rows.rs.columnNames = columns return columns } -func (rows *mysqlRows) Close() error { +func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { + return rows.rs.columns[i].typeDatabaseName() +} + +// func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { +// return int64(rows.rs.columns[i].length), true +// } + +func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { + return rows.rs.columns[i].flags&flagNotNULL == 0, true +} + +func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { + column := rows.rs.columns[i] + decimals := int64(column.decimals) + + switch column.fieldType { + case fieldTypeDecimal, fieldTypeNewDecimal: + if decimals > 0 { + return int64(column.length) - 2, decimals, true + } + return int64(column.length) - 1, decimals, true + case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: + return decimals, decimals, true + case fieldTypeFloat, fieldTypeDouble: + if decimals == 0x1f { + return math.MaxInt64, math.MaxInt64, true + } + return math.MaxInt64, decimals, true + } + + return 0, 0, false +} + +func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { + return rows.rs.columns[i].scanType() +} + +func (rows *mysqlRows) Close() (err error) { + if f := rows.finish; f != nil { + f() + rows.finish = nil + } + mc := rows.mc if mc == nil { return nil } - if mc.netConn == nil { - return ErrInvalidConn + if err := mc.error(); err != nil { + return err } // Remove unread packets from stream - err := mc.readUntilEOF() + if !rows.rs.done { + err = mc.readUntilEOF() + } if err == nil { if err = mc.discardResults(); err != nil { return err @@ -75,10 +125,66 @@ func (rows *mysqlRows) Close() error { return err } +func (rows *mysqlRows) HasNextResultSet() (b bool) { + if rows.mc == nil { + return false + } + return rows.mc.status&statusMoreResultsExists != 0 +} + +func (rows *mysqlRows) nextResultSet() (int, error) { + if rows.mc == nil { + return 0, io.EOF + } + if err := rows.mc.error(); err != nil { + return 0, err + } + + // Remove unread packets from stream + if !rows.rs.done { + if err := rows.mc.readUntilEOF(); err != nil { + return 0, err + } + rows.rs.done = true + } + + if !rows.HasNextResultSet() { + rows.mc = nil + return 0, io.EOF + } + rows.rs = resultSet{} + return rows.mc.readResultSetHeaderPacket() +} + +func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { + for { + resLen, err := rows.nextResultSet() + if err != nil { + return 0, err + } + + if resLen > 0 { + return resLen, nil + } + + rows.rs.done = true + } +} + +func (rows *binaryRows) NextResultSet() error { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err +} + func (rows *binaryRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { - if mc.netConn == nil { - return ErrInvalidConn + if err := mc.error(); err != nil { + return err } // Fetch next row from stream @@ -87,10 +193,20 @@ func (rows *binaryRows) Next(dest []driver.Value) error { return io.EOF } +func (rows *textRows) NextResultSet() (err error) { + resLen, err := rows.nextNotEmptyResultSet() + if err != nil { + return err + } + + rows.rs.columns, err = rows.mc.readColumns(resLen) + return err +} + func (rows *textRows) Next(dest []driver.Value) error { if mc := rows.mc; mc != nil { - if mc.netConn == nil { - return ErrInvalidConn + if err := mc.error(); err != nil { + return err } // Fetch next row from stream @@ -98,15 +214,3 @@ func (rows *textRows) Next(dest []driver.Value) error { } return io.EOF } - -func (rows emptyRows) Columns() []string { - return nil -} - -func (rows emptyRows) Close() error { - return nil -} - -func (rows emptyRows) Next(dest []driver.Value) error { - return io.EOF -} diff --git a/vendor/github.com/go-sql-driver/mysql/statement.go b/vendor/github.com/go-sql-driver/mysql/statement.go index ead9a6bf47..ce7fe4cd04 100644 --- a/vendor/github.com/go-sql-driver/mysql/statement.go +++ b/vendor/github.com/go-sql-driver/mysql/statement.go @@ -11,6 +11,7 @@ package mysql import ( "database/sql/driver" "fmt" + "io" "reflect" "strconv" ) @@ -19,12 +20,14 @@ type mysqlStmt struct { mc *mysqlConn id uint32 paramCount int - columns []mysqlField // cached from the first query } func (stmt *mysqlStmt) Close() error { - if stmt.mc == nil || stmt.mc.netConn == nil { - errLog.Print(ErrInvalidConn) + if stmt.mc == nil || stmt.mc.closed.IsSet() { + // driver.Stmt.Close can be called more than once, thus this function + // has to be idempotent. + // See also Issue #450 and golang/go#16019. + //errLog.Print(ErrInvalidConn) return driver.ErrBadConn } @@ -42,14 +45,14 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { } func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { - if stmt.mc.netConn == nil { + if stmt.mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := stmt.writeExecutePacket(args) if err != nil { - return nil, err + return nil, stmt.mc.markBadConn(err) } mc := stmt.mc @@ -59,37 +62,45 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { // Read Result resLen, err := mc.readResultSetHeaderPacket() - if err == nil { - if resLen > 0 { - // Columns - err = mc.readUntilEOF() - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - // Rows - err = mc.readUntilEOF() + if resLen > 0 { + // Columns + if err = mc.readUntilEOF(); err != nil { + return nil, err } - if err == nil { - return &mysqlResult{ - affectedRows: int64(mc.affectedRows), - insertId: int64(mc.insertId), - }, nil + + // Rows + if err := mc.readUntilEOF(); err != nil { + return nil, err } } - return nil, err + if err := mc.discardResults(); err != nil { + return nil, err + } + + return &mysqlResult{ + affectedRows: int64(mc.affectedRows), + insertId: int64(mc.insertId), + }, nil } func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { - if stmt.mc.netConn == nil { + return stmt.query(args) +} + +func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { + if stmt.mc.closed.IsSet() { errLog.Print(ErrInvalidConn) return nil, driver.ErrBadConn } // Send command err := stmt.writeExecutePacket(args) if err != nil { - return nil, err + return nil, stmt.mc.markBadConn(err) } mc := stmt.mc @@ -104,14 +115,15 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { if resLen > 0 { rows.mc = mc - // Columns - // If not cached, read them and cache them - if stmt.columns == nil { - rows.columns, err = mc.readColumns(resLen) - stmt.columns = rows.columns - } else { - rows.columns = stmt.columns - err = mc.readUntilEOF() + rows.rs.columns, err = mc.readColumns(resLen) + } else { + rows.rs.done = true + + switch err := rows.NextResultSet(); err { + case nil, io.EOF: + return rows, nil + default: + return nil, err } } @@ -120,19 +132,36 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { type converter struct{} +// ConvertValue mirrors the reference/default converter in database/sql/driver +// with _one_ exception. We support uint64 with their high bit and the default +// implementation does not. This function should be kept in sync with +// database/sql/driver defaultConverter.ConvertValue() except for that +// deliberate difference. func (c converter) ConvertValue(v interface{}) (driver.Value, error) { if driver.IsValue(v) { return v, nil } + if vr, ok := v.(driver.Valuer); ok { + sv, err := callValuerValue(vr) + if err != nil { + return nil, err + } + if !driver.IsValue(sv) { + return nil, fmt.Errorf("non-Value type %T returned from Value", sv) + } + return sv, nil + } + rv := reflect.ValueOf(v) switch rv.Kind() { case reflect.Ptr: // indirect pointers if rv.IsNil() { return nil, nil + } else { + return c.ConvertValue(rv.Elem().Interface()) } - return c.ConvertValue(rv.Elem().Interface()) case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return rv.Int(), nil case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: @@ -145,6 +174,38 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { return int64(u64), nil case reflect.Float32, reflect.Float64: return rv.Float(), nil + case reflect.Bool: + return rv.Bool(), nil + case reflect.Slice: + ek := rv.Type().Elem().Kind() + if ek == reflect.Uint8 { + return rv.Bytes(), nil + } + return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) + case reflect.String: + return rv.String(), nil } return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) } + +var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() + +// callValuerValue returns vr.Value(), with one exception: +// If vr.Value is an auto-generated method on a pointer type and the +// pointer is nil, it would panic at runtime in the panicwrap +// method. Treat it like nil instead. +// +// This is so people can implement driver.Value on value types and +// still use nil pointers to those types to mean nil/NULL, just like +// string/*string. +// +// This is an exact copy of the same-named unexported function from the +// database/sql package. +func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { + if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && + rv.IsNil() && + rv.Type().Elem().Implements(valuerReflectType) { + return nil, nil + } + return vr.Value() +} diff --git a/vendor/github.com/go-sql-driver/mysql/transaction.go b/vendor/github.com/go-sql-driver/mysql/transaction.go index 33c749b35c..417d72793b 100644 --- a/vendor/github.com/go-sql-driver/mysql/transaction.go +++ b/vendor/github.com/go-sql-driver/mysql/transaction.go @@ -13,7 +13,7 @@ type mysqlTx struct { } func (tx *mysqlTx) Commit() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.closed.IsSet() { return ErrInvalidConn } err = tx.mc.exec("COMMIT") @@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { } func (tx *mysqlTx) Rollback() (err error) { - if tx.mc == nil || tx.mc.netConn == nil { + if tx.mc == nil || tx.mc.closed.IsSet() { return ErrInvalidConn } err = tx.mc.exec("ROLLBACK") diff --git a/vendor/github.com/go-sql-driver/mysql/utils.go b/vendor/github.com/go-sql-driver/mysql/utils.go index d523b7ffdf..84d595b6ba 100644 --- a/vendor/github.com/go-sql-driver/mysql/utils.go +++ b/vendor/github.com/go-sql-driver/mysql/utils.go @@ -9,23 +9,29 @@ package mysql import ( - "crypto/sha1" "crypto/tls" "database/sql/driver" "encoding/binary" "fmt" "io" "strings" + "sync" + "sync/atomic" "time" ) +// Registry for custom tls.Configs var ( - tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs + tlsConfigLock sync.RWMutex + tlsConfigRegistry map[string]*tls.Config ) // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. // Use the key as a value in the DSN where tls=value. // +// Note: The provided tls.Config is exclusively owned by the driver after +// registering it. +// // rootCertPool := x509.NewCertPool() // pem, err := ioutil.ReadFile("/path/ca-cert.pem") // if err != nil { @@ -51,19 +57,32 @@ func RegisterTLSConfig(key string, config *tls.Config) error { return fmt.Errorf("key '%s' is reserved", key) } - if tlsConfigRegister == nil { - tlsConfigRegister = make(map[string]*tls.Config) + tlsConfigLock.Lock() + if tlsConfigRegistry == nil { + tlsConfigRegistry = make(map[string]*tls.Config) } - tlsConfigRegister[key] = config + tlsConfigRegistry[key] = config + tlsConfigLock.Unlock() return nil } // DeregisterTLSConfig removes the tls.Config associated with key. func DeregisterTLSConfig(key string) { - if tlsConfigRegister != nil { - delete(tlsConfigRegister, key) + tlsConfigLock.Lock() + if tlsConfigRegistry != nil { + delete(tlsConfigRegistry, key) } + tlsConfigLock.Unlock() +} + +func getTLSConfigClone(key string) (config *tls.Config) { + tlsConfigLock.RLock() + if v, ok := tlsConfigRegistry[key]; ok { + config = cloneTLSConfig(v) + } + tlsConfigLock.RUnlock() + return } // Returns the bool value of the input. @@ -80,119 +99,6 @@ func readBool(input string) (value bool, valid bool) { return } -/****************************************************************************** -* Authentication * -******************************************************************************/ - -// Encrypt password using 4.1+ method -func scramblePassword(scramble, password []byte) []byte { - if len(password) == 0 { - return nil - } - - // stage1Hash = SHA1(password) - crypt := sha1.New() - crypt.Write(password) - stage1 := crypt.Sum(nil) - - // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) - // inner Hash - crypt.Reset() - crypt.Write(stage1) - hash := crypt.Sum(nil) - - // outer Hash - crypt.Reset() - crypt.Write(scramble) - crypt.Write(hash) - scramble = crypt.Sum(nil) - - // token = scrambleHash XOR stage1Hash - for i := range scramble { - scramble[i] ^= stage1[i] - } - return scramble -} - -// Encrypt password using pre 4.1 (old password) method -// https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c -type myRnd struct { - seed1, seed2 uint32 -} - -const myRndMaxVal = 0x3FFFFFFF - -// Pseudo random number generator -func newMyRnd(seed1, seed2 uint32) *myRnd { - return &myRnd{ - seed1: seed1 % myRndMaxVal, - seed2: seed2 % myRndMaxVal, - } -} - -// Tested to be equivalent to MariaDB's floating point variant -// http://play.golang.org/p/QHvhd4qved -// http://play.golang.org/p/RG0q4ElWDx -func (r *myRnd) NextByte() byte { - r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal - r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal - - return byte(uint64(r.seed1) * 31 / myRndMaxVal) -} - -// Generate binary hash from byte string using insecure pre 4.1 method -func pwHash(password []byte) (result [2]uint32) { - var add uint32 = 7 - var tmp uint32 - - result[0] = 1345345333 - result[1] = 0x12345671 - - for _, c := range password { - // skip spaces and tabs in password - if c == ' ' || c == '\t' { - continue - } - - tmp = uint32(c) - result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) - result[1] += (result[1] << 8) ^ result[0] - add += tmp - } - - // Remove sign bit (1<<31)-1) - result[0] &= 0x7FFFFFFF - result[1] &= 0x7FFFFFFF - - return -} - -// Encrypt password using insecure pre 4.1 method -func scrambleOldPassword(scramble, password []byte) []byte { - if len(password) == 0 { - return nil - } - - scramble = scramble[:8] - - hashPw := pwHash(password) - hashSc := pwHash(scramble) - - r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) - - var out [8]byte - for i := range out { - out[i] = r.NextByte() + 64 - } - - mask := r.NextByte() - for i := range out { - out[i] ^= mask - } - - return out[:] -} - /****************************************************************************** * Time related utils * ******************************************************************************/ @@ -519,7 +425,7 @@ func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { // Check data length if len(b) >= n { - return b[n-int(num) : n], false, n, nil + return b[n-int(num) : n : n], false, n, nil } return nil, false, n, io.EOF } @@ -548,8 +454,8 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { if len(b) == 0 { return 0, true, 1 } - switch b[0] { + switch b[0] { // 251: NULL case 0xfb: return 0, true, 1 @@ -738,3 +644,67 @@ func escapeStringQuotes(buf []byte, v string) []byte { return buf[:pos] } + +/****************************************************************************** +* Sync utils * +******************************************************************************/ + +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://github.com/golang/go/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct{} + +// Lock is a no-op used by -copylocks checker from `go vet`. +func (*noCopy) Lock() {} + +// atomicBool is a wrapper around uint32 for usage as a boolean value with +// atomic access. +type atomicBool struct { + _noCopy noCopy + value uint32 +} + +// IsSet returns wether the current boolean value is true +func (ab *atomicBool) IsSet() bool { + return atomic.LoadUint32(&ab.value) > 0 +} + +// Set sets the value of the bool regardless of the previous value +func (ab *atomicBool) Set(value bool) { + if value { + atomic.StoreUint32(&ab.value, 1) + } else { + atomic.StoreUint32(&ab.value, 0) + } +} + +// TrySet sets the value of the bool and returns wether the value changed +func (ab *atomicBool) TrySet(value bool) bool { + if value { + return atomic.SwapUint32(&ab.value, 1) == 0 + } + return atomic.SwapUint32(&ab.value, 0) > 0 +} + +// atomicError is a wrapper for atomically accessed error values +type atomicError struct { + _noCopy noCopy + value atomic.Value +} + +// Set sets the error value regardless of the previous value. +// The value must not be nil +func (ae *atomicError) Set(value error) { + ae.value.Store(value) +} + +// Value returns the current error value +func (ae *atomicError) Value() error { + if v := ae.value.Load(); v != nil { + // this will panic if the value doesn't implement the error interface + return v.(error) + } + return nil +} diff --git a/vendor/github.com/go-sql-driver/mysql/utils_go17.go b/vendor/github.com/go-sql-driver/mysql/utils_go17.go new file mode 100644 index 0000000000..f595634567 --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/utils_go17.go @@ -0,0 +1,40 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.7 +// +build !go1.8 + +package mysql + +import "crypto/tls" + +func cloneTLSConfig(c *tls.Config) *tls.Config { + return &tls.Config{ + Rand: c.Rand, + Time: c.Time, + Certificates: c.Certificates, + NameToCertificate: c.NameToCertificate, + GetCertificate: c.GetCertificate, + RootCAs: c.RootCAs, + NextProtos: c.NextProtos, + ServerName: c.ServerName, + ClientAuth: c.ClientAuth, + ClientCAs: c.ClientCAs, + InsecureSkipVerify: c.InsecureSkipVerify, + CipherSuites: c.CipherSuites, + PreferServerCipherSuites: c.PreferServerCipherSuites, + SessionTicketsDisabled: c.SessionTicketsDisabled, + SessionTicketKey: c.SessionTicketKey, + ClientSessionCache: c.ClientSessionCache, + MinVersion: c.MinVersion, + MaxVersion: c.MaxVersion, + CurvePreferences: c.CurvePreferences, + DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, + Renegotiation: c.Renegotiation, + } +} diff --git a/vendor/github.com/go-sql-driver/mysql/utils_go18.go b/vendor/github.com/go-sql-driver/mysql/utils_go18.go new file mode 100644 index 0000000000..c35c2a6aab --- /dev/null +++ b/vendor/github.com/go-sql-driver/mysql/utils_go18.go @@ -0,0 +1,50 @@ +// Go MySQL Driver - A MySQL-Driver for Go's database/sql package +// +// Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. +// +// This Source Code Form is subject to the terms of the Mozilla Public +// License, v. 2.0. If a copy of the MPL was not distributed with this file, +// You can obtain one at http://mozilla.org/MPL/2.0/. + +// +build go1.8 + +package mysql + +import ( + "crypto/tls" + "database/sql" + "database/sql/driver" + "errors" + "fmt" +) + +func cloneTLSConfig(c *tls.Config) *tls.Config { + return c.Clone() +} + +func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { + dargs := make([]driver.Value, len(named)) + for n, param := range named { + if len(param.Name) > 0 { + // TODO: support the use of Named Parameters #561 + return nil, errors.New("mysql: driver does not support the use of Named Parameters") + } + dargs[n] = param.Value + } + return dargs, nil +} + +func mapIsolationLevel(level driver.IsolationLevel) (string, error) { + switch sql.IsolationLevel(level) { + case sql.LevelRepeatableRead: + return "REPEATABLE READ", nil + case sql.LevelReadCommitted: + return "READ COMMITTED", nil + case sql.LevelReadUncommitted: + return "READ UNCOMMITTED", nil + case sql.LevelSerializable: + return "SERIALIZABLE", nil + default: + return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) + } +}