You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

152 lines
4.4 KiB

11 months ago
  1. const Provider = require("./ai-provider.js");
  2. const InheritMultiple = require("./helpers/classes.js");
  3. const UnTooled = require("./helpers/untooled.js");
  4. const { ChatBedrockConverse } = require("@langchain/aws");
  5. const {
  6. HumanMessage,
  7. SystemMessage,
  8. AIMessage,
  9. } = require("@langchain/core/messages");
  10. /**
  11. * The agent provider for the AWS Bedrock provider.
  12. */
  13. class AWSBedrockProvider extends InheritMultiple([Provider, UnTooled]) {
  14. model;
  15. constructor(_config = {}) {
  16. super();
  17. const model = process.env.AWS_BEDROCK_LLM_MODEL_PREFERENCE ?? null;
  18. const client = new ChatBedrockConverse({
  19. region: process.env.AWS_BEDROCK_LLM_REGION,
  20. credentials: {
  21. accessKeyId: process.env.AWS_BEDROCK_LLM_ACCESS_KEY_ID,
  22. secretAccessKey: process.env.AWS_BEDROCK_LLM_ACCESS_KEY,
  23. // If we're using a session token, we need to pass it in as a credential
  24. // otherwise we must omit it so it does not conflict if using IAM auth
  25. ...(this.authMethod === "sessionToken"
  26. ? { sessionToken: process.env.AWS_BEDROCK_LLM_SESSION_TOKEN }
  27. : {}),
  28. },
  29. model,
  30. });
  31. this._client = client;
  32. this.model = model;
  33. this.verbose = true;
  34. }
  35. /**
  36. * Get the authentication method for the AWS Bedrock LLM.
  37. * There are only two valid values for this setting - anything else will default to "iam".
  38. * @returns {"iam"|"sessionToken"}
  39. */
  40. get authMethod() {
  41. const method = process.env.AWS_BEDROCK_LLM_CONNECTION_METHOD || "iam";
  42. if (!["iam", "sessionToken"].includes(method)) return "iam";
  43. return method;
  44. }
  45. get client() {
  46. return this._client;
  47. }
  48. // For streaming we use Langchain's wrapper to handle weird chunks
  49. // or otherwise absorb headaches that can arise from Ollama models
  50. #convertToLangchainPrototypes(chats = []) {
  51. const langchainChats = [];
  52. const roleToMessageMap = {
  53. system: SystemMessage,
  54. user: HumanMessage,
  55. assistant: AIMessage,
  56. };
  57. for (const chat of chats) {
  58. if (!roleToMessageMap.hasOwnProperty(chat.role)) continue;
  59. const MessageClass = roleToMessageMap[chat.role];
  60. langchainChats.push(new MessageClass({ content: chat.content }));
  61. }
  62. return langchainChats;
  63. }
  64. async #handleFunctionCallChat({ messages = [] }) {
  65. const response = await this.client
  66. .invoke(this.#convertToLangchainPrototypes(messages))
  67. .then((res) => res)
  68. .catch((e) => {
  69. console.error(e);
  70. return null;
  71. });
  72. return response?.content;
  73. }
  74. /**
  75. * Create a completion based on the received messages.
  76. *
  77. * @param messages A list of messages to send to the API.
  78. * @param functions
  79. * @returns The completion.
  80. */
  81. async complete(messages, functions = null) {
  82. try {
  83. let completion;
  84. if (functions.length > 0) {
  85. const { toolCall, text } = await this.functionCall(
  86. messages,
  87. functions,
  88. this.#handleFunctionCallChat.bind(this)
  89. );
  90. if (toolCall !== null) {
  91. this.providerLog(`Valid tool call found - running ${toolCall.name}.`);
  92. this.deduplicator.trackRun(toolCall.name, toolCall.arguments);
  93. return {
  94. result: null,
  95. functionCall: {
  96. name: toolCall.name,
  97. arguments: toolCall.arguments,
  98. },
  99. cost: 0,
  100. };
  101. }
  102. completion = { content: text };
  103. }
  104. if (!completion?.content) {
  105. this.providerLog(
  106. "Will assume chat completion without tool call inputs."
  107. );
  108. const response = await this.client.invoke(
  109. this.#convertToLangchainPrototypes(this.cleanMsgs(messages))
  110. );
  111. completion = response;
  112. }
  113. // The UnTooled class inherited Deduplicator is mostly useful to prevent the agent
  114. // from calling the exact same function over and over in a loop within a single chat exchange
  115. // _but_ we should enable it to call previously used tools in a new chat interaction.
  116. this.deduplicator.reset("runs");
  117. return {
  118. result: completion.content,
  119. cost: 0,
  120. };
  121. } catch (error) {
  122. throw error;
  123. }
  124. }
  125. /**
  126. * Get the cost of the completion.
  127. *
  128. * @param _usage The completion to get the cost for.
  129. * @returns The cost of the completion.
  130. * Stubbed since KoboldCPP has no cost basis.
  131. */
  132. getCost(_usage) {
  133. return 0;
  134. }
  135. }
  136. module.exports = AWSBedrockProvider;