Black Lives Matter. Support the Equal Justice Initiative.

Source file src/database/sql/fakedb_test.go

Documentation: database/sql

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sql
     6  
     7  import (
     8  	"context"
     9  	"database/sql/driver"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"reflect"
    14  	"sort"
    15  	"strconv"
    16  	"strings"
    17  	"sync"
    18  	"testing"
    19  	"time"
    20  )
    21  
    22  // fakeDriver is a fake database that implements Go's driver.Driver
    23  // interface, just for testing.
    24  //
    25  // It speaks a query language that's semantically similar to but
    26  // syntactically different and simpler than SQL.  The syntax is as
    27  // follows:
    28  //
    29  //   WIPE
    30  //   CREATE|<tablename>|<col>=<type>,<col>=<type>,...
    31  //     where types are: "string", [u]int{8,16,32,64}, "bool"
    32  //   INSERT|<tablename>|col=val,col2=val2,col3=?
    33  //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
    34  //   SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
    35  //
    36  // Any of these can be preceded by PANIC|<method>|, to cause the
    37  // named method on fakeStmt to panic.
    38  //
    39  // Any of these can be proceeded by WAIT|<duration>|, to cause the
    40  // named method on fakeStmt to sleep for the specified duration.
    41  //
    42  // Multiple of these can be combined when separated with a semicolon.
    43  //
    44  // When opening a fakeDriver's database, it starts empty with no
    45  // tables. All tables and data are stored in memory only.
    46  type fakeDriver struct {
    47  	mu         sync.Mutex // guards 3 following fields
    48  	openCount  int        // conn opens
    49  	closeCount int        // conn closes
    50  	waitCh     chan struct{}
    51  	waitingCh  chan struct{}
    52  	dbs        map[string]*fakeDB
    53  }
    54  
    55  type fakeConnector struct {
    56  	name string
    57  
    58  	waiter func(context.Context)
    59  	closed bool
    60  }
    61  
    62  func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
    63  	conn, err := fdriver.Open(c.name)
    64  	conn.(*fakeConn).waiter = c.waiter
    65  	return conn, err
    66  }
    67  
    68  func (c *fakeConnector) Driver() driver.Driver {
    69  	return fdriver
    70  }
    71  
    72  func (c *fakeConnector) Close() error {
    73  	if c.closed {
    74  		return errors.New("fakedb: connector is closed")
    75  	}
    76  	c.closed = true
    77  	return nil
    78  }
    79  
    80  type fakeDriverCtx struct {
    81  	fakeDriver
    82  }
    83  
    84  var _ driver.DriverContext = &fakeDriverCtx{}
    85  
    86  func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
    87  	return &fakeConnector{name: name}, nil
    88  }
    89  
    90  type fakeDB struct {
    91  	name string
    92  
    93  	mu       sync.Mutex
    94  	tables   map[string]*table
    95  	badConn  bool
    96  	allowAny bool
    97  }
    98  
    99  type table struct {
   100  	mu      sync.Mutex
   101  	colname []string
   102  	coltype []string
   103  	rows    []*row
   104  }
   105  
   106  func (t *table) columnIndex(name string) int {
   107  	for n, nname := range t.colname {
   108  		if name == nname {
   109  			return n
   110  		}
   111  	}
   112  	return -1
   113  }
   114  
   115  type row struct {
   116  	cols []interface{} // must be same size as its table colname + coltype
   117  }
   118  
   119  type memToucher interface {
   120  	// touchMem reads & writes some memory, to help find data races.
   121  	touchMem()
   122  }
   123  
   124  type fakeConn struct {
   125  	db *fakeDB // where to return ourselves to
   126  
   127  	currTx *fakeTx
   128  
   129  	// Every operation writes to line to enable the race detector
   130  	// check for data races.
   131  	line int64
   132  
   133  	// Stats for tests:
   134  	mu          sync.Mutex
   135  	stmtsMade   int
   136  	stmtsClosed int
   137  	numPrepare  int
   138  
   139  	// bad connection tests; see isBad()
   140  	bad       bool
   141  	stickyBad bool
   142  
   143  	skipDirtySession bool // tests that use Conn should set this to true.
   144  
   145  	// dirtySession tests ResetSession, true if a query has executed
   146  	// until ResetSession is called.
   147  	dirtySession bool
   148  
   149  	// The waiter is called before each query. May be used in place of the "WAIT"
   150  	// directive.
   151  	waiter func(context.Context)
   152  }
   153  
   154  func (c *fakeConn) touchMem() {
   155  	c.line++
   156  }
   157  
   158  func (c *fakeConn) incrStat(v *int) {
   159  	c.mu.Lock()
   160  	*v++
   161  	c.mu.Unlock()
   162  }
   163  
   164  type fakeTx struct {
   165  	c *fakeConn
   166  }
   167  
   168  type boundCol struct {
   169  	Column      string
   170  	Placeholder string
   171  	Ordinal     int
   172  }
   173  
   174  type fakeStmt struct {
   175  	memToucher
   176  	c *fakeConn
   177  	q string // just for debugging
   178  
   179  	cmd   string
   180  	table string
   181  	panic string
   182  	wait  time.Duration
   183  
   184  	next *fakeStmt // used for returning multiple results.
   185  
   186  	closed bool
   187  
   188  	colName      []string      // used by CREATE, INSERT, SELECT (selected columns)
   189  	colType      []string      // used by CREATE
   190  	colValue     []interface{} // used by INSERT (mix of strings and "?" for bound params)
   191  	placeholders int           // used by INSERT/SELECT: number of ? params
   192  
   193  	whereCol []boundCol // used by SELECT (all placeholders)
   194  
   195  	placeholderConverter []driver.ValueConverter // used by INSERT
   196  }
   197  
   198  var fdriver driver.Driver = &fakeDriver{}
   199  
   200  func init() {
   201  	Register("test", fdriver)
   202  }
   203  
   204  func contains(list []string, y string) bool {
   205  	for _, x := range list {
   206  		if x == y {
   207  			return true
   208  		}
   209  	}
   210  	return false
   211  }
   212  
   213  type Dummy struct {
   214  	driver.Driver
   215  }
   216  
   217  func TestDrivers(t *testing.T) {
   218  	unregisterAllDrivers()
   219  	Register("test", fdriver)
   220  	Register("invalid", Dummy{})
   221  	all := Drivers()
   222  	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
   223  		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
   224  	}
   225  }
   226  
   227  // hook to simulate connection failures
   228  var hookOpenErr struct {
   229  	sync.Mutex
   230  	fn func() error
   231  }
   232  
   233  func setHookOpenErr(fn func() error) {
   234  	hookOpenErr.Lock()
   235  	defer hookOpenErr.Unlock()
   236  	hookOpenErr.fn = fn
   237  }
   238  
   239  // Supports dsn forms:
   240  //    <dbname>
   241  //    <dbname>;<opts>  (only currently supported option is `badConn`,
   242  //                      which causes driver.ErrBadConn to be returned on
   243  //                      every other conn.Begin())
   244  func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
   245  	hookOpenErr.Lock()
   246  	fn := hookOpenErr.fn
   247  	hookOpenErr.Unlock()
   248  	if fn != nil {
   249  		if err := fn(); err != nil {
   250  			return nil, err
   251  		}
   252  	}
   253  	parts := strings.Split(dsn, ";")
   254  	if len(parts) < 1 {
   255  		return nil, errors.New("fakedb: no database name")
   256  	}
   257  	name := parts[0]
   258  
   259  	db := d.getDB(name)
   260  
   261  	d.mu.Lock()
   262  	d.openCount++
   263  	d.mu.Unlock()
   264  	conn := &fakeConn{db: db}
   265  
   266  	if len(parts) >= 2 && parts[1] == "badConn" {
   267  		conn.bad = true
   268  	}
   269  	if d.waitCh != nil {
   270  		d.waitingCh <- struct{}{}
   271  		<-d.waitCh
   272  		d.waitCh = nil
   273  		d.waitingCh = nil
   274  	}
   275  	return conn, nil
   276  }
   277  
   278  func (d *fakeDriver) getDB(name string) *fakeDB {
   279  	d.mu.Lock()
   280  	defer d.mu.Unlock()
   281  	if d.dbs == nil {
   282  		d.dbs = make(map[string]*fakeDB)
   283  	}
   284  	db, ok := d.dbs[name]
   285  	if !ok {
   286  		db = &fakeDB{name: name}
   287  		d.dbs[name] = db
   288  	}
   289  	return db
   290  }
   291  
   292  func (db *fakeDB) wipe() {
   293  	db.mu.Lock()
   294  	defer db.mu.Unlock()
   295  	db.tables = nil
   296  }
   297  
   298  func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
   299  	db.mu.Lock()
   300  	defer db.mu.Unlock()
   301  	if db.tables == nil {
   302  		db.tables = make(map[string]*table)
   303  	}
   304  	if _, exist := db.tables[name]; exist {
   305  		return fmt.Errorf("fakedb: table %q already exists", name)
   306  	}
   307  	if len(columnNames) != len(columnTypes) {
   308  		return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
   309  			name, len(columnNames), len(columnTypes))
   310  	}
   311  	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
   312  	return nil
   313  }
   314  
   315  // must be called with db.mu lock held
   316  func (db *fakeDB) table(table string) (*table, bool) {
   317  	if db.tables == nil {
   318  		return nil, false
   319  	}
   320  	t, ok := db.tables[table]
   321  	return t, ok
   322  }
   323  
   324  func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
   325  	db.mu.Lock()
   326  	defer db.mu.Unlock()
   327  	t, ok := db.table(table)
   328  	if !ok {
   329  		return
   330  	}
   331  	for n, cname := range t.colname {
   332  		if cname == column {
   333  			return t.coltype[n], true
   334  		}
   335  	}
   336  	return "", false
   337  }
   338  
   339  func (c *fakeConn) isBad() bool {
   340  	if c.stickyBad {
   341  		return true
   342  	} else if c.bad {
   343  		if c.db == nil {
   344  			return false
   345  		}
   346  		// alternate between bad conn and not bad conn
   347  		c.db.badConn = !c.db.badConn
   348  		return c.db.badConn
   349  	} else {
   350  		return false
   351  	}
   352  }
   353  
   354  func (c *fakeConn) isDirtyAndMark() bool {
   355  	if c.skipDirtySession {
   356  		return false
   357  	}
   358  	if c.currTx != nil {
   359  		c.dirtySession = true
   360  		return false
   361  	}
   362  	if c.dirtySession {
   363  		return true
   364  	}
   365  	c.dirtySession = true
   366  	return false
   367  }
   368  
   369  func (c *fakeConn) Begin() (driver.Tx, error) {
   370  	if c.isBad() {
   371  		return nil, driver.ErrBadConn
   372  	}
   373  	if c.currTx != nil {
   374  		return nil, errors.New("fakedb: already in a transaction")
   375  	}
   376  	c.touchMem()
   377  	c.currTx = &fakeTx{c: c}
   378  	return c.currTx, nil
   379  }
   380  
   381  var hookPostCloseConn struct {
   382  	sync.Mutex
   383  	fn func(*fakeConn, error)
   384  }
   385  
   386  func setHookpostCloseConn(fn func(*fakeConn, error)) {
   387  	hookPostCloseConn.Lock()
   388  	defer hookPostCloseConn.Unlock()
   389  	hookPostCloseConn.fn = fn
   390  }
   391  
   392  var testStrictClose *testing.T
   393  
   394  // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
   395  // fails to close. If nil, the check is disabled.
   396  func setStrictFakeConnClose(t *testing.T) {
   397  	testStrictClose = t
   398  }
   399  
   400  func (c *fakeConn) ResetSession(ctx context.Context) error {
   401  	c.dirtySession = false
   402  	c.currTx = nil
   403  	if c.isBad() {
   404  		return driver.ErrBadConn
   405  	}
   406  	return nil
   407  }
   408  
   409  var _ driver.Validator = (*fakeConn)(nil)
   410  
   411  func (c *fakeConn) IsValid() bool {
   412  	return !c.isBad()
   413  }
   414  
   415  func (c *fakeConn) Close() (err error) {
   416  	drv := fdriver.(*fakeDriver)
   417  	defer func() {
   418  		if err != nil && testStrictClose != nil {
   419  			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
   420  		}
   421  		hookPostCloseConn.Lock()
   422  		fn := hookPostCloseConn.fn
   423  		hookPostCloseConn.Unlock()
   424  		if fn != nil {
   425  			fn(c, err)
   426  		}
   427  		if err == nil {
   428  			drv.mu.Lock()
   429  			drv.closeCount++
   430  			drv.mu.Unlock()
   431  		}
   432  	}()
   433  	c.touchMem()
   434  	if c.currTx != nil {
   435  		return errors.New("fakedb: can't close fakeConn; in a Transaction")
   436  	}
   437  	if c.db == nil {
   438  		return errors.New("fakedb: can't close fakeConn; already closed")
   439  	}
   440  	if c.stmtsMade > c.stmtsClosed {
   441  		return errors.New("fakedb: can't close; dangling statement(s)")
   442  	}
   443  	c.db = nil
   444  	return nil
   445  }
   446  
   447  func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
   448  	for _, arg := range args {
   449  		switch arg.Value.(type) {
   450  		case int64, float64, bool, nil, []byte, string, time.Time:
   451  		default:
   452  			if !allowAny {
   453  				return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
   454  			}
   455  		}
   456  	}
   457  	return nil
   458  }
   459  
   460  func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   461  	// Ensure that ExecContext is called if available.
   462  	panic("ExecContext was not called.")
   463  }
   464  
   465  func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   466  	// This is an optional interface, but it's implemented here
   467  	// just to check that all the args are of the proper types.
   468  	// ErrSkip is returned so the caller acts as if we didn't
   469  	// implement this at all.
   470  	err := checkSubsetTypes(c.db.allowAny, args)
   471  	if err != nil {
   472  		return nil, err
   473  	}
   474  	return nil, driver.ErrSkip
   475  }
   476  
   477  func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   478  	// Ensure that ExecContext is called if available.
   479  	panic("QueryContext was not called.")
   480  }
   481  
   482  func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   483  	// This is an optional interface, but it's implemented here
   484  	// just to check that all the args are of the proper types.
   485  	// ErrSkip is returned so the caller acts as if we didn't
   486  	// implement this at all.
   487  	err := checkSubsetTypes(c.db.allowAny, args)
   488  	if err != nil {
   489  		return nil, err
   490  	}
   491  	return nil, driver.ErrSkip
   492  }
   493  
   494  func errf(msg string, args ...interface{}) error {
   495  	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
   496  }
   497  
   498  // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
   499  // (note that where columns must always contain ? marks,
   500  //  just a limitation for fakedb)
   501  func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   502  	if len(parts) != 3 {
   503  		stmt.Close()
   504  		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
   505  	}
   506  	stmt.table = parts[0]
   507  
   508  	stmt.colName = strings.Split(parts[1], ",")
   509  	for n, colspec := range strings.Split(parts[2], ",") {
   510  		if colspec == "" {
   511  			continue
   512  		}
   513  		nameVal := strings.Split(colspec, "=")
   514  		if len(nameVal) != 2 {
   515  			stmt.Close()
   516  			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   517  		}
   518  		column, value := nameVal[0], nameVal[1]
   519  		_, ok := c.db.columnType(stmt.table, column)
   520  		if !ok {
   521  			stmt.Close()
   522  			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
   523  		}
   524  		if !strings.HasPrefix(value, "?") {
   525  			stmt.Close()
   526  			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
   527  				stmt.table, column)
   528  		}
   529  		stmt.placeholders++
   530  		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
   531  	}
   532  	return stmt, nil
   533  }
   534  
   535  // parts are table|col=type,col2=type2
   536  func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   537  	if len(parts) != 2 {
   538  		stmt.Close()
   539  		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
   540  	}
   541  	stmt.table = parts[0]
   542  	for n, colspec := range strings.Split(parts[1], ",") {
   543  		nameType := strings.Split(colspec, "=")
   544  		if len(nameType) != 2 {
   545  			stmt.Close()
   546  			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   547  		}
   548  		stmt.colName = append(stmt.colName, nameType[0])
   549  		stmt.colType = append(stmt.colType, nameType[1])
   550  	}
   551  	return stmt, nil
   552  }
   553  
   554  // parts are table|col=?,col2=val
   555  func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   556  	if len(parts) != 2 {
   557  		stmt.Close()
   558  		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
   559  	}
   560  	stmt.table = parts[0]
   561  	for n, colspec := range strings.Split(parts[1], ",") {
   562  		nameVal := strings.Split(colspec, "=")
   563  		if len(nameVal) != 2 {
   564  			stmt.Close()
   565  			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   566  		}
   567  		column, value := nameVal[0], nameVal[1]
   568  		ctype, ok := c.db.columnType(stmt.table, column)
   569  		if !ok {
   570  			stmt.Close()
   571  			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
   572  		}
   573  		stmt.colName = append(stmt.colName, column)
   574  
   575  		if !strings.HasPrefix(value, "?") {
   576  			var subsetVal interface{}
   577  			// Convert to driver subset type
   578  			switch ctype {
   579  			case "string":
   580  				subsetVal = []byte(value)
   581  			case "blob":
   582  				subsetVal = []byte(value)
   583  			case "int32":
   584  				i, err := strconv.Atoi(value)
   585  				if err != nil {
   586  					stmt.Close()
   587  					return nil, errf("invalid conversion to int32 from %q", value)
   588  				}
   589  				subsetVal = int64(i) // int64 is a subset type, but not int32
   590  			case "table": // For testing cursor reads.
   591  				c.skipDirtySession = true
   592  				vparts := strings.Split(value, "!")
   593  
   594  				substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
   595  				if err != nil {
   596  					return nil, err
   597  				}
   598  				cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
   599  				substmt.Close()
   600  				if err != nil {
   601  					return nil, err
   602  				}
   603  				subsetVal = cursor
   604  			default:
   605  				stmt.Close()
   606  				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
   607  			}
   608  			stmt.colValue = append(stmt.colValue, subsetVal)
   609  		} else {
   610  			stmt.placeholders++
   611  			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
   612  			stmt.colValue = append(stmt.colValue, value)
   613  		}
   614  	}
   615  	return stmt, nil
   616  }
   617  
   618  // hook to simulate broken connections
   619  var hookPrepareBadConn func() bool
   620  
   621  func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
   622  	panic("use PrepareContext")
   623  }
   624  
   625  func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   626  	c.numPrepare++
   627  	if c.db == nil {
   628  		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
   629  	}
   630  
   631  	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
   632  		return nil, driver.ErrBadConn
   633  	}
   634  
   635  	c.touchMem()
   636  	var firstStmt, prev *fakeStmt
   637  	for _, query := range strings.Split(query, ";") {
   638  		parts := strings.Split(query, "|")
   639  		if len(parts) < 1 {
   640  			return nil, errf("empty query")
   641  		}
   642  		stmt := &fakeStmt{q: query, c: c, memToucher: c}
   643  		if firstStmt == nil {
   644  			firstStmt = stmt
   645  		}
   646  		if len(parts) >= 3 {
   647  			switch parts[0] {
   648  			case "PANIC":
   649  				stmt.panic = parts[1]
   650  				parts = parts[2:]
   651  			case "WAIT":
   652  				wait, err := time.ParseDuration(parts[1])
   653  				if err != nil {
   654  					return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
   655  				}
   656  				parts = parts[2:]
   657  				stmt.wait = wait
   658  			}
   659  		}
   660  		cmd := parts[0]
   661  		stmt.cmd = cmd
   662  		parts = parts[1:]
   663  
   664  		if c.waiter != nil {
   665  			c.waiter(ctx)
   666  		}
   667  
   668  		if stmt.wait > 0 {
   669  			wait := time.NewTimer(stmt.wait)
   670  			select {
   671  			case <-wait.C:
   672  			case <-ctx.Done():
   673  				wait.Stop()
   674  				return nil, ctx.Err()
   675  			}
   676  		}
   677  
   678  		c.incrStat(&c.stmtsMade)
   679  		var err error
   680  		switch cmd {
   681  		case "WIPE":
   682  			// Nothing
   683  		case "SELECT":
   684  			stmt, err = c.prepareSelect(stmt, parts)
   685  		case "CREATE":
   686  			stmt, err = c.prepareCreate(stmt, parts)
   687  		case "INSERT":
   688  			stmt, err = c.prepareInsert(ctx, stmt, parts)
   689  		case "NOSERT":
   690  			// Do all the prep-work like for an INSERT but don't actually insert the row.
   691  			// Used for some of the concurrent tests.
   692  			stmt, err = c.prepareInsert(ctx, stmt, parts)
   693  		default:
   694  			stmt.Close()
   695  			return nil, errf("unsupported command type %q", cmd)
   696  		}
   697  		if err != nil {
   698  			return nil, err
   699  		}
   700  		if prev != nil {
   701  			prev.next = stmt
   702  		}
   703  		prev = stmt
   704  	}
   705  	return firstStmt, nil
   706  }
   707  
   708  func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
   709  	if s.panic == "ColumnConverter" {
   710  		panic(s.panic)
   711  	}
   712  	if len(s.placeholderConverter) == 0 {
   713  		return driver.DefaultParameterConverter
   714  	}
   715  	return s.placeholderConverter[idx]
   716  }
   717  
   718  func (s *fakeStmt) Close() error {
   719  	if s.panic == "Close" {
   720  		panic(s.panic)
   721  	}
   722  	if s.c == nil {
   723  		panic("nil conn in fakeStmt.Close")
   724  	}
   725  	if s.c.db == nil {
   726  		panic("in fakeStmt.Close, conn's db is nil (already closed)")
   727  	}
   728  	s.touchMem()
   729  	if !s.closed {
   730  		s.c.incrStat(&s.c.stmtsClosed)
   731  		s.closed = true
   732  	}
   733  	if s.next != nil {
   734  		s.next.Close()
   735  	}
   736  	return nil
   737  }
   738  
   739  var errClosed = errors.New("fakedb: statement has been closed")
   740  
   741  // hook to simulate broken connections
   742  var hookExecBadConn func() bool
   743  
   744  func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
   745  	panic("Using ExecContext")
   746  }
   747  
   748  var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
   749  
   750  func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   751  	if s.panic == "Exec" {
   752  		panic(s.panic)
   753  	}
   754  	if s.closed {
   755  		return nil, errClosed
   756  	}
   757  
   758  	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
   759  		return nil, driver.ErrBadConn
   760  	}
   761  	if s.c.isDirtyAndMark() {
   762  		return nil, errFakeConnSessionDirty
   763  	}
   764  
   765  	err := checkSubsetTypes(s.c.db.allowAny, args)
   766  	if err != nil {
   767  		return nil, err
   768  	}
   769  	s.touchMem()
   770  
   771  	if s.wait > 0 {
   772  		time.Sleep(s.wait)
   773  	}
   774  
   775  	select {
   776  	default:
   777  	case <-ctx.Done():
   778  		return nil, ctx.Err()
   779  	}
   780  
   781  	db := s.c.db
   782  	switch s.cmd {
   783  	case "WIPE":
   784  		db.wipe()
   785  		return driver.ResultNoRows, nil
   786  	case "CREATE":
   787  		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
   788  			return nil, err
   789  		}
   790  		return driver.ResultNoRows, nil
   791  	case "INSERT":
   792  		return s.execInsert(args, true)
   793  	case "NOSERT":
   794  		// Do all the prep-work like for an INSERT but don't actually insert the row.
   795  		// Used for some of the concurrent tests.
   796  		return s.execInsert(args, false)
   797  	}
   798  	return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
   799  }
   800  
   801  // When doInsert is true, add the row to the table.
   802  // When doInsert is false do prep-work and error checking, but don't
   803  // actually add the row to the table.
   804  func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
   805  	db := s.c.db
   806  	if len(args) != s.placeholders {
   807  		panic("error in pkg db; should only get here if size is correct")
   808  	}
   809  	db.mu.Lock()
   810  	t, ok := db.table(s.table)
   811  	db.mu.Unlock()
   812  	if !ok {
   813  		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   814  	}
   815  
   816  	t.mu.Lock()
   817  	defer t.mu.Unlock()
   818  
   819  	var cols []interface{}
   820  	if doInsert {
   821  		cols = make([]interface{}, len(t.colname))
   822  	}
   823  	argPos := 0
   824  	for n, colname := range s.colName {
   825  		colidx := t.columnIndex(colname)
   826  		if colidx == -1 {
   827  			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
   828  		}
   829  		var val interface{}
   830  		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
   831  			if strvalue == "?" {
   832  				val = args[argPos].Value
   833  			} else {
   834  				// Assign value from argument placeholder name.
   835  				for _, a := range args {
   836  					if a.Name == strvalue[1:] {
   837  						val = a.Value
   838  						break
   839  					}
   840  				}
   841  			}
   842  			argPos++
   843  		} else {
   844  			val = s.colValue[n]
   845  		}
   846  		if doInsert {
   847  			cols[colidx] = val
   848  		}
   849  	}
   850  
   851  	if doInsert {
   852  		t.rows = append(t.rows, &row{cols: cols})
   853  	}
   854  	return driver.RowsAffected(1), nil
   855  }
   856  
   857  // hook to simulate broken connections
   858  var hookQueryBadConn func() bool
   859  
   860  func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
   861  	panic("Use QueryContext")
   862  }
   863  
   864  func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   865  	if s.panic == "Query" {
   866  		panic(s.panic)
   867  	}
   868  	if s.closed {
   869  		return nil, errClosed
   870  	}
   871  
   872  	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
   873  		return nil, driver.ErrBadConn
   874  	}
   875  	if s.c.isDirtyAndMark() {
   876  		return nil, errFakeConnSessionDirty
   877  	}
   878  
   879  	err := checkSubsetTypes(s.c.db.allowAny, args)
   880  	if err != nil {
   881  		return nil, err
   882  	}
   883  
   884  	s.touchMem()
   885  	db := s.c.db
   886  	if len(args) != s.placeholders {
   887  		panic("error in pkg db; should only get here if size is correct")
   888  	}
   889  
   890  	setMRows := make([][]*row, 0, 1)
   891  	setColumns := make([][]string, 0, 1)
   892  	setColType := make([][]string, 0, 1)
   893  
   894  	for {
   895  		db.mu.Lock()
   896  		t, ok := db.table(s.table)
   897  		db.mu.Unlock()
   898  		if !ok {
   899  			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   900  		}
   901  
   902  		if s.table == "magicquery" {
   903  			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
   904  				if args[0].Value == "sleep" {
   905  					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
   906  				}
   907  			}
   908  		}
   909  		if s.table == "tx_status" && s.colName[0] == "tx_status" {
   910  			txStatus := "autocommit"
   911  			if s.c.currTx != nil {
   912  				txStatus = "transaction"
   913  			}
   914  			cursor := &rowsCursor{
   915  				parentMem: s.c,
   916  				posRow:    -1,
   917  				rows: [][]*row{
   918  					{
   919  						{
   920  							cols: []interface{}{
   921  								txStatus,
   922  							},
   923  						},
   924  					},
   925  				},
   926  				cols: [][]string{
   927  					{
   928  						"tx_status",
   929  					},
   930  				},
   931  				colType: [][]string{
   932  					{
   933  						"string",
   934  					},
   935  				},
   936  				errPos: -1,
   937  			}
   938  			return cursor, nil
   939  		}
   940  
   941  		t.mu.Lock()
   942  
   943  		colIdx := make(map[string]int) // select column name -> column index in table
   944  		for _, name := range s.colName {
   945  			idx := t.columnIndex(name)
   946  			if idx == -1 {
   947  				t.mu.Unlock()
   948  				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
   949  			}
   950  			colIdx[name] = idx
   951  		}
   952  
   953  		mrows := []*row{}
   954  	rows:
   955  		for _, trow := range t.rows {
   956  			// Process the where clause, skipping non-match rows. This is lazy
   957  			// and just uses fmt.Sprintf("%v") to test equality. Good enough
   958  			// for test code.
   959  			for _, wcol := range s.whereCol {
   960  				idx := t.columnIndex(wcol.Column)
   961  				if idx == -1 {
   962  					t.mu.Unlock()
   963  					return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
   964  				}
   965  				tcol := trow.cols[idx]
   966  				if bs, ok := tcol.([]byte); ok {
   967  					// lazy hack to avoid sprintf %v on a []byte
   968  					tcol = string(bs)
   969  				}
   970  				var argValue interface{}
   971  				if wcol.Placeholder == "?" {
   972  					argValue = args[wcol.Ordinal-1].Value
   973  				} else {
   974  					// Assign arg value from placeholder name.
   975  					for _, a := range args {
   976  						if a.Name == wcol.Placeholder[1:] {
   977  							argValue = a.Value
   978  							break
   979  						}
   980  					}
   981  				}
   982  				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
   983  					continue rows
   984  				}
   985  			}
   986  			mrow := &row{cols: make([]interface{}, len(s.colName))}
   987  			for seli, name := range s.colName {
   988  				mrow.cols[seli] = trow.cols[colIdx[name]]
   989  			}
   990  			mrows = append(mrows, mrow)
   991  		}
   992  
   993  		var colType []string
   994  		for _, column := range s.colName {
   995  			colType = append(colType, t.coltype[t.columnIndex(column)])
   996  		}
   997  
   998  		t.mu.Unlock()
   999  
  1000  		setMRows = append(setMRows, mrows)
  1001  		setColumns = append(setColumns, s.colName)
  1002  		setColType = append(setColType, colType)
  1003  
  1004  		if s.next == nil {
  1005  			break
  1006  		}
  1007  		s = s.next
  1008  	}
  1009  
  1010  	cursor := &rowsCursor{
  1011  		parentMem: s.c,
  1012  		posRow:    -1,
  1013  		rows:      setMRows,
  1014  		cols:      setColumns,
  1015  		colType:   setColType,
  1016  		errPos:    -1,
  1017  	}
  1018  	return cursor, nil
  1019  }
  1020  
  1021  func (s *fakeStmt) NumInput() int {
  1022  	if s.panic == "NumInput" {
  1023  		panic(s.panic)
  1024  	}
  1025  	return s.placeholders
  1026  }
  1027  
  1028  // hook to simulate broken connections
  1029  var hookCommitBadConn func() bool
  1030  
  1031  func (tx *fakeTx) Commit() error {
  1032  	tx.c.currTx = nil
  1033  	if hookCommitBadConn != nil && hookCommitBadConn() {
  1034  		return driver.ErrBadConn
  1035  	}
  1036  	tx.c.touchMem()
  1037  	return nil
  1038  }
  1039  
  1040  // hook to simulate broken connections
  1041  var hookRollbackBadConn func() bool
  1042  
  1043  func (tx *fakeTx) Rollback() error {
  1044  	tx.c.currTx = nil
  1045  	if hookRollbackBadConn != nil && hookRollbackBadConn() {
  1046  		return driver.ErrBadConn
  1047  	}
  1048  	tx.c.touchMem()
  1049  	return nil
  1050  }
  1051  
  1052  type rowsCursor struct {
  1053  	parentMem memToucher
  1054  	cols      [][]string
  1055  	colType   [][]string
  1056  	posSet    int
  1057  	posRow    int
  1058  	rows      [][]*row
  1059  	closed    bool
  1060  
  1061  	// errPos and err are for making Next return early with error.
  1062  	errPos int
  1063  	err    error
  1064  
  1065  	// a clone of slices to give out to clients, indexed by the
  1066  	// original slice's first byte address.  we clone them
  1067  	// just so we're able to corrupt them on close.
  1068  	bytesClone map[*byte][]byte
  1069  
  1070  	// Every operation writes to line to enable the race detector
  1071  	// check for data races.
  1072  	// This is separate from the fakeConn.line to allow for drivers that
  1073  	// can start multiple queries on the same transaction at the same time.
  1074  	line int64
  1075  }
  1076  
  1077  func (rc *rowsCursor) touchMem() {
  1078  	rc.parentMem.touchMem()
  1079  	rc.line++
  1080  }
  1081  
  1082  func (rc *rowsCursor) Close() error {
  1083  	rc.touchMem()
  1084  	rc.parentMem.touchMem()
  1085  	rc.closed = true
  1086  	return nil
  1087  }
  1088  
  1089  func (rc *rowsCursor) Columns() []string {
  1090  	return rc.cols[rc.posSet]
  1091  }
  1092  
  1093  func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
  1094  	return colTypeToReflectType(rc.colType[rc.posSet][index])
  1095  }
  1096  
  1097  var rowsCursorNextHook func(dest []driver.Value) error
  1098  
  1099  func (rc *rowsCursor) Next(dest []driver.Value) error {
  1100  	if rowsCursorNextHook != nil {
  1101  		return rowsCursorNextHook(dest)
  1102  	}
  1103  
  1104  	if rc.closed {
  1105  		return errors.New("fakedb: cursor is closed")
  1106  	}
  1107  	rc.touchMem()
  1108  	rc.posRow++
  1109  	if rc.posRow == rc.errPos {
  1110  		return rc.err
  1111  	}
  1112  	if rc.posRow >= len(rc.rows[rc.posSet]) {
  1113  		return io.EOF // per interface spec
  1114  	}
  1115  	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
  1116  		// TODO(bradfitz): convert to subset types? naah, I
  1117  		// think the subset types should only be input to
  1118  		// driver, but the sql package should be able to handle
  1119  		// a wider range of types coming out of drivers. all
  1120  		// for ease of drivers, and to prevent drivers from
  1121  		// messing up conversions or doing them differently.
  1122  		dest[i] = v
  1123  
  1124  		if bs, ok := v.([]byte); ok {
  1125  			if rc.bytesClone == nil {
  1126  				rc.bytesClone = make(map[*byte][]byte)
  1127  			}
  1128  			clone, ok := rc.bytesClone[&bs[0]]
  1129  			if !ok {
  1130  				clone = make([]byte, len(bs))
  1131  				copy(clone, bs)
  1132  				rc.bytesClone[&bs[0]] = clone
  1133  			}
  1134  			dest[i] = clone
  1135  		}
  1136  	}
  1137  	return nil
  1138  }
  1139  
  1140  func (rc *rowsCursor) HasNextResultSet() bool {
  1141  	rc.touchMem()
  1142  	return rc.posSet < len(rc.rows)-1
  1143  }
  1144  
  1145  func (rc *rowsCursor) NextResultSet() error {
  1146  	rc.touchMem()
  1147  	if rc.HasNextResultSet() {
  1148  		rc.posSet++
  1149  		rc.posRow = -1
  1150  		return nil
  1151  	}
  1152  	return io.EOF // Per interface spec.
  1153  }
  1154  
  1155  // fakeDriverString is like driver.String, but indirects pointers like
  1156  // DefaultValueConverter.
  1157  //
  1158  // This could be surprising behavior to retroactively apply to
  1159  // driver.String now that Go1 is out, but this is convenient for
  1160  // our TestPointerParamsAndScans.
  1161  //
  1162  type fakeDriverString struct{}
  1163  
  1164  func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
  1165  	switch c := v.(type) {
  1166  	case string, []byte:
  1167  		return v, nil
  1168  	case *string:
  1169  		if c == nil {
  1170  			return nil, nil
  1171  		}
  1172  		return *c, nil
  1173  	}
  1174  	return fmt.Sprintf("%v", v), nil
  1175  }
  1176  
  1177  type anyTypeConverter struct{}
  1178  
  1179  func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) {
  1180  	return v, nil
  1181  }
  1182  
  1183  func converterForType(typ string) driver.ValueConverter {
  1184  	switch typ {
  1185  	case "bool":
  1186  		return driver.Bool
  1187  	case "nullbool":
  1188  		return driver.Null{Converter: driver.Bool}
  1189  	case "byte", "int16":
  1190  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1191  	case "int32":
  1192  		return driver.Int32
  1193  	case "nullbyte", "nullint32", "nullint16":
  1194  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1195  	case "string":
  1196  		return driver.NotNull{Converter: fakeDriverString{}}
  1197  	case "nullstring":
  1198  		return driver.Null{Converter: fakeDriverString{}}
  1199  	case "int64":
  1200  		// TODO(coopernurse): add type-specific converter
  1201  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1202  	case "nullint64":
  1203  		// TODO(coopernurse): add type-specific converter
  1204  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1205  	case "float64":
  1206  		// TODO(coopernurse): add type-specific converter
  1207  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1208  	case "nullfloat64":
  1209  		// TODO(coopernurse): add type-specific converter
  1210  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1211  	case "datetime":
  1212  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1213  	case "nulldatetime":
  1214  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1215  	case "any":
  1216  		return anyTypeConverter{}
  1217  	}
  1218  	panic("invalid fakedb column type of " + typ)
  1219  }
  1220  
  1221  func colTypeToReflectType(typ string) reflect.Type {
  1222  	switch typ {
  1223  	case "bool":
  1224  		return reflect.TypeOf(false)
  1225  	case "nullbool":
  1226  		return reflect.TypeOf(NullBool{})
  1227  	case "int16":
  1228  		return reflect.TypeOf(int16(0))
  1229  	case "nullint16":
  1230  		return reflect.TypeOf(NullInt16{})
  1231  	case "int32":
  1232  		return reflect.TypeOf(int32(0))
  1233  	case "nullint32":
  1234  		return reflect.TypeOf(NullInt32{})
  1235  	case "string":
  1236  		return reflect.TypeOf("")
  1237  	case "nullstring":
  1238  		return reflect.TypeOf(NullString{})
  1239  	case "int64":
  1240  		return reflect.TypeOf(int64(0))
  1241  	case "nullint64":
  1242  		return reflect.TypeOf(NullInt64{})
  1243  	case "float64":
  1244  		return reflect.TypeOf(float64(0))
  1245  	case "nullfloat64":
  1246  		return reflect.TypeOf(NullFloat64{})
  1247  	case "datetime":
  1248  		return reflect.TypeOf(time.Time{})
  1249  	case "any":
  1250  		return reflect.TypeOf(new(interface{})).Elem()
  1251  	}
  1252  	panic("invalid fakedb column type of " + typ)
  1253  }
  1254  

View as plain text