Invoke Google Vertex AI Agent from Salesforce Lightning Web Component

Invoke Google Vertex AI Agent from Salesforce Lightning Web Component

In this Blog Post, I have used Salesforce Lightning Web Component to invoke the Google Vertex AI Agent.

Please check the following for the prerequisites.

Sample Code:

Apex Controller:

public class GoogleVertexAgentHandler {
    
    public static String LastTokenIssued;
    public static FINAL String strRegionId = '<YOUR_REGION>';
    public static FINAL String strProjectId = '<YOUR_PROJECT_ID>';
    public static FINAL String strAgentId = '<YOUR_AGENT_ID>';
    
    @AuraEnabled
    public static Map < String, Object > sendRequestToVertexAI( 
        String strAccessToken, 
        String strLastTokenIssued, 
        String strSessionId, 
        String strMessage 
    ) {
        
        Map < String, Object > resultMap = new Map < String, Object >(); 
        resultMap.put( 
            'Message', 
            'Some error occurred. Please try again later.' 
        );
        
        if ( 
            String.isBlank( strAccessToken ) || 
            String.isBlank( strLastTokenIssued ) || 
            ( ( DateTime.now().getTime() - DateTime.newInstance( Long.valueOf( strLastTokenIssued ) ).getTime() )  / 60000 ) > 55 
        ) {
            
            strAccessToken = getAccessToken();
            System.debug ( strAccessToken );
            resultMap.put( 'AccessToken', strAccessToken );
            System.debug( 'LastTokenIssued is:' + LastTokenIssued );
            resultMap.put( 'LastTokenIssued', LastTokenIssued );
            
        }
        
        if ( strAccessToken != 'error' ) {
            
            if ( String.isBlank( strSessionId ) ) {
                
                // Generating UUID for the Session Id
                UUID randomUUID = UUID.randomUUID();
                strSessionId = randomUUID.toString();
                resultMap.put( 'SessionId', strSessionId );
                
            }
            
            String strEndpoint = 'https://' + strRegionId + 
                '-dialogflow.googleapis.com/v3/projects/' +
                strProjectId + '/locations/' + 
                strRegionId + '/agents/' + 
                strAgentId + '/sessions/' + 
                strSessionId + ':detectIntent';
            
            HTTP h = new HTTP();
            HTTPRequest req = new HTTPRequest();
            req.setEndPoint( 
                strEndpoint 
            );
            req.setMethod( 
                'POST' 
            );
            req.setHeader( 
                'Content-Type', 
                'application/json' 
            );    
            req.setHeader( 
                'X-Goog-User-Project', 
                strProjectId 
            );   
            req.setHeader( 
                'Authorization', 
                'Bearer ' + strAccessToken 
            );       
            req.setBody( 
                '{ "queryInput": { "language_code": "en",' +
                ' "text": { "text": "' 
                + strMessage + '" } } }'
            );
            HTTPResponse response = h.send(
                req
            );
            System.debug( 
                'Response is ' + 
                response.getBody() 
            );
            
            Map < String, Object > responseMap = 
                ( Map < String, Object > ) JSON.deserializeUntyped( response.getBody() );
            
            if ( responseMap.containsKey( 'queryResult' ) ) {
                
                Map < String, Object > queryResultMap = 
                    ( Map < String, Object > ) responseMap.get( 'queryResult' );
                
                if ( queryResultMap.containsKey( 'responseMessages' ) ) {
                    
                    List < Object > responseMessagesList = 
                        ( List < Object > ) queryResultMap.get( 'responseMessages' );
                    Map < String, Object > firstMessageMap = 
                        ( Map < String, Object > ) responseMessagesList.get( 0 );
                    
                    if ( firstMessageMap.containsKey( 'text' ) ) {
                        
                        Map < String, Object > textMap = 
                            ( Map < String, Object > ) firstMessageMap.get( 'text' );
                        List < Object > textList = ( List < Object > ) textMap.get( 'text' );
                        
                        if ( !textList.isEmpty() ) {
                            
                            System.debug( 
                                'Response is ' + 
                                textList.get( 0 ).toString() 
                            );
                            resultMap.put( 'Message', textList.get( 0 ).toString() );
                            
                            
                        } 
                        
                    }  else {
                            
                        System.debug( 'No Text' ); 
                        if ( queryResultMap.containsKey( 'match' ) ) {
                            
                            Map < String, Object > matchMap = 
                                ( Map < String, Object > ) queryResultMap.get( 'match' );
                            String strEvent = ( String ) matchMap.get( 'event' );
                            
                            if ( strEvent == 'END_SESSION' ) {
                                
                                resultMap.put( 
                                    'Message', 
                                    'Thank you for chatting. Your session is ended'
                                );
                                resultMap.put( 'SessionId', '' );
                                
                            }
                            
                        }
                        
                    }
                    
                } 
                
            }
            
        } else {
            
            resultMap.put( 'AccessToken', '' );
            resultMap.put( 'LastTokenIssued', '' );
            
        }
        
        return resultMap;
        
    }
    
