diff --git a/composables/ws.ts b/composables/ws.ts index 0501a05..01db40f 100644 --- a/composables/ws.ts +++ b/composables/ws.ts @@ -1,12 +1,19 @@ +import type { NuxtError } from "#app"; + export type WebSocketCallback = (message: string) => void; +export type WebSocketErrorHandler = (error: NuxtError) => void; export class WebSocketHandler { private listeners: Array = []; + private outQueue: Array = []; private inQueue: Array = []; + private ws: WebSocket | undefined = undefined; private connected: boolean = false; + private errorHandler: WebSocketErrorHandler | undefined = undefined; + constructor(route: string) { if (import.meta.server) return; const isSecure = location.protocol === "https:"; @@ -22,6 +29,18 @@ export class WebSocketHandler { this.ws.onmessage = (e) => { const message = e.data; + switch (message) { + case "unauthenticated": + const error = createError({ + statusCode: 403, + statusMessage: "Unable to connect to websocket - unauthenticated", + }); + if (this.errorHandler) { + return this.errorHandler(error); + } else { + throw error; + } + } if (this.listeners.length == 0) { this.inQueue.push(message); return; @@ -33,6 +52,10 @@ export class WebSocketHandler { }; } + error(handler: WebSocketErrorHandler) { + this.errorHandler = handler; + } + listen(callback: WebSocketCallback) { this.listeners.push(callback); } diff --git a/package.json b/package.json index 6910905..1ac0d8b 100644 --- a/package.json +++ b/package.json @@ -17,6 +17,7 @@ "@prisma/client": "5.20.0", "axios": "^1.7.7", "bcryptjs": "^2.4.3", + "cookie-es": "^1.2.2", "fast-fuzzy": "^1.12.0", "file-type-mime": "^0.4.3", "jdenticon": "^3.3.0", diff --git a/server/api/v1/task/index.get.ts b/server/api/v1/task/index.get.ts index b5da1b8..5716c6e 100644 --- a/server/api/v1/task/index.get.ts +++ b/server/api/v1/task/index.get.ts @@ -2,26 +2,30 @@ import { H3Event } from "h3"; import session from "~/server/internal/session"; import { v4 as uuidv4 } from "uuid"; import taskHandler, { TaskMessage } from "~/server/internal/tasks"; +import { parse as parseCookies } from "cookie-es"; // TODO add web socket sessions for horizontal scaling // ID to admin const adminSocketSessions: { [key: string]: boolean } = {}; export default defineWebSocketHandler({ - open(peer) { - const dummyEvent = { - node: { - req: { - headers: peer.request?.headers, - }, - }, - } as unknown as H3Event; - const userId = session.getUserId(dummyEvent); + async open(peer) { + const cookies = peer.request?.headers?.get("Cookie"); + if (!cookies) { + peer.send("unauthenticated"); + return; + } + + const parsedCookies = parseCookies(cookies); + const token = parsedCookies[session.getDropTokenCookie()]; + + const userId = await session.getUserIdRaw(token); if (!userId) { peer.send("unauthenticated"); return; } - const admin = session.getAdminUser(dummyEvent); + + const admin = session.getAdminUser(token); adminSocketSessions[peer.id] = admin !== undefined; const rtMsg: TaskMessage = { diff --git a/server/internal/session/index.ts b/server/internal/session/index.ts index 51c04ff..ec9cf96 100644 --- a/server/internal/session/index.ts +++ b/server/internal/session/index.ts @@ -43,6 +43,10 @@ export class SessionHandler { return token; } + getDropTokenCookie() { + return dropTokenCookie; + } + async getSession(h3: H3Event) { const token = this.getSessionToken(h3); if (!token) return undefined;