feat: add multi subscription support (#734)

## Description

Previously we assumed that there can only be 1 subscription per user.
However, that will soon no longer the case with the introduction of the
Teams subscription.

This PR will apply the required migrations to support multiple
subscriptions.

## Changes Made

- Updated the Prisma schema to allow for multiple `Subscriptions` per
`User`
- Added a Stripe `customerId` field to the `User` model
- Updated relevant billing sections to support multiple subscriptions

## Testing Performed

- Tested running the Prisma migration on a demo database created on the
main branch

Will require a lot of additional testing.

## Checklist

- [ ] I have tested these changes locally and they work as expected.
- [ ] I have added/updated tests that prove the effectiveness of these
changes.
- [X] I have followed the project's coding style guidelines.

## Additional Notes

Added the following custom SQL statement to the migration:

> DELETE FROM "Subscription" WHERE "planId" IS NULL OR "priceId" IS
NULL;

Prior to deployment this will require changes to Stripe products:
- Adding `type` meta attribute

---------

Co-authored-by: Lucas Smith <me@lucasjamessmith.me>
This commit is contained in:
David Nguyen
2023-12-14 15:22:54 +11:00
committed by GitHub
parent 6d34ebd91b
commit 88534fa1c6
28 changed files with 288 additions and 366 deletions

View File

@ -77,14 +77,10 @@ NEXT_PRIVATE_MAILCHANNELS_DKIM_PRIVATE_KEY=
NEXT_PRIVATE_STRIPE_API_KEY=
NEXT_PRIVATE_STRIPE_WEBHOOK_SECRET=
NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_MONTHLY_PRICE_ID=
NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_YEARLY_PRICE_ID=
NEXT_PUBLIC_STRIPE_FREE_PLAN_ID=
# [[FEATURES]]
# OPTIONAL: Leave blank to disable PostHog and feature flags.
NEXT_PUBLIC_POSTHOG_KEY=""
# OPTIONAL: Defines the host to use for PostHog.
NEXT_PUBLIC_POSTHOG_HOST="https://eu.posthog.com"
# OPTIONAL: Leave blank to disable billing.
NEXT_PUBLIC_FEATURE_BILLING_ENABLED=

View File

@ -6,8 +6,6 @@ declare namespace NodeJS {
NEXT_PRIVATE_DATABASE_URL: string;
NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_MONTHLY_PRICE_ID: string;
NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_YEARLY_PRICE_ID: string;
NEXT_PUBLIC_STRIPE_FREE_PLAN_ID?: string;
NEXT_PRIVATE_STRIPE_API_KEY: string;
NEXT_PRIVATE_STRIPE_WEBHOOK_SECRET: string;

View File

@ -1,160 +0,0 @@
'use client';
import React, { useEffect, useState } from 'react';
import { useSearchParams } from 'next/navigation';
import { zodResolver } from '@hookform/resolvers/zod';
import { Info } from 'lucide-react';
import { usePlausible } from 'next-plausible';
import { useForm } from 'react-hook-form';
import { z } from 'zod';
import { useAnalytics } from '@documenso/lib/client-only/hooks/use-analytics';
import { cn } from '@documenso/ui/lib/utils';
import { Button } from '@documenso/ui/primitives/button';
import {
Dialog,
DialogContent,
DialogDescription,
DialogHeader,
DialogTitle,
DialogTrigger,
} from '@documenso/ui/primitives/dialog';
import { Input } from '@documenso/ui/primitives/input';
import { Label } from '@documenso/ui/primitives/label';
import { useToast } from '@documenso/ui/primitives/use-toast';
import { claimPlan } from '~/api/claim-plan/fetcher';
import { FormErrorMessage } from '../form/form-error-message';
export const ZClaimPlanDialogFormSchema = z.object({
name: z.string().trim().min(3, { message: 'Please enter a valid name.' }),
email: z.string().email(),
});
export type TClaimPlanDialogFormSchema = z.infer<typeof ZClaimPlanDialogFormSchema>;
export type ClaimPlanDialogProps = {
className?: string;
planId: string;
children: React.ReactNode;
};
export const ClaimPlanDialog = ({ className, planId, children }: ClaimPlanDialogProps) => {
const params = useSearchParams();
const analytics = useAnalytics();
const event = usePlausible();
const { toast } = useToast();
const [open, setOpen] = useState(() => params?.get('cancelled') === 'true');
const {
register,
handleSubmit,
formState: { errors, isSubmitting },
reset,
} = useForm<TClaimPlanDialogFormSchema>({
defaultValues: {
name: params?.get('name') ?? '',
email: params?.get('email') ?? '',
},
resolver: zodResolver(ZClaimPlanDialogFormSchema),
});
const onFormSubmit = async ({ name, email }: TClaimPlanDialogFormSchema) => {
try {
const delay = new Promise<void>((resolve) => {
setTimeout(resolve, 1000);
});
const [redirectUrl] = await Promise.all([
claimPlan({ name, email, planId, signatureText: name, signatureDataUrl: null }),
delay,
]);
event('claim-plan-pricing');
analytics.capture('Marketing: Claim plan', { planId, email });
window.location.href = redirectUrl;
} catch (error) {
event('claim-plan-failed');
analytics.capture('Marketing: Claim plan failure', { planId, email });
toast({
title: 'Something went wrong',
description: error instanceof Error ? error.message : 'Please try again later.',
variant: 'destructive',
});
}
};
useEffect(() => {
if (!isSubmitting && !open) {
reset();
}
}, [open]);
return (
<Dialog open={open} onOpenChange={(value) => !isSubmitting && setOpen(value)}>
<DialogTrigger asChild>{children}</DialogTrigger>
<DialogContent>
<DialogHeader>
<DialogTitle>Claim your plan</DialogTitle>
<DialogDescription className="mt-4">
We're almost there! Please enter your email address and name to claim your plan.
</DialogDescription>
</DialogHeader>
<form onSubmit={handleSubmit(onFormSubmit)}>
<fieldset disabled={isSubmitting} className={cn('flex flex-col gap-y-4', className)}>
{params?.get('cancelled') === 'true' && (
<div className="rounded-lg border border-yellow-400 bg-yellow-50 p-4">
<div className="flex">
<div className="flex-shrink-0">
<Info className="h-5 w-5 text-yellow-400" />
</div>
<div className="ml-3">
<p className="text-sm leading-5 text-yellow-700">
You have cancelled the payment process. If you didn't mean to do this, please
try again.
</p>
</div>
</div>
</div>
)}
<div>
<Label className="text-muted-foreground">Name</Label>
<Input type="text" className="mt-2" {...register('name')} autoFocus />
<FormErrorMessage className="mt-1" error={errors.name} />
</div>
<div>
<Label className="text-muted-foreground">Email</Label>
<Input type="email" className="mt-2" {...register('email')} />
<FormErrorMessage className="mt-1" error={errors.email} />
</div>
<Button type="submit" size="lg" loading={isSubmitting}>
Claim the early adopters Plan (
{/* eslint-disable-next-line turbo/no-undeclared-env-vars */}
{planId === process.env.NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_MONTHLY_PRICE_ID
? 'Monthly'
: 'Yearly'}
)
</Button>
</fieldset>
</form>
</DialogContent>
</Dialog>
);
};

View File

@ -1,9 +1,9 @@
'use client';
import { HTMLAttributes, useState } from 'react';
import type { HTMLAttributes } from 'react';
import { useState } from 'react';
import Link from 'next/link';
import { useSearchParams } from 'next/navigation';
import { AnimatePresence, motion } from 'framer-motion';
import { usePlausible } from 'next-plausible';
@ -16,14 +16,9 @@ export type PricingTableProps = HTMLAttributes<HTMLDivElement>;
const SELECTED_PLAN_BAR_LAYOUT_ID = 'selected-plan-bar';
export const PricingTable = ({ className, ...props }: PricingTableProps) => {
const params = useSearchParams();
const event = usePlausible();
const [period, setPeriod] = useState<'MONTHLY' | 'YEARLY'>(() =>
params?.get('planId') === process.env.NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_YEARLY_PRICE_ID
? 'YEARLY'
: 'MONTHLY',
);
const [period, setPeriod] = useState<'MONTHLY' | 'YEARLY'>('MONTHLY');
return (
<div className={cn('', className)} {...props}>

View File

@ -6,8 +6,6 @@ declare namespace NodeJS {
NEXT_PRIVATE_DATABASE_URL: string;
NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_MONTHLY_PRICE_ID: string;
NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_YEARLY_PRICE_ID: string;
NEXT_PUBLIC_STRIPE_FREE_PLAN_ID?: string;
NEXT_PRIVATE_STRIPE_API_KEY: string;
NEXT_PRIVATE_STRIPE_WEBHOOK_SECRET: string;

View File

@ -8,7 +8,7 @@ import { Edit, Loader } from 'lucide-react';
import { useDebouncedValue } from '@documenso/lib/client-only/hooks/use-debounced-value';
import { useUpdateSearchParams } from '@documenso/lib/client-only/hooks/use-update-search-params';
import { Document, Role, Subscription } from '@documenso/prisma/client';
import type { Document, Role, Subscription } from '@documenso/prisma/client';
import { Button } from '@documenso/ui/primitives/button';
import { DataTable } from '@documenso/ui/primitives/data-table';
import { DataTablePagination } from '@documenso/ui/primitives/data-table-pagination';
@ -19,7 +19,7 @@ type UserData = {
name: string | null;
email: string;
roles: Role[];
Subscription?: SubscriptionLite | null;
Subscription?: SubscriptionLite[] | null;
Document: DocumentLite[];
};
@ -35,9 +35,16 @@ type UsersDataTableProps = {
totalPages: number;
perPage: number;
page: number;
individualPriceIds: string[];
};
export const UsersDataTable = ({ users, totalPages, perPage, page }: UsersDataTableProps) => {
export const UsersDataTable = ({
users,
totalPages,
perPage,
page,
individualPriceIds,
}: UsersDataTableProps) => {
const [isPending, startTransition] = useTransition();
const updateSearchParams = useUpdateSearchParams();
const [searchString, setSearchString] = useState('');
@ -100,7 +107,13 @@ export const UsersDataTable = ({ users, totalPages, perPage, page }: UsersDataTa
{
header: 'Subscription',
accessorKey: 'subscription',
cell: ({ row }) => row.original.Subscription?.status ?? 'NONE',
cell: ({ row }) => {
const foundIndividualSubscription = (row.original.Subscription ?? []).find((sub) =>
individualPriceIds.includes(sub.priceId),
);
return foundIndividualSubscription?.status ?? 'NONE';
},
},
{
header: 'Documents',

View File

@ -1,3 +1,5 @@
import { getPricesByType } from '@documenso/ee/server-only/stripe/get-prices-by-type';
import { UsersDataTable } from './data-table-users';
import { search } from './fetch-users.actions';
@ -14,12 +16,23 @@ export default async function AdminManageUsers({ searchParams = {} }: AdminManag
const perPage = Number(searchParams.perPage) || 10;
const searchString = searchParams.search || '';
const { users, totalPages } = await search(searchString, page, perPage);
const [{ users, totalPages }, individualPrices] = await Promise.all([
search(searchString, page, perPage),
getPricesByType('individual'),
]);
const individualPriceIds = individualPrices.map((price) => price.id);
return (
<div>
<h2 className="text-4xl font-semibold">Manage users</h2>
<UsersDataTable users={users} totalPages={totalPages} page={page} perPage={perPage} />
<UsersDataTable
users={users}
individualPriceIds={individualPriceIds}
totalPages={totalPages}
page={page}
perPage={perPage}
/>
</div>
);
}

View File

@ -4,7 +4,7 @@ import { useState } from 'react';
import { AnimatePresence, motion } from 'framer-motion';
import { PriceIntervals } from '@documenso/ee/server-only/stripe/get-prices-by-interval';
import type { PriceIntervals } from '@documenso/ee/server-only/stripe/get-prices-by-interval';
import { useIsMounted } from '@documenso/lib/client-only/hooks/use-is-mounted';
import { toHumanPrice } from '@documenso/lib/universal/stripe/to-human-price';
import { Button } from '@documenso/ui/primitives/button';

View File

@ -1,46 +1,13 @@
'use server';
import {
getStripeCustomerByEmail,
getStripeCustomerById,
} from '@documenso/ee/server-only/stripe/get-customer';
import { getStripeCustomerByUser } from '@documenso/ee/server-only/stripe/get-customer';
import { getPortalSession } from '@documenso/ee/server-only/stripe/get-portal-session';
import { getRequiredServerComponentSession } from '@documenso/lib/next-auth/get-server-component-session';
import type { Stripe } from '@documenso/lib/server-only/stripe';
import { stripe } from '@documenso/lib/server-only/stripe';
import { getSubscriptionByUserId } from '@documenso/lib/server-only/subscription/get-subscription-by-user-id';
export const createBillingPortal = async () => {
const { user } = await getRequiredServerComponentSession();
const existingSubscription = await getSubscriptionByUserId({ userId: user.id });
let stripeCustomer: Stripe.Customer | null = null;
// Find the Stripe customer for the current user subscription.
if (existingSubscription) {
stripeCustomer = await getStripeCustomerById(existingSubscription.customerId);
if (!stripeCustomer) {
throw new Error('Missing Stripe customer for subscription');
}
}
// Fallback to check if a Stripe customer already exists for the current user email.
if (!stripeCustomer) {
stripeCustomer = await getStripeCustomerByEmail(user.email);
}
// Create a Stripe customer if it does not exist for the current user.
if (!stripeCustomer) {
stripeCustomer = await stripe.customers.create({
name: user.name ?? undefined,
email: user.email,
metadata: {
userId: user.id,
},
});
}
const { stripeCustomer } = await getStripeCustomerByUser(user);
return getPortalSession({
customerId: stripeCustomer.id,

View File

@ -1,55 +1,36 @@
'use server';
import { createCustomer } from '@documenso/ee/server-only/stripe/create-customer';
import { getCheckoutSession } from '@documenso/ee/server-only/stripe/get-checkout-session';
import {
getStripeCustomerByEmail,
getStripeCustomerById,
} from '@documenso/ee/server-only/stripe/get-customer';
import { getStripeCustomerByUser } from '@documenso/ee/server-only/stripe/get-customer';
import { getPortalSession } from '@documenso/ee/server-only/stripe/get-portal-session';
import { getRequiredServerComponentSession } from '@documenso/lib/next-auth/get-server-component-session';
import type { Stripe } from '@documenso/lib/server-only/stripe';
import { getSubscriptionByUserId } from '@documenso/lib/server-only/subscription/get-subscription-by-user-id';
import { getSubscriptionsByUserId } from '@documenso/lib/server-only/subscription/get-subscriptions-by-user-id';
export type CreateCheckoutOptions = {
priceId: string;
};
export const createCheckout = async ({ priceId }: CreateCheckoutOptions) => {
const { user } = await getRequiredServerComponentSession();
const session = await getRequiredServerComponentSession();
const existingSubscription = await getSubscriptionByUserId({ userId: user.id });
const { user, stripeCustomer } = await getStripeCustomerByUser(session.user);
let stripeCustomer: Stripe.Customer | null = null;
const existingSubscriptions = await getSubscriptionsByUserId({ userId: user.id });
// Find the Stripe customer for the current user subscription.
if (existingSubscription?.periodEnd && existingSubscription.periodEnd >= new Date()) {
stripeCustomer = await getStripeCustomerById(existingSubscription.customerId);
if (!stripeCustomer) {
throw new Error('Missing Stripe customer for subscription');
}
const foundSubscription = existingSubscriptions.find(
(subscription) =>
subscription.priceId === priceId &&
subscription.periodEnd &&
subscription.periodEnd >= new Date(),
);
if (foundSubscription) {
return getPortalSession({
customerId: stripeCustomer.id,
returnUrl: `${process.env.NEXT_PUBLIC_WEBAPP_URL}/settings/billing`,
});
}
// Fallback to check if a Stripe customer already exists for the current user email.
if (!stripeCustomer) {
stripeCustomer = await getStripeCustomerByEmail(user.email);
}
// Create a Stripe customer if it does not exist for the current user.
if (!stripeCustomer) {
await createCustomer({
user,
});
stripeCustomer = await getStripeCustomerByEmail(user.email);
}
return getCheckoutSession({
customerId: stripeCustomer.id,
priceId,

View File

@ -2,12 +2,15 @@ import { redirect } from 'next/navigation';
import { match } from 'ts-pattern';
import { getStripeCustomerByUser } from '@documenso/ee/server-only/stripe/get-customer';
import { getPricesByInterval } from '@documenso/ee/server-only/stripe/get-prices-by-interval';
import { getPricesByType } from '@documenso/ee/server-only/stripe/get-prices-by-type';
import { getProductByPriceId } from '@documenso/ee/server-only/stripe/get-product-by-price-id';
import { getRequiredServerComponentSession } from '@documenso/lib/next-auth/get-server-component-session';
import { getServerComponentFlag } from '@documenso/lib/server-only/feature-flags/get-server-component-feature-flag';
import type { Stripe } from '@documenso/lib/server-only/stripe';
import { getSubscriptionByUserId } from '@documenso/lib/server-only/subscription/get-subscription-by-user-id';
import { type Stripe } from '@documenso/lib/server-only/stripe';
import { getSubscriptionsByUserId } from '@documenso/lib/server-only/subscription/get-subscriptions-by-user-id';
import { SubscriptionStatus } from '@documenso/prisma/client';
import { LocaleDate } from '~/components/formatter/locale-date';
@ -15,7 +18,7 @@ import { BillingPlans } from './billing-plans';
import { BillingPortalButton } from './billing-portal-button';
export default async function BillingSettingsPage() {
const { user } = await getRequiredServerComponentSession();
let { user } = await getRequiredServerComponentSession();
const isBillingEnabled = await getServerComponentFlag('app_billing');
@ -24,20 +27,36 @@ export default async function BillingSettingsPage() {
redirect('/settings/profile');
}
const [subscription, prices] = await Promise.all([
getSubscriptionByUserId({ userId: user.id }),
getPricesByInterval(),
if (!user.customerId) {
user = await getStripeCustomerByUser(user).then((result) => result.user);
}
const [subscriptions, prices, individualPrices] = await Promise.all([
getSubscriptionsByUserId({ userId: user.id }),
getPricesByInterval({ type: 'individual' }),
getPricesByType('individual'),
]);
const individualPriceIds = individualPrices.map(({ id }) => id);
let subscriptionProduct: Stripe.Product | null = null;
const individualUserSubscriptions = subscriptions.filter(({ priceId }) =>
individualPriceIds.includes(priceId),
);
const subscription =
individualUserSubscriptions.find(({ status }) => status === SubscriptionStatus.ACTIVE) ??
individualUserSubscriptions[0];
if (subscription?.priceId) {
subscriptionProduct = await getProductByPriceId({ priceId: subscription.priceId }).catch(
() => null,
);
}
const isMissingOrInactiveOrFreePlan = !subscription || subscription.status === 'INACTIVE';
const isMissingOrInactiveOrFreePlan =
!subscription || subscription.status === SubscriptionStatus.INACTIVE;
return (
<div>

View File

@ -1,10 +1,10 @@
import { DateTime } from 'luxon';
import { stripe } from '@documenso/lib/server-only/stripe';
import { getFlag } from '@documenso/lib/universal/get-feature-flag';
import { prisma } from '@documenso/prisma';
import { SubscriptionStatus } from '@documenso/prisma/client';
import { getPricesByType } from '../stripe/get-prices-by-type';
import { FREE_PLAN_LIMITS, SELFHOSTED_PLAN_LIMITS } from './constants';
import { ERROR_CODES } from './errors';
import { ZLimitsSchema } from './schema';
@ -43,23 +43,29 @@ export const getServerLimits = async ({ email }: GetServerLimitsOptions) => {
let quota = structuredClone(FREE_PLAN_LIMITS);
let remaining = structuredClone(FREE_PLAN_LIMITS);
// Since we store details and allow for past due plans we need to check if the subscription is active.
if (user.Subscription?.status !== SubscriptionStatus.INACTIVE && user.Subscription?.priceId) {
const { product } = await stripe.prices
.retrieve(user.Subscription.priceId, {
expand: ['product'],
})
.catch((err) => {
console.error(err);
throw err;
});
const activeSubscriptions = user.Subscription.filter(
({ status }) => status === SubscriptionStatus.ACTIVE,
);
if (typeof product === 'string') {
throw new Error(ERROR_CODES.SUBSCRIPTION_FETCH_FAILED);
if (activeSubscriptions.length > 0) {
const individualPrices = await getPricesByType('individual');
for (const subscription of activeSubscriptions) {
const price = individualPrices.find((price) => price.id === subscription.priceId);
if (!price || typeof price.product === 'string' || price.product.deleted) {
continue;
}
const currentQuota = ZLimitsSchema.parse(
'metadata' in price.product ? price.product.metadata : {},
);
// Use the subscription with the highest quota.
if (currentQuota.documents > quota.documents && currentQuota.recipients > quota.recipients) {
quota = currentQuota;
remaining = structuredClone(quota);
}
}
quota = ZLimitsSchema.parse('metadata' in product ? product.metadata : {});
remaining = structuredClone(quota);
}
const documents = await prisma.document.count({

View File

@ -1,31 +0,0 @@
import { stripe } from '@documenso/lib/server-only/stripe';
import { getSubscriptionByUserId } from '@documenso/lib/server-only/subscription/get-subscription-by-user-id';
import { prisma } from '@documenso/prisma';
import { User } from '@documenso/prisma/client';
export type CreateCustomerOptions = {
user: User;
};
export const createCustomer = async ({ user }: CreateCustomerOptions) => {
const existingSubscription = await getSubscriptionByUserId({ userId: user.id });
if (existingSubscription) {
throw new Error('User already has a subscription');
}
const customer = await stripe.customers.create({
name: user.name ?? undefined,
email: user.email,
metadata: {
userId: user.id,
},
});
return await prisma.subscription.create({
data: {
userId: user.id,
customerId: customer.id,
},
});
};

View File

@ -1,4 +1,8 @@
import { stripe } from '@documenso/lib/server-only/stripe';
import { prisma } from '@documenso/prisma';
import type { User } from '@documenso/prisma/client';
import { onSubscriptionUpdated } from './webhook/on-subscription-updated';
export const getStripeCustomerByEmail = async (email: string) => {
const foundStripeCustomers = await stripe.customers.list({
@ -17,3 +21,74 @@ export const getStripeCustomerById = async (stripeCustomerId: string) => {
return null;
}
};
/**
* Get a stripe customer by user.
*
* Will create a Stripe customer and update the relevant user if one does not exist.
*/
export const getStripeCustomerByUser = async (user: User) => {
if (user.customerId) {
const stripeCustomer = await getStripeCustomerById(user.customerId);
if (!stripeCustomer) {
throw new Error('Missing Stripe customer');
}
return {
user,
stripeCustomer,
};
}
let stripeCustomer = await getStripeCustomerByEmail(user.email);
const isSyncRequired = Boolean(stripeCustomer && !stripeCustomer.deleted);
if (!stripeCustomer) {
stripeCustomer = await stripe.customers.create({
name: user.name ?? undefined,
email: user.email,
metadata: {
userId: user.id,
},
});
}
const updatedUser = await prisma.user.update({
where: {
id: user.id,
},
data: {
customerId: stripeCustomer.id,
},
});
// Sync subscriptions if the customer already exists for back filling the DB
// and local development.
if (isSyncRequired) {
await syncStripeCustomerSubscriptions(user.id, stripeCustomer.id).catch((e) => {
console.error(e);
});
}
return {
user: updatedUser,
stripeCustomer,
};
};
const syncStripeCustomerSubscriptions = async (userId: number, stripeCustomerId: string) => {
const stripeSubscriptions = await stripe.subscriptions.list({
customer: stripeCustomerId,
});
await Promise.all(
stripeSubscriptions.data.map(async (subscription) =>
onSubscriptionUpdated({
userId,
subscription,
}),
),
);
};

View File

@ -1,4 +1,4 @@
import Stripe from 'stripe';
import type Stripe from 'stripe';
import { stripe } from '@documenso/lib/server-only/stripe';
@ -7,7 +7,14 @@ type PriceWithProduct = Stripe.Price & { product: Stripe.Product };
export type PriceIntervals = Record<Stripe.Price.Recurring.Interval, PriceWithProduct[]>;
export const getPricesByInterval = async () => {
export type GetPricesByIntervalOptions = {
/**
* Filter products by their meta 'type' attribute.
*/
type?: 'individual';
};
export const getPricesByInterval = async ({ type }: GetPricesByIntervalOptions = {}) => {
let { data: prices } = await stripe.prices.search({
query: `active:'true' type:'recurring'`,
expand: ['data.product'],
@ -19,8 +26,10 @@ export const getPricesByInterval = async () => {
// eslint-disable-next-line @typescript-eslint/consistent-type-assertions
const product = price.product as Stripe.Product;
const filter = !type || product.metadata?.type === type;
// Filter out prices for products that are not active.
return product.active;
return product.active && filter;
});
const intervals: PriceIntervals = {

View File

@ -0,0 +1,11 @@
import { stripe } from '@documenso/lib/server-only/stripe';
export const getPricesByType = async (type: 'individual') => {
const { data: prices } = await stripe.prices.search({
query: `metadata['type']:'${type}' type:'recurring'`,
expand: ['data.product'],
limit: 100,
});
return prices;
};

View File

@ -75,23 +75,23 @@ export const stripeWebhookHandler = async (
// Finally, attempt to get the user ID from the subscription within the database.
if (!userId && customerId) {
const result = await prisma.subscription.findFirst({
const result = await prisma.user.findFirst({
select: {
userId: true,
id: true,
},
where: {
customerId,
},
});
if (!result?.userId) {
if (!result?.id) {
return res.status(500).json({
success: false,
message: 'User not found',
});
}
userId = result.userId;
userId = result.id;
}
const subscriptionId =
@ -124,23 +124,23 @@ export const stripeWebhookHandler = async (
? subscription.customer
: subscription.customer.id;
const result = await prisma.subscription.findFirst({
const result = await prisma.user.findFirst({
select: {
userId: true,
id: true,
},
where: {
customerId,
},
});
if (!result?.userId) {
if (!result?.id) {
return res.status(500).json({
success: false,
message: 'User not found',
});
}
await onSubscriptionUpdated({ userId: result.userId, subscription });
await onSubscriptionUpdated({ userId: result.id, subscription });
return res.status(200).json({
success: true,
@ -182,23 +182,23 @@ export const stripeWebhookHandler = async (
});
}
const result = await prisma.subscription.findFirst({
const result = await prisma.user.findFirst({
select: {
userId: true,
id: true,
},
where: {
customerId,
},
});
if (!result?.userId) {
if (!result?.id) {
return res.status(500).json({
success: false,
message: 'User not found',
});
}
await onSubscriptionUpdated({ userId: result.userId, subscription });
await onSubscriptionUpdated({ userId: result.id, subscription });
return res.status(200).json({
success: true,
@ -233,23 +233,23 @@ export const stripeWebhookHandler = async (
});
}
const result = await prisma.subscription.findFirst({
const result = await prisma.user.findFirst({
select: {
userId: true,
id: true,
},
where: {
customerId,
},
});
if (!result?.userId) {
if (!result?.id) {
return res.status(500).json({
success: false,
message: 'User not found',
});
}
await onSubscriptionUpdated({ userId: result.userId, subscription });
await onSubscriptionUpdated({ userId: result.id, subscription });
return res.status(200).json({
success: true,

View File

@ -1,4 +1,4 @@
import { Stripe } from '@documenso/lib/server-only/stripe';
import type { Stripe } from '@documenso/lib/server-only/stripe';
import { prisma } from '@documenso/prisma';
import { SubscriptionStatus } from '@documenso/prisma/client';
@ -7,12 +7,9 @@ export type OnSubscriptionDeletedOptions = {
};
export const onSubscriptionDeleted = async ({ subscription }: OnSubscriptionDeletedOptions) => {
const customerId =
typeof subscription.customer === 'string' ? subscription.customer : subscription.customer?.id;
await prisma.subscription.update({
where: {
customerId,
planId: subscription.id,
},
data: {
status: SubscriptionStatus.INACTIVE,

View File

@ -1,6 +1,6 @@
import { match } from 'ts-pattern';
import { Stripe } from '@documenso/lib/server-only/stripe';
import type { Stripe } from '@documenso/lib/server-only/stripe';
import { prisma } from '@documenso/prisma';
import { SubscriptionStatus } from '@documenso/prisma/client';
@ -13,9 +13,6 @@ export const onSubscriptionUpdated = async ({
userId,
subscription,
}: OnSubscriptionUpdatedOptions) => {
const customerId =
typeof subscription.customer === 'string' ? subscription.customer : subscription.customer?.id;
const status = match(subscription.status)
.with('active', () => SubscriptionStatus.ACTIVE)
.with('past_due', () => SubscriptionStatus.PAST_DUE)
@ -23,22 +20,22 @@ export const onSubscriptionUpdated = async ({
await prisma.subscription.upsert({
where: {
customerId,
planId: subscription.id,
},
create: {
customerId,
status: status,
planId: subscription.id,
priceId: subscription.items.data[0].price.id,
periodEnd: new Date(subscription.current_period_end * 1000),
userId,
cancelAtPeriodEnd: subscription.cancel_at_period_end,
},
update: {
customerId,
status: status,
planId: subscription.id,
priceId: subscription.items.data[0].price.id,
periodEnd: new Date(subscription.current_period_end * 1000),
cancelAtPeriodEnd: subscription.cancel_at_period_end,
},
});
};

View File

@ -9,7 +9,9 @@ export const getUsersWithSubscriptionsCount = async () => {
return await prisma.user.count({
where: {
Subscription: {
status: SubscriptionStatus.ACTIVE,
some: {
status: SubscriptionStatus.ACTIVE,
},
},
},
});

View File

@ -1,15 +0,0 @@
'use server';
import { prisma } from '@documenso/prisma';
export type GetSubscriptionByUserIdOptions = {
userId: number;
};
export const getSubscriptionByUserId = async ({ userId }: GetSubscriptionByUserIdOptions) => {
return await prisma.subscription.findFirst({
where: {
userId,
},
});
};

View File

@ -0,0 +1,15 @@
'use server';
import { prisma } from '@documenso/prisma';
export type GetSubscriptionsByUserIdOptions = {
userId: number;
};
export const getSubscriptionsByUserId = async ({ userId }: GetSubscriptionsByUserIdOptions) => {
return await prisma.subscription.findMany({
where: {
userId,
},
});
};

View File

@ -1,9 +1,11 @@
import { hash } from 'bcrypt';
import { getStripeCustomerByUser } from '@documenso/ee/server-only/stripe/get-customer';
import { prisma } from '@documenso/prisma';
import { IdentityProvider } from '@documenso/prisma/client';
import { SALT_ROUNDS } from '../../constants/auth';
import { getFlag } from '../../universal/get-feature-flag';
export interface CreateUserOptions {
name: string;
@ -13,6 +15,8 @@ export interface CreateUserOptions {
}
export const createUser = async ({ name, email, password, signature }: CreateUserOptions) => {
const isBillingEnabled = await getFlag('app_billing');
const hashedPassword = await hash(password, SALT_ROUNDS);
const userExists = await prisma.user.findFirst({
@ -25,7 +29,7 @@ export const createUser = async ({ name, email, password, signature }: CreateUse
throw new Error('User already exists');
}
return await prisma.user.create({
let user = await prisma.user.create({
data: {
name,
email: email.toLowerCase(),
@ -34,4 +38,15 @@ export const createUser = async ({ name, email, password, signature }: CreateUse
identityProvider: IdentityProvider.DOCUMENSO,
},
});
if (isBillingEnabled) {
try {
const stripeSession = await getStripeCustomerByUser(user);
user = stripeSession.user;
} catch (e) {
console.error(e);
}
}
return user;
};

View File

@ -0,0 +1,31 @@
/*
Warnings:
- A unique constraint covering the columns `[planId]` on the table `Subscription` will be added. If there are existing duplicate values, this will fail.
- A unique constraint covering the columns `[customerId]` on the table `User` will be added. If there are existing duplicate values, this will fail.
- Made the column `planId` on table `Subscription` required. This step will fail if there are existing NULL values in that column.
- Made the column `priceId` on table `Subscription` required. This step will fail if there are existing NULL values in that column.
*/
-- Custom migration statement
DELETE FROM "Subscription" WHERE "planId" IS NULL OR "priceId" IS NULL;
-- DropIndex
DROP INDEX "Subscription_customerId_key";
-- DropIndex
DROP INDEX "Subscription_userId_key";
-- AlterTable
ALTER TABLE "Subscription" ALTER COLUMN "planId" SET NOT NULL,
ALTER COLUMN "priceId" SET NOT NULL;
-- AlterTable
ALTER TABLE "User" ADD COLUMN "customerId" TEXT;
ALTER TABLE "Subscription" DROP COLUMN "customerId";
-- CreateIndex
CREATE UNIQUE INDEX "Subscription_planId_key" ON "Subscription"("planId");
-- CreateIndex
CREATE UNIQUE INDEX "User_customerId_key" ON "User"("customerId");

View File

@ -21,6 +21,7 @@ enum Role {
model User {
id Int @id @default(autoincrement())
name String?
customerId String? @unique
email String @unique
emailVerified DateTime?
password String?
@ -34,7 +35,7 @@ model User {
accounts Account[]
sessions Session[]
Document Document[]
Subscription Subscription?
Subscription Subscription[]
PasswordResetToken PasswordResetToken[]
twoFactorSecret String?
twoFactorEnabled Boolean @default(false)
@ -72,18 +73,16 @@ enum SubscriptionStatus {
model Subscription {
id Int @id @default(autoincrement())
status SubscriptionStatus @default(INACTIVE)
planId String?
priceId String?
customerId String
planId String @unique
priceId String
periodEnd DateTime?
userId Int @unique
userId Int
createdAt DateTime @default(now())
updatedAt DateTime @updatedAt
cancelAtPeriodEnd Boolean @default(false)
User User @relation(fields: [userId], references: [id], onDelete: Cascade)
@@unique([customerId])
@@index([userId])
}

View File

@ -10,8 +10,6 @@ declare namespace NodeJS {
NEXT_PRIVATE_ENCRYPTION_KEY: string;
NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_MONTHLY_PRICE_ID: string;
NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_YEARLY_PRICE_ID: string;
NEXT_PUBLIC_STRIPE_FREE_PLAN_ID?: string;
NEXT_PRIVATE_STRIPE_API_KEY: string;
NEXT_PRIVATE_STRIPE_WEBHOOK_SECRET: string;

View File

@ -67,14 +67,10 @@ services:
sync: false
- key: NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_MONTHLY_PRICE_ID
sync: false
- key: NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_YEARLY_PRICE_ID
sync: false
# Features
- key: NEXT_PUBLIC_POSTHOG_KEY
sync: false
- key: NEXT_PUBLIC_POSTHOG_HOST
value: 'https://eu.posthog.com'
- key: NEXT_PUBLIC_FEATURE_BILLING_ENABLED
sync: false

View File

@ -40,11 +40,8 @@
"NEXT_PUBLIC_WEBAPP_URL",
"NEXT_PUBLIC_MARKETING_URL",
"NEXT_PUBLIC_POSTHOG_KEY",
"NEXT_PUBLIC_POSTHOG_HOST",
"NEXT_PUBLIC_FEATURE_BILLING_ENABLED",
"NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_YEARLY_PRICE_ID",
"NEXT_PUBLIC_STRIPE_COMMUNITY_PLAN_MONTHLY_PRICE_ID",
"NEXT_PUBLIC_STRIPE_FREE_PLAN_ID",
"NEXT_PRIVATE_DATABASE_URL",
"NEXT_PRIVATE_DIRECT_DATABASE_URL",
"NEXT_PRIVATE_GOOGLE_CLIENT_ID",