diff --git a/apps/remix/app/components/general/billing-plans.tsx b/apps/remix/app/components/general/billing-plans.tsx new file mode 100644 index 000000000..2113da8fc --- /dev/null +++ b/apps/remix/app/components/general/billing-plans.tsx @@ -0,0 +1,138 @@ +import { useState } from 'react'; + +import type { MessageDescriptor } from '@lingui/core'; +import { msg } from '@lingui/core/macro'; +import { useLingui } from '@lingui/react'; +import { Trans } from '@lingui/react/macro'; +import { AnimatePresence, motion } from 'framer-motion'; + +import type { PriceIntervals } from '@documenso/ee/server-only/stripe/get-prices-by-interval'; +import { useIsMounted } from '@documenso/lib/client-only/hooks/use-is-mounted'; +import { toHumanPrice } from '@documenso/lib/universal/stripe/to-human-price'; +import { trpc } from '@documenso/trpc/react'; +import { Button } from '@documenso/ui/primitives/button'; +import { Card, CardContent, CardTitle } from '@documenso/ui/primitives/card'; +import { Tabs, TabsList, TabsTrigger } from '@documenso/ui/primitives/tabs'; +import { useToast } from '@documenso/ui/primitives/use-toast'; + +type Interval = keyof PriceIntervals; + +const INTERVALS: Interval[] = ['day', 'week', 'month', 'year']; + +// eslint-disable-next-line @typescript-eslint/consistent-type-assertions +const isInterval = (value: unknown): value is Interval => INTERVALS.includes(value as Interval); + +const FRIENDLY_INTERVALS: Record = { + day: msg`Daily`, + week: msg`Weekly`, + month: msg`Monthly`, + year: msg`Yearly`, +}; + +const MotionCard = motion(Card); + +export type BillingPlansProps = { + prices: PriceIntervals; +}; + +export const BillingPlans = ({ prices }: BillingPlansProps) => { + const { _ } = useLingui(); + const { toast } = useToast(); + + const isMounted = useIsMounted(); + + const [interval, setInterval] = useState('month'); + const [checkoutSessionPriceId, setCheckoutSessionPriceId] = useState(null); + + const { mutateAsync: createCheckoutSession } = trpc.profile.createCheckoutSession.useMutation(); + + const onSubscribeClick = async (priceId: string) => { + try { + setCheckoutSessionPriceId(priceId); + + const url = await createCheckoutSession({ priceId }); + + if (!url) { + throw new Error('Unable to create session'); + } + + window.open(url); + } catch (_err) { + toast({ + title: _(msg`Something went wrong`), + description: _(msg`An error occurred while trying to create a checkout session.`), + variant: 'destructive', + }); + } finally { + setCheckoutSessionPriceId(null); + } + }; + + return ( +
+ isInterval(value) && setInterval(value)}> + + {INTERVALS.map( + (interval) => + prices[interval].length > 0 && ( + + {_(FRIENDLY_INTERVALS[interval])} + + ), + )} + + + +
+ + {prices[interval].map((price) => ( + + + {price.product.name} + +
+ ${toHumanPrice(price.unit_amount ?? 0)} {price.currency.toUpperCase()}{' '} + per {interval} +
+ +
+ {price.product.description} +
+ + {price.product.features && price.product.features.length > 0 && ( +
+
Includes:
+ +
    + {price.product.features.map((feature, index) => ( +
  • + {feature.name} +
  • + ))} +
+
+ )} + +
+ + + + + ))} + +
+
+ ); +}; diff --git a/apps/remix/app/components/general/billing-portal-button.tsx b/apps/remix/app/components/general/billing-portal-button.tsx new file mode 100644 index 000000000..ea8735954 --- /dev/null +++ b/apps/remix/app/components/general/billing-portal-button.tsx @@ -0,0 +1,48 @@ +import { msg } from '@lingui/core/macro'; +import { useLingui } from '@lingui/react'; +import { Trans } from '@lingui/react/macro'; + +import { trpc } from '@documenso/trpc/react'; +import { Button } from '@documenso/ui/primitives/button'; +import { useToast } from '@documenso/ui/primitives/use-toast'; + +export type BillingPortalButtonProps = { + buttonProps?: React.ComponentProps; + children?: React.ReactNode; +}; + +export const BillingPortalButton = ({ buttonProps, children }: BillingPortalButtonProps) => { + const { _ } = useLingui(); + const { toast } = useToast(); + + const { mutateAsync: createBillingPortal, isPending } = + trpc.profile.createBillingPortal.useMutation({ + onSuccess: (sessionUrl) => { + window.open(sessionUrl, '_blank'); + }, + onError: (err) => { + let description = _( + msg`We are unable to proceed to the billing portal at this time. Please try again, or contact support.`, + ); + + if (err.message === 'CUSTOMER_NOT_FOUND') { + description = _( + msg`You do not currently have a customer record, this should not happen. Please contact support for assistance.`, + ); + } + + toast({ + title: _(msg`Something went wrong`), + description, + variant: 'destructive', + duration: 10000, + }); + }, + }); + + return ( + + ); +}; diff --git a/apps/remix/app/routes/_authenticated+/settings+/billing.tsx b/apps/remix/app/routes/_authenticated+/settings+/billing.tsx new file mode 100644 index 000000000..ba719e06d --- /dev/null +++ b/apps/remix/app/routes/_authenticated+/settings+/billing.tsx @@ -0,0 +1,156 @@ +import { Trans, useLingui } from '@lingui/react/macro'; +import { SubscriptionStatus } from '@prisma/client'; +import { redirect } from 'react-router'; +import { match } from 'ts-pattern'; + +import { getSession } from '@documenso/auth/server/lib/utils/get-session'; +import { getStripeCustomerByUser } from '@documenso/ee/server-only/stripe/get-customer'; +import { getPricesByInterval } from '@documenso/ee/server-only/stripe/get-prices-by-interval'; +import { getPrimaryAccountPlanPrices } from '@documenso/ee/server-only/stripe/get-primary-account-plan-prices'; +import { getProductByPriceId } from '@documenso/ee/server-only/stripe/get-product-by-price-id'; +import { IS_BILLING_ENABLED } from '@documenso/lib/constants/app'; +import { STRIPE_PLAN_TYPE } from '@documenso/lib/constants/billing'; +import { type Stripe } from '@documenso/lib/server-only/stripe'; +import { getSubscriptionsByUserId } from '@documenso/lib/server-only/subscription/get-subscriptions-by-user-id'; + +import { BillingPlans } from '~/components/general/billing-plans'; +import { BillingPortalButton } from '~/components/general/billing-portal-button'; +import { appMetaTags } from '~/utils/meta'; + +import type { Route } from './+types/billing'; + +export function meta() { + return appMetaTags('Billing'); +} + +export async function loader({ request }: Route.LoaderArgs) { + const { user } = await getSession(request); + + // Redirect if subscriptions are not enabled. + if (!IS_BILLING_ENABLED()) { + throw redirect('/settings/profile'); + } + + if (!user.customerId) { + await getStripeCustomerByUser(user).then((result) => result.user); + } + + const [subscriptions, prices, primaryAccountPlanPrices] = await Promise.all([ + getSubscriptionsByUserId({ userId: user.id }), + getPricesByInterval({ plans: [STRIPE_PLAN_TYPE.REGULAR, STRIPE_PLAN_TYPE.PLATFORM] }), + getPrimaryAccountPlanPrices(), + ]); + + const primaryAccountPlanPriceIds = primaryAccountPlanPrices.map(({ id }) => id); + + let subscriptionProduct: Stripe.Product | null = null; + + const primaryAccountPlanSubscriptions = subscriptions.filter(({ priceId }) => + primaryAccountPlanPriceIds.includes(priceId), + ); + + const subscription = + primaryAccountPlanSubscriptions.find(({ status }) => status === SubscriptionStatus.ACTIVE) ?? + primaryAccountPlanSubscriptions[0]; + + if (subscription?.priceId) { + subscriptionProduct = await getProductByPriceId({ priceId: subscription.priceId }).catch( + () => null, + ); + } + + const isMissingOrInactiveOrFreePlan = + !subscription || subscription.status === SubscriptionStatus.INACTIVE; + + return { + prices, + subscription, + subscriptionProductName: subscriptionProduct?.name, + isMissingOrInactiveOrFreePlan, + }; +} + +export default function TeamsSettingBillingPage({ loaderData }: Route.ComponentProps) { + const { prices, subscription, subscriptionProductName, isMissingOrInactiveOrFreePlan } = + loaderData; + + const { i18n } = useLingui(); + + return ( +
+
+
+

+ Billing +

+ +
+ {isMissingOrInactiveOrFreePlan && ( +

+ + You are currently on the Free Plan. + +

+ )} + + {/* Todo: Translation */} + {!isMissingOrInactiveOrFreePlan && + match(subscription.status) + .with('ACTIVE', () => ( +

+ {subscriptionProductName ? ( + + You are currently subscribed to{' '} + {subscriptionProductName} + + ) : ( + You currently have an active plan + )} + + {subscription.periodEnd && ( + + {' '} + which is set to{' '} + {subscription.cancelAtPeriodEnd ? ( + + end on{' '} + + {i18n.date(subscription.periodEnd)}. + + + ) : ( + + automatically renew on{' '} + + {i18n.date(subscription.periodEnd)}. + + + )} + + )} +

+ )) + .with('PAST_DUE', () => ( +

+ + Your current plan is past due. Please update your payment information. + +

+ )) + .otherwise(() => null)} +
+
+ + {isMissingOrInactiveOrFreePlan && ( + + Manage billing + + )} +
+ +
+ + {isMissingOrInactiveOrFreePlan ? : } +
+ ); +} diff --git a/packages/auth/server/lib/session/session.ts b/packages/auth/server/lib/session/session.ts index 35497f5cf..cd690be64 100644 --- a/packages/auth/server/lib/session/session.ts +++ b/packages/auth/server/lib/session/session.ts @@ -23,6 +23,7 @@ export type SessionUser = Pick< | 'roles' | 'signature' | 'url' + | 'customerId' >; export type SessionValidationResult = @@ -99,6 +100,7 @@ export const validateSessionToken = async (token: string): Promise { * * Will create a Stripe customer and update the relevant user if one does not exist. */ -export const getStripeCustomerByUser = async (user: User) => { +export const getStripeCustomerByUser = async ( + user: Pick, +) => { if (user.customerId) { const stripeCustomer = await getStripeCustomerById(user.customerId); diff --git a/packages/lib/server-only/team/create-team.ts b/packages/lib/server-only/team/create-team.ts index 2f3975010..210187c5c 100644 --- a/packages/lib/server-only/team/create-team.ts +++ b/packages/lib/server-only/team/create-team.ts @@ -205,7 +205,7 @@ export const createTeamFromPendingTeam = async ({ pendingTeamId, subscription, }: CreateTeamFromPendingTeamOptions) => { - return await prisma.$transaction(async (tx) => { + const createdTeam = await prisma.$transaction(async (tx) => { const pendingTeam = await tx.teamPending.findUniqueOrThrow({ where: { id: pendingTeamId, @@ -249,19 +249,21 @@ export const createTeamFromPendingTeam = async ({ mapStripeSubscriptionToPrismaUpsertAction(subscription, undefined, team.id), ); - // Attach the team ID to the subscription metadata for sanity reasons. - await stripe.subscriptions - .update(subscription.id, { - metadata: { - teamId: team.id.toString(), - }, - }) - .catch((e) => { - console.error(e); - // Non-critical error, but we want to log it so we can rectify it. - // Todo: Teams - Alert us. - }); - return team; }); + + // Attach the team ID to the subscription metadata for sanity reasons. + await stripe.subscriptions + .update(subscription.id, { + metadata: { + teamId: createdTeam.id.toString(), + }, + }) + .catch((e) => { + console.error(e); + // Non-critical error, but we want to log it so we can rectify it. + // Todo: Teams - Alert us. + }); + + return createdTeam; }; diff --git a/packages/lib/server-only/user/create-billing-portal.ts b/packages/lib/server-only/user/create-billing-portal.ts new file mode 100644 index 000000000..4e00ecc79 --- /dev/null +++ b/packages/lib/server-only/user/create-billing-portal.ts @@ -0,0 +1,22 @@ +import type { User } from '@prisma/client'; + +import { getStripeCustomerByUser } from '@documenso/ee/server-only/stripe/get-customer'; +import { getPortalSession } from '@documenso/ee/server-only/stripe/get-portal-session'; +import { IS_BILLING_ENABLED, NEXT_PUBLIC_WEBAPP_URL } from '@documenso/lib/constants/app'; + +export type CreateBillingPortalOptions = { + user: Pick; +}; + +export const createBillingPortal = async ({ user }: CreateBillingPortalOptions) => { + if (!IS_BILLING_ENABLED()) { + throw new Error('Billing is not enabled'); + } + + const { stripeCustomer } = await getStripeCustomerByUser(user); + + return getPortalSession({ + customerId: stripeCustomer.id, + returnUrl: `${NEXT_PUBLIC_WEBAPP_URL()}/settings/billing`, + }); +}; diff --git a/packages/lib/server-only/user/create-checkout-session.ts b/packages/lib/server-only/user/create-checkout-session.ts new file mode 100644 index 000000000..3b6a45d47 --- /dev/null +++ b/packages/lib/server-only/user/create-checkout-session.ts @@ -0,0 +1,39 @@ +import type { User } from '@prisma/client'; + +import { getCheckoutSession } from '@documenso/ee/server-only/stripe/get-checkout-session'; +import { getStripeCustomerByUser } from '@documenso/ee/server-only/stripe/get-customer'; +import { getPortalSession } from '@documenso/ee/server-only/stripe/get-portal-session'; +import { NEXT_PUBLIC_WEBAPP_URL } from '@documenso/lib/constants/app'; + +import { getSubscriptionsByUserId } from '../subscription/get-subscriptions-by-user-id'; + +export type CreateCheckoutSession = { + user: Pick; + priceId: string; +}; + +export const createCheckoutSession = async ({ user, priceId }: CreateCheckoutSession) => { + const { stripeCustomer } = await getStripeCustomerByUser(user); + + const existingSubscriptions = await getSubscriptionsByUserId({ userId: user.id }); + + const foundSubscription = existingSubscriptions.find( + (subscription) => + subscription.priceId === priceId && + subscription.periodEnd && + subscription.periodEnd >= new Date(), + ); + + if (foundSubscription) { + return getPortalSession({ + customerId: stripeCustomer.id, + returnUrl: `${NEXT_PUBLIC_WEBAPP_URL()}/settings/billing`, + }); + } + + return getCheckoutSession({ + customerId: stripeCustomer.id, + priceId, + returnUrl: `${NEXT_PUBLIC_WEBAPP_URL()}/settings/billing`, + }); +}; diff --git a/packages/trpc/server/profile-router/router.ts b/packages/trpc/server/profile-router/router.ts index e5f5f4a7c..fc1e978c7 100644 --- a/packages/trpc/server/profile-router/router.ts +++ b/packages/trpc/server/profile-router/router.ts @@ -4,6 +4,8 @@ import { IS_BILLING_ENABLED } from '@documenso/lib/constants/app'; import { AppError } from '@documenso/lib/errors/app-error'; import { setAvatarImage } from '@documenso/lib/server-only/profile/set-avatar-image'; import { getSubscriptionsByUserId } from '@documenso/lib/server-only/subscription/get-subscriptions-by-user-id'; +import { createBillingPortal } from '@documenso/lib/server-only/user/create-billing-portal'; +import { createCheckoutSession } from '@documenso/lib/server-only/user/create-checkout-session'; import { deleteUser } from '@documenso/lib/server-only/user/delete-user'; import { findUserSecurityAuditLogs } from '@documenso/lib/server-only/user/find-user-security-audit-logs'; import { getUserById } from '@documenso/lib/server-only/user/get-user-by-id'; @@ -12,6 +14,7 @@ import { updatePublicProfile } from '@documenso/lib/server-only/user/update-publ import { adminProcedure, authenticatedProcedure, router } from '../trpc'; import { + ZCreateCheckoutSessionRequestSchema, ZFindUserSecurityAuditLogsSchema, ZRetrieveUserByIdQuerySchema, ZSetProfileImageMutationSchema, @@ -35,6 +38,31 @@ export const profileRouter = router({ return await getUserById({ id }); }), + createBillingPortal: authenticatedProcedure.mutation(async ({ ctx }) => { + return await createBillingPortal({ + user: { + id: ctx.user.id, + customerId: ctx.user.customerId, + email: ctx.user.email, + name: ctx.user.name, + }, + }); + }), + + createCheckoutSession: authenticatedProcedure + .input(ZCreateCheckoutSessionRequestSchema) + .mutation(async ({ ctx, input }) => { + return await createCheckoutSession({ + user: { + id: ctx.user.id, + customerId: ctx.user.customerId, + email: ctx.user.email, + name: ctx.user.name, + }, + priceId: input.priceId, + }); + }), + updateProfile: authenticatedProcedure .input(ZUpdateProfileMutationSchema) .mutation(async ({ input, ctx }) => { diff --git a/packages/trpc/server/profile-router/schema.ts b/packages/trpc/server/profile-router/schema.ts index b02451d55..1d607437f 100644 --- a/packages/trpc/server/profile-router/schema.ts +++ b/packages/trpc/server/profile-router/schema.ts @@ -15,6 +15,10 @@ export const ZRetrieveUserByIdQuerySchema = z.object({ export type TRetrieveUserByIdQuerySchema = z.infer; +export const ZCreateCheckoutSessionRequestSchema = z.object({ + priceId: z.string().min(1), +}); + export const ZUpdateProfileMutationSchema = z.object({ name: z.string().min(1), signature: z.string(),