Invoke Google Vertex AI Agent from Salesforce Apex

Invoke Google Vertex AI Agent from Salesforce Apex

Prerequisites:

1. Remote Site Settings:

In Salesforce Setup, add the Dialog Flow URL and Token URL. Make sure these remote site settings are active so that the Apex call to these endpoints will not fail.

Here us-central1 is the Region Id or location id. If you are using a global region then use ‘global’ instead of ‘us-central1’.

2. Service Account in Google IAM:

a. Go to Service Accounts in Google IAM & Admin.

b. Create a new Service Account. Use the following roles when you create it.

i. Dialogflow API Admin

ii. Service Usage Consumer

c. Create a key of JSON type. The JSON key will be downloaded.

3. Google Vertex AI Agent:

Create an Agent in Google Vertex AI. Go to https://console.cloud.google.com/gen-app-builder/engines, select the Project and click Create App.

4. Static Resource in Salesforce with Service Account JSON:

Create a Static Resource in Salesforce Setup and upload the Google Service Account’s JSON Key.

We have to make a POST Request to the following endpoint. So, get the REGION_ID, PROJECT_ID and AGENT_ID.

https://REGION_ID-dialogflow.googleapis.com/v3/projects/PROJECT_ID/locations/REGION_ID/agents/AGENT_ID/sessions/SESSION_ID:detectIntent

Sample Apex Code:

public class GoogleVertexAgentHandler {
    
    public static FINAL String strRegionId = '<YOUR_REGION>';
    public static FINAL String strProjectId = '<YOUR_PROJECT_ID>';
    public static FINAL String strAgentId = '<YOUR_AGENT_ID>';

    public static String sendRequestToVertexAI( String strMessage ) {
    
        // Setting default value to return
        String strResponse = 'Some error occurred. Please try again later.';
        String strAccessToken = getAccessToken();
        System.debug ( strAccessToken );
        
        // Generating UUID for the Session Id
        UUID randomUUID = UUID.randomUUID();
        String strSessionId = randomUUID.toString();
        
        String strEndpoint = 'https://' + strRegionId + 
            '-dialogflow.googleapis.com/v3/projects/' +
        	strProjectId + '/locations/' + 
            strRegionId + '/agents/' + 
        	strAgentId + '/sessions/' + 
            strSessionId + ':detectIntent';
        
        HTTP h = new HTTP();
        HTTPRequest req = new HTTPRequest();
        req.setEndPoint( 
            strEndpoint 
        );
        req.setMethod( 
            'POST' 
        );
        req.setHeader( 
            'Content-Type', 
            'application/json' 
        );    
        req.setHeader( 
            'X-Goog-User-Project', 
            strProjectId 
        );   
        req.setHeader( 
            'Authorization', 
            'Bearer ' + strAccessToken 
        );       
        req.setBody( 
            '{ "queryInput": { "language_code": "en",' +
            ' "text": { "text": "' 
            + strMessage + '" } } }'
        );
        HTTPResponse response = h.send(
            req
        );
        System.debug( 
            'Response is ' + 
            response.getBody() 
        );
        
        Map < String, Object > responseMap = 
            ( Map < String, Object > ) JSON.deserializeUntyped( response.getBody() );
        
        if ( responseMap.containsKey( 'queryResult' ) ) {
        
            Map < String, Object > queryResultMap = 
                ( Map < String, Object > ) responseMap.get( 'queryResult' );
            
            if ( queryResultMap.containsKey( 'responseMessages' ) ) {
            
                List < Object > responseMessagesList = 
                    ( List < Object > ) queryResultMap.get( 'responseMessages' );
                Map < String, Object > firstMessageMap = 
                    ( Map < String, Object > ) responseMessagesList.get( 0 );
                
                if ( firstMessageMap.containsKey( 'text' ) ) {
                
                    Map < String, Object > textMap = 
                        ( Map < String, Object > ) firstMessageMap.get( 'text' );
                    List < Object > textList = ( List < Object > ) textMap.get( 'text' );
            
                    if ( !textList.isEmpty()) {
                    
                        System.debug( 
                            'Response is ' + 
                            textList.get( 0 ).toString() 
                        );
                        strResponse = textList.get( 0 ).toString();
                        
                    } else {
                        System.debug( 'No Text' ); 
                    }
                
                }
            
            }
        
        }
        
        return strResponse;
    
    }

