Plasticity of Mohr-Coulomb with apex-smoothing#

This tutorial aims to demonstrate how modern automatic algorithmic differentiation (AD) techniques may be used to define a complex constitutive model demanding a lot of by-hand differentiation. In particular, we implement the non-associative plasticity model of Mohr-Coulomb with apex-smoothing applied to a slope stability problem for soil. We use the JAX package to define constitutive relations including the differentiation of certain terms and FEMExternalOperator class to incorporate this model into a weak formulation within UFL.

The tutorial is based on the limit analysis within semi-definite programming framework, where the plasticity model was replaced by the MFront/TFEL implementation of the Mohr-Coulomb elastoplastic model with apex smoothing.

Problem formulation#

We solve a slope stability problem of a soil domain \(\Omega\) represented by a rectangle \([0; L] \times [0; W]\) with homogeneous Dirichlet boundary conditions for the displacement field \(\boldsymbol{u} = \boldsymbol{0}\) on the right side \(x = L\) and the bottom one \(z = 0\). The loading consists of a gravitational body force \(\boldsymbol{q}=[0, -\gamma]^T\) with \(\gamma\) being the soil self-weight. The solution of the problem is to find the collapse load \(q_\text{lim}\), for which we know an analytical solution in the case of the standard Mohr-Coulomb model without smoothing under plane strain assumption for associative plastic law [Chen and Liu, 1990]. Here we follow the same Mandel-Voigt notation as in the von Mises plasticity tutorial.

If \(V\) is a functional space of admissible displacement fields, then we can write out a weak formulation of the problem:

Find \(\boldsymbol{u} \in V\) such that

\[ F(\boldsymbol{u}; \boldsymbol{v}) = \int\limits_\Omega \boldsymbol{\sigma}(\boldsymbol{u}) \cdot \boldsymbol{\varepsilon}(\boldsymbol{v}) \, \mathrm{d}\boldsymbol{x} - \int\limits_\Omega \boldsymbol{q} \cdot \boldsymbol{v} \, \mathrm{d}\boldsymbol{x} = \boldsymbol{0}, \quad \forall \boldsymbol{v} \in V, \]

where \(\boldsymbol{\sigma}\) is an external operator representing the stress tensor.

Note

Although the tutorial shows the implementation of the Mohr-Coulomb model, it is quite general to be adapted to a wide rage of plasticity models that may be defined through a yield surface and a plastic potential.

Implementation#

Preamble#

from mpi4py import MPI
from petsc4py import PETSc

import jax
import jax.lax
import jax.numpy as jnp
import matplotlib.cm as cm
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
from mpltools import annotation  # for slope markers
from solvers import LinearProblem
from utilities import find_cell_by_point

import basix
import ufl
from dolfinx import default_scalar_type, fem, mesh
from dolfinx_external_operator import (
    FEMExternalOperator,
    evaluate_external_operators,
    evaluate_operands,
    replace_external_operators,
)

jax.config.update("jax_enable_x64", True)

Here we define geometrical and material parameters of the problem as well as some useful constants.

E = 6778  # [MPa] Young modulus
nu = 0.25  # [-] Poisson ratio
c = 3.45  # [MPa] cohesion
phi = 30 * np.pi / 180  # [rad] friction angle
psi = 30 * np.pi / 180  # [rad] dilatancy angle
theta_T = 26 * np.pi / 180  # [rad] transition angle as defined by Abbo and Sloan
a = 0.26 * c / np.tan(phi)  # [MPa] tension cuff-off parameter
L, H = (1.2, 1.0)
Nx, Ny = (50, 50)
gamma = 1.0
domain = mesh.create_rectangle(MPI.COMM_WORLD, [np.array([0, 0]), np.array([L, H])], [Nx, Ny])
k_u = 2
gdim = domain.topology.dim
V = fem.functionspace(domain, ("Lagrange", k_u, (gdim,)))


# Boundary conditions
def on_right(x):
    return np.isclose(x[0], L)


def on_bottom(x):
    return np.isclose(x[1], 0.0)


bottom_dofs = fem.locate_dofs_geometrical(V, on_bottom)
right_dofs = fem.locate_dofs_geometrical(V, on_right)

bcs = [
    fem.dirichletbc(np.array([0.0, 0.0], dtype=PETSc.ScalarType), bottom_dofs, V),
    fem.dirichletbc(np.array([0.0, 0.0], dtype=PETSc.ScalarType), right_dofs, V),
]


def epsilon(v):
    grad_v = ufl.grad(v)
    return ufl.as_vector(
        [
            grad_v[0, 0],
            grad_v[1, 1],
            0.0,
            np.sqrt(2.0) * 0.5 * (grad_v[0, 1] + grad_v[1, 0]),
        ]
    )


k_stress = 2 * (k_u - 1)

dx = ufl.Measure(
    "dx",
    domain=domain,
    metadata={"quadrature_degree": k_stress, "quadrature_scheme": "default"},
)

stress_dim = 2 * gdim
S_element = basix.ufl.quadrature_element(domain.topology.cell_name(), degree=k_stress, value_shape=(stress_dim,))
S = fem.functionspace(domain, S_element)


Du = fem.Function(V, name="Du")
u = fem.Function(V, name="Total_displacement")
du = fem.Function(V, name="du")
# v = ufl.TrialFunction(V)
v = ufl.TestFunction(V)

sigma = FEMExternalOperator(epsilon(Du), function_space=S)
sigma_n = fem.Function(S, name="sigma_n")

Defining plasticity model and external operator#

The constitutive model of the soil is described by a non-associative plasticity law without hardening that is defined by the Mohr-Coulomb yield surface \(f\) and the plastic potential \(g\). Both quantities may be expressed through the following function \(h\)

\[\begin{align*} & h(\boldsymbol{\sigma}, \alpha) = \frac{I_1(\boldsymbol{\sigma})}{3}\sin\alpha + \sqrt{J_2(\boldsymbol{\sigma}) K^2(\alpha) + a^2(\alpha)\sin^2\alpha} - c\cos\alpha, \\ & f(\boldsymbol{\sigma}) = h(\boldsymbol{\sigma}, \phi), \\ & g(\boldsymbol{\sigma}) = h(\boldsymbol{\sigma}, \psi), \end{align*}\]

where \(\phi\) and \(\psi\) are friction and dilatancy angles, \(c\) is a cohesion, \(I_1(\boldsymbol{\sigma}) = \mathrm{tr} \boldsymbol{\sigma}\) is the first invariant of the stress tensor and \(J_2(\boldsymbol{\sigma}) = \frac{1}{2}\boldsymbol{s} \cdot \boldsymbol{s}\) is the second invariant of the deviatoric part of the stress tensor. The expression of the coefficient \(K(\alpha)\) may be found in the MFront/TFEL implementation of this plastic model.

During the plastic loading the stress-strain state of the solid must satisfy the following system of nonlinear equations

(5)#\[\begin{split} \begin{cases} \boldsymbol{r}_{g}(\boldsymbol{\sigma}_{n+1}, \Delta\lambda) = \boldsymbol{\sigma}_{n+1} - \boldsymbol{\sigma}_n - \boldsymbol{C} \cdot (\Delta\boldsymbol{\varepsilon} - \Delta\lambda \frac{\mathrm{d} g}{\mathrm{d}\boldsymbol{\sigma}}(\boldsymbol{\sigma_{n+1}})) = \boldsymbol{0}, \\ r_f(\boldsymbol{\sigma}_{n+1}) = f(\boldsymbol{\sigma}_{n+1}) = 0, \end{cases}\end{split}\]

where \(\Delta\) is associated with increments of a quantity between the next loading step \(n + 1\) and the current loading step \(n\).

By introducing the residual vector \(\boldsymbol{r} = [\boldsymbol{r}_{g}^T, r_f]^T\) and its argument vector \(\boldsymbol{y}_{n+1} = [\boldsymbol{\sigma}_{n+1}^T, \Delta\lambda]^T\), we obtain the following nonlinear constitutive equation:

\[ \boldsymbol{r}(\boldsymbol{y}_{n+1}) = \boldsymbol{0}. \]

To solve this equation we apply the Newton method and introduce the local Jacobian of the residual vector \(\boldsymbol{j} := \frac{\mathrm{d} \boldsymbol{r}}{\mathrm{d} \boldsymbol{y}}\). Thus we solve the following linear system at each quadrature point for the plastic phase

\[\begin{split} \begin{cases} \boldsymbol{j}(\boldsymbol{y}_{n})\boldsymbol{t} = - \boldsymbol{r}(\boldsymbol{y}_{n}), \\ \boldsymbol{x}_{n+1} = \boldsymbol{x}_n + \boldsymbol{t}. \end{cases} \end{split}\]

During the elastic loading, we consider a trivial system of equations

(6)#\[\begin{split} \begin{cases} \boldsymbol{\sigma}_{n+1} = \boldsymbol{\sigma}_n + \boldsymbol{C} \cdot \Delta\boldsymbol{\varepsilon}, \\ \Delta\lambda = 0. \end{cases} \end{split}\]

The algorithm solving the systems (5)(6) is called the return-mapping procedure and the solution defines the return-mapping correction of the stress tensor. By implementation of the external operator \(\boldsymbol{\sigma}\) we mean the implementation of this algorithmic procedure.

The automatic differentiation tools of the JAX library are applied to calculate the three distinct derivatives:

  1. \(\frac{\mathrm{d} g}{\mathrm{d}\boldsymbol{\sigma}}\) - derivative of the plastic potential \(g\),

  2. \(j = \frac{\mathrm{d} \boldsymbol{r}}{\mathrm{d} \boldsymbol{y}}\) - derivative of the local residual \(\boldsymbol{r}\),

  3. \(\boldsymbol{C}_\text{tang} = \frac{\mathrm{d}\boldsymbol{\sigma}}{\mathrm{d}\boldsymbol{\varepsilon}}\) - stress tensor derivative or consistent tangent moduli.

Defining yield surface and plastic potential#

First of all, we define supplementary functions that help us to express the yield surface \(f\) and the plastic potential \(g\). In the following definitions, we use built-in functions of the JAX package, in particular, the conditional primitive jax.lax.cond. It is necessary for the correct work of the AD tool and just-in-time compilation. For more details, please, visit the JAX documentation.

def J3(s):
    return s[2] * (s[0] * s[1] - s[3] * s[3] / 2.0)


def J2(s):
    return 0.5 * jnp.vdot(s, s)


def theta(s):
    J2_ = J2(s)
    arg = -(3.0 * np.sqrt(3.0) * J3(s)) / (2.0 * jnp.sqrt(J2_ * J2_ * J2_))
    arg = jnp.clip(arg, -1.0, 1.0)
    theta = 1.0 / 3.0 * jnp.arcsin(arg)
    return theta


def sign(x):
    return jax.lax.cond(x < 0.0, lambda x: -1, lambda x: 1, x)


def coeff1(theta, angle):
    return np.cos(theta_T) - (1.0 / np.sqrt(3.0)) * np.sin(angle) * np.sin(theta_T)


def coeff2(theta, angle):
    return sign(theta) * np.sin(theta_T) + (1.0 / np.sqrt(3.0)) * np.sin(angle) * np.cos(theta_T)


coeff3 = 18.0 * np.cos(3.0 * theta_T) * np.cos(3.0 * theta_T) * np.cos(3.0 * theta_T)


def C(theta, angle):
    return (
        -np.cos(3.0 * theta_T) * coeff1(theta, angle) - 3.0 * sign(theta) * np.sin(3.0 * theta_T) * coeff2(theta, angle)
    ) / coeff3


def B(theta, angle):
    return (
        sign(theta) * np.sin(6.0 * theta_T) * coeff1(theta, angle) - 6.0 * np.cos(6.0 * theta_T) * coeff2(theta, angle)
    ) / coeff3


def A(theta, angle):
    return (
        -(1.0 / np.sqrt(3.0)) * np.sin(angle) * sign(theta) * np.sin(theta_T)
        - B(theta, angle) * sign(theta) * np.sin(3 * theta_T)
        - C(theta, angle) * np.sin(3.0 * theta_T) * np.sin(3.0 * theta_T)
        + np.cos(theta_T)
    )


def K(theta, angle):
    def K_false(theta):
        return jnp.cos(theta) - (1.0 / np.sqrt(3.0)) * np.sin(angle) * jnp.sin(theta)

    def K_true(theta):
        return (
            A(theta, angle)
            + B(theta, angle) * jnp.sin(3.0 * theta)
            + C(theta, angle) * jnp.sin(3.0 * theta) * jnp.sin(3.0 * theta)
        )

    return jax.lax.cond(jnp.abs(theta) > theta_T, K_true, K_false, theta)


def a_g(angle):
    return a * np.tan(phi) / np.tan(angle)


dev = np.array(
    [
        [2.0 / 3.0, -1.0 / 3.0, -1.0 / 3.0, 0.0],
        [-1.0 / 3.0, 2.0 / 3.0, -1.0 / 3.0, 0.0],
        [-1.0 / 3.0, -1.0 / 3.0, 2.0 / 3.0, 0.0],
        [0.0, 0.0, 0.0, 1.0],
    ],
    dtype=PETSc.ScalarType,
)
tr = np.array([1.0, 1.0, 1.0, 0.0], dtype=PETSc.ScalarType)


