Plasticity of Mohr-Coulomb with apex-smoothing#

This tutorial aims to demonstrate how modern automatic algorithmic differentiation (AD) techniques may be used to define a complex constitutive model demanding a lot of by-hand differentiation. In particular, we implement the non-associative plasticity model of Mohr-Coulomb with apex-smoothing applied to a slope stability problem for soil. We use the JAX package to define constitutive relations including the differentiation of certain terms and FEMExternalOperator class to incorporate this model into a weak formulation within UFL.

The tutorial is based on the limit analysis within semi-definite programming framework, where the plasticity model was replaced by the MFront/TFEL implementation of the Mohr-Coulomb elastoplastic model with apex smoothing.

Problem formulation#

We solve a slope stability problem of a soil domain \(\Omega\) represented by a rectangle \([0; L] \times [0; W]\) with homogeneous Dirichlet boundary conditions for the displacement field \(\boldsymbol{u} = \boldsymbol{0}\) on the right side \(x = L\) and the bottom one \(z = 0\). The loading consists of a gravitational body force \(\boldsymbol{q}=[0, -\gamma]^T\) with \(\gamma\) being the soil self-weight. The solution of the problem is to find the collapse load \(q_\text{lim}\), for which we know an analytical solution in the case of the standard Mohr-Coulomb model without smoothing under plane strain assumption for associative plastic law [Chen and Liu, 1990]. Here we follow the same Mandel-Voigt notation as in the von Mises plasticity tutorial.

If \(V\) is a functional space of admissible displacement fields, then we can write out a weak formulation of the problem:

Find \(\boldsymbol{u} \in V\) such that

\[ F(\boldsymbol{u}; \boldsymbol{v}) = \int\limits_\Omega \boldsymbol{\sigma}(\boldsymbol{u}) \cdot \boldsymbol{\varepsilon}(\boldsymbol{v}) \, \mathrm{d}\boldsymbol{x} - \int\limits_\Omega \boldsymbol{q} \cdot \boldsymbol{v} \, \mathrm{d}\boldsymbol{x} = \boldsymbol{0}, \quad \forall \boldsymbol{v} \in V, \]

where \(\boldsymbol{\sigma}\) is an external operator representing the stress tensor.

Note

Although the tutorial shows the implementation of the Mohr-Coulomb model, it is quite general to be adapted to a wide rage of plasticity models that may be defined through a yield surface and a plastic potential.

Implementation#

Preamble#

from mpi4py import MPI
from petsc4py import PETSc

import jax
import jax.lax
import jax.numpy as jnp
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from mpltools import annotation  # for slope markers
from solvers import PETScNonlinearProblem, PETScNonlinearSolver
from utilities import find_cell_by_point

import basix
import ufl
from dolfinx import default_scalar_type, fem, mesh
from dolfinx_external_operator import (
    FEMExternalOperator,
    evaluate_external_operators,
    evaluate_operands,
    replace_external_operators,
)

jax.config.update("jax_enable_x64", True)

Here we define geometrical and material parameters of the problem as well as some useful constants.

E = 6778  # [MPa] Young modulus
nu = 0.25  # [-] Poisson ratio
c = 3.45  # [MPa] cohesion
phi = 30 * np.pi / 180  # [rad] friction angle
psi = 30 * np.pi / 180  # [rad] dilatancy angle
theta_T = 26 * np.pi / 180  # [rad] transition angle as defined by Abbo and Sloan
a = 0.26 * c / np.tan(phi)  # [MPa] tension cuff-off parameter
L, H = (1.2, 1.0)
Nx, Ny = (25, 25)
gamma = 1.0
domain = mesh.create_rectangle(MPI.COMM_WORLD, [np.array([0, 0]), np.array([L, H])], [Nx, Ny])
k_u = 2
gdim = domain.topology.dim
V = fem.functionspace(domain, ("Lagrange", k_u, (gdim,)))


# Boundary conditions
def on_right(x):
    return np.isclose(x[0], L)


def on_bottom(x):
    return np.isclose(x[1], 0.0)


bottom_dofs = fem.locate_dofs_geometrical(V, on_bottom)
right_dofs = fem.locate_dofs_geometrical(V, on_right)

bcs = [
    fem.dirichletbc(np.array([0.0, 0.0], dtype=PETSc.ScalarType), bottom_dofs, V),
    fem.dirichletbc(np.array([0.0, 0.0], dtype=PETSc.ScalarType), right_dofs, V),
]


def epsilon(v):
    grad_v = ufl.grad(v)
    return ufl.as_vector(
        [
            grad_v[0, 0],
            grad_v[1, 1],
            0.0,
            np.sqrt(2.0) * 0.5 * (grad_v[0, 1] + grad_v[1, 0]),
        ]
    )


k_stress = 2 * (k_u - 1)

dx = ufl.Measure(
    "dx",
    domain=domain,
    metadata={"quadrature_degree": k_stress, "quadrature_scheme": "default"},
)

stress_dim = 2 * gdim
S_element = basix.ufl.quadrature_element(domain.topology.cell_name(), degree=k_stress, value_shape=(stress_dim,))
S = fem.functionspace(domain, S_element)


Du = fem.Function(V, name="Du")
u = fem.Function(V, name="Total_displacement")
du = fem.Function(V, name="du")
v = ufl.TestFunction(V)

sigma = FEMExternalOperator(epsilon(Du), function_space=S)
sigma_n = fem.Function(S, name="sigma_n")

Defining plasticity model and external operator#

The constitutive model of the soil is described by a non-associative plasticity law without hardening that is defined by the Mohr-Coulomb yield surface \(f\) and the plastic potential \(g\). Both quantities may be expressed through the following function \(h\)

\[\begin{align*} & h(\boldsymbol{\sigma}, \alpha) = \frac{I_1(\boldsymbol{\sigma})}{3}\sin\alpha + \sqrt{J_2(\boldsymbol{\sigma}) K^2(\alpha) + a^2(\alpha)\sin^2\alpha} - c\cos\alpha, \\ & f(\boldsymbol{\sigma}) = h(\boldsymbol{\sigma}, \phi), \\ & g(\boldsymbol{\sigma}) = h(\boldsymbol{\sigma}, \psi), \end{align*}\]

where \(\phi\) and \(\psi\) are friction and dilatancy angles, \(c\) is a cohesion, \(I_1(\boldsymbol{\sigma}) = \mathrm{tr} \boldsymbol{\sigma}\) is the first invariant of the stress tensor and \(J_2(\boldsymbol{\sigma}) = \frac{1}{2}\boldsymbol{s} \cdot \boldsymbol{s}\) is the second invariant of the deviatoric part of the stress tensor. The expression of the coefficient \(K(\alpha)\) may be found in the MFront/TFEL implementation of this plastic model.

During the plastic loading the stress-strain state of the solid must satisfy the following system of nonlinear equations

(5)#\[\begin{split} \begin{cases} \boldsymbol{r}_{g}(\boldsymbol{\sigma}_{n+1}, \Delta\lambda) = \boldsymbol{\sigma}_{n+1} - \boldsymbol{\sigma}_n - \boldsymbol{C} \cdot (\Delta\boldsymbol{\varepsilon} - \Delta\lambda \frac{\mathrm{d} g}{\mathrm{d}\boldsymbol{\sigma}}(\boldsymbol{\sigma_{n+1}})) = \boldsymbol{0}, \\ r_f(\boldsymbol{\sigma}_{n+1}) = f(\boldsymbol{\sigma}_{n+1}) = 0, \end{cases}\end{split}\]

where \(\Delta\) is associated with increments of a quantity between the next loading step \(n + 1\) and the current loading step \(n\).

By introducing the residual vector \(\boldsymbol{r} = [\boldsymbol{r}_{g}^T, r_f]^T\) and its argument vector \(\boldsymbol{y}_{n+1} = [\boldsymbol{\sigma}_{n+1}^T, \Delta\lambda]^T\), we obtain the following nonlinear constitutive equation:

\[ \boldsymbol{r}(\boldsymbol{y}_{n+1}) = \boldsymbol{0}. \]

To solve this equation we apply the Newton method and introduce the local Jacobian of the residual vector \(\boldsymbol{j} := \frac{\mathrm{d} \boldsymbol{r}}{\mathrm{d} \boldsymbol{y}}\). Thus we solve the following linear system at each quadrature point for the plastic phase

\[\begin{split} \begin{cases} \boldsymbol{j}(\boldsymbol{y}_{n})\boldsymbol{t} = - \boldsymbol{r}(\boldsymbol{y}_{n}), \\ \boldsymbol{x}_{n+1} = \boldsymbol{x}_n + \boldsymbol{t}. \end{cases} \end{split}\]

During the elastic loading, we consider a trivial system of equations

(6)#\[\begin{split} \begin{cases} \boldsymbol{\sigma}_{n+1} = \boldsymbol{\sigma}_n + \boldsymbol{C} \cdot \Delta\boldsymbol{\varepsilon}, \\ \Delta\lambda = 0. \end{cases} \end{split}\]

The algorithm solving the systems (5)(6) is called the return-mapping procedure and the solution defines the return-mapping correction of the stress tensor. By implementation of the external operator \(\boldsymbol{\sigma}\) we mean the implementation of this algorithmic procedure.

The automatic differentiation tools of the JAX library are applied to calculate the three distinct derivatives:

  1. \(\frac{\mathrm{d} g}{\mathrm{d}\boldsymbol{\sigma}}\) - derivative of the plastic potential \(g\),

  2. \(j = \frac{\mathrm{d} \boldsymbol{r}}{\mathrm{d} \boldsymbol{y}}\) - derivative of the local residual \(\boldsymbol{r}\),

  3. \(\boldsymbol{C}_\text{tang} = \frac{\mathrm{d}\boldsymbol{\sigma}}{\mathrm{d}\boldsymbol{\varepsilon}}\) - stress tensor derivative or consistent tangent moduli.

Defining yield surface and plastic potential#

First of all, we define supplementary functions that help us to express the yield surface \(f\) and the plastic potential \(g\). In the following definitions, we use built-in functions of the JAX package, in particular, the conditional primitive jax.lax.cond. It is necessary for the correct work of the AD tool and just-in-time compilation. For more details, please, visit the JAX documentation.

def J3(s):
    return s[2] * (s[0] * s[1] - s[3] * s[3] / 2.0)


def J2(s):
    return 0.5 * jnp.vdot(s, s)


def theta(s):
    J2_ = J2(s)
    arg = -(3.0 * np.sqrt(3.0) * J3(s)) / (2.0 * jnp.sqrt(J2_ * J2_ * J2_))
    arg = jnp.clip(arg, -1.0, 1.0)
    theta = 1.0 / 3.0 * jnp.arcsin(arg)
    return theta


def sign(x):
    return jax.lax.cond(x < 0.0, lambda x: -1, lambda x: 1, x)


def coeff1(theta, angle):
    return np.cos(theta_T) - (1.0 / np.sqrt(3.0)) * np.sin(angle) * np.sin(theta_T)


def coeff2(theta, angle):
    return sign(theta) * np.sin(theta_T) + (1.0 / np.sqrt(3.0)) * np.sin(angle) * np.cos(theta_T)


coeff3 = 18.0 * np.cos(3.0 * theta_T) * np.cos(3.0 * theta_T) * np.cos(3.0 * theta_T)