    public static String getAccessToken() {
        
        String scopes = 'https://www.googleapis.com/auth/cloud-platform';
        StaticResource objSR = [ 
            SELECT Body 
            FROM StaticResource 
            WHERE Name = 'GoogleJSON'
        ];
        
        String serviceAccountJson = objSR.Body.toString();
        
        try {
            
            // Parse the JSON service account file
            Map < String, Object > serviceAccount = 
                ( Map < String, Object > ) JSON.deserializeUntyped( serviceAccountJson );
            String clientEmail = ( String ) serviceAccount.get( 'client_email' );
            String privateKey = ( String ) serviceAccount.get( 'private_key' );
            
            // Generate the JWT
            String jwt = generateJwt( clientEmail, scopes, privateKey );
            
            // Make the HTTP request to get the access token
            HttpRequest req = new HttpRequest();
            req.setEndpoint( 'https://oauth2.googleapis.com/token' );
            req.setMethod( 'POST' );
            req.setHeader( 'Content-Type', 'application/x-www-form-urlencoded' );
            req.setBody( 
                'grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer&assertion=' +
                EncodingUtil.urlEncode( jwt, 'UTF-8' ) 
            );
            
            Http http = new Http();
            HttpResponse res = http.send(req);
            
            // Parse the response and return the access token
            if ( res.getStatusCode() == 200 ) {
                
                Map < String, Object > response = 
                    ( Map < String, Object > ) JSON.deserializeUntyped( res.getBody() );
                return ( String ) response.get( 'access_token' );
                
            } else {
                
                System.debug( 
                    'Error getting access token: ' + 
                    res.getStatusCode() + ' ' + res.getBody() 
                );
                return 'error';
                
            }
            
        } catch ( Exception e ) {
            
            System.debug( 
                'Exception in getAccessToken: ' + 
                e.getMessage() 
            );
            return 'error';
            
        }
        
    }
    
    private static String generateJwt( 
        String clientEmail, 
        String scopes, 
        String privateKey 
    ) {
        
        try {
            
            // Build the JWT payload
            Map < String, Object > payload = new Map < String, Object >();
            payload.put( 'iss', clientEmail );
            payload.put( 'scope', scopes );
            payload.put( 'aud', 'https://oauth2.googleapis.com/token' );
            // Setting Token to expire in an hour
            Long tokenExpiry = ( System.now().getTime() / 1000 ) + 3600;
            payload.put( 'exp', tokenExpiry ); 
            LastTokenIssued = String.valueOf( DateTime.now().getTime() );
            payload.put( 'iat', System.now().getTime() / 1000 );
            
            // Encode the JWT header and payload
            String header = EncodingUtil.base64Encode( 
                Blob.valueOf( 
                    JSON.serialize(
                        new Map < String, String > { 'alg' => 'RS256', 'typ' => 'JWT' } 
                    ) 
                ) 
            );
            String encodedPayload = EncodingUtil.base64Encode( 
                Blob.valueOf( JSON.serialize( payload ) ) 
            );
            
            // Sign the JWT
            String unsignedToken = header + '.' + encodedPayload;
            Blob signature = Crypto.sign( 
                'RSA-SHA256', 
                Blob.valueOf( unsignedToken ), 
                EncodingUtil.base64Decode(
                    privateKey.replaceAll(
                        '-----BEGIN PRIVATE KEY-----', ''
                    ).replaceAll(
                        '-----END PRIVATE KEY-----', ''
                    ).replaceAll(
                        '\n', ''
                    )
                )
            );
            String encodedSignature = EncodingUtil.base64Encode( signature );
            
            // Construct the final JWT
            return unsignedToken + '.' + encodedSignature;
            
        } catch ( Exception e ) {
            
            System.debug( 
                'Exception in generateJwt: ' + 
                e.getMessage() 
            );
            return null;
            
        }
        
    }
    
}

Lightning Web Component:

HTML:

<template>
    <div class="slds-align_absolute-center">
        <div style="width: 450px;">
            <lightning-card title="Messaging">
                <!-- Displaying spinner when sending and receiving messages -->
                <template lwc:if={showSpinner}>
                    <lightning-spinner 
                        alternative-text="Loading" 
                        size="large">
                    </lightning-spinner>
                </template>
                <div class="slds-p-around_medium">
                    <div class="slds-grid slds-grid_vertical">
                            <div
                                class="scrollableArea" 
                                style="height: 300px; overflow-x: hidden; overflow-y: auto;">
                            <template for:each={messages} for:item="message">
                                <div 
                                    key={message.id} 
                                    class={message.senderClass} 
                                    style={message.senderStyle}>
                                    <div class="slds-p-around_small">
                                        <p>{message.role}: {message.text}</p>
                                        <div class="slds-text-body_small">
                                            {message.timestamp}
                                        </div>
                                    </div>
                                </div>
                            </template>
                        </div>
                        <div class="slds-grid slds-gutters">
                            <div class="slds-col slds-size_2-of-3">
                                <lightning-input type="text"
                                                value={strMessage}
                                                variant="label-hidden"
                                                onchange={handleInputChange}
                                                placeholder="Type a message...">
                                </lightning-input>
                            </div>
                            <div class="slds-col slds-size_1-of-3">
                                <lightning-button label="Send" onclick={sendMessage}>
                                </lightning-button>
                            </div>
                        </div>
                    </div>
                </div>
            </lightning-card>
        </div>
    </div>
</template>

JavaScript:

import { LightningElement } from 'lwc';
import sendRequestToVertexAI from '@salesforce/apex/GoogleVertexAgentHandler.sendRequestToVertexAI';

export default class GoogleVertexAgent extends LightningElement {
    
    strMessage;
    strSessionId;
    strAccessToken;
    strLastTokenIssued;
    messages = [];
    messageId = 0;
    showSpinner = false;

    connectedCallback() {
        
        const timestamp = new Date( Date.now() ).toUTCString();
        const initialMessage = {
            role: 'Agent',
            id: this.messageId++,
            text: 'Hello! How can I help you?',
            timestamp: timestamp,
            senderClass: 'slds-text-align_left',
            senderStyle: 'background:white;'
        };
        this.messages = [ initialMessage ];

    }

    handleInputChange( event ) {
        this.strMessage = event.target.value;
    }

    sendMessage() {

        this.showSpinner = true;

        if ( 
            this.strMessage && 
            this.strMessage.length > 0 
        ) {
            
            const timestamp = new Date( Date.now() ).toUTCString();

            const message = {
                role: 'You',
                id: this.messageId++,
                text: this.strMessage,
                timestamp: timestamp,
                senderClass: 'slds-text-align_right slds-text-color_inverse',
                senderStyle: 'background:#16325c;'
            };

            this.messages = [ ...this.messages, message ];

            // Invoking the Apex method to send the message to the Google Agent
            sendRequestToVertexAI( { 
                strAccessToken : this.strAccessToken, 
                strLastTokenIssued : this.strLastTokenIssued, 
                strSessionId : this.strSessionId, 
                strMessage : this.strMessage 

            } )    
            .then( result => {  
                
                console.log( 'result is', JSON.stringify( result ) );

                if ( result.Message ) {
                    
                    const timestamp = new Date( Date.now() ).toUTCString();
                    const agentMessage = {
                        role: 'Agent',
                        id: this.messageId++,
                        text: result.Message,
                        timestamp: timestamp,
                        senderClass: 'slds-text-align_left', 
                        senderStyle: 'background:white;'
                    };
                    this.messages = [ ...this.messages, agentMessage ];
                    
                } 


                if ( result.LastTokenIssued ) {

                    this.strLastTokenIssued = result.LastTokenIssued;

                }
                
                if ( result.AccessToken ) {

                    this.strAccessToken = result.AccessToken;

                }

                if ( result.SessionId ) {

                    this.strSessionId = result.SessionId;

                }

                this.template.querySelector( '.scrollableArea' ).scrollTop = 
                    this.template.querySelector( '.scrollableArea' ).scrollHeight;
                this.showSpinner = false;

            } )  
            .catch( error => {  
                
                this.showSpinner = false;
                console.log( 'Error Occured', JSON.stringify( error ) );

            } );  

            this.strMessage = '';

        }

    }

}

js-meta.xml:

<?xml version="1.0" encoding="UTF-8"?>
<LightningComponentBundle xmlns="http://soap.sforce.com/2006/04/metadata">
    <apiVersion>63.0</apiVersion>
    <isExposed>true</isExposed>
    <targets>
        <target>lightning__Tab</target>
    </targets>
</LightningComponentBundle>

Output:

Leave a Reply