feat: test for mysql and postgres
This commit is contained in:
parent
5fdf5d0eca
commit
ea02515468
64
scripts/test-sql.sh
Executable file
64
scripts/test-sql.sh
Executable file
@ -0,0 +1,64 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# stop_docker name
|
||||||
|
stop_docker() {
|
||||||
|
sudo docker stop $1 1>/dev/null 2>&1
|
||||||
|
sudo docker rm $1 1>/dev/null 2>&1
|
||||||
|
}
|
||||||
|
|
||||||
|
# start_mysql name password port
|
||||||
|
start_mysql() {
|
||||||
|
stop_docker $1
|
||||||
|
sudo docker run --rm --name $1 -e MYSQL_ROOT_PASSWORD=$2 -p $3:3306 -d mysql
|
||||||
|
}
|
||||||
|
|
||||||
|
# start_postgres name password port
|
||||||
|
start_postgres() {
|
||||||
|
stop_docker $1
|
||||||
|
sudo docker run --rm --name $1 -e POSTGRES_PASSWORD=$2 -p $3:5432 -d postgres
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
test_sqlite() {
|
||||||
|
rm -f test.db
|
||||||
|
./filebrowser -a 0.0.0.0 -d sqlite3://test.db
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
test_postgres() {
|
||||||
|
start_postgres test-postgres postgres 5433
|
||||||
|
sleep 30
|
||||||
|
./filebrowser -a 0.0.0.0 -d postgres://postgres:postgres@127.0.0.1:5433/postgres?sslmode=disable
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
test_mysql() {
|
||||||
|
start_mysql test-mysql root 3307
|
||||||
|
sleep 60
|
||||||
|
./filebrowser -a 0.0.0.0 -d 'mysql://root:root@127.0.0.1:3307/mysql'
|
||||||
|
}
|
||||||
|
|
||||||
|
help() {
|
||||||
|
echo "USAGE: $0 sqlite|mysql|postgres"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if (( $# == 0 )); then
|
||||||
|
help
|
||||||
|
fi
|
||||||
|
|
||||||
|
case $1 in
|
||||||
|
sqlite)
|
||||||
|
test_sqlite
|
||||||
|
;;
|
||||||
|
mysql)
|
||||||
|
test_mysql
|
||||||
|
;;
|
||||||
|
postgres)
|
||||||
|
test_postgres
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
help
|
||||||
|
esac
|
||||||
|
|
||||||
|
|
||||||
@ -11,7 +11,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type authBackend struct {
|
type authBackend struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
dbType string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s authBackend) Get(t settings.AuthMethod) (auth.Auther, error) {
|
func (s authBackend) Get(t settings.AuthMethod) (auth.Auther, error) {
|
||||||
@ -38,5 +39,9 @@ func (s authBackend) Save(a auth.Auther) error {
|
|||||||
if checkError(err, "Fail to save auth.Auther") {
|
if checkError(err, "Fail to save auth.Auther") {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return SetSetting(s.db, "auther", string(val))
|
return SetSetting(s.db, s.dbType, "auther", string(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthBackend(db *sql.DB, dbType string) authBackend {
|
||||||
|
return authBackend{db: db, dbType: dbType}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -28,7 +28,7 @@ func cloneServer(server settings.Server) settings.Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s settingsBackend) GetServer() (*settings.Server, error) {
|
func (s settingsBackend) GetServer() (*settings.Server, error) {
|
||||||
sql := fmt.Sprintf("select key, value from %s", SettingsTable)
|
sql := fmt.Sprintf("select %s, value from %s", quoteName(s.dbType, "key"), quoteName(s.dbType, SettingsTable))
|
||||||
rows, err := s.db.Query(sql)
|
rows, err := s.db.Query(sql)
|
||||||
if checkError(err, "Fail to Query for GetServer") {
|
if checkError(err, "Fail to Query for GetServer") {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -93,7 +93,11 @@ func (s settingsBackend) SaveServer(ss *settings.Server) error {
|
|||||||
if checkError(err, "Fail to begin db transaction") {
|
if checkError(err, "Fail to begin db transaction") {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
sql := fmt.Sprintf("INSERT INTO \"%s\" (key, value) VALUES($1,$2)", SettingsTable)
|
table := quoteName(s.dbType, SettingsTable)
|
||||||
|
k := quoteName(s.dbType, "key")
|
||||||
|
p1 := placeHolder(s.dbType, 1)
|
||||||
|
p2 := placeHolder(s.dbType, 2)
|
||||||
|
sql := fmt.Sprintf("INSERT INTO %s (%s, value) VALUES(%s,%s)", table, k, p1, p2)
|
||||||
for i, field := range fields {
|
for i, field := range fields {
|
||||||
stmt, err := s.db.Prepare(sql)
|
stmt, err := s.db.Prepare(sql)
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
|
|||||||
@ -17,11 +17,12 @@ func init() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type settingsBackend struct {
|
type settingsBackend struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
dbType string
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitSettingsTable(db *sql.DB) error {
|
func InitSettingsTable(db *sql.DB, dbType string) error {
|
||||||
sql := fmt.Sprintf("create table if not exists \"%s\"(key text primary key, value text);", SettingsTable)
|
sql := fmt.Sprintf("create table if not exists %s (%s varchar(128) primary key, value text);", quoteName(dbType, SettingsTable), quoteName(dbType, "key"))
|
||||||
_, err := db.Exec(sql)
|
_, err := db.Exec(sql)
|
||||||
checkError(err, "Fail to create table settings")
|
checkError(err, "Fail to create table settings")
|
||||||
return err
|
return err
|
||||||
@ -136,7 +137,7 @@ func boolToString(b bool) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s settingsBackend) Get() (*settings.Settings, error) {
|
func (s settingsBackend) Get() (*settings.Settings, error) {
|
||||||
sql := fmt.Sprintf("select key, value from \"%s\";", SettingsTable)
|
sql := fmt.Sprintf("select %s, value from %s;", quoteName(s.dbType, "key"), quoteName(s.dbType, SettingsTable))
|
||||||
rows, err := s.db.Query(sql)
|
rows, err := s.db.Query(sql)
|
||||||
if checkError(err, "Fail to Query settings.Settings") {
|
if checkError(err, "Fail to Query settings.Settings") {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -196,13 +197,16 @@ func (s settingsBackend) Save(ss *settings.Settings) error {
|
|||||||
if checkError(err, "Fail to begin db transaction") {
|
if checkError(err, "Fail to begin db transaction") {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
table := quoteName(s.dbType, SettingsTable)
|
||||||
|
k := quoteName(s.dbType, "key")
|
||||||
|
p1 := placeHolder(s.dbType, 1)
|
||||||
|
p2 := placeHolder(s.dbType, 2)
|
||||||
for i, field := range fields {
|
for i, field := range fields {
|
||||||
exists := ContainKey(s.db, field)
|
exists := ContainKey(s.db, s.dbType, field)
|
||||||
sql := fmt.Sprintf("INSERT INTO \"%s\" (value, key) VALUES($1,$2);", SettingsTable)
|
sql := fmt.Sprintf("INSERT INTO %s (value, %s) VALUES(%s,%s);", table, k, p1, p2)
|
||||||
if exists {
|
if exists {
|
||||||
sql = fmt.Sprintf("UPDATE \"%s\" set value = $1 where key = $2;", SettingsTable)
|
sql = fmt.Sprintf("UPDATE %s set value = %s where %s = %s;", table, p1, k, p2)
|
||||||
}
|
}
|
||||||
fmt.Println(sql)
|
|
||||||
stmt, err := s.db.Prepare(sql)
|
stmt, err := s.db.Prepare(sql)
|
||||||
defer stmt.Close()
|
defer stmt.Close()
|
||||||
if checkError(err, "Fail to prepare statement") {
|
if checkError(err, "Fail to prepare statement") {
|
||||||
@ -274,31 +278,37 @@ func cloneSettings(s settings.Settings) settings.Settings {
|
|||||||
return s1
|
return s1
|
||||||
}
|
}
|
||||||
|
|
||||||
func SetSetting(db *sql.DB, key string, value string) error {
|
func SetSetting(db *sql.DB, dbType string, key string, value string) error {
|
||||||
sql := fmt.Sprintf("select count(key) from \"%s\" where key = '%s';", SettingsTable, key)
|
t := quoteName(dbType, SettingsTable)
|
||||||
|
k := quoteName(dbType, "key")
|
||||||
|
sql := fmt.Sprintf("select count(%s) from %s where %s = '%s';", k, t, k, key)
|
||||||
count := 0
|
count := 0
|
||||||
err := db.QueryRow(sql).Scan(&count)
|
err := db.QueryRow(sql).Scan(&count)
|
||||||
if checkError(err, "Fail to QueryRow for key="+key) {
|
if checkError(err, "Fail to QueryRow for key="+key) {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if count == 0 {
|
if count == 0 {
|
||||||
return addSetting(db, key, value)
|
return addSetting(db, dbType, key, value)
|
||||||
}
|
}
|
||||||
return updateSetting(db, key, value)
|
return updateSetting(db, dbType, key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSetting(db *sql.DB, key string) string {
|
func GetSetting(db *sql.DB, dbType string, key string) string {
|
||||||
sql := fmt.Sprintf("select value from \"%s\" where key = '%s';", SettingsTable, key)
|
sql := fmt.Sprintf("select value from %s where %s = '%s';", quoteName(dbType, SettingsTable), quoteName(dbType, "key"), key)
|
||||||
value := ""
|
value := ""
|
||||||
err := db.QueryRow(sql).Scan(&value)
|
err := db.QueryRow(sql).Scan(&value)
|
||||||
if checkError(err, "Fail to QueryRow for key "+key) {
|
if checkError(err, "") {
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
|
|
||||||
func addSetting(db *sql.DB, key string, value string) error {
|
func addSetting(db *sql.DB, dbType string, key string, value string) error {
|
||||||
sql := fmt.Sprintf("insert into \"%s\" (key, value) values($1, $2);", SettingsTable)
|
table := quoteName(dbType, SettingsTable)
|
||||||
|
k := quoteName(dbType, "key")
|
||||||
|
p1 := placeHolder(dbType, 1)
|
||||||
|
p2 := placeHolder(dbType, 2)
|
||||||
|
sql := fmt.Sprintf("insert into %s (%s, value) values(%s, %s);", table, k, p1, p2)
|
||||||
stmt, err := db.Prepare(sql)
|
stmt, err := db.Prepare(sql)
|
||||||
if checkError(err, "Fail to prepare sql") {
|
if checkError(err, "Fail to prepare sql") {
|
||||||
return err
|
return err
|
||||||
@ -308,8 +318,14 @@ func addSetting(db *sql.DB, key string, value string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func updateSetting(db *sql.DB, key string, value string) error {
|
func updateSetting(db *sql.DB, dbType string, key string, value string) error {
|
||||||
sql := fmt.Sprintf("update \"%s\" set value = $1 where key = $2;", SettingsTable)
|
sql := fmt.Sprintf(
|
||||||
|
"update %s set value = %s where %s = %s;",
|
||||||
|
quoteName(dbType, SettingsTable),
|
||||||
|
placeHolder(dbType, 1),
|
||||||
|
quoteName(dbType, "key"),
|
||||||
|
placeHolder(dbType, 2),
|
||||||
|
)
|
||||||
stmt, err := db.Prepare(sql)
|
stmt, err := db.Prepare(sql)
|
||||||
if checkError(err, "Fail to prepare sql") {
|
if checkError(err, "Fail to prepare sql") {
|
||||||
return err
|
return err
|
||||||
@ -320,23 +336,32 @@ func updateSetting(db *sql.DB, key string, value string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func HadSetting(db *sql.DB) bool {
|
func HadSetting(db *sql.DB) bool {
|
||||||
key := GetSetting(db, "Key")
|
dbType, err := GetDBType(db)
|
||||||
|
if checkError(err, "Fail to get db type") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
key := GetSetting(db, dbType, "Key")
|
||||||
if key == "" {
|
if key == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func ContainKey(db *sql.DB, key string) bool {
|
func ContainKey(db *sql.DB, dbType string, key string) bool {
|
||||||
sql := fmt.Sprintf("select value from \"%s\" where key = '%s';", SettingsTable, key)
|
sql := fmt.Sprintf("select value from %s where %s = '%s';", quoteName(dbType, SettingsTable), quoteName(dbType, "key"), key)
|
||||||
value := ""
|
value := ""
|
||||||
err := db.QueryRow(sql).Scan(&value)
|
err := db.QueryRow(sql).Scan(&value)
|
||||||
if checkError(err, "Fail to QueryRow for key "+key) {
|
if checkError(err, "") {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func HadSettingOfKey(db *sql.DB, key string) bool {
|
func HadSettingOfKey(db *sql.DB, dbType string, key string) bool {
|
||||||
return GetSetting(db, "Key") == key
|
return GetSetting(db, dbType, "Key") == key
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSettingsBackend(db *sql.DB, dbType string) settingsBackend {
|
||||||
|
InitSettingsTable(db, dbType)
|
||||||
|
return settingsBackend{db: db, dbType: dbType}
|
||||||
}
|
}
|
||||||
|
|||||||
@ -8,15 +8,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type shareBackend struct {
|
type shareBackend struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
dbType string
|
||||||
}
|
}
|
||||||
|
|
||||||
type linkRecord interface {
|
type linkRecord interface {
|
||||||
Scan(dest ...interface{}) error
|
Scan(dest ...interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitSharesTable(db *sql.DB) error {
|
func InitSharesTable(db *sql.DB, dbType string) error {
|
||||||
sql := fmt.Sprintf("create table if not exists \"%s\" (hash text, path text, userid integer, expire integer, passwordhash text, token text)", SharesTable)
|
sql := fmt.Sprintf("create table if not exists %s (hash text, path text, userid integer, expire integer, passwordhash text, token text)", quoteName(dbType, SharesTable))
|
||||||
_, err := db.Exec(sql)
|
_, err := db.Exec(sql)
|
||||||
checkError(err, "Fail to InitSharesTable")
|
checkError(err, "Fail to InitSharesTable")
|
||||||
return err
|
return err
|
||||||
@ -43,8 +44,8 @@ func parseLink(row linkRecord) (*share.Link, error) {
|
|||||||
return &link, nil
|
return &link, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func queryLinks(db *sql.DB, condition string) ([]*share.Link, error) {
|
func queryLinks(db *sql.DB, dbType string, condition string) ([]*share.Link, error) {
|
||||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\"", SharesTable)
|
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s", quoteName(dbType, SharesTable))
|
||||||
if len(condition) > 0 {
|
if len(condition) > 0 {
|
||||||
sql = sql + " where " + condition
|
sql = sql + " where " + condition
|
||||||
}
|
}
|
||||||
@ -64,37 +65,42 @@ func queryLinks(db *sql.DB, condition string) ([]*share.Link, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s shareBackend) All() ([]*share.Link, error) {
|
func (s shareBackend) All() ([]*share.Link, error) {
|
||||||
return queryLinks(s.db, "")
|
return queryLinks(s.db, s.dbType, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s shareBackend) FindByUserID(id uint) ([]*share.Link, error) {
|
func (s shareBackend) FindByUserID(id uint) ([]*share.Link, error) {
|
||||||
condition := fmt.Sprintf("userid=%d", id)
|
condition := fmt.Sprintf("userid=%d", id)
|
||||||
return queryLinks(s.db, condition)
|
return queryLinks(s.db, s.dbType, condition)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s shareBackend) GetByHash(hash string) (*share.Link, error) {
|
func (s shareBackend) GetByHash(hash string) (*share.Link, error) {
|
||||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\" where hash='%s'", SharesTable, hash)
|
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s where hash='%s'", quoteName(s.dbType, SharesTable), hash)
|
||||||
return parseLink(s.db.QueryRow(sql))
|
return parseLink(s.db.QueryRow(sql))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s shareBackend) GetPermanent(path string, id uint) (*share.Link, error) {
|
func (s shareBackend) GetPermanent(path string, id uint) (*share.Link, error) {
|
||||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from \"%s\" where path='%s' and userid=%d", SharesTable, path, id)
|
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s where path='%s' and userid=%d", quoteName(s.dbType, SharesTable), path, id)
|
||||||
return parseLink(s.db.QueryRow(sql))
|
return parseLink(s.db.QueryRow(sql))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s shareBackend) Gets(path string, id uint) ([]*share.Link, error) {
|
func (s shareBackend) Gets(path string, id uint) ([]*share.Link, error) {
|
||||||
condition := fmt.Sprintf("userid=%d and path='%s'", id, path)
|
condition := fmt.Sprintf("userid=%d and path='%s'", id, path)
|
||||||
return queryLinks(s.db, condition)
|
return queryLinks(s.db, s.dbType, condition)
|
||||||
}
|
}
|
||||||
func (s shareBackend) Save(l *share.Link) error {
|
func (s shareBackend) Save(l *share.Link) error {
|
||||||
sql := fmt.Sprintf("insert into \"%s\" (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", SharesTable, l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token)
|
sql := fmt.Sprintf("insert into %s (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", quoteName(s.dbType, SharesTable), l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token)
|
||||||
_, err := s.db.Exec(sql)
|
_, err := s.db.Exec(sql)
|
||||||
checkError(err, "Fail to Save share")
|
checkError(err, "Fail to Save share")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
func (s shareBackend) Delete(hash string) error {
|
func (s shareBackend) Delete(hash string) error {
|
||||||
sql := fmt.Sprintf("DELETE FROM \"%s\" WHERE hash='%s'", SharesTable, hash)
|
sql := fmt.Sprintf("DELETE FROM %s WHERE hash='%s'", quoteName(s.dbType, SharesTable), hash)
|
||||||
_, err := s.db.Exec(sql)
|
_, err := s.db.Exec(sql)
|
||||||
checkError(err, "Fail to Delete share")
|
checkError(err, "Fail to Delete share")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newShareBackend(db *sql.DB, dbType string) shareBackend {
|
||||||
|
InitSharesTable(db, dbType)
|
||||||
|
return shareBackend{db: db, dbType: dbType}
|
||||||
|
}
|
||||||
|
|||||||
@ -3,6 +3,9 @@ package sql
|
|||||||
import (
|
import (
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/filebrowser/filebrowser/v2/auth"
|
"github.com/filebrowser/filebrowser/v2/auth"
|
||||||
@ -15,6 +18,58 @@ import (
|
|||||||
_ "github.com/mattn/go-sqlite3"
|
_ "github.com/mattn/go-sqlite3"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
}
|
||||||
|
|
||||||
|
type DBConnectionRecord struct {
|
||||||
|
db *sql.DB
|
||||||
|
dbType string
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
dbRecords map[string]*DBConnectionRecord = map[string]*DBConnectionRecord{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetDBType used to get the driver type of a sql.DB
|
||||||
|
// It is based on existing dbRecords
|
||||||
|
// All sql.DB should opened by OpenDB
|
||||||
|
func GetDBType(db *sql.DB) (string, error) {
|
||||||
|
for _, record := range dbRecords {
|
||||||
|
if record.db == db {
|
||||||
|
return record.dbType, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", errors.New("No such database open by this module")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNameQuote(dbType string) string {
|
||||||
|
if dbType == "mysql" {
|
||||||
|
return "`"
|
||||||
|
}
|
||||||
|
return "\""
|
||||||
|
}
|
||||||
|
|
||||||
|
// for mysql, it is ``
|
||||||
|
// for postgres and sqlite, it is ""
|
||||||
|
func quoteName(dbType string, name string) string {
|
||||||
|
q := getNameQuote(dbType)
|
||||||
|
return q + name + q
|
||||||
|
}
|
||||||
|
|
||||||
|
// placeholder for sql stmt
|
||||||
|
// for postgres, it is $1, $2, $3...
|
||||||
|
// for mysql and sqlite3, it is ?,?,?...
|
||||||
|
func placeHolder(dbType string, index int) string {
|
||||||
|
if index <= 0 {
|
||||||
|
panic("the placeholder index should >= 1")
|
||||||
|
}
|
||||||
|
if dbType == "postgres" {
|
||||||
|
return fmt.Sprintf("$%d", index)
|
||||||
|
}
|
||||||
|
return "?"
|
||||||
|
}
|
||||||
|
|
||||||
func IsDBPath(path string) bool {
|
func IsDBPath(path string) bool {
|
||||||
prefixes := []string{"sqlite3", "postgres", "mysql"}
|
prefixes := []string{"sqlite3", "postgres", "mysql"}
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
@ -26,18 +81,93 @@ func IsDBPath(path string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func OpenDB(path string) (*sql.DB, error) {
|
func OpenDB(path string) (*sql.DB, error) {
|
||||||
|
if val, ok := dbRecords[path]; ok {
|
||||||
|
return val.db, nil
|
||||||
|
}
|
||||||
prefixes := []string{"sqlite3", "postgres", "mysql"}
|
prefixes := []string{"sqlite3", "postgres", "mysql"}
|
||||||
for _, prefix := range prefixes {
|
for _, prefix := range prefixes {
|
||||||
if strings.HasPrefix(path, prefix) {
|
if strings.HasPrefix(path, prefix) {
|
||||||
return connectDB(prefix, path)
|
db, err := connectDB(prefix, path)
|
||||||
|
if !checkError(err, "Fail to connect database "+path) {
|
||||||
|
dbRecords[path] = &DBConnectionRecord{db: db, dbType: prefix, path: path}
|
||||||
|
}
|
||||||
|
return db, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil, errors.New("Unsupported db scheme")
|
return nil, errors.New("Unsupported db scheme")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type DatabaseResource struct {
|
||||||
|
scheme string
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
host string
|
||||||
|
port int
|
||||||
|
database string
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseDatabasePath(path string) (*DatabaseResource, error) {
|
||||||
|
pattern := "^(([a-zA-Z0-9]+)://)?(([^:]+)(:(.*))?@)?([a-zA-Z0-9_.]+)(:([0-9]+))?(/([a-zA-Z0-9_-]+))?$"
|
||||||
|
reg, err := regexp.Compile(pattern)
|
||||||
|
if checkError(err, "Fail to compile regexp") {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
matches := reg.FindAllStringSubmatch(path, -1)
|
||||||
|
if matches == nil || len(matches) == 0 {
|
||||||
|
return nil, errors.New("Fail to parse database")
|
||||||
|
}
|
||||||
|
r := DatabaseResource{}
|
||||||
|
r.scheme = matches[0][2]
|
||||||
|
r.username = matches[0][4]
|
||||||
|
r.password = matches[0][6]
|
||||||
|
r.host = matches[0][7]
|
||||||
|
if len(matches[0][9]) > 0 {
|
||||||
|
port, err := strconv.Atoi(matches[0][9])
|
||||||
|
if !checkError(err, "Fail to parse port") {
|
||||||
|
r.port = port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.database = matches[0][11]
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mysql://user:password@host:port/db => mysql://user:password@tcp(host:port)/db
|
||||||
|
func transformMysqlPath(path string) (string, error) {
|
||||||
|
r, err := ParseDatabasePath(path)
|
||||||
|
if checkError(err, "Fail to parse database path") {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
scheme := r.scheme
|
||||||
|
if len(scheme) == 0 {
|
||||||
|
scheme = "mysql"
|
||||||
|
}
|
||||||
|
credential := ""
|
||||||
|
if len(r.username) > 0 && len(r.password) > 0 {
|
||||||
|
credential = r.username + ":" + r.password + "@"
|
||||||
|
} else if len(r.username) > 0 {
|
||||||
|
credential = r.username + "@"
|
||||||
|
}
|
||||||
|
host := r.host
|
||||||
|
port := r.port
|
||||||
|
if port == 0 {
|
||||||
|
port = 3306
|
||||||
|
}
|
||||||
|
if len(r.database) == 0 {
|
||||||
|
return "", errors.New("no database found in path")
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s://%stcp(%s:%d)/%s", scheme, credential, host, port, r.database), nil
|
||||||
|
}
|
||||||
|
|
||||||
func connectDB(dbType string, path string) (*sql.DB, error) {
|
func connectDB(dbType string, path string) (*sql.DB, error) {
|
||||||
if dbType == "sqlite3" && strings.HasPrefix(path, "sqlite3://") {
|
if dbType == "sqlite3" && strings.HasPrefix(path, "sqlite3://") {
|
||||||
path = strings.TrimPrefix(path, "sqlite3://")
|
path = strings.TrimPrefix(path, "sqlite3://")
|
||||||
|
} else if dbType == "mysql" && strings.HasPrefix(path, "mysql://") {
|
||||||
|
p, err := transformMysqlPath(path)
|
||||||
|
if checkError(err, "Fail to parse mysql path") {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
path = p
|
||||||
|
path = strings.TrimPrefix(path, "mysql://")
|
||||||
}
|
}
|
||||||
db, err := sql.Open(dbType, path)
|
db, err := sql.Open(dbType, path)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
@ -47,20 +177,13 @@ func connectDB(dbType string, path string) (*sql.DB, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func NewStorage(db *sql.DB) (*storage.Storage, error) {
|
func NewStorage(db *sql.DB) (*storage.Storage, error) {
|
||||||
|
dbType, err := GetDBType(db)
|
||||||
|
checkError(err, "Fail to get database type, maybe this sql.DB is not opened by OpenDB")
|
||||||
|
|
||||||
InitUserTable(db)
|
userStore := users.NewStorage(newUsersBackend(db, dbType))
|
||||||
InitSharesTable(db)
|
shareStore := share.NewStorage(newShareBackend(db, dbType))
|
||||||
InitSettingsTable(db)
|
settingsStore := settings.NewStorage(newSettingsBackend(db, dbType))
|
||||||
|
authStore := auth.NewStorage(newAuthBackend(db, dbType), userStore)
|
||||||
userStore := users.NewStorage(usersBackend{db: db})
|
|
||||||
shareStore := share.NewStorage(shareBackend{db: db})
|
|
||||||
settingsStore := settings.NewStorage(settingsBackend{db: db})
|
|
||||||
authStore := auth.NewStorage(authBackend{db: db}, userStore)
|
|
||||||
|
|
||||||
err := SetSetting(db, "version", "2")
|
|
||||||
if checkError(err, "Fail to set version") {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
storage := &storage.Storage{
|
storage := &storage.Storage{
|
||||||
Auth: authStore,
|
Auth: authStore,
|
||||||
|
|||||||
@ -14,7 +14,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type usersBackend struct {
|
type usersBackend struct {
|
||||||
db *sql.DB
|
db *sql.DB
|
||||||
|
dbType string
|
||||||
}
|
}
|
||||||
|
|
||||||
func PermFromString(s string) users.Permissions {
|
func PermFromString(s string) users.Permissions {
|
||||||
@ -114,13 +115,24 @@ func createAdminUser() users.User {
|
|||||||
return user
|
return user
|
||||||
}
|
}
|
||||||
|
|
||||||
func InitUserTable(db *sql.DB) error {
|
func InitUserTable(db *sql.DB, dbType string) error {
|
||||||
sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS \"%s\" (id integer primary key, username text, password text, scope text, locale text, lockpassword integer, viewmode text, perm text, commands text, sorting text, rules text, hidedotfiles integer, dateformat integer, singleclick integer);", UsersTable)
|
primaryKey := "integer primary key"
|
||||||
|
if dbType == "postgres" {
|
||||||
|
primaryKey = "serial primary key"
|
||||||
|
} else if dbType == "mysql" {
|
||||||
|
primaryKey = "int unsigned primary key auto_increment"
|
||||||
|
}
|
||||||
|
sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id %s, username text, password text, scope text, locale text, lockpassword integer, viewmode text, perm text, commands text, sorting text, rules text, hidedotfiles integer, dateformat integer, singleclick integer);", quoteName(dbType, UsersTable), primaryKey)
|
||||||
_, err := db.Exec(sql)
|
_, err := db.Exec(sql)
|
||||||
checkError(err, "Fail to create users table")
|
checkError(err, "Fail to create users table")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newUsersBackend(db *sql.DB, dbType string) usersBackend {
|
||||||
|
InitUserTable(db, dbType)
|
||||||
|
return usersBackend{db: db, dbType: dbType}
|
||||||
|
}
|
||||||
|
|
||||||
func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
|
func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
|
||||||
columns := []string{"id", "username", "password", "scope", "locale", "lockpassword", "viewmode", "perm", "commands", "sorting", "rules", "hidedotfiles", "dateformat", "singleclick"}
|
columns := []string{"id", "username", "password", "scope", "locale", "lockpassword", "viewmode", "perm", "commands", "sorting", "rules", "hidedotfiles", "dateformat", "singleclick"}
|
||||||
columnsStr := strings.Join(columns, ",")
|
columnsStr := strings.Join(columns, ",")
|
||||||
@ -150,7 +162,7 @@ func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
|
|||||||
dateformat := false
|
dateformat := false
|
||||||
singleclick := false
|
singleclick := false
|
||||||
user := users.User{}
|
user := users.User{}
|
||||||
sql := fmt.Sprintf("SELECT %s FROM \"%s\" WHERE %s", columnsStr, UsersTable, conditionStr)
|
sql := fmt.Sprintf("SELECT %s FROM %s WHERE %s", columnsStr, quoteName(s.dbType, UsersTable), conditionStr)
|
||||||
err := s.db.QueryRow(sql).Scan(&userID, &username, &password, &scope, &locale, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules, &hidedotfiles, &dateformat, &singleclick)
|
err := s.db.QueryRow(sql).Scan(&userID, &username, &password, &scope, &locale, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules, &hidedotfiles, &dateformat, &singleclick)
|
||||||
if checkError(err, "") {
|
if checkError(err, "") {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -173,7 +185,7 @@ func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s usersBackend) Gets() ([]*users.User, error) {
|
func (s usersBackend) Gets() ([]*users.User, error) {
|
||||||
sql := fmt.Sprintf("SELECT id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules FROM \"%s\"", UsersTable)
|
sql := fmt.Sprintf("SELECT id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules FROM %s", quoteName(s.dbType, UsersTable))
|
||||||
rows, err := s.db.Query(sql)
|
rows, err := s.db.Query(sql)
|
||||||
if checkError(err, "Fail to Query []*users.User") {
|
if checkError(err, "Fail to Query []*users.User") {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -217,8 +229,8 @@ func (s usersBackend) updateUser(id uint, user *users.User) error {
|
|||||||
lockpassword = 1
|
lockpassword = 1
|
||||||
}
|
}
|
||||||
sql := fmt.Sprintf(
|
sql := fmt.Sprintf(
|
||||||
"UPDATE \"%s\" SET username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' WHERE id=%d",
|
"UPDATE %s SET username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' WHERE id=%d",
|
||||||
UsersTable,
|
quoteName(s.dbType, UsersTable),
|
||||||
user.Username,
|
user.Username,
|
||||||
user.Password,
|
user.Password,
|
||||||
user.Scope,
|
user.Scope,
|
||||||
@ -259,7 +271,10 @@ func (s usersBackend) insertUser(user *users.User) error {
|
|||||||
}
|
}
|
||||||
columnStr := strings.Join(columns, ",")
|
columnStr := strings.Join(columns, ",")
|
||||||
specStr := strings.Join(specs, ",")
|
specStr := strings.Join(specs, ",")
|
||||||
sqlFormat := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES (%s)", UsersTable, columnStr, specStr)
|
sqlFormat := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoteName(s.dbType, UsersTable), columnStr, specStr)
|
||||||
|
if s.dbType == "postgres" {
|
||||||
|
sqlFormat = sqlFormat + " RETURNING id;"
|
||||||
|
}
|
||||||
sql := fmt.Sprintf(
|
sql := fmt.Sprintf(
|
||||||
sqlFormat,
|
sqlFormat,
|
||||||
user.Username,
|
user.Username,
|
||||||
@ -276,12 +291,19 @@ func (s usersBackend) insertUser(user *users.User) error {
|
|||||||
boolToString(user.DateFormat),
|
boolToString(user.DateFormat),
|
||||||
boolToString(user.SingleClick),
|
boolToString(user.SingleClick),
|
||||||
)
|
)
|
||||||
|
if s.dbType == "postgres" {
|
||||||
|
id := uint(0)
|
||||||
|
err := s.db.QueryRow(sql).Scan(&id)
|
||||||
|
if !checkError(err, "Fail to insert user") {
|
||||||
|
user.ID = id
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
res, err := s.db.Exec(sql)
|
res, err := s.db.Exec(sql)
|
||||||
if !checkError(err, "Fail to insert user") {
|
if !checkError(err, "Fail to insert user") {
|
||||||
id, err2 := res.LastInsertId()
|
id, err := res.LastInsertId()
|
||||||
if !checkError(err2, "Fail to fetch last insert id") {
|
checkError(err, "Fail to get last inserted id")
|
||||||
user.ID = uint(id)
|
user.ID = uint(id)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -296,14 +318,14 @@ func (s usersBackend) Save(user *users.User) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s usersBackend) DeleteByID(id uint) error {
|
func (s usersBackend) DeleteByID(id uint) error {
|
||||||
sql := fmt.Sprintf("delete from \"%s\" where id=%d", UsersTable, id)
|
sql := fmt.Sprintf("delete from %s where id=%d", quoteName(s.dbType, UsersTable), id)
|
||||||
_, err := s.db.Exec(sql)
|
_, err := s.db.Exec(sql)
|
||||||
checkError(err, "Fail to delete User by id")
|
checkError(err, "Fail to delete User by id")
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s usersBackend) DeleteByUsername(username string) error {
|
func (s usersBackend) DeleteByUsername(username string) error {
|
||||||
sql := fmt.Sprintf("delete from \"%s\" where username='%s'", UsersTable, username)
|
sql := fmt.Sprintf("delete from %s where username='%s'", quoteName(s.dbType, UsersTable), username)
|
||||||
_, err := s.db.Exec(sql)
|
_, err := s.db.Exec(sql)
|
||||||
checkError(err, "Fail to delete user by username")
|
checkError(err, "Fail to delete user by username")
|
||||||
return err
|
return err
|
||||||
@ -323,15 +345,15 @@ func (s usersBackend) Update(u *users.User, fields ...string) error {
|
|||||||
val := userField.Interface()
|
val := userField.Interface()
|
||||||
typeStr := reflect.TypeOf(val).Kind().String()
|
typeStr := reflect.TypeOf(val).Kind().String()
|
||||||
if typeStr == "string" {
|
if typeStr == "string" {
|
||||||
setItems = append(setItems, fmt.Sprintf("\"%s\"='%s'", field, val))
|
setItems = append(setItems, fmt.Sprintf("%s='%s'", field, val))
|
||||||
} else if typeStr == "bool" {
|
} else if typeStr == "bool" {
|
||||||
setItems = append(setItems, fmt.Sprintf("\"%s\"=%s", field, boolToString(val.(bool))))
|
setItems = append(setItems, fmt.Sprintf("%s=%s", field, boolToString(val.(bool))))
|
||||||
} else {
|
} else {
|
||||||
// TODO
|
// TODO
|
||||||
setItems = append(setItems, fmt.Sprintf("\"%s\"=%s", field, val))
|
setItems = append(setItems, fmt.Sprintf("%s=%s", field, val))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sql := fmt.Sprintf("UPDATE \"%s\" SET %s WHERE id=%d", UsersTable, strings.Join(setItems, ","), u.ID)
|
sql := fmt.Sprintf("UPDATE %s SET %s WHERE id=%d", UsersTable, strings.Join(setItems, ","), u.ID)
|
||||||
_, err := s.db.Exec(sql)
|
_, err := s.db.Exec(sql)
|
||||||
checkError(err, "Fail to update user")
|
checkError(err, "Fail to update user")
|
||||||
return err
|
return err
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user