def C(theta, angle):
    return (
        -np.cos(3.0 * theta_T) * coeff1(theta, angle) - 3.0 * sign(theta) * np.sin(3.0 * theta_T) * coeff2(theta, angle)
    ) / coeff3


def B(theta, angle):
    return (
        sign(theta) * np.sin(6.0 * theta_T) * coeff1(theta, angle) - 6.0 * np.cos(6.0 * theta_T) * coeff2(theta, angle)
    ) / coeff3


def A(theta, angle):
    return (
        -(1.0 / np.sqrt(3.0)) * np.sin(angle) * sign(theta) * np.sin(theta_T)
        - B(theta, angle) * sign(theta) * np.sin(3 * theta_T)
        - C(theta, angle) * np.sin(3.0 * theta_T) * np.sin(3.0 * theta_T)
        + np.cos(theta_T)
    )


def K(theta, angle):
    def K_false(theta):
        return jnp.cos(theta) - (1.0 / np.sqrt(3.0)) * np.sin(angle) * jnp.sin(theta)

    def K_true(theta):
        return (
            A(theta, angle)
            + B(theta, angle) * jnp.sin(3.0 * theta)
            + C(theta, angle) * jnp.sin(3.0 * theta) * jnp.sin(3.0 * theta)
        )

    return jax.lax.cond(jnp.abs(theta) > theta_T, K_true, K_false, theta)


def a_g(angle):
    return a * np.tan(phi) / np.tan(angle)


dev = np.array(
    [
        [2.0 / 3.0, -1.0 / 3.0, -1.0 / 3.0, 0.0],
        [-1.0 / 3.0, 2.0 / 3.0, -1.0 / 3.0, 0.0],
        [-1.0 / 3.0, -1.0 / 3.0, 2.0 / 3.0, 0.0],
        [0.0, 0.0, 0.0, 1.0],
    ],
    dtype=PETSc.ScalarType,
)
tr = np.array([1.0, 1.0, 1.0, 0.0], dtype=PETSc.ScalarType)


def surface(sigma_local, angle):
    s = dev @ sigma_local
    I1 = tr @ sigma_local
    theta_ = theta(s)
    return (
        (I1 / 3.0 * np.sin(angle))
        + jnp.sqrt(
            J2(s) * K(theta_, angle) * K(theta_, angle) + a_g(angle) * a_g(angle) * np.sin(angle) * np.sin(angle)
        )
        - c * np.cos(angle)
    )

By picking up an appropriate angle we define the yield surface \(f\) and the plastic potential \(g\).

def f(sigma_local):
    return surface(sigma_local, phi)


def g(sigma_local):
    return surface(sigma_local, psi)


dgdsigma = jax.jacfwd(g)

Solving constitutive equations#

In this section, we define the constitutive model by solving the systems (5)(6). They must be solved at each Gauss point, so we apply the Newton method, implement the whole algorithm locally and then vectorize the final result using jax.vmap.

In the following cell, we define locally the residual \(\boldsymbol{r}\) and its Jacobian drdy.

lmbda = E * nu / ((1.0 + nu) * (1.0 - 2.0 * nu))
mu = E / (2.0 * (1.0 + nu))
C_elas = np.array(
    [
        [lmbda + 2 * mu, lmbda, lmbda, 0],
        [lmbda, lmbda + 2 * mu, lmbda, 0],
        [lmbda, lmbda, lmbda + 2 * mu, 0],
        [0, 0, 0, 2 * mu],
    ],
    dtype=PETSc.ScalarType,
)
S_elas = np.linalg.inv(C_elas)
ZERO_VECTOR = np.zeros(stress_dim, dtype=PETSc.ScalarType)


def deps_p(sigma_local, dlambda, deps_local, sigma_n_local):
    sigma_elas_local = sigma_n_local + C_elas @ deps_local
    yielding = f(sigma_elas_local)

    def deps_p_elastic(sigma_local, dlambda):
        return ZERO_VECTOR

    def deps_p_plastic(sigma_local, dlambda):
        return dlambda * dgdsigma(sigma_local)

    return jax.lax.cond(yielding <= 0.0, deps_p_elastic, deps_p_plastic, sigma_local, dlambda)


def r_g(sigma_local, dlambda, deps_local, sigma_n_local):
    deps_p_local = deps_p(sigma_local, dlambda, deps_local, sigma_n_local)
    return sigma_local - sigma_n_local - C_elas @ (deps_local - deps_p_local)


def r_f(sigma_local, dlambda, deps_local, sigma_n_local):
    sigma_elas_local = sigma_n_local + C_elas @ deps_local
    yielding = f(sigma_elas_local)

    def r_f_elastic(sigma_local, dlambda):
        return dlambda

    def r_f_plastic(sigma_local, dlambda):
        return f(sigma_local)

    return jax.lax.cond(yielding <= 0.0, r_f_elastic, r_f_plastic, sigma_local, dlambda)


def r(y_local, deps_local, sigma_n_local):
    sigma_local = y_local[:stress_dim]
    dlambda_local = y_local[-1]

    res_g = r_g(sigma_local, dlambda_local, deps_local, sigma_n_local)
    res_f = r_f(sigma_local, dlambda_local, deps_local, sigma_n_local)

    res = jnp.c_["0,1,-1", res_g, res_f]  # concatenates an array and a scalar
    return res


drdy = jax.jacfwd(r)

Then we define the function return_mapping that implements the return-mapping algorithm numerically via the Newton method.

Nitermax, tol = 200, 1e-8

ZERO_SCALAR = np.array([0.0])


def return_mapping(deps_local, sigma_n_local):
    """Performs the return-mapping procedure.

    It solves elastoplastic constitutive equations numerically by applying the
    Newton method in a single Gauss point. The Newton loop is implement via
    `jax.lax.while_loop`.

    The function returns `sigma_local` two times to reuse its values after
    differentiation, i.e. as once we apply
    `jax.jacfwd(return_mapping, has_aux=True)` the ouput function will
    have an output of
    `(C_tang_local, (sigma_local, niter_total, yielding, norm_res, dlambda))`.

    Returns:
        sigma_local: The stress at the current Gauss point.
        niter_total: The total number of iterations.
        yielding: The value of the yield function.
        norm_res: The norm of the residuals.
        dlambda: The value of the plastic multiplier.
    """
    niter = 0

    dlambda = ZERO_SCALAR
    sigma_local = sigma_n_local
    y_local = jnp.concatenate([sigma_local, dlambda])

    res = r(y_local, deps_local, sigma_n_local)
    norm_res0 = jnp.linalg.norm(res)

    def cond_fun(state):
        norm_res, niter, _ = state
        return jnp.logical_and(norm_res / norm_res0 > tol, niter < Nitermax)

    def body_fun(state):
        norm_res, niter, history = state

        y_local, deps_local, sigma_n_local, res = history

        j = drdy(y_local, deps_local, sigma_n_local)
        j_inv_vp = jnp.linalg.solve(j, -res)
        y_local = y_local + j_inv_vp

        res = r(y_local, deps_local, sigma_n_local)
        norm_res = jnp.linalg.norm(res)
        history = y_local, deps_local, sigma_n_local, res

        niter += 1

        return (norm_res, niter, history)

    history = (y_local, deps_local, sigma_n_local, res)

    norm_res, niter_total, y_local = jax.lax.while_loop(cond_fun, body_fun, (norm_res0, niter, history))

    sigma_local = y_local[0][:stress_dim]
    dlambda = y_local[0][-1]
    sigma_elas_local = C_elas @ deps_local
    yielding = f(sigma_n_local + sigma_elas_local)

    return sigma_local, (sigma_local, niter_total, yielding, norm_res, dlambda)

Consistent tangent stiffness matrix#

Not only is the automatic differentiation able to compute the derivative of a mathematical expression but also a numerical algorithm. For instance, AD can calculate the derivative of the function performing return-mapping with respect to its output, the stress tensor \(\boldsymbol{\sigma}\). In the context of the consistent tangent moduli \(\boldsymbol{C}_\text{tang}\), this feature becomes very useful, as there is no need to write an additional program computing the stress derivative.

JAX’s AD tool permits taking the derivative of the function return_mapping, which is factually the while loop. The derivative is taken with respect to the first output and the remaining outputs are used as auxiliary data. Thus, the derivative dsigma_ddeps returns both values of the consistent tangent moduli and the stress tensor, so there is no need in a supplementary computation of the stress tensor.

dsigma_ddeps = jax.jacfwd(return_mapping, has_aux=True)

Defining external operator#

Once we define the function dsigma_ddeps, which evaluates both the external operator and its derivative locally, we can simply vectorize it and define the final implementation of the external operator derivative.

Note

The function dsigma_ddeps containing a while_loop is designed to be called at a single Gauss point that’s why we need to vectorize it for the all points of our functional space S. For this purpose we use the vmap function of JAX. It creates another while_loop, which terminates only when all mapped loops terminate. Find further details in this discussion.

dsigma_ddeps_vec = jax.jit(jax.vmap(dsigma_ddeps, in_axes=(0, 0)))


def C_tang_impl(deps):
    deps_ = deps.reshape((-1, stress_dim))
    sigma_n_ = sigma_n.x.array.reshape((-1, stress_dim))

    (C_tang_global, state) = dsigma_ddeps_vec(deps_, sigma_n_)
    sigma_global, niter, yielding, norm_res, dlambda = state

    unique_iters, counts = jnp.unique(niter, return_counts=True)

    if MPI.COMM_WORLD.rank == 0:
        print("\tInner Newton summary:")
        print(f"\t\tUnique number of iterations: {unique_iters}")
        print(f"\t\tCounts of unique number of iterations: {counts}")
        print(f"\t\tMaximum f: {jnp.max(yielding)}")
        print(f"\t\tMaximum residual: {jnp.max(norm_res)}")

    return C_tang_global.reshape(-1), sigma_global.reshape(-1)

Similarly to the von Mises example, we do not implement explicitly the evaluation of the external operator. Instead, we obtain its values during the evaluation of its derivative and then update the values of the operator in the main Newton loop.

def sigma_external(derivatives):
    if derivatives == (1,):
        return C_tang_impl
    else:
        return NotImplementedError


sigma.external_function = sigma_external

Defining the forms#

q = fem.Constant(domain, default_scalar_type((0, -gamma)))


def F_ext(v):
    return ufl.dot(q, v) * dx


u_hat = ufl.TrialFunction(V)
F = ufl.inner(epsilon(v), sigma) * dx - F_ext(v)
J = ufl.derivative(F, Du, u_hat)
J_expanded = ufl.algorithms.expand_derivatives(J)

F_replaced, F_external_operators = replace_external_operators(F)
J_replaced, J_external_operators = replace_external_operators(J_expanded)

F_form = fem.form(F_replaced)
J_form = fem.form(J_replaced)

Variables initialization and compilation#

Before solving the problem we have to initialize values of the stiffness matrix, as it requires for the system assembling. During the first loading step, we expect an elastic response only, so it’s enough to solve the constitutive equations for a relatively small displacement field at each Gauss point. This results in initializing the consistent tangent moduli with elastic ones.

Du.x.array[:] = 1.0
sigma_n.x.array[:] = 0.0

evaluated_operands = evaluate_operands(F_external_operators)
_ = evaluate_external_operators(J_external_operators, evaluated_operands)
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -2.2109628558533108
		Maximum residual: 0.0

