This commit is contained in:
wwt 2023-12-24 21:57:48 +08:00
parent b3ee816c42
commit bbe9fd8789
2 changed files with 19 additions and 7 deletions

View File

@ -50,7 +50,7 @@ func getNameQuote(dbType string) string {
return "\""
}
// for mysql, it is ``
// for mysql, it is
// for postgres and sqlite, it is ""
func quoteName(dbType string, name string) string {
q := getNameQuote(dbType)
@ -64,14 +64,14 @@ func placeHolder(dbType string, index int) string {
if index <= 0 {
panic("the placeholder index should >= 1")
}
if dbType == "postgres" {
if dbType == "postgres" || dbType == "postgresql" {
return fmt.Sprintf("$%d", index)
}
return "?"
}
func IsDBPath(path string) bool {
prefixes := []string{"sqlite3", "postgres", "mysql"}
prefixes := []string{"sqlite3", "postgres", "postgresql", "mysql"}
for _, prefix := range prefixes {
if strings.HasPrefix(path, prefix+"://") {
return true
@ -84,7 +84,7 @@ func OpenDB(path string) (*sql.DB, error) {
if val, ok := dbRecords[path]; ok {
return val.db, nil
}
prefixes := []string{"sqlite3", "postgres", "mysql"}
prefixes := []string{"sqlite3", "postgres", "postgresql", "mysql"}
for _, prefix := range prefixes {
if strings.HasPrefix(path, prefix) {
db, err := connectDB(prefix, path)

View File

@ -117,7 +117,7 @@ func createAdminUser() users.User {
func InitUserTable(db *sql.DB, dbType string) error {
primaryKey := "integer primary key"
if dbType == "postgres" {
if dbType == "postgres" || dbType == "postgresql" {
primaryKey = "serial primary key"
} else if dbType == "mysql" {
primaryKey = "int unsigned primary key auto_increment"
@ -133,6 +133,18 @@ func newUsersBackend(db *sql.DB, dbType string) usersBackend {
return usersBackend{db: db, dbType: dbType}
}
func (s usersBackend) IsPostgresql() bool {
return s.dbType == "postgres" || s.dbType == "postgresql"
}
func (s usersBackend) IsMysql() bool {
return s.dbType == "mysql"
}
func (s usersBackend) IsSqlite() bool {
return s.dbType == "sqlite3"
}
func (s usersBackend) GetBy(i interface{}) (*users.User, error) {
columns := []string{"id", "username", "password", "scope", "locale", "lockpassword", "viewmode", "perm", "commands", "sorting", "rules", "hidedotfiles", "dateformat", "singleclick"}
columnsStr := strings.Join(columns, ",")
@ -272,7 +284,7 @@ func (s usersBackend) insertUser(user *users.User) error {
columnStr := strings.Join(columns, ",")
specStr := strings.Join(specs, ",")
sqlFormat := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", quoteName(s.dbType, UsersTable), columnStr, specStr)
if s.dbType == "postgres" {
if s.IsPostgresql() {
sqlFormat = sqlFormat + " RETURNING id;"
}
sql := fmt.Sprintf(
@ -291,7 +303,7 @@ func (s usersBackend) insertUser(user *users.User) error {
boolToString(user.DateFormat),
boolToString(user.SingleClick),
)
if s.dbType == "postgres" {
if s.IsPostgresql() {
id := uint(0)
err := s.db.QueryRow(sql).Scan(&id)
if !checkError(err, "Fail to insert user") {