def surface(sigma_local, angle):
    s = dev @ sigma_local
    I1 = tr @ sigma_local
    theta_ = theta(s)
    return (
        (I1 / 3.0 * np.sin(angle))
        + jnp.sqrt(
            J2(s) * K(theta_, angle) * K(theta_, angle) + a_g(angle) * a_g(angle) * np.sin(angle) * np.sin(angle)
        )
        - c * np.cos(angle)
    )

By picking up an appropriate angle we define the yield surface \(f\) and the plastic potential \(g\).

def f(sigma_local):
    return surface(sigma_local, phi)


def g(sigma_local):
    return surface(sigma_local, psi)


dgdsigma = jax.jacfwd(g)

Solving constitutive equations#

In this section, we define the constitutive model by solving the systems (5)(6). They must be solved at each Gauss point, so we apply the Newton method, implement the whole algorithm locally and then vectorize the final result using jax.vmap.

In the following cell, we define locally the residual \(\boldsymbol{r}\) and its Jacobian drdy.

lmbda = E * nu / ((1.0 + nu) * (1.0 - 2.0 * nu))
mu = E / (2.0 * (1.0 + nu))
C_elas = np.array(
    [
        [lmbda + 2 * mu, lmbda, lmbda, 0],
        [lmbda, lmbda + 2 * mu, lmbda, 0],
        [lmbda, lmbda, lmbda + 2 * mu, 0],
        [0, 0, 0, 2 * mu],
    ],
    dtype=PETSc.ScalarType,
)
S_elas = np.linalg.inv(C_elas)
ZERO_VECTOR = np.zeros(stress_dim, dtype=PETSc.ScalarType)


def deps_p(sigma_local, dlambda, deps_local, sigma_n_local):
    sigma_elas_local = sigma_n_local + C_elas @ deps_local
    yielding = f(sigma_elas_local)

    def deps_p_elastic(sigma_local, dlambda):
        return ZERO_VECTOR

    def deps_p_plastic(sigma_local, dlambda):
        return dlambda * dgdsigma(sigma_local)

    return jax.lax.cond(yielding <= 0.0, deps_p_elastic, deps_p_plastic, sigma_local, dlambda)


def r_g(sigma_local, dlambda, deps_local, sigma_n_local):
    deps_p_local = deps_p(sigma_local, dlambda, deps_local, sigma_n_local)
    return sigma_local - sigma_n_local - C_elas @ (deps_local - deps_p_local)


def r_f(sigma_local, dlambda, deps_local, sigma_n_local):
    sigma_elas_local = sigma_n_local + C_elas @ deps_local
    yielding = f(sigma_elas_local)

    def r_f_elastic(sigma_local, dlambda):
        return dlambda

    def r_f_plastic(sigma_local, dlambda):
        return f(sigma_local)

    return jax.lax.cond(yielding <= 0.0, r_f_elastic, r_f_plastic, sigma_local, dlambda)


def r(y_local, deps_local, sigma_n_local):
    sigma_local = y_local[:stress_dim]
    dlambda_local = y_local[-1]

    res_g = r_g(sigma_local, dlambda_local, deps_local, sigma_n_local)
    res_f = r_f(sigma_local, dlambda_local, deps_local, sigma_n_local)

    res = jnp.c_["0,1,-1", res_g, res_f]  # concatenates an array and a scalar
    return res


drdy = jax.jacfwd(r)

Then we define the function return_mapping that implements the return-mapping algorithm numerically via the Newton method.

Nitermax, tol = 200, 1e-10

ZERO_SCALAR = np.array([0.0])


def return_mapping(deps_local, sigma_n_local):
    """Performs the return-mapping procedure.

    It solves elastoplastic constitutive equations numerically by applying the
    Newton method in a single Gauss point. The Newton loop is implement via
    `jax.lax.while_loop`.

    The function returns `sigma_local` two times to reuse its values after
    differentiation, i.e. as once we apply
    `jax.jacfwd(return_mapping, has_aux=True)` the ouput function will
    have an output of
    `(C_tang_local, (sigma_local, niter_total, yielding, norm_res, dlambda))`.

    Returns:
        sigma_local: The stress at the current Gauss point.
        niter_total: The total number of iterations.
        yielding: The value of the yield function.
        norm_res: The norm of the residuals.
        dlambda: The value of the plastic multiplier.
    """
    niter = 0

    dlambda = ZERO_SCALAR
    sigma_local = sigma_n_local
    y_local = jnp.concatenate([sigma_local, dlambda])

    res = r(y_local, deps_local, sigma_n_local)
    norm_res0 = jnp.linalg.norm(res)

    def cond_fun(state):
        norm_res, niter, _ = state
        return jnp.logical_and(norm_res / norm_res0 > tol, niter < Nitermax)

    def body_fun(state):
        norm_res, niter, history = state

        y_local, deps_local, sigma_n_local, res = history

        j = drdy(y_local, deps_local, sigma_n_local)
        j_inv_vp = jnp.linalg.solve(j, -res)
        y_local = y_local + j_inv_vp

        res = r(y_local, deps_local, sigma_n_local)
        norm_res = jnp.linalg.norm(res)
        history = y_local, deps_local, sigma_n_local, res

        niter += 1

        return (norm_res, niter, history)

    history = (y_local, deps_local, sigma_n_local, res)

    norm_res, niter_total, y_local = jax.lax.while_loop(cond_fun, body_fun, (norm_res0, niter, history))

    sigma_local = y_local[0][:stress_dim]
    dlambda = y_local[0][-1]
    sigma_elas_local = C_elas @ deps_local
    yielding = f(sigma_n_local + sigma_elas_local)

    return sigma_local, (sigma_local, niter_total, yielding, norm_res, dlambda)

Consistent tangent stiffness matrix#

Not only is the automatic differentiation able to compute the derivative of a mathematical expression but also a numerical algorithm. For instance, AD can calculate the derivative of the function performing return-mapping with respect to its output, the stress tensor \(\boldsymbol{\sigma}\). In the context of the consistent tangent moduli \(\boldsymbol{C}_\text{tang}\), this feature becomes very useful, as there is no need to write an additional program computing the stress derivative.

JAX’s AD tool permits taking the derivative of the function return_mapping, which is factually the while loop. The derivative is taken with respect to the first output and the remaining outputs are used as auxiliary data. Thus, the derivative dsigma_ddeps returns both values of the consistent tangent moduli and the stress tensor, so there is no need in a supplementary computation of the stress tensor.

dsigma_ddeps = jax.jacfwd(return_mapping, has_aux=True)

Defining external operator#

Once we define the function dsigma_ddeps, which evaluates both the external operator and its derivative locally, we can simply vectorize it and define the final implementation of the external operator derivative.

Note

The function dsigma_ddeps containing a while_loop is designed to be called at a single Gauss point that’s why we need to vectorize it for the all points of our functional space S. For this purpose we use the vmap function of JAX. It creates another while_loop, which terminates only when all mapped loops terminate. Find further details in this discussion.

dsigma_ddeps_vec = jax.jit(jax.vmap(dsigma_ddeps, in_axes=(0, 0)))


def C_tang_impl(deps):
    deps_ = deps.reshape((-1, stress_dim))
    sigma_n_ = sigma_n.x.array.reshape((-1, stress_dim))

    (C_tang_global, state) = dsigma_ddeps_vec(deps_, sigma_n_)
    sigma_global, niter, yielding, norm_res, dlambda = state

    unique_iters, counts = jnp.unique(niter, return_counts=True)

    print("\tInner Newton summary:")
    print(f"\t\tUnique number of iterations: {unique_iters}")
    print(f"\t\tCounts of unique number of iterations: {counts}")
    print(f"\t\tMaximum f: {jnp.max(yielding)}")
    print(f"\t\tMaximum residual: {jnp.max(norm_res)}")

    return C_tang_global.reshape(-1), sigma_global.reshape(-1)

Similarly to the von Mises example, we do not implement explicitly the evaluation of the external operator. Instead, we obtain its values during the evaluation of its derivative and then update the values of the operator in the main Newton loop.

def sigma_external(derivatives):
    if derivatives == (1,):
        return C_tang_impl
    else:
        return NotImplementedError


sigma.external_function = sigma_external

Defining the forms#

q = fem.Constant(domain, default_scalar_type((0, -gamma)))


def F_ext(v):
    return ufl.dot(q, v) * dx


u_hat = ufl.TrialFunction(V)
F = ufl.inner(epsilon(v), sigma) * dx - F_ext(v)
J = ufl.derivative(F, Du, u_hat)
J_expanded = ufl.algorithms.expand_derivatives(J)

F_replaced, F_external_operators = replace_external_operators(F)
J_replaced, J_external_operators = replace_external_operators(J_expanded)

F_form = fem.form(F_replaced)
J_form = fem.form(J_replaced)

Variables initialization and compilation#

Before solving the problem we have to initialize values of the stiffness matrix, as it requires for the system assembling. During the first loading step, we expect an elastic response only, so it’s enough to solve the constitutive equations for a relatively small displacement field at each Gauss point. This results in initializing the consistent tangent moduli with elastic ones.

Du.x.array[:] = 1.0
sigma_n.x.array[:] = 0.0

evaluated_operands = evaluate_operands(F_external_operators)
_ = evaluate_external_operators(J_external_operators, evaluated_operands)
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [15000]
		Maximum f: -2.2109628558449494
		Maximum residual: 0.0

Solving the problem#

Summing up, we apply the Newton method to solve the main weak problem. On each iteration of the main Newton loop, we solve elastoplastic constitutive equations by using the second (inner) Newton method at each Gauss point. Thanks to the framework and the JAX library, the final interface is general enough to be applied to other plasticity models.

external_operator_problem = LinearProblem(J_replaced, -F_replaced, Du, bcs=bcs)
x_point = np.array([[0, H, 0]])
cells, points_on_process = find_cell_by_point(domain, x_point)
# parameters of the manual Newton method
max_iterations, relative_tolerance = 200, 1e-8

load_steps_1 = np.linspace(2, 21, 40)
load_steps_2 = np.linspace(21, 22.75, 20)[1:]
load_steps = np.concatenate([load_steps_1, load_steps_2])
num_increments = len(load_steps)
results = np.zeros((num_increments + 1, 2))
for i, load in enumerate(load_steps):
    q.value = load * np.array([0, -gamma])
    external_operator_problem.assemble_vector()

    residual_0 = external_operator_problem.b.norm()
    residual = residual_0
    Du.x.array[:] = 0

    if MPI.COMM_WORLD.rank == 0:
        print(f"Load increment #{i}, load: {load}, initial residual: {residual_0}")

    for iteration in range(0, max_iterations):
        if residual / residual_0 < relative_tolerance:
            break

        if MPI.COMM_WORLD.rank == 0:
            print(f"\tOuter Newton iteration #{iteration}")
        external_operator_problem.assemble_matrix()
        external_operator_problem.solve(du)

        Du.x.petsc_vec.axpy(1.0, du.x.petsc_vec)
        Du.x.scatter_forward()

        evaluated_operands = evaluate_operands(F_external_operators)
        ((_, sigma_new),) = evaluate_external_operators(J_external_operators, evaluated_operands)
        # Direct access to the external operator values
        sigma.ref_coefficient.x.array[:] = sigma_new

        external_operator_problem.assemble_vector()
        residual = external_operator_problem.b.norm()

        if MPI.COMM_WORLD.rank == 0:
            print(f"\tResidual: {residual}\n")

    u.x.petsc_vec.axpy(1.0, Du.x.petsc_vec)
    u.x.scatter_forward()

    sigma_n.x.array[:] = sigma.ref_coefficient.x.array

    if len(points_on_process) > 0:
        results[i + 1, :] = (-u.eval(points_on_process, cells)[0], load)

print(f"Slope stability factor: {-q.value[-1]*H/c}")
Load increment #0, load: 2.0, initial residual: 0.027573900703382316
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [15000]
		Maximum f: -0.2983932974714918
		Maximum residual: 0.0
	Residual: 7.890919449890783e-14

Load increment #1, load: 2.4871794871794872, initial residual: 0.006716719402111263
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3]
		Counts of unique number of iterations: [14999     1]
		Maximum f: 0.28795356169797115
		Maximum residual: 1.6069532994136943e-15
	Residual: 0.002218734438292365

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 4]
		Counts of unique number of iterations: [14999     1]
		Maximum f: 0.4490497372726332
		Maximum residual: 4.74491267375125e-16
	Residual: 8.901387764604577e-05

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 4]
		Counts of unique number of iterations: [14999     1]
		Maximum f: 0.4526047639992705
		Maximum residual: 5.66795677352112e-16
	Residual: 8.72870007368659e-08

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 4]
		Counts of unique number of iterations: [14999     1]
		Maximum f: 0.4526085284802486
		Maximum residual: 9.23344865209648e-16
	Residual: 8.981949346521205e-14

