From ca2debca31038ca7ffdf787741cbfee96ec411d1 Mon Sep 17 00:00:00 2001 From: slawk0 Date: Tue, 4 Feb 2025 19:53:43 +0100 Subject: [PATCH] code refactor, added contact post route --- .idea/dataSources.xml | 17 ++++ .idea/vcs.xml | 6 ++ database/auth.go | 29 +++---- database/contacts.go | 187 ++++++++++++++++++++++++++++++++++++++-- handlers/auth.go | 42 +++++---- handlers/contacts.go | 65 ++++++++++++-- middleware/protected.go | 2 +- model/model.go | 18 +++- router/router.go | 5 +- 9 files changed, 319 insertions(+), 52 deletions(-) create mode 100644 .idea/dataSources.xml create mode 100644 .idea/vcs.xml diff --git a/.idea/dataSources.xml b/.idea/dataSources.xml new file mode 100644 index 0000000..321d912 --- /dev/null +++ b/.idea/dataSources.xml @@ -0,0 +1,17 @@ + + + + + postgresql + true + org.postgresql.Driver + jdbc:postgresql://192.168.0.47:5432/relay + + + + + + $ProjectFileDir$ + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..94a25f7 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,6 @@ + + + + + + \ No newline at end of file diff --git a/database/auth.go b/database/auth.go index 5e2a64f..d7bf66d 100644 --- a/database/auth.go +++ b/database/auth.go @@ -3,6 +3,7 @@ package database import ( "database/sql" "fmt" + "github.com/google/uuid" _ "github.com/lib/pq" "log" ) @@ -34,7 +35,7 @@ func CheckUserExists(db *sql.DB, username string) (bool, error) { err := db.QueryRow(query, username).Scan(&count) if err != nil { - return false, fmt.Errorf("error checking username exists: %v", err) + return false, fmt.Errorf("error checking username exists: %w", err) } return count > 0, err @@ -47,15 +48,15 @@ func InsertUser(db *sql.DB, username string, passwordHash string) (string, error RETURNING user_id; ` - var userId string - err := db.QueryRow(query, username, passwordHash).Scan(&userId) + var userID string + err := db.QueryRow(query, username, passwordHash).Scan(&userID) if err != nil { - return "", fmt.Errorf("error inserting user: %v", err) + return "", fmt.Errorf("error inserting user: %w", err) } - log.Printf("Inserted user: %v", username) + fmt.Printf("Inserted user: %w", username) - return userId, err + return userID, err } func GetPasswordHash(db *sql.DB, username string) (string, error) { @@ -67,24 +68,22 @@ func GetPasswordHash(db *sql.DB, username string) (string, error) { var passwordHash string err := db.QueryRow(query, username).Scan(&passwordHash) if err != nil { - fmt.Printf("error getting password hash: %v\n", err) - return "", fmt.Errorf("error getting password hash: %v", err) + return "", fmt.Errorf("error getting password hash: %w", err) } return passwordHash, err } -func GetUserId(db *sql.DB, username string) (string, error) { +func GetUserID(db *sql.DB, username string) (uuid.UUID, error) { query := ` SELECT user_id FROM Accounts - WHERE LOWER(username) = $1; + WHERE LOWER(username) = LOWER($1); ` - var dbUsername string - err := db.QueryRow(query, username).Scan(&dbUsername) + var userID uuid.UUID + err := db.QueryRow(query, username).Scan(&userID) if err != nil { - fmt.Printf("Error getting password: %v\n", err) - return "", fmt.Errorf("error getting user id: %v", err) + return uuid.Nil, fmt.Errorf("error getting user id: %w", err) } - return dbUsername, err + return userID, nil } diff --git a/database/contacts.go b/database/contacts.go index 9e5c67a..5b8cc8a 100644 --- a/database/contacts.go +++ b/database/contacts.go @@ -5,7 +5,7 @@ import ( "errors" "fmt" "github.com/google/uuid" - "log" + "relay-server/model" ) func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (string, error) { @@ -58,8 +58,6 @@ func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (stri if rowsAffected == 0 { return "", fmt.Errorf("no matching membership found with conversation id: %s, user id: %s", conversationID, userID) } - - log.Printf("Successfully removed user %s from group %s", userID, conversationID) } else { // Handle direct conversation res, err := db.Exec( @@ -68,21 +66,198 @@ func DeleteContact(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (stri conversationID, ) if err != nil { - log.Printf("Error deleting contact: %v", err) return "", fmt.Errorf("error deleting contact: %w", err) } rowsAffected, err := res.RowsAffected() if err != nil { - log.Printf("Error checking contact deletion: %v", err) return "", fmt.Errorf("error checking contact deletion: %w", err) } if rowsAffected == 0 { return fmt.Sprintf("no matching contact found with user id: %s, conversation id: %s", userID, conversationID), nil } - log.Printf("Successfully deleted contact for user %s in conversation %s", userID, conversationID) + fmt.Printf("Successfully deleted contact for user %s in conversation %s", userID, conversationID) } return "", nil } + +func InsertContact(db *sql.DB, userID uuid.UUID, contactID uuid.UUID, contactUsername string) (*model.Contact, error) { + isSelfContact := userID == contactID + var contact model.Contact + var conversationID uuid.UUID + + if isSelfContact { + findSelfConversationQuery := ` + SELECT c.conversation_id + FROM Conversations c + JOIN Memberships m ON c.conversation_id = m.conversation_id + WHERE c.conversation_type = 'direct' + AND m.user_id = $1 + AND ( + SELECT COUNT(*) + FROM Memberships + WHERE conversation_id = c.conversation_id + ) = 1 + LIMIT 1; + ` + + err := db.QueryRow(findSelfConversationQuery, userID).Scan(&conversationID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + conversationID = uuid.Nil + } + return nil, fmt.Errorf("error finding existing conversation: %w", err) + } + + // if conversation for themselves don't exist + if conversationID == uuid.Nil { + createConversationQuery := ` + INSERT INTO Conversations (conversation_type) + VALUES ('direct') + RETURNING conversation_id; + ` + err := db.QueryRow(createConversationQuery).Scan(&conversationID) + if err != nil { + return nil, fmt.Errorf("error creating conversation for self-contact: %w", err) + } + + createMembershipQuery := ` + INSERT INTO Memberships (conversation_id, user_id) + VALUES ($1, $2) + ON CONFLICT (conversation_id, user_id) DO NOTHING; + ` + _, err = db.Exec(createMembershipQuery, conversationID, userID) + + } else { + // For regular contacts, check if a conversation already exists between the two users + findConversationQuery := ` + SELECT c.conversation_id + FROM Conversations c + JOIN Memberships m1 ON c.conversation_id = m1.conversation_id + JOIN Memberships m2 ON c.conversation_id = m2.conversation_id + WHERE c.conversation_type = 'direct' + AND ( + (m1.user_id = $1 AND m2.user_id = $2) + OR + (m1.user_id = $2 AND m2.user_id = $1) + ) + LIMIT 1; + ` + err := db.QueryRow(findConversationQuery, userID, contactID).Scan(&conversationID) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + conversationID = uuid.Nil + } + return nil, fmt.Errorf("error finding existing conversation: %w", err) + } + if conversationID != uuid.Nil { + // Create a new conversation between user and contact + createConversationQuery := ` + INSERT INTO Conversations (conversation_type) + VALUES ('direct') + RETURNING conversation_id; + ` + err := db.QueryRow(createConversationQuery).Scan(&conversationID) + if err != nil { + return nil, fmt.Errorf("error creating conversation: %w", err) + } + + createMembershipQuery := ` + INSERT INTO Memberships (conversation_id, user_id) + VALUES ($1, $2), ($1, $3) + ON CONFLICT (conversation_id, user_id) DO NOTHING; + ` + + res, err := db.Exec(createMembershipQuery, conversationID, userID, contactID) + if err != nil { + return nil, fmt.Errorf("error creating membership: %w", err) + } + rowsAffected, err := res.RowsAffected() + if err != nil { + return nil, fmt.Errorf("error checking membership creation: %w", err) + } + if rowsAffected == 0 { + return nil, fmt.Errorf("error creating membership %w", err) + } + } + } + + } + insertedContact, err := InsertContactByID(db, contactID, conversationID) + if err != nil || insertedContact.UserID == uuid.Nil { + return nil, fmt.Errorf("error inserting contact by id: %w", err) + } + + latestMessage, err := GetLatestMessage(db, conversationID) + if err != nil { + return nil, fmt.Errorf("error getting latest message: %w", err) + } + + contact = model.Contact{ + ID: insertedContact.ID, + ConversationID: insertedContact.ConversationID, + UserID: insertedContact.UserID, + Username: contactUsername, + Type: "direct", + LastMessageID: latestMessage.LastMessageID, + LastMessage: latestMessage.LastMessage, + LastMessageTime: latestMessage.LastMessageTime, + LastMessageSender: latestMessage.LastMessageSender, + } + return &contact, nil +} + +func InsertContactByID(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (*model.Contact, error) { + // First check if contact already exists + checkQuery := ` + SELECT contact_id, conversation_id, user_id + FROM Contacts + WHERE user_id = $1 AND conversation_id = $2 + ` + var contact model.Contact + err := db.QueryRow(checkQuery, userID, conversationID).Scan(&contact.ID, &contact.ConversationID, &contact.UserID) + if err == nil { + return nil, fmt.Errorf("contact already exists") + } else if !errors.Is(err, sql.ErrNoRows) { + return nil, fmt.Errorf("error checking contact existence: %w", err) + } + + insertQuery := ` + INSERT INTO Contacts (user_id, conversation_id) + VALUES($1, $2) + RETURNING contact_id, conversation_id, user_id + ` + + err = db.QueryRow(insertQuery, userID, conversationID).Scan(&contact.ID, &contact.ConversationID, &contact.UserID) + if err != nil { + return nil, fmt.Errorf("error inserting contact: %w", err) + } + fmt.Printf("Successfully inserted contact by id: %v", conversationID) + + return &contact, nil +} + +func GetLatestMessage(db *sql.DB, conversationId uuid.UUID) (*model.Contact, error) { + var latestMessage model.Contact + + query := ` + SELECT DISTINCT ON (m.conversation_id) + m.message_id AS last_message_id, + m.content AS last_message, + m.sent_at AS last_message_time, + a.username AS last_message_sender + FROM Messages m + JOIN Accounts a ON m.user_id = a.user_id + WHERE m.conversation_id = $1 + ORDER BY m.conversation_id, m.sent_at DESC + LIMIT 1; + ` + + err := db.QueryRow(query, conversationId).Scan(&latestMessage.LastMessageID, &latestMessage.LastMessage, &latestMessage.LastMessageTime, &latestMessage.LastMessageSender) + if err != nil { + return nil, fmt.Errorf("error getting latest message: %w", err) + } + return &latestMessage, nil +} diff --git a/handlers/auth.go b/handlers/auth.go index 3c2b67a..c4df198 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -1,10 +1,10 @@ package handlers import ( - "fmt" "github.com/gofiber/fiber/v2" "github.com/golang-jwt/jwt/v5" "golang.org/x/crypto/bcrypt" + "log" "os" "relay-server/config" "relay-server/database" @@ -49,20 +49,20 @@ func Signup(c *fiber.Ctx) error { // Create password hash passwordHash, err := bcrypt.GenerateFromPassword([]byte(u.Password), config.BCRYPT_COST) if err != nil { - fmt.Printf("error hashing password: %v\n", err) + log.Printf("error hashing password: %w\n", err) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "internal server error"}) } // Insert username and password hash to database - userId, err := database.InsertUser(db, u.Username, string(passwordHash)) + userID, err := database.InsertUser(db, u.Username, string(passwordHash)) if err != nil { - fmt.Printf("error inserting user: %v\n", err) + log.Print(err) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) } // Generate token with user id and username token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "user_id": userId, + "user_id": userID, "username": u.Username, }) // Sign token @@ -77,7 +77,7 @@ func Signup(c *fiber.Ctx) error { c.Cookie(tokenCookie) // If everything went well sent username and user_id assigned by database - return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Successfully signed up", "username": u.Username, "user_id": userId}) + return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Successfully signed up", "username": u.Username, "user_id": userID}) } func Login(c *fiber.Ctx) error { @@ -117,19 +117,22 @@ func Login(c *fiber.Ctx) error { } // Verifies password matching - passwordHash, _ := database.GetPasswordHash(db, u.Username) + passwordHash, err := database.GetPasswordHash(db, u.Username) + if err != nil { + log.Printf("error getting password: %w\n", err) + } if bcrypt.CompareHashAndPassword([]byte(passwordHash), []byte(u.Password)) != nil { return c.Status(fiber.StatusUnauthorized).JSON(fiber.Map{"error": "invalid password"}) } - userId, err := database.GetUserId(db, u.Username) + userID, err := database.GetUserID(db, u.Username) if err != nil { - fmt.Printf("error getting user id: %v\n", err) + log.Print(err) return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "internal server error"}) } // Generate token with user id and username token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ - "user_id": userId, + "user_id": userID, "username": u.Username, }) // Sign token @@ -140,19 +143,20 @@ func Login(c *fiber.Ctx) error { tokenCookie.Name = "token" tokenCookie.Value = signedToken tokenCookie.Expires = time.Now().Add(30 * 24 * time.Hour) - //tokenCookie.HTTPOnly = true c.Cookie(tokenCookie) - return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Successfully logged in", "username": u.Username, "user_id": userId}) + return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Successfully logged in", "username": u.Username, "user_id": userID}) } func ValidateToken(c *fiber.Ctx) error { - username := c.Locals("username") - userId := c.Locals("user_id") + username := c.Locals("username").(string) + userID := c.Locals("userID").(string) - if userId == nil || username == nil { - fmt.Println("userId or username is nil") - return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid token"}) - } - return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "authorized", "username": c.Locals("username").(string), "user_id": c.Locals("userId").(string)}) + //log.Printf("userID: %v, username: %v", userID, username) + //if userID == "" || username == "" { + // log.Printf("userID or username is empty %v", c.Locals("username")) + // return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "invalid token"}) + //} + + return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "authorized", "username": username, "user_id": userID}) } diff --git a/handlers/contacts.go b/handlers/contacts.go index ff7614e..f280155 100644 --- a/handlers/contacts.go +++ b/handlers/contacts.go @@ -1,17 +1,21 @@ package handlers import ( + "database/sql" + "errors" + "fmt" "github.com/gofiber/fiber/v2" "github.com/google/uuid" "log" "relay-server/database" + "relay-server/helpers" ) func DeleteContact(c *fiber.Ctx) error { type params struct { - ContactId uuid.UUID `params:"contact_id"` - ConversationId uuid.UUID `params:"conversation_id"` + ContactID uuid.UUID `params:"contact_id"` + ConversationID uuid.UUID `params:"conversation_id"` } p := new(params) @@ -21,22 +25,69 @@ func DeleteContact(c *fiber.Ctx) error { db := database.DB - if p.ContactId == uuid.Nil { + if p.ContactID == uuid.Nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "contact_id is empty"}) } - if p.ConversationId == uuid.Nil { + if p.ConversationID == uuid.Nil { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "conversation_id is empty"}) } - msg, err := database.DeleteContact(db, p.ContactId, p.ConversationId) + msg, err := database.DeleteContact(db, p.ContactID, p.ConversationID) if err != nil { log.Println(err) - return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Failed to delete contact"}) + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"error": "Internal server error"}) } if msg != "" { return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": msg}) } - log.Println("Contact deleted") + fmt.Println("Contact deleted") return c.Status(fiber.StatusOK).JSON(fiber.Map{"message": "Contact deleted"}) } + +func InsertContact(c *fiber.Ctx) error { + userID, err := uuid.Parse(c.Locals("userID").(string)) + if err != nil { + log.Println(err) + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Internal server error"}) + } + type params struct { + ContactUsername string `params:"contact_username"` + } + + p := new(params) + + if err := c.ParamsParser(p); err != nil { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"error": "Invalid params"}) + } + + db := database.DB + + if p.ContactUsername == "" { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "contact_username is empty"}) + } + + if !helpers.IsValidUsername(p.ContactUsername) { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "username is invalid"}) + } + + contactID, err := database.GetUserID(db, p.ContactUsername) + + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{"message": "user does not exist"}) + } + log.Println(err) + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Internal server error"}) + } + + newContacts, err := database.InsertContact(db, userID, contactID, p.ContactUsername) + if err != nil { + log.Println(err) + return c.Status(fiber.StatusInternalServerError).JSON(fiber.Map{"message": "Internal server error"}) + } + + log.Println("Contact added") + return c.Status(fiber.StatusOK).JSON(newContacts) + +} diff --git a/middleware/protected.go b/middleware/protected.go index 2065a4c..32e0158 100644 --- a/middleware/protected.go +++ b/middleware/protected.go @@ -18,7 +18,7 @@ func Protected() fiber.Handler { user := c.Locals("user").(*jwt.Token) claims := user.Claims.(*model.UserClaims) - c.Locals("userId", claims.UserId) + c.Locals("userID", claims.UserID) c.Locals("username", claims.Username) return c.Next() }, diff --git a/model/model.go b/model/model.go index c60c6ac..5d3ac96 100644 --- a/model/model.go +++ b/model/model.go @@ -1,9 +1,23 @@ package model -import "github.com/golang-jwt/jwt/v5" +import ( + "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" +) type UserClaims struct { Username string `json:"username"` - UserId string `json:"user_id"` + UserID string `json:"user_id"` jwt.RegisteredClaims } +type Contact struct { + ID int `json:"contact_id"` + ConversationID uuid.UUID `json:"conversation_id"` + UserID uuid.UUID `json:"user_id"` + Username string `json:"username"` + Type string `json:"type"` + LastMessageID int `json:"last_message_id"` + LastMessage string `json:"last_message"` + LastMessageTime string `json:"last_message_time"` + LastMessageSender string `json:"last_message_sender"` +} diff --git a/router/router.go b/router/router.go index 550b180..35a931a 100644 --- a/router/router.go +++ b/router/router.go @@ -15,12 +15,13 @@ func SetupRoutes(app *fiber.App) { chat := api.Group("/chat", middleware.Protected(), logger.New()) // Auth group - auth := api.Group("/auth", middleware.Protected(), handlers.ValidateToken) + auth := api.Group("/auth", logger.New()) auth.Post("/signup", handlers.Signup) auth.Post("/login", handlers.Login) - auth.Get("/validate", handlers.ValidateToken) + auth.Get("/validate", middleware.Protected(), handlers.ValidateToken) // Contacts group contacts := chat.Group("/contacts", middleware.Protected(), logger.New()) contacts.Delete("/:contact_id/:conversation_id", handlers.DeleteContact) + contacts.Post("/:contact_username", handlers.InsertContact) }