diff --git a/database/contacts.go b/database/contacts.go index 66f402e..cecc899 100644 --- a/database/contacts.go +++ b/database/contacts.go @@ -313,7 +313,7 @@ func GetContacts(db *sql.DB, userID uuid.UUID) ([]*model.Contact, error) { } if err = rows.Err(); err != nil { - return nil, helpers.NewError(helpers.ErrInternal, "Failed to process contacts", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process contacts: %w", err)) } return contacts, nil @@ -342,7 +342,7 @@ func ContactSuggestion(db *sql.DB, contactUsername string) ([]string, error) { suggestions = append(suggestions, suggestion) } if err = rows.Err(); err != nil { - return nil, helpers.NewError(helpers.ErrInternal, "Error processing suggestions", err) + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process suggestions")) } return suggestions, nil } diff --git a/database/messages.go b/database/messages.go new file mode 100644 index 0000000..0ca4a75 --- /dev/null +++ b/database/messages.go @@ -0,0 +1,109 @@ +package database + +import ( + "database/sql" + "fmt" + "github.com/google/uuid" + "relay-server/helpers" + "relay-server/model" +) + +func GetMessages(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID, limit int, cursor int) ([]*model.Message, error) { + _, err := checkMembership(db, userID, conversationID) + if err != nil { + return nil, err + } + var query string + var messages []*model.Message + + if cursor != 0 { + query = ` + SELECT + m.message_id, + m.content AS message, + m.sent_at, + m.attachment_urls, + a.username AS sender + FROM Messages m + JOIN Accounts a ON m.user_id = a.user_id + WHERE m.conversation_id = $1 + AND m.message_id < $2 + ORDER BY m.message_id DESC + LIMIT $3; + ` + rows, err := db.Query(query, conversationID, cursor, limit) + if err != nil { + return nil, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to get messages: %w", err)) + } + + for rows.Next() { + message := &model.Message{} + err = rows.Scan(&message.MessageID, &message.Message, &message.SentAt, &message.AttachmentUrl, &message.Sender) + if err != nil { + return nil, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to scan message: %w", err)) + } + messages = append(messages, message) + } + if err = rows.Err(); err != nil { + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process messages: %w", err)) + } + + } else { + query = ` + SELECT + m.message_id, + m.content AS message, + m.sent_at, + m.attachment_urls, + a.username AS sender + FROM Messages m + JOIN Accounts a ON m.user_id = a.user_id + WHERE m.conversation_id = $1 + ORDER BY m.message_id DESC + LIMIT $2; + ` + rows, err := db.Query(query, conversationID, cursor, limit) + if err != nil { + return nil, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to get messages: %w", err)) + } + + for rows.Next() { + message := &model.Message{} + err = rows.Scan(&message.MessageID, &message.Message, &message.SentAt, &message.AttachmentUrl, &message.Sender) + if err != nil { + return nil, helpers.NewError(helpers.ErrInternal, "Failed to get messages", fmt.Errorf("failed to scan message: %w", err)) + } + messages = append(messages, message) + } + if err = rows.Err(); err != nil { + return nil, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to process messages: %w", err)) + } + } + if cursor != 0 { + // Reverse first messages + for i, j := 0, len(messages)-1; i < j; i, j = i+1, j-1 { + messages[i], messages[j] = messages[j], messages[i] + } + } + return messages, nil +} + +func checkMembership(db *sql.DB, userID uuid.UUID, conversationID uuid.UUID) (bool, error) { + query := ` + SELECT EXISTS ( + SELECT 1 + FROM Memberships + WHERE user_id = $1 + AND conversation_id = $2 + ) AS is_member; + ` + var isMember bool + err := db.QueryRow(query, userID, conversationID).Scan(&isMember) + if err != nil { + return false, helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to check membership: %w", err)) + } + if !isMember { + return false, helpers.NewError(helpers.ErrForbidden, "You are member of the conversation", nil) + } + return isMember, nil +} diff --git a/handlers/auth.go b/handlers/auth.go index 1e64b2c..ec4b580 100644 --- a/handlers/auth.go +++ b/handlers/auth.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/gofiber/fiber/v2" "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" "golang.org/x/crypto/bcrypt" "os" "relay-server/config" @@ -158,12 +159,13 @@ func Login(c *fiber.Ctx) error { func ValidateToken(c *fiber.Ctx) error { username, ok := c.Locals("username").(string) if !ok { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid token: missing username", nil) + return helpers.NewError(helpers.ErrInvalidInput, "Invalid token: missing username", fmt.Errorf("missing username: %v", c.Locals("username"))) } - userID, ok := c.Locals("userID").(string) + userIDVal := c.Locals("userID") + userID, ok := userIDVal.(uuid.UUID) if !ok { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid token: missing user ID", nil) + return helpers.NewError(helpers.ErrUnauthorized, "unauthorized", fmt.Errorf("missing/invalid userID type: %T, value: %v\n", userIDVal, userIDVal)) } return c.Status(fiber.StatusOK).JSON(fiber.Map{ diff --git a/handlers/contacts.go b/handlers/contacts.go index 29e8e12..55cdec5 100644 --- a/handlers/contacts.go +++ b/handlers/contacts.go @@ -35,19 +35,10 @@ func DeleteContact(c *fiber.Ctx) error { func InsertContact(c *fiber.Ctx) error { type params struct { - ContactUsername string `params:"contact_username"` + ContactUsername string `params:"contactUsername"` } - userIDVal := c.Locals("userID") - userIDStr, ok := userIDVal.(string) - if !ok { - return helpers.NewError(helpers.ErrInternal, "Internal server error", nil) - } - - userID, err := uuid.Parse(userIDStr) - if err != nil { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid user ID format", err) - } + userID := c.Locals("userID").(uuid.UUID) p := new(params) if err := c.ParamsParser(p); err != nil { @@ -76,16 +67,7 @@ func InsertContact(c *fiber.Ctx) error { } func GetContacts(c *fiber.Ctx) error { - userIDVal := c.Locals("userID") - userIDStr, ok := userIDVal.(string) - if !ok { - return helpers.NewError(helpers.ErrInternal, "Invalid user session", nil) - } - - userID, err := uuid.Parse(userIDStr) - if err != nil { - return helpers.NewError(helpers.ErrInvalidInput, "Invalid user ID format", err) - } + userID := c.Locals("userID").(uuid.UUID) contacts, err := database.GetContacts(database.DB, userID) if err != nil { diff --git a/handlers/messages.go b/handlers/messages.go new file mode 100644 index 0000000..1e16749 --- /dev/null +++ b/handlers/messages.go @@ -0,0 +1,41 @@ +package handlers + +import ( + "github.com/gofiber/fiber/v2" + "github.com/google/uuid" + "relay-server/database" + "relay-server/helpers" +) + +func GetMessages(c *fiber.Ctx) error { + type params struct { + conversationID uuid.UUID `params:"conversationID"` + } + type query struct { + limit int `query:"limit"` + cursor int `query:"cursor"` + } + userID := c.Locals("userID").(uuid.UUID) + + p := new(params) + q := new(query) + if err := c.ParamsParser(p); err != nil { + return helpers.NewError(helpers.ErrInvalidInput, "Invalid params", err) + } + if err := c.QueryParser(q); err != nil { + return helpers.NewError(helpers.ErrInvalidInput, "Invalid query", err) + } + if p.conversationID == uuid.Nil { + return helpers.NewError(helpers.ErrInvalidInput, "conversation ID is empty", nil) + } + + messages, err := database.GetMessages(database.DB, userID, p.conversationID, q.limit, q.cursor) + if err != nil { + return err + } + if len(messages) == 0 { + return c.JSON(fiber.Map{"message": "No messages found"}) + } + + return c.JSON(fiber.Map{"messages": messages}) +} diff --git a/helpers/errorHandler.go b/helpers/errorHandler.go index e8a5a69..bb7a650 100644 --- a/helpers/errorHandler.go +++ b/helpers/errorHandler.go @@ -11,6 +11,8 @@ const ( ErrInternal ErrorCode = "INTERNAL_ERROR" ErrNotFound ErrorCode = "NOT_FOUND" ErrInvalidInput ErrorCode = "INVALID_INPUT" + ErrForbidden ErrorCode = "FORBIDDEN" + ErrUnauthorized ErrorCode = "UNAUTHORIZED" ) type Error struct { @@ -40,6 +42,10 @@ func NewError(code ErrorCode, userMsg string, internalErr error) *Error { statusCode = fiber.StatusBadRequest case ErrInternal: statusCode = fiber.StatusInternalServerError + case ErrForbidden: + statusCode = fiber.StatusForbidden + case ErrUnauthorized: + statusCode = fiber.StatusUnauthorized } return &Error{ diff --git a/middleware/protected.go b/middleware/protected.go index 32e0158..e60cb86 100644 --- a/middleware/protected.go +++ b/middleware/protected.go @@ -1,10 +1,13 @@ package middleware import ( + "fmt" jwtware "github.com/gofiber/contrib/jwt" "github.com/gofiber/fiber/v2" "github.com/golang-jwt/jwt/v5" + "github.com/google/uuid" "os" + "relay-server/helpers" "relay-server/model" ) @@ -17,8 +20,11 @@ func Protected() fiber.Handler { SuccessHandler: func(c *fiber.Ctx) error { user := c.Locals("user").(*jwt.Token) claims := user.Claims.(*model.UserClaims) - - c.Locals("userID", claims.UserID) + userID, err := uuid.Parse(claims.UserID) + if err != nil { + return helpers.NewError(helpers.ErrInternal, "internal server error", fmt.Errorf("failed to parse user ID: %w", err)) + } + c.Locals("userID", userID) c.Locals("username", claims.Username) return c.Next() }, diff --git a/model/model.go b/model/model.go index a347330..2933d8b 100644 --- a/model/model.go +++ b/model/model.go @@ -25,3 +25,11 @@ type Contact struct { type ContactSuggestion struct { Username string `json:"username"` } + +type Message struct { + MessageID int `json:"message_id"` + Message string `json:"message"` + SentAt string `json:"sent_at"` + Sender string `json:"sender"` + AttachmentUrl string `json:"attachment_url"` +} diff --git a/router/router.go b/router/router.go index 905d40b..8795aa3 100644 --- a/router/router.go +++ b/router/router.go @@ -21,9 +21,13 @@ func SetupRoutes(app *fiber.App) { auth.Get("/validate", middleware.Protected(), handlers.ValidateToken) // Contacts group - contacts := chat.Group("/contacts", middleware.Protected(), logger.New()) - contacts.Delete("/:contact_id/:conversation_id", handlers.DeleteContact) - contacts.Post("/:contact_username", handlers.InsertContact) + contacts := chat.Group("/contact", middleware.Protected(), logger.New()) + contacts.Delete("/:contactID/:conversation_id", handlers.DeleteContact) + contacts.Post("/:contactUsername", handlers.InsertContact) contacts.Get("/", handlers.GetContacts) contacts.Get("/suggestions/:contactUsername", handlers.GetContactSuggestions) + + // Messages group + messages := chat.Group("/messages", middleware.Protected(), logger.New()) + messages.Get("/:conversationID", handlers.GetMessages) }