You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
3.9 KiB

11 months ago
  1. const OpenAI = require("openai");
  2. const Provider = require("./ai-provider.js");
  3. const { RetryError } = require("../error.js");
  4. /**
  5. * The agent provider for the OpenAI API.
  6. * By default, the model is set to 'gpt-3.5-turbo'.
  7. */
  8. class OpenAIProvider extends Provider {
  9. model;
  10. static COST_PER_TOKEN = {
  11. "gpt-3.5-turbo": {
  12. input: 0.0015,
  13. output: 0.002,
  14. },
  15. "gpt-3.5-turbo-16k": {
  16. input: 0.003,
  17. output: 0.004,
  18. },
  19. "gpt-4": {
  20. input: 0.03,
  21. output: 0.06,
  22. },
  23. "gpt-4-turbo": {
  24. input: 0.01,
  25. output: 0.03,
  26. },
  27. "gpt-4o": {
  28. input: 0.005,
  29. output: 0.015,
  30. },
  31. "gpt-4-32k": {
  32. input: 0.06,
  33. output: 0.12,
  34. },
  35. "gpt-4o-mini": {
  36. input: 0.00015,
  37. output: 0.0006,
  38. },
  39. };
  40. constructor(config = {}) {
  41. const {
  42. options = {
  43. apiKey: process.env.OPEN_AI_KEY,
  44. maxRetries: 3,
  45. },
  46. model = "gpt-4o",
  47. } = config;
  48. const client = new OpenAI(options);
  49. super(client);
  50. this.model = model;
  51. }
  52. /**
  53. * Create a completion based on the received messages.
  54. *
  55. * @param messages A list of messages to send to the OpenAI API.
  56. * @param functions
  57. * @returns The completion.
  58. */
  59. async complete(messages, functions = null) {
  60. try {
  61. const response = await this.client.chat.completions.create({
  62. model: this.model,
  63. // stream: true,
  64. messages,
  65. ...(Array.isArray(functions) && functions?.length > 0
  66. ? { functions }
  67. : {}),
  68. });
  69. // Right now, we only support one completion,
  70. // so we just take the first one in the list
  71. const completion = response.choices[0].message;
  72. const cost = this.getCost(response.usage);
  73. // treat function calls
  74. if (completion.function_call) {
  75. let functionArgs = {};
  76. try {
  77. functionArgs = JSON.parse(completion.function_call.arguments);
  78. } catch (error) {
  79. // call the complete function again in case it gets a json error
  80. return this.complete(
  81. [
  82. ...messages,
  83. {
  84. role: "function",
  85. name: completion.function_call.name,
  86. function_call: completion.function_call,
  87. content: error?.message,
  88. },
  89. ],
  90. functions
  91. );
  92. }
  93. // console.log(completion, { functionArgs })
  94. return {
  95. result: null,
  96. functionCall: {
  97. name: completion.function_call.name,
  98. arguments: functionArgs,
  99. },
  100. cost,
  101. };
  102. }
  103. return {
  104. result: completion.content,
  105. cost,
  106. };
  107. } catch (error) {
  108. // If invalid Auth error we need to abort because no amount of waiting
  109. // will make auth better.
  110. if (error instanceof OpenAI.AuthenticationError) throw error;
  111. if (
  112. error instanceof OpenAI.RateLimitError ||
  113. error instanceof OpenAI.InternalServerError ||
  114. error instanceof OpenAI.APIError // Also will catch AuthenticationError!!!
  115. ) {
  116. throw new RetryError(error.message);
  117. }
  118. throw error;
  119. }
  120. }
  121. /**
  122. * Get the cost of the completion.
  123. *
  124. * @param usage The completion to get the cost for.
  125. * @returns The cost of the completion.
  126. */
  127. getCost(usage) {
  128. if (!usage) {
  129. return Number.NaN;
  130. }
  131. // regex to remove the version number from the model
  132. const modelBase = this.model.replace(/-(\d{4})$/, "");
  133. if (!(modelBase in OpenAIProvider.COST_PER_TOKEN)) {
  134. return Number.NaN;
  135. }
  136. const costPerToken = OpenAIProvider.COST_PER_TOKEN?.[modelBase];
  137. const inputCost = (usage.prompt_tokens / 1000) * costPerToken.input;
  138. const outputCost = (usage.completion_tokens / 1000) * costPerToken.output;
  139. return inputCost + outputCost;
  140. }
  141. }
  142. module.exports = OpenAIProvider;