Load increment #2, load: 2.9743589743589745, initial residual: 0.0067167194021058516
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [14998     1     1]
		Maximum f: 0.959060060256649
		Maximum residual: 1.7759412808265995e-11
	Residual: 0.004802356196088396

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 4]
		Counts of unique number of iterations: [14998     2]
		Maximum f: 1.1348741666216449
		Maximum residual: 4.526797389797801e-16
	Residual: 9.211661510543003e-05

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 4]
		Counts of unique number of iterations: [14998     2]
		Maximum f: 1.1350924181953492
		Maximum residual: 9.278484370861693e-16
	Residual: 1.693176317199017e-08

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 4]
		Counts of unique number of iterations: [14998     2]
		Maximum f: 1.1350931611448787
		Maximum residual: 7.563076050675473e-16
	Residual: 6.291070770855705e-15

Load increment #3, load: 3.4615384615384617, initial residual: 0.0067167194021059504
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 4 5]
		Counts of unique number of iterations: [14996     3     1]
		Maximum f: 1.2094005593143735
		Maximum residual: 2.1599919653272044e-15
	Residual: 0.004577763989326799

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14995     1     3     1]
		Maximum f: 1.3971321181728666
		Maximum residual: 4.523450768197638e-12
	Residual: 0.0008492027216157221

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14995     1     3     1]
		Maximum f: 1.4248006636524875
		Maximum residual: 4.312760339536258e-14
	Residual: 5.8039785871004464e-06

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14995     1     3     1]
		Maximum f: 1.4248654126369007
		Maximum residual: 4.778100226722301e-14
	Residual: 1.7549459842413692e-10

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14995     1     3     1]
		Maximum f: 1.4248654164118402
		Maximum residual: 4.710855903239195e-14
	Residual: 6.115624526774481e-15

Load increment #4, load: 3.948717948717949, initial residual: 0.006716719402106029
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14992     2     4     2]
		Maximum f: 1.6622114419616891
		Maximum residual: 2.284009519320461e-11
	Residual: 0.005991367307375562

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 4 5 6]
		Counts of unique number of iterations: [14993     5     1     1]
		Maximum f: 1.9712452060954653
		Maximum residual: 1.9719683348635102e-14
	Residual: 0.0008411326742961785

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 4 5 6]
		Counts of unique number of iterations: [14993     5     1     1]
		Maximum f: 1.9669201227931743
		Maximum residual: 2.7998841883107362e-14
	Residual: 1.0331906341028766e-06

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 4 5 6]
		Counts of unique number of iterations: [14993     5     1     1]
		Maximum f: 1.9669904350357945
		Maximum residual: 2.542481932384309e-14
	Residual: 6.564572860183445e-12

Load increment #5, load: 4.435897435897436, initial residual: 0.006716719402089759
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14988     3     7     1     1]
		Maximum f: 2.1624454001146183
		Maximum residual: 9.216681243241093e-11
	Residual: 0.0063871311640514685

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14986     2    10     2]
		Maximum f: 2.2145902632058116
		Maximum residual: 9.457400774077234e-11
	Residual: 0.0018311090776529243

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14986     2    10     2]
		Maximum f: 2.247831186097726
		Maximum residual: 7.939428618327782e-11
	Residual: 6.1860511095659675e-06

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14986     2    10     2]
		Maximum f: 2.247847319546333
		Maximum residual: 7.941305609091945e-11
	Residual: 1.1402764366337778e-10

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14986     2    10     2]
		Maximum f: 2.247847321965589
		Maximum residual: 7.94128946420124e-11
	Residual: 6.311161812942645e-15

Load increment #6, load: 4.923076923076923, initial residual: 0.006716719402105907
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14980     4    13     3]
		Maximum f: 2.412340741822826
		Maximum residual: 1.9281690307661182e-11
	Residual: 0.007208784346729728

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14978     4    15     3]
		Maximum f: 2.6373092115923655
		Maximum residual: 3.947193023626819e-11
	Residual: 0.0010890919985447423

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14978     3    17     2]
		Maximum f: 2.657703882653005
		Maximum residual: 7.343964781767663e-11
	Residual: 1.2600498418665459e-05

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14978     3    17     2]
		Maximum f: 2.657821730601326
		Maximum residual: 7.345380461870128e-11
	Residual: 1.2853238681401728e-09

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14978     3    17     2]
		Maximum f: 2.6578217428611954
		Maximum residual: 7.345313605854967e-11
	Residual: 6.215559826041168e-15

Load increment #7, load: 5.410256410256411, initial residual: 0.006716719402105902
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14971     4    20     5]
		Maximum f: 2.8972284361943648
		Maximum residual: 1.5932847856711257e-10
	Residual: 0.007185163108330384

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14970     4    20     6]
		Maximum f: 3.074460151265438
		Maximum residual: 1.4841855612405854e-10
	Residual: 0.003635385508342228

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14969     4    21     6]
		Maximum f: 3.1068353261265877
		Maximum residual: 1.6963349703642361e-10
	Residual: 0.0006614649205114514

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14969     5    20     6]
		Maximum f: 3.111336688258294
		Maximum residual: 1.751536628153783e-10
	Residual: 2.3272859544284286e-06

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14969     4    21     6]
		Maximum f: 3.1113629314934976
		Maximum residual: 1.751781846454831e-10
	Residual: 1.6043351964469255e-11

Load increment #8, load: 5.897435897435898, initial residual: 0.00671671940209461
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14956    16    25     3]
		Maximum f: 3.3170712973542007
		Maximum residual: 1.73745936022647e-10
	Residual: 0.008492114295014659

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14957     9    27     7]
		Maximum f: 3.6710312133149317
		Maximum residual: 1.311813565498907e-10
	Residual: 0.005488364080386103

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14956     9    28     7]
		Maximum f: 3.6914352462087643
		Maximum residual: 2.349449759274743e-10
	Residual: 8.319437470582816e-05

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14956     9    28     7]
		Maximum f: 3.6914214031303962
		Maximum residual: 2.575476364059704e-10
	Residual: 1.404834945506049e-07

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14956     9    28     7]
		Maximum f: 3.6914216019955455
		Maximum residual: 2.575471385257117e-10
	Residual: 1.5064168133037005e-13

Load increment #9, load: 6.384615384615384, initial residual: 0.006716719402106336
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14936    22    35     7]
		Maximum f: 4.045554233015363
		Maximum residual: 3.987995294010693e-11
	Residual: 0.008836812867953047

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14935    21    38     5     1]
		Maximum f: 4.316638571781837
		Maximum residual: 3.0069052121590976e-10
	Residual: 0.010718894285201485

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14934    21    39     5     1]
		Maximum f: 4.355314382789103
		Maximum residual: 3.109200840857651e-10
	Residual: 0.0002441838471887533

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14934    21    39     5     1]
		Maximum f: 4.356975910983897
		Maximum residual: 3.170287487032596e-10
	Residual: 1.3661176406774217e-07

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14934    21    39     5     1]
		Maximum f: 4.35697629965423
		Maximum residual: 3.1703029219510236e-10
	Residual: 5.149272482380523e-14

Load increment #10, load: 6.871794871794871, initial residual: 0.0067167194021057934
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14915    33    44     8]
		Maximum f: 4.738994089491671
		Maximum residual: 2.1753344521763023e-10
	Residual: 0.008959963113172967

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14915    28    47    10]
		Maximum f: 4.994258511340817
		Maximum residual: 7.209303188255585e-11
	Residual: 0.011397174281186211

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14913    30    47    10]
		Maximum f: 5.018175040325893
		Maximum residual: 8.131241815887432e-11
	Residual: 0.00034539497630206546

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14913    30    47    10]
		Maximum f: 5.021313078557187
		Maximum residual: 8.056295300680725e-11
	Residual: 2.2738134262025453e-07

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14913    30    47    10]
		Maximum f: 5.021314344097126
		Maximum residual: 8.056312787730837e-11
	Residual: 1.3677781358889237e-13

Load increment #11, load: 7.358974358974359, initial residual: 0.006716719402106081
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14895     1    39    56     9]
		Maximum f: 5.425401278773759
		Maximum residual: 1.3371143159723552e-10
	Residual: 0.008873442816314499

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14892    39    60     9]
		Maximum f: 5.660279604155246
		Maximum residual: 3.049244742487409e-10
	Residual: 0.0015471905099869799

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14892    39    60     9]
		Maximum f: 5.6778106598755205
		Maximum residual: 3.202075274385731e-10
	Residual: 1.0838415950050226e-05

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14892    39    60     9]
		Maximum f: 5.677921603622375
		Maximum residual: 3.20125841523144e-10
	Residual: 3.9821681987266934e-10

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14892    39    60     9]
		Maximum f: 5.677921607038627
		Maximum residual: 3.2012576635724396e-10
	Residual: 6.768042292143486e-15

Load increment #12, load: 7.846153846153846, initial residual: 0.006716719402105904
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14876    46    70     8]
		Maximum f: 6.049350890460289
		Maximum residual: 6.874724052290696e-10
	Residual: 0.009096351314687754

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14872    45    74     9]
		Maximum f: 6.239529353869511
		Maximum residual: 4.087989408182388e-10
	Residual: 0.0016211998652745385

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14871    44    77     8]
		Maximum f: 6.25076803298351
		Maximum residual: 4.477617981359121e-10
	Residual: 0.0001560318306305226

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14871    44    77     8]
		Maximum f: 6.250264936882839
		Maximum residual: 4.6403048773060224e-10
	Residual: 1.0783414271764713e-07

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14871    44    77     8]
		Maximum f: 6.250266601223057
		Maximum residual: 4.6403632351628526e-10
	Residual: 2.7402163026108637e-14

Load increment #13, load: 8.333333333333332, initial residual: 0.006716719402105762
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14856    54    82     8]
		Maximum f: 6.585637001276497
		Maximum residual: 3.090222282294581e-10
	Residual: 0.009007878917959499

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14855    51    83    11]
		Maximum f: 6.755125612945211
		Maximum residual: 4.649418472520499e-10
	Residual: 0.0012749818287204335

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14853    53    83    11]
		Maximum f: 6.765747785088715
		Maximum residual: 4.664596923508608e-10
	Residual: 0.00030322429329111725

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14853    53    83    11]
		Maximum f: 6.76707137611937
		Maximum residual: 4.6559541898992416e-10
	Residual: 1.489896642179165e-07

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14853    53    83    11]
		Maximum f: 6.767071956213789
		Maximum residual: 4.655959223750781e-10
	Residual: 6.224191646409628e-14

Load increment #14, load: 8.820512820512821, initial residual: 0.006716719402106516
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14840    62    94     4]
		Maximum f: 7.113663502374457
		Maximum residual: 4.240201862913906e-10
	Residual: 0.008041984424249974

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14836    62    98     4]
		Maximum f: 7.25296041059803
		Maximum residual: 4.537294457934902e-10
	Residual: 0.0013241251663798853

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14834    64    98     4]
		Maximum f: 7.257375794857158
		Maximum residual: 6.022015489259446e-10
	Residual: 7.755831879993391e-05

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14834    64    98     4]
		Maximum f: 7.257098271755524
		Maximum residual: 6.121699829916764e-10
	Residual: 2.3409396873000134e-08

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14834    64    98     4]
		Maximum f: 7.257098617903164
		Maximum residual: 6.121838514040751e-10
	Residual: 7.329576817357213e-15

Load increment #15, load: 9.307692307692307, initial residual: 0.00671671940210589
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14823    72   100     5]
		Maximum f: 7.58041627942597
		Maximum residual: 4.3583780272924964e-10
	Residual: 0.006820691030113568

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14820    71   104     5]
		Maximum f: 7.6939920654095495
		Maximum residual: 4.272594756845077e-10
	Residual: 0.0004652039135065601

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14820    71   104     5]
		Maximum f: 7.6958712038617545
		Maximum residual: 4.500866160302736e-10
	Residual: 5.193518789243318e-07

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14820    71   104     5]
		Maximum f: 7.695876907322399
		Maximum residual: 4.5007848603537275e-10
	Residual: 7.659269634626732e-13

Load increment #16, load: 9.794871794871796, initial residual: 0.006716719402108467
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14808    90    97     5]
		Maximum f: 7.99050966817293
		Maximum residual: 5.013465405904206e-10
	Residual: 0.006701736965509444

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14805    87   103     5]
		Maximum f: 8.089550459713337
		Maximum residual: 5.140883706623874e-10
	Residual: 0.0008807439219532116

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14805    88   102     5]
		Maximum f: 8.091632210151511
		Maximum residual: 5.259422770664863e-10
	Residual: 2.8174946872392068e-06

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14805    88   102     5]
		Maximum f: 8.091634648275788
		Maximum residual: 5.259424452988232e-10
	Residual: 3.6270020933263514e-11

