Merge b88010b779 into e167c3e1ef
This commit is contained in:
commit
2b7c79b60c
@ -45,7 +45,7 @@ func init() {
|
||||
persistent := rootCmd.PersistentFlags()
|
||||
|
||||
persistent.StringVarP(&cfgFile, "config", "c", "", "config file path")
|
||||
persistent.StringP("database", "d", "./filebrowser.db", "database path")
|
||||
persistent.StringP("database", "d", "./filebrowser.db", "database path. Possible choices:\n\tbolt: ./bolt.db\n\tsqlite3: sqlite3://test.db\n\tpostgresql: postgres://user:password@192.168.1.10:5432/postgres?sslmode=disable\n\tmysql: mysql://user:password@192.168.1.10:3306/root\n")
|
||||
flags.Bool("noauth", false, "use the noauth auther when using quick setup")
|
||||
flags.String("username", "admin", "username for the first user when using quick config")
|
||||
flags.String("password", "", "hashed password for the first user when using quick config (default \"admin\")")
|
||||
|
||||
61
cmd/utils.go
61
cmd/utils.go
@ -17,6 +17,7 @@ import (
|
||||
"github.com/filebrowser/filebrowser/v2/settings"
|
||||
"github.com/filebrowser/filebrowser/v2/storage"
|
||||
"github.com/filebrowser/filebrowser/v2/storage/bolt"
|
||||
"github.com/filebrowser/filebrowser/v2/storage/sql"
|
||||
)
|
||||
|
||||
func checkErr(err error) {
|
||||
@ -82,27 +83,51 @@ func dbExists(path string) (bool, error) {
|
||||
return false, err
|
||||
}
|
||||
|
||||
type Closeable interface {
|
||||
Close() error
|
||||
}
|
||||
|
||||
func openBoltDB(path string, cfg pythonConfig) (pythonData, Closeable) {
|
||||
data := pythonData{hadDB: true}
|
||||
exists, err := dbExists(path)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if exists && cfg.noDB {
|
||||
log.Fatal(path + " already exists")
|
||||
} else if !exists && !cfg.noDB && !cfg.allowNoDB {
|
||||
log.Fatal(path + " does not exist. Please run 'filebrowser config init' first.")
|
||||
}
|
||||
|
||||
data.hadDB = exists
|
||||
db, err := storm.Open(path)
|
||||
checkErr(err)
|
||||
data.store, err = bolt.NewStorage(db)
|
||||
checkErr(err)
|
||||
return data, db
|
||||
}
|
||||
|
||||
func openDB(path string, cfg pythonConfig) (pythonData, Closeable) {
|
||||
if sql.IsDBPath(path) {
|
||||
data := pythonData{hadDB: false}
|
||||
db, err := sql.OpenDB(path)
|
||||
if err != nil {
|
||||
log.Fatal("Fail to open database " + path)
|
||||
}
|
||||
data.store, err = sql.NewStorage(db)
|
||||
if err != nil {
|
||||
log.Fatal("Fail to create database storage for " + path)
|
||||
}
|
||||
data.hadDB = sql.HadSetting(db)
|
||||
return data, db
|
||||
}
|
||||
return openBoltDB(path, cfg)
|
||||
}
|
||||
|
||||
func python(fn pythonFunc, cfg pythonConfig) cobraFunc {
|
||||
return func(cmd *cobra.Command, args []string) {
|
||||
data := pythonData{hadDB: true}
|
||||
|
||||
path := getParam(cmd.Flags(), "database")
|
||||
exists, err := dbExists(path)
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
} else if exists && cfg.noDB {
|
||||
log.Fatal(path + " already exists")
|
||||
} else if !exists && !cfg.noDB && !cfg.allowNoDB {
|
||||
log.Fatal(path + " does not exist. Please run 'filebrowser config init' first.")
|
||||
}
|
||||
|
||||
data.hadDB = exists
|
||||
db, err := storm.Open(path)
|
||||
checkErr(err)
|
||||
data, db := openDB(getParam(cmd.Flags(), "database"), cfg)
|
||||
defer db.Close()
|
||||
data.store, err = bolt.NewStorage(db)
|
||||
checkErr(err)
|
||||
fn(cmd, args, data)
|
||||
}
|
||||
}
|
||||
|
||||
3
go.mod
3
go.mod
@ -7,11 +7,14 @@ require (
|
||||
github.com/disintegration/imaging v1.6.2
|
||||
github.com/dsoprea/go-exif/v3 v3.0.0-20201216222538-db167117f483
|
||||
github.com/flynn/go-shlex v0.0.0-20150515145356-3f9db97f8568
|
||||
github.com/go-sql-driver/mysql v1.6.0
|
||||
github.com/golang-jwt/jwt/v4 v4.4.3
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/maruel/natural v1.1.0
|
||||
github.com/marusama/semaphore/v2 v2.5.0
|
||||
github.com/mattn/go-sqlite3 v1.14.17
|
||||
github.com/mholt/archiver/v3 v3.5.1
|
||||
github.com/mitchellh/go-homedir v1.1.0
|
||||
github.com/pelletier/go-toml/v2 v2.0.6
|
||||
|
||||
6
go.sum
6
go.sum
@ -95,6 +95,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2
|
||||
github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8=
|
||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
|
||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
|
||||
github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE=
|
||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/golang-jwt/jwt/v4 v4.4.3 h1:Hxl6lhQFj4AnOX6MLrsCb/+7tCj7DxP7VA+2rDIq5AU=
|
||||
github.com/golang-jwt/jwt/v4 v4.4.3/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI=
|
||||
@ -196,6 +198,8 @@ github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3x
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
|
||||
github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0/go.mod h1:zJYVVT2jmtg6P3p1VtQj7WsuWi/y4VnjVBn7F8KPB3I=
|
||||
github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY=
|
||||
@ -204,6 +208,8 @@ github.com/maruel/natural v1.1.0 h1:2z1NgP/Vae+gYrtC0VuvrTJ6U35OuyUqDdfluLqMWuQ=
|
||||
github.com/maruel/natural v1.1.0/go.mod h1:eFVhYCcUOfZFxXoDZam8Ktya72wa79fNC3lc/leA0DQ=
|
||||
github.com/marusama/semaphore/v2 v2.5.0 h1:o/1QJD9DBYOWRnDhPwDVAXQn6mQYD0gZaS1Tpx6DJGM=
|
||||
github.com/marusama/semaphore/v2 v2.5.0/go.mod h1:z9nMiNUekt/LTpTUQdpp+4sJeYqUGpwMHfW0Z8V8fnQ=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/mholt/archiver/v3 v3.5.1 h1:rDjOBX9JSF5BvoJGvjqK479aL70qh9DIpZCl+k7Clwo=
|
||||
github.com/mholt/archiver/v3 v3.5.1/go.mod h1:e3dqJ7H78uzsRSEACH1joayhuSyhnonssnDhppzS1L4=
|
||||
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
|
||||
|
||||
64
scripts/test-sql.sh
Executable file
64
scripts/test-sql.sh
Executable file
@ -0,0 +1,64 @@
|
||||
#!/bin/bash
|
||||
|
||||
# stop_docker name
|
||||
stop_docker() {
|
||||
sudo docker stop $1 1>/dev/null 2>&1
|
||||
sudo docker rm $1 1>/dev/null 2>&1
|
||||
}
|
||||
|
||||
# start_mysql name password port
|
||||
start_mysql() {
|
||||
stop_docker $1
|
||||
sudo docker run --rm --name $1 -e MYSQL_ROOT_PASSWORD=$2 -p $3:3306 -d mysql
|
||||
}
|
||||
|
||||
# start_postgres name password port
|
||||
start_postgres() {
|
||||
stop_docker $1
|
||||
sudo docker run --rm --name $1 -e POSTGRES_PASSWORD=$2 -p $3:5432 -d postgres
|
||||
}
|
||||
|
||||
|
||||
test_sqlite() {
|
||||
rm -f test.db
|
||||
./filebrowser -a 0.0.0.0 -d sqlite3://test.db
|
||||
}
|
||||
|
||||
|
||||
test_postgres() {
|
||||
start_postgres test-postgres postgres 5433
|
||||
sleep 30
|
||||
./filebrowser -a 0.0.0.0 -d postgres://postgres:postgres@127.0.0.1:5433/postgres?sslmode=disable
|
||||
}
|
||||
|
||||
|
||||
test_mysql() {
|
||||
start_mysql test-mysql root 3307
|
||||
sleep 60
|
||||
./filebrowser -a 0.0.0.0 -d 'mysql://root:root@127.0.0.1:3307/mysql'
|
||||
}
|
||||
|
||||
help() {
|
||||
echo "USAGE: $0 sqlite|mysql|postgres"
|
||||
exit 1
|
||||
}
|
||||
|
||||
if (( $# == 0 )); then
|
||||
help
|
||||
fi
|
||||
|
||||
case $1 in
|
||||
sqlite)
|
||||
test_sqlite
|
||||
;;
|
||||
mysql)
|
||||
test_mysql
|
||||
;;
|
||||
postgres)
|
||||
test_postgres
|
||||
;;
|
||||
*)
|
||||
help
|
||||
esac
|
||||
|
||||
|
||||
47
storage/sql/auth.go
Normal file
47
storage/sql/auth.go
Normal file
@ -0,0 +1,47 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/filebrowser/filebrowser/v2/auth"
|
||||
"github.com/filebrowser/filebrowser/v2/errors"
|
||||
"github.com/filebrowser/filebrowser/v2/settings"
|
||||
)
|
||||
|
||||
type authBackend struct {
|
||||
db *sql.DB
|
||||
dbType string
|
||||
}
|
||||
|
||||
func (s authBackend) Get(t settings.AuthMethod) (auth.Auther, error) {
|
||||
var auther auth.Auther
|
||||
|
||||
switch t {
|
||||
case auth.MethodJSONAuth:
|
||||
auther = &auth.JSONAuth{}
|
||||
case auth.MethodProxyAuth:
|
||||
auther = &auth.ProxyAuth{}
|
||||
case auth.MethodHookAuth:
|
||||
auther = &auth.HookAuth{}
|
||||
case auth.MethodNoAuth:
|
||||
auther = &auth.NoAuth{}
|
||||
default:
|
||||
fmt.Println("ERROR: unknown auth method " + t)
|
||||
return nil, errors.ErrInvalidAuthMethod
|
||||
}
|
||||
return auther, nil
|
||||
}
|
||||
|
||||
func (s authBackend) Save(a auth.Auther) error {
|
||||
val, err := json.Marshal(a)
|
||||
if checkError(err, "Fail to save auth.Auther") {
|
||||
return err
|
||||
}
|
||||
return SetSetting(s.db, s.dbType, "auther", string(val))
|
||||
}
|
||||
|
||||
func newAuthBackend(db *sql.DB, dbType string) authBackend {
|
||||
return authBackend{db: db, dbType: dbType}
|
||||
}
|
||||
21
storage/sql/config.go
Normal file
21
storage/sql/config.go
Normal file
@ -0,0 +1,21 @@
|
||||
package sql
|
||||
|
||||
import "os"
|
||||
|
||||
var SettingsTable = "fb_settings"
|
||||
var UsersTable = "fb_users"
|
||||
var SharesTable = "fb_shares"
|
||||
|
||||
func getEnv(key string, defaultValue string) string {
|
||||
val := os.Getenv(key)
|
||||
if len(val) == 0 {
|
||||
return defaultValue
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func init() {
|
||||
SettingsTable = getEnv("FILEBROWSER_SETTINGS_TABLE", SettingsTable)
|
||||
UsersTable = getEnv("FILEBROWSER_USERS_TABLE", UsersTable)
|
||||
SharesTable = getEnv("FILEBROWSER_SHARES_TABLE", SharesTable)
|
||||
}
|
||||
160
storage/sql/server.go
Normal file
160
storage/sql/server.go
Normal file
@ -0,0 +1,160 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/filebrowser/filebrowser/v2/settings"
|
||||
)
|
||||
|
||||
var defaultServer = settings.Server{
|
||||
Port: "8080",
|
||||
Log: "stdout",
|
||||
EnableThumbnails: false,
|
||||
ResizePreview: false,
|
||||
EnableExec: false,
|
||||
TypeDetectionByHeader: false,
|
||||
}
|
||||
|
||||
func cloneServer(server settings.Server) settings.Server {
|
||||
data, err := json.Marshal(server)
|
||||
s := settings.Server{}
|
||||
if checkError(err, "Fail to clone settings.Server") {
|
||||
return s
|
||||
}
|
||||
err = json.Unmarshal(data, &s)
|
||||
checkError(err, "Fail to decode for settings.Server")
|
||||
return s
|
||||
}
|
||||
|
||||
func (s settingsBackend) GetServer() (*settings.Server, error) {
|
||||
sql := fmt.Sprintf("select %s, value from %s", quoteName(s.dbType, "key"), quoteName(s.dbType, SettingsTable))
|
||||
rows, err := s.db.Query(sql)
|
||||
if checkError(err, "Fail to Query for GetServer") {
|
||||
return nil, err
|
||||
}
|
||||
server := cloneServer(defaultServer)
|
||||
key := ""
|
||||
value := ""
|
||||
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&key, &value)
|
||||
if checkError(err, "Fail to query settings.Settings") {
|
||||
continue
|
||||
}
|
||||
if key == "Root" {
|
||||
server.Root = value
|
||||
} else if key == "BaseURL" {
|
||||
server.BaseURL = value
|
||||
} else if key == "Socket" {
|
||||
server.Socket = value
|
||||
} else if key == "TLSKey" {
|
||||
server.TLSKey = value
|
||||
} else if key == "TLSCert" {
|
||||
server.TLSCert = value
|
||||
} else if key == "Port" {
|
||||
server.Port = value
|
||||
} else if key == "Address" {
|
||||
server.Address = value
|
||||
} else if key == "Log" {
|
||||
server.Log = value
|
||||
} else if key == "EnableThumbnails" {
|
||||
server.EnableThumbnails = boolFromString(value)
|
||||
} else if key == "ResizePreview" {
|
||||
server.ResizePreview = boolFromString(value)
|
||||
} else if key == "EnableExec" {
|
||||
server.EnableExec = boolFromString(value)
|
||||
} else if key == "TypeDetectionByHeader" {
|
||||
server.TypeDetectionByHeader = boolFromString(value)
|
||||
} else if key == "AuthHook" {
|
||||
server.AuthHook = value
|
||||
}
|
||||
}
|
||||
return &server, nil
|
||||
}
|
||||
|
||||
func (s settingsBackend) SaveServer(ss *settings.Server) error {
|
||||
fields := []string{"Root", "BaseURL", "Socket", "TLSKey", "TLSCert", "Port", "Address", "Log", "EnableThumbnails", "ResizePreview", "EnableExec", "TypeDetectionByHeader", "AuthHook"}
|
||||
values := []string{
|
||||
ss.Root,
|
||||
ss.BaseURL,
|
||||
ss.Socket,
|
||||
ss.TLSKey,
|
||||
ss.TLSCert,
|
||||
ss.Port,
|
||||
ss.Address,
|
||||
ss.Log,
|
||||
boolToString(ss.EnableThumbnails),
|
||||
boolToString(ss.ResizePreview),
|
||||
boolToString(ss.EnableExec),
|
||||
boolToString(ss.TypeDetectionByHeader),
|
||||
ss.AuthHook}
|
||||
tx, err := s.db.Begin()
|
||||
if checkError(err, "Fail to begin db transaction") {
|
||||
return err
|
||||
}
|
||||
table := quoteName(s.dbType, SettingsTable)
|
||||
keyName := quoteName(s.dbType, "key")
|
||||
p1 := placeHolder(s.dbType, 1)
|
||||
p2 := placeHolder(s.dbType, 2)
|
||||
|
||||
Insert := func(key string, value string) bool {
|
||||
insertSql := fmt.Sprintf("INSERT INTO %s (%s, value) VALUES(%s,%s)", table, keyName, p1, p2)
|
||||
stmt, err := s.db.Prepare(insertSql)
|
||||
defer stmt.Close()
|
||||
if checkError(err, "Fail to prepare statement") {
|
||||
tx.Rollback()
|
||||
return false
|
||||
}
|
||||
_, err = stmt.Exec(key, value)
|
||||
if checkError(err, "Fail to insert field "+key+" of settings.Server") {
|
||||
tx.Rollback()
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
Update := func(key string, value string) bool {
|
||||
updateSql := fmt.Sprintf("UPDATE %s SET value=%s WHERE %s=%s", table, p1, keyName, p2)
|
||||
stmt, err := s.db.Prepare(updateSql)
|
||||
defer stmt.Close()
|
||||
if checkError(err, "Fail to prepare statement") {
|
||||
tx.Rollback()
|
||||
return false
|
||||
}
|
||||
_, err = stmt.Exec(value, key)
|
||||
if checkError(err, "Fail to update field "+key+" of settings.Server") {
|
||||
tx.Rollback()
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
Exist := func(key string) bool {
|
||||
querySql := fmt.Sprintf("SELECT count(*) FROM %s WHERE %s=%s", table, keyName, p1)
|
||||
row := s.db.QueryRow(querySql, key)
|
||||
count := 0
|
||||
err := row.Scan(&count)
|
||||
if checkError(err, "Fail to Query "+key+" for GetServer") {
|
||||
return false
|
||||
}
|
||||
return count == 1
|
||||
}
|
||||
InsertOrUpdate := func(key string, value string) bool {
|
||||
if Exist(key) {
|
||||
return Update(key, value)
|
||||
} else {
|
||||
return Insert(key, value)
|
||||
}
|
||||
}
|
||||
|
||||
for i, field := range fields {
|
||||
if !InsertOrUpdate(field, values[i]) {
|
||||
break
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if checkError(err, "Fail to commit") {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
368
storage/sql/settings.go
Normal file
368
storage/sql/settings.go
Normal file
@ -0,0 +1,368 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/filebrowser/filebrowser/v2/auth"
|
||||
"github.com/filebrowser/filebrowser/v2/files"
|
||||
"github.com/filebrowser/filebrowser/v2/rules"
|
||||
"github.com/filebrowser/filebrowser/v2/settings"
|
||||
"github.com/filebrowser/filebrowser/v2/users"
|
||||
)
|
||||
|
||||
func init() {
|
||||
}
|
||||
|
||||
type settingsBackend struct {
|
||||
db *sql.DB
|
||||
dbType string
|
||||
}
|
||||
|
||||
func InitSettingsTable(db *sql.DB, dbType string) error {
|
||||
sql := fmt.Sprintf("create table if not exists %s (%s varchar(128) primary key, value text);", quoteName(dbType, SettingsTable), quoteName(dbType, "key"))
|
||||
_, err := db.Exec(sql)
|
||||
checkError(err, "Fail to create table settings")
|
||||
return err
|
||||
}
|
||||
|
||||
func bytesToString(data []byte) string {
|
||||
return base64.RawStdEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
func bytesFromString(s string) ([]byte, error) {
|
||||
return base64.RawStdEncoding.DecodeString(s)
|
||||
}
|
||||
|
||||
func userDefaultsFromString(s string) settings.UserDefaults {
|
||||
if s == "" {
|
||||
return settings.UserDefaults{}
|
||||
}
|
||||
userDefaults := settings.UserDefaults{}
|
||||
err := json.Unmarshal([]byte(s), &userDefaults)
|
||||
checkError(err, "Fail to parse settings.UserDefaults")
|
||||
return userDefaults
|
||||
}
|
||||
|
||||
func userDefaultsToString(d settings.UserDefaults) string {
|
||||
data, err := json.Marshal(d)
|
||||
if checkError(err, "Fail to stringify settings.UserDefaults") {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func brandingFromString(s string) settings.Branding {
|
||||
if s == "" {
|
||||
return settings.Branding{}
|
||||
}
|
||||
branding := settings.Branding{}
|
||||
err := json.Unmarshal([]byte(s), &branding)
|
||||
checkError(err, "Fail to parse settings.Branding")
|
||||
return branding
|
||||
}
|
||||
|
||||
func brandingToString(s settings.Branding) string {
|
||||
data, err := json.Marshal(s)
|
||||
if checkError(err, "Fail to jsonify settings.Branding") {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func commandsToString(c map[string][]string) string {
|
||||
data, err := json.Marshal(c)
|
||||
if checkError(err, "Fail to jsonify commands") {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func commandsFromString(s string) map[string][]string {
|
||||
c := make(map[string][]string)
|
||||
if s == "" {
|
||||
return c
|
||||
}
|
||||
err := json.Unmarshal([]byte(s), &c)
|
||||
checkError(err, "Fail to parse commands")
|
||||
return c
|
||||
}
|
||||
|
||||
func stringsFromString(s string) []string {
|
||||
c := make([]string, 0)
|
||||
if s == "" {
|
||||
return c
|
||||
}
|
||||
err := json.Unmarshal([]byte(s), &c)
|
||||
checkError(err, "Fail to parse []string")
|
||||
return c
|
||||
}
|
||||
|
||||
func stringsToString(c []string) string {
|
||||
data, err := json.Marshal(c)
|
||||
if checkError(err, "Fail to jsonify strings") {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func boolToInt(b bool) int {
|
||||
if b {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func boolFromInt(i int) bool {
|
||||
if i == 0 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func boolFromString(s string) bool {
|
||||
if s == "0" || s == "" || s == "f" || s == "F" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func boolToString(b bool) string {
|
||||
if b {
|
||||
return "1"
|
||||
}
|
||||
return "0"
|
||||
}
|
||||
|
||||
func (s settingsBackend) Get() (*settings.Settings, error) {
|
||||
sql := fmt.Sprintf("select %s, value from %s;", quoteName(s.dbType, "key"), quoteName(s.dbType, SettingsTable))
|
||||
rows, err := s.db.Query(sql)
|
||||
if checkError(err, "Fail to Query settings.Settings") {
|
||||
return nil, err
|
||||
}
|
||||
key := ""
|
||||
value := ""
|
||||
settings1 := cloneSettings(defaultSettings)
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&key, &value)
|
||||
checkError(err, "Fail to query settings.Settings")
|
||||
if key == "Key" {
|
||||
val, err := bytesFromString(value)
|
||||
if !checkError(err, "Fail to parse []byte from string") {
|
||||
settings1.Key = val
|
||||
}
|
||||
} else if key == "Signup" {
|
||||
settings1.Signup = boolFromString(value)
|
||||
} else if key == "CreateUserDir" {
|
||||
settings1.CreateUserDir = boolFromString(value)
|
||||
} else if key == "UserHomeBasePath" {
|
||||
settings1.UserHomeBasePath = value
|
||||
} else if key == "Defaults" {
|
||||
settings1.Defaults = userDefaultsFromString(value)
|
||||
} else if key == "AuthMethod" {
|
||||
settings1.AuthMethod = settings.AuthMethod(value)
|
||||
} else if key == "Branding" {
|
||||
settings1.Branding = brandingFromString(value)
|
||||
} else if key == "Commands" {
|
||||
settings1.Commands = commandsFromString(value)
|
||||
} else if key == "Shell" {
|
||||
settings1.Shell = stringsFromString(value)
|
||||
} else if key == "Rules" {
|
||||
settings1.Rules = rulesFromString(value)
|
||||
}
|
||||
}
|
||||
if len(settings1.Key) == 0 {
|
||||
fmt.Println("The tables may not exist. Please run 'filebrowser config init' first")
|
||||
return &settings1, nil
|
||||
}
|
||||
return &settings1, nil
|
||||
}
|
||||
|
||||
func (s settingsBackend) Save(ss *settings.Settings) error {
|
||||
fields := []string{"Key", "Signup", "CreateUserDir", "UserHomeBasePath", "Defaults", "AuthMethod", "Branding", "Commands", "Shell", "Rules"}
|
||||
values := []string{
|
||||
bytesToString(ss.Key),
|
||||
boolToString(ss.Signup),
|
||||
boolToString(ss.CreateUserDir),
|
||||
ss.UserHomeBasePath,
|
||||
userDefaultsToString(ss.Defaults),
|
||||
string(ss.AuthMethod),
|
||||
brandingToString(ss.Branding),
|
||||
commandsToString(ss.Commands),
|
||||
stringsToString(ss.Shell),
|
||||
RulesToString(ss.Rules),
|
||||
}
|
||||
tx, err := s.db.Begin()
|
||||
if checkError(err, "Fail to begin db transaction") {
|
||||
return err
|
||||
}
|
||||
table := quoteName(s.dbType, SettingsTable)
|
||||
k := quoteName(s.dbType, "key")
|
||||
p1 := placeHolder(s.dbType, 1)
|
||||
p2 := placeHolder(s.dbType, 2)
|
||||
for i, field := range fields {
|
||||
exists := ContainKey(s.db, s.dbType, field)
|
||||
sql := fmt.Sprintf("INSERT INTO %s (value, %s) VALUES(%s,%s);", table, k, p1, p2)
|
||||
if exists {
|
||||
sql = fmt.Sprintf("UPDATE %s set value = %s where %s = %s;", table, p1, k, p2)
|
||||
}
|
||||
stmt, err := s.db.Prepare(sql)
|
||||
defer stmt.Close()
|
||||
if checkError(err, "Fail to prepare statement") {
|
||||
tx.Rollback()
|
||||
break
|
||||
}
|
||||
_, err = stmt.Exec(values[i], field)
|
||||
if checkError(err, "Fail to insert field "+field+" of settings") {
|
||||
tx.Rollback()
|
||||
break
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if checkError(err, "Fail to commit") {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
var defaultSettings = settings.Settings{
|
||||
Key: []byte(""),
|
||||
Signup: false,
|
||||
CreateUserDir: false,
|
||||
UserHomeBasePath: "/users",
|
||||
Defaults: settings.UserDefaults{
|
||||
Scope: ".",
|
||||
Locale: "en",
|
||||
ViewMode: "mosaic",
|
||||
SingleClick: false,
|
||||
Sorting: files.Sorting{
|
||||
By: "",
|
||||
Asc: false,
|
||||
},
|
||||
Perm: users.Permissions{
|
||||
Admin: false,
|
||||
Execute: true,
|
||||
Create: true,
|
||||
Rename: true,
|
||||
Modify: true,
|
||||
Delete: true,
|
||||
Share: true,
|
||||
Download: true,
|
||||
},
|
||||
Commands: make([]string, 0),
|
||||
HideDotfiles: false,
|
||||
DateFormat: false,
|
||||
},
|
||||
AuthMethod: auth.MethodJSONAuth,
|
||||
Branding: settings.Branding{
|
||||
Name: "",
|
||||
DisableExternal: false,
|
||||
Files: "",
|
||||
Theme: "",
|
||||
Color: "",
|
||||
},
|
||||
Commands: make(map[string][]string),
|
||||
Shell: make([]string, 0),
|
||||
Rules: make([]rules.Rule, 0),
|
||||
}
|
||||
|
||||
func cloneSettings(s settings.Settings) settings.Settings {
|
||||
data, err := json.Marshal(s)
|
||||
s1 := settings.Settings{}
|
||||
if checkError(err, "Fail to clone settings.Settings") {
|
||||
return s1
|
||||
}
|
||||
json.Unmarshal(data, &s1)
|
||||
return s1
|
||||
}
|
||||
|
||||
func SetSetting(db *sql.DB, dbType string, key string, value string) error {
|
||||
t := quoteName(dbType, SettingsTable)
|
||||
k := quoteName(dbType, "key")
|
||||
sql := fmt.Sprintf("select count(%s) from %s where %s = '%s';", k, t, k, key)
|
||||
count := 0
|
||||
err := db.QueryRow(sql).Scan(&count)
|
||||
if checkError(err, "Fail to QueryRow for key="+key) {
|
||||
return err
|
||||
}
|
||||
if count == 0 {
|
||||
return addSetting(db, dbType, key, value)
|
||||
}
|
||||
return updateSetting(db, dbType, key, value)
|
||||
}
|
||||
|
||||
func GetSetting(db *sql.DB, dbType string, key string) string {
|
||||
sql := fmt.Sprintf("select value from %s where %s = '%s';", quoteName(dbType, SettingsTable), quoteName(dbType, "key"), key)
|
||||
value := ""
|
||||
err := db.QueryRow(sql).Scan(&value)
|
||||
if checkError(err, "") {
|
||||
return value
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
func addSetting(db *sql.DB, dbType string, key string, value string) error {
|
||||
table := quoteName(dbType, SettingsTable)
|
||||
k := quoteName(dbType, "key")
|
||||
p1 := placeHolder(dbType, 1)
|
||||
p2 := placeHolder(dbType, 2)
|
||||
sql := fmt.Sprintf("insert into %s (%s, value) values(%s, %s);", table, k, p1, p2)
|
||||
stmt, err := db.Prepare(sql)
|
||||
if checkError(err, "Fail to prepare sql") {
|
||||
return err
|
||||
}
|
||||
_, err = stmt.Exec(key, value)
|
||||
checkError(err, "Fail to add settings")
|
||||
return err
|
||||
}
|
||||
|
||||
func updateSetting(db *sql.DB, dbType string, key string, value string) error {
|
||||
sql := fmt.Sprintf(
|
||||
"update %s set value = %s where %s = %s;",
|
||||
quoteName(dbType, SettingsTable),
|
||||
placeHolder(dbType, 1),
|
||||
quoteName(dbType, "key"),
|
||||
placeHolder(dbType, 2),
|
||||
)
|
||||
stmt, err := db.Prepare(sql)
|
||||
if checkError(err, "Fail to prepare sql") {
|
||||
return err
|
||||
}
|
||||
_, err = stmt.Exec(key, value)
|
||||
checkError(err, "Fail to updateSetting")
|
||||
return err
|
||||
}
|
||||
|
||||
func HadSetting(db *sql.DB) bool {
|
||||
dbType, err := GetDBType(db)
|
||||
if checkError(err, "Fail to get db type") {
|
||||
return false
|
||||
}
|
||||
key := GetSetting(db, dbType, "Key")
|
||||
if key == "" {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func ContainKey(db *sql.DB, dbType string, key string) bool {
|
||||
sql := fmt.Sprintf("select value from %s where %s = '%s';", quoteName(dbType, SettingsTable), quoteName(dbType, "key"), key)
|
||||
value := ""
|
||||
err := db.QueryRow(sql).Scan(&value)
|
||||
if checkError(err, "") {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func HadSettingOfKey(db *sql.DB, dbType string, key string) bool {
|
||||
return GetSetting(db, dbType, "Key") == key
|
||||
}
|
||||
|
||||
func newSettingsBackend(db *sql.DB, dbType string) settingsBackend {
|
||||
InitSettingsTable(db, dbType)
|
||||
return settingsBackend{db: db, dbType: dbType}
|
||||
}
|
||||
106
storage/sql/share.go
Normal file
106
storage/sql/share.go
Normal file
@ -0,0 +1,106 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/filebrowser/filebrowser/v2/share"
|
||||
)
|
||||
|
||||
type shareBackend struct {
|
||||
db *sql.DB
|
||||
dbType string
|
||||
}
|
||||
|
||||
type linkRecord interface {
|
||||
Scan(dest ...interface{}) error
|
||||
}
|
||||
|
||||
func InitSharesTable(db *sql.DB, dbType string) error {
|
||||
sql := fmt.Sprintf("create table if not exists %s (hash text, path text, userid integer, expire integer, passwordhash text, token text)", quoteName(dbType, SharesTable))
|
||||
_, err := db.Exec(sql)
|
||||
checkError(err, "Fail to InitSharesTable")
|
||||
return err
|
||||
}
|
||||
|
||||
func parseLink(row linkRecord) (*share.Link, error) {
|
||||
path := ""
|
||||
hash := ""
|
||||
userid := uint(0)
|
||||
expire := int64(0)
|
||||
passwordhash := ""
|
||||
token := ""
|
||||
err := row.Scan(&path, &hash, &userid, &expire, &passwordhash, &token)
|
||||
if checkError(err, "Fail to parse record for share.Link") {
|
||||
return nil, err
|
||||
}
|
||||
link := share.Link{}
|
||||
link.Path = path
|
||||
link.Hash = hash
|
||||
link.UserID = userid
|
||||
link.Expire = expire
|
||||
link.PasswordHash = passwordhash
|
||||
link.Token = token
|
||||
return &link, nil
|
||||
}
|
||||
|
||||
func queryLinks(db *sql.DB, dbType string, condition string) ([]*share.Link, error) {
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s", quoteName(dbType, SharesTable))
|
||||
if len(condition) > 0 {
|
||||
sql = sql + " where " + condition
|
||||
}
|
||||
rows, err := db.Query(sql)
|
||||
if checkError(err, "Fail to Query links") {
|
||||
return nil, err
|
||||
}
|
||||
var links []*share.Link = []*share.Link{}
|
||||
for rows.Next() {
|
||||
link, err := parseLink(rows)
|
||||
if checkError(err, "Fail to parse record for share.Link") {
|
||||
continue
|
||||
}
|
||||
links = append(links, link)
|
||||
}
|
||||
return links, nil
|
||||
}
|
||||
|
||||
func (s shareBackend) All() ([]*share.Link, error) {
|
||||
return queryLinks(s.db, s.dbType, "")
|
||||
}
|
||||
|
||||
func (s shareBackend) FindByUserID(id uint) ([]*share.Link, error) {
|
||||
condition := fmt.Sprintf("userid=%d", id)
|
||||
return queryLinks(s.db, s.dbType, condition)
|
||||
}
|
||||
|
||||
func (s shareBackend) GetByHash(hash string) (*share.Link, error) {
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s where hash='%s'", quoteName(s.dbType, SharesTable), hash)
|
||||
return parseLink(s.db.QueryRow(sql))
|
||||
}
|
||||
|
||||
func (s shareBackend) GetPermanent(path string, id uint) (*share.Link, error) {
|
||||
sql := fmt.Sprintf("select hash, path, userid, expire, passwordhash, token from %s where path='%s' and userid=%d", quoteName(s.dbType, SharesTable), path, id)
|
||||
return parseLink(s.db.QueryRow(sql))
|
||||
}
|
||||
|
||||
func (s shareBackend) Gets(path string, id uint) ([]*share.Link, error) {
|
||||
condition := fmt.Sprintf("userid=%d and path='%s'", id, path)
|
||||
return queryLinks(s.db, s.dbType, condition)
|
||||
}
|
||||
func (s shareBackend) Save(l *share.Link) error {
|
||||
sql := fmt.Sprintf("insert into %s (hash, path, userid, expire, passwordhash, token) values('%s', '%s', %d, %d, '%s', '%s')", quoteName(s.dbType, SharesTable), l.Hash, l.Path, l.UserID, l.Expire, l.PasswordHash, l.Token)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to Save share")
|
||||
return err
|
||||
}
|
||||
func (s shareBackend) Delete(hash string) error {
|
||||
sql := fmt.Sprintf("DELETE FROM %s WHERE hash='%s'", quoteName(s.dbType, SharesTable), hash)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to Delete share")
|
||||
return err
|
||||
}
|
||||
|
||||
func newShareBackend(db *sql.DB, dbType string) shareBackend {
|
||||
InitSharesTable(db, dbType)
|
||||
return shareBackend{db: db, dbType: dbType}
|
||||
}
|
||||
195
storage/sql/sql.go
Normal file
195
storage/sql/sql.go
Normal file
@ -0,0 +1,195 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/filebrowser/filebrowser/v2/auth"
|
||||
"github.com/filebrowser/filebrowser/v2/settings"
|
||||
"github.com/filebrowser/filebrowser/v2/share"
|
||||
"github.com/filebrowser/filebrowser/v2/storage"
|
||||
"github.com/filebrowser/filebrowser/v2/users"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
_ "github.com/lib/pq"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
)
|
||||
|
||||
func init() {
|
||||
}
|
||||
|
||||
type DBConnectionRecord struct {
|
||||
db *sql.DB
|
||||
dbType string
|
||||
path string
|
||||
}
|
||||
|
||||
var (
|
||||
dbRecords map[string]*DBConnectionRecord = map[string]*DBConnectionRecord{}
|
||||
)
|
||||
|
||||
// GetDBType used to get the driver type of a sql.DB
|
||||
// It is based on existing dbRecords
|
||||
// All sql.DB should opened by OpenDB
|
||||
func GetDBType(db *sql.DB) (string, error) {
|
||||
for _, record := range dbRecords {
|
||||
if record.db == db {
|
||||
return record.dbType, nil
|
||||
}
|
||||
}
|
||||
return "", errors.New("No such database open by this module")
|
||||
}
|
||||
|
||||
func getNameQuote(dbType string) string {
|
||||
if dbType == "mysql" {
|
||||
return "`"
|
||||
}
|
||||
return "\""
|
||||
}
|
||||
|
||||
// for mysql, it is “
|
||||
// for postgres and sqlite, it is ""
|
||||
func quoteName(dbType string, name string) string {
|
||||
q := getNameQuote(dbType)
|
||||
return q + name + q
|
||||
}
|
||||
|
||||
// placeholder for sql stmt
|
||||
// for postgres, it is $1, $2, $3...
|
||||
// for mysql and sqlite3, it is ?,?,?...
|
||||
func placeHolder(dbType string, index int) string {
|
||||
if index <= 0 {
|
||||
panic("the placeholder index should >= 1")
|
||||
}
|
||||
if dbType == "postgres" || dbType == "postgresql" {
|
||||
return fmt.Sprintf("$%d", index)
|
||||
}
|
||||
return "?"
|
||||
}
|
||||
|
||||
func IsDBPath(path string) bool {
|
||||
prefixes := []string{"sqlite3", "postgres", "postgresql", "mysql"}
|
||||
for _, prefix := range prefixes {
|
||||
if strings.HasPrefix(path, prefix+"://") {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func OpenDB(path string) (*sql.DB, error) {
|
||||
if val, ok := dbRecords[path]; ok {
|
||||
return val.db, nil
|
||||
}
|
||||
prefixes := []string{"sqlite3", "postgres", "postgresql", "mysql"}
|
||||
for _, prefix := range prefixes {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
db, err := connectDB(prefix, path)
|
||||
if !checkError(err, "Fail to connect database "+path) {
|
||||
dbRecords[path] = &DBConnectionRecord{db: db, dbType: prefix, path: path}
|
||||
}
|
||||
return db, err
|
||||
}
|
||||
}
|
||||
return nil, errors.New("Unsupported db scheme")
|
||||
}
|
||||
|
||||
type DatabaseResource struct {
|
||||
scheme string
|
||||
username string
|
||||
password string
|
||||
host string
|
||||
port int
|
||||
database string
|
||||
}
|
||||
|
||||
func ParseDatabasePath(path string) (*DatabaseResource, error) {
|
||||
pattern := "^(([a-zA-Z0-9]+)://)?(([^:]+)(:(.*))?@)?([a-zA-Z0-9_.]+)(:([0-9]+))?(/([a-zA-Z0-9_-]+))?$"
|
||||
reg, err := regexp.Compile(pattern)
|
||||
if checkError(err, "Fail to compile regexp") {
|
||||
return nil, err
|
||||
}
|
||||
matches := reg.FindAllStringSubmatch(path, -1)
|
||||
if matches == nil || len(matches) == 0 {
|
||||
return nil, errors.New("Fail to parse database")
|
||||
}
|
||||
r := DatabaseResource{}
|
||||
r.scheme = matches[0][2]
|
||||
r.username = matches[0][4]
|
||||
r.password = matches[0][6]
|
||||
r.host = matches[0][7]
|
||||
if len(matches[0][9]) > 0 {
|
||||
port, err := strconv.Atoi(matches[0][9])
|
||||
if !checkError(err, "Fail to parse port") {
|
||||
r.port = port
|
||||
}
|
||||
}
|
||||
r.database = matches[0][11]
|
||||
return &r, nil
|
||||
}
|
||||
|
||||
// mysql://user:password@host:port/db => mysql://user:password@tcp(host:port)/db
|
||||
func transformMysqlPath(path string) (string, error) {
|
||||
r, err := ParseDatabasePath(path)
|
||||
if checkError(err, "Fail to parse database path") {
|
||||
return "", err
|
||||
}
|
||||
scheme := r.scheme
|
||||
if len(scheme) == 0 {
|
||||
scheme = "mysql"
|
||||
}
|
||||
credential := ""
|
||||
if len(r.username) > 0 && len(r.password) > 0 {
|
||||
credential = r.username + ":" + r.password + "@"
|
||||
} else if len(r.username) > 0 {
|
||||
credential = r.username + "@"
|
||||
}
|
||||
host := r.host
|
||||
port := r.port
|
||||
if port == 0 {
|
||||
port = 3306
|
||||
}
|
||||
if len(r.database) == 0 {
|
||||
return "", errors.New("no database found in path")
|
||||
}
|
||||
return fmt.Sprintf("%s://%stcp(%s:%d)/%s", scheme, credential, host, port, r.database), nil
|
||||
}
|
||||
|
||||
func connectDB(dbType string, path string) (*sql.DB, error) {
|
||||
if dbType == "sqlite3" && strings.HasPrefix(path, "sqlite3://") {
|
||||
path = strings.TrimPrefix(path, "sqlite3://")
|
||||
} else if dbType == "mysql" && strings.HasPrefix(path, "mysql://") {
|
||||
p, err := transformMysqlPath(path)
|
||||
if checkError(err, "Fail to parse mysql path") {
|
||||
return nil, err
|
||||
}
|
||||
path = p
|
||||
path = strings.TrimPrefix(path, "mysql://")
|
||||
}
|
||||
db, err := sql.Open(dbType, path)
|
||||
if err == nil {
|
||||
return db, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func NewStorage(db *sql.DB) (*storage.Storage, error) {
|
||||
dbType, err := GetDBType(db)
|
||||
checkError(err, "Fail to get database type, maybe this sql.DB is not opened by OpenDB")
|
||||
|
||||
userStore := users.NewStorage(newUsersBackend(db, dbType))
|
||||
shareStore := share.NewStorage(newShareBackend(db, dbType))
|
||||
settingsStore := settings.NewStorage(newSettingsBackend(db, dbType))
|
||||
authStore := auth.NewStorage(newAuthBackend(db, dbType), userStore)
|
||||
|
||||
storage := &storage.Storage{
|
||||
Auth: authStore,
|
||||
Users: userStore,
|
||||
Share: shareStore,
|
||||
Settings: settingsStore,
|
||||
}
|
||||
return storage, nil
|
||||
}
|
||||
372
storage/sql/users.go
Normal file
372
storage/sql/users.go
Normal file
@ -0,0 +1,372 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/filebrowser/filebrowser/v2/errors"
|
||||
"github.com/filebrowser/filebrowser/v2/files"
|
||||
"github.com/filebrowser/filebrowser/v2/rules"
|
||||
"github.com/filebrowser/filebrowser/v2/users"
|
||||
)
|
||||
|
||||
type usersBackend struct {
|
||||
db *sql.DB
|
||||
dbType string
|
||||
}
|
||||
|
||||
func PermFromString(s string) users.Permissions {
|
||||
var perm users.Permissions
|
||||
if s == "" {
|
||||
return perm
|
||||
}
|
||||
err := json.Unmarshal([]byte(s), &perm)
|
||||
checkError(err, "Fail to parse perm from string")
|
||||
return perm
|
||||
}
|
||||
|
||||
func PermToString(perm users.Permissions) string {
|
||||
data, err := json.Marshal(perm)
|
||||
if checkError(err, "Fail to stringify users.Permissions") {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func CommandsFromString(s string) []string {
|
||||
if s == "" {
|
||||
return make([]string, 0)
|
||||
}
|
||||
var commands []string
|
||||
err := json.Unmarshal([]byte(s), &commands)
|
||||
checkError(err, "Fail to parse users Commands")
|
||||
return commands
|
||||
}
|
||||
|
||||
func CommandsToString(commands []string) string {
|
||||
data, err := json.Marshal(commands)
|
||||
if checkError(err, "Fail to stringify users commands") {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func SortingFromString(s string) files.Sorting {
|
||||
if s == "" {
|
||||
return files.Sorting{}
|
||||
}
|
||||
var sorting files.Sorting
|
||||
err := json.Unmarshal([]byte(s), &sorting)
|
||||
checkError(err, "Fail to parse Sorting from string")
|
||||
return sorting
|
||||
}
|
||||
|
||||
func SortingToString(sorting files.Sorting) string {
|
||||
data, err := json.Marshal(sorting)
|
||||
if checkError(err, "Fail to stringify files.Sorting") {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
func rulesFromString(s string) []rules.Rule {
|
||||
rules := make([]rules.Rule, 0)
|
||||
if s == "" {
|
||||
return rules
|
||||
}
|
||||
err := json.Unmarshal([]byte(s), &rules)
|
||||
checkError(err, "Fail to parse Rules from string")
|
||||
return rules
|
||||
}
|
||||
|
||||
func RulesToString(rules []rules.Rule) string {
|
||||
data, err := json.Marshal(rules)
|
||||
if checkError(err, "Fail to stringify []rules.Rule") {
|
||||
return ""
|
||||
}
|
||||
return string(data)
|
||||
}
|
||||
|
||||
var adminUser = createAdminUser()
|
||||
|
||||
func createAdminUser() users.User {
|
||||
userDefault := defaultSettings.Defaults
|
||||
user := users.User{}
|
||||
user.Username = "admin"
|
||||
user.Password = "admin"
|
||||
user.Scope = userDefault.Scope
|
||||
user.LockPassword = false
|
||||
user.ViewMode = userDefault.ViewMode
|
||||
user.Perm = users.Permissions{
|
||||
Admin: true,
|
||||
Execute: true,
|
||||
Create: true,
|
||||
Rename: true,
|
||||
Modify: true,
|
||||
Delete: true,
|
||||
Share: true,
|
||||
Download: true,
|
||||
}
|
||||
user.Commands = userDefault.Commands
|
||||
user.Sorting = userDefault.Sorting
|
||||
return user
|
||||
}
|
||||
|
||||
func InitUserTable(db *sql.DB, dbType string) error {
|
||||
primaryKey := "integer primary key"
|
||||
if dbType == "postgres" || dbType == "postgresql" {
|
||||
primaryKey = "serial primary key"
|
||||
} else if dbType == "mysql" {
|
||||
primaryKey = "int unsigned primary key auto_increment"
|
||||
}
|
||||
sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id %s, username text, password text, scope text, locale text, lockpassword integer, viewmode text, perm text, commands text, sorting text, rules text, hidedotfiles integer, dateformat integer, singleclick integer);", quoteName(dbType, UsersTable), primaryKey)
|
||||
_, err := db.Exec(sql)
|
||||
checkError(err, "Fail to create users table")
|
||||
return err
|
||||
}
|
||||
|
||||
func newUsersBackend(db *sql.DB, dbType string) usersBackend {
|
||||
InitUserTable(db, dbType)
|
||||
return usersBackend{db: db, dbType: dbType}
|
||||
}
|
||||
|
||||
func (s usersBackend) IsPostgresql() bool {
|
||||
return s.dbType == "postgres" || s.dbType == "postgresql"
|
||||
}
|
||||
|
||||
func (s usersBackend) IsMysql() bool {
|
||||
return s.dbType == "mysql"
|
||||
}
|
||||
|
||||
func (s usersBackend) IsSqlite() bool {
|
||||
return s.dbType == "sqlite3"
|
||||
}
|
||||
|
||||
func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
|
||||
columns := []string{"id", "username", "password", "scope", "locale", "lockpassword", "viewmode", "perm", "commands", "sorting", "rules", "hidedotfiles", "dateformat", "singleclick"}
|
||||
columnsStr := strings.Join(columns, ",")
|
||||
var conditionStr string
|
||||
switch i.(type) {
|
||||
case uint:
|
||||
conditionStr = fmt.Sprintf("id=%v", i)
|
||||
case int:
|
||||
conditionStr = fmt.Sprintf("id=%v", i)
|
||||
case string:
|
||||
conditionStr = fmt.Sprintf("username='%v'", i)
|
||||
default:
|
||||
return nil, errors.ErrInvalidDataType
|
||||
}
|
||||
userID := uint(0)
|
||||
username := ""
|
||||
password := ""
|
||||
scope := ""
|
||||
locale := ""
|
||||
lockpassword := false
|
||||
var viewmode users.ViewMode = users.ListViewMode
|
||||
perm := ""
|
||||
commands := ""
|
||||
sorting := ""
|
||||
rules := ""
|
||||
hidedotfiles := false
|
||||
dateformat := false
|
||||
singleclick := false
|
||||
user := users.User{}
|
||||
sql := fmt.Sprintf("SELECT %s FROM %s WHERE %s", columnsStr, quoteName(s.dbType, UsersTable), conditionStr)
|
||||
err := s.db.QueryRow(sql).Scan(&userID, &username, &password, &scope, &locale, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules, &hidedotfiles, &dateformat, &singleclick)
|
||||
if checkError(err, "") {
|
||||
return nil, err
|
||||
}
|
||||
user.ID = userID
|
||||
user.Username = username
|
||||
user.Password = password
|
||||
user.Scope = scope
|
||||
user.Locale = locale
|
||||
user.LockPassword = lockpassword
|
||||
user.ViewMode = viewmode
|
||||
user.Perm = PermFromString(perm)
|
||||
user.Commands = CommandsFromString(commands)
|
||||
user.Sorting = SortingFromString(sorting)
|
||||
user.Rules = rulesFromString(rules)
|
||||
user.HideDotfiles = hidedotfiles
|
||||
user.DateFormat = dateformat
|
||||
user.SingleClick = singleclick
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s usersBackend) Gets() ([]*users.User, error) {
|
||||
sql := fmt.Sprintf("SELECT id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules FROM %s", quoteName(s.dbType, UsersTable))
|
||||
rows, err := s.db.Query(sql)
|
||||
if checkError(err, "Fail to Query []*users.User") {
|
||||
return nil, err
|
||||
}
|
||||
var users2 []*users.User = make([]*users.User, 0)
|
||||
for rows.Next() {
|
||||
id := 0
|
||||
username := ""
|
||||
password := ""
|
||||
scope := ""
|
||||
lockpassword := false
|
||||
var viewmode users.ViewMode = "list"
|
||||
perm := ""
|
||||
commands := ""
|
||||
sorting := ""
|
||||
rules := ""
|
||||
err := rows.Scan(&id, &username, &password, &scope, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules)
|
||||
if checkError(err, "Fail to parse record for user.User") {
|
||||
continue
|
||||
}
|
||||
user := users.User{}
|
||||
user.ID = uint(id)
|
||||
user.Username = username
|
||||
user.Password = password
|
||||
user.Scope = scope
|
||||
user.LockPassword = lockpassword
|
||||
user.ViewMode = viewmode
|
||||
user.Perm = PermFromString(perm)
|
||||
user.Commands = CommandsFromString(commands)
|
||||
user.Sorting = SortingFromString(sorting)
|
||||
user.Rules = rulesFromString(rules)
|
||||
|
||||
users2 = append(users2, &user)
|
||||
}
|
||||
return users2, nil
|
||||
}
|
||||
|
||||
func (s usersBackend) updateUser(id uint, user *users.User) error {
|
||||
lockpassword := 0
|
||||
if user.LockPassword {
|
||||
lockpassword = 1
|
||||
}
|
||||
sql := fmt.Sprintf(
|
||||
"UPDATE %s SET username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' WHERE id=%d",
|
||||
quoteName(s.dbType, UsersTable),
|
||||
user.Username,
|
||||
user.Password,
|
||||
user.Scope,
|
||||
lockpassword,
|
||||
user.ViewMode,
|
||||
PermToString(user.Perm),
|
||||
CommandsToString(user.Commands),
|
||||
SortingToString(user.Sorting),
|
||||
RulesToString(user.Rules),
|
||||
user.ID,
|
||||
)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to update user")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s usersBackend) insertUser(user *users.User) error {
|
||||
columnSpec := [][]string{
|
||||
{"username", "'%s'"},
|
||||
{"password", "'%s'"},
|
||||
{"scope", "'%s'"},
|
||||
{"locale", "'%s'"},
|
||||
{"lockpassword", "%s"},
|
||||
{"viewmode", "'%s'"},
|
||||
{"perm", "'%s'"},
|
||||
{"commands", "'%s'"},
|
||||
{"sorting", "'%s'"},
|
||||
{"rules", "'%s'"},
|
||||
{"hidedotfiles", "%s"},
|
||||
{"dateformat", "%s"},
|
||||
{"singleclick", "%s"},
|
||||
}
|
||||
columns := []string{}
|
||||
specs := []string{}
|
||||
for _, c := range columnSpec {
|
||||
columns = append(columns, c[0])
|
||||
specs = append(specs, c[1])
|
||||
}
|
||||
columnStr := strings.Join(columns, ",")
|
||||
specStr := strings.Join(specs, ",")
|
||||
sqlFormat := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoteName(s.dbType, UsersTable), columnStr, specStr)
|
||||
if s.IsPostgresql() {
|
||||
sqlFormat = sqlFormat + " RETURNING id;"
|
||||
}
|
||||
sql := fmt.Sprintf(
|
||||
sqlFormat,
|
||||
user.Username,
|
||||
user.Password,
|
||||
user.Scope,
|
||||
user.Locale,
|
||||
boolToString(user.LockPassword),
|
||||
user.ViewMode,
|
||||
PermToString(user.Perm),
|
||||
CommandsToString(user.Commands),
|
||||
SortingToString(user.Sorting),
|
||||
RulesToString(user.Rules),
|
||||
boolToString(user.HideDotfiles),
|
||||
boolToString(user.DateFormat),
|
||||
boolToString(user.SingleClick),
|
||||
)
|
||||
if s.IsPostgresql() {
|
||||
id := uint(0)
|
||||
err := s.db.QueryRow(sql).Scan(&id)
|
||||
if !checkError(err, "Fail to insert user") {
|
||||
user.ID = id
|
||||
}
|
||||
return err
|
||||
}
|
||||
res, err := s.db.Exec(sql)
|
||||
if !checkError(err, "Fail to insert user") {
|
||||
id, err := res.LastInsertId()
|
||||
checkError(err, "Fail to get last inserted id")
|
||||
user.ID = uint(id)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (s usersBackend) Save(user *users.User) error {
|
||||
userOriginal, err := s.GetBy(user.Username)
|
||||
checkError(err, "")
|
||||
if userOriginal != nil {
|
||||
return s.updateUser(user.ID, user)
|
||||
}
|
||||
return s.insertUser(user)
|
||||
}
|
||||
|
||||
func (s usersBackend) DeleteByID(id uint) error {
|
||||
sql := fmt.Sprintf("delete from %s where id=%d", quoteName(s.dbType, UsersTable), id)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to delete User by id")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s usersBackend) DeleteByUsername(username string) error {
|
||||
sql := fmt.Sprintf("delete from %s where username='%s'", quoteName(s.dbType, UsersTable), username)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to delete user by username")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s usersBackend) Update(u *users.User, fields ...string) error {
|
||||
if len(fields) == 0 {
|
||||
return s.Save(u)
|
||||
}
|
||||
var setItems = []string{}
|
||||
for _, field := range fields {
|
||||
userField := reflect.ValueOf(u).Elem().FieldByName(field)
|
||||
if !userField.IsValid() {
|
||||
continue
|
||||
}
|
||||
field = strings.ToLower(field)
|
||||
val := userField.Interface()
|
||||
typeStr := reflect.TypeOf(val).Kind().String()
|
||||
if typeStr == "string" {
|
||||
setItems = append(setItems, fmt.Sprintf("%s='%s'", field, val))
|
||||
} else if typeStr == "bool" {
|
||||
setItems = append(setItems, fmt.Sprintf("%s=%s", field, boolToString(val.(bool))))
|
||||
} else {
|
||||
// TODO
|
||||
setItems = append(setItems, fmt.Sprintf("%s=%s", field, val))
|
||||
}
|
||||
}
|
||||
sql := fmt.Sprintf("UPDATE %s SET %s WHERE id=%d", UsersTable, strings.Join(setItems, ","), u.ID)
|
||||
_, err := s.db.Exec(sql)
|
||||
checkError(err, "Fail to update user")
|
||||
return err
|
||||
}
|
||||
71
storage/sql/utils.go
Normal file
71
storage/sql/utils.go
Normal file
@ -0,0 +1,71 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"log"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func getRuntimeFunctionName(frame uint) string {
|
||||
pc := make([]uintptr, 1)
|
||||
count := runtime.Callers(int(frame)+2, pc)
|
||||
if count == 0 {
|
||||
return ""
|
||||
}
|
||||
f := runtime.FuncForPC(pc[0])
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
func checkError(err error, message string) bool {
|
||||
if err != nil {
|
||||
if len(message) > 0 {
|
||||
funcname := filepath.Base(getRuntimeFunctionName(1))
|
||||
log.Printf("ERROR [%s]: %s\n", funcname, err.Error())
|
||||
log.Printf("ERROR [%s]: %s\n", funcname, message)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func logFunction() {
|
||||
funcname := getRuntimeFunctionName(1)
|
||||
log.Printf("%s is running\n", funcname)
|
||||
}
|
||||
|
||||
func reverse(list []string) []string {
|
||||
var output []string
|
||||
for i := len(list) - 1; i >= 0; i-- {
|
||||
output = append(output, list[i])
|
||||
}
|
||||
return output
|
||||
}
|
||||
|
||||
func logBacktrace() {
|
||||
funcs := make([]string, 0)
|
||||
for _, i := range []int{1, 2, 3} {
|
||||
p := filepath.Base(getRuntimeFunctionName(uint(i)))
|
||||
if len(p) > 0 {
|
||||
funcs = append(funcs, p)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
funcs = reverse(funcs)
|
||||
log.Printf("%s\n", strings.Join(funcs, " -> "))
|
||||
}
|
||||
|
||||
func LogBacktrace() {
|
||||
funcs := make([]string, 0)
|
||||
for _, i := range []int{1, 2, 3} {
|
||||
p := filepath.Base(getRuntimeFunctionName(uint(i)))
|
||||
if len(p) > 0 {
|
||||
funcs = append(funcs, p)
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
funcs = reverse(funcs)
|
||||
log.Printf("%s\n", strings.Join(funcs, " -> "))
|
||||
}
|
||||
Loading…
Reference in New Issue
Block a user