From f7efee56f0b65dbc029e8b285a7dfd5d9beb6dbd Mon Sep 17 00:00:00 2001 From: slawk0 Date: Tue, 28 Jan 2025 15:26:22 +0100 Subject: [PATCH] code refactor, added middleware, --- config/config.go | 2 +- database/connect.go | 29 ++++++++++++++++++++++ database/db.go | 19 -------------- go.mod | 12 +++++---- go.sum | 15 +++++++++++ handlers/auth.go | 40 ++++++++++++++++++------------ main.go | 35 ++++++++------------------ middleware/protected.go | 24 ++++++++++++++++++ models/models.go => model/model.go | 2 +- router/router.go | 21 ++++++++++++++++ 10 files changed, 133 insertions(+), 66 deletions(-) create mode 100644 database/connect.go create mode 100644 middleware/protected.go rename models/models.go => model/model.go (95%) create mode 100644 router/router.go diff --git a/config/config.go b/config/config.go index ddc79b7..e194e4a 100644 --- a/config/config.go +++ b/config/config.go @@ -12,4 +12,4 @@ const PASSWORD_REGEX = `^[A-Za-z0-9!@#$%^&*(),.?":{}|<>]+$` const USERNAME_REGEX = `^[a-zA-Z0-9_]+$` -const BCRYPT_COST = 16 +const BCRYPT_COST = 12 diff --git a/database/connect.go b/database/connect.go new file mode 100644 index 0000000..ab042f1 --- /dev/null +++ b/database/connect.go @@ -0,0 +1,29 @@ +package database + +import ( + "database/sql" + "fmt" + "github.com/joho/godotenv" + "log" + "os" +) + +var DB *sql.DB + +func Init() (*sql.DB, error) { + err := godotenv.Load(".env") + if err != nil { + log.Fatal("Error loading .env file") + } + password := os.Getenv("PG_PASSWORD") + host := os.Getenv("PG_HOST") + connStr := fmt.Sprintf("user=postgres host=%s dbname=relay password=%s sslmode=disable", host, password) + + DB, err = sql.Open("postgres", connStr) + if err != nil { + log.Fatal("Failed to connect to database", err) + } + fmt.Println("Successfully connected to database") + + return DB, nil +} diff --git a/database/db.go b/database/db.go index 27f462e..62f0934 100644 --- a/database/db.go +++ b/database/db.go @@ -3,29 +3,10 @@ package database import ( "database/sql" "fmt" - "github.com/joho/godotenv" _ "github.com/lib/pq" "log" - "os" ) -func InitDatabase() (*sql.DB, error) { - err := godotenv.Load(".env") - if err != nil { - log.Fatal("Error loading .env file") - } - password := os.Getenv("PG_PASSWORD") - host := os.Getenv("PG_HOST") - connStr := fmt.Sprintf("user=postgres host=%s dbname=relay password=%s sslmode=disable", host, password) - DB, err := sql.Open("postgres", connStr) - if err != nil { - log.Fatal(err) - } - - return DB, nil - -} - func GetUsers(db *sql.DB) ([]string, error) { query := `SELECT username FROM accounts;` diff --git a/go.mod b/go.mod index 0e43da8..c6cdae7 100644 --- a/go.mod +++ b/go.mod @@ -10,16 +10,18 @@ require ( ) require ( - github.com/andybalholm/brotli v1.1.0 // indirect + github.com/MicahParks/keyfunc/v2 v2.1.0 // indirect + github.com/andybalholm/brotli v1.1.1 // indirect + github.com/gofiber/contrib/jwt v1.0.10 // indirect github.com/golang-jwt/jwt/v5 v5.2.1 // indirect github.com/google/uuid v1.6.0 // indirect - github.com/klauspost/compress v1.17.9 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect + github.com/klauspost/compress v1.17.11 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect - github.com/rivo/uniseg v0.2.0 // indirect + github.com/rivo/uniseg v0.4.7 // indirect github.com/valyala/bytebufferpool v1.0.0 // indirect - github.com/valyala/fasthttp v1.51.0 // indirect + github.com/valyala/fasthttp v1.58.0 // indirect github.com/valyala/tcplisten v1.0.0 // indirect golang.org/x/sys v0.29.0 // indirect ) diff --git a/go.sum b/go.sum index 1f5110d..a832a39 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,11 @@ +github.com/MicahParks/keyfunc/v2 v2.1.0 h1:6ZXKb9Rp6qp1bDbJefnG7cTH8yMN1IC/4nf+GVjO99k= +github.com/MicahParks/keyfunc/v2 v2.1.0/go.mod h1:rW42fi+xgLJ2FRRXAfNx9ZA8WpD4OeE/yHVMteCkw9k= github.com/andybalholm/brotli v1.1.0 h1:eLKJA0d02Lf0mVpIDgYnqXcUn0GqVmEFny3VuID1U3M= github.com/andybalholm/brotli v1.1.0/go.mod h1:sms7XGricyQI9K10gOSf56VKKWS4oLer58Q+mhRPtnY= +github.com/andybalholm/brotli v1.1.1 h1:PR2pgnyFznKEugtsUo0xLdDop5SKXd5Qf5ysW+7XdTA= +github.com/andybalholm/brotli v1.1.1/go.mod h1:05ib4cKhjx3OQYUY22hTVd34Bc8upXjOLL2rKwwZBoA= +github.com/gofiber/contrib/jwt v1.0.10 h1:/ilGepl6i0Bntl0Zcd+lAzagY8BiS1+fEiAj32HMApk= +github.com/gofiber/contrib/jwt v1.0.10/go.mod h1:1qBENE6sZ6PPT4xIpBzx1VxeyROQO7sj48OlM1I9qdU= github.com/gofiber/fiber/v2 v2.52.6 h1:Rfp+ILPiYSvvVuIPvxrBns+HJp8qGLDnLJawAu27XVI= github.com/gofiber/fiber/v2 v2.52.6/go.mod h1:YEcBbO/FB+5M1IZNBP9FO3J9281zgPAreiI1oqg8nDw= github.com/golang-jwt/jwt/v5 v5.2.1 h1:OuVbFODueb089Lh128TAcimifWaLhJwVflnrgM17wHk= @@ -10,10 +16,14 @@ github.com/joho/godotenv v1.5.1 h1:7eLL/+HRGLY0ldzfGMeQkb7vMd0as4CfYvUVzLqw0N0= github.com/joho/godotenv v1.5.1/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/klauspost/compress v1.17.9 h1:6KIumPrER1LHsvBVuDa0r5xaG0Es51mhhB9BQB2qeMA= github.com/klauspost/compress v1.17.9/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw= +github.com/klauspost/compress v1.17.11 h1:In6xLpyWOi1+C7tXUUWv2ot1QvBjxevKAaI6IXrJmUc= +github.com/klauspost/compress v1.17.11/go.mod h1:pMDklpSncoRMuLFrf1W9Ss9KT+0rH90U12bZKk7uwG0= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= @@ -21,12 +31,17 @@ github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6T github.com/mattn/go-runewidth v0.0.16/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= +github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/valyala/bytebufferpool v1.0.0 h1:GqA5TC/0021Y/b9FG4Oi9Mr3q7XYx6KllzawFIhcdPw= github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= github.com/valyala/fasthttp v1.51.0 h1:8b30A5JlZ6C7AS81RsWjYMQmrZG6feChmgAolCl1SqA= github.com/valyala/fasthttp v1.51.0/go.mod h1:oI2XroL+lI7vdXyYoQk03bXBThfFl2cVdIA3Xl7cH8g= +github.com/valyala/fasthttp v1.58.0 h1:GGB2dWxSbEprU9j0iMJHgdKYJVDyjrOwF9RE59PbRuE= +github.com/valyala/fasthttp v1.58.0/go.mod h1:SYXvHHaFp7QZHGKSHmoMipInhrI5StHrhDTYVEjK/Kw= github.com/valyala/tcplisten v1.0.0 h1:rBHj/Xf+E1tRGZyWIWwJDiRY0zc1Js+CV5DqwacVSA8= github.com/valyala/tcplisten v1.0.0/go.mod h1:T0xQ8SeCZGxckz9qRXTfG43PvQ/mcWh7FwZEA7Ioqkc= +github.com/xyproto/randomstring v1.0.5/go.mod h1:rgmS5DeNXLivK7YprL0pY+lTuhNQW3iGxZ18UQApw/E= golang.org/x/crypto v0.14.0 h1:wBqGXzWJW6m1XrIKlAH0Hs1JJ7+9KBwnIO8v66Q9cHc= golang.org/x/crypto v0.14.0/go.mod h1:MVFd36DqK4CsrnJYDkBA3VC4m2GkXAM0PvzMCn4JQf4= golang.org/x/crypto v0.32.0 h1:euUpcYgM8WcP71gNpTqQCn6rC2t6ULUPiOzfWaXVVfc= diff --git a/handlers/auth.go b/handlers/auth.go index 5878c25..e86c029 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -9,13 +9,13 @@ import ( "relay-server/config" "relay-server/database" "relay-server/helpers" - "relay-server/models" + "relay-server/model" "time" ) func Signup(c *fiber.Ctx) error { - db, _ := database.InitDatabase() - u := new(models.SignupStruct) + db := database.DB + u := new(model.SignupStruct) if err := c.BodyParser(u); err != nil { return err } @@ -42,13 +42,14 @@ func Signup(c *fiber.Ctx) error { // Create password hash passwordHash, err := bcrypt.GenerateFromPassword([]byte(u.Password), config.BCRYPT_COST) if err != nil { - fmt.Printf("error hashing password: %v", err) + fmt.Printf("error hashing password: %v\n", err) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "internal server error"}) } // Insert username and password hash to database userId, err := database.InsertUser(db, u.Username, string(passwordHash)) if err != nil { + fmt.Printf("error inserting user: %v\n", err) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) } @@ -61,19 +62,20 @@ func Signup(c *fiber.Ctx) error { signedToken, err := token.SignedString([]byte(os.Getenv("JWT_SECRET"))) // Set token to cookies - cookie := new(fiber.Cookie) - cookie.Name = "token" - cookie.Value = signedToken - cookie.Expires = time.Now().Add(30 * 24 * time.Hour) - cookie.HTTPOnly = true + tokenCookie := new(fiber.Cookie) + tokenCookie.Name = "token" + tokenCookie.Value = signedToken + tokenCookie.Expires = time.Now().Add(30 * 24 * time.Hour) + //tokenCookie.HTTPOnly = true + c.Cookie(tokenCookie) // If everything went well sent username and user_id assigned by database return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Successfully signed up", "username": u.Username, "user_id": userId}) } func Login(c *fiber.Ctx) error { - db, _ := database.InitDatabase() - u := new(models.LoginStruct) + db := database.DB + u := new(model.LoginStruct) if err := c.BodyParser(u); err != nil { return err @@ -107,6 +109,7 @@ func Login(c *fiber.Ctx) error { userId, err := database.GetUserId(db, u.Username) if err != nil { + fmt.Printf("error getting user id: %v\n", err) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "internal server error"}) } // Generate token with user id and username @@ -118,11 +121,16 @@ func Login(c *fiber.Ctx) error { signedToken, err := token.SignedString([]byte(os.Getenv("JWT_SECRET"))) // Set token to cookies - cookie := new(fiber.Cookie) - cookie.Name = "token" - cookie.Value = signedToken - cookie.Expires = time.Now().Add(30 * 24 * time.Hour) - cookie.HTTPOnly = true + tokenCookie := new(fiber.Cookie) + tokenCookie.Name = "token" + tokenCookie.Value = signedToken + tokenCookie.Expires = time.Now().Add(30 * 24 * time.Hour) + //tokenCookie.HTTPOnly = true + c.Cookie(tokenCookie) return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Successfully logged in", "username": u.Username, "user_id": userId}) } + +//func ValidateToken(c *fiber.Ctx) error { +// +//} diff --git a/main.go b/main.go index a2a2786..f4fd297 100644 --- a/main.go +++ b/main.go @@ -1,39 +1,26 @@ package main import ( - "database/sql" "github.com/gofiber/fiber/v2" "log" "relay-server/database" - "relay-server/handlers" + "relay-server/router" ) func main() { app := fiber.New() - db, err := database.InitDatabase() + db, err := database.Init() if err != nil { - log.Fatal(err) + log.Fatal("Failed to initialize database") } - defer func(db *sql.DB) { - err := db.Close() - if err != nil { - log.Fatal(err) + + defer func() { + if err := db.Close(); err != nil { + log.Fatalf("Failed to close database connection: %v", err) } - }(db) + log.Println("Database connection closed") + }() - app.Get("/", func(c *fiber.Ctx) error { - return c.SendString("Hello, World!") - }) - - //app.Get("/users", func(c *fiber.Ctx) error { - // users, _ := database.GetUsers(db) - // return c.JSON(fiber.Map{"users": users}) - //}) - - app.Post("/api/auth/signup", handlers.Signup) - app.Post("/api/auth/login", handlers.Login) - err = app.Listen(":3000") - if err != nil { - return - } + router.SetupRoutes(app) + app.Listen(":3000") } diff --git a/middleware/protected.go b/middleware/protected.go new file mode 100644 index 0000000..c6c6103 --- /dev/null +++ b/middleware/protected.go @@ -0,0 +1,24 @@ +package middleware + +import ( + jwtware "github.com/gofiber/contrib/jwt" + "github.com/gofiber/fiber/v2" + "os" +) + +func Protected() fiber.Handler { + return jwtware.New(jwtware.Config{ + SigningKey: jwtware.SigningKey{Key: []byte(os.Getenv("JWT_SECRET"))}, + ErrorHandler: jwtError, + TokenLookup: "cookie:token", + }) +} + +func jwtError(c *fiber.Ctx, err error) error { + if err.Error() == "Missing or malformed JWT" { + return c.Status(fiber.StatusBadRequest). + JSON(fiber.Map{"error": "Missing or malformed token"}) + } + return c.Status(fiber.StatusUnauthorized). + JSON(fiber.Map{"error": "Invalid or expired token"}) +} diff --git a/models/models.go b/model/model.go similarity index 95% rename from models/models.go rename to model/model.go index 5fc7c30..d4998f6 100644 --- a/models/models.go +++ b/model/model.go @@ -1,4 +1,4 @@ -package models +package model type LoginStruct struct { Username string `json:"username" xml:"username" form:"username"` diff --git a/router/router.go b/router/router.go new file mode 100644 index 0000000..dcc7848 --- /dev/null +++ b/router/router.go @@ -0,0 +1,21 @@ +package router + +import ( + "github.com/gofiber/fiber/v2" + "github.com/gofiber/fiber/v2/middleware/logger" + "relay-server/handlers" + "relay-server/middleware" +) + +func SetupRoutes(app *fiber.App) { + app.Get("/", middleware.Protected(), func(c *fiber.Ctx) error { + return c.SendString("Hello, World!") + }) + api := app.Group("/api", logger.New()) + + // Auth group + auth := api.Group("/auth") + auth.Post("/signup", handlers.Signup) + auth.Post("/login", handlers.Login) + +}