Load increment #17, load: 10.282051282051281, initial residual: 0.006716719402117765
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14795   112    85     8]
		Maximum f: 8.358501355606794
		Maximum residual: 5.712738175174098e-10
	Residual: 0.0063306127750015345

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14793   113    86     8]
		Maximum f: 8.437903718526334
		Maximum residual: 5.980822538472252e-10
	Residual: 0.00045101679349716207

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14793   113    86     8]
		Maximum f: 8.44003484251495
		Maximum residual: 5.98215613283621e-10
	Residual: 6.877522470581463e-07

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14793   113    86     8]
		Maximum f: 8.440036786591405
		Maximum residual: 5.982155862537371e-10
	Residual: 1.6053404660618207e-12

Load increment #18, load: 10.769230769230768, initial residual: 0.006716719402110631
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14780   135    78     7]
		Maximum f: 8.68456502127428
		Maximum residual: 5.552947981106763e-10
	Residual: 0.006321576427286919

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14779     1   131    82     7]
		Maximum f: 8.752245478545689
		Maximum residual: 5.661125049159691e-10
	Residual: 0.0006104928700975985

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14778     1   132    82     7]
		Maximum f: 8.7529350521491
		Maximum residual: 5.662446230508161e-10
	Residual: 0.00023235849305208804

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14778     1   132    82     7]
		Maximum f: 8.753156784434015
		Maximum residual: 5.664745903123164e-10
	Residual: 2.9780061029475124e-07

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14778     1   132    82     7]
		Maximum f: 8.753157009998258
		Maximum residual: 5.664766281087768e-10
	Residual: 3.8016844821466933e-13

Load increment #19, load: 11.256410256410255, initial residual: 0.006716719402108031
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14765   151    78     5     1]
		Maximum f: 8.958294424841467
		Maximum residual: 7.238628227290251e-10
	Residual: 0.008132000241176851

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14762   152    80     5     1]
		Maximum f: 8.998791074465823
		Maximum residual: 7.227711446385966e-10
	Residual: 0.001997392238883473

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14762   152    80     5     1]
		Maximum f: 9.005204370557509
		Maximum residual: 7.290349944368096e-10
	Residual: 2.52599622998144e-05

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14762   152    80     5     1]
		Maximum f: 9.005208512201829
		Maximum residual: 7.290424660974518e-10
	Residual: 2.4448542002402695e-09

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14762   152    80     5     1]
		Maximum f: 9.005208513681765
		Maximum residual: 7.290401961831911e-10
	Residual: 7.70071218194023e-15

Load increment #20, load: 11.743589743589743, initial residual: 0.006716719402105775
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14746   175    74     5]
		Maximum f: 9.133896594623621
		Maximum residual: 7.588024190511681e-10
	Residual: 0.008896728090045757

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14745   174    76     5]
		Maximum f: 9.192693190502789
		Maximum residual: 8.056791555890222e-10
	Residual: 0.0038934284784811137

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14745   174    76     5]
		Maximum f: 9.19757041344616
		Maximum residual: 8.358968531175626e-10
	Residual: 3.796847475201592e-05

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14745   174    76     5]
		Maximum f: 9.197584052522377
		Maximum residual: 8.360896363124227e-10
	Residual: 4.2398571608009855e-09

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14745   174    76     5]
		Maximum f: 9.197584055106221
		Maximum residual: 8.360761170549043e-10
	Residual: 7.937564969278366e-15

Load increment #21, load: 12.23076923076923, initial residual: 0.006716719402105997
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14728   199    68     4     1]
		Maximum f: 9.309318026977646
		Maximum residual: 7.745733179621374e-10
	Residual: 0.009216976204473903

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14727   196    71     4     2]
		Maximum f: 9.372210255858127
		Maximum residual: 6.614291898767247e-10
	Residual: 0.005677562280785258

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14728   196    70     4     2]
		Maximum f: 9.375041989619428
		Maximum residual: 6.643537217422253e-10
	Residual: 0.00019877343896580445

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14728   196    70     4     2]
		Maximum f: 9.37511531688497
		Maximum residual: 6.644318148463891e-10
	Residual: 7.633720770281903e-09

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14728   196    70     4     2]
		Maximum f: 9.375115318687154
		Maximum residual: 6.644304468516605e-10
	Residual: 7.98331332399425e-15

Load increment #22, load: 12.717948717948717, initial residual: 0.0067167194021062184
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14703   230    63     4]
		Maximum f: 9.631734472003261
		Maximum residual: 9.004495644739933e-10
	Residual: 0.010103758517920908

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14703   227    66     4]
		Maximum f: 10.266067965825984
		Maximum residual: 6.579958639939529e-10
	Residual: 0.005635378524062774

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14702   227    67     4]
		Maximum f: 10.294537007242775
		Maximum residual: 6.625034045463842e-10
	Residual: 0.0006453068279120127

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14703   226    67     4]
		Maximum f: 10.305089259960004
		Maximum residual: 6.634648277637736e-10
	Residual: 2.3771921489821883e-06

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14703   226    67     4]
		Maximum f: 10.305140011501196
		Maximum residual: 6.634647552212982e-10
	Residual: 3.204867000788119e-12

Load increment #23, load: 13.205128205128204, initial residual: 0.006716719402092674
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14669   269    58     3     1]
		Maximum f: 11.025940505751029
		Maximum residual: 1.0175534866991814e-09
	Residual: 0.010006180949859491

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14666   271    59     3     1]
		Maximum f: 12.212967525109589
		Maximum residual: 4.845631986426588e-10
	Residual: 0.0061385342969422834

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14665   272    59     3     1]
		Maximum f: 12.289198016073618
		Maximum residual: 4.876787289465884e-10
	Residual: 0.00044609546114351896

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14666   271    59     3     1]
		Maximum f: 12.292335568048466
		Maximum residual: 4.879192891750546e-10
	Residual: 2.9181610091638813e-05

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14666   271    59     3     1]
		Maximum f: 12.292564859874892
		Maximum residual: 4.87925770235352e-10
	Residual: 9.410641793846815e-10

	Outer Newton iteration #5
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [14666   271    59     3     1]
		Maximum f: 12.292564858474789
		Maximum residual: 4.879259072840858e-10
	Residual: 8.292064960245494e-15

Load increment #24, load: 13.692307692307692, initial residual: 0.006716719402105967
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14626     1   316    52     5]
		Maximum f: 13.752178237997418
		Maximum residual: 9.009687888385508e-10
	Residual: 0.009288088168082454

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14628     1   311    55     5]
		Maximum f: 14.526046465190406
		Maximum residual: 1.1118740115103616e-09
	Residual: 0.01748781673753544

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14628     1   310    56     5]
		Maximum f: 14.630545455974252
		Maximum residual: 5.60616468208064e-10
	Residual: 0.0011315237281224913

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14628     1   310    56     5]
		Maximum f: 14.644482371673819
		Maximum residual: 5.609460003864369e-10
	Residual: 1.407071390247005e-06

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14628     1   310    56     5]
		Maximum f: 14.644487569973952
		Maximum residual: 5.609465923503495e-10
	Residual: 3.956057118756387e-12

Load increment #25, load: 14.179487179487179, initial residual: 0.006716719402093062
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14585   358    52     5]
		Maximum f: 15.668897835243778
		Maximum residual: 5.453032629910085e-10
	Residual: 0.01020272827916171

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14585   358    53     4]
		Maximum f: 16.897476113500414
		Maximum residual: 9.014246075011788e-10
	Residual: 0.018204880744767896

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14586     1   356    53     4]
		Maximum f: 17.005942469911133
		Maximum residual: 9.296743108920492e-10
	Residual: 0.0017665245952441415

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14586     1   356    53     4]
		Maximum f: 17.010760706455805
		Maximum residual: 9.316588449987665e-10
	Residual: 2.4867469674712053e-06

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14586     1   356    53     4]
		Maximum f: 17.01075783454806
		Maximum residual: 9.316553741272862e-10
	Residual: 1.1558233726990843e-11

Load increment #26, load: 14.666666666666666, initial residual: 0.006716719402090535
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14537   407    54     2]
		Maximum f: 18.392943468550488
		Maximum residual: 4.929875170184392e-10
	Residual: 0.010739219690748508

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14537     1   409    51     2]
		Maximum f: 19.448449832009402
		Maximum residual: 8.562163022279948e-10
	Residual: 0.0217933127661377

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14537     1   408    52     2]
		Maximum f: 19.59381838166779
		Maximum residual: 6.971880626091294e-10
	Residual: 4.8782826483654025e-05

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14537     1   408    52     2]
		Maximum f: 19.593925369807053
		Maximum residual: 6.966564911098194e-10
	Residual: 6.694631074622378e-09

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14537     1   408    52     2]
		Maximum f: 19.593925341451143
		Maximum residual: 6.966587236269051e-10
	Residual: 8.901913864258566e-15

Load increment #27, load: 15.153846153846153, initial residual: 0.006716719402105616
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14480     2   463    53     2]
		Maximum f: 20.96674904419338
		Maximum residual: 5.524816866263788e-10
	Residual: 0.010417054367107544

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14484     3   457    53     3]
		Maximum f: 22.221433151368622
		Maximum residual: 1.2519566051086383e-09
	Residual: 0.0263233526704633

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14481     2   463    51     3]
		Maximum f: 22.41046695631673
		Maximum residual: 1.193688987740655e-09
	Residual: 0.0009187005626425985

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14481     3   463    50     3]
		Maximum f: 22.405318497174452
		Maximum residual: 1.1898570344001849e-09
	Residual: 6.865470701334167e-06

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14481     3   463    50     3]
		Maximum f: 22.405274561032304
		Maximum residual: 1.1898310708047024e-09
	Residual: 2.7254283617801256e-10

	Outer Newton iteration #5
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14481     3   463    50     3]
		Maximum f: 22.405274558653705
		Maximum residual: 1.1898327925901724e-09
	Residual: 9.24943173106498e-15

Load increment #28, load: 15.64102564102564, initial residual: 0.006716719402106811
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5 6]
		Counts of unique number of iterations: [14413     1   534    50     1     1]
		Maximum f: 24.305556090508148
		Maximum residual: 7.245119011666657e-10
	Residual: 0.010984385790507681

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 6]
		Counts of unique number of iterations: [14418     2   525    53     2]
		Maximum f: 25.37444834249622
		Maximum residual: 7.415107687445929e-10
	Residual: 0.02382245428860332

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 6]
		Counts of unique number of iterations: [14415     1   531    51     2]
		Maximum f: 25.524100278581628
		Maximum residual: 7.360521747922634e-10
	Residual: 0.0006321108883439714

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 6]
		Counts of unique number of iterations: [14415     1   531    51     2]
		Maximum f: 25.534937262325332
		Maximum residual: 7.426256324308136e-10
	Residual: 1.7407450784632632e-06

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 6]
		Counts of unique number of iterations: [14415     1   531    51     2]
		Maximum f: 25.53494655259069
		Maximum residual: 7.426369154619653e-10
	Residual: 1.4317163968530303e-11

Load increment #29, load: 16.128205128205128, initial residual: 0.006716719402092309
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14353   595    49     3]
		Maximum f: 27.388219241620916
		Maximum residual: 1.1171455360337583e-09
	Residual: 0.012201831023476982

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14356     1   588    52     3]
		Maximum f: 28.584558517236882
		Maximum residual: 8.951279335909449e-10
	Residual: 0.01978884561978989

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14353     1   592    51     3]
		Maximum f: 28.66830999695858
		Maximum residual: 9.048896009968985e-10
	Residual: 0.0009708591699704159

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14353   593    51     3]
		Maximum f: 28.666671779503833
		Maximum residual: 9.104507142514185e-10
	Residual: 8.220389578526364e-06

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14353   593    51     3]
		Maximum f: 28.666640183013357
		Maximum residual: 9.104666710763999e-10
	Residual: 3.342774079225019e-10

	Outer Newton iteration #5
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14353   593    51     3]
		Maximum f: 28.666640180949567
		Maximum residual: 9.104774313731228e-10
	Residual: 9.767688349230945e-15

Load increment #30, load: 16.615384615384613, initial residual: 0.006716719402105511
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14291   659    47     3]
		Maximum f: 30.3391927130098
		Maximum residual: 9.784813450454796e-10
	Residual: 0.01478105711605393

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14283   667    47     3]
		Maximum f: 31.862733719202538
		Maximum residual: 1.3808055558672501e-09
	Residual: 0.009732116539257358

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14281   670    46     3]
		Maximum f: 32.106479915367046
		Maximum residual: 1.325766598911644e-09
	Residual: 0.002454155058540014

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14281   669    47     3]
		Maximum f: 32.12478051590549
		Maximum residual: 1.3278569092003317e-09
	Residual: 1.6633409864762663e-06

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14281   669    47     3]
		Maximum f: 32.124794424982106
		Maximum residual: 1.3278662028463454e-09
	Residual: 5.307556350691181e-12

