196 lines
4.9 KiB
Go
196 lines
4.9 KiB
Go
package sql
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"regexp"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/filebrowser/filebrowser/v2/auth"
|
|
"github.com/filebrowser/filebrowser/v2/settings"
|
|
"github.com/filebrowser/filebrowser/v2/share"
|
|
"github.com/filebrowser/filebrowser/v2/storage"
|
|
"github.com/filebrowser/filebrowser/v2/users"
|
|
_ "github.com/go-sql-driver/mysql"
|
|
_ "github.com/lib/pq"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
func init() {
|
|
}
|
|
|
|
type DBConnectionRecord struct {
|
|
db *sql.DB
|
|
dbType string
|
|
path string
|
|
}
|
|
|
|
var (
|
|
dbRecords map[string]*DBConnectionRecord = map[string]*DBConnectionRecord{}
|
|
)
|
|
|
|
// GetDBType used to get the driver type of a sql.DB
|
|
// It is based on existing dbRecords
|
|
// All sql.DB should opened by OpenDB
|
|
func GetDBType(db *sql.DB) (string, error) {
|
|
for _, record := range dbRecords {
|
|
if record.db == db {
|
|
return record.dbType, nil
|
|
}
|
|
}
|
|
return "", errors.New("No such database open by this module")
|
|
}
|
|
|
|
func getNameQuote(dbType string) string {
|
|
if dbType == "mysql" {
|
|
return "`"
|
|
}
|
|
return "\""
|
|
}
|
|
|
|
// for mysql, it is “
|
|
// for postgres and sqlite, it is ""
|
|
func quoteName(dbType string, name string) string {
|
|
q := getNameQuote(dbType)
|
|
return q + name + q
|
|
}
|
|
|
|
// placeholder for sql stmt
|
|
// for postgres, it is $1, $2, $3...
|
|
// for mysql and sqlite3, it is ?,?,?...
|
|
func placeHolder(dbType string, index int) string {
|
|
if index <= 0 {
|
|
panic("the placeholder index should >= 1")
|
|
}
|
|
if dbType == "postgres" || dbType == "postgresql" {
|
|
return fmt.Sprintf("$%d", index)
|
|
}
|
|
return "?"
|
|
}
|
|
|
|
func IsDBPath(path string) bool {
|
|
prefixes := []string{"sqlite3", "postgres", "postgresql", "mysql"}
|
|
for _, prefix := range prefixes {
|
|
if strings.HasPrefix(path, prefix+"://") {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func OpenDB(path string) (*sql.DB, error) {
|
|
if val, ok := dbRecords[path]; ok {
|
|
return val.db, nil
|
|
}
|
|
prefixes := []string{"sqlite3", "postgres", "postgresql", "mysql"}
|
|
for _, prefix := range prefixes {
|
|
if strings.HasPrefix(path, prefix) {
|
|
db, err := connectDB(prefix, path)
|
|
if !checkError(err, "Fail to connect database "+path) {
|
|
dbRecords[path] = &DBConnectionRecord{db: db, dbType: prefix, path: path}
|
|
}
|
|
return db, err
|
|
}
|
|
}
|
|
return nil, errors.New("Unsupported db scheme")
|
|
}
|
|
|
|
type DatabaseResource struct {
|
|
scheme string
|
|
username string
|
|
password string
|
|
host string
|
|
port int
|
|
database string
|
|
}
|
|
|
|
func ParseDatabasePath(path string) (*DatabaseResource, error) {
|
|
pattern := "^(([a-zA-Z0-9]+)://)?(([^:]+)(:(.*))?@)?([a-zA-Z0-9_.]+)(:([0-9]+))?(/([a-zA-Z0-9_-]+))?$"
|
|
reg, err := regexp.Compile(pattern)
|
|
if checkError(err, "Fail to compile regexp") {
|
|
return nil, err
|
|
}
|
|
matches := reg.FindAllStringSubmatch(path, -1)
|
|
if matches == nil || len(matches) == 0 {
|
|
return nil, errors.New("Fail to parse database")
|
|
}
|
|
r := DatabaseResource{}
|
|
r.scheme = matches[0][2]
|
|
r.username = matches[0][4]
|
|
r.password = matches[0][6]
|
|
r.host = matches[0][7]
|
|
if len(matches[0][9]) > 0 {
|
|
port, err := strconv.Atoi(matches[0][9])
|
|
if !checkError(err, "Fail to parse port") {
|
|
r.port = port
|
|
}
|
|
}
|
|
r.database = matches[0][11]
|
|
return &r, nil
|
|
}
|
|
|
|
// mysql://user:password@host:port/db => mysql://user:password@tcp(host:port)/db
|
|
func transformMysqlPath(path string) (string, error) {
|
|
r, err := ParseDatabasePath(path)
|
|
if checkError(err, "Fail to parse database path") {
|
|
return "", err
|
|
}
|
|
scheme := r.scheme
|
|
if len(scheme) == 0 {
|
|
scheme = "mysql"
|
|
}
|
|
credential := ""
|
|
if len(r.username) > 0 && len(r.password) > 0 {
|
|
credential = r.username + ":" + r.password + "@"
|
|
} else if len(r.username) > 0 {
|
|
credential = r.username + "@"
|
|
}
|
|
host := r.host
|
|
port := r.port
|
|
if port == 0 {
|
|
port = 3306
|
|
}
|
|
if len(r.database) == 0 {
|
|
return "", errors.New("no database found in path")
|
|
}
|
|
return fmt.Sprintf("%s://%stcp(%s:%d)/%s", scheme, credential, host, port, r.database), nil
|
|
}
|
|
|
|
func connectDB(dbType string, path string) (*sql.DB, error) {
|
|
if dbType == "sqlite3" && strings.HasPrefix(path, "sqlite3://") {
|
|
path = strings.TrimPrefix(path, "sqlite3://")
|
|
} else if dbType == "mysql" && strings.HasPrefix(path, "mysql://") {
|
|
p, err := transformMysqlPath(path)
|
|
if checkError(err, "Fail to parse mysql path") {
|
|
return nil, err
|
|
}
|
|
path = p
|
|
path = strings.TrimPrefix(path, "mysql://")
|
|
}
|
|
db, err := sql.Open(dbType, path)
|
|
if err == nil {
|
|
return db, nil
|
|
}
|
|
return nil, err
|
|
}
|
|
|
|
func NewStorage(db *sql.DB) (*storage.Storage, error) {
|
|
dbType, err := GetDBType(db)
|
|
checkError(err, "Fail to get database type, maybe this sql.DB is not opened by OpenDB")
|
|
|
|
userStore := users.NewStorage(newUsersBackend(db, dbType))
|
|
shareStore := share.NewStorage(newShareBackend(db, dbType))
|
|
settingsStore := settings.NewStorage(newSettingsBackend(db, dbType))
|
|
authStore := auth.NewStorage(newAuthBackend(db, dbType), userStore)
|
|
|
|
storage := &storage.Storage{
|
|
Auth: authStore,
|
|
Users: userStore,
|
|
Share: shareStore,
|
|
Settings: settingsStore,
|
|
}
|
|
return storage, nil
|
|
}
|