You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

239 lines
6.1 KiB

11 months ago
  1. const { v4 } = require("uuid");
  2. const prisma = require("../utils/prisma");
  3. const { VALID_CHAT_MODE } = require("../utils/chats/stream");
  4. const EmbedConfig = {
  5. writable: [
  6. // Used for generic updates so we can validate keys in request body
  7. "enabled",
  8. "allowlist_domains",
  9. "allow_model_override",
  10. "allow_temperature_override",
  11. "allow_prompt_override",
  12. "max_chats_per_day",
  13. "max_chats_per_session",
  14. "chat_mode",
  15. "workspace_id",
  16. ],
  17. new: async function (data, creatorId = null) {
  18. try {
  19. const embed = await prisma.embed_configs.create({
  20. data: {
  21. uuid: v4(),
  22. enabled: true,
  23. chat_mode: validatedCreationData(data?.chat_mode, "chat_mode"),
  24. allowlist_domains: validatedCreationData(
  25. data?.allowlist_domains,
  26. "allowlist_domains"
  27. ),
  28. allow_model_override: validatedCreationData(
  29. data?.allow_model_override,
  30. "allow_model_override"
  31. ),
  32. allow_temperature_override: validatedCreationData(
  33. data?.allow_temperature_override,
  34. "allow_temperature_override"
  35. ),
  36. allow_prompt_override: validatedCreationData(
  37. data?.allow_prompt_override,
  38. "allow_prompt_override"
  39. ),
  40. max_chats_per_day: validatedCreationData(
  41. data?.max_chats_per_day,
  42. "max_chats_per_day"
  43. ),
  44. max_chats_per_session: validatedCreationData(
  45. data?.max_chats_per_session,
  46. "max_chats_per_session"
  47. ),
  48. createdBy: Number(creatorId) ?? null,
  49. workspace: {
  50. connect: { id: Number(data.workspace_id) },
  51. },
  52. },
  53. });
  54. return { embed, message: null };
  55. } catch (error) {
  56. console.error(error.message);
  57. return { embed: null, message: error.message };
  58. }
  59. },
  60. update: async function (embedId = null, data = {}) {
  61. if (!embedId) throw new Error("No embed id provided for update");
  62. const validKeys = Object.keys(data).filter((key) =>
  63. this.writable.includes(key)
  64. );
  65. if (validKeys.length === 0)
  66. return { embed: { id }, message: "No valid fields to update!" };
  67. const updates = {};
  68. validKeys.map((key) => {
  69. updates[key] = validatedCreationData(data[key], key);
  70. });
  71. try {
  72. await prisma.embed_configs.update({
  73. where: { id: Number(embedId) },
  74. data: updates,
  75. });
  76. return { success: true, error: null };
  77. } catch (error) {
  78. console.error(error.message);
  79. return { success: false, error: error.message };
  80. }
  81. },
  82. get: async function (clause = {}) {
  83. try {
  84. const embedConfig = await prisma.embed_configs.findFirst({
  85. where: clause,
  86. });
  87. return embedConfig || null;
  88. } catch (error) {
  89. console.error(error.message);
  90. return null;
  91. }
  92. },
  93. getWithWorkspace: async function (clause = {}) {
  94. try {
  95. const embedConfig = await prisma.embed_configs.findFirst({
  96. where: clause,
  97. include: {
  98. workspace: true,
  99. },
  100. });
  101. return embedConfig || null;
  102. } catch (error) {
  103. console.error(error.message);
  104. return null;
  105. }
  106. },
  107. delete: async function (clause = {}) {
  108. try {
  109. await prisma.embed_configs.delete({
  110. where: clause,
  111. });
  112. return true;
  113. } catch (error) {
  114. console.error(error.message);
  115. return false;
  116. }
  117. },
  118. where: async function (clause = {}, limit = null, orderBy = null) {
  119. try {
  120. const results = await prisma.embed_configs.findMany({
  121. where: clause,
  122. ...(limit !== null ? { take: limit } : {}),
  123. ...(orderBy !== null ? { orderBy } : {}),
  124. });
  125. return results;
  126. } catch (error) {
  127. console.error(error.message);
  128. return [];
  129. }
  130. },
  131. whereWithWorkspace: async function (
  132. clause = {},
  133. limit = null,
  134. orderBy = null
  135. ) {
  136. try {
  137. const results = await prisma.embed_configs.findMany({
  138. where: clause,
  139. include: {
  140. workspace: true,
  141. _count: {
  142. select: { embed_chats: true },
  143. },
  144. },
  145. ...(limit !== null ? { take: limit } : {}),
  146. ...(orderBy !== null ? { orderBy } : {}),
  147. });
  148. return results;
  149. } catch (error) {
  150. console.error(error.message);
  151. return [];
  152. }
  153. },
  154. // Will return null if process should be skipped
  155. // an empty array means the system will check. This
  156. // prevents a bad parse from allowing all requests
  157. parseAllowedHosts: function (embed) {
  158. if (!embed.allowlist_domains) return null;
  159. try {
  160. return JSON.parse(embed.allowlist_domains);
  161. } catch {
  162. console.error(`Failed to parse allowlist_domains for Embed ${embed.id}!`);
  163. return [];
  164. }
  165. },
  166. };
  167. const BOOLEAN_KEYS = [
  168. "allow_model_override",
  169. "allow_temperature_override",
  170. "allow_prompt_override",
  171. "enabled",
  172. ];
  173. const NUMBER_KEYS = [
  174. "max_chats_per_day",
  175. "max_chats_per_session",
  176. "workspace_id",
  177. ];
  178. // Helper to validate a data object strictly into the proper format
  179. function validatedCreationData(value, field) {
  180. if (field === "chat_mode") {
  181. if (!value || !VALID_CHAT_MODE.includes(value)) return "query";
  182. return value;
  183. }
  184. if (field === "allowlist_domains") {
  185. try {
  186. if (!value) return null;
  187. return JSON.stringify(
  188. // Iterate and force all domains to URL object
  189. // and stringify the result.
  190. value
  191. .split(",")
  192. .map((input) => {
  193. let url = input;
  194. if (!url.includes("http://") && !url.includes("https://"))
  195. url = `https://${url}`;
  196. try {
  197. new URL(url);
  198. return url;
  199. } catch {
  200. return null;
  201. }
  202. })
  203. .filter((u) => !!u)
  204. );
  205. } catch {
  206. return null;
  207. }
  208. }
  209. if (BOOLEAN_KEYS.includes(field)) {
  210. return value === true || value === false ? value : false;
  211. }
  212. if (NUMBER_KEYS.includes(field)) {
  213. return isNaN(value) || Number(value) <= 0 ? null : Number(value);
  214. }
  215. return null;
  216. }
  217. module.exports = { EmbedConfig };