Load increment #31, load: 17.102564102564102, initial residual: 0.006716719402101599
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14206     1   743    46     4]
		Maximum f: 34.25259433815284
		Maximum residual: 6.978415824979157e-10
	Residual: 0.012973816245009347

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14197     2   752    44     5]
		Maximum f: 35.493859474729504
		Maximum residual: 9.938831281782743e-10
	Residual: 0.013085384417730949

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14194     1   757    43     5]
		Maximum f: 35.67214269181068
		Maximum residual: 1.083607682794471e-09
	Residual: 0.00047677811876405536

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14194     1   758    42     5]
		Maximum f: 35.68206047013779
		Maximum residual: 1.088269045877497e-09
	Residual: 5.414261404230504e-07

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14194     1   758    42     5]
		Maximum f: 35.682064702451214
		Maximum residual: 1.0882703100783492e-09
	Residual: 1.0062689861185195e-12

Load increment #32, load: 17.58974358974359, initial residual: 0.006716719402103525
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [14109     1   843    45     2]
		Maximum f: 37.82177871116257
		Maximum residual: 1.808789249190159e-09
	Residual: 0.013265121948815397

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14103   848    47     2]
		Maximum f: 39.33636640323676
		Maximum residual: 2.8117130862484794e-09
	Residual: 0.01734480664235657

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14102   850    46     2]
		Maximum f: 39.46098163507629
		Maximum residual: 2.9418443363735676e-09
	Residual: 0.002786063521334128

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14101   851    46     2]
		Maximum f: 39.494698402560275
		Maximum residual: 3.0329970201634423e-09
	Residual: 3.1504884435841264e-05

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14101   851    46     2]
		Maximum f: 39.494935809990665
		Maximum residual: 3.0337386051270403e-09
	Residual: 6.00411686489015e-09

	Outer Newton iteration #5
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [14101   851    46     2]
		Maximum f: 39.49493582681761
		Maximum residual: 3.033734001770251e-09
	Residual: 1.1368296277092335e-14

Load increment #33, load: 18.076923076923077, initial residual: 0.00671671940210664
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5 6]
		Counts of unique number of iterations: [14007     1   944    45     2     1]
		Maximum f: 41.909820887293506
		Maximum residual: 2.6755180982912257e-09
	Residual: 0.013883417990777544

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5 6]
		Counts of unique number of iterations: [13994     1   948    53     3     1]
		Maximum f: 43.30582535963698
		Maximum residual: 4.969424153365426e-10
	Residual: 0.013918974861119406

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [13991   953    51     4     1]
		Maximum f: 43.46183448724685
		Maximum residual: 3.4008510125659406e-10
	Residual: 0.0006913086148759544

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [13991   953    51     4     1]
		Maximum f: 43.47158167047849
		Maximum residual: 3.4270678108275825e-10
	Residual: 1.31120114864905e-06

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [13991   953    51     4     1]
		Maximum f: 43.471584920436754
		Maximum residual: 3.4270290917925527e-10
	Residual: 3.627261085127511e-12

Load increment #34, load: 18.564102564102562, initial residual: 0.006716719402095659
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [13902  1038    58     2]
		Maximum f: 45.91480430149897
		Maximum residual: 2.4304431415047483e-09
	Residual: 0.01381553512146383

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [13893  1048    58     1]
		Maximum f: 47.624646296211225
		Maximum residual: 1.2653203412513818e-09
	Residual: 0.01610609449294831

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [13889  1052    58     1]
		Maximum f: 47.781394618393875
		Maximum residual: 6.831015994437234e-10
	Residual: 0.0007718361866142185

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [13888  1053    58     1]
		Maximum f: 47.79394062638336
		Maximum residual: 6.903306999643218e-10
	Residual: 1.9206165864719544e-05

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [13888  1053    58     1]
		Maximum f: 47.793970756076355
		Maximum residual: 6.90310704929693e-10
	Residual: 4.1750719303444733e-10

	Outer Newton iteration #5
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [13888  1053    58     1]
		Maximum f: 47.79397075829508
		Maximum residual: 6.903112059680749e-10
	Residual: 1.1863605400930496e-14

Load increment #35, load: 19.05128205128205, initial residual: 0.006716719402106133
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5 6]
		Counts of unique number of iterations: [13785     1  1148    61     4     1]
		Maximum f: 50.14385631632916
		Maximum residual: 1.5157200017471618e-09
	Residual: 0.015919379757795126

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [13779  1154    62     4     1]
		Maximum f: 52.23835692667192
		Maximum residual: 1.6622881621025106e-09
	Residual: 0.016147744791352343

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [13779  1152    65     3     1]
		Maximum f: 52.44076497459371
		Maximum residual: 1.776333851293984e-09
	Residual: 0.0013833112793794212

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [13778  1153    65     3     1]
		Maximum f: 52.45168162180874
		Maximum residual: 1.7871770307961866e-09
	Residual: 4.661542284659411e-05

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [13778  1153    65     3     1]
		Maximum f: 52.45212784374142
		Maximum residual: 1.7878879274289336e-09
	Residual: 1.5900780213460463e-08

	Outer Newton iteration #5
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5 6]
		Counts of unique number of iterations: [13778  1153    65     3     1]
		Maximum f: 52.452127946536066
		Maximum residual: 1.7878846245888235e-09
	Residual: 1.2671445931207544e-14

Load increment #36, load: 19.538461538461537, initial residual: 0.006716719402105679
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [13669     4  1252    70     5]
		Maximum f: 55.567850084860176
		Maximum residual: 1.5158017729600783e-09
	Residual: 0.01648733299209133

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [13653     1  1267    72     7]
		Maximum f: 57.520674863876835
		Maximum residual: 2.0360913781816354e-09
	Residual: 0.009555160880257305

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [13649     1  1271    72     7]
		Maximum f: 57.732233406463564
		Maximum residual: 2.276939297308774e-09
	Residual: 0.0012033324423391873

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [13649     1  1271    72     7]
		Maximum f: 57.7498544942572
		Maximum residual: 2.31029178844454e-09
	Residual: 4.145697724897417e-06

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [13649     1  1271    72     7]
		Maximum f: 57.74986829134545
		Maximum residual: 2.3103093135202066e-09
	Residual: 4.7209344668098317e-11

Load increment #37, load: 20.025641025641026, initial residual: 0.00671671940196218
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [13524  1396    76     4]
		Maximum f: 61.06991012598089
		Maximum residual: 1.0419547980406697e-09
	Residual: 0.01542877271383975

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [13507     3  1391    94     5]
		Maximum f: 63.56468035108713
		Maximum residual: 2.297434134061737e-09
	Residual: 0.014602468246210643

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [13506     2  1393    94     5]
		Maximum f: 63.79127578757072
		Maximum residual: 2.4736336179417338e-09
	Residual: 0.00215200848924207

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [13505     1  1395    94     5]
		Maximum f: 63.80501943470964
		Maximum residual: 2.474690416253466e-09
	Residual: 9.969640240656925e-05

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [13505     1  1395    94     5]
		Maximum f: 63.8064502879536
		Maximum residual: 2.47528087992554e-09
	Residual: 9.21988530170634e-08

	Outer Newton iteration #5
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [13505     1  1395    94     5]
		Maximum f: 63.8064507336262
		Maximum residual: 2.475281634210043e-09
	Residual: 3.290309010559536e-14

Load increment #38, load: 20.51282051282051, initial residual: 0.006716719402106353
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [13393     1  1501   101     4]
		Maximum f: 67.75666077388834
		Maximum residual: 8.892870228150307e-09
	Residual: 0.01751742965341938

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [13365  1516   115     4]
		Maximum f: 70.58578065154649
		Maximum residual: 1.4855313178726592e-09
	Residual: 0.015748989462208554

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [13365  1516   115     4]
		Maximum f: 70.93025143785108
		Maximum residual: 1.5953787472205064e-09
	Residual: 0.000516986539502159

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [13364  1517   115     4]
		Maximum f: 70.93404931644474
		Maximum residual: 1.5963803993495142e-09
	Residual: 1.490332273344886e-06

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 3 4 5]
		Counts of unique number of iterations: [13364  1517   115     4]
		Maximum f: 70.93406547628874
		Maximum residual: 1.5963962464711063e-09
	Residual: 1.2099815503055078e-11

Load increment #39, load: 21.0, initial residual: 0.006716719402106207
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [13215     1  1653   124     7]
		Maximum f: 75.39348146497682
		Maximum residual: 2.7200912417826727e-09
	Residual: 0.01837162002521085

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [13194     1  1667   130     8]
		Maximum f: 79.1519989054704
		Maximum residual: 5.528808783550887e-09
	Residual: 0.017161964830256334

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5 6]
		Counts of unique number of iterations: [13191     3  1668   130     7     1]
		Maximum f: 79.58593312017605
		Maximum residual: 6.539158263480048e-09
	Residual: 0.0013369309025360297

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5 6]
		Counts of unique number of iterations: [13191     2  1669   130     7     1]
		Maximum f: 79.6099646724377
		Maximum residual: 6.541417192423762e-09
	Residual: 2.920964415354603e-06

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5 6]
		Counts of unique number of iterations: [13191     2  1669   130     7     1]
		Maximum f: 79.6099757401031
		Maximum residual: 6.541350490563634e-09
	Residual: 3.630192208671106e-11

Load increment #40, load: 21.092105263157894, initial residual: 0.0012698506902034947
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13165    21  1795    19]
		Maximum f: 16.08975786803034
		Maximum residual: 1.5615146502863196e-10
	Residual: 0.0031321320109657545

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13163    27  1806     4]
		Maximum f: 16.10841316642004
		Maximum residual: 2.3868757553000546e-10
	Residual: 0.00013097513646595602

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13163    27  1806     4]
		Maximum f: 16.111750859726392
		Maximum residual: 2.4548461882246814e-10
	Residual: 9.146392777926918e-08

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13163    27  1806     4]
		Maximum f: 16.111751899955934
		Maximum residual: 2.4548539806946936e-10
	Residual: 3.1013134252967853e-14

Load increment #41, load: 21.18421052631579, initial residual: 0.001269850690287741
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13125    31  1841     3]
		Maximum f: 16.28406346683637
		Maximum residual: 2.461194646909824e-10
	Residual: 0.0021467155268750926

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13122    28  1847     3]
		Maximum f: 16.49460668053909
		Maximum residual: 2.5563213017639554e-10
	Residual: 8.480982743983298e-05

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13122    28  1847     3]
		Maximum f: 16.496151369432415
		Maximum residual: 2.5313063525550913e-10
	Residual: 3.10333266608704e-08

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13122    28  1847     3]
		Maximum f: 16.496151886790926
		Maximum residual: 2.531276339288553e-10
	Residual: 6.23343484070837e-15

Load increment #42, load: 21.276315789473685, initial residual: 0.0012698506902872947
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13084    31  1883     2]
		Maximum f: 16.788419796333756
		Maximum residual: 1.9126257404353311e-10
	Residual: 0.002799391161568597

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13081    33  1884     2]
		Maximum f: 17.034440423380282
		Maximum residual: 1.792912373663637e-10
	Residual: 0.00043717722019198104

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13081    33  1884     2]
		Maximum f: 17.041643481924492
		Maximum residual: 2.335616771412466e-10
	Residual: 1.7721777126625083e-06

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13081    33  1884     2]
		Maximum f: 17.041660096410972
		Maximum residual: 2.3360684847395073e-10
	Residual: 1.127712020673048e-11

Load increment #43, load: 21.36842105263158, initial residual: 0.001269850690298948
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13049    40  1909     2]
		Maximum f: 17.278623558566498
		Maximum residual: 2.417327114256076e-10
	Residual: 0.0025934355703301436

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13049    36  1913     2]
		Maximum f: 17.490395192093285
		Maximum residual: 3.848908769086853e-10
	Residual: 0.0007190854199344023

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13049    36  1913     2]
		Maximum f: 17.496370578262987
		Maximum residual: 1.8229511923598908e-10
	Residual: 4.3775839130928303e-07

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13049    36  1913     2]
		Maximum f: 17.496373236474987
		Maximum residual: 1.8229718120265876e-10
	Residual: 4.709769915609418e-13

Load increment #44, load: 21.460526315789473, initial residual: 0.0012698506902881048
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13013    38  1947     2]
		Maximum f: 17.714696002614733
		Maximum residual: 3.262502708104543e-10
	Residual: 0.002701869108258954

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13011    39  1948     2]
		Maximum f: 17.965123605797647
		Maximum residual: 2.8815665912207747e-10
	Residual: 0.0004216453628278521

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13011    38  1949     2]
		Maximum f: 17.969582107517958
		Maximum residual: 2.777914905746301e-10
	Residual: 9.472278112814814e-08

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [13011    38  1949     2]
		Maximum f: 17.9695838688886
		Maximum residual: 2.77782565679447e-10
	Residual: 2.321107371458391e-14

Load increment #45, load: 21.55263157894737, initial residual: 0.001269850690288401
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12971    39  1989     1]
		Maximum f: 18.277479535195035
		Maximum residual: 3.720438162153813e-10
	Residual: 0.00302469032701064

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12972    35  1991     2]
		Maximum f: 18.49745998031136
		Maximum residual: 2.853722479672918e-10
	Residual: 0.0013588818808835665

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12972    35  1991     2]
		Maximum f: 18.52518575789379
		Maximum residual: 2.77602345620237e-10
	Residual: 1.9054350811334565e-06

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12972    35  1991     2]
		Maximum f: 18.525206077922945
		Maximum residual: 2.775283295079526e-10
	Residual: 9.808783920843275e-12

