373 lines
9.9 KiB
Go
373 lines
9.9 KiB
Go
package sql
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"fmt"
|
|
"reflect"
|
|
"strings"
|
|
|
|
"github.com/filebrowser/filebrowser/v2/errors"
|
|
"github.com/filebrowser/filebrowser/v2/files"
|
|
"github.com/filebrowser/filebrowser/v2/rules"
|
|
"github.com/filebrowser/filebrowser/v2/users"
|
|
)
|
|
|
|
type usersBackend struct {
|
|
db *sql.DB
|
|
dbType string
|
|
}
|
|
|
|
func PermFromString(s string) users.Permissions {
|
|
var perm users.Permissions
|
|
if s == "" {
|
|
return perm
|
|
}
|
|
err := json.Unmarshal([]byte(s), &perm)
|
|
checkError(err, "Fail to parse perm from string")
|
|
return perm
|
|
}
|
|
|
|
func PermToString(perm users.Permissions) string {
|
|
data, err := json.Marshal(perm)
|
|
if checkError(err, "Fail to stringify users.Permissions") {
|
|
return ""
|
|
}
|
|
return string(data)
|
|
}
|
|
|
|
func CommandsFromString(s string) []string {
|
|
if s == "" {
|
|
return make([]string, 0)
|
|
}
|
|
var commands []string
|
|
err := json.Unmarshal([]byte(s), &commands)
|
|
checkError(err, "Fail to parse users Commands")
|
|
return commands
|
|
}
|
|
|
|
func CommandsToString(commands []string) string {
|
|
data, err := json.Marshal(commands)
|
|
if checkError(err, "Fail to stringify users commands") {
|
|
return ""
|
|
}
|
|
return string(data)
|
|
}
|
|
|
|
func SortingFromString(s string) files.Sorting {
|
|
if s == "" {
|
|
return files.Sorting{}
|
|
}
|
|
var sorting files.Sorting
|
|
err := json.Unmarshal([]byte(s), &sorting)
|
|
checkError(err, "Fail to parse Sorting from string")
|
|
return sorting
|
|
}
|
|
|
|
func SortingToString(sorting files.Sorting) string {
|
|
data, err := json.Marshal(sorting)
|
|
if checkError(err, "Fail to stringify files.Sorting") {
|
|
return ""
|
|
}
|
|
return string(data)
|
|
}
|
|
|
|
func rulesFromString(s string) []rules.Rule {
|
|
rules := make([]rules.Rule, 0)
|
|
if s == "" {
|
|
return rules
|
|
}
|
|
err := json.Unmarshal([]byte(s), &rules)
|
|
checkError(err, "Fail to parse Rules from string")
|
|
return rules
|
|
}
|
|
|
|
func RulesToString(rules []rules.Rule) string {
|
|
data, err := json.Marshal(rules)
|
|
if checkError(err, "Fail to stringify []rules.Rule") {
|
|
return ""
|
|
}
|
|
return string(data)
|
|
}
|
|
|
|
var adminUser = createAdminUser()
|
|
|
|
func createAdminUser() users.User {
|
|
userDefault := defaultSettings.Defaults
|
|
user := users.User{}
|
|
user.Username = "admin"
|
|
user.Password = "admin"
|
|
user.Scope = userDefault.Scope
|
|
user.LockPassword = false
|
|
user.ViewMode = userDefault.ViewMode
|
|
user.Perm = users.Permissions{
|
|
Admin: true,
|
|
Execute: true,
|
|
Create: true,
|
|
Rename: true,
|
|
Modify: true,
|
|
Delete: true,
|
|
Share: true,
|
|
Download: true,
|
|
}
|
|
user.Commands = userDefault.Commands
|
|
user.Sorting = userDefault.Sorting
|
|
return user
|
|
}
|
|
|
|
func InitUserTable(db *sql.DB, dbType string) error {
|
|
primaryKey := "integer primary key"
|
|
if dbType == "postgres" || dbType == "postgresql" {
|
|
primaryKey = "serial primary key"
|
|
} else if dbType == "mysql" {
|
|
primaryKey = "int unsigned primary key auto_increment"
|
|
}
|
|
sql := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (id %s, username text, password text, scope text, locale text, lockpassword integer, viewmode text, perm text, commands text, sorting text, rules text, hidedotfiles integer, dateformat integer, singleclick integer);", quoteName(dbType, UsersTable), primaryKey)
|
|
_, err := db.Exec(sql)
|
|
checkError(err, "Fail to create users table")
|
|
return err
|
|
}
|
|
|
|
func newUsersBackend(db *sql.DB, dbType string) usersBackend {
|
|
InitUserTable(db, dbType)
|
|
return usersBackend{db: db, dbType: dbType}
|
|
}
|
|
|
|
func (s usersBackend) IsPostgresql() bool {
|
|
return s.dbType == "postgres" || s.dbType == "postgresql"
|
|
}
|
|
|
|
func (s usersBackend) IsMysql() bool {
|
|
return s.dbType == "mysql"
|
|
}
|
|
|
|
func (s usersBackend) IsSqlite() bool {
|
|
return s.dbType == "sqlite3"
|
|
}
|
|
|
|
func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
|
|
columns := []string{"id", "username", "password", "scope", "locale", "lockpassword", "viewmode", "perm", "commands", "sorting", "rules", "hidedotfiles", "dateformat", "singleclick"}
|
|
columnsStr := strings.Join(columns, ",")
|
|
var conditionStr string
|
|
switch i.(type) {
|
|
case uint:
|
|
conditionStr = fmt.Sprintf("id=%v", i)
|
|
case int:
|
|
conditionStr = fmt.Sprintf("id=%v", i)
|
|
case string:
|
|
conditionStr = fmt.Sprintf("username='%v'", i)
|
|
default:
|
|
return nil, errors.ErrInvalidDataType
|
|
}
|
|
userID := uint(0)
|
|
username := ""
|
|
password := ""
|
|
scope := ""
|
|
locale := ""
|
|
lockpassword := false
|
|
var viewmode users.ViewMode = users.ListViewMode
|
|
perm := ""
|
|
commands := ""
|
|
sorting := ""
|
|
rules := ""
|
|
hidedotfiles := false
|
|
dateformat := false
|
|
singleclick := false
|
|
user := users.User{}
|
|
sql := fmt.Sprintf("SELECT %s FROM %s WHERE %s", columnsStr, quoteName(s.dbType, UsersTable), conditionStr)
|
|
err := s.db.QueryRow(sql).Scan(&userID, &username, &password, &scope, &locale, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules, &hidedotfiles, &dateformat, &singleclick)
|
|
if checkError(err, "") {
|
|
return nil, err
|
|
}
|
|
user.ID = userID
|
|
user.Username = username
|
|
user.Password = password
|
|
user.Scope = scope
|
|
user.Locale = locale
|
|
user.LockPassword = lockpassword
|
|
user.ViewMode = viewmode
|
|
user.Perm = PermFromString(perm)
|
|
user.Commands = CommandsFromString(commands)
|
|
user.Sorting = SortingFromString(sorting)
|
|
user.Rules = rulesFromString(rules)
|
|
user.HideDotfiles = hidedotfiles
|
|
user.DateFormat = dateformat
|
|
user.SingleClick = singleclick
|
|
return &user, nil
|
|
}
|
|
|
|
func (s usersBackend) Gets() ([]*users.User, error) {
|
|
sql := fmt.Sprintf("SELECT id, username, password, scope, lockpassword, viewmode, perm,commands,sorting,rules FROM %s", quoteName(s.dbType, UsersTable))
|
|
rows, err := s.db.Query(sql)
|
|
if checkError(err, "Fail to Query []*users.User") {
|
|
return nil, err
|
|
}
|
|
var users2 []*users.User = make([]*users.User, 0)
|
|
for rows.Next() {
|
|
id := 0
|
|
username := ""
|
|
password := ""
|
|
scope := ""
|
|
lockpassword := false
|
|
var viewmode users.ViewMode = "list"
|
|
perm := ""
|
|
commands := ""
|
|
sorting := ""
|
|
rules := ""
|
|
err := rows.Scan(&id, &username, &password, &scope, &lockpassword, &viewmode, &perm, &commands, &sorting, &rules)
|
|
if checkError(err, "Fail to parse record for user.User") {
|
|
continue
|
|
}
|
|
user := users.User{}
|
|
user.ID = uint(id)
|
|
user.Username = username
|
|
user.Password = password
|
|
user.Scope = scope
|
|
user.LockPassword = lockpassword
|
|
user.ViewMode = viewmode
|
|
user.Perm = PermFromString(perm)
|
|
user.Commands = CommandsFromString(commands)
|
|
user.Sorting = SortingFromString(sorting)
|
|
user.Rules = rulesFromString(rules)
|
|
|
|
users2 = append(users2, &user)
|
|
}
|
|
return users2, nil
|
|
}
|
|
|
|
func (s usersBackend) updateUser(id uint, user *users.User) error {
|
|
lockpassword := 0
|
|
if user.LockPassword {
|
|
lockpassword = 1
|
|
}
|
|
sql := fmt.Sprintf(
|
|
"UPDATE %s SET username='%s',password='%s',scope='%s',lockpassword=%d,viewmode='%s',perm='%s',commands='%s',sorting='%s',rules='%s' WHERE id=%d",
|
|
quoteName(s.dbType, UsersTable),
|
|
user.Username,
|
|
user.Password,
|
|
user.Scope,
|
|
lockpassword,
|
|
user.ViewMode,
|
|
PermToString(user.Perm),
|
|
CommandsToString(user.Commands),
|
|
SortingToString(user.Sorting),
|
|
RulesToString(user.Rules),
|
|
user.ID,
|
|
)
|
|
_, err := s.db.Exec(sql)
|
|
checkError(err, "Fail to update user")
|
|
return err
|
|
}
|
|
|
|
func (s usersBackend) insertUser(user *users.User) error {
|
|
columnSpec := [][]string{
|
|
{"username", "'%s'"},
|
|
{"password", "'%s'"},
|
|
{"scope", "'%s'"},
|
|
{"locale", "'%s'"},
|
|
{"lockpassword", "%s"},
|
|
{"viewmode", "'%s'"},
|
|
{"perm", "'%s'"},
|
|
{"commands", "'%s'"},
|
|
{"sorting", "'%s'"},
|
|
{"rules", "'%s'"},
|
|
{"hidedotfiles", "%s"},
|
|
{"dateformat", "%s"},
|
|
{"singleclick", "%s"},
|
|
}
|
|
columns := []string{}
|
|
specs := []string{}
|
|
for _, c := range columnSpec {
|
|
columns = append(columns, c[0])
|
|
specs = append(specs, c[1])
|
|
}
|
|
columnStr := strings.Join(columns, ",")
|
|
specStr := strings.Join(specs, ",")
|
|
sqlFormat := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoteName(s.dbType, UsersTable), columnStr, specStr)
|
|
if s.IsPostgresql() {
|
|
sqlFormat = sqlFormat + " RETURNING id;"
|
|
}
|
|
sql := fmt.Sprintf(
|
|
sqlFormat,
|
|
user.Username,
|
|
user.Password,
|
|
user.Scope,
|
|
user.Locale,
|
|
boolToString(user.LockPassword),
|
|
user.ViewMode,
|
|
PermToString(user.Perm),
|
|
CommandsToString(user.Commands),
|
|
SortingToString(user.Sorting),
|
|
RulesToString(user.Rules),
|
|
boolToString(user.HideDotfiles),
|
|
boolToString(user.DateFormat),
|
|
boolToString(user.SingleClick),
|
|
)
|
|
if s.IsPostgresql() {
|
|
id := uint(0)
|
|
err := s.db.QueryRow(sql).Scan(&id)
|
|
if !checkError(err, "Fail to insert user") {
|
|
user.ID = id
|
|
}
|
|
return err
|
|
}
|
|
res, err := s.db.Exec(sql)
|
|
if !checkError(err, "Fail to insert user") {
|
|
id, err := res.LastInsertId()
|
|
checkError(err, "Fail to get last inserted id")
|
|
user.ID = uint(id)
|
|
}
|
|
return err
|
|
}
|
|
|
|
func (s usersBackend) Save(user *users.User) error {
|
|
userOriginal, err := s.GetBy(user.Username)
|
|
checkError(err, "")
|
|
if userOriginal != nil {
|
|
return s.updateUser(user.ID, user)
|
|
}
|
|
return s.insertUser(user)
|
|
}
|
|
|
|
func (s usersBackend) DeleteByID(id uint) error {
|
|
sql := fmt.Sprintf("delete from %s where id=%d", quoteName(s.dbType, UsersTable), id)
|
|
_, err := s.db.Exec(sql)
|
|
checkError(err, "Fail to delete User by id")
|
|
return err
|
|
}
|
|
|
|
func (s usersBackend) DeleteByUsername(username string) error {
|
|
sql := fmt.Sprintf("delete from %s where username='%s'", quoteName(s.dbType, UsersTable), username)
|
|
_, err := s.db.Exec(sql)
|
|
checkError(err, "Fail to delete user by username")
|
|
return err
|
|
}
|
|
|
|
func (s usersBackend) Update(u *users.User, fields ...string) error {
|
|
if len(fields) == 0 {
|
|
return s.Save(u)
|
|
}
|
|
var setItems = []string{}
|
|
for _, field := range fields {
|
|
userField := reflect.ValueOf(u).Elem().FieldByName(field)
|
|
if !userField.IsValid() {
|
|
continue
|
|
}
|
|
field = strings.ToLower(field)
|
|
val := userField.Interface()
|
|
typeStr := reflect.TypeOf(val).Kind().String()
|
|
if typeStr == "string" {
|
|
setItems = append(setItems, fmt.Sprintf("%s='%s'", field, val))
|
|
} else if typeStr == "bool" {
|
|
setItems = append(setItems, fmt.Sprintf("%s=%s", field, boolToString(val.(bool))))
|
|
} else {
|
|
// TODO
|
|
setItems = append(setItems, fmt.Sprintf("%s=%s", field, val))
|
|
}
|
|
}
|
|
sql := fmt.Sprintf("UPDATE %s SET %s WHERE id=%d", UsersTable, strings.Join(setItems, ","), u.ID)
|
|
_, err := s.db.Exec(sql)
|
|
checkError(err, "Fail to update user")
|
|
return err
|
|
}
|