#include<bits/stdc++.h>
using namespace std;
const int N=1e6;
typedef unsigned long long ll;
char s[N+100],t[N+100];
ll P[N+100],Ps[N+100],Ns[N+100],Pt[N+100];
int lens,lent;
long long sum[N+100];
ll getPs(int l,int r){return Ps[r]-Ps[l-1]*P[r-l+1];}
ll getPt(int l,int r){return Pt[r]-Pt[l-1]*P[r-l+1];}
ll getNs(int l,int r){return Ns[l]-Ns[r+1]*P[r-l+1];}
void init(){
scanf("%s%s",s,t);lens=strlen(s),lent=strlen(t);
P[0]=1;for(int i=1;i<=N;++i)P[i]=P[i-1]*131;
for(int i=1;i<=lens;++i)Ps[i]=Ps[i-1]*131+s[i-1]-'a';
for(int i=lens;i>=1;--i)Ns[i]=Ns[i+1]*131+s[i-1]-'a';
for(int i=1;i<=lent;++i)Pt[i]=Pt[i-1]*131+t[i-1]-'a';
for(int i=1;i<=lens;++i){
int l=1,r=min(i,lent);
while(l<r){
int mid=(l+r+1)/2;
if(getPt(1,mid)==getNs(i-mid+1,i))l=mid;
else r=mid-1;
}
if(s[i-1]!=t[0])sum[i]=0;
else sum[i]=l;
}
for(int i=1;i<=lens;++i)sum[i]+=sum[i-1];
}
int main(){
init();
long long ans=0;
for(int i=1;i<=lens;++i){
int l=0,r=min(i-1,lens-i);
while(l<r){
int mid=(l+r+1)/2;
if(getPs(i-mid,i)==getNs(i,i+mid))l=mid;
else r=mid-1;
}
if(i-l-1>0)ans+=sum[i-1]-sum[i-l-2];
else ans+=sum[i-1]-sum[i-l-1];
}
for(int i=2;i<=lens;++i){
if(s[i-1]!=s[i-2])continue;
int l=0,r=min(i-2,lens-i);
while(l<r){
int mid=(l+r+1)/2;
if(getPs(i-1-mid,i-1)==getNs(i,i+mid))l=mid;
else r=mid-1;
}
if(i-1-l-2>=0)ans=ans+sum[i-2]-sum[i-1-l-2];
else ans=ans+sum[i-2]-sum[i-1-l-1];
}
cout<<ans<<endl;
}