Load increment #46, load: 21.644736842105264, initial residual: 0.0012698506903203349
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12934    32  2033     1]
		Maximum f: 18.864563613425286
		Maximum residual: 3.9230424087875377e-10
	Residual: 0.003426519411707304

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12928    36  2033     3]
		Maximum f: 19.125341983196126
		Maximum residual: 2.912354520697611e-10
	Residual: 0.0013763884970385853

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12928    31  2038     3]
		Maximum f: 19.15555009153013
		Maximum residual: 2.711152605171914e-10
	Residual: 1.4515798861607068e-06

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12928    31  2038     3]
		Maximum f: 19.155563782593568
		Maximum residual: 2.71149885612548e-10
	Residual: 4.780042944604455e-12

Load increment #47, load: 21.736842105263158, initial residual: 0.0012698506902895194
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12893    37  2069     1]
		Maximum f: 19.45980628178125
		Maximum residual: 2.7960469915910013e-10
	Residual: 0.002792620725643187

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12892    28  2079     1]
		Maximum f: 19.79391716906716
		Maximum residual: 2.9288185154740486e-10
	Residual: 0.0008070300959354558

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12892    26  2081     1]
		Maximum f: 19.804971741904655
		Maximum residual: 2.089264864959735e-10
	Residual: 3.955579937957151e-07

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12892    26  2081     1]
		Maximum f: 19.80497255701194
		Maximum residual: 2.0892757204222218e-10
	Residual: 2.5017785799583337e-13

Load increment #48, load: 21.82894736842105, initial residual: 0.00126985069028734
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12863    34  2101     2]
		Maximum f: 20.19353675186485
		Maximum residual: 1.8707329317885239e-10
	Residual: 0.0027907416726949256

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12859    36  2103     2]
		Maximum f: 20.46175466364585
		Maximum residual: 2.202903608383735e-10
	Residual: 0.00032650169469668184

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12859    36  2103     2]
		Maximum f: 20.468858529025166
		Maximum residual: 3.00033130928293e-10
	Residual: 5.679180809519861e-07

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12859    36  2103     2]
		Maximum f: 20.468865168425893
		Maximum residual: 3.000075923151847e-10
	Residual: 1.6534173101173173e-12

Load increment #49, load: 21.92105263157895, initial residual: 0.0012698506902833243
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12825    38  2135     2]
		Maximum f: 20.949676029049108
		Maximum residual: 2.965073992040903e-10
	Residual: 0.0024192662306188294

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12819    42  2137     2]
		Maximum f: 21.18795309429991
		Maximum residual: 4.6180415548632207e-10
	Residual: 0.0010026943670315587

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12819    41  2138     2]
		Maximum f: 21.221498194568618
		Maximum residual: 4.001382613030954e-10
	Residual: 2.7592988389517836e-06

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12819    41  2138     2]
		Maximum f: 21.22153592915037
		Maximum residual: 4.0003089334006846e-10
	Residual: 2.686877794265977e-11

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12819    41  2138     2]
		Maximum f: 21.2215359297028
		Maximum residual: 4.0003249488406636e-10
	Residual: 5.855829771291049e-15

Load increment #50, load: 22.013157894736842, initial residual: 0.0012698506902866077
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12782    45  2171     2]
		Maximum f: 21.6783888807352
		Maximum residual: 3.0516638240229204e-10
	Residual: 0.0038564025789666486

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12781    35  2181     3]
		Maximum f: 22.046334883954977
		Maximum residual: 5.156255268887793e-10
	Residual: 0.000263528690247095

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12781    35  2181     3]
		Maximum f: 22.052212612207125
		Maximum residual: 4.3995335057610405e-10
	Residual: 8.555994573217932e-07

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12781    35  2181     3]
		Maximum f: 22.052220810250237
		Maximum residual: 4.3993835769178695e-10
	Residual: 5.245780582215293e-12

Load increment #51, load: 22.105263157894736, initial residual: 0.0012698506902737478
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12730    40  2227     3]
		Maximum f: 22.506144453055807
		Maximum residual: 5.734290511693288e-10
	Residual: 0.004752901407093473

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12728    37  2231     4]
		Maximum f: 23.044394452951778
		Maximum residual: 2.3786183255812576e-10
	Residual: 0.0003002935751645465

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12728    35  2233     4]
		Maximum f: 23.0574031547856
		Maximum residual: 3.260009899342512e-10
	Residual: 6.047314234185686e-07

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12728    35  2233     4]
		Maximum f: 23.05740815912893
		Maximum residual: 3.260497624681028e-10
	Residual: 2.162564289448801e-12

Load increment #52, load: 22.19736842105263, initial residual: 0.0012698506902888527
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12686    38  2273     3]
		Maximum f: 23.562035145481484
		Maximum residual: 2.299756058125252e-10
	Residual: 0.0032275242506623752

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12678    36  2283     3]
		Maximum f: 24.036513056181317
		Maximum residual: 3.7914450377492174e-10
	Residual: 0.000594742197505832

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12678    34  2285     3]
		Maximum f: 24.064533536431977
		Maximum residual: 3.9250710723058666e-10
	Residual: 1.0603079241373934e-06

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12678    34  2285     3]
		Maximum f: 24.064552585951535
		Maximum residual: 3.920740069372708e-10
	Residual: 4.401271521048271e-12

Load increment #53, load: 22.289473684210527, initial residual: 0.001269850690301699
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12633    27  2336     4]
		Maximum f: 24.65965926912075
		Maximum residual: 2.373931046539078e-10
	Residual: 0.003939452091339807

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12627    31  2337     5]
		Maximum f: 25.276296226131706
		Maximum residual: 5.537395407863762e-10
	Residual: 0.0005439274179921635

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12627    30  2337     6]
		Maximum f: 25.3100396642238
		Maximum residual: 2.3360705323087356e-10
	Residual: 7.356620363823689e-07

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12627    30  2337     6]
		Maximum f: 25.310059505680044
		Maximum residual: 2.3377635475062556e-10
	Residual: 1.7540184713923856e-12

Load increment #54, load: 22.38157894736842, initial residual: 0.00126985069028829
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12580    31  2386     3]
		Maximum f: 26.050593036706456
		Maximum residual: 3.3220624600759926e-10
	Residual: 0.005078785269065983

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12574    24  2396     6]
		Maximum f: 26.784959586831427
		Maximum residual: 2.776074662211976e-10
	Residual: 0.00043096273393746463

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12574    25  2394     7]
		Maximum f: 26.796588199799306
		Maximum residual: 3.6851234314393446e-10
	Residual: 1.889029200175346e-07

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12574    25  2394     7]
		Maximum f: 26.79659238819626
		Maximum residual: 3.6853165202234656e-10
	Residual: 8.096592386611406e-14

Load increment #55, load: 22.473684210526315, initial residual: 0.0012698506902876778
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12513    32  2452     3]
		Maximum f: 27.393803183785266
		Maximum residual: 2.0552153399826106e-10
	Residual: 0.00493451711037449

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12503    32  2453    12]
		Maximum f: 28.52937815358708
		Maximum residual: 8.502488721314188e-10
	Residual: 0.0015749942110383417

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12504    30  2454    12]
		Maximum f: 28.60246963353509
		Maximum residual: 6.776913284945542e-10
	Residual: 0.0003533644816955238

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12504    30  2454    12]
		Maximum f: 28.604182642110597
		Maximum residual: 6.848096602165291e-10
	Residual: 4.712912152432788e-08

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12504    30  2454    12]
		Maximum f: 28.604183173807137
		Maximum residual: 6.848112593055709e-10
	Residual: 1.5106086837247922e-14

Load increment #56, load: 22.56578947368421, initial residual: 0.001269850690287766
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12424    37  2528    11]
		Maximum f: 30.201375466645466
		Maximum residual: 4.4443679115520274e-10
	Residual: 0.005528665482217568

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12418    36  2532    14]
		Maximum f: 32.19777060875148
		Maximum residual: 3.3235182201821935e-10
	Residual: 0.0028584452408076744

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12419    33  2533    15]
		Maximum f: 32.42294215005556
		Maximum residual: 5.318898911869827e-10
	Residual: 0.0006593251595243586

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12419    33  2533    15]
		Maximum f: 32.43887954815178
		Maximum residual: 6.080086776067936e-10
	Residual: 7.544769289557309e-07

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12419    33  2533    15]
		Maximum f: 32.438896567273254
		Maximum residual: 6.081108255801936e-10
	Residual: 2.6061316682184895e-12

Load increment #57, load: 22.657894736842106, initial residual: 0.0012698506902981722
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12293    32  2656    19]
		Maximum f: 35.14425352228363
		Maximum residual: 3.496234195834213e-10
	Residual: 0.008357821604262744

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12292    24  2649    35]
		Maximum f: 40.29081844578413
		Maximum residual: 1.576990256040779e-09
	Residual: 0.007395964036149111

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12283    23  2660    34]
		Maximum f: 40.99539836523154
		Maximum residual: 1.6515983389399281e-09
	Residual: 0.0008981157590161343

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12282    23  2660    35]
		Maximum f: 41.08028918388952
		Maximum residual: 1.702001244213463e-09
	Residual: 5.490401017715508e-05

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12282    23  2660    35]
		Maximum f: 41.08065705097812
		Maximum residual: 1.7021775649282016e-09
	Residual: 2.5604401448233753e-08

	Outer Newton iteration #5
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12282    23  2660    35]
		Maximum f: 41.08065705549767
		Maximum residual: 1.702176330003804e-09
	Residual: 1.1554726810935663e-14

Load increment #58, load: 22.75, initial residual: 0.0012698506902867053
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4]
		Counts of unique number of iterations: [12119    26  2790    65]
		Maximum f: 50.14669383688169
		Maximum residual: 3.1994009491809425e-09
	Residual: 0.01125547255201783

	Outer Newton iteration #1
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [12102    22  2709   164     3]
		Maximum f: 75.49974796139144
		Maximum residual: 2.624166197076301e-09
	Residual: 0.01613850129093001

	Outer Newton iteration #2
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [12094    16  2352   521    17]
		Maximum f: 112.38886200504191
		Maximum residual: 3.439952059714869e-09
	Residual: 0.010057266896525136

	Outer Newton iteration #3
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [12087    18  2231   645    19]
		Maximum f: 126.1851339816193
		Maximum residual: 2.887067526737981e-09
	Residual: 0.0046928305823081854

	Outer Newton iteration #4
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [12084    19  2211   667    19]
		Maximum f: 130.4294891396616
		Maximum residual: 2.2795034419890342e-09
	Residual: 0.0036394006770381912

	Outer Newton iteration #5
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [12082    17  2208   673    20]
		Maximum f: 131.97346778200122
		Maximum residual: 1.7875952731366231e-09
	Residual: 0.000888617865595232

	Outer Newton iteration #6
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [12082    17  2207   674    20]
		Maximum f: 132.43275844109672
		Maximum residual: 5.047734345022004e-09
	Residual: 1.8959646289737513e-05

	Outer Newton iteration #7
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [12082    17  2207   674    20]
		Maximum f: 132.43735152957282
		Maximum residual: 5.044392878023876e-09
	Residual: 4.714585905734586e-09

	Outer Newton iteration #8
	Inner Newton summary:
		Unique number of iterations: [1 2 3 4 5]
		Counts of unique number of iterations: [12082    17  2207   674    20]
		Maximum f: 132.4373528199075
		Maximum residual: 5.04439014693748e-09
	Residual: 2.7591921386899504e-14

Slope stability factor: 6.594202898550725

Verification#

Critical load#

According to Chen and Liu [1990], we can derive analytically the slope stability factor \(l_\text{lim}\) for the standard Mohr-Coulomb plasticity model (without apex smoothing) under plane strain assumption for associative plastic flow

\[ l_\text{lim} = \gamma_\text{lim} H/c, \]

where \(\gamma_\text{lim}\) is an associated value of the soil self-weight. In particular, for the rectangular slope with the friction angle \(\phi\) equal to \(30^\circ\), \(l_\text{lim} = 6.69\) [Chen and Liu, 1990]. Thus, by computing \(\gamma_\text{lim}\) from the formula above, we can progressively increase the second component of the gravitational body force \(\boldsymbol{q}=[0, -\gamma]^T\), up to the critical value \(\gamma_\text{lim}^\text{num}\), when the perfect plasticity plateau is reached on the loading-displacement curve at the \((0, H)\) point and then compare \(\gamma_\text{lim}^\text{num}\) against analytical \(\gamma_\text{lim}\).

By demonstrating the loading-displacement curve on the figure below we approve that the yield strength limit reached for \(\gamma_\text{lim}^\text{num}\) is close to \(\gamma_\text{lim}\).