    public static String getAccessToken() {

        String scopes = 'https://www.googleapis.com/auth/cloud-platform';
        StaticResource objSR = [ 
            SELECT Body 
            FROM StaticResource 
            WHERE Name = 'GoogleJSON'
        ];
        
        String serviceAccountJson = objSR.Body.toString();
        
        try {
        
            // Parse the JSON service account file
            Map < String, Object > serviceAccount = ( Map < String, Object > ) JSON.deserializeUntyped( serviceAccountJson );
            String clientEmail = ( String ) serviceAccount.get( 'client_email' );
            String privateKey = ( String ) serviceAccount.get( 'private_key' );

            // Generate the JWT
            String jwt = generateJwt( clientEmail, scopes, privateKey );

            // Make the HTTP request to get the access token
            HttpRequest req = new HttpRequest();
            req.setEndpoint( 'https://oauth2.googleapis.com/token' );
            req.setMethod( 'POST' );
            req.setHeader( 'Content-Type', 'application/x-www-form-urlencoded' );
            req.setBody( 
                'grant_type=urn:ietf:params:oauth:grant-type:jwt-bearer&assertion=' +
                EncodingUtil.urlEncode( jwt, 'UTF-8' ) 
            );

            Http http = new Http();
            HttpResponse res = http.send(req);

            // Parse the response and return the access token
            if ( res.getStatusCode() == 200 ) {
            
                Map < String, Object > response = ( Map < String, Object > ) JSON.deserializeUntyped( res.getBody() );
                return ( String ) response.get( 'access_token' );
                
            } else {
            
                System.debug( 'Error getting access token: ' + res.getStatusCode() + ' ' + res.getBody() );
                return null;
                
            }

        } catch ( Exception e ) {
        
            System.debug( 
                'Exception in getAccessToken: ' + 
                e.getMessage() 
            );
            return null;
            
        }
        
    }

    private static String generateJwt( String clientEmail, String scopes, String privateKey ) {
    
        try {
        
            // Build the JWT payload
            Map < String, Object > payload = new Map < String, Object >();
            payload.put( 'iss', clientEmail );
            payload.put( 'scope', scopes );
            payload.put( 'aud', 'https://oauth2.googleapis.com/token' );
            // Setting Token to expire in an hour
            payload.put( 'exp', ( System.now().getTime() / 1000) + 3600 ); 
            payload.put( 'iat', System.now().getTime() / 1000 );

            // Encode the JWT header and payload
            String header = EncodingUtil.base64Encode( 
                Blob.valueOf( 
                    JSON.serialize(
                        new Map < String, String > { 'alg' => 'RS256', 'typ' => 'JWT' } 
                    ) 
                ) 
            );
            String encodedPayload = EncodingUtil.base64Encode( Blob.valueOf( JSON.serialize( payload ) ) );

            // Sign the JWT
            String unsignedToken = header + '.' + encodedPayload;
            Blob signature = Crypto.sign( 
                'RSA-SHA256', 
                Blob.valueOf( unsignedToken ), 
                EncodingUtil.base64Decode(
                    privateKey.replaceAll(
                        '-----BEGIN PRIVATE KEY-----', ''
                    ).replaceAll(
                        '-----END PRIVATE KEY-----', ''
                    ).replaceAll(
                        '\n', ''
                    )
                )
            );
            String encodedSignature = EncodingUtil.base64Encode( signature );

            // Construct the final JWT
            return unsignedToken + '.' + encodedSignature;

        } catch ( Exception e ) {
        
            System.debug( 'Exception in generateJwt: ' + e.getMessage() );
            return null;
            
        }
        
    }
    
}

Sample Apex Code to test quickly:

String strResponse = 
    GoogleVertexAgentHandler.sendRequestToVertexAI( 'Top Restaurants' );
System.debug ( strResponse );

Leave a Reply