feat: test for mysql and postgres
This commit is contained in:
parent
5fdf5d0eca
commit
ea02515468
64
scripts/test-sql.sh
Executable file
64
scripts/test-sql.sh
Executable file
@ -0,0 +1,64 @@
|
||||
#!/bin/bash
|
||||
|
||||
# stop_docker name
|
||||
stop_docker() {
|
||||
sudo docker stop $1 1>/dev/null 2>&1
|
||||
sudo docker rm $1 1>/dev/null 2>&1
|
||||
}
|
||||
|
||||
# start_mysql name password port
|
||||
start_mysql() {
|
||||
stop_docker $1
|
||||
sudo docker run --rm --name $1 -e MYSQL_ROOT_PASSWORD=$2 -p $3:3306 -d mysql
|
||||
}
|
||||
|
||||
# start_postgres name password port
|
||||
start_postgres() {
|
||||
stop_docker $1
|
||||
sudo docker run --rm --name $1 -e POSTGRES_PASSWORD=$2 -p $3:5432 -d postgres
|
||||
}
|
||||
|
||||
|
||||
test_sqlite() {
|
||||
rm -f test.db
|
||||
./filebrowser -a 0.0.0.0 -d sqlite3://test.db
|
||||
}
|
||||
|
||||
|
||||
test_postgres() {
|
||||
start_postgres test-postgres postgres 5433
|
||||
sleep 30
|
||||
./filebrowser -a 0.0.0.0 -d postgres://postgres:postgres@127.0.0.1:5433/postgres?sslmode=disable
|
||||
}
|
||||
|
||||
|
||||
test_mysql() {
|
||||
start_mysql test-mysql root 3307
|
||||
sleep 60
|
||||
./filebrowser -a 0.0.0.0 -d 'mysql://root:root@127.0.0.1:3307/mysql'
|
||||
}
|
||||
|
||||
help() {
|
||||
echo "USAGE: $0 sqlite|mysql|postgres"
|
||||
exit 1
|
||||
}
|
||||
|
||||
if (( $# == 0 )); then
|
||||
help
|
||||
fi
|
||||
|
||||
case $1 in
|
||||
sqlite)
|
||||
test_sqlite
|
||||
;;
|
||||
mysql)
|
||||
test_mysql
|
||||
;;
|
||||
postgres)
|
||||
test_postgres
|
||||
;;
|
||||
*)
|
||||
help
|
||||
esac
|
||||
|
||||
|
||||
@ -11,7 +11,8 @@ import (
|
||||
)
|
||||
|
||||
type authBackend struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
dbType string
|
||||
}
|
||||
|
||||
func (s authBackend) Get(t settings.AuthMethod) (auth.Auther, error) {
|
||||
@ -38,5 +39,9 @@ func (s authBackend) Save(a auth.Auther) error {
|
||||
if checkError(err, "Fail to save auth.Auther") {
|
||||
return err
|
||||
}
|
||||
return SetSetting(s.db, "auther", string(val))
|
||||
return SetSetting(s.db, s.dbType, "auther", string(val))
|
||||
}
|
||||
|
||||
func newAuthBackend(db *sql.DB, dbType string) authBackend {
|
||||
return authBackend{db: db, dbType: dbType}
|
||||
}
|
||||
|
||||
@ -28,7 +28,7 @@ func cloneServer(server settings.Server) settings.Server {
|
||||
}
|
||||
|
||||
func (s settingsBackend) GetServer() (*settings.Server, error) {
|
||||
sql := fmt.Sprintf("select key, value from %s", SettingsTable)
|
||||
sql := fmt.Sprintf("select %s, value from %s", quoteName(s.dbType, "key"), quoteName(s.dbType, SettingsTable))
|
||||
rows, err := s.db.Query(sql)
|
||||
if checkError(err, "Fail to Query for GetServer") {
|
||||
return nil, err
|
||||
@ -93,7 +93,11 @@ func (s settingsBackend) SaveServer(ss *settings.Server) error {
|
||||
if checkError(err, "Fail to begin db transaction") {
|
||||
return err
|
||||
}
|
||||
sql := fmt.Sprintf("INSERT INTO \"%s\" (key, value) VALUES($1,$2)", SettingsTable)
|
||||
table := quoteName(s.dbType, SettingsTable)
|
||||
k := quoteName(s.dbType, "key")
|
||||
p1 := placeHolder(s.dbType, 1)
|
||||
p2 := placeHolder(s.dbType, 2)
|
||||
sql := fmt.Sprintf("INSERT INTO %s (%s, value) VALUES(%s,%s)", table, k, p1, p2)
|
||||
for i, field := range fields {
|
||||
stmt, err := s.db.Prepare(sql)
|
||||
defer stmt.Close()
|
||||
|
||||
@ -17,11 +17,12 @@ func init() {
|
||||
}
|
||||
|
||||
type settingsBackend struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
dbType string
|
||||
}
|
||||
|
||||
func InitSettingsTable(db *sql.DB) error {
|
||||
sql := fmt.Sprintf("create table if not exists \"%s\"(key text primary key, value text);", SettingsTable)
|
||||
func InitSettingsTable(db *sql.DB, dbType string) error {
|
||||
sql := fmt.Sprintf("create table if not exists %s (%s varchar(128) primary key, value text);", quoteName(dbType, SettingsTable), quoteName(dbType, "key"))
|
||||
_, err := db.Exec(sql)
|
||||
checkError(err, "Fail to create table settings")
|
||||
return err
|
||||
@ -136,7 +137,7 @@ func boolToString(b bool) string {
|
||||
}
|
||||
|
||||
func (s settingsBackend) Get() (*settings.Settings, error) {
|
||||
sql := fmt.Sprintf("select key, value from \"%s\";", SettingsTable)
|
||||
sql := fmt.Sprintf("select %s, value from %s;", quoteName(s.dbType, "key"), quoteName(s.dbType, SettingsTable))
|
||||
rows, err := s.db.Query(sql)
|
||||
if checkError(err, "Fail to Query settings.Settings") {
|
||||
return nil, err
|
||||
@ -196,13 +197,16 @@ func (s settingsBackend) Save(ss *settings.Settings) error {
|
||||
if checkError(err, "Fail to begin db transaction") {
|
||||
return err
|
||||
}
|
||||
table := quoteName(s.dbType, SettingsTable)
|
||||
k := quoteName(s.dbType, "key")
|
||||
p1 := placeHolder(s.dbType, 1)
|
||||
p2 := placeHolder(s.dbType, 2)
|
||||
for i, field := range fields {
|
||||
exists := ContainKey(s.db, field)
|
||||
sql := fmt.Sprintf("INSERT INTO \"%s\" (value, key) VALUES($1,$2);", SettingsTable)
|
||||
exists := ContainKey(s.db, s.dbType, field)
|
||||
sql := fmt.Sprintf("INSERT INTO %s (value, %s) VALUES(%s,%s);", table, k, p1, p2)
|
||||
if exists {
|
||||
sql = fmt.Sprintf("UPDATE \"%s\" set value = $1 where key = $2;", SettingsTable)
|
||||
sql = fmt.Sprintf("UPDATE %s set value = %s where %s = %s;", table, p1, k, p2)
|
||||
}
|
||||
fmt.Println(sql)
|
||||
stmt, err := s.db.Prepare(sql)
|
||||
defer stmt.Close()
|
||||
if checkError(err, "Fail to prepare statement") {
|
||||
@ -274,31 +278,37 @@ func cloneSettings(s settings.Settings) settings.Settings {
|
||||
return s1
|
||||
}
|
||||
|
||||
func SetSetting(db *sql.DB, key string, value string) error {
|
||||
sql := fmt.Sprintf("select count(key) from \"%s\" where key = '%s';", SettingsTable, key)
|
||||
func SetSetting(db *sql.DB, dbType string, key string, value string) error {
|
||||
t := quoteName(dbType, SettingsTable)
|
||||
k := quoteName(dbType, "key")
|
||||
sql := fmt.Sprintf("select count(%s) from %s where %s = '%s';", k, t, k, key)
|
||||
count := 0
|
||||
err := db.QueryRow(sql).Scan(&count)
|
||||
if checkError(err, "Fail to QueryRow for key="+key) {
|
||||
return err
|
||||
}
|
||||
if count == 0 {
|
||||
return addSetting(db, key, value)
|
||||
return addSetting(db, dbType, key, value)
|
||||
}
|
||||
return updateSetting(db, key, value)
|
||||
return updateSetting(db, dbType, key, value)
|
||||
}
|
||||
|
||||
func GetSetting(db *sql.DB, key string) string {
|
||||
sql := fmt.Sprintf("select value from \"%s\" where key = '%s';", SettingsTable, key)
|
||||
func GetSetting(db *sql.DB, dbType string, key string) string {
|
||||
sql := fmt.Sprintf("select value from %s where %s = '%s';", quoteName(dbType, SettingsTable), quoteName(dbType, "key"), key)
|
||||
value := ""
|
||||
err := db.QueryRow(sql).Scan(&value)
|
||||
if checkError(err, "Fail to QueryRow for key "+key) {
|
||||
if checkError(err, "") {
|
||||
return value
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func addSetting(db *sql.DB, key string, value string) error {
|
||||
sql := fmt.Sprintf("insert into \"%s\" (key, value) values($1, $2);", SettingsTable)
|
||||
func addSetting(db *sql.DB, dbType string, key string, value string) error {
|
||||
table := quoteName(dbType, SettingsTable)
|
||||
k := quoteName(dbType, "key")
|
||||
p1 := placeHolder(dbType, 1)
|
||||
p2 := placeHolder(dbType, 2)
|
||||
sql := fmt.Sprintf("insert into %s (%s, value) values(%s, %s);", table, k, p1, p2)
|
||||
stmt, err := db.Prepare(sql)
|
||||
if checkError(err, "Fail to prepare sql") {
|
||||
return err
|
||||
@ -308,8 +318,14 @@ func addSetting(db *sql.DB, key string, value string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func updateSetting(db *sql.DB, key string, value string) error {
|
||||
sql := fmt.Sprintf("update \"%s\" set value = $1 where key = $2;", SettingsTable)
|
||||
func updateSetting(db *sql.DB, dbType string, key string, value string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"update %s set value = %s where %s = %s;",
|
||||
quoteName(dbType, SettingsTable),
|
||||
placeHolder(dbType, 1),
|
||||
quoteName(dbType, "key"),
|
||||
placeHolder(dbType, 2),
|
||||
)
|
||||
stmt, err := db.Prepare(sql)
|
||||
if checkError(err, "Fail to prepare sql") {
|
||||
return err
|
||||
@ -320,23 +336,32 @@ func updateSetting(db *sql.DB, key string, value string) error {
|
||||
}
|
||||
|
||||
func HadSetting(db *sql.DB) bool {
|
||||
key := GetSetting(db, "Key")
|
||||
dbType, err := GetDBType(db)
|
||||
if checkError(err, "Fail to get db type") {
|
||||
return false
|
||||
}
|
||||
key := GetSetting(db, dbType, "Key")
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ContainKey(db *sql.DB, key string) bool {
|
||||
sql := fmt.Sprintf("select value from \"%s\" where key = '%s';", SettingsTable, key)
|
||||
func ContainKey(db *sql.DB, dbType string, key string) bool {
|
||||
sql := fmt.Sprintf("select value from %s where %s = '%s';", quoteName(dbType, SettingsTable), quoteName(dbType, "key"), key)
|
||||
value := ""
|
||||
err := db.QueryRow(sql).Scan(&value)
|
||||
if checkError(err, "Fail to QueryRow for key "+key) {
|
||||
if checkError(err, "") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func HadSettingOfKey(db *sql.DB, key string) bool {
|
||||
return GetSetting(db, "Key") == key
|
||||
func HadSettingOfKey(db *sql.DB, dbType string, key string) bool {
|
||||
return GetSetting(db, dbType, "Key") == key
|
||||
}
|
||||
|
||||
func newSettingsBackend(db *sql.DB, dbType string) settingsBackend {
|
||||
InitSettingsTable(db, dbType)
|
||||
return settingsBackend{db: db, dbType: dbType}
|
||||
}
|
||||
|
||||
@ -8,15 +8,16 @@ import (
|
||||
)
|
||||
|
||||
type shareBackend struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
dbType string
|
||||
}
|
||||
|
||||
type linkRecord interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}
|
||||
|
||||
func InitSharesTable(db *sql.DB) error {
|
||||
sql := fmt.Sprintf("create table if not exists \"%s\" (hash text, path text, userid integer, expire integer, passwordhash text, token text)", SharesTable)
|
||||
func InitSharesTable(db *sql.DB, dbType string) error {
|
||||
sql := fmt.Sprintf("create table if not exists %s (hash text, path text, userid integer, expire integer, passwordhash text, token text)", quoteName(dbType, SharesTable))
|
||||
_, err := db.Exec(sql)
|
||||
checkError(err, "Fail to InitSharesTable")
|
||||
return err
|
||||
@ -43,8 +44,8 @@ func parseLink(row linkRecord) (*share.Link, error) {
|
||||
return &link, nil
|
||||
}
|
||||
|
||||
func queryLinks(db *sql.DB, condition string) ([]*share.Link, error) {
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\"", SharesTable)
|
||||
func queryLinks(db *sql.DB, dbType string, condition string) ([]*share.Link, error) {
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s", quoteName(dbType, SharesTable))
|
||||
if len(condition) > 0 {
|
||||
sql = sql + " where " + condition
|
||||
}
|
||||
@ -64,37 +65,42 @@ func queryLinks(db *sql.DB, condition string) ([]*share.Link, error) {
|
||||
}
|
||||
|
||||
func (s shareBackend) All() ([]*share.Link, error) {
|
||||
return queryLinks(s.db, "")
|
||||
return queryLinks(s.db, s.dbType, "")
|
||||
}
|
||||
|
||||
func (s shareBackend) FindByUserID(id uint) ([]*share.Link, error) {
|
||||
condition := fmt.Sprintf("userid=%d", id)
|
||||
return queryLinks(s.db, condition)
|
||||
return queryLinks(s.db, s.dbType, condition)
|
||||
}
|
||||
|
||||
func (s shareBackend) GetByHash(hash string) (*share.Link, error) {
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\" where hash='%s'", SharesTable, hash)
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s where hash='%s'", quoteName(s.dbType, SharesTable), hash)
|
||||
return parseLink(s.db.QueryRow(sql))
|
||||
}
|
||||
|
||||
func (s shareBackend) GetPermanent(path string, id uint) (*share.Link, error) {
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\" where path='%s' and userid=%d", SharesTable, path, id)
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s where path='%s' and userid=%d", quoteName(s.dbType, SharesTable), path, id)
|
||||
return parseLink(s.db.QueryRow(sql))
|
||||
}
|
||||
|
||||
func (s shareBackend) Gets(path string, id uint) ([]*share.Link, error) {
|
||||
condition := fmt.Sprintf("userid=%d and path='%s'", id, path)
|
||||
return queryLinks(s.db, condition)
|
||||
return queryLinks(s.db, s.dbType, condition)
|
||||
}
|
||||
func (s shareBackend) Save(l *share.Link) error {
|
||||
sql := fmt.Sprintf("insert into \"%s\" (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", SharesTable, l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token)
|
||||
sql := fmt.Sprintf("insert into %s (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", quoteName(s.dbType, SharesTable), l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to Save share")
|
||||
return err
|
||||
}
|
||||
func (s shareBackend) Delete(hash string) error {
|
||||
sql := fmt.Sprintf("DELETE FROM \"%s\" WHERE hash='%s'", SharesTable, hash)
|
||||
sql := fmt.Sprintf("DELETE FROM %s WHERE hash='%s'", quoteName(s.dbType, SharesTable), hash)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to Delete share")
|
||||
return err
|
||||
}
|
||||
|
||||
func newShareBackend(db *sql.DB, dbType string) shareBackend {
|
||||
InitSharesTable(db, dbType)
|
||||
return shareBackend{db: db, dbType: dbType}
|
||||
}
|
||||
|
||||
@ -3,6 +3,9 @@ package sql
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/filebrowser/filebrowser/v2/auth"
|
||||
@ -15,6 +18,58 @@ import (
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
}
|
||||
|
||||
type DBConnectionRecord struct {
|
||||
db *sql.DB
|
||||
dbType string
|
||||
path string
|
||||
}
|
||||
|
||||
var (
|
||||
dbRecords map[string]*DBConnectionRecord = map[string]*DBConnectionRecord{}
|
||||
)
|
||||
|
||||
// GetDBType used to get the driver type of a sql.DB
|
||||
// It is based on existing dbRecords
|
||||
// All sql.DB should opened by OpenDB
|
||||
func GetDBType(db *sql.DB) (string, error) {
|
||||
for _, record := range dbRecords {
|
||||
if record.db == db {
|
||||
return record.dbType, nil
|
||||
}
|
||||
}
|
||||
return "", errors.New("No such database open by this module")
|
||||
}
|
||||
|
||||
func getNameQuote(dbType string) string {
|
||||
if dbType == "mysql" {
|
||||
return "`"
|
||||
}
|
||||
return "\""
|
||||
}
|
||||
|
||||
// for mysql, it is ``
|
||||
// for postgres and sqlite, it is ""
|
||||
func quoteName(dbType string, name string) string {
|
||||
q := getNameQuote(dbType)
|
||||
return q + name + q
|
||||
}
|
||||
|
||||
// placeholder for sql stmt
|
||||
// for postgres, it is $1, $2, $3...
|
||||
// for mysql and sqlite3, it is ?,?,?...
|
||||
func placeHolder(dbType string, index int) string {
|
||||
if index <= 0 {
|
||||
panic("the placeholder index should >= 1")
|
||||
}
|
||||
if dbType == "postgres" {
|
||||
return fmt.Sprintf("$%d", index)
|
||||
}
|
||||
return "?"
|
||||
}
|
||||
|
||||
func IsDBPath(path string) bool {
|
||||
prefixes := []string{"sqlite3", "postgres", "mysql"}
|
||||
for _, prefix := range prefixes {
|
||||
@ -26,18 +81,93 @@ func IsDBPath(path string) bool {
|
||||
}
|
||||
|
||||
func OpenDB(path string) (*sql.DB, error) {
|
||||
if val, ok := dbRecords[path]; ok {
|
||||
return val.db, nil
|
||||
}
|
||||
prefixes := []string{"sqlite3", "postgres", "mysql"}
|
||||
for _, prefix := range prefixes {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return connectDB(prefix, path)
|
||||
db, err := connectDB(prefix, path)
|
||||
if !checkError(err, "Fail to connect database "+path) {
|
||||
dbRecords[path] = &DBConnectionRecord{db: db, dbType: prefix, path: path}
|
||||
}
|
||||
return db, err
|
||||
}
|
||||
}
|
||||
return nil, errors.New("Unsupported db scheme")
|
||||
}
|
||||
|
||||
type DatabaseResource struct {
|
||||
scheme string
|
||||
username string
|
||||
password string
|
||||
host string
|
||||
port int
|
||||
database string
|
||||
}
|
||||
|
||||
func ParseDatabasePath(path string) (*DatabaseResource, error) {
|
||||
pattern := "^(([a-zA-Z0-9]+)://)?(([^:]+)(:(.*))?@)?([a-zA-Z0-9_.]+)(:([0-9]+))?(/([a-zA-Z0-9_-]+))?$"
|
||||
reg, err := regexp.Compile(pattern)
|
||||
if checkError(err, "Fail to compile regexp") {
|
||||
return nil, err
|
||||
}
|
||||
matches := reg.FindAllStringSubmatch(path, -1)
|
||||
if matches == nil || len(matches) == 0 {
|
||||
return nil, errors.New("Fail to parse database")
|
||||
}
|
||||
r := DatabaseResource{}
|
||||
r.scheme = matches[0][2]
|
||||
r.username = matches[0][4]
|
||||
r.password = matches[0][6]
|
||||
r.host = matches[0][7]
|
||||
if len(matches[0][9]) > 0 {
|
||||
port, err := strconv.Atoi(matches[0][9])
|
||||
if !checkError(err, "Fail to parse port") {
|
||||
r.port = port
|
||||
}
|
||||
}
|
||||
r.database = matches[0][11]
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// mysql://user:password@host:port/db => mysql://user:password@tcp(host:port)/db
|
||||
func transformMysqlPath(path string) (string, error) {
|
||||
r, err := ParseDatabasePath(path)
|
||||
if checkError(err, "Fail to parse database path") {
|
||||
return "", err
|
||||
}
|
||||
scheme := r.scheme
|
||||
if len(scheme) == 0 {
|
||||
scheme = "mysql"
|
||||
}
|
||||
credential := ""
|
||||
if len(r.username) > 0 && len(r.password) > 0 {
|
||||
credential = r.username + ":" + r.password + "@"
|
||||
} else if len(r.username) > 0 {
|
||||
credential = r.username + "@"
|
||||
}
|
||||
host := r.host
|
||||
port := r.port
|
||||
if port == 0 {
|
||||
port = 3306
|
||||
}
|
||||
if len(r.database) == 0 {
|
||||
return "", errors.New("no database found in path")
|
||||
}
|
||||
return fmt.Sprintf("%s://%stcp(%s:%d)/%s", scheme, credential, host, port, r.database), nil
|
||||
}
|
||||
|
||||
func connectDB(dbType string, path string) (*sql.DB, error) {
|
||||
if dbType == "sqlite3" && strings.HasPrefix(path, "sqlite3://") {
|
||||
path = strings.TrimPrefix(path, "sqlite3://")
|
||||
} else if dbType == "mysql" && strings.HasPrefix(path, "mysql://") {
|
||||
p, err := transformMysqlPath(path)
|
||||
if checkError(err, "Fail to parse mysql path") {
|
||||
return nil, err
|
||||
}
|
||||
path = p
|
||||
path = strings.TrimPrefix(path, "mysql://")
|
||||
}
|
||||
db, err := sql.Open(dbType, path)
|
||||
if err == nil {
|
||||
@ -47,20 +177,13 @@ func connectDB(dbType string, path string) (*sql.DB, error) {
|
||||
}
|
||||
|
||||
func NewStorage(db *sql.DB) (*storage.Storage, error) {
|
||||
dbType, err := GetDBType(db)
|
||||
checkError(err, "Fail to get database type, maybe this sql.DB is not opened by OpenDB")
|
||||
|
||||
InitUserTable(db)
|
||||
InitSharesTable(db)
|
||||
InitSettingsTable(db)
|
||||
|
||||
userStore := users.NewStorage(usersBackend{db: db})
|
||||
shareStore := share.NewStorage(shareBackend{db: db})
|
||||
settingsStore := settings.NewStorage(settingsBackend{db: db})
|
||||
authStore := auth.NewStorage(authBackend{db: db}, userStore)
|
||||
|
||||
err := SetSetting(db, "version", "2")
|
||||
if checkError(err, "Fail to set version") {
|
||||
return nil, err
|
||||
}
|
||||
userStore := users.NewStorage(newUsersBackend(db, dbType))
|
||||
shareStore := share.NewStorage(newShareBackend(db, dbType))
|
||||
settingsStore := settings.NewStorage(newSettingsBackend(db, dbType))
|
||||
authStore := auth.NewStorage(newAuthBackend(db, dbType), userStore)
|
||||
|
||||
storage := &storage.Storage{
|
||||
Auth: authStore,
|
||||
|
||||
@ -14,7 +14,8 @@ import (
|
||||
)
|
||||
|
||||
type usersBackend struct {
|
||||
db *sql.DB
|
||||
db *sql.DB
|
||||
dbType string
|
||||
}
|
||||
|
||||
func PermFromString(s string) users.Permissions {
|
||||
@ -114,13 +115,24 @@ func createAdminUser() users.User {
|
||||
return user
|
||||
}
|
||||
|
||||
func InitUserTable(db *sql.DB) error {
|
||||
sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS \"%s\" (id integer primary key, username text, password text, scope text, locale text, lockpassword integer, viewmode text, perm text, commands text, sorting text, rules text, hidedotfiles integer, dateformat integer, singleclick integer);", UsersTable)
|
||||
func InitUserTable(db *sql.DB, dbType string) error {
|
||||
primaryKey := "integer primary key"
|
||||
if dbType == "postgres" {
|
||||
primaryKey = "serial primary key"
|
||||
} else if dbType == "mysql" {
|
||||
primaryKey = "int unsigned primary key auto_increment"
|
||||
}
|
||||
sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id %s, username text, password text, scope text, locale text, lockpassword integer, viewmode text, perm text, commands text, sorting text, rules text, hidedotfiles integer, dateformat integer, singleclick integer);", quoteName(dbType, UsersTable), primaryKey)
|
||||
_, err := db.Exec(sql)
|
||||
checkError(err, "Fail to create users table")
|
||||
return err
|
||||
}
|
||||
|
||||
func newUsersBackend(db *sql.DB, dbType string) usersBackend {
|
||||
InitUserTable(db, dbType)
|
||||
return usersBackend{db: db, dbType: dbType}
|
||||
}
|
||||
|
||||
func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
|
||||
columns := []string{"id", "username", "password", "scope", "locale", "lockpassword", "viewmode", "perm", "commands", "sorting", "rules", "hidedotfiles", "dateformat", "singleclick"}
|
||||
columnsStr := strings.Join(columns, ",")
|
||||
@ -150,7 +162,7 @@ func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
|
||||
dateformat := false
|
||||
singleclick := false
|
||||
user := users.User{}
|
||||
sql := fmt.Sprintf("SELECT %s FROM \"%s\" WHERE %s", columnsStr, UsersTable, conditionStr)
|
||||
sql := fmt.Sprintf("SELECT %s FROM %s WHERE %s", columnsStr, quoteName(s.dbType, UsersTable), conditionStr)
|
||||
err := s.db.QueryRow(sql).Scan(&userID, &username, &password, &scope, &locale, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules, &hidedotfiles, &dateformat, &singleclick)
|
||||
if checkError(err, "") {
|
||||
return nil, err
|
||||
@ -173,7 +185,7 @@ func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
|
||||
}
|
||||
|
||||
func (s usersBackend) Gets() ([]*users.User, error) {
|
||||
sql := fmt.Sprintf("SELECT id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules FROM \"%s\"", UsersTable)
|
||||
sql := fmt.Sprintf("SELECT id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules FROM %s", quoteName(s.dbType, UsersTable))
|
||||
rows, err := s.db.Query(sql)
|
||||
if checkError(err, "Fail to Query []*users.User") {
|
||||
return nil, err
|
||||
@ -217,8 +229,8 @@ func (s usersBackend) updateUser(id uint, user *users.User) error {
|
||||
lockpassword = 1
|
||||
}
|
||||
sql := fmt.Sprintf(
|
||||
"UPDATE \"%s\" SET username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' WHERE id=%d",
|
||||
UsersTable,
|
||||
"UPDATE %s SET username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' WHERE id=%d",
|
||||
quoteName(s.dbType, UsersTable),
|
||||
user.Username,
|
||||
user.Password,
|
||||
user.Scope,
|
||||
@ -259,7 +271,10 @@ func (s usersBackend) insertUser(user *users.User) error {
|
||||
}
|
||||
columnStr := strings.Join(columns, ",")
|
||||
specStr := strings.Join(specs, ",")
|
||||
sqlFormat := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)", UsersTable, columnStr, specStr)
|
||||
sqlFormat := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoteName(s.dbType, UsersTable), columnStr, specStr)
|
||||
if s.dbType == "postgres" {
|
||||
sqlFormat = sqlFormat + " RETURNING id;"
|
||||
}
|
||||
sql := fmt.Sprintf(
|
||||
sqlFormat,
|
||||
user.Username,
|
||||
@ -276,12 +291,19 @@ func (s usersBackend) insertUser(user *users.User) error {
|
||||
boolToString(user.DateFormat),
|
||||
boolToString(user.SingleClick),
|
||||
)
|
||||
if s.dbType == "postgres" {
|
||||
id := uint(0)
|
||||
err := s.db.QueryRow(sql).Scan(&id)
|
||||
if !checkError(err, "Fail to insert user") {
|
||||
user.ID = id
|
||||
}
|
||||
return err
|
||||
}
|
||||
res, err := s.db.Exec(sql)
|
||||
if !checkError(err, "Fail to insert user") {
|
||||
id, err2 := res.LastInsertId()
|
||||
if !checkError(err2, "Fail to fetch last insert id") {
|
||||
user.ID = uint(id)
|
||||
}
|
||||
id, err := res.LastInsertId()
|
||||
checkError(err, "Fail to get last inserted id")
|
||||
user.ID = uint(id)
|
||||
}
|
||||
return err
|
||||
}
|
||||
@ -296,14 +318,14 @@ func (s usersBackend) Save(user *users.User) error {
|
||||
}
|
||||
|
||||
func (s usersBackend) DeleteByID(id uint) error {
|
||||
sql := fmt.Sprintf("delete from \"%s\" where id=%d", UsersTable, id)
|
||||
sql := fmt.Sprintf("delete from %s where id=%d", quoteName(s.dbType, UsersTable), id)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to delete User by id")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s usersBackend) DeleteByUsername(username string) error {
|
||||
sql := fmt.Sprintf("delete from \"%s\" where username='%s'", UsersTable, username)
|
||||
sql := fmt.Sprintf("delete from %s where username='%s'", quoteName(s.dbType, UsersTable), username)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to delete user by username")
|
||||
return err
|
||||
@ -323,15 +345,15 @@ func (s usersBackend) Update(u *users.User, fields ...string) error {
|
||||
val := userField.Interface()
|
||||
typeStr := reflect.TypeOf(val).Kind().String()
|
||||
if typeStr == "string" {
|
||||
setItems = append(setItems, fmt.Sprintf("\"%s\"='%s'", field, val))
|
||||
setItems = append(setItems, fmt.Sprintf("%s='%s'", field, val))
|
||||
} else if typeStr == "bool" {
|
||||
setItems = append(setItems, fmt.Sprintf("\"%s\"=%s", field, boolToString(val.(bool))))
|
||||
setItems = append(setItems, fmt.Sprintf("%s=%s", field, boolToString(val.(bool))))
|
||||
} else {
|
||||
// TODO
|
||||
setItems = append(setItems, fmt.Sprintf("\"%s\"=%s", field, val))
|
||||
setItems = append(setItems, fmt.Sprintf("%s=%s", field, val))
|
||||
}
|
||||
}
|
||||
sql := fmt.Sprintf("UPDATE \"%s\" SET %s WHERE id=%d", UsersTable, strings.Join(setItems, ","), u.ID)
|
||||
sql := fmt.Sprintf("UPDATE %s SET %s WHERE id=%d", UsersTable, strings.Join(setItems, ","), u.ID)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to update user")
|
||||
return err
|
||||
|
||||
Loading…
Reference in New Issue
Block a user