if len(points_on_process) > 0:
    l_lim = 6.69
    gamma_lim = l_lim / H * c
    plt.plot(results[:, 0], results[:, 1], "o-", label=r"$\gamma$")
    plt.axhline(y=gamma_lim, color="r", linestyle="--", label=r"$\gamma_\text{lim}$")
    plt.xlabel(r"Displacement of the slope $u_x$ at $(0, H)$ [mm]")
    plt.ylabel(r"Soil self-weight $\gamma$ [MPa/mm$^3$]")
    plt.grid()
    plt.legend()
../_images/04d2e14c2494e2161faafdd2fb0ded9db7ee495e32f261154b978724e90829ba.png

The slope profile reaching its stability limit:

try:
    import pyvista

    print(pyvista.global_theme.jupyter_backend)
    import dolfinx.plot

    pyvista.start_xvfb(0.1)

    W = fem.functionspace(domain, ("Lagrange", 1, (gdim,)))
    u_tmp = fem.Function(W, name="Displacement")
    u_tmp.interpolate(u)

    pyvista.start_xvfb()
    plotter = pyvista.Plotter(window_size=[600, 400])
    topology, cell_types, x = dolfinx.plot.vtk_mesh(domain)
    grid = pyvista.UnstructuredGrid(topology, cell_types, x)
    vals = np.zeros((x.shape[0], 3))
    vals[:, : len(u_tmp)] = u_tmp.x.array.reshape((x.shape[0], len(u_tmp)))
    grid["u"] = vals
    warped = grid.warp_by_vector("u", factor=20)
    plotter.add_text("Displacement field", font_size=11)
    plotter.add_mesh(warped, show_edges=False, show_scalar_bar=True)
    plotter.view_xy()
    plotter.show()
except ImportError:
    print("pyvista required for this plot")
static
error: XDG_RUNTIME_DIR is invalid or not set in the environment.
MESA: error: ZINK: failed to choose pdev
glx: failed to create drisw screen
../_images/8a49985188b263e9ae1e2547e1bbdadbe0e2beb477fb60b8d0f46a771c5f5b5e.png

Yield surface#

We verify that the constitutive model is correctly implemented by tracing the yield surface. We generate several stress paths and check whether they remain within the Mohr-Coulomb yield surface. The stress tracing is performed in the Haigh-Westergaard coordinates \((\xi, \rho, \theta)\) which are defined as follows

\[ \xi = \frac{1}{\sqrt{3}}I_1, \quad \rho = \sqrt{2J_2}, \quad \sin(3\theta) = -\frac{3\sqrt{3}}{2} \frac{J_3}{J_2^{3/2}}, \]

where \(J_3(\boldsymbol{\sigma}) = \det(\boldsymbol{s})\) is the third invariant of the deviatoric part of the stress tensor, \(\xi\) is the deviatoric coordinate, \(\rho\) is the radial coordinate and the angle \(\theta \in [-\frac{\pi}{6}, \frac{\pi}{6}]\) is called Lode or stress angle.

To generate the stress paths we use the principal stresses formula written in Haigh-Westergaard coordinates as follows

\[\begin{split} \begin{pmatrix} \sigma_{I} \\ \sigma_{II} \\ \sigma_{III} \end{pmatrix} = p \begin{pmatrix} 1 \\ 1 \\ 1 \end{pmatrix} + \frac{\rho}{\sqrt{2}} \begin{pmatrix} \cos\theta + \frac{\sin\theta}{\sqrt{3}} \\ -\frac{2\sin\theta}{\sqrt{3}} \\ \frac{\sin\theta}{\sqrt{3}} - \cos\theta \end{pmatrix}, \end{split}\]

where \(p = \xi/\sqrt{3}\) is a hydrostatic variable and \(\sigma_{I} \geq \sigma_{II} \geq \sigma_{III}\).

Now we generate the loading path by evaluating principal stresses in Haigh-Westergaard coordinates for the Lode angle \(\theta\) being varied from \(-\frac{\pi}{6}\) to \(\frac{\pi}{6}\) with fixed \(\rho\) and \(p\).

N_angles = 50
N_loads = 9  # number of loadings or paths
eps = 0.00001
R = 0.7  # fix the values of rho
p = 0.1  # fix the deviatoric coordinate
theta_1 = -np.pi / 6
theta_2 = np.pi / 6

theta_values = np.linspace(theta_1 + eps, theta_2 - eps, N_angles)
theta_returned = np.empty((N_loads, N_angles))
rho_returned = np.empty((N_loads, N_angles))
sigma_returned = np.empty((N_loads, N_angles, stress_dim))

# fix an increment of the stress path
dsigma_path = np.zeros((N_angles, stress_dim))
dsigma_path[:, 0] = (R / np.sqrt(2)) * (np.cos(theta_values) + np.sin(theta_values) / np.sqrt(3))
dsigma_path[:, 1] = (R / np.sqrt(2)) * (-2 * np.sin(theta_values) / np.sqrt(3))
dsigma_path[:, 2] = (R / np.sqrt(2)) * (np.sin(theta_values) / np.sqrt(3) - np.cos(theta_values))

sigma_n_local = np.zeros_like(dsigma_path)
sigma_n_local[:, 0] = p
sigma_n_local[:, 1] = p
sigma_n_local[:, 2] = p
derviatoric_axis = tr

Then, we define and vectorize functions rho, Lode_angle and sigma_tracing evaluating respectively the coordinates \(\rho\), \(\theta\) and the corrected (or “returned”) stress tensor for a certain stress state. sigma_tracing calls the function return_mapping, where the constitutive model was defined via JAX previously.

def rho(sigma_local):
    s = dev @ sigma_local
    return jnp.sqrt(2.0 * J2(s))


def Lode_angle(sigma_local):
    s = dev @ sigma_local
    arg = -(3.0 * jnp.sqrt(3.0) * J3(s)) / (2.0 * jnp.sqrt(J2(s) * J2(s) * J2(s)))
    arg = jnp.clip(arg, -1.0, 1.0)
    angle = 1.0 / 3.0 * jnp.arcsin(arg)
    return angle


def sigma_tracing(sigma_local, sigma_n_local):
    deps_elas = S_elas @ sigma_local
    sigma_corrected, state = return_mapping(deps_elas, sigma_n_local)
    yielding = state[2]
    return sigma_corrected, yielding


Lode_angle_v = jax.jit(jax.vmap(Lode_angle, in_axes=(0)))
rho_v = jax.jit(jax.vmap(rho, in_axes=(0)))
sigma_tracing_v = jax.jit(jax.vmap(sigma_tracing, in_axes=(0, 0)))

For each stress path, we call the function sigma_tracing_v to get the corrected stress state and then we project it onto the deviatoric plane \((\rho, \theta)\) with a fixed value of \(p\).

for i in range(N_loads):
    print(f"Loading path#{i}")
    dsigma, yielding = sigma_tracing_v(dsigma_path, sigma_n_local)
    dp = dsigma @ tr / 3.0 - p
    dsigma -= np.outer(dp, derviatoric_axis)  # projection on the same deviatoric plane

    sigma_returned[i, :] = dsigma
    theta_returned[i, :] = Lode_angle_v(dsigma)
    rho_returned[i, :] = rho_v(dsigma)
    print(f"max f: {jnp.max(yielding)}\n")
    sigma_n_local[:] = dsigma
Loading path#0
max f: -2.005661796528811

Loading path#1
max f: -1.6474140052837287

Loading path#2
max f: -1.208026577509537

Loading path#3
max f: -0.7355419142072792

Loading path#4
max f: -0.24734105482689195

Loading path#5
max f: 0.24936204365846004

Loading path#6
max f: 0.5729074243253174

Loading path#7
max f: 0.6673099622873986

Loading path#8
max f: 0.6947128474362492

Then, by knowing the expression of the standrad Mohr-Coulomb yield surface in principle stresses, we can obtain an analogue expression in Haigh-Westergaard coordinates, which leads us to the following equation:

(7)#\[ \frac{\rho}{\sqrt{6}}(\sqrt{3}\cos\theta + \sin\phi \sin\theta) - p\sin\phi - c\cos\phi= 0. \]

Thus, we restore the standard Mohr-Coulomb yield surface:

def MC_yield_surface(theta_, p):
    """Restores the coordinate `rho` satisfying the standard Mohr-Coulomb yield
    criterion."""
    rho = (np.sqrt(2) * (c * np.cos(phi) + p * np.sin(phi))) / (
        np.cos(theta_) - np.sin(phi) * np.sin(theta_) / np.sqrt(3)
    )
    return rho


rho_standard_MC = MC_yield_surface(theta_values, p)

Finally, we plot the yield surface:

colormap = cm.plasma
colors = colormap(np.linspace(0.0, 1.0, N_loads))

fig, ax = plt.subplots(subplot_kw={"projection": "polar"}, figsize=(8, 8))
# Mohr-Coulomb yield surface with apex smoothing
for i, color in enumerate(colors):
    rho_total = np.array([])
    theta_total = np.array([])
    for j in range(12):
        angles = j * np.pi / 3 - j % 2 * theta_returned[i] + (1 - j % 2) * theta_returned[i]
        theta_total = np.concatenate([theta_total, angles])
        rho_total = np.concatenate([rho_total, rho_returned[i]])

    ax.plot(theta_total, rho_total, ".", color=color)

# standard Mohr-Coulomb yield surface
theta_standard_MC_total = np.array([])
rho_standard_MC_total = np.array([])
for j in range(12):
    angles = j * np.pi / 3 - j % 2 * theta_values + (1 - j % 2) * theta_values
    theta_standard_MC_total = np.concatenate([theta_standard_MC_total, angles])
    rho_standard_MC_total = np.concatenate([rho_standard_MC_total, rho_standard_MC])
ax.plot(theta_standard_MC_total, rho_standard_MC_total, "-", color="black")
ax.set_yticklabels([])

norm = mcolors.Normalize(vmin=0.1, vmax=0.7 * 9)
sm = plt.cm.ScalarMappable(cmap=colormap, norm=norm)
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax, orientation="vertical")
cbar.set_label(r"Magnitude of the stress path deviator, $\rho$ [MPa]")

plt.show()
../_images/3c1b1b272131377e9d1dc4473a944daedb2eadf85f0cbe46f3e6ea3dac7f3650.png

Each colour represents one loading path. The circles are associated with the loading during the elastic phase. Once the loading reaches the elastic limit, the circles start outlining the yield surface, which in the limit lay along the standard Mohr-Coulomb one without smoothing (black contour).

Taylor test#

Here, we perform a Taylor test to check that the form \(F\) and its Jacobian \(J\) are consistent zeroth- and first-order approximations of the residual \(F\). In particular, the test verifies that the program dsigma_ddeps_vec obtained by the JAX’s AD returns correct values of the external operator \(\boldsymbol{\sigma}\) and its derivative \(\boldsymbol{C}_\text{tang}\), which define \(F\) and \(J\) respectively.

To perform the test, we introduce the operators \(\mathcal{F}: V \rightarrow V^\prime\) and \(\mathcal{J}: V \rightarrow \mathcal{L}(V, V^\prime)\) defined as follows:

\[ \langle \mathcal{F}(\boldsymbol{u}), \boldsymbol{v} \rangle := F(\boldsymbol{u}; \boldsymbol{v}), \quad \forall \boldsymbol{v} \in V, \]
\[ \langle (\mathcal{J}(\boldsymbol{u}))(k\boldsymbol{\delta u}), \boldsymbol{v} \rangle := J(\boldsymbol{u}; k\boldsymbol{\delta u}, \boldsymbol{v}), \quad \forall \boldsymbol{v} \in V, \]

where \(V^\prime\) is a dual space of \(V\), \(\langle \cdot, \cdot \rangle\) is the \(V^\prime \times V\) duality pairing and \(\mathcal{L}(V, V^\prime)\) is a space of bounded linear operators from \(V\) to its dual.

Then, by following the Taylor’s theorem on Banach spaces and perturbating the functional \(\mathcal{F}\) in the direction \(k \, \boldsymbol{δu} \in V\) for \(k > 0\), the zeroth and first order Taylor reminders \(r_k^0\) and \(r_k^1\) have the following mesh-independent convergence rates in the dual space \(V^\prime\):

(8)#\[ \| r_k^0 \|_{V^\prime} := \| \mathcal{F}(\boldsymbol{u} + k \, \boldsymbol{\delta u}) - \mathcal{F}(\boldsymbol{u}) \|_{V^\prime} \longrightarrow 0 \text{ at } O(k), \]
(9)#\[ \| r_k^1 \|_{V^\prime} := \| \mathcal{F}(\boldsymbol{u} + k \, \boldsymbol{\delta u}) - \mathcal{F}(\boldsymbol{u}) - \, (\mathcal{J}(\boldsymbol{u}))(k\boldsymbol{\delta u}) \|_{V^\prime} \longrightarrow 0 \text{ at } O(k^2). \]

In order to compute the norm of an element \(f \in V^\prime\) from the dual space \(V^\prime\), we apply the Riesz representation theorem, which states that there is a linear isometric isomorphism \(\mathcal{R} : V^\prime \to V\), which associates a linear functional \(f\) with a unique element \(\mathcal{R} f = \boldsymbol{u} \in V\). In practice, within a finite subspace \(V_h \subset V\), the Riesz map \(\mathcal{R}\) is represented by the matrix \(\mathsf{L}^{-1}\), the inverse of the Laplacian operator [Kirby, 2010]

\[ \mathsf{L}_{ij} = \int\limits_\Omega \nabla\varphi_i \cdot \nabla\varphi_j \mathrm{d} x , \quad i,j = 1, \dots, n, \]

where \(\{\varphi_i\}_{i=1}^{\dim V_h}\) is a set of basis function of the space \(V_h\).

If the Euclidean vectors \(\mathsf{r}_k^i \in \mathbb{R}^{\dim V_h}, \, i \in \{0,1\}\) represent the Taylor remainders from (8)(9) in the finite space, then the dual norms are computed through the following formula [Kirby, 2010]

(10)#\[ \| r_k^i \|^2_{V^\prime_h} = (\mathsf{r}_k^i)^T \mathsf{L}^{-1} \mathsf{r}_k^i, \quad i \in \{0,1\}. \]

In practice, the vectors \(\mathsf{r}_k^i\) are defined through the residual vector \(\mathsf{F} \in \mathbb{R}^{\dim V_h}\) and the Jacobian matrix \(\mathsf{J} \in \mathbb{R}^{\dim V_h\times\dim V_h}\)

(11)#\[ \mathsf{r}_k^0 = \mathsf{F}(\mathsf{u} + k \, \mathsf{\delta u}) - \mathsf{F}(\mathsf{u}) \in \mathbb{R}^n, \]
(12)#\[ \mathsf{r}_k^1 = \mathsf{F}(\mathsf{u} + k \, \mathsf{\delta u}) - \mathsf{F}(\mathsf{u}) - \, \mathsf{J}(\mathsf{u}) \cdot k\mathsf{\delta u} \in \mathbb{R}^n, \]

where \(\mathsf{u} \in \mathbb{R}^{\dim V_h}\) and \(\mathsf{\delta u} \in \mathbb{R}^{\dim V_h}\) represent dispacement fields \(\boldsymbol{u} \in V_h\) and \(\boldsymbol{\delta u} \in V_h\).

Now we can proceed with the Taylor test implementation. Let us first start with defining the Laplace operator.

L_form = fem.form(ufl.inner(ufl.grad(u_hat), ufl.grad(v)) * ufl.dx)
L = fem.petsc.assemble_matrix(L_form, bcs=bcs)
L.assemble()
Riesz_solver = PETSc.KSP().create(domain.comm)
Riesz_solver.setType("preonly")
Riesz_solver.getPC().setType("lu")
Riesz_solver.setOperators(L)
y = fem.Function(V, name="Riesz_representer_of_r")  # r - a Taylor remainder

Now we initialize main variables of the plasticity problem.

# Reset main variables to zero including the external operators values
sigma_n.x.array[:] = 0.0
sigma.ref_coefficient.x.array[:] = 0.0
J_external_operators[0].ref_coefficient.x.array[:] = 0.0
# Reset the values of the consistent tangent matrix to elastic moduli
Du.x.array[:] = 1.0
evaluated_operands = evaluate_operands(F_external_operators)
_ = evaluate_external_operators(J_external_operators, evaluated_operands)
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [15000]
		Maximum f: -2.2109628558449494
		Maximum residual: 0.0

As the derivatives of the constitutive model are different for elastic and plastic phases, we must consider two initial states for the Taylor test. For this reason, we solve the problem once for a certain loading value to get the initial state close to the one with plastic deformations but still remain in the elastic phase.

i = 0
load = 2.0
q.value = load * np.array([0, -gamma])
external_operator_problem.assemble_vector()

residual_0 = external_operator_problem.b.norm()
residual = residual_0
Du.x.array[:] = 0

if MPI.COMM_WORLD.rank == 0:
    print(f"Load increment #{i}, load: {load}, initial residual: {residual_0}")

for iteration in range(0, max_iterations):
    if residual / residual_0 < relative_tolerance:
        break

    if MPI.COMM_WORLD.rank == 0:
        print(f"\tOuter Newton iteration #{iteration}")
    external_operator_problem.assemble_matrix()
    external_operator_problem.solve(du)

    Du.x.petsc_vec.axpy(1.0, du.x.petsc_vec)
    Du.x.scatter_forward()

    evaluated_operands = evaluate_operands(F_external_operators)
    ((_, sigma_new),) = evaluate_external_operators(J_external_operators, evaluated_operands)

    sigma.ref_coefficient.x.array[:] = sigma_new

    external_operator_problem.assemble_vector()
    residual = external_operator_problem.b.norm()

    if MPI.COMM_WORLD.rank == 0:
        print(f"\tResidual: {residual}\n")

sigma_n.x.array[:] = sigma.ref_coefficient.x.array

# Initial values of the displacement field and the stress state for the Taylor
# test
Du0 = np.copy(Du.x.array)
sigma_n0 = np.copy(sigma_n.x.array)
Load increment #0, load: 2.0, initial residual: 0.027573900703382316
	Outer Newton iteration #0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [15000]
		Maximum f: -0.2983932974714918
		Maximum residual: 0.0
	Residual: 7.890919449890783e-14

If we take into account the initial stress state sigma_n0 computed in the cell above, we perform the Taylor test for the plastic phase, otherwise we stay in the elastic one.

Finally, we define the function perform_Taylor_test, which returns the norms of the Taylor reminders in dual space (10)(12).

k_list = np.logspace(-2.0, -6.0, 5)[::-1]


def perform_Taylor_test(Du0, sigma_n0):
    # r0 = F(Du0 + k*δu) - F(Du0)
    # r1 = F(Du0 + k*δu) - F(Du0) - k*J(Du0)*δu
    Du.x.array[:] = Du0
    sigma_n.x.array[:] = sigma_n0
    evaluated_operands = evaluate_operands(F_external_operators)
    ((_, sigma_new),) = evaluate_external_operators(J_external_operators, evaluated_operands)
    sigma.ref_coefficient.x.array[:] = sigma_new

    F0 = fem.petsc.assemble_vector(F_form)  # F(Du0)
    F0.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
    fem.set_bc(F0, bcs)

    J0 = fem.petsc.assemble_matrix(J_form, bcs=bcs)
    J0.assemble()  # J(Du0)
    Ju = J0.createVecLeft()  # Ju = J0 @ u

    δu = fem.Function(V)
    δu.x.array[:] = Du0  # δu == Du0

    zero_order_remainder = np.zeros_like(k_list)
    first_order_remainder = np.zeros_like(k_list)

    for i, k in enumerate(k_list):
        Du.x.array[:] = Du0 + k * δu.x.array
        evaluated_operands = evaluate_operands(F_external_operators)
        ((_, sigma_new),) = evaluate_external_operators(J_external_operators, evaluated_operands)
        sigma.ref_coefficient.x.array[:] = sigma_new

        F_delta = fem.petsc.assemble_vector(F_form)  # F(Du0 + h*δu)
        F_delta.ghostUpdate(addv=PETSc.InsertMode.ADD, mode=PETSc.ScatterMode.REVERSE)
        fem.set_bc(F_delta, bcs)

        J0.mult(δu.x.petsc_vec, Ju)  # Ju = J(Du0)*δu
        Ju.scale(k)  # Ju = k*Ju

        r0 = F_delta - F0
        r1 = F_delta - F0 - Ju

        Riesz_solver.solve(r0, y.x.petsc_vec)  # y = L^{-1} r0
        y.x.scatter_forward()
        zero_order_remainder[i] = np.sqrt(r0.dot(y.x.petsc_vec))  # sqrt{r0^T L^{-1} r0}

        Riesz_solver.solve(r1, y.x.petsc_vec)  # y = L^{-1} r1
        y.x.scatter_forward()
        first_order_remainder[i] = np.sqrt(r1.dot(y.x.petsc_vec))  # sqrt{r1^T L^{-1} r1}

    return zero_order_remainder, first_order_remainder


print("Elastic phase")
zero_order_remainder_elastic, first_order_remainder_elastic = perform_Taylor_test(Du0, 0.0)
print("Plastic phase")
zero_order_remainder_plastic, first_order_remainder_plastic = perform_Taylor_test(Du0, sigma_n0)
Elastic phase
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [15000]
		Maximum f: -0.2983932974714918
		Maximum residual: 0.0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [15000]
		Maximum f: -0.29839091658838823
		Maximum residual: 0.0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [15000]
		Maximum f: -0.29836948862876245
		Maximum residual: 0.0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [15000]
		Maximum f: -0.2981552078749816
		Maximum residual: 0.0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [15000]
		Maximum f: -0.2960122846812845
		Maximum residual: 0.0
	Inner Newton summary:
		Unique number of iterations: [1]
		Counts of unique number of iterations: [15000]
		Maximum f: -0.2745715837020426
		Maximum residual: 0.0
Plastic phase
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [14989     5     6]
		Maximum f: 2.1523374160661706
		Maximum residual: 2.191449148030456e-10
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [14989     5     6]
		Maximum f: 2.15233990398981
		Maximum residual: 2.1914579313521028e-10
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [14989     5     6]
		Maximum f: 2.152362295304328
		Maximum residual: 2.1915488686906374e-10
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [14989     5     6]
		Maximum f: 2.152586208624277
		Maximum residual: 2.1924029027367768e-10
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [14989     5     6]
		Maximum f: 2.154825359292723
		Maximum residual: 2.200976411456201e-10
	Inner Newton summary:
		Unique number of iterations: [1 3 4]
		Counts of unique number of iterations: [14989     5     6]
		Maximum f: 2.1772186045752187
		Maximum residual: 2.2878007480990514e-10
fig, axs = plt.subplots(1, 2, figsize=(10, 5))

axs[0].loglog(k_list, zero_order_remainder_elastic, "o-", label=r"$\|r_k^0\|_{V^\prime}$")
axs[0].loglog(k_list, first_order_remainder_elastic, "o-", label=r"$\|r_k^1\|_{V^\prime}$")
annotation.slope_marker((2e-4, 5e-5), 1, ax=axs[0], poly_kwargs={"facecolor": "tab:blue"})
axs[0].text(0.5, -0.2, "(a) Elastic phase", transform=axs[0].transAxes, ha="center", va="top")

axs[1].loglog(k_list, zero_order_remainder_plastic, "o-", label=r"$\|r_k^0\|_{V^\prime}$")
annotation.slope_marker((2e-4, 5e-5), 1, ax=axs[1], poly_kwargs={"facecolor": "tab:blue"})
axs[1].loglog(k_list, first_order_remainder_plastic, "o-", label=r"$\|r_k^1\|_{V^\prime}$")
annotation.slope_marker((2e-4, 5e-13), 2, ax=axs[1], poly_kwargs={"facecolor": "tab:orange"})
axs[1].text(0.5, -0.2, "(b) Plastic phase", transform=axs[1].transAxes, ha="center", va="top")

for i in range(2):
    axs[i].set_xlabel("k")
    axs[i].set_ylabel("Taylor remainder norm")
    axs[i].legend()
    axs[i].grid()

plt.tight_layout()
plt.show()

first_order_rate = np.polyfit(np.log(k_list), np.log(zero_order_remainder_elastic), 1)[0]
second_order_rate = np.polyfit(np.log(k_list), np.log(first_order_remainder_elastic), 1)[0]
print(f"Elastic phase:\n\tthe 1st order rate = {first_order_rate:.2f}\n\tthe 2nd order rate = {second_order_rate:.2f}")
first_order_rate = np.polyfit(np.log(k_list), np.log(zero_order_remainder_plastic), 1)[0]
second_order_rate = np.polyfit(np.log(k_list[1:]), np.log(first_order_remainder_plastic[1:]), 1)[0]
print(f"Plastic phase:\n\tthe 1st order rate = {first_order_rate:.2f}\n\tthe 2nd order rate = {second_order_rate:.2f}")
../_images/9472ac1867cad598b36c689b6084523d90708072e37a120e8805f458f5e6e107.png
Elastic phase:
	the 1st order rate = 1.00
	the 2nd order rate = 0.00
Plastic phase:
	the 1st order rate = 1.00
	the 2nd order rate = 1.90

For the elastic phase (a) the zeroth-order Taylor remainder \(r_k^0\) achieves the first-order convergence rate, whereas the first-order remainder \(r_k^1\) is computed at the level of machine precision due to the constant Jacobian. Similarly to the elastic flow, the zeroth-order Taylor remainder \(r_k^0\) of the plastic phase (b) reaches the first-order convergence, whereas the first-order remainder \(r_k^1\) achieves the second-order convergence rate, as expected.

References#

[CL90] (1,2,3)

W. F. Chen and X. L. Liu. Limit Analysis in Soil Mechanics. Volume 52. Elsevier Science, 1990.

[Kir10] (1,2)

Robert C. Kirby. From Functional Analysis to Iterative Methods. SIAM Review, 52(2):269–293, 2010. doi:10.1137/070706914.