Solving the problem#

Similarly to the von Mises tutorial, we use a Newton solver, but this time we rely on SNES, the implementation from the PETSc library. We implemented the class PETScNonlinearProblem that allows to call an additional routine external_callback at each iteration of SNES before the vector and matrix assembly.

def constitutive_update():
    evaluated_operands = evaluate_operands(F_external_operators)
    ((_, sigma_new),) = evaluate_external_operators(J_external_operators, evaluated_operands)
    # Direct access to the external operator values
    sigma.ref_coefficient.x.array[:] = sigma_new


problem = PETScNonlinearProblem(Du, F_replaced, J_replaced, bcs=bcs, external_callback=constitutive_update)

petsc_options = {
    "snes_type": "vinewtonrsls",
    "snes_linesearch_type": "basic",
    "ksp_type": "preonly",
    "pc_type": "lu",
    "pc_factor_mat_solver_type": "mumps",
    "snes_atol": 1.0e-8,
    "snes_rtol": 1.0e-8,
    "snes_max_it": 100,
    "snes_monitor": "",
}


solver = PETScNonlinearSolver(domain.comm, problem, petsc_options=petsc_options)  # PETSc.SNES wrapper

After definition of the nonlinear problem and the Newton solver, we are ready to get the final result.

load_steps_1 = np.linspace(2, 22.9, 50)
load_steps_2 = np.array([22.96, 22.99])
load_steps = np.concatenate([load_steps_1, load_steps_2])
num_increments = len(load_steps)
results = np.zeros((num_increments + 1, 2))

x_point = np.array([[0, H, 0]])
cells, points_on_process = find_cell_by_point(domain, x_point)

for i, load in enumerate(load_steps):
    q.value = load * np.array([0, -gamma])

    if MPI.COMM_WORLD.rank == 0:
        print(f"Load increment #{i}, load: {load}")

    solver.solve(Du)

    u.x.petsc_vec.axpy(1.0, Du.x.petsc_vec)
    u.x.scatter_forward()

    sigma_n.x.array[:] = sigma.ref_coefficient.x.array

    if len(points_on_process) > 0:
        results[i + 1, :] = (-u.eval(points_on_process, cells)[0], load)

print(f"Slope stability factor: {-q.value[-1] * H / c}")
Load increment #0, load: 2.0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -2.2109628558533108
		Maximum residual: 0.0
  0 SNES Function norm 1.195147420875e+05
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -0.754479661417915
		Maximum residual: 0.0
  1 SNES Function norm 4.904068111920e-10
Load increment #1, load: 2.426530612244898
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 1.1914324474182885
		Maximum residual: 5.037138905646796e-10
  0 SNES Function norm 4.756853160422e-02
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -0.8092635980827132
		Maximum residual: 2.355154659777929e-16
  1 SNES Function norm 3.033810662369e-02
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -0.35079912068642827
		Maximum residual: 2.3592239273284576e-16
  2 SNES Function norm 2.545132354954e-15
Load increment #2, load: 2.853061224489796
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3749    1]
		Maximum f: 0.06115928116524216
		Maximum residual: 1.3031087567683006e-15
  0 SNES Function norm 9.132432640845e-04
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3749    1]
		Maximum f: 0.09230862184731814
		Maximum residual: 8.894860566000958e-14
  1 SNES Function norm 8.054297427597e-06
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3749    1]
		Maximum f: 0.09244639358766138
		Maximum residual: 9.002690704602798e-14
  2 SNES Function norm 3.456153779038e-10
Load increment #3, load: 3.279591836734694
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3749    1]
		Maximum f: 0.4484421652174473
		Maximum residual: 1.1048989468455901e-13
  0 SNES Function norm 5.392112070247e-03
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 0.6457397002297021
		Maximum residual: 5.6451654665433484e-09
  1 SNES Function norm 9.222044508532e-04
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 0.6671698219110791
		Maximum residual: 5.567801930278813e-09
  2 SNES Function norm 2.306874769950e-06
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 0.6671780558585327
		Maximum residual: 5.564613164610517e-09
  3 SNES Function norm 8.267706072392e-12
Load increment #4, load: 3.706122448979592
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 0.6729159475629167
		Maximum residual: 1.961914019852753e-09
  0 SNES Function norm 7.708642137450e-03
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 0.8182731983589551
		Maximum residual: 3.0518095327792377e-09
  1 SNES Function norm 1.306670391138e-04
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 0.8184517040414163
		Maximum residual: 3.2657514304459696e-09
  2 SNES Function norm 1.677202587141e-08
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 0.8184520904141563
		Maximum residual: 3.265797852507986e-09
  3 SNES Function norm 2.705484852041e-15
Load increment #5, load: 4.13265306122449
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3747    3]
		Maximum f: 0.8189820439289797
		Maximum residual: 1.7557436048981662e-09
  0 SNES Function norm 6.653604188183e-03
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3747    3]
		Maximum f: 0.9125840840028117
		Maximum residual: 2.8909160594637384e-09
  1 SNES Function norm 2.564375080375e-05
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3747    3]
		Maximum f: 0.9124710788798271
		Maximum residual: 2.9132798401008463e-09
  2 SNES Function norm 1.145551698424e-09
Load increment #6, load: 4.559183673469388
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3744    5    1]
		Maximum f: 0.9120013616400899
		Maximum residual: 2.094614500284725e-09
  0 SNES Function norm 4.373771847428e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3744    5    1]
		Maximum f: 1.0834716869126564
		Maximum residual: 7.0973731893345355e-09
  1 SNES Function norm 9.084544888110e-05
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3744    5    1]
		Maximum f: 1.085363428884213
		Maximum residual: 7.0939696589608045e-09
  2 SNES Function norm 1.607752517055e-08
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3744    5    1]
		Maximum f: 1.085363929078086
		Maximum residual: 7.0939525461859185e-09
  3 SNES Function norm 2.893023612745e-15
Load increment #7, load: 4.985714285714286
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3743    6    1]
		Maximum f: 1.0817281941295938
		Maximum residual: 6.913824480896462e-09
  0 SNES Function norm 7.149688794666e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3743    5    2]
		Maximum f: 1.3124494265785436
		Maximum residual: 3.8054853560241645e-09
  1 SNES Function norm 2.058508967714e-04
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3743    5    2]
		Maximum f: 1.3140149963446661
		Maximum residual: 3.9450923105992785e-09
  2 SNES Function norm 7.527865371142e-08
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3743    5    2]
		Maximum f: 1.3140158007718141
		Maximum residual: 3.945117943084998e-09
  3 SNES Function norm 1.321747796382e-14
Load increment #8, load: 5.412244897959184
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3740    8    1    1]
		Maximum f: 1.3088422717205206
		Maximum residual: 7.666114760291095e-09
  0 SNES Function norm 7.696175706972e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3739    8    2    1]
		Maximum f: 1.5216973741715996
		Maximum residual: 1.4413462093993158e-08
  1 SNES Function norm 5.990449202790e-04
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3739    7    3    1]
		Maximum f: 1.52923347206826
		Maximum residual: 1.4268738732832307e-08
  2 SNES Function norm 4.936076057041e-07
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3739    7    3    1]
		Maximum f: 1.5292375934479288
		Maximum residual: 1.4268663859345346e-08
  3 SNES Function norm 4.633378480028e-13
Load increment #9, load: 5.838775510204082
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3736   13    1]
		Maximum f: 1.524300040267685
		Maximum residual: 7.688161900452914e-09
  0 SNES Function norm 9.589791815981e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3735   11    4]
		Maximum f: 1.7689821236208938
		Maximum residual: 9.505123556988733e-09
  1 SNES Function norm 3.862729558146e-04
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3735   11    4]
		Maximum f: 1.7742284095842664
		Maximum residual: 1.0807696682576373e-08
  2 SNES Function norm 4.301313419599e-07
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3735   11    4]
		Maximum f: 1.774233748804027
		Maximum residual: 1.080953903115202e-08
  3 SNES Function norm 5.347991298968e-13
Load increment #10, load: 6.265306122448979
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3731   16    3]
		Maximum f: 1.7693631620667802
		Maximum residual: 1.0598585889381775e-08
  0 SNES Function norm 1.143203715428e-02
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3730   15    5]
		Maximum f: 1.931443346652911
		Maximum residual: 1.482621385021867e-08
  1 SNES Function norm 1.584942427511e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3730   14    6]
		Maximum f: 1.9459812815294124
		Maximum residual: 1.485312251369592e-08
  2 SNES Function norm 3.025107792013e-06
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3730   14    6]
		Maximum f: 1.946002601666399
		Maximum residual: 1.486022416821949e-08
  3 SNES Function norm 1.580316621700e-11
Load increment #11, load: 6.691836734693877
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3724   22    2    2]
		Maximum f: 1.942723296295474
		Maximum residual: 1.1070828265789418e-08
  0 SNES Function norm 1.107822857802e-02
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3723   18    6    3]
		Maximum f: 2.1223921642834136
		Maximum residual: 1.7660578820083674e-08
  1 SNES Function norm 5.672791184133e-04
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3723   18    6    3]
		Maximum f: 2.1279911785405763
		Maximum residual: 1.8046844781872237e-08
  2 SNES Function norm 8.341174881919e-07
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3723   18    6    3]
		Maximum f: 2.1279993282057226
		Maximum residual: 1.8047264082681924e-08
  3 SNES Function norm 1.851653688183e-12
Load increment #12, load: 7.118367346938775
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3719   27    4]
		Maximum f: 2.125608307652833
		Maximum residual: 1.2834185599199991e-08
  0 SNES Function norm 1.211225124343e-02
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3716   24    9    1]
		Maximum f: 2.328908273477739
		Maximum residual: 2.8296922960994398e-08
  1 SNES Function norm 1.578810096150e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3715   24    9    2]
		Maximum f: 2.3415246200100595
		Maximum residual: 1.8299013720853134e-08
  2 SNES Function norm 1.892066413845e-04
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3715   24    9    2]
		Maximum f: 2.342448287387987
		Maximum residual: 1.8318806537801995e-08
  3 SNES Function norm 3.305365337132e-08
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3715   24    9    2]
		Maximum f: 2.342448401872072
		Maximum residual: 1.831880985770971e-08
  4 SNES Function norm 3.440614709517e-15
Load increment #13, load: 7.544897959183673
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3711   34    5]
		Maximum f: 2.340162751888442
		Maximum residual: 2.591648292442109e-08
  0 SNES Function norm 1.196820342248e-02
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3710   28   12]
		Maximum f: 2.5259332779377996
		Maximum residual: 2.286849002086255e-08
  1 SNES Function norm 2.743993581195e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3710   27   13]
		Maximum f: 2.542296026957439
		Maximum residual: 2.1308369948440555e-08
  2 SNES Function norm 7.305829738870e-06
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3710   27   13]
		Maximum f: 2.5423337762226796
		Maximum residual: 2.1310229835455936e-08
  3 SNES Function norm 7.631495540733e-11
