Merge b88010b779 into e167c3e1ef
This commit is contained in:
commit
2b7c79b60c
@ -45,7 +45,7 @@ func init() {
|
|||||||
persistent := rootCmd.PersistentFlags()
|
persistent := rootCmd.PersistentFlags()
|
||||||
|
|
||||||
persistent.StringVarP(&cfgFile, "config", "c", "", "config file path")
|
persistent.StringVarP(&cfgFile, "config", "c", "", "config file path")
|
||||||
persistent.StringP("database", "d", "./filebrowser.db", "database path")
|
persistent.StringP("database", "d", "./filebrowser.db", "database path. Possible choices:\n\tbolt: ./bolt.db\n\tsqlite3: sqlite3://test.db\n\tpostgresql: postgres://user:password@192.168.1.10:5432/postgres?sslmode=disable\n\tmysql: mysql://user:password@192.168.1.10:3306/root\n")
|
||||||
flags.Bool("noauth", false, "use the noauth auther when using quick setup")
|
flags.Bool("noauth", false, "use the noauth auther when using quick setup")
|
||||||
flags.String("username", "admin", "username for the first user when using quick config")
|
flags.String("username", "admin", "username for the first user when using quick config")
|
||||||
flags.String("password", "", "hashed password for the first user when using quick config (default \"admin\")")
|
flags.String("password", "", "hashed password for the first user when using quick config (default \"admin\")")
|
||||||
|
|||||||
61
cmd/utils.go
61
cmd/utils.go
@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/filebrowser/filebrowser/v2/settings"
|
"github.com/filebrowser/filebrowser/v2/settings"
|
||||||
"github.com/filebrowser/filebrowser/v2/storage"
|
"github.com/filebrowser/filebrowser/v2/storage"
|
||||||
"github.com/filebrowser/filebrowser/v2/storage/bolt"
|
"github.com/filebrowser/filebrowser/v2/storage/bolt"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/storage/sql"
|
||||||
)
|
)
|
||||||
|
|
||||||
func checkErr(err error) {
|
func checkErr(err error) {
|
||||||
@ -82,27 +83,51 @@ func dbExists(path string) (bool, error) {
|
|||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Closeable interface {
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func openBoltDB(path string, cfg pythonConfig) (pythonData, Closeable) {
|
||||||
|
data := pythonData{hadDB: true}
|
||||||
|
exists, err := dbExists(path)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
} else if exists && cfg.noDB {
|
||||||
|
log.Fatal(path + " already exists")
|
||||||
|
} else if !exists && !cfg.noDB && !cfg.allowNoDB {
|
||||||
|
log.Fatal(path + " does not exist. Please run 'filebrowser config init' first.")
|
||||||
|
}
|
||||||
|
|
||||||
|
data.hadDB = exists
|
||||||
|
db, err := storm.Open(path)
|
||||||
|
checkErr(err)
|
||||||
|
data.store, err = bolt.NewStorage(db)
|
||||||
|
checkErr(err)
|
||||||
|
return data, db
|
||||||
|
}
|
||||||
|
|
||||||
|
func openDB(path string, cfg pythonConfig) (pythonData, Closeable) {
|
||||||
|
if sql.IsDBPath(path) {
|
||||||
|
data := pythonData{hadDB: false}
|
||||||
|
db, err := sql.OpenDB(path)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Fail to open database " + path)
|
||||||
|
}
|
||||||
|
data.store, err = sql.NewStorage(db)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal("Fail to create database storage for " + path)
|
||||||
|
}
|
||||||
|
data.hadDB = sql.HadSetting(db)
|
||||||
|
return data, db
|
||||||
|
}
|
||||||
|
return openBoltDB(path, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
func python(fn pythonFunc, cfg pythonConfig) cobraFunc {
|
func python(fn pythonFunc, cfg pythonConfig) cobraFunc {
|
||||||
return func(cmd *cobra.Command, args []string) {
|
return func(cmd *cobra.Command, args []string) {
|
||||||
data := pythonData{hadDB: true}
|
data, db := openDB(getParam(cmd.Flags(), "database"), cfg)
|
||||||
|
|
||||||
path := getParam(cmd.Flags(), "database")
|
|
||||||
exists, err := dbExists(path)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
} else if exists && cfg.noDB {
|
|
||||||
log.Fatal(path + " already exists")
|
|
||||||
} else if !exists && !cfg.noDB && !cfg.allowNoDB {
|
|
||||||
log.Fatal(path + " does not exist. Please run 'filebrowser config init' first.")
|
|
||||||
}
|
|
||||||
|
|
||||||
data.hadDB = exists
|
|
||||||
db, err := storm.Open(path)
|
|
||||||
checkErr(err)
|
|
||||||
defer db.Close()
|
defer db.Close()
|
||||||
data.store, err = bolt.NewStorage(db)
|
|
||||||
checkErr(err)
|
|
||||||
fn(cmd, args, data)
|
fn(cmd, args, data)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
3
go.mod
3
go.mod
@ -7,11 +7,14 @@ require (
|
|||||||
github.com/disintegration/imaging v1.6.2
|
github.com/disintegration/imaging v1.6.2
|
||||||
github.com/dsoprea/go-exif/v3 v3.0.0-20201216222538-db167117f483
|
github.com/dsoprea/go-exif/v3 v3.0.0-20201216222538-db167117f483
|
||||||
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568
|
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568
|
||||||
|
github.com/go-sql-driver/mysql v1.6.0
|
||||||
github.com/golang-jwt/jwt/v4 v4.4.3
|
github.com/golang-jwt/jwt/v4 v4.4.3
|
||||||
github.com/gorilla/mux v1.8.0
|
github.com/gorilla/mux v1.8.0
|
||||||
github.com/gorilla/websocket v1.5.0
|
github.com/gorilla/websocket v1.5.0
|
||||||
|
github.com/lib/pq v1.10.9
|
||||||
github.com/maruel/natural v1.1.0
|
github.com/maruel/natural v1.1.0
|
||||||
github.com/marusama/semaphore/v2 v2.5.0
|
github.com/marusama/semaphore/v2 v2.5.0
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.17
|
||||||
github.com/mholt/archiver/v3 v3.5.1
|
github.com/mholt/archiver/v3 v3.5.1
|
||||||
github.com/mitchellh/go-homedir v1.1.0
|
github.com/mitchellh/go-homedir v1.1.0
|
||||||
github.com/pelletier/go-toml/v2 v2.0.6
|
github.com/pelletier/go-toml/v2 v2.0.6
|
||||||
|
|||||||
6
go.sum
6
go.sum
@ -95,6 +95,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2
|
|||||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||||
|
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
||||||
|
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||||
github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU=
|
github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU=
|
||||||
github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||||
github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI=
|
github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI=
|
||||||
@ -196,6 +198,8 @@ github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3x
|
|||||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||||
|
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||||
|
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||||
@ -204,6 +208,8 @@ github.com/maruel/natural v1.1.0 h1:2z1NgP/Vae+gYrtC0VuvrTJ6U35OuyUqDdfluLqMWuQ=
|
|||||||
github.com/maruel/natural v1.1.0/go.mod h1:eFVhYCcUOfZFxXoDZam8Ktya72wa79fNC3lc/leA0DQ=
|
github.com/maruel/natural v1.1.0/go.mod h1:eFVhYCcUOfZFxXoDZam8Ktya72wa79fNC3lc/leA0DQ=
|
||||||
github.com/marusama/semaphore/v2 v2.5.0 h1:o/1QJD9DBYOWRnDhPwDVAXQn6mQYD0gZaS1Tpx6DJGM=
|
github.com/marusama/semaphore/v2 v2.5.0 h1:o/1QJD9DBYOWRnDhPwDVAXQn6mQYD0gZaS1Tpx6DJGM=
|
||||||
github.com/marusama/semaphore/v2 v2.5.0/go.mod h1:z9nMiNUekt/LTpTUQdpp+4sJeYqUGpwMHfW0Z8V8fnQ=
|
github.com/marusama/semaphore/v2 v2.5.0/go.mod h1:z9nMiNUekt/LTpTUQdpp+4sJeYqUGpwMHfW0Z8V8fnQ=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||||
github.com/mholt/archiver/v3 v3.5.1 h1:rDjOBX9JSF5BvoJGvjqK479aL70qh9DIpZCl+k7Clwo=
|
github.com/mholt/archiver/v3 v3.5.1 h1:rDjOBX9JSF5BvoJGvjqK479aL70qh9DIpZCl+k7Clwo=
|
||||||
github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4=
|
github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4=
|
||||||
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
|
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
|
||||||
|
|||||||
64
scripts/test-sql.sh
Executable file
64
scripts/test-sql.sh
Executable file
@ -0,0 +1,64 @@
|
|||||||
|
#!/bin/bash
|
||||||
|
|
||||||
|
# stop_docker name
|
||||||
|
stop_docker() {
|
||||||
|
sudo docker stop $1 1>/dev/null 2>&1
|
||||||
|
sudo docker rm $1 1>/dev/null 2>&1
|
||||||
|
}
|
||||||
|
|
||||||
|
# start_mysql name password port
|
||||||
|
start_mysql() {
|
||||||
|
stop_docker $1
|
||||||
|
sudo docker run --rm --name $1 -e MYSQL_ROOT_PASSWORD=$2 -p $3:3306 -d mysql
|
||||||
|
}
|
||||||
|
|
||||||
|
# start_postgres name password port
|
||||||
|
start_postgres() {
|
||||||
|
stop_docker $1
|
||||||
|
sudo docker run --rm --name $1 -e POSTGRES_PASSWORD=$2 -p $3:5432 -d postgres
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
test_sqlite() {
|
||||||
|
rm -f test.db
|
||||||
|
./filebrowser -a 0.0.0.0 -d sqlite3://test.db
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
test_postgres() {
|
||||||
|
start_postgres test-postgres postgres 5433
|
||||||
|
sleep 30
|
||||||
|
./filebrowser -a 0.0.0.0 -d postgres://postgres:postgres@127.0.0.1:5433/postgres?sslmode=disable
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
test_mysql() {
|
||||||
|
start_mysql test-mysql root 3307
|
||||||
|
sleep 60
|
||||||
|
./filebrowser -a 0.0.0.0 -d 'mysql://root:root@127.0.0.1:3307/mysql'
|
||||||
|
}
|
||||||
|
|
||||||
|
help() {
|
||||||
|
echo "USAGE: $0 sqlite|mysql|postgres"
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
if (( $# == 0 )); then
|
||||||
|
help
|
||||||
|
fi
|
||||||
|
|
||||||
|
case $1 in
|
||||||
|
sqlite)
|
||||||
|
test_sqlite
|
||||||
|
;;
|
||||||
|
mysql)
|
||||||
|
test_mysql
|
||||||
|
;;
|
||||||
|
postgres)
|
||||||
|
test_postgres
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
help
|
||||||
|
esac
|
||||||
|
|
||||||
|
|
||||||
47
storage/sql/auth.go
Normal file
47
storage/sql/auth.go
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
package sql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/filebrowser/filebrowser/v2/auth"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/errors"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/settings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type authBackend struct {
|
||||||
|
db *sql.DB
|
||||||
|
dbType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s authBackend) Get(t settings.AuthMethod) (auth.Auther, error) {
|
||||||
|
var auther auth.Auther
|
||||||
|
|
||||||
|
switch t {
|
||||||
|
case auth.MethodJSONAuth:
|
||||||
|
auther = &auth.JSONAuth{}
|
||||||
|
case auth.MethodProxyAuth:
|
||||||
|
auther = &auth.ProxyAuth{}
|
||||||
|
case auth.MethodHookAuth:
|
||||||
|
auther = &auth.HookAuth{}
|
||||||
|
case auth.MethodNoAuth:
|
||||||
|
auther = &auth.NoAuth{}
|
||||||
|
default:
|
||||||
|
fmt.Println("ERROR: unknown auth method " + t)
|
||||||
|
return nil, errors.ErrInvalidAuthMethod
|
||||||
|
}
|
||||||
|
return auther, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s authBackend) Save(a auth.Auther) error {
|
||||||
|
val, err := json.Marshal(a)
|
||||||
|
if checkError(err, "Fail to save auth.Auther") {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return SetSetting(s.db, s.dbType, "auther", string(val))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAuthBackend(db *sql.DB, dbType string) authBackend {
|
||||||
|
return authBackend{db: db, dbType: dbType}
|
||||||
|
}
|
||||||
21
storage/sql/config.go
Normal file
21
storage/sql/config.go
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
package sql
|
||||||
|
|
||||||
|
import "os"
|
||||||
|
|
||||||
|
var SettingsTable = "fb_settings"
|
||||||
|
var UsersTable = "fb_users"
|
||||||
|
var SharesTable = "fb_shares"
|
||||||
|
|
||||||
|
func getEnv(key string, defaultValue string) string {
|
||||||
|
val := os.Getenv(key)
|
||||||
|
if len(val) == 0 {
|
||||||
|
return defaultValue
|
||||||
|
}
|
||||||
|
return val
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
SettingsTable = getEnv("FILEBROWSER_SETTINGS_TABLE", SettingsTable)
|
||||||
|
UsersTable = getEnv("FILEBROWSER_USERS_TABLE", UsersTable)
|
||||||
|
SharesTable = getEnv("FILEBROWSER_SHARES_TABLE", SharesTable)
|
||||||
|
}
|
||||||
160
storage/sql/server.go
Normal file
160
storage/sql/server.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
package sql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/filebrowser/filebrowser/v2/settings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var defaultServer = settings.Server{
|
||||||
|
Port: "8080",
|
||||||
|
Log: "stdout",
|
||||||
|
EnableThumbnails: false,
|
||||||
|
ResizePreview: false,
|
||||||
|
EnableExec: false,
|
||||||
|
TypeDetectionByHeader: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneServer(server settings.Server) settings.Server {
|
||||||
|
data, err := json.Marshal(server)
|
||||||
|
s := settings.Server{}
|
||||||
|
if checkError(err, "Fail to clone settings.Server") {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
err = json.Unmarshal(data, &s)
|
||||||
|
checkError(err, "Fail to decode for settings.Server")
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s settingsBackend) GetServer() (*settings.Server, error) {
|
||||||
|
sql := fmt.Sprintf("select %s, value from %s", quoteName(s.dbType, "key"), quoteName(s.dbType, SettingsTable))
|
||||||
|
rows, err := s.db.Query(sql)
|
||||||
|
if checkError(err, "Fail to Query for GetServer") {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
server := cloneServer(defaultServer)
|
||||||
|
key := ""
|
||||||
|
value := ""
|
||||||
|
|
||||||
|
for rows.Next() {
|
||||||
|
err = rows.Scan(&key, &value)
|
||||||
|
if checkError(err, "Fail to query settings.Settings") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if key == "Root" {
|
||||||
|
server.Root = value
|
||||||
|
} else if key == "BaseURL" {
|
||||||
|
server.BaseURL = value
|
||||||
|
} else if key == "Socket" {
|
||||||
|
server.Socket = value
|
||||||
|
} else if key == "TLSKey" {
|
||||||
|
server.TLSKey = value
|
||||||
|
} else if key == "TLSCert" {
|
||||||
|
server.TLSCert = value
|
||||||
|
} else if key == "Port" {
|
||||||
|
server.Port = value
|
||||||
|
} else if key == "Address" {
|
||||||
|
server.Address = value
|
||||||
|
} else if key == "Log" {
|
||||||
|
server.Log = value
|
||||||
|
} else if key == "EnableThumbnails" {
|
||||||
|
server.EnableThumbnails = boolFromString(value)
|
||||||
|
} else if key == "ResizePreview" {
|
||||||
|
server.ResizePreview = boolFromString(value)
|
||||||
|
} else if key == "EnableExec" {
|
||||||
|
server.EnableExec = boolFromString(value)
|
||||||
|
} else if key == "TypeDetectionByHeader" {
|
||||||
|
server.TypeDetectionByHeader = boolFromString(value)
|
||||||
|
} else if key == "AuthHook" {
|
||||||
|
server.AuthHook = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &server, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s settingsBackend) SaveServer(ss *settings.Server) error {
|
||||||
|
fields := []string{"Root", "BaseURL", "Socket", "TLSKey", "TLSCert", "Port", "Address", "Log", "EnableThumbnails", "ResizePreview", "EnableExec", "TypeDetectionByHeader", "AuthHook"}
|
||||||
|
values := []string{
|
||||||
|
ss.Root,
|
||||||
|
ss.BaseURL,
|
||||||
|
ss.Socket,
|
||||||
|
ss.TLSKey,
|
||||||
|
ss.TLSCert,
|
||||||
|
ss.Port,
|
||||||
|
ss.Address,
|
||||||
|
ss.Log,
|
||||||
|
boolToString(ss.EnableThumbnails),
|
||||||
|
boolToString(ss.ResizePreview),
|
||||||
|
boolToString(ss.EnableExec),
|
||||||
|
boolToString(ss.TypeDetectionByHeader),
|
||||||
|
ss.AuthHook}
|
||||||
|
tx, err := s.db.Begin()
|
||||||
|
if checkError(err, "Fail to begin db transaction") {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
table := quoteName(s.dbType, SettingsTable)
|
||||||
|
keyName := quoteName(s.dbType, "key")
|
||||||
|
p1 := placeHolder(s.dbType, 1)
|
||||||
|
p2 := placeHolder(s.dbType, 2)
|
||||||
|
|
||||||
|
Insert := func(key string, value string) bool {
|
||||||
|
insertSql := fmt.Sprintf("INSERT INTO %s (%s, value) VALUES(%s,%s)", table, keyName, p1, p2)
|
||||||
|
stmt, err := s.db.Prepare(insertSql)
|
||||||
|
defer stmt.Close()
|
||||||
|
if checkError(err, "Fail to prepare statement") {
|
||||||
|
tx.Rollback()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, err = stmt.Exec(key, value)
|
||||||
|
if checkError(err, "Fail to insert field "+key+" of settings.Server") {
|
||||||
|
tx.Rollback()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
Update := func(key string, value string) bool {
|
||||||
|
updateSql := fmt.Sprintf("UPDATE %s SET value=%s WHERE %s=%s", table, p1, keyName, p2)
|
||||||
|
stmt, err := s.db.Prepare(updateSql)
|
||||||
|
defer stmt.Close()
|
||||||
|
if checkError(err, "Fail to prepare statement") {
|
||||||
|
tx.Rollback()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_, err = stmt.Exec(value, key)
|
||||||
|
if checkError(err, "Fail to update field "+key+" of settings.Server") {
|
||||||
|
tx.Rollback()
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
Exist := func(key string) bool {
|
||||||
|
querySql := fmt.Sprintf("SELECT count(*) FROM %s WHERE %s=%s", table, keyName, p1)
|
||||||
|
row := s.db.QueryRow(querySql, key)
|
||||||
|
count := 0
|
||||||
|
err := row.Scan(&count)
|
||||||
|
if checkError(err, "Fail to Query "+key+" for GetServer") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return count == 1
|
||||||
|
}
|
||||||
|
InsertOrUpdate := func(key string, value string) bool {
|
||||||
|
if Exist(key) {
|
||||||
|
return Update(key, value)
|
||||||
|
} else {
|
||||||
|
return Insert(key, value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, field := range fields {
|
||||||
|
if !InsertOrUpdate(field, values[i]) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = tx.Commit()
|
||||||
|
if checkError(err, "Fail to commit") {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
368
storage/sql/settings.go
Normal file
368
storage/sql/settings.go
Normal file
@ -0,0 +1,368 @@
|
|||||||
|
package sql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/filebrowser/filebrowser/v2/auth"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/files"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/rules"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/settings"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/users"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
}
|
||||||
|
|
||||||
|
type settingsBackend struct {
|
||||||
|
db *sql.DB
|
||||||
|
dbType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitSettingsTable(db *sql.DB, dbType string) error {
|
||||||
|
sql := fmt.Sprintf("create table if not exists %s (%s varchar(128) primary key, value text);", quoteName(dbType, SettingsTable), quoteName(dbType, "key"))
|
||||||
|
_, err := db.Exec(sql)
|
||||||
|
checkError(err, "Fail to create table settings")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func bytesToString(data []byte) string {
|
||||||
|
return base64.RawStdEncoding.EncodeToString(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func bytesFromString(s string) ([]byte, error) {
|
||||||
|
return base64.RawStdEncoding.DecodeString(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
func userDefaultsFromString(s string) settings.UserDefaults {
|
||||||
|
if s == "" {
|
||||||
|
return settings.UserDefaults{}
|
||||||
|
}
|
||||||
|
userDefaults := settings.UserDefaults{}
|
||||||
|
err := json.Unmarshal([]byte(s), &userDefaults)
|
||||||
|
checkError(err, "Fail to parse settings.UserDefaults")
|
||||||
|
return userDefaults
|
||||||
|
}
|
||||||
|
|
||||||
|
func userDefaultsToString(d settings.UserDefaults) string {
|
||||||
|
data, err := json.Marshal(d)
|
||||||
|
if checkError(err, "Fail to stringify settings.UserDefaults") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func brandingFromString(s string) settings.Branding {
|
||||||
|
if s == "" {
|
||||||
|
return settings.Branding{}
|
||||||
|
}
|
||||||
|
branding := settings.Branding{}
|
||||||
|
err := json.Unmarshal([]byte(s), &branding)
|
||||||
|
checkError(err, "Fail to parse settings.Branding")
|
||||||
|
return branding
|
||||||
|
}
|
||||||
|
|
||||||
|
func brandingToString(s settings.Branding) string {
|
||||||
|
data, err := json.Marshal(s)
|
||||||
|
if checkError(err, "Fail to jsonify settings.Branding") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func commandsToString(c map[string][]string) string {
|
||||||
|
data, err := json.Marshal(c)
|
||||||
|
if checkError(err, "Fail to jsonify commands") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func commandsFromString(s string) map[string][]string {
|
||||||
|
c := make(map[string][]string)
|
||||||
|
if s == "" {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
err := json.Unmarshal([]byte(s), &c)
|
||||||
|
checkError(err, "Fail to parse commands")
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringsFromString(s string) []string {
|
||||||
|
c := make([]string, 0)
|
||||||
|
if s == "" {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
err := json.Unmarshal([]byte(s), &c)
|
||||||
|
checkError(err, "Fail to parse []string")
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringsToString(c []string) string {
|
||||||
|
data, err := json.Marshal(c)
|
||||||
|
if checkError(err, "Fail to jsonify strings") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolToInt(b bool) int {
|
||||||
|
if b {
|
||||||
|
return 1
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolFromInt(i int) bool {
|
||||||
|
if i == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolFromString(s string) bool {
|
||||||
|
if s == "0" || s == "" || s == "f" || s == "F" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolToString(b bool) string {
|
||||||
|
if b {
|
||||||
|
return "1"
|
||||||
|
}
|
||||||
|
return "0"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s settingsBackend) Get() (*settings.Settings, error) {
|
||||||
|
sql := fmt.Sprintf("select %s, value from %s;", quoteName(s.dbType, "key"), quoteName(s.dbType, SettingsTable))
|
||||||
|
rows, err := s.db.Query(sql)
|
||||||
|
if checkError(err, "Fail to Query settings.Settings") {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
key := ""
|
||||||
|
value := ""
|
||||||
|
settings1 := cloneSettings(defaultSettings)
|
||||||
|
for rows.Next() {
|
||||||
|
err = rows.Scan(&key, &value)
|
||||||
|
checkError(err, "Fail to query settings.Settings")
|
||||||
|
if key == "Key" {
|
||||||
|
val, err := bytesFromString(value)
|
||||||
|
if !checkError(err, "Fail to parse []byte from string") {
|
||||||
|
settings1.Key = val
|
||||||
|
}
|
||||||
|
} else if key == "Signup" {
|
||||||
|
settings1.Signup = boolFromString(value)
|
||||||
|
} else if key == "CreateUserDir" {
|
||||||
|
settings1.CreateUserDir = boolFromString(value)
|
||||||
|
} else if key == "UserHomeBasePath" {
|
||||||
|
settings1.UserHomeBasePath = value
|
||||||
|
} else if key == "Defaults" {
|
||||||
|
settings1.Defaults = userDefaultsFromString(value)
|
||||||
|
} else if key == "AuthMethod" {
|
||||||
|
settings1.AuthMethod = settings.AuthMethod(value)
|
||||||
|
} else if key == "Branding" {
|
||||||
|
settings1.Branding = brandingFromString(value)
|
||||||
|
} else if key == "Commands" {
|
||||||
|
settings1.Commands = commandsFromString(value)
|
||||||
|
} else if key == "Shell" {
|
||||||
|
settings1.Shell = stringsFromString(value)
|
||||||
|
} else if key == "Rules" {
|
||||||
|
settings1.Rules = rulesFromString(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(settings1.Key) == 0 {
|
||||||
|
fmt.Println("The tables may not exist. Please run 'filebrowser config init' first")
|
||||||
|
return &settings1, nil
|
||||||
|
}
|
||||||
|
return &settings1, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s settingsBackend) Save(ss *settings.Settings) error {
|
||||||
|
fields := []string{"Key", "Signup", "CreateUserDir", "UserHomeBasePath", "Defaults", "AuthMethod", "Branding", "Commands", "Shell", "Rules"}
|
||||||
|
values := []string{
|
||||||
|
bytesToString(ss.Key),
|
||||||
|
boolToString(ss.Signup),
|
||||||
|
boolToString(ss.CreateUserDir),
|
||||||
|
ss.UserHomeBasePath,
|
||||||
|
userDefaultsToString(ss.Defaults),
|
||||||
|
string(ss.AuthMethod),
|
||||||
|
brandingToString(ss.Branding),
|
||||||
|
commandsToString(ss.Commands),
|
||||||
|
stringsToString(ss.Shell),
|
||||||
|
RulesToString(ss.Rules),
|
||||||
|
}
|
||||||
|
tx, err := s.db.Begin()
|
||||||
|
if checkError(err, "Fail to begin db transaction") {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
table := quoteName(s.dbType, SettingsTable)
|
||||||
|
k := quoteName(s.dbType, "key")
|
||||||
|
p1 := placeHolder(s.dbType, 1)
|
||||||
|
p2 := placeHolder(s.dbType, 2)
|
||||||
|
for i, field := range fields {
|
||||||
|
exists := ContainKey(s.db, s.dbType, field)
|
||||||
|
sql := fmt.Sprintf("INSERT INTO %s (value, %s) VALUES(%s,%s);", table, k, p1, p2)
|
||||||
|
if exists {
|
||||||
|
sql = fmt.Sprintf("UPDATE %s set value = %s where %s = %s;", table, p1, k, p2)
|
||||||
|
}
|
||||||
|
stmt, err := s.db.Prepare(sql)
|
||||||
|
defer stmt.Close()
|
||||||
|
if checkError(err, "Fail to prepare statement") {
|
||||||
|
tx.Rollback()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
_, err = stmt.Exec(values[i], field)
|
||||||
|
if checkError(err, "Fail to insert field "+field+" of settings") {
|
||||||
|
tx.Rollback()
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
err = tx.Commit()
|
||||||
|
if checkError(err, "Fail to commit") {
|
||||||
|
tx.Rollback()
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
var defaultSettings = settings.Settings{
|
||||||
|
Key: []byte(""),
|
||||||
|
Signup: false,
|
||||||
|
CreateUserDir: false,
|
||||||
|
UserHomeBasePath: "/users",
|
||||||
|
Defaults: settings.UserDefaults{
|
||||||
|
Scope: ".",
|
||||||
|
Locale: "en",
|
||||||
|
ViewMode: "mosaic",
|
||||||
|
SingleClick: false,
|
||||||
|
Sorting: files.Sorting{
|
||||||
|
By: "",
|
||||||
|
Asc: false,
|
||||||
|
},
|
||||||
|
Perm: users.Permissions{
|
||||||
|
Admin: false,
|
||||||
|
Execute: true,
|
||||||
|
Create: true,
|
||||||
|
Rename: true,
|
||||||
|
Modify: true,
|
||||||
|
Delete: true,
|
||||||
|
Share: true,
|
||||||
|
Download: true,
|
||||||
|
},
|
||||||
|
Commands: make([]string, 0),
|
||||||
|
HideDotfiles: false,
|
||||||
|
DateFormat: false,
|
||||||
|
},
|
||||||
|
AuthMethod: auth.MethodJSONAuth,
|
||||||
|
Branding: settings.Branding{
|
||||||
|
Name: "",
|
||||||
|
DisableExternal: false,
|
||||||
|
Files: "",
|
||||||
|
Theme: "",
|
||||||
|
Color: "",
|
||||||
|
},
|
||||||
|
Commands: make(map[string][]string),
|
||||||
|
Shell: make([]string, 0),
|
||||||
|
Rules: make([]rules.Rule, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneSettings(s settings.Settings) settings.Settings {
|
||||||
|
data, err := json.Marshal(s)
|
||||||
|
s1 := settings.Settings{}
|
||||||
|
if checkError(err, "Fail to clone settings.Settings") {
|
||||||
|
return s1
|
||||||
|
}
|
||||||
|
json.Unmarshal(data, &s1)
|
||||||
|
return s1
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetSetting(db *sql.DB, dbType string, key string, value string) error {
|
||||||
|
t := quoteName(dbType, SettingsTable)
|
||||||
|
k := quoteName(dbType, "key")
|
||||||
|
sql := fmt.Sprintf("select count(%s) from %s where %s = '%s';", k, t, k, key)
|
||||||
|
count := 0
|
||||||
|
err := db.QueryRow(sql).Scan(&count)
|
||||||
|
if checkError(err, "Fail to QueryRow for key="+key) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if count == 0 {
|
||||||
|
return addSetting(db, dbType, key, value)
|
||||||
|
}
|
||||||
|
return updateSetting(db, dbType, key, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetSetting(db *sql.DB, dbType string, key string) string {
|
||||||
|
sql := fmt.Sprintf("select value from %s where %s = '%s';", quoteName(dbType, SettingsTable), quoteName(dbType, "key"), key)
|
||||||
|
value := ""
|
||||||
|
err := db.QueryRow(sql).Scan(&value)
|
||||||
|
if checkError(err, "") {
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
return value
|
||||||
|
}
|
||||||
|
|
||||||
|
func addSetting(db *sql.DB, dbType string, key string, value string) error {
|
||||||
|
table := quoteName(dbType, SettingsTable)
|
||||||
|
k := quoteName(dbType, "key")
|
||||||
|
p1 := placeHolder(dbType, 1)
|
||||||
|
p2 := placeHolder(dbType, 2)
|
||||||
|
sql := fmt.Sprintf("insert into %s (%s, value) values(%s, %s);", table, k, p1, p2)
|
||||||
|
stmt, err := db.Prepare(sql)
|
||||||
|
if checkError(err, "Fail to prepare sql") {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = stmt.Exec(key, value)
|
||||||
|
checkError(err, "Fail to add settings")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func updateSetting(db *sql.DB, dbType string, key string, value string) error {
|
||||||
|
sql := fmt.Sprintf(
|
||||||
|
"update %s set value = %s where %s = %s;",
|
||||||
|
quoteName(dbType, SettingsTable),
|
||||||
|
placeHolder(dbType, 1),
|
||||||
|
quoteName(dbType, "key"),
|
||||||
|
placeHolder(dbType, 2),
|
||||||
|
)
|
||||||
|
stmt, err := db.Prepare(sql)
|
||||||
|
if checkError(err, "Fail to prepare sql") {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = stmt.Exec(key, value)
|
||||||
|
checkError(err, "Fail to updateSetting")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func HadSetting(db *sql.DB) bool {
|
||||||
|
dbType, err := GetDBType(db)
|
||||||
|
if checkError(err, "Fail to get db type") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
key := GetSetting(db, dbType, "Key")
|
||||||
|
if key == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func ContainKey(db *sql.DB, dbType string, key string) bool {
|
||||||
|
sql := fmt.Sprintf("select value from %s where %s = '%s';", quoteName(dbType, SettingsTable), quoteName(dbType, "key"), key)
|
||||||
|
value := ""
|
||||||
|
err := db.QueryRow(sql).Scan(&value)
|
||||||
|
if checkError(err, "") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func HadSettingOfKey(db *sql.DB, dbType string, key string) bool {
|
||||||
|
return GetSetting(db, dbType, "Key") == key
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSettingsBackend(db *sql.DB, dbType string) settingsBackend {
|
||||||
|
InitSettingsTable(db, dbType)
|
||||||
|
return settingsBackend{db: db, dbType: dbType}
|
||||||
|
}
|
||||||
106
storage/sql/share.go
Normal file
106
storage/sql/share.go
Normal file
@ -0,0 +1,106 @@
|
|||||||
|
package sql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/filebrowser/filebrowser/v2/share"
|
||||||
|
)
|
||||||
|
|
||||||
|
type shareBackend struct {
|
||||||
|
db *sql.DB
|
||||||
|
dbType string
|
||||||
|
}
|
||||||
|
|
||||||
|
type linkRecord interface {
|
||||||
|
Scan(dest ...interface{}) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitSharesTable(db *sql.DB, dbType string) error {
|
||||||
|
sql := fmt.Sprintf("create table if not exists %s (hash text, path text, userid integer, expire integer, passwordhash text, token text)", quoteName(dbType, SharesTable))
|
||||||
|
_, err := db.Exec(sql)
|
||||||
|
checkError(err, "Fail to InitSharesTable")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseLink(row linkRecord) (*share.Link, error) {
|
||||||
|
path := ""
|
||||||
|
hash := ""
|
||||||
|
userid := uint(0)
|
||||||
|
expire := int64(0)
|
||||||
|
passwordhash := ""
|
||||||
|
token := ""
|
||||||
|
err := row.Scan(&path, &hash, &userid, &expire, &passwordhash, &token)
|
||||||
|
if checkError(err, "Fail to parse record for share.Link") {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
link := share.Link{}
|
||||||
|
link.Path = path
|
||||||
|
link.Hash = hash
|
||||||
|
link.UserID = userid
|
||||||
|
link.Expire = expire
|
||||||
|
link.PasswordHash = passwordhash
|
||||||
|
link.Token = token
|
||||||
|
return &link, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func queryLinks(db *sql.DB, dbType string, condition string) ([]*share.Link, error) {
|
||||||
|
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s", quoteName(dbType, SharesTable))
|
||||||
|
if len(condition) > 0 {
|
||||||
|
sql = sql + " where " + condition
|
||||||
|
}
|
||||||
|
rows, err := db.Query(sql)
|
||||||
|
if checkError(err, "Fail to Query links") {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var links []*share.Link = []*share.Link{}
|
||||||
|
for rows.Next() {
|
||||||
|
link, err := parseLink(rows)
|
||||||
|
if checkError(err, "Fail to parse record for share.Link") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
links = append(links, link)
|
||||||
|
}
|
||||||
|
return links, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s shareBackend) All() ([]*share.Link, error) {
|
||||||
|
return queryLinks(s.db, s.dbType, "")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s shareBackend) FindByUserID(id uint) ([]*share.Link, error) {
|
||||||
|
condition := fmt.Sprintf("userid=%d", id)
|
||||||
|
return queryLinks(s.db, s.dbType, condition)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s shareBackend) GetByHash(hash string) (*share.Link, error) {
|
||||||
|
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s where hash='%s'", quoteName(s.dbType, SharesTable), hash)
|
||||||
|
return parseLink(s.db.QueryRow(sql))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s shareBackend) GetPermanent(path string, id uint) (*share.Link, error) {
|
||||||
|
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s where path='%s' and userid=%d", quoteName(s.dbType, SharesTable), path, id)
|
||||||
|
return parseLink(s.db.QueryRow(sql))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s shareBackend) Gets(path string, id uint) ([]*share.Link, error) {
|
||||||
|
condition := fmt.Sprintf("userid=%d and path='%s'", id, path)
|
||||||
|
return queryLinks(s.db, s.dbType, condition)
|
||||||
|
}
|
||||||
|
func (s shareBackend) Save(l *share.Link) error {
|
||||||
|
sql := fmt.Sprintf("insert into %s (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", quoteName(s.dbType, SharesTable), l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token)
|
||||||
|
_, err := s.db.Exec(sql)
|
||||||
|
checkError(err, "Fail to Save share")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
func (s shareBackend) Delete(hash string) error {
|
||||||
|
sql := fmt.Sprintf("DELETE FROM %s WHERE hash='%s'", quoteName(s.dbType, SharesTable), hash)
|
||||||
|
_, err := s.db.Exec(sql)
|
||||||
|
checkError(err, "Fail to Delete share")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func newShareBackend(db *sql.DB, dbType string) shareBackend {
|
||||||
|
InitSharesTable(db, dbType)
|
||||||
|
return shareBackend{db: db, dbType: dbType}
|
||||||
|
}
|
||||||
195
storage/sql/sql.go
Normal file
195
storage/sql/sql.go
Normal file
@ -0,0 +1,195 @@
|
|||||||
|
package sql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/filebrowser/filebrowser/v2/auth"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/settings"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/share"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/storage"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/users"
|
||||||
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
}
|
||||||
|
|
||||||
|
type DBConnectionRecord struct {
|
||||||
|
db *sql.DB
|
||||||
|
dbType string
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
dbRecords map[string]*DBConnectionRecord = map[string]*DBConnectionRecord{}
|
||||||
|
)
|
||||||
|
|
||||||
|
// GetDBType used to get the driver type of a sql.DB
|
||||||
|
// It is based on existing dbRecords
|
||||||
|
// All sql.DB should opened by OpenDB
|
||||||
|
func GetDBType(db *sql.DB) (string, error) {
|
||||||
|
for _, record := range dbRecords {
|
||||||
|
if record.db == db {
|
||||||
|
return record.dbType, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return "", errors.New("No such database open by this module")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNameQuote(dbType string) string {
|
||||||
|
if dbType == "mysql" {
|
||||||
|
return "`"
|
||||||
|
}
|
||||||
|
return "\""
|
||||||
|
}
|
||||||
|
|
||||||
|
// for mysql, it is “
|
||||||
|
// for postgres and sqlite, it is ""
|
||||||
|
func quoteName(dbType string, name string) string {
|
||||||
|
q := getNameQuote(dbType)
|
||||||
|
return q + name + q
|
||||||
|
}
|
||||||
|
|
||||||
|
// placeholder for sql stmt
|
||||||
|
// for postgres, it is $1, $2, $3...
|
||||||
|
// for mysql and sqlite3, it is ?,?,?...
|
||||||
|
func placeHolder(dbType string, index int) string {
|
||||||
|
if index <= 0 {
|
||||||
|
panic("the placeholder index should >= 1")
|
||||||
|
}
|
||||||
|
if dbType == "postgres" || dbType == "postgresql" {
|
||||||
|
return fmt.Sprintf("$%d", index)
|
||||||
|
}
|
||||||
|
return "?"
|
||||||
|
}
|
||||||
|
|
||||||
|
func IsDBPath(path string) bool {
|
||||||
|
prefixes := []string{"sqlite3", "postgres", "postgresql", "mysql"}
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if strings.HasPrefix(path, prefix+"://") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func OpenDB(path string) (*sql.DB, error) {
|
||||||
|
if val, ok := dbRecords[path]; ok {
|
||||||
|
return val.db, nil
|
||||||
|
}
|
||||||
|
prefixes := []string{"sqlite3", "postgres", "postgresql", "mysql"}
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if strings.HasPrefix(path, prefix) {
|
||||||
|
db, err := connectDB(prefix, path)
|
||||||
|
if !checkError(err, "Fail to connect database "+path) {
|
||||||
|
dbRecords[path] = &DBConnectionRecord{db: db, dbType: prefix, path: path}
|
||||||
|
}
|
||||||
|
return db, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, errors.New("Unsupported db scheme")
|
||||||
|
}
|
||||||
|
|
||||||
|
type DatabaseResource struct {
|
||||||
|
scheme string
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
host string
|
||||||
|
port int
|
||||||
|
database string
|
||||||
|
}
|
||||||
|
|
||||||
|
func ParseDatabasePath(path string) (*DatabaseResource, error) {
|
||||||
|
pattern := "^(([a-zA-Z0-9]+)://)?(([^:]+)(:(.*))?@)?([a-zA-Z0-9_.]+)(:([0-9]+))?(/([a-zA-Z0-9_-]+))?$"
|
||||||
|
reg, err := regexp.Compile(pattern)
|
||||||
|
if checkError(err, "Fail to compile regexp") {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
matches := reg.FindAllStringSubmatch(path, -1)
|
||||||
|
if matches == nil || len(matches) == 0 {
|
||||||
|
return nil, errors.New("Fail to parse database")
|
||||||
|
}
|
||||||
|
r := DatabaseResource{}
|
||||||
|
r.scheme = matches[0][2]
|
||||||
|
r.username = matches[0][4]
|
||||||
|
r.password = matches[0][6]
|
||||||
|
r.host = matches[0][7]
|
||||||
|
if len(matches[0][9]) > 0 {
|
||||||
|
port, err := strconv.Atoi(matches[0][9])
|
||||||
|
if !checkError(err, "Fail to parse port") {
|
||||||
|
r.port = port
|
||||||
|
}
|
||||||
|
}
|
||||||
|
r.database = matches[0][11]
|
||||||
|
return &r, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// mysql://user:password@host:port/db => mysql://user:password@tcp(host:port)/db
|
||||||
|
func transformMysqlPath(path string) (string, error) {
|
||||||
|
r, err := ParseDatabasePath(path)
|
||||||
|
if checkError(err, "Fail to parse database path") {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
scheme := r.scheme
|
||||||
|
if len(scheme) == 0 {
|
||||||
|
scheme = "mysql"
|
||||||
|
}
|
||||||
|
credential := ""
|
||||||
|
if len(r.username) > 0 && len(r.password) > 0 {
|
||||||
|
credential = r.username + ":" + r.password + "@"
|
||||||
|
} else if len(r.username) > 0 {
|
||||||
|
credential = r.username + "@"
|
||||||
|
}
|
||||||
|
host := r.host
|
||||||
|
port := r.port
|
||||||
|
if port == 0 {
|
||||||
|
port = 3306
|
||||||
|
}
|
||||||
|
if len(r.database) == 0 {
|
||||||
|
return "", errors.New("no database found in path")
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s://%stcp(%s:%d)/%s", scheme, credential, host, port, r.database), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectDB(dbType string, path string) (*sql.DB, error) {
|
||||||
|
if dbType == "sqlite3" && strings.HasPrefix(path, "sqlite3://") {
|
||||||
|
path = strings.TrimPrefix(path, "sqlite3://")
|
||||||
|
} else if dbType == "mysql" && strings.HasPrefix(path, "mysql://") {
|
||||||
|
p, err := transformMysqlPath(path)
|
||||||
|
if checkError(err, "Fail to parse mysql path") {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
path = p
|
||||||
|
path = strings.TrimPrefix(path, "mysql://")
|
||||||
|
}
|
||||||
|
db, err := sql.Open(dbType, path)
|
||||||
|
if err == nil {
|
||||||
|
return db, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewStorage(db *sql.DB) (*storage.Storage, error) {
|
||||||
|
dbType, err := GetDBType(db)
|
||||||
|
checkError(err, "Fail to get database type, maybe this sql.DB is not opened by OpenDB")
|
||||||
|
|
||||||
|
userStore := users.NewStorage(newUsersBackend(db, dbType))
|
||||||
|
shareStore := share.NewStorage(newShareBackend(db, dbType))
|
||||||
|
settingsStore := settings.NewStorage(newSettingsBackend(db, dbType))
|
||||||
|
authStore := auth.NewStorage(newAuthBackend(db, dbType), userStore)
|
||||||
|
|
||||||
|
storage := &storage.Storage{
|
||||||
|
Auth: authStore,
|
||||||
|
Users: userStore,
|
||||||
|
Share: shareStore,
|
||||||
|
Settings: settingsStore,
|
||||||
|
}
|
||||||
|
return storage, nil
|
||||||
|
}
|
||||||
372
storage/sql/users.go
Normal file
372
storage/sql/users.go
Normal file
@ -0,0 +1,372 @@
|
|||||||
|
package sql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/filebrowser/filebrowser/v2/errors"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/files"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/rules"
|
||||||
|
"github.com/filebrowser/filebrowser/v2/users"
|
||||||
|
)
|
||||||
|
|
||||||
|
type usersBackend struct {
|
||||||
|
db *sql.DB
|
||||||
|
dbType string
|
||||||
|
}
|
||||||
|
|
||||||
|
func PermFromString(s string) users.Permissions {
|
||||||
|
var perm users.Permissions
|
||||||
|
if s == "" {
|
||||||
|
return perm
|
||||||
|
}
|
||||||
|
err := json.Unmarshal([]byte(s), &perm)
|
||||||
|
checkError(err, "Fail to parse perm from string")
|
||||||
|
return perm
|
||||||
|
}
|
||||||
|
|
||||||
|
func PermToString(perm users.Permissions) string {
|
||||||
|
data, err := json.Marshal(perm)
|
||||||
|
if checkError(err, "Fail to stringify users.Permissions") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func CommandsFromString(s string) []string {
|
||||||
|
if s == "" {
|
||||||
|
return make([]string, 0)
|
||||||
|
}
|
||||||
|
var commands []string
|
||||||
|
err := json.Unmarshal([]byte(s), &commands)
|
||||||
|
checkError(err, "Fail to parse users Commands")
|
||||||
|
return commands
|
||||||
|
}
|
||||||
|
|
||||||
|
func CommandsToString(commands []string) string {
|
||||||
|
data, err := json.Marshal(commands)
|
||||||
|
if checkError(err, "Fail to stringify users commands") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SortingFromString(s string) files.Sorting {
|
||||||
|
if s == "" {
|
||||||
|
return files.Sorting{}
|
||||||
|
}
|
||||||
|
var sorting files.Sorting
|
||||||
|
err := json.Unmarshal([]byte(s), &sorting)
|
||||||
|
checkError(err, "Fail to parse Sorting from string")
|
||||||
|
return sorting
|
||||||
|
}
|
||||||
|
|
||||||
|
func SortingToString(sorting files.Sorting) string {
|
||||||
|
data, err := json.Marshal(sorting)
|
||||||
|
if checkError(err, "Fail to stringify files.Sorting") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func rulesFromString(s string) []rules.Rule {
|
||||||
|
rules := make([]rules.Rule, 0)
|
||||||
|
if s == "" {
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
err := json.Unmarshal([]byte(s), &rules)
|
||||||
|
checkError(err, "Fail to parse Rules from string")
|
||||||
|
return rules
|
||||||
|
}
|
||||||
|
|
||||||
|
func RulesToString(rules []rules.Rule) string {
|
||||||
|
data, err := json.Marshal(rules)
|
||||||
|
if checkError(err, "Fail to stringify []rules.Rule") {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return string(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
var adminUser = createAdminUser()
|
||||||
|
|
||||||
|
func createAdminUser() users.User {
|
||||||
|
userDefault := defaultSettings.Defaults
|
||||||
|
user := users.User{}
|
||||||
|
user.Username = "admin"
|
||||||
|
user.Password = "admin"
|
||||||
|
user.Scope = userDefault.Scope
|
||||||
|
user.LockPassword = false
|
||||||
|
user.ViewMode = userDefault.ViewMode
|
||||||
|
user.Perm = users.Permissions{
|
||||||
|
Admin: true,
|
||||||
|
Execute: true,
|
||||||
|
Create: true,
|
||||||
|
Rename: true,
|
||||||
|
Modify: true,
|
||||||
|
Delete: true,
|
||||||
|
Share: true,
|
||||||
|
Download: true,
|
||||||
|
}
|
||||||
|
user.Commands = userDefault.Commands
|
||||||
|
user.Sorting = userDefault.Sorting
|
||||||
|
return user
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitUserTable(db *sql.DB, dbType string) error {
|
||||||
|
primaryKey := "integer primary key"
|
||||||
|
if dbType == "postgres" || dbType == "postgresql" {
|
||||||
|
primaryKey = "serial primary key"
|
||||||
|
} else if dbType == "mysql" {
|
||||||
|
primaryKey = "int unsigned primary key auto_increment"
|
||||||
|
}
|
||||||
|
sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id %s, username text, password text, scope text, locale text, lockpassword integer, viewmode text, perm text, commands text, sorting text, rules text, hidedotfiles integer, dateformat integer, singleclick integer);", quoteName(dbType, UsersTable), primaryKey)
|
||||||
|
_, err := db.Exec(sql)
|
||||||
|
checkError(err, "Fail to create users table")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUsersBackend(db *sql.DB, dbType string) usersBackend {
|
||||||
|
InitUserTable(db, dbType)
|
||||||
|
return usersBackend{db: db, dbType: dbType}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s usersBackend) IsPostgresql() bool {
|
||||||
|
return s.dbType == "postgres" || s.dbType == "postgresql"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s usersBackend) IsMysql() bool {
|
||||||
|
return s.dbType == "mysql"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s usersBackend) IsSqlite() bool {
|
||||||
|
return s.dbType == "sqlite3"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
|
||||||
|
columns := []string{"id", "username", "password", "scope", "locale", "lockpassword", "viewmode", "perm", "commands", "sorting", "rules", "hidedotfiles", "dateformat", "singleclick"}
|
||||||
|
columnsStr := strings.Join(columns, ",")
|
||||||
|
var conditionStr string
|
||||||
|
switch i.(type) {
|
||||||
|
case uint:
|
||||||
|
conditionStr = fmt.Sprintf("id=%v", i)
|
||||||
|
case int:
|
||||||
|
conditionStr = fmt.Sprintf("id=%v", i)
|
||||||
|
case string:
|
||||||
|
conditionStr = fmt.Sprintf("username='%v'", i)
|
||||||
|
default:
|
||||||
|
return nil, errors.ErrInvalidDataType
|
||||||
|
}
|
||||||
|
userID := uint(0)
|
||||||
|
username := ""
|
||||||
|
password := ""
|
||||||
|
scope := ""
|
||||||
|
locale := ""
|
||||||
|
lockpassword := false
|
||||||
|
var viewmode users.ViewMode = users.ListViewMode
|
||||||
|
perm := ""
|
||||||
|
commands := ""
|
||||||
|
sorting := ""
|
||||||
|
rules := ""
|
||||||
|
hidedotfiles := false
|
||||||
|
dateformat := false
|
||||||
|
singleclick := false
|
||||||
|
user := users.User{}
|
||||||
|
sql := fmt.Sprintf("SELECT %s FROM %s WHERE %s", columnsStr, quoteName(s.dbType, UsersTable), conditionStr)
|
||||||
|
err := s.db.QueryRow(sql).Scan(&userID, &username, &password, &scope, &locale, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules, &hidedotfiles, &dateformat, &singleclick)
|
||||||
|
if checkError(err, "") {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
user.ID = userID
|
||||||
|
user.Username = username
|
||||||
|
user.Password = password
|
||||||
|
user.Scope = scope
|
||||||
|
user.Locale = locale
|
||||||
|
user.LockPassword = lockpassword
|
||||||
|
user.ViewMode = viewmode
|
||||||
|
user.Perm = PermFromString(perm)
|
||||||
|
user.Commands = CommandsFromString(commands)
|
||||||
|
user.Sorting = SortingFromString(sorting)
|
||||||
|
user.Rules = rulesFromString(rules)
|
||||||
|
user.HideDotfiles = hidedotfiles
|
||||||
|
user.DateFormat = dateformat
|
||||||
|
user.SingleClick = singleclick
|
||||||
|
return &user, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s usersBackend) Gets() ([]*users.User, error) {
|
||||||
|
sql := fmt.Sprintf("SELECT id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules FROM %s", quoteName(s.dbType, UsersTable))
|
||||||
|
rows, err := s.db.Query(sql)
|
||||||
|
if checkError(err, "Fail to Query []*users.User") {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var users2 []*users.User = make([]*users.User, 0)
|
||||||
|
for rows.Next() {
|
||||||
|
id := 0
|
||||||
|
username := ""
|
||||||
|
password := ""
|
||||||
|
scope := ""
|
||||||
|
lockpassword := false
|
||||||
|
var viewmode users.ViewMode = "list"
|
||||||
|
perm := ""
|
||||||
|
commands := ""
|
||||||
|
sorting := ""
|
||||||
|
rules := ""
|
||||||
|
err := rows.Scan(&id, &username, &password, &scope, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules)
|
||||||
|
if checkError(err, "Fail to parse record for user.User") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
user := users.User{}
|
||||||
|
user.ID = uint(id)
|
||||||
|
user.Username = username
|
||||||
|
user.Password = password
|
||||||
|
user.Scope = scope
|
||||||
|
user.LockPassword = lockpassword
|
||||||
|
user.ViewMode = viewmode
|
||||||
|
user.Perm = PermFromString(perm)
|
||||||
|
user.Commands = CommandsFromString(commands)
|
||||||
|
user.Sorting = SortingFromString(sorting)
|
||||||
|
user.Rules = rulesFromString(rules)
|
||||||
|
|
||||||
|
users2 = append(users2, &user)
|
||||||
|
}
|
||||||
|
return users2, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s usersBackend) updateUser(id uint, user *users.User) error {
|
||||||
|
lockpassword := 0
|
||||||
|
if user.LockPassword {
|
||||||
|
lockpassword = 1
|
||||||
|
}
|
||||||
|
sql := fmt.Sprintf(
|
||||||
|
"UPDATE %s SET username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' WHERE id=%d",
|
||||||
|
quoteName(s.dbType, UsersTable),
|
||||||
|
user.Username,
|
||||||
|
user.Password,
|
||||||
|
user.Scope,
|
||||||
|
lockpassword,
|
||||||
|
user.ViewMode,
|
||||||
|
PermToString(user.Perm),
|
||||||
|
CommandsToString(user.Commands),
|
||||||
|
SortingToString(user.Sorting),
|
||||||
|
RulesToString(user.Rules),
|
||||||
|
user.ID,
|
||||||
|
)
|
||||||
|
_, err := s.db.Exec(sql)
|
||||||
|
checkError(err, "Fail to update user")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s usersBackend) insertUser(user *users.User) error {
|
||||||
|
columnSpec := [][]string{
|
||||||
|
{"username", "'%s'"},
|
||||||
|
{"password", "'%s'"},
|
||||||
|
{"scope", "'%s'"},
|
||||||
|
{"locale", "'%s'"},
|
||||||
|
{"lockpassword", "%s"},
|
||||||
|
{"viewmode", "'%s'"},
|
||||||
|
{"perm", "'%s'"},
|
||||||
|
{"commands", "'%s'"},
|
||||||
|
{"sorting", "'%s'"},
|
||||||
|
{"rules", "'%s'"},
|
||||||
|
{"hidedotfiles", "%s"},
|
||||||
|
{"dateformat", "%s"},
|
||||||
|
{"singleclick", "%s"},
|
||||||
|
}
|
||||||
|
columns := []string{}
|
||||||
|
specs := []string{}
|
||||||
|
for _, c := range columnSpec {
|
||||||
|
columns = append(columns, c[0])
|
||||||
|
specs = append(specs, c[1])
|
||||||
|
}
|
||||||
|
columnStr := strings.Join(columns, ",")
|
||||||
|
specStr := strings.Join(specs, ",")
|
||||||
|
sqlFormat := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoteName(s.dbType, UsersTable), columnStr, specStr)
|
||||||
|
if s.IsPostgresql() {
|
||||||
|
sqlFormat = sqlFormat + " RETURNING id;"
|
||||||
|
}
|
||||||
|
sql := fmt.Sprintf(
|
||||||
|
sqlFormat,
|
||||||
|
user.Username,
|
||||||
|
user.Password,
|
||||||
|
user.Scope,
|
||||||
|
user.Locale,
|
||||||
|
boolToString(user.LockPassword),
|
||||||
|
user.ViewMode,
|
||||||
|
PermToString(user.Perm),
|
||||||
|
CommandsToString(user.Commands),
|
||||||
|
SortingToString(user.Sorting),
|
||||||
|
RulesToString(user.Rules),
|
||||||
|
boolToString(user.HideDotfiles),
|
||||||
|
boolToString(user.DateFormat),
|
||||||
|
boolToString(user.SingleClick),
|
||||||
|
)
|
||||||
|
if s.IsPostgresql() {
|
||||||
|
id := uint(0)
|
||||||
|
err := s.db.QueryRow(sql).Scan(&id)
|
||||||
|
if !checkError(err, "Fail to insert user") {
|
||||||
|
user.ID = id
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
res, err := s.db.Exec(sql)
|
||||||
|
if !checkError(err, "Fail to insert user") {
|
||||||
|
id, err := res.LastInsertId()
|
||||||
|
checkError(err, "Fail to get last inserted id")
|
||||||
|
user.ID = uint(id)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s usersBackend) Save(user *users.User) error {
|
||||||
|
userOriginal, err := s.GetBy(user.Username)
|
||||||
|
checkError(err, "")
|
||||||
|
if userOriginal != nil {
|
||||||
|
return s.updateUser(user.ID, user)
|
||||||
|
}
|
||||||
|
return s.insertUser(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s usersBackend) DeleteByID(id uint) error {
|
||||||
|
sql := fmt.Sprintf("delete from %s where id=%d", quoteName(s.dbType, UsersTable), id)
|
||||||
|
_, err := s.db.Exec(sql)
|
||||||
|
checkError(err, "Fail to delete User by id")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s usersBackend) DeleteByUsername(username string) error {
|
||||||
|
sql := fmt.Sprintf("delete from %s where username='%s'", quoteName(s.dbType, UsersTable), username)
|
||||||
|
_, err := s.db.Exec(sql)
|
||||||
|
checkError(err, "Fail to delete user by username")
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s usersBackend) Update(u *users.User, fields ...string) error {
|
||||||
|
if len(fields) == 0 {
|
||||||
|
return s.Save(u)
|
||||||
|
}
|
||||||
|
var setItems = []string{}
|
||||||
|
for _, field := range fields {
|
||||||
|
userField := reflect.ValueOf(u).Elem().FieldByName(field)
|
||||||
|
if !userField.IsValid() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
field = strings.ToLower(field)
|
||||||
|
val := userField.Interface()
|
||||||
|
typeStr := reflect.TypeOf(val).Kind().String()
|
||||||
|
if typeStr == "string" {
|
||||||
|
setItems = append(setItems, fmt.Sprintf("%s='%s'", field, val))
|
||||||
|
} else if typeStr == "bool" {
|
||||||
|
setItems = append(setItems, fmt.Sprintf("%s=%s", field, boolToString(val.(bool))))
|
||||||
|
} else {
|
||||||
|
// TODO
|
||||||
|
setItems = append(setItems, fmt.Sprintf("%s=%s", field, val))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sql := fmt.Sprintf("UPDATE %s SET %s WHERE id=%d", UsersTable, strings.Join(setItems, ","), u.ID)
|
||||||
|
_, err := s.db.Exec(sql)
|
||||||
|
checkError(err, "Fail to update user")
|
||||||
|
return err
|
||||||
|
}
|
||||||
71
storage/sql/utils.go
Normal file
71
storage/sql/utils.go
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
package sql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"log"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getRuntimeFunctionName(frame uint) string {
|
||||||
|
pc := make([]uintptr, 1)
|
||||||
|
count := runtime.Callers(int(frame)+2, pc)
|
||||||
|
if count == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
f := runtime.FuncForPC(pc[0])
|
||||||
|
return f.Name()
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkError(err error, message string) bool {
|
||||||
|
if err != nil {
|
||||||
|
if len(message) > 0 {
|
||||||
|
funcname := filepath.Base(getRuntimeFunctionName(1))
|
||||||
|
log.Printf("ERROR [%s]: %s\n", funcname, err.Error())
|
||||||
|
log.Printf("ERROR [%s]: %s\n", funcname, message)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func logFunction() {
|
||||||
|
funcname := getRuntimeFunctionName(1)
|
||||||
|
log.Printf("%s is running\n", funcname)
|
||||||
|
}
|
||||||
|
|
||||||
|
func reverse(list []string) []string {
|
||||||
|
var output []string
|
||||||
|
for i := len(list) - 1; i >= 0; i-- {
|
||||||
|
output = append(output, list[i])
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|
||||||
|
func logBacktrace() {
|
||||||
|
funcs := make([]string, 0)
|
||||||
|
for _, i := range []int{1, 2, 3} {
|
||||||
|
p := filepath.Base(getRuntimeFunctionName(uint(i)))
|
||||||
|
if len(p) > 0 {
|
||||||
|
funcs = append(funcs, p)
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
funcs = reverse(funcs)
|
||||||
|
log.Printf("%s\n", strings.Join(funcs, " -> "))
|
||||||
|
}
|
||||||
|
|
||||||
|
func LogBacktrace() {
|
||||||
|
funcs := make([]string, 0)
|
||||||
|
for _, i := range []int{1, 2, 3} {
|
||||||
|
p := filepath.Base(getRuntimeFunctionName(uint(i)))
|
||||||
|
if len(p) > 0 {
|
||||||
|
funcs = append(funcs, p)
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
funcs = reverse(funcs)
|
||||||
|
log.Printf("%s\n", strings.Join(funcs, " -> "))
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue
Block a user