// +build go1.9 package mssql import ( "bytes" "encoding/binary" "errors" "fmt" "reflect" "strings" "time" ) const ( jsonTag = "json" tvpTag = "tvp" skipTagValue = "-" sqlSeparator = "." ) var ( ErrorEmptyTVPTypeName = errors.New("TypeName must not be empty") ErrorTypeSlice = errors.New("TVP must be slice type") ErrorTypeSliceIsEmpty = errors.New("TVP mustn't be null value") ErrorSkip = errors.New("all fields mustn't skip") ErrorObjectName = errors.New("wrong tvp name") ErrorWrongTyping = errors.New("the number of elements in columnStr and tvpFieldIndexes do not align") ) //TVP is driver type, which allows supporting Table Valued Parameters (TVP) in SQL Server type TVP struct { //TypeName mustn't be default value TypeName string //Value must be the slice, mustn't be nil Value interface{} } func (tvp TVP) check() error { if len(tvp.TypeName) == 0 { return ErrorEmptyTVPTypeName } if !isProc(tvp.TypeName) { return ErrorEmptyTVPTypeName } if sepCount := getCountSQLSeparators(tvp.TypeName); sepCount > 1 { return ErrorObjectName } valueOf := reflect.ValueOf(tvp.Value) if valueOf.Kind() != reflect.Slice { return ErrorTypeSlice } if valueOf.IsNil() { return ErrorTypeSliceIsEmpty } if reflect.TypeOf(tvp.Value).Elem().Kind() != reflect.Struct { return ErrorTypeSlice } return nil } func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldIndexes []int) ([]byte, error) { if len(columnStr) != len(tvpFieldIndexes) { return nil, ErrorWrongTyping } preparedBuffer := make([]byte, 0, 20+(10*len(columnStr))) buf := bytes.NewBuffer(preparedBuffer) err := writeBVarChar(buf, "") if err != nil { return nil, err } writeBVarChar(buf, schema) writeBVarChar(buf, name) binary.Write(buf, binary.LittleEndian, uint16(len(columnStr))) for i, column := range columnStr { binary.Write(buf, binary.LittleEndian, uint32(column.UserType)) binary.Write(buf, binary.LittleEndian, uint16(column.Flags)) writeTypeInfo(buf, &columnStr[i].ti) writeBVarChar(buf, "") } // The returned error is always nil buf.WriteByte(_TVP_END_TOKEN) conn := new(Conn) conn.sess = new(tdsSession) conn.sess.loginAck = loginAckStruct{TDSVersion: verTDS73} stmt := &Stmt{ c: conn, } val := reflect.ValueOf(tvp.Value) for i := 0; i < val.Len(); i++ { refStr := reflect.ValueOf(val.Index(i).Interface()) buf.WriteByte(_TVP_ROW_TOKEN) for columnStrIdx, fieldIdx := range tvpFieldIndexes { field := refStr.Field(fieldIdx) tvpVal := field.Interface() valOf := reflect.ValueOf(tvpVal) elemKind := field.Kind() if elemKind == reflect.Ptr && valOf.IsNil() { switch tvpVal.(type) { case *bool, *time.Time, *int8, *int16, *int32, *int64, *float32, *float64, *int: binary.Write(buf, binary.LittleEndian, uint8(0)) continue default: binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL)) continue } } if elemKind == reflect.Slice && valOf.IsNil() { binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL)) continue } cval, err := convertInputParameter(tvpVal) if err != nil { return nil, fmt.Errorf("failed to convert tvp parameter row col: %s", err) } param, err := stmt.makeParam(cval) if err != nil { return nil, fmt.Errorf("failed to make tvp parameter row col: %s", err) } columnStr[columnStrIdx].ti.Writer(buf, param.ti, param.buffer) } } buf.WriteByte(_TVP_END_TOKEN) return buf.Bytes(), nil } func (tvp TVP) columnTypes() ([]columnStruct, []int, error) { val := reflect.ValueOf(tvp.Value) var firstRow interface{} if val.Len() != 0 { firstRow = val.Index(0).Interface() } else { firstRow = reflect.New(reflect.TypeOf(tvp.Value).Elem()).Elem().Interface() } tvpRow := reflect.TypeOf(firstRow) columnCount := tvpRow.NumField() defaultValues := make([]interface{}, 0, columnCount) tvpFieldIndexes := make([]int, 0, columnCount) for i := 0; i < columnCount; i++ { field := tvpRow.Field(i) tvpTagValue, isTvpTag := field.Tag.Lookup(tvpTag) jsonTagValue, isJsonTag := field.Tag.Lookup(jsonTag) if IsSkipField(tvpTagValue, isTvpTag, jsonTagValue, isJsonTag) { continue } tvpFieldIndexes = append(tvpFieldIndexes, i) if field.Type.Kind() == reflect.Ptr { v := reflect.New(field.Type.Elem()) defaultValues = append(defaultValues, v.Interface()) continue } defaultValues = append(defaultValues, reflect.Zero(field.Type).Interface()) } if columnCount-len(tvpFieldIndexes) == columnCount { return nil, nil, ErrorSkip } conn := new(Conn) conn.sess = new(tdsSession) conn.sess.loginAck = loginAckStruct{TDSVersion: verTDS73} stmt := &Stmt{ c: conn, } columnConfiguration := make([]columnStruct, 0, columnCount) for index, val := range defaultValues { cval, err := convertInputParameter(val) if err != nil { return nil, nil, fmt.Errorf("failed to convert tvp parameter row %d col %d: %s", index, val, err) } param, err := stmt.makeParam(cval) if err != nil { return nil, nil, err } column := columnStruct{ ti: param.ti, } switch param.ti.TypeId { case typeNVarChar, typeBigVarBin: column.ti.Size = 0 } columnConfiguration = append(columnConfiguration, column) } return columnConfiguration, tvpFieldIndexes, nil } func IsSkipField(tvpTagValue string, isTvpValue bool, jsonTagValue string, isJsonTagValue bool) bool { if !isTvpValue && !isJsonTagValue { return false } else if isTvpValue && tvpTagValue != skipTagValue { return false } else if !isTvpValue && isJsonTagValue && jsonTagValue != skipTagValue { return false } return true } func getSchemeAndName(tvpName string) (string, string, error) { if len(tvpName) == 0 { return "", "", ErrorEmptyTVPTypeName } splitVal := strings.Split(tvpName, ".") if len(splitVal) > 2 { return "", "", errors.New("wrong tvp name") } if len(splitVal) == 2 { res := make([]string, 2) for key, value := range splitVal { tmp := strings.Replace(value, "[", "", -1) tmp = strings.Replace(tmp, "]", "", -1) res[key] = tmp } return res[0], res[1], nil } tmp := strings.Replace(splitVal[0], "[", "", -1) tmp = strings.Replace(tmp, "]", "", -1) return "", tmp, nil } func getCountSQLSeparators(str string) int { return strings.Count(str, sqlSeparator) }