Load increment #14, load: 7.971428571428571
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3705   42    2    1]
		Maximum f: 2.540201201791897
		Maximum residual: 1.4293190722188166e-08
  0 SNES Function norm 1.079588899770e-02
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3704   34   11    1]
		Maximum f: 2.728556224371425
		Maximum residual: 2.2875959146396816e-08
  1 SNES Function norm 1.040283872579e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3704   34   11    1]
		Maximum f: 2.736308284868541
		Maximum residual: 2.3842123839694838e-08
  2 SNES Function norm 1.400314054692e-06
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3704   34   11    1]
		Maximum f: 2.736317418830786
		Maximum residual: 2.384323208026476e-08
  3 SNES Function norm 2.772559908433e-12
Load increment #15, load: 8.397959183673468
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3699   49    2]
		Maximum f: 2.734686952289237
		Maximum residual: 7.159754550722705e-09
  0 SNES Function norm 1.091345771982e-02
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3698   41   11]
		Maximum f: 2.8973654944029046
		Maximum residual: 1.994779810376391e-08
  1 SNES Function norm 7.408397972280e-04
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3698   41   11]
		Maximum f: 2.902394333239562
		Maximum residual: 2.063408185711428e-08
  2 SNES Function norm 7.993577406574e-07
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3698   41   11]
		Maximum f: 2.9023986347998174
		Maximum residual: 2.0634617112803917e-08
  3 SNES Function norm 8.010628299507e-13
Load increment #16, load: 8.824489795918367
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3693    1   54    2]
		Maximum f: 2.9013013029667207
		Maximum residual: 7.082318791308725e-09
  0 SNES Function norm 1.131724631643e-02
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3692   53    5]
		Maximum f: 3.056207361207066
		Maximum residual: 1.9590593185362122e-08
  1 SNES Function norm 1.008710940736e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3692   51    7]
		Maximum f: 3.061188684009411
		Maximum residual: 2.028090520073796e-08
  2 SNES Function norm 1.574785763060e-06
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3692   51    7]
		Maximum f: 3.0611926039834336
		Maximum residual: 2.0281446806363532e-08
  3 SNES Function norm 3.672588087470e-12
Load increment #17, load: 9.251020408163264
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3688   61    1]
		Maximum f: 3.0606359799274014
		Maximum residual: 4.539286335393457e-09
  0 SNES Function norm 1.039626377063e-02
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3687   57    6]
		Maximum f: 3.1932576381579962
		Maximum residual: 1.3834417265902114e-08
  1 SNES Function norm 5.727495097146e-04
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3687   57    6]
		Maximum f: 3.1964061857745567
		Maximum residual: 1.4192616841316228e-08
  2 SNES Function norm 4.430536444255e-07
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3687   57    6]
		Maximum f: 3.1964077098125476
		Maximum residual: 1.419278853455254e-08
  3 SNES Function norm 2.390403421211e-13
Load increment #18, load: 9.677551020408163
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3682   67    1]
		Maximum f: 3.196279234543098
		Maximum residual: 3.1280966022035733e-09
  0 SNES Function norm 1.077705186543e-02
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3682   63    5]
		Maximum f: 3.327384679063542
		Maximum residual: 3.0069783902278564e-08
  1 SNES Function norm 2.242198150756e-04
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3682   63    5]
		Maximum f: 3.3291427526630772
		Maximum residual: 3.1822788885684694e-08
  2 SNES Function norm 8.066505840336e-08
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3682   63    5]
		Maximum f: 3.3291431795177835
		Maximum residual: 3.1823235983648236e-08
  3 SNES Function norm 1.182557473735e-14
Load increment #19, load: 10.10408163265306
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3677    3   68    1    1]
		Maximum f: 3.329381053831558
		Maximum residual: 2.4015415015508287e-08
  0 SNES Function norm 1.009301544127e-02
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3676   69    4    1]
		Maximum f: 3.446488877319539
		Maximum residual: 2.1903745747420033e-08
  1 SNES Function norm 2.451203768064e-04
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3676   69    4    1]
		Maximum f: 3.4472552731640014
		Maximum residual: 1.875396568078783e-08
  2 SNES Function norm 6.772470633284e-08
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3676   69    4    1]
		Maximum f: 3.4472558177841717
		Maximum residual: 1.8754306955626633e-08
  3 SNES Function norm 1.168343569053e-14
Load increment #20, load: 10.530612244897958
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3673   75    2]
		Maximum f: 3.447745774192725
		Maximum residual: 1.230092644610275e-08
  0 SNES Function norm 8.769338038872e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3672   74    4]
		Maximum f: 3.553612234013514
		Maximum residual: 2.073176861393257e-08
  1 SNES Function norm 1.398871589577e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3672   74    4]
		Maximum f: 3.5575500941879707
		Maximum residual: 2.0880012214871582e-08
  2 SNES Function norm 1.379070835902e-06
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3672   74    4]
		Maximum f: 3.557553374057776
		Maximum residual: 2.0880046002651022e-08
  3 SNES Function norm 2.265930547242e-12
Load increment #21, load: 10.957142857142856
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3669   79    2]
		Maximum f: 3.558222423855676
		Maximum residual: 1.3134273313277192e-09
  0 SNES Function norm 8.691186079801e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3669    2   75    4]
		Maximum f: 3.65759394259212
		Maximum residual: 3.595458105323198e-08
  1 SNES Function norm 3.411985921927e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3669    1   76    4]
		Maximum f: 3.660080990277342
		Maximum residual: 2.902156845935725e-09
  2 SNES Function norm 2.666288752862e-06
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3669    1   76    4]
		Maximum f: 3.660081030314147
		Maximum residual: 2.9021542407501645e-09
  3 SNES Function norm 1.419466613335e-11
Load increment #22, load: 11.383673469387753
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3664    3   82    1]
		Maximum f: 3.6608897267290836
		Maximum residual: 2.63979341044774e-08
  0 SNES Function norm 8.784785795509e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3662    2   84    2]
		Maximum f: 3.751648638769242
		Maximum residual: 4.677943397443282e-08
  1 SNES Function norm 2.334921659320e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3662    1   85    2]
		Maximum f: 3.7550666405164494
		Maximum residual: 4.724383372331403e-08
  2 SNES Function norm 9.229833616516e-06
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3662    1   85    2]
		Maximum f: 3.7550698944933427
		Maximum residual: 4.724239354767479e-08
  3 SNES Function norm 1.329653469918e-10
Load increment #23, load: 11.810204081632651
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3660    4   84    2]
		Maximum f: 3.7559609444110147
		Maximum residual: 3.3906255791760624e-08
  0 SNES Function norm 7.703871014509e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3659    3   85    3]
		Maximum f: 3.838298203047153
		Maximum residual: 2.1412612385054286e-08
  1 SNES Function norm 7.732693973152e-04
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3659    2   86    3]
		Maximum f: 3.839826825119904
		Maximum residual: 2.1816110741095973e-08
  2 SNES Function norm 4.461456405469e-06
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3659    2   86    3]
		Maximum f: 3.8398290986266983
		Maximum residual: 2.1815665852642728e-08
  3 SNES Function norm 6.876437668555e-11
Load increment #24, load: 12.23673469387755
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3653    3   91    3]
		Maximum f: 3.840756821401302
		Maximum residual: 4.2057862496062974e-08
  0 SNES Function norm 1.146900093023e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3652    1   94    3]
		Maximum f: 3.9276871850440958
		Maximum residual: 5.796849704214347e-08
  1 SNES Function norm 2.495998296202e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3653   94    3]
		Maximum f: 3.9311993982793285
		Maximum residual: 5.724728667093015e-08
  2 SNES Function norm 1.623104017187e-05
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [3653   94    3]
		Maximum f: 3.931205210614484
		Maximum residual: 5.724503173211357e-08
  3 SNES Function norm 3.768453794922e-10
Load increment #25, load: 12.663265306122447
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3644    4   99    2    1]
		Maximum f: 3.932143966018882
		Maximum residual: 5.0623511299290145e-09
  0 SNES Function norm 8.353184330108e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3646  101    2    1]
		Maximum f: 4.327729961142113
		Maximum residual: 3.233389059775927e-08
  1 SNES Function norm 8.334150687845e-03
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3646  101    2    1]
		Maximum f: 4.311288264135284
		Maximum residual: 3.2397402145143064e-08
  2 SNES Function norm 4.637932353007e-06
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [3646  101    2    1]
		Maximum f: 4.311303131281377
		Maximum residual: 3.239702352192351e-08
  3 SNES Function norm 2.609546572361e-11
Load increment #26, load: 13.089795918367345
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3639    1  108    2]
		Maximum f: 4.3131762380165455
		Maximum residual: 2.9576102833246317e-08
  0 SNES Function norm 1.194001516472e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3637    2  108    3]
		Maximum f: 4.867253928816288
		Maximum residual: 1.768669204481932e-08
  1 SNES Function norm 6.098877544132e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3638    2  107    3]
		Maximum f: 4.959680353880241
		Maximum residual: 1.66216449532847e-08
  2 SNES Function norm 1.108269377350e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3638    2  107    3]
		Maximum f: 4.956079107449407
		Maximum residual: 1.6609216120894577e-08
  3 SNES Function norm 3.073549374397e-07
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3638    2  107    3]
		Maximum f: 4.956077715188027
		Maximum residual: 1.6609212730558424e-08
  4 SNES Function norm 1.137146500737e-13
Load increment #27, load: 13.516326530612243
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3628    5  116    1]
		Maximum f: 4.957344021789547
		Maximum residual: 1.811707927971261e-08
  0 SNES Function norm 1.240366718545e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3628    3  116    3]
		Maximum f: 5.683701149470998
		Maximum residual: 1.295827894694929e-08
  1 SNES Function norm 2.161926288287e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3627    4  115    4]
		Maximum f: 5.720335933291519
		Maximum residual: 1.2955191558802049e-08
  2 SNES Function norm 1.004507493487e-04
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3627    4  115    4]
		Maximum f: 5.722494932268212
		Maximum residual: 1.2857682746470161e-08
  3 SNES Function norm 3.137931183155e-08
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3627    4  115    4]
		Maximum f: 5.7224950713856
		Maximum residual: 1.2857665297063343e-08
  4 SNES Function norm 4.702458609396e-15
Load increment #28, load: 13.942857142857141
	Inner Newton summary:
		Unique number of iterations: [1 2 3]
		Counts of unique number of iterations: [3620    3  127]
		Maximum f: 5.723571310910765
		Maximum residual: 8.619153716275755e-09
  0 SNES Function norm 1.195591532823e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3616    2  129    3]
		Maximum f: 6.645080068812803
		Maximum residual: 1.7387329320585557e-08
  1 SNES Function norm 4.874621421367e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3616    1  130    3]
		Maximum f: 6.81177212166552
		Maximum residual: 1.8920521346362546e-08
  2 SNES Function norm 6.244285417544e-05
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3616    1  130    3]
		Maximum f: 6.812691928883735
		Maximum residual: 1.8927764677838143e-08
  3 SNES Function norm 8.457653354988e-09
Load increment #29, load: 14.36938775510204
	Inner Newton summary:
		Unique number of iterations: [1 2 3]
		Counts of unique number of iterations: [3608    4  138]
		Maximum f: 6.813577582394805
		Maximum residual: 2.63905120203199e-08
  0 SNES Function norm 1.319402690197e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3]
		Counts of unique number of iterations: [3607    2  141]
		Maximum f: 7.596225865539354
		Maximum residual: 2.7316580560848143e-08
  1 SNES Function norm 6.000358681716e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3]
		Counts of unique number of iterations: [3607    2  141]
		Maximum f: 7.662614238370983
		Maximum residual: 3.1755276033684446e-08
  2 SNES Function norm 2.231686172689e-05
	Inner Newton summary:
		Unique number of iterations: [1 2 3]
		Counts of unique number of iterations: [3607    2  141]
		Maximum f: 7.662669562454688
		Maximum residual: 3.176597798977612e-08
  3 SNES Function norm 8.253916069700e-10
