From ea0251546850fe749fa225ddf42b30fb7581fdda Mon Sep 17 00:00:00 2001 From: "face.wsl" Date: Sat, 26 Nov 2022 17:39:46 +0800 Subject: [PATCH] feat: test for mysql and postgres --- scripts/test-sql.sh | 64 +++++++++++++++++ storage/sql/auth.go | 9 ++- storage/sql/server.go | 8 ++- storage/sql/settings.go | 75 +++++++++++++------- storage/sql/share.go | 30 ++++---- storage/sql/sql.go | 151 ++++++++++++++++++++++++++++++++++++---- storage/sql/users.go | 58 ++++++++++----- 7 files changed, 322 insertions(+), 73 deletions(-) create mode 100755 scripts/test-sql.sh diff --git a/scripts/test-sql.sh b/scripts/test-sql.sh new file mode 100755 index 00000000..5dbc1068 --- /dev/null +++ b/scripts/test-sql.sh @@ -0,0 +1,64 @@ +#!/bin/bash + +# stop_docker name +stop_docker() { + sudo docker stop $1 1>/dev/null 2>&1 + sudo docker rm $1 1>/dev/null 2>&1 +} + +# start_mysql name password port +start_mysql() { + stop_docker $1 + sudo docker run --rm --name $1 -e MYSQL_ROOT_PASSWORD=$2 -p $3:3306 -d mysql +} + +# start_postgres name password port +start_postgres() { + stop_docker $1 + sudo docker run --rm --name $1 -e POSTGRES_PASSWORD=$2 -p $3:5432 -d postgres +} + + +test_sqlite() { + rm -f test.db + ./filebrowser -a 0.0.0.0 -d sqlite3://test.db +} + + +test_postgres() { + start_postgres test-postgres postgres 5433 + sleep 30 + ./filebrowser -a 0.0.0.0 -d postgres://postgres:postgres@127.0.0.1:5433/postgres?sslmode=disable +} + + +test_mysql() { + start_mysql test-mysql root 3307 + sleep 60 + ./filebrowser -a 0.0.0.0 -d 'mysql://root:root@127.0.0.1:3307/mysql' +} + +help() { + echo "USAGE: $0 sqlite|mysql|postgres" + exit 1 +} + +if (( $# == 0 )); then + help +fi + +case $1 in + sqlite) + test_sqlite + ;; + mysql) + test_mysql + ;; + postgres) + test_postgres + ;; + *) + help +esac + + diff --git a/storage/sql/auth.go b/storage/sql/auth.go index 1eb7e5b2..c04ba32a 100644 --- a/storage/sql/auth.go +++ b/storage/sql/auth.go @@ -11,7 +11,8 @@ import ( ) type authBackend struct { - db *sql.DB + db *sql.DB + dbType string } func (s authBackend) Get(t settings.AuthMethod) (auth.Auther, error) { @@ -38,5 +39,9 @@ func (s authBackend) Save(a auth.Auther) error { if checkError(err, "Fail to save auth.Auther") { return err } - return SetSetting(s.db, "auther", string(val)) + return SetSetting(s.db, s.dbType, "auther", string(val)) +} + +func newAuthBackend(db *sql.DB, dbType string) authBackend { + return authBackend{db: db, dbType: dbType} } diff --git a/storage/sql/server.go b/storage/sql/server.go index 9672364b..121619c0 100644 --- a/storage/sql/server.go +++ b/storage/sql/server.go @@ -28,7 +28,7 @@ func cloneServer(server settings.Server) settings.Server { } func (s settingsBackend) GetServer() (*settings.Server, error) { - sql := fmt.Sprintf("select key, value from %s", SettingsTable) + sql := fmt.Sprintf("select %s, value from %s", quoteName(s.dbType, "key"), quoteName(s.dbType, SettingsTable)) rows, err := s.db.Query(sql) if checkError(err, "Fail to Query for GetServer") { return nil, err @@ -93,7 +93,11 @@ func (s settingsBackend) SaveServer(ss *settings.Server) error { if checkError(err, "Fail to begin db transaction") { return err } - sql := fmt.Sprintf("INSERT INTO \"%s\" (key, value) VALUES($1,$2)", SettingsTable) + table := quoteName(s.dbType, SettingsTable) + k := quoteName(s.dbType, "key") + p1 := placeHolder(s.dbType, 1) + p2 := placeHolder(s.dbType, 2) + sql := fmt.Sprintf("INSERT INTO %s (%s, value) VALUES(%s,%s)", table, k, p1, p2) for i, field := range fields { stmt, err := s.db.Prepare(sql) defer stmt.Close() diff --git a/storage/sql/settings.go b/storage/sql/settings.go index 5e1fe5d8..3e2a6e70 100644 --- a/storage/sql/settings.go +++ b/storage/sql/settings.go @@ -17,11 +17,12 @@ func init() { } type settingsBackend struct { - db *sql.DB + db *sql.DB + dbType string } -func InitSettingsTable(db *sql.DB) error { - sql := fmt.Sprintf("create table if not exists \"%s\"(key text primary key, value text);", SettingsTable) +func InitSettingsTable(db *sql.DB, dbType string) error { + sql := fmt.Sprintf("create table if not exists %s (%s varchar(128) primary key, value text);", quoteName(dbType, SettingsTable), quoteName(dbType, "key")) _, err := db.Exec(sql) checkError(err, "Fail to create table settings") return err @@ -136,7 +137,7 @@ func boolToString(b bool) string { } func (s settingsBackend) Get() (*settings.Settings, error) { - sql := fmt.Sprintf("select key, value from \"%s\";", SettingsTable) + sql := fmt.Sprintf("select %s, value from %s;", quoteName(s.dbType, "key"), quoteName(s.dbType, SettingsTable)) rows, err := s.db.Query(sql) if checkError(err, "Fail to Query settings.Settings") { return nil, err @@ -196,13 +197,16 @@ func (s settingsBackend) Save(ss *settings.Settings) error { if checkError(err, "Fail to begin db transaction") { return err } + table := quoteName(s.dbType, SettingsTable) + k := quoteName(s.dbType, "key") + p1 := placeHolder(s.dbType, 1) + p2 := placeHolder(s.dbType, 2) for i, field := range fields { - exists := ContainKey(s.db, field) - sql := fmt.Sprintf("INSERT INTO \"%s\" (value, key) VALUES($1,$2);", SettingsTable) + exists := ContainKey(s.db, s.dbType, field) + sql := fmt.Sprintf("INSERT INTO %s (value, %s) VALUES(%s,%s);", table, k, p1, p2) if exists { - sql = fmt.Sprintf("UPDATE \"%s\" set value = $1 where key = $2;", SettingsTable) + sql = fmt.Sprintf("UPDATE %s set value = %s where %s = %s;", table, p1, k, p2) } - fmt.Println(sql) stmt, err := s.db.Prepare(sql) defer stmt.Close() if checkError(err, "Fail to prepare statement") { @@ -274,31 +278,37 @@ func cloneSettings(s settings.Settings) settings.Settings { return s1 } -func SetSetting(db *sql.DB, key string, value string) error { - sql := fmt.Sprintf("select count(key) from \"%s\" where key = '%s';", SettingsTable, key) +func SetSetting(db *sql.DB, dbType string, key string, value string) error { + t := quoteName(dbType, SettingsTable) + k := quoteName(dbType, "key") + sql := fmt.Sprintf("select count(%s) from %s where %s = '%s';", k, t, k, key) count := 0 err := db.QueryRow(sql).Scan(&count) if checkError(err, "Fail to QueryRow for key="+key) { return err } if count == 0 { - return addSetting(db, key, value) + return addSetting(db, dbType, key, value) } - return updateSetting(db, key, value) + return updateSetting(db, dbType, key, value) } -func GetSetting(db *sql.DB, key string) string { - sql := fmt.Sprintf("select value from \"%s\" where key = '%s';", SettingsTable, key) +func GetSetting(db *sql.DB, dbType string, key string) string { + sql := fmt.Sprintf("select value from %s where %s = '%s';", quoteName(dbType, SettingsTable), quoteName(dbType, "key"), key) value := "" err := db.QueryRow(sql).Scan(&value) - if checkError(err, "Fail to QueryRow for key "+key) { + if checkError(err, "") { return value } return value } -func addSetting(db *sql.DB, key string, value string) error { - sql := fmt.Sprintf("insert into \"%s\" (key, value) values($1, $2);", SettingsTable) +func addSetting(db *sql.DB, dbType string, key string, value string) error { + table := quoteName(dbType, SettingsTable) + k := quoteName(dbType, "key") + p1 := placeHolder(dbType, 1) + p2 := placeHolder(dbType, 2) + sql := fmt.Sprintf("insert into %s (%s, value) values(%s, %s);", table, k, p1, p2) stmt, err := db.Prepare(sql) if checkError(err, "Fail to prepare sql") { return err @@ -308,8 +318,14 @@ func addSetting(db *sql.DB, key string, value string) error { return err } -func updateSetting(db *sql.DB, key string, value string) error { - sql := fmt.Sprintf("update \"%s\" set value = $1 where key = $2;", SettingsTable) +func updateSetting(db *sql.DB, dbType string, key string, value string) error { + sql := fmt.Sprintf( + "update %s set value = %s where %s = %s;", + quoteName(dbType, SettingsTable), + placeHolder(dbType, 1), + quoteName(dbType, "key"), + placeHolder(dbType, 2), + ) stmt, err := db.Prepare(sql) if checkError(err, "Fail to prepare sql") { return err @@ -320,23 +336,32 @@ func updateSetting(db *sql.DB, key string, value string) error { } func HadSetting(db *sql.DB) bool { - key := GetSetting(db, "Key") + dbType, err := GetDBType(db) + if checkError(err, "Fail to get db type") { + return false + } + key := GetSetting(db, dbType, "Key") if key == "" { return false } return true } -func ContainKey(db *sql.DB, key string) bool { - sql := fmt.Sprintf("select value from \"%s\" where key = '%s';", SettingsTable, key) +func ContainKey(db *sql.DB, dbType string, key string) bool { + sql := fmt.Sprintf("select value from %s where %s = '%s';", quoteName(dbType, SettingsTable), quoteName(dbType, "key"), key) value := "" err := db.QueryRow(sql).Scan(&value) - if checkError(err, "Fail to QueryRow for key "+key) { + if checkError(err, "") { return false } return true } -func HadSettingOfKey(db *sql.DB, key string) bool { - return GetSetting(db, "Key") == key +func HadSettingOfKey(db *sql.DB, dbType string, key string) bool { + return GetSetting(db, dbType, "Key") == key +} + +func newSettingsBackend(db *sql.DB, dbType string) settingsBackend { + InitSettingsTable(db, dbType) + return settingsBackend{db: db, dbType: dbType} } diff --git a/storage/sql/share.go b/storage/sql/share.go index d9610da9..b286647b 100644 --- a/storage/sql/share.go +++ b/storage/sql/share.go @@ -8,15 +8,16 @@ import ( ) type shareBackend struct { - db *sql.DB + db *sql.DB + dbType string } type linkRecord interface { Scan(dest ...interface{}) error } -func InitSharesTable(db *sql.DB) error { - sql := fmt.Sprintf("create table if not exists \"%s\" (hash text, path text, userid integer, expire integer, passwordhash text, token text)", SharesTable) +func InitSharesTable(db *sql.DB, dbType string) error { + sql := fmt.Sprintf("create table if not exists %s (hash text, path text, userid integer, expire integer, passwordhash text, token text)", quoteName(dbType, SharesTable)) _, err := db.Exec(sql) checkError(err, "Fail to InitSharesTable") return err @@ -43,8 +44,8 @@ func parseLink(row linkRecord) (*share.Link, error) { return &link, nil } -func queryLinks(db *sql.DB, condition string) ([]*share.Link, error) { - sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\"", SharesTable) +func queryLinks(db *sql.DB, dbType string, condition string) ([]*share.Link, error) { + sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s", quoteName(dbType, SharesTable)) if len(condition) > 0 { sql = sql + " where " + condition } @@ -64,37 +65,42 @@ func queryLinks(db *sql.DB, condition string) ([]*share.Link, error) { } func (s shareBackend) All() ([]*share.Link, error) { - return queryLinks(s.db, "") + return queryLinks(s.db, s.dbType, "") } func (s shareBackend) FindByUserID(id uint) ([]*share.Link, error) { condition := fmt.Sprintf("userid=%d", id) - return queryLinks(s.db, condition) + return queryLinks(s.db, s.dbType, condition) } func (s shareBackend) GetByHash(hash string) (*share.Link, error) { - sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\" where hash='%s'", SharesTable, hash) + sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s where hash='%s'", quoteName(s.dbType, SharesTable), hash) return parseLink(s.db.QueryRow(sql)) } func (s shareBackend) GetPermanent(path string, id uint) (*share.Link, error) { - sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\" where path='%s' and userid=%d", SharesTable, path, id) + sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s where path='%s' and userid=%d", quoteName(s.dbType, SharesTable), path, id) return parseLink(s.db.QueryRow(sql)) } func (s shareBackend) Gets(path string, id uint) ([]*share.Link, error) { condition := fmt.Sprintf("userid=%d and path='%s'", id, path) - return queryLinks(s.db, condition) + return queryLinks(s.db, s.dbType, condition) } func (s shareBackend) Save(l *share.Link) error { - sql := fmt.Sprintf("insert into \"%s\" (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", SharesTable, l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token) + sql := fmt.Sprintf("insert into %s (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", quoteName(s.dbType, SharesTable), l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token) _, err := s.db.Exec(sql) checkError(err, "Fail to Save share") return err } func (s shareBackend) Delete(hash string) error { - sql := fmt.Sprintf("DELETE FROM \"%s\" WHERE hash='%s'", SharesTable, hash) + sql := fmt.Sprintf("DELETE FROM %s WHERE hash='%s'", quoteName(s.dbType, SharesTable), hash) _, err := s.db.Exec(sql) checkError(err, "Fail to Delete share") return err } + +func newShareBackend(db *sql.DB, dbType string) shareBackend { + InitSharesTable(db, dbType) + return shareBackend{db: db, dbType: dbType} +} diff --git a/storage/sql/sql.go b/storage/sql/sql.go index 8a5cabc6..aaaf72f4 100644 --- a/storage/sql/sql.go +++ b/storage/sql/sql.go @@ -3,6 +3,9 @@ package sql import ( "database/sql" "errors" + "fmt" + "regexp" + "strconv" "strings" "github.com/filebrowser/filebrowser/v2/auth" @@ -15,6 +18,58 @@ import ( _ "github.com/mattn/go-sqlite3" ) +func init() { +} + +type DBConnectionRecord struct { + db *sql.DB + dbType string + path string +} + +var ( + dbRecords map[string]*DBConnectionRecord = map[string]*DBConnectionRecord{} +) + +// GetDBType used to get the driver type of a sql.DB +// It is based on existing dbRecords +// All sql.DB should opened by OpenDB +func GetDBType(db *sql.DB) (string, error) { + for _, record := range dbRecords { + if record.db == db { + return record.dbType, nil + } + } + return "", errors.New("No such database open by this module") +} + +func getNameQuote(dbType string) string { + if dbType == "mysql" { + return "`" + } + return "\"" +} + +// for mysql, it is `` +// for postgres and sqlite, it is "" +func quoteName(dbType string, name string) string { + q := getNameQuote(dbType) + return q + name + q +} + +// placeholder for sql stmt +// for postgres, it is $1, $2, $3... +// for mysql and sqlite3, it is ?,?,?... +func placeHolder(dbType string, index int) string { + if index <= 0 { + panic("the placeholder index should >= 1") + } + if dbType == "postgres" { + return fmt.Sprintf("$%d", index) + } + return "?" +} + func IsDBPath(path string) bool { prefixes := []string{"sqlite3", "postgres", "mysql"} for _, prefix := range prefixes { @@ -26,18 +81,93 @@ func IsDBPath(path string) bool { } func OpenDB(path string) (*sql.DB, error) { + if val, ok := dbRecords[path]; ok { + return val.db, nil + } prefixes := []string{"sqlite3", "postgres", "mysql"} for _, prefix := range prefixes { if strings.HasPrefix(path, prefix) { - return connectDB(prefix, path) + db, err := connectDB(prefix, path) + if !checkError(err, "Fail to connect database "+path) { + dbRecords[path] = &DBConnectionRecord{db: db, dbType: prefix, path: path} + } + return db, err } } return nil, errors.New("Unsupported db scheme") } +type DatabaseResource struct { + scheme string + username string + password string + host string + port int + database string +} + +func ParseDatabasePath(path string) (*DatabaseResource, error) { + pattern := "^(([a-zA-Z0-9]+)://)?(([^:]+)(:(.*))?@)?([a-zA-Z0-9_.]+)(:([0-9]+))?(/([a-zA-Z0-9_-]+))?$" + reg, err := regexp.Compile(pattern) + if checkError(err, "Fail to compile regexp") { + return nil, err + } + matches := reg.FindAllStringSubmatch(path, -1) + if matches == nil || len(matches) == 0 { + return nil, errors.New("Fail to parse database") + } + r := DatabaseResource{} + r.scheme = matches[0][2] + r.username = matches[0][4] + r.password = matches[0][6] + r.host = matches[0][7] + if len(matches[0][9]) > 0 { + port, err := strconv.Atoi(matches[0][9]) + if !checkError(err, "Fail to parse port") { + r.port = port + } + } + r.database = matches[0][11] + return &r, nil +} + +// mysql://user:password@host:port/db => mysql://user:password@tcp(host:port)/db +func transformMysqlPath(path string) (string, error) { + r, err := ParseDatabasePath(path) + if checkError(err, "Fail to parse database path") { + return "", err + } + scheme := r.scheme + if len(scheme) == 0 { + scheme = "mysql" + } + credential := "" + if len(r.username) > 0 && len(r.password) > 0 { + credential = r.username + ":" + r.password + "@" + } else if len(r.username) > 0 { + credential = r.username + "@" + } + host := r.host + port := r.port + if port == 0 { + port = 3306 + } + if len(r.database) == 0 { + return "", errors.New("no database found in path") + } + return fmt.Sprintf("%s://%stcp(%s:%d)/%s", scheme, credential, host, port, r.database), nil +} + func connectDB(dbType string, path string) (*sql.DB, error) { if dbType == "sqlite3" && strings.HasPrefix(path, "sqlite3://") { path = strings.TrimPrefix(path, "sqlite3://") + } else if dbType == "mysql" && strings.HasPrefix(path, "mysql://") { + p, err := transformMysqlPath(path) + if checkError(err, "Fail to parse mysql path") { + return nil, err + } + path = p + path = strings.TrimPrefix(path, "mysql://") } db, err := sql.Open(dbType, path) if err == nil { @@ -47,20 +177,13 @@ func connectDB(dbType string, path string) (*sql.DB, error) { } func NewStorage(db *sql.DB) (*storage.Storage, error) { + dbType, err := GetDBType(db) + checkError(err, "Fail to get database type, maybe this sql.DB is not opened by OpenDB") - InitUserTable(db) - InitSharesTable(db) - InitSettingsTable(db) - - userStore := users.NewStorage(usersBackend{db: db}) - shareStore := share.NewStorage(shareBackend{db: db}) - settingsStore := settings.NewStorage(settingsBackend{db: db}) - authStore := auth.NewStorage(authBackend{db: db}, userStore) - - err := SetSetting(db, "version", "2") - if checkError(err, "Fail to set version") { - return nil, err - } + userStore := users.NewStorage(newUsersBackend(db, dbType)) + shareStore := share.NewStorage(newShareBackend(db, dbType)) + settingsStore := settings.NewStorage(newSettingsBackend(db, dbType)) + authStore := auth.NewStorage(newAuthBackend(db, dbType), userStore) storage := &storage.Storage{ Auth: authStore, diff --git a/storage/sql/users.go b/storage/sql/users.go index 88f3d5d4..1fd4b56e 100644 --- a/storage/sql/users.go +++ b/storage/sql/users.go @@ -14,7 +14,8 @@ import ( ) type usersBackend struct { - db *sql.DB + db *sql.DB + dbType string } func PermFromString(s string) users.Permissions { @@ -114,13 +115,24 @@ func createAdminUser() users.User { return user } -func InitUserTable(db *sql.DB) error { - sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS \"%s\" (id integer primary key, username text, password text, scope text, locale text, lockpassword integer, viewmode text, perm text, commands text, sorting text, rules text, hidedotfiles integer, dateformat integer, singleclick integer);", UsersTable) +func InitUserTable(db *sql.DB, dbType string) error { + primaryKey := "integer primary key" + if dbType == "postgres" { + primaryKey = "serial primary key" + } else if dbType == "mysql" { + primaryKey = "int unsigned primary key auto_increment" + } + sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id %s, username text, password text, scope text, locale text, lockpassword integer, viewmode text, perm text, commands text, sorting text, rules text, hidedotfiles integer, dateformat integer, singleclick integer);", quoteName(dbType, UsersTable), primaryKey) _, err := db.Exec(sql) checkError(err, "Fail to create users table") return err } +func newUsersBackend(db *sql.DB, dbType string) usersBackend { + InitUserTable(db, dbType) + return usersBackend{db: db, dbType: dbType} +} + func (s usersBackend) GetBy(i interface{}) (*users.User, error) { columns := []string{"id", "username", "password", "scope", "locale", "lockpassword", "viewmode", "perm", "commands", "sorting", "rules", "hidedotfiles", "dateformat", "singleclick"} columnsStr := strings.Join(columns, ",") @@ -150,7 +162,7 @@ func (s usersBackend) GetBy(i interface{}) (*users.User, error) { dateformat := false singleclick := false user := users.User{} - sql := fmt.Sprintf("SELECT %s FROM \"%s\" WHERE %s", columnsStr, UsersTable, conditionStr) + sql := fmt.Sprintf("SELECT %s FROM %s WHERE %s", columnsStr, quoteName(s.dbType, UsersTable), conditionStr) err := s.db.QueryRow(sql).Scan(&userID, &username, &password, &scope, &locale, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules, &hidedotfiles, &dateformat, &singleclick) if checkError(err, "") { return nil, err @@ -173,7 +185,7 @@ func (s usersBackend) GetBy(i interface{}) (*users.User, error) { } func (s usersBackend) Gets() ([]*users.User, error) { - sql := fmt.Sprintf("SELECT id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules FROM \"%s\"", UsersTable) + sql := fmt.Sprintf("SELECT id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules FROM %s", quoteName(s.dbType, UsersTable)) rows, err := s.db.Query(sql) if checkError(err, "Fail to Query []*users.User") { return nil, err @@ -217,8 +229,8 @@ func (s usersBackend) updateUser(id uint, user *users.User) error { lockpassword = 1 } sql := fmt.Sprintf( - "UPDATE \"%s\" SET username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' WHERE id=%d", - UsersTable, + "UPDATE %s SET username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' WHERE id=%d", + quoteName(s.dbType, UsersTable), user.Username, user.Password, user.Scope, @@ -259,7 +271,10 @@ func (s usersBackend) insertUser(user *users.User) error { } columnStr := strings.Join(columns, ",") specStr := strings.Join(specs, ",") - sqlFormat := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)", UsersTable, columnStr, specStr) + sqlFormat := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoteName(s.dbType, UsersTable), columnStr, specStr) + if s.dbType == "postgres" { + sqlFormat = sqlFormat + " RETURNING id;" + } sql := fmt.Sprintf( sqlFormat, user.Username, @@ -276,12 +291,19 @@ func (s usersBackend) insertUser(user *users.User) error { boolToString(user.DateFormat), boolToString(user.SingleClick), ) + if s.dbType == "postgres" { + id := uint(0) + err := s.db.QueryRow(sql).Scan(&id) + if !checkError(err, "Fail to insert user") { + user.ID = id + } + return err + } res, err := s.db.Exec(sql) if !checkError(err, "Fail to insert user") { - id, err2 := res.LastInsertId() - if !checkError(err2, "Fail to fetch last insert id") { - user.ID = uint(id) - } + id, err := res.LastInsertId() + checkError(err, "Fail to get last inserted id") + user.ID = uint(id) } return err } @@ -296,14 +318,14 @@ func (s usersBackend) Save(user *users.User) error { } func (s usersBackend) DeleteByID(id uint) error { - sql := fmt.Sprintf("delete from \"%s\" where id=%d", UsersTable, id) + sql := fmt.Sprintf("delete from %s where id=%d", quoteName(s.dbType, UsersTable), id) _, err := s.db.Exec(sql) checkError(err, "Fail to delete User by id") return err } func (s usersBackend) DeleteByUsername(username string) error { - sql := fmt.Sprintf("delete from \"%s\" where username='%s'", UsersTable, username) + sql := fmt.Sprintf("delete from %s where username='%s'", quoteName(s.dbType, UsersTable), username) _, err := s.db.Exec(sql) checkError(err, "Fail to delete user by username") return err @@ -323,15 +345,15 @@ func (s usersBackend) Update(u *users.User, fields ...string) error { val := userField.Interface() typeStr := reflect.TypeOf(val).Kind().String() if typeStr == "string" { - setItems = append(setItems, fmt.Sprintf("\"%s\"='%s'", field, val)) + setItems = append(setItems, fmt.Sprintf("%s='%s'", field, val)) } else if typeStr == "bool" { - setItems = append(setItems, fmt.Sprintf("\"%s\"=%s", field, boolToString(val.(bool)))) + setItems = append(setItems, fmt.Sprintf("%s=%s", field, boolToString(val.(bool)))) } else { // TODO - setItems = append(setItems, fmt.Sprintf("\"%s\"=%s", field, val)) + setItems = append(setItems, fmt.Sprintf("%s=%s", field, val)) } } - sql := fmt.Sprintf("UPDATE \"%s\" SET %s WHERE id=%d", UsersTable, strings.Join(setItems, ","), u.ID) + sql := fmt.Sprintf("UPDATE %s SET %s WHERE id=%d", UsersTable, strings.Join(setItems, ","), u.ID) _, err := s.db.Exec(sql) checkError(err, "Fail to update user") return err