feat: test for mysql and postgres

This commit is contained in:
face.wsl 2022-11-26 17:39:46 +08:00
parent 5fdf5d0eca
commit ea02515468
7 changed files with 322 additions and 73 deletions

64
scripts/test-sql.sh Executable file
View File

@ -0,0 +1,64 @@
#!/bin/bash
# stop_docker name
stop_docker() {
sudo docker stop $1 1>/dev/null 2>&1
sudo docker rm $1 1>/dev/null 2>&1
}
# start_mysql name password port
start_mysql() {
stop_docker $1
sudo docker run --rm --name $1 -e MYSQL_ROOT_PASSWORD=$2 -p $3:3306 -d mysql
}
# start_postgres name password port
start_postgres() {
stop_docker $1
sudo docker run --rm --name $1 -e POSTGRES_PASSWORD=$2 -p $3:5432 -d postgres
}
test_sqlite() {
rm -f test.db
./filebrowser -a 0.0.0.0 -d sqlite3://test.db
}
test_postgres() {
start_postgres test-postgres postgres 5433
sleep 30
./filebrowser -a 0.0.0.0 -d postgres://postgres:postgres@127.0.0.1:5433/postgres?sslmode=disable
}
test_mysql() {
start_mysql test-mysql root 3307
sleep 60
./filebrowser -a 0.0.0.0 -d 'mysql://root:root@127.0.0.1:3307/mysql'
}
help() {
echo "USAGE: $0 sqlite|mysql|postgres"
exit 1
}
if (( $# == 0 )); then
help
fi
case $1 in
sqlite)
test_sqlite
;;
mysql)
test_mysql
;;
postgres)
test_postgres
;;
*)
help
esac

View File

@ -11,7 +11,8 @@ import (
) )
type authBackend struct { type authBackend struct {
db *sql.DB db *sql.DB
dbType string
} }
func (s authBackend) Get(t settings.AuthMethod) (auth.Auther, error) { func (s authBackend) Get(t settings.AuthMethod) (auth.Auther, error) {
@ -38,5 +39,9 @@ func (s authBackend) Save(a auth.Auther) error {
if checkError(err, "Fail to save auth.Auther") { if checkError(err, "Fail to save auth.Auther") {
return err return err
} }
return SetSetting(s.db, "auther", string(val)) return SetSetting(s.db, s.dbType, "auther", string(val))
}
func newAuthBackend(db *sql.DB, dbType string) authBackend {
return authBackend{db: db, dbType: dbType}
} }

View File

@ -28,7 +28,7 @@ func cloneServer(server settings.Server) settings.Server {
} }
func (s settingsBackend) GetServer() (*settings.Server, error) { func (s settingsBackend) GetServer() (*settings.Server, error) {
sql := fmt.Sprintf("select key, value from %s", SettingsTable) sql := fmt.Sprintf("select %s, value from %s", quoteName(s.dbType, "key"), quoteName(s.dbType, SettingsTable))
rows, err := s.db.Query(sql) rows, err := s.db.Query(sql)
if checkError(err, "Fail to Query for GetServer") { if checkError(err, "Fail to Query for GetServer") {
return nil, err return nil, err
@ -93,7 +93,11 @@ func (s settingsBackend) SaveServer(ss *settings.Server) error {
if checkError(err, "Fail to begin db transaction") { if checkError(err, "Fail to begin db transaction") {
return err return err
} }
sql := fmt.Sprintf("INSERT INTO \"%s\" (key, value) VALUES($1,$2)", SettingsTable) table := quoteName(s.dbType, SettingsTable)
k := quoteName(s.dbType, "key")
p1 := placeHolder(s.dbType, 1)
p2 := placeHolder(s.dbType, 2)
sql := fmt.Sprintf("INSERT INTO %s (%s, value) VALUES(%s,%s)", table, k, p1, p2)
for i, field := range fields { for i, field := range fields {
stmt, err := s.db.Prepare(sql) stmt, err := s.db.Prepare(sql)
defer stmt.Close() defer stmt.Close()

View File

@ -17,11 +17,12 @@ func init() {
} }
type settingsBackend struct { type settingsBackend struct {
db *sql.DB db *sql.DB
dbType string
} }
func InitSettingsTable(db *sql.DB) error { func InitSettingsTable(db *sql.DB, dbType string) error {
sql := fmt.Sprintf("create table if not exists \"%s\"(key text primary key, value text);", SettingsTable) sql := fmt.Sprintf("create table if not exists %s (%s varchar(128) primary key, value text);", quoteName(dbType, SettingsTable), quoteName(dbType, "key"))
_, err := db.Exec(sql) _, err := db.Exec(sql)
checkError(err, "Fail to create table settings") checkError(err, "Fail to create table settings")
return err return err
@ -136,7 +137,7 @@ func boolToString(b bool) string {
} }
func (s settingsBackend) Get() (*settings.Settings, error) { func (s settingsBackend) Get() (*settings.Settings, error) {
sql := fmt.Sprintf("select key, value from \"%s\";", SettingsTable) sql := fmt.Sprintf("select %s, value from %s;", quoteName(s.dbType, "key"), quoteName(s.dbType, SettingsTable))
rows, err := s.db.Query(sql) rows, err := s.db.Query(sql)
if checkError(err, "Fail to Query settings.Settings") { if checkError(err, "Fail to Query settings.Settings") {
return nil, err return nil, err
@ -196,13 +197,16 @@ func (s settingsBackend) Save(ss *settings.Settings) error {
if checkError(err, "Fail to begin db transaction") { if checkError(err, "Fail to begin db transaction") {
return err return err
} }
table := quoteName(s.dbType, SettingsTable)
k := quoteName(s.dbType, "key")
p1 := placeHolder(s.dbType, 1)
p2 := placeHolder(s.dbType, 2)
for i, field := range fields { for i, field := range fields {
exists := ContainKey(s.db, field) exists := ContainKey(s.db, s.dbType, field)
sql := fmt.Sprintf("INSERT INTO \"%s\" (value, key) VALUES($1,$2);", SettingsTable) sql := fmt.Sprintf("INSERT INTO %s (value, %s) VALUES(%s,%s);", table, k, p1, p2)
if exists { if exists {
sql = fmt.Sprintf("UPDATE \"%s\" set value = $1 where key = $2;", SettingsTable) sql = fmt.Sprintf("UPDATE %s set value = %s where %s = %s;", table, p1, k, p2)
} }
fmt.Println(sql)
stmt, err := s.db.Prepare(sql) stmt, err := s.db.Prepare(sql)
defer stmt.Close() defer stmt.Close()
if checkError(err, "Fail to prepare statement") { if checkError(err, "Fail to prepare statement") {
@ -274,31 +278,37 @@ func cloneSettings(s settings.Settings) settings.Settings {
return s1 return s1
} }
func SetSetting(db *sql.DB, key string, value string) error { func SetSetting(db *sql.DB, dbType string, key string, value string) error {
sql := fmt.Sprintf("select count(key) from \"%s\" where key = '%s';", SettingsTable, key) t := quoteName(dbType, SettingsTable)
k := quoteName(dbType, "key")
sql := fmt.Sprintf("select count(%s) from %s where %s = '%s';", k, t, k, key)
count := 0 count := 0
err := db.QueryRow(sql).Scan(&count) err := db.QueryRow(sql).Scan(&count)
if checkError(err, "Fail to QueryRow for key="+key) { if checkError(err, "Fail to QueryRow for key="+key) {
return err return err
} }
if count == 0 { if count == 0 {
return addSetting(db, key, value) return addSetting(db, dbType, key, value)
} }
return updateSetting(db, key, value) return updateSetting(db, dbType, key, value)
} }
func GetSetting(db *sql.DB, key string) string { func GetSetting(db *sql.DB, dbType string, key string) string {
sql := fmt.Sprintf("select value from \"%s\" where key = '%s';", SettingsTable, key) sql := fmt.Sprintf("select value from %s where %s = '%s';", quoteName(dbType, SettingsTable), quoteName(dbType, "key"), key)
value := "" value := ""
err := db.QueryRow(sql).Scan(&value) err := db.QueryRow(sql).Scan(&value)
if checkError(err, "Fail to QueryRow for key "+key) { if checkError(err, "") {
return value return value
} }
return value return value
} }
func addSetting(db *sql.DB, key string, value string) error { func addSetting(db *sql.DB, dbType string, key string, value string) error {
sql := fmt.Sprintf("insert into \"%s\" (key, value) values($1, $2);", SettingsTable) table := quoteName(dbType, SettingsTable)
k := quoteName(dbType, "key")
p1 := placeHolder(dbType, 1)
p2 := placeHolder(dbType, 2)
sql := fmt.Sprintf("insert into %s (%s, value) values(%s, %s);", table, k, p1, p2)
stmt, err := db.Prepare(sql) stmt, err := db.Prepare(sql)
if checkError(err, "Fail to prepare sql") { if checkError(err, "Fail to prepare sql") {
return err return err
@ -308,8 +318,14 @@ func addSetting(db *sql.DB, key string, value string) error {
return err return err
} }
func updateSetting(db *sql.DB, key string, value string) error { func updateSetting(db *sql.DB, dbType string, key string, value string) error {
sql := fmt.Sprintf("update \"%s\" set value = $1 where key = $2;", SettingsTable) sql := fmt.Sprintf(
"update %s set value = %s where %s = %s;",
quoteName(dbType, SettingsTable),
placeHolder(dbType, 1),
quoteName(dbType, "key"),
placeHolder(dbType, 2),
)
stmt, err := db.Prepare(sql) stmt, err := db.Prepare(sql)
if checkError(err, "Fail to prepare sql") { if checkError(err, "Fail to prepare sql") {
return err return err
@ -320,23 +336,32 @@ func updateSetting(db *sql.DB, key string, value string) error {
} }
func HadSetting(db *sql.DB) bool { func HadSetting(db *sql.DB) bool {
key := GetSetting(db, "Key") dbType, err := GetDBType(db)
if checkError(err, "Fail to get db type") {
return false
}
key := GetSetting(db, dbType, "Key")
if key == "" { if key == "" {
return false return false
} }
return true return true
} }
func ContainKey(db *sql.DB, key string) bool { func ContainKey(db *sql.DB, dbType string, key string) bool {
sql := fmt.Sprintf("select value from \"%s\" where key = '%s';", SettingsTable, key) sql := fmt.Sprintf("select value from %s where %s = '%s';", quoteName(dbType, SettingsTable), quoteName(dbType, "key"), key)
value := "" value := ""
err := db.QueryRow(sql).Scan(&value) err := db.QueryRow(sql).Scan(&value)
if checkError(err, "Fail to QueryRow for key "+key) { if checkError(err, "") {
return false return false
} }
return true return true
} }
func HadSettingOfKey(db *sql.DB, key string) bool { func HadSettingOfKey(db *sql.DB, dbType string, key string) bool {
return GetSetting(db, "Key") == key return GetSetting(db, dbType, "Key") == key
}
func newSettingsBackend(db *sql.DB, dbType string) settingsBackend {
InitSettingsTable(db, dbType)
return settingsBackend{db: db, dbType: dbType}
} }

View File

@ -8,15 +8,16 @@ import (
) )
type shareBackend struct { type shareBackend struct {
db *sql.DB db *sql.DB
dbType string
} }
type linkRecord interface { type linkRecord interface {
Scan(dest ...interface{}) error Scan(dest ...interface{}) error
} }
func InitSharesTable(db *sql.DB) error { func InitSharesTable(db *sql.DB, dbType string) error {
sql := fmt.Sprintf("create table if not exists \"%s\" (hash text, path text, userid integer, expire integer, passwordhash text, token text)", SharesTable) sql := fmt.Sprintf("create table if not exists %s (hash text, path text, userid integer, expire integer, passwordhash text, token text)", quoteName(dbType, SharesTable))
_, err := db.Exec(sql) _, err := db.Exec(sql)
checkError(err, "Fail to InitSharesTable") checkError(err, "Fail to InitSharesTable")
return err return err
@ -43,8 +44,8 @@ func parseLink(row linkRecord) (*share.Link, error) {
return &link, nil return &link, nil
} }
func queryLinks(db *sql.DB, condition string) ([]*share.Link, error) { func queryLinks(db *sql.DB, dbType string, condition string) ([]*share.Link, error) {
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\"", SharesTable) sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s", quoteName(dbType, SharesTable))
if len(condition) > 0 { if len(condition) > 0 {
sql = sql + " where " + condition sql = sql + " where " + condition
} }
@ -64,37 +65,42 @@ func queryLinks(db *sql.DB, condition string) ([]*share.Link, error) {
} }
func (s shareBackend) All() ([]*share.Link, error) { func (s shareBackend) All() ([]*share.Link, error) {
return queryLinks(s.db, "") return queryLinks(s.db, s.dbType, "")
} }
func (s shareBackend) FindByUserID(id uint) ([]*share.Link, error) { func (s shareBackend) FindByUserID(id uint) ([]*share.Link, error) {
condition := fmt.Sprintf("userid=%d", id) condition := fmt.Sprintf("userid=%d", id)
return queryLinks(s.db, condition) return queryLinks(s.db, s.dbType, condition)
} }
func (s shareBackend) GetByHash(hash string) (*share.Link, error) { func (s shareBackend) GetByHash(hash string) (*share.Link, error) {
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\" where hash='%s'", SharesTable, hash) sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s where hash='%s'", quoteName(s.dbType, SharesTable), hash)
return parseLink(s.db.QueryRow(sql)) return parseLink(s.db.QueryRow(sql))
} }
func (s shareBackend) GetPermanent(path string, id uint) (*share.Link, error) { func (s shareBackend) GetPermanent(path string, id uint) (*share.Link, error) {
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\" where path='%s' and userid=%d", SharesTable, path, id) sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s where path='%s' and userid=%d", quoteName(s.dbType, SharesTable), path, id)
return parseLink(s.db.QueryRow(sql)) return parseLink(s.db.QueryRow(sql))
} }
func (s shareBackend) Gets(path string, id uint) ([]*share.Link, error) { func (s shareBackend) Gets(path string, id uint) ([]*share.Link, error) {
condition := fmt.Sprintf("userid=%d and path='%s'", id, path) condition := fmt.Sprintf("userid=%d and path='%s'", id, path)
return queryLinks(s.db, condition) return queryLinks(s.db, s.dbType, condition)
} }
func (s shareBackend) Save(l *share.Link) error { func (s shareBackend) Save(l *share.Link) error {
sql := fmt.Sprintf("insert into \"%s\" (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", SharesTable, l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token) sql := fmt.Sprintf("insert into %s (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", quoteName(s.dbType, SharesTable), l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token)
_, err := s.db.Exec(sql) _, err := s.db.Exec(sql)
checkError(err, "Fail to Save share") checkError(err, "Fail to Save share")
return err return err
} }
func (s shareBackend) Delete(hash string) error { func (s shareBackend) Delete(hash string) error {
sql := fmt.Sprintf("DELETE FROM \"%s\" WHERE hash='%s'", SharesTable, hash) sql := fmt.Sprintf("DELETE FROM %s WHERE hash='%s'", quoteName(s.dbType, SharesTable), hash)
_, err := s.db.Exec(sql) _, err := s.db.Exec(sql)
checkError(err, "Fail to Delete share") checkError(err, "Fail to Delete share")
return err return err
} }
func newShareBackend(db *sql.DB, dbType string) shareBackend {
InitSharesTable(db, dbType)
return shareBackend{db: db, dbType: dbType}
}

View File

@ -3,6 +3,9 @@ package sql
import ( import (
"database/sql" "database/sql"
"errors" "errors"
"fmt"
"regexp"
"strconv"
"strings" "strings"
"github.com/filebrowser/filebrowser/v2/auth" "github.com/filebrowser/filebrowser/v2/auth"
@ -15,6 +18,58 @@ import (
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
) )
func init() {
}
type DBConnectionRecord struct {
db *sql.DB
dbType string
path string
}
var (
dbRecords map[string]*DBConnectionRecord = map[string]*DBConnectionRecord{}
)
// GetDBType used to get the driver type of a sql.DB
// It is based on existing dbRecords
// All sql.DB should opened by OpenDB
func GetDBType(db *sql.DB) (string, error) {
for _, record := range dbRecords {
if record.db == db {
return record.dbType, nil
}
}
return "", errors.New("No such database open by this module")
}
func getNameQuote(dbType string) string {
if dbType == "mysql" {
return "`"
}
return "\""
}
// for mysql, it is ``
// for postgres and sqlite, it is ""
func quoteName(dbType string, name string) string {
q := getNameQuote(dbType)
return q + name + q
}
// placeholder for sql stmt
// for postgres, it is $1, $2, $3...
// for mysql and sqlite3, it is ?,?,?...
func placeHolder(dbType string, index int) string {
if index <= 0 {
panic("the placeholder index should >= 1")
}
if dbType == "postgres" {
return fmt.Sprintf("$%d", index)
}
return "?"
}
func IsDBPath(path string) bool { func IsDBPath(path string) bool {
prefixes := []string{"sqlite3", "postgres", "mysql"} prefixes := []string{"sqlite3", "postgres", "mysql"}
for _, prefix := range prefixes { for _, prefix := range prefixes {
@ -26,18 +81,93 @@ func IsDBPath(path string) bool {
} }
func OpenDB(path string) (*sql.DB, error) { func OpenDB(path string) (*sql.DB, error) {
if val, ok := dbRecords[path]; ok {
return val.db, nil
}
prefixes := []string{"sqlite3", "postgres", "mysql"} prefixes := []string{"sqlite3", "postgres", "mysql"}
for _, prefix := range prefixes { for _, prefix := range prefixes {
if strings.HasPrefix(path, prefix) { if strings.HasPrefix(path, prefix) {
return connectDB(prefix, path) db, err := connectDB(prefix, path)
if !checkError(err, "Fail to connect database "+path) {
dbRecords[path] = &DBConnectionRecord{db: db, dbType: prefix, path: path}
}
return db, err
} }
} }
return nil, errors.New("Unsupported db scheme") return nil, errors.New("Unsupported db scheme")
} }
type DatabaseResource struct {
scheme string
username string
password string
host string
port int
database string
}
func ParseDatabasePath(path string) (*DatabaseResource, error) {
pattern := "^(([a-zA-Z0-9]+)://)?(([^:]+)(:(.*))?@)?([a-zA-Z0-9_.]+)(:([0-9]+))?(/([a-zA-Z0-9_-]+))?$"
reg, err := regexp.Compile(pattern)
if checkError(err, "Fail to compile regexp") {
return nil, err
}
matches := reg.FindAllStringSubmatch(path, -1)
if matches == nil || len(matches) == 0 {
return nil, errors.New("Fail to parse database")
}
r := DatabaseResource{}
r.scheme = matches[0][2]
r.username = matches[0][4]
r.password = matches[0][6]
r.host = matches[0][7]
if len(matches[0][9]) > 0 {
port, err := strconv.Atoi(matches[0][9])
if !checkError(err, "Fail to parse port") {
r.port = port
}
}
r.database = matches[0][11]
return &r, nil
}
// mysql://user:password@host:port/db => mysql://user:password@tcp(host:port)/db
func transformMysqlPath(path string) (string, error) {
r, err := ParseDatabasePath(path)
if checkError(err, "Fail to parse database path") {
return "", err
}
scheme := r.scheme
if len(scheme) == 0 {
scheme = "mysql"
}
credential := ""
if len(r.username) > 0 && len(r.password) > 0 {
credential = r.username + ":" + r.password + "@"
} else if len(r.username) > 0 {
credential = r.username + "@"
}
host := r.host
port := r.port
if port == 0 {
port = 3306
}
if len(r.database) == 0 {
return "", errors.New("no database found in path")
}
return fmt.Sprintf("%s://%stcp(%s:%d)/%s", scheme, credential, host, port, r.database), nil
}
func connectDB(dbType string, path string) (*sql.DB, error) { func connectDB(dbType string, path string) (*sql.DB, error) {
if dbType == "sqlite3" && strings.HasPrefix(path, "sqlite3://") { if dbType == "sqlite3" && strings.HasPrefix(path, "sqlite3://") {
path = strings.TrimPrefix(path, "sqlite3://") path = strings.TrimPrefix(path, "sqlite3://")
} else if dbType == "mysql" && strings.HasPrefix(path, "mysql://") {
p, err := transformMysqlPath(path)
if checkError(err, "Fail to parse mysql path") {
return nil, err
}
path = p
path = strings.TrimPrefix(path, "mysql://")
} }
db, err := sql.Open(dbType, path) db, err := sql.Open(dbType, path)
if err == nil { if err == nil {
@ -47,20 +177,13 @@ func connectDB(dbType string, path string) (*sql.DB, error) {
} }
func NewStorage(db *sql.DB) (*storage.Storage, error) { func NewStorage(db *sql.DB) (*storage.Storage, error) {
dbType, err := GetDBType(db)
checkError(err, "Fail to get database type, maybe this sql.DB is not opened by OpenDB")
InitUserTable(db) userStore := users.NewStorage(newUsersBackend(db, dbType))
InitSharesTable(db) shareStore := share.NewStorage(newShareBackend(db, dbType))
InitSettingsTable(db) settingsStore := settings.NewStorage(newSettingsBackend(db, dbType))
authStore := auth.NewStorage(newAuthBackend(db, dbType), userStore)
userStore := users.NewStorage(usersBackend{db: db})
shareStore := share.NewStorage(shareBackend{db: db})
settingsStore := settings.NewStorage(settingsBackend{db: db})
authStore := auth.NewStorage(authBackend{db: db}, userStore)
err := SetSetting(db, "version", "2")
if checkError(err, "Fail to set version") {
return nil, err
}
storage := &storage.Storage{ storage := &storage.Storage{
Auth: authStore, Auth: authStore,

View File

@ -14,7 +14,8 @@ import (
) )
type usersBackend struct { type usersBackend struct {
db *sql.DB db *sql.DB
dbType string
} }
func PermFromString(s string) users.Permissions { func PermFromString(s string) users.Permissions {
@ -114,13 +115,24 @@ func createAdminUser() users.User {
return user return user
} }
func InitUserTable(db *sql.DB) error { func InitUserTable(db *sql.DB, dbType string) error {
sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS \"%s\" (id integer primary key, username text, password text, scope text, locale text, lockpassword integer, viewmode text, perm text, commands text, sorting text, rules text, hidedotfiles integer, dateformat integer, singleclick integer);", UsersTable) primaryKey := "integer primary key"
if dbType == "postgres" {
primaryKey = "serial primary key"
} else if dbType == "mysql" {
primaryKey = "int unsigned primary key auto_increment"
}
sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id %s, username text, password text, scope text, locale text, lockpassword integer, viewmode text, perm text, commands text, sorting text, rules text, hidedotfiles integer, dateformat integer, singleclick integer);", quoteName(dbType, UsersTable), primaryKey)
_, err := db.Exec(sql) _, err := db.Exec(sql)
checkError(err, "Fail to create users table") checkError(err, "Fail to create users table")
return err return err
} }
func newUsersBackend(db *sql.DB, dbType string) usersBackend {
InitUserTable(db, dbType)
return usersBackend{db: db, dbType: dbType}
}
func (s usersBackend) GetBy(i interface{}) (*users.User, error) { func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
columns := []string{"id", "username", "password", "scope", "locale", "lockpassword", "viewmode", "perm", "commands", "sorting", "rules", "hidedotfiles", "dateformat", "singleclick"} columns := []string{"id", "username", "password", "scope", "locale", "lockpassword", "viewmode", "perm", "commands", "sorting", "rules", "hidedotfiles", "dateformat", "singleclick"}
columnsStr := strings.Join(columns, ",") columnsStr := strings.Join(columns, ",")
@ -150,7 +162,7 @@ func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
dateformat := false dateformat := false
singleclick := false singleclick := false
user := users.User{} user := users.User{}
sql := fmt.Sprintf("SELECT %s FROM \"%s\" WHERE %s", columnsStr, UsersTable, conditionStr) sql := fmt.Sprintf("SELECT %s FROM %s WHERE %s", columnsStr, quoteName(s.dbType, UsersTable), conditionStr)
err := s.db.QueryRow(sql).Scan(&userID, &username, &password, &scope, &locale, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules, &hidedotfiles, &dateformat, &singleclick) err := s.db.QueryRow(sql).Scan(&userID, &username, &password, &scope, &locale, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules, &hidedotfiles, &dateformat, &singleclick)
if checkError(err, "") { if checkError(err, "") {
return nil, err return nil, err
@ -173,7 +185,7 @@ func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
} }
func (s usersBackend) Gets() ([]*users.User, error) { func (s usersBackend) Gets() ([]*users.User, error) {
sql := fmt.Sprintf("SELECT id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules FROM \"%s\"", UsersTable) sql := fmt.Sprintf("SELECT id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules FROM %s", quoteName(s.dbType, UsersTable))
rows, err := s.db.Query(sql) rows, err := s.db.Query(sql)
if checkError(err, "Fail to Query []*users.User") { if checkError(err, "Fail to Query []*users.User") {
return nil, err return nil, err
@ -217,8 +229,8 @@ func (s usersBackend) updateUser(id uint, user *users.User) error {
lockpassword = 1 lockpassword = 1
} }
sql := fmt.Sprintf( sql := fmt.Sprintf(
"UPDATE \"%s\" SET username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' WHERE id=%d", "UPDATE %s SET username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' WHERE id=%d",
UsersTable, quoteName(s.dbType, UsersTable),
user.Username, user.Username,
user.Password, user.Password,
user.Scope, user.Scope,
@ -259,7 +271,10 @@ func (s usersBackend) insertUser(user *users.User) error {
} }
columnStr := strings.Join(columns, ",") columnStr := strings.Join(columns, ",")
specStr := strings.Join(specs, ",") specStr := strings.Join(specs, ",")
sqlFormat := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)", UsersTable, columnStr, specStr) sqlFormat := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoteName(s.dbType, UsersTable), columnStr, specStr)
if s.dbType == "postgres" {
sqlFormat = sqlFormat + " RETURNING id;"
}
sql := fmt.Sprintf( sql := fmt.Sprintf(
sqlFormat, sqlFormat,
user.Username, user.Username,
@ -276,12 +291,19 @@ func (s usersBackend) insertUser(user *users.User) error {
boolToString(user.DateFormat), boolToString(user.DateFormat),
boolToString(user.SingleClick), boolToString(user.SingleClick),
) )
if s.dbType == "postgres" {
id := uint(0)
err := s.db.QueryRow(sql).Scan(&id)
if !checkError(err, "Fail to insert user") {
user.ID = id
}
return err
}
res, err := s.db.Exec(sql) res, err := s.db.Exec(sql)
if !checkError(err, "Fail to insert user") { if !checkError(err, "Fail to insert user") {
id, err2 := res.LastInsertId() id, err := res.LastInsertId()
if !checkError(err2, "Fail to fetch last insert id") { checkError(err, "Fail to get last inserted id")
user.ID = uint(id) user.ID = uint(id)
}
} }
return err return err
} }
@ -296,14 +318,14 @@ func (s usersBackend) Save(user *users.User) error {
} }
func (s usersBackend) DeleteByID(id uint) error { func (s usersBackend) DeleteByID(id uint) error {
sql := fmt.Sprintf("delete from \"%s\" where id=%d", UsersTable, id) sql := fmt.Sprintf("delete from %s where id=%d", quoteName(s.dbType, UsersTable), id)
_, err := s.db.Exec(sql) _, err := s.db.Exec(sql)
checkError(err, "Fail to delete User by id") checkError(err, "Fail to delete User by id")
return err return err
} }
func (s usersBackend) DeleteByUsername(username string) error { func (s usersBackend) DeleteByUsername(username string) error {
sql := fmt.Sprintf("delete from \"%s\" where username='%s'", UsersTable, username) sql := fmt.Sprintf("delete from %s where username='%s'", quoteName(s.dbType, UsersTable), username)
_, err := s.db.Exec(sql) _, err := s.db.Exec(sql)
checkError(err, "Fail to delete user by username") checkError(err, "Fail to delete user by username")
return err return err
@ -323,15 +345,15 @@ func (s usersBackend) Update(u *users.User, fields ...string) error {
val := userField.Interface() val := userField.Interface()
typeStr := reflect.TypeOf(val).Kind().String() typeStr := reflect.TypeOf(val).Kind().String()
if typeStr == "string" { if typeStr == "string" {
setItems = append(setItems, fmt.Sprintf("\"%s\"='%s'", field, val)) setItems = append(setItems, fmt.Sprintf("%s='%s'", field, val))
} else if typeStr == "bool" { } else if typeStr == "bool" {
setItems = append(setItems, fmt.Sprintf("\"%s\"=%s", field, boolToString(val.(bool)))) setItems = append(setItems, fmt.Sprintf("%s=%s", field, boolToString(val.(bool))))
} else { } else {
// TODO // TODO
setItems = append(setItems, fmt.Sprintf("\"%s\"=%s", field, val)) setItems = append(setItems, fmt.Sprintf("%s=%s", field, val))
} }
} }
sql := fmt.Sprintf("UPDATE \"%s\" SET %s WHERE id=%d", UsersTable, strings.Join(setItems, ","), u.ID) sql := fmt.Sprintf("UPDATE %s SET %s WHERE id=%d", UsersTable, strings.Join(setItems, ","), u.ID)
_, err := s.db.Exec(sql) _, err := s.db.Exec(sql)
checkError(err, "Fail to update user") checkError(err, "Fail to update user")
return err return err