Load increment #30, load: 14.795918367346937
	Inner Newton summary:
		Unique number of iterations: [1 2 3]
		Counts of unique number of iterations: [3595    6  149]
		Maximum f: 7.663139576851206
		Maximum residual: 2.2367382946246526e-08
  0 SNES Function norm 1.503558807032e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3]
		Counts of unique number of iterations: [3594    3  153]
		Maximum f: 8.454917434393025
		Maximum residual: 2.0862183272009163e-08
  1 SNES Function norm 6.182460517857e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3]
		Counts of unique number of iterations: [3594    3  153]
		Maximum f: 8.561832222792974
		Maximum residual: 2.6117714665962503e-08
  2 SNES Function norm 5.709638472829e-05
	Inner Newton summary:
		Unique number of iterations: [1 2 3]
		Counts of unique number of iterations: [3594    3  153]
		Maximum f: 8.562225052308646
		Maximum residual: 2.613695352113217e-08
  3 SNES Function norm 6.605695473685e-09
Load increment #31, load: 15.222448979591835
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3582    6  161    1]
		Maximum f: 8.56243815525612
		Maximum residual: 5.366137699670671e-08
  0 SNES Function norm 1.356272783579e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3578    7  162    2    1]
		Maximum f: 9.462706125067776
		Maximum residual: 2.9615099837545648e-08
  1 SNES Function norm 1.108959098908e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3578    6  163    2    1]
		Maximum f: 9.531525798162859
		Maximum residual: 3.6424348445667524e-08
  2 SNES Function norm 2.776717567076e-05
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3578    6  163    2    1]
		Maximum f: 9.531403525835538
		Maximum residual: 3.667264595718097e-08
  3 SNES Function norm 1.547137141417e-09
Load increment #32, load: 15.648979591836733
	Inner Newton summary:
		Unique number of iterations: [1 2 3 5]
		Counts of unique number of iterations: [3558    6  185    1]
		Maximum f: 9.531507663600417
		Maximum residual: 2.2943694576787332e-08
  0 SNES Function norm 1.690023171389e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3555    2  188    4    1]
		Maximum f: 11.208475801671002
		Maximum residual: 1.3925564911763202e-08
  1 SNES Function norm 1.136486356816e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3555    1  189    4    1]
		Maximum f: 11.328808304184566
		Maximum residual: 2.0711840643078828e-08
  2 SNES Function norm 4.043685669761e-05
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3555    1  189    4    1]
		Maximum f: 11.32899628921991
		Maximum residual: 2.0738739035798078e-08
  3 SNES Function norm 3.107154046850e-09
Load increment #33, load: 16.07551020408163
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3542    2  204    2]
		Maximum f: 11.32923800881529
		Maximum residual: 1.1473379158000754e-08
  0 SNES Function norm 1.445355802854e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3540    4  198    8]
		Maximum f: 12.39001686490581
		Maximum residual: 3.481616605399822e-08
  1 SNES Function norm 8.639360130456e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3540    6  196    8]
		Maximum f: 12.473904687662117
		Maximum residual: 5.45588762887105e-08
  2 SNES Function norm 1.258589840397e-05
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3540    6  196    8]
		Maximum f: 12.473963542525938
		Maximum residual: 5.447864304164713e-08
  3 SNES Function norm 2.119208342218e-10
Load increment #34, load: 16.502040816326527
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3522    7  219    2]
		Maximum f: 12.4740402231657
		Maximum residual: 3.0852740898079846e-08
  0 SNES Function norm 1.409596865654e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3522   10  210    8]
		Maximum f: 13.621030460089989
		Maximum residual: 2.6838150963321442e-08
  1 SNES Function norm 3.751124853547e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3520   10  212    8]
		Maximum f: 13.672897613109855
		Maximum residual: 3.5242904493329216e-08
  2 SNES Function norm 1.294558689652e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3521    9  212    8]
		Maximum f: 13.683042197974691
		Maximum residual: 3.036042905899164e-08
  3 SNES Function norm 5.415828097312e-05
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3521    9  212    8]
		Maximum f: 13.683257893898288
		Maximum residual: 3.031424367973327e-08
  4 SNES Function norm 1.090266480654e-09
Load increment #35, load: 16.928571428571427
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3507   13  229    1]
		Maximum f: 13.683249738187323
		Maximum residual: 1.8274114501768337e-08
  0 SNES Function norm 1.773988907407e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3499    4  238    9]
		Maximum f: 15.181050137106139
		Maximum residual: 5.335973133942075e-08
  1 SNES Function norm 4.041295762343e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3499    4  238    9]
		Maximum f: 15.293687390994892
		Maximum residual: 2.382814576282647e-08
  2 SNES Function norm 2.084752657077e-05
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3499    4  238    9]
		Maximum f: 15.293641206172301
		Maximum residual: 2.3835954034747172e-08
  3 SNES Function norm 5.396665264123e-10
Load increment #36, load: 17.355102040816327
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3480   11  258    1]
		Maximum f: 15.293621570476732
		Maximum residual: 4.37332654479395e-08
  0 SNES Function norm 1.733879266120e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3478    7  254   11]
		Maximum f: 16.65829634458384
		Maximum residual: 4.011336823622539e-08
  1 SNES Function norm 1.499749094517e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3478    5  256   11]
		Maximum f: 16.75338047972726
		Maximum residual: 1.1986830756764936e-08
  2 SNES Function norm 3.007459439803e-05
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3478    5  256   11]
		Maximum f: 16.753739177854943
		Maximum residual: 1.1976577925699163e-08
  3 SNES Function norm 1.425970891118e-09
Load increment #37, load: 17.781632653061223
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3464   11  273    1    1]
		Maximum f: 16.753695524790903
		Maximum residual: 5.451325421502184e-08
  0 SNES Function norm 1.861829821566e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3461   11  271    5    2]
		Maximum f: 18.047630053481406
		Maximum residual: 6.100411449593522e-08
  1 SNES Function norm 2.630504081379e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3460    8  275    5    2]
		Maximum f: 18.10943969874471
		Maximum residual: 7.997603551936617e-08
  2 SNES Function norm 3.124568213985e-05
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3460    8  275    5    2]
		Maximum f: 18.10995427013048
		Maximum residual: 8.032064111940249e-08
  3 SNES Function norm 2.300728030970e-09
Load increment #38, load: 18.20816326530612
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3435    8  305    2]
		Maximum f: 18.109866894786588
		Maximum residual: 3.543844086751272e-08
  0 SNES Function norm 1.800669384053e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3427    8  309    6]
		Maximum f: 19.682636417527245
		Maximum residual: 5.462722511283731e-08
  1 SNES Function norm 4.274620709122e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3427    7  310    6]
		Maximum f: 19.816059584913745
		Maximum residual: 7.360482233952666e-08
  2 SNES Function norm 2.487642596236e-05
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3427    7  310    6]
		Maximum f: 19.81669976929875
		Maximum residual: 7.313229349288892e-08
  3 SNES Function norm 1.096984769008e-09
Load increment #39, load: 18.63469387755102
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3407   15  325    3]
		Maximum f: 19.8166086033827
		Maximum residual: 3.294772783286022e-08
  0 SNES Function norm 1.931304808040e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3406    7  328    9]
		Maximum f: 21.38900486104476
		Maximum residual: 4.958933695819563e-08
  1 SNES Function norm 5.092741146638e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3406    7  328    9]
		Maximum f: 21.48012370240422
		Maximum residual: 5.538558616703825e-08
  2 SNES Function norm 2.203884275214e-05
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3406    7  328    9]
		Maximum f: 21.480128176158352
		Maximum residual: 5.5395373537815275e-08
  3 SNES Function norm 1.122075583647e-09
Load increment #40, load: 19.061224489795915
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3385   13  350    2]
		Maximum f: 21.480018052307
		Maximum residual: 3.935054785198907e-08
  0 SNES Function norm 1.807185057934e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3377   14  351    8]
		Maximum f: 23.101864004083076
		Maximum residual: 7.489325207128068e-08
  1 SNES Function norm 4.096957992682e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3376    8  358    8]
		Maximum f: 23.18123342797577
		Maximum residual: 5.437624690129319e-08
  2 SNES Function norm 3.801885713100e-04
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3376    7  359    8]
		Maximum f: 23.1859425453265
		Maximum residual: 5.475290492409997e-08
  3 SNES Function norm 2.086410188907e-08
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3376    7  359    8]
		Maximum f: 23.18594259385986
		Maximum residual: 5.475291128931586e-08
  4 SNES Function norm 6.450273874450e-15
Load increment #41, load: 19.487755102040815
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3356   11  381    2]
		Maximum f: 23.185822012668602
		Maximum residual: 6.081588439714694e-08
  0 SNES Function norm 2.318768587420e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3352    6  380   11    1]
		Maximum f: 25.497168088942164
		Maximum residual: 4.983963708704957e-08
  1 SNES Function norm 5.854691572877e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3351    6  382   10    1]
		Maximum f: 25.56790342173726
		Maximum residual: 7.710589796499596e-08
  2 SNES Function norm 1.857969440167e-04
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3351    6  382   10    1]
		Maximum f: 25.570064621334627
		Maximum residual: 7.630891326853144e-08
  3 SNES Function norm 1.324273259487e-07
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3351    6  382   10    1]
		Maximum f: 25.570065383153636
		Maximum residual: 7.630882125441558e-08
  4 SNES Function norm 4.504251787951e-14
Load increment #42, load: 19.91428571428571
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3324   11  413    1    1]
		Maximum f: 25.569961053877687
		Maximum residual: 5.606058386075585e-08
  0 SNES Function norm 1.991950861898e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3320    6  413   10    1]
		Maximum f: 27.485399219692333
		Maximum residual: 4.017158963102894e-08
  1 SNES Function norm 6.102183392930e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3320    8  411   10    1]
		Maximum f: 27.60556017048504
		Maximum residual: 3.686313760675295e-08
  2 SNES Function norm 5.558752205573e-05
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3320    8  411   10    1]
		Maximum f: 27.605715757299023
		Maximum residual: 3.684919381912581e-08
  3 SNES Function norm 7.661646680116e-09
Load increment #43, load: 20.34081632653061
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3291   14  444    1]
		Maximum f: 27.605586022819814
		Maximum residual: 5.775545097253557e-08
  0 SNES Function norm 2.099035542902e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3288   11  443    8]
		Maximum f: 29.843867326127402
		Maximum residual: 7.139781842063572e-08
  1 SNES Function norm 4.733736371878e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3289   11  442    8]
		Maximum f: 29.96980534076419
		Maximum residual: 6.912248136139325e-08
  2 SNES Function norm 5.313476427749e-04
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3289   11  442    8]
		Maximum f: 29.97403243340989
		Maximum residual: 6.694396134702839e-08
  3 SNES Function norm 1.094543341777e-07
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3289   11  442    8]
		Maximum f: 29.974032616432332
		Maximum residual: 6.69429857059461e-08
  4 SNES Function norm 1.327982652808e-14
Load increment #44, load: 20.767346938775507
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3259   20  469    2]
		Maximum f: 29.973884886574275
		Maximum residual: 6.014568481100905e-08
  0 SNES Function norm 2.246453731859e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3253   15  472   10]
		Maximum f: 32.902056673948024
		Maximum residual: 4.978738430260414e-08
  1 SNES Function norm 4.783821924576e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3254    9  477   10]
		Maximum f: 33.07703435009291
		Maximum residual: 5.3298264850318704e-08
  2 SNES Function norm 5.291518412056e-04
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3254    9  477   10]
		Maximum f: 33.08102796048976
		Maximum residual: 5.3375796291569515e-08
  3 SNES Function norm 4.903824792371e-08
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3254    9  477   10]
		Maximum f: 33.08102797676841
		Maximum residual: 5.337579780233577e-08
  4 SNES Function norm 8.116657537757e-15
Load increment #45, load: 21.193877551020407
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3222   14  512    1    1]
		Maximum f: 33.08088303623554
		Maximum residual: 4.586197404437217e-08
  0 SNES Function norm 2.150367183099e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3214   14  509   12    1]
		Maximum f: 36.42903952633926
		Maximum residual: 6.06036328661853e-08
  1 SNES Function norm 8.415327791298e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3212   14  512   11    1]
		Maximum f: 36.59087082197445
		Maximum residual: 7.033600267395731e-08
  2 SNES Function norm 2.624314374597e-04
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3212   14  512   11    1]
		Maximum f: 36.59283926162488
		Maximum residual: 7.04186726369224e-08
  3 SNES Function norm 2.176387424467e-07
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3212   14  512   11    1]
		Maximum f: 36.592840384998034
		Maximum residual: 7.041868807929974e-08
  4 SNES Function norm 1.349726752827e-13
Load increment #46, load: 21.620408163265303
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3178   21  549    1    1]
		Maximum f: 36.59268395743397
		Maximum residual: 2.3037653494723446e-08
  0 SNES Function norm 2.628890154187e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3175   15  548   12]
		Maximum f: 40.880163997825704
		Maximum residual: 4.861051765330541e-08
  1 SNES Function norm 9.704135791355e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3174   13  551   12]
		Maximum f: 41.096880445879776
		Maximum residual: 5.8885975211709047e-08
  2 SNES Function norm 3.940138059941e-04
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3174   13  551   12]
		Maximum f: 41.09933447019447
		Maximum residual: 5.8958579625387315e-08
  3 SNES Function norm 3.404476109117e-07
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3174   13  551   12]
		Maximum f: 41.099337319965144
		Maximum residual: 5.89586393073194e-08
  4 SNES Function norm 3.270753863591e-13
Load increment #47, load: 22.046938775510203
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3138   21  588    3]
		Maximum f: 41.099162405545904
		Maximum residual: 1.5081275145365286e-07
  0 SNES Function norm 2.587384322553e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3129   27  582   12]
		Maximum f: 46.39966838688648
		Maximum residual: 6.262153862432692e-08
  1 SNES Function norm 6.838178297473e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3127   21  589   12    1]
		Maximum f: 46.935923909893745
		Maximum residual: 1.0299805512008026e-07
  2 SNES Function norm 1.362075338948e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3127   19  591   12    1]
		Maximum f: 46.972122873919
		Maximum residual: 8.753638625454746e-08
  3 SNES Function norm 8.572437203236e-06
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3127   19  591   12    1]
		Maximum f: 46.97221015126677
		Maximum residual: 8.753995914456126e-08
  4 SNES Function norm 1.261711528208e-10
Load increment #48, load: 22.4734693877551
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3086   30  627    7]
		Maximum f: 46.97203139227426
		Maximum residual: 6.743331377951034e-08
  0 SNES Function norm 2.944549916639e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3072   26  633   19]
		Maximum f: 54.79142280645232
		Maximum residual: 1.280752743493727e-07
  1 SNES Function norm 1.035617501699e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3072   25  634   19]
		Maximum f: 55.92681732819181
		Maximum residual: 1.9443039041972241e-07
  2 SNES Function norm 1.664047906161e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3072   23  637   18]
		Maximum f: 55.97127030749179
		Maximum residual: 1.9966759690432078e-07
  3 SNES Function norm 6.382437508103e-06
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [3072   23  637   18]
		Maximum f: 55.97133470774749
		Maximum residual: 1.9967645383223433e-07
  4 SNES Function norm 1.046914159393e-10
Load increment #49, load: 22.9
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [3006   36  697   10    1]
		Maximum f: 55.97111640595315
		Maximum residual: 9.181422069428905e-08
  0 SNES Function norm 3.451445654693e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [2979   20  730   20    1]
		Maximum f: 71.24902596408256
		Maximum residual: 1.7736360583511463e-07
  1 SNES Function norm 3.791796477487e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [2971   20  726   32    1]
		Maximum f: 78.5108334052504
		Maximum residual: 3.422498510147269e-07
  2 SNES Function norm 6.547900362151e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [2973   18  725   33    1]
		Maximum f: 79.67788407053536
		Maximum residual: 5.53015502199453e-07
  3 SNES Function norm 1.411257682783e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [2973   18  725   33    1]
		Maximum f: 79.77364035795087
		Maximum residual: 5.717286862449747e-07
  4 SNES Function norm 1.114569889288e-04
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [2973   18  725   33    1]
		Maximum f: 79.775181015846
		Maximum residual: 5.719374277704249e-07
  5 SNES Function norm 1.307185600199e-08
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [2973   18  725   33    1]
		Maximum f: 79.77518100703846
		Maximum residual: 5.719374195648558e-07
  6 SNES Function norm 1.520730377443e-14
Load increment #50, load: 22.96
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [2875   27  823   25]
		Maximum f: 79.77483948791341
		Maximum residual: 7.720997784225645e-08
  0 SNES Function norm 4.510532966358e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [2980  163  590   17]
		Maximum f: 18.040683536862442
		Maximum residual: 3.262301343503065e-08
  1 SNES Function norm 3.813905199849e-01
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [2956  143  646    5]
		Maximum f: 21.19212798427792
		Maximum residual: 3.221171668014536e-08
  2 SNES Function norm 2.731881915507e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [2939  126  676    9]
		Maximum f: 27.35877462534762
		Maximum residual: 3.970759483320902e-08
  3 SNES Function norm 9.518819283938e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [2937  103  700   10]
		Maximum f: 32.93695984912134
		Maximum residual: 7.979458128360009e-08
  4 SNES Function norm 2.440368481887e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [2935  101  703   11]
		Maximum f: 34.04873608851252
		Maximum residual: 2.491242365454654e-08
  5 SNES Function norm 4.704398167594e-04
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [2935  101  703   11]
		Maximum f: 34.14222362721051
		Maximum residual: 2.4406673263822645e-08
  6 SNES Function norm 1.117573929913e-06
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [2935  101  703   11]
		Maximum f: 34.14235387239224
		Maximum residual: 2.4406105961311278e-08
  7 SNES Function norm 5.486653804204e-12
Load increment #51, load: 22.99
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [2923  112  712    3]
		Maximum f: 34.14219556832628
		Maximum residual: 6.430010808567842e-08
  0 SNES Function norm 1.458357152583e-02
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [2925  161  662    2]
		Maximum f: 29.88026032183153
		Maximum residual: 2.927566721427952e-08
  1 SNES Function norm 5.932855542944e-03
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [2926  157  665    2]
		Maximum f: 31.321078373323793
		Maximum residual: 3.3343378048205675e-08
  2 SNES Function norm 6.233209758953e-04
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [2926  157  665    2]
		Maximum f: 31.363803938675638
		Maximum residual: 3.320235793407725e-08
  3 SNES Function norm 4.410931140236e-07
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [2926  157  665    2]
		Maximum f: 31.36387692081376
		Maximum residual: 3.320264076036999e-08
  4 SNES Function norm 1.805097926217e-12
Slope stability factor: 6.663768115942029

Note

We demonstrated here the use of PETSc.SNES together with external operators through the PETScNonlinearProblem and PETScNonlinearSolver classes. If the user is familiar with original DOLFINx NonlinearProblem, feel free to use NonlinearProblemWithCallback covered in the von Mises tutorial.

Verification#

Critical load#

According to Chen and Liu [1990], we can derive analytically the slope stability factor \(l_\text{lim}\) for the standard Mohr-Coulomb plasticity model (without apex smoothing) under plane strain assumption for associative plastic flow

\[ l_\text{lim} = \gamma_\text{lim} H/c, \]

where \(\gamma_\text{lim}\) is an associated value of the soil self-weight. In particular, for the rectangular slope with the friction angle \(\phi\) equal to \(30^\circ\), \(l_\text{lim} = 6.69\) [Chen and Liu, 1990]. Thus, by computing \(\gamma_\text{lim}\) from the formula above, we can progressively increase the second component of the gravitational body force \(\boldsymbol{q}=[0, -\gamma]^T\), up to the critical value \(\gamma_\text{lim}^\text{num}\), when the perfect plasticity plateau is reached on the loading-displacement curve at the \((0, H)\) point and then compare \(\gamma_\text{lim}^\text{num}\) against analytical \(\gamma_\text{lim}\).

By demonstrating the loading-displacement curve on the figure below we approve that the yield strength limit reached for \(\gamma_\text{lim}^\text{num}\) is close to \(\gamma_\text{lim}\).

if len(points_on_process) > 0:
    l_lim = 6.69
    gamma_lim = l_lim / H * c
    plt.plot(results[:, 0], results[:, 1], "o-", label=r"$\gamma$")
    plt.axhline(y=gamma_lim, color="r", linestyle="--", label=r"$\gamma_\text{lim}$")
    plt.xlabel(r"Displacement of the slope $u_x$ at $(0, H)$ [mm]")
    plt.ylabel(r"Soil self-weight $\gamma$ [MPa/mm$^3$]")
    plt.grid()
    plt.legend()
../_images/87f53306a8bb2ef5e098eb0a922c98ff4f8b947d83d11fb1bb49706edbdd6623.png

The slope profile reaching its stability limit:

try:
    import pyvista

    print(pyvista.global_theme.jupyter_backend)
    import dolfinx.plot

    pyvista.start_xvfb(0.1)

    W = fem.functionspace(domain, ("Lagrange", 1, (gdim,)))
    u_tmp = fem.Function(W, name="Displacement")
    u_tmp.interpolate(u)

    pyvista.start_xvfb()
    plotter = pyvista.Plotter(window_size=[600, 400])
    topology, cell_types, x = dolfinx.plot.vtk_mesh(domain)
    grid = pyvista.UnstructuredGrid(topology, cell_types, x)
    vals = np.zeros((x.shape[0], 3))
    vals[:, : len(u_tmp)] = u_tmp.x.array.reshape((x.shape[0], len(u_tmp)))
    grid["u"] = vals
    warped = grid.warp_by_vector("u", factor=20)
    plotter.add_text("Displacement field", font_size=11)
    plotter.add_mesh(warped, show_edges=False, show_scalar_bar=True)
    plotter.view_xy()
    plotter.show()
except ImportError:
    print("pyvista required for this plot")
static
error: XDG_RUNTIME_DIR is invalid or not set in the environment.
../_images/aafefdb9cb8f0a1f3c2a5a12b30cac348fa44ed64a69035f3873e1e96facdbdf.png

Yield surface#

We verify that the constitutive model is correctly implemented by tracing the yield surface. We generate several stress paths and check whether they remain within the Mohr-Coulomb yield surface. The stress tracing is performed in the Haigh-Westergaard coordinates \((\xi, \rho, \theta)\) which are defined as follows

\[ \xi = \frac{1}{\sqrt{3}}I_1, \quad \rho = \sqrt{2J_2}, \quad \sin(3\theta) = -\frac{3\sqrt{3}}{2} \frac{J_3}{J_2^{3/2}}, \]

where \(J_3(\boldsymbol{\sigma}) = \det(\boldsymbol{s})\) is the third invariant of the deviatoric part of the stress tensor, \(\xi\) is the deviatoric coordinate, \(\rho\) is the radial coordinate and the angle \(\theta \in [-\frac{\pi}{6}, \frac{\pi}{6}]\) is called Lode or stress angle.

To generate the stress paths we use the principal stresses formula written in Haigh-Westergaard coordinates as follows

\[\begin{split} \begin{pmatrix} \sigma_{I} \\ \sigma_{II} \\ \sigma_{III} \end{pmatrix} = p \begin{pmatrix} 1 \\ 1 \\ 1 \end{pmatrix} + \frac{\rho}{\sqrt{2}} \begin{pmatrix} \cos\theta + \frac{\sin\theta}{\sqrt{3}} \\ -\frac{2\sin\theta}{\sqrt{3}} \\ \frac{\sin\theta}{\sqrt{3}} - \cos\theta \end{pmatrix}, \end{split}\]

where \(p = \xi/\sqrt{3}\) is a hydrostatic variable and \(\sigma_{I} \geq \sigma_{II} \geq \sigma_{III}\).

Now we generate the loading path by evaluating principal stresses in Haigh-Westergaard coordinates for the Lode angle \(\theta\) being varied from \(-\frac{\pi}{6}\) to \(\frac{\pi}{6}\) with fixed \(\rho\) and \(p\).

N_angles = 50
N_loads = 9  # number of loadings or paths
eps = 0.00001
R = 0.7  # fix the values of rho
p = 0.1  # fix the deviatoric coordinate
theta_1 = -np.pi / 6
theta_2 = np.pi / 6

theta_values = np.linspace(theta_1 + eps, theta_2 - eps, N_angles)
theta_returned = np.empty((N_loads, N_angles))
rho_returned = np.empty((N_loads, N_angles))
sigma_returned = np.empty((N_loads, N_angles, stress_dim))

# fix an increment of the stress path
dsigma_path = np.zeros((N_angles, stress_dim))
dsigma_path[:, 0] = (R / np.sqrt(2)) * (np.cos(theta_values) + np.sin(theta_values) / np.sqrt(3))
dsigma_path[:, 1] = (R / np.sqrt(2)) * (-2 * np.sin(theta_values) / np.sqrt(3))
dsigma_path[:, 2] = (R / np.sqrt(2)) * (np.sin(theta_values) / np.sqrt(3) - np.cos(theta_values))

sigma_n_local = np.zeros_like(dsigma_path)
sigma_n_local[:, 0] = p
sigma_n_local[:, 1] = p
sigma_n_local[:, 2] = p
derviatoric_axis = tr

Then, we define and vectorize functions rho, Lode_angle and sigma_tracing evaluating respectively the coordinates \(\rho\), \(\theta\) and the corrected (or “returned”) stress tensor for a certain stress state. sigma_tracing calls the function return_mapping, where the constitutive model was defined via JAX previously.

def rho(sigma_local):
    s = dev @ sigma_local
    return jnp.sqrt(2.0 * J2(s))


def Lode_angle(sigma_local):
    s = dev @ sigma_local
    arg = -(3.0 * jnp.sqrt(3.0) * J3(s)) / (2.0 * jnp.sqrt(J2(s) * J2(s) * J2(s)))
    arg = jnp.clip(arg, -1.0, 1.0)
    angle = 1.0 / 3.0 * jnp.arcsin(arg)
    return angle


def sigma_tracing(sigma_local, sigma_n_local):
    deps_elas = S_elas @ sigma_local
    sigma_corrected, state = return_mapping(deps_elas, sigma_n_local)
    yielding = state[2]
    return sigma_corrected, yielding


Lode_angle_v = jax.jit(jax.vmap(Lode_angle, in_axes=(0)))
rho_v = jax.jit(jax.vmap(rho, in_axes=(0)))
sigma_tracing_v = jax.jit(jax.vmap(sigma_tracing, in_axes=(0, 0)))

For each stress path, we call the function sigma_tracing_v to get the corrected stress state and then we project it onto the deviatoric plane \((\rho, \theta)\) with a fixed value of \(p\).

for i in range(N_loads):
    print(f"Loading path#{i}")
    dsigma, yielding = sigma_tracing_v(dsigma_path, sigma_n_local)
    dp = dsigma @ tr / 3.0 - p
    dsigma -= np.outer(dp, derviatoric_axis)  # projection on the same deviatoric plane

    sigma_returned[i, :] = dsigma
    theta_returned[i, :] = Lode_angle_v(dsigma)
    rho_returned[i, :] = rho_v(dsigma)
    print(f"max f: {jnp.max(yielding)}\n")
    sigma_n_local[:] = dsigma
Loading path#0
max f: -2.005661796528811

Loading path#1
max f: -1.6474140052837287

Loading path#2
max f: -1.208026577509537

Loading path#3
max f: -0.7355419142072792

Loading path#4
max f: -0.24734105482689195

Loading path#5
max f: 0.24936204365846004

Loading path#6
max f: 0.5729074243253174

Loading path#7
max f: 0.6673099640326816

Loading path#8
max f: 0.6947128480833946

Then, by knowing the expression of the standrad Mohr-Coulomb yield surface in principle stresses, we can obtain an analogue expression in Haigh-Westergaard coordinates, which leads us to the following equation:

(7)#\[ \frac{\rho}{\sqrt{6}}(\sqrt{3}\cos\theta + \sin\phi \sin\theta) - p\sin\phi - c\cos\phi= 0. \]

Thus, we restore the standard Mohr-Coulomb yield surface:

def MC_yield_surface(theta_, p):
    """Restores the coordinate `rho` satisfying the standard Mohr-Coulomb yield
    criterion."""
    rho = (np.sqrt(2) * (c * np.cos(phi) + p * np.sin(phi))) / (
        np.cos(theta_) - np.sin(phi) * np.sin(theta_) / np.sqrt(3)
    )
    return rho


rho_standard_MC = MC_yield_surface(theta_values, p)

Finally, we plot the yield surface:

colormap = cm.plasma
colors = colormap(np.linspace(0.0, 1.0, N_loads))

fig, ax = plt.subplots(subplot_kw={"projection": "polar"}, figsize=(8, 8))
# Mohr-Coulomb yield surface with apex smoothing
for i, color in enumerate(colors):
    rho_total = np.array([])
    theta_total = np.array([])
    for j in range(12):
        angles = j * np.pi / 3 - j % 2 * theta_returned[i] + (1 - j % 2) * theta_returned[i]
        theta_total = np.concatenate([theta_total, angles])
        rho_total = np.concatenate([rho_total, rho_returned[i]])

    ax.plot(theta_total, rho_total, ".", color=color)

# standard Mohr-Coulomb yield surface
theta_standard_MC_total = np.array([])
rho_standard_MC_total = np.array([])
for j in range(12):
    angles = j * np.pi / 3 - j % 2 * theta_values + (1 - j % 2) * theta_values
    theta_standard_MC_total = np.concatenate([theta_standard_MC_total, angles])
    rho_standard_MC_total = np.concatenate([rho_standard_MC_total, rho_standard_MC])
ax.plot(theta_standard_MC_total, rho_standard_MC_total, "-", color="black")
ax.set_yticklabels([])

norm = mcolors.Normalize(vmin=0.1, vmax=0.7 * 9)
sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax, orientation="vertical")
cbar.set_label(r"Magnitude of the stress path deviator, $\rho$ [MPa]")

plt.show()
../_images/5a4325ea175358d850cf4b2362f086c8c2e31f831365c148b2c6bf4d06236108.png

Each colour represents one loading path. The circles are associated with the loading during the elastic phase. Once the loading reaches the elastic limit, the circles start outlining the yield surface, which in the limit lay along the standard Mohr-Coulomb one without smoothing (black contour).

Taylor test#

Here, we perform a Taylor test to check that the form \(F\) and its Jacobian \(J\) are consistent zeroth- and first-order approximations of the residual \(F\). In particular, the test verifies that the program dsigma_ddeps_vec obtained by the JAX’s AD returns correct values of the external operator \(\boldsymbol{\sigma}\) and its derivative \(\boldsymbol{C}_\text{tang}\), which define \(F\) and \(J\) respectively.

To perform the test, we introduce the operators \(\mathcal{F}: V \rightarrow V^\prime\) and \(\mathcal{J}: V \rightarrow \mathcal{L}(V, V^\prime)\) defined as follows:

\[ \langle \mathcal{F}(\boldsymbol{u}), \boldsymbol{v} \rangle := F(\boldsymbol{u}; \boldsymbol{v}), \quad \forall \boldsymbol{v} \in V, \]
\[ \langle (\mathcal{J}(\boldsymbol{u}))(k\boldsymbol{\delta u}), \boldsymbol{v} \rangle := J(\boldsymbol{u}; k\boldsymbol{\delta u}, \boldsymbol{v}), \quad \forall \boldsymbol{v} \in V, \]

where \(V^\prime\) is a dual space of \(V\), \(\langle \cdot, \cdot \rangle\) is the \(V^\prime \times V\) duality pairing and \(\mathcal{L}(V, V^\prime)\) is a space of bounded linear operators from \(V\) to its dual.

Then, by following the Taylor’s theorem on Banach spaces and perturbating the functional \(\mathcal{F}\) in the direction \(k \, \boldsymbol{δu} \in V\) for \(k > 0\), the zeroth and first order Taylor reminders \(r_k^0\) and \(r_k^1\) have the following mesh-independent convergence rates in the dual space \(V^\prime\):

(8)#\[ \| r_k^0 \|_{V^\prime} := \| \mathcal{F}(\boldsymbol{u} + k \, \boldsymbol{\delta u}) - \mathcal{F}(\boldsymbol{u}) \|_{V^\prime} \longrightarrow 0 \text{ at } O(k), \]
(9)#\[ \| r_k^1 \|_{V^\prime} := \| \mathcal{F}(\boldsymbol{u} + k \, \boldsymbol{\delta u}) - \mathcal{F}(\boldsymbol{u}) - \, (\mathcal{J}(\boldsymbol{u}))(k\boldsymbol{\delta u}) \|_{V^\prime} \longrightarrow 0 \text{ at } O(k^2). \]

In order to compute the norm of an element \(f \in V^\prime\) from the dual space \(V^\prime\), we apply the Riesz representation theorem, which states that there is a linear isometric isomorphism \(\mathcal{R} : V^\prime \to V\), which associates a linear functional \(f\) with a unique element \(\mathcal{R} f = \boldsymbol{u} \in V\). In practice, within a finite subspace \(V_h \subset V\), the Riesz map \(\mathcal{R}\) is represented by the matrix \(\mathsf{L}^{-1}\), the inverse of the Laplacian operator [Kirby, 2010]

\[ \mathsf{L}_{ij} = \int\limits_\Omega \nabla\varphi_i \cdot \nabla\varphi_j \mathrm{d} x , \quad i,j = 1, \dots, n, \]

where \(\{\varphi_i\}_{i=1}^{\dim V_h}\) is a set of basis function of the space \(V_h\).

If the Euclidean vectors \(\mathsf{r}_k^i \in \mathbb{R}^{\dim V_h}, \, i \in \{0,1\}\) represent the Taylor remainders from (8)(9) in the finite space, then the dual norms are computed through the following formula [Kirby, 2010]

(10)#\[ \| r_k^i \|^2_{V^\prime_h} = (\mathsf{r}_k^i)^T \mathsf{L}^{-1} \mathsf{r}_k^i, \quad i \in \{0,1\}. \]

In practice, the vectors \(\mathsf{r}_k^i\) are defined through the residual vector \(\mathsf{F} \in \mathbb{R}^{\dim V_h}\) and the Jacobian matrix \(\mathsf{J} \in \mathbb{R}^{\dim V_h\times\dim V_h}\)

(11)#\[ \mathsf{r}_k^0 = \mathsf{F}(\mathsf{u} + k \, \mathsf{\delta u}) - \mathsf{F}(\mathsf{u}) \in \mathbb{R}^n, \]
(12)#\[ \mathsf{r}_k^1 = \mathsf{F}(\mathsf{u} + k \, \mathsf{\delta u}) - \mathsf{F}(\mathsf{u}) - \, \mathsf{J}(\mathsf{u}) \cdot k\mathsf{\delta u} \in \mathbb{R}^n, \]

where \(\mathsf{u} \in \mathbb{R}^{\dim V_h}\) and \(\mathsf{\delta u} \in \mathbb{R}^{\dim V_h}\) represent dispacement fields \(\boldsymbol{u} \in V_h\) and \(\boldsymbol{\delta u} \in V_h\).

Now we can proceed with the Taylor test implementation. Let us first start with defining the Laplace operator.

L_form = fem.form(ufl.inner(ufl.grad(u_hat), ufl.grad(v)) * ufl.dx)
L = fem.petsc.assemble_matrix(L_form, bcs=bcs)
L.assemble()
Riesz_solver = PETSc.KSP().create(domain.comm)
Riesz_solver.setType("preonly")
Riesz_solver.getPC().setType("lu")
Riesz_solver.setOperators(L)
y = fem.Function(V, name="Riesz_representer_of_r")  # r - a Taylor remainder

Now we initialize main variables of the plasticity problem.

# Reset main variables to zero including the external operators values
sigma_n.x.array[:] = 0.0
sigma.ref_coefficient.x.array[:] = 0.0
J_external_operators[0].ref_coefficient.x.array[:] = 0.0
# Reset the values of the consistent tangent matrix to elastic moduli
Du.x.array[:] = 1.0
evaluated_operands = evaluate_operands(F_external_operators)
_ = evaluate_external_operators(J_external_operators, evaluated_operands)
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -2.2109628558533108
		Maximum residual: 0.0

As the derivatives of the constitutive model are different for elastic and plastic phases, we must consider two initial states for the Taylor test. For this reason, we solve the problem once for a certain loading value to get the initial state close to the one with plastic deformations but still remain in the elastic phase.

i = 0
load = 2.0
q.value = load * np.array([0, -gamma])
Du.x.array[:] = 1e-8

if MPI.COMM_WORLD.rank == 0:
    print(f"Load increment #{i}, load: {load}")

solver.solve(Du)

u.x.petsc_vec.axpy(1.0, Du.x.petsc_vec)
u.x.scatter_forward()

sigma_n.x.array[:] = sigma.ref_coefficient.x.array

Du0 = np.copy(Du.x.array)
sigma_n0 = np.copy(sigma_n.x.array)
Load increment #0, load: 2.0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -2.210962855861672
		Maximum residual: 0.0
  0 SNES Function norm 5.506494297258e-02
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -0.754479660018216
		Maximum residual: 0.0
  1 SNES Function norm 2.938553015478e-14

If we take into account the initial stress state sigma_n0 computed in the cell above, we perform the Taylor test for the plastic phase, otherwise we stay in the elastic one.

Finally, we define the function perform_Taylor_test, which returns the norms of the Taylor reminders in dual space (10)(12).

k_list = np.logspace(-2.0, -6.0, 5)[::-1]


def perform_Taylor_test(Du0, sigma_n0):
    # r0 = F(Du0 + k*δu) - F(Du0)
    # r1 = F(Du0 + k*δu) - F(Du0) - k*J(Du0)*δu
    Du.x.array[:] = Du0
    sigma_n.x.array[:] = sigma_n0
    evaluated_operands = evaluate_operands(F_external_operators)
    ((_, sigma_new),) = evaluate_external_operators(J_external_operators, evaluated_operands)
    sigma.ref_coefficient.x.array[:] = sigma_new

    F0 = fem.petsc.assemble_vector(F_form)  # F(Du0)
    F0.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
    fem.set_bc(F0, bcs)

    J0 = fem.petsc.assemble_matrix(J_form, bcs=bcs)
    J0.assemble()  # J(Du0)
    Ju = J0.createVecLeft()  # Ju = J0 @ u

    δu = fem.Function(V)
    δu.x.array[:] = Du0  # δu == Du0

    zero_order_remainder = np.zeros_like(k_list)
    first_order_remainder = np.zeros_like(k_list)

    for i, k in enumerate(k_list):
        Du.x.array[:] = Du0 + k * δu.x.array
        evaluated_operands = evaluate_operands(F_external_operators)
        ((_, sigma_new),) = evaluate_external_operators(J_external_operators, evaluated_operands)
        sigma.ref_coefficient.x.array[:] = sigma_new

        F_delta = fem.petsc.assemble_vector(F_form)  # F(Du0 + h*δu)
        F_delta.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
        fem.set_bc(F_delta, bcs)

        J0.mult(δu.x.petsc_vec, Ju)  # Ju = J(Du0)*δu
        Ju.scale(k)  # Ju = k*Ju

        r0 = F_delta - F0
        r1 = F_delta - F0 - Ju

        Riesz_solver.solve(r0, y.x.petsc_vec)  # y = L^{-1} r0
        y.x.scatter_forward()
        zero_order_remainder[i] = np.sqrt(r0.dot(y.x.petsc_vec))  # sqrt{r0^T L^{-1} r0}

        Riesz_solver.solve(r1, y.x.petsc_vec)  # y = L^{-1} r1
        y.x.scatter_forward()
        first_order_remainder[i] = np.sqrt(r1.dot(y.x.petsc_vec))  # sqrt{r1^T L^{-1} r1}

    return zero_order_remainder, first_order_remainder


print("Elastic phase")
zero_order_remainder_elastic, first_order_remainder_elastic = perform_Taylor_test(Du0, 0.0)
print("Plastic phase")
zero_order_remainder_plastic, first_order_remainder_plastic = perform_Taylor_test(Du0, sigma_n0)
Elastic phase
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -0.754479660018216
		Maximum residual: 0.0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -0.7544777931845448
		Maximum residual: 0.0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -0.7544609916686915
		Maximum residual: 0.0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -0.7542929752409684
		Maximum residual: 0.0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -0.7526126841444585
		Maximum residual: 0.0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [3750]
		Maximum f: -0.7357971890466226
		Maximum residual: 0.0
Plastic phase
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 1.1914324503846294
		Maximum residual: 5.03712891275251e-10
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 1.191434439616415
		Maximum residual: 5.037156356780114e-10
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 1.1914523427045762
		Maximum residual: 5.037359065504066e-10
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 1.1916313737948552
		Maximum residual: 5.039395901812625e-10
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 1.1934217055527925
		Maximum residual: 5.059807641206864e-10
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [3748    2]
		Maximum f: 1.211327098969814
		Maximum residual: 5.266651628781508e-10
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].loglog(k_list, zero_order_remainder_elastic, "o-", label=r"$\|r_k^0\|_{V^\prime}$")
axs[0].loglog(k_list, first_order_remainder_elastic, "o-", label=r"$\|r_k^1\|_{V^\prime}$")
annotation.slope_marker((2e-4, 5e-5), 1, ax=axs[0], poly_kwargs={"facecolor": "tab:blue"})
axs[0].text(0.5, -0.2, "(a) Elastic phase", transform=axs[0].transAxes, ha="center", va="top")

axs[1].loglog(k_list, zero_order_remainder_plastic, "o-", label=r"$\|r_k^0\|_{V^\prime}$")
annotation.slope_marker((2e-4, 5e-5), 1, ax=axs[1], poly_kwargs={"facecolor": "tab:blue"})
axs[1].loglog(k_list, first_order_remainder_plastic, "o-", label=r"$\|r_k^1\|_{V^\prime}$")
annotation.slope_marker((2e-4, 5e-13), 2, ax=axs[1], poly_kwargs={"facecolor": "tab:orange"})
axs[1].text(0.5, -0.2, "(b) Plastic phase", transform=axs[1].transAxes, ha="center", va="top")

for i in range(2):
    axs[i].set_xlabel("k")
    axs[i].set_ylabel("Taylor remainder norm")
    axs[i].legend()
    axs[i].grid()

plt.tight_layout()
plt.show()

first_order_rate = np.polyfit(np.log(k_list), np.log(zero_order_remainder_elastic), 1)[0]
second_order_rate = np.polyfit(np.log(k_list), np.log(first_order_remainder_elastic), 1)[0]
print(f"Elastic phase:\n\tthe 1st order rate = {first_order_rate:.2f}\n\tthe 2nd order rate = {second_order_rate:.2f}")
first_order_rate = np.polyfit(np.log(k_list), np.log(zero_order_remainder_plastic), 1)[0]
second_order_rate = np.polyfit(np.log(k_list[1:]), np.log(first_order_remainder_plastic[1:]), 1)[0]
print(f"Plastic phase:\n\tthe 1st order rate = {first_order_rate:.2f}\n\tthe 2nd order rate = {second_order_rate:.2f}")
../_images/f17125d551880500bbce74d851f0b5bac99987784fdc3091947b8924a77c4fdd.png
Elastic phase:
	the 1st order rate = 1.00
	the 2nd order rate = 0.00
Plastic phase:
	the 1st order rate = 1.00
	the 2nd order rate = 1.93

For the elastic phase (a) the zeroth-order Taylor remainder \(r_k^0\) achieves the first-order convergence rate, whereas the first-order remainder \(r_k^1\) is computed at the level of machine precision due to the constant Jacobian. Similarly to the elastic flow, the zeroth-order Taylor remainder \(r_k^0\) of the plastic phase (b) reaches the first-order convergence, whereas the first-order remainder \(r_k^1\) achieves the second-order convergence rate, as expected.

References#

[CL90] (1,2,3)

W. F. Chen and X. L. Liu. Limit Analysis in Soil Mechanics. Volume 52. Elsevier Science, 1990.

[Kir10] (1,2)

Robert C. Kirby. From Functional Analysis to Iterative Methods. SIAM Review, 52(2):269–293, 2010. doi